Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test value adjustment APIs #60

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@

#![deny(missing_docs, missing_debug_implementations)]

use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::env::VarError;
use std::fmt::Debug;
Expand Down Expand Up @@ -818,6 +819,140 @@ macro_rules! fail_point {
($name:expr, $cond:expr, $e:expr) => {{}};
}

#[derive(Clone)]
struct SyncCallback1(Arc<Mutex<dyn FnMut(&mut dyn Any) + Send + Sync + 'static>>);
v01dstar marked this conversation as resolved.
Show resolved Hide resolved

impl Debug for SyncCallback1 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("SyncCallback1()")
}
}

impl PartialEq for SyncCallback1 {
#[allow(clippy::vtable_address_comparisons)]
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}

impl SyncCallback1 {
fn new(f: Box<dyn FnMut(&mut dyn Any) + Send + Sync>) -> SyncCallback1 {
SyncCallback1(Arc::new(Mutex::new(f)))
}

fn run(&mut self, var: &mut dyn Any) {
let callback = &mut self.0.lock().unwrap();
callback(var);
}
}

struct MapEntry {
type_id: TypeId,
cb: SyncCallback1,
}

impl MapEntry {
fn new(type_id: TypeId, cb: Box<dyn FnMut(&mut dyn Any) + Send + Sync>) -> MapEntry {
MapEntry {
type_id: type_id,
cb: SyncCallback1::new(cb),
}
}
}

lazy_static::lazy_static! {
static ref TESTVALUE_REGISTRY: RwLock<HashMap<String, MapEntry>>= Default::default();
}

/// Set a callback
///
/// Dummy doc
pub fn set_callback<S, T, F>(name: S, mut f: F) -> Result<(), String>
where
S: Into<String>,
T: Any,
F: FnMut(&mut T) + Send + Sync + 'static,
{
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
registry.insert(
name.into(),
MapEntry::new(
TypeId::of::<T>(),
Box::new(move |var| {
if let Some(var) = var.downcast_mut::<T>() {
f(var);
} else {
panic!("Type mismtach");
}
}),
),
);
Ok(())
}

/// Set a scoped callback using RAII
#[derive(Debug)]
pub struct ScopedCallback {
name: String,
}

impl ScopedCallback {
/// Creates a RAII instance.
pub fn new<S, T, F>(name: S, f: F) -> Self
where
S: Into<String> + Copy,
T: Any,
F: FnMut(&mut T) + Send + Sync + 'static,
{
set_callback(name.clone(), f).unwrap();
ScopedCallback { name: name.into() }
}
}

impl Drop for ScopedCallback {
fn drop(&mut self) {
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
registry.remove(&self.name);
}
}

#[doc(hidden)]
pub fn internal_adjust<S, T>(name: S, var: &mut T)
where
S: Into<String>,
T: Clone + 'static,
{
let mut registry = TESTVALUE_REGISTRY.write().unwrap();
// Clone the var here, since the argument is required to be 'static.
let mut clone = var.clone();
if let Some(entry) = registry.get_mut(&name.into()) {
if (*entry).type_id != TypeId::of::<T>() {
panic!("Type mismatch");
}
(*entry).cb.run(&mut clone);
}
*var = clone;
}

/// Define a test value adjustment (requires `failpoints` feature).
///
#[macro_export]
#[cfg(feature = "failpoints")]
macro_rules! adjust {
($name:expr, $var:expr) => {{
$crate::internal_adjust($name, $var);
}};
}

/// Define a test value adjustment (disabled, see `failpoints` feature).
///
/// Dummy doc.
#[macro_export]
#[cfg(not(feature = "failpoints"))]
macro_rules! adjust {
($name:expr, $var:expr) => {{}};
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
33 changes: 33 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,36 @@ fn test_list() {
fail::cfg("list", "return").unwrap();
assert!(fail::list().contains(&("list".to_string(), "return".to_string())));
}

#[test]
fn test_value_adjust() {
v01dstar marked this conversation as resolved.
Show resolved Hide resolved
let f = || -> i32 {
let mut var = 1;
fail::adjust!("adjust_var", &mut var);
var
};
assert_eq!(f(), 1);

fail::set_callback("adjust_var", |vari| {
*vari = 2;
})
.unwrap();
assert_eq!(f(), 2);
}

#[test]
fn test_value_adjust_raii() {
let f = || -> i32 {
let mut var = 1;
fail::adjust!("adjust_var1", &mut var);
var
};
{
let _raii = fail::ScopedCallback::new("adjust_var1", |var| {
*var = 2;
});
assert_eq!(f(), 2);
}

assert_eq!(f(), 1);
}