Skip to content

Commit 0b76713

Browse files
committed
Fix unsoundness of peek_ahead in iter.rs
1 parent 380f130 commit 0b76713

File tree

2 files changed

+39
-28
lines changed

2 files changed

+39
-28
lines changed

src/iter.rs

+27-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use core::convert::TryInto;
21
use core::convert::TryFrom;
2+
use core::convert::TryInto;
33

44
#[allow(missing_docs)]
55
pub struct Bytes<'a> {
@@ -41,19 +41,27 @@ impl<'a> Bytes<'a> {
4141
}
4242
}
4343

44-
#[inline]
45-
pub fn peek_ahead(&self, n: usize) -> Option<u8> {
46-
// SAFETY: obtain a potentially OOB pointer that is later compared against the `self.end`
47-
// pointer.
48-
let ptr = self.cursor.wrapping_add(n);
49-
if ptr < self.end {
50-
// SAFETY: bounds checked pointer dereference is safe
51-
Some(unsafe { *ptr })
44+
/// Peek at byte `n` ahead of cursor
45+
///
46+
/// # Safety
47+
///
48+
/// Caller must ensure that `n <= self.len()`, otherwise `self.cursor.add(n)` is UB.
49+
/// That means there are at least `n-1` bytes between `self.cursor` and `self.end`
50+
/// and `self.cursor.add(n)` is either `self.end` or points to a valid byte.
51+
#[inline]
52+
pub unsafe fn peek_ahead(&self, n: usize) -> Option<u8> {
53+
debug_assert!(n <= self.len());
54+
// SAFETY: by preconditions
55+
let p = unsafe { self.cursor.add(n) };
56+
if p < self.end {
57+
// SAFETY: by preconditions, if this is not `self.end`,
58+
// then it is safe to dereference
59+
Some(unsafe { *p })
5260
} else {
5361
None
5462
}
5563
}
56-
64+
5765
#[inline]
5866
pub fn peek_n<'b: 'a, U: TryFrom<&'a [u8]>>(&'b self, n: usize) -> Option<U> {
5967
// TODO: once we bump MSRC, use const generics to allow only [u8; N] reads
@@ -65,7 +73,7 @@ impl<'a> Bytes<'a> {
6573
/// Advance by 1, equivalent to calling `advance(1)`.
6674
///
6775
/// # Safety
68-
///
76+
///
6977
/// Caller must ensure that Bytes hasn't been advanced/bumped by more than [`Bytes::len()`].
7078
#[inline]
7179
pub unsafe fn bump(&mut self) {
@@ -75,7 +83,7 @@ impl<'a> Bytes<'a> {
7583
/// Advance cursor by `n`
7684
///
7785
/// # Safety
78-
///
86+
///
7987
/// Caller must ensure that Bytes hasn't been advanced/bumped by more than [`Bytes::len()`].
8088
#[inline]
8189
pub unsafe fn advance(&mut self, n: usize) {
@@ -104,7 +112,7 @@ impl<'a> Bytes<'a> {
104112
// TODO: this is an anti-pattern, should be removed
105113
/// Deprecated. Do not use!
106114
/// # Safety
107-
///
115+
///
108116
/// Caller must ensure that `skip` is at most the number of advances (i.e., `bytes.advance(3)`
109117
/// implies a skip of at most 3).
110118
#[inline]
@@ -114,21 +122,21 @@ impl<'a> Bytes<'a> {
114122
self.commit();
115123
head
116124
}
117-
125+
118126
#[inline]
119127
pub fn commit(&mut self) {
120128
self.start = self.cursor
121129
}
122130

123131
/// # Safety
124-
///
132+
///
125133
/// see [`Bytes::advance`] safety comment.
126134
#[inline]
127135
pub unsafe fn advance_and_commit(&mut self, n: usize) {
128136
self.advance(n);
129137
self.commit();
130138
}
131-
139+
132140
#[inline]
133141
pub fn as_ptr(&self) -> *const u8 {
134142
self.cursor
@@ -138,14 +146,14 @@ impl<'a> Bytes<'a> {
138146
pub fn start(&self) -> *const u8 {
139147
self.start
140148
}
141-
149+
142150
#[inline]
143151
pub fn end(&self) -> *const u8 {
144152
self.end
145153
}
146-
154+
147155
/// # Safety
148-
///
156+
///
149157
/// Must ensure invariant `bytes.start() <= ptr && ptr <= bytes.end()`.
150158
#[inline]
151159
pub unsafe fn set_cursor(&mut self, ptr: *const u8) {

src/lib.rs

+12-9
Original file line numberDiff line numberDiff line change
@@ -849,20 +849,23 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> {
849849
const POST: [u8; 4] = *b"POST";
850850
match bytes.peek_n::<[u8; 4]>(4) {
851851
Some(GET) => {
852-
// SAFETY: matched the ASCII string and boundary checked
852+
// SAFETY: we matched "GET " which has 4 bytes and is ASCII
853853
let method = unsafe {
854-
bytes.advance(4);
855-
let buf = bytes.slice_skip(1);
856-
str::from_utf8_unchecked(buf)
854+
bytes.advance(4); // advance cursor past "GET "
855+
str::from_utf8_unchecked(bytes.slice_skip(1)) // "GET" without space
857856
};
858857
Ok(Status::Complete(method))
859858
}
860-
Some(POST) if bytes.peek_ahead(4) == Some(b' ') => {
861-
// SAFETY: matched the ASCII string and boundary checked
859+
// SAFETY:
860+
// If `bytes.peek_n...` returns a Some([u8; 4]),
861+
// then we are assured that `bytes` contains at least 4 bytes.
862+
// Thus `bytes.len() >= 4`,
863+
// and it is safe to peek at byte 4 with `bytes.peek_ahead(4)`.
864+
Some(POST) if unsafe { bytes.peek_ahead(4) } == Some(b' ') => {
865+
// SAFETY: we matched "POST " which has 5 bytes
862866
let method = unsafe {
863-
bytes.advance(5);
864-
let buf = bytes.slice_skip(1);
865-
str::from_utf8_unchecked(buf)
867+
bytes.advance(5); // advance cursor past "POST "
868+
str::from_utf8_unchecked(bytes.slice_skip(1)) // "POST" without space
866869
};
867870
Ok(Status::Complete(method))
868871
}

0 commit comments

Comments
 (0)