From c777545d66e4c052683a37c145773cfc2a553c07 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 31 Dec 2024 12:38:38 +0100 Subject: [PATCH 1/4] Better lock error. --- src/api/sync.rs | 9 +++++---- src/api/tokio.rs | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/api/sync.rs b/src/api/sync.rs index d1a27b6..219178b 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -101,7 +101,7 @@ fn lock_file(path: PathBuf) -> Result { } } } - let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition)?; + let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(lock.clone()))?; Ok(Handle { path, _file }) } @@ -148,9 +148,10 @@ pub enum ApiError { #[error("Invalid part file - corrupted file")] InvalidResume, - /// Lock acquisition - #[error("Unable to acquire lock")] - LockAcquisition, + /// We failed to acquire lock for file `f`. Meaning + /// Someone else is writing/downloading said file + #[error("Lock acquisition failed: {0}")] + LockAcquisition(PathBuf), } /// Helper to create [`Api`] with all the options. diff --git a/src/api/tokio.rs b/src/api/tokio.rs index b9c6c0b..a3c7335 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -95,7 +95,7 @@ async fn lock_file(path: PathBuf) -> Result { } } } - let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition)?; + let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(lock.clone()))?; Ok(Handle { path, _file }) } @@ -148,8 +148,10 @@ pub enum ApiError { #[error("Join: {0}")] Join(#[from] JoinError), - #[error("Lock acquisition failed")] - LockAcquisition, + /// We failed to acquire lock for file `f`. Meaning + /// Someone else is writing/downloading said file + #[error("Lock acquisition failed: {0}")] + LockAcquisition(PathBuf), } /// Helper to create [`Api`] with all the options. From fde1e29c041bc1b6afff30e8d66924d8c5a0f95f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 31 Dec 2024 17:19:05 +0100 Subject: [PATCH 2/4] Fixing the extension names + locking mecanism --- src/api/sync.rs | 34 +++++++++++++++++++--------------- src/api/tokio.rs | 18 +++++++++--------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/api/sync.rs b/src/api/sync.rs index 219178b..20f98e0 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -30,7 +30,7 @@ type HeaderMap = HashMap<&'static str, String>; type HeaderName = &'static str; /// Specific name for the sync part of the resumable file -const EXTENSION: &str = ".part"; +const EXTENSION: &str = "part"; struct Wrapper<'a, P: Progress, R: Read> { progress: &'a mut P, @@ -70,6 +70,7 @@ impl HeaderAgent { } } +#[derive(Debug)] struct Handle { _file: std::fs::File, path: PathBuf, @@ -81,27 +82,27 @@ impl Drop for Handle { } fn lock_file(path: PathBuf) -> Result { - let mut lock = path.clone(); - lock.set_extension(".lock"); + let mut path = path.clone(); + path.set_extension("lock"); let mut n = 0; - while lock.exists() { + while path.exists() { n += 1; if n > 0 {} } let mut lock_handle = None; for i in 0..30 { - match std::fs::File::create(lock.clone()) { + match std::fs::File::create(path.clone()) { Ok(handle) => lock_handle = Some(handle), Err(_err) => { if i == 0 { - eprintln!("Waiting for lock"); + log::warn!("Waiting for lock {path:?}"); } std::thread::sleep(Duration::from_secs(1)); } } } - let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(lock.clone()))?; + let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?; Ok(Handle { path, _file }) } @@ -491,7 +492,7 @@ impl Api { let filepath = tmp_path; // Create the file and set everything properly - let lock = lock_file(filepath.clone())?; + let lock = lock_file(filepath.clone()).expect("lock"); let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) { Ok(f) => f, @@ -694,11 +695,12 @@ impl ApiRepo { let mut tmp_path = blob_path.clone(); tmp_path.set_extension(EXTENSION); - let tmp_filename = - self.api - .download_tempfile(&url, metadata.size, progress, tmp_path, filename)?; + let tmp_filename = self + .api + .download_tempfile(&url, metadata.size, progress, tmp_path, filename) + .expect("downloaded"); - std::fs::rename(tmp_filename, &blob_path)?; + std::fs::rename(tmp_filename, &blob_path).expect("rename"); let mut pointer_path = self .api .cache @@ -707,7 +709,7 @@ impl ApiRepo { pointer_path.push(filename); std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); - symlink_or_rename(&blob_path, &pointer_path)?; + symlink_or_rename(&blob_path, &pointer_path).expect("rename"); self.api .cache .repo(self.repo.clone()) @@ -847,7 +849,7 @@ mod tests { let new_size = (size as f32 * truncate) as u64; file.set_len(new_size).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".part"); + blob_part.set_extension("part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); @@ -880,7 +882,7 @@ mod tests { file.write_all(&[0]).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".part"); + blob_part.set_extension("part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); @@ -893,6 +895,8 @@ mod tests { let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); assert_eq!(downloaded_path, new_downloaded_path); + println!("{new_downloaded_path:?}"); + println!("Corrupted {val:#x}"); assert_eq!( val[..], // Corrupted sha diff --git a/src/api/tokio.rs b/src/api/tokio.rs index a3c7335..38d814c 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -30,7 +30,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); /// Current name (used in user-agent) const NAME: &str = env!("CARGO_PKG_NAME"); -const EXTENSION: &str = ".sync.part"; +const EXTENSION: &str = "sync.part"; /// This trait is used by users of the lib /// to implement custom behavior during file downloads @@ -75,17 +75,17 @@ impl Drop for Handle { } async fn lock_file(path: PathBuf) -> Result { - let mut lock = path.clone(); - lock.set_extension(".lock"); + let mut path = path.clone(); + path.set_extension("lock"); let mut n = 0; - while lock.exists() { + while path.exists() { n += 1; if n > 0 {} } let mut lock_handle = None; for i in 0..30 { - match tokio::fs::File::create(lock.clone()).await { + match tokio::fs::File::create(path.clone()).await { Ok(handle) => lock_handle = Some(handle), Err(_err) => { if i == 0 { @@ -95,7 +95,7 @@ async fn lock_file(path: PathBuf) -> Result { } } } - let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(lock.clone()))?; + let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?; Ok(Handle { path, _file }) } @@ -941,7 +941,7 @@ mod tests { let new_size = (size as f32 * truncate) as u64; file.set_len(new_size).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".sync.part"); + blob_part.set_extension("sync.part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); @@ -977,7 +977,7 @@ mod tests { file.write_all(&new_size.to_le_bytes()).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".sync.part"); + blob_part.set_extension("sync.part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); @@ -1019,7 +1019,7 @@ mod tests { file.write_all(&[0]).unwrap(); let mut blob_part = blob.clone(); - blob_part.set_extension(".sync.part"); + blob_part.set_extension("sync.part"); std::fs::rename(blob, &blob_part).unwrap(); std::fs::remove_file(&downloaded_path).unwrap(); let content = std::fs::read(&*blob_part).unwrap(); From fe46891a34e4218a952818688dd72022d657c45d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 31 Dec 2024 18:14:02 +0100 Subject: [PATCH 3/4] Lockfile tests. --- src/api/sync.rs | 67 ++++++++++++++++++++++++++++++++++------------- src/api/tokio.rs | 68 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 101 insertions(+), 34 deletions(-) diff --git a/src/api/sync.rs b/src/api/sync.rs index 20f98e0..cf54583 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -81,20 +81,17 @@ impl Drop for Handle { } } -fn lock_file(path: PathBuf) -> Result { - let mut path = path.clone(); +fn lock_file(mut path: PathBuf) -> Result { path.set_extension("lock"); - let mut n = 0; - while path.exists() { - n += 1; - if n > 0 {} - } let mut lock_handle = None; for i in 0..30 { - match std::fs::File::create(path.clone()) { - Ok(handle) => lock_handle = Some(handle), - Err(_err) => { + match std::fs::File::create_new(path.clone()) { + Ok(handle) => { + lock_handle = Some(handle); + break; + } + _ => { if i == 0 { log::warn!("Waiting for lock {path:?}"); } @@ -492,7 +489,6 @@ impl Api { let filepath = tmp_path; // Create the file and set everything properly - let lock = lock_file(filepath.clone()).expect("lock"); let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) { Ok(f) => f, @@ -521,7 +517,6 @@ impl Api { } } res?; - drop(lock); Ok(filepath) } @@ -693,14 +688,16 @@ impl ApiRepo { .blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; + let lock = lock_file(blob_path.clone())?; let mut tmp_path = blob_path.clone(); tmp_path.set_extension(EXTENSION); - let tmp_filename = self - .api - .download_tempfile(&url, metadata.size, progress, tmp_path, filename) - .expect("downloaded"); + let tmp_filename = + self.api + .download_tempfile(&url, metadata.size, progress, tmp_path, filename)?; + + std::fs::rename(tmp_filename, &blob_path)?; + drop(lock); - std::fs::rename(tmp_filename, &blob_path).expect("rename"); let mut pointer_path = self .api .cache @@ -709,7 +706,7 @@ impl ApiRepo { pointer_path.push(filename); std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); - symlink_or_rename(&blob_path, &pointer_path).expect("rename"); + symlink_or_rename(&blob_path, &pointer_path)?; self.api .cache .repo(self.repo.clone()) @@ -904,6 +901,40 @@ mod tests { ); } + #[test] + fn locking() { + use std::sync::{Arc, Mutex}; + let tmp = Arc::new(Mutex::new(TempDir::new())); + + let mut handles = vec![]; + for _ in 0..5 { + let tmp2 = tmp.clone(); + let f = std::thread::spawn(move || { + // 0..256ms sleep to randomize potential clashes + std::thread::sleep(Duration::from_millis(rand::random::().into())); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp2.lock().unwrap().path.clone()) + .build() + .unwrap(); + + let model_id = "julien-c/dummy-unknown".to_string(); + let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); + downloaded_path + }); + handles.push(f); + } + while let Some(handle) = handles.pop() { + let downloaded_path = handle.join().unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + } + } + #[test] fn simple_with_retries() { let tmp = TempDir::new(); diff --git a/src/api/tokio.rs b/src/api/tokio.rs index 38d814c..7147c5e 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -74,24 +74,21 @@ impl Drop for Handle { } } -async fn lock_file(path: PathBuf) -> Result { - let mut path = path.clone(); +async fn lock_file(mut path: PathBuf) -> Result { path.set_extension("lock"); - let mut n = 0; - while path.exists() { - n += 1; - if n > 0 {} - } let mut lock_handle = None; for i in 0..30 { - match tokio::fs::File::create(path.clone()).await { - Ok(handle) => lock_handle = Some(handle), + match tokio::fs::File::create_new(path.clone()).await { + Ok(handle) => { + lock_handle = Some(handle); + break; + } Err(_err) => { if i == 0 { - eprintln!("Waiting for lock"); + log::warn!("Waiting for lock {path:?}"); } - std::thread::sleep(Duration::from_secs(1)); + tokio::time::sleep(Duration::from_secs(1)).await; } } } @@ -567,7 +564,6 @@ impl ApiRepo { // Create the file and set everything properly const N_BYTES: usize = size_of::(); - let lock = lock_file(filename.clone()); let start = match tokio::fs::OpenOptions::new() .read(true) .open(&filename) @@ -576,7 +572,7 @@ impl ApiRepo { Ok(mut f) => { let len = f.metadata().await?.len(); if len == (length + N_BYTES) as u64 { - f.seek(SeekFrom::Start(length as u64)).await.unwrap(); + f.seek(SeekFrom::Start(length as u64)).await?; let mut buf = [0u8; N_BYTES]; let n = f.read(buf.as_mut_slice()).await?; if n == N_BYTES { @@ -683,7 +679,6 @@ impl ApiRepo { .set_len(length as u64) .await?; progressbar.finish().await; - drop(lock); Ok(filename) } @@ -803,15 +798,16 @@ impl ApiRepo { let blob_path = cache.blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; + let lock = lock_file(blob_path.clone()).await; progress.init(metadata.size, filename).await; let mut tmp_path = blob_path.clone(); tmp_path.set_extension(EXTENSION); let tmp_filename = self .download_tempfile(&url, metadata.size, tmp_path, progress) - .await - .unwrap(); + .await?; tokio::fs::rename(&tmp_filename, &blob_path).await?; + drop(lock); let mut pointer_path = cache.pointer_path(&metadata.commit_hash); pointer_path.push(filename); @@ -911,6 +907,46 @@ mod tests { assert_eq!(cache_path, downloaded_path); } + #[tokio::test] + async fn locking() { + use std::sync::Arc; + use tokio::sync::Mutex; + use tokio::task::JoinSet; + let tmp = Arc::new(Mutex::new(TempDir::new())); + + let mut handles = JoinSet::new(); + for _ in 0..4 { + let tmp2 = tmp.clone(); + handles.spawn(async move { + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp2.lock().await.path.clone()) + .build() + .unwrap(); + + // 0..256ms sleep to randomize potential clashes + let millis: u64 = rand::random::().into(); + tokio::time::sleep(Duration::from_millis(millis)).await; + let model_id = "julien-c/dummy-unknown".to_string(); + let downloaded_path = api + .model(model_id.clone()) + .download("config.json") + .await + .unwrap(); + downloaded_path + }); + } + while let Some(handle) = handles.join_next().await { + let downloaded_path = handle.unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") + ); + } + } + #[tokio::test] async fn resume() { let tmp = TempDir::new(); From 441bf49f28c699921fb6c2c43af99a7c2710ad08 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 31 Dec 2024 18:17:46 +0100 Subject: [PATCH 4/4] Clippy. --- src/api/sync.rs | 3 +-- src/api/tokio.rs | 8 +++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/api/sync.rs b/src/api/sync.rs index cf54583..cd2c4b4 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -919,8 +919,7 @@ mod tests { .unwrap(); let model_id = "julien-c/dummy-unknown".to_string(); - let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); - downloaded_path + api.model(model_id.clone()).download("config.json").unwrap() }); handles.push(f); } diff --git a/src/api/tokio.rs b/src/api/tokio.rs index 7147c5e..0677f2f 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -915,7 +915,7 @@ mod tests { let tmp = Arc::new(Mutex::new(TempDir::new())); let mut handles = JoinSet::new(); - for _ in 0..4 { + for _ in 0..5 { let tmp2 = tmp.clone(); handles.spawn(async move { let api = ApiBuilder::new() @@ -928,12 +928,10 @@ mod tests { let millis: u64 = rand::random::().into(); tokio::time::sleep(Duration::from_millis(millis)).await; let model_id = "julien-c/dummy-unknown".to_string(); - let downloaded_path = api - .model(model_id.clone()) + api.model(model_id.clone()) .download("config.json") .await - .unwrap(); - downloaded_path + .unwrap() }); } while let Some(handle) = handles.join_next().await {