diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index d0dcfaf28..6de61f185 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -3,6 +3,7 @@ use std::{ future::Future, io, io::IoSliceMut, + mem, net::{SocketAddr, SocketAddrV6}, pin::Pin, str, @@ -215,7 +216,7 @@ impl Endpoint { let addr = socket.local_addr()?; let socket = self.runtime.wrap_udp_socket(socket)?; let mut inner = self.inner.state.lock().unwrap(); - inner.socket = socket; + inner.abandoned_socket = Some(mem::replace(&mut inner.socket, socket)); inner.ipv6 = addr.is_ipv6(); // Generate some activity so peers notice the rebind @@ -419,6 +420,9 @@ impl EndpointInner { pub(crate) struct State { socket: Arc, inner: proto::Endpoint, + /// During an active migration, abandoned_socket receives traffic + /// until the first packet arrives on the new socket. + abandoned_socket: Option>, recv_state: RecvState, driver: Option, ipv6: bool, @@ -439,11 +443,27 @@ pub(crate) struct Shared { impl State { fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result { self.recv_state.recv_limiter.start_cycle(); - let poll_res = self - .recv_state - .poll_socket(cx, &mut self.inner, &*self.socket, now)?; + let mut poll_res = PollResult::default(); + if let Some(socket) = &self.abandoned_socket { + poll_res = self + .recv_state + .poll_socket(cx, &mut self.inner, &**socket, now); + }; + if !poll_res.keep_going { + poll_res = self + .recv_state + .poll_socket(cx, &mut self.inner, &*self.socket, now); + if poll_res.received_connection_packet { + // Traffic has arrived on self.socket, therefore + // there is no need for the abandoned one anymore. + self.abandoned_socket = None; + } + } self.recv_state.recv_limiter.finish_cycle(); - Ok(poll_res) + match poll_res.error { + None => Ok(poll_res.keep_going), + Some(err) => Err(err), + } } fn drive_send(&mut self, cx: &mut Context) -> Result { @@ -689,6 +709,7 @@ impl EndpointRef { state: Mutex::new(State { socket, inner, + abandoned_socket: None, ipv6, events, driver: None, @@ -731,6 +752,7 @@ impl std::ops::Deref for EndpointRef { &self.0 } } + /// State directly involved in handling incoming packets #[derive(Debug)] struct RecvState { @@ -772,7 +794,8 @@ impl RecvState { endpoint: &mut proto::Endpoint, socket: &dyn AsyncUdpSocket, now: Instant, - ) -> Result { + ) -> PollResult { + let mut received_connection_packet = false; let mut metas = [RecvMeta::default(); BATCH_SIZE]; let mut iovs: [IoSliceMut; BATCH_SIZE] = { let mut bufs = self @@ -813,6 +836,7 @@ impl RecvState { } Some(DatagramEvent::ConnectionEvent(handle, event)) => { // Ignoring errors from dropped connections that haven't yet been cleaned up + received_connection_packet = true; let _ = self .connections .senders @@ -829,7 +853,11 @@ impl RecvState { } } Poll::Pending => { - break; + return PollResult { + received_connection_packet, + keep_going: false, + error: None, + }; } // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an // attacker @@ -837,14 +865,27 @@ impl RecvState { continue; } Poll::Ready(Err(e)) => { - return Err(e); + return PollResult { + received_connection_packet, + keep_going: false, + error: Some(e), + }; } } if !self.recv_limiter.allow_work() { - return Ok(true); + return PollResult { + received_connection_packet, + keep_going: true, + error: None, + }; } } - - Ok(false) } } + +#[derive(Default)] +struct PollResult { + received_connection_packet: bool, + keep_going: bool, + error: Option, +}