Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better lock error. #87

Merged
merged 4 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 56 additions & 21 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type HeaderMap = HashMap<&'static str, String>;
type HeaderName = &'static str;

/// Specific name for the sync part of the resumable file
const EXTENSION: &str = ".part";
const EXTENSION: &str = "part";

struct Wrapper<'a, P: Progress, R: Read> {
progress: &'a mut P,
Expand Down Expand Up @@ -70,6 +70,7 @@ impl HeaderAgent {
}
}

#[derive(Debug)]
struct Handle {
_file: std::fs::File,
path: PathBuf,
Expand All @@ -80,28 +81,25 @@ impl Drop for Handle {
}
}

fn lock_file(path: PathBuf) -> Result<Handle, ApiError> {
let mut lock = path.clone();
lock.set_extension(".lock");
fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");

let mut n = 0;
while lock.exists() {
n += 1;
if n > 0 {}
}
let mut lock_handle = None;
for i in 0..30 {
match std::fs::File::create(lock.clone()) {
Ok(handle) => lock_handle = Some(handle),
Err(_err) => {
match std::fs::File::create_new(path.clone()) {
Ok(handle) => {
lock_handle = Some(handle);
break;
}
_ => {
if i == 0 {
eprintln!("Waiting for lock");
log::warn!("Waiting for lock {path:?}");
}
std::thread::sleep(Duration::from_secs(1));
}
}
}
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition)?;
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?;
Ok(Handle { path, _file })
}

Expand Down Expand Up @@ -148,9 +146,10 @@ pub enum ApiError {
#[error("Invalid part file - corrupted file")]
InvalidResume,

/// Lock acquisition
#[error("Unable to acquire lock")]
LockAcquisition,
/// We failed to acquire lock for file `f`. Meaning
/// Someone else is writing/downloading said file
#[error("Lock acquisition failed: {0}")]
LockAcquisition(PathBuf),
}

/// Helper to create [`Api`] with all the options.
Expand Down Expand Up @@ -490,7 +489,6 @@ impl Api {
let filepath = tmp_path;

// Create the file and set everything properly
let lock = lock_file(filepath.clone())?;

let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) {
Ok(f) => f,
Expand Down Expand Up @@ -519,7 +517,6 @@ impl Api {
}
}
res?;
drop(lock);
Ok(filepath)
}

Expand Down Expand Up @@ -691,13 +688,16 @@ impl ApiRepo {
.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let lock = lock_file(blob_path.clone())?;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENSION);
let tmp_filename =
self.api
.download_tempfile(&url, metadata.size, progress, tmp_path, filename)?;

std::fs::rename(tmp_filename, &blob_path)?;
drop(lock);

let mut pointer_path = self
.api
.cache
Expand Down Expand Up @@ -846,7 +846,7 @@ mod tests {
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension(".part");
blob_part.set_extension("part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand Down Expand Up @@ -879,7 +879,7 @@ mod tests {
file.write_all(&[0]).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".part");
blob_part.set_extension("part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand All @@ -892,13 +892,48 @@ mod tests {
let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap());
assert_eq!(downloaded_path, new_downloaded_path);
println!("{new_downloaded_path:?}");
println!("Corrupted {val:#x}");
assert_eq!(
val[..],
// Corrupted sha
hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5")
);
}

#[test]
fn locking() {
use std::sync::{Arc, Mutex};
let tmp = Arc::new(Mutex::new(TempDir::new()));

let mut handles = vec![];
for _ in 0..5 {
let tmp2 = tmp.clone();
let f = std::thread::spawn(move || {
// 0..256ms sleep to randomize potential clashes
std::thread::sleep(Duration::from_millis(rand::random::<u8>().into()));
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp2.lock().unwrap().path.clone())
.build()
.unwrap();

let model_id = "julien-c/dummy-unknown".to_string();
api.model(model_id.clone()).download("config.json").unwrap()
});
handles.push(f);
}
while let Some(handle) = handles.pop() {
let downloaded_path = handle.join().unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
}
}

#[test]
fn simple_with_retries() {
let tmp = TempDir::new();
Expand Down
84 changes: 60 additions & 24 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Current name (used in user-agent)
const NAME: &str = env!("CARGO_PKG_NAME");

const EXTENSION: &str = ".sync.part";
const EXTENSION: &str = "sync.part";

/// This trait is used by users of the lib
/// to implement custom behavior during file downloads
Expand Down Expand Up @@ -74,28 +74,25 @@ impl Drop for Handle {
}
}

