diff --git a/Cargo.toml b/Cargo.toml index 52244e8..57f2ef2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,14 @@ exclude = ["/.github/*", "/.travis.yml", "/appveyor.yml"] log = { version = "0.4", features = ["std"] } once_cell = "1.9.0" rand = "0.8" +tokio = { version = "1.32", features = ["sync"], optional = true } + +[dev-dependencies] +tokio = { version = "1.32", features = ["sync", "rt-multi-thread", "time", "macros"] } [features] failpoints = [] +async = ["tokio"] [package.metadata.docs.rs] all-features = true diff --git a/src/lib.rs b/src/lib.rs index f23cc44..cf270fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -282,6 +282,8 @@ enum Task { Delay(u64), /// Call callback function. Callback(SyncCallback), + #[cfg(feature = "async")] + CallbackAsync(async_imp::AsyncCallback), } #[derive(Debug)] @@ -433,6 +435,8 @@ impl FromStr for Action { struct FailPoint { pause: Mutex, pause_notifier: Condvar, + #[cfg(feature = "async")] + async_pause_notify: tokio::sync::Notify, actions: RwLock>, actions_str: RwLock, } @@ -443,6 +447,8 @@ impl FailPoint { FailPoint { pause: Mutex::new(false), pause_notifier: Condvar::new(), + #[cfg(feature = "async")] + async_pause_notify: tokio::sync::Notify::new(), actions: RwLock::default(), actions_str: RwLock::default(), } @@ -450,6 +456,8 @@ impl FailPoint { fn set_actions(&self, actions_str: &str, actions: Vec) { loop { + #[cfg(feature = "async")] + self.async_pause_notify.notify_waiters(); // TODO: maybe busy waiting here. match self.actions.try_write() { Err(TryLockError::WouldBlock) => {} @@ -509,6 +517,10 @@ impl FailPoint { Task::Callback(f) => { f.run(); } + #[cfg(feature = "async")] + Task::CallbackAsync(_) => panic!( + "to use async callback, please enable `async` feature and use `async_fail_point`" + ), } None } @@ -852,6 +864,179 @@ macro_rules! fail_point { ($name:expr, $cond:expr, $e:expr) => {{}}; } +#[cfg(feature = "async")] +mod async_imp { + use super::*; + type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; + + #[derive(Clone)] + pub(crate) struct AsyncCallback( + Arc BoxFuture<'static, ()> + Send + Sync + 'static>, + ); + + impl Debug for AsyncCallback { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("AsyncCallback()") + } + } + + impl PartialEq for AsyncCallback { + #[allow(clippy::vtable_address_comparisons)] + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } + } + + impl AsyncCallback { + fn new(f: impl Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static) -> AsyncCallback { + AsyncCallback(Arc::new(f)) + } + + async fn run(&self) { + let callback = &self.0; + callback().await; + } + } + + /// `fail_point` but with support for async callback and pause. + #[macro_export] + #[cfg(feature = "failpoints")] + macro_rules! async_fail_point { + ($name:expr) => {{ + $crate::async_eval($name, |_| { + panic!("Return is not supported for the fail point \"{}\"", $name); + }) + .await; + }}; + ($name:expr, $e:expr) => {{ + if let Some(res) = $crate::async_eval($name, $e).await { + return res; + } + }}; + ($name:expr, $cond:expr, $e:expr) => {{ + if $cond { + $crate::async_fail_point!($name, $e); + } + }}; + } + + /// Define an async fail point (disabled, see `failpoints` feature). + #[macro_export] + #[cfg(not(feature = "failpoints"))] + macro_rules! async_fail_point { + ($name:expr, $e:expr) => {{}}; + ($name:expr) => {{}}; + ($name:expr, $cond:expr, $e:expr) => {{}}; + } + + /// Configures an async callback to be triggered at the specified + /// failpoint. If the failpoint is not implemented using + /// `async_fail_point`, the execution will raise an exception. + pub fn cfg_async_callback(name: S, f: F) -> Result<(), String> + where + S: Into, + F: Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static, + { + let mut registry = REGISTRY.registry.write().unwrap(); + let p = registry + .entry(name.into()) + .or_insert_with(|| Arc::new(FailPoint::new())); + let action = Action::from_async_callback(f); + let actions = vec![action]; + p.set_actions("callback", actions); + Ok(()) + } + + #[doc(hidden)] + pub async fn async_eval) -> R>(name: &str, f: F) -> Option { + let p = { + let registry = REGISTRY.registry.read().unwrap(); + match registry.get(name) { + None => return None, + Some(p) => p.clone(), + } + }; + p.async_eval(name).await.map(f) + } + + impl Action { + fn from_async_callback( + f: impl Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static, + ) -> Action { + let task = Task::CallbackAsync(AsyncCallback::new(f)); + Action { + task, + freq: 1.0, + count: None, + } + } + } + + impl FailPoint { + #[cfg_attr(feature = "cargo-clippy", allow(clippy::option_option))] + async fn async_eval(&self, name: &str) -> Option> { + let task = { + let task = self + .actions + .read() + .unwrap() + .iter() + .filter_map(Action::get_task) + .next(); + match task { + Some(Task::Pause) => { + // let n = self.async_pause_notify.clone(); + self.async_pause_notify.notified().await; + return None; + } + Some(t) => t, + None => return None, + } + }; + + match task { + Task::Off => {} + Task::Return(s) => return Some(s), + Task::Sleep(t) => { + let not = Arc::new(tokio::sync::Notify::new()); + let not_for_thread = not.clone(); + let handle = std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(t)); + not_for_thread.notify_waiters(); + }); + not.notified().await; + handle.join().unwrap(); + } + Task::Panic(msg) => match msg { + Some(ref msg) => panic!("{}", msg), + None => panic!("failpoint {} panic", name), + }, + Task::Print(msg) => match msg { + Some(ref msg) => log::info!("{}", msg), + None => log::info!("failpoint {} executed.", name), + }, + Task::Pause => unreachable!(), + Task::Yield => thread::yield_now(), + Task::Delay(t) => { + let timer = Instant::now(); + let timeout = Duration::from_millis(t); + while timer.elapsed() < timeout {} + } + Task::Callback(f) => { + f.run(); + } + Task::CallbackAsync(f) => { + f.run().await; + } + } + None + } + } +} + +#[cfg(feature = "async")] +pub use async_imp::*; + #[cfg(test)] mod tests { use super::*; @@ -1062,4 +1247,60 @@ mod tests { assert_eq!(rx.recv_timeout(Duration::from_millis(500)).unwrap(), 0); assert_eq!(f1(), 0); } + + #[cfg(feature = "async")] + #[cfg_attr(not(feature = "failpoints"), ignore)] + #[tokio::test] + async fn test_async_failpoints() { + let f1 = async { + async_fail_point!("async_cb"); + }; + let f2 = async { + async_fail_point!("async_cb"); + }; + + let counter = Arc::new(AtomicUsize::new(0)); + let counter2 = counter.clone(); + cfg_async_callback("async_cb", move || { + counter2.fetch_add(1, Ordering::SeqCst); + Box::pin(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + }) + }) + .unwrap(); + f1.await; + f2.await; + assert_eq!(2, counter.load(Ordering::SeqCst)); + + cfg("async_pause", "pause").unwrap(); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let handle = tokio::spawn(async move { + async_fail_point!("async_pause"); + tx.send(()).await.unwrap(); + }); + tokio::time::timeout(Duration::from_millis(500), rx.recv()) + .await + .unwrap_err(); + remove("async_pause"); + tokio::time::timeout(Duration::from_millis(500), rx.recv()) + .await + .unwrap(); + handle.await.unwrap(); + + cfg("async_sleep", "sleep(500)").unwrap(); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let handle = tokio::spawn(async move { + tx.send(()).await.unwrap(); + async_fail_point!("async_sleep"); + tx.send(()).await.unwrap(); + }); + rx.recv().await.unwrap(); + tokio::time::timeout(Duration::from_millis(300), rx.recv()) + .await + .unwrap_err(); + tokio::time::timeout(Duration::from_millis(300), rx.recv()) + .await + .unwrap(); + handle.await.unwrap(); + } }