Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 177 additions & 5 deletions runtime/src/network/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,129 @@ impl crate::Sink for Sink {
}

/// Implementation of [crate::Stream] for the [tokio] runtime.
///
/// # Buffering
///
/// This stream uses a read buffer to reduce syscall overhead.
/// Multiple small reads can be satisfied from the buffer without
/// additional network operations.
pub struct Stream {
read_timeout: Duration,
stream: OwnedReadHalf,
/// Internal buffer for batching reads.
buffer: Vec<u8>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use BytesMut here along with OwnedReadHalf::read_buf, to align with #2558. Would also get rid of the need to manually track start/end.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2595 may be a better soln

/// Start position of valid data in the buffer.
start: usize,
/// End position of valid data in the buffer (exclusive).
end: usize,
}

impl Stream {
/// Returns the number of buffered bytes available.
#[inline]
const fn buffered(&self) -> usize {
self.end - self.start
}

/// Moves any remaining data to the front of the buffer.
#[inline]
fn compact(&mut self) {
if self.start == 0 {
return;
}

let remaining = self.end - self.start;
if remaining > 0 {
self.buffer.copy_within(self.start..self.end, 0);
Copy link
Contributor

@patrick-ogrady patrick-ogrady Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we ever need to actually do this. Whenever we need to fill the buffer, there is nothing of interest left (should I think be good to just start at 0 in buffered read).

}
self.start = 0;
self.end = remaining;
}

/// Reads at least `min_bytes` into the internal buffer, up to available capacity.
/// Returns the total number of bytes read, or an error.
async fn fill_buffer(&mut self, min_bytes: usize) -> Result<usize, Error> {
// Compact first to maximize space for reading
self.compact();

// Use a single deadline for the entire operation to prevent slow-drip attacks
let deadline = tokio::time::Instant::now() + self.read_timeout;
let target = self.end + min_bytes;

// Read at least min_bytes more, up to buffer capacity
while self.end < target {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's preventing me from trying to fill_buffer with min_bytes >> self.buffer.len() ? Because it looks like you'll infinite loop in that event (since self.end will always be less than target)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If this should never happen, perhaps document the invariant with an assertion check on the min_bytes input)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a case we check before starting the buffered read.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I see where min_bytes is computed on L138 but feels like you should still assert it's not too big within this function, in case another use cases pops up and someone gets it wrong...

// Compute the remaining time and check if we've timed out
let remaining_time = deadline.saturating_duration_since(tokio::time::Instant::now());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only real ~gotcha on this PR! We don't want to let an adversary force us to sit in this loop indefinitely by feeding us 1 byte at a time.

if remaining_time.is_zero() {
return Err(Error::Timeout);
}

// Read up to the remaining time
let bytes_read = timeout(
remaining_time,
self.stream.read(&mut self.buffer[self.end..]),
)
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::RecvFailed)?;
if bytes_read == 0 {
return Err(Error::RecvFailed); // EOF
}

self.end += bytes_read;
}

Ok(self.end - self.start)
}

/// Copies bytes from the internal buffer to the output.
/// Returns the number of bytes copied.
#[inline]
fn copy_from_buffer(&mut self, output: &mut [u8]) -> usize {
let to_copy = output.len().min(self.buffered());
output[..to_copy].copy_from_slice(&self.buffer[self.start..self.start + to_copy]);
self.start += to_copy;
to_copy
}
}

