Skip to content

Commit

Permalink
Resumable downloads. (#84)
Browse files Browse the repository at this point in the history
* Resumable downloads.

* Clippy.

* Proof of resumability through manual corruption.

* Speeding up downloads (less writes) + more accurate estimates (moving
window)

* Remove unwrap.
  • Loading branch information
Narsil authored Dec 30, 2024
1 parent 57c58af commit 41d49d6
Show file tree
Hide file tree
Showing 4 changed files with 387 additions and 55 deletions.
55 changes: 51 additions & 4 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use indicatif::{ProgressBar, ProgressStyle};
use std::{collections::VecDeque, time::Duration};

use indicatif::{style::ProgressTracker, HumanBytes, ProgressBar, ProgressStyle};
use serde::Deserialize;

/// The asynchronous version of the API
Expand Down Expand Up @@ -35,9 +37,9 @@ impl Progress for ProgressBar {
self.set_length(size as u64);
self.set_style(
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec_smoothed} ({eta})",
).unwrap().with_key("bytes_per_sec_smoothed", MovingAvgRate::default())
,
);
let maxlength = 30;
let message = if filename.len() > maxlength {
Expand Down Expand Up @@ -73,3 +75,48 @@ pub struct RepoInfo {
/// The commit sha of the repo.
pub sha: String,
}

#[derive(Clone, Default)]
struct MovingAvgRate {
samples: VecDeque<(std::time::Instant, u64)>,
}

impl ProgressTracker for MovingAvgRate {
fn clone_box(&self) -> Box<dyn ProgressTracker> {
Box::new(self.clone())
}

fn tick(&mut self, state: &indicatif::ProgressState, now: std::time::Instant) {
// sample at most every 20ms
if self
.samples
.back()
.map_or(true, |(prev, _)| (now - *prev) > Duration::from_millis(20))
{
self.samples.push_back((now, state.pos()));
}

while let Some(first) = self.samples.front() {
if now - first.0 > Duration::from_secs(1) {
self.samples.pop_front();
} else {
break;
}
}
}

fn reset(&mut self, _state: &indicatif::ProgressState, _now: std::time::Instant) {
self.samples = Default::default();
}

fn write(&self, _state: &indicatif::ProgressState, w: &mut dyn std::fmt::Write) {
match (self.samples.front(), self.samples.back()) {
(Some((t0, p0)), Some((t1, p1))) if self.samples.len() > 1 => {
let elapsed_ms = (*t1 - *t0).as_millis();
let rate = ((p1 - p0) as f64 * 1000f64 / elapsed_ms as f64) as u64;
write!(w, "{}/s", HumanBytes(rate)).unwrap()
}
_ => write!(w, "-").unwrap(),
}
}
}
112 changes: 106 additions & 6 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ const AUTHORIZATION: &str = "Authorization";
type HeaderMap = HashMap<&'static str, String>;
type HeaderName = &'static str;

/// Specific name for the sync part of the resumable file
const EXTENTION: &str = ".part";

struct Wrapper<'a, P: Progress, R: Read> {
progress: &'a mut P,
inner: R,
Expand Down Expand Up @@ -104,6 +107,10 @@ pub enum ApiError {
#[error("Native tls: {0}")]
#[cfg(feature = "native-tls")]
Native(#[from] native_tls::Error),

/// The part file is corrupted
#[error("Invalid part file - corrupted file")]
InvalidResume,
}

/// Helper to create [`Api`] with all the options.
Expand Down Expand Up @@ -436,15 +443,26 @@ impl Api {
url: &str,
size: usize,
mut progress: P,
tmp_path: PathBuf,
filename: &str,
) -> Result<PathBuf, ApiError> {
progress.init(size, filename);
let filepath = self.cache.temp_path();
let filepath = tmp_path;

// Create the file and set everything properly
let mut file = std::fs::File::create(&filepath)?;

let mut res = self.download_from(url, 0u64, size, &mut file, filename, &mut progress);
let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) {
Ok(f) => f,
Err(_) => std::fs::File::create(&filepath)?,
};

// In case of resume.
let start = file.metadata()?.len();
if start > size as u64 {
return Err(ApiError::InvalidResume);
}

let mut res = self.download_from(url, start, size, &mut file, filename, &mut progress);
if self.max_retries > 0 {
let mut i = 0;
while let Err(dlerr) = res {
Expand Down Expand Up @@ -631,9 +649,11 @@ impl ApiRepo {
.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let tmp_filename = self
.api
.download_tempfile(&url, metadata.size, progress, filename)?;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENTION);
let tmp_filename =
self.api
.download_tempfile(&url, metadata.size, progress, tmp_path, filename)?;

std::fs::rename(tmp_filename, &blob_path)?;
let mut pointer_path = self
Expand Down Expand Up @@ -704,6 +724,7 @@ mod tests {
use rand::{distributions::Alphanumeric, Rng};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::io::{Seek, SeekFrom, Write};

struct TempDir {
path: PathBuf,
Expand Down Expand Up @@ -756,6 +777,85 @@ mod tests {
assert_eq!(cache_path, downloaded_path);
}

#[test]
fn resume() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();

let model_id = "julien-c/dummy-unknown".to_string();
let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);

let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
let truncate: f32 = rand::random();
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension(".part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, new_size);
let val = Sha256::digest(content);
// We modified the sha.
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);

// Here we prove the previous part was correctly resuming by purposefully corrupting the
// file.
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
// Not random for consistent sha corruption
let truncate: f32 = 0.5;
let new_size = (size as f32 * truncate) as u64;
// Truncating
file.set_len(new_size).unwrap();
// Corrupting by changing a single byte.
file.seek(SeekFrom::Start(new_size - 1)).unwrap();
file.write_all(&[0]).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, new_size);
let val = Sha256::digest(content);
// We modified the sha.
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
println!("Sha {val:#x}");
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
// Corrupted sha
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5")
);
}

#[test]
fn simple_with_retries() {
let tmp = TempDir::new();
Expand Down
Loading

0 comments on commit 41d49d6

Please sign in to comment.