Skip to content

Commit

Permalink
Socket rebind: drain old socket
Browse files Browse the repository at this point in the history
During a planned/active connection migration allow the client to
receive trafic via the old, abandoned socket until the first packet
arrives on the socket.
  • Loading branch information
nemethf committed Apr 4, 2024
1 parent 1f8611d commit aa30d49
Showing 1 changed file with 70 additions and 21 deletions.
91 changes: 70 additions & 21 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ impl Endpoint {
let addr = socket.local_addr()?;
let socket = self.runtime.wrap_udp_socket(socket)?;
let mut inner = self.inner.state.lock().unwrap();
inner.abandoned_socket = Some(inner.socket.clone());
inner.socket = socket;
inner.ipv6 = addr.is_ipv6();

Expand Down Expand Up @@ -409,6 +410,9 @@ impl EndpointInner {
#[derive(Debug)]
pub(crate) struct State {
socket: Arc<dyn AsyncUdpSocket>,
/// During an active migration, abandoned_socket receives traffic
/// until the first packet arrives on the new socket.
abandoned_socket: Option<Arc<dyn AsyncUdpSocket>>,
inner: proto::Endpoint,
transmit_state: TransmitState,
incoming: VecDeque<(proto::Incoming, BytesMut)>,
Expand All @@ -432,27 +436,31 @@ pub(crate) struct Shared {
}

impl State {
fn drive_recv<'a>(&'a mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
self.recv_limiter.start_cycle();
let mut metas = [RecvMeta::default(); BATCH_SIZE];
let mut iovs = MaybeUninit::<[IoSliceMut<'a>; BATCH_SIZE]>::uninit();
self.recv_buf
.chunks_mut(self.recv_buf.len() / BATCH_SIZE)
.enumerate()
.for_each(|(i, buf)| unsafe {
iovs.as_mut_ptr()
.cast::<IoSliceMut>()
.add(i)
.write(IoSliceMut::<'a>::new(buf));
});
let mut iovs = unsafe { iovs.assume_init() };
fn poll_socket(
&mut self,
primary_socket: bool,
cx: &mut Context,
now: Instant,
iovs: &mut [IoSliceMut; BATCH_SIZE],
mut metas: [RecvMeta; BATCH_SIZE],
) -> (
bool, // Recevied data
Result<bool, io::Error>,
) {
let mut received_data = false;
let socket = if primary_socket {
&self.socket
} else {
self.abandoned_socket.as_ref().unwrap()
};
loop {
match self.socket.poll_recv(cx, &mut iovs, &mut metas) {
match socket.poll_recv(cx, iovs, &mut metas) {
Poll::Ready(Ok(msgs)) => {
self.recv_limiter.record_work(msgs);
for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
let mut data: BytesMut = buf[0..meta.len].into();
while !data.is_empty() {
received_data = true;
let buf = data.split_to(meta.stride.min(data.len()));
let mut response_buffer = BytesMut::new();
match self.inner.handle(
Expand All @@ -464,6 +472,9 @@ impl State {
&mut response_buffer,
) {
Some(DatagramEvent::NewConnection(incoming)) => {
if !primary_socket {
continue;
}
if self.incoming.len() < MAX_INCOMING_CONNECTIONS {
self.incoming.push_back((incoming, response_buffer));
} else {
Expand All @@ -482,6 +493,9 @@ impl State {
.send(ConnectionEvent::Proto(event));
}
Some(DatagramEvent::Response(transmit)) => {
if !primary_socket {
continue;
}
self.transmit_state.respond(transmit, response_buffer);
}
None => {}
Expand All @@ -490,25 +504,59 @@ impl State {
}
}
Poll::Pending => {
break;
return (received_data, Ok(false));
}
// Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
// attacker
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
continue;
}
Poll::Ready(Err(e)) => {
return Err(e);
return (received_data, Err(e));
}
}
if !self.recv_limiter.allow_work() {
self.recv_limiter.finish_cycle();
return Ok(true);
return (received_data, Ok(true));
}
}
}

self.recv_limiter.finish_cycle();
Ok(false)
fn drive_recv<'a>(&'a mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
self.recv_limiter.start_cycle();
let metas = [RecvMeta::default(); BATCH_SIZE];
let mut iovs = MaybeUninit::<[IoSliceMut<'a>; BATCH_SIZE]>::uninit();
self.recv_buf
.chunks_mut(self.recv_buf.len() / BATCH_SIZE)
.enumerate()
.for_each(|(i, buf)| unsafe {
iovs.as_mut_ptr()
.cast::<IoSliceMut>()
.add(i)
.write(IoSliceMut::new(buf));
});
let mut iovs = unsafe { iovs.assume_init() };
let (_, mut result) = if self.abandoned_socket.is_some() {
self.poll_socket(false, cx, now, &mut iovs, metas)
} else {
// Ok(false) means recv_limiter lets more work to be done.
(false, Ok(false))
};
match result {
Ok(true) => {}
Ok(false) | Err(_) => {
let received_data: bool;
(received_data, result) = self.poll_socket(true, cx, now, &mut iovs, metas);
if received_data {
// Traffic has arrived on self.socket, therefore
// there is no need for the abandoned one anymore.
self.abandoned_socket = None;
}
}
}
if result.is_ok() {
self.recv_limiter.finish_cycle();
}
result
}

fn drive_send(&mut self, cx: &mut Context) -> Result<bool, io::Error> {
Expand Down Expand Up @@ -754,6 +802,7 @@ impl EndpointRef {
},
state: Mutex::new(State {
socket,
abandoned_socket: None,
inner,
ipv6,
events,
Expand Down

0 comments on commit aa30d49

Please sign in to comment.