diff --git a/src/lib.rs b/src/lib.rs index 7f2bac1..aa1c140 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -225,6 +225,7 @@ #![deny(missing_docs, missing_debug_implementations)] +use std::any::{Any, TypeId}; use std::collections::HashMap; use std::env::VarError; use std::fmt::Debug; @@ -818,6 +819,153 @@ macro_rules! fail_point { ($name:expr, $cond:expr, $e:expr) => {{}}; } +#[derive(Clone)] +struct SyncMutCallback1(Arc>); + +impl Debug for SyncMutCallback1 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("SyncMutCallback1()") + } +} + +impl PartialEq for SyncMutCallback1 { + #[allow(clippy::vtable_address_comparisons)] + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl SyncMutCallback1 { + fn new(f: Box) -> SyncMutCallback1 { + SyncMutCallback1(Arc::new(Mutex::new(f))) + } + + fn run(&mut self, var: &mut dyn Any) { + let callback = &mut self.0.lock().unwrap(); + callback(var); + } +} + +struct MapEntry { + type_id: TypeId, + cb: SyncMutCallback1, +} + +impl MapEntry { + fn new(type_id: TypeId, cb: Box) -> MapEntry { + MapEntry { + type_id: type_id, + cb: SyncMutCallback1::new(cb), + } + } +} + +lazy_static::lazy_static! { + static ref TESTVALUE_REGISTRY: RwLock>= Default::default(); +} + +/// Set the callback for a test value adjustment. +/// +/// Usage: +/// +/// ```rust +/// use fail::{adjust, ScopedCallback}; +/// +/// fn production_code() { +/// let mut var = 1; +/// adjust!("adjust_this_var", &mut var); +/// } +/// +/// fn test_code() { +/// let _raii = ScopedCallback::new("adjust_this_var", |var| { +/// *var = 2; +/// }); +/// } +/// ``` +/// +pub fn set_callback(name: S, mut f: F) -> Result<(), String> +where + S: Into, + T: Any, + F: FnMut(&mut T) + Send + Sync + 'static, +{ + let mut registry = TESTVALUE_REGISTRY.write().unwrap(); + registry.insert( + name.into(), + MapEntry::new( + TypeId::of::(), + Box::new(move |var| { + if let Some(var) = var.downcast_mut::() { + f(var); + } else { + panic!("Type mismtach"); + } + }), + ), + ); + Ok(()) +} + +/// Set a scoped callback using RAII +#[derive(Debug)] +pub struct ScopedCallback { + name: String, +} + +impl ScopedCallback { + /// Creates a RAII instance. + pub fn new(name: S, f: F) -> Self + where + S: Into + Copy, + T: Any, + F: FnMut(&mut T) + Send + Sync + 'static, + { + set_callback(name.clone(), f).unwrap(); + ScopedCallback { name: name.into() } + } +} + +impl Drop for ScopedCallback { + fn drop(&mut self) { + let mut registry = TESTVALUE_REGISTRY.write().unwrap(); + registry.remove(&self.name); + } +} + +#[doc(hidden)] +pub fn internal_adjust(name: S, var: &mut T) +where + S: Into, + T: Clone + 'static, +{ + let mut registry = TESTVALUE_REGISTRY.write().unwrap(); + // Clone the var here, since the argument is required to be 'static. + let mut clone = var.clone(); + if let Some(entry) = registry.get_mut(&name.into()) { + if (*entry).type_id != TypeId::of::() { + panic!("Type mismatch"); + } + (*entry).cb.run(&mut clone); + } + *var = clone; +} + +/// Define a test value adjustment (requires `failpoints` feature). +#[macro_export] +#[cfg(feature = "failpoints")] +macro_rules! adjust { + ($name:expr, $var:expr) => {{ + $crate::internal_adjust($name, $var); + }}; +} + +/// Define a test value adjustment (disabled, see `failpoints` feature). +#[macro_export] +#[cfg(not(feature = "failpoints"))] +macro_rules! adjust { + ($name:expr, $var:expr) => {{}}; +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/tests.rs b/tests/tests.rs index fecb53c..706541e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -211,3 +211,36 @@ fn test_list() { fail::cfg("list", "return").unwrap(); assert!(fail::list().contains(&("list".to_string(), "return".to_string()))); } + +#[test] +fn test_value_adjust() { + let f = || -> i32 { + let mut var = 1; + fail::adjust!("adjust_var", &mut var); + var + }; + assert_eq!(f(), 1); + + fail::set_callback("adjust_var", |vari| { + *vari = 2; + }) + .unwrap(); + assert_eq!(f(), 2); +} + +#[test] +fn test_value_adjust_raii() { + let f = || -> i32 { + let mut var = 1; + fail::adjust!("adjust_var1", &mut var); + var + }; + { + let _raii = fail::ScopedCallback::new("adjust_var1", |var| { + *var = 2; + }); + assert_eq!(f(), 2); + } + + assert_eq!(f(), 1); +}