From a4da8630ba0f56eb04a4ead6c374a850724445ac Mon Sep 17 00:00:00 2001 From: Milos Stankovic <82043364+morph-dev@users.noreply.github.com> Date: Tue, 4 Feb 2025 11:01:38 +0200 Subject: [PATCH] feat!: distinguish peer and peer_id on api level (#136) --- .circleci/config.yml | 3 +- src/cid.rs | 11 +--- src/conn.rs | 40 +++++++------ src/event.rs | 7 ++- src/lib.rs | 1 + src/peer.rs | 96 ++++++++++++++++++++++++++++++ src/socket.rs | 139 ++++++++++++++++++++++++++----------------- src/stream.rs | 17 +++--- src/testutils.rs | 28 ++++++--- src/udp.rs | 18 +++--- tests/socket.rs | 31 ++++++---- tests/stream.rs | 13 ++-- 12 files changed, 275 insertions(+), 129 deletions(-) create mode 100644 src/peer.rs diff --git a/.circleci/config.yml b/.circleci/config.yml index 1eada56..97a9003 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,5 +7,4 @@ workflows: - rust/lint-test-build: clippy_arguments: '--all-targets --all-features -- --deny warnings' release: true - version: 1.71.1 - + version: 1.81.0 diff --git a/src/cid.rs b/src/cid.rs index b4b0dc5..1db52be 100644 --- a/src/cid.rs +++ b/src/cid.rs @@ -1,15 +1,6 @@ -use std::fmt::Debug; -use std::hash::Hash; -use std::net::SocketAddr; - -/// A remote peer. -pub trait ConnectionPeer: Clone + Debug + Eq + Hash + PartialEq + Send + Sync {} - -impl ConnectionPeer for SocketAddr {} - #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] pub struct ConnectionId

{ pub send: u16, pub recv: u16, - pub peer: P, + pub peer_id: P, } diff --git a/src/conn.rs b/src/conn.rs index e8dbcc3..d2336e0 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -8,10 +8,12 @@ use delay_map::HashMapDelay; use futures::StreamExt; use tokio::sync::{mpsc, oneshot, Notify}; -use crate::cid::{ConnectionId, ConnectionPeer}; +use crate::cid::ConnectionId; use crate::congestion; use crate::event::{SocketEvent, StreamEvent}; use crate::packet::{Packet, PacketBuilder, PacketType, SelectiveAck}; +use crate::peer::ConnectionPeer; +use crate::peer::Peer; use crate::recv::ReceiveBuffer; use crate::send::SendBuffer; use crate::sent::{SentPackets, SentPacketsError}; @@ -167,9 +169,10 @@ impl From for congestion::Config { } } -pub struct Connection { +pub struct Connection { state: State, - cid: ConnectionId

, + cid: ConnectionId, + peer: Peer

, config: ConnectionConfig, endpoint: Endpoint, peer_ts_diff: Duration, @@ -185,7 +188,8 @@ pub struct Connection { impl Connection { pub fn new( - cid: ConnectionId

, + cid: ConnectionId, + peer: Peer

, config: ConnectionConfig, syn: Option, connected: oneshot::Sender>, @@ -212,6 +216,7 @@ impl Connection { Self { state: State::Connecting(Some(connected)), cid, + peer, config, endpoint, peer_ts_diff, @@ -232,7 +237,7 @@ impl Connection { mut writes: mpsc::UnboundedReceiver, mut shutdown: oneshot::Receiver<()>, ) -> io::Result<()> { - tracing::debug!("uTP conn starting... {:?}", self.cid.peer); + tracing::debug!("uTP conn starting... {:?}", self.peer); // If we are the initiating endpoint, then send the SYN. If we are the accepting endpoint, // then send the SYN-ACK. @@ -240,7 +245,7 @@ impl Connection { Endpoint::Initiator((syn_seq_num, ..)) => { let syn = self.syn_packet(syn_seq_num); self.socket_events - .send(SocketEvent::Outgoing((syn.clone(), self.cid.peer.clone()))) + .send(SocketEvent::Outgoing((syn.clone(), self.peer.clone()))) .unwrap(); self.unacked .insert_at(syn_seq_num, syn, self.config.initial_timeout); @@ -250,7 +255,7 @@ impl Connection { Endpoint::Acceptor((syn, syn_ack)) => { let state = self.state_packet().unwrap(); self.socket_events - .send(SocketEvent::Outgoing((state, self.cid.peer.clone()))) + .send(SocketEvent::Outgoing((state, self.peer.clone()))) .unwrap(); let recv_buf = ReceiveBuffer::new(syn); @@ -409,7 +414,7 @@ impl Connection { &mut self.unacked, &mut self.socket_events, fin, - &self.cid.peer, + &self.peer, Instant::now(), ); } @@ -441,7 +446,7 @@ impl Connection { &mut self.unacked, &mut self.socket_events, fin, - &self.cid.peer, + &self.peer, Instant::now(), ); } @@ -542,7 +547,7 @@ impl Connection { &mut self.unacked, &mut self.socket_events, packet, - &self.cid.peer, + &self.peer, now, ); seq_num = seq_num.wrapping_add(1); @@ -680,7 +685,7 @@ impl Connection { let packet = self.syn_packet(seq); let _ = self .socket_events - .send(SocketEvent::Outgoing((packet, self.cid.peer.clone()))); + .send(SocketEvent::Outgoing((packet, self.peer.clone()))); } } Endpoint::Acceptor(..) => {} @@ -728,7 +733,7 @@ impl Connection { &mut self.unacked, &mut self.socket_events, packet, - &self.cid.peer, + &self.peer, now, ); } @@ -784,7 +789,7 @@ impl Connection { match packet.packet_type() { PacketType::Syn | PacketType::Fin | PacketType::Data => { if let Some(state) = self.state_packet() { - let event = SocketEvent::Outgoing((state, self.cid.peer.clone())); + let event = SocketEvent::Outgoing((state, self.peer.clone())); if self.socket_events.send(event).is_err() { tracing::warn!("Cannot transmit state packet: socket closed channel"); return; @@ -1156,7 +1161,7 @@ impl Connection { &mut self.unacked, &mut self.socket_events, packet, - &self.cid.peer, + &self.peer, now, ); } @@ -1167,7 +1172,7 @@ impl Connection { unacked: &mut HashMapDelay, socket_events: &mut mpsc::UnboundedSender>, packet: Packet, - dest: &P, + peer: &Peer

, now: Instant, ) { let (payload, len) = if packet.payload().is_empty() { @@ -1189,7 +1194,7 @@ impl Connection { sent_packets.on_transmit(packet.seq_num(), packet.packet_type(), payload, len, now); unacked.insert_at(packet.seq_num(), packet.clone(), sent_packets.timeout()); - let outbound = SocketEvent::Outgoing((packet, dest.clone())); + let outbound = SocketEvent::Outgoing((packet, peer.clone())); if socket_events.send(outbound).is_err() { tracing::warn!("Cannot transmit packet: socket closed channel"); } @@ -1214,12 +1219,13 @@ mod test { let cid = ConnectionId { send: 101, recv: 100, - peer, + peer_id: peer, }; Connection { state: State::Connecting(Some(connected)), cid, + peer: Peer::new(peer), config: ConnectionConfig::default(), endpoint, peer_ts_diff: Duration::from_millis(100), diff --git a/src/event.rs b/src/event.rs index 8399329..43c6b6c 100644 --- a/src/event.rs +++ b/src/event.rs @@ -1,5 +1,6 @@ use crate::cid::ConnectionId; use crate::packet::Packet; +use crate::peer::{ConnectionPeer, Peer}; #[derive(Clone, Debug)] pub enum StreamEvent { @@ -8,7 +9,7 @@ pub enum StreamEvent { } #[derive(Clone, Debug)] -pub enum SocketEvent

{ - Outgoing((Packet, P)), - Shutdown(ConnectionId

), +pub enum SocketEvent { + Outgoing((Packet, Peer

)), + Shutdown(ConnectionId), } diff --git a/src/lib.rs b/src/lib.rs index ca83ee2..4006266 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod congestion; pub mod conn; pub mod event; pub mod packet; +pub mod peer; pub mod recv; pub mod send; pub mod sent; diff --git a/src/peer.rs b/src/peer.rs new file mode 100644 index 0000000..f668276 --- /dev/null +++ b/src/peer.rs @@ -0,0 +1,96 @@ +use std::fmt::Debug; +use std::hash::Hash; +use std::net::SocketAddr; + +/// A trait that describes remote peer +pub trait ConnectionPeer: Debug + Clone + Send + Sync { + type Id: Debug + Clone + PartialEq + Eq + Hash + Send + Sync; + + /// Returns peer's id + fn id(&self) -> Self::Id; + + /// Consolidates two peers into one. + /// + /// It's possible that we have two instances that represent the same peer (equal `peer_id`), + /// and we need to consolidate them into one. This can happen when [Peer]-s passed with + /// [UtpSocket::accept_with_cid](crate::socket::UtpSocket::accept_with_cid) or + /// [UtpSocket::connect_with_cid](crate::socket::UtpSocket::connect_with_cid), and returned by + /// [AsyncUdpSocket::recv_from](crate::udp::AsyncUdpSocket::recv_from) contain peers (not just + /// `peer_id`). + /// + /// The structure implementing this trait can decide on the exact behavior. Some examples: + /// - If structure is simple (i.e. two peers are the same iff all fields are the same), return + /// either (see implementation for `SocketAddr`) + /// - If we can determine which peer is newer (e.g. using timestamp or version field), return + /// newer peer + /// - If structure behaves more like a key-value map whose values don't change over time, + /// merge key-value pairs from both instances into one + /// + /// Should panic if ids are not matching. + fn consolidate(a: Self, b: Self) -> Self; +} + +impl ConnectionPeer for SocketAddr { + type Id = Self; + + fn id(&self) -> Self::Id { + *self + } + + fn consolidate(a: Self, b: Self) -> Self { + assert!(a == b, "Consolidating non-equal peers"); + a + } +} + +/// Structure that stores peer's id, and maybe peer as well. +#[derive(Debug, Clone)] +pub struct Peer { + id: P::Id, + peer: Option

, +} + +impl Peer

{ + /// Creates new instance that stores peer + pub fn new(peer: P) -> Self { + Self { + id: peer.id(), + peer: Some(peer), + } + } + + /// Creates new instance that only stores peer's id + pub fn new_id(peer_id: P::Id) -> Self { + Self { + id: peer_id, + peer: None, + } + } + + /// Returns peer's id + pub fn id(&self) -> &P::Id { + &self.id + } + + /// Returns optional reference to peer + pub fn peer(&self) -> Option<&P> { + self.peer.as_ref() + } + + /// Consolidates given peer into `Self` whilst consuming it. + /// + /// See [ConnectionPeer::consolidate] for details. + /// + /// Panics if ids are not matching. + pub fn consolidate(&mut self, other: Self) { + assert!(self.id == other.id, "Consolidating with non-equal peer"); + let Some(other_peer) = other.peer else { + return; + }; + + self.peer = match self.peer.take() { + Some(peer) => Some(P::consolidate(peer, other_peer)), + None => Some(other_peer), + }; + } +} diff --git a/src/socket.rs b/src/socket.rs index 06a654d..4ea0766 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -11,20 +11,27 @@ use tokio::net::UdpSocket; use tokio::sync::mpsc::UnboundedSender; use tokio::sync::{mpsc, oneshot}; -use crate::cid::{ConnectionId, ConnectionPeer}; +use crate::cid::ConnectionId; use crate::conn::ConnectionConfig; use crate::event::{SocketEvent, StreamEvent}; use crate::packet::{Packet, PacketBuilder, PacketType}; +use crate::peer::{ConnectionPeer, Peer}; use crate::stream::UtpStream; use crate::udp::AsyncUdpSocket; type ConnChannel = UnboundedSender; -struct Accept

{ +struct Accept { stream: oneshot::Sender>>, config: ConnectionConfig, } +struct AcceptWithCidPeer { + cid: ConnectionId, + peer: Peer

, + accept: Accept

, +} + const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize; const CID_GENERATION_TRY_WARNING_COUNT: usize = 10; @@ -36,10 +43,10 @@ const CID_GENERATION_TRY_WARNING_COUNT: usize = 10; /// but thee uTP config refactor is currently very low priority. const AWAITING_CONNECTION_TIMEOUT: Duration = Duration::from_secs(20); -pub struct UtpSocket

{ - conns: Arc, ConnChannel>>>, +pub struct UtpSocket { + conns: Arc, ConnChannel>>>, accepts: UnboundedSender>, - accepts_with_cid: UnboundedSender<(Accept

, ConnectionId

)>, + accepts_with_cid: UnboundedSender>, socket_events: UnboundedSender>, } @@ -53,7 +60,7 @@ impl UtpSocket { impl

UtpSocket

where - P: ConnectionPeer + Unpin + 'static, + P: ConnectionPeer + Unpin + 'static, { pub fn with_socket(mut socket: S) -> Self where @@ -62,10 +69,10 @@ where let conns = HashMap::new(); let conns = Arc::new(RwLock::new(conns)); - let mut awaiting: HashMapDelay, Accept

> = + let mut awaiting: HashMapDelay, AcceptWithCidPeer

> = HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT); - let mut incoming_conns: HashMapDelay, Packet> = + let mut incoming_conns: HashMapDelay, (Peer

, Packet)> = HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT); let (socket_event_tx, mut socket_event_rx) = mpsc::unbounded_channel(); @@ -84,18 +91,19 @@ where loop { tokio::select! { biased; - Ok((n, src)) = socket.recv_from(&mut buf) => { + Ok((n, mut peer)) = socket.recv_from(&mut buf) => { + let peer_id = peer.id(); let packet = match Packet::decode(&buf[..n]) { Ok(pkt) => pkt, Err(..) => { - tracing::warn!(?src, "unable to decode uTP packet"); + tracing::warn!(?peer, "unable to decode uTP packet"); continue; } }; - let peer_init_cid = cid_from_packet(&packet, &src, IdType::SendIdPeerInitiated); - let we_init_cid = cid_from_packet(&packet, &src, IdType::SendIdWeInitiated); - let acc_cid = cid_from_packet(&packet, &src, IdType::RecvId); + let peer_init_cid = cid_from_packet::

(&packet, peer_id, IdType::SendIdPeerInitiated); + let we_init_cid = cid_from_packet::

(&packet, peer_id, IdType::SendIdWeInitiated); + let acc_cid = cid_from_packet::

(&packet, peer_id, IdType::RecvId); let mut conns = conns.write().unwrap(); let conn = conns .get(&acc_cid) @@ -107,12 +115,14 @@ where } None => { if std::matches!(packet.packet_type(), PacketType::Syn) { - let cid = cid_from_packet(&packet, &src, IdType::RecvId); + let cid = acc_cid; // If there was an awaiting connection with the CID, then // create a new stream for that connection. Otherwise, add the // connection to the incoming connections. - if let Some(accept) = awaiting.remove(&cid) { + if let Some(accept_with_cid) = awaiting.remove(&cid) { + peer.consolidate(accept_with_cid.peer); + let (connected_tx, connected_rx) = oneshot::channel(); let (events_tx, events_rx) = mpsc::unbounded_channel(); @@ -120,7 +130,8 @@ where let stream = UtpStream::new( cid, - accept.config, + peer, + accept_with_cid.accept.config, Some(packet), socket_event_tx.clone(), events_rx, @@ -128,10 +139,10 @@ where ); tokio::spawn(async move { - Self::await_connected(stream, accept, connected_rx).await + Self::await_connected(stream, accept_with_cid.accept.stream, connected_rx).await }); } else { - incoming_conns.insert(cid, packet); + incoming_conns.insert(cid, (peer, packet)); } } else { tracing::debug!( @@ -151,7 +162,7 @@ where let reset_packet = PacketBuilder::new(PacketType::Reset, packet.conn_id(), crate::time::now_micros(), 100_000, random_seq_num) .build(); - let event = SocketEvent::Outgoing((reset_packet, src.clone())); + let event = SocketEvent::Outgoing((reset_packet, peer)); if socket_event_tx.send(event).is_err() { tracing::warn!("Cannot transmit reset packet: socket closed channel"); return; @@ -161,18 +172,19 @@ where }, } } - Some((accept, cid)) = accepts_with_cid_rx.recv() => { - let Some(syn) = incoming_conns.remove(&cid) else { - awaiting.insert(cid, accept); + Some(accept_with_cid) = accepts_with_cid_rx.recv() => { + let Some((mut peer, syn)) = incoming_conns.remove(&accept_with_cid.cid) else { + awaiting.insert(accept_with_cid.cid.clone(), accept_with_cid); continue; }; - Self::select_accept_helper(cid, syn, conns.clone(), accept, socket_event_tx.clone()); + peer.consolidate(accept_with_cid.peer); + Self::select_accept_helper(accept_with_cid.cid, peer, syn, conns.clone(), accept_with_cid.accept, socket_event_tx.clone()); } Some(accept) = accepts_rx.recv(), if !incoming_conns.is_empty() => { - let (cid, _) = incoming_conns.iter().next().expect("at least one incoming connection"); + let cid = incoming_conns.keys().next().expect("at least one incoming connection"); let cid = cid.clone(); - let packet = incoming_conns.remove(&cid).expect("to delete incoming connection"); - Self::select_accept_helper(cid, packet, conns.clone(), accept, socket_event_tx.clone()); + let (peer, packet) = incoming_conns.remove(&cid).expect("to delete incoming connection"); + Self::select_accept_helper(cid, peer, packet, conns.clone(), accept, socket_event_tx.clone()); } Some(event) = socket_event_rx.recv() => { match event { @@ -195,11 +207,11 @@ where } } } - Some(Ok((cid, accept))) = awaiting.next() => { + Some(Ok((cid, accept_with_cid))) = awaiting.next() => { // accept_with_cid didn't receive an inbound connection within the timeout period // log it and return a timeout error tracing::debug!(%cid.send, %cid.recv, "accept_with_cid timed out"); - let _ = accept + let _ = accept_with_cid.accept .stream .send(Err(io::Error::from(io::ErrorKind::TimedOut))); } @@ -218,14 +230,14 @@ where /// Internal cid generation fn generate_cid( &self, - peer: P, + peer_id: P::Id, is_initiator: bool, event_tx: Option>, - ) -> ConnectionId

{ + ) -> ConnectionId { let mut cid = ConnectionId { send: 0, recv: 0, - peer, + peer_id, }; let mut generation_attempt_count = 0; loop { @@ -251,8 +263,8 @@ where } } - pub fn cid(&self, peer: P, is_initiator: bool) -> ConnectionId

{ - self.generate_cid(peer, is_initiator, None) + pub fn cid(&self, peer_id: P::Id, is_initiator: bool) -> ConnectionId { + self.generate_cid(peer_id, is_initiator, None) } /// Returns the number of connections currently open, both inbound and outbound. @@ -281,16 +293,21 @@ where /// they aren't compatible to use interchangeably in a program pub async fn accept_with_cid( &self, - cid: ConnectionId

, + cid: ConnectionId, + peer: Peer

, config: ConnectionConfig, ) -> io::Result> { let (stream_tx, stream_rx) = oneshot::channel(); - let accept = Accept { - stream: stream_tx, - config, + let accept = AcceptWithCidPeer { + cid, + peer, + accept: Accept { + stream: stream_tx, + config, + }, }; self.accepts_with_cid - .send((accept, cid)) + .send(accept) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?; match stream_rx.await { Ok(stream) => Ok(stream?), @@ -298,13 +315,18 @@ where } } - pub async fn connect(&self, peer: P, config: ConnectionConfig) -> io::Result> { + pub async fn connect( + &self, + peer: Peer

, + config: ConnectionConfig, + ) -> io::Result> { let (connected_tx, connected_rx) = oneshot::channel(); let (events_tx, events_rx) = mpsc::unbounded_channel(); - let cid = self.generate_cid(peer, true, Some(events_tx)); + let cid = self.generate_cid(peer.id().clone(), true, Some(events_tx)); let stream = UtpStream::new( cid, + peer, config, None, self.socket_events.clone(), @@ -321,7 +343,8 @@ where pub async fn connect_with_cid( &self, - cid: ConnectionId

, + cid: ConnectionId, + peer: Peer

, config: ConnectionConfig, ) -> io::Result> { if self.conns.read().unwrap().contains_key(&cid) { @@ -340,6 +363,7 @@ where let stream = UtpStream::new( cid.clone(), + peer, config, None, self.socket_events.clone(), @@ -362,28 +386,27 @@ where async fn await_connected( stream: UtpStream

, - accept: Accept

, + callback: oneshot::Sender>>, connected: oneshot::Receiver>, ) { match connected.await { Ok(Ok(..)) => { - let _ = accept.stream.send(Ok(stream)); + let _ = callback.send(Ok(stream)); } Ok(Err(err)) => { - let _ = accept.stream.send(Err(err)); + let _ = callback.send(Err(err)); } Err(..) => { - let _ = accept - .stream - .send(Err(io::Error::from(io::ErrorKind::ConnectionAborted))); + let _ = callback.send(Err(io::Error::from(io::ErrorKind::ConnectionAborted))); } } } fn select_accept_helper( - cid: ConnectionId

, + cid: ConnectionId, + peer: Peer

, syn: Packet, - conns: Arc, UnboundedSender>>>, + conns: Arc, ConnChannel>>>, accept: Accept

, socket_event_tx: UnboundedSender>, ) { @@ -404,6 +427,7 @@ where let stream = UtpStream::new( cid, + peer, accept.config, Some(syn), socket_event_tx, @@ -411,7 +435,9 @@ where connected_tx, ); - tokio::spawn(async move { Self::await_connected(stream, accept, connected_rx).await }); + tokio::spawn( + async move { Self::await_connected(stream, accept.stream, connected_rx).await }, + ); } } @@ -424,9 +450,10 @@ enum IdType { fn cid_from_packet( packet: &Packet, - src: &P, + peer_id: &P::Id, id_type: IdType, -) -> ConnectionId

{ +) -> ConnectionId { + let peer_id = peer_id.clone(); match id_type { IdType::RecvId => { let (send, recv) = match packet.packet_type() { @@ -438,7 +465,7 @@ fn cid_from_packet( ConnectionId { send, recv, - peer: src.clone(), + peer_id, } } IdType::SendIdWeInitiated => { @@ -446,7 +473,7 @@ fn cid_from_packet( ConnectionId { send, recv, - peer: src.clone(), + peer_id, } } IdType::SendIdPeerInitiated => { @@ -454,13 +481,13 @@ fn cid_from_packet( ConnectionId { send, recv, - peer: src.clone(), + peer_id, } } } } -impl

Drop for UtpSocket

{ +impl Drop for UtpSocket

{ fn drop(&mut self) { for conn in self.conns.read().unwrap().values() { let _ = conn.send(StreamEvent::Shutdown); diff --git a/src/stream.rs b/src/stream.rs index 363f3a8..4311709 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,18 +4,19 @@ use tokio::sync::{mpsc, oneshot}; use tokio::task; use tracing::Instrument; -use crate::cid::{ConnectionId, ConnectionPeer}; +use crate::cid::ConnectionId; use crate::congestion::DEFAULT_MAX_PACKET_SIZE_BYTES; use crate::conn; use crate::event::{SocketEvent, StreamEvent}; use crate::packet::Packet; +use crate::peer::{ConnectionPeer, Peer}; /// The size of the send and receive buffers. // TODO: Make the buffer size configurable. const BUF: usize = 1024 * 1024; -pub struct UtpStream

{ - cid: ConnectionId

, +pub struct UtpStream { + cid: ConnectionId, reads: mpsc::UnboundedReceiver, writes: mpsc::UnboundedSender, shutdown: Option>, @@ -27,7 +28,8 @@ where P: ConnectionPeer + 'static, { pub(crate) fn new( - cid: ConnectionId

, + cid: ConnectionId, + peer: Peer

, config: conn::ConnectionConfig, syn: Option, socket_events: mpsc::UnboundedSender>, @@ -39,6 +41,7 @@ where let (writes_tx, writes_rx) = mpsc::unbounded_channel(); let mut conn = conn::Connection::::new( cid.clone(), + peer, config, syn, connected, @@ -60,7 +63,7 @@ where } } - pub fn cid(&self) -> &ConnectionId

{ + pub fn cid(&self) -> &ConnectionId { &self.cid } @@ -117,7 +120,7 @@ where } } -impl

UtpStream

{ +impl UtpStream

{ // Send signal to the connection event loop to exit, after all outgoing writes have completed. // Public callers should use close() instead. fn shutdown(&mut self) -> io::Result<()> { @@ -130,7 +133,7 @@ impl

UtpStream

{ } } -impl

Drop for UtpStream

{ +impl Drop for UtpStream

{ fn drop(&mut self) { let _ = self.shutdown(); } diff --git a/src/testutils.rs b/src/testutils.rs index 4372d0a..2ea0f00 100644 --- a/src/testutils.rs +++ b/src/testutils.rs @@ -5,7 +5,8 @@ use std::sync::Arc; use async_trait::async_trait; use tokio::sync::mpsc; -use crate::cid::{ConnectionId, ConnectionPeer}; +use crate::cid::ConnectionId; +use crate::peer::{ConnectionPeer, Peer}; use crate::udp::AsyncUdpSocket; /// A mock socket that can be used to simulate a perfect link. @@ -38,8 +39,8 @@ impl AsyncUdpSocket for MockUdpSocket { /// /// Panics if `target` is not equal to `self.only_peer`. This socket is built to support /// exactly two peers communicating with each other, so it will panic if used with more. - async fn send_to(&mut self, buf: &[u8], target: &char) -> io::Result { - if target != &self.only_peer { + async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result { + if peer.id() != &self.only_peer { panic!("MockUdpSocket only supports sending to one peer"); } if !self.is_up() { @@ -58,7 +59,7 @@ impl AsyncUdpSocket for MockUdpSocket { /// # Panics /// /// Panics if `buf` is smaller than the packet size. - async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, char)> { + async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)> { let packet = self .inbound .recv() @@ -69,11 +70,22 @@ impl AsyncUdpSocket for MockUdpSocket { } let packet_len = packet.len(); buf[..packet_len].copy_from_slice(&packet[..]); - Ok((packet_len, self.only_peer)) + Ok((packet_len, Peer::new(self.only_peer))) } } -impl ConnectionPeer for char {} +impl ConnectionPeer for char { + type Id = char; + + fn id(&self) -> Self::Id { + *self + } + + fn consolidate(a: Self, b: Self) -> Self { + assert!(a == b, "Consolidating non-equal peers"); + a + } +} fn build_link_pair() -> (MockUdpSocket, MockUdpSocket) { let (peer_a, peer_b): (char, char) = ('A', 'B'); @@ -110,12 +122,12 @@ fn build_connection_id_pair_starting_at( let a_cid = ConnectionId { send: higher_id, recv: lower_id, - peer: socket_a.only_peer, + peer_id: socket_a.only_peer, }; let b_cid = ConnectionId { send: lower_id, recv: higher_id, - peer: socket_b.only_peer, + peer_id: socket_b.only_peer, }; (a_cid, b_cid) } diff --git a/src/udp.rs b/src/udp.rs index 62d2bae..ced0e7a 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -4,25 +4,27 @@ use std::net::SocketAddr; use async_trait::async_trait; use tokio::net::UdpSocket; -use crate::cid::ConnectionPeer; +use crate::peer::{ConnectionPeer, Peer}; /// An abstract representation of an asynchronous UDP socket. #[async_trait] pub trait AsyncUdpSocket: Send + Sync { - /// Attempts to send data on the socket to a given address. + /// Attempts to send data on the socket to a given peer. /// Note that this should return nearly immediately, rather than awaiting something internally. - async fn send_to(&mut self, buf: &[u8], target: &P) -> io::Result; + async fn send_to(&mut self, buf: &[u8], peer: &Peer

) -> io::Result; /// Attempts to receive a single datagram on the socket. - async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, P)>; + async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer

)>; } #[async_trait] impl AsyncUdpSocket for UdpSocket { - async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result { - UdpSocket::send_to(self, buf, target).await + async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result { + UdpSocket::send_to(self, buf, peer.id()).await } - async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - UdpSocket::recv_from(self, buf).await + async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)> { + UdpSocket::recv_from(self, buf) + .await + .map(|(len, peer)| (len, Peer::new(peer))) } } diff --git a/tests/socket.rs b/tests/socket.rs index 026d80d..2f9382e 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -1,6 +1,7 @@ use futures::stream::{FuturesUnordered, StreamExt}; use std::net::SocketAddr; use std::sync::Arc; +use utp_rs::peer::Peer; use tokio::task::JoinHandle; use tokio::time::Instant; @@ -115,16 +116,19 @@ async fn initiate_transfer( let recv_cid = cid::ConnectionId { send: initiator_cid, recv: responder_cid, - peer: send_addr, + peer_id: send_addr, }; let send_cid = cid::ConnectionId { send: responder_cid, recv: initiator_cid, - peer: recv_addr, + peer_id: recv_addr, }; let recv_handle = tokio::spawn(async move { - let mut stream = recv.accept_with_cid(recv_cid, conn_config).await.unwrap(); + let mut stream = recv + .accept_with_cid(recv_cid, Peer::new(send_addr), conn_config) + .await + .unwrap(); let mut buf = vec![]; let n = match stream.read_to_eof(&mut buf).await { Ok(num_bytes) => num_bytes, @@ -141,7 +145,10 @@ async fn initiate_transfer( }); let send_handle = tokio::spawn(async move { - let mut stream = send.connect_with_cid(send_cid, conn_config).await.unwrap(); + let mut stream = send + .connect_with_cid(send_cid, Peer::new(recv_addr), conn_config) + .await + .unwrap(); let n = stream.write(data).await.unwrap(); assert_eq!(n, data.len()); @@ -174,18 +181,18 @@ async fn test_socket_reports_two_connections() { let recv_one_cid = cid::ConnectionId { send: 100, recv: 101, - peer: send_addr, + peer_id: send_addr, }; let send_one_cid = cid::ConnectionId { send: 101, recv: 100, - peer: recv_addr, + peer_id: recv_addr, }; let recv_one = Arc::clone(&recv); let recv_one_handle = tokio::spawn(async move { recv_one - .accept_with_cid(recv_one_cid, conn_config) + .accept_with_cid(recv_one_cid, Peer::new(send_addr), conn_config) .await .unwrap() }); @@ -193,7 +200,7 @@ async fn test_socket_reports_two_connections() { let send_one = Arc::clone(&send); let send_one_handle = tokio::spawn(async move { send_one - .connect_with_cid(send_one_cid, conn_config) + .connect_with_cid(send_one_cid, Peer::new(recv_addr), conn_config) .await .unwrap() }); @@ -201,18 +208,18 @@ async fn test_socket_reports_two_connections() { let recv_two_cid = cid::ConnectionId { send: 200, recv: 201, - peer: send_addr, + peer_id: send_addr, }; let send_two_cid = cid::ConnectionId { send: 201, recv: 200, - peer: recv_addr, + peer_id: recv_addr, }; let recv_two = Arc::clone(&recv); let recv_two_handle = tokio::spawn(async move { recv_two - .accept_with_cid(recv_two_cid, conn_config) + .accept_with_cid(recv_two_cid, Peer::new(send_addr), conn_config) .await .unwrap() }); @@ -220,7 +227,7 @@ async fn test_socket_reports_two_connections() { let send_two = Arc::clone(&send); let send_two_handle = tokio::spawn(async move { send_two - .connect_with_cid(send_two_cid, conn_config) + .connect_with_cid(send_two_cid, Peer::new(recv_addr), conn_config) .await .unwrap() }); diff --git a/tests/stream.rs b/tests/stream.rs index 188a00f..719c196 100644 --- a/tests/stream.rs +++ b/tests/stream.rs @@ -6,6 +6,7 @@ use std::time::Duration; use tokio::time::timeout; use utp_rs::conn::{ConnectionConfig, DEFAULT_MAX_IDLE_TIMEOUT}; +use utp_rs::peer::Peer; use utp_rs::socket::UtpSocket; use utp_rs::testutils; @@ -29,7 +30,7 @@ async fn close_is_successful_when_write_completes() { let recv_one = Arc::clone(&recv); let recv_one_handle = tokio::spawn(async move { recv_one - .accept_with_cid(recv_cid, conn_config) + .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config) .await .unwrap() }); @@ -39,7 +40,7 @@ async fn close_is_successful_when_write_completes() { let send_one = Arc::clone(&send); let send_one_handle = tokio::spawn(async move { send_one - .connect_with_cid(send_cid, conn_config) + .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config) .await .unwrap() }); @@ -100,7 +101,7 @@ async fn close_errors_if_all_packets_dropped() { let recv_one = Arc::clone(&recv); let recv_one_handle = tokio::spawn(async move { recv_one - .accept_with_cid(recv_cid, conn_config) + .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config) .await .unwrap() }); @@ -110,7 +111,7 @@ async fn close_errors_if_all_packets_dropped() { let send_one = Arc::clone(&send); let send_one_handle = tokio::spawn(async move { send_one - .connect_with_cid(send_cid, conn_config) + .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config) .await .unwrap() }); @@ -178,7 +179,7 @@ async fn close_succeeds_if_only_fin_ack_dropped() { let recv_one = Arc::clone(&recv); let recv_one_handle = tokio::spawn(async move { recv_one - .accept_with_cid(recv_cid, conn_config) + .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config) .await .unwrap() }); @@ -188,7 +189,7 @@ async fn close_succeeds_if_only_fin_ack_dropped() { let send_one = Arc::clone(&send); let send_one_handle = tokio::spawn(async move { send_one - .connect_with_cid(send_cid, conn_config) + .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config) .await .unwrap() });