Skip to content

Commit

Permalink
Add CallbackHandlerWaker
Browse files Browse the repository at this point in the history
  • Loading branch information
mzohreva committed Nov 21, 2020
1 parent 39f1de5 commit e381dca
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
47 changes: 42 additions & 5 deletions async-usercalls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ impl AsyncUsercallProvider {
let callbacks = Mutex::new(HashMap::new());
let (callback_tx, callback_rx) = mpmc::unbounded();
let provider = Self { core, callback_tx };
let waker = CallbackHandlerWaker::new();
let handler = CallbackHandler {
return_rx,
callbacks,
callback_rx,
waker,
};
(provider, handler)
}
Expand All @@ -91,22 +93,56 @@ impl AsyncUsercallProvider {
}
}

#[derive(Clone)]
pub struct CallbackHandlerWaker {
rx: mpmc::Receiver<()>,
tx: mpmc::Sender<()>,
}

impl CallbackHandlerWaker {
fn new() -> Self {
let (tx, rx) = mpmc::bounded(1);
Self { tx, rx }
}

/// Interrupts the currently running or a future call to the related
/// CallbackHandler's `poll()`.
pub fn wake(&self) {
let _ = self.tx.try_send(());
}

/// Clears the effect of a previous call to `self.wake()` that is not yet
/// observed by `CallbackHandler::poll()`.
pub fn clear(&self) {
let _ = self.rx.try_recv();
}
}

pub struct CallbackHandler {
return_rx: mpmc::Receiver<Identified<Return>>,
callbacks: Mutex<HashMap<u64, Callback>>,
// This is used so that threads sending usercalls don't have to take the lock.
callback_rx: mpmc::Receiver<(u64, Callback)>,
waker: CallbackHandlerWaker,
}

impl CallbackHandler {
// Returns an object that can be used to interrupt a blocked `self.poll()`.
pub fn waker(&self) -> CallbackHandlerWaker {
self.waker.clone()
}

#[inline]
fn recv_returns(&self, timeout: Option<Duration>, returns: &mut [Identified<Return>]) -> usize {
let first = match timeout {
None => self.return_rx.recv().ok(),
Some(timeout) => match self.return_rx.recv_timeout(timeout) {
Ok(val) => Some(val),
Err(mpmc::RecvTimeoutError::Disconnected) => None,
Err(mpmc::RecvTimeoutError::Timeout) => return 0,
None => mpmc::select! {
recv(self.return_rx) -> res => res.ok(),
recv(self.waker.rx) -> _res => return 0,
},
Some(timeout) => mpmc::select! {
recv(self.return_rx) -> res => res.ok(),
recv(self.waker.rx) -> _res => return 0,
default(timeout) => return 0,
},
}
.expect("return channel closed unexpectedly");
Expand All @@ -122,6 +158,7 @@ impl CallbackHandler {
/// functions. If `timeout` is `None`, it will block execution until at
/// least one return is received, otherwise it will block until there is a
/// return or timeout is elapsed. Returns the number of executed callbacks.
/// This can be interrupted using `CallbackHandlerWaker::wake()`.
pub fn poll(&self, timeout: Option<Duration>) -> usize {
// 1. wait for returns
let mut returns = [Identified {
Expand Down
19 changes: 19 additions & 0 deletions async-usercalls/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,25 @@ fn read_buffer_basic() {
assert_eq!(&buf, b"hello\0\0\0");
}

#[test]
fn callback_handler_waker() {
let (_provider, handler) = AsyncUsercallProvider::new();
let waker = handler.waker();
let (tx, rx) = mpmc::bounded(1);
let h = thread::spawn(move || {
let n1 = handler.poll(None);
tx.send(()).unwrap();
let n2 = handler.poll(Some(Duration::from_secs(3)));
tx.send(()).unwrap();
n1 + n2
});
for _ in 0..2 {
waker.wake();
rx.recv().unwrap();
}
assert_eq!(h.join().unwrap(), 0);
}

#[test]
#[ignore]
fn echo() {
Expand Down

0 comments on commit e381dca

Please sign in to comment.