Skip to content

Commit

Permalink
Hotfix the progress calls for sync.
Browse files Browse the repository at this point in the history
Co-Authored-By: neo773 <[email protected]>

Dummy co-author based on work in 437dcf4
about rustls-tls feature.
  • Loading branch information
Narsil committed Dec 27, 2024
1 parent a4e366a commit 8804781
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,26 +414,28 @@ impl Api {
})
}

fn download_tempfile<P: Progress + Sync + Send>(
fn download_tempfile<P: Progress>(
&self,
url: &str,
size: usize,
mut progress: P,
filename: &str,
) -> Result<PathBuf, ApiError> {
progress.init(size, filename);
let filepath = self.cache.temp_path();

// Create the file and set everything properly
let mut file = std::fs::File::create(&filepath)?;

let mut res = self.download_from(url, 0u64, &mut file, filename, &mut progress);
let mut res = self.download_from(url, 0u64, size, &mut file, filename, &mut progress);
if self.max_retries > 0 {
let mut i = 0;
while let Err(dlerr) = res {
let wait_time = exponential_backoff(300, i, 10_000);
std::thread::sleep(std::time::Duration::from_millis(wait_time as u64));

let current = file.stream_position()?;
res = self.download_from(url, current, &mut file, filename, &mut progress);
res = self.download_from(url, current, size, &mut file, filename, &mut progress);
i += 1;
if i > self.max_retries {
return Err(ApiError::TooManyRetries(dlerr.into()));
Expand All @@ -448,6 +450,7 @@ impl Api {
&self,
url: &str,
current: u64,
size: usize,
file: &mut std::fs::File,
filename: &str,
progress: &mut P,
Expand All @@ -463,7 +466,7 @@ impl Api {
.call()
.map_err(Box::new)?;
let reader = response.into_reader();
progress.init(0, filename);
progress.init(size, filename);
progress.update(current as usize);
let mut reader = Box::new(wrap_read(reader, progress));
std::io::copy(&mut reader, file)?;
Expand Down Expand Up @@ -596,10 +599,10 @@ impl ApiRepo {
/// let progress = MyProgress{current: 0, total: 0};
/// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).unwrap();
/// ```
pub fn download_with_progress<P: Progress + Send + Sync>(
pub fn download_with_progress<P: Progress>(
&self,
filename: &str,
mut progress: P,
progress: P,
) -> Result<PathBuf, ApiError> {
let url = self.url(filename);
let metadata = self.api.metadata(&url)?;
Expand All @@ -611,12 +614,11 @@ impl ApiRepo {
.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

progress.init(metadata.size, filename);

let tmp_filename = self.api.download_tempfile(&url, progress, filename)?;
let tmp_filename = self
.api
.download_tempfile(&url, metadata.size, progress, filename)?;

std::fs::rename(tmp_filename, &blob_path)?;

let mut pointer_path = self
.api
.cache
Expand Down

0 comments on commit 8804781

Please sign in to comment.