diff --git a/Cargo.lock b/Cargo.lock index 9909dfd3cc..0e54fb3d6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1592,6 +1592,7 @@ name = "commonware-stream-fuzz" version = "0.0.63" dependencies = [ "arbitrary", + "bytes", "chacha20poly1305", "commonware-codec", "commonware-cryptography", diff --git a/examples/sync/src/bin/server.rs b/examples/sync/src/bin/server.rs index 0ab3cb669e..bf5f9eaece 100644 --- a/examples/sync/src/bin/server.rs +++ b/examples/sync/src/bin/server.rs @@ -336,8 +336,8 @@ where outgoing = response_receiver.next() => { if let Some(response) = outgoing { // We have a response to send to the client. - let response_data = response.encode().to_vec(); - if let Err(err) = send_frame(&mut sink, &response_data, MAX_MESSAGE_SIZE).await { + let response_data = response.encode(); + if let Err(err) = send_frame(&mut sink, response_data, MAX_MESSAGE_SIZE).await { info!(client_addr = %client_addr, ?err, "send failed (client likely disconnected)"); state.error_counter.inc(); return Ok(()); diff --git a/examples/sync/src/net/io.rs b/examples/sync/src/net/io.rs index f881f61499..1d378bcce9 100644 --- a/examples/sync/src/net/io.rs +++ b/examples/sync/src/net/io.rs @@ -45,8 +45,8 @@ async fn run_loop( Some(Request { request, response_tx }) => { let request_id = request.request_id(); pending_requests.insert(request_id, response_tx); - let data = request.encode().to_vec(); - if let Err(e) = send_frame(&mut sink, &data, MAX_MESSAGE_SIZE).await { + let data = request.encode(); + if let Err(e) = send_frame(&mut sink, data, MAX_MESSAGE_SIZE).await { if let Some(sender) = pending_requests.remove(&request_id) { let _ = sender.send(Err(Error::Network(e))); } diff --git a/p2p/src/simulated/network.rs b/p2p/src/simulated/network.rs index b3ca9d622c..e33cbad2cf 100644 --- a/p2p/src/simulated/network.rs +++ b/p2p/src/simulated/network.rs @@ -1055,7 +1055,9 @@ impl Link { context.with_label("link").spawn(move |context| async move { // Dial the peer and handshake by sending it the dialer's public key let (mut sink, _) = context.dial(socket).await.unwrap(); - if let Err(err) = send_frame(&mut sink, &dialer, max_size).await { + if let Err(err) = + send_frame(&mut sink, Bytes::from_owner(dialer.clone()), max_size).await + { error!(?err, "failed to send public key to listener"); return; } @@ -1070,7 +1072,7 @@ impl Link { data.extend_from_slice(&channel.to_be_bytes()); data.extend_from_slice(&message); let data = data.freeze(); - let _ = send_frame(&mut sink, &data, max_size).await; + let _ = send_frame(&mut sink, data, max_size).await; // Bump received messages metric received_messages diff --git a/runtime/src/iouring/mod.rs b/runtime/src/iouring/mod.rs index 0027c3a808..af5641a87f 100644 --- a/runtime/src/iouring/mod.rs +++ b/runtime/src/iouring/mod.rs @@ -57,6 +57,7 @@ //! 3. If `shutdown_timeout` is configured, abandons remaining operations after the timeout //! 4. Cleans up and exits +use bytes::Buf; use commonware_utils::StableBuf; use futures::{ channel::{mpsc, oneshot}, @@ -78,13 +79,17 @@ const TIMEOUT_WORK_ID: u64 = u64::MAX; /// Active operations keyed by their work id. /// /// Each entry keeps the caller's oneshot sender, the `StableBuf` that must stay -/// alive until the kernel finishes touching it, and when op_timeout is enabled, -/// the boxed `Timespec` used when we link in an IOSQE_IO_LINK timeout. +/// alive until the kernel finishes touching it, an optional boxed `Buf` that must +/// remain valid and is returned to the caller, an optional keepalive that must +/// stay alive but is not returned, and when op_timeout is enabled, the boxed +/// `Timespec` used when we link in an IOSQE_IO_LINK timeout. type Waiters = HashMap< u64, ( - oneshot::Sender<(i32, Option)>, + oneshot::Sender<(i32, Option, Option>)>, Option, + Option>, + Option>, Option>, ), >; @@ -205,14 +210,22 @@ pub struct Op { /// The submission queue entry to be submitted to the ring. /// Its user data field will be overwritten. Users shouldn't rely on it. pub work: SqueueEntry, - /// Sends the result of the operation and `buffer`. - pub sender: oneshot::Sender<(i32, Option)>, + /// Sends the result of the operation, the `buffer`, and the `buf`. + pub sender: oneshot::Sender<(i32, Option, Option>)>, /// The buffer used for the operation, if any. /// E.g. For read, this is the buffer being read into. /// If None, the operation doesn't use a buffer (e.g. a sync operation). /// We hold the buffer here so it's guaranteed to live until the operation /// completes, preventing write-after-free issues. pub buffer: Option, + /// A boxed `Buf` to keep alive for the duration of the operation and return + /// to the caller. This is useful for vectored I/O where the original buffer + /// must remain valid and the caller needs it back to call `advance()`. + pub buf: Option>, + /// Additional data to keep alive for the duration of the operation but not + /// returned to the caller. Useful for things like iovec arrays that the kernel + /// references during the operation. + pub keepalive: Option>, } // Returns false iff we received a shutdown timeout @@ -235,8 +248,9 @@ fn handle_cqe(waiters: &mut Waiters, cqe: CqueueEntry, cfg: &Config) { result }; - let (result_sender, buffer, _) = waiters.remove(&work_id).expect("missing sender"); - let _ = result_sender.send((result, buffer)); + let (result_sender, buffer, buf, _keepalive, _timespec) = + waiters.remove(&work_id).expect("missing sender"); + let _ = result_sender.send((result, buffer, buf)); } } } @@ -296,6 +310,8 @@ pub(crate) async fn run(cfg: Config, metrics: Arc, mut receiver: mpsc:: mut work, sender, buffer, + buf, + keepalive, } = op; // Assign a unique id @@ -357,7 +373,7 @@ pub(crate) async fn run(cfg: Config, metrics: Arc, mut receiver: mpsc:: }; // We'll send the result of this operation to `sender`. - waiters.insert(work_id, (sender, buffer, timespec)); + waiters.insert(work_id, (sender, buffer, buf, keepalive, timespec)); } // Submit and wait for at least 1 item to be in the completion queue. @@ -480,6 +496,8 @@ mod tests { work: recv, sender: recv_tx, buffer: Some(buf.into()), + buf: None, + keepalive: None, }) .await .expect("failed to send work"); @@ -498,15 +516,17 @@ mod tests { work: write, sender: write_tx, buffer: Some(msg.into()), + buf: None, + keepalive: None, }) .await .expect("failed to send work"); // Wait for the read and write operations to complete. if should_succeed { - let (result, _) = recv_rx.await.expect("failed to receive result"); + let (result, _, _) = recv_rx.await.expect("failed to receive result"); assert!(result > 0, "recv failed: {result}"); - let (result, _) = write_rx.await.expect("failed to receive result"); + let (result, _, _) = write_rx.await.expect("failed to receive result"); assert!(result > 0, "write failed: {result}"); } else { let _ = recv_rx.await; @@ -567,11 +587,13 @@ mod tests { work, sender: tx, buffer: Some(buf.into()), + buf: None, + keepalive: None, }) .await .expect("failed to send work"); // Wait for the timeout - let (result, _) = rx.await.expect("failed to receive result"); + let (result, _, _) = rx.await.expect("failed to receive result"); assert_eq!(result, -libc::ETIMEDOUT); drop(submitter); handle.await.unwrap(); @@ -597,6 +619,8 @@ mod tests { work: timeout, sender: tx, buffer: None, + buf: None, + keepalive: None, }) .await .unwrap(); @@ -605,7 +629,7 @@ mod tests { drop(submitter); // Wait for the operation `timeout` to fire. - let (result, _) = rx.await.unwrap(); + let (result, _, _) = rx.await.unwrap(); assert_eq!(result, -libc::ETIME); handle.await.unwrap(); } @@ -630,6 +654,8 @@ mod tests { work: timeout, sender: tx, buffer: None, + buf: None, + keepalive: None, }) .await .unwrap(); @@ -642,8 +668,7 @@ mod tests { // The event loop should shut down before the `timeout` fires, // dropping `tx` and causing `rx` to return Canceled. - let err = rx.await.unwrap_err(); - assert!(matches!(err, Canceled { .. })); + assert!(matches!(rx.await, Err(Canceled { .. }))); handle.await.unwrap(); } @@ -673,6 +698,8 @@ mod tests { work: nop, sender: tx, buffer: None, + buf: None, + keepalive: None, }) .await .unwrap(); @@ -681,7 +708,7 @@ mod tests { // All NOPs should complete successfully for rx in rxs { - let (res, _) = rx.await.unwrap(); + let (res, _, _) = rx.await.unwrap(); assert_eq!(res, 0, "NOP op failed: {res}"); } @@ -711,12 +738,14 @@ mod tests { work: opcode::Nop::new().build(), sender: tx, buffer: None, + buf: None, + keepalive: None, }) .await .unwrap(); // Verify it completes successfully - let (result, _) = rx.await.unwrap(); + let (result, _, _) = rx.await.unwrap(); assert_eq!(result, 0); // Clean shutdown diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index e2f4956919..f00745ceed 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -22,6 +22,7 @@ html_favicon_url = "https://commonware.xyz/favicon.ico" )] +use bytes::Buf; use commonware_macros::select; use commonware_utils::StableBuf; use prometheus_client::registry::Metric; @@ -436,10 +437,17 @@ pub trait Listener: Sync + Send + 'static { /// Interface that any runtime must implement to send /// messages over a network connection. pub trait Sink: Sync + Send + 'static { - /// Send a message to the sink. + /// Send messages to the sink using vectored I/O. + /// + /// All buffers are written in order as if they were concatenated + /// into a single contiguous message. The implementation guarantees + /// that either all bytes are written or an error is returned. + /// + /// Implementations restrict the maximum number of buffers that can be + /// written at once to `16`. fn send( &mut self, - msg: impl Into + Send, + bufs: impl Buf + Send + 'static, ) -> impl Future> + Send; } @@ -545,7 +553,6 @@ pub trait Blob: Clone + Send + Sync + 'static { mod tests { use super::*; use crate::telemetry::traces::collector::TraceStorage; - use bytes::Bytes; use commonware_macros::{select, test_collect_traces}; use futures::{ channel::{mpsc, oneshot}, @@ -2600,7 +2607,7 @@ mod tests { let request = format!( "GET /metrics HTTP/1.1\r\nHost: {address}\r\nConnection: close\r\n\r\n" ); - sink.send(Bytes::from(request).to_vec()).await.unwrap(); + sink.send(bytes::Bytes::from(request)).await.unwrap(); // Read and verify the HTTP status line let status_line = read_line(&mut stream).await.unwrap(); diff --git a/runtime/src/mocks.rs b/runtime/src/mocks.rs index 05fa0a27ca..6753de43b4 100644 --- a/runtime/src/mocks.rs +++ b/runtime/src/mocks.rs @@ -1,7 +1,7 @@ //! A mock implementation of a channel that implements the Sink and Stream traits. use crate::{Error, Sink as SinkTrait, Stream as StreamTrait}; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use commonware_utils::StableBuf; use futures::channel::oneshot; use std::{ @@ -50,8 +50,7 @@ pub struct Sink { } impl SinkTrait for Sink { - async fn send(&mut self, msg: impl Into + Send) -> Result<(), Error> { - let msg = msg.into(); + async fn send(&mut self, mut buf: impl Buf + Send + 'static) -> Result<(), Error> { let (os_send, data) = { let mut channel = self.channel.lock().unwrap(); @@ -60,8 +59,15 @@ impl SinkTrait for Sink { return Err(Error::Closed); } + // Reserve memory for the upcoming write. + let total_size = buf.remaining(); + let current_len = channel.buffer.len(); + channel.buffer.resize(total_size + current_len, 0); + // Add the data to the buffer. - channel.buffer.extend(msg.as_ref()); + buf.copy_to_slice( + &mut channel.buffer.make_contiguous()[current_len..current_len + total_size], + ); // If there is a waiter and the buffer is large enough, // return the waiter (while clearing the waiter field). @@ -153,11 +159,11 @@ mod tests { #[test] fn test_send_recv() { let (mut sink, mut stream) = Channel::init(); - let data = b"hello world".to_vec(); + let data = b"hello world"; let executor = deterministic::Runner::default(); executor.start(|_| async move { - sink.send(data.clone()).await.unwrap(); + sink.send(Bytes::from_static(data)).await.unwrap(); let buf = stream.recv(vec![0; data.len()]).await.unwrap(); assert_eq!(buf.as_ref(), data); }); @@ -166,13 +172,11 @@ mod tests { #[test] fn test_send_recv_partial_multiple() { let (mut sink, mut stream) = Channel::init(); - let data = b"hello".to_vec(); - let data2 = b" world".to_vec(); let executor = deterministic::Runner::default(); executor.start(|_| async move { - sink.send(data).await.unwrap(); - sink.send(data2).await.unwrap(); + sink.send(Bytes::from_static(b"hello")).await.unwrap(); + sink.send(Bytes::from_static(b" world")).await.unwrap(); let buf = stream.recv(vec![0; 5]).await.unwrap(); assert_eq!(buf.as_ref(), b"hello"); let buf = stream.recv(buf).await.unwrap(); @@ -191,7 +195,7 @@ mod tests { executor.start(|_| async move { let (buf, _) = futures::try_join!(stream.recv(vec![0; data.len()]), async { sleep(Duration::from_millis(50)); - sink.send(data.to_vec()).await + sink.send(data.as_ref()).await }) .unwrap(); assert_eq!(buf.as_ref(), data); @@ -237,7 +241,7 @@ mod tests { let executor = deterministic::Runner::default(); executor.start(|context| async move { // Send some bytes - assert!(sink.send(b"7 bytes".to_vec()).await.is_ok()); + assert!(sink.send(b"7 bytes".as_slice()).await.is_ok()); // Spawn a task to initiate recv's where the first one will succeed and then will drop. let handle = context.clone().spawn(|_| async move { @@ -253,7 +257,7 @@ mod tests { assert!(matches!(handle.await, Err(Error::Closed))); // Try to send a message. The stream is dropped, so this should fail. - let result = sink.send(b"hello world".to_vec()).await; + let result = sink.send(b"hello world".as_slice()).await; assert!(matches!(result, Err(Error::Closed))); }); } @@ -265,7 +269,7 @@ mod tests { let executor = deterministic::Runner::default(); executor.start(|_| async move { - let result = sink.send(b"hello world".to_vec()).await; + let result = sink.send(b"hello world".as_slice()).await; assert!(matches!(result, Err(Error::Closed))); }); } diff --git a/runtime/src/network/audited.rs b/runtime/src/network/audited.rs index ae78eee4e8..690ab9d0e7 100644 --- a/runtime/src/network/audited.rs +++ b/runtime/src/network/audited.rs @@ -1,4 +1,5 @@ use crate::{deterministic::Auditor, Error, SinkOf, StreamOf}; +use bytes::Buf; use commonware_utils::StableBuf; use sha2::Digest; use std::{net::SocketAddr, sync::Arc}; @@ -11,14 +12,14 @@ pub struct Sink { } impl crate::Sink for Sink { - async fn send(&mut self, data: impl Into + Send) -> Result<(), Error> { - let data = data.into(); + async fn send(&mut self, mut buf: impl Buf + Send + 'static) -> Result<(), Error> { + let bytes = buf.copy_to_bytes(buf.remaining()); self.auditor.event(b"send_attempt", |hasher| { hasher.update(self.remote_addr.to_string().as_bytes()); - hasher.update(data.as_ref()); + hasher.update(&bytes); }); - self.inner.send(data).await.inspect_err(|e| { + self.inner.send(bytes).await.inspect_err(|e| { self.auditor.event(b"send_failure", |hasher| { hasher.update(self.remote_addr.to_string().as_bytes()); hasher.update(e.to_string().as_bytes()); @@ -261,7 +262,7 @@ mod tests { assert_eq!(buf.as_ref(), CLIENT_MSG.as_bytes()); // Send response - sink.send(Vec::from(SERVER_MSG)).await.unwrap(); + sink.send(SERVER_MSG.as_bytes()).await.unwrap(); }); server_handles.push(handle); } @@ -275,7 +276,7 @@ mod tests { let (mut sink, mut stream) = network.dial(listener_addr).await.unwrap(); // Send data to server - sink.send(Vec::from(CLIENT_MSG)).await.unwrap(); + sink.send(CLIENT_MSG.as_bytes()).await.unwrap(); // Receive response let buf = stream.recv(vec![0; SERVER_MSG.len()]).await.unwrap(); diff --git a/runtime/src/network/deterministic.rs b/runtime/src/network/deterministic.rs index 60634c297a..d42c9265c4 100644 --- a/runtime/src/network/deterministic.rs +++ b/runtime/src/network/deterministic.rs @@ -1,4 +1,6 @@ -use crate::{mocks, Error, StableBuf}; +use crate::{mocks, Error}; +use bytes::Buf; +use commonware_utils::StableBuf; use futures::{channel::mpsc, SinkExt as _, StreamExt as _}; use std::{ collections::HashMap, @@ -16,8 +18,8 @@ pub struct Sink { } impl crate::Sink for Sink { - async fn send(&mut self, msg: impl Into + Send) -> Result<(), Error> { - self.sender.send(msg).await.map_err(|_| Error::SendFailed) + async fn send(&mut self, buf: impl Buf + Send + 'static) -> Result<(), Error> { + self.sender.send(buf).await.map_err(|_| Error::SendFailed) } } diff --git a/runtime/src/network/iouring.rs b/runtime/src/network/iouring.rs index 1b4099f6f6..4aaf3e96e7 100644 --- a/runtime/src/network/iouring.rs +++ b/runtime/src/network/iouring.rs @@ -23,6 +23,7 @@ //! This implementation is only available on Linux systems that support io_uring. use crate::iouring::{self, should_retry}; +use bytes::Buf; use commonware_utils::StableBuf; use futures::{ channel::{mpsc, oneshot}, @@ -32,6 +33,7 @@ use futures::{ use io_uring::types::Fd; use prometheus_client::registry::Registry; use std::{ + io::IoSlice, net::SocketAddr, os::fd::{AsRawFd, OwnedFd}, sync::Arc, @@ -39,6 +41,8 @@ use std::{ use tokio::net::{TcpListener, TcpStream}; use tracing::warn; +const IOV_MAX: usize = 16; + #[derive(Clone, Debug, Default)] pub struct Config { /// If Some, explicitly sets TCP_NODELAY on the socket. @@ -210,6 +214,15 @@ impl crate::Listener for Listener { } } +/// Wrapper to make iovec array Send. +/// +/// SAFETY: The iovecs point into a buffer that is kept alive in `Op.buf`, +/// so the pointers remain valid for the duration of the operation. The iovecs +/// are only used by the kernel during the io_uring operation and are not +/// accessed from multiple threads. +struct SendIovecs([libc::iovec; IOV_MAX]); +unsafe impl Send for SendIovecs {} + /// Implementation of [crate::Sink] for an io-uring [Network]. pub struct Sink { fd: Arc, @@ -224,47 +237,50 @@ impl Sink { } impl crate::Sink for Sink { - async fn send(&mut self, msg: impl Into + Send) -> Result<(), crate::Error> { - let mut msg = msg.into(); - let mut bytes_sent = 0; - let msg_len = msg.len(); - - while bytes_sent < msg_len { - // Figure out how much is left to send and where to send from. - // - // SAFETY: `msg` is a `StableBuf` guaranteeing the memory won't move. - // `bytes_sent` is always < `msg_len` due to the loop condition, so - // `add(bytes_sent)` stays within bounds and `msg_len - bytes_sent` - // correctly represents the remaining valid bytes. - let remaining = unsafe { - std::slice::from_raw_parts( - msg.as_mut_ptr().add(bytes_sent) as *const u8, - msg_len - bytes_sent, - ) + async fn send(&mut self, buf: impl Buf + Send + 'static) -> Result<(), crate::Error> { + // Box the buf upfront so we can work with a consistent type throughout. + let mut buf: Box = Box::new(buf); + + while buf.has_remaining() { + // Collect chunks from the buffer into IoSlices. + let mut io_slices: [IoSlice<'_>; IOV_MAX] = std::array::from_fn(|_| IoSlice::new(&[])); + let n = buf.chunks_vectored(&mut io_slices); + + // SAFETY: `libc::iovec` and `IoSlice` have the same memory layout on Unix, + // as guaranteed by `IoSlice`'s documentation. + let iovecs: [libc::iovec; IOV_MAX] = unsafe { + std::mem::transmute::<[IoSlice<'_>; IOV_MAX], [libc::iovec; IOV_MAX]>(io_slices) }; - // Create the io_uring send operation - let op = io_uring::opcode::Send::new( - self.as_raw_fd(), - remaining.as_ptr(), - remaining.len() as u32, - ) - .build(); + // Box the iovecs FIRST to get a stable heap address, then take the pointer. + // This ensures the pointer remains valid after we move the box into keepalive. + let iovecs_box: Box = Box::new(SendIovecs(iovecs)); + let iovecs_ptr = iovecs_box.0.as_ptr(); - // Submit the operation to the io_uring event loop + // Create the io_uring writev operation. + let op = io_uring::opcode::Writev::new(self.as_raw_fd(), iovecs_ptr, n as u32).build(); + + // Submit the operation to the io_uring event loop. + // The buf is stored in `buf` and will be returned to us. + // The iovecs are stored in `keepalive` to keep them alive during the operation. let (tx, rx) = oneshot::channel(); self.submitter .send(crate::iouring::Op { work: op, sender: tx, - buffer: Some(msg), + buffer: None, + buf: Some(buf), + keepalive: Some(iovecs_box), }) .await .map_err(|_| crate::Error::SendFailed)?; // Wait for the operation to complete - let (result, got_msg) = rx.await.map_err(|_| crate::Error::SendFailed)?; - msg = got_msg.unwrap(); + let (result, _, got_buf) = rx.await.map_err(|_| crate::Error::SendFailed)?; + + // Get the buf back from the result + buf = got_buf.expect("buf should be returned"); + if should_retry(result) { continue; } @@ -274,8 +290,8 @@ impl crate::Sink for Sink { return Err(crate::Error::SendFailed); } - // Mark bytes as sent. - bytes_sent += result as usize; + // Advance the buffer by the number of bytes written. + buf.advance(result as usize); } Ok(()) } @@ -328,12 +344,14 @@ impl crate::Stream for Stream { work: op, sender: tx, buffer: Some(buf), + buf: None, + keepalive: None, }) .await .map_err(|_| crate::Error::RecvFailed)?; // Wait for the operation to complete - let (result, got_buf) = rx.await.map_err(|_| crate::Error::RecvFailed)?; + let (result, got_buf, _) = rx.await.map_err(|_| crate::Error::RecvFailed)?; buf = got_buf.unwrap(); if should_retry(result) { continue; diff --git a/runtime/src/network/metered.rs b/runtime/src/network/metered.rs index 4e17117f40..90ce21e9e1 100644 --- a/runtime/src/network/metered.rs +++ b/runtime/src/network/metered.rs @@ -1,4 +1,5 @@ use crate::{SinkOf, StreamOf}; +use bytes::Buf; use commonware_utils::StableBuf; use prometheus_client::{metrics::counter::Counter, registry::Registry}; use std::{net::SocketAddr, sync::Arc}; @@ -55,10 +56,9 @@ pub struct Sink { } impl crate::Sink for Sink { - async fn send(&mut self, data: impl Into + Send) -> Result<(), crate::Error> { - let data = data.into(); - let len = data.len(); - self.inner.send(data).await?; + async fn send(&mut self, buf: impl Buf + Send + 'static) -> Result<(), crate::Error> { + let len = buf.remaining(); + self.inner.send(buf).await?; self.metrics.outbound_bandwidth.inc_by(len as u64); Ok(()) } @@ -218,7 +218,7 @@ mod tests { let server = tokio::spawn(async move { let (_, mut sink, mut stream) = listener.accept().await.unwrap(); let buf = stream.recv(vec![0; MSG_SIZE as usize]).await.unwrap(); - sink.send(buf).await.unwrap(); + sink.send(bytes::Bytes::from(buf)).await.unwrap(); }); // Send and receive data as client @@ -226,7 +226,10 @@ mod tests { // Send fixed-size data and receive response let msg = vec![42u8; MSG_SIZE as usize]; - client_sink.send(msg.clone()).await.unwrap(); + client_sink + .send(bytes::Bytes::copy_from_slice(&msg)) + .await + .unwrap(); let response = client_stream .recv(vec![0; MSG_SIZE as usize]) diff --git a/runtime/src/network/mod.rs b/runtime/src/network/mod.rs index b20a1a8d50..f3b14efc26 100644 --- a/runtime/src/network/mod.rs +++ b/runtime/src/network/mod.rs @@ -11,6 +11,7 @@ pub(crate) mod iouring; #[cfg(test)] mod tests { use crate::{Listener, Sink, Stream}; + use bytes::Bytes; use futures::join; use std::net::SocketAddr; @@ -50,7 +51,7 @@ mod tests { .await .expect("Failed to receive"); assert_eq!(read.as_ref(), CLIENT_SEND_DATA.as_bytes()); - sink.send(Vec::from(SERVER_SEND_DATA)) + sink.send(Bytes::from_static(SERVER_SEND_DATA.as_bytes())) .await .expect("Failed to send"); }); @@ -63,7 +64,7 @@ mod tests { .await .expect("Failed to dial server"); - sink.send(Vec::from(CLIENT_SEND_DATA)) + sink.send(Bytes::from_static(CLIENT_SEND_DATA.as_bytes())) .await .expect("Failed to send data"); @@ -103,7 +104,7 @@ mod tests { .expect("Failed to receive"); assert_eq!(read.as_ref(), CLIENT_SEND_DATA.as_bytes()); - sink.send(Vec::from(SERVER_SEND_DATA)) + sink.send(Bytes::from_static(SERVER_SEND_DATA.as_bytes())) .await .expect("Failed to send"); } @@ -119,7 +120,7 @@ mod tests { .expect("Failed to dial server"); // Send a message to the server - sink.send(Vec::from(CLIENT_SEND_DATA)) + sink.send(Bytes::from_static(CLIENT_SEND_DATA.as_bytes())) .await .expect("Failed to send data"); @@ -160,7 +161,9 @@ mod tests { .recv(vec![0; CHUNK_SIZE]) .await .expect("Failed to receive chunk"); - sink.send(read).await.expect("Failed to send chunk"); + sink.send(Bytes::from(read)) + .await + .expect("Failed to send chunk"); } }); @@ -173,18 +176,18 @@ mod tests { .expect("Failed to dial server"); // Create a pattern of data - let pattern = (0..CHUNK_SIZE).map(|i| (i % 256) as u8).collect::>(); + let pattern: Vec = (0..CHUNK_SIZE).map(|i| (i % 256) as u8).collect(); // Send and verify data in chunks for _ in 0..NUM_CHUNKS { - sink.send(pattern.clone()) + sink.send(Bytes::copy_from_slice(&pattern)) .await .expect("Failed to send chunk"); let read = stream .recv(vec![0; CHUNK_SIZE]) .await .expect("Failed to receive chunk"); - assert_eq!(read.as_ref(), pattern); + assert_eq!(read.as_ref(), &pattern[..]); } }); @@ -241,7 +244,7 @@ mod tests { tokio::spawn(async move { for _ in 0..NUM_MESSAGES { let data = stream.recv(vec![0; MESSAGE_SIZE]).await.unwrap(); - sink.send(data).await.unwrap(); + sink.send(Bytes::from(data)).await.unwrap(); } }); } @@ -255,9 +258,9 @@ mod tests { let (mut sink, mut stream) = network.dial(addr).await.unwrap(); let payload = vec![42u8; MESSAGE_SIZE]; for _ in 0..NUM_MESSAGES { - sink.send(payload.clone()).await.unwrap(); + sink.send(Bytes::copy_from_slice(&payload)).await.unwrap(); let echo = stream.recv(vec![0; MESSAGE_SIZE]).await.unwrap(); - assert_eq!(echo.as_ref(), payload); + assert_eq!(echo.as_ref(), &payload[..]); } })); } diff --git a/runtime/src/network/tokio.rs b/runtime/src/network/tokio.rs index b42ed2f954..5200e66332 100644 --- a/runtime/src/network/tokio.rs +++ b/runtime/src/network/tokio.rs @@ -1,4 +1,5 @@ use crate::Error; +use bytes::Buf; use commonware_utils::StableBuf; use std::{net::SocketAddr, time::Duration}; use tokio::{ @@ -18,9 +19,9 @@ pub struct Sink { } impl crate::Sink for Sink { - async fn send(&mut self, msg: impl Into + Send) -> Result<(), Error> { + async fn send(&mut self, mut buf: impl Buf + Send + 'static) -> Result<(), Error> { // Time out if we take too long to write - timeout(self.write_timeout, self.sink.write_all(msg.into().as_ref())) + timeout(self.write_timeout, self.sink.write_all_buf(&mut buf)) .await .map_err(|_| Error::Timeout)? .map_err(|_| Error::SendFailed)?; diff --git a/runtime/src/storage/iouring.rs b/runtime/src/storage/iouring.rs index f32e81ed75..808414fa54 100644 --- a/runtime/src/storage/iouring.rs +++ b/runtime/src/storage/iouring.rs @@ -270,12 +270,14 @@ impl crate::Blob for Blob { work: op, sender, buffer: Some(buf), + buf: None, + keepalive: None, }) .await .map_err(|_| Error::ReadFailed)?; // Wait for the result - let (result, got_buf) = receiver.await.map_err(|_| Error::ReadFailed)?; + let (result, got_buf, _) = receiver.await.map_err(|_| Error::ReadFailed)?; buf = got_buf.unwrap(); if should_retry(result) { continue; @@ -326,12 +328,14 @@ impl crate::Blob for Blob { work: op, sender, buffer: Some(buf), + buf: None, + keepalive: None, }) .await .map_err(|_| Error::WriteFailed)?; // Wait for the result - let (return_value, got_buf) = receiver.await.map_err(|_| Error::WriteFailed)?; + let (return_value, got_buf, _) = receiver.await.map_err(|_| Error::WriteFailed)?; buf = got_buf.unwrap(); if should_retry(return_value) { continue; @@ -366,6 +370,8 @@ impl crate::Blob for Blob { work: op, sender, buffer: None, + buf: None, + keepalive: None, }) .await .map_err(|_| { @@ -377,7 +383,7 @@ impl crate::Blob for Blob { })?; // Wait for the result - let (return_value, _) = receiver.await.map_err(|_| { + let (return_value, _, _) = receiver.await.map_err(|_| { Error::BlobSyncFailed( self.partition.clone(), hex(&self.name), diff --git a/stream/fuzz/Cargo.toml b/stream/fuzz/Cargo.toml index 5b2e374486..1421744dea 100644 --- a/stream/fuzz/Cargo.toml +++ b/stream/fuzz/Cargo.toml @@ -10,6 +10,7 @@ cargo-fuzz = true [dependencies] arbitrary = { workspace = true, features = ["derive"] } +bytes.workspace = true chacha20poly1305 = { workspace = true, default-features = false, features = ["std", "getrandom"] } commonware-codec.workspace = true commonware-cryptography.workspace = true diff --git a/stream/fuzz/fuzz_targets/e2e.rs b/stream/fuzz/fuzz_targets/e2e.rs index 8732a83837..a5efed49cb 100644 --- a/stream/fuzz/fuzz_targets/e2e.rs +++ b/stream/fuzz/fuzz_targets/e2e.rs @@ -1,5 +1,6 @@ #![no_main] +use bytes::Bytes; use commonware_cryptography::{ed25519::PrivateKey, Signer}; use commonware_runtime::{deterministic, mocks, Handle, Runner as _, Spawner}; use commonware_stream::{ @@ -136,7 +137,7 @@ fn fuzz(input: FuzzInput) { let mut corruption_i = 0; let announce = recv_frame(&mut adversary_d_stream, MAX_MESSAGE_SIZE).await?; - send_frame(&mut adversary_d_sink, &announce, MAX_MESSAGE_SIZE).await?; + send_frame(&mut adversary_d_sink, announce, MAX_MESSAGE_SIZE).await?; let mut m1 = recv_frame(&mut adversary_d_stream, MAX_MESSAGE_SIZE) .await? @@ -147,7 +148,7 @@ fn fuzz(input: FuzzInput) { corruption_i += 1; } } - send_frame(&mut adversary_d_sink, &m1, MAX_MESSAGE_SIZE).await?; + send_frame(&mut adversary_d_sink, Bytes::from(m1), MAX_MESSAGE_SIZE).await?; let mut m2 = recv_frame(&mut adversary_l_stream, MAX_MESSAGE_SIZE) .await? @@ -158,7 +159,7 @@ fn fuzz(input: FuzzInput) { corruption_i += 1; } } - send_frame(&mut adversary_l_sink, &m2, MAX_MESSAGE_SIZE).await?; + send_frame(&mut adversary_l_sink, Bytes::from(m2), MAX_MESSAGE_SIZE).await?; let mut m3 = recv_frame(&mut adversary_d_stream, MAX_MESSAGE_SIZE) .await? @@ -171,7 +172,7 @@ fn fuzz(input: FuzzInput) { } let sent_corrupted_data = setup_corruption.iter().take(corruption_i).any(|x| *x != 0); - send_frame(&mut adversary_d_sink, &m3, MAX_MESSAGE_SIZE).await?; + send_frame(&mut adversary_d_sink, Bytes::from(m3), MAX_MESSAGE_SIZE).await?; Ok(( sent_corrupted_data, adversary_d_stream, @@ -238,7 +239,7 @@ fn fuzz(input: FuzzInput) { }; sender.send(&data).await.unwrap(); let frame = recv_frame(a_in, MAX_MESSAGE_SIZE).await.unwrap(); - send_frame(a_out, &frame, MAX_MESSAGE_SIZE).await.unwrap(); + send_frame(a_out, frame, MAX_MESSAGE_SIZE).await.unwrap(); let data2 = receiver.recv().await.unwrap(); assert_eq!(data, data2, "expected data to match"); } @@ -264,7 +265,9 @@ fn fuzz(input: FuzzInput) { }; sender.send(&[]).await.unwrap(); let _ = recv_frame(a_in, MAX_MESSAGE_SIZE).await.unwrap(); - send_frame(a_out, &data, MAX_MESSAGE_SIZE).await.unwrap(); + send_frame(a_out, Bytes::from(data), MAX_MESSAGE_SIZE) + .await + .unwrap(); let res = receiver.recv().await; assert!(res.is_err()); } diff --git a/stream/src/lib.rs b/stream/src/lib.rs index 6af6ed2e3c..ce292a3ad6 100644 --- a/stream/src/lib.rs +++ b/stream/src/lib.rs @@ -174,7 +174,7 @@ pub async fn dial( let inner_routine = async move { send_frame( &mut sink, - config.signing_key.public_key().encode().as_ref(), + Bytes::from(config.signing_key.public_key().encode()), config.max_message_size, ) .await?; @@ -190,13 +190,13 @@ pub async fn dial( peer, ), ); - send_frame(&mut sink, &syn.encode(), config.max_message_size).await?; + send_frame(&mut sink, syn.encode(), config.max_message_size).await?; let syn_ack_bytes = recv_frame(&mut stream, config.max_message_size).await?; let syn_ack = SynAck::::decode(syn_ack_bytes)?; let (ack, send, recv) = dial_end(state, syn_ack)?; - send_frame(&mut sink, &ack.encode(), config.max_message_size).await?; + send_frame(&mut sink, ack.encode(), config.max_message_size).await?; Ok(( Sender { @@ -257,7 +257,7 @@ pub async fn listen< ), msg1, )?; - send_frame(&mut sink, &syn_ack.encode(), config.max_message_size).await?; + send_frame(&mut sink, syn_ack.encode(), config.max_message_size).await?; let ack_bytes = recv_frame(&mut stream, config.max_message_size).await?; let ack = Ack::decode(ack_bytes)?; @@ -298,7 +298,7 @@ impl Sender { let c = self.cipher.send(msg)?; send_frame( &mut self.sink, - &c, + Bytes::from(c), self.max_message_size + CIPHERTEXT_OVERHEAD, ) .await?; diff --git a/stream/src/utils/codec.rs b/stream/src/utils/codec.rs index b268d2e87c..322a4ac198 100644 --- a/stream/src/utils/codec.rs +++ b/stream/src/utils/codec.rs @@ -1,26 +1,25 @@ use crate::Error; -use bytes::{BufMut as _, Bytes, BytesMut}; +use bytes::{Buf, Bytes}; use commonware_runtime::{Sink, Stream}; /// Sends data to the sink with a 4-byte length prefix. /// Returns an error if the message is too large or the stream is closed. pub async fn send_frame( sink: &mut S, - buf: &[u8], + buf: impl Buf + Send + 'static, max_message_size: usize, ) -> Result<(), Error> { // Validate frame size - let n = buf.len(); + let n = buf.remaining(); if n > max_message_size { return Err(Error::SendTooLarge(n)); } - // Prefix `buf` with its length and send it - let mut prefixed_buf = BytesMut::with_capacity(4 + buf.len()); let len: u32 = n.try_into().map_err(|_| Error::SendTooLarge(n))?; - prefixed_buf.put_u32(len); - prefixed_buf.extend_from_slice(buf); - sink.send(prefixed_buf).await.map_err(Error::SendFailed) + let len_bytes = len.to_be_bytes(); + sink.send(Bytes::from_owner(len_bytes).chain(buf)) + .await + .map_err(Error::SendFailed) } /// Receives data from the stream with a 4-byte length prefix. @@ -46,6 +45,7 @@ pub async fn recv_frame( #[cfg(test)] mod tests { use super::*; + use bytes::{BufMut, BytesMut}; use commonware_runtime::{deterministic, mocks, Runner}; use rand::Rng; @@ -60,7 +60,7 @@ mod tests { let mut buf = [0u8; MAX_MESSAGE_SIZE]; context.fill(&mut buf); - let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await; + let result = send_frame(&mut sink, Bytes::from_owner(buf), MAX_MESSAGE_SIZE).await; assert!(result.is_ok()); let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); @@ -81,9 +81,9 @@ mod tests { context.fill(&mut buf2); // Send two messages of different sizes - let result = send_frame(&mut sink, &buf1, MAX_MESSAGE_SIZE).await; + let result = send_frame(&mut sink, Bytes::from_owner(buf1), MAX_MESSAGE_SIZE).await; assert!(result.is_ok()); - let result = send_frame(&mut sink, &buf2, MAX_MESSAGE_SIZE).await; + let result = send_frame(&mut sink, Bytes::from_owner(buf2), MAX_MESSAGE_SIZE).await; assert!(result.is_ok()); // Read both messages in order @@ -105,7 +105,7 @@ mod tests { let mut buf = [0u8; MAX_MESSAGE_SIZE]; context.fill(&mut buf); - let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE).await; + let result = send_frame(&mut sink, Bytes::from_owner(buf), MAX_MESSAGE_SIZE).await; assert!(result.is_ok()); // Do the reading manually without using recv_frame @@ -126,7 +126,7 @@ mod tests { let mut buf = [0u8; MAX_MESSAGE_SIZE]; context.fill(&mut buf); - let result = send_frame(&mut sink, &buf, MAX_MESSAGE_SIZE - 1).await; + let result = send_frame(&mut sink, Bytes::from_owner(buf), MAX_MESSAGE_SIZE - 1).await; assert!(matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE)); }); } @@ -141,9 +141,8 @@ mod tests { let mut msg = [0u8; MAX_MESSAGE_SIZE]; context.fill(&mut msg); - let mut buf = BytesMut::with_capacity(4 + msg.len()); - buf.put_u32(MAX_MESSAGE_SIZE as u32); - buf.extend_from_slice(&msg); + let buf = Bytes::from_owner((MAX_MESSAGE_SIZE as u32).to_be_bytes()) + .chain(Bytes::from_owner(msg)); sink.send(buf).await.unwrap(); let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); @@ -159,8 +158,7 @@ mod tests { let executor = deterministic::Runner::default(); executor.start(|_| async move { // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size - let mut buf = BytesMut::with_capacity(4); - buf.put_u32(MAX_MESSAGE_SIZE as u32); + let buf = Bytes::from_owner((MAX_MESSAGE_SIZE as u32).to_be_bytes()); sink.send(buf).await.unwrap(); let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;