Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test value adjustment APIs #60

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@

#![deny(missing_docs, missing_debug_implementations)]

use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::env::VarError;
use std::fmt::Debug;
Expand Down Expand Up @@ -818,6 +819,153 @@ macro_rules! fail_point {
($name:expr, $cond:expr, $e:expr) => {{}};
}

#[derive(Clone)]
struct SyncMutCallback1(Arc<Mutex<dyn FnMut(&mut dyn Any) + Send + Sync + 'static>>);
v01dstar marked this conversation as resolved.
Show resolved Hide resolved

impl Debug for SyncMutCallback1 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("SyncMutCallback1()")
}
}

impl PartialEq for SyncMutCallback1 {
#[allow(clippy::vtable_address_comparisons)]
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}

impl SyncMutCallback1 {
fn new(f: Box<dyn FnMut(&mut dyn Any) + Send + Sync>) -> SyncMutCallback1 {
SyncMutCallback1(Arc::new(Mutex::new(f)))
}

fn run(&mut self, var: &mut dyn Any) {
let callback = &mut self.0.lock().unwrap();
callback(var);
}
}

struct MapEntry {
type_id: TypeId,
cb: SyncMutCallback1,
}

impl MapEntry {
fn new(type_id: TypeId, cb: Box<dyn FnMut(&mut dyn Any) + Send + Sync>) -> MapEntry {
MapEntry {
type_id: type_id,
cb: SyncMutCallback1::new(cb),
}
}
}

lazy_static::lazy_static! {
static ref TESTVALUE_REGISTRY: RwLock<HashMap<String, MapEntry>>= Default::default();
}

/// Set the callback for a test value adjustment.
///
/// Usage:
///
/// ```rust
/// use fail::{adjust, ScopedCallback};
///
/// fn production_code() {
/// let mut var = 1;
/// adjust!("adjust_this_var", &mut var);
/// }
///
/// fn test_code() {
/// let _raii = ScopedCallback::new("adjust_this_var", |var| {
/// *var = 2;
/// });
/// }
/// ```
///
pub fn set_callback<S, T, F>(name: S, mut f: F) -> Result<(), String>
where
S: Into<String>,
T: Any,
F: FnMut(&mut T) + Send + Sync + 'static,
{
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
registry.insert(
name.into(),
MapEntry::new(
TypeId::of::<T>(),
Box::new(move |var| {
if let Some(var) = var.downcast_mut::<T>() {
f(var);
} else {
panic!("Type mismtach");
}
}),
),
);
Ok(())
}

/// Set a scoped callback using RAII
#[derive(Debug)]
pub struct ScopedCallback {
name: String,
}

impl ScopedCallback {
/// Creates a RAII instance.
pub fn new<S, T, F>(name: S, f: F) -> Self
where
S: Into<String> + Copy,
T: Any,
F: FnMut(&mut T) + Send + Sync + 'static,
{
set_callback(name.clone(), f).unwrap();
ScopedCallback { name: name.into() }
}
}

impl Drop for ScopedCallback {
fn drop(&mut self) {
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
registry.remove(&self.name);
}
}

#[doc(hidden)]
pub fn internal_adjust<S, T>(name: S, var: &mut T)
where
S: Into<String>,
T: Clone + 'static,
{
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
// Clone the var here, since the argument is required to be 'static.
let mut clone = var.clone();
if let Some(entry) = registry.get_mut(&name.into()) {
if (*entry).type_id != TypeId::of::<T>() {
panic!("Type mismatch");
}
(*entry).cb.run(&mut clone);
}
*var = clone;
}

/// Define a test value adjustment (requires `failpoints` feature).
#[macro_export]
#[cfg(feature = "failpoints")]
macro_rules! adjust {
($name:expr, $var:expr) => {{
$crate::internal_adjust($name, $var);
}};
}

/// Define a test value adjustment (disabled, see `failpoints` feature).
#[macro_export]
#[cfg(not(feature = "failpoints"))]
macro_rules! adjust {
($name:expr, $var:expr) => {{}};
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
33 changes: 33 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,36 @@ fn test_list() {
fail::cfg("list", "return").unwrap();
assert!(fail::list().contains(&("list".to_string(), "return".to_string())));
}

#[test]
fn test_value_adjust() {
v01dstar marked this conversation as resolved.
Show resolved Hide resolved
let f = || -> i32 {
let mut var = 1;
fail::adjust!("adjust_var", &mut var);
var
};
assert_eq!(f(), 1);

fail::set_callback("adjust_var", |vari| {
*vari = 2;
})
.unwrap();
assert_eq!(f(), 2);
}

#[test]
fn test_value_adjust_raii() {
let f = || -> i32 {
let mut var = 1;
fail::adjust!("adjust_var1", &mut var);
var
};
{
let _raii = fail::ScopedCallback::new("adjust_var1", |var| {
*var = 2;
});
assert_eq!(f(), 2);
}

assert_eq!(f(), 1);
}