Skip to content

Commit

Permalink
Better lock error. (#87)
Browse files Browse the repository at this point in the history
* Better lock error.

* Fixing the extension names + locking mecanism

* Lockfile tests.

* Clippy.
  • Loading branch information
Narsil authored Dec 31, 2024
1 parent 6e5123e commit ddffd9d
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 45 deletions.
77 changes: 56 additions & 21 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type HeaderMap = HashMap<&'static str, String>;
type HeaderName = &'static str;

/// Specific name for the sync part of the resumable file
const EXTENSION: &str = ".part";
const EXTENSION: &str = "part";

struct Wrapper<'a, P: Progress, R: Read> {
progress: &'a mut P,
Expand Down Expand Up @@ -70,6 +70,7 @@ impl HeaderAgent {
}
}

#[derive(Debug)]
struct Handle {
_file: std::fs::File,
path: PathBuf,
Expand All @@ -80,28 +81,25 @@ impl Drop for Handle {
}
}

fn lock_file(path: PathBuf) -> Result<Handle, ApiError> {
let mut lock = path.clone();
lock.set_extension(".lock");
fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");

let mut n = 0;
while lock.exists() {
n += 1;
if n > 0 {}
}
let mut lock_handle = None;
for i in 0..30 {
match std::fs::File::create(lock.clone()) {
Ok(handle) => lock_handle = Some(handle),
Err(_err) => {
match std::fs::File::create_new(path.clone()) {
Ok(handle) => {
lock_handle = Some(handle);
break;
}
_ => {
if i == 0 {
eprintln!("Waiting for lock");
log::warn!("Waiting for lock {path:?}");
}
std::thread::sleep(Duration::from_secs(1));
}
}
}
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition)?;
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?;
Ok(Handle { path, _file })
}

Expand Down Expand Up @@ -148,9 +146,10 @@ pub enum ApiError {
#[error("Invalid part file - corrupted file")]
InvalidResume,

/// Lock acquisition
#[error("Unable to acquire lock")]
LockAcquisition,
/// We failed to acquire lock for file `f`. Meaning
/// Someone else is writing/downloading said file
#[error("Lock acquisition failed: {0}")]
LockAcquisition(PathBuf),
}

/// Helper to create [`Api`] with all the options.
Expand Down Expand Up @@ -490,7 +489,6 @@ impl Api {
let filepath = tmp_path;

// Create the file and set everything properly
let lock = lock_file(filepath.clone())?;

let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) {
Ok(f) => f,
Expand Down Expand Up @@ -519,7 +517,6 @@ impl Api {
}
}
res?;
drop(lock);
Ok(filepath)
}

Expand Down Expand Up @@ -691,13 +688,16 @@ impl ApiRepo {
.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let lock = lock_file(blob_path.clone())?;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENSION);
let tmp_filename =
self.api
.download_tempfile(&url, metadata.size, progress, tmp_path, filename)?;

std::fs::rename(tmp_filename, &blob_path)?;
drop(lock);

let mut pointer_path = self
.api
.cache
Expand Down Expand Up @@ -846,7 +846,7 @@ mod tests {
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension(".part");
blob_part.set_extension("part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand Down Expand Up @@ -879,7 +879,7 @@ mod tests {
file.write_all(&[0]).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".part");
blob_part.set_extension("part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand All @@ -892,13 +892,48 @@ mod tests {
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
println!("{new_downloaded_path:?}");
println!("Corrupted {val:#x}");
assert_eq!(
val[..],
// Corrupted sha
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5")
);
}

#[test]
fn locking() {
use std::sync::{Arc, Mutex};
let tmp = Arc::new(Mutex::new(TempDir::new()));

let mut handles = vec![];
for _ in 0..5 {
let tmp2 = tmp.clone();
let f = std::thread::spawn(move || {
// 0..256ms sleep to randomize potential clashes
std::thread::sleep(Duration::from_millis(rand::random::<u8>().into()));
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp2.lock().unwrap().path.clone())
.build()
.unwrap();

let model_id = "julien-c/dummy-unknown".to_string();
api.model(model_id.clone()).download("config.json").unwrap()
});
handles.push(f);
}
while let Some(handle) = handles.pop() {
let downloaded_path = handle.join().unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
}
}

