Skip to content

Commit

Permalink
Sleep for multi-threaded WASM targets
Browse files Browse the repository at this point in the history
  • Loading branch information
surban committed Dec 20, 2024
1 parent 655844a commit 827cc6e
Showing 1 changed file with 172 additions and 76 deletions.
248 changes: 172 additions & 76 deletions remoc/src/exec/js/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,114 +2,210 @@
#![allow(unsafe_code)]

use js_sys::Function;
use std::{
fmt,
future::{Future, IntoFuture},
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
task::{Context, Poll},
time::Duration,
};
use wasm_bindgen::{prelude::*, JsCast};
use web_sys::{Window, WorkerGlobalScope};

/// Future returned by [`sleep`].
pub struct Sleep {
inner: Arc<Mutex<SleepInner>>,
timeout_id: i32,
_callback: Closure<dyn FnMut()>,
}

// Implement Send + Sync for target without threads.
#[cfg(all(target_family = "wasm", not(target_feature = "atomics")))]
unsafe impl Send for Sleep {}
#[cfg(all(target_family = "wasm", not(target_feature = "atomics")))]
unsafe impl Sync for Sleep {}
/// JavaScript sleep wrapper.
mod js {
use js_sys::Function;
use std::{
cell::RefCell,
fmt,
future::Future,
pin::Pin,
rc::Rc,
task::{Context, Poll, Waker},
time::Duration,
};
use wasm_bindgen::{prelude::*, JsCast};
use web_sys::{Window, WorkerGlobalScope};

impl fmt::Debug for Sleep {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Sleep").field("timeout_id", &self.timeout_id).finish()
/// JavaScript sleep.
///
/// This is not Send + Sync since the underlying callback is bound to a
/// JavaScript thread.
pub struct JsSleep {
inner: Rc<RefCell<JsSleepInner>>,
timeout_id: i32,
_callback: Closure<dyn FnMut()>,
}
}

#[derive(Default)]
struct SleepInner {
fired: bool,
waker: Option<Waker>,
}
// Implement Send + Sync for target without threads.
#[cfg(all(target_family = "wasm", not(target_feature = "atomics")))]
unsafe impl Send for JsSleep {}
#[cfg(all(target_family = "wasm", not(target_feature = "atomics")))]
unsafe impl Sync for JsSleep {}

impl fmt::Debug for JsSleep {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Sleep").field("timeout_id", &self.timeout_id).finish()
}
}

