Skip to content

Commit

Permalink
Update oauth-client
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan committed Sep 25, 2024
1 parent dc223cb commit 7625577
Show file tree
Hide file tree
Showing 22 changed files with 123 additions and 213 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions atrium-oauth/identity/src/did.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod plc_resolver;
mod web_resolver;

pub use self::common_resolver::{CommonDidResolver, CommonDidResolverConfig};
pub use self::plc_resolver::DEFAULT_PLC_DIRECTORY_URL;
use crate::Resolver;
use atrium_api::did_doc::DidDocument;
use atrium_api::types::string::Did;
Expand Down
5 changes: 1 addition & 4 deletions atrium-oauth/identity/src/did/common_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ where
type Output = DidDocument;

async fn resolve(&self, did: &Self::Input) -> Result<Self::Output> {
match did
.strip_prefix("did:")
.and_then(|s| s.split_once(':').and_then(|(method, _)| Some(method)))
{
match did.strip_prefix("did:").and_then(|s| s.split_once(':').map(|(method, _)| method)) {
Some("plc") => self.plc_resolver.resolve(did).await,
Some("web") => self.web_resolver.resolve(did).await,
_ => Err(Error::UnsupportedDidMethod(did.clone())),
Expand Down
1 change: 0 additions & 1 deletion atrium-oauth/identity/src/did/plc_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use atrium_xrpc::http::{Request, Uri};
use atrium_xrpc::HttpClient;
use std::sync::Arc;

#[allow(dead_code)]
pub const DEFAULT_PLC_DIRECTORY_URL: &str = "https://plc.directory/";

#[derive(Clone, Debug)]
Expand Down
4 changes: 1 addition & 3 deletions atrium-oauth/identity/src/identity_resolver.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::did::DidResolver;
use crate::error::{Error, Result};
use crate::handle::HandleResolver;
use crate::Resolver;
use crate::{did::DidResolver, handle::HandleResolver, Resolver};
use atrium_api::types::string::AtIdentifier;
use serde::{Deserialize, Serialize};

Expand Down
1 change: 1 addition & 0 deletions atrium-oauth/oauth-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ serde_html_form.workspace = true
serde_json.workspace = true
sha2.workspace = true
thiserror.workspace = true
trait-variant.workspace = true

[dev-dependencies]
hickory-resolver.workspace = true
Expand Down
36 changes: 15 additions & 21 deletions atrium-oauth/oauth-client/examples/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use atrium_identity::handle::{DnsTxtResolver, HandleResolverImpl};
use atrium_identity::identity_resolver::{DidResolverConfig, HandleResolverConfig};
use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL};
use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver};
use atrium_oauth_client::store::state::MemoryStateStore;
use atrium_oauth_client::{
AtprotoLocalhostClientMetadata, AuthorizeOptions, OAuthClient, OAuthClientConfig,
OAuthResolverConfig,
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, OAuthClient,
OAuthClientConfig, OAuthResolverConfig,
};
use atrium_xrpc::http::Uri;
use hickory_resolver::TokioAsyncResolver;
Expand All @@ -23,35 +23,32 @@ impl Default for HickoryDnsTxtResolver {
}
}

#[async_trait::async_trait]
impl DnsTxtResolver for HickoryDnsTxtResolver {
async fn resolve(
&self,
query: &str,
) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync + 'static>> {
Ok(self
.resolver
.txt_lookup(query)
.await?
.iter()
.map(|txt| txt.to_string())
.collect())
Ok(self.resolver.txt_lookup(query).await?.iter().map(|txt| txt.to_string()).collect())
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let http_client = Arc::new(DefaultHttpClient::default());
let config = OAuthClientConfig {
client_metadata: AtprotoLocalhostClientMetadata {
redirect_uris: vec!["http://127.0.0.1".to_string()],
},
keys: None,
resolver: OAuthResolverConfig {
did: DidResolverConfig::default(),
handle: HandleResolverConfig {
r#impl: HandleResolverImpl::Atproto(Arc::new(HickoryDnsTxtResolver::default())),
cache: Default::default(),
},
did_resolver: CommonDidResolver::new(CommonDidResolverConfig {
plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
http_client: http_client.clone(),
}),
handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig {
dns_txt_resolver: HickoryDnsTxtResolver::default(),
http_client: http_client.clone(),
}),
authorization_server_metadata: Default::default(),
protected_resource_metadata: Default::default(),
},
Expand Down Expand Up @@ -81,10 +78,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let uri = url.trim().parse::<Uri>()?;
let params = serde_html_form::from_str(uri.query().unwrap())?;
println!(
"{}",
serde_json::to_string_pretty(&client.callback(params).await?)?
);
println!("{}", serde_json::to_string_pretty(&client.callback(params).await?)?);

Ok(())
}
6 changes: 1 addition & 5 deletions atrium-oauth/oauth-client/src/atproto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,7 @@ impl TryIntoOAuthClientMetadata for AtprotoClientMetadata {
token_endpoint_auth_method: Some(self.token_endpoint_auth_method.into()),
grant_types: Some(self.grant_types.into_iter().map(|v| v.into()).collect()),
scope: Some(
self.scopes
.into_iter()
.map(|v| v.into())
.collect::<Vec<String>>()
.join(" "),
self.scopes.into_iter().map(|v| v.into()).collect::<Vec<String>>().join(" "),
),
dpop_bound_access_tokens: Some(true),
jwks_uri,
Expand Down
10 changes: 2 additions & 8 deletions atrium-oauth/oauth-client/src/http_client/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ pub struct DefaultHttpClient {
client: Client,
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl HttpClient for DefaultHttpClient {
async fn send_http(
&self,
Expand All @@ -20,16 +18,12 @@ impl HttpClient for DefaultHttpClient {
for (k, v) in response.headers() {
builder = builder.header(k, v);
}
builder
.body(response.bytes().await?.to_vec())
.map_err(Into::into)
builder.body(response.bytes().await?.to_vec()).map_err(Into::into)
}
}

impl Default for DefaultHttpClient {
fn default() -> Self {
Self {
client: reqwest::Client::new(),
}
Self { client: reqwest::Client::new() }
}
}
16 changes: 3 additions & 13 deletions atrium-oauth/oauth-client/src/http_client/dpop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,7 @@ impl<T> DpopClient<T> {
}
}
let nonces = MemorySimpleStore::<String, String>::default();
Ok(Self {
inner: http_client,
key,
iss,
nonces,
})
Ok(Self { inner: http_client, key, iss, nonces })
}
fn build_proof(&self, htm: String, htu: String, nonce: Option<String>) -> Result<String> {
match crypto::Key::try_from(&self.key).map_err(Error::JwkCrypto)? {
Expand Down Expand Up @@ -120,8 +115,6 @@ impl<T> DpopClient<T> {
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl<T> HttpClient for DpopClient<T>
where
T: HttpClient + Send + Sync + 'static,
Expand All @@ -141,11 +134,8 @@ where
request.headers_mut().insert("DPoP", init_proof.parse()?);
let response = self.inner.send_http(request.clone()).await?;

let next_nonce = response
.headers()
.get("DPoP-Nonce")
.and_then(|v| v.to_str().ok())
.map(String::from);
let next_nonce =
response.headers().get("DPoP-Nonce").and_then(|v| v.to_str().ok()).map(String::from);
match &next_nonce {
Some(s) if next_nonce != init_nonce => {
// Store the fresh nonce for future requests
Expand Down
5 changes: 1 addition & 4 deletions atrium-oauth/oauth-client/src/jose/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ pub struct PublicClaims {

impl From<RegisteredClaims> for Claims {
fn from(registered: RegisteredClaims) -> Self {
Self {
registered,
public: PublicClaims::default(),
}
Self { registered, public: PublicClaims::default() }
}
}

Expand Down
5 changes: 1 addition & 4 deletions atrium-oauth/oauth-client/src/jose/signing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,5 @@ where
let header = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header)?);
let payload = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims)?);
let signature: Signature<_> = key.sign(format!("{header}.{payload}").as_bytes());
Ok(format!(
"{header}.{payload}.{}",
URL_SAFE_NO_PAD.encode(signature.to_bytes())
))
Ok(format!("{header}.{payload}.{}", URL_SAFE_NO_PAD.encode(signature.to_bytes())))
}
5 changes: 2 additions & 3 deletions atrium-oauth/oauth-client/src/keyset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ pub type Result<T> = core::result::Result<T, Error>;
pub struct Keyset(Vec<Jwk>);

impl Keyset {
const PREFERRED_SIGNING_ALGORITHMS: [&'static str; 9] = [
"EdDSA", "ES256K", "ES256", "PS256", "PS384", "PS512", "HS256", "HS384", "HS512",
];
const PREFERRED_SIGNING_ALGORITHMS: [&'static str; 9] =
["EdDSA", "ES256K", "ES256", "PS256", "PS384", "PS512", "HS256", "HS384", "HS512"];
pub fn public_jwks(&self) -> JwkSet {
let mut keys = Vec::with_capacity(self.0.len());
for mut key in self.0.clone() {
Expand Down
Loading

0 comments on commit 7625577

Please sign in to comment.