async fn lock_file(path: PathBuf) -> Result<Handle, ApiError> {
let mut lock = path.clone();
lock.set_extension(".lock");
async fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");

let mut n = 0;
while lock.exists() {
n += 1;
if n > 0 {}
}
let mut lock_handle = None;
for i in 0..30 {
match tokio::fs::File::create(lock.clone()).await {
Ok(handle) => lock_handle = Some(handle),
match tokio::fs::File::create_new(path.clone()).await {
Ok(handle) => {
lock_handle = Some(handle);
break;
}
Err(_err) => {
if i == 0 {
eprintln!("Waiting for lock");
log::warn!("Waiting for lock {path:?}");
}
std::thread::sleep(Duration::from_secs(1));
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition)?;
let _file = lock_handle.ok_or_else(|| ApiError::LockAcquisition(path.clone()))?;
Ok(Handle { path, _file })
}

Expand Down Expand Up @@ -148,8 +145,10 @@ pub enum ApiError {
#[error("Join: {0}")]
Join(#[from] JoinError),

#[error("Lock acquisition failed")]
LockAcquisition,
/// We failed to acquire lock for file `f`. Meaning
/// Someone else is writing/downloading said file
#[error("Lock acquisition failed: {0}")]
LockAcquisition(PathBuf),
}

/// Helper to create [`Api`] with all the options.
Expand Down Expand Up @@ -565,7 +564,6 @@ impl ApiRepo {
// Create the file and set everything properly
const N_BYTES: usize = size_of::<u64>();

let lock = lock_file(filename.clone());
let start = match tokio::fs::OpenOptions::new()
.read(true)
.open(&filename)
Expand All @@ -574,7 +572,7 @@ impl ApiRepo {
Ok(mut f) => {
let len = f.metadata().await?.len();
if len == (length + N_BYTES) as u64 {
f.seek(SeekFrom::Start(length as u64)).await.unwrap();
f.seek(SeekFrom::Start(length as u64)).await?;
let mut buf = [0u8; N_BYTES];
let n = f.read(buf.as_mut_slice()).await?;
if n == N_BYTES {
Expand Down Expand Up @@ -681,7 +679,6 @@ impl ApiRepo {
.set_len(length as u64)
.await?;
progressbar.finish().await;
drop(lock);
Ok(filename)
}

Expand Down Expand Up @@ -801,15 +798,16 @@ impl ApiRepo {
let blob_path = cache.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let lock = lock_file(blob_path.clone()).await;
progress.init(metadata.size, filename).await;
let mut tmp_path = blob_path.clone();
tmp_path.set_extension(EXTENSION);
let tmp_filename = self
.download_tempfile(&url, metadata.size, tmp_path, progress)
.await
.unwrap();
.await?;

tokio::fs::rename(&tmp_filename, &blob_path).await?;
drop(lock);

let mut pointer_path = cache.pointer_path(&metadata.commit_hash);
pointer_path.push(filename);
Expand Down Expand Up @@ -909,6 +907,44 @@ mod tests {
assert_eq!(cache_path, downloaded_path);
}

#[tokio::test]
async fn locking() {
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
let tmp = Arc::new(Mutex::new(TempDir::new()));

let mut handles = JoinSet::new();
for _ in 0..5 {
let tmp2 = tmp.clone();
handles.spawn(async move {
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp2.lock().await.path.clone())
.build()
.unwrap();

// 0..256ms sleep to randomize potential clashes
let millis: u64 = rand::random::<u8>().into();
tokio::time::sleep(Duration::from_millis(millis)).await;
let model_id = "julien-c/dummy-unknown".to_string();
api.model(model_id.clone())
.download("config.json")
.await
.unwrap()
});
}
while let Some(handle) = handles.join_next().await {
let downloaded_path = handle.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
}
}

#[tokio::test]
async fn resume() {
let tmp = TempDir::new();
Expand Down Expand Up @@ -939,7 +975,7 @@ mod tests {
let new_size = (size as f32 * truncate) as u64;
file.set_len(new_size).unwrap();
let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand Down Expand Up @@ -975,7 +1011,7 @@ mod tests {
file.write_all(&new_size.to_le_bytes()).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand Down Expand Up @@ -1017,7 +1053,7 @@ mod tests {
file.write_all(&[0]).unwrap();

let mut blob_part = blob.clone();
blob_part.set_extension(".sync.part");
blob_part.set_extension("sync.part");
std::fs::rename(blob, &blob_part).unwrap();
std::fs::remove_file(&downloaded_path).unwrap();
let content = std::fs::read(&*blob_part).unwrap();
Expand Down
Loading