Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use relative redirects #36

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,23 @@ It aims to be compatible with [huggingface_hub](https://github.com/huggingface/h

[dependencies]
dirs = "5.0.1"
rand = {version = "0.8.5", optional = true}
http = { version = "1.0.0", optional = true }
rand = { version = "0.8.5", optional = true }
reqwest = { version = "0.11.18", optional = true, features = ["json"] }
serde = { version = "1.0.171", features = ["derive"] , optional = true}
serde_json = {version="1.0.103", optional = true}
indicatif = { version = "0.17.5", optional = true}
num_cpus = { version = "1.15.0", optional = true}
serde = { version = "1.0.171", features = ["derive"], optional = true }
serde_json = { version = "1.0.103", optional = true }
indicatif = { version = "0.17.5", optional = true }
num_cpus = { version = "1.15.0", optional = true }
tokio = { version = "1.29.1", optional = true, features = ["fs", "macros"] }
futures = { version = "0.3.28", optional = true}
futures = { version = "0.3.28", optional = true }
thiserror = { version = "1.0.43", optional = true }
ureq = { version = "2.8.0", optional = true, features = ["native-tls", "json", "socks-proxy"] }
native-tls = { version = "0.2.11", optional = true }
log = "0.4.19"

[features]
default = ["online"]
online = ["dep:ureq", "dep:native-tls", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:thiserror"]
online = ["dep:ureq", "dep:native-tls", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:thiserror", "dep:http"]
tokio = ["dep:reqwest", "dep:tokio", "tokio/rt-multi-thread", "dep:futures", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus", "dep:thiserror"]

[dev-dependencies]
Expand Down
91 changes: 77 additions & 14 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
use super::RepoInfo;
use crate::api::sync::ApiError::InvalidHeader;
use crate::{Cache, Repo, RepoType};
use http::{StatusCode, Uri};
use indicatif::{ProgressBar, ProgressStyle};
use std::collections::HashMap;
// use reqwest::{
// blocking::Agent,
// header::{
// HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION,
// CONTENT_RANGE, LOCATION, RANGE, USER_AGENT,
// },
// redirect::Policy,
// Error as ReqwestError,
// };
use super::RepoInfo;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use std::str::FromStr;
use thiserror::Error;
use ureq::{Agent, Request};

Expand All @@ -26,6 +20,7 @@ const CONTENT_RANGE: &str = "Content-Range";
const LOCATION: &str = "Location";
const USER_AGENT: &str = "User-Agent";
const AUTHORIZATION: &str = "Authorization";

type HeaderMap = HashMap<&'static str, String>;
type HeaderName = &'static str;

Expand Down Expand Up @@ -270,12 +265,57 @@ impl Api {
}

fn metadata(&self, url: &str) -> Result<Metadata, ApiError> {
let response = self
let mut response = self
.no_redirect_client
.get(url)
.set(RANGE, "bytes=0-0")
.call()
.map_err(Box::new)?;

// Closure to check if status code is a redirection
let should_redirect = |status_code: u16| {
matches!(
StatusCode::from_u16(status_code).unwrap(),
StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND
| StatusCode::SEE_OTHER
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT
)
};

// Follow redirects until `host.is_some()` i.e. only follow relative redirects
// See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411
let response = loop {
// Check if redirect
if should_redirect(response.status()) {
// Get redirect location
if let Some(location) = response.header("Location") {
// Parse location
let uri = Uri::from_str(location).map_err(|_| InvalidHeader("location"))?;

// Check if relative i.e. host is none
if uri.host().is_none() {
// Merge relative path with url
let mut parts = Uri::from_str(url).unwrap().into_parts();
parts.path_and_query = uri.into_parts().path_and_query;
// Final uri
let redirect_uri = Uri::from_parts(parts).unwrap();

// Follow redirect
response = self
.no_redirect_client
.get(&redirect_uri.to_string())
.set(RANGE, "bytes=0-0")
.call()
.map_err(Box::new)?;
continue;
}
};
}
break response;
};

// let headers = response.headers();
let header_commit = "x-repo-commit";
let header_linked_etag = "x-linked-etag";
Expand Down Expand Up @@ -459,7 +499,7 @@ impl ApiRepo {
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
.unwrap(), // .progress_chars("━ "),
);
let maxlength = 30;
let message = if filename.len() > maxlength {
Expand Down Expand Up @@ -605,6 +645,28 @@ mod tests {
)
}

#[test]
fn models() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"BAAI/bGe-reRanker-Base".to_string(),
RepoType::Model,
"refs/pr/5".to_string(),
);
let downloaded_path = api.repo(repo).download("tokenizer.json").unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("9EB652AC4E40CC093272BBBE0F55D521CF67570060227109B5CDC20945A4489E")
)
}

#[test]
fn info() {
let tmp = TempDir::new();
Expand Down Expand Up @@ -703,9 +765,9 @@ mod tests {
},
Siblings {
rfilename: "wikitext-2-v1/validation/index.duckdb".to_string()
}
},
],
sha: "f23dc6c07c427c9908f56bdb9829b0a767578ee5".to_string()
sha: "3acdf8c72a4dd61d76f34d7b54ee2a5b088ea3b1".to_string(),
}
)
}
Expand Down Expand Up @@ -738,6 +800,7 @@ mod tests {
"_id": "621ffdc136468d709f17ddb4",
"author": "mcpotato",
"config": {},
"createdAt": "2022-03-02T23:29:05.000Z",
"disabled": false,
"downloads": 0,
"gated": false,
Expand Down
60 changes: 51 additions & 9 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,17 +159,36 @@ impl ApiBuilder {
pub fn build(self) -> Result<Api, ApiError> {
let headers = self.build_headers()?;
let client = Client::builder().default_headers(headers.clone()).build()?;
let no_redirect_client = Client::builder()
.redirect(Policy::none())

// Policy: only follow relative redirects
// See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411
let relative_redirect_policy = Policy::custom(|attempt| {
// Follow redirects up to a maximum of 10.
if attempt.previous().len() > 10 {
return attempt.error("too many redirects");
}

if let Some(last) = attempt.previous().last() {
// If the url is not relative
if last.make_relative(attempt.url()).is_none() {
return attempt.stop();
}
}

// Follow redirect
attempt.follow()
});

let relative_redirect_client = Client::builder()
.redirect(relative_redirect_policy)
.default_headers(headers)
.build()?;
Ok(Api {
endpoint: self.endpoint,
url_template: self.url_template,
cache: self.cache,
client,

no_redirect_client,
relative_redirect_client,
max_files: self.max_files,
chunk_size: self.chunk_size,
parallel_failures: self.parallel_failures,
Expand All @@ -195,7 +214,7 @@ pub struct Api {
url_template: String,
cache: Cache,
client: Client,
no_redirect_client: Client,
relative_redirect_client: Client,
max_files: usize,
chunk_size: usize,
parallel_failures: usize,
Expand Down Expand Up @@ -277,7 +296,7 @@ impl Api {

async fn metadata(&self, url: &str) -> Result<Metadata, ApiError> {
let response = self
.no_redirect_client
.relative_redirect_client
.get(url)
.header(RANGE, "bytes=0-0")
.send()
Expand Down Expand Up @@ -536,7 +555,7 @@ impl ApiRepo {
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
.unwrap(), // .progress_chars("━ "),
);
let maxlength = 30;
let message = if filename.len() > maxlength {
Expand Down Expand Up @@ -706,6 +725,28 @@ mod tests {
)
}

#[tokio::test]
async fn models() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"BAAI/bGe-reRanker-Base".to_string(),
RepoType::Model,
"refs/pr/5".to_string(),
);
let downloaded_path = api.repo(repo).download("tokenizer.json").await.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("9EB652AC4E40CC093272BBBE0F55D521CF67570060227109B5CDC20945A4489E")
)
}

#[tokio::test]
async fn info() {
let tmp = TempDir::new();
Expand Down Expand Up @@ -804,9 +845,9 @@ mod tests {
},
Siblings {
rfilename: "wikitext-2-v1/validation/index.duckdb".to_string()
}
},
],
sha: "f23dc6c07c427c9908f56bdb9829b0a767578ee5".to_string()
sha: "3acdf8c72a4dd61d76f34d7b54ee2a5b088ea3b1".to_string(),
}
)
}
Expand Down Expand Up @@ -841,6 +882,7 @@ mod tests {
"_id": "621ffdc136468d709f17ddb4",
"author": "mcpotato",
"config": {},
"createdAt": "2022-03-02T23:29:05.000Z",
"disabled": false,
"downloads": 0,
"gated": false,
Expand Down