diff --git a/Cargo.lock b/Cargo.lock index 9909dfd3cc..e3ac122639 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1581,6 +1581,7 @@ dependencies = [ "criterion", "futures", "rand 0.8.5", + "rand_chacha 0.3.1", "rand_core 0.6.4", "thiserror 2.0.17", "x25519-dalek", diff --git a/stream/Cargo.toml b/stream/Cargo.toml index c1c2a88280..da3e4d8d35 100644 --- a/stream/Cargo.toml +++ b/stream/Cargo.toml @@ -30,6 +30,13 @@ zeroize.workspace = true [dev-dependencies] criterion.workspace = true +rand.workspace = true +rand_chacha.workspace = true [lib] bench = false + +[[bench]] +name = "stream" +harness = false +path = "src/benches/bench.rs" diff --git a/stream/src/benches/bench.rs b/stream/src/benches/bench.rs new file mode 100644 index 0000000000..2708e85596 --- /dev/null +++ b/stream/src/benches/bench.rs @@ -0,0 +1,7 @@ +//! Benchmarks for the stream crate. + +use criterion::criterion_main; + +mod send_frame; + +criterion_main!(send_frame::benches); diff --git a/stream/src/benches/send_frame.rs b/stream/src/benches/send_frame.rs new file mode 100644 index 0000000000..ac28d22d8b --- /dev/null +++ b/stream/src/benches/send_frame.rs @@ -0,0 +1,114 @@ +use commonware_runtime::{benchmarks::tokio, mocks, Stream as _}; +use commonware_stream::utils::codec::{send_frame, BufferedSender}; +use criterion::{criterion_group, Criterion, Throughput}; +use rand::{Rng, RngCore, SeedableRng as _}; +use rand_chacha::ChaCha8Rng; +use std::time::{Duration, Instant}; + +/// Maximum message size for benchmarks. +const MAX_MESSAGE_SIZE: usize = 2usize.pow(17); + +fn generate_message_sizes( + rng: &mut ChaCha8Rng, + count: usize, + min: usize, + max: usize, +) -> Vec { + (0..count).map(|_| rng.gen_range(min..=max)).collect() +} + +fn generate_messages(rng: &mut ChaCha8Rng, sizes: &[usize]) -> Vec> { + sizes + .iter() + .map(|&size| { + let mut data = vec![0u8; size]; + rng.fill_bytes(&mut data); + data + }) + .collect() +} + +fn bench_send_frame(c: &mut Criterion) { + let runner = tokio::Runner::default(); + + // Test different traffic patterns + let patterns = [ + (32, 256, 5000), // Small control messages + (1024, 65536, 5000), // Large data messages + (64, 8192, 5000), // Typical mix + ]; + + for (min_size, max_size, count) in patterns { + let mut rng = ChaCha8Rng::seed_from_u64(42); + let sizes = generate_message_sizes(&mut rng, count, min_size, max_size); + let messages = generate_messages(&mut rng, &sizes); + let total_bytes: usize = sizes.iter().sum(); + + let mut group = c.benchmark_group(module_path!()); + group.throughput(Throughput::Bytes(total_bytes as u64)); + + let bench_name = move |method: &str| { + format!("{method}/num_messages={count} min_size={min_size} max_size={max_size}",) + }; + group.bench_function(bench_name("unbuffered_sender"), |b| { + b.to_async(&runner).iter_custom(|iters| { + let messages = messages.clone(); + async move { + let mut duration = Duration::ZERO; + + for _ in 0..iters { + let (mut sink, mut stream) = mocks::Channel::init(); + + let start = Instant::now(); + for msg in messages.iter() { + send_frame(&mut sink, msg, MAX_MESSAGE_SIZE).await.unwrap(); + } + duration += start.elapsed(); + + // drain + for msg in messages.iter() { + let _ = stream.recv(vec![0u8; 4 + msg.len()]).await; + } + } + + duration + } + }); + }); + + group.bench_function(bench_name("buffered_sender"), |b| { + b.to_async(&runner).iter_custom(|iters| { + let messages = messages.clone(); + async move { + let mut duration = Duration::ZERO; + + for _ in 0..iters { + let (sink, mut stream) = mocks::Channel::init(); + let mut sender = BufferedSender::new(sink, MAX_MESSAGE_SIZE); + + let start = Instant::now(); + for msg in messages.iter() { + sender.send_frame(msg).await.unwrap(); + } + duration += start.elapsed(); + + // drain + for msg in messages.iter() { + let _ = stream.recv(vec![0u8; 4 + msg.len()]).await; + } + } + + duration + } + }); + }); + + group.finish(); + } +} + +criterion_group! { + name = benches; + config = Criterion::default().sample_size(20); + targets = bench_send_frame +} diff --git a/stream/src/lib.rs b/stream/src/lib.rs index 6af6ed2e3c..bfe8a76f61 100644 --- a/stream/src/lib.rs +++ b/stream/src/lib.rs @@ -62,7 +62,7 @@ pub mod utils; -use crate::utils::codec::{recv_frame, send_frame}; +use crate::utils::codec::{recv_frame, send_frame, BufferedSender}; use bytes::Bytes; use commonware_codec::{DecodeExt, Encode as _, Error as CodecError}; use commonware_cryptography::{ @@ -201,8 +201,7 @@ pub async fn dial( Ok(( Sender { cipher: send, - sink, - max_message_size: config.max_message_size, + sender: BufferedSender::new(sink, config.max_message_size + CIPHERTEXT_OVERHEAD), }, Receiver { cipher: recv, @@ -268,8 +267,7 @@ pub async fn listen< peer, Sender { cipher: send, - sink, - max_message_size: config.max_message_size, + sender: BufferedSender::new(sink, config.max_message_size + CIPHERTEXT_OVERHEAD), }, Receiver { cipher: recv, @@ -288,20 +286,14 @@ pub async fn listen< /// Sends encrypted messages to a peer. pub struct Sender { cipher: SendCipher, - sink: O, - max_message_size: usize, + sender: BufferedSender, } impl Sender { /// Encrypts and sends a message to the peer. pub async fn send(&mut self, msg: &[u8]) -> Result<(), Error> { let c = self.cipher.send(msg)?; - send_frame( - &mut self.sink, - &c, - self.max_message_size + CIPHERTEXT_OVERHEAD, - ) - .await?; + self.sender.send_frame(&c).await?; Ok(()) } } diff --git a/stream/src/utils/codec.rs b/stream/src/utils/codec.rs index b268d2e87c..ac5e8746b2 100644 --- a/stream/src/utils/codec.rs +++ b/stream/src/utils/codec.rs @@ -23,6 +23,79 @@ pub async fn send_frame( sink.send(prefixed_buf).await.map_err(Error::SendFailed) } +/// A sender that pools buffers to reduce allocation overhead. +/// +/// Instead of allocating a new buffer for each send as [`send_frame`] does, +/// this struct maintains a reusable buffer that grows as needed but is never +/// deallocated until the sender is dropped. +pub struct BufferedSender { + sink: S, + max_message_size: usize, + /// Reusable buffer for the length-prefixed frame. + /// Grows to accommodate the largest message sent. + send_buf: BytesMut, +} + +impl BufferedSender { + /// Creates a new `BufferedSender` wrapping the given sink. + pub fn new(sink: S, max_message_size: usize) -> Self { + Self { + sink, + max_message_size, + // Start with a reasonable initial capacity + send_buf: BytesMut::with_capacity(1024), + } + } + + /// Returns a reference to the underlying sink. + pub const fn sink(&self) -> &S { + &self.sink + } + + /// Returns a mutable reference to the underlying sink. + pub const fn sink_mut(&mut self) -> &mut S { + &mut self.sink + } + + /// Consumes the sender and returns the underlying sink. + pub fn into_inner(self) -> S { + self.sink + } + + /// Sends a frame with a 4-byte length prefix. + /// Returns an error if the message is too large or the stream is closed. + pub async fn send_frame(&mut self, buf: &[u8]) -> Result<(), Error> { + // Validate frame size + let n = buf.len(); + if n > self.max_message_size { + return Err(Error::SendTooLarge(n)); + } + + let len: u32 = n.try_into().map_err(|_| Error::SendTooLarge(n))?; + let frame_size = 4 + n; + + // Clear and reuse the buffer (capacity is preserved) + self.send_buf.clear(); + + // Reserve extra capacity so split() leaves room for reuse. + // We reserve 2x the frame size to avoid reallocation on every send. + let target_capacity = frame_size * 2; + if self.send_buf.capacity() < target_capacity { + self.send_buf + .reserve(target_capacity - self.send_buf.capacity()); + } + + // Build the frame + self.send_buf.put_u32(len); + self.send_buf.extend_from_slice(buf); + + // Split off the data to send. Because we reserved 2x capacity, + // self.send_buf retains capacity for the next frame. + let to_send = self.send_buf.split(); + self.sink.send(to_send).await.map_err(Error::SendFailed) + } +} + /// Receives data from the stream with a 4-byte length prefix. /// Returns an error if the message is too large or the stream is closed. pub async fn recv_frame( @@ -188,4 +261,71 @@ mod tests { assert!(matches!(&result, Err(Error::RecvFailed(_)))); }); } + + #[test] + fn test_buffered_sender_single_message() { + let (sink, mut stream) = mocks::Channel::init(); + + let executor = deterministic::Runner::default(); + executor.start(|mut context| async move { + let mut buf = [0u8; MAX_MESSAGE_SIZE]; + context.fill(&mut buf); + + let mut sender = super::BufferedSender::new(sink, MAX_MESSAGE_SIZE); + let result = sender.send_frame(&buf).await; + assert!(result.is_ok()); + + let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); + assert_eq!(data.len(), buf.len()); + assert_eq!(data, Bytes::from(buf.to_vec())); + }); + } + + #[test] + fn test_buffered_sender_multiple_messages() { + let (sink, mut stream) = mocks::Channel::init(); + + let executor = deterministic::Runner::default(); + executor.start(|mut context| async move { + let mut buf1 = [0u8; MAX_MESSAGE_SIZE]; + let mut buf2 = [0u8; MAX_MESSAGE_SIZE / 2]; + let mut buf3 = [0u8; 64]; + context.fill(&mut buf1); + context.fill(&mut buf2); + context.fill(&mut buf3); + + let mut sender = super::BufferedSender::new(sink, MAX_MESSAGE_SIZE); + + sender.send_frame(&buf1).await.unwrap(); + sender.send_frame(&buf2).await.unwrap(); + sender.send_frame(&buf3).await.unwrap(); + + let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); + assert_eq!(data.len(), buf1.len()); + assert_eq!(data, Bytes::from(buf1.to_vec())); + + let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); + assert_eq!(data.len(), buf2.len()); + assert_eq!(data, Bytes::from(buf2.to_vec())); + + let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); + assert_eq!(data.len(), buf3.len()); + assert_eq!(data, Bytes::from(buf3.to_vec())); + }); + } + + #[test] + fn test_buffered_sender_too_large() { + let (sink, _stream) = mocks::Channel::init(); + + let executor = deterministic::Runner::default(); + executor.start(|mut context| async move { + let mut buf = [0u8; MAX_MESSAGE_SIZE]; + context.fill(&mut buf); + + let mut sender = super::BufferedSender::new(sink, MAX_MESSAGE_SIZE - 1); + let result = sender.send_frame(&buf).await; + assert!(matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE)); + }); + } }