From 1ae0e4bf9135afad66d13eb12db0005b72e4dba2 Mon Sep 17 00:00:00 2001 From: Felician Nemeth Date: Tue, 2 Apr 2024 11:33:20 +0200 Subject: [PATCH] Socket rebind: drain old socket During a planned/active connection migration allow the client to receive trafic via the old, abandoned socket until the first packet arrives on the socket. --- quinn/src/endpoint.rs | 83 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 11 deletions(-) diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 773116c7c..88e617201 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -215,6 +215,7 @@ impl Endpoint { let addr = socket.local_addr()?; let socket = self.runtime.wrap_udp_socket(socket)?; let mut inner = self.inner.state.lock().unwrap(); + inner.abandoned_socket = Some(inner.socket.clone()); inner.socket = socket; inner.ipv6 = addr.is_ipv6(); @@ -409,6 +410,9 @@ impl EndpointInner { #[derive(Debug)] pub(crate) struct State { socket: Arc, + /// During an active migration, abandoned_socket receives traffic + /// until the first packet arrives on the new socket. + abandoned_socket: Option>, inner: proto::Endpoint, transmit_state: TransmitState, incoming: VecDeque<(proto::Incoming, BytesMut)>, @@ -443,20 +447,54 @@ impl State { iovs.as_mut_ptr() .cast::() .add(i) - .write(IoSliceMut::<'a>::new(buf)); + .write(IoSliceMut::new(buf)); }); let mut iovs = unsafe { iovs.assume_init() }; + let rcx = &mut ReceiveContext { + cx, + now, + iovs: &mut iovs, + metas: &mut metas, + }; + let mut poll_res = PollResult::default(); + if let Some(socket) = &self.abandoned_socket { + poll_res = self.poll_socket(&socket.clone(), rcx); + }; + if !poll_res.keep_going { + poll_res = self.poll_socket(&self.socket.clone(), rcx); + if poll_res.received_data { + // Traffic has arrived on self.socket, therefore + // there is no need for the abandoned one anymore. + self.abandoned_socket = None; + } + } + match poll_res.error { + None => { + self.recv_limiter.finish_cycle(); + Ok(poll_res.keep_going) + } + Some(err) => Err(err) + } + } + + fn poll_socket( + &mut self, + socket: &Arc, + rc: &mut ReceiveContext, + ) -> PollResult { + let mut received_data = false; loop { - match self.socket.poll_recv(cx, &mut iovs, &mut metas) { + match socket.poll_recv(rc.cx, rc.iovs, rc.metas) { Poll::Ready(Ok(msgs)) => { self.recv_limiter.record_work(msgs); - for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) { + for (meta, buf) in rc.metas.iter().zip(rc.iovs.iter()).take(msgs) { let mut data: BytesMut = buf[0..meta.len].into(); while !data.is_empty() { + received_data = true; let buf = data.split_to(meta.stride.min(data.len())); let mut response_buffer = BytesMut::new(); match self.inner.handle( - now, + rc.now, meta.addr, meta.dst_ip, meta.ecn.map(proto_ecn), @@ -490,7 +528,11 @@ impl State { } } Poll::Pending => { - break; + return PollResult { + received_data, + keep_going: false, + error: None, + }; } // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an // attacker @@ -498,17 +540,21 @@ impl State { continue; } Poll::Ready(Err(e)) => { - return Err(e); + return PollResult { + received_data, + keep_going: false, + error: Some(e), + }; } } if !self.recv_limiter.allow_work() { - self.recv_limiter.finish_cycle(); - return Ok(true); + return PollResult { + received_data, + keep_going: true, + error: None, + }; } } - - self.recv_limiter.finish_cycle(); - Ok(false) } fn drive_send(&mut self, cx: &mut Context) -> Result { @@ -754,6 +800,7 @@ impl EndpointRef { }, state: Mutex::new(State { socket, + abandoned_socket: None, inner, ipv6, events, @@ -805,3 +852,17 @@ impl std::ops::Deref for EndpointRef { &self.0 } } + +struct ReceiveContext<'a, 'b: 'c, 'c> { + cx: &'c mut Context<'b>, + now: Instant, + iovs: &'a mut [IoSliceMut<'a>; BATCH_SIZE], + metas: &'c mut [RecvMeta; BATCH_SIZE], +} + +#[derive(Default)] +struct PollResult { + received_data: bool, + keep_going: bool, + error: Option, +}