diff --git a/tokio/src/sync/oneshot.rs b/tokio/src/sync/oneshot.rs index 9e8c3fcb7f7..ab29b3e3edd 100644 --- a/tokio/src/sync/oneshot.rs +++ b/tokio/src/sync/oneshot.rs @@ -1072,7 +1072,14 @@ impl Receiver { impl Drop for Receiver { fn drop(&mut self) { if let Some(inner) = self.inner.as_ref() { - inner.close(); + let state = inner.close(); + + if state.is_complete() { + // SAFETY: we have ensured that the `VALUE_SENT` bit has been set, + // so only the receiver can access the value. + drop(unsafe { inner.consume_value() }); + } + #[cfg(all(tokio_unstable, feature = "tracing"))] self.resource_span.in_scope(|| { tracing::trace!( @@ -1202,7 +1209,7 @@ impl Inner { } /// Called by `Receiver` to indicate that the value will never be received. - fn close(&self) { + fn close(&self) -> State { let prev = State::set_closed(&self.state); if prev.is_tx_task_set() && !prev.is_complete() { @@ -1210,6 +1217,8 @@ impl Inner { self.tx_task.with_task(Waker::wake_by_ref); } } + + prev } /// Consumes the value. This function does not check `state`. @@ -1248,6 +1257,15 @@ impl Drop for Inner { self.tx_task.drop_task(); } } + + // SAFETY: we have `&mut self`, and therefore we have + // exclusive access to the value. + unsafe { + // Note: the assertion holds because if the value has been sent by sender, + // we must ensure that the value must have been consumed by the receiver before + // dropping the `Inner`. + debug_assert!(self.consume_value().is_none()); + } } } diff --git a/tokio/src/sync/tests/loom_oneshot.rs b/tokio/src/sync/tests/loom_oneshot.rs index c5f79720794..717edcfd2a3 100644 --- a/tokio/src/sync/tests/loom_oneshot.rs +++ b/tokio/src/sync/tests/loom_oneshot.rs @@ -138,3 +138,50 @@ fn changing_tx_task() { } }); } + +#[test] +fn checking_tx_send_ok_not_drop() { + use std::borrow::Borrow; + use std::cell::Cell; + + loom::thread_local! { + static IS_RX: Cell = Cell::new(true); + } + + struct Msg; + + impl Drop for Msg { + fn drop(&mut self) { + IS_RX.with(|is_rx: &Cell<_>| { + // On `tx.send(msg)` returning `Err(msg)`, + // we call `std::mem::forget(msg)`, so that + // `drop` is not expected to be called in the + // tx thread. + assert!(is_rx.get()); + }); + } + } + + let mut builder = loom::model::Builder::new(); + builder.preemption_bound = Some(2); + + builder.check(|| { + let (tx, rx) = oneshot::channel(); + + // tx thread + let tx_thread_join_handle = thread::spawn(move || { + // Ensure that `Msg::drop` in this thread will see is_rx == false + IS_RX.with(|is_rx: &Cell<_>| { + is_rx.set(false); + }); + if let Err(msg) = tx.send(Msg) { + std::mem::forget(msg); + } + }); + + // main thread is the rx thread + drop(rx); + + tx_thread_join_handle.join().unwrap(); + }); +}