From 8b78440be17196d32abf40962a908c602ddee56c Mon Sep 17 00:00:00 2001 From: Sebastian Urban Date: Tue, 2 Apr 2024 23:54:50 +0200 Subject: [PATCH] rch::bin: support forwarding Support automatic forwarding of binary channels. --- remoc/src/rch/bin/mod.rs | 25 ++++-- remoc/src/rch/bin/receiver.rs | 86 +++++++++++++----- remoc/src/rch/bin/sender.rs | 86 +++++++++++++----- remoc/src/rch/interlock.rs | 7 ++ remoc/src/rch/mod.rs | 4 +- remoc/tests/rch/bin.rs | 159 ++++++++++++++++++++++++++++++++++ 6 files changed, 315 insertions(+), 52 deletions(-) diff --git a/remoc/src/rch/bin/mod.rs b/remoc/src/rch/bin/mod.rs index 7a47a0c..27842f1 100644 --- a/remoc/src/rch/bin/mod.rs +++ b/remoc/src/rch/bin/mod.rs @@ -1,8 +1,9 @@ //! A channel that exchanges binary data with a remote endpoint. //! -//! Allow low-overhead exchange of binary data. -//! One end of the channel must be local while the other end must be remote. -//! Forwarding is not supported. +//! Allows low-overhead exchange of binary data. +//! +//! At least one end of the channel must be remote. +//! Forwarding, i.e. both channel ends on remote endpoints, is supported. //! //! If the sole use is to transfer a large binary object into one direction, //! consider using a [lazy blob](crate::robj::lazy_blob) instead. @@ -25,9 +26,21 @@ use super::interlock::{Interlock, Location}; pub fn channel() -> (Sender, Receiver) { let (sender_tx, sender_rx) = tokio::sync::mpsc::unbounded_channel(); let (receiver_tx, receiver_rx) = tokio::sync::mpsc::unbounded_channel(); - let interlock = Arc::new(Mutex::new(Interlock { sender: Location::Local, receiver: Location::Local })); + let interlock = Arc::new(Mutex::new(Interlock::new())); - let sender = Sender { sender: None, sender_rx, receiver_tx: Some(receiver_tx), interlock: interlock.clone() }; - let receiver = Receiver { receiver: None, sender_tx: Some(sender_tx), receiver_rx, interlock }; + let sender = Sender { + sender: None, + sender_rx, + receiver_tx: Some(receiver_tx), + interlock: interlock.clone(), + successor_tx: std::sync::Mutex::new(None), + }; + let receiver = Receiver { + receiver: None, + sender_tx: Some(sender_tx), + receiver_rx, + interlock, + successor_tx: std::sync::Mutex::new(None), + }; (sender, receiver) } diff --git a/remoc/src/rch/bin/receiver.rs b/remoc/src/rch/bin/receiver.rs index 9f68993..9e66358 100644 --- a/remoc/src/rch/bin/receiver.rs +++ b/remoc/src/rch/bin/receiver.rs @@ -1,7 +1,7 @@ use futures::FutureExt; -use serde::{ser, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use std::{ - fmt, + fmt, mem, sync::{Arc, Mutex}, }; @@ -20,6 +20,7 @@ pub struct Receiver { pub(super) sender_tx: Option>>, pub(super) receiver_rx: tokio::sync::mpsc::UnboundedReceiver>, pub(super) interlock: Arc>, + pub(super) successor_tx: std::sync::Mutex>>, } impl fmt::Debug for Receiver { @@ -53,7 +54,17 @@ impl Receiver { /// to the remote endpoint. pub async fn into_inner(mut self) -> Result { self.connect().await; - self.receiver.unwrap() + self.receiver.take().unwrap() + } + + /// Forward data. + async fn forward(successor_rx: tokio::sync::oneshot::Receiver, tx: super::Sender) { + let Ok(rx) = successor_rx.await else { return }; + let Ok(mut rx) = rx.into_inner().await else { return }; + let Ok(mut tx) = tx.into_inner().await else { return }; + if let Err(err) = chmux::forward(&mut rx, &mut tx).await { + tracing::debug!("forwarding binary channel failed: {err}"); + } } } @@ -63,34 +74,48 @@ impl Serialize for Receiver { where S: serde::Serializer, { - let sender_tx = - self.sender_tx.clone().ok_or_else(|| ser::Error::custom("cannot forward received receiver"))?; - + let sender_tx = self.sender_tx.clone(); let interlock_confirm = { let mut interlock = self.interlock.lock().unwrap(); - if !interlock.sender.check_local() { - return Err(ser::Error::custom("cannot send receiver because sender has been sent")); + if interlock.sender.check_local() { + Some(interlock.sender.start_send()) + } else { + None } - interlock.sender.start_send() }; - let port = PortSerializer::connect(|connect| { - async move { - let _ = interlock_confirm.send(()); + match (sender_tx, interlock_confirm) { + // Local-remote connection. + (Some(sender_tx), Some(interlock_confirm)) => { + let port = PortSerializer::connect(|connect| { + async move { + let _ = interlock_confirm.send(()); - match connect.await { - Ok((raw_tx, _)) => { - let _ = sender_tx.send(Ok(raw_tx)); + match connect.await { + Ok((raw_tx, _)) => { + let _ = sender_tx.send(Ok(raw_tx)); + } + Err(err) => { + let _ = sender_tx.send(Err(ConnectError::Connect(err))); + } + } } - Err(err) => { - let _ = sender_tx.send(Err(ConnectError::Connect(err))); - } - } + .boxed() + })?; + + TransportedReceiver { port }.serialize(serializer) } - .boxed() - })?; - TransportedReceiver { port }.serialize(serializer) + // Forwarding. + _ => { + let (successor_tx, successor_rx) = tokio::sync::oneshot::channel(); + *self.successor_tx.lock().unwrap() = Some(successor_tx); + let (tx, rx) = super::channel(); + PortSerializer::spawn(Self::forward(successor_rx, tx))?; + + rx.serialize(serializer) + } + } } } @@ -122,6 +147,23 @@ impl<'de> Deserialize<'de> for Receiver { sender_tx: None, receiver_rx, interlock: Arc::new(Mutex::new(Interlock { sender: Location::Remote, receiver: Location::Local })), + successor_tx: std::sync::Mutex::new(None), }) } } + +impl Drop for Receiver { + fn drop(&mut self) { + let successor_tx = self.successor_tx.lock().unwrap().take(); + if let Some(successor_tx) = successor_tx { + let dummy = Self { + receiver: None, + sender_tx: None, + receiver_rx: tokio::sync::mpsc::unbounded_channel().1, + interlock: Arc::new(Mutex::new(Interlock::new())), + successor_tx: std::sync::Mutex::new(None), + }; + let _ = successor_tx.send(mem::replace(self, dummy)); + } + } +} diff --git a/remoc/src/rch/bin/sender.rs b/remoc/src/rch/bin/sender.rs index 25d962e..c2ac891 100644 --- a/remoc/src/rch/bin/sender.rs +++ b/remoc/src/rch/bin/sender.rs @@ -1,7 +1,7 @@ use futures::FutureExt; -use serde::{ser, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use std::{ - fmt, + fmt, mem, sync::{Arc, Mutex}, }; @@ -20,6 +20,7 @@ pub struct Sender { pub(super) sender_rx: tokio::sync::mpsc::UnboundedReceiver>, pub(super) receiver_tx: Option>>, pub(super) interlock: Arc>, + pub(super) successor_tx: std::sync::Mutex>>, } impl fmt::Debug for Sender { @@ -53,7 +54,17 @@ impl Sender { /// to the remote endpoint. pub async fn into_inner(mut self) -> Result { self.connect().await; - self.sender.unwrap() + self.sender.take().unwrap() + } + + /// Forward data. + async fn forward(successor_rx: tokio::sync::oneshot::Receiver, rx: super::Receiver) { + let Ok(tx) = successor_rx.await else { return }; + let Ok(mut tx) = tx.into_inner().await else { return }; + let Ok(mut rx) = rx.into_inner().await else { return }; + if let Err(err) = chmux::forward(&mut rx, &mut tx).await { + tracing::debug!("forwarding binary channel failed: {err}"); + } } } @@ -63,34 +74,48 @@ impl Serialize for Sender { where S: serde::Serializer, { - let receiver_tx = - self.receiver_tx.clone().ok_or_else(|| ser::Error::custom("cannot forward received sender"))?; - + let receiver_tx = self.receiver_tx.clone(); let interlock_confirm = { let mut interlock = self.interlock.lock().unwrap(); - if !interlock.receiver.check_local() { - return Err(ser::Error::custom("cannot send sender because receiver has been sent")); + if interlock.receiver.check_local() { + Some(interlock.receiver.start_send()) + } else { + None } - interlock.receiver.start_send() }; - let port = PortSerializer::connect(|connect| { - async move { - let _ = interlock_confirm.send(()); + match (receiver_tx, interlock_confirm) { + // Local-remote connection. + (Some(receiver_tx), Some(interlock_confirm)) => { + let port = PortSerializer::connect(|connect| { + async move { + let _ = interlock_confirm.send(()); - match connect.await { - Ok((_, raw_rx)) => { - let _ = receiver_tx.send(Ok(raw_rx)); + match connect.await { + Ok((_, raw_rx)) => { + let _ = receiver_tx.send(Ok(raw_rx)); + } + Err(err) => { + let _ = receiver_tx.send(Err(ConnectError::Connect(err))); + } + } } - Err(err) => { - let _ = receiver_tx.send(Err(ConnectError::Connect(err))); - } - } + .boxed() + })?; + + TransportedSender { port }.serialize(serializer) } - .boxed() - })?; - TransportedSender { port }.serialize(serializer) + // Forwarding. + _ => { + let (successor_tx, successor_rx) = tokio::sync::oneshot::channel(); + *self.successor_tx.lock().unwrap() = Some(successor_tx); + let (tx, rx) = super::channel(); + PortSerializer::spawn(Self::forward(successor_rx, rx))?; + + tx.serialize(serializer) + } + } } } @@ -122,6 +147,23 @@ impl<'de> Deserialize<'de> for Sender { sender_rx, receiver_tx: None, interlock: Arc::new(Mutex::new(Interlock { sender: Location::Local, receiver: Location::Remote })), + successor_tx: std::sync::Mutex::new(None), }) } } + +impl Drop for Sender { + fn drop(&mut self) { + let successor_tx = self.successor_tx.lock().unwrap().take(); + if let Some(successor_tx) = successor_tx { + let dummy = Self { + sender: None, + sender_rx: tokio::sync::mpsc::unbounded_channel().1, + receiver_tx: None, + interlock: Arc::new(Mutex::new(Interlock::new())), + successor_tx: std::sync::Mutex::new(None), + }; + let _ = successor_tx.send(mem::replace(self, dummy)); + } + } +} diff --git a/remoc/src/rch/interlock.rs b/remoc/src/rch/interlock.rs index 3d373d3..f2ed9e4 100644 --- a/remoc/src/rch/interlock.rs +++ b/remoc/src/rch/interlock.rs @@ -4,6 +4,13 @@ pub(crate) struct Interlock { pub receiver: Location, } +impl Interlock { + /// Creates a new interlock with local sender and receiver locations. + pub fn new() -> Self { + Self { sender: Location::Local, receiver: Location::Local } + } +} + /// Location of a sender or receiver. pub(crate) enum Location { Local, diff --git a/remoc/src/rch/mod.rs b/remoc/src/rch/mod.rs index fd46b40..665964e 100644 --- a/remoc/src/rch/mod.rs +++ b/remoc/src/rch/mod.rs @@ -36,8 +36,8 @@ //! A [binary channel](bin) can be used to exchange binary data over a channel. //! It skips serialization and deserialization and thus is more efficient for binary data, //! especially when using text codecs such as JSON. -//! However, it does not support forwarding and exactly one half of it must be on a remote -//! endpoint. +//! It does support forwarding. +//! However, at least one half of it must be on a remote endpoint. //! //! # Acknowledgements and connection latency //! diff --git a/remoc/tests/rch/bin.rs b/remoc/tests/rch/bin.rs index 51598eb..91554d3 100644 --- a/remoc/tests/rch/bin.rs +++ b/remoc/tests/rch/bin.rs @@ -72,3 +72,162 @@ async fn loopback() { reply_task.await.unwrap(); } + +#[tokio::test] +async fn forward() { + crate::init(); + let ((mut a_tx, _), (_, mut b_rx)) = loop_channel::<(bin::Sender, bin::Receiver)>().await; + let ((mut c_tx, _), (_, mut d_rx)) = loop_channel::<(bin::Sender, bin::Receiver)>().await; + + println!("Sending remote bin channel sender and receiver"); + let (tx1, rx1) = bin::channel(); + let (tx2, rx2) = bin::channel(); + a_tx.send((tx1, rx2)).await.unwrap(); + + println!("Receiving remote bin channel sender and receiver"); + let (tx1, rx2) = b_rx.recv().await.unwrap().unwrap(); + + println!("Forwarding remote bin channel sender and receiver"); + c_tx.send((tx1, rx2)).await.unwrap(); + + println!("Receiving forwarded remote bin channel sender and receiver"); + let (tx1, rx2) = d_rx.recv().await.unwrap().unwrap(); + + let reply_task = tokio::spawn(async move { + let mut rx1 = rx1.into_inner().await.unwrap(); + let mut tx2 = tx2.into_inner().await.unwrap(); + + loop { + match rx1.recv_any().await.unwrap() { + Some(Received::Data(data)) => { + println!("Echoing data of length {}", data.remaining()); + tx2.send(data.into()).await.unwrap(); + } + Some(Received::Chunks) => { + println!("Echoing big data stream"); + let mut i = 0; + let mut cs = tx2.send_chunks(); + while let Some(chunk) = rx1.recv_chunk().await.unwrap() { + println!("Echoing chunk {} of size {}", i, chunk.len()); + cs = cs.send(chunk).await.unwrap(); + i += 1; + } + cs.finish().await.unwrap(); + } + Some(_) => (), + None => break, + } + } + }); + + let mut tx1 = tx1.into_inner().await.unwrap(); + let mut rx2 = rx2.into_inner().await.unwrap(); + + rx2.set_max_data_size(1_000_000); + + let mut rng = rand::thread_rng(); + for i in 1..100 { + let size = if i % 2 == 0 { rng.gen_range(0..1_000_000) } else { 1024 }; + let mut data = vec![0u8; size]; + rng.fill_bytes(&mut data); + let data = Bytes::from(data); + + println!("Sending message of length {}", data.len()); + let (send, recv) = join!(tx1.send(data.clone()), rx2.recv()); + send.unwrap(); + let data_recv = recv.unwrap().unwrap(); + println!("Received reply of length {}", data_recv.remaining()); + let data_recv = Bytes::from(data_recv); + assert_eq!(data, data_recv, "mismatched echo reply"); + } + drop(tx1); + + if rx2.recv().await.unwrap().is_some() { + panic!("received data after close"); + } + + reply_task.await.unwrap(); +} + +#[tokio::test] +async fn double_forward() { + crate::init(); + let ((mut a_tx, _), (_, mut b_rx)) = loop_channel::<(bin::Sender, bin::Receiver)>().await; + let ((mut c_tx, _), (_, mut d_rx)) = loop_channel::<(bin::Sender, bin::Receiver)>().await; + let ((mut e_tx, _), (_, mut f_rx)) = loop_channel::<(bin::Sender, bin::Receiver)>().await; + + println!("Sending remote bin channel sender and receiver"); + let (tx1, rx1) = bin::channel(); + let (tx2, rx2) = bin::channel(); + a_tx.send((tx1, rx2)).await.unwrap(); + + println!("Receiving remote bin channel sender and receiver"); + let (tx1, rx2) = b_rx.recv().await.unwrap().unwrap(); + + println!("Forwarding remote bin channel sender and receiver"); + c_tx.send((tx1, rx2)).await.unwrap(); + + println!("Receiving forwarded remote bin channel sender and receiver"); + let (tx1, rx2) = d_rx.recv().await.unwrap().unwrap(); + + println!("Forwarding remote bin channel sender and receiver again"); + e_tx.send((tx1, rx2)).await.unwrap(); + + println!("Receiving forwarded remote bin channel sender and receiver again"); + let (tx1, rx2) = f_rx.recv().await.unwrap().unwrap(); + + let reply_task = tokio::spawn(async move { + let mut rx1 = rx1.into_inner().await.unwrap(); + let mut tx2 = tx2.into_inner().await.unwrap(); + + loop { + match rx1.recv_any().await.unwrap() { + Some(Received::Data(data)) => { + println!("Echoing data of length {}", data.remaining()); + tx2.send(data.into()).await.unwrap(); + } + Some(Received::Chunks) => { + println!("Echoing big data stream"); + let mut i = 0; + let mut cs = tx2.send_chunks(); + while let Some(chunk) = rx1.recv_chunk().await.unwrap() { + println!("Echoing chunk {} of size {}", i, chunk.len()); + cs = cs.send(chunk).await.unwrap(); + i += 1; + } + cs.finish().await.unwrap(); + } + Some(_) => (), + None => break, + } + } + }); + + let mut tx1 = tx1.into_inner().await.unwrap(); + let mut rx2 = rx2.into_inner().await.unwrap(); + + rx2.set_max_data_size(1_000_000); + + let mut rng = rand::thread_rng(); + for i in 1..100 { + let size = if i % 2 == 0 { rng.gen_range(0..1_000_000) } else { 1024 }; + let mut data = vec![0u8; size]; + rng.fill_bytes(&mut data); + let data = Bytes::from(data); + + println!("Sending message of length {}", data.len()); + let (send, recv) = join!(tx1.send(data.clone()), rx2.recv()); + send.unwrap(); + let data_recv = recv.unwrap().unwrap(); + println!("Received reply of length {}", data_recv.remaining()); + let data_recv = Bytes::from(data_recv); + assert_eq!(data, data_recv, "mismatched echo reply"); + } + drop(tx1); + + if rx2.recv().await.unwrap().is_some() { + panic!("received data after close"); + } + + reply_task.await.unwrap(); +}