diff --git a/Cargo.lock b/Cargo.lock index c707253d7e21..8b2df3bfa728 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5026,6 +5026,7 @@ dependencies = [ "sha2", "thiserror", "tokio", + "tokio-util", "tracing", "url", "uv-client", diff --git a/crates/uv-fs/src/lib.rs b/crates/uv-fs/src/lib.rs index 1d9c4fef16ec..5c69e5a962b3 100644 --- a/crates/uv-fs/src/lib.rs +++ b/crates/uv-fs/src/lib.rs @@ -1,7 +1,11 @@ use fs2::FileExt; use std::fmt::Display; +use std::io; use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context, Poll}; use tempfile::NamedTempFile; +use tokio::io::{AsyncRead, ReadBuf}; use tracing::{debug, error, info, trace, warn}; pub use crate::path::*; @@ -387,3 +391,32 @@ impl Drop for LockedFile { } } } + +/// An asynchronous reader that reports progress as bytes are read. +pub struct ProgressReader { + reader: Reader, + callback: Callback, +} + +impl ProgressReader { + /// Create a new [`ProgressReader`] that wraps another reader. + pub fn new(reader: Reader, callback: Callback) -> Self { + Self { reader, callback } + } +} + +impl AsyncRead + for ProgressReader +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.as_mut().reader) + .poll_read(cx, buf) + .map_ok(|()| { + (self.callback)(buf.filled().len()); + }) + } +} diff --git a/crates/uv-publish/Cargo.toml b/crates/uv-publish/Cargo.toml index 6e186b978dd3..20224a4a2bc6 100644 --- a/crates/uv-publish/Cargo.toml +++ b/crates/uv-publish/Cargo.toml @@ -31,6 +31,7 @@ serde_json = { workspace = true } sha2 = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } +tokio-util = { workspace = true , features = ["io"] } tracing = { workspace = true } url = { workspace = true } diff --git a/crates/uv-publish/src/lib.rs b/crates/uv-publish/src/lib.rs index 76b00b942a4b..cbe0e9c34e2d 100644 --- a/crates/uv-publish/src/lib.rs +++ b/crates/uv-publish/src/lib.rs @@ -15,13 +15,15 @@ use serde::Deserialize; use sha2::{Digest, Sha256}; use std::io::BufReader; use std::path::{Path, PathBuf}; +use std::sync::Arc; use std::{fmt, io}; use thiserror::Error; use tokio::io::AsyncReadExt; +use tokio_util::io::ReaderStream; use tracing::{debug, enabled, trace, Level}; use url::Url; use uv_client::BaseClient; -use uv_fs::Simplified; +use uv_fs::{ProgressReader, Simplified}; use uv_metadata::read_metadata_async_seek; #[derive(Error, Debug)] @@ -79,6 +81,13 @@ pub enum PublishSendError { RedirectError(Url), } +pub trait Reporter: Send + Sync + 'static { + fn on_progress(&self, name: &str, id: usize); + fn on_download_start(&self, name: &str, size: Option) -> usize; + fn on_download_progress(&self, id: usize, inc: u64); + fn on_download_complete(&self); +} + impl PublishSendError { /// Extract `code` from the PyPI json error response, if any. /// @@ -212,6 +221,7 @@ pub async fn upload( client: &BaseClient, username: Option<&str>, password: Option<&str>, + reporter: Arc, ) -> Result { let form_metadata = form_metadata(file, filename) .await @@ -224,6 +234,7 @@ pub async fn upload( username, password, form_metadata, + reporter, ) .await .map_err(|err| PublishError::PublishPrepare(file.to_path_buf(), Box::new(err)))?; @@ -396,18 +407,23 @@ async fn build_request( username: Option<&str>, password: Option<&str>, form_metadata: Vec<(&'static str, String)>, + reporter: Arc, ) -> Result { let mut form = reqwest::multipart::Form::new(); for (key, value) in form_metadata { form = form.text(key, value); } - let file: tokio::fs::File = fs_err::tokio::File::open(file).await?.into(); - let file_reader = Body::from(file); - form = form.part( - "content", - Part::stream(file_reader).file_name(filename.to_string()), - ); + let file = fs_err::tokio::File::open(file).await?; + let idx = reporter.on_download_start(&filename.to_string(), Some(file.metadata().await?.len())); + let reader = ProgressReader::new(file, move |read| { + reporter.on_download_progress(idx, read as u64); + }); + // Stream wrapping puts a static lifetime requirement on the reader (so the request doesn't have + // a lifetime) -> callback needs to be static -> reporter reference needs to be Arc'd. + let file_reader = Body::wrap_stream(ReaderStream::new(reader)); + let part = Part::stream(file_reader).file_name(filename.to_string()); + form = form.part("content", part); let url = if let Some(username) = username { if password.is_none() { @@ -525,14 +541,26 @@ async fn handle_response(registry: &Url, response: Response) -> Result) -> usize { + 0 + } + fn on_download_progress(&self, _id: usize, _inc: u64) {} + fn on_download_complete(&self) {} + } + /// Snapshot the data we send for an upload request for a source distribution. #[tokio::test] async fn upload_request_source_dist() { @@ -602,6 +630,7 @@ mod tests { Some("ferris"), Some("F3RR!S"), form_metadata, + Arc::new(DummyReporter), ) .await .unwrap(); @@ -744,6 +773,7 @@ mod tests { Some("ferris"), Some("F3RR!S"), form_metadata, + Arc::new(DummyReporter), ) .await .unwrap(); diff --git a/crates/uv/src/commands/publish.rs b/crates/uv/src/commands/publish.rs index 84b5be894446..b675721ce61b 100644 --- a/crates/uv/src/commands/publish.rs +++ b/crates/uv/src/commands/publish.rs @@ -1,8 +1,10 @@ +use crate::commands::reporters::PublishReporter; use crate::commands::{human_readable_bytes, ExitStatus}; use crate::printer::Printer; use anyhow::{bail, Result}; use owo_colors::OwoColorize; use std::fmt::Write; +use std::sync::Arc; use tracing::info; use url::Url; use uv_client::{BaseClientBuilder, Connectivity}; @@ -51,6 +53,7 @@ pub(crate) async fn publish( "Uploading".bold().green(), format!("({bytes:.1}{unit})").dimmed() )?; + let reporter = PublishReporter::single(printer); let uploaded = upload( &file, &filename, @@ -58,6 +61,8 @@ pub(crate) async fn publish( &client, username.as_deref(), password.as_deref(), + // Needs to be an `Arc` because the reqwest `Body` static lifetime requirement + Arc::new(reporter), ) .await?; // Filename and/or URL are already attached, if applicable. info!("Upload succeeded"); diff --git a/crates/uv/src/commands/reporters.rs b/crates/uv/src/commands/reporters.rs index 6a037765d00c..d81880aead8d 100644 --- a/crates/uv/src/commands/reporters.rs +++ b/crates/uv/src/commands/reporters.rs @@ -143,9 +143,10 @@ impl ProgressReporter { ); if size.is_some() { + // We're using binary bytes to match `human_readable_bytes`. progress.set_style( ProgressStyle::with_template( - "{msg:10.dim} {bar:30.green/dim} {decimal_bytes:>7}/{decimal_total_bytes:7}", + "{msg:10.dim} {bar:30.green/dim} {binary_bytes:>7}/{binary_total_bytes:7}", ) .unwrap() .progress_chars("--"), @@ -485,6 +486,48 @@ impl uv_python::downloads::Reporter for PythonDownloadReporter { } } +#[derive(Debug)] +pub(crate) struct PublishReporter { + reporter: ProgressReporter, +} + +impl PublishReporter { + /// Initialize a [`PublishReporter`] for a single upload. + pub(crate) fn single(printer: Printer) -> Self { + Self::new(printer, 1) + } + + /// Initialize a [`PublishReporter`] for multiple uploads. + pub(crate) fn new(printer: Printer, length: u64) -> Self { + let multi_progress = MultiProgress::with_draw_target(printer.target()); + let root = multi_progress.add(ProgressBar::with_draw_target( + Some(length), + printer.target(), + )); + let reporter = ProgressReporter::new(root, multi_progress, printer); + Self { reporter } + } +} + +impl uv_publish::Reporter for PublishReporter { + fn on_progress(&self, _name: &str, id: usize) { + self.reporter.on_download_complete(id); + } + + fn on_download_start(&self, name: &str, size: Option) -> usize { + self.reporter.on_download_start(name.to_string(), size) + } + + fn on_download_progress(&self, id: usize, inc: u64) { + self.reporter.on_download_progress(id, inc); + } + + fn on_download_complete(&self) { + self.reporter.root.set_message(""); + self.reporter.root.finish_and_clear(); + } +} + /// Like [`std::fmt::Display`], but with colors. trait ColorDisplay { fn to_color_string(&self) -> String;