diff --git a/runtime/src/network/tokio.rs b/runtime/src/network/tokio.rs index b42ed2f954..fdcc1cb375 100644 --- a/runtime/src/network/tokio.rs +++ b/runtime/src/network/tokio.rs @@ -29,23 +29,129 @@ impl crate::Sink for Sink { } /// Implementation of [crate::Stream] for the [tokio] runtime. +/// +/// # Buffering +/// +/// This stream uses a read buffer to reduce syscall overhead. +/// Multiple small reads can be satisfied from the buffer without +/// additional network operations. pub struct Stream { read_timeout: Duration, stream: OwnedReadHalf, + /// Internal buffer for batching reads. + buffer: Vec, + /// Start position of valid data in the buffer. + start: usize, + /// End position of valid data in the buffer (exclusive). + end: usize, +} + +impl Stream { + /// Returns the number of buffered bytes available. + #[inline] + const fn buffered(&self) -> usize { + self.end - self.start + } + + /// Moves any remaining data to the front of the buffer. + #[inline] + fn compact(&mut self) { + if self.start == 0 { + return; + } + + let remaining = self.end - self.start; + if remaining > 0 { + self.buffer.copy_within(self.start..self.end, 0); + } + self.start = 0; + self.end = remaining; + } + + /// Reads at least `min_bytes` into the internal buffer, up to available capacity. + /// Returns the total number of bytes read, or an error. + async fn fill_buffer(&mut self, min_bytes: usize) -> Result { + // Compact first to maximize space for reading + self.compact(); + + // Use a single deadline for the entire operation to prevent slow-drip attacks + let deadline = tokio::time::Instant::now() + self.read_timeout; + let target = self.end + min_bytes; + + // Read at least min_bytes more, up to buffer capacity + while self.end < target { + // Compute the remaining time and check if we've timed out + let remaining_time = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining_time.is_zero() { + return Err(Error::Timeout); + } + + // Read up to the remaining time + let bytes_read = timeout( + remaining_time, + self.stream.read(&mut self.buffer[self.end..]), + ) + .await + .map_err(|_| Error::Timeout)? + .map_err(|_| Error::RecvFailed)?; + if bytes_read == 0 { + return Err(Error::RecvFailed); // EOF + } + + self.end += bytes_read; + } + + Ok(self.end - self.start) + } + + /// Copies bytes from the internal buffer to the output. + /// Returns the number of bytes copied. + #[inline] + fn copy_from_buffer(&mut self, output: &mut [u8]) -> usize { + let to_copy = output.len().min(self.buffered()); + output[..to_copy].copy_from_slice(&self.buffer[self.start..self.start + to_copy]); + self.start += to_copy; + to_copy + } } impl crate::Stream for Stream { async fn recv(&mut self, buf: impl Into + Send) -> Result { let mut buf = buf.into(); - if buf.is_empty() { + let needed = buf.len(); + if needed == 0 { return Ok(buf); } + let mut filled = 0; - // Time out if we take too long to read - timeout(self.read_timeout, self.stream.read_exact(buf.as_mut())) + // First, drain any buffered data + if self.buffered() > 0 { + filled = self.copy_from_buffer(&mut buf.as_mut()[..needed]); + if filled == needed { + return Ok(buf); + } + } + + // Need more data. If the remaining request is large (>= buffer capacity), + // read directly into the output buffer to avoid extra copies. + let remaining = needed - filled; + if remaining >= self.buffer.len() { + // Read directly into output buffer + timeout( + self.read_timeout, + self.stream.read_exact(&mut buf.as_mut()[filled..]), + ) .await .map_err(|_| Error::Timeout)? .map_err(|_| Error::RecvFailed)?; + return Ok(buf); + } + + // For smaller remaining requests, fill the buffer with at least + // the remaining bytes needed (but opportunistically read more), + // then copy out what we need. + self.fill_buffer(remaining).await?; + self.copy_from_buffer(&mut buf.as_mut()[filled..needed]); Ok(buf) } @@ -83,6 +189,9 @@ impl crate::Listener for Listener { Stream { read_timeout: self.cfg.read_timeout, stream, + buffer: vec![0u8; self.cfg.read_buffer_size], + start: 0, + end: 0, }, )) } @@ -110,6 +219,11 @@ pub struct Config { read_timeout: Duration, /// Write timeout for connections, after which the connection will be closed write_timeout: Duration, + /// Size of the read buffer for batching network reads. + /// + /// A larger buffer reduces syscall overhead by reading more data per call, + /// but uses more memory per connection. Defaults to 64 KB. + read_buffer_size: usize, } #[cfg_attr(feature = "iouring-network", allow(dead_code))] @@ -130,6 +244,11 @@ impl Config { self.write_timeout = write_timeout; self } + /// See [Config] + pub const fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self { + self.read_buffer_size = read_buffer_size; + self + } // Getters /// See [Config] @@ -144,6 +263,10 @@ impl Config { pub const fn write_timeout(&self) -> Duration { self.write_timeout } + /// See [Config] + pub const fn read_buffer_size(&self) -> usize { + self.read_buffer_size + } } impl Default for Config { @@ -152,6 +275,7 @@ impl Default for Config { tcp_nodelay: None, read_timeout: Duration::from_secs(60), write_timeout: Duration::from_secs(30), + read_buffer_size: 64 * 1024, // 64 KB } } } @@ -213,6 +337,9 @@ impl crate::Network for Network { Stream { read_timeout: self.cfg.read_timeout, stream, + buffer: vec![0u8; self.cfg.read_buffer_size], + start: 0, + end: 0, }, )) } @@ -220,9 +347,12 @@ impl crate::Network for Network { #[cfg(test)] mod tests { - use crate::network::{tests, tokio as TokioNetwork}; + use crate::{ + network::{tests, tokio as TokioNetwork}, + Listener as _, Network as _, Sink as _, Stream as _, + }; use commonware_macros::test_group; - use std::time::Duration; + use std::time::{Duration, Instant}; #[tokio::test] async fn test_trait() { @@ -248,4 +378,46 @@ mod tests { }) .await; } + + #[tokio::test] + async fn test_small_send_read_quickly() { + // Use a long read timeout to ensure we're not just waiting for timeout + let read_timeout = Duration::from_secs(30); + let network = TokioNetwork::Network::from( + TokioNetwork::Config::default() + .with_read_timeout(read_timeout) + .with_write_timeout(Duration::from_secs(5)), + ); + + // Bind a listener + let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Spawn a task to accept and read + let reader = tokio::spawn(async move { + let (_addr, _sink, mut stream) = listener.accept().await.unwrap(); + + // Read a small message (much smaller than the 64KB buffer) + let start = Instant::now(); + let buf = stream.recv(vec![0u8; 10]).await.unwrap(); + let elapsed = start.elapsed(); + + (buf, elapsed) + }); + + // Connect and send a small message + let (mut sink, _stream) = network.dial(addr).await.unwrap(); + let msg = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + sink.send(msg.clone()).await.unwrap(); + + // Wait for the reader to complete + let (received, elapsed) = reader.await.unwrap(); + + // Verify we got the right data + assert_eq!(received.as_ref(), &msg[..]); + + // Verify it completed quickly (well under the read timeout) + // Should complete in milliseconds, not seconds + assert!(elapsed < read_timeout); + } }