Skip to content

Commit

Permalink
Update http_client, use generics
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan committed Aug 23, 2024
1 parent c0b115c commit 8e7b58b
Show file tree
Hide file tree
Showing 14 changed files with 340 additions and 217 deletions.
66 changes: 2 additions & 64 deletions atrium-oauth/oauth-client/src/http_client.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,3 @@
pub mod dpop;

use atrium_xrpc::HttpClient;
#[cfg(feature = "default-client")]
use reqwest::Client;
use std::sync::{Arc, OnceLock};

static HTTP_CLIENT: OnceLock<Arc<dyn HttpClient + Send + Sync + 'static>> = OnceLock::new();

#[allow(dead_code)]
pub fn set_http_client(
client: impl HttpClient + Send + Sync + 'static,
) -> Result<(), Arc<dyn HttpClient + Send + Sync + 'static>> {
HTTP_CLIENT.set(Arc::new(client))
}

pub fn get_http_client() -> Arc<dyn HttpClient + Send + Sync + 'static> {
HTTP_CLIENT.get_or_init(get_default_client).clone()
}

#[cfg(feature = "default-client")]
fn get_default_client() -> Arc<dyn HttpClient + Send + Sync + 'static> {
Arc::new(ReqwestClient::default())
}

#[cfg(not(feature = "default-client"))]
fn get_default_client() -> Arc<dyn HttpClient + Send + Sync + 'static> {
panic!("no default client available")
}

#[cfg(feature = "default-client")]
struct ReqwestClient {
client: Client,
}

#[cfg(feature = "default-client")]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl HttpClient for ReqwestClient {
async fn send_http(
&self,
request: atrium_xrpc::http::Request<Vec<u8>>,
) -> core::result::Result<
atrium_xrpc::http::Response<Vec<u8>>,
Box<dyn std::error::Error + Send + Sync + 'static>,
> {
let response = self.client.execute(request.try_into()?).await?;
let mut builder = atrium_xrpc::http::Response::builder().status(response.status());
for (k, v) in response.headers() {
builder = builder.header(k, v);
}
builder
.body(response.bytes().await?.to_vec())
.map_err(Into::into)
}
}

#[cfg(feature = "default-client")]
impl Default for ReqwestClient {
fn default() -> Self {
Self {
client: reqwest::Client::new(),
}
}
}
pub mod default;
pub mod dpop;
35 changes: 35 additions & 0 deletions atrium-oauth/oauth-client/src/http_client/default.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use atrium_xrpc::HttpClient;
use reqwest::Client;

