diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 6eb07d8366..2720ee5455 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -471,6 +471,10 @@ pub trait Listener: Sync + Send + 'static { /// messages over a network connection. pub trait Sink: Sync + Send + 'static { /// Send a message to the sink. + /// + /// # Warning + /// + /// If the sink returns an error, part of the message may still be delivered. fn send( &mut self, msg: impl Into + Send, @@ -482,6 +486,10 @@ pub trait Sink: Sync + Send + 'static { pub trait Stream: Sync + Send + 'static { /// Receive a message from the stream, storing it in the given buffer. /// Reads exactly the number of bytes that fit in the buffer. + /// + /// # Warning + /// + /// If the stream returns an error, partially read data may be discarded. fn recv( &mut self, buf: impl Into + Send, diff --git a/runtime/src/network/tokio.rs b/runtime/src/network/tokio.rs index b42ed2f954..683c2ceca0 100644 --- a/runtime/src/network/tokio.rs +++ b/runtime/src/network/tokio.rs @@ -2,7 +2,7 @@ use crate::Error; use commonware_utils::StableBuf; use std::{net::SocketAddr, time::Duration}; use tokio::{ - io::{AsyncReadExt as _, AsyncWriteExt as _}, + io::{AsyncReadExt as _, AsyncWriteExt as _, BufReader}, net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpListener, TcpStream, @@ -29,9 +29,12 @@ impl crate::Sink for Sink { } /// Implementation of [crate::Stream] for the [tokio] runtime. +/// +/// Uses a [`BufReader`] to reduce syscall overhead. Multiple small reads +/// can be satisfied from the buffer without additional network operations. pub struct Stream { read_timeout: Duration, - stream: OwnedReadHalf, + stream: BufReader, } impl crate::Stream for Stream { @@ -82,7 +85,7 @@ impl crate::Listener for Listener { }, Stream { read_timeout: self.cfg.read_timeout, - stream, + stream: BufReader::with_capacity(self.cfg.read_buffer_size, stream), }, )) } @@ -110,6 +113,11 @@ pub struct Config { read_timeout: Duration, /// Write timeout for connections, after which the connection will be closed write_timeout: Duration, + /// Size of the read buffer for batching network reads. + /// + /// A larger buffer reduces syscall overhead by reading more data per call, + /// but uses more memory per connection. Defaults to 64 KB. + read_buffer_size: usize, } #[cfg_attr(feature = "iouring-network", allow(dead_code))] @@ -130,6 +138,11 @@ impl Config { self.write_timeout = write_timeout; self } + /// See [Config] + pub const fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self { + self.read_buffer_size = read_buffer_size; + self + } // Getters /// See [Config] @@ -144,6 +157,10 @@ impl Config { pub const fn write_timeout(&self) -> Duration { self.write_timeout } + /// See [Config] + pub const fn read_buffer_size(&self) -> usize { + self.read_buffer_size + } } impl Default for Config { @@ -152,6 +169,7 @@ impl Default for Config { tcp_nodelay: None, read_timeout: Duration::from_secs(60), write_timeout: Duration::from_secs(30), + read_buffer_size: 64 * 1024, // 64 KB } } } @@ -212,7 +230,7 @@ impl crate::Network for Network { }, Stream { read_timeout: self.cfg.read_timeout, - stream, + stream: BufReader::with_capacity(self.cfg.read_buffer_size, stream), }, )) } @@ -220,9 +238,12 @@ impl crate::Network for Network { #[cfg(test)] mod tests { - use crate::network::{tests, tokio as TokioNetwork}; + use crate::{ + network::{tests, tokio as TokioNetwork}, + Listener as _, Network as _, Sink as _, Stream as _, + }; use commonware_macros::test_group; - use std::time::Duration; + use std::time::{Duration, Instant}; #[tokio::test] async fn test_trait() { @@ -248,4 +269,85 @@ mod tests { }) .await; } + + #[tokio::test] + async fn test_small_send_read_quickly() { + // Use a long read timeout to ensure we're not just waiting for timeout + let read_timeout = Duration::from_secs(30); + let network = TokioNetwork::Network::from( + TokioNetwork::Config::default() + .with_read_timeout(read_timeout) + .with_write_timeout(Duration::from_secs(5)), + ); + + // Bind a listener + let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Spawn a task to accept and read + let reader = tokio::spawn(async move { + let (_addr, _sink, mut stream) = listener.accept().await.unwrap(); + + // Read a small message (much smaller than the 64KB buffer) + let start = Instant::now(); + let buf = stream.recv(vec![0u8; 10]).await.unwrap(); + let elapsed = start.elapsed(); + + (buf, elapsed) + }); + + // Connect and send a small message + let (mut sink, _stream) = network.dial(addr).await.unwrap(); + let msg = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + sink.send(msg.clone()).await.unwrap(); + + // Wait for the reader to complete + let (received, elapsed) = reader.await.unwrap(); + + // Verify we got the right data + assert_eq!(received.as_ref(), &msg[..]); + + // Verify it completed quickly (well under the read timeout) + // Should complete in milliseconds, not seconds + assert!(elapsed < read_timeout); + } + + #[tokio::test] + async fn test_read_timeout_with_partial_data() { + // Use a short read timeout to make the test fast + let read_timeout = Duration::from_millis(100); + let network = TokioNetwork::Network::from( + TokioNetwork::Config::default() + .with_read_timeout(read_timeout) + .with_write_timeout(Duration::from_secs(5)), + ); + + // Bind a listener + let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let reader = tokio::spawn(async move { + let (_addr, _sink, mut stream) = listener.accept().await.unwrap(); + + // Try to read 100 bytes, but only 5 will be sent + let start = Instant::now(); + let result = stream.recv(vec![0u8; 100]).await; + let elapsed = start.elapsed(); + + (result, elapsed) + }); + + // Connect and send only partial data + let (mut sink, _stream) = network.dial(addr).await.unwrap(); + sink.send(vec![1u8, 2, 3, 4, 5]).await.unwrap(); + + // Wait for the reader to complete + let (result, elapsed) = reader.await.unwrap(); + assert!(matches!(result, Err(crate::Error::Timeout))); + + // Verify the timeout occurred around the expected time + assert!(elapsed >= read_timeout); + // Allow some margin for timing variance + assert!(elapsed < read_timeout * 2); + } }