diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 773116c7c..bb8662eac 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -215,6 +215,7 @@ impl Endpoint { let addr = socket.local_addr()?; let socket = self.runtime.wrap_udp_socket(socket)?; let mut inner = self.inner.state.lock().unwrap(); + inner.abandoned_socket = Some(inner.socket.clone()); inner.socket = socket; inner.ipv6 = addr.is_ipv6(); @@ -409,6 +410,9 @@ impl EndpointInner { #[derive(Debug)] pub(crate) struct State { socket: Arc, + /// During an active migration, abandoned_socket receives traffic + /// until the first packet arrives on the new socket. + abandoned_socket: Option>, inner: proto::Endpoint, transmit_state: TransmitState, incoming: VecDeque<(proto::Incoming, BytesMut)>, @@ -432,27 +436,31 @@ pub(crate) struct Shared { } impl State { - fn drive_recv<'a>(&'a mut self, cx: &mut Context, now: Instant) -> Result { - self.recv_limiter.start_cycle(); - let mut metas = [RecvMeta::default(); BATCH_SIZE]; - let mut iovs = MaybeUninit::<[IoSliceMut<'a>; BATCH_SIZE]>::uninit(); - self.recv_buf - .chunks_mut(self.recv_buf.len() / BATCH_SIZE) - .enumerate() - .for_each(|(i, buf)| unsafe { - iovs.as_mut_ptr() - .cast::() - .add(i) - .write(IoSliceMut::<'a>::new(buf)); - }); - let mut iovs = unsafe { iovs.assume_init() }; + fn poll_socket( + &mut self, + primary_socket: bool, + cx: &mut Context, + now: Instant, + iovs: &mut [IoSliceMut; BATCH_SIZE], + mut metas: [RecvMeta; BATCH_SIZE], + ) -> ( + bool, // Recevied data + Result, + ) { + let mut received_data = false; + let socket = if primary_socket { + &self.socket + } else { + self.abandoned_socket.as_ref().unwrap() + }; loop { - match self.socket.poll_recv(cx, &mut iovs, &mut metas) { + match socket.poll_recv(cx, iovs, &mut metas) { Poll::Ready(Ok(msgs)) => { self.recv_limiter.record_work(msgs); for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) { let mut data: BytesMut = buf[0..meta.len].into(); while !data.is_empty() { + received_data = true; let buf = data.split_to(meta.stride.min(data.len())); let mut response_buffer = BytesMut::new(); match self.inner.handle( @@ -464,6 +472,9 @@ impl State { &mut response_buffer, ) { Some(DatagramEvent::NewConnection(incoming)) => { + if !primary_socket { + continue; + } if self.incoming.len() < MAX_INCOMING_CONNECTIONS { self.incoming.push_back((incoming, response_buffer)); } else { @@ -482,6 +493,9 @@ impl State { .send(ConnectionEvent::Proto(event)); } Some(DatagramEvent::Response(transmit)) => { + if !primary_socket { + continue; + } self.transmit_state.respond(transmit, response_buffer); } None => {} @@ -490,7 +504,7 @@ impl State { } } Poll::Pending => { - break; + return (received_data, Ok(false)); } // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an // attacker @@ -498,17 +512,51 @@ impl State { continue; } Poll::Ready(Err(e)) => { - return Err(e); + return (received_data, Err(e)); } } if !self.recv_limiter.allow_work() { - self.recv_limiter.finish_cycle(); - return Ok(true); + return (received_data, Ok(true)); } } + } - self.recv_limiter.finish_cycle(); - Ok(false) + fn drive_recv<'a>(&'a mut self, cx: &mut Context, now: Instant) -> Result { + self.recv_limiter.start_cycle(); + let metas = [RecvMeta::default(); BATCH_SIZE]; + let mut iovs = MaybeUninit::<[IoSliceMut<'a>; BATCH_SIZE]>::uninit(); + self.recv_buf + .chunks_mut(self.recv_buf.len() / BATCH_SIZE) + .enumerate() + .for_each(|(i, buf)| unsafe { + iovs.as_mut_ptr() + .cast::() + .add(i) + .write(IoSliceMut::new(buf)); + }); + let mut iovs = unsafe { iovs.assume_init() }; + let (_, mut result) = if self.abandoned_socket.is_some() { + self.poll_socket(false, cx, now, &mut iovs, metas) + } else { + // Ok(false) means recv_limiter lets more work to be done. + (false, Ok(false)) + }; + match result { + Ok(true) => {} + Ok(false) | Err(_) => { + let received_data: bool; + (received_data, result) = self.poll_socket(true, cx, now, &mut iovs, metas); + if received_data { + // Traffic has arrived on self.socket, therefore + // there is no need for the abandoned one anymore. + self.abandoned_socket = None; + } + } + } + if result.is_ok() { + self.recv_limiter.finish_cycle(); + } + result } fn drive_send(&mut self, cx: &mut Context) -> Result { @@ -754,6 +802,7 @@ impl EndpointRef { }, state: Mutex::new(State { socket, + abandoned_socket: None, inner, ipv6, events,