From b0a2ed0e7a2206300aee9dfe881a427616baa9da Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 25 Dec 2024 12:27:37 +0100 Subject: [PATCH] Tmp. Custom progressbar. Allow internal state. More API. Clippy. Remove print statements Using stream feature for shorter lived progressbars. Moved muliplexing behind less obvious builder (needs more testing to showcase pros and cons). --- Cargo.toml | 3 +- README.md | 2 +- src/api/mod.rs | 47 +++++++++++ src/api/sync.rs | 137 +++++++++++++++++++------------ src/api/tokio.rs | 208 +++++++++++++++++++++++++++++++++++------------ 5 files changed, 289 insertions(+), 108 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c24367e..fce8734 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ num_cpus = { version = "1.15.0", optional = true } rand = { version = "0.8.5", optional = true } reqwest = { version = "0.12.2", optional = true, default-features = false, features = [ "json", + "stream", ] } rustls = { version = "0.23.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } @@ -41,7 +42,7 @@ native-tls = { version = "0.2.12", optional = true } [features] default = ["default-tls", "tokio", "ureq"] # These features are only relevant when used with the `tokio` feature, but this might change in the future. -default-tls = [] +default-tls = ["native-tls"] native-tls = ["dep:reqwest", "reqwest/default", "dep:native-tls", "ureq/native-tls"] rustls-tls = ["dep:rustls", "reqwest/rustls-tls"] tokio = [ diff --git a/README.md b/README.md index 6876c46..6d0b0d9 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ let _filename = repo.get("config.json").unwrap(); # SSL/TLS -This library uses its dependencies' default TLS implementations which are `rustls` for `ureq` (sync) and `native-tls` (openssl) for `tokio`. +This library uses tokio default TLS implementations which is `native-tls` (openssl) for `tokio`. If you want control over the TLS backend you can remove the default features and only add the backend you are intending to use. diff --git a/src/api/mod.rs b/src/api/mod.rs index ef738ca..a5bc6a4 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,4 @@ +use indicatif::{ProgressBar, ProgressStyle}; use serde::Deserialize; /// The asynchronous version of the API @@ -8,6 +9,52 @@ pub mod tokio; #[cfg(feature = "ureq")] pub mod sync; +/// This trait is used by users of the lib +/// to implement custom behavior during file downloads +pub trait Progress { + /// At the start of the download + /// The size is the total size in bytes of the file. + fn init(&mut self, size: usize, filename: &str); + /// This function is called whenever `size` bytes have been + /// downloaded in the temporary file + fn update(&mut self, size: usize); + /// This is called at the end of the download + fn finish(&mut self); +} + +impl Progress for () { + fn init(&mut self, _size: usize, _filename: &str) {} + fn update(&mut self, _size: usize) {} + fn finish(&mut self) {} +} + +impl Progress for ProgressBar { + fn init(&mut self, size: usize, filename: &str) { + self.set_length(size as u64); + self.set_style( + ProgressStyle::with_template( + "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", + ) + .unwrap(), // .progress_chars("━ "), + ); + let maxlength = 30; + let message = if filename.len() > maxlength { + format!("..{}", &filename[filename.len() - maxlength..]) + } else { + filename.to_string() + }; + self.set_message(message); + } + + fn update(&mut self, size: usize) { + self.inc(size as u64) + } + + fn finish(&mut self) { + ProgressBar::finish(self); + } +} + /// Siblings are simplified file descriptions of remote files on the hub #[derive(Debug, Clone, Deserialize, PartialEq)] pub struct Siblings { diff --git a/src/api/sync.rs b/src/api/sync.rs index 31d627f..bbf6ebb 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -1,10 +1,12 @@ use super::RepoInfo; use crate::api::sync::ApiError::InvalidHeader; +use crate::api::Progress; use crate::{Cache, Repo, RepoType}; use http::{StatusCode, Uri}; -use indicatif::{ProgressBar, ProgressStyle}; +use indicatif::ProgressBar; use rand::Rng; use std::collections::HashMap; +use std::io::Read; use std::io::Seek; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; @@ -26,6 +28,23 @@ const AUTHORIZATION: &str = "Authorization"; type HeaderMap = HashMap<&'static str, String>; type HeaderName = &'static str; +struct Wrapper { + progress: P, + inner: R, +} + +fn wrap_read(inner: R, progress: P) -> Wrapper { + Wrapper { inner, progress } +} + +impl Read for Wrapper { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let read = self.inner.read(buf)?; + self.progress.update(read); + Ok(read) + } +} + /// Simple wrapper over [`ureq::Agent`] to include default headers #[derive(Clone, Debug)] pub struct HeaderAgent { @@ -92,7 +111,6 @@ pub enum ApiError { pub struct ApiBuilder { endpoint: String, cache: Cache, - url_template: String, token: Option, max_retries: usize, progress: bool, @@ -128,12 +146,10 @@ impl ApiBuilder { let max_retries = 0; let progress = true; - let endpoint = - std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_owned()); + let endpoint = "https://huggingface.co".to_string(); Self { endpoint, - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), cache, token, max_retries, @@ -197,10 +213,8 @@ impl ApiBuilder { Ok(Api { endpoint: self.endpoint, - url_template: self.url_template, cache: self.cache, client, - no_redirect_client, max_retries: self.max_retries, progress: self.progress, @@ -216,12 +230,10 @@ struct Metadata { } /// The actual Api used to interacto with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] +/// Use any repo with [`Api::repo`] #[derive(Clone, Debug)] pub struct Api { endpoint: String, - url_template: String, cache: Cache, client: HeaderAgent, no_redirect_client: HeaderAgent, @@ -402,10 +414,10 @@ impl Api { }) } - fn download_tempfile( + fn download_tempfile( &self, url: &str, - progressbar: Option, + progress: P, ) -> Result { let filename = self.cache.temp_path(); @@ -434,16 +446,16 @@ impl Api { let response = self.client.get(url).call().map_err(Box::new)?; - let mut reader = response.into_reader(); - if let Some(p) = &progressbar { - reader = Box::new(p.wrap_read(reader)); - } + let reader = response.into_reader(); + //if let Some(p) = &progressbar { + let mut reader = Box::new(wrap_read(reader, progress)); + //} std::io::copy(&mut reader, &mut file)?; - if let Some(p) = progressbar { - p.finish(); - } + //if let Some(p) = progressbar { + reader.progress.finish(); + //} Ok(filename) } @@ -506,6 +518,8 @@ impl Api { } /// Shorthand for accessing things within a particular repo +/// You can inspect repos with [`ApiRepo::info`] +/// or download files with [`ApiRepo::download`] #[derive(Debug)] pub struct ApiRepo { api: Api, @@ -541,12 +555,8 @@ impl ApiRepo { pub fn url(&self, filename: &str) -> String { let endpoint = &self.api.endpoint; let revision = &self.repo.url_revision(); - self.api - .url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &self.repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) + let repo_id = self.repo.url(); + format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}") } /// This will attempt the fetch the file locally first, then [`Api.download`] @@ -563,16 +573,40 @@ impl ApiRepo { } } - /// Downloads a remote file (if not already present) into the cache directory - /// to be used locally. - /// This functions require internet access to verify if new versions of the file - /// exist, even if a file is already on disk at location. + /// This function is used to download a file with a custom progress function. + /// It uses the [`Progress`] trait and can be used in more complex use + /// cases like downloading a showing progress in a UI. /// ```no_run - /// # use hf_hub::api::sync::Api; + /// # use hf_hub::api::{sync::Api, Progress}; + /// struct MyProgress{ + /// current: usize, + /// total: usize + /// } + /// + /// impl Progress for MyProgress{ + /// fn init(&mut self, size: usize, _filename: &str){ + /// self.total = size; + /// self.current = 0; + /// } + /// + /// fn update(&mut self, size: usize){ + /// self.current += size; + /// println!("{}/{}", self.current, self.total) + /// } + /// + /// fn finish(&mut self){ + /// println!("Done !"); + /// } + /// } /// let api = Api::new().unwrap(); - /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap(); + /// let progress = MyProgress{current: 0, total: 0}; + /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).unwrap(); /// ``` - pub fn download(&self, filename: &str) -> Result { + pub fn download_with_progress( + &self, + filename: &str, + mut progress: P, + ) -> Result { let url = self.url(filename); let metadata = self.api.metadata(&url)?; @@ -583,27 +617,9 @@ impl ApiRepo { .blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let progressbar = if self.api.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; + progress.init(metadata.size, filename); - let tmp_filename = self.api.download_tempfile(&url, progressbar)?; + let tmp_filename = self.api.download_tempfile(&url, progress)?; std::fs::rename(tmp_filename, &blob_path)?; @@ -624,6 +640,23 @@ impl ApiRepo { Ok(pointer_path) } + /// Downloads a remote file (if not already present) into the cache directory + /// to be used locally. + /// This functions require internet access to verify if new versions of the file + /// exist, even if a file is already on disk at location. + /// ```no_run + /// # use hf_hub::api::sync::Api; + /// let api = Api::new().unwrap(); + /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap(); + /// ``` + pub fn download(&self, filename: &str) -> Result { + if self.api.progress { + self.download_with_progress(filename, ProgressBar::new(0)) + } else { + self.download_with_progress(filename, ()) + } + } + /// Get information about the Repo /// ``` /// use hf_hub::{api::sync::Api}; diff --git a/src/api/tokio.rs b/src/api/tokio.rs index b3f4959..15d06f9 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,6 +1,8 @@ +use super::Progress as SyncProgress; use super::RepoInfo; use crate::{Cache, Repo, RepoType}; -use indicatif::{ProgressBar, ProgressStyle}; +use futures::StreamExt; +use indicatif::ProgressBar; use rand::Rng; use reqwest::{ header::{ @@ -22,6 +24,38 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); /// Current name (used in user-agent) const NAME: &str = env!("CARGO_PKG_NAME"); +/// This trait is used by users of the lib +/// to implement custom behavior during file downloads +pub trait Progress { + /// At the start of the download + /// The size is the total size in bytes of the file. + fn init(&mut self, size: usize, filename: &str) + -> impl std::future::Future + Send; + /// This function is called whenever `size` bytes have been + /// downloaded in the temporary file + fn update(&mut self, size: usize) -> impl std::future::Future + Send; + /// This is called at the end of the download + fn finish(&mut self) -> impl std::future::Future + Send; +} + +impl Progress for ProgressBar { + async fn init(&mut self, size: usize, filename: &str) { + ::init(self, size, filename); + } + async fn finish(&mut self) { + ::finish(self); + } + async fn update(&mut self, size: usize) { + ::update(self, size); + } +} + +impl Progress for () { + async fn init(&mut self, _size: usize, _filename: &str) {} + async fn finish(&mut self) {} + async fn update(&mut self, _size: usize) {} +} + #[derive(Debug, Error)] /// All errors the API can throw pub enum ApiError { @@ -74,10 +108,9 @@ pub enum ApiError { pub struct ApiBuilder { endpoint: String, cache: Cache, - url_template: String, token: Option, max_files: usize, - chunk_size: usize, + chunk_size: Option, parallel_failures: usize, max_retries: usize, progress: bool, @@ -100,6 +133,23 @@ impl ApiBuilder { Self::from_cache(cache) } + /// High CPU download + /// + /// This may cause issues on regular desktops as it will saturate + /// CPUs by multiplexing the downloads. + /// However on high CPU machines on the cloud, this may help + /// saturate the bandwidth (>500MB/s) better. + /// ``` + /// use hf_hub::api::tokio::ApiBuilder; + /// let api = ApiBuilder::high().build().unwrap(); + /// ``` + pub fn high() -> Self { + let cache = Cache::default(); + Self::from_cache(cache) + .with_max_files(num_cpus::get()) + .with_chunk_size(Some(10_000_000)) + } + /// From a given cache /// ``` /// use hf_hub::{api::tokio::ApiBuilder, Cache}; @@ -114,11 +164,11 @@ impl ApiBuilder { Self { endpoint: "https://huggingface.co".to_string(), - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), cache, token, - max_files: num_cpus::get(), - chunk_size: 10_000_000, + max_files: 1, + // chunk_size: 10_000_000, + chunk_size: None, parallel_failures: 0, max_retries: 0, progress, @@ -130,7 +180,7 @@ impl ApiBuilder { self.progress = progress; self } - + /// Changes the endpoint of the API. Default is `https://huggingface.co`. pub fn with_endpoint(mut self, endpoint: String) -> Self { self.endpoint = endpoint; @@ -149,6 +199,18 @@ impl ApiBuilder { self } + /// Sets the number of open files + pub fn with_max_files(mut self, max_files: usize) -> Self { + self.max_files = max_files; + self + } + + /// Sets the size of each chunk + pub fn with_chunk_size(mut self, chunk_size: Option) -> Self { + self.chunk_size = chunk_size; + self + } + fn build_headers(&self) -> Result { let mut headers = HeaderMap::new(); let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); @@ -192,7 +254,6 @@ impl ApiBuilder { .build()?; Ok(Api { endpoint: self.endpoint, - url_template: self.url_template, cache: self.cache, client, relative_redirect_client, @@ -213,17 +274,15 @@ struct Metadata { } /// The actual Api used to interact with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] +/// Use any repo with [`Api::repo`] #[derive(Clone, Debug)] pub struct Api { endpoint: String, - url_template: String, cache: Cache, client: Client, relative_redirect_client: Client, max_files: usize, - chunk_size: usize, + chunk_size: Option, parallel_failures: usize, max_retries: usize, progress: bool, @@ -400,6 +459,8 @@ impl Api { } /// Shorthand for accessing things within a particular repo +/// You can inspect repos with [`ApiRepo::info`] +/// or download files with [`ApiRepo::download`] #[derive(Debug)] pub struct ApiRepo { api: Api, @@ -423,19 +484,15 @@ impl ApiRepo { pub fn url(&self, filename: &str) -> String { let endpoint = &self.api.endpoint; let revision = &self.repo.url_revision(); - self.api - .url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &self.repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) + let repo_id = self.repo.url(); + format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}") } - async fn download_tempfile( + async fn download_tempfile<'a, P: Progress + Clone + Send + Sync + 'static>( &self, url: &str, length: usize, - progressbar: Option, + mut progressbar: P, ) -> Result { let mut handles = vec![]; let semaphore = Arc::new(Semaphore::new(self.api.max_files)); @@ -448,7 +505,7 @@ impl ApiRepo { .set_len(length as u64) .await?; - let chunk_size = self.api.chunk_size; + let chunk_size = self.api.chunk_size.unwrap_or(length); for start in (0..length).step_by(chunk_size) { let url = url.to_string(); let filename = filename.clone(); @@ -461,7 +518,9 @@ impl ApiRepo { let parallel_failures_semaphore = parallel_failures_semaphore.clone(); let progress = progressbar.clone(); handles.push(tokio::spawn(async move { - let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; + let mut chunk = + Self::download_chunk(&client, &url, &filename, start, stop, progress.clone()) + .await; let mut i = 0; if parallel_failures > 0 { while let Err(dlerr) = chunk { @@ -472,7 +531,15 @@ impl ApiRepo { tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64)) .await; - chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; + chunk = Self::download_chunk( + &client, + &url, + &filename, + start, + stop, + progress.clone(), + ) + .await; i += 1; if i > max_retries { return Err(ApiError::TooManyRetries(dlerr.into())); @@ -481,9 +548,9 @@ impl ApiRepo { } } drop(permit); - if let Some(p) = progress { - p.inc((stop - start) as u64); - } + // if let Some(p) = progress { + // progress.update(stop - start).await; + // } chunk })); } @@ -493,19 +560,21 @@ impl ApiRepo { futures::future::join_all(handles).await; let results: Result<(), ApiError> = results.into_iter().flatten().collect(); results?; - if let Some(p) = progressbar { - p.finish(); - } + progressbar.finish().await; Ok(filename) } - async fn download_chunk( + async fn download_chunk

( client: &reqwest::Client, url: &str, filename: &PathBuf, start: usize, stop: usize, - ) -> Result<(), ApiError> { + mut progress: P, + ) -> Result<(), ApiError> + where + P: Progress, + { // Process each socket concurrently. let range = format!("bytes={start}-{stop}"); let mut file = tokio::fs::OpenOptions::new() @@ -519,8 +588,12 @@ impl ApiRepo { .send() .await? .error_for_status()?; - let content = response.bytes().await?; - file.write_all(&content).await?; + let mut byte_stream = response.bytes_stream(); + while let Some(next) = byte_stream.next().await { + let next = next?; + file.write_all(&next).await?; + progress.update(next.len()).await; + } Ok(()) } @@ -552,6 +625,52 @@ impl ApiRepo { /// # }) /// ``` pub async fn download(&self, filename: &str) -> Result { + if self.api.progress { + self.download_with_progress(filename, ProgressBar::new(0)) + .await + } else { + self.download_with_progress(filename, ()).await + } + } + + /// This function is used to download a file with a custom progress function. + /// It uses the [`Progress`] trait and can be used in more complex use + /// cases like downloading a showing progress in a UI. + /// ```no_run + /// use hf_hub::api::tokio::{Api, Progress}; + /// + /// #[derive(Clone)] + /// struct MyProgress{ + /// current: usize, + /// total: usize + /// } + /// + /// impl Progress for MyProgress{ + /// async fn init(&mut self, size: usize, _filename: &str){ + /// self.total = size; + /// self.current = 0; + /// } + /// + /// async fn update(&mut self, size: usize){ + /// self.current += size; + /// println!("{}/{}", self.current, self.total) + /// } + /// + /// async fn finish(&mut self){ + /// println!("Done !"); + /// } + /// } + /// # tokio_test::block_on(async { + /// let api = Api::new().unwrap(); + /// let progress = MyProgress{ current: 0, total : 0}; + /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).await.unwrap(); + /// # }) + /// ``` + pub async fn download_with_progress( + &self, + filename: &str, + mut progress: P, + ) -> Result { let url = self.url(filename); let metadata = self.api.metadata(&url).await?; let cache = self.api.cache.repo(self.repo.clone()); @@ -559,28 +678,9 @@ impl ApiRepo { let blob_path = cache.blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let progressbar = if self.api.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; - + progress.init(metadata.size, filename).await; let tmp_filename = self - .download_tempfile(&url, metadata.size, progressbar) + .download_tempfile(&url, metadata.size, progress) .await?; tokio::fs::rename(&tmp_filename, &blob_path).await?;