Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global improvement for wasm_thread #13

Merged
merged 3 commits into from
Mar 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 188 additions & 48 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@ use std::any::Any;
use std::fmt;
use std::mem;

use std::sync::Mutex;
pub use std::thread::{current, sleep, Result, Thread, ThreadId};
use std::{
marker::PhantomData,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
time::Duration,
};

use wasm_bindgen::prelude::*;
use wasm_bindgen::*;
Expand All @@ -20,6 +30,9 @@ extern "C" {
fn load_module_workers_polyfill();
}

type DefaultBuilder = Mutex<Option<Builder>>;
static DEFAULT_BUILDER: DefaultBuilder = Mutex::new(None);

/// Extracts path of the `wasm_bindgen` generated .js shim script
pub fn get_wasm_bindgen_shim_script_path() -> String {
js_sys::eval(include_str!("script_path.js"))
Expand All @@ -30,42 +43,40 @@ pub fn get_wasm_bindgen_shim_script_path() -> String {

/// Generates worker entry script as URL encoded blob
pub fn get_worker_script(wasm_bindgen_shim_url: Option<String>) -> String {
unsafe {
static mut SCRIPT_URL: Option<String> = None;
static mut SCRIPT_URL: Option<String> = None;

if let Some(url) = SCRIPT_URL.as_ref() {
url.clone()
} else {
// If wasm bindgen shim url is not provided, try to obtain one automatically
let wasm_bindgen_shim_url =
wasm_bindgen_shim_url.unwrap_or_else(get_wasm_bindgen_shim_script_path);

// Generate script from template
let template;
#[cfg(feature = "es_modules")]
{
template = include_str!("web_worker_module.js");
}
#[cfg(not(feature = "es_modules"))]
{
template = include_str!("web_worker.js");
}
let script = template.replace("WASM_BINDGEN_SHIM_URL", &wasm_bindgen_shim_url);

// Create url encoded blob
let arr = js_sys::Array::new();
arr.set(0, JsValue::from_str(&script));
let blob = Blob::new_with_str_sequence(&arr).unwrap();
let url = Url::create_object_url_with_blob(
&blob
.slice_with_f64_and_f64_and_content_type(0.0, blob.size(), "text/javascript")
.unwrap(),
)
.unwrap();
SCRIPT_URL = Some(url.clone());
if let Some(url) = unsafe { SCRIPT_URL.as_ref() } {
url.clone()
} else {
// If wasm bindgen shim url is not provided, try to obtain one automatically
let wasm_bindgen_shim_url =
wasm_bindgen_shim_url.unwrap_or_else(get_wasm_bindgen_shim_script_path);

url
// Generate script from template
let template;
#[cfg(feature = "es_modules")]
{
template = include_str!("web_worker_module.js");
}
#[cfg(not(feature = "es_modules"))]
{
template = include_str!("web_worker.js");
}
let script = template.replace("WASM_BINDGEN_SHIM_URL", &wasm_bindgen_shim_url);

// Create url encoded blob
let arr = js_sys::Array::new();
arr.set(0, JsValue::from_str(&script));
let blob = Blob::new_with_str_sequence(&arr).unwrap();
let url = Url::create_object_url_with_blob(
&blob
.slice_with_f64_and_f64_and_content_type(0.0, blob.size(), "text/javascript")
.unwrap(),
)
.unwrap();
unsafe { SCRIPT_URL = Some(url.clone()) };

url
}
}

Expand Down Expand Up @@ -98,24 +109,24 @@ enum WorkerMessage {
impl WorkerMessage {
pub fn post(self) {
let req = Box::new(self);
unsafe {
js_sys::eval("self")
.unwrap()
.dyn_into::<DedicatedWorkerGlobalScope>()
.unwrap()
.post_message(&JsValue::from(std::mem::transmute::<_, f64>(
Box::into_raw(req) as u64,
)))
.unwrap();
}
let req = unsafe { std::mem::transmute::<_, f64>(Box::into_raw(req) as u64) };

js_sys::eval("self")
.unwrap()
.dyn_into::<DedicatedWorkerGlobalScope>()
.unwrap()
.post_message(&JsValue::from(req))
.unwrap();
}
}

/// Thread factory, which can be used in order to configure the properties of a new thread.
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct Builder {
// A name for the thread-to-be, for identification in panic messages
name: Option<String>,
// A prefix for the thread-to-be, for identification in panic messages
prefix: Option<String>,
// The size of the stack for the spawned thread in bytes
stack_size: Option<usize>,
// Url of the `wasm_bindgen` generated shim `.js` script to use as web worker entry point
Expand All @@ -126,7 +137,18 @@ impl Builder {
/// Generates the base configuration for spawning a thread, from which
/// configuration methods can be chained.
pub fn new() -> Builder {
Builder::default()
let default_builder = DEFAULT_BUILDER.lock().unwrap().clone();
default_builder.unwrap_or(Builder::default())
}

pub fn set_default(self) {
*DEFAULT_BUILDER.lock().unwrap() = Some(self);
}

/// Sets the prefix of the thread-to-be.
pub fn prefix(mut self, prefix: String) -> Builder {
self.prefix = Some(prefix);
self
}

/// Names the thread-to-be.
Expand Down Expand Up @@ -158,6 +180,21 @@ impl Builder {
unsafe { self.spawn_unchecked(f) }
}

pub fn spawn_scoped<'scope, 'env, F, T>(
self,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> std::io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(
unsafe { self.spawn_unchecked(f) }?,
PhantomData,
))
}

/// Spawns a new thread without any lifetime restrictions by taking ownership
/// of the `Builder`, and returns an [`io::Result`] to its [`JoinHandle`].
///
Expand Down Expand Up @@ -210,6 +247,7 @@ impl Builder {
unsafe fn spawn_for_context(self, ctx: WebWorkerContext) {
let Builder {
name,
prefix,
wasm_bindgen_shim_url,
..
} = self;
Expand All @@ -219,9 +257,20 @@ impl Builder {

// Todo: figure out how to set stack size
let mut options = WorkerOptions::new();
if let Some(name) = name {
options.name(&name);
}
match (name, prefix) {
(Some(name), Some(prefix)) => {
options.name(&format!("{}:{}", prefix, name));
}
(Some(name), None) => {
options.name(&name);
}
(None, Some(prefix)) => {
let random = (js_sys::Math::random() * 10e10) as u64;
options.name(&format!("{}:{}", prefix, random));
}
(None, None) => {}
};

#[cfg(feature = "es_modules")]
{
load_module_workers_polyfill();
Expand Down Expand Up @@ -328,3 +377,94 @@ where
{
Builder::new().spawn(f).expect("failed to spawn thread")
}

use core::num::NonZeroUsize;
pub fn available_parallelism() -> std::io::Result<NonZeroUsize> {
// TODO: Use [Navigator::hardware_concurrency](https://rustwasm.github.io/wasm-bindgen/api/web_sys/struct.Navigator.html#method.hardware_concurrency)
Ok(NonZeroUsize::new(8).unwrap())
}

pub struct ScopeData {
num_running_threads: AtomicUsize,
a_thread_panicked: AtomicBool,
main_thread: Thread,
}

pub struct Scope<'scope, 'env: 'scope> {
data: Arc<ScopeData>,
/// Invariance over 'scope, to make sure 'scope cannot shrink,
/// which is necessary for soundness.
///
/// Without invariance, this would compile fine but be unsound:
///
/// ```compile_fail,E0373
/// std::thread::scope(|s| {
/// s.spawn(|| {
/// let a = String::from("abcd");
/// s.spawn(|| println!("{a:?}")); // might run after `a` is dropped
/// });
/// });
/// ```
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}

pub fn scope<'env, F, T>(f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
// We put the `ScopeData` into an `Arc` so that other threads can finish their
// `decrement_num_running_threads` even after this function returns.
let scope = Scope {
data: Arc::new(ScopeData {
num_running_threads: AtomicUsize::new(0),
main_thread: current(),
a_thread_panicked: AtomicBool::new(false),
}),
env: PhantomData,
scope: PhantomData,
};

// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));

// Wait until all the threads are finished.
while scope.data.num_running_threads.load(Ordering::Acquire) != 0 {
// park();
// TODO: Replaced by a wasm-friendly version of park()
sleep(Duration::from_millis(1));
}

// Throw any panic from `f`, or the return value of `f` if no thread panicked.
match result {
Err(e) => resume_unwind(e),
Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => {
panic!("a scoped thread panicked")
}
Ok(result) => result,
}
}

pub struct ScopedJoinHandle<'scope, T>(crate::JoinHandle<T>, PhantomData<&'scope ()>);
impl<'scope, T> ScopedJoinHandle<'scope, T> {
pub fn join(self) -> std::io::Result<T> {
self.0
.join()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, ""))
}
}

pub fn spawn_scoped<'scope, 'env, F, T>(
builder: crate::Builder,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> std::io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(
unsafe { builder.spawn_unchecked(f) }?,
PhantomData,
))
}