Skip to content

Commit

Permalink
Since there doesn't seem to be a sane locking dep, let's reinvent our
Browse files Browse the repository at this point in the history
own.
  • Loading branch information
Narsil committed Jan 7, 2025
1 parent f9f6a05 commit 08fc326
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 16 deletions.
13 changes: 11 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,21 @@ rustls = { version = "0.23.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
serde_json = { version = "1", optional = true }
thiserror = { version = "2", optional = true }
tokio = { version = "1.29.1", optional = true, features = ["fs", "macros", "signal"] }
tokio = { version = "1.29.1", optional = true, features = ["fs", "macros"] }
ureq = { version = "2.8.0", optional = true, features = [
"json",
"socks-proxy",
] }
native-tls = { version = "0.2.12", optional = true }
libc = { version = "0.2", optional = true }

[target.'cfg(windows)'.dependencies.windows-sys]
version = "0.59"
features = ["Win32_Foundation", "Win32_Storage_FileSystem"]
optional = true

[target.'cfg(unix)'.dependencies.libc]
version = "0.2"
optional = true

[features]
default = ["default-tls", "tokio", "ureq"]
Expand All @@ -61,6 +69,7 @@ tokio = [
"dep:tokio",
"tokio/rt-multi-thread",
"dep:libc",
"dep:windows-sys",
]
ureq = [
"dep:http",
Expand Down
7 changes: 2 additions & 5 deletions examples/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ fn main() {}
#[cfg(feature = "ureq")]
#[cfg(not(feature = "tokio"))]
fn main() {
let api = hf_hub::api::sync::ApiBuilder::from_env().build().unwrap();
let api = hf_hub::api::sync::Api::new().unwrap();

let _filename = api
.model("meta-llama/Llama-2-7b-hf".to_string())
Expand All @@ -16,10 +16,7 @@ fn main() {
#[cfg(feature = "tokio")]
#[tokio::main]
async fn main() {
let api = hf_hub::api::tokio::ApiBuilder::from_env()
.high()
.build()
.unwrap();
let api = hf_hub::api::tokio::Api::new().unwrap();

let _filename = api
.model("meta-llama/Llama-2-7b-hf".to_string())
Expand Down
65 changes: 60 additions & 5 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::collections::HashMap;
use std::io::Read;
use std::io::Seek;
use std::num::ParseIntError;
use std::os::fd::AsRawFd;
use std::path::{Component, Path, PathBuf};
use std::str::FromStr;
use thiserror::Error;
Expand Down Expand Up @@ -70,28 +69,27 @@ impl HeaderAgent {
}
}

#[derive(Debug)]
struct Handle {
file: std::fs::File,
}

impl Drop for Handle {
fn drop(&mut self) {
unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) };
unlock(&self.file);
}
}

fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");

let file = std::fs::File::create(path.clone())?;
let mut res = unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) };
let mut res = lock(&file);
for _ in 0..5 {
if res == 0 {
break;
}
std::thread::sleep(std::time::Duration::from_secs(1));
res = unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) };
res = lock(&file);
}
if res != 0 {
Err(ApiError::LockAcquisition(path))
Expand All @@ -100,6 +98,63 @@ fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
}
}

#[cfg(target_family = "unix")]
mod unix {
use std::os::fd::AsRawFd;

pub(crate) fn lock(file: &std::fs::File) -> i32 {
unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) }
}
pub(crate) fn unlock(file: &std::fs::File) -> i32 {
unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) }
}
}
#[cfg(target_family = "unix")]
use unix::{lock, unlock};

#[cfg(target_family = "windows")]
mod windows {
use std::os::windows::io::AsRawHandle;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Storage::FileSystem::{
LockFileEx, UnlockFile, LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY,
};

pub(crate) fn lock(file: &std::fs::File) -> i32 {
unsafe {
let mut overlapped = mem::zeroed();
let flags = LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY;
LockFileEx(
file.as_raw_handle() as HANDLE,
flags,
0,
!0,
!0,
&mut overlapped,
)
}
}
pub(crate) fn unlock(file: &std::fs::File) -> i32 {
unsafe {
UnlockFile(file.as_raw_handle() as HANDLE, 0, 0, !0, !0);
}
}
}
#[cfg(target_family = "windows")]
use windows::{lock, unlock};

#[cfg(not(any(target_family = "unix", target_family = "windows")))]
mod other {
pub(crate) fn lock(file: &std::fs::File) -> i32 {
0
}
pub(crate) fn unlock(file: &std::fs::File) -> i32 {
0
}
}
#[cfg(not(any(target_family = "unix", target_family = "windows")))]
use other::{lock, unlock};

#[derive(Debug, Error)]
/// All errors the API can throw
pub enum ApiError {
Expand Down
64 changes: 60 additions & 4 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use reqwest::{
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::num::ParseIntError;
use std::os::fd::{AsFd, AsRawFd};
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use thiserror::Error;
Expand Down Expand Up @@ -70,21 +69,21 @@ struct Handle {

impl Drop for Handle {
fn drop(&mut self) {
unsafe { libc::flock(self.file.as_fd().as_raw_fd(), libc::LOCK_UN) };
unlock(&self.file);
}
}

async fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
path.set_extension("lock");

let file = tokio::fs::File::create(path.clone()).await?;
let mut res = unsafe { libc::flock(file.as_fd().as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) };
let mut res = lock(&file);
for _ in 0..5 {
if res == 0 {
break;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
res = unsafe { libc::flock(file.as_fd().as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) };
res = lock(&file);
}
if res != 0 {
Err(ApiError::LockAcquisition(path))
Expand All @@ -93,6 +92,63 @@ async fn lock_file(mut path: PathBuf) -> Result<Handle, ApiError> {
}
}

#[cfg(target_family = "unix")]
mod unix {
use std::os::fd::AsRawFd;

pub(crate) fn lock(file: &tokio::fs::File) -> i32 {
unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) }
}
pub(crate) fn unlock(file: &tokio::fs::File) -> i32 {
unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) }
}
}
#[cfg(target_family = "unix")]
use unix::{lock, unlock};

#[cfg(target_family = "windows")]
mod windows {
use std::os::windows::io::AsRawHandle;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Storage::FileSystem::{
LockFileEx, UnlockFile, LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY,
};

pub(crate) fn lock(file: &tokio::fs::File) -> i32 {
unsafe {
let mut overlapped = mem::zeroed();
let flags = LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY;
LockFileEx(
file.as_raw_handle() as HANDLE,
flags,
0,
!0,
!0,
&mut overlapped,
)
}
}
pub(crate) fn unlock(file: &tokio::fs::File) -> i32 {
unsafe {
UnlockFile(file.as_raw_handle() as HANDLE, 0, 0, !0, !0);
}
}
}
#[cfg(target_family = "windows")]
use windows::{lock, unlock};

#[cfg(not(any(target_family = "unix", target_family = "windows")))]
mod other {
pub(crate) fn lock(file: &tokio::fs::File) -> i32 {
0
}
pub(crate) fn unlock(file: &tokio::fs::File) -> i32 {
0
}
}
#[cfg(not(any(target_family = "unix", target_family = "windows")))]
use other::{lock, unlock};

#[derive(Debug, Error)]
/// All errors the API can throw
pub enum ApiError {
Expand Down

0 comments on commit 08fc326

Please sign in to comment.