From 8804781b911d0abe3dc256161ee5f75b6e4fed81 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 27 Dec 2024 14:46:22 +0100 Subject: [PATCH] Hotfix the progress calls for `sync`. Co-Authored-By: neo773 Dummy co-author based on work in https://github.com/huggingface/hf-hub/pull/77/commits/437dcf4049233f4a0e0cfc4e1c1dbf5ed2c57760 about rustls-tls feature. --- src/api/sync.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/api/sync.rs b/src/api/sync.rs index a0f96b2..ff191d3 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -414,18 +414,20 @@ impl Api { }) } - fn download_tempfile( + fn download_tempfile( &self, url: &str, + size: usize, mut progress: P, filename: &str, ) -> Result { + progress.init(size, filename); let filepath = self.cache.temp_path(); // Create the file and set everything properly let mut file = std::fs::File::create(&filepath)?; - let mut res = self.download_from(url, 0u64, &mut file, filename, &mut progress); + let mut res = self.download_from(url, 0u64, size, &mut file, filename, &mut progress); if self.max_retries > 0 { let mut i = 0; while let Err(dlerr) = res { @@ -433,7 +435,7 @@ impl Api { std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); let current = file.stream_position()?; - res = self.download_from(url, current, &mut file, filename, &mut progress); + res = self.download_from(url, current, size, &mut file, filename, &mut progress); i += 1; if i > self.max_retries { return Err(ApiError::TooManyRetries(dlerr.into())); @@ -448,6 +450,7 @@ impl Api { &self, url: &str, current: u64, + size: usize, file: &mut std::fs::File, filename: &str, progress: &mut P, @@ -463,7 +466,7 @@ impl Api { .call() .map_err(Box::new)?; let reader = response.into_reader(); - progress.init(0, filename); + progress.init(size, filename); progress.update(current as usize); let mut reader = Box::new(wrap_read(reader, progress)); std::io::copy(&mut reader, file)?; @@ -596,10 +599,10 @@ impl ApiRepo { /// let progress = MyProgress{current: 0, total: 0}; /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).unwrap(); /// ``` - pub fn download_with_progress( + pub fn download_with_progress( &self, filename: &str, - mut progress: P, + progress: P, ) -> Result { let url = self.url(filename); let metadata = self.api.metadata(&url)?; @@ -611,12 +614,11 @@ impl ApiRepo { .blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - progress.init(metadata.size, filename); - - let tmp_filename = self.api.download_tempfile(&url, progress, filename)?; + let tmp_filename = self + .api + .download_tempfile(&url, metadata.size, progress, filename)?; std::fs::rename(tmp_filename, &blob_path)?; - let mut pointer_path = self .api .cache