diff --git a/Cargo.lock b/Cargo.lock index fdd4278e3ac..fd6bfac273c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5814,6 +5814,7 @@ dependencies = [ name = "nym-gateway-requests" version = "0.1.0" dependencies = [ + "anyhow", "bs58", "futures", "generic-array 0.14.7", @@ -5826,6 +5827,7 @@ dependencies = [ "nym-sphinx", "nym-statistics-common", "nym-task", + "nym-test-utils", "rand 0.8.5", "serde", "serde_json", @@ -6503,6 +6505,7 @@ dependencies = [ "futures", "nym-crypto", "nym-noise-keys", + "nym-test-utils", "pin-project", "rand_chacha 0.3.1", "sha2 0.10.9", @@ -7052,6 +7055,16 @@ dependencies = [ "wasmtimer", ] +[[package]] +name = "nym-test-utils" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures", + "rand_chacha 0.3.1", + "tokio", +] + [[package]] name = "nym-ticketbooks-merkle" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 3c6143e7257..ab1266ba269 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,7 +92,7 @@ members = [ "common/socks5/requests", "common/statistics", "common/store-cipher", - "common/task", + "common/task", "common/test-utils", "common/ticketbooks-merkle", "common/topology", "common/tun", diff --git a/common/gateway-requests/Cargo.toml b/common/gateway-requests/Cargo.toml index c39c5530dda..83042ead140 100644 --- a/common/gateway-requests/Cargo.toml +++ b/common/gateway-requests/Cargo.toml @@ -47,4 +47,7 @@ workspace = true default-features = false [dev-dependencies] +anyhow = { workspace = true } nym-compact-ecash = { path = "../nym_offline_compact_ecash" } # we need specific imports in tests +nym-test-utils = { path = "../test-utils" } +tokio = { workspace = true, features = ["full"] } diff --git a/common/gateway-requests/src/registration/handshake/mod.rs b/common/gateway-requests/src/registration/handshake/mod.rs index 53373f5c308..f6ef8a4de41 100644 --- a/common/gateway-requests/src/registration/handshake/mod.rs +++ b/common/gateway-requests/src/registration/handshake/mod.rs @@ -109,3 +109,85 @@ GATEWAY -> CLIENT DONE(status) */ + +#[cfg(test)] +mod tests { + use super::*; + use crate::ClientControlRequest; + use futures::StreamExt; + use nym_test_utils::helpers::u64_seeded_rng; + use nym_test_utils::mocks::stream_sink::mock_streams; + use nym_test_utils::traits::{Leak, Timeboxed, TimeboxedSpawnable}; + use tokio::join; + use tungstenite::Message; + + #[tokio::test] + async fn basic_handshake() -> anyhow::Result<()> { + use anyhow::Context as _; + + // solve the lifetime issue by just leaking the contents of the boxes + // which is perfectly fine in test + let client_rng = u64_seeded_rng(42).leak(); + let gateway_rng = u64_seeded_rng(69).leak(); + + let client_keys = ed25519::KeyPair::new(client_rng).leak(); + let gateway_keys = ed25519::KeyPair::new(gateway_rng).leak(); + + let (client_ws, gateway_ws) = mock_streams::(); + + // we need streams that return Result + let client_ws = client_ws.map(Ok); + let gateway_ws = gateway_ws.map(Ok); + + let client_ws = client_ws.leak(); + let gateway_ws = gateway_ws.leak(); + + let handshake_client = client_handshake( + client_rng, + client_ws, + client_keys, + *gateway_keys.public_key(), + false, + true, + TaskClient::dummy(), + ); + + let client_fut = handshake_client.spawn_timeboxed(); + + // we need to receive the first message so that it could be propagated to the gateway side of the handshake + let ClientControlRequest::RegisterHandshakeInitRequest { + protocol_version: _, + data, + } = (gateway_ws.next()) + .timeboxed() + .await + .context("timeout")? + .context("no message!")?? + .into_text()? + .parse::()? + else { + panic!("bad message") + }; + + let init_msg = data; + + let handshake_gateway = gateway_handshake( + gateway_rng, + gateway_ws, + gateway_keys, + init_msg, + TaskClient::dummy(), + ); + + let gateway_fut = handshake_gateway.spawn_timeboxed(); + let (client, gateway) = join!(client_fut, gateway_fut); + + let client_key = client???; + let gateway_key = gateway???; + + // ensure the created keys are the same + assert_eq!(client_key, gateway_key); + + Ok(()) + } +} diff --git a/common/nymnoise/Cargo.toml b/common/nymnoise/Cargo.toml index 4b29e88eed8..63050814ef7 100644 --- a/common/nymnoise/Cargo.toml +++ b/common/nymnoise/Cargo.toml @@ -28,6 +28,7 @@ anyhow = { workspace = true } tokio = { workspace = true, features = ["full"] } rand_chacha = { workspace = true } nym-crypto = { path = "../crypto", features = ["rand"] } +nym-test-utils = { path = "../test-utils" } [lints] diff --git a/common/nymnoise/src/stream/mod.rs b/common/nymnoise/src/stream/mod.rs index 048f50157c7..e860ef46e02 100644 --- a/common/nymnoise/src/stream/mod.rs +++ b/common/nymnoise/src/stream/mod.rs @@ -411,122 +411,21 @@ where mod tests { use super::*; use nym_crypto::asymmetric::x25519; - use rand_chacha::rand_core::SeedableRng; - use std::io::Error; - use std::mem; + use nym_test_utils::helpers::deterministic_rng; + use nym_test_utils::mocks::async_read_write::mock_io_streams; + use nym_test_utils::traits::{Timeboxed, TimeboxedSpawnable}; use std::sync::Arc; - use std::task::{Context, Waker}; - use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::join; - use tokio::sync::Mutex; - use tokio::time::timeout; - - fn mock_streams() -> (MockStream, MockStream) { - let ch1 = Arc::new(Mutex::new(Default::default())); - let ch2 = Arc::new(Mutex::new(Default::default())); - - ( - MockStream { - inner: MockStreamInner { - tx: ch1.clone(), - rx: ch2.clone(), - }, - }, - MockStream { - inner: MockStreamInner { tx: ch2, rx: ch1 }, - }, - ) - } - - struct MockStream { - inner: MockStreamInner, - } - - #[allow(dead_code)] - impl MockStream { - fn unchecked_tx_data(&self) -> Vec { - self.inner.tx.try_lock().unwrap().data.clone() - } - - fn unchecked_rx_data(&self) -> Vec { - self.inner.rx.try_lock().unwrap().data.clone() - } - } - - struct MockStreamInner { - tx: Arc>, - rx: Arc>, - } - - #[derive(Default)] - struct DataWrapper { - data: Vec, - waker: Option, - } - - impl AsyncRead for MockStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let mut inner = self.inner.rx.try_lock().unwrap(); - let data = mem::take(&mut inner.data); - if data.is_empty() { - inner.waker = Some(cx.waker().clone()); - return Poll::Pending; - } - - if let Some(waker) = inner.waker.take() { - waker.wake(); - } - - buf.put_slice(&data); - Poll::Ready(Ok(())) - } - } - - impl AsyncWrite for MockStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let mut inner = self.inner.tx.try_lock().unwrap(); - let len = buf.len(); - - if !inner.data.is_empty() { - assert!(inner.waker.is_none()); - inner.waker = Some(cx.waker().clone()); - return Poll::Pending; - } - - inner.data.extend_from_slice(buf); - if let Some(waker) = inner.waker.take() { - waker.wake(); - } - Poll::Ready(Ok(len)) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - } #[tokio::test] async fn noise_handshake() -> anyhow::Result<()> { - let dummy_seed = [42u8; 32]; - let mut rng = rand_chacha::ChaCha20Rng::from_seed(dummy_seed); + let mut rng = deterministic_rng(); let initiator_keys = Arc::new(x25519::KeyPair::new(&mut rng)); let responder_keys = Arc::new(x25519::KeyPair::new(&mut rng)); - let (initiator_stream, responder_stream) = mock_streams(); + let (initiator_stream, responder_stream) = mock_io_streams(); let psk = generate_psk(*responder_keys.public_key(), NoiseVersion::V1)?; let pattern = NoisePattern::default(); @@ -547,14 +446,8 @@ mod tests { *responder_keys.public_key(), ); - let initiator_fut = - tokio::spawn( - async move { timeout(Duration::from_millis(200), stream_initiator).await }, - ); - let responder_fut = - tokio::spawn( - async move { timeout(Duration::from_millis(200), stream_responder).await }, - ); + let initiator_fut = stream_initiator.spawn_timeboxed(); + let responder_fut = stream_responder.spawn_timeboxed(); let (initiator, responder) = join!(initiator_fut, responder_fut); @@ -563,14 +456,13 @@ mod tests { let msg = b"hello there"; // if noise was successful we should be able to write a proper message across - timeout(Duration::from_millis(200), initiator.write_all(msg)).await??; - + initiator.write_all(msg).timeboxed().await??; initiator.inner_stream.flush().await?; let inner_buf = initiator.inner_stream.get_ref().unchecked_tx_data(); let mut buf = [0u8; 11]; - timeout(Duration::from_millis(200), responder.read(&mut buf)).await??; + responder.read(&mut buf).timeboxed().await??; assert_eq!(&buf[..], msg); diff --git a/common/test-utils/Cargo.toml b/common/test-utils/Cargo.toml new file mode 100644 index 00000000000..8c937797022 --- /dev/null +++ b/common/test-utils/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "nym-test-utils" +version = "0.1.0" +authors.workspace = true +repository.workspace = true +homepage.workspace = true +documentation.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true +readme.workspace = true + +[dependencies] +anyhow = { workspace = true } +futures = { workspace = true } +rand_chacha = { workspace = true } +tokio = { workspace = true, features = ["sync", "time", "rt"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["full"] } + +[lints] +workspace = true diff --git a/common/test-utils/src/helpers.rs b/common/test-utils/src/helpers.rs new file mode 100644 index 00000000000..26dd19f3b1f --- /dev/null +++ b/common/test-utils/src/helpers.rs @@ -0,0 +1,33 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use crate::traits::Timeboxed; +use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha20Rng; +use std::future::Future; +use tokio::task::JoinHandle; +use tokio::time::error::Elapsed; + +pub fn leak(val: T) -> &'static mut T { + Box::leak(Box::new(val)) +} + +pub fn spawn_timeboxed(fut: F) -> JoinHandle> +where + F: Future + Send + 'static, + ::Output: Send, +{ + tokio::spawn(async move { fut.timeboxed().await }) +} + +pub fn deterministic_rng() -> ChaCha20Rng { + seeded_rng([42u8; 32]) +} + +pub fn seeded_rng(seed: [u8; 32]) -> ChaCha20Rng { + ChaCha20Rng::from_seed(seed) +} + +pub fn u64_seeded_rng(seed: u64) -> ChaCha20Rng { + ChaCha20Rng::seed_from_u64(seed) +} diff --git a/common/test-utils/src/lib.rs b/common/test-utils/src/lib.rs new file mode 100644 index 00000000000..b1818473cc7 --- /dev/null +++ b/common/test-utils/src/lib.rs @@ -0,0 +1,6 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +pub mod helpers; +pub mod mocks; +pub mod traits; diff --git a/common/test-utils/src/mocks/async_read_write.rs b/common/test-utils/src/mocks/async_read_write.rs new file mode 100644 index 00000000000..859c8158821 --- /dev/null +++ b/common/test-utils/src/mocks/async_read_write.rs @@ -0,0 +1,161 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use crate::mocks::shared::InnerWrapper; +use futures::ready; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +// sending buffer of the first stream is the receiving buffer of the second stream +// and vice versa +pub fn mock_io_streams() -> (MockIOStream, MockIOStream) { + let ch1 = MockIOStream::default(); + let ch2 = ch1.make_connection(); + + (ch1, ch2) +} + +#[derive(Default)] +pub struct MockIOStream { + // messages to send + tx: InnerWrapper>, + + // messages to receive + rx: InnerWrapper>, +} + +impl MockIOStream { + fn make_connection(&self) -> Self { + MockIOStream { + tx: self.rx.cloned_buffer(), + rx: self.tx.cloned_buffer(), + } + } + + // unwrap in test code is fine + #[allow(clippy::unwrap_used)] + pub fn unchecked_tx_data(&self) -> Vec { + self.tx.buffer.try_lock().unwrap().content.clone() + } + + // unwrap in test code is fine + #[allow(clippy::unwrap_used)] + pub fn unchecked_rx_data(&self) -> Vec { + self.rx.buffer.try_lock().unwrap().content.clone() + } +} + +impl AsyncRead for MockIOStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + ready!(Pin::new(&mut self.rx).poll_guard_ready(cx)); + + // SAFETY: guard is ready + #[allow(clippy::unwrap_used)] + let guard = self.rx.guard().unwrap(); + + let data = guard.take_content(); + if data.is_empty() { + // nothing to retrieve - store the waiter so that the sender could trigger it + guard.waker = Some(cx.waker().clone()); + + // drop the guard so that the sender could actually put messages in + self.rx.transition_to_idle(); + return Poll::Pending; + } + + // if let Some(waker) = guard.waker.take() { + // waker.wake(); + // } + + self.rx.transition_to_idle(); + + buf.put_slice(&data); + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for MockIOStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // wait until we transition to the locked state + ready!(Pin::new(&mut self.tx).poll_guard_ready(cx)); + + // SAFETY: guard is ready + #[allow(clippy::unwrap_used)] + let guard = self.tx.guard().unwrap(); + + let len = buf.len(); + guard.content.extend_from_slice(buf); + + // TODO: if we wanted the behaviour of always reading everything before writing anything extra + // if !guard.content.is_empty() { + // // sanity check + // assert!(guard.waker.is_none()); + // guard.waker = Some(cx.waker().clone()); + // self.tx.transition_to_idle(); + // return Poll::Pending; + // } + + Poll::Ready(Ok(len)) + } + + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let Some(guard) = self.tx.guard() else { + return Poll::Ready(Err(io::Error::other( + "invalid lock state to send/flush messages", + ))); + }; + + if let Some(waker) = guard.waker.take() { + // notify the receiver if it was waiting for messages + waker.wake(); + } + + // release the guard + self.tx.transition_to_idle(); + + Poll::Ready(Ok(())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + // make sure our guard is always dropped on close + self.tx.transition_to_idle(); + + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[tokio::test] + async fn basic() { + let (mut stream1, mut stream2) = mock_io_streams(); + stream1.write_all(&[1, 2, 3, 4, 5]).await.unwrap(); + stream1.flush().await.unwrap(); + + let mut buf = [0u8; 5]; + let read = stream2.read(&mut buf).await.unwrap(); + assert_eq!(read, 5); + assert_eq!(&buf[0..5], &[1, 2, 3, 4, 5]); + + let mut buf = [0u8; 5]; + stream2.write_all(&[6, 7, 8, 9, 10]).await.unwrap(); + stream2.flush().await.unwrap(); + + let read = stream1.read(&mut buf).await.unwrap(); + assert_eq!(read, 5); + assert_eq!(&buf[0..5], &[6, 7, 8, 9, 10]); + } +} diff --git a/common/test-utils/src/mocks/mod.rs b/common/test-utils/src/mocks/mod.rs new file mode 100644 index 00000000000..9962160d2a7 --- /dev/null +++ b/common/test-utils/src/mocks/mod.rs @@ -0,0 +1,6 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +pub mod async_read_write; +mod shared; +pub mod stream_sink; diff --git a/common/test-utils/src/mocks/shared.rs b/common/test-utils/src/mocks/shared.rs new file mode 100644 index 00000000000..c760cd9b5d8 --- /dev/null +++ b/common/test-utils/src/mocks/shared.rs @@ -0,0 +1,109 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use futures::future::BoxFuture; +use futures::{ready, FutureExt}; +use std::mem; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; +use tokio::sync::{Mutex, OwnedMutexGuard}; + +#[derive(Default)] +pub(crate) struct InnerWrapper { + pub(crate) buffer: Arc>>, + lock_state: LockState, +} + +impl InnerWrapper { + pub(crate) fn clone_buffer(&self) -> Arc>> { + Arc::clone(&self.buffer) + } + + pub(crate) fn cloned_buffer(&self) -> Self { + assert!(matches!(self.lock_state, LockState::Idle)); + InnerWrapper { + buffer: self.clone_buffer(), + lock_state: LockState::Idle, + } + } + + // NOTE: it's responsibility of the caller to ensure the guard is released and state transitions to idle! + pub(crate) fn poll_guard_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + match &mut self.lock_state { + LockState::Idle => { + // 1. first try to obtain the guard without locking + let Ok(guard) = self.buffer.clone().try_lock_owned() else { + // 2. if that fails, create the future for obtaining it + self.lock_state = + LockState::TryingToLock(self.buffer.clone().lock_owned().boxed()); + return Poll::Pending; + }; + + // correctly transition to locked state and poll ourselves again + self.lock_state = LockState::Locked(guard); + cx.waker().wake_by_ref(); + Poll::Ready(()) + } + + LockState::TryingToLock(lock_fut) => { + // see if the guard future has resolved, if so, transition to locked state and schedule for another poll + let guard = ready!(lock_fut.as_mut().poll(cx)); + self.lock_state = LockState::Locked(guard); + cx.waker().wake_by_ref(); + Poll::Pending + } + + LockState::Locked(_) => Poll::Ready(()), + } + } + + pub(crate) fn guard(&mut self) -> Option<&mut OwnedMutexGuard>> { + match &mut self.lock_state { + LockState::Locked(guard) => Some(guard), + _ => None, + } + } + + pub(crate) fn transition_to_idle(&mut self) { + self.lock_state = LockState::Idle + } +} + +#[derive(Default)] +pub(crate) enum LockState { + // We haven’t started locking yet + #[default] + Idle, + + // Waiting for the mutex lock future to resolve + TryingToLock(BoxFuture<'static, OwnedMutexGuard>>), + + // We hold the mutex guard + Locked(OwnedMutexGuard>), +} + +#[derive(Default)] +pub struct ContentWrapper { + pub(crate) content: T, + pub(crate) waker: Option, +} + +impl ContentWrapper { + pub fn into_content(self) -> T { + self.content + } + + pub fn content(&self) -> &T { + &self.content + } + + pub(crate) fn take_content(&mut self) -> T + where + T: Default, + { + mem::take(&mut self.content) + } +} + +impl LockState {} diff --git a/common/test-utils/src/mocks/stream_sink.rs b/common/test-utils/src/mocks/stream_sink.rs new file mode 100644 index 00000000000..0cd866092e9 --- /dev/null +++ b/common/test-utils/src/mocks/stream_sink.rs @@ -0,0 +1,181 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use crate::mocks::shared::{ContentWrapper, InnerWrapper}; +use anyhow::{anyhow, bail}; +use futures::{ready, Sink, Stream}; +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::Mutex; + +// sending buffer of the first stream is the receiving buffer of the second stream +// and vice versa +pub fn mock_streams() -> (MockStream, MockStream) +where + T: Send, +{ + let ch1 = MockStream::default(); + let ch2 = ch1.make_connection(); + + (ch1, ch2) +} + +pub struct MockStream { + // messages to send + tx: InnerWrapper>, + + // messages to receive + rx: InnerWrapper>, +} + +impl MockStream { + pub fn clone_tx_buffer(&self) -> Arc>>> + where + T: Send, + { + self.tx.clone_buffer() + } + + pub fn clone_rx_buffer(&self) -> Arc>>> + where + T: Send, + { + self.rx.clone_buffer() + } + + fn make_connection(&self) -> Self + where + T: Send, + { + MockStream { + tx: self.rx.cloned_buffer(), + rx: self.tx.cloned_buffer(), + } + } +} + +impl Default for MockStream { + fn default() -> Self { + MockStream { + tx: InnerWrapper::default(), + rx: InnerWrapper::default(), + } + } +} + +impl Stream for MockStream +where + T: Send, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(Pin::new(&mut self.rx).poll_guard_ready(cx)); + + // SAFETY: guard is ready + #[allow(clippy::unwrap_used)] + let guard = self.rx.guard().unwrap(); + + let Some(next) = guard.content.pop_front() else { + // nothing to retrieve - store the waiter so that the sender could trigger it + guard.waker = Some(cx.waker().clone()); + + // drop the guard so that the sender could actually put messages in + self.rx.transition_to_idle(); + return Poll::Pending; + }; + + // there are more messages buffered waiting for us to retrieve + // keep the guard! + if !guard.content.is_empty() { + cx.waker().wake_by_ref(); + } else { + // no more messages, drop the guard + self.rx.transition_to_idle(); + } + + Poll::Ready(Some(next)) + } + + fn size_hint(&self) -> (usize, Option) { + // that's just a minor optimisation, so don't sweat about it too much, + // if we can obtain the mutex, give precise information, otherwise return default values + let Ok(guard) = self.rx.buffer.try_lock() else { + return (0, None); + }; + let items = guard.content.len(); + (items, Some(items)) + } +} + +impl Sink for MockStream +where + T: Send, +{ + type Error = anyhow::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // wait until we transition to the locked state + ready!(Pin::new(&mut self.tx).poll_guard_ready(cx)); + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let Some(guard) = self.tx.guard() else { + bail!("invalid lock state to send messages"); + }; + guard.content.push_back(item); + + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let Some(guard) = self.tx.guard() else { + return Poll::Ready(Err(anyhow!("invalid lock state to send/flush messages"))); + }; + + if let Some(waker) = guard.waker.take() { + // notify the receiver if it was waiting for messages + waker.wake(); + } + + // release the guard + self.tx.transition_to_idle(); + + Poll::Ready(Ok(())) + } + + fn poll_close( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + // make sure our guard is always dropped on close + self.tx.transition_to_idle(); + + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::{SinkExt, StreamExt}; + + #[tokio::test] + async fn basic() { + let (mut stream1, mut stream2) = mock_streams(); + stream1.send("foomp").await.unwrap(); + + let received = stream2.next().await.unwrap(); + assert_eq!(received, "foomp"); + + stream2.send("bar").await.unwrap(); + let received = stream1.next().await.unwrap(); + assert_eq!(received, "bar"); + } +} diff --git a/common/test-utils/src/traits.rs b/common/test-utils/src/traits.rs new file mode 100644 index 00000000000..9b95dbd6bf5 --- /dev/null +++ b/common/test-utils/src/traits.rs @@ -0,0 +1,57 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use crate::helpers::{leak, spawn_timeboxed}; +use std::future::{Future, IntoFuture}; +use std::time::Duration; +use tokio::task::JoinHandle; +use tokio::time::error::Elapsed; + +// a helper trait for use in tests to easily convert `T` into `&'static mut T` +pub trait Leak { + fn leak(self) -> &'static mut Self; +} + +impl Leak for T { + fn leak(self) -> &'static mut T { + leak(self) + } +} + +// those are internal testing traits so we're not concerned about auto traits +#[allow(async_fn_in_trait)] +pub trait Timeboxed: IntoFuture + Sized { + async fn timeboxed(self) -> Result { + self.execute_with_deadline(Duration::from_millis(200)).await + } + + async fn execute_with_deadline(self, timeout: Duration) -> Result { + tokio::time::timeout(timeout, self).await + } +} + +impl Timeboxed for T where T: IntoFuture + Sized {} + +// those are internal testing traits so we're not concerned about auto traits +#[allow(async_fn_in_trait)] +pub trait Spawnable: Future + Sized + Send + 'static { + fn spawn(self) -> JoinHandle + where + ::Output: Send + 'static, + { + tokio::spawn(self) + } +} + +impl Spawnable for T where T: Future + Sized + Send + 'static {} + +pub trait TimeboxedSpawnable: Timeboxed + Spawnable { + fn spawn_timeboxed(self) -> JoinHandle::Output, Elapsed>> + where + ::Output: Send, + { + spawn_timeboxed(self) + } +} + +impl TimeboxedSpawnable for T where T: Spawnable + Future + Send {}