From dc223cb1cfdfb28829e59aedf5b59fdb33d7a943 Mon Sep 17 00:00:00 2001 From: sugyan Date: Wed, 25 Sep 2024 22:25:12 +0900 Subject: [PATCH] Update atrium-identity --- Cargo.lock | 1 + atrium-oauth/identity/Cargo.toml | 5 +- atrium-oauth/identity/src/did.rs | 2 - .../identity/src/did/base_resolver.rs | 41 ------ .../identity/src/did/common_resolver.rs | 49 ++++--- atrium-oauth/identity/src/did/plc_resolver.rs | 17 +-- atrium-oauth/identity/src/did/web_resolver.rs | 7 +- atrium-oauth/identity/src/handle.rs | 75 ----------- .../identity/src/handle/appview_resolver.rs | 14 +- .../identity/src/handle/atproto_resolver.rs | 42 +++--- .../identity/src/handle/dns_resolver.rs | 64 +++------ .../src/handle/doh_dns_txt_resolver.rs | 14 +- .../src/handle/well_known_resolver.rs | 9 +- .../identity/src/identity_resolver.rs | 114 +++++----------- atrium-oauth/identity/src/resolver.rs | 122 ++++++++++-------- .../identity/src/resolver/cache_impl/moka.rs | 6 +- .../identity/src/resolver/cache_impl/wasm.rs | 15 +-- .../identity/src/resolver/cached_resolver.rs | 44 +++---- .../src/resolver/throttled_resolver.rs | 33 +++-- 19 files changed, 230 insertions(+), 444 deletions(-) delete mode 100644 atrium-oauth/identity/src/did/base_resolver.rs diff --git a/Cargo.lock b/Cargo.lock index e08fd698..f7ba5f88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -231,6 +231,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "trait-variant", "wasm-bindgen-test", "web-time", ] diff --git a/atrium-oauth/identity/Cargo.toml b/atrium-oauth/identity/Cargo.toml index d80f45a6..55a0b15b 100644 --- a/atrium-oauth/identity/Cargo.toml +++ b/atrium-oauth/identity/Cargo.toml @@ -17,13 +17,13 @@ keywords = ["atproto", "bluesky", "identity"] atrium-api = { workspace = true, default-features = false } atrium-xrpc.workspace = true dashmap.workspace = true -futures.workspace = true hickory-proto = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } serde_html_form.workspace = true serde_json.workspace = true thiserror.workspace = true tokio = { workspace = true, default-features = false, features = ["sync"] } +trait-variant.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] moka = { workspace = true, features = ["future"] } @@ -32,6 +32,9 @@ moka = { workspace = true, features = ["future"] } lru.workspace = true web-time.workspace = true +[dev-dependencies] +futures.workspace = true + [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } diff --git a/atrium-oauth/identity/src/did.rs b/atrium-oauth/identity/src/did.rs index a0c7b150..5d433d6c 100644 --- a/atrium-oauth/identity/src/did.rs +++ b/atrium-oauth/identity/src/did.rs @@ -1,10 +1,8 @@ -mod base_resolver; mod common_resolver; mod plc_resolver; mod web_resolver; pub use self::common_resolver::{CommonDidResolver, CommonDidResolverConfig}; -pub(crate) use self::plc_resolver::DEFAULT_PLC_DIRECTORY_URL; use crate::Resolver; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; diff --git a/atrium-oauth/identity/src/did/base_resolver.rs b/atrium-oauth/identity/src/did/base_resolver.rs deleted file mode 100644 index 09812320..00000000 --- a/atrium-oauth/identity/src/did/base_resolver.rs +++ /dev/null @@ -1,41 +0,0 @@ -use super::DidResolver; -use crate::error::{Error, Result}; -use crate::Resolver; -use async_trait::async_trait; -use atrium_api::did_doc::DidDocument; -use atrium_api::types::string::Did; -use std::sync::Arc; - -pub enum Method { - Plc, - Web, -} - -pub trait BaseResolver { - fn get_resolver(&self, method: Method) -> Arc; -} - -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Resolver for T -where - T: BaseResolver + Send + Sync + 'static, -{ - type Input = Did; - type Output = DidDocument; - - async fn resolve(&self, did: &Did) -> Result { - match did.strip_prefix("did:").and_then(|s| { - s.split_once(':').and_then(|(method, _)| match method { - "plc" => Some(Method::Plc), - "web" => Some(Method::Web), - _ => None, - }) - }) { - Some(method) => self.get_resolver(method).resolve(did).await, - None => Err(Error::UnsupportedDidMethod(did.clone())), - } - } -} - -impl DidResolver for T where T: BaseResolver + Send + Sync + 'static {} diff --git a/atrium-oauth/identity/src/did/common_resolver.rs b/atrium-oauth/identity/src/did/common_resolver.rs index a8dbca15..5ceb61a9 100644 --- a/atrium-oauth/identity/src/did/common_resolver.rs +++ b/atrium-oauth/identity/src/did/common_resolver.rs @@ -1,8 +1,11 @@ -use super::base_resolver::{BaseResolver, Method}; +use atrium_api::did_doc::DidDocument; +use atrium_api::types::string::Did; + use super::plc_resolver::{PlcDidResolver, PlcDidResolverConfig}; use super::web_resolver::{WebDidResolver, WebDidResolverConfig}; use super::DidResolver; -use crate::error::Result; +use crate::error::{Error, Result}; +use crate::Resolver; use std::sync::Arc; #[derive(Clone, Debug)] @@ -12,33 +15,47 @@ pub struct CommonDidResolverConfig { } pub struct CommonDidResolver { - plc_resolver: Arc>, - web_resolver: Arc>, + plc_resolver: PlcDidResolver, + web_resolver: WebDidResolver, } impl CommonDidResolver { - pub fn new(config: CommonDidResolverConfig) -> Result { - Ok(Self { - plc_resolver: Arc::new(PlcDidResolver::new(PlcDidResolverConfig { + pub fn new(config: CommonDidResolverConfig) -> Self { + Self { + plc_resolver: PlcDidResolver::new(PlcDidResolverConfig { plc_directory_url: config.plc_directory_url, http_client: config.http_client.clone(), - })?), - web_resolver: Arc::new(WebDidResolver::new(WebDidResolverConfig { + }), + web_resolver: WebDidResolver::new(WebDidResolverConfig { http_client: config.http_client, - })), - }) + }), + } } } -impl BaseResolver for CommonDidResolver +impl Resolver for CommonDidResolver where PlcDidResolver: DidResolver + Send + Sync + 'static, WebDidResolver: DidResolver + Send + Sync + 'static, { - fn get_resolver(&self, method: Method) -> Arc { - match method { - Method::Plc => self.plc_resolver.clone(), - Method::Web => self.web_resolver.clone(), + type Input = Did; + type Output = DidDocument; + + async fn resolve(&self, did: &Self::Input) -> Result { + match did + .strip_prefix("did:") + .and_then(|s| s.split_once(':').and_then(|(method, _)| Some(method))) + { + Some("plc") => self.plc_resolver.resolve(did).await, + Some("web") => self.web_resolver.resolve(did).await, + _ => Err(Error::UnsupportedDidMethod(did.clone())), } } } + +impl DidResolver for CommonDidResolver +where + PlcDidResolver: DidResolver + Send + Sync + 'static, + WebDidResolver: DidResolver + Send + Sync + 'static, +{ +} diff --git a/atrium-oauth/identity/src/did/plc_resolver.rs b/atrium-oauth/identity/src/did/plc_resolver.rs index 4419336b..f30c8ae6 100644 --- a/atrium-oauth/identity/src/did/plc_resolver.rs +++ b/atrium-oauth/identity/src/did/plc_resolver.rs @@ -1,7 +1,6 @@ use super::DidResolver; use crate::error::{Error, Result}; use crate::Resolver; -use async_trait::async_trait; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; use atrium_xrpc::http::uri::Builder; @@ -9,7 +8,8 @@ use atrium_xrpc::http::{Request, Uri}; use atrium_xrpc::HttpClient; use std::sync::Arc; -pub(crate) const DEFAULT_PLC_DIRECTORY_URL: &str = "https://plc.directory/"; +#[allow(dead_code)] +pub const DEFAULT_PLC_DIRECTORY_URL: &str = "https://plc.directory/"; #[derive(Clone, Debug)] pub struct PlcDidResolverConfig { @@ -18,21 +18,16 @@ pub struct PlcDidResolverConfig { } pub struct PlcDidResolver { - plc_directory_url: Uri, + plc_directory_url: String, http_client: Arc, } impl PlcDidResolver { - pub fn new(config: PlcDidResolverConfig) -> Result { - Ok(Self { - plc_directory_url: config.plc_directory_url.parse()?, - http_client: config.http_client, - }) + pub fn new(config: PlcDidResolverConfig) -> Self { + Self { plc_directory_url: config.plc_directory_url, http_client: config.http_client } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Resolver for PlcDidResolver where T: HttpClient + Send + Sync + 'static, @@ -41,7 +36,7 @@ where type Output = DidDocument; async fn resolve(&self, did: &Self::Input) -> Result { - let uri = Builder::from(self.plc_directory_url.clone()) + let uri = Builder::from(self.plc_directory_url.parse::()?) .path_and_query(format!("/{}", did.as_str())) .build()?; let res = self diff --git a/atrium-oauth/identity/src/did/web_resolver.rs b/atrium-oauth/identity/src/did/web_resolver.rs index 33d08685..582bdd00 100644 --- a/atrium-oauth/identity/src/did/web_resolver.rs +++ b/atrium-oauth/identity/src/did/web_resolver.rs @@ -1,7 +1,6 @@ use super::DidResolver; use crate::error::{Error, Result}; use crate::Resolver; -use async_trait::async_trait; use atrium_api::did_doc::DidDocument; use atrium_api::types::string::Did; use atrium_xrpc::http::{header::ACCEPT, Request, Uri}; @@ -21,14 +20,10 @@ pub struct WebDidResolver { impl WebDidResolver { pub fn new(config: WebDidResolverConfig) -> Self { - Self { - http_client: config.http_client, - } + Self { http_client: config.http_client } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Resolver for WebDidResolver where T: HttpClient + Send + Sync + 'static, diff --git a/atrium-oauth/identity/src/handle.rs b/atrium-oauth/identity/src/handle.rs index afdce86f..2ae285dd 100644 --- a/atrium-oauth/identity/src/handle.rs +++ b/atrium-oauth/identity/src/handle.rs @@ -8,85 +8,10 @@ mod well_known_resolver; pub use self::appview_resolver::{AppViewHandleResolver, AppViewHandleResolverConfig}; pub use self::atproto_resolver::{AtprotoHandleResolver, AtprotoHandleResolverConfig}; pub use self::dns_resolver::DnsTxtResolver; -use self::dns_resolver::DynamicDnsTxtResolver; #[cfg(feature = "doh-handle-resolver")] pub use self::doh_dns_txt_resolver::{DohDnsTxtResolver, DohDnsTxtResolverConfig}; pub use self::well_known_resolver::{WellKnownHandleResolver, WellKnownHandleResolverConfig}; -use crate::error::{Error, Result}; use crate::Resolver; -use async_trait::async_trait; use atrium_api::types::string::{Did, Handle}; -use atrium_xrpc::HttpClient; -use std::sync::Arc; pub trait HandleResolver: Resolver {} - -pub struct DynamicHandleResolver { - resolver: Arc, -} - -impl DynamicHandleResolver { - pub fn new(resolver: Arc) -> Self { - Self { resolver } - } -} - -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Resolver for DynamicHandleResolver { - type Input = Handle; - type Output = Did; - - async fn resolve(&self, handle: &Self::Input) -> Result { - self.resolver.resolve(handle).await - } -} - -impl HandleResolver for DynamicHandleResolver {} - -#[derive(Clone)] -pub enum HandleResolverImpl { - Atproto(Arc), - AppView(String), -} - -impl std::fmt::Debug for HandleResolverImpl { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - HandleResolverImpl::Atproto(_) => write!(f, "Atproto"), - HandleResolverImpl::AppView(url) => write!(f, "AppView({url})"), - } - } -} - -#[derive(Clone, Debug)] -pub struct HandleResolverConfig { - pub r#impl: HandleResolverImpl, - pub http_client: Arc, -} - -impl TryFrom> for DynamicHandleResolver -where - T: HttpClient + Send + Sync + 'static, -{ - type Error = Error; - - fn try_from(config: HandleResolverConfig) -> Result { - Ok(Self { - resolver: match config.r#impl { - HandleResolverImpl::Atproto(dns_txt_resolver) => { - Arc::new(AtprotoHandleResolver::new(AtprotoHandleResolverConfig { - dns_txt_resolver: DynamicDnsTxtResolver::new(dns_txt_resolver), - http_client: config.http_client, - })?) - } - HandleResolverImpl::AppView(service) => { - Arc::new(AppViewHandleResolver::new(AppViewHandleResolverConfig { - service_url: service, - http_client: config.http_client, - })?) - } - }, - }) - } -} diff --git a/atrium-oauth/identity/src/handle/appview_resolver.rs b/atrium-oauth/identity/src/handle/appview_resolver.rs index 47bcf5e9..90255a35 100644 --- a/atrium-oauth/identity/src/handle/appview_resolver.rs +++ b/atrium-oauth/identity/src/handle/appview_resolver.rs @@ -1,7 +1,6 @@ use super::HandleResolver; use crate::error::{Error, Result}; use crate::Resolver; -use async_trait::async_trait; use atrium_api::com::atproto::identity::resolve_handle; use atrium_api::types::string::{Did, Handle}; use atrium_xrpc::http::uri::Builder; @@ -16,21 +15,16 @@ pub struct AppViewHandleResolverConfig { } pub struct AppViewHandleResolver { - service_url: Uri, + service_url: String, http_client: Arc, } impl AppViewHandleResolver { - pub fn new(config: AppViewHandleResolverConfig) -> Result { - Ok(Self { - service_url: config.service_url.parse()?, - http_client: config.http_client, - }) + pub fn new(config: AppViewHandleResolverConfig) -> Self { + Self { service_url: config.service_url, http_client: config.http_client } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Resolver for AppViewHandleResolver where T: HttpClient + Send + Sync + 'static, @@ -39,7 +33,7 @@ where type Output = Did; async fn resolve(&self, handle: &Self::Input) -> Result { - let uri = Builder::from(self.service_url.clone()) + let uri = Builder::from(self.service_url.parse::()?) .path_and_query(format!( "/xrpc/com.atproto.identity.resolveHandle?{}", serde_html_form::to_string(resolve_handle::ParametersData { diff --git a/atrium-oauth/identity/src/handle/atproto_resolver.rs b/atrium-oauth/identity/src/handle/atproto_resolver.rs index e8f9c974..25ec54b4 100644 --- a/atrium-oauth/identity/src/handle/atproto_resolver.rs +++ b/atrium-oauth/identity/src/handle/atproto_resolver.rs @@ -3,10 +3,8 @@ use super::well_known_resolver::{WellKnownHandleResolver, WellKnownHandleResolve use super::HandleResolver; use crate::error::Result; use crate::Resolver; -use async_trait::async_trait; use atrium_api::types::string::{Did, Handle}; use atrium_xrpc::HttpClient; -use futures::future::select_ok; use std::sync::Arc; #[derive(Clone, Debug)] @@ -15,40 +13,46 @@ pub struct AtprotoHandleResolverConfig { pub http_client: Arc, } -pub struct AtprotoHandleResolver { - dns: DnsHandleResolver, +pub struct AtprotoHandleResolver { + dns: DnsHandleResolver, http: WellKnownHandleResolver, } -impl AtprotoHandleResolver { - pub fn new(config: AtprotoHandleResolverConfig) -> Result - where - R: DnsTxtResolver + Send + Sync + 'static, - { - Ok(Self { +impl AtprotoHandleResolver { + pub fn new(config: AtprotoHandleResolverConfig) -> Self { + Self { dns: DnsHandleResolver::new(DnsHandleResolverConfig { dns_txt_resolver: config.dns_txt_resolver, - })?, + }), http: WellKnownHandleResolver::new(WellKnownHandleResolverConfig { http_client: config.http_client, - })?, - }) + }), + } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Resolver for AtprotoHandleResolver +impl Resolver for AtprotoHandleResolver where + R: DnsTxtResolver + Send + Sync + 'static, T: HttpClient + Send + Sync + 'static, { type Input = Handle; type Output = Did; async fn resolve(&self, handle: &Self::Input) -> Result { - let (did, _) = select_ok([self.dns.resolve(handle), self.http.resolve(handle)]).await?; - Ok(did) + let d_fut = self.dns.resolve(handle); + let h_fut = self.http.resolve(handle); + if let Ok(did) = d_fut.await { + Ok(did) + } else { + h_fut.await + } } } -impl HandleResolver for AtprotoHandleResolver where T: HttpClient + Send + Sync + 'static {} +impl HandleResolver for AtprotoHandleResolver +where + R: DnsTxtResolver + Send + Sync + 'static, + T: HttpClient + Send + Sync + 'static, +{ +} diff --git a/atrium-oauth/identity/src/handle/dns_resolver.rs b/atrium-oauth/identity/src/handle/dns_resolver.rs index e8ae7262..7cdc6a92 100644 --- a/atrium-oauth/identity/src/handle/dns_resolver.rs +++ b/atrium-oauth/identity/src/handle/dns_resolver.rs @@ -1,66 +1,44 @@ use super::HandleResolver; use crate::error::{Error, Result}; use crate::Resolver; -use async_trait::async_trait; use atrium_api::types::string::{Did, Handle}; -use std::sync::Arc; +use std::future::Future; const SUBDOMAIN: &str = "_atproto"; const PREFIX: &str = "did="; -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub trait DnsTxtResolver { - async fn resolve( + fn resolve( &self, query: &str, - ) -> core::result::Result, Box>; -} - -pub struct DynamicDnsTxtResolver { - resolver: Arc, -} - -impl DynamicDnsTxtResolver { - pub fn new(resolver: Arc) -> Self { - Self { resolver } - } -} - -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl DnsTxtResolver for DynamicDnsTxtResolver { - async fn resolve( - &self, - query: &str, - ) -> core::result::Result, Box> { - self.resolver.resolve(query).await - } + ) -> impl Future< + Output = core::result::Result< + Vec, + Box, + >, + >; } #[derive(Clone, Debug)] -pub struct DnsHandleResolverConfig { - pub dns_txt_resolver: T, +pub struct DnsHandleResolverConfig { + pub dns_txt_resolver: R, } -pub struct DnsHandleResolver { - dns_txt_resolver: Arc, +pub struct DnsHandleResolver { + dns_txt_resolver: R, } -impl DnsHandleResolver { - pub fn new(config: DnsHandleResolverConfig) -> Result - where - T: DnsTxtResolver + Send + Sync + 'static, - { - Ok(Self { - dns_txt_resolver: Arc::new(config.dns_txt_resolver), - }) +impl DnsHandleResolver { + pub fn new(config: DnsHandleResolverConfig) -> Self { + Self { dns_txt_resolver: config.dns_txt_resolver } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Resolver for DnsHandleResolver { +impl Resolver for DnsHandleResolver +where + R: DnsTxtResolver + Send + Sync + 'static, +{ type Input = Handle; type Output = Did; @@ -79,4 +57,4 @@ impl Resolver for DnsHandleResolver { } } -impl HandleResolver for DnsHandleResolver {} +impl HandleResolver for DnsHandleResolver where R: DnsTxtResolver + Send + Sync + 'static {} diff --git a/atrium-oauth/identity/src/handle/doh_dns_txt_resolver.rs b/atrium-oauth/identity/src/handle/doh_dns_txt_resolver.rs index a7304bde..e2b00a78 100644 --- a/atrium-oauth/identity/src/handle/doh_dns_txt_resolver.rs +++ b/atrium-oauth/identity/src/handle/doh_dns_txt_resolver.rs @@ -1,6 +1,5 @@ use super::DnsTxtResolver; -use async_trait::async_trait; -use atrium_xrpc::http::{uri::InvalidUri, StatusCode, Uri}; +use atrium_xrpc::http::StatusCode; use atrium_xrpc::HttpClient; use hickory_proto::op::{Message, Query}; use hickory_proto::rr::{RData, RecordType}; @@ -22,22 +21,17 @@ pub struct DohDnsTxtResolverConfig { } pub struct DohDnsTxtResolver { - service_url: Uri, + service_url: String, http_client: Arc, } impl DohDnsTxtResolver { #[allow(dead_code)] - pub fn new(config: DohDnsTxtResolverConfig) -> core::result::Result { - Ok(Self { - service_url: config.service_url.parse()?, - http_client: config.http_client, - }) + pub fn new(config: DohDnsTxtResolverConfig) -> Self { + Self { service_url: config.service_url, http_client: config.http_client } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl DnsTxtResolver for DohDnsTxtResolver where T: HttpClient + Send + Sync + 'static, diff --git a/atrium-oauth/identity/src/handle/well_known_resolver.rs b/atrium-oauth/identity/src/handle/well_known_resolver.rs index 228bb86a..9f04b2b7 100644 --- a/atrium-oauth/identity/src/handle/well_known_resolver.rs +++ b/atrium-oauth/identity/src/handle/well_known_resolver.rs @@ -1,7 +1,6 @@ use super::HandleResolver; use crate::error::{Error, Result}; use crate::Resolver; -use async_trait::async_trait; use atrium_api::types::string::{Did, Handle}; use atrium_xrpc::http::Request; use atrium_xrpc::HttpClient; @@ -19,15 +18,11 @@ pub struct WellKnownHandleResolver { } impl WellKnownHandleResolver { - pub fn new(config: WellKnownHandleResolverConfig) -> Result { - Ok(Self { - http_client: config.http_client, - }) + pub fn new(config: WellKnownHandleResolverConfig) -> Self { + Self { http_client: config.http_client } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Resolver for WellKnownHandleResolver where T: HttpClient + Send + Sync + 'static, diff --git a/atrium-oauth/identity/src/identity_resolver.rs b/atrium-oauth/identity/src/identity_resolver.rs index 3701beed..3dcb3646 100644 --- a/atrium-oauth/identity/src/identity_resolver.rs +++ b/atrium-oauth/identity/src/identity_resolver.rs @@ -1,14 +1,9 @@ -use crate::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; +use crate::did::DidResolver; use crate::error::{Error, Result}; -use crate::handle::{DynamicHandleResolver, HandleResolverImpl}; -use crate::resolver::CachedResolverConfig; +use crate::handle::HandleResolver; use crate::Resolver; -use async_trait::async_trait; use atrium_api::types::string::AtIdentifier; -use atrium_xrpc::HttpClient; use serde::{Deserialize, Serialize}; -use std::marker::PhantomData; -use std::sync::Arc; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct ResolvedIdentity { @@ -17,106 +12,55 @@ pub struct ResolvedIdentity { } #[derive(Clone, Debug)] -pub struct DidResolverConfig { - pub plc_directory_url: String, - pub cache: CachedResolverConfig, +pub struct IdentityResolverConfig { + pub did_resolver: D, + pub handle_resolver: H, } -impl Default for DidResolverConfig { - fn default() -> Self { - Self { - plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), - cache: CachedResolverConfig::default(), - } - } -} - -#[derive(Clone, Debug)] -pub struct HandleResolverConfig { - pub r#impl: HandleResolverImpl, - pub cache: CachedResolverConfig, -} - -#[derive(Clone, Debug)] -pub struct IdentityResolverConfig { - pub did: DidResolverConfig, - pub handle: HandleResolverConfig, - pub http_client: Arc, -} - -pub struct IdentityResolver, H = DynamicHandleResolver> { +pub struct IdentityResolver { did_resolver: D, handle_resolver: H, - _phantom: PhantomData, -} - -impl IdentityResolver { - pub fn new(config: IdentityResolverConfig) -> Result - where - T: HttpClient + Send + Sync + 'static, - { - Ok(Self::from(( - CommonDidResolver::new(CommonDidResolverConfig { - plc_directory_url: config.did.plc_directory_url, - http_client: config.http_client.clone(), - })?, - DynamicHandleResolver::try_from(super::handle::HandleResolverConfig { - r#impl: config.handle.r#impl, - http_client: config.http_client, - })?, - ))) - } } -impl From<(D, H)> for IdentityResolver { - fn from((did_resolver, handle_resolver): (D, H)) -> Self { - Self { - did_resolver, - handle_resolver, - _phantom: PhantomData, - } +impl IdentityResolver { + pub fn new(config: IdentityResolverConfig) -> Self { + Self { did_resolver: config.did_resolver, handle_resolver: config.handle_resolver } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Resolver for IdentityResolver +impl Resolver for IdentityResolver where - T: HttpClient + Send + Sync + 'static, + D: DidResolver + Send + Sync + 'static, + H: HandleResolver + Send + Sync + 'static, { type Input = str; type Output = ResolvedIdentity; async fn resolve(&self, input: &Self::Input) -> Result { - let document = match input - .parse::() - .map_err(|e| Error::AtIdentifier(e.to_string()))? - { - AtIdentifier::Did(did) => self.did_resolver.resolve(&did).await?, - AtIdentifier::Handle(handle) => { - let did = self.handle_resolver.resolve(&handle).await?; - let document = self.did_resolver.resolve(&did).await?; - if let Some(aka) = &document.also_known_as { - if !aka.contains(&format!("at://{}", handle.as_str())) { - return Err(Error::DidDocument(format!( - "did document for `{}` does not include the handle `{}`", - did.as_str(), - handle.as_str() - ))); + let document = + match input.parse::().map_err(|e| Error::AtIdentifier(e.to_string()))? { + AtIdentifier::Did(did) => self.did_resolver.resolve(&did).await?, + AtIdentifier::Handle(handle) => { + let did = self.handle_resolver.resolve(&handle).await?; + let document = self.did_resolver.resolve(&did).await?; + if let Some(aka) = &document.also_known_as { + if !aka.contains(&format!("at://{}", handle.as_str())) { + return Err(Error::DidDocument(format!( + "did document for `{}` does not include the handle `{}`", + did.as_str(), + handle.as_str() + ))); + } } + document } - document - } - }; + }; let Some(service) = document.get_pds_endpoint() else { return Err(Error::DidDocument(format!( "no valid `AtprotoPersonalDataServer` service found in `{}`", document.id ))); }; - Ok(ResolvedIdentity { - did: document.id, - pds: service, - }) + Ok(ResolvedIdentity { did: document.id, pds: service }) } } diff --git a/atrium-oauth/identity/src/resolver.rs b/atrium-oauth/identity/src/resolver.rs index 0ccef83a..5cfdff90 100644 --- a/atrium-oauth/identity/src/resolver.rs +++ b/atrium-oauth/identity/src/resolver.rs @@ -2,18 +2,56 @@ mod cache_impl; mod cached_resolver; mod throttled_resolver; -pub use self::cached_resolver::{CachedResolverConfig, MaybeCachedResolver}; +pub use self::cached_resolver::{CachedResolver, CachedResolverConfig}; pub use self::throttled_resolver::ThrottledResolver; pub use crate::error::Result; -use async_trait::async_trait; +use std::future::Future; +use std::hash::Hash; -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub trait Resolver { type Input: ?Sized; type Output; - async fn resolve(&self, input: &Self::Input) -> Result; + fn resolve(&self, input: &Self::Input) -> impl Future>; +} + +pub trait Cacheable +where + Self: Sized + Resolver, + Self::Input: Sized, +{ + fn cached(self, config: CachedResolverConfig) -> CachedResolver; +} + +impl Cacheable for R +where + R: Sized + Resolver, + R::Input: Sized + Hash + Eq + Send + Sync + 'static, + R::Output: Clone + Send + Sync + 'static, +{ + fn cached(self, config: CachedResolverConfig) -> CachedResolver { + CachedResolver::new(self, config) + } +} + +pub trait Throttleable +where + Self: Sized + Resolver, + Self::Input: Sized, +{ + fn throttled(self) -> ThrottledResolver; +} + +impl Throttleable for R +where + R: Sized + Resolver, + R::Input: Clone + Hash + Eq + Send + Sync + 'static, + R::Output: Clone + Send + Sync + 'static, +{ + fn throttled(self) -> ThrottledResolver { + ThrottledResolver::new(self) + } } #[cfg(test)] @@ -42,8 +80,6 @@ mod tests { counts: Arc>>, } - #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] - #[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Resolver for MockResolver { type Input = String; type Output = String; @@ -75,7 +111,7 @@ mod tests { #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] async fn test_no_cached() { let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = MaybeCachedResolver::new(mock_resolver(counts.clone()), None); + let resolver = mock_resolver(counts.clone()); for (input, expected) in [ ("k1", Some("v1")), ("k2", Some("v2")), @@ -93,13 +129,9 @@ mod tests { } assert_eq!( *counts.read().await, - [ - (String::from("k1"), 3), - (String::from("k2"), 2), - (String::from("k3"), 2), - ] - .into_iter() - .collect() + [(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),] + .into_iter() + .collect() ); } @@ -107,8 +139,7 @@ mod tests { #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] async fn test_cached() { let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = - MaybeCachedResolver::new(mock_resolver(counts.clone()), Some(Default::default())); + let resolver = mock_resolver(counts.clone()).cached(Default::default()); for (input, expected) in [ ("k1", Some("v1")), ("k2", Some("v2")), @@ -126,13 +157,9 @@ mod tests { } assert_eq!( *counts.read().await, - [ - (String::from("k1"), 1), - (String::from("k2"), 1), - (String::from("k3"), 2), - ] - .into_iter() - .collect() + [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),] + .into_iter() + .collect() ); } @@ -140,13 +167,8 @@ mod tests { #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] async fn test_cached_with_max_capacity() { let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = MaybeCachedResolver::new( - mock_resolver(counts.clone()), - Some(CachedResolverConfig { - max_capacity: Some(1), - ..Default::default() - }), - ); + let resolver = mock_resolver(counts.clone()) + .cached(CachedResolverConfig { max_capacity: Some(1), ..Default::default() }); for (input, expected) in [ ("k1", Some("v1")), ("k2", Some("v2")), @@ -164,13 +186,9 @@ mod tests { } assert_eq!( *counts.read().await, - [ - (String::from("k1"), 2), - (String::from("k2"), 1), - (String::from("k3"), 2), - ] - .into_iter() - .collect() + [(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),] + .into_iter() + .collect() ); } @@ -178,13 +196,10 @@ mod tests { #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] async fn test_cached_with_time_to_live() { let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = MaybeCachedResolver::new( - mock_resolver(counts.clone()), - Some(CachedResolverConfig { - time_to_live: Some(Duration::from_millis(10)), - ..Default::default() - }), - ); + let resolver = mock_resolver(counts.clone()).cached(CachedResolverConfig { + time_to_live: Some(Duration::from_millis(10)), + ..Default::default() + }); for _ in 0..10 { let result = resolver.resolve(&String::from("k1")).await; assert_eq!(result.expect("failed to resolve"), "v1"); @@ -194,17 +209,14 @@ mod tests { let result = resolver.resolve(&String::from("k1")).await; assert_eq!(result.expect("failed to resolve"), "v1"); } - assert_eq!( - *counts.read().await, - [(String::from("k1"), 2)].into_iter().collect() - ); + assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect()); } #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] async fn test_throttled() { let counts = Arc::new(RwLock::new(HashMap::new())); - let resolver = Arc::new(ThrottledResolver::new(mock_resolver(counts.clone()))); + let resolver = Arc::new(mock_resolver(counts.clone()).throttled()); let mut handles = Vec::new(); for (input, expected) in [ @@ -227,13 +239,9 @@ mod tests { } assert_eq!( *counts.read().await, - [ - (String::from("k1"), 1), - (String::from("k2"), 1), - (String::from("k3"), 1), - ] - .into_iter() - .collect() + [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),] + .into_iter() + .collect() ); } } diff --git a/atrium-oauth/identity/src/resolver/cache_impl/moka.rs b/atrium-oauth/identity/src/resolver/cache_impl/moka.rs index 0c70d11e..f35fa3a8 100644 --- a/atrium-oauth/identity/src/resolver/cache_impl/moka.rs +++ b/atrium-oauth/identity/src/resolver/cache_impl/moka.rs @@ -1,5 +1,4 @@ use super::super::cached_resolver::{Cache as CacheTrait, CachedResolverConfig}; -use async_trait::async_trait; use moka::{future::Cache, policy::EvictionPolicy}; use std::collections::hash_map::RandomState; use std::hash::Hash; @@ -8,7 +7,6 @@ pub struct MokaCache { inner: Cache, } -#[async_trait] impl CacheTrait for MokaCache where I: Hash + Eq + Send + Sync + 'static, @@ -25,9 +23,7 @@ where if let Some(time_to_live) = config.time_to_live { builder = builder.time_to_live(time_to_live); } - Self { - inner: builder.build(), - } + Self { inner: builder.build() } } async fn get(&self, key: &Self::Input) -> Option { self.inner.run_pending_tasks().await; diff --git a/atrium-oauth/identity/src/resolver/cache_impl/wasm.rs b/atrium-oauth/identity/src/resolver/cache_impl/wasm.rs index bbac9535..8af03932 100644 --- a/atrium-oauth/identity/src/resolver/cache_impl/wasm.rs +++ b/atrium-oauth/identity/src/resolver/cache_impl/wasm.rs @@ -1,5 +1,4 @@ use super::super::cached_resolver::{Cache as CacheTrait, CachedResolverConfig}; -use async_trait::async_trait; use lru::LruCache; use std::collections::HashMap; use std::hash::Hash; @@ -57,7 +56,6 @@ pub struct WasmCache { expiration: Option, } -#[async_trait(?Send)] impl CacheTrait for WasmCache where I: Hash + Eq + Send + Sync + 'static, @@ -75,10 +73,7 @@ where } else { Store::HashMap(HashMap::new()) }; - Self { - inner: Arc::new(Mutex::new(store)), - expiration: config.time_to_live, - } + Self { inner: Arc::new(Mutex::new(store)), expiration: config.time_to_live } } async fn get(&self, key: &Self::Input) -> Option { let mut cache = self.inner.lock().await; @@ -95,12 +90,6 @@ where } } async fn set(&self, key: Self::Input, value: Self::Output) { - self.inner.lock().await.set( - key, - ValueWithInstant { - value, - instant: Instant::now(), - }, - ); + self.inner.lock().await.set(key, ValueWithInstant { value, instant: Instant::now() }); } } diff --git a/atrium-oauth/identity/src/resolver/cached_resolver.rs b/atrium-oauth/identity/src/resolver/cached_resolver.rs index c7b8ce80..79a38295 100644 --- a/atrium-oauth/identity/src/resolver/cached_resolver.rs +++ b/atrium-oauth/identity/src/resolver/cached_resolver.rs @@ -1,13 +1,11 @@ use super::cache_impl::CacheImpl; use crate::error::Result; use crate::Resolver; -use async_trait::async_trait; use std::fmt::Debug; use std::hash::Hash; use std::time::Duration; -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub(crate) trait Cache { type Input: Hash + Eq + Sync + 'static; type Output: Clone + Sync + 'static; @@ -23,47 +21,41 @@ pub struct CachedResolverConfig { pub time_to_live: Option, } -pub struct MaybeCachedResolver +pub struct CachedResolver where - R: Resolver, + R: Resolver, + R::Input: Sized, { resolver: R, - cache: Option>, + cache: CacheImpl, } -impl MaybeCachedResolver +impl CachedResolver where - R: Resolver, - I: Hash + Eq + Send + Sync + 'static, - O: Clone + Send + Sync + 'static, + R: Resolver, + R::Input: Sized + Hash + Eq + Send + Sync + 'static, + R::Output: Clone + Send + Sync + 'static, { - pub fn new(resolver: R, config: Option) -> Self { - let cache = config.map(CacheImpl::new); - Self { resolver, cache } + pub fn new(resolver: R, config: CachedResolverConfig) -> Self { + Self { resolver, cache: CacheImpl::new(config) } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Resolver for MaybeCachedResolver +impl Resolver for CachedResolver where - R: Resolver + Send + Sync + 'static, - I: Clone + Hash + Eq + Send + Sync + 'static + Debug, - O: Clone + Send + Sync + 'static, + R: Resolver + Send + Sync + 'static, + R::Input: Clone + Hash + Eq + Send + Sync + 'static + Debug, + R::Output: Clone + Send + Sync + 'static, { type Input = R::Input; type Output = R::Output; async fn resolve(&self, input: &Self::Input) -> Result { - if let Some(cache) = &self.cache { - if let Some(output) = cache.get(input).await { - return Ok(output); - } + if let Some(output) = self.cache.get(input).await { + return Ok(output); } let output = self.resolver.resolve(input).await?; - if let Some(cache) = &self.cache { - cache.set(input.clone(), output.clone()).await; - } + self.cache.set(input.clone(), output.clone()).await; Ok(output) } } diff --git a/atrium-oauth/identity/src/resolver/throttled_resolver.rs b/atrium-oauth/identity/src/resolver/throttled_resolver.rs index 71570a02..195473f0 100644 --- a/atrium-oauth/identity/src/resolver/throttled_resolver.rs +++ b/atrium-oauth/identity/src/resolver/throttled_resolver.rs @@ -1,38 +1,37 @@ use super::Resolver; use crate::error::{Error, Result}; -use async_trait::async_trait; use dashmap::{DashMap, Entry}; use std::hash::Hash; -use std::{fmt::Debug, sync::Arc}; +use std::sync::Arc; use tokio::sync::broadcast::{channel, Sender}; use tokio::sync::Mutex; -type SharedSender = Arc>>; +type SharedSender = Arc>>>; -pub struct ThrottledResolver { +pub struct ThrottledResolver +where + R: Resolver, + R::Input: Sized, +{ resolver: R, - senders: Arc>>>, + senders: Arc>>, } -impl ThrottledResolver +impl ThrottledResolver where - I: Clone + Hash + Eq + Send + Sync + 'static + Debug, + R: Resolver, + R::Input: Clone + Hash + Eq + Send + Sync + 'static, { pub fn new(resolver: R) -> Self { - Self { - resolver, - senders: Arc::new(DashMap::new()), - } + Self { resolver, senders: Arc::new(DashMap::new()) } } } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Resolver for ThrottledResolver +impl Resolver for ThrottledResolver where - R: Resolver + Send + Sync + 'static, - I: Clone + Hash + Eq + Send + Sync + 'static, - O: Clone + Send + Sync + 'static, + R: Resolver + Send + Sync + 'static, + R::Input: Clone + Hash + Eq + Send + Sync + 'static, + R::Output: Clone + Send + Sync + 'static, { type Input = R::Input; type Output = R::Output;