Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resumable downloads. #84

Merged
merged 5 commits into from
Dec 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
@@ -22,6 +22,8 @@
default = pkgs.mkShell {
buildInputs = with pkgs; [
rustup
pkg-config
openssl
];
};

55 changes: 51 additions & 4 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use indicatif::{ProgressBar, ProgressStyle};
use std::{collections::VecDeque, time::Duration};

use indicatif::{style::ProgressTracker, HumanBytes, ProgressBar, ProgressStyle};
use serde::Deserialize;

/// The asynchronous version of the API
@@ -33,9 +35,9 @@ impl Progress for ProgressBar {
self.set_length(size as u64);
self.set_style(
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec_smoothed} ({eta})",
).unwrap().with_key("bytes_per_sec_smoothed", MovingAvgRate::default())
,
);
let maxlength = 30;
let message = if filename.len() > maxlength {
@@ -71,3 +73,48 @@ pub struct RepoInfo {
/// The commit sha of the repo.
pub sha: String,
}

#[derive(Clone, Default)]
struct MovingAvgRate {
samples: VecDeque<(std::time::Instant, u64)>,
}

impl ProgressTracker for MovingAvgRate {
fn clone_box(&self) -> Box<dyn ProgressTracker> {
Box::new(self.clone())
}

fn tick(&mut self, state: &indicatif::ProgressState, now: std::time::Instant) {
// sample at most every 20ms
if self
.samples
.back()
.map_or(true, |(prev, _)| (now - *prev) > Duration::from_millis(20))
{
self.samples.push_back((now, state.pos()));
}

while let Some(first) = self.samples.front() {
if now - first.0 > Duration::from_secs(1) {
self.samples.pop_front();
} else {
break;
}
}
}

fn reset(&mut self, _state: &indicatif::ProgressState, _now: std::time::Instant) {
self.samples = Default::default();
}

fn write(&self, _state: &indicatif::ProgressState, w: &mut dyn std::fmt::Write) {
match (self.samples.front(), self.samples.back()) {
(Some((t0, p0)), Some((t1, p1))) if self.samples.len() > 1 => {
let elapsed_ms = (*t1 - *t0).as_millis();
let rate = ((p1 - p0) as f64 * 1000f64 / elapsed_ms as f64) as u64;
write!(w, "{}/s", HumanBytes(rate)).unwrap()
}
_ => write!(w, "-").unwrap(),
}
}
}
112 changes: 106 additions & 6 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
@@ -28,6 +28,9 @@ const AUTHORIZATION: &str = "Authorization";
type HeaderMap = HashMap<&'static str, String>;
type HeaderName = &'static str;

/// Specific name for the sync part of the resumable file
const EXTENTION: &str = ".part";

struct Wrapper<'a, P: Progress, R: Read> {
progress: &'a mut P,
inner: R,
@@ -104,6 +107,10 @@ pub enum ApiError {
#[error("Native tls: {0}")]
#[cfg(feature = "native-tls")]
Native(#[from] native_tls::Error),

/// The part file is corrupted
#[error("Invalid part file - corrupted file")]
InvalidResume,
}

/// Helper to create [`Api`] with all the options.
@@ -419,15 +426,26 @@ impl Api {
url: &str,
size: usize,
mut progress: P,
tmp_path: PathBuf,
filename: &str,
) -> Result<PathBuf, ApiError> {
progress.init(size, filename);
let filepath = self.cache.temp_path();
let filepath = tmp_path;

// Create the file and set everything properly
let mut file = std::fs::File::create(&filepath)?;

let mut res = self.download_from(url, 0u64, size, &mut file, filename, &mut progress);
let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) {
let mut part_file = match std::fs::OpenOptions::new().append(true).open(&filepath) {

Ok(f) => f,
Err(_) => std::fs::File::create(&filepath)?,
};

// In case of resume.
let start = file.metadata()?.len();
if start > size as u64 {
return Err(ApiError::InvalidResume);
}

let mut res = self.download_from(url, start, size, &mut file, filename, &mut progress);
if self.max_retries > 0 {
let mut i = 0;
while let Err(dlerr) = res {
@@ -614,9 +632,11 @@ impl ApiRepo {
.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let tmp_filename = self
.api
.download_tempfile(&url, metadata.size, progress, filename)?;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENTION);
let tmp_filename =
self.api
.download_tempfile(&url, metadata.size, progress, tmp_path, filename)?;

std::fs::rename(tmp_filename, &blob_path)?;
let mut pointer_path = self
@@ -687,6 +707,7 @@ mod tests {
use rand::{distributions::Alphanumeric, Rng};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::io::{Seek, SeekFrom, Write};

struct TempDir {
path: PathBuf,
@@ -739,6 +760,85 @@ mod tests {
assert_eq!(cache_path, downloaded_path);
}

#[test]
fn resume() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();

let model_id = "julien-c/dummy-unknown".to_string();
let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);

let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
let truncate: f32 = rand::random();
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension(".part");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use EXTENTION here as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done on purpose to have it hardcoded in the test. I don't want my test to change when library code changes.

std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, new_size);
let val = Sha256::digest(content);
// We modified the sha.
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);

// Here we prove the previous part was correctly resuming by purposefully corrupting the
// file.
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
// Not random for consistent sha corruption
let truncate: f32 = 0.5;
let new_size = (size as f32 * truncate) as u64;
// Truncating
file.set_len(new_size).unwrap();
// Corrupting by changing a single byte.
file.seek(SeekFrom::Start(new_size - 1)).unwrap();
file.write_all(&[0]).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, new_size);
let val = Sha256::digest(content);
// We modified the sha.
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
println!("Sha {val:#x}");
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
// Corrupted sha
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5")
);
}

#[test]
fn simple_with_retries() {
let tmp = TempDir::new();
258 changes: 230 additions & 28 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::Progress as SyncProgress;
use super::RepoInfo;
use crate::{Cache, Repo, RepoType};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use indicatif::ProgressBar;
use rand::Rng;
@@ -12,18 +13,24 @@ use reqwest::{
redirect::Policy,
Client, Error as ReqwestError, RequestBuilder,
};
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::AsyncReadExt;
use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
use tokio::sync::{AcquireError, Semaphore, TryAcquireError};
use tokio::task::JoinError;

/// Current version (used in user-agent)
const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Current name (used in user-agent)
const NAME: &str = env!("CARGO_PKG_NAME");

const EXTENTION: &str = ".sync.part";

/// This trait is used by users of the lib
/// to implement custom behavior during file downloads
pub trait Progress {
@@ -101,6 +108,9 @@ pub enum ApiError {
// /// Semaphore cannot be acquired
// #[error("Invalid Response: {0:?}")]
// InvalidResponse(Response),
/// Join failed
#[error("Join: {0}")]
Join(#[from] JoinError),
}

/// Helper to create [`Api`] with all the options.
@@ -167,8 +177,8 @@ impl ApiBuilder {
cache,
token,
max_files: 1,
// chunk_size: 10_000_000,
chunk_size: None,
// We need to have some chunk size for things to be able to resume.
chunk_size: Some(10_000_000),
parallel_failures: 0,
max_retries: 0,
progress,
@@ -492,32 +502,61 @@ impl ApiRepo {
&self,
url: &str,
length: usize,
filename: PathBuf,
mut progressbar: P,
) -> Result<PathBuf, ApiError> {
let mut handles = vec![];
let semaphore = Arc::new(Semaphore::new(self.api.max_files));
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures));
let filename = self.api.cache.temp_path();

// Create the file and set everything properly
tokio::fs::File::create(&filename)
.await?
.set_len(length as u64)
.await?;
const N_BYTES: usize = size_of::<u64>();
let start = match tokio::fs::OpenOptions::new()
.read(true)
.open(&filename)
.await
{
Ok(mut f) => {
let len = f.metadata().await?.len();
if len == (length + N_BYTES) as u64 {
f.seek(SeekFrom::Start(length as u64)).await.unwrap();
let mut buf = [0u8; N_BYTES];
let n = f.read(buf.as_mut_slice()).await?;
if n == N_BYTES {
let committed = u64::from_le_bytes(buf);
committed as usize
} else {
0
}
} else {
0
}
}
Err(_err) => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Err(_err) => {
Err(_) => {

tokio::fs::File::create(&filename)
.await?
.set_len((length + N_BYTES) as u64)
.await?;
0
}
};
progressbar.update(start).await;

let chunk_size = self.api.chunk_size.unwrap_or(length);
for start in (0..length).step_by(chunk_size) {
let n_chunks = length / chunk_size;
let mut handles = Vec::with_capacity(n_chunks);
for start in (start..length).step_by(chunk_size) {
let url = url.to_string();
let filename = filename.clone();
let client = self.api.client.clone();

let stop = std::cmp::min(start + chunk_size - 1, length);
let permit = semaphore.clone().acquire_owned().await?;
let permit = semaphore.clone();
let parallel_failures = self.api.parallel_failures;
let max_retries = self.api.max_retries;
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
let progress = progressbar.clone();
handles.push(tokio::spawn(async move {
let permit = permit.acquire_owned().await?;
let mut chunk =
Self::download_chunk(&client, &url, &filename, start, stop, progress.clone())
.await;
@@ -548,18 +587,43 @@ impl ApiRepo {
}
}
drop(permit);
// if let Some(p) = progress {
// progress.update(stop - start).await;
// }
chunk
}));
}

// Output the chained result
let results: Vec<Result<Result<(), ApiError>, tokio::task::JoinError>> =
futures::future::join_all(handles).await;
let results: Result<(), ApiError> = results.into_iter().flatten().collect();
results?;
let mut futures: FuturesUnordered<_> = handles.into_iter().collect();
let mut temporaries = BinaryHeap::new();
let mut committed: u64 = start as u64;
while let Some(chunk) = futures.next().await {
let chunk = chunk?;
let (start, stop) = chunk?;
temporaries.push(Reverse((start, stop)));

let mut modified = false;
while let Some(Reverse((min, max))) = temporaries.pop() {
if min as u64 == committed {
committed = max as u64 + 1;
modified = true;
} else {
temporaries.push(Reverse((min, max)));
break;
}
}
if modified {
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.open(&filename)
.await?;
file.seek(SeekFrom::Start(length as u64)).await?;
file.write_all(&committed.to_le_bytes()).await?;
}
}
tokio::fs::OpenOptions::new()
.write(true)
.open(&filename)
.await?
.set_len(length as u64)
.await?;
progressbar.finish().await;
Ok(filename)
}
@@ -571,30 +635,32 @@ impl ApiRepo {
start: usize,
stop: usize,
mut progress: P,
) -> Result<(), ApiError>
) -> Result<(usize, usize), ApiError>
where
P: Progress,
{
// Process each socket concurrently.
let range = format!("bytes={start}-{stop}");
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.open(filename)
.await?;
file.seek(SeekFrom::Start(start as u64)).await?;
let response = client
.get(url)
.header(RANGE, range)
.send()
.await?
.error_for_status()?;
let mut byte_stream = response.bytes_stream();
let mut buf: Vec<u8> = Vec::with_capacity(stop - start);
while let Some(next) = byte_stream.next().await {
let next = next?;
file.write_all(&next).await?;
buf.extend(&next);
progress.update(next.len()).await;
}
Ok(())
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.open(filename)
.await?;
file.seek(SeekFrom::Start(start as u64)).await?;
file.write_all(&buf).await?;
Ok((start, stop))
}

/// This will attempt the fetch the file locally first, then [`Api.download`]
@@ -679,9 +745,12 @@ impl ApiRepo {
std::fs::create_dir_all(blob_path.parent().unwrap())?;

progress.init(metadata.size, filename).await;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENTION);
let tmp_filename = self
.download_tempfile(&url, metadata.size, progress)
.await?;
.download_tempfile(&url, metadata.size, tmp_path, progress)
.await
.unwrap();

tokio::fs::rename(&tmp_filename, &blob_path).await?;

@@ -734,6 +803,7 @@ mod tests {
use rand::distributions::Alphanumeric;
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::io::{Seek, Write};

struct TempDir {
path: PathBuf,
@@ -782,6 +852,138 @@ mod tests {
assert_eq!(cache_path, downloaded_path);
}

#[tokio::test]
async fn resume() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let model_id = "julien-c/dummy-unknown".to_string();
let downloaded_path = api
.model(model_id.clone())
.download("config.json")
.await
.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);

// This actually sets the file to a trashed version of the part file, full redownload will
// ensue
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
let truncate: f32 = rand::random();
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, new_size);
let val = Sha256::digest(content);
// We modified the sha.
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api
.model(model_id.clone())
.download("config.json")
.await
.unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);

// Now this is a valid partial download file
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
let truncate: f32 = rand::random();
let new_size = (size as f32 * truncate) as u64;
// Truncating
file.set_len(new_size).unwrap();
let total_size = size + size_of::<u64>() as u64;
file.set_len(total_size).unwrap();
file.seek(SeekFrom::Start(size)).unwrap();
file.write_all(&new_size.to_le_bytes()).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, total_size);
let val = Sha256::digest(content);
// We modified the sha.
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api
.model(model_id.clone())
.download("config.json")
.await
.unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);