impl crate::Stream for Stream {
async fn recv(&mut self, buf: impl Into<StableBuf> + Send) -> Result<StableBuf, Error> {
let mut buf = buf.into();
if buf.is_empty() {
let needed = buf.len();
if needed == 0 {
return Ok(buf);
}
let mut filled = 0;

// Time out if we take too long to read
timeout(self.read_timeout, self.stream.read_exact(buf.as_mut()))
// First, drain any buffered data
Copy link
Contributor

@patrick-ogrady patrick-ogrady Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tokio has some built-ins but AFAICT nothing quite right

if self.buffered() > 0 {
filled = self.copy_from_buffer(&mut buf.as_mut()[..needed]);
if filled == needed {
return Ok(buf);
}
}

// Need more data. If the remaining request is large (>= buffer capacity),
// read directly into the output buffer to avoid extra copies.
let remaining = needed - filled;
if remaining >= self.buffer.len() {
// Read directly into output buffer
timeout(
self.read_timeout,
self.stream.read_exact(&mut buf.as_mut()[filled..]),
)
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::RecvFailed)?;
return Ok(buf);
}

// For smaller remaining requests, fill the buffer with at least
// the remaining bytes needed (but opportunistically read more),
// then copy out what we need.
self.fill_buffer(remaining).await?;
self.copy_from_buffer(&mut buf.as_mut()[filled..needed]);

Ok(buf)
}
Expand Down Expand Up @@ -83,6 +189,9 @@ impl crate::Listener for Listener {
Stream {
read_timeout: self.cfg.read_timeout,
stream,
buffer: vec![0u8; self.cfg.read_buffer_size],
start: 0,
end: 0,
},
))
}
Expand Down Expand Up @@ -110,6 +219,11 @@ pub struct Config {
read_timeout: Duration,
/// Write timeout for connections, after which the connection will be closed
write_timeout: Duration,
/// Size of the read buffer for batching network reads.
///
/// A larger buffer reduces syscall overhead by reading more data per call,
/// but uses more memory per connection. Defaults to 64 KB.
read_buffer_size: usize,
}

#[cfg_attr(feature = "iouring-network", allow(dead_code))]
Expand All @@ -130,6 +244,11 @@ impl Config {
self.write_timeout = write_timeout;
self
}
/// See [Config]
pub const fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self {
self.read_buffer_size = read_buffer_size;
self
}

// Getters
/// See [Config]
Expand All @@ -144,6 +263,10 @@ impl Config {
pub const fn write_timeout(&self) -> Duration {
self.write_timeout
}
/// See [Config]
pub const fn read_buffer_size(&self) -> usize {
self.read_buffer_size
}
}

impl Default for Config {
Expand All @@ -152,6 +275,7 @@ impl Default for Config {
tcp_nodelay: None,
read_timeout: Duration::from_secs(60),
write_timeout: Duration::from_secs(30),
read_buffer_size: 64 * 1024, // 64 KB
}
}
}
Expand Down Expand Up @@ -213,16 +337,22 @@ impl crate::Network for Network {
Stream {
read_timeout: self.cfg.read_timeout,
stream,
buffer: vec![0u8; self.cfg.read_buffer_size],
start: 0,
end: 0,
},
))
}
}

#[cfg(test)]
mod tests {
use crate::network::{tests, tokio as TokioNetwork};
use crate::{
network::{tests, tokio as TokioNetwork},
Listener as _, Network as _, Sink as _, Stream as _,
};
use commonware_macros::test_group;
use std::time::Duration;
use std::time::{Duration, Instant};

#[tokio::test]
async fn test_trait() {
Expand All @@ -248,4 +378,46 @@ mod tests {
})
.await;
}

#[tokio::test]
async fn test_small_send_read_quickly() {
// Use a long read timeout to ensure we're not just waiting for timeout
let read_timeout = Duration::from_secs(30);
let network = TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(read_timeout)
.with_write_timeout(Duration::from_secs(5)),
);

// Bind a listener
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();

// Spawn a task to accept and read
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();

// Read a small message (much smaller than the 64KB buffer)
let start = Instant::now();
let buf = stream.recv(vec![0u8; 10]).await.unwrap();
let elapsed = start.elapsed();

(buf, elapsed)
});

// Connect and send a small message
let (mut sink, _stream) = network.dial(addr).await.unwrap();
let msg = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
sink.send(msg.clone()).await.unwrap();

// Wait for the reader to complete
let (received, elapsed) = reader.await.unwrap();

// Verify we got the right data
assert_eq!(received.as_ref(), &msg[..]);

// Verify it completed quickly (well under the read timeout)
// Should complete in milliseconds, not seconds
assert!(elapsed < read_timeout);
}
}
Loading