diff --git a/src/api/sync.rs b/src/api/sync.rs index d1a27b6..cd2c4b4 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -30,7 +30,7 @@ type HeaderMap = HashMap<&'static str, String>; type HeaderName = &'static str; /// Specific name for the sync part of the resumable file -const EXTENSION: &str = ".part"; +const EXTENSION: &str = "part"; struct Wrapper<'a, P: Progress, R: Read> { progress: &'a mut P, @@ -70,6 +70,7 @@ impl HeaderAgent { } } +#[derive(Debug)] struct Handle { _file: std::fs::File, path: PathBuf, @@ -80,28 +81,25 @@ impl Drop for Handle { } } -fn lock_file(path: PathBuf) -> Result { - let mut lock = path.clone(); - lock.set_extension(".lock"); +fn lock_file(mut path: PathBuf) -> Result { + path.set_extension("lock"); - let mut n = 0; - while lock.exists() { - n += 1; - if n > 0 {} - } let mut lock_handle = None; for i in 0..30 { - match std::fs::File::create(lock.clone()) { - Ok(handle) => lock_handle = Some(handle), - Err(_err) => { + match std::fs::File::create_new(path.clone()) { + Ok(handle) => { + lock_handle = Some(handle); + break; + } + _ => { if i == 0 { - eprintln!("Waiting for lock"); + log::warn!("Waiting for lock {path:?}"); } std::thread::sleep(Duration::from_secs(1)); } } } - let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition)?; + let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?; Ok(Handle { path, _file }) } @@ -148,9 +146,10 @@ pub enum ApiError { #[error("Invalid part file - corrupted file")] InvalidResume, - /// Lock acquisition - #[error("Unable to acquire lock")] - LockAcquisition, + /// We failed to acquire lock for file `f`. Meaning + /// Someone else is writing/downloading said file + #[error("Lock acquisition failed: {0}")] + LockAcquisition(PathBuf), } /// Helper to create [`Api`] with all the options. @@ -490,7 +489,6 @@ impl Api { let filepath = tmp_path; // Create the file and set everything properly - let lock = lock_file(filepath.clone())?; let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) { Ok(f) => f, @@ -519,7 +517,6 @@ impl Api { } } res?; - drop(lock); Ok(filepath) } @@ -691,6 +688,7 @@ impl ApiRepo { .blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; + let lock = lock_file(blob_path.clone())?; let mut tmp_path = blob_path.clone(); tmp_path.set_extension(EXTENSION); let tmp_filename = @@ -698,6 +696,8 @@ impl ApiRepo { .download_tempfile(&url, metadata.size, progress, tmp_path, filename)?; std::fs::rename(tmp_filename, &blob_path)?; + drop(lock); + let mut pointer_path = self .api .cache @@ -846,7 +846,7 @@ mod tests { let new_size = (size as f32 * truncate) as u64; file.set_len(new_size).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".part"); + blob_part.set_extension("part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); @@ -879,7 +879,7 @@ mod tests { file.write_all(&[0]).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".part"); + blob_part.set_extension("part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); @@ -892,6 +892,8 @@ mod tests { let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); assert_eq!(downloaded_path, new_downloaded_path); + println!("{new_downloaded_path:?}"); + println!("Corrupted {val:#x}"); assert_eq!( val[..], // Corrupted sha @@ -899,6 +901,39 @@ mod tests { ); } + #[test] + fn locking() { + use std::sync::{Arc, Mutex}; + let tmp = Arc::new(Mutex::new(TempDir::new())); + + let mut handles = vec![]; + for _ in 0..5 { + let tmp2 = tmp.clone(); + let f = std::thread::spawn(move || { + // 0..256ms sleep to randomize potential clashes + std::thread::sleep(Duration::from_millis(rand::random::().into())); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp2.lock().unwrap().path.clone()) + .build() + .unwrap(); + + let model_id = "julien-c/dummy-unknown".to_string(); + api.model(model_id.clone()).download("config.json").unwrap() + }); + handles.push(f); + } + while let Some(handle) = handles.pop() { + let downloaded_path = handle.join().unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + } + } + #[test] fn simple_with_retries() { let tmp = TempDir::new(); diff --git a/src/api/tokio.rs b/src/api/tokio.rs index b9c6c0b..0677f2f 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -30,7 +30,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); /// Current name (used in user-agent) const NAME: &str = env!("CARGO_PKG_NAME"); -const EXTENSION: &str = ".sync.part"; +const EXTENSION: &str = "sync.part"; /// This trait is used by users of the lib /// to implement custom behavior during file downloads @@ -74,28 +74,25 @@ impl Drop for Handle { } } -async fn lock_file(path: PathBuf) -> Result { - let mut lock = path.clone(); - lock.set_extension(".lock"); +async fn lock_file(mut path: PathBuf) -> Result { + path.set_extension("lock"); - let mut n = 0; - while lock.exists() { - n += 1; - if n > 0 {} - } let mut lock_handle = None; for i in 0..30 { - match tokio::fs::File::create(lock.clone()).await { - Ok(handle) => lock_handle = Some(handle), + match tokio::fs::File::create_new(path.clone()).await { + Ok(handle) => { + lock_handle = Some(handle); + break; + } Err(_err) => { if i == 0 { - eprintln!("Waiting for lock"); + log::warn!("Waiting for lock {path:?}"); } - std::thread::sleep(Duration::from_secs(1)); + tokio::time::sleep(Duration::from_secs(1)).await; } } } - let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition)?; + let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?; Ok(Handle { path, _file }) } @@ -148,8 +145,10 @@ pub enum ApiError { #[error("Join: {0}")] Join(#[from] JoinError), - #[error("Lock acquisition failed")] - LockAcquisition, + /// We failed to acquire lock for file `f`. Meaning + /// Someone else is writing/downloading said file + #[error("Lock acquisition failed: {0}")] + LockAcquisition(PathBuf), } /// Helper to create [`Api`] with all the options. @@ -565,7 +564,6 @@ impl ApiRepo { // Create the file and set everything properly const N_BYTES: usize = size_of::(); - let lock = lock_file(filename.clone()); let start = match tokio::fs::OpenOptions::new() .read(true) .open(&filename) @@ -574,7 +572,7 @@ impl ApiRepo { Ok(mut f) => { let len = f.metadata().await?.len(); if len == (length + N_BYTES) as u64 { - f.seek(SeekFrom::Start(length as u64)).await.unwrap(); + f.seek(SeekFrom::Start(length as u64)).await?; let mut buf = [0u8; N_BYTES]; let n = f.read(buf.as_mut_slice()).await?; if n == N_BYTES { @@ -681,7 +679,6 @@ impl ApiRepo { .set_len(length as u64) .await?; progressbar.finish().await; - drop(lock); Ok(filename) } @@ -801,15 +798,16 @@ impl ApiRepo { let blob_path = cache.blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; + let lock = lock_file(blob_path.clone()).await; progress.init(metadata.size, filename).await; let mut tmp_path = blob_path.clone(); tmp_path.set_extension(EXTENSION); let tmp_filename = self .download_tempfile(&url, metadata.size, tmp_path, progress) - .await - .unwrap(); + .await?; tokio::fs::rename(&tmp_filename, &blob_path).await?; + drop(lock); let mut pointer_path = cache.pointer_path(&metadata.commit_hash); pointer_path.push(filename); @@ -909,6 +907,44 @@ mod tests { assert_eq!(cache_path, downloaded_path); } + #[tokio::test] + async fn locking() { + use std::sync::Arc; + use tokio::sync::Mutex; + use tokio::task::JoinSet; + let tmp = Arc::new(Mutex::new(TempDir::new())); + + let mut handles = JoinSet::new(); + for _ in 0..5 { + let tmp2 = tmp.clone(); + handles.spawn(async move { + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp2.lock().await.path.clone()) + .build() + .unwrap(); + + // 0..256ms sleep to randomize potential clashes + let millis: u64 = rand::random::().into(); + tokio::time::sleep(Duration::from_millis(millis)).await; + let model_id = "julien-c/dummy-unknown".to_string(); + api.model(model_id.clone()) + .download("config.json") + .await + .unwrap() + }); + } + while let Some(handle) = handles.join_next().await { + let downloaded_path = handle.unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + } + } + #[tokio::test] async fn resume() { let tmp = TempDir::new(); @@ -939,7 +975,7 @@ mod tests { let new_size = (size as f32 * truncate) as u64; file.set_len(new_size).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".sync.part"); + blob_part.set_extension("sync.part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); @@ -975,7 +1011,7 @@ mod tests { file.write_all(&new_size.to_le_bytes()).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".sync.part"); + blob_part.set_extension("sync.part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); @@ -1017,7 +1053,7 @@ mod tests { file.write_all(&[0]).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".sync.part"); + blob_part.set_extension("sync.part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap();