Skip to content

Commit

Permalink
use once cell to store registry
Browse files Browse the repository at this point in the history
Signed-off-by: Xinye <[email protected]>
  • Loading branch information
Xinye committed Oct 25, 2023
1 parent eb63946 commit f1a9306
Showing 1 changed file with 51 additions and 22 deletions.
73 changes: 51 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,11 @@ struct FailPointRegistry {
registry: RwLock<Registry>,
}

use once_cell::sync::Lazy;
use once_cell::sync::{Lazy, OnceCell};

static REGISTRY: Lazy<FailPointRegistry> = Lazy::new(FailPointRegistry::default);
static SCENARIO: Lazy<Mutex<&'static FailPointRegistry>> = Lazy::new(|| Mutex::new(&REGISTRY));
static REGISTRY: OnceCell<FailPointRegistry> = OnceCell::new();
static SCENARIO: Lazy<Mutex<&'static FailPointRegistry>> =
Lazy::new(|| Mutex::new(REGISTRY.get_or_init(Default::default)));

/// Test scenario with configured fail points.
#[derive(Debug)]
Expand Down Expand Up @@ -636,7 +637,11 @@ pub const fn has_failpoints() -> bool {
///
/// Return a vector of `(name, actions)` pairs.
pub fn list() -> Vec<(String, String)> {
let registry = REGISTRY.registry.read().unwrap();
let registry = if let Some(r) = REGISTRY.get() {
r.registry.read().unwrap()
} else {
return Vec::new();
};
registry
.iter()
.map(|(name, fp)| (name.to_string(), fp.actions_str.read().unwrap().clone()))
Expand All @@ -645,8 +650,13 @@ pub fn list() -> Vec<(String, String)> {

#[doc(hidden)]
pub fn eval<R, F: FnOnce(Option<String>) -> R>(name: &str, f: F) -> Option<R> {
let registry = if let Some(r) = REGISTRY.get() {
&r.registry
} else {
return None;
};
let p = {
let registry = REGISTRY.registry.read().unwrap();
let registry = registry.read().unwrap();
match registry.get(name) {
None => return None,
Some(p) => p.clone(),
Expand Down Expand Up @@ -686,7 +696,11 @@ pub fn eval<R, F: FnOnce(Option<String>) -> R>(name: &str, f: F) -> Option<R> {
/// A call to `cfg` with a particular fail point name overwrites any existing actions for
/// that fail point, including those set via the `FAILPOINTS` environment variable.
pub fn cfg<S: Into<String>>(name: S, actions: &str) -> Result<(), String> {
let mut registry = REGISTRY.registry.write().unwrap();
let mut registry = REGISTRY
.get_or_init(Default::default)
.registry
.write()
.unwrap();
set(&mut registry, name.into(), actions)
}

Expand All @@ -699,7 +713,11 @@ where
S: Into<String>,
F: Fn() + Send + Sync + 'static,
{
let mut registry = REGISTRY.registry.write().unwrap();
let mut registry = REGISTRY
.get_or_init(Default::default)
.registry
.write()
.unwrap();
let p = registry
.entry(name.into())
.or_insert_with(|| Arc::new(FailPoint::new()));
Expand All @@ -713,7 +731,11 @@ where
///
/// If the fail point doesn't exist, nothing will happen.
pub fn remove<S: AsRef<str>>(name: S) {
let mut registry = REGISTRY.registry.write().unwrap();
let mut registry = if let Some(r) = REGISTRY.get() {
r.registry.write().unwrap()
} else {
return;
};
if let Some(p) = registry.remove(name.as_ref()) {
// wake up all pause failpoint.
p.set_actions("", vec![]);
Expand Down Expand Up @@ -937,7 +959,11 @@ mod async_imp {
S: Into<String>,
F: Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
let mut registry = REGISTRY.registry.write().unwrap();
let mut registry = REGISTRY
.get_or_init(Default::default)
.registry
.write()
.unwrap();
let p = registry
.entry(name.into())
.or_insert_with(|| Arc::new(FailPoint::new()));
Expand All @@ -949,8 +975,13 @@ mod async_imp {

#[doc(hidden)]
pub async fn async_eval<R, F: FnOnce(Option<String>) -> R>(name: &str, f: F) -> Option<R> {
let registry = if let Some(r) = REGISTRY.get() {
&r.registry
} else {
return None;
};
let p = {
let registry = REGISTRY.registry.read().unwrap();
let registry = registry.read().unwrap();
match registry.get(name) {
None => return None,
Some(p) => p.clone(),
Expand Down Expand Up @@ -1017,7 +1048,7 @@ mod async_imp {
},
Task::Pause => unreachable!(),
Task::Yield => thread::yield_now(),
Task::Delay(_) => {
Task::Delay(t) => {
let timer = Instant::now();
let timeout = Duration::from_millis(t);
while timer.elapsed() < timeout {}
Expand Down Expand Up @@ -1251,19 +1282,17 @@ mod tests {
#[cfg(feature = "async")]
#[cfg_attr(not(feature = "failpoints"), ignore)]
#[tokio::test]
async fn test_async_failpoint() {
use std::time::Duration;

async fn test_async_failpoints() {
let f1 = async {
async_fail_point!("cb");
async_fail_point!("async_cb");
};
let f2 = async {
async_fail_point!("cb");
async_fail_point!("async_cb");
};

let counter = Arc::new(AtomicUsize::new(0));
let counter2 = counter.clone();
cfg_async_callback("cb", move || {
cfg_async_callback("async_cb", move || {
counter2.fetch_add(1, Ordering::SeqCst);
Box::pin(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
Expand All @@ -1274,26 +1303,26 @@ mod tests {
f2.await;
assert_eq!(2, counter.load(Ordering::SeqCst));

cfg("pause", "pause").unwrap();
cfg("async_pause", "pause").unwrap();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let handle = tokio::spawn(async move {
async_fail_point!("pause");
async_fail_point!("async_pause");
tx.send(()).await.unwrap();
});
tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.unwrap_err();
remove("pause");
remove("async_pause");
tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.unwrap();
handle.await.unwrap();

cfg("sleep", "sleep(500)").unwrap();
cfg("async_sleep", "sleep(500)").unwrap();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let handle = tokio::spawn(async move {
tx.send(()).await.unwrap();
async_fail_point!("sleep");
async_fail_point!("async_sleep");
tx.send(()).await.unwrap();
});
rx.recv().await.unwrap();
Expand Down

0 comments on commit f1a9306

Please sign in to comment.