diff --git a/tokio-util/src/sync/cancellation_token.rs b/tokio-util/src/sync/cancellation_token.rs index 66fbf1a73e7..32086977348 100644 --- a/tokio-util/src/sync/cancellation_token.rs +++ b/tokio-util/src/sync/cancellation_token.rs @@ -287,6 +287,54 @@ impl CancellationToken { } .await } + + /// Runs a future to completion and returns its result wrapped inside of an `Option` + /// unless the `CancellationToken` is cancelled. In that case the function returns + /// `None` and the future gets dropped. + /// + /// The function takes self by value. + /// + /// # Cancel safety + /// + /// This method is only cancel safe if `fut` is cancel safe. + pub async fn run_until_cancelled_owned(self, fut: F) -> Option + where + F: Future, + { + pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled or a given Future gets resolved. It is biased towards the + /// Future completion. + #[must_use = "futures do nothing unless polled"] + struct RunUntilCancelledFuture { + #[pin] + cancellation: WaitForCancellationFutureOwned, + #[pin] + future: F, + } + } + + impl Future for RunUntilCancelledFuture { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + if let Poll::Ready(res) = this.future.poll(cx) { + Poll::Ready(Some(res)) + } else if this.cancellation.poll(cx).is_ready() { + Poll::Ready(None) + } else { + Poll::Pending + } + } + } + + RunUntilCancelledFuture { + cancellation: self.cancelled_owned(), + future: fut, + } + .await + } } // ===== impl WaitForCancellationFuture ===== diff --git a/tokio-util/tests/sync_cancellation_token.rs b/tokio-util/tests/sync_cancellation_token.rs index db33114a2e3..8c751b07d73 100644 --- a/tokio-util/tests/sync_cancellation_token.rs +++ b/tokio-util/tests/sync_cancellation_token.rs @@ -493,3 +493,51 @@ fn run_until_cancelled_test() { ); } } + +#[test] +fn run_until_cancelled_owned_test() { + let (waker, _) = new_count_waker(); + + { + let token = CancellationToken::new(); + let to_cancel = token.clone(); + + let fut = token.run_until_cancelled_owned(std::future::pending::<()>()); + pin!(fut); + + assert_eq!( + Poll::Pending, + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + to_cancel.cancel(); + + assert_eq!( + Poll::Ready(None), + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + } + + { + let (tx, rx) = oneshot::channel::<()>(); + + let token = CancellationToken::new(); + let fut = token.run_until_cancelled_owned(async move { + rx.await.unwrap(); + 42 + }); + pin!(fut); + + assert_eq!( + Poll::Pending, + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + tx.send(()).unwrap(); + + assert_eq!( + Poll::Ready(Some(42)), + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + } +}