Skip to content

Commit

Permalink
Tmp.
Browse files Browse the repository at this point in the history
Custom progressbar.

Allow internal state.

More API.

Clippy.

Remove print statements

Using stream feature for shorter lived progressbars.

Moved muliplexing behind less obvious builder (needs more testing to
showcase pros and cons).
  • Loading branch information
Narsil committed Dec 27, 2024
1 parent 206b344 commit b0a2ed0
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 108 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ num_cpus = { version = "1.15.0", optional = true }
rand = { version = "0.8.5", optional = true }
reqwest = { version = "0.12.2", optional = true, default-features = false, features = [
"json",
"stream",
] }
rustls = { version = "0.23.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
Expand All @@ -41,7 +42,7 @@ native-tls = { version = "0.2.12", optional = true }
[features]
default = ["default-tls", "tokio", "ureq"]
# These features are only relevant when used with the `tokio` feature, but this might change in the future.
default-tls = []
default-tls = ["native-tls"]
native-tls = ["dep:reqwest", "reqwest/default", "dep:native-tls", "ureq/native-tls"]
rustls-tls = ["dep:rustls", "reqwest/rustls-tls"]
tokio = [
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ let _filename = repo.get("config.json").unwrap();

# SSL/TLS

This library uses its dependencies' default TLS implementations which are `rustls` for `ureq` (sync) and `native-tls` (openssl) for `tokio`.
This library uses tokio default TLS implementations which is `native-tls` (openssl) for `tokio`.

If you want control over the TLS backend you can remove the default features and only add the backend you are intending to use.

Expand Down
47 changes: 47 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use indicatif::{ProgressBar, ProgressStyle};
use serde::Deserialize;

/// The asynchronous version of the API
Expand All @@ -8,6 +9,52 @@ pub mod tokio;
#[cfg(feature = "ureq")]
pub mod sync;

/// This trait is used by users of the lib
/// to implement custom behavior during file downloads
pub trait Progress {
/// At the start of the download
/// The size is the total size in bytes of the file.
fn init(&mut self, size: usize, filename: &str);
/// This function is called whenever `size` bytes have been
/// downloaded in the temporary file
fn update(&mut self, size: usize);
/// This is called at the end of the download
fn finish(&mut self);
}

impl Progress for () {
fn init(&mut self, _size: usize, _filename: &str) {}
fn update(&mut self, _size: usize) {}
fn finish(&mut self) {}
}

impl Progress for ProgressBar {
fn init(&mut self, size: usize, filename: &str) {
self.set_length(size as u64);
self.set_style(
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
);
let maxlength = 30;
let message = if filename.len() > maxlength {
format!("..{}", &filename[filename.len() - maxlength..])
} else {
filename.to_string()
};
self.set_message(message);
}

fn update(&mut self, size: usize) {
self.inc(size as u64)
}

fn finish(&mut self) {
ProgressBar::finish(self);
}
}

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
Expand Down
137 changes: 85 additions & 52 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use super::RepoInfo;
use crate::api::sync::ApiError::InvalidHeader;
use crate::api::Progress;
use crate::{Cache, Repo, RepoType};
use http::{StatusCode, Uri};
use indicatif::{ProgressBar, ProgressStyle};
use indicatif::ProgressBar;
use rand::Rng;
use std::collections::HashMap;
use std::io::Read;
use std::io::Seek;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
Expand All @@ -26,6 +28,23 @@ const AUTHORIZATION: &str = "Authorization";
type HeaderMap = HashMap<&'static str, String>;
type HeaderName = &'static str;

struct Wrapper<P: Progress, R: Read> {
progress: P,
inner: R,
}

fn wrap_read<P: Progress, R: Read>(inner: R, progress: P) -> Wrapper<P, R> {
Wrapper { inner, progress }
}

impl<P: Progress, R: Read> Read for Wrapper<P, R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let read = self.inner.read(buf)?;
self.progress.update(read);
Ok(read)
}
}

/// Simple wrapper over [`ureq::Agent`] to include default headers
#[derive(Clone, Debug)]
pub struct HeaderAgent {
Expand Down Expand Up @@ -92,7 +111,6 @@ pub enum ApiError {
pub struct ApiBuilder {
endpoint: String,
cache: Cache,
url_template: String,
token: Option<String>,
max_retries: usize,
progress: bool,
Expand Down Expand Up @@ -128,12 +146,10 @@ impl ApiBuilder {
let max_retries = 0;
let progress = true;

let endpoint =
std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_owned());
let endpoint = "https://huggingface.co".to_string();

Self {
endpoint,
url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(),
cache,
token,
max_retries,
Expand Down Expand Up @@ -197,10 +213,8 @@ impl ApiBuilder {

Ok(Api {
endpoint: self.endpoint,
url_template: self.url_template,
cache: self.cache,
client,

no_redirect_client,
max_retries: self.max_retries,
progress: self.progress,
Expand All @@ -216,12 +230,10 @@ struct Metadata {
}

/// The actual Api used to interacto with the hub.
/// You can inspect repos with [`Api::info`]
/// or download files with [`Api::download`]
/// Use any repo with [`Api::repo`]
#[derive(Clone, Debug)]
pub struct Api {
endpoint: String,
url_template: String,
cache: Cache,
client: HeaderAgent,
no_redirect_client: HeaderAgent,
Expand Down Expand Up @@ -402,10 +414,10 @@ impl Api {
})
}

fn download_tempfile(
fn download_tempfile<P: Progress + Sync + Send>(
&self,
url: &str,
progressbar: Option<ProgressBar>,
progress: P,
) -> Result<PathBuf, ApiError> {
let filename = self.cache.temp_path();

Expand Down Expand Up @@ -434,16 +446,16 @@ impl Api {

let response = self.client.get(url).call().map_err(Box::new)?;

let mut reader = response.into_reader();
if let Some(p) = &progressbar {
reader = Box::new(p.wrap_read(reader));
}
let reader = response.into_reader();
//if let Some(p) = &progressbar {
let mut reader = Box::new(wrap_read(reader, progress));
//}

std::io::copy(&mut reader, &mut file)?;

if let Some(p) = progressbar {
p.finish();
}
//if let Some(p) = progressbar {
reader.progress.finish();
//}
Ok(filename)
}

Expand Down Expand Up @@ -506,6 +518,8 @@ impl Api {
}

/// Shorthand for accessing things within a particular repo
/// You can inspect repos with [`ApiRepo::info`]
/// or download files with [`ApiRepo::download`]
#[derive(Debug)]
pub struct ApiRepo {
api: Api,
Expand Down Expand Up @@ -541,12 +555,8 @@ impl ApiRepo {
pub fn url(&self, filename: &str) -> String {
let endpoint = &self.api.endpoint;
let revision = &self.repo.url_revision();
self.api
.url_template
.replace("{endpoint}", endpoint)
.replace("{repo_id}", &self.repo.url())
.replace("{revision}", revision)
.replace("{filename}", filename)
let repo_id = self.repo.url();
format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}")
}

/// This will attempt the fetch the file locally first, then [`Api.download`]
Expand All @@ -563,16 +573,40 @@ impl ApiRepo {
}
}

/// Downloads a remote file (if not already present) into the cache directory
/// to be used locally.
/// This functions require internet access to verify if new versions of the file
/// exist, even if a file is already on disk at location.
/// This function is used to download a file with a custom progress function.
/// It uses the [`Progress`] trait and can be used in more complex use
/// cases like downloading a showing progress in a UI.
/// ```no_run
/// # use hf_hub::api::sync::Api;
/// # use hf_hub::api::{sync::Api, Progress};
/// struct MyProgress{
/// current: usize,
/// total: usize
/// }
///
/// impl Progress for MyProgress{
/// fn init(&mut self, size: usize, _filename: &str){
/// self.total = size;
/// self.current = 0;
/// }
///
/// fn update(&mut self, size: usize){
/// self.current += size;
/// println!("{}/{}", self.current, self.total)
/// }
///
/// fn finish(&mut self){
/// println!("Done !");
/// }
/// }
/// let api = Api::new().unwrap();
/// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap();
/// let progress = MyProgress{current: 0, total: 0};
/// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).unwrap();
/// ```
pub fn download(&self, filename: &str) -> Result<PathBuf, ApiError> {
pub fn download_with_progress<P: Progress + Send + Sync>(
&self,
filename: &str,
mut progress: P,
) -> Result<PathBuf, ApiError> {
let url = self.url(filename);
let metadata = self.api.metadata(&url)?;

Expand All @@ -583,27 +617,9 @@ impl ApiRepo {
.blob_path(&metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let progressbar = if self.api.progress {
let progress = ProgressBar::new(metadata.size as u64);
progress.set_style(
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
);
let maxlength = 30;
let message = if filename.len() > maxlength {
format!("..{}", &filename[filename.len() - maxlength..])
} else {
filename.to_string()
};
progress.set_message(message);
Some(progress)
} else {
None
};
progress.init(metadata.size, filename);

let tmp_filename = self.api.download_tempfile(&url, progressbar)?;
let tmp_filename = self.api.download_tempfile(&url, progress)?;

std::fs::rename(tmp_filename, &blob_path)?;

Expand All @@ -624,6 +640,23 @@ impl ApiRepo {
Ok(pointer_path)
}

/// Downloads a remote file (if not already present) into the cache directory
/// to be used locally.
/// This functions require internet access to verify if new versions of the file
/// exist, even if a file is already on disk at location.
/// ```no_run
/// # use hf_hub::api::sync::Api;
/// let api = Api::new().unwrap();
/// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap();
/// ```
pub fn download(&self, filename: &str) -> Result<PathBuf, ApiError> {
if self.api.progress {
self.download_with_progress(filename, ProgressBar::new(0))
} else {
self.download_with_progress(filename, ())
}
}

/// Get information about the Repo
/// ```
/// use hf_hub::{api::sync::Api};
Expand Down
Loading

0 comments on commit b0a2ed0

Please sign in to comment.