From 4b2648beb0f549eeda967982ec391b1579a485e2 Mon Sep 17 00:00:00 2001 From: v01dstar Date: Thu, 6 Jan 2022 00:11:00 -0800 Subject: [PATCH 1/3] Add test value adjustment APIs Signed-off-by: v01dstar --- src/lib.rs | 135 +++++++++++++++++++++++++++++++++++++++++++++++++ tests/tests.rs | 33 ++++++++++++ 2 files changed, 168 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 7f2bac1..d347c7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -225,6 +225,7 @@ #![deny(missing_docs, missing_debug_implementations)] +use std::any::{Any, TypeId}; use std::collections::HashMap; use std::env::VarError; use std::fmt::Debug; @@ -818,6 +819,140 @@ macro_rules! fail_point { ($name:expr, $cond:expr, $e:expr) => {{}}; } +#[derive(Clone)] +struct SyncCallback1(Arc>); + +impl Debug for SyncCallback1 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("SyncCallback1()") + } +} + +impl PartialEq for SyncCallback1 { + #[allow(clippy::vtable_address_comparisons)] + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl SyncCallback1 { + fn new(f: Box) -> SyncCallback1 { + SyncCallback1(Arc::new(Mutex::new(f))) + } + + fn run(&mut self, var: &mut dyn Any) { + let callback = &mut self.0.lock().unwrap(); + callback(var); + } +} + +struct MapEntry { + type_id: TypeId, + cb: SyncCallback1, +} + +impl MapEntry { + fn new(type_id: TypeId, cb: Box) -> MapEntry { + MapEntry { + type_id: type_id, + cb: SyncCallback1::new(cb), + } + } +} + +lazy_static::lazy_static! { + static ref TESTVALUE_REGISTRY: RwLock>= Default::default(); +} + +/// Set a callback +/// +/// Dummy doc +pub fn set_callback(name: S, mut f: F) -> Result<(), String> +where + S: Into, + T: Any, + F: FnMut(&mut T) + Send + Sync + 'static, +{ + let mut registry = TESTVALUE_REGISTRY.write().unwrap(); + registry.insert( + name.into(), + MapEntry::new( + TypeId::of::(), + Box::new(move |var| { + if let Some(var) = var.downcast_mut::() { + f(var); + } else { + panic!("Type mismtach"); + } + }), + ), + ); + Ok(()) +} + +/// Set a scoped callback using RAII +#[derive(Debug)] +pub struct ScopedCallback { + name: String, +} + +impl ScopedCallback { + /// Creates a RAII instance. + pub fn new(name: S, f: F) -> Self + where + S: Into + Copy, + T: Any, + F: FnMut(&mut T) + Send + Sync + 'static, + { + set_callback(name.clone(), f).unwrap(); + ScopedCallback { name: name.into() } + } +} + +impl Drop for ScopedCallback { + fn drop(&mut self) { + let mut registry = TESTVALUE_REGISTRY.write().unwrap(); + registry.remove(&self.name); + } +} + +#[doc(hidden)] +pub fn internal_adjust(name: S, var: &mut T) +where + S: Into, + T: Clone + 'static, +{ + let mut registry = TESTVALUE_REGISTRY.write().unwrap(); + // Clone the var here, since the argument is required to be 'static. + let mut clone = var.clone(); + if let Some(entry) = registry.get_mut(&name.into()) { + if (*entry).type_id != TypeId::of::() { + panic!("Type mismatch"); + } + (*entry).cb.run(&mut clone); + } + *var = clone; +} + +/// Define a test value adjustment (requires `failpoints` feature). +/// +#[macro_export] +#[cfg(feature = "failpoints")] +macro_rules! adjust { + ($name:expr, $var:expr) => {{ + $crate::internal_adjust($name, $var); + }}; +} + +/// Define a test value adjustment (disabled, see `failpoints` feature). +/// +/// Dummy doc. +#[macro_export] +#[cfg(not(feature = "failpoints"))] +macro_rules! adjust { + ($name:expr, $var:expr) => {{}}; +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/tests.rs b/tests/tests.rs index fecb53c..706541e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -211,3 +211,36 @@ fn test_list() { fail::cfg("list", "return").unwrap(); assert!(fail::list().contains(&("list".to_string(), "return".to_string()))); } + +#[test] +fn test_value_adjust() { + let f = || -> i32 { + let mut var = 1; + fail::adjust!("adjust_var", &mut var); + var + }; + assert_eq!(f(), 1); + + fail::set_callback("adjust_var", |vari| { + *vari = 2; + }) + .unwrap(); + assert_eq!(f(), 2); +} + +#[test] +fn test_value_adjust_raii() { + let f = || -> i32 { + let mut var = 1; + fail::adjust!("adjust_var1", &mut var); + var + }; + { + let _raii = fail::ScopedCallback::new("adjust_var1", |var| { + *var = 2; + }); + assert_eq!(f(), 2); + } + + assert_eq!(f(), 1); +} From eb581272b73abb2fb531b3b41364191d3bafbdf5 Mon Sep 17 00:00:00 2001 From: v01dstar Date: Thu, 6 Jan 2022 00:22:50 -0800 Subject: [PATCH 2/3] Add docs Signed-off-by: v01dstar --- src/lib.rs | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d347c7e..b7e48f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -864,9 +864,27 @@ lazy_static::lazy_static! { static ref TESTVALUE_REGISTRY: RwLock>= Default::default(); } -/// Set a callback +/// Set the callback for a test value adjustment. +/// +/// Usage: +/// +/// ```rust +/// fn production_code() { +/// ... +/// let mut var = SomeVar(); +/// adjust("adjust_this_var", &mut var); +/// ... +/// } +/// +/// fn test_code() { +/// ... +/// let _raii = ScopedCallback::new("adjust_this_var", |var| { +/// *var = SomeNewValue(); +/// }); +/// ... +/// } +/// ``` /// -/// Dummy doc pub fn set_callback(name: S, mut f: F) -> Result<(), String> where S: Into, @@ -935,7 +953,6 @@ where } /// Define a test value adjustment (requires `failpoints` feature). -/// #[macro_export] #[cfg(feature = "failpoints")] macro_rules! adjust { @@ -945,8 +962,6 @@ macro_rules! adjust { } /// Define a test value adjustment (disabled, see `failpoints` feature). -/// -/// Dummy doc. #[macro_export] #[cfg(not(feature = "failpoints"))] macro_rules! adjust { From 89871054470b4ff4f6f4df9eedfe1c216135d5ba Mon Sep 17 00:00:00 2001 From: v01dstar Date: Tue, 11 Jan 2022 23:46:36 -0800 Subject: [PATCH 3/3] Change function name Signed-off-by: v01dstar --- src/lib.rs | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b7e48f9..aa1c140 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -820,24 +820,24 @@ macro_rules! fail_point { } #[derive(Clone)] -struct SyncCallback1(Arc>); +struct SyncMutCallback1(Arc>); -impl Debug for SyncCallback1 { +impl Debug for SyncMutCallback1 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("SyncCallback1()") + f.write_str("SyncMutCallback1()") } } -impl PartialEq for SyncCallback1 { +impl PartialEq for SyncMutCallback1 { #[allow(clippy::vtable_address_comparisons)] fn eq(&self, other: &Self) -> bool { Arc::ptr_eq(&self.0, &other.0) } } -impl SyncCallback1 { - fn new(f: Box) -> SyncCallback1 { - SyncCallback1(Arc::new(Mutex::new(f))) +impl SyncMutCallback1 { + fn new(f: Box) -> SyncMutCallback1 { + SyncMutCallback1(Arc::new(Mutex::new(f))) } fn run(&mut self, var: &mut dyn Any) { @@ -848,14 +848,14 @@ impl SyncCallback1 { struct MapEntry { type_id: TypeId, - cb: SyncCallback1, + cb: SyncMutCallback1, } impl MapEntry { fn new(type_id: TypeId, cb: Box) -> MapEntry { MapEntry { type_id: type_id, - cb: SyncCallback1::new(cb), + cb: SyncMutCallback1::new(cb), } } } @@ -869,19 +869,17 @@ lazy_static::lazy_static! { /// Usage: /// /// ```rust +/// use fail::{adjust, ScopedCallback}; +/// /// fn production_code() { -/// ... -/// let mut var = SomeVar(); -/// adjust("adjust_this_var", &mut var); -/// ... +/// let mut var = 1; +/// adjust!("adjust_this_var", &mut var); /// } /// /// fn test_code() { -/// ... /// let _raii = ScopedCallback::new("adjust_this_var", |var| { -/// *var = SomeNewValue(); +/// *var = 2; /// }); -/// ... /// } /// ``` ///