From a0c01aebc92677de20e70a269926f6ad761ce48c Mon Sep 17 00:00:00 2001 From: driftluo Date: Wed, 4 Mar 2026 14:26:39 +0800 Subject: [PATCH 1/2] fix: fix yamux stream handle split --- yamux/src/stream.rs | 349 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 342 insertions(+), 7 deletions(-) diff --git a/yamux/src/stream.rs b/yamux/src/stream.rs index 7cdbb407..06dd04ab 100644 --- a/yamux/src/stream.rs +++ b/yamux/src/stream.rs @@ -301,10 +301,43 @@ impl StreamHandle { Ok(has_new_frame) } - fn recv_frames_wake(&mut self, cx: &mut Context) -> Result<(), Error> { + fn try_recv_frames(&mut self) -> Result { + let mut has_new_frame = false; + loop { + match self.state { + StreamState::RemoteClosing => { + return Err(Error::SubStreamRemoteClosing); + } + StreamState::Reset | StreamState::Closed => { + return Err(Error::SessionShutdown); + } + _ => {} + } + + if self.frame_receiver.is_terminated() { + self.state = StreamState::RemoteClosing; + return Err(Error::SubStreamRemoteClosing); + } + + match self.frame_receiver.try_next() { + Ok(Some(frame)) => { + self.handle_frame(frame)?; + has_new_frame = true; + } + Ok(None) => { + self.state = StreamState::RemoteClosing; + return Err(Error::SubStreamRemoteClosing); + } + Err(_) => break, + } + } + Ok(has_new_frame) + } + + fn recv_frames_wake(&mut self) -> Result<(), Error> { let buf_len = self.read_buf.len(); let state = self.state; - match self.recv_frames(cx) { + match self.try_recv_frames() { Ok(should_wake_read) => { // if state change to RemoteClosing, wake read // if read buf len change, wake read @@ -363,6 +396,7 @@ impl StreamHandle { return Poll::Ready(Ok(0)); } + self.readable_wake = Some(cx.waker().clone()); if let Err(Error::UnexpectedFlag | Error::RecvWindowExceeded | Error::InvalidMsgType) = self.recv_frames(cx) { @@ -376,7 +410,6 @@ impl StreamHandle { } if self.read_buf.is_empty() { - self.readable_wake = Some(cx.waker().clone()); return Poll::Pending; } @@ -420,6 +453,7 @@ impl AsyncRead for StreamHandle { return Poll::Ready(Ok(())); } + self.readable_wake = Some(cx.waker().clone()); if let Err(Error::UnexpectedFlag | Error::RecvWindowExceeded | Error::InvalidMsgType) = self.recv_frames(cx) { @@ -433,7 +467,6 @@ impl AsyncRead for StreamHandle { } if self.read_buf.is_empty() { - self.readable_wake = Some(cx.waker().clone()); return Poll::Pending; } @@ -504,7 +537,7 @@ impl AsyncWrite for StreamHandle { // When the session receives a window update frame, it can update the state of the stream. // In the implementation here, we try not to share state between the session and the stream. if let Err(Error::UnexpectedFlag | Error::RecvWindowExceeded | Error::InvalidMsgType) = - self.recv_frames_wake(cx) + self.recv_frames_wake() { // read flag error or read data error self.send_go_away(); @@ -632,9 +665,31 @@ mod test { use futures::{ SinkExt, StreamExt, channel::mpsc::{channel, unbounded}, + task::{ArcWake, waker_ref}, }; - use std::io::ErrorKind; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use std::{ + io::ErrorKind, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, + task::{Context, Poll}, + }; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; + + #[derive(Default)] + struct FlagWaker(AtomicBool); + impl ArcWake for FlagWaker { + fn wake_by_ref(arc_self: &Arc) { + arc_self.0.store(true, Ordering::SeqCst); + } + } + impl FlagWaker { + fn woken(&self) -> bool { + self.0.load(Ordering::SeqCst) + } + } #[test] fn test_drop() { @@ -940,4 +995,284 @@ mod test { assert_eq!(stream.read_buf.capacity(), 4); }); } + + // Regression test for: + // `poll_write` calling `recv_frames(write_cx)` which polls `frame_receiver` + // with the **write** task's Context and thereby overwrites the read task's waker + // that was stored there by `poll_read`. Once the write_waker is stale (write + // task finished), any incoming data frame wakes nobody → read side hangs. + // + // This test is fully deterministic: it uses custom flag-wakers and manually drives + // `poll_read` / `poll_write`, so there is no reliance on tokio scheduler ordering. + // + // Failure scenario (old buggy code): + // 1. poll_read(read_cx) → recv_frames(read_cx) → read_waker stored in frame_receiver + // 2. poll_write(write_cx)→ recv_frames(write_cx) → write_waker stored in frame_receiver + // (OVERWRITES read_waker) + // 3. Data frame injected → frame_receiver wakes write_waker (stale / already done) + // → read_waker is NEVER notified → read side hangs forever. + // + // With the fix (try_recv_frames / try_next): + // Step 2 does NOT touch frame_receiver's stored waker. + // Step 3 correctly wakes read_waker. + #[test] + fn test_write_side_does_not_overwrite_read_waker() { + let rt = rt(); + rt.block_on(async { + let (mut frame_sender, frame_receiver) = channel(128); + let (unbound_sender, _unbound_receiver) = unbounded(); + let mut stream = StreamHandle::new( + 1, + unbound_sender, + frame_receiver, + StreamState::Init, + INITIAL_STREAM_WINDOW, + ); + + let read_fw = Arc::new(FlagWaker::default()); + let write_fw = Arc::new(FlagWaker::default()); + let read_waker_ref = waker_ref(&read_fw); + let write_waker_ref = waker_ref(&write_fw); + let mut read_cx = Context::from_waker(&read_waker_ref); + let mut write_cx = Context::from_waker(&write_waker_ref); + + // Step 1: poll_read → parks → read_waker registered in frame_receiver. + let mut buf = vec![0u8; 32]; + let mut rbuf = ReadBuf::new(&mut buf); + let r = Pin::new(&mut stream).poll_read(&mut read_cx, &mut rbuf); + assert!( + r.is_pending(), + "poll_read must return Pending (no data yet)" + ); + + // Step 2: poll_write (send_window > 0, succeeds immediately). + // OLD buggy code: recv_frames_wake(write_cx) → recv_frames(write_cx) + // → frame_receiver.poll_next(write_cx) → OVERWRITES read_waker + // with write_waker. + // NEW fixed code: recv_frames_wake(_cx) → try_recv_frames() → try_next() + // → does NOT touch frame_receiver's stored waker at all. + let r = Pin::new(&mut stream).poll_write(&mut write_cx, b"ping"); + assert!( + matches!(r, Poll::Ready(Ok(4))), + "poll_write must succeed (send_window has capacity)" + ); + + // Step 3: inject an incoming data frame. + // The mpsc channel calls wake() synchronously on the stored waker when an + // item is enqueued while the receiver is waiting. + // + // OLD (bug): write_waker was stored last → write_fw.woken() == true, + // read_fw.woken() == false → read side would hang. + // NEW (fix): read_waker was stored last → read_fw.woken() == true. + let frame = Frame::new_data(Flags::from(Flag::Syn), 1, BytesMut::from("data")); + frame_sender.send(frame).await.unwrap(); + + assert!( + read_fw.woken(), + "BUG REPRODUCED: read_waker was overwritten by write side; \ + incoming data frame woke write_waker instead of read_waker. \ + The read side would hang forever." + ); + assert!( + !write_fw.woken(), + "write_waker must NOT be stored in frame_receiver \ + (only the read side should register its waker there)" + ); + }); + } + + // Regression test: when send_window == 0 (write side blocked), a window-update + // frame from the remote must travel via the READ path—not the write path—to + // unblock the write side. + // + // Correct flow (with fix): + // window_update arrives → frame_receiver wakes read_waker (read path owns it) + // → poll_read processes handle_window_update → send_window increases + // → writeable_wake.wake() → write_waker notified → write side can proceed. + // + // Buggy flow (old code): + // recv_frames(write_cx) stored write_waker in frame_receiver, overwriting read_waker. + // window_update arrives → write_waker notified → write task itself drains the frame, + // which accidentally works for the write side, BUT the read task's waker is now gone. + // Any subsequent DATA frame would silently wake the stale write_waker → read hangs. + // + // This test is fully deterministic via custom flag-wakers and manual polling. + #[test] + fn test_window_update_wakes_write_via_read_path() { + let rt = rt(); + rt.block_on(async { + let (mut frame_sender, frame_receiver) = channel(128); + let (unbound_sender, _unbound_receiver) = unbounded(); + let mut stream = StreamHandle::new( + 1, + unbound_sender, + frame_receiver, + StreamState::Init, + INITIAL_STREAM_WINDOW, + ); + + // Exhaust the send window so poll_write will park. + stream.send_window = 0; + + let read_fw = Arc::new(FlagWaker::default()); + let write_fw = Arc::new(FlagWaker::default()); + let read_waker_ref = waker_ref(&read_fw); + let write_waker_ref = waker_ref(&write_fw); + let mut read_cx = Context::from_waker(&read_waker_ref); + let mut write_cx = Context::from_waker(&write_waker_ref); + + // Step 1: poll_read → parks → read_waker registered in frame_receiver, + // readable_wake = read_waker. + let mut buf = vec![0u8; 32]; + let mut rbuf = ReadBuf::new(&mut buf); + assert!( + Pin::new(&mut stream) + .poll_read(&mut read_cx, &mut rbuf) + .is_pending() + ); + + // Step 2: poll_write → send_window == 0 → parks. + // OLD: recv_frames(write_cx) first OVERWRITES frame_receiver's waker with + // write_waker. Then send_window==0 → writeable_wake = write_waker. + // NEW: try_recv_frames() does NOT touch frame_receiver waker. + // send_window==0 → writeable_wake = write_waker. + assert!( + Pin::new(&mut stream) + .poll_write(&mut write_cx, b"ping") + .is_pending() + ); + + // Step 3: inject a window-update frame (simulates remote granting more window). + // This synchronously wakes whoever is registered in frame_receiver. + // OLD (bug): write_waker → write_fw.woken() == true BEFORE we poll_read. + // But frame_receiver now holds write_waker (stale once write is done). + // NEW (fix): read_waker → read_fw.woken() == true. + let wu = Frame::new_window_update(Flags::default(), 1, 65535); + frame_sender.send(wu).await.unwrap(); + + // With the fix, the read_waker must be the one notified. + assert!( + read_fw.woken(), + "BUG: read_waker was overwritten; window-update woke write_waker instead. \ + The read path cannot process the window-update → write side stays stuck." + ); + assert!( + !write_fw.woken(), + "write_waker must not be in frame_receiver; it should only be in writeable_wake" + ); + + // Step 4: simulate the read task re-polling after being woken. + // poll_read processes the window-update frame via handle_window_update, + // which increases send_window and calls writeable_wake.wake(). + let mut rbuf2 = ReadBuf::new(&mut buf); + // poll_read will drain the window-update frame and internally call + // writeable_wake.wake(), which notifies write_fw. + // (The window-update has no data so read returns Pending again.) + let _ = Pin::new(&mut stream).poll_read(&mut read_cx, &mut rbuf2); + + // After handle_window_update → writeable_wake.wake(), the write task + // (write_waker) must now be notified so it can retry and succeed. + assert!( + write_fw.woken(), + "write_waker must be notified via writeable_wake after \ + the read path processes the window-update frame" + ); + + // Step 5: poll_write again now that send_window > 0. + let r = Pin::new(&mut stream).poll_write(&mut write_cx, b"ping"); + assert!( + matches!(r, Poll::Ready(Ok(4))), + "poll_write must now succeed after window was restored" + ); + }); + } + + // Verifies that when `poll_write` calls `try_recv_frames()` and intercepts an + // incoming DATA frame (i.e. the read buffer grows), it proactively wakes the + // parked read task via `readable_wake`. + // + // Motivation: + // The write path uses `try_recv_frames()` to drain `frame_receiver` non-blockingly + // before attempting to write. When it finds data frames, those frames accumulate + // in `read_buf` but the read task is still parked waiting on `frame_receiver`. + // Since `try_recv_frames()` does NOT register any waker in `frame_receiver`, the + // read task will never receive a wakeup from the channel itself. Therefore + // `recv_frames_wake` must explicitly call `readable_wake.wake()` whenever the + // read buffer grows. Without this, the read side would silently stall even though + // data has already arrived and is sitting in `read_buf`. + // + // Note: this is a pure positive-behavior test of the `readable_wake` notification + // path in `recv_frames_wake`. It is orthogonal to the waker-overwrite regression + // tests above: when a frame is already in the channel, `poll_next` returns `Ready` + // immediately without storing a waker, so the overwrite bug does not apply here. + // Both the buggy and fixed implementations correctly satisfy this assertion. + // + // Test sequence (fully deterministic, no async scheduler involvement): + // 1. poll_read(read_cx) → Pending, readable_wake = read_waker. + // 2. Pre-queue a data frame in frame_sender (already available synchronously). + // 3. poll_write(write_cx) → try_recv_frames() drains the data frame → + // read_buf grows → recv_frames_wake detects buf change → + // readable_wake.take().wake() → read_fw.woken() == true. + #[test] + fn test_poll_write_wakes_read_when_data_frame_intercepted() { + let rt = rt(); + rt.block_on(async { + let (mut frame_sender, frame_receiver) = channel(128); + let (unbound_sender, _unbound_receiver) = unbounded(); + let mut stream = StreamHandle::new( + 1, + unbound_sender, + frame_receiver, + StreamState::Init, + INITIAL_STREAM_WINDOW, + ); + + let read_fw = Arc::new(FlagWaker::default()); + let write_fw = Arc::new(FlagWaker::default()); + let read_waker_ref = waker_ref(&read_fw); + let write_waker_ref = waker_ref(&write_fw); + let mut read_cx = Context::from_waker(&read_waker_ref); + let mut write_cx = Context::from_waker(&write_waker_ref); + + // Step 1: poll_read parks → readable_wake = read_waker. + let mut buf = vec![0u8; 32]; + let mut rbuf = ReadBuf::new(&mut buf); + assert!( + Pin::new(&mut stream) + .poll_read(&mut read_cx, &mut rbuf) + .is_pending(), + "poll_read must return Pending (no data yet)" + ); + assert!(!read_fw.woken(), "read_waker must not be woken yet"); + + // Step 2: pre-queue a data frame so it is ready for synchronous delivery. + let frame = Frame::new_data(Flags::from(Flag::Syn), 1, BytesMut::from("hello")); + frame_sender + .try_send(frame) + .expect("channel must accept frame"); + + // Step 3: poll_write → recv_frames_wake → try_recv_frames() drains the data + // frame synchronously → read_buf grows from 0 to 1 → buf_len check triggers + // → readable_wake.take().wake() → read_fw.woken() == true. + let r = Pin::new(&mut stream).poll_write(&mut write_cx, b"ping"); + assert!( + matches!(r, Poll::Ready(Ok(4))), + "poll_write must succeed (send_window has capacity)" + ); + + assert!( + read_fw.woken(), + "poll_write must wake the read task after intercepting a data frame via \ + try_recv_frames(); without this the read side silently stalls even though \ + data is already sitting in read_buf" + ); + + // Sanity: the data frame really did land in read_buf. + assert_eq!( + stream.read_buf.len(), + 1, + "data frame must be in read_buf after try_recv_frames()" + ); + }); + } } From 8622684b61bb37c70e8f97a7e2b03cad6e0a6541 Mon Sep 17 00:00:00 2001 From: driftluo Date: Wed, 4 Mar 2026 16:34:11 +0800 Subject: [PATCH 2/2] test: test on tentacle --- tentacle/Cargo.toml | 2 +- tentacle/tests/test_spawn_first_message.rs | 306 +++++++++++++++++++++ yamux/src/stream.rs | 14 +- 3 files changed, 314 insertions(+), 8 deletions(-) create mode 100644 tentacle/tests/test_spawn_first_message.rs diff --git a/tentacle/Cargo.toml b/tentacle/Cargo.toml index af1c4230..eca95110 100644 --- a/tentacle/Cargo.toml +++ b/tentacle/Cargo.toml @@ -48,7 +48,7 @@ tokio-rustls = { version = "0.26.0", optional = true } [target.'cfg(not(target_family = "wasm"))'.dependencies] # rand 0.8 not support wasm32 rand = "0.8" -socket2 = { version = "0.5.0", features = ["all"] } +socket2 = { version = "0.6.0", features = ["all"] } fast-socks5 = "0.10.0" [target.'cfg(target_family = "wasm")'.dependencies] diff --git a/tentacle/tests/test_spawn_first_message.rs b/tentacle/tests/test_spawn_first_message.rs new file mode 100644 index 00000000..87d8e349 --- /dev/null +++ b/tentacle/tests/test_spawn_first_message.rs @@ -0,0 +1,306 @@ +#![cfg(feature = "unstable")] + +use futures::StreamExt; +use std::{ + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + mpsc::channel, + }, + thread, + time::Duration, +}; +use tentacle::{ + SubstreamReadPart, + builder::{MetaBuilder, ServiceBuilder}, + context::SessionContext, + multiaddr::Multiaddr, + secio::SecioKeyPair, + service::{ProtocolMeta, Service, ServiceAsyncControl, TargetProtocol}, + traits::{ProtocolSpawn, ServiceHandle}, +}; + +// --- Helpers --- + +fn create_service(secio: bool, meta: ProtocolMeta, shandle: F) -> Service +where + F: ServiceHandle + Unpin + 'static, +{ + let builder = ServiceBuilder::default().insert_protocol(meta); + if secio { + builder + .handshake_type(SecioKeyPair::secp256k1_generated().into()) + .build(shandle) + } else { + builder.build(shandle) + } +} + +/// Run a listener + dialer pair. +/// +/// Each probe is responsible for calling `control.shutdown()` on its own +/// service after finishing its work. When a side shuts down and the TCP +/// connection closes, the remote side's inner service detects the +/// disconnect, exits its event loop, and `service.run()` returns +/// naturally (no explicit shutdown needed on that side). +/// +/// A hard timeout (`Duration`) on `service.run()` prevents infinite hangs. +fn run_pair( + secio: bool, + listener_meta: ProtocolMeta, + dialer_meta: ProtocolMeta, + timeout: Duration, +) { + let (addr_sender, addr_receiver) = channel::(); + + let listener_thread = thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let mut service = create_service(secio, listener_meta, ()); + rt.block_on(async move { + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + let _ignore = tokio::time::timeout(timeout, service.run()).await; + }); + }); + + let dialer_thread = thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let mut service = create_service(secio, dialer_meta, ()); + rt.block_on(async move { + let listen_addr = addr_receiver.recv().unwrap(); + service + .dial(listen_addr, TargetProtocol::Single(1.into())) + .await + .unwrap(); + let _ignore = tokio::time::timeout(timeout, service.run()).await; + }); + }); + + listener_thread.join().unwrap(); + dialer_thread.join().unwrap(); +} + +// ======================================================================== +// Test 1: First message exchange +// +// Both sides send a single "init" message immediately on spawn and verify +// they receive it. This tests that the very first message written by one +// side is correctly forwarded through the read channel to the other +// side's SubstreamReadPart — the exact path where the yamux waker- +// overwrite bug manifested. +// +// Each probe calls `control.shutdown()` after receiving its init message. +// The remote side detects the connection drop and its `service.run()` +// returns naturally through the event loop. +// ======================================================================== + +#[derive(Clone)] +struct FirstMessageProbe { + received: Arc, +} + +impl ProtocolSpawn for FirstMessageProbe { + fn spawn( + &self, + context: Arc, + control: &ServiceAsyncControl, + mut read_part: SubstreamReadPart, + ) { + let session_id = context.id; + let proto_id = read_part.protocol_id(); + + // Send one "init" message to the remote side. + let send_control = control.clone(); + tokio::spawn(async move { + let _ignore = send_control + .send_message_to(session_id, proto_id, b"init".to_vec().into()) + .await; + }); + + // Read the "init" message, then shut down. + let received = self.received.clone(); + let shutdown_control = control.clone(); + tokio::spawn(async move { + let _ignore = tokio::time::timeout(Duration::from_secs(10), async { + while let Some(Ok(data)) = read_part.next().await { + if data.as_ref() == b"init" { + received.fetch_add(1, Ordering::SeqCst); + break; + } + } + }) + .await; + // Brief delay so the remote side also has time to receive *our* + // init before we tear down the connection. + tokio::time::sleep(Duration::from_secs(1)).await; + let _ignore = shutdown_control.shutdown().await; + }); + } +} + +fn run_first_message_test(secio: bool, iterations: usize) { + for _i in 0..iterations { + let listener_received = Arc::new(AtomicUsize::new(0)); + let dialer_received = Arc::new(AtomicUsize::new(0)); + + let meta_listener = MetaBuilder::new() + .id(1.into()) + .protocol_spawn(FirstMessageProbe { + received: listener_received.clone(), + }) + .build(); + + let meta_dialer = MetaBuilder::new() + .id(1.into()) + .protocol_spawn(FirstMessageProbe { + received: dialer_received.clone(), + }) + .build(); + + run_pair(secio, meta_listener, meta_dialer, Duration::from_secs(15)); + + let lr = listener_received.load(Ordering::SeqCst); + let dr = dialer_received.load(Ordering::SeqCst); + assert_eq!(lr, 1, "listener did not receive first message"); + assert_eq!(dr, 1, "dialer did not receive first message"); + } +} + +#[test] +fn test_spawn_first_message_with_no_secio() { + run_first_message_test(false, 3); +} + +#[test] +fn test_spawn_first_message_with_secio() { + run_first_message_test(true, 3); +} + +// ======================================================================== +// Test 2: Multi-message bidirectional exchange +// +// Both sides send N messages and read until they've received all N. +// This tests sustained bidirectional data flow through the spawn model, +// exercising the channel-based read forwarding under load. +// ======================================================================== + +const MULTI_MSG_COUNT: usize = 100; + +#[derive(Clone)] +struct MultiMessageProbe { + received: Arc, + total_done: Arc, +} + +impl ProtocolSpawn for MultiMessageProbe { + fn spawn( + &self, + context: Arc, + control: &ServiceAsyncControl, + mut read_part: SubstreamReadPart, + ) { + let session_id = context.id; + let proto_id = read_part.protocol_id(); + + // Send N messages + let send_control = control.clone(); + tokio::spawn(async move { + for i in 0..MULTI_MSG_COUNT { + let msg = format!("msg-{i}"); + if let Err(_e) = send_control + .send_message_to(session_id, proto_id, msg.into_bytes().into()) + .await + { + break; + } + } + }); + + // Receive N messages, wait for peer, then shut down. + let received = self.received.clone(); + let total_done = self.total_done.clone(); + let shutdown_control = control.clone(); + tokio::spawn(async move { + let _ignore = tokio::time::timeout(Duration::from_secs(30), async { + let mut count = 0usize; + while let Some(Ok(_data)) = read_part.next().await { + count += 1; + received.fetch_add(1, Ordering::SeqCst); + if count >= MULTI_MSG_COUNT { + break; + } + } + }) + .await; + + // Signal this side is done receiving. + total_done.fetch_add(1, Ordering::SeqCst); + + // Wait for the other side to also finish (max 10s). + for _ in 0..200 { + if total_done.load(Ordering::SeqCst) >= 2 { + break; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + + let _ignore = shutdown_control.shutdown().await; + }); + } +} + +fn run_multi_message_test(secio: bool, iterations: usize) { + for _i in 0..iterations { + let listener_received = Arc::new(AtomicUsize::new(0)); + let dialer_received = Arc::new(AtomicUsize::new(0)); + let total_done = Arc::new(AtomicUsize::new(0)); + + let meta_listener = MetaBuilder::new() + .id(1.into()) + .protocol_spawn(MultiMessageProbe { + received: listener_received.clone(), + total_done: total_done.clone(), + }) + .build(); + + let meta_dialer = MetaBuilder::new() + .id(1.into()) + .protocol_spawn(MultiMessageProbe { + received: dialer_received.clone(), + total_done: total_done.clone(), + }) + .build(); + + run_pair(secio, meta_listener, meta_dialer, Duration::from_secs(60)); + + let lr = listener_received.load(Ordering::SeqCst); + let dr = dialer_received.load(Ordering::SeqCst); + assert_eq!( + lr, MULTI_MSG_COUNT, + "listener received {lr}/{MULTI_MSG_COUNT} messages", + ); + assert_eq!( + dr, MULTI_MSG_COUNT, + "dialer received {dr}/{MULTI_MSG_COUNT} messages", + ); + } +} + +#[test] +fn test_spawn_multi_message_with_no_secio() { + run_multi_message_test(false, 3); +} + +#[test] +fn test_spawn_multi_message_with_secio() { + run_multi_message_test(true, 3); +} diff --git a/yamux/src/stream.rs b/yamux/src/stream.rs index 06dd04ab..34fddbe4 100644 --- a/yamux/src/stream.rs +++ b/yamux/src/stream.rs @@ -319,16 +319,16 @@ impl StreamHandle { return Err(Error::SubStreamRemoteClosing); } - match self.frame_receiver.try_next() { - Ok(Some(frame)) => { + match self.frame_receiver.try_recv() { + Ok(frame) => { self.handle_frame(frame)?; has_new_frame = true; } - Ok(None) => { + Err(futures::channel::mpsc::TryRecvError::Closed) => { self.state = StreamState::RemoteClosing; return Err(Error::SubStreamRemoteClosing); } - Err(_) => break, + Err(futures::channel::mpsc::TryRecvError::Empty) => break, } } Ok(has_new_frame) @@ -923,8 +923,8 @@ mod test { let jh = tokio::spawn(tokio::time::timeout(std::time::Duration::from_secs(4), async move { loop { - match unbound_receiver.try_next() { - Ok(Some(ref event)) if matches!(event, StreamEvent::Frame(frame) if frame.length() == TEXT.len() as u32) => break, + match unbound_receiver.try_recv() { + Ok(ref event) if matches!(event, StreamEvent::Frame(frame) if frame.length() == TEXT.len() as u32) => break, Err(_) => (), _ => panic!("must be frame with written text"), } @@ -1168,7 +1168,7 @@ mod test { // poll_read will drain the window-update frame and internally call // writeable_wake.wake(), which notifies write_fw. // (The window-update has no data so read returns Pending again.) - let _ = Pin::new(&mut stream).poll_read(&mut read_cx, &mut rbuf2); + let _ignore = Pin::new(&mut stream).poll_read(&mut read_cx, &mut rbuf2); // After handle_window_update → writeable_wake.wake(), the write task // (write_waker) must now be notified so it can retry and succeed.