diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index 755b5eabd16..d7e9121c6e3 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -107,7 +107,61 @@ struct Inner { #[derive(Debug)] enum State { Idle(Option), - Busy(JoinHandle<(Operation, Buf)>), + Busy(JoinHandleInner<(Operation, Buf)>), +} + +#[derive(Debug)] +enum JoinHandleInner { + Blocking(JoinHandle), + #[cfg(all( + tokio_unstable, + feature = "io-uring", + feature = "rt", + feature = "fs", + target_os = "linux" + ))] + Async(BoxedOp), +} + +cfg_io_uring! { + struct BoxedOp(Pin + Send + Sync + 'static>>); + + impl std::fmt::Debug for BoxedOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // format of BoxedFuture(T::type_name()) + f.debug_tuple("BoxedFuture") + .field(&std::any::type_name::()) + .finish() + } + } + + impl Future for BoxedOp { + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.as_mut().poll(cx) + } + } +} + +impl Future for JoinHandleInner<(Operation, Buf)> { + type Output = io::Result<(Operation, Buf)>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + JoinHandleInner::Blocking(ref mut jh) => Pin::new(jh) + .poll(cx) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "background task failed")), + #[cfg(all( + tokio_unstable, + feature = "io-uring", + feature = "rt", + feature = "fs", + target_os = "linux" + ))] + JoinHandleInner::Async(ref mut jh) => Pin::new(jh).poll(cx).map(Ok), + } + } } #[derive(Debug)] @@ -399,7 +453,7 @@ impl File { let std = self.std.clone(); - inner.state = State::Busy(spawn_blocking(move || { + inner.state = State::Busy(JoinHandleInner::Blocking(spawn_blocking(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| std.set_len(size)) } else { @@ -409,7 +463,7 @@ impl File { // Return the result as a seek (Operation::Seek(res), buf) - })); + }))); let (op, buf) = match inner.state { State::Idle(_) => unreachable!(), @@ -613,13 +667,14 @@ impl AsyncRead for File { let std = me.std.clone(); let max_buf_size = cmp::min(dst.remaining(), me.max_buf_size); - inner.state = State::Busy(spawn_blocking(move || { - // SAFETY: the `Read` implementation of `std` does not - // read from the buffer it is borrowing and correctly - // reports the length of the data written into the buffer. - let res = unsafe { buf.read_from(&mut &*std, max_buf_size) }; - (Operation::Read(res), buf) - })); + inner.state = + State::Busy(JoinHandleInner::Blocking(spawn_blocking(move || { + // SAFETY: the `Read` implementation of `std` does not + // read from the buffer it is borrowing and correctly + // reports the length of the data written into the buffer. + let res = unsafe { buf.read_from(&mut &*std, max_buf_size) }; + (Operation::Read(res), buf) + }))); } State::Busy(ref mut rx) => { let (op, mut buf) = ready!(Pin::new(rx).poll(cx))?; @@ -685,10 +740,10 @@ impl AsyncSeek for File { let std = me.std.clone(); - inner.state = State::Busy(spawn_blocking(move || { + inner.state = State::Busy(JoinHandleInner::Blocking(spawn_blocking(move || { let res = (&*std).seek(pos); (Operation::Seek(res), buf) - })); + }))); Ok(()) } } @@ -753,20 +808,90 @@ impl AsyncWrite for File { let n = buf.copy_from(src, me.max_buf_size); let std = me.std.clone(); - let blocking_task_join_handle = spawn_mandatory_blocking(move || { - let res = if let Some(seek) = seek { - (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) - } else { - buf.write_to(&mut &*std) - }; + #[allow(unused_mut)] + let mut data = Some((std, buf)); + + let mut task_join_handle = None; + + #[cfg(all( + tokio_unstable, + feature = "io-uring", + feature = "rt", + feature = "fs", + target_os = "linux" + ))] + { + use crate::runtime::Handle; + + // Handle not present in some tests? + if let Ok(handle) = Handle::try_current() { + if handle.inner.driver().io().check_and_init()? { + task_join_handle = { + use crate::{io::uring::utils::ArcFd, runtime::driver::op::Op}; + + let (std, mut buf) = data.take().unwrap(); + if let Some(seek) = seek { + // we do std seek before a write, so we can always use u64::MAX (current cursor) for the file offset + // seeking only modifies kernel metadata and does not block, so we can do it here + (&*std).seek(seek).map_err(|e| { + io::Error::new( + e.kind(), + format!("failed to seek before write: {e}"), + ) + })?; + } + + let mut fd: ArcFd = std; + let handle = BoxedOp(Box::pin(async move { + loop { + let op = Op::write_at(fd, buf, u64::MAX); + let (r, _buf, _fd) = op.await; + buf = _buf; + fd = _fd; + match r { + Ok(0) => { + break ( + Operation::Write(Err( + io::ErrorKind::WriteZero.into(), + )), + buf, + ); + } + Ok(_) if buf.is_empty() => { + break (Operation::Write(Ok(())), buf); + } + Ok(_) => continue, // more to write + Err(e) => break (Operation::Write(Err(e)), buf), + } + } + })); + + Some(JoinHandleInner::Async(handle)) + }; + } + } + } - (Operation::Write(res), buf) - }) - .ok_or_else(|| { - io::Error::new(io::ErrorKind::Other, "background task failed") - })?; + if let Some((std, mut buf)) = data { + task_join_handle = { + let handle = spawn_mandatory_blocking(move || { + let res = if let Some(seek) = seek { + (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) + } else { + buf.write_to(&mut &*std) + }; + + (Operation::Write(res), buf) + }) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "background task failed") + })?; + + Some(JoinHandleInner::Blocking(handle)) + }; + } - inner.state = State::Busy(blocking_task_join_handle); + inner.state = State::Busy(task_join_handle.unwrap()); return Poll::Ready(Ok(n)); } @@ -824,20 +949,88 @@ impl AsyncWrite for File { let n = buf.copy_from_bufs(bufs, me.max_buf_size); let std = me.std.clone(); - let blocking_task_join_handle = spawn_mandatory_blocking(move || { - let res = if let Some(seek) = seek { - (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) - } else { - buf.write_to(&mut &*std) - }; + #[allow(unused_mut)] + let mut data = Some((std, buf)); + + let mut task_join_handle = None; + + #[cfg(all( + tokio_unstable, + feature = "io-uring", + feature = "rt", + feature = "fs", + target_os = "linux" + ))] + { + use crate::runtime::Handle; + + // Handle not present in some tests? + if let Ok(handle) = Handle::try_current() { + if handle.inner.driver().io().check_and_init()? { + task_join_handle = { + use crate::{io::uring::utils::ArcFd, runtime::driver::op::Op}; + + let (std, mut buf) = data.take().unwrap(); + if let Some(seek) = seek { + // we do std seek before a write, so we can always use u64::MAX (current cursor) for the file offset + // seeking only modifies kernel metadata and does not block, so we can do it here + (&*std).seek(seek).map_err(|e| { + io::Error::new( + e.kind(), + format!("failed to seek before write: {e}"), + ) + })?; + } + + let mut fd: ArcFd = std; + let handle = BoxedOp(Box::pin(async move { + loop { + let op = Op::write_at(fd, buf, u64::MAX); + let (r, _buf, _fd) = op.await; + buf = _buf; + fd = _fd; + match r { + Ok(0) => { + break ( + Operation::Write(Err( + io::ErrorKind::WriteZero.into(), + )), + buf, + ); + } + Ok(_) if buf.is_empty() => { + break (Operation::Write(Ok(())), buf); + } + Ok(_) => continue, // more to write + Err(e) => break (Operation::Write(Err(e)), buf), + } + } + })); + + Some(JoinHandleInner::Async(handle)) + }; + } + } + } - (Operation::Write(res), buf) - }) - .ok_or_else(|| { - io::Error::new(io::ErrorKind::Other, "background task failed") - })?; + if let Some((std, mut buf)) = data { + task_join_handle = Some(JoinHandleInner::Blocking( + spawn_mandatory_blocking(move || { + let res = if let Some(seek) = seek { + (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) + } else { + buf.write_to(&mut &*std) + }; + + (Operation::Write(res), buf) + }) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "background task failed") + })?, + )); + } - inner.state = State::Busy(blocking_task_join_handle); + inner.state = State::Busy(task_join_handle.unwrap()); return Poll::Ready(Ok(n)); } diff --git a/tokio/src/fs/write.rs b/tokio/src/fs/write.rs index c70a1978811..443905f563b 100644 --- a/tokio/src/fs/write.rs +++ b/tokio/src/fs/write.rs @@ -1,3 +1,11 @@ +#[cfg(all( + tokio_unstable, + feature = "io-uring", + feature = "rt", + feature = "fs", + target_os = "linux" +))] +use crate::io::blocking; use crate::{fs::asyncify, util::as_ref::OwnedBuf}; use std::{io, path::Path}; @@ -25,7 +33,6 @@ use std::{io, path::Path}; /// ``` pub async fn write(path: impl AsRef, contents: impl AsRef<[u8]>) -> io::Result<()> { let path = path.as_ref(); - let contents = crate::util::as_ref::upgrade(contents); #[cfg(all( tokio_unstable, @@ -38,10 +45,15 @@ pub async fn write(path: impl AsRef, contents: impl AsRef<[u8]>) -> io::Re let handle = crate::runtime::Handle::current(); let driver_handle = handle.inner.driver().io(); if driver_handle.check_and_init()? { - return write_uring(path, contents).await; + use crate::io::blocking; + + let mut buf = blocking::Buf::with_capacity(contents.as_ref().len()); + buf.copy_from(contents.as_ref(), contents.as_ref().len()); + return write_uring(path, buf).await; } } + let contents = crate::util::as_ref::upgrade(contents); write_spawn_blocking(path, contents).await } @@ -52,9 +64,9 @@ pub async fn write(path: impl AsRef, contents: impl AsRef<[u8]>) -> io::Re feature = "fs", target_os = "linux" ))] -async fn write_uring(path: &Path, mut buf: OwnedBuf) -> io::Result<()> { - use crate::{fs::OpenOptions, runtime::driver::op::Op}; - use std::os::fd::OwnedFd; +async fn write_uring(path: &Path, mut buf: blocking::Buf) -> io::Result<()> { + use crate::{fs::OpenOptions, io::uring::utils::ArcFd, runtime::driver::op::Op}; + use std::sync::Arc; let file = OpenOptions::new() .write(true) @@ -63,16 +75,14 @@ async fn write_uring(path: &Path, mut buf: OwnedBuf) -> io::Result<()> { .open(path) .await?; - let mut fd: OwnedFd = file - .try_into_std() - .expect("unexpected in-flight operation detected") - .into(); + let mut fd: ArcFd = Arc::new( + file.try_into_std() + .expect("unexpected in-flight operation detected"), + ); - let total: usize = buf.as_ref().len(); - let mut buf_offset: usize = 0; let mut file_offset: u64 = 0; - while buf_offset < total { - let (n, _buf, _fd) = Op::write_at(fd, buf, buf_offset, file_offset)?.await; + while !buf.is_empty() { + let (n, _buf, _fd) = Op::write_at(fd, buf, file_offset).await; // TODO: handle EINT here let n = n?; if n == 0 { @@ -81,7 +91,6 @@ async fn write_uring(path: &Path, mut buf: OwnedBuf) -> io::Result<()> { buf = _buf; fd = _fd; - buf_offset += n as usize; file_offset += n as u64; } diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs index 1af5065456d..5c9307d5614 100644 --- a/tokio/src/io/blocking.rs +++ b/tokio/src/io/blocking.rs @@ -234,6 +234,25 @@ impl Buf { &self.buf[self.pos..] } + #[cfg(all( + tokio_unstable, + feature = "io-uring", + feature = "rt", + feature = "fs", + target_os = "linux" + ))] + pub(crate) fn advance(&mut self, n: usize) { + if n > self.len() { + panic!("advance past end of buffer"); + } + + self.pos += n; + if self.pos == self.buf.len() { + self.buf.truncate(0); + self.pos = 0; + } + } + /// # Safety /// /// `rd` must not read from the buffer `read` is borrowing and must correctly diff --git a/tokio/src/io/uring/utils.rs b/tokio/src/io/uring/utils.rs index e30e7a5ddc4..7b731c8ea46 100644 --- a/tokio/src/io/uring/utils.rs +++ b/tokio/src/io/uring/utils.rs @@ -1,6 +1,10 @@ +use std::os::fd::AsRawFd; use std::os::unix::ffi::OsStrExt; +use std::sync::Arc; use std::{ffi::CString, io, path::Path}; +pub(crate) type ArcFd = Arc; + pub(crate) fn cstr(p: &Path) -> io::Result { Ok(CString::new(p.as_os_str().as_bytes())?) } diff --git a/tokio/src/io/uring/write.rs b/tokio/src/io/uring/write.rs index 7341f7622da..14c61ce2ebe 100644 --- a/tokio/src/io/uring/write.rs +++ b/tokio/src/io/uring/write.rs @@ -1,19 +1,31 @@ +use crate::io::blocking; +use crate::io::uring::utils::ArcFd; use crate::runtime::driver::op::{CancelData, Cancellable, Completable, CqeResult, Op}; -use crate::util::as_ref::OwnedBuf; use io_uring::{opcode, types}; use std::io::{self, Error}; -use std::os::fd::{AsRawFd, OwnedFd}; -#[derive(Debug)] pub(crate) struct Write { - buf: OwnedBuf, - fd: OwnedFd, + buf: blocking::Buf, + fd: ArcFd, +} + +impl std::fmt::Debug for Write { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Write") + .field("buf_len", &self.buf.len()) + .field("fd", &self.fd.as_raw_fd()) + .finish() + } } impl Completable for Write { - type Output = (io::Result, OwnedBuf, OwnedFd); - fn complete(self, cqe: CqeResult) -> Self::Output { + type Output = (io::Result, blocking::Buf, ArcFd); + fn complete(mut self, cqe: CqeResult) -> Self::Output { + if let Ok(n) = cqe.result.as_ref() { + self.buf.advance(*n as usize); + } + (cqe.result, self.buf, self.fd) } @@ -31,17 +43,12 @@ impl Cancellable for Write { impl Op { /// Issue a write that starts at `buf_offset` within `buf` and writes some bytes /// into `file` at `file_offset`. - pub(crate) fn write_at( - fd: OwnedFd, - buf: OwnedBuf, - buf_offset: usize, - file_offset: u64, - ) -> io::Result { + pub(crate) fn write_at(fd: ArcFd, buf: blocking::Buf, file_offset: u64) -> Self { // There is a cap on how many bytes we can write in a single uring write operation. // ref: https://github.com/axboe/liburing/discussions/497 - let len = u32::try_from(buf.as_ref().len() - buf_offset).unwrap_or(u32::MAX); + let len = u32::try_from(buf.len()).unwrap_or(u32::MAX); - let ptr = buf.as_ref()[buf_offset..buf_offset + len as usize].as_ptr(); + let ptr = buf.bytes().as_ptr(); let sqe = opcode::Write::new(types::Fd(fd.as_raw_fd()), ptr, len) .offset(file_offset) @@ -49,7 +56,6 @@ impl Op { // SAFETY: parameters of the entry, such as `fd` and `buf`, are valid // until this operation completes. - let op = unsafe { Op::new(sqe, Write { buf, fd }) }; - Ok(op) + unsafe { Op::new(sqe, Write { buf, fd }) } } } diff --git a/tokio/tests/fs_uring.rs b/tokio/tests/fs_uring.rs index cd0d207d278..e914cdf6112 100644 --- a/tokio/tests/fs_uring.rs +++ b/tokio/tests/fs_uring.rs @@ -14,6 +14,9 @@ use std::task::Poll; use std::time::Duration; use std::{future::poll_fn, path::PathBuf}; use tempfile::NamedTempFile; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncSeekExt; +use tokio::io::AsyncWriteExt; use tokio::{ fs::OpenOptions, runtime::{Builder, Runtime}, @@ -145,6 +148,41 @@ async fn cancel_op_future() { assert!(res.is_cancelled()); } +#[tokio::test] +async fn test_file_write() { + let (_tmp_file, path): (Vec, Vec) = create_tmp_files(1); + + let mut file = OpenOptions::new().write(true).open(&path[0]).await.unwrap(); + + let data = b"hello io_uring"; + file.write_all(data).await.unwrap(); +} + +#[tokio::test] +async fn test_file_write_seek() { + let (_tmp_file, path): (Vec, Vec) = create_tmp_files(1); + + let mut file = OpenOptions::new() + .write(true) + .read(true) + .open(&path[0]) + .await + .unwrap(); + + let data = b"hello uring"; + file.write_all(data).await.unwrap(); + + file.seek(std::io::SeekFrom::Start(6)).await.unwrap(); + + let data2 = b"world"; + file.write_all(data2).await.unwrap(); + + let mut content = vec![0u8; 11]; + file.seek(std::io::SeekFrom::Start(0)).await.unwrap(); + file.read_exact(&mut content).await.unwrap(); + assert_eq!(&content, b"hello world"); +} + fn create_tmp_files(num_files: usize) -> (Vec, Vec) { let mut files = Vec::with_capacity(num_files); for _ in 0..num_files {