Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ pub trait Listener: Sync + Send + 'static {
/// messages over a network connection.
pub trait Sink: Sync + Send + 'static {
/// Send a message to the sink.
///
/// # Warning
///
/// If the sink returns an error, part of the message may still be delivered.
fn send(
&mut self,
msg: impl Into<StableBuf> + Send,
Expand All @@ -482,6 +486,10 @@ pub trait Sink: Sync + Send + 'static {
pub trait Stream: Sync + Send + 'static {
/// Receive a message from the stream, storing it in the given buffer.
/// Reads exactly the number of bytes that fit in the buffer.
///
/// # Warning
///
/// If the stream returns an error, partially read data may be discarded.
fn recv(
&mut self,
buf: impl Into<StableBuf> + Send,
Expand Down
114 changes: 108 additions & 6 deletions runtime/src/network/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::Error;
use commonware_utils::StableBuf;
use std::{net::SocketAddr, time::Duration};
use tokio::{
io::{AsyncReadExt as _, AsyncWriteExt as _},
io::{AsyncReadExt as _, AsyncWriteExt as _, BufReader},
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpListener, TcpStream,
Expand All @@ -29,9 +29,12 @@ impl crate::Sink for Sink {
}

/// Implementation of [crate::Stream] for the [tokio] runtime.
///
/// Uses a [`BufReader`] to reduce syscall overhead. Multiple small reads
/// can be satisfied from the buffer without additional network operations.
pub struct Stream {
read_timeout: Duration,
stream: OwnedReadHalf,
stream: BufReader<OwnedReadHalf>,
}

impl crate::Stream for Stream {
Expand Down Expand Up @@ -82,7 +85,7 @@ impl crate::Listener for Listener {
},
Stream {
read_timeout: self.cfg.read_timeout,
stream,
stream: BufReader::with_capacity(self.cfg.read_buffer_size, stream),
},
))
}
Expand Down Expand Up @@ -110,6 +113,11 @@ pub struct Config {
read_timeout: Duration,
/// Write timeout for connections, after which the connection will be closed
write_timeout: Duration,
/// Size of the read buffer for batching network reads.
///
/// A larger buffer reduces syscall overhead by reading more data per call,
/// but uses more memory per connection. Defaults to 64 KB.
read_buffer_size: usize,
}

#[cfg_attr(feature = "iouring-network", allow(dead_code))]
Expand All @@ -130,6 +138,11 @@ impl Config {
self.write_timeout = write_timeout;
self
}
/// See [Config]
pub const fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self {
self.read_buffer_size = read_buffer_size;
self
}

// Getters
/// See [Config]
Expand All @@ -144,6 +157,10 @@ impl Config {
pub const fn write_timeout(&self) -> Duration {
self.write_timeout
}
/// See [Config]
pub const fn read_buffer_size(&self) -> usize {
self.read_buffer_size
}
}

impl Default for Config {
Expand All @@ -152,6 +169,7 @@ impl Default for Config {
tcp_nodelay: None,
read_timeout: Duration::from_secs(60),
write_timeout: Duration::from_secs(30),
read_buffer_size: 64 * 1024, // 64 KB
}
}
}
Expand Down Expand Up @@ -212,17 +230,20 @@ impl crate::Network for Network {
},
Stream {
read_timeout: self.cfg.read_timeout,
stream,
stream: BufReader::with_capacity(self.cfg.read_buffer_size, stream),
},
))
}
}

#[cfg(test)]
mod tests {
use crate::network::{tests, tokio as TokioNetwork};
use crate::{
network::{tests, tokio as TokioNetwork},
Listener as _, Network as _, Sink as _, Stream as _,
};
use commonware_macros::test_group;
use std::time::Duration;
use std::time::{Duration, Instant};

#[tokio::test]
async fn test_trait() {
Expand All @@ -248,4 +269,85 @@ mod tests {
})
.await;
}

#[tokio::test]
async fn test_small_send_read_quickly() {
// Use a long read timeout to ensure we're not just waiting for timeout
let read_timeout = Duration::from_secs(30);
let network = TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(read_timeout)
.with_write_timeout(Duration::from_secs(5)),
);

// Bind a listener
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();

// Spawn a task to accept and read
let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();

// Read a small message (much smaller than the 64KB buffer)
let start = Instant::now();
let buf = stream.recv(vec![0u8; 10]).await.unwrap();
let elapsed = start.elapsed();

(buf, elapsed)
});

// Connect and send a small message
let (mut sink, _stream) = network.dial(addr).await.unwrap();
let msg = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
sink.send(msg.clone()).await.unwrap();

// Wait for the reader to complete
let (received, elapsed) = reader.await.unwrap();

// Verify we got the right data
assert_eq!(received.as_ref(), &msg[..]);

// Verify it completed quickly (well under the read timeout)
// Should complete in milliseconds, not seconds
assert!(elapsed < read_timeout);
}

#[tokio::test]
async fn test_read_timeout_with_partial_data() {
// Use a short read timeout to make the test fast
let read_timeout = Duration::from_millis(100);
let network = TokioNetwork::Network::from(
TokioNetwork::Config::default()
.with_read_timeout(read_timeout)
.with_write_timeout(Duration::from_secs(5)),
);

// Bind a listener
let mut listener = network.bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
let addr = listener.local_addr().unwrap();

let reader = tokio::spawn(async move {
let (_addr, _sink, mut stream) = listener.accept().await.unwrap();

// Try to read 100 bytes, but only 5 will be sent
let start = Instant::now();
let result = stream.recv(vec![0u8; 100]).await;
let elapsed = start.elapsed();

(result, elapsed)
});

// Connect and send only partial data
let (mut sink, _stream) = network.dial(addr).await.unwrap();
sink.send(vec![1u8, 2, 3, 4, 5]).await.unwrap();

// Wait for the reader to complete
let (result, elapsed) = reader.await.unwrap();
assert!(matches!(result, Err(crate::Error::Timeout)));

// Verify the timeout occurred around the expected time
assert!(elapsed >= read_timeout);
// Allow some margin for timing variance
assert!(elapsed < read_timeout * 2);
}
}
Loading