diff --git a/src/api/sync.rs b/src/api/sync.rs index 84bb58a..31d627f 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -3,7 +3,9 @@ use crate::api::sync::ApiError::InvalidHeader; use crate::{Cache, Repo, RepoType}; use http::{StatusCode, Uri}; use indicatif::{ProgressBar, ProgressStyle}; +use rand::Rng; use std::collections::HashMap; +use std::io::Seek; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; use std::str::FromStr; @@ -92,6 +94,7 @@ pub struct ApiBuilder { cache: Cache, url_template: String, token: Option, + max_retries: usize, progress: bool, } @@ -122,6 +125,7 @@ impl ApiBuilder { pub fn from_cache(cache: Cache) -> Self { let token = cache.token(); + let max_retries = 0; let progress = true; let endpoint = @@ -132,6 +136,7 @@ impl ApiBuilder { url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), cache, token, + max_retries, progress, } } @@ -160,6 +165,12 @@ impl ApiBuilder { self } + /// Sets the number of times the API will retry to download a file + pub fn with_retries(mut self, max_retries: usize) -> Self { + self.max_retries = max_retries; + self + } + fn build_headers(&self) -> HeaderMap { let mut headers = HeaderMap::new(); let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); @@ -191,6 +202,7 @@ impl ApiBuilder { client, no_redirect_client, + max_retries: self.max_retries, progress: self.progress, }) } @@ -213,6 +225,7 @@ pub struct Api { cache: Cache, client: HeaderAgent, no_redirect_client: HeaderAgent, + max_retries: usize, progress: bool, } @@ -270,6 +283,14 @@ fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { Ok(()) } +fn jitter() -> usize { + rand::thread_rng().gen_range(0..=500) +} + +fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { + (base_wait_time + n.pow(2) + jitter()).min(max) +} + impl Api { /// Creates a default Api, for Api options See [`ApiBuilder`] pub fn new() -> Result { @@ -391,6 +412,26 @@ impl Api { // Create the file and set everything properly let mut file = std::fs::File::create(&filename)?; + if self.max_retries > 0 { + let mut i = 0; + let mut res = self.download_from(url, 0u64, &mut file); + while let Err(dlerr) = res { + let wait_time = exponential_backoff(300, i, 10_000); + std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); + + res = self.download_from(url, file.stream_position()?, &mut file); + i += 1; + if i > self.max_retries { + return Err(ApiError::TooManyRetries(dlerr.into())); + } + } + res?; + if let Some(p) = progressbar { + p.finish() + } + return Ok(filename); + } + let response = self.client.get(url).call().map_err(Box::new)?; let mut reader = response.into_reader(); @@ -406,6 +447,24 @@ impl Api { Ok(filename) } + fn download_from( + &self, + url: &str, + current: u64, + file: &mut std::fs::File, + ) -> Result<(), ApiError> { + let range = format!("bytes={current}-"); + let response = self + .client + .get(url) + .set(RANGE, &range) + .call() + .map_err(Box::new)?; + let mut reader = response.into_reader(); + std::io::copy(&mut reader, file)?; + Ok(()) + } + /// Creates a new handle [`ApiRepo`] which contains operations /// on a particular [`Repo`] pub fn repo(&self, repo: Repo) -> ApiRepo { @@ -651,6 +710,34 @@ mod tests { assert_eq!(cache_path, downloaded_path); } + #[test] + fn simple_with_retries() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .with_retries(3) + .build() + .unwrap(); + + let model_id = "julien-c/dummy-unknown".to_string(); + let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + + // Make sure the file is now seeable without connection + let cache_path = api + .cache + .repo(Repo::new(model_id, RepoType::Model)) + .get("config.json") + .unwrap(); + assert_eq!(cache_path, downloaded_path); + } + #[test] fn dataset() { let tmp = TempDir::new();