From ebc5dd9cd86a0085c533f8fa82cd328c685bbb69 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Thu, 2 Jan 2025 10:54:38 +0100 Subject: [PATCH] io: extend `Buf` length only after having read into it --- tokio/src/fs/file.rs | 8 ++++-- tokio/src/io/blocking.rs | 48 ++++++++++++++++++++++-------------- tokio/src/io/stderr.rs | 6 ++++- tokio/src/io/stdin.rs | 6 ++++- tokio/src/io/stdout.rs | 6 ++++- tokio/src/process/windows.rs | 6 ++++- 6 files changed, 56 insertions(+), 24 deletions(-) diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index 63dd8af3e98..7847066404d 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -7,6 +7,7 @@ use crate::io::blocking::{Buf, DEFAULT_MAX_BUF_SIZE}; use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use crate::sync::Mutex; +use std::cmp; use std::fmt; use std::fs::{Metadata, Permissions}; use std::future::Future; @@ -600,11 +601,14 @@ impl AsyncRead for File { return Poll::Ready(Ok(())); } - buf.ensure_capacity_for(dst, me.max_buf_size); let std = me.std.clone(); + let max_buf_size = cmp::min(dst.remaining(), me.max_buf_size); inner.state = State::Busy(spawn_blocking(move || { - let res = buf.read_from(&mut &*std); + // SAFETY: the `Read` implementation of `std` does not + // read from the buffer it is borrowing and correctly + // reports the length of the data written into the buffer. + let res = unsafe { buf.read_from(&mut &*std, max_buf_size) }; (Operation::Read(res), buf) })); } diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs index f189136b52e..1af5065456d 100644 --- a/tokio/src/io/blocking.rs +++ b/tokio/src/io/blocking.rs @@ -5,6 +5,7 @@ use std::cmp; use std::future::Future; use std::io; use std::io::prelude::*; +use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{ready, Context, Poll}; @@ -33,8 +34,13 @@ enum State { cfg_io_blocking! { impl Blocking { + /// # Safety + /// + /// The `Read` implementation of `inner` must never read from the buffer + /// it is borrowing and must correctly report the length of the data + /// written into the buffer. #[cfg_attr(feature = "fs", allow(dead_code))] - pub(crate) fn new(inner: T) -> Blocking { + pub(crate) unsafe fn new(inner: T) -> Blocking { Blocking { inner: Some(inner), state: State::Idle(Some(Buf::with_capacity(0))), @@ -64,11 +70,12 @@ where return Poll::Ready(Ok(())); } - buf.ensure_capacity_for(dst, DEFAULT_MAX_BUF_SIZE); let mut inner = self.inner.take().unwrap(); + let max_buf_size = cmp::min(dst.remaining(), DEFAULT_MAX_BUF_SIZE); self.state = State::Busy(sys::run(move || { - let res = buf.read_from(&mut inner); + // SAFETY: the requirements are satisfied by `Blocking::new`. + let res = unsafe { buf.read_from(&mut inner, max_buf_size) }; (res, buf, inner) })); } @@ -227,25 +234,30 @@ impl Buf { &self.buf[self.pos..] } - pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>, max_buf_size: usize) { + /// # Safety + /// + /// `rd` must not read from the buffer `read` is borrowing and must correctly + /// report the length of the data written into the buffer. + pub(crate) unsafe fn read_from( + &mut self, + rd: &mut T, + max_buf_size: usize, + ) -> io::Result { assert!(self.is_empty()); + self.buf.reserve(max_buf_size); - let len = cmp::min(bytes.remaining(), max_buf_size); - - if self.buf.len() < len { - self.buf.reserve(len - self.buf.len()); - } - - unsafe { - self.buf.set_len(len); - } - } - - pub(crate) fn read_from(&mut self, rd: &mut T) -> io::Result { - let res = uninterruptibly!(rd.read(&mut self.buf)); + let buf = &mut self.buf.spare_capacity_mut()[..max_buf_size]; + // SAFETY: The memory may be uninitialized, but `rd.read` will only write to the buffer. + let buf = unsafe { &mut *(buf as *mut [MaybeUninit] as *mut [u8]) }; + let res = uninterruptibly!(rd.read(buf)); if let Ok(n) = res { - self.buf.truncate(n); + // SAFETY: the caller promises that `rd.read` initializes + // a section of `buf` and correctly reports that length. + // The `self.is_empty()` assertion verifies that `n` + // equals the length of the `buf` capacity that was written + // to (and that `buf` isn't being shrunk). + unsafe { self.buf.set_len(n) } } else { self.buf.clear(); } diff --git a/tokio/src/io/stderr.rs b/tokio/src/io/stderr.rs index e55cb1628fb..0988e2d9da0 100644 --- a/tokio/src/io/stderr.rs +++ b/tokio/src/io/stderr.rs @@ -67,8 +67,12 @@ cfg_io_std! { /// ``` pub fn stderr() -> Stderr { let std = io::stderr(); + // SAFETY: The `Read` implementation of `std` does not read from the + // buffer it is borrowing and correctly reports the length of the data + // written into the buffer. + let blocking = unsafe { Blocking::new(std) }; Stderr { - std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)), + std: SplitByUtf8BoundaryIfWindows::new(blocking), } } } diff --git a/tokio/src/io/stdin.rs b/tokio/src/io/stdin.rs index 640cb0b7236..877c48b30fb 100644 --- a/tokio/src/io/stdin.rs +++ b/tokio/src/io/stdin.rs @@ -42,8 +42,12 @@ cfg_io_std! { /// user input and use blocking IO directly in that thread. pub fn stdin() -> Stdin { let std = io::stdin(); + // SAFETY: The `Read` implementation of `std` does not read from the + // buffer it is borrowing and correctly reports the length of the data + // written into the buffer. + let std = unsafe { Blocking::new(std) }; Stdin { - std: Blocking::new(std), + std, } } } diff --git a/tokio/src/io/stdout.rs b/tokio/src/io/stdout.rs index b4621469202..f46ca0f05c4 100644 --- a/tokio/src/io/stdout.rs +++ b/tokio/src/io/stdout.rs @@ -116,8 +116,12 @@ cfg_io_std! { /// ``` pub fn stdout() -> Stdout { let std = io::stdout(); + // SAFETY: The `Read` implementation of `std` does not read from the + // buffer it is borrowing and correctly reports the length of the data + // written into the buffer. + let blocking = unsafe { Blocking::new(std) }; Stdout { - std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)), + std: SplitByUtf8BoundaryIfWindows::new(blocking), } } } diff --git a/tokio/src/process/windows.rs b/tokio/src/process/windows.rs index 792a9c9610b..db3c15790ce 100644 --- a/tokio/src/process/windows.rs +++ b/tokio/src/process/windows.rs @@ -242,7 +242,11 @@ where use std::os::windows::prelude::FromRawHandle; let raw = Arc::new(unsafe { StdFile::from_raw_handle(io.into_raw_handle()) }); - let io = Blocking::new(ArcFile(raw.clone())); + let io = ArcFile(raw.clone()); + // SAFETY: the `Read` implementation of `io` does not + // read from the buffer it is borrowing and correctly + // reports the length of the data written into the buffer. + let io = unsafe { Blocking::new(io) }; Ok(ChildStdio { raw, io }) }