impl Sleep {
fn new(duration: Duration) -> Self {
let inner = Arc::new(Mutex::new(SleepInner::default()));

let callback = {
let inner = inner.clone();
Closure::new(move || {
let mut inner = inner.lock().unwrap();
inner.fired = true;
if let Some(waker) = inner.waker.take() {
waker.wake();
}
})
};

let timeout = duration.as_millis().try_into().expect("sleep duration overflow");
let timeout_id = Self::register_timeout(callback.as_ref().unchecked_ref(), timeout);

Self { inner, timeout_id, _callback: callback }
#[derive(Default)]
struct JsSleepInner {
fired: bool,
waker: Option<Waker>,
}

fn register_timeout(handler: &Function, timeout: i32) -> i32 {
let global = js_sys::global();
impl JsSleep {
pub(super) fn new(duration: Duration) -> Self {
let inner = Rc::new(RefCell::new(JsSleepInner::default()));

if let Some(window) = global.dyn_ref::<Window>() {
window.set_timeout_with_callback_and_timeout_and_arguments_0(handler, timeout).unwrap()
} else if let Some(worker) = global.dyn_ref::<WorkerGlobalScope>() {
worker.set_timeout_with_callback_and_timeout_and_arguments_0(handler, timeout).unwrap()
} else {
panic!("unsupported JavaScript global: {global:?}");
let callback = {
let inner = inner.clone();
Closure::new(move || {
let mut inner = inner.borrow_mut();
inner.fired = true;
if let Some(waker) = inner.waker.take() {
waker.wake();
}
})
};

let timeout = duration.as_millis().try_into().expect("sleep duration overflow");
let timeout_id = Self::register_timeout(callback.as_ref().unchecked_ref(), timeout);

Self { inner, timeout_id, _callback: callback }
}

fn register_timeout(handler: &Function, timeout: i32) -> i32 {
let global = js_sys::global();

if let Some(window) = global.dyn_ref::<Window>() {
window.set_timeout_with_callback_and_timeout_and_arguments_0(handler, timeout).unwrap()
} else if let Some(worker) = global.dyn_ref::<WorkerGlobalScope>() {
worker.set_timeout_with_callback_and_timeout_and_arguments_0(handler, timeout).unwrap()
} else {
panic!("unsupported JavaScript global: {global:?}");
}
}

fn unregister_timeout(id: i32) {
let global = js_sys::global();

if let Some(window) = global.dyn_ref::<Window>() {
window.clear_timeout_with_handle(id);
} else if let Some(worker) = global.dyn_ref::<WorkerGlobalScope>() {
worker.clear_timeout_with_handle(id);
} else {
panic!("unsupported JavaScript global: {global:?}");
}
}
}

fn unregister_timeout(id: i32) {
let global = js_sys::global();
impl Future for JsSleep {
type Output = ();

/// Waits until the sleep duration has elapsed.
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut inner = self.inner.borrow_mut();

if inner.fired {
return Poll::Ready(());
}

if let Some(window) = global.dyn_ref::<Window>() {
window.clear_timeout_with_handle(id);
} else if let Some(worker) = global.dyn_ref::<WorkerGlobalScope>() {
worker.clear_timeout_with_handle(id);
} else {
panic!("unsupported JavaScript global: {global:?}");
inner.waker = Some(cx.waker().clone());
Poll::Pending
}
}

impl Drop for JsSleep {
fn drop(&mut self) {
let inner = self.inner.borrow_mut();
if !inner.fired {
Self::unregister_timeout(self.timeout_id);
}
}
}
}

impl Future for Sleep {
type Output = ();
/// Future for [`sleep`].
#[cfg(all(target_family = "wasm", not(target_feature = "atomics")))]
pub use js::JsSleep as Sleep;

/// Thread-safe sleep.
#[cfg(all(target_family = "wasm", target_feature = "atomics"))]
mod threads {
use futures::{ready, FutureExt};
use std::{
fmt,
future::Future,
pin::Pin,
sync::LazyLock,
task::{Context, Poll},
time::Duration,
};
use tokio::sync::{mpsc, oneshot};
use wasm_bindgen_futures::spawn_thread;

/// Waits until the sleep duration has elapsed.
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut inner = self.inner.lock().unwrap();
use super::js::JsSleep;

if inner.fired {
return Poll::Ready(());
struct SleepReq {
duration: Duration,
wake_tx: oneshot::Sender<()>,
}

static SLEEP_TX: LazyLock<mpsc::UnboundedSender<SleepReq>> = LazyLock::new(|| {
let (sleep_tx, mut sleep_rx) = mpsc::unbounded_channel::<SleepReq>();
spawn_thread(move || async move {
while let Some(SleepReq { duration, mut wake_tx }) = sleep_rx.recv().await {
wasm_bindgen_futures::spawn_local(async move {
tokio::select! {
() = JsSleep::new(duration) => {
let _ = wake_tx.send(());
},
() = wake_tx.closed() => (),
}
});
}
});
sleep_tx
});

/// Thread-safe sleep wrapper.
pub struct Sleep(oneshot::Receiver<()>);

impl fmt::Debug for Sleep {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("Sleep").field(&self.0).finish()
}
}

inner.waker = Some(cx.waker().clone());
Poll::Pending
impl Sleep {
pub(super) fn new(duration: Duration) -> Self {
let (wake_tx, wake_rx) = oneshot::channel();
SLEEP_TX.send(SleepReq { duration, wake_tx }).expect("sleep thread failed");
Self(wake_rx)
}
}
}

impl Drop for Sleep {
fn drop(&mut self) {
let inner = self.inner.lock().unwrap();
if !inner.fired {
Self::unregister_timeout(self.timeout_id);
impl Future for Sleep {
type Output = ();

/// Waits until the sleep duration has elapsed.
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
ready!(self.0.poll_unpin(cx)).expect("sleep thread failed");
Poll::Ready(())
}
}

impl Drop for Sleep {
fn drop(&mut self) {
// empty
}
}
}

/// Future for [`sleep`].
#[cfg(all(target_family = "wasm", target_feature = "atomics"))]
pub use threads::Sleep;

/// Waits until `duration` has elapsed.
pub fn sleep(duration: Duration) -> Sleep {
Sleep::new(duration)
Expand Down

0 comments on commit 827cc6e

Please sign in to comment.