diff --git a/Cargo.lock b/Cargo.lock index f7ba5f88..fb95f17e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -259,6 +259,7 @@ dependencies = [ "sha2", "thiserror", "tokio", + "trait-variant", ] [[package]] diff --git a/atrium-oauth/identity/src/did.rs b/atrium-oauth/identity/src/did.rs index 5d433d6c..79621721 100644 --- a/atrium-oauth/identity/src/did.rs +++ b/atrium-oauth/identity/src/did.rs @@ -3,6 +3,7 @@ mod plc_resolver; mod web_resolver; pub use self::common_resolver::{CommonDidResolver, CommonDidResolverConfig}; +pub use self::plc_resolver::DEFAULT_PLC_DIRECTORY_URL; use crate::Resolver; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; diff --git a/atrium-oauth/identity/src/did/common_resolver.rs b/atrium-oauth/identity/src/did/common_resolver.rs index 5ceb61a9..5d0934b2 100644 --- a/atrium-oauth/identity/src/did/common_resolver.rs +++ b/atrium-oauth/identity/src/did/common_resolver.rs @@ -42,10 +42,7 @@ where type Output = DidDocument; async fn resolve(&self, did: &Self::Input) -> Result { - match did - .strip_prefix("did:") - .and_then(|s| s.split_once(':').and_then(|(method, _)| Some(method))) - { + match did.strip_prefix("did:").and_then(|s| s.split_once(':').map(|(method, _)| method)) { Some("plc") => self.plc_resolver.resolve(did).await, Some("web") => self.web_resolver.resolve(did).await, _ => Err(Error::UnsupportedDidMethod(did.clone())), diff --git a/atrium-oauth/identity/src/did/plc_resolver.rs b/atrium-oauth/identity/src/did/plc_resolver.rs index f30c8ae6..5f8dc1e7 100644 --- a/atrium-oauth/identity/src/did/plc_resolver.rs +++ b/atrium-oauth/identity/src/did/plc_resolver.rs @@ -8,7 +8,6 @@ use atrium_xrpc::http::{Request, Uri}; use atrium_xrpc::HttpClient; use std::sync::Arc; -#[allow(dead_code)] pub const DEFAULT_PLC_DIRECTORY_URL: &str = "https://plc.directory/"; #[derive(Clone, Debug)] diff --git a/atrium-oauth/identity/src/identity_resolver.rs b/atrium-oauth/identity/src/identity_resolver.rs index 3dcb3646..e8244bce 100644 --- a/atrium-oauth/identity/src/identity_resolver.rs +++ b/atrium-oauth/identity/src/identity_resolver.rs @@ -1,7 +1,5 @@ -use crate::did::DidResolver; use crate::error::{Error, Result}; -use crate::handle::HandleResolver; -use crate::Resolver; +use crate::{did::DidResolver, handle::HandleResolver, Resolver}; use atrium_api::types::string::AtIdentifier; use serde::{Deserialize, Serialize}; diff --git a/atrium-oauth/oauth-client/Cargo.toml b/atrium-oauth/oauth-client/Cargo.toml index f74c7e5e..99a0f3db 100644 --- a/atrium-oauth/oauth-client/Cargo.toml +++ b/atrium-oauth/oauth-client/Cargo.toml @@ -31,6 +31,7 @@ serde_html_form.workspace = true serde_json.workspace = true sha2.workspace = true thiserror.workspace = true +trait-variant.workspace = true [dev-dependencies] hickory-resolver.workspace = true diff --git a/atrium-oauth/oauth-client/examples/main.rs b/atrium-oauth/oauth-client/examples/main.rs index 11683d47..40a91a9e 100644 --- a/atrium-oauth/oauth-client/examples/main.rs +++ b/atrium-oauth/oauth-client/examples/main.rs @@ -1,9 +1,9 @@ -use atrium_identity::handle::{DnsTxtResolver, HandleResolverImpl}; -use atrium_identity::identity_resolver::{DidResolverConfig, HandleResolverConfig}; +use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; +use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver}; use atrium_oauth_client::store::state::MemoryStateStore; use atrium_oauth_client::{ - AtprotoLocalhostClientMetadata, AuthorizeOptions, OAuthClient, OAuthClientConfig, - OAuthResolverConfig, + AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, OAuthClient, + OAuthClientConfig, OAuthResolverConfig, }; use atrium_xrpc::http::Uri; use hickory_resolver::TokioAsyncResolver; @@ -23,35 +23,32 @@ impl Default for HickoryDnsTxtResolver { } } -#[async_trait::async_trait] impl DnsTxtResolver for HickoryDnsTxtResolver { async fn resolve( &self, query: &str, ) -> core::result::Result, Box> { - Ok(self - .resolver - .txt_lookup(query) - .await? - .iter() - .map(|txt| txt.to_string()) - .collect()) + Ok(self.resolver.txt_lookup(query).await?.iter().map(|txt| txt.to_string()).collect()) } } #[tokio::main] async fn main() -> Result<(), Box> { + let http_client = Arc::new(DefaultHttpClient::default()); let config = OAuthClientConfig { client_metadata: AtprotoLocalhostClientMetadata { redirect_uris: vec!["http://127.0.0.1".to_string()], }, keys: None, resolver: OAuthResolverConfig { - did: DidResolverConfig::default(), - handle: HandleResolverConfig { - r#impl: HandleResolverImpl::Atproto(Arc::new(HickoryDnsTxtResolver::default())), - cache: Default::default(), - }, + did_resolver: CommonDidResolver::new(CommonDidResolverConfig { + plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), + http_client: http_client.clone(), + }), + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { + dns_txt_resolver: HickoryDnsTxtResolver::default(), + http_client: http_client.clone(), + }), authorization_server_metadata: Default::default(), protected_resource_metadata: Default::default(), }, @@ -81,10 +78,7 @@ async fn main() -> Result<(), Box> { let uri = url.trim().parse::()?; let params = serde_html_form::from_str(uri.query().unwrap())?; - println!( - "{}", - serde_json::to_string_pretty(&client.callback(params).await?)? - ); + println!("{}", serde_json::to_string_pretty(&client.callback(params).await?)?); Ok(()) } diff --git a/atrium-oauth/oauth-client/src/atproto.rs b/atrium-oauth/oauth-client/src/atproto.rs index c7212afd..94bf4c56 100644 --- a/atrium-oauth/oauth-client/src/atproto.rs +++ b/atrium-oauth/oauth-client/src/atproto.rs @@ -151,11 +151,7 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata { token_endpoint_auth_method: Some(self.token_endpoint_auth_method.into()), grant_types: Some(self.grant_types.into_iter().map(|v| v.into()).collect()), scope: Some( - self.scopes - .into_iter() - .map(|v| v.into()) - .collect::>() - .join(" "), + self.scopes.into_iter().map(|v| v.into()).collect::>().join(" "), ), dpop_bound_access_tokens: Some(true), jwks_uri, diff --git a/atrium-oauth/oauth-client/src/http_client/default.rs b/atrium-oauth/oauth-client/src/http_client/default.rs index 6ce52e2d..8408d5e5 100644 --- a/atrium-oauth/oauth-client/src/http_client/default.rs +++ b/atrium-oauth/oauth-client/src/http_client/default.rs @@ -5,8 +5,6 @@ pub struct DefaultHttpClient { client: Client, } -#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] impl HttpClient for DefaultHttpClient { async fn send_http( &self, @@ -20,16 +18,12 @@ impl HttpClient for DefaultHttpClient { for (k, v) in response.headers() { builder = builder.header(k, v); } - builder - .body(response.bytes().await?.to_vec()) - .map_err(Into::into) + builder.body(response.bytes().await?.to_vec()).map_err(Into::into) } } impl Default for DefaultHttpClient { fn default() -> Self { - Self { - client: reqwest::Client::new(), - } + Self { client: reqwest::Client::new() } } } diff --git a/atrium-oauth/oauth-client/src/http_client/dpop.rs b/atrium-oauth/oauth-client/src/http_client/dpop.rs index 62b7fc52..489fc3e8 100644 --- a/atrium-oauth/oauth-client/src/http_client/dpop.rs +++ b/atrium-oauth/oauth-client/src/http_client/dpop.rs @@ -66,12 +66,7 @@ impl DpopClient { } } let nonces = MemorySimpleStore::::default(); - Ok(Self { - inner: http_client, - key, - iss, - nonces, - }) + Ok(Self { inner: http_client, key, iss, nonces }) } fn build_proof(&self, htm: String, htu: String, nonce: Option) -> Result { match crypto::Key::try_from(&self.key).map_err(Error::JwkCrypto)? { @@ -120,8 +115,6 @@ impl DpopClient { } } -#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] impl HttpClient for DpopClient where T: HttpClient + Send + Sync + 'static, @@ -141,11 +134,8 @@ where request.headers_mut().insert("DPoP", init_proof.parse()?); let response = self.inner.send_http(request.clone()).await?; - let next_nonce = response - .headers() - .get("DPoP-Nonce") - .and_then(|v| v.to_str().ok()) - .map(String::from); + let next_nonce = + response.headers().get("DPoP-Nonce").and_then(|v| v.to_str().ok()).map(String::from); match &next_nonce { Some(s) if next_nonce != init_nonce => { // Store the fresh nonce for future requests diff --git a/atrium-oauth/oauth-client/src/jose/jwt.rs b/atrium-oauth/oauth-client/src/jose/jwt.rs index f0f6f997..87621942 100644 --- a/atrium-oauth/oauth-client/src/jose/jwt.rs +++ b/atrium-oauth/oauth-client/src/jose/jwt.rs @@ -43,10 +43,7 @@ pub struct PublicClaims { impl From for Claims { fn from(registered: RegisteredClaims) -> Self { - Self { - registered, - public: PublicClaims::default(), - } + Self { registered, public: PublicClaims::default() } } } diff --git a/atrium-oauth/oauth-client/src/jose/signing.rs b/atrium-oauth/oauth-client/src/jose/signing.rs index 709072c4..22a98166 100644 --- a/atrium-oauth/oauth-client/src/jose/signing.rs +++ b/atrium-oauth/oauth-client/src/jose/signing.rs @@ -24,8 +24,5 @@ where let header = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header)?); let payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims)?); let signature: Signature<_> = key.sign(format!("{header}.{payload}").as_bytes()); - Ok(format!( - "{header}.{payload}.{}", - URL_SAFE_NO_PAD.encode(signature.to_bytes()) - )) + Ok(format!("{header}.{payload}.{}", URL_SAFE_NO_PAD.encode(signature.to_bytes()))) } diff --git a/atrium-oauth/oauth-client/src/keyset.rs b/atrium-oauth/oauth-client/src/keyset.rs index 7bf91cc6..b6728f9e 100644 --- a/atrium-oauth/oauth-client/src/keyset.rs +++ b/atrium-oauth/oauth-client/src/keyset.rs @@ -31,9 +31,8 @@ pub type Result = core::result::Result; pub struct Keyset(Vec); impl Keyset { - const PREFERRED_SIGNING_ALGORITHMS: [&'static str; 9] = [ - "EdDSA", "ES256K", "ES256", "PS256", "PS384", "PS512", "HS256", "HS384", "HS512", - ]; + const PREFERRED_SIGNING_ALGORITHMS: [&'static str; 9] = + ["EdDSA", "ES256K", "ES256", "PS256", "PS384", "PS512", "HS256", "HS384", "HS512"]; pub fn public_jwks(&self) -> JwkSet { let mut keys = Vec::with_capacity(self.0.len()); for mut key in self.0.clone() { diff --git a/atrium-oauth/oauth-client/src/oauth_client.rs b/atrium-oauth/oauth-client/src/oauth_client.rs index 2262b514..ca1534a3 100644 --- a/atrium-oauth/oauth-client/src/oauth_client.rs +++ b/atrium-oauth/oauth-client/src/oauth_client.rs @@ -11,7 +11,7 @@ use crate::types::{ TryIntoOAuthClientMetadata, }; use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values}; -use atrium_identity::Resolver; +use atrium_identity::{did::DidResolver, handle::HandleResolver, Resolver}; use atrium_xrpc::HttpClient; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; @@ -22,7 +22,7 @@ use sha2::{Digest, Sha256}; use std::sync::Arc; #[cfg(feature = "default-client")] -pub struct OAuthClientConfig +pub struct OAuthClientConfig where M: TryIntoOAuthClientMetadata, { @@ -32,11 +32,11 @@ where // Stores pub state_store: S, // Services - pub resolver: OAuthResolverConfig, + pub resolver: OAuthResolverConfig, } #[cfg(not(feature = "default-client"))] -pub struct OAuthClientConfig +pub struct OAuthClientConfig where M: TryIntoOAuthClientMetadata, { @@ -46,57 +46,53 @@ where // Stores pub state_store: S, // Services - pub resolver: OAuthResolverConfig, + pub resolver: OAuthResolverConfig, // Others pub http_client: T, } #[cfg(feature = "default-client")] -pub struct OAuthClient +pub struct OAuthClient where S: StateStore, T: HttpClient + Send + Sync + 'static, { pub client_metadata: OAuthClientMetadata, keyset: Option, - resolver: Arc>, + resolver: Arc>, state_store: S, http_client: Arc, } #[cfg(not(feature = "default-client"))] -pub struct OAuthClient +pub struct OAuthClient where S: StateStore, T: HttpClient + Send + Sync + 'static, { pub client_metadata: OAuthClientMetadata, keyset: Option, - resolver: Arc>, + resolver: Arc>, state_store: S, http_client: Arc, } #[cfg(feature = "default-client")] -impl OAuthClient +impl OAuthClient where S: StateStore, { - pub fn new(config: OAuthClientConfig) -> Result + pub fn new(config: OAuthClientConfig) -> Result where M: TryIntoOAuthClientMetadata, { - let keyset = if let Some(keys) = config.keys { - Some(keys.try_into()?) - } else { - None - }; + let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None }; let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?; let http_client = Arc::new(crate::http_client::default::DefaultHttpClient::default()); Ok(Self { client_metadata, keyset, - resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())?), + resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())), state_store: config.state_store, http_client, }) @@ -104,42 +100,37 @@ where } #[cfg(not(feature = "default-client"))] -impl OAuthClient +impl OAuthClient where S: StateStore, T: HttpClient + Send + Sync + 'static, { - pub fn new(config: OAuthClientConfig) -> Result + pub fn new(config: OAuthClientConfig) -> Result where M: TryIntoOAuthClientMetadata, { - let keyset = if let Some(keys) = config.keys { - Some(keys.try_into()?) - } else { - None - }; + let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None }; let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?; let http_client = Arc::new(config.http_client); Ok(Self { client_metadata, keyset, - resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())?), + resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())), state_store: config.state_store, http_client, }) } } -impl OAuthClient +impl OAuthClient where S: StateStore, + D: DidResolver + Send + Sync + 'static, + H: HandleResolver + Send + Sync + 'static, T: HttpClient + Send + Sync + 'static, { pub fn jwks(&self) -> JwkSet { - self.keyset - .as_ref() - .map(|keyset| keyset.public_jwks()) - .unwrap_or_default() + self.keyset.as_ref().map(|keyset| keyset.public_jwks()).unwrap_or_default() } pub async fn authorize( &self, @@ -169,11 +160,7 @@ where .set(state.clone(), state_data) .await .map_err(|e| Error::StateStore(Box::new(e)))?; - let login_hint = if identity.is_some() { - Some(input.as_ref().into()) - } else { - None - }; + let login_hint = if identity.is_some() { Some(input.as_ref().into()) } else { None }; let parameters = PushedAuthorizationRequestParameters { response_type: AuthorizationResponseType::Code, redirect_uri, @@ -213,9 +200,7 @@ where }) .unwrap()) } else if metadata.require_pushed_authorization_requests == Some(true) { - Err(Error::Authorize( - "server requires PAR but no endpoint is available".into(), - )) + Err(Error::Authorize("server requires PAR but no endpoint is available".into())) } else { // now "the use of PAR is *mandatory* for all clients" // https://github.com/bluesky-social/proposals/tree/main/0004-oauth#framework @@ -227,26 +212,15 @@ where return Err(Error::Callback("missing `state` parameter".into())); }; - let Some(state) = self - .state_store - .get(&state_key) - .await - .map_err(|e| Error::StateStore(Box::new(e)))? + let Some(state) = + self.state_store.get(&state_key).await.map_err(|e| Error::StateStore(Box::new(e)))? else { - return Err(Error::Callback(format!( - "unknown authorization state: {state_key}" - ))); + return Err(Error::Callback(format!("unknown authorization state: {state_key}"))); }; // Prevent any kind of replay - self.state_store - .del(&state_key) - .await - .map_err(|e| Error::StateStore(Box::new(e)))?; + self.state_store.del(&state_key).await.map_err(|e| Error::StateStore(Box::new(e)))?; - let metadata = self - .resolver - .get_authorization_server_metadata(&state.iss) - .await?; + let metadata = self.resolver.get_authorization_server_metadata(&state.iss).await?; // https://datatracker.ietf.org/doc/html/rfc9207#section-2.4 if let Some(iss) = params.iss { if iss != metadata.issuer { @@ -272,10 +246,8 @@ where Ok(token_set) } fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option { - let mut algs = metadata - .dpop_signing_alg_values_supported - .clone() - .unwrap_or(vec![FALLBACK_ALG.into()]); + let mut algs = + metadata.dpop_signing_alg_values_supported.clone().unwrap_or(vec![FALLBACK_ALG.into()]); algs.sort_by(compare_algos); generate_key(&algs) } @@ -285,9 +257,6 @@ where URL_SAFE_NO_PAD.encode(get_random_values::<_, 32>(&mut ThreadRng::default())); let mut hasher = Sha256::new(); hasher.update(verifier.as_bytes()); - ( - URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes())), - verifier, - ) + (URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes())), verifier) } } diff --git a/atrium-oauth/oauth-client/src/resolver.rs b/atrium-oauth/oauth-client/src/resolver.rs index 6dab73f6..3c195166 100644 --- a/atrium-oauth/oauth-client/src/resolver.rs +++ b/atrium-oauth/oauth-client/src/resolver.rs @@ -4,101 +4,109 @@ mod oauth_protected_resource_resolver; use self::oauth_authorization_server_resolver::DefaultOAuthAuthorizationServerResolver; use self::oauth_protected_resource_resolver::DefaultOAuthProtectedResourceResolver; use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; -use async_trait::async_trait; use atrium_identity::identity_resolver::{ - DidResolverConfig, HandleResolverConfig, IdentityResolver, IdentityResolverConfig, - ResolvedIdentity, + IdentityResolver, IdentityResolverConfig, ResolvedIdentity, }; -use atrium_identity::resolver::{CachedResolverConfig, MaybeCachedResolver}; -use atrium_identity::{Error, Resolver, Result}; +use atrium_identity::resolver::{CachedResolver, CachedResolverConfig}; +use atrium_identity::{did::DidResolver, handle::HandleResolver, Resolver}; +use atrium_identity::{Error, Result}; use atrium_xrpc::HttpClient; +use std::marker::PhantomData; use std::sync::Arc; use std::time::Duration; #[derive(Clone, Debug)] pub struct OAuthAuthorizationServerMetadataResolverConfig { - pub cache: Option, + pub cache: CachedResolverConfig, } impl Default for OAuthAuthorizationServerMetadataResolverConfig { fn default() -> Self { Self { - cache: Some(CachedResolverConfig { + cache: CachedResolverConfig { max_capacity: Some(100), time_to_live: Some(Duration::from_secs(60)), - }), + }, } } } #[derive(Clone, Debug)] pub struct OAuthProtectedResourceMetadataResolverConfig { - pub cache: Option, + pub cache: CachedResolverConfig, } impl Default for OAuthProtectedResourceMetadataResolverConfig { fn default() -> Self { Self { - cache: Some(CachedResolverConfig { + cache: CachedResolverConfig { max_capacity: Some(100), time_to_live: Some(Duration::from_secs(60)), - }), + }, } } } #[derive(Clone, Debug)] -pub struct OAuthResolverConfig { - pub did: DidResolverConfig, - pub handle: HandleResolverConfig, +pub struct OAuthResolverConfig { + pub did_resolver: D, + pub handle_resolver: H, pub authorization_server_metadata: OAuthAuthorizationServerMetadataResolverConfig, pub protected_resource_metadata: OAuthProtectedResourceMetadataResolverConfig, } pub struct OAuthResolver< T, + D, + H, PRR = DefaultOAuthProtectedResourceResolver, ASR = DefaultOAuthAuthorizationServerResolver, > where PRR: Resolver, ASR: Resolver, { - identity_resolver: IdentityResolver, - protected_resource_resolver: MaybeCachedResolver, - authorization_server_resolver: - MaybeCachedResolver, + identity_resolver: IdentityResolver, + protected_resource_resolver: CachedResolver, + authorization_server_resolver: CachedResolver, + _phantom: PhantomData, } -impl OAuthResolver +impl OAuthResolver where T: HttpClient + Send + Sync + 'static, { - pub fn new(config: OAuthResolverConfig, http_client: Arc) -> Result { - let protected_resource_resolver = MaybeCachedResolver::new( + pub fn new(config: OAuthResolverConfig, http_client: Arc) -> Self { + let protected_resource_resolver = CachedResolver::new( DefaultOAuthProtectedResourceResolver::new(http_client.clone()), config.authorization_server_metadata.cache, ); - let authorization_server_resolver = MaybeCachedResolver::new( + let authorization_server_resolver = CachedResolver::new( DefaultOAuthAuthorizationServerResolver::new(http_client.clone()), config.protected_resource_metadata.cache, ); - Ok(Self { + Self { identity_resolver: IdentityResolver::new(IdentityResolverConfig { - did: config.did, - handle: config.handle, - http_client, - })?, + did_resolver: config.did_resolver, + handle_resolver: config.handle_resolver, + }), protected_resource_resolver, authorization_server_resolver, - }) + _phantom: PhantomData, + } } +} + +impl OAuthResolver +where + T: HttpClient + Send + Sync + 'static, + D: DidResolver + Send + Sync + 'static, + H: HandleResolver + Send + Sync + 'static, +{ pub async fn get_authorization_server_metadata( &self, issuer: impl AsRef, ) -> Result { - self.authorization_server_resolver - .resolve(&issuer.as_ref().to_string()) - .await + self.authorization_server_resolver.resolve(&issuer.as_ref().to_string()).await } async fn resolve_from_service(&self, input: &str) -> Result { // Assume first that input is a PDS URL (as required by ATPROTO) @@ -120,10 +128,7 @@ where &self, pds: &str, ) -> Result { - let rs_metadata = self - .protected_resource_resolver - .resolve(&pds.to_string()) - .await?; + let rs_metadata = self.protected_resource_resolver.resolve(&pds.to_string()).await?; // ATPROTO requires one, and only one, authorization server entry // > That document MUST contain a single item in the authorization_servers array. // https://github.com/bluesky-social/proposals/tree/main/0004-oauth#server-metadata @@ -167,11 +172,11 @@ where } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Resolver for OAuthResolver +impl Resolver for OAuthResolver where T: HttpClient + Send + Sync + 'static, + D: DidResolver + Send + Sync + 'static, + H: HandleResolver + Send + Sync + 'static, { type Input = str; type Output = (OAuthAuthorizationServerMetadata, Option); diff --git a/atrium-oauth/oauth-client/src/resolver/oauth_authorization_server_resolver.rs b/atrium-oauth/oauth-client/src/resolver/oauth_authorization_server_resolver.rs index 964da942..e38428fe 100644 --- a/atrium-oauth/oauth-client/src/resolver/oauth_authorization_server_resolver.rs +++ b/atrium-oauth/oauth-client/src/resolver/oauth_authorization_server_resolver.rs @@ -1,5 +1,4 @@ use crate::types::OAuthAuthorizationServerMetadata; -use async_trait::async_trait; use atrium_identity::{Error, Resolver, Result}; use atrium_xrpc::http::uri::Builder; use atrium_xrpc::http::{Request, StatusCode, Uri}; @@ -16,8 +15,6 @@ impl DefaultOAuthAuthorizationServerResolver { } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Resolver for DefaultOAuthAuthorizationServerResolver where T: HttpClient + Send + Sync + 'static, diff --git a/atrium-oauth/oauth-client/src/resolver/oauth_protected_resource_resolver.rs b/atrium-oauth/oauth-client/src/resolver/oauth_protected_resource_resolver.rs index 2ff3919a..98c2ea7a 100644 --- a/atrium-oauth/oauth-client/src/resolver/oauth_protected_resource_resolver.rs +++ b/atrium-oauth/oauth-client/src/resolver/oauth_protected_resource_resolver.rs @@ -1,5 +1,4 @@ use crate::types::OAuthProtectedResourceMetadata; -use async_trait::async_trait; use atrium_identity::{Error, Resolver, Result}; use atrium_xrpc::http::uri::Builder; use atrium_xrpc::http::{Request, StatusCode, Uri}; @@ -16,8 +15,6 @@ impl DefaultOAuthProtectedResourceResolver { } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Resolver for DefaultOAuthProtectedResourceResolver where T: HttpClient + Send + Sync + 'static, diff --git a/atrium-oauth/oauth-client/src/server_agent.rs b/atrium-oauth/oauth-client/src/server_agent.rs index 1d3fb292..2a05beff 100644 --- a/atrium-oauth/oauth-client/src/server_agent.rs +++ b/atrium-oauth/oauth-client/src/server_agent.rs @@ -9,6 +9,7 @@ use crate::types::{ }; use crate::utils::{compare_algos, generate_nonce}; use atrium_api::types::string::Datetime; +use atrium_identity::{did::DidResolver, handle::HandleResolver}; use atrium_xrpc::http::{Method, Request, StatusCode}; use atrium_xrpc::HttpClient; use chrono::{TimeDelta, Utc}; @@ -92,26 +93,28 @@ where parameters: T, } -pub struct OAuthServerAgent +pub struct OAuthServerAgent where T: HttpClient + Send + Sync + 'static, { server_metadata: OAuthAuthorizationServerMetadata, client_metadata: OAuthClientMetadata, dpop_client: DpopClient, - resolver: Arc>, + resolver: Arc>, keyset: Option, } -impl OAuthServerAgent +impl OAuthServerAgent where T: HttpClient + Send + Sync + 'static, + D: DidResolver + Send + Sync + 'static, + H: HandleResolver + Send + Sync + 'static, { pub fn new( dpop_key: Key, server_metadata: OAuthAuthorizationServerMetadata, client_metadata: OAuthClientMetadata, - resolver: Arc>, + resolver: Arc>, http_client: Arc, keyset: Option, ) -> Result { @@ -121,13 +124,7 @@ where http_client, &server_metadata.token_endpoint_auth_signing_alg_values_supported, )?; - Ok(Self { - server_metadata, - client_metadata, - dpop_client, - resolver, - keyset, - }) + Ok(Self { server_metadata, client_metadata, dpop_client, resolver, keyset }) } /** * VERY IMPORTANT ! Always call this to process token responses. @@ -192,18 +189,11 @@ where .method(Method::POST) .header("Content-Type", "application/x-www-form-urlencoded") .body(body.into_bytes())?; - let res = self - .dpop_client - .send_http(req) - .await - .map_err(Error::HttpClient)?; + let res = self.dpop_client.send_http(req).await.map_err(Error::HttpClient)?; if res.status() == request.expected_status() { Ok(serde_json::from_slice(res.body())?) } else if res.status().is_client_error() { - Err(Error::HttpStatusWithBody( - res.status(), - serde_json::from_slice(res.body())?, - )) + Err(Error::HttpStatusWithBody(res.status(), serde_json::from_slice(res.body())?)) } else { Err(Error::HttpStatus(res.status())) } @@ -279,10 +269,9 @@ where OAuthRequest::Token(_) => Some(&self.server_metadata.token_endpoint), OAuthRequest::Revocation => self.server_metadata.revocation_endpoint.as_ref(), OAuthRequest::Introspection => self.server_metadata.introspection_endpoint.as_ref(), - OAuthRequest::PushedAuthorizationRequest(_) => self - .server_metadata - .pushed_authorization_request_endpoint - .as_ref(), + OAuthRequest::PushedAuthorizationRequest(_) => { + self.server_metadata.pushed_authorization_request_endpoint.as_ref() + } } } } diff --git a/atrium-oauth/oauth-client/src/store.rs b/atrium-oauth/oauth-client/src/store.rs index c54dc6be..c88ac8a9 100644 --- a/atrium-oauth/oauth-client/src/store.rs +++ b/atrium-oauth/oauth-client/src/store.rs @@ -1,12 +1,10 @@ pub mod memory; pub mod state; -use async_trait::async_trait; use std::error::Error; use std::hash::Hash; -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub trait SimpleStore where K: Eq + Hash, diff --git a/atrium-oauth/oauth-client/src/store/memory.rs b/atrium-oauth/oauth-client/src/store/memory.rs index ea24ee9f..c43c557d 100644 --- a/atrium-oauth/oauth-client/src/store/memory.rs +++ b/atrium-oauth/oauth-client/src/store/memory.rs @@ -1,5 +1,4 @@ use super::SimpleStore; -use async_trait::async_trait; use std::collections::HashMap; use std::fmt::Debug; use std::hash::Hash; @@ -17,14 +16,10 @@ pub struct MemorySimpleStore { impl Default for MemorySimpleStore { fn default() -> Self { - Self { - store: Arc::new(Mutex::new(HashMap::new())), - } + Self { store: Arc::new(Mutex::new(HashMap::new())) } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl SimpleStore for MemorySimpleStore where K: Debug + Eq + Hash + Send + Sync + 'static, diff --git a/atrium-oauth/oauth-client/src/types.rs b/atrium-oauth/oauth-client/src/types.rs index 82823909..45ef9bdb 100644 --- a/atrium-oauth/oauth-client/src/types.rs +++ b/atrium-oauth/oauth-client/src/types.rs @@ -42,11 +42,7 @@ pub struct AuthorizeOptions { impl Default for AuthorizeOptions { fn default() -> Self { - Self { - redirect_uri: None, - scopes: Some(vec![String::from("atproto")]), - prompt: None, - } + Self { redirect_uri: None, scopes: Some(vec![String::from("atproto")]), prompt: None } } } diff --git a/atrium-oauth/oauth-client/src/utils.rs b/atrium-oauth/oauth-client/src/utils.rs index e921593d..951fcf64 100644 --- a/atrium-oauth/oauth-client/src/utils.rs +++ b/atrium-oauth/oauth-client/src/utils.rs @@ -10,9 +10,9 @@ pub fn generate_key(allowed_algos: &[String]) -> Option { #[allow(clippy::single_match)] match alg.as_str() { "ES256" => { - return Some(Key::from(&crypto::Key::from( - SecretKey::::random(&mut ThreadRng::default()), - ))); + return Some(Key::from(&crypto::Key::from(SecretKey::::random( + &mut ThreadRng::default(), + )))); } _ => { // TODO: Implement other algorithms?