Skip to content

Commit

Permalink
Socket rebind: drain old socket
Browse files Browse the repository at this point in the history
During a planned/active connection migration allow the client to
receive trafic via the old, abandoned socket until the first packet
arrives on the socket.
  • Loading branch information
nemethf committed Apr 8, 2024
1 parent 1f8611d commit 1ae0e4b
Showing 1 changed file with 72 additions and 11 deletions.
83 changes: 72 additions & 11 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ impl Endpoint {
let addr = socket.local_addr()?;
let socket = self.runtime.wrap_udp_socket(socket)?;
let mut inner = self.inner.state.lock().unwrap();
inner.abandoned_socket = Some(inner.socket.clone());
inner.socket = socket;
inner.ipv6 = addr.is_ipv6();

Expand Down Expand Up @@ -409,6 +410,9 @@ impl EndpointInner {
#[derive(Debug)]
pub(crate) struct State {
socket: Arc<dyn AsyncUdpSocket>,
/// During an active migration, abandoned_socket receives traffic
/// until the first packet arrives on the new socket.
abandoned_socket: Option<Arc<dyn AsyncUdpSocket>>,
inner: proto::Endpoint,
transmit_state: TransmitState,
incoming: VecDeque<(proto::Incoming, BytesMut)>,
Expand Down Expand Up @@ -443,20 +447,54 @@ impl State {
iovs.as_mut_ptr()
.cast::<IoSliceMut>()
.add(i)
.write(IoSliceMut::<'a>::new(buf));
.write(IoSliceMut::new(buf));
});
let mut iovs = unsafe { iovs.assume_init() };
let rcx = &mut ReceiveContext {
cx,
now,
iovs: &mut iovs,
metas: &mut metas,
};
let mut poll_res = PollResult::default();
if let Some(socket) = &self.abandoned_socket {
poll_res = self.poll_socket(&socket.clone(), rcx);
};
if !poll_res.keep_going {
poll_res = self.poll_socket(&self.socket.clone(), rcx);
if poll_res.received_data {
// Traffic has arrived on self.socket, therefore
// there is no need for the abandoned one anymore.
self.abandoned_socket = None;
}
}
match poll_res.error {
None => {
self.recv_limiter.finish_cycle();
Ok(poll_res.keep_going)
}
Some(err) => Err(err)
}
}

fn poll_socket(
&mut self,
socket: &Arc<dyn AsyncUdpSocket>,
rc: &mut ReceiveContext,
) -> PollResult {
let mut received_data = false;
loop {
match self.socket.poll_recv(cx, &mut iovs, &mut metas) {
match socket.poll_recv(rc.cx, rc.iovs, rc.metas) {
Poll::Ready(Ok(msgs)) => {
self.recv_limiter.record_work(msgs);
for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
for (meta, buf) in rc.metas.iter().zip(rc.iovs.iter()).take(msgs) {
let mut data: BytesMut = buf[0..meta.len].into();
while !data.is_empty() {
received_data = true;
let buf = data.split_to(meta.stride.min(data.len()));
let mut response_buffer = BytesMut::new();
match self.inner.handle(
now,
rc.now,
meta.addr,
meta.dst_ip,
meta.ecn.map(proto_ecn),
Expand Down Expand Up @@ -490,25 +528,33 @@ impl State {
}
}
Poll::Pending => {
break;
return PollResult {
received_data,
keep_going: false,
error: None,
};
}
// Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
// attacker
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
continue;
}
Poll::Ready(Err(e)) => {
return Err(e);
return PollResult {
received_data,
keep_going: false,
error: Some(e),
};
}
}
if !self.recv_limiter.allow_work() {
self.recv_limiter.finish_cycle();
return Ok(true);
return PollResult {
received_data,
keep_going: true,
error: None,
};
}
}

self.recv_limiter.finish_cycle();
Ok(false)
}

fn drive_send(&mut self, cx: &mut Context) -> Result<bool, io::Error> {
Expand Down Expand Up @@ -754,6 +800,7 @@ impl EndpointRef {
},
state: Mutex::new(State {
socket,
abandoned_socket: None,
inner,
ipv6,
events,
Expand Down Expand Up @@ -805,3 +852,17 @@ impl std::ops::Deref for EndpointRef {
&self.0
}
}

struct ReceiveContext<'a, 'b: 'c, 'c> {
cx: &'c mut Context<'b>,
now: Instant,
iovs: &'a mut [IoSliceMut<'a>; BATCH_SIZE],
metas: &'c mut [RecvMeta; BATCH_SIZE],
}

#[derive(Default)]
struct PollResult {
received_data: bool,
keep_going: bool,
error: Option<io::Error>,
}

0 comments on commit 1ae0e4b

Please sign in to comment.