diff --git a/quinn-proto/src/connection/send_buffer.rs b/quinn-proto/src/connection/send_buffer.rs index 73559bc2..aca92dc9 100644 --- a/quinn-proto/src/connection/send_buffer.rs +++ b/quinn-proto/src/connection/send_buffer.rs @@ -1,6 +1,6 @@ use std::{collections::VecDeque, ops::Range}; -use bytes::{Buf, BufMut, Bytes}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::{VarInt, range_set::ArrayRangeSet}; @@ -35,16 +35,23 @@ pub(super) struct SendBuffer { retransmits: ArrayRangeSet, } +/// Maximum number of bytes to combine into a single segment +/// +/// Any segment larger than this will be stored as-is, possibly triggering a flush of the buffer. +const MAX_COMBINE: usize = 1452; + /// This is where the data of the send buffer lives. It supports appending at the end, /// removing from the front, and retrieving data by range. #[derive(Default, Debug)] struct SendBufferData { /// Start offset of the buffered data offset: u64, + /// Total size of [`Self::segments`] and [`Self::last_segment`] + len: usize, /// Buffered data segments segments: VecDeque, - /// Total size of `buffered_segments` - len: usize, + /// Last segment, possibly empty + last_segment: BytesMut, } impl SendBufferData { @@ -62,7 +69,19 @@ impl SendBufferData { /// Append data to the end of the buffer fn append(&mut self, data: Bytes) { self.len += data.len(); - self.segments.push_back(data); + if data.len() > MAX_COMBINE { + // use in place + if !self.last_segment.is_empty() { + self.segments.push_back(self.last_segment.split().freeze()); + } + self.segments.push_back(data); + } else { + // copy + if self.last_segment.len() + data.len() > MAX_COMBINE && !self.last_segment.is_empty() { + self.segments.push_back(self.last_segment.split().freeze()); + } + self.last_segment.extend_from_slice(&data); + } } /// Discard data from the front of the buffer @@ -73,8 +92,10 @@ impl SendBufferData { self.len -= n; self.offset += n as u64; while n > 0 { - let front = self.segments.front_mut().expect("Expected buffered data"); - + // segments is empty, which leaves only last_segment + let Some(front) = self.segments.front_mut() else { + break; + }; if front.len() <= n { // Remove the whole front segment n -= front.len(); @@ -85,11 +106,24 @@ impl SendBufferData { n = 0; } } + // the rest has to be in the last segment + self.last_segment.advance(n); + // shrink segments if we have a lot of unused capacity if self.segments.len() * 4 < self.segments.capacity() { self.segments.shrink_to_fit(); } } + /// Iterator over all segments in order + /// + /// Concatenates `segments` and `last_segment` so they can be handled uniformly + fn segments_iter(&self) -> impl Iterator { + self.segments + .iter() + .map(|x| x.as_ref()) + .chain(std::iter::once(self.last_segment.as_ref())) + } + /// Returns data which is associated with a range /// /// Requesting a range outside of the buffered data will panic. @@ -105,7 +139,7 @@ impl SendBufferData { end: (offsets.end - self.offset) as usize, }; let mut segment_offset = 0; - for segment in self.segments.iter() { + for segment in self.segments_iter() { if offsets.start >= segment_offset && offsets.start < segment_offset + segment.len() { let start = offsets.start - segment_offset; let end = offsets.end - segment_offset; @@ -129,7 +163,7 @@ impl SendBufferData { end: (offsets.end - self.offset) as usize, }; let mut segment_offset = 0; - for segment in self.segments.iter() { + for segment in self.segments_iter() { // intersect segment range with requested range let start = segment_offset.max(offsets.start); let end = (segment_offset + segment.len()).min(offsets.end); @@ -148,8 +182,8 @@ impl SendBufferData { #[cfg(test)] fn to_vec(&self) -> Vec { let mut result = Vec::with_capacity(self.len); - for segment in self.segments.iter() { - result.extend_from_slice(&segment[..]); + for segment in self.segments_iter() { + result.extend_from_slice(segment); } result } @@ -415,48 +449,63 @@ mod tests { ); } + /// tests that large segments are copied as-is in the SendBuffer #[test] - fn multiple_segments() { - let mut buf = SendBuffer::new(); - const MSG: &[u8] = b"Hello, world!"; - const MSG_LEN: u64 = MSG.len() as u64; - - const SEG1: &[u8] = b"He"; - buf.write(SEG1.into()); - const SEG2: &[u8] = b"llo,"; - buf.write(SEG2.into()); - const SEG3: &[u8] = b" w"; - buf.write(SEG3.into()); - const SEG4: &[u8] = b"o"; - buf.write(SEG4.into()); - const SEG5: &[u8] = b"rld!"; - buf.write(SEG5.into()); - - assert_eq!(aggregate_unacked(&buf), MSG); - - assert_eq!(buf.poll_transmit(16), (0..8, true)); - assert_eq!(buf.get(0..5), SEG1); - assert_eq!(buf.get(2..8), SEG2); - assert_eq!(buf.get(6..8), SEG3); - - assert_eq!(buf.poll_transmit(16), (8..MSG_LEN, true)); - assert_eq!(buf.get(8..MSG_LEN), SEG4); - assert_eq!(buf.get(9..MSG_LEN), SEG5); + fn multiple_large_segments() { + // this must be bigger than MAX_COMBINE so we don't get writes coalesced. + const N: usize = 2000; + const K: u64 = N as u64; + fn dup(data: &[u8]) -> Bytes { + let mut buf = BytesMut::with_capacity(data.len() * N); + for c in data { + for _ in 0..N { + buf.put_u8(*c); + } + } + buf.freeze() + } - assert_eq!(buf.poll_transmit(42), (MSG_LEN..MSG_LEN, true)); + fn same(a: &[u8], b: &[u8]) -> bool { + // surprisingly, eq also checks the fat pointer metadata aka length + std::ptr::eq(a.as_ptr(), b.as_ptr()) + } + let mut buf = SendBuffer::new(); + let msg: Bytes = dup(b"Hello, world!"); + let msg_len: u64 = msg.len() as u64; + + let seg1: Bytes = dup(b"He"); + buf.write(seg1.clone()); + let seg2: Bytes = dup(b"llo,"); + buf.write(seg2.clone()); + let seg3: Bytes = dup(b" w"); + buf.write(seg3.clone()); + let seg4: Bytes = dup(b"o"); + buf.write(seg4.clone()); + let seg5: Bytes = dup(b"rld!"); + buf.write(seg5.clone()); + assert_eq!(aggregate_unacked(&buf), msg); + // Check that the segments were stored as-is + assert!(same(buf.get(0..5 * K), &seg1)); + assert!(same(buf.get(2 * K..8 * K), &seg2)); + assert!(same(buf.get(6 * K..8 * K), &seg3)); + assert!(same(buf.get(8 * 2000..msg_len), &seg4)); + assert!(same(buf.get(9 * 2000..msg_len), &seg5)); // Now drain the segments - buf.ack(0..1); - assert_eq!(aggregate_unacked(&buf), &MSG[1..]); - buf.ack(0..3); - assert_eq!(aggregate_unacked(&buf), &MSG[3..]); - buf.ack(3..5); - assert_eq!(aggregate_unacked(&buf), &MSG[5..]); - buf.ack(7..9); - assert_eq!(aggregate_unacked(&buf), &MSG[5..]); - buf.ack(4..7); - assert_eq!(aggregate_unacked(&buf), &MSG[9..]); - buf.ack(0..MSG_LEN); + buf.ack(0..K); + assert_eq!(aggregate_unacked(&buf), &msg[N..]); + buf.ack(0..3 * K); + assert_eq!(aggregate_unacked(&buf), &msg[3 * N..]); + buf.ack(3 * K..5 * K); + assert_eq!(aggregate_unacked(&buf), &msg[5 * N..]); + // ack with gap, doesn't free anything + buf.ack(7 * K..9 * K); + assert_eq!(aggregate_unacked(&buf), &msg[5 * N..]); + // fill the gap, free up to 9 K + buf.ack(4 * K..7 * K); + assert_eq!(aggregate_unacked(&buf), &msg[9 * N..]); + // ack all + buf.ack(0..msg_len); assert_eq!(aggregate_unacked(&buf), &[] as &[u8]); }