diff --git a/Cargo.toml b/Cargo.toml index c24367e..fce8734 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ num_cpus = { version = "1.15.0", optional = true } rand = { version = "0.8.5", optional = true } reqwest = { version = "0.12.2", optional = true, default-features = false, features = [ "json", + "stream", ] } rustls = { version = "0.23.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } @@ -41,7 +42,7 @@ native-tls = { version = "0.2.12", optional = true } [features] default = ["default-tls", "tokio", "ureq"] # These features are only relevant when used with the `tokio` feature, but this might change in the future. -default-tls = [] +default-tls = ["native-tls"] native-tls = ["dep:reqwest", "reqwest/default", "dep:native-tls", "ureq/native-tls"] rustls-tls = ["dep:rustls", "reqwest/rustls-tls"] tokio = [ diff --git a/README.md b/README.md index 6876c46..6d0b0d9 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ let _filename = repo.get("config.json").unwrap(); # SSL/TLS -This library uses its dependencies' default TLS implementations which are `rustls` for `ureq` (sync) and `native-tls` (openssl) for `tokio`. +This library uses tokio default TLS implementations which is `native-tls` (openssl) for `tokio`. If you want control over the TLS backend you can remove the default features and only add the backend you are intending to use. diff --git a/examples/download.rs b/examples/download.rs index 7aea0c6..c59a663 100644 --- a/examples/download.rs +++ b/examples/download.rs @@ -1,10 +1,9 @@ #[cfg(not(feature = "ureq"))] -#[cfg(not(feature="tokio"))] -fn main() { -} +#[cfg(not(feature = "tokio"))] +fn main() {} #[cfg(feature = "ureq")] -#[cfg(not(feature="tokio"))] +#[cfg(not(feature = "tokio"))] fn main() { let api = hf_hub::api::sync::Api::new().unwrap(); diff --git a/examples/iced/.gitignore b/examples/iced/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/examples/iced/.gitignore @@ -0,0 +1 @@ +/target diff --git a/examples/iced/Cargo.toml b/examples/iced/Cargo.toml new file mode 100644 index 0000000..86bf96b --- /dev/null +++ b/examples/iced/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "iced_hf_hub" +version = "0.1.0" +edition = "2021" + +[dependencies] +iced = { version = "0.13.1", features = ["tokio"] } +hf-hub = { path = "../../", default-features = false, features = ["tokio", "rustls-tls"] } diff --git a/examples/iced/src/main.rs b/examples/iced/src/main.rs new file mode 100644 index 0000000..a9e0d43 --- /dev/null +++ b/examples/iced/src/main.rs @@ -0,0 +1,244 @@ +use hf_hub::api::tokio::{Api, ApiError}; +use iced::futures::{SinkExt, Stream}; +use iced::stream::try_channel; +use iced::task; +use iced::widget::{button, center, column, progress_bar, text, Column}; + +use iced::{Center, Element, Right, Task}; + +#[derive(Debug, Clone)] +pub enum Progress { + Downloading { current: usize, total: usize }, + Finished, +} + +#[derive(Debug, Clone)] +pub enum Error { + Api(String), +} + +impl From for Error { + fn from(value: ApiError) -> Self { + Self::Api(value.to_string()) + } +} + +pub fn main() -> iced::Result { + iced::application("Download Progress - Iced", Example::update, Example::view).run() +} + +#[derive(Debug)] +struct Example { + downloads: Vec, + last_id: usize, +} + +#[derive(Clone)] +struct Prog { + output: iced::futures::channel::mpsc::Sender, + total: usize, +} + +impl hf_hub::api::tokio::Progress for Prog { + async fn update(&mut self, size: usize) { + let _ = self + .output + .send(Progress::Downloading { + current: size, + total: self.total, + }) + .await; + } + async fn finish(&mut self) { + let _ = self.output.send(Progress::Finished).await; + } + + async fn init(&mut self, size: usize, _filename: &str) { + println!("Initiating {size}"); + let _ = self + .output + .send(Progress::Downloading { + current: 0, + total: size, + }) + .await; + self.total = size; + } +} + +pub fn download( + repo: String, + filename: impl AsRef, +) -> impl Stream> { + try_channel(1, move |output| async move { + let prog = Prog { output, total: 0 }; + + let api = Api::new().unwrap().model(repo); + api.download_with_progress(filename.as_ref(), prog).await?; + + Ok(()) + }) +} + +#[derive(Debug, Clone)] +pub enum Message { + Add, + Download(usize), + DownloadProgressed(usize, Result), +} + +impl Example { + fn new() -> Self { + Self { + downloads: vec![Download::new(0)], + last_id: 0, + } + } + + fn update(&mut self, message: Message) -> Task { + match message { + Message::Add => { + self.last_id += 1; + + self.downloads.push(Download::new(self.last_id)); + + Task::none() + } + Message::Download(index) => { + let Some(download) = self.downloads.get_mut(index) else { + return Task::none(); + }; + + let task = download.start(); + + task.map(move |progress| Message::DownloadProgressed(index, progress)) + } + Message::DownloadProgressed(id, progress) => { + if let Some(download) = self.downloads.iter_mut().find(|download| download.id == id) + { + download.progress(progress); + } + + Task::none() + } + } + } + + fn view(&self) -> Element { + let downloads = Column::with_children(self.downloads.iter().map(Download::view)) + .push( + button("Add another download") + .on_press(Message::Add) + .padding(10), + ) + .spacing(20) + .align_x(Right); + + center(downloads).padding(20).into() + } +} + +impl Default for Example { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +struct Download { + id: usize, + state: State, +} + +#[derive(Debug)] +enum State { + Idle, + Downloading { progress: f32, _task: task::Handle }, + Finished, + Errored, +} + +impl Download { + pub fn new(id: usize) -> Self { + Download { + id, + state: State::Idle, + } + } + + pub fn start(&mut self) -> Task> { + match self.state { + State::Idle { .. } | State::Finished { .. } | State::Errored { .. } => { + let (task, handle) = Task::stream(download( + "mattshumer/Reflection-Llama-3.1-70B".to_string(), + "model-00001-of-00162.safetensors", + )) + .abortable(); + + self.state = State::Downloading { + progress: 0.0, + _task: handle.abort_on_drop(), + }; + + task + } + State::Downloading { .. } => Task::none(), + } + } + + pub fn progress(&mut self, new_progress: Result) { + if let State::Downloading { progress, .. } = &mut self.state { + match new_progress { + Ok(Progress::Downloading { current, total }) => { + println!("Status {progress} - {current}"); + let new_progress = current as f32 / total as f32 * 100.0; + println!("New progress {current} {new_progress}"); + *progress += new_progress; + } + Ok(Progress::Finished) => { + self.state = State::Finished; + } + Err(_error) => { + self.state = State::Errored; + } + } + } + } + + pub fn view(&self) -> Element { + let current_progress = match &self.state { + State::Idle { .. } => 0.0, + State::Downloading { progress, .. } => *progress, + State::Finished { .. } => 100.0, + State::Errored { .. } => 0.0, + }; + + let progress_bar = progress_bar(0.0..=100.0, current_progress); + + let control: Element<_> = match &self.state { + State::Idle => button("Start the download!") + .on_press(Message::Download(self.id)) + .into(), + State::Finished => column!["Download finished!", button("Start again")] + .spacing(10) + .align_x(Center) + .into(), + State::Downloading { .. } => text!("Downloading... {current_progress:.2}%").into(), + State::Errored => column![ + "Something went wrong :(", + button("Try again").on_press(Message::Download(self.id)), + ] + .spacing(10) + .align_x(Center) + .into(), + }; + + Column::new() + .spacing(10) + .padding(10) + .align_x(Center) + .push(progress_bar) + .push(control) + .into() + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs index ef738ca..a5bc6a4 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,4 @@ +use indicatif::{ProgressBar, ProgressStyle}; use serde::Deserialize; /// The asynchronous version of the API @@ -8,6 +9,52 @@ pub mod tokio; #[cfg(feature = "ureq")] pub mod sync; +/// This trait is used by users of the lib +/// to implement custom behavior during file downloads +pub trait Progress { + /// At the start of the download + /// The size is the total size in bytes of the file. + fn init(&mut self, size: usize, filename: &str); + /// This function is called whenever `size` bytes have been + /// downloaded in the temporary file + fn update(&mut self, size: usize); + /// This is called at the end of the download + fn finish(&mut self); +} + +impl Progress for () { + fn init(&mut self, _size: usize, _filename: &str) {} + fn update(&mut self, _size: usize) {} + fn finish(&mut self) {} +} + +impl Progress for ProgressBar { + fn init(&mut self, size: usize, filename: &str) { + self.set_length(size as u64); + self.set_style( + ProgressStyle::with_template( + "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", + ) + .unwrap(), // .progress_chars("━ "), + ); + let maxlength = 30; + let message = if filename.len() > maxlength { + format!("..{}", &filename[filename.len() - maxlength..]) + } else { + filename.to_string() + }; + self.set_message(message); + } + + fn update(&mut self, size: usize) { + self.inc(size as u64) + } + + fn finish(&mut self) { + ProgressBar::finish(self); + } +} + /// Siblings are simplified file descriptions of remote files on the hub #[derive(Debug, Clone, Deserialize, PartialEq)] pub struct Siblings { diff --git a/src/api/sync.rs b/src/api/sync.rs index 31d627f..a0f96b2 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -1,10 +1,12 @@ use super::RepoInfo; use crate::api::sync::ApiError::InvalidHeader; +use crate::api::Progress; use crate::{Cache, Repo, RepoType}; use http::{StatusCode, Uri}; -use indicatif::{ProgressBar, ProgressStyle}; +use indicatif::ProgressBar; use rand::Rng; use std::collections::HashMap; +use std::io::Read; use std::io::Seek; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; @@ -26,6 +28,23 @@ const AUTHORIZATION: &str = "Authorization"; type HeaderMap = HashMap<&'static str, String>; type HeaderName = &'static str; +struct Wrapper<'a, P: Progress, R: Read> { + progress: &'a mut P, + inner: R, +} + +fn wrap_read(inner: R, progress: &mut P) -> Wrapper { + Wrapper { inner, progress } +} + +impl Read for Wrapper<'_, P, R> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let read = self.inner.read(buf)?; + self.progress.update(read); + Ok(read) + } +} + /// Simple wrapper over [`ureq::Agent`] to include default headers #[derive(Clone, Debug)] pub struct HeaderAgent { @@ -92,7 +111,6 @@ pub enum ApiError { pub struct ApiBuilder { endpoint: String, cache: Cache, - url_template: String, token: Option, max_retries: usize, progress: bool, @@ -128,12 +146,10 @@ impl ApiBuilder { let max_retries = 0; let progress = true; - let endpoint = - std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_owned()); + let endpoint = "https://huggingface.co".to_string(); Self { endpoint, - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), cache, token, max_retries, @@ -197,10 +213,8 @@ impl ApiBuilder { Ok(Api { endpoint: self.endpoint, - url_template: self.url_template, cache: self.cache, client, - no_redirect_client, max_retries: self.max_retries, progress: self.progress, @@ -216,12 +230,10 @@ struct Metadata { } /// The actual Api used to interacto with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] +/// Use any repo with [`Api::repo`] #[derive(Clone, Debug)] pub struct Api { endpoint: String, - url_template: String, cache: Cache, client: HeaderAgent, no_redirect_client: HeaderAgent, @@ -402,57 +414,47 @@ impl Api { }) } - fn download_tempfile( + fn download_tempfile( &self, url: &str, - progressbar: Option, + mut progress: P, + filename: &str, ) -> Result { - let filename = self.cache.temp_path(); + let filepath = self.cache.temp_path(); // Create the file and set everything properly - let mut file = std::fs::File::create(&filename)?; + let mut file = std::fs::File::create(&filepath)?; + let mut res = self.download_from(url, 0u64, &mut file, filename, &mut progress); if self.max_retries > 0 { let mut i = 0; - let mut res = self.download_from(url, 0u64, &mut file); while let Err(dlerr) = res { let wait_time = exponential_backoff(300, i, 10_000); std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); - res = self.download_from(url, file.stream_position()?, &mut file); + let current = file.stream_position()?; + res = self.download_from(url, current, &mut file, filename, &mut progress); i += 1; if i > self.max_retries { return Err(ApiError::TooManyRetries(dlerr.into())); } } - res?; - if let Some(p) = progressbar { - p.finish() - } - return Ok(filename); } - - let response = self.client.get(url).call().map_err(Box::new)?; - - let mut reader = response.into_reader(); - if let Some(p) = &progressbar { - reader = Box::new(p.wrap_read(reader)); - } - - std::io::copy(&mut reader, &mut file)?; - - if let Some(p) = progressbar { - p.finish(); - } - Ok(filename) + res?; + Ok(filepath) } - fn download_from( + fn download_from

( &self, url: &str, current: u64, file: &mut std::fs::File, - ) -> Result<(), ApiError> { + filename: &str, + progress: &mut P, + ) -> Result<(), ApiError> + where + P: Progress, + { let range = format!("bytes={current}-"); let response = self .client @@ -460,8 +462,12 @@ impl Api { .set(RANGE, &range) .call() .map_err(Box::new)?; - let mut reader = response.into_reader(); + let reader = response.into_reader(); + progress.init(0, filename); + progress.update(current as usize); + let mut reader = Box::new(wrap_read(reader, progress)); std::io::copy(&mut reader, file)?; + progress.finish(); Ok(()) } @@ -506,6 +512,8 @@ impl Api { } /// Shorthand for accessing things within a particular repo +/// You can inspect repos with [`ApiRepo::info`] +/// or download files with [`ApiRepo::download`] #[derive(Debug)] pub struct ApiRepo { api: Api, @@ -541,12 +549,8 @@ impl ApiRepo { pub fn url(&self, filename: &str) -> String { let endpoint = &self.api.endpoint; let revision = &self.repo.url_revision(); - self.api - .url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &self.repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) + let repo_id = self.repo.url(); + format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}") } /// This will attempt the fetch the file locally first, then [`Api.download`] @@ -563,16 +567,40 @@ impl ApiRepo { } } - /// Downloads a remote file (if not already present) into the cache directory - /// to be used locally. - /// This functions require internet access to verify if new versions of the file - /// exist, even if a file is already on disk at location. + /// This function is used to download a file with a custom progress function. + /// It uses the [`Progress`] trait and can be used in more complex use + /// cases like downloading a showing progress in a UI. /// ```no_run - /// # use hf_hub::api::sync::Api; + /// # use hf_hub::api::{sync::Api, Progress}; + /// struct MyProgress{ + /// current: usize, + /// total: usize + /// } + /// + /// impl Progress for MyProgress{ + /// fn init(&mut self, size: usize, _filename: &str){ + /// self.total = size; + /// self.current = 0; + /// } + /// + /// fn update(&mut self, size: usize){ + /// self.current += size; + /// println!("{}/{}", self.current, self.total) + /// } + /// + /// fn finish(&mut self){ + /// println!("Done !"); + /// } + /// } /// let api = Api::new().unwrap(); - /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap(); + /// let progress = MyProgress{current: 0, total: 0}; + /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).unwrap(); /// ``` - pub fn download(&self, filename: &str) -> Result { + pub fn download_with_progress( + &self, + filename: &str, + mut progress: P, + ) -> Result { let url = self.url(filename); let metadata = self.api.metadata(&url)?; @@ -583,27 +611,9 @@ impl ApiRepo { .blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let progressbar = if self.api.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; + progress.init(metadata.size, filename); - let tmp_filename = self.api.download_tempfile(&url, progressbar)?; + let tmp_filename = self.api.download_tempfile(&url, progress, filename)?; std::fs::rename(tmp_filename, &blob_path)?; @@ -624,6 +634,23 @@ impl ApiRepo { Ok(pointer_path) } + /// Downloads a remote file (if not already present) into the cache directory + /// to be used locally. + /// This functions require internet access to verify if new versions of the file + /// exist, even if a file is already on disk at location. + /// ```no_run + /// # use hf_hub::api::sync::Api; + /// let api = Api::new().unwrap(); + /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap(); + /// ``` + pub fn download(&self, filename: &str) -> Result { + if self.api.progress { + self.download_with_progress(filename, ProgressBar::new(0)) + } else { + self.download_with_progress(filename, ()) + } + } + /// Get information about the Repo /// ``` /// use hf_hub::{api::sync::Api}; diff --git a/src/api/tokio.rs b/src/api/tokio.rs index b3f4959..15d06f9 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,6 +1,8 @@ +use super::Progress as SyncProgress; use super::RepoInfo; use crate::{Cache, Repo, RepoType}; -use indicatif::{ProgressBar, ProgressStyle}; +use futures::StreamExt; +use indicatif::ProgressBar; use rand::Rng; use reqwest::{ header::{ @@ -22,6 +24,38 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); /// Current name (used in user-agent) const NAME: &str = env!("CARGO_PKG_NAME"); +/// This trait is used by users of the lib +/// to implement custom behavior during file downloads +pub trait Progress { + /// At the start of the download + /// The size is the total size in bytes of the file. + fn init(&mut self, size: usize, filename: &str) + -> impl std::future::Future + Send; + /// This function is called whenever `size` bytes have been + /// downloaded in the temporary file + fn update(&mut self, size: usize) -> impl std::future::Future + Send; + /// This is called at the end of the download + fn finish(&mut self) -> impl std::future::Future + Send; +} + +impl Progress for ProgressBar { + async fn init(&mut self, size: usize, filename: &str) { + ::init(self, size, filename); + } + async fn finish(&mut self) { + ::finish(self); + } + async fn update(&mut self, size: usize) { + ::update(self, size); + } +} + +impl Progress for () { + async fn init(&mut self, _size: usize, _filename: &str) {} + async fn finish(&mut self) {} + async fn update(&mut self, _size: usize) {} +} + #[derive(Debug, Error)] /// All errors the API can throw pub enum ApiError { @@ -74,10 +108,9 @@ pub enum ApiError { pub struct ApiBuilder { endpoint: String, cache: Cache, - url_template: String, token: Option, max_files: usize, - chunk_size: usize, + chunk_size: Option, parallel_failures: usize, max_retries: usize, progress: bool, @@ -100,6 +133,23 @@ impl ApiBuilder { Self::from_cache(cache) } + /// High CPU download + /// + /// This may cause issues on regular desktops as it will saturate + /// CPUs by multiplexing the downloads. + /// However on high CPU machines on the cloud, this may help + /// saturate the bandwidth (>500MB/s) better. + /// ``` + /// use hf_hub::api::tokio::ApiBuilder; + /// let api = ApiBuilder::high().build().unwrap(); + /// ``` + pub fn high() -> Self { + let cache = Cache::default(); + Self::from_cache(cache) + .with_max_files(num_cpus::get()) + .with_chunk_size(Some(10_000_000)) + } + /// From a given cache /// ``` /// use hf_hub::{api::tokio::ApiBuilder, Cache}; @@ -114,11 +164,11 @@ impl ApiBuilder { Self { endpoint: "https://huggingface.co".to_string(), - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), cache, token, - max_files: num_cpus::get(), - chunk_size: 10_000_000, + max_files: 1, + // chunk_size: 10_000_000, + chunk_size: None, parallel_failures: 0, max_retries: 0, progress, @@ -130,7 +180,7 @@ impl ApiBuilder { self.progress = progress; self } - + /// Changes the endpoint of the API. Default is `https://huggingface.co`. pub fn with_endpoint(mut self, endpoint: String) -> Self { self.endpoint = endpoint; @@ -149,6 +199,18 @@ impl ApiBuilder { self } + /// Sets the number of open files + pub fn with_max_files(mut self, max_files: usize) -> Self { + self.max_files = max_files; + self + } + + /// Sets the size of each chunk + pub fn with_chunk_size(mut self, chunk_size: Option) -> Self { + self.chunk_size = chunk_size; + self + } + fn build_headers(&self) -> Result { let mut headers = HeaderMap::new(); let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); @@ -192,7 +254,6 @@ impl ApiBuilder { .build()?; Ok(Api { endpoint: self.endpoint, - url_template: self.url_template, cache: self.cache, client, relative_redirect_client, @@ -213,17 +274,15 @@ struct Metadata { } /// The actual Api used to interact with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] +/// Use any repo with [`Api::repo`] #[derive(Clone, Debug)] pub struct Api { endpoint: String, - url_template: String, cache: Cache, client: Client, relative_redirect_client: Client, max_files: usize, - chunk_size: usize, + chunk_size: Option, parallel_failures: usize, max_retries: usize, progress: bool, @@ -400,6 +459,8 @@ impl Api { } /// Shorthand for accessing things within a particular repo +/// You can inspect repos with [`ApiRepo::info`] +/// or download files with [`ApiRepo::download`] #[derive(Debug)] pub struct ApiRepo { api: Api, @@ -423,19 +484,15 @@ impl ApiRepo { pub fn url(&self, filename: &str) -> String { let endpoint = &self.api.endpoint; let revision = &self.repo.url_revision(); - self.api - .url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &self.repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) + let repo_id = self.repo.url(); + format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}") } - async fn download_tempfile( + async fn download_tempfile<'a, P: Progress + Clone + Send + Sync + 'static>( &self, url: &str, length: usize, - progressbar: Option, + mut progressbar: P, ) -> Result { let mut handles = vec![]; let semaphore = Arc::new(Semaphore::new(self.api.max_files)); @@ -448,7 +505,7 @@ impl ApiRepo { .set_len(length as u64) .await?; - let chunk_size = self.api.chunk_size; + let chunk_size = self.api.chunk_size.unwrap_or(length); for start in (0..length).step_by(chunk_size) { let url = url.to_string(); let filename = filename.clone(); @@ -461,7 +518,9 @@ impl ApiRepo { let parallel_failures_semaphore = parallel_failures_semaphore.clone(); let progress = progressbar.clone(); handles.push(tokio::spawn(async move { - let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; + let mut chunk = + Self::download_chunk(&client, &url, &filename, start, stop, progress.clone()) + .await; let mut i = 0; if parallel_failures > 0 { while let Err(dlerr) = chunk { @@ -472,7 +531,15 @@ impl ApiRepo { tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64)) .await; - chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; + chunk = Self::download_chunk( + &client, + &url, + &filename, + start, + stop, + progress.clone(), + ) + .await; i += 1; if i > max_retries { return Err(ApiError::TooManyRetries(dlerr.into())); @@ -481,9 +548,9 @@ impl ApiRepo { } } drop(permit); - if let Some(p) = progress { - p.inc((stop - start) as u64); - } + // if let Some(p) = progress { + // progress.update(stop - start).await; + // } chunk })); } @@ -493,19 +560,21 @@ impl ApiRepo { futures::future::join_all(handles).await; let results: Result<(), ApiError> = results.into_iter().flatten().collect(); results?; - if let Some(p) = progressbar { - p.finish(); - } + progressbar.finish().await; Ok(filename) } - async fn download_chunk( + async fn download_chunk

( client: &reqwest::Client, url: &str, filename: &PathBuf, start: usize, stop: usize, - ) -> Result<(), ApiError> { + mut progress: P, + ) -> Result<(), ApiError> + where + P: Progress, + { // Process each socket concurrently. let range = format!("bytes={start}-{stop}"); let mut file = tokio::fs::OpenOptions::new() @@ -519,8 +588,12 @@ impl ApiRepo { .send() .await? .error_for_status()?; - let content = response.bytes().await?; - file.write_all(&content).await?; + let mut byte_stream = response.bytes_stream(); + while let Some(next) = byte_stream.next().await { + let next = next?; + file.write_all(&next).await?; + progress.update(next.len()).await; + } Ok(()) } @@ -552,6 +625,52 @@ impl ApiRepo { /// # }) /// ``` pub async fn download(&self, filename: &str) -> Result { + if self.api.progress { + self.download_with_progress(filename, ProgressBar::new(0)) + .await + } else { + self.download_with_progress(filename, ()).await + } + } + + /// This function is used to download a file with a custom progress function. + /// It uses the [`Progress`] trait and can be used in more complex use + /// cases like downloading a showing progress in a UI. + /// ```no_run + /// use hf_hub::api::tokio::{Api, Progress}; + /// + /// #[derive(Clone)] + /// struct MyProgress{ + /// current: usize, + /// total: usize + /// } + /// + /// impl Progress for MyProgress{ + /// async fn init(&mut self, size: usize, _filename: &str){ + /// self.total = size; + /// self.current = 0; + /// } + /// + /// async fn update(&mut self, size: usize){ + /// self.current += size; + /// println!("{}/{}", self.current, self.total) + /// } + /// + /// async fn finish(&mut self){ + /// println!("Done !"); + /// } + /// } + /// # tokio_test::block_on(async { + /// let api = Api::new().unwrap(); + /// let progress = MyProgress{ current: 0, total : 0}; + /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).await.unwrap(); + /// # }) + /// ``` + pub async fn download_with_progress( + &self, + filename: &str, + mut progress: P, + ) -> Result { let url = self.url(filename); let metadata = self.api.metadata(&url).await?; let cache = self.api.cache.repo(self.repo.clone()); @@ -559,28 +678,9 @@ impl ApiRepo { let blob_path = cache.blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let progressbar = if self.api.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; - + progress.init(metadata.size, filename).await; let tmp_filename = self - .download_tempfile(&url, metadata.size, progressbar) + .download_tempfile(&url, metadata.size, progress) .await?; tokio::fs::rename(&tmp_filename, &blob_path).await?;