-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resumable downloads. #84
Changes from all commits
13d5f0f
7692515
d5eb805
6e7668a
d0895cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,8 @@ | |
default = pkgs.mkShell { | ||
buildInputs = with pkgs; [ | ||
rustup | ||
pkg-config | ||
openssl | ||
]; | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,9 @@ const AUTHORIZATION: &str = "Authorization"; | |
type HeaderMap = HashMap<&'static str, String>; | ||
type HeaderName = &'static str; | ||
|
||
/// Specific name for the sync part of the resumable file | ||
const EXTENTION: &str = ".part"; | ||
|
||
struct Wrapper<'a, P: Progress, R: Read> { | ||
progress: &'a mut P, | ||
inner: R, | ||
|
@@ -104,6 +107,10 @@ pub enum ApiError { | |
#[error("Native tls: {0}")] | ||
#[cfg(feature = "native-tls")] | ||
Native(#[from] native_tls::Error), | ||
|
||
/// The part file is corrupted | ||
#[error("Invalid part file - corrupted file")] | ||
InvalidResume, | ||
} | ||
|
||
/// Helper to create [`Api`] with all the options. | ||
|
@@ -419,15 +426,26 @@ impl Api { | |
url: &str, | ||
size: usize, | ||
mut progress: P, | ||
tmp_path: PathBuf, | ||
filename: &str, | ||
) -> Result<PathBuf, ApiError> { | ||
progress.init(size, filename); | ||
let filepath = self.cache.temp_path(); | ||
let filepath = tmp_path; | ||
|
||
// Create the file and set everything properly | ||
let mut file = std::fs::File::create(&filepath)?; | ||
|
||
let mut res = self.download_from(url, 0u64, size, &mut file, filename, &mut progress); | ||
let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) { | ||
Ok(f) => f, | ||
Err(_) => std::fs::File::create(&filepath)?, | ||
}; | ||
|
||
// In case of resume. | ||
let start = file.metadata()?.len(); | ||
if start > size as u64 { | ||
return Err(ApiError::InvalidResume); | ||
} | ||
|
||
let mut res = self.download_from(url, start, size, &mut file, filename, &mut progress); | ||
if self.max_retries > 0 { | ||
let mut i = 0; | ||
while let Err(dlerr) = res { | ||
|
@@ -614,9 +632,11 @@ impl ApiRepo { | |
.blob_path(&metadata.etag); | ||
std::fs::create_dir_all(blob_path.parent().unwrap())?; | ||
|
||
let tmp_filename = self | ||
.api | ||
.download_tempfile(&url, metadata.size, progress, filename)?; | ||
let mut tmp_path = blob_path.clone(); | ||
tmp_path.set_extension(EXTENTION); | ||
let tmp_filename = | ||
self.api | ||
.download_tempfile(&url, metadata.size, progress, tmp_path, filename)?; | ||
|
||
std::fs::rename(tmp_filename, &blob_path)?; | ||
let mut pointer_path = self | ||
|
@@ -687,6 +707,7 @@ mod tests { | |
use rand::{distributions::Alphanumeric, Rng}; | ||
use serde_json::{json, Value}; | ||
use sha2::{Digest, Sha256}; | ||
use std::io::{Seek, SeekFrom, Write}; | ||
|
||
struct TempDir { | ||
path: PathBuf, | ||
|
@@ -739,6 +760,85 @@ mod tests { | |
assert_eq!(cache_path, downloaded_path); | ||
} | ||
|
||
#[test] | ||
fn resume() { | ||
let tmp = TempDir::new(); | ||
let api = ApiBuilder::new() | ||
.with_progress(false) | ||
.with_cache_dir(tmp.path.clone()) | ||
.build() | ||
.unwrap(); | ||
|
||
let model_id = "julien-c/dummy-unknown".to_string(); | ||
let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); | ||
assert!(downloaded_path.exists()); | ||
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); | ||
assert_eq!( | ||
val[..], | ||
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||
); | ||
|
||
let blob = std::fs::canonicalize(&downloaded_path).unwrap(); | ||
let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); | ||
let size = file.metadata().unwrap().len(); | ||
let truncate: f32 = rand::random(); | ||
let new_size = (size as f32 * truncate) as u64; | ||
file.set_len(new_size).unwrap(); | ||
let mut blob_part = blob.clone(); | ||
blob_part.set_extension(".part"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's done on purpose to have it hardcoded in the test. I don't want my test to change when library code changes. |
||
std::fs::rename(blob, &blob_part).unwrap(); | ||
std::fs::remove_file(&downloaded_path).unwrap(); | ||
let content = std::fs::read(&*blob_part).unwrap(); | ||
assert_eq!(content.len() as u64, new_size); | ||
let val = Sha256::digest(content); | ||
// We modified the sha. | ||
assert!( | ||
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||
); | ||
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); | ||
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); | ||
assert_eq!(downloaded_path, new_downloaded_path); | ||
assert_eq!( | ||
val[..], | ||
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||
); | ||
|
||
// Here we prove the previous part was correctly resuming by purposefully corrupting the | ||
// file. | ||
let blob = std::fs::canonicalize(&downloaded_path).unwrap(); | ||
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); | ||
let size = file.metadata().unwrap().len(); | ||
// Not random for consistent sha corruption | ||
let truncate: f32 = 0.5; | ||
let new_size = (size as f32 * truncate) as u64; | ||
// Truncating | ||
file.set_len(new_size).unwrap(); | ||
// Corrupting by changing a single byte. | ||
file.seek(SeekFrom::Start(new_size - 1)).unwrap(); | ||
file.write_all(&[0]).unwrap(); | ||
|
||
let mut blob_part = blob.clone(); | ||
blob_part.set_extension(".part"); | ||
std::fs::rename(blob, &blob_part).unwrap(); | ||
std::fs::remove_file(&downloaded_path).unwrap(); | ||
let content = std::fs::read(&*blob_part).unwrap(); | ||
assert_eq!(content.len() as u64, new_size); | ||
let val = Sha256::digest(content); | ||
// We modified the sha. | ||
assert!( | ||
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||
); | ||
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); | ||
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); | ||
println!("Sha {val:#x}"); | ||
assert_eq!(downloaded_path, new_downloaded_path); | ||
assert_eq!( | ||
val[..], | ||
// Corrupted sha | ||
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5") | ||
); | ||
} | ||
|
||
#[test] | ||
fn simple_with_retries() { | ||
let tmp = TempDir::new(); | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||
use super::Progress as SyncProgress; | ||||||
use super::RepoInfo; | ||||||
use crate::{Cache, Repo, RepoType}; | ||||||
use futures::stream::FuturesUnordered; | ||||||
use futures::StreamExt; | ||||||
use indicatif::ProgressBar; | ||||||
use rand::Rng; | ||||||
|
@@ -12,18 +13,24 @@ use reqwest::{ | |||||
redirect::Policy, | ||||||
Client, Error as ReqwestError, RequestBuilder, | ||||||
}; | ||||||
use std::cmp::Reverse; | ||||||
use std::collections::BinaryHeap; | ||||||
use std::num::ParseIntError; | ||||||
use std::path::{Component, Path, PathBuf}; | ||||||
use std::sync::Arc; | ||||||
use thiserror::Error; | ||||||
use tokio::io::AsyncReadExt; | ||||||
use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; | ||||||
use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; | ||||||
use tokio::task::JoinError; | ||||||
|
||||||
/// Current version (used in user-agent) | ||||||
const VERSION: &str = env!("CARGO_PKG_VERSION"); | ||||||
/// Current name (used in user-agent) | ||||||
const NAME: &str = env!("CARGO_PKG_NAME"); | ||||||
|
||||||
const EXTENTION: &str = ".sync.part"; | ||||||
|
||||||
/// This trait is used by users of the lib | ||||||
/// to implement custom behavior during file downloads | ||||||
pub trait Progress { | ||||||
|
@@ -101,6 +108,9 @@ pub enum ApiError { | |||||
// /// Semaphore cannot be acquired | ||||||
// #[error("Invalid Response: {0:?}")] | ||||||
// InvalidResponse(Response), | ||||||
/// Join failed | ||||||
#[error("Join: {0}")] | ||||||
Join(#[from] JoinError), | ||||||
} | ||||||
|
||||||
/// Helper to create [`Api`] with all the options. | ||||||
|
@@ -167,8 +177,8 @@ impl ApiBuilder { | |||||
cache, | ||||||
token, | ||||||
max_files: 1, | ||||||
// chunk_size: 10_000_000, | ||||||
chunk_size: None, | ||||||
// We need to have some chunk size for things to be able to resume. | ||||||
chunk_size: Some(10_000_000), | ||||||
parallel_failures: 0, | ||||||
max_retries: 0, | ||||||
progress, | ||||||
|
@@ -492,32 +502,61 @@ impl ApiRepo { | |||||
&self, | ||||||
url: &str, | ||||||
length: usize, | ||||||
filename: PathBuf, | ||||||
mut progressbar: P, | ||||||
) -> Result<PathBuf, ApiError> { | ||||||
let mut handles = vec![]; | ||||||
let semaphore = Arc::new(Semaphore::new(self.api.max_files)); | ||||||
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures)); | ||||||
let filename = self.api.cache.temp_path(); | ||||||
|
||||||
// Create the file and set everything properly | ||||||
tokio::fs::File::create(&filename) | ||||||
.await? | ||||||
.set_len(length as u64) | ||||||
.await?; | ||||||
const N_BYTES: usize = size_of::<u64>(); | ||||||
let start = match tokio::fs::OpenOptions::new() | ||||||
.read(true) | ||||||
.open(&filename) | ||||||
.await | ||||||
{ | ||||||
Ok(mut f) => { | ||||||
let len = f.metadata().await?.len(); | ||||||
if len == (length + N_BYTES) as u64 { | ||||||
f.seek(SeekFrom::Start(length as u64)).await.unwrap(); | ||||||
let mut buf = [0u8; N_BYTES]; | ||||||
let n = f.read(buf.as_mut_slice()).await?; | ||||||
if n == N_BYTES { | ||||||
let committed = u64::from_le_bytes(buf); | ||||||
committed as usize | ||||||
} else { | ||||||
0 | ||||||
} | ||||||
} else { | ||||||
0 | ||||||
} | ||||||
} | ||||||
Err(_err) => { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
tokio::fs::File::create(&filename) | ||||||
.await? | ||||||
.set_len((length + N_BYTES) as u64) | ||||||
.await?; | ||||||
0 | ||||||
} | ||||||
}; | ||||||
progressbar.update(start).await; | ||||||
|
||||||
let chunk_size = self.api.chunk_size.unwrap_or(length); | ||||||
for start in (0..length).step_by(chunk_size) { | ||||||
let n_chunks = length / chunk_size; | ||||||
let mut handles = Vec::with_capacity(n_chunks); | ||||||
for start in (start..length).step_by(chunk_size) { | ||||||
let url = url.to_string(); | ||||||
let filename = filename.clone(); | ||||||
let client = self.api.client.clone(); | ||||||
|
||||||
let stop = std::cmp::min(start + chunk_size - 1, length); | ||||||
let permit = semaphore.clone().acquire_owned().await?; | ||||||
let permit = semaphore.clone(); | ||||||
let parallel_failures = self.api.parallel_failures; | ||||||
let max_retries = self.api.max_retries; | ||||||
let parallel_failures_semaphore = parallel_failures_semaphore.clone(); | ||||||
let progress = progressbar.clone(); | ||||||
handles.push(tokio::spawn(async move { | ||||||
let permit = permit.acquire_owned().await?; | ||||||
let mut chunk = | ||||||
Self::download_chunk(&client, &url, &filename, start, stop, progress.clone()) | ||||||
.await; | ||||||
|
@@ -548,18 +587,43 @@ impl ApiRepo { | |||||
} | ||||||
} | ||||||
drop(permit); | ||||||
// if let Some(p) = progress { | ||||||
// progress.update(stop - start).await; | ||||||
// } | ||||||
chunk | ||||||
})); | ||||||
} | ||||||
|
||||||
// Output the chained result | ||||||
let results: Vec<Result<Result<(), ApiError>, tokio::task::JoinError>> = | ||||||
futures::future::join_all(handles).await; | ||||||
let results: Result<(), ApiError> = results.into_iter().flatten().collect(); | ||||||
results?; | ||||||
let mut futures: FuturesUnordered<_> = handles.into_iter().collect(); | ||||||
let mut temporaries = BinaryHeap::new(); | ||||||
let mut committed: u64 = start as u64; | ||||||
while let Some(chunk) = futures.next().await { | ||||||
let chunk = chunk?; | ||||||
let (start, stop) = chunk?; | ||||||
temporaries.push(Reverse((start, stop))); | ||||||
|
||||||
let mut modified = false; | ||||||
while let Some(Reverse((min, max))) = temporaries.pop() { | ||||||
if min as u64 == committed { | ||||||
committed = max as u64 + 1; | ||||||
modified = true; | ||||||
} else { | ||||||
temporaries.push(Reverse((min, max))); | ||||||
break; | ||||||
} | ||||||
} | ||||||
if modified { | ||||||
let mut file = tokio::fs::OpenOptions::new() | ||||||
.write(true) | ||||||
.open(&filename) | ||||||
.await?; | ||||||
file.seek(SeekFrom::Start(length as u64)).await?; | ||||||
file.write_all(&committed.to_le_bytes()).await?; | ||||||
} | ||||||
} | ||||||
tokio::fs::OpenOptions::new() | ||||||
.write(true) | ||||||
.open(&filename) | ||||||
.await? | ||||||
.set_len(length as u64) | ||||||
.await?; | ||||||
progressbar.finish().await; | ||||||
Ok(filename) | ||||||
} | ||||||
|
@@ -571,30 +635,32 @@ impl ApiRepo { | |||||
start: usize, | ||||||
stop: usize, | ||||||
mut progress: P, | ||||||
) -> Result<(), ApiError> | ||||||
) -> Result<(usize, usize), ApiError> | ||||||
where | ||||||
P: Progress, | ||||||
{ | ||||||
// Process each socket concurrently. | ||||||
let range = format!("bytes={start}-{stop}"); | ||||||
let mut file = tokio::fs::OpenOptions::new() | ||||||
.write(true) | ||||||
.open(filename) | ||||||
.await?; | ||||||
file.seek(SeekFrom::Start(start as u64)).await?; | ||||||
let response = client | ||||||
.get(url) | ||||||
.header(RANGE, range) | ||||||
.send() | ||||||
.await? | ||||||
.error_for_status()?; | ||||||
let mut byte_stream = response.bytes_stream(); | ||||||
let mut buf: Vec<u8> = Vec::with_capacity(stop - start); | ||||||
while let Some(next) = byte_stream.next().await { | ||||||
let next = next?; | ||||||
file.write_all(&next).await?; | ||||||
buf.extend(&next); | ||||||
progress.update(next.len()).await; | ||||||
} | ||||||
Ok(()) | ||||||
let mut file = tokio::fs::OpenOptions::new() | ||||||
.write(true) | ||||||
.open(filename) | ||||||
.await?; | ||||||
file.seek(SeekFrom::Start(start as u64)).await?; | ||||||
file.write_all(&buf).await?; | ||||||
Ok((start, stop)) | ||||||
} | ||||||
|
||||||
/// This will attempt the fetch the file locally first, then [`Api.download`] | ||||||
|
@@ -679,9 +745,12 @@ impl ApiRepo { | |||||
std::fs::create_dir_all(blob_path.parent().unwrap())?; | ||||||
|
||||||
progress.init(metadata.size, filename).await; | ||||||
let mut tmp_path = blob_path.clone(); | ||||||
tmp_path.set_extension(EXTENTION); | ||||||
let tmp_filename = self | ||||||
.download_tempfile(&url, metadata.size, progress) | ||||||
.await?; | ||||||
.download_tempfile(&url, metadata.size, tmp_path, progress) | ||||||
.await | ||||||
.unwrap(); | ||||||
|
||||||
tokio::fs::rename(&tmp_filename, &blob_path).await?; | ||||||
|
||||||
|
@@ -734,6 +803,7 @@ mod tests { | |||||
use rand::distributions::Alphanumeric; | ||||||
use serde_json::{json, Value}; | ||||||
use sha2::{Digest, Sha256}; | ||||||
use std::io::{Seek, Write}; | ||||||
|
||||||
struct TempDir { | ||||||
path: PathBuf, | ||||||
|
@@ -782,6 +852,138 @@ mod tests { | |||||
assert_eq!(cache_path, downloaded_path); | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
async fn resume() { | ||||||
let tmp = TempDir::new(); | ||||||
let api = ApiBuilder::new() | ||||||
.with_progress(false) | ||||||
.with_cache_dir(tmp.path.clone()) | ||||||
.build() | ||||||
.unwrap(); | ||||||
let model_id = "julien-c/dummy-unknown".to_string(); | ||||||
let downloaded_path = api | ||||||
.model(model_id.clone()) | ||||||
.download("config.json") | ||||||
.await | ||||||
.unwrap(); | ||||||
assert!(downloaded_path.exists()); | ||||||
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); | ||||||
assert_eq!( | ||||||
val[..], | ||||||
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||||||
); | ||||||
|
||||||
// This actually sets the file to a trashed version of the part file, full redownload will | ||||||
// ensue | ||||||
let blob = std::fs::canonicalize(&downloaded_path).unwrap(); | ||||||
let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); | ||||||
let size = file.metadata().unwrap().len(); | ||||||
let truncate: f32 = rand::random(); | ||||||
let new_size = (size as f32 * truncate) as u64; | ||||||
file.set_len(new_size).unwrap(); | ||||||
let mut blob_part = blob.clone(); | ||||||
blob_part.set_extension(".sync.part"); | ||||||
std::fs::rename(blob, &blob_part).unwrap(); | ||||||
std::fs::remove_file(&downloaded_path).unwrap(); | ||||||
let content = std::fs::read(&*blob_part).unwrap(); | ||||||
assert_eq!(content.len() as u64, new_size); | ||||||
let val = Sha256::digest(content); | ||||||
// We modified the sha. | ||||||
assert!( | ||||||
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||||||
); | ||||||
let new_downloaded_path = api | ||||||
.model(model_id.clone()) | ||||||
.download("config.json") | ||||||
.await | ||||||
.unwrap(); | ||||||
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); | ||||||
assert_eq!(downloaded_path, new_downloaded_path); | ||||||
assert_eq!( | ||||||
val[..], | ||||||
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||||||
); | ||||||
|
||||||
// Now this is a valid partial download file | ||||||
let blob = std::fs::canonicalize(&downloaded_path).unwrap(); | ||||||
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); | ||||||
let size = file.metadata().unwrap().len(); | ||||||
let truncate: f32 = rand::random(); | ||||||
let new_size = (size as f32 * truncate) as u64; | ||||||
// Truncating | ||||||
file.set_len(new_size).unwrap(); | ||||||
let total_size = size + size_of::<u64>() as u64; | ||||||
file.set_len(total_size).unwrap(); | ||||||
file.seek(SeekFrom::Start(size)).unwrap(); | ||||||
file.write_all(&new_size.to_le_bytes()).unwrap(); | ||||||
|
||||||
let mut blob_part = blob.clone(); | ||||||
blob_part.set_extension(".sync.part"); | ||||||
std::fs::rename(blob, &blob_part).unwrap(); | ||||||
std::fs::remove_file(&downloaded_path).unwrap(); | ||||||
let content = std::fs::read(&*blob_part).unwrap(); | ||||||
assert_eq!(content.len() as u64, total_size); | ||||||
let val = Sha256::digest(content); | ||||||
// We modified the sha. | ||||||
assert!( | ||||||
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||||||
); | ||||||
let new_downloaded_path = api | ||||||
.model(model_id.clone()) | ||||||
.download("config.json") | ||||||
.await | ||||||
.unwrap(); | ||||||
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); | ||||||
assert_eq!(downloaded_path, new_downloaded_path); | ||||||
assert_eq!( | ||||||
val[..], | ||||||
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||||||
); | ||||||
|
||||||
// Here we prove the previous part was correctly resuming by purposefully corrupting the | ||||||
// file. | ||||||
let blob = std::fs::canonicalize(&downloaded_path).unwrap(); | ||||||
let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); | ||||||
let size = file.metadata().unwrap().len(); | ||||||
// Not random for consistent sha corruption | ||||||
let truncate: f32 = 0.5; | ||||||
let new_size = (size as f32 * truncate) as u64; | ||||||
// Truncating | ||||||
file.set_len(new_size).unwrap(); | ||||||
let total_size = size + size_of::<u64>() as u64; | ||||||
file.set_len(total_size).unwrap(); | ||||||
file.seek(SeekFrom::Start(size)).unwrap(); | ||||||
file.write_all(&new_size.to_le_bytes()).unwrap(); | ||||||
|
||||||
// Corrupting by changing a single byte. | ||||||
file.seek(SeekFrom::Start(new_size - 1)).unwrap(); | ||||||
file.write_all(&[0]).unwrap(); | ||||||
|
||||||
let mut blob_part = blob.clone(); | ||||||
blob_part.set_extension(".sync.part"); | ||||||
std::fs::rename(blob, &blob_part).unwrap(); | ||||||
std::fs::remove_file(&downloaded_path).unwrap(); | ||||||
let content = std::fs::read(&*blob_part).unwrap(); | ||||||
assert_eq!(content.len() as u64, total_size); | ||||||
let val = Sha256::digest(content); | ||||||
// We modified the sha. | ||||||
assert!( | ||||||
val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") | ||||||
); | ||||||
let new_downloaded_path = api | ||||||
.model(model_id.clone()) | ||||||
.download("config.json") | ||||||
.await | ||||||
.unwrap(); | ||||||
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); | ||||||
assert_eq!(downloaded_path, new_downloaded_path); | ||||||
assert_eq!( | ||||||
val[..], | ||||||
// Corrupted sha | ||||||
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5") | ||||||
); | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
async fn revision() { | ||||||
let tmp = TempDir::new(); | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.