pub struct DefaultHttpClient {
client: Client,
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl HttpClient for DefaultHttpClient {
async fn send_http(
&self,
request: atrium_xrpc::http::Request<Vec<u8>>,
) -> core::result::Result<
atrium_xrpc::http::Response<Vec<u8>>,
Box<dyn std::error::Error + Send + Sync + 'static>,
> {
let response = self.client.execute(request.try_into()?).await?;
let mut builder = atrium_xrpc::http::Response::builder().status(response.status());
for (k, v) in response.headers() {
builder = builder.header(k, v);
}
builder
.body(response.bytes().await?.to_vec())
.map_err(Into::into)
}
}

impl Default for DefaultHttpClient {
fn default() -> Self {
Self {
client: reqwest::Client::new(),
}
}
}
30 changes: 22 additions & 8 deletions atrium-oauth/oauth-client/src/http_client/dpop.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use super::get_http_client;
use crate::store::memory::MemorySimpleStore;
use crate::store::SimpleStore;
use crate::utils::get_random_values;
Expand All @@ -18,6 +17,7 @@ use elliptic_curve::{
};
use rand::rngs::ThreadRng;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;

#[derive(Error, Debug)]
Expand Down Expand Up @@ -61,17 +61,23 @@ struct JwtClaims {
nonce: Option<String>,
}

pub struct DpopClient<S = MemorySimpleStore<String, String>>
pub struct DpopClient<T, S = MemorySimpleStore<String, String>>
where
S: SimpleStore<String, String>,
{
inner: Arc<T>,
key: JwkEcKey,
iss: String,
nonces: S,
}

impl DpopClient {
pub fn new(key: JwkEcKey, iss: String, supported_algs: Option<Vec<String>>) -> Result<Self> {
impl<T> DpopClient<T> {
pub fn new(
key: JwkEcKey,
iss: String,
supported_algs: Option<Vec<String>>,
http_client: Arc<T>,
) -> Result<Self> {
if let Some(algs) = supported_algs {
let alg = String::from(match key.crv() {
k256::Secp256k1::CRV => "ES256K",
Expand All @@ -83,7 +89,12 @@ impl DpopClient {
}
}
let nonces = MemorySimpleStore::<String, String>::default();
Ok(Self { key, iss, nonces })
Ok(Self {
inner: http_client,
key,
iss,
nonces,
})
}
fn build_proof(&self, htm: String, htu: String, nonce: Option<String>) -> Result<String> {
Ok(match self.key.crv() {
Expand Down Expand Up @@ -155,7 +166,10 @@ impl DpopClient {

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl HttpClient for DpopClient {
impl<T> HttpClient for DpopClient<T>
where
T: HttpClient + Send + Sync + 'static,
{
async fn send_http(
&self,
mut request: Request<Vec<u8>>,
Expand All @@ -169,7 +183,7 @@ impl HttpClient for DpopClient {
let init_nonce = self.nonces.get(&nonce_key).await?;
let init_proof = self.build_proof(htm.clone(), htu.clone(), init_nonce.clone())?;
request.headers_mut().insert("DPoP", init_proof.parse()?);
let response = get_http_client().send_http(request.clone()).await?;
let response = self.inner.send_http(request.clone()).await?;

let next_nonce = response
.headers()
Expand All @@ -193,7 +207,7 @@ impl HttpClient for DpopClient {
}
let next_proof = self.build_proof(htm, htu, next_nonce)?;
request.headers_mut().insert("DPoP", next_proof.parse()?);
let response = get_http_client().send_http(request).await?;
let response = self.inner.send_http(request).await?;
Ok(response)
}
}
111 changes: 94 additions & 17 deletions atrium-oauth/oauth-client/src/oauth_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::types::{
OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters,
};
use crate::utils::get_random_values;
use atrium_xrpc::HttpClient;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use elliptic_curve::JwkEcKey;
Expand All @@ -18,6 +19,7 @@ use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::sync::Arc;

#[cfg(feature = "default-client")]
pub struct OAuthClientConfig<S>
where
S: StateStore,
Expand All @@ -31,36 +33,100 @@ where
pub plc_directory_url: Option<String>,
}

pub struct OAuthClient<S>
#[cfg(not(feature = "default-client"))]
pub struct OAuthClientConfig<S, T>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
{
resolver: Arc<OAuthResolver>,
// Config
pub client_metadata: ClientMetadata,
// Stores
pub state_store: S,
// Services
pub handle_resolver: HandleResolverConfig,
pub plc_directory_url: Option<String>,
// Others
pub http_client: T,
}

pub struct OAuthClient<S, T>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
{
resolver: Arc<OAuthResolver<T>>,
client_metadata: OAuthClientMetadata,
state_store: S,
http_client: Arc<T>,
}

impl<S> OAuthClient<S>
#[cfg(feature = "default-client")]
impl<S> OAuthClient<S, crate::http_client::default::DefaultHttpClient>
where
S: StateStore,
{
pub fn new(config: OAuthClientConfig<S>) -> Result<Self> {
// TODO: validate client metadata
let client_metadata = config.client_metadata.validate()?;
let http_client = Arc::new(crate::http_client::default::DefaultHttpClient::default());
Ok(Self {
resolver: Arc::new(OAuthResolver::new(
IdentityResolver::new(
Arc::new(
CommonResolver::new(CommonResolverConfig {
plc_directory_url: config.plc_directory_url,
http_client: http_client.clone(),
})
.map_err(|e| Error::Resolver(crate::resolver::Error::DidResolver(e)))?,
),
handle_resolver(config.handle_resolver, http_client.clone()),
),
http_client.clone(),
)),
client_metadata,
state_store: config.state_store,
http_client,
})
}
}

#[cfg(not(feature = "default-client"))]
impl<S, T> OAuthClient<S, T>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
{
pub fn new(config: OAuthClientConfig<S, T>) -> Result<Self> {
// TODO: validate client metadata
let client_metadata = config.client_metadata.validate()?;
let http_client = Arc::new(config.http_client);
Ok(Self {
resolver: Arc::new(OAuthResolver::new(IdentityResolver::new(
Arc::new(
CommonResolver::new(CommonResolverConfig {
plc_directory_url: config.plc_directory_url,
})
.map_err(|e| Error::Resolver(crate::resolver::Error::DidResolver(e)))?,
resolver: Arc::new(OAuthResolver::new(
IdentityResolver::new(
Arc::new(
CommonResolver::new(CommonResolverConfig {
plc_directory_url: config.plc_directory_url,
http_client: http_client.clone(),
})
.map_err(|e| Error::Resolver(crate::resolver::Error::DidResolver(e)))?,
),
handle_resolver(config.handle_resolver, http_client.clone()),
),
Self::handle_resolver(config.handle_resolver),
))),
http_client.clone(),
)),
client_metadata,
state_store: config.state_store,
http_client,
})
}
}

impl<S, T> OAuthClient<S, T>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
{
pub async fn authorize(&mut self, input: impl AsRef<str>) -> Result<String> {
let redirect_uri = {
// TODO: use options.redirect_uri
Expand Down Expand Up @@ -112,6 +178,7 @@ where
metadata.clone(),
self.client_metadata.clone(),
self.resolver.clone(),
self.http_client.clone(),
)?;
let par_response = server
.request::<OAuthPusehedAuthorizationRequestResponse>(
Expand Down Expand Up @@ -186,18 +253,13 @@ where
metadata.clone(),
self.client_metadata.clone(),
self.resolver.clone(),
self.http_client.clone(),
)?;
let token_set = server.exchange_code(&params.code, &state.verifier).await?;
// TODO: verify id_token?
println!("{token_set:?}",);
Ok(())
}
fn handle_resolver(handle_resolver_config: HandleResolverConfig) -> Arc<dyn HandleResolver> {
match handle_resolver_config {
HandleResolverConfig::AppView(uri) => Arc::new(AppViewResolver::new(uri)),
HandleResolverConfig::Service(service) => service,
}
}
fn generate_key(mut algs: Vec<String>) -> Option<JwkEcKey> {
// 256K > ES (256 > 384 > 512) > PS (256 > 384 > 512) > RS (256 > 384 > 512) > other (in original order)
fn compare_algos(a: &String, b: &String) -> std::cmp::Ordering {
Expand Down Expand Up @@ -242,3 +304,18 @@ where
URL_SAFE_NO_PAD.encode(get_random_values::<_, 16>(&mut ThreadRng::default()))
}
}

fn handle_resolver<T>(
handle_resolver_config: HandleResolverConfig,
http_client: Arc<T>,
) -> Arc<dyn HandleResolver>
where
T: HttpClient + Send + Sync + 'static,
{
match handle_resolver_config {
HandleResolverConfig::AppView(uri) => {
Arc::new(AppViewResolver::new(uri, http_client.clone()))

Check warning on line 317 in atrium-oauth/oauth-client/src/oauth_client.rs

View workflow job for this annotation

GitHub Actions / clippy

[clippy] atrium-oauth/oauth-client/src/oauth_client.rs#L317

warning: redundant clone --> atrium-oauth/oauth-client/src/oauth_client.rs:317:59 | 317 | Arc::new(AppViewResolver::new(uri, http_client.clone())) | ^^^^^^^^ help: remove this | note: this value is dropped without further use --> atrium-oauth/oauth-client/src/oauth_client.rs:317:48 | 317 | Arc::new(AppViewResolver::new(uri, http_client.clone())) | ^^^^^^^^^^^ = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_clone = note: `#[warn(clippy::redundant_clone)]` on by default
Raw output
atrium-oauth/oauth-client/src/oauth_client.rs:317:59:w:warning: redundant clone
   --> atrium-oauth/oauth-client/src/oauth_client.rs:317:59
    |
317 |             Arc::new(AppViewResolver::new(uri, http_client.clone()))
    |                                                           ^^^^^^^^ help: remove this
    |
note: this value is dropped without further use
   --> atrium-oauth/oauth-client/src/oauth_client.rs:317:48
    |
317 |             Arc::new(AppViewResolver::new(uri, http_client.clone()))
    |                                                ^^^^^^^^^^^
    = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#redundant_clone
    = note: `#[warn(clippy::redundant_clone)]` on by default


__END__
}
HandleResolverConfig::Service(service) => service,
}
}
Loading

0 comments on commit 8e7b58b

Please sign in to comment.