-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ampd): retry event handling with timeout and max attemmpts (#251)
* feat(ampd): retry event handling with timeout and max attemmpts * add tests * address comments * improve tests * address comments * improve comments --------- Co-authored-by: Sammy Liu <[email protected]>
- Loading branch information
1 parent
e7a2301
commit 96eadd5
Showing
4 changed files
with
215 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
use futures::{Future, FutureExt}; | ||
use std::pin::Pin; | ||
use std::task::{Context, Poll}; | ||
use std::time::Duration; | ||
|
||
use tokio::time; | ||
|
||
pub fn with_retry<F, Fut, R, Err>( | ||
future: F, | ||
policy: RetryPolicy, | ||
) -> impl Future<Output = Result<R, Err>> | ||
where | ||
F: Fn() -> Fut, | ||
Fut: Future<Output = Result<R, Err>>, | ||
{ | ||
RetriableFuture::new(future, policy) | ||
} | ||
|
||
pub enum RetryPolicy { | ||
RepeatConstant { sleep: Duration, max_attempts: u64 }, | ||
} | ||
|
||
struct RetriableFuture<F, Fut, R, Err> | ||
where | ||
F: Fn() -> Fut, | ||
Fut: Future<Output = Result<R, Err>>, | ||
{ | ||
future: F, | ||
inner: Pin<Box<Fut>>, | ||
policy: RetryPolicy, | ||
err_count: u64, | ||
} | ||
|
||
impl<F, Fut, R, Err> Unpin for RetriableFuture<F, Fut, R, Err> | ||
where | ||
F: Fn() -> Fut, | ||
Fut: Future<Output = Result<R, Err>>, | ||
{ | ||
} | ||
|
||
impl<F, Fut, R, Err> RetriableFuture<F, Fut, R, Err> | ||
where | ||
F: Fn() -> Fut, | ||
Fut: Future<Output = Result<R, Err>>, | ||
{ | ||
fn new(get_future: F, policy: RetryPolicy) -> Self { | ||
let future = get_future(); | ||
|
||
Self { | ||
future: get_future, | ||
inner: Box::pin(future), | ||
policy, | ||
err_count: 0, | ||
} | ||
} | ||
|
||
fn handle_err( | ||
mut self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
error: Err, | ||
) -> Poll<Result<R, Err>> { | ||
self.err_count += 1; | ||
|
||
match self.policy { | ||
RetryPolicy::RepeatConstant { | ||
sleep, | ||
max_attempts, | ||
} => { | ||
if self.err_count >= max_attempts { | ||
return Poll::Ready(Err(error)); | ||
} | ||
|
||
self.inner = Box::pin((self.future)()); | ||
|
||
let waker = cx.waker().clone(); | ||
tokio::spawn(time::sleep(sleep).then(|_| async { | ||
waker.wake(); | ||
})); | ||
|
||
Poll::Pending | ||
} | ||
} | ||
} | ||
} | ||
|
||
impl<F, Fut, R, Err> Future for RetriableFuture<F, Fut, R, Err> | ||
where | ||
F: Fn() -> Fut, | ||
Fut: Future<Output = Result<R, Err>>, | ||
{ | ||
type Output = Result<R, Err>; | ||
|
||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||
match self.inner.as_mut().poll(cx) { | ||
Poll::Pending => Poll::Pending, | ||
Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)), | ||
Poll::Ready(Err(error)) => self.handle_err(cx, error), | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::{future, sync::Mutex}; | ||
|
||
use tokio::time::Instant; | ||
|
||
use super::*; | ||
|
||
#[tokio::test] | ||
async fn should_return_ok_when_the_internal_future_returns_ok_immediately() { | ||
let fut = with_retry( | ||
|| future::ready(Ok::<(), ()>(())), | ||
RetryPolicy::RepeatConstant { | ||
sleep: Duration::from_secs(1), | ||
max_attempts: 3, | ||
}, | ||
); | ||
let start = Instant::now(); | ||
|
||
assert!(fut.await.is_ok()); | ||
assert!(start.elapsed() < Duration::from_secs(1)); | ||
} | ||
|
||
#[tokio::test(start_paused = true)] | ||
async fn should_return_ok_when_the_internal_future_returns_ok_eventually() { | ||
let max_attempts = 3; | ||
let count = Mutex::new(0); | ||
let fut = with_retry( | ||
|| async { | ||
*count.lock().unwrap() += 1; | ||
time::sleep(Duration::from_secs(1)).await; | ||
|
||
if *count.lock().unwrap() < max_attempts - 1 { | ||
Err::<(), ()>(()) | ||
} else { | ||
Ok::<(), ()>(()) | ||
} | ||
}, | ||
RetryPolicy::RepeatConstant { | ||
sleep: Duration::from_secs(1), | ||
max_attempts, | ||
}, | ||
); | ||
let start = Instant::now(); | ||
|
||
assert!(fut.await.is_ok()); | ||
assert!(start.elapsed() >= Duration::from_secs(3)); | ||
assert!(start.elapsed() < Duration::from_secs(4)); | ||
} | ||
|
||
#[tokio::test(start_paused = true)] | ||
async fn should_return_error_when_the_internal_future_returns_error_after_max_attempts() { | ||
let fut = with_retry( | ||
|| future::ready(Err::<(), ()>(())), | ||
RetryPolicy::RepeatConstant { | ||
sleep: Duration::from_secs(1), | ||
max_attempts: 3, | ||
}, | ||
); | ||
let start = Instant::now(); | ||
|
||
assert!(fut.await.is_err()); | ||
assert!(start.elapsed() >= Duration::from_secs(2)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
pub mod future; | ||
pub mod task; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters