diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs index f189136b52e..ebc32c5c6ef 100644 --- a/tokio/src/io/blocking.rs +++ b/tokio/src/io/blocking.rs @@ -21,6 +21,7 @@ pub(crate) struct Blocking { pub(crate) struct Buf { buf: Vec, pos: usize, + init_len: usize, } pub(crate) const DEFAULT_MAX_BUF_SIZE: usize = 2 * 1024 * 1024; @@ -190,6 +191,7 @@ impl Buf { Buf { buf: Vec::with_capacity(n), pos: 0, + init_len: 0, } } @@ -220,6 +222,7 @@ impl Buf { let n = cmp::min(src.len(), max_buf_size); self.buf.extend_from_slice(&src[..n]); + self.init_len = cmp::max(self.init_len, self.buf.len()); n } @@ -236,6 +239,27 @@ impl Buf { self.buf.reserve(len - self.buf.len()); } + if self.init_len < len { + debug_assert!( + self.init_len < self.buf.capacity(), + "init_len of Vec is bigger than the capacity" + ); + debug_assert!( + len <= self.buf.capacity(), + "uninit area of Vec is bigger than the capacity" + ); + + let uninit_len = len - self.init_len; + // SAFETY: the area is within the allocation of the Vec + unsafe { + self.buf + .as_mut_ptr() + .add(self.init_len) + .write_bytes(0, uninit_len); + } + } + + // SAFETY: `len` is within the capacity and is init unsafe { self.buf.set_len(len); } @@ -287,6 +311,7 @@ cfg_fs! { self.buf.extend_from_slice(&buf[..len]); rem -= len; } + self.init_len = cmp::max(self.init_len, self.buf.len()); max_buf_size - rem } diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index dc2c4309e66..7a5b4c6b346 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -293,3 +293,6 @@ cfg_io_blocking! { pub(crate) use crate::blocking::JoinHandle as Blocking; } } + +#[cfg(test)] +mod tests; diff --git a/tokio/src/io/tests.rs b/tokio/src/io/tests.rs new file mode 100644 index 00000000000..81813f5cc2f --- /dev/null +++ b/tokio/src/io/tests.rs @@ -0,0 +1,39 @@ +cfg_io_blocking! { + use std::{ + io::{self, Read}, + mem::MaybeUninit, + }; + + use crate::io::ReadBuf; + use super::blocking::Buf; + + const PAYLOAD: &[u8] = b"hello test!"; + + // Have miri check that `Buf::ensure_capacity_for` initializes its length. + #[test] + fn buf_ensure_capacity_for_len_is_init() { + const MAX_BUF_SIZE: usize = 128; + + let mut buf = Buf::with_capacity(0); + + let mut dst = [MaybeUninit::uninit(); 64]; + buf.ensure_capacity_for(&ReadBuf::uninit(&mut dst), MAX_BUF_SIZE); + miri_assert_init(buf.bytes()); + let res = buf.read_from(&mut EnsureInitReader(PAYLOAD)).unwrap(); + assert_eq!(res, PAYLOAD.len()); + assert_eq!(buf.bytes(), PAYLOAD); + } + + struct EnsureInitReader(R); + + impl Read for EnsureInitReader { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + miri_assert_init(buf); + self.0.read(buf) + } + } + + fn miri_assert_init(buf: &[u8]) { + for &_b in buf {} + } +}