// Here we prove the previous part was correctly resuming by purposefully corrupting the
// file.
let blob = std::fs::canonicalize(&downloaded_path).unwrap();
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap();
let size = file.metadata().unwrap().len();
// Not random for consistent sha corruption
let truncate: f32 = 0.5;
let new_size = (size as f32 * truncate) as u64;
// Truncating
file.set_len(new_size).unwrap();
let total_size = size + size_of::<u64>() as u64;
file.set_len(total_size).unwrap();
file.seek(SeekFrom::Start(size)).unwrap();
file.write_all(&new_size.to_le_bytes()).unwrap();

// Corrupting by changing a single byte.
file.seek(SeekFrom::Start(new_size - 1)).unwrap();
file.write_all(&[0]).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
assert_eq!(content.len() as u64, total_size);
let val = Sha256::digest(content);
// We modified the sha.
assert!(
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let new_downloaded_path = api
.model(model_id.clone())
.download("config.json")
.await
.unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
assert_eq!(
val[..],
// Corrupted sha
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5")
);
}

#[tokio::test]
async fn revision() {
let tmp = TempDir::new();
17 changes: 0 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -4,8 +4,6 @@
not(feature = "ureq"),
doc = "Documentation is meant to be compiled with default features (at least ureq)"
)]
#[cfg(any(feature = "tokio", feature = "ureq"))]
use rand::{distributions::Alphanumeric, Rng};
use std::io::Write;
use std::path::PathBuf;

@@ -109,21 +107,6 @@ impl Cache {
pub fn space(&self, model_id: String) -> CacheRepo {
self.repo(Repo::new(model_id, RepoType::Space))
}

#[cfg(any(feature = "tokio", feature = "ureq"))]
pub(crate) fn temp_path(&self) -> PathBuf {
let mut path = self.path().clone();
path.push("tmp");
std::fs::create_dir_all(&path).ok();

let s: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect();
path.push(s);
path.to_path_buf()
}
}

/// Shorthand for accessing things within a particular repo