diff --git a/async-usercalls/src/lib.rs b/async-usercalls/src/lib.rs index 0c51d29a..ce58b8ad 100644 --- a/async-usercalls/src/lib.rs +++ b/async-usercalls/src/lib.rs @@ -67,10 +67,12 @@ impl AsyncUsercallProvider { let callbacks = Mutex::new(HashMap::new()); let (callback_tx, callback_rx) = mpmc::unbounded(); let provider = Self { core, callback_tx }; + let waker = CallbackHandlerWaker::new(); let handler = CallbackHandler { return_rx, callbacks, callback_rx, + waker, }; (provider, handler) } @@ -91,22 +93,56 @@ impl AsyncUsercallProvider { } } +#[derive(Clone)] +pub struct CallbackHandlerWaker { + rx: mpmc::Receiver<()>, + tx: mpmc::Sender<()>, +} + +impl CallbackHandlerWaker { + fn new() -> Self { + let (tx, rx) = mpmc::bounded(1); + Self { tx, rx } + } + + /// Interrupts the currently running or a future call to the related + /// CallbackHandler's `poll()`. + pub fn wake(&self) { + let _ = self.tx.try_send(()); + } + + /// Clears the effect of a previous call to `self.wake()` that is not yet + /// observed by `CallbackHandler::poll()`. + pub fn clear(&self) { + let _ = self.rx.try_recv(); + } +} + pub struct CallbackHandler { return_rx: mpmc::Receiver>, callbacks: Mutex>, // This is used so that threads sending usercalls don't have to take the lock. callback_rx: mpmc::Receiver<(u64, Callback)>, + waker: CallbackHandlerWaker, } impl CallbackHandler { + // Returns an object that can be used to interrupt a blocked `self.poll()`. + pub fn waker(&self) -> CallbackHandlerWaker { + self.waker.clone() + } + #[inline] fn recv_returns(&self, timeout: Option, returns: &mut [Identified]) -> usize { let first = match timeout { - None => self.return_rx.recv().ok(), - Some(timeout) => match self.return_rx.recv_timeout(timeout) { - Ok(val) => Some(val), - Err(mpmc::RecvTimeoutError::Disconnected) => None, - Err(mpmc::RecvTimeoutError::Timeout) => return 0, + None => mpmc::select! { + recv(self.return_rx) -> res => res.ok(), + recv(self.waker.rx) -> _res => return 0, + }, + Some(timeout) => mpmc::select! { + recv(self.return_rx) -> res => res.ok(), + recv(self.waker.rx) -> _res => return 0, + default(timeout) => return 0, }, } .expect("return channel closed unexpectedly"); @@ -122,6 +158,7 @@ impl CallbackHandler { /// functions. If `timeout` is `None`, it will block execution until at /// least one return is received, otherwise it will block until there is a /// return or timeout is elapsed. Returns the number of executed callbacks. + /// This can be interrupted using `CallbackHandlerWaker::wake()`. pub fn poll(&self, timeout: Option) -> usize { // 1. wait for returns let mut returns = [Identified { diff --git a/async-usercalls/src/tests.rs b/async-usercalls/src/tests.rs index 2bdc473b..ff838c48 100644 --- a/async-usercalls/src/tests.rs +++ b/async-usercalls/src/tests.rs @@ -251,6 +251,25 @@ fn read_buffer_basic() { assert_eq!(&buf, b"hello\0\0\0"); } +#[test] +fn callback_handler_waker() { + let (_provider, handler) = AsyncUsercallProvider::new(); + let waker = handler.waker(); + let (tx, rx) = mpmc::bounded(1); + let h = thread::spawn(move || { + let n1 = handler.poll(None); + tx.send(()).unwrap(); + let n2 = handler.poll(Some(Duration::from_secs(3))); + tx.send(()).unwrap(); + n1 + n2 + }); + for _ in 0..2 { + waker.wake(); + rx.recv().unwrap(); + } + assert_eq!(h.join().unwrap(), 0); +} + #[test] #[ignore] fn echo() {