#[test]
fn simple_with_retries() {
let tmp = TempDir::new();
Expand Down
84 changes: 60 additions & 24 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Current name (used in user-agent)
const NAME: &str = env!("CARGO_PKG_NAME");

const EXTENSION: &str = ".sync.part";
const EXTENSION: &str = "sync.part";

/// This trait is used by users of the lib
/// to implement custom behavior during file downloads
Expand Down Expand Up @@ -74,28 +74,25 @@ impl Drop for Handle {
}
}

async fn lock_file(path: PathBuf) -> Result<Handle, ApiError> {
let mut lock = path.clone();
lock.set_extension(".lock");
async fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");

let mut n = 0;
while lock.exists() {
n += 1;
if n > 0 {}
}
let mut lock_handle = None;
for i in 0..30 {
match tokio::fs::File::create(lock.clone()).await {
Ok(handle) => lock_handle = Some(handle),
match tokio::fs::File::create_new(path.clone()).await {
Ok(handle) => {
lock_handle = Some(handle);
break;
}
Err(_err) => {
if i == 0 {
eprintln!("Waiting for lock");
log::warn!("Waiting for lock {path:?}");
}
std::thread::sleep(Duration::from_secs(1));
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition)?;
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?;
Ok(Handle { path, _file })
}

Expand Down Expand Up @@ -148,8 +145,10 @@ pub enum ApiError {
#[error("Join: {0}")]
Join(#[from] JoinError),

#[error("Lock acquisition failed")]
LockAcquisition,
/// We failed to acquire lock for file `f`. Meaning
/// Someone else is writing/downloading said file
#[error("Lock acquisition failed: {0}")]
LockAcquisition(PathBuf),
}

/// Helper to create [`Api`] with all the options.
Expand Down Expand Up @@ -565,7 +564,6 @@ impl ApiRepo {
// Create the file and set everything properly
const N_BYTES: usize = size_of::<u64>();

let lock = lock_file(filename.clone());
let start = match tokio::fs::OpenOptions::new()
.read(true)
.open(&filename)
Expand All @@ -574,7 +572,7 @@ impl ApiRepo {
Ok(mut f) => {
let len = f.metadata().await?.len();
if len == (length + N_BYTES) as u64 {
f.seek(SeekFrom::Start(length as u64)).await.unwrap();
f.seek(SeekFrom::Start(length as u64)).await?;
let mut buf = [0u8; N_BYTES];
let n = f.read(buf.as_mut_slice()).await?;
if n == N_BYTES {
Expand Down Expand Up @@ -681,7 +679,6 @@ impl ApiRepo {
.set_len(length as u64)
.await?;
progressbar.finish().await;
drop(lock);
Ok(filename)
}

Expand Down Expand Up @@ -801,15 +798,16 @@ impl ApiRepo {
let blob_path = cache.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let lock = lock_file(blob_path.clone()).await;
progress.init(metadata.size, filename).await;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENSION);
let tmp_filename = self
.download_tempfile(&url, metadata.size, tmp_path, progress)
.await
.unwrap();
.await?;

tokio::fs::rename(&tmp_filename, &blob_path).await?;
drop(lock);

let mut pointer_path = cache.pointer_path(&metadata.commit_hash);
pointer_path.push(filename);
Expand Down Expand Up @@ -909,6 +907,44 @@ mod tests {
assert_eq!(cache_path, downloaded_path);
}

#[tokio::test]
async fn locking() {
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
let tmp = Arc::new(Mutex::new(TempDir::new()));

let mut handles = JoinSet::new();
for _ in 0..5 {
let tmp2 = tmp.clone();
handles.spawn(async move {
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp2.lock().await.path.clone())
.build()
.unwrap();

// 0..256ms sleep to randomize potential clashes
let millis: u64 = rand::random::<u8>().into();
tokio::time::sleep(Duration::from_millis(millis)).await;
let model_id = "julien-c/dummy-unknown".to_string();
api.model(model_id.clone())
.download("config.json")
.await
.unwrap()
});
}
while let Some(handle) = handles.join_next().await {
let downloaded_path = handle.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
}
}

#[tokio::test]
async fn resume() {
let tmp = TempDir::new();
Expand Down Expand Up @@ -939,7 +975,7 @@ mod tests {
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand Down Expand Up @@ -975,7 +1011,7 @@ mod tests {
file.write_all(&new_size.to_le_bytes()).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand Down Expand Up @@ -1017,7 +1053,7 @@ mod tests {
file.write_all(&[0]).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand Down

0 comments on commit ddffd9d

Please sign in to comment.