diff --git a/hydroflow/src/util/tcp.rs b/hydroflow/src/util/tcp.rs index e1c766aa01c1..9d93d4961bf0 100644 --- a/hydroflow/src/util/tcp.rs +++ b/hydroflow/src/util/tcp.rs @@ -1,16 +1,15 @@ #![cfg(not(target_arch = "wasm32"))] -use std::cell::RefCell; use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::HashMap; use std::net::SocketAddr; -use std::pin::pin; -use std::rc::Rc; use futures::{SinkExt, StreamExt}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; +use tokio::select; use tokio::task::spawn_local; +use tokio_stream::StreamMap; use tokio_util::codec::{ BytesCodec, Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec, LinesCodec, }; @@ -73,6 +72,7 @@ pub type TcpFramedSink = Sender<(T, SocketAddr)>; pub type TcpFramedStream = Receiver::Item, SocketAddr), ::Error>>; +// TODO(mingwei): this temporary code should be replaced with a properly thought out networking system. /// Create a listening tcp socket, and then as new connections come in, receive their data and forward it to a queue. pub async fn bind_tcp>( endpoint: SocketAddr, @@ -82,60 +82,64 @@ pub async fn bind_tcp> let bound_endpoint = listener.local_addr()?; - let (tx_egress, mut rx_egress) = unsync_channel(None); - let (tx_ingress, rx_ingress) = unsync_channel(None); - - let clients = Rc::new(RefCell::new(HashMap::new())); - - spawn_local({ - let clients = clients.clone(); - - async move { - while let Some((payload, addr)) = rx_egress.next().await { - let client = clients.borrow_mut().remove(&addr); - - if let Some(mut sender) = client { - let _ = SinkExt::send(&mut sender, payload).await; - clients.borrow_mut().insert(addr, sender); - } - } - } - }); + let (send_egress, mut recv_egress) = unsync_channel::<(T, SocketAddr)>(None); + let (send_ingres, recv_ingres) = unsync_channel(None); spawn_local(async move { + let send_ingress = send_ingres; + let mut peers_send = HashMap::new(); + let mut peers_recv = StreamMap::new(); + loop { - let (stream, peer_addr) = if let Ok((stream, _)) = listener.accept().await { - if let Ok(peer_addr) = stream.peer_addr() { - (stream, peer_addr) - } else { - continue; + // Calling methods in a loop, futures must be cancel-safe. + select! { + biased; + // Accept new clients. + new_peer = listener.accept() => { + let Ok((stream, _addr)) = new_peer else { + continue; + }; + let Ok(peer_addr) = stream.peer_addr() else { + continue; + }; + let (peer_send, peer_recv) = tcp_framed(stream, codec.clone()); + + // TODO: Using peer_addr here as the key is a little bit sketchy. + // It's possible that a peer could send a message, disconnect, then another peer connects from the + // same IP address (and the same src port), and then the response could be sent to that new client. + // This can be solved by using monotonically increasing IDs for each new peer, but would break the + // similarity with the UDP versions of this function. + peers_send.insert(peer_addr, peer_send); + peers_recv.insert(peer_addr, peer_recv); } - } else { - continue; - }; - - let mut tx_ingress = tx_ingress.clone(); - - let (send, recv) = tcp_framed(stream, codec.clone()); - - // TODO: Using peer_addr here as the key is a little bit sketchy. - // It's possible that a client could send a message, disconnect, then another client connects from the same IP address (and the same src port), and then the response could be sent to that new client. - // This can be solved by using monotonically increasing IDs for each new client, but would break the similarity with the UDP versions of this function. - clients.borrow_mut().insert(peer_addr, send); - - spawn_local({ - let clients = clients.clone(); - async move { - let mapped = recv.map(|x| Ok(x.map(|x| (x, peer_addr)))); - let _ = tx_ingress.send_all(&mut pin!(mapped)).await; - - clients.borrow_mut().remove(&peer_addr); + // Send outgoing messages. + msg_send = recv_egress.next() => { + let Some((payload, peer_addr)) = msg_send else { + continue; + }; + let Some(stream) = peers_send.get_mut(&peer_addr) else { + eprintln!("Dropping message to non-connected peer: {}", peer_addr); + continue; + }; + if let Err(_err) = SinkExt::send(stream, payload).await { + eprintln!("Failed to send message to peer: {}", peer_addr); + }; + } + // Receive incoming messages. + msg_recv = peers_recv.next() => { + let Some((peer_addr, payload_result)) = msg_recv else { + eprintln!("Error receiving message"); + continue; + }; + if let Err(err) = send_ingress.send(payload_result.map(|payload| (payload, peer_addr))).await { + eprintln!("Error passing along received message: {:?}", err); + } } - }); + } } }); - Ok((tx_egress, rx_ingress, bound_endpoint)) + Ok((send_egress, recv_ingres, bound_endpoint)) } /// The inverse of [`bind_tcp`]. @@ -146,34 +150,54 @@ pub async fn bind_tcp> pub fn connect_tcp>( codec: Codec, ) -> (TcpFramedSink, TcpFramedStream) { - let (tx_egress, mut rx_egress) = unsync_channel(None); - let (tx_ingress, rx_ingress) = unsync_channel(None); + let (send_egress, mut recv_egress) = unsync_channel(None); + let (send_ingres, recv_ingres) = unsync_channel(None); spawn_local(async move { - let mut streams = HashMap::new(); - - while let Some((payload, addr)) = rx_egress.next().await { - let stream = match streams.entry(addr) { - Occupied(entry) => entry.into_mut(), - Vacant(entry) => { - let socket = TcpSocket::new_v4().unwrap(); - let stream = socket.connect(addr).await.unwrap(); - - let (send, recv) = tcp_framed(stream, codec.clone()); + let send_ingres = send_ingres; + let mut peers_send = HashMap::new(); + let mut peers_recv = StreamMap::new(); - let mut tx_ingress = tx_ingress.clone(); - spawn_local(async move { - let mapped = recv.map(|x| Ok(x.map(|x| (x, addr)))); - let _ = tx_ingress.send_all(&mut pin!(mapped)).await; - }); - - entry.insert(send) + loop { + // Calling methods in a loop, futures must be cancel-safe. + select! { + biased; + // Send outgoing messages. + msg_send = recv_egress.next() => { + let Some((payload, peer_addr)) = msg_send else { + continue; + }; + + let stream = match peers_send.entry(peer_addr) { + Occupied(entry) => entry.into_mut(), + Vacant(entry) => { + let socket = TcpSocket::new_v4().unwrap(); + let stream = socket.connect(peer_addr).await.unwrap(); + + let (peer_send, peer_recv) = tcp_framed(stream, codec.clone()); + + peers_recv.insert(peer_addr, peer_recv); + entry.insert(peer_send) + } + }; + + if let Err(_err) = stream.send(payload).await { + eprintln!("Failed to send message to peer: {}", peer_addr); + } } - }; - - let _ = stream.send(payload).await; + // Receive incoming messages. + msg_recv = peers_recv.next() => { + let Some((peer_addr, payload_result)) = msg_recv else { + eprintln!("Error receiving message"); + continue; + }; + if let Err(err) = send_ingres.send(payload_result.map(|payload| (payload, peer_addr))).await { + eprintln!("Error passing along received message: {:?}", err); + } + } + } } }); - (tx_egress, rx_ingress) + (send_egress, recv_ingres) }