diff --git a/Cargo.toml b/Cargo.toml index 9806502..cbbb8f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,14 +23,15 @@ It aims to be compatible with [huggingface_hub](https://github.com/huggingface/h [dependencies] dirs = "5.0.1" -rand = {version = "0.8.5", optional = true} +http = { version = "1.0.0", optional = true } +rand = { version = "0.8.5", optional = true } reqwest = { version = "0.11.18", optional = true, features = ["json"] } -serde = { version = "1.0.171", features = ["derive"] , optional = true} -serde_json = {version="1.0.103", optional = true} -indicatif = { version = "0.17.5", optional = true} -num_cpus = { version = "1.15.0", optional = true} +serde = { version = "1.0.171", features = ["derive"], optional = true } +serde_json = { version = "1.0.103", optional = true } +indicatif = { version = "0.17.5", optional = true } +num_cpus = { version = "1.15.0", optional = true } tokio = { version = "1.29.1", optional = true, features = ["fs", "macros"] } -futures = { version = "0.3.28", optional = true} +futures = { version = "0.3.28", optional = true } thiserror = { version = "1.0.43", optional = true } ureq = { version = "2.8.0", optional = true, features = ["native-tls", "json", "socks-proxy"] } native-tls = { version = "0.2.11", optional = true } @@ -38,7 +39,7 @@ log = "0.4.19" [features] default = ["online"] -online = ["dep:ureq", "dep:native-tls", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:thiserror"] +online = ["dep:ureq", "dep:native-tls", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:thiserror", "dep:http"] tokio = ["dep:reqwest", "dep:tokio", "tokio/rt-multi-thread", "dep:futures", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus", "dep:thiserror"] [dev-dependencies] diff --git a/src/api/sync.rs b/src/api/sync.rs index 78e1f5d..bd70d64 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -1,18 +1,12 @@ +use super::RepoInfo; +use crate::api::sync::ApiError::InvalidHeader; use crate::{Cache, Repo, RepoType}; +use http::{StatusCode, Uri}; use indicatif::{ProgressBar, ProgressStyle}; use std::collections::HashMap; -// use reqwest::{ -// blocking::Agent, -// header::{ -// HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, -// CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, -// }, -// redirect::Policy, -// Error as ReqwestError, -// }; -use super::RepoInfo; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; +use std::str::FromStr; use thiserror::Error; use ureq::{Agent, Request}; @@ -26,6 +20,7 @@ const CONTENT_RANGE: &str = "Content-Range"; const LOCATION: &str = "Location"; const USER_AGENT: &str = "User-Agent"; const AUTHORIZATION: &str = "Authorization"; + type HeaderMap = HashMap<&'static str, String>; type HeaderName = &'static str; @@ -270,12 +265,57 @@ impl Api { } fn metadata(&self, url: &str) -> Result { - let response = self + let mut response = self .no_redirect_client .get(url) .set(RANGE, "bytes=0-0") .call() .map_err(Box::new)?; + + // Closure to check if status code is a redirection + let should_redirect = |status_code: u16| { + matches!( + StatusCode::from_u16(status_code).unwrap(), + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::SEE_OTHER + | StatusCode::TEMPORARY_REDIRECT + | StatusCode::PERMANENT_REDIRECT + ) + }; + + // Follow redirects until `host.is_some()` i.e. only follow relative redirects + // See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411 + let response = loop { + // Check if redirect + if should_redirect(response.status()) { + // Get redirect location + if let Some(location) = response.header("Location") { + // Parse location + let uri = Uri::from_str(location).map_err(|_| InvalidHeader("location"))?; + + // Check if relative i.e. host is none + if uri.host().is_none() { + // Merge relative path with url + let mut parts = Uri::from_str(url).unwrap().into_parts(); + parts.path_and_query = uri.into_parts().path_and_query; + // Final uri + let redirect_uri = Uri::from_parts(parts).unwrap(); + + // Follow redirect + response = self + .no_redirect_client + .get(&redirect_uri.to_string()) + .set(RANGE, "bytes=0-0") + .call() + .map_err(Box::new)?; + continue; + } + }; + } + break response; + }; + // let headers = response.headers(); let header_commit = "x-repo-commit"; let header_linked_etag = "x-linked-etag"; @@ -459,7 +499,7 @@ impl ApiRepo { ProgressStyle::with_template( "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", ) - .unwrap(), // .progress_chars("━ "), + .unwrap(), // .progress_chars("━ "), ); let maxlength = 30; let message = if filename.len() > maxlength { @@ -605,6 +645,28 @@ mod tests { ) } + #[test] + fn models() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + let repo = Repo::with_revision( + "BAAI/bGe-reRanker-Base".to_string(), + RepoType::Model, + "refs/pr/5".to_string(), + ); + let downloaded_path = api.repo(repo).download("tokenizer.json").unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("9EB652AC4E40CC093272BBBE0F55D521CF67570060227109B5CDC20945A4489E") + ) + } + #[test] fn info() { let tmp = TempDir::new(); @@ -703,9 +765,9 @@ mod tests { }, Siblings { rfilename: "wikitext-2-v1/validation/index.duckdb".to_string() - } + }, ], - sha: "f23dc6c07c427c9908f56bdb9829b0a767578ee5".to_string() + sha: "3acdf8c72a4dd61d76f34d7b54ee2a5b088ea3b1".to_string(), } ) } @@ -738,6 +800,7 @@ mod tests { "_id": "621ffdc136468d709f17ddb4", "author": "mcpotato", "config": {}, + "createdAt": "2022-03-02T23:29:05.000Z", "disabled": false, "downloads": 0, "gated": false, diff --git a/src/api/tokio.rs b/src/api/tokio.rs index 9d6a772..97e1be0 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -159,8 +159,28 @@ impl ApiBuilder { pub fn build(self) -> Result { let headers = self.build_headers()?; let client = Client::builder().default_headers(headers.clone()).build()?; - let no_redirect_client = Client::builder() - .redirect(Policy::none()) + + // Policy: only follow relative redirects + // See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411 + let relative_redirect_policy = Policy::custom(|attempt| { + // Follow redirects up to a maximum of 10. + if attempt.previous().len() > 10 { + return attempt.error("too many redirects"); + } + + if let Some(last) = attempt.previous().last() { + // If the url is not relative + if last.make_relative(attempt.url()).is_none() { + return attempt.stop(); + } + } + + // Follow redirect + attempt.follow() + }); + + let relative_redirect_client = Client::builder() + .redirect(relative_redirect_policy) .default_headers(headers) .build()?; Ok(Api { @@ -168,8 +188,7 @@ impl ApiBuilder { url_template: self.url_template, cache: self.cache, client, - - no_redirect_client, + relative_redirect_client, max_files: self.max_files, chunk_size: self.chunk_size, parallel_failures: self.parallel_failures, @@ -195,7 +214,7 @@ pub struct Api { url_template: String, cache: Cache, client: Client, - no_redirect_client: Client, + relative_redirect_client: Client, max_files: usize, chunk_size: usize, parallel_failures: usize, @@ -277,7 +296,7 @@ impl Api { async fn metadata(&self, url: &str) -> Result { let response = self - .no_redirect_client + .relative_redirect_client .get(url) .header(RANGE, "bytes=0-0") .send() @@ -536,7 +555,7 @@ impl ApiRepo { ProgressStyle::with_template( "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", ) - .unwrap(), // .progress_chars("━ "), + .unwrap(), // .progress_chars("━ "), ); let maxlength = 30; let message = if filename.len() > maxlength { @@ -706,6 +725,28 @@ mod tests { ) } + #[tokio::test] + async fn models() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + let repo = Repo::with_revision( + "BAAI/bGe-reRanker-Base".to_string(), + RepoType::Model, + "refs/pr/5".to_string(), + ); + let downloaded_path = api.repo(repo).download("tokenizer.json").await.unwrap(); + assert!(downloaded_path.exists()); + let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); + assert_eq!( + val[..], + hex!("9EB652AC4E40CC093272BBBE0F55D521CF67570060227109B5CDC20945A4489E") + ) + } + #[tokio::test] async fn info() { let tmp = TempDir::new(); @@ -804,9 +845,9 @@ mod tests { }, Siblings { rfilename: "wikitext-2-v1/validation/index.duckdb".to_string() - } + }, ], - sha: "f23dc6c07c427c9908f56bdb9829b0a767578ee5".to_string() + sha: "3acdf8c72a4dd61d76f34d7b54ee2a5b088ea3b1".to_string(), } ) } @@ -841,6 +882,7 @@ mod tests { "_id": "621ffdc136468d709f17ddb4", "author": "mcpotato", "config": {}, + "createdAt": "2022-03-02T23:29:05.000Z", "disabled": false, "downloads": 0, "gated": false,