From 827cc6e88cc9a2b2809b0387580c68a7200e473c Mon Sep 17 00:00:00 2001 From: Sebastian Urban Date: Sat, 21 Dec 2024 00:37:38 +0100 Subject: [PATCH] Sleep for multi-threaded WASM targets --- remoc/src/exec/js/time.rs | 248 ++++++++++++++++++++++++++------------ 1 file changed, 172 insertions(+), 76 deletions(-) diff --git a/remoc/src/exec/js/time.rs b/remoc/src/exec/js/time.rs index fbb45e6..3321d2d 100644 --- a/remoc/src/exec/js/time.rs +++ b/remoc/src/exec/js/time.rs @@ -2,114 +2,210 @@ #![allow(unsafe_code)] -use js_sys::Function; use std::{ fmt, future::{Future, IntoFuture}, pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll, Waker}, + task::{Context, Poll}, time::Duration, }; -use wasm_bindgen::{prelude::*, JsCast}; -use web_sys::{Window, WorkerGlobalScope}; - -/// Future returned by [`sleep`]. -pub struct Sleep { - inner: Arc>, - timeout_id: i32, - _callback: Closure, -} -// Implement Send + Sync for target without threads. -#[cfg(all(target_family = "wasm", not(target_feature = "atomics")))] -unsafe impl Send for Sleep {} -#[cfg(all(target_family = "wasm", not(target_feature = "atomics")))] -unsafe impl Sync for Sleep {} +/// JavaScript sleep wrapper. +mod js { + use js_sys::Function; + use std::{ + cell::RefCell, + fmt, + future::Future, + pin::Pin, + rc::Rc, + task::{Context, Poll, Waker}, + time::Duration, + }; + use wasm_bindgen::{prelude::*, JsCast}; + use web_sys::{Window, WorkerGlobalScope}; -impl fmt::Debug for Sleep { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Sleep").field("timeout_id", &self.timeout_id).finish() + /// JavaScript sleep. + /// + /// This is not Send + Sync since the underlying callback is bound to a + /// JavaScript thread. + pub struct JsSleep { + inner: Rc>, + timeout_id: i32, + _callback: Closure, } -} -#[derive(Default)] -struct SleepInner { - fired: bool, - waker: Option, -} + // Implement Send + Sync for target without threads. + #[cfg(all(target_family = "wasm", not(target_feature = "atomics")))] + unsafe impl Send for JsSleep {} + #[cfg(all(target_family = "wasm", not(target_feature = "atomics")))] + unsafe impl Sync for JsSleep {} + + impl fmt::Debug for JsSleep { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Sleep").field("timeout_id", &self.timeout_id).finish() + } + } -impl Sleep { - fn new(duration: Duration) -> Self { - let inner = Arc::new(Mutex::new(SleepInner::default())); - - let callback = { - let inner = inner.clone(); - Closure::new(move || { - let mut inner = inner.lock().unwrap(); - inner.fired = true; - if let Some(waker) = inner.waker.take() { - waker.wake(); - } - }) - }; - - let timeout = duration.as_millis().try_into().expect("sleep duration overflow"); - let timeout_id = Self::register_timeout(callback.as_ref().unchecked_ref(), timeout); - - Self { inner, timeout_id, _callback: callback } + #[derive(Default)] + struct JsSleepInner { + fired: bool, + waker: Option, } - fn register_timeout(handler: &Function, timeout: i32) -> i32 { - let global = js_sys::global(); + impl JsSleep { + pub(super) fn new(duration: Duration) -> Self { + let inner = Rc::new(RefCell::new(JsSleepInner::default())); - if let Some(window) = global.dyn_ref::() { - window.set_timeout_with_callback_and_timeout_and_arguments_0(handler, timeout).unwrap() - } else if let Some(worker) = global.dyn_ref::() { - worker.set_timeout_with_callback_and_timeout_and_arguments_0(handler, timeout).unwrap() - } else { - panic!("unsupported JavaScript global: {global:?}"); + let callback = { + let inner = inner.clone(); + Closure::new(move || { + let mut inner = inner.borrow_mut(); + inner.fired = true; + if let Some(waker) = inner.waker.take() { + waker.wake(); + } + }) + }; + + let timeout = duration.as_millis().try_into().expect("sleep duration overflow"); + let timeout_id = Self::register_timeout(callback.as_ref().unchecked_ref(), timeout); + + Self { inner, timeout_id, _callback: callback } + } + + fn register_timeout(handler: &Function, timeout: i32) -> i32 { + let global = js_sys::global(); + + if let Some(window) = global.dyn_ref::() { + window.set_timeout_with_callback_and_timeout_and_arguments_0(handler, timeout).unwrap() + } else if let Some(worker) = global.dyn_ref::() { + worker.set_timeout_with_callback_and_timeout_and_arguments_0(handler, timeout).unwrap() + } else { + panic!("unsupported JavaScript global: {global:?}"); + } + } + + fn unregister_timeout(id: i32) { + let global = js_sys::global(); + + if let Some(window) = global.dyn_ref::() { + window.clear_timeout_with_handle(id); + } else if let Some(worker) = global.dyn_ref::() { + worker.clear_timeout_with_handle(id); + } else { + panic!("unsupported JavaScript global: {global:?}"); + } } } - fn unregister_timeout(id: i32) { - let global = js_sys::global(); + impl Future for JsSleep { + type Output = (); + + /// Waits until the sleep duration has elapsed. + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut inner = self.inner.borrow_mut(); + + if inner.fired { + return Poll::Ready(()); + } - if let Some(window) = global.dyn_ref::() { - window.clear_timeout_with_handle(id); - } else if let Some(worker) = global.dyn_ref::() { - worker.clear_timeout_with_handle(id); - } else { - panic!("unsupported JavaScript global: {global:?}"); + inner.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + + impl Drop for JsSleep { + fn drop(&mut self) { + let inner = self.inner.borrow_mut(); + if !inner.fired { + Self::unregister_timeout(self.timeout_id); + } } } } -impl Future for Sleep { - type Output = (); +/// Future for [`sleep`]. +#[cfg(all(target_family = "wasm", not(target_feature = "atomics")))] +pub use js::JsSleep as Sleep; + +/// Thread-safe sleep. +#[cfg(all(target_family = "wasm", target_feature = "atomics"))] +mod threads { + use futures::{ready, FutureExt}; + use std::{ + fmt, + future::Future, + pin::Pin, + sync::LazyLock, + task::{Context, Poll}, + time::Duration, + }; + use tokio::sync::{mpsc, oneshot}; + use wasm_bindgen_futures::spawn_thread; - /// Waits until the sleep duration has elapsed. - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let mut inner = self.inner.lock().unwrap(); + use super::js::JsSleep; - if inner.fired { - return Poll::Ready(()); + struct SleepReq { + duration: Duration, + wake_tx: oneshot::Sender<()>, + } + + static SLEEP_TX: LazyLock> = LazyLock::new(|| { + let (sleep_tx, mut sleep_rx) = mpsc::unbounded_channel::(); + spawn_thread(move || async move { + while let Some(SleepReq { duration, mut wake_tx }) = sleep_rx.recv().await { + wasm_bindgen_futures::spawn_local(async move { + tokio::select! { + () = JsSleep::new(duration) => { + let _ = wake_tx.send(()); + }, + () = wake_tx.closed() => (), + } + }); + } + }); + sleep_tx + }); + + /// Thread-safe sleep wrapper. + pub struct Sleep(oneshot::Receiver<()>); + + impl fmt::Debug for Sleep { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("Sleep").field(&self.0).finish() } + } - inner.waker = Some(cx.waker().clone()); - Poll::Pending + impl Sleep { + pub(super) fn new(duration: Duration) -> Self { + let (wake_tx, wake_rx) = oneshot::channel(); + SLEEP_TX.send(SleepReq { duration, wake_tx }).expect("sleep thread failed"); + Self(wake_rx) + } } -} -impl Drop for Sleep { - fn drop(&mut self) { - let inner = self.inner.lock().unwrap(); - if !inner.fired { - Self::unregister_timeout(self.timeout_id); + impl Future for Sleep { + type Output = (); + + /// Waits until the sleep duration has elapsed. + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + ready!(self.0.poll_unpin(cx)).expect("sleep thread failed"); + Poll::Ready(()) + } + } + + impl Drop for Sleep { + fn drop(&mut self) { + // empty } } } +/// Future for [`sleep`]. +#[cfg(all(target_family = "wasm", target_feature = "atomics"))] +pub use threads::Sleep; + /// Waits until `duration` has elapsed. pub fn sleep(duration: Duration) -> Sleep { Sleep::new(duration)