diff --git a/Cargo.lock b/Cargo.lock index d073f8d3..445a7163 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,6 +1077,21 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "compact_str" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644" +dependencies = [ + "castaway", + "cfg-if", + "itoa 1.0.14", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + [[package]] name = "condtype" version = "1.3.0" @@ -2290,7 +2305,7 @@ version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a3233677ea1554a48235d81bb59d2a41654969a8e29a1316c48105fd1701693" dependencies = [ - "compact_str", + "compact_str 0.7.1", "garde_derive", "idna", "once_cell", @@ -3981,6 +3996,7 @@ version = "0.0.1-pre.6" dependencies = [ "base64-simd", "bytes", + "compact_str 0.8.0", "divan", "headers", "http", diff --git a/Cargo.toml b/Cargo.toml index 63156da8..dc93536e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,6 +112,7 @@ chrono = { version = "0.4.39", default-features = false } clap = { version = "4.5.23", features = ["derive", "wrap_help"] } color-eyre = "0.6.3" colored_json = "5.0.0" +compact_str = { version = "0.8.0", features = ["serde"] } const_format = "0.2.34" const-oid = { version = "0.9.6", features = ["db"] } cookie = { version = "0.18.1", features = ["percent-encode"] } diff --git a/lib/komainu/Cargo.toml b/lib/komainu/Cargo.toml index 394f130c..99268ab1 100644 --- a/lib/komainu/Cargo.toml +++ b/lib/komainu/Cargo.toml @@ -16,6 +16,7 @@ harness = false [dependencies] base64-simd.workspace = true bytes.workspace = true +compact_str.workspace = true http.workspace = true memchr.workspace = true serde.workspace = true diff --git a/lib/komainu/src/code_grant.rs b/lib/komainu/src/code_grant.rs index 5347d41a..910b662b 100644 --- a/lib/komainu/src/code_grant.rs +++ b/lib/komainu/src/code_grant.rs @@ -1,13 +1,8 @@ use crate::{ - error::Error, flow::pkce, params::ParamStorage, AuthInstruction, Client, ClientExtractor, -}; -use std::{ - borrow::{Borrow, Cow}, - collections::HashSet, - future::Future, - ops::Deref, - str::FromStr, + error::Error, flow::pkce, params::ParamStorage, primitive::Scopes, AuthInstruction, Client, + ClientExtractor, }; +use std::{borrow::Cow, future::Future, ops::Deref, str::FromStr}; use strum::{AsRefStr, Display}; use thiserror::Error; @@ -88,15 +83,11 @@ where return Err(GrantError::AccessDenied); } - let request_scopes = scope.split_whitespace().collect::>(); - let client_scopes = client - .scopes - .iter() - .map(Borrow::borrow) - .collect::>(); + let request_scopes = scope.split_whitespace().collect::(); + let client_scopes = client.scopes.iter().map(Deref::deref).collect::(); - if !request_scopes.is_subset(&client_scopes) { - debug!(?client_id, "scopes aren't a subset"); + if !client_scopes.can_perform(&request_scopes) { + debug!(?client_id, "client can't issue the requested scopes"); return Err(GrantError::AccessDenied); } diff --git a/lib/komainu/src/flow/authorization.rs b/lib/komainu/src/flow/authorization.rs index 18b73878..88e939a3 100644 --- a/lib/komainu/src/flow/authorization.rs +++ b/lib/komainu/src/flow/authorization.rs @@ -23,8 +23,8 @@ pub trait Issuer { pub async fn perform( req: http::Request, client_extractor: CE, - token_issuer: I, -) -> Result, flow::Error> + token_issuer: &I, +) -> Result, flow::Error> where CE: ClientExtractor, I: Issuer, @@ -70,15 +70,5 @@ where pkce.verify(code_verifier)?; } - let token = token_issuer.issue_token(&authorization).await?; - let body = sonic_rs::to_vec(&token).unwrap(); - - debug!("token successfully issued. building response"); - - let response = http::Response::builder() - .status(http::StatusCode::OK) - .body(body.into()) - .unwrap(); - - Ok(response) + token_issuer.issue_token(&authorization).await } diff --git a/lib/komainu/src/flow/refresh.rs b/lib/komainu/src/flow/refresh.rs index 658b450c..6c4426b7 100644 --- a/lib/komainu/src/flow/refresh.rs +++ b/lib/komainu/src/flow/refresh.rs @@ -19,8 +19,8 @@ pub trait Issuer { pub async fn perform( req: http::Request, client_extractor: CE, - token_issuer: I, -) -> Result, flow::Error> + token_issuer: &I, +) -> Result, flow::Error> where CE: ClientExtractor, I: Issuer, @@ -46,15 +46,5 @@ where .extract(client_id, Some(client_secret)) .await?; - let token = token_issuer.issue_token(&client, refresh_token).await?; - let body = sonic_rs::to_vec(&token).unwrap(); - - debug!("token successfully issued. building response"); - - let response = http::Response::builder() - .status(http::StatusCode::OK) - .body(body.into()) - .unwrap(); - - Ok(response) + token_issuer.issue_token(&client, refresh_token).await } diff --git a/lib/komainu/src/lib.rs b/lib/komainu/src/lib.rs index 67a22a88..34392c2e 100644 --- a/lib/komainu/src/lib.rs +++ b/lib/komainu/src/lib.rs @@ -10,6 +10,7 @@ pub mod error; pub mod extract; pub mod flow; pub mod params; +pub mod primitive; pub struct Authorization<'a> { pub code: Cow<'a, str>, diff --git a/lib/komainu/src/primitive/mod.rs b/lib/komainu/src/primitive/mod.rs new file mode 100644 index 00000000..ba2c0930 --- /dev/null +++ b/lib/komainu/src/primitive/mod.rs @@ -0,0 +1,3 @@ +mod scope; + +pub use self::scope::Scopes; diff --git a/lib/komainu/src/primitive/scope.rs b/lib/komainu/src/primitive/scope.rs new file mode 100644 index 00000000..cf58090b --- /dev/null +++ b/lib/komainu/src/primitive/scope.rs @@ -0,0 +1,82 @@ +use compact_str::CompactString; +use serde::Deserialize; +use std::{ + collections::{hash_set, HashSet}, + convert::Infallible, + str::FromStr, +}; + +#[derive(Default, Deserialize)] +#[serde(transparent)] +pub struct Scopes { + inner: HashSet, +} + +impl FromStr for Scopes { + type Err = Infallible; + + fn from_str(s: &str) -> Result { + Ok(s.split_whitespace().collect()) + } +} + +impl Scopes { + #[inline] + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[inline] + pub fn insert(&mut self, item: Item) + where + Item: Into, + { + self.inner.insert(item.into()); + } + + /// Determine whether `self` can be accessed by `resource` + /// + /// This implies that `resource` is equal to or a superset of `self` + #[inline] + #[must_use] + pub fn can_be_accessed_by(&self, resource: &Self) -> bool { + resource.inner.is_superset(&self.inner) + } + + /// Determine whether `self` is allowed to perform an action + /// for which you at least need `resource` scope + #[inline] + #[must_use] + pub fn can_perform(&self, resource: &Self) -> bool { + self.inner.is_superset(&resource.inner) + } + + #[inline] + pub fn iter(&self) -> impl Iterator { + self.inner.iter().map(CompactString::as_str) + } +} + +impl FromIterator for Scopes +where + Item: Into, +{ + #[inline] + fn from_iter>(iter: T) -> Self { + let mut collection = Self::new(); + for item in iter { + collection.insert(item.into()); + } + collection + } +} + +impl IntoIterator for Scopes { + type Item = CompactString; + type IntoIter = hash_set::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} diff --git a/lib/komainu/tests/scope.rs b/lib/komainu/tests/scope.rs new file mode 100644 index 00000000..da0ea6f1 --- /dev/null +++ b/lib/komainu/tests/scope.rs @@ -0,0 +1,47 @@ +use komainu::primitive::Scopes; +use rstest::rstest; + +#[rstest] +#[case("read", "read write")] +#[case("read write", "read write")] +#[case("read write follow", "read write follow push")] +fn can_perform(#[case] request: &str, #[case] client: &str) { + let request: Scopes = request.parse().unwrap(); + let client: Scopes = client.parse().unwrap(); + + assert!(client.can_perform(&request)); +} + +#[rstest] +#[case("read write", "read")] +#[case("read follow", "write")] +#[case("write push", "read")] +fn cant_perform(#[case] request: &str, #[case] client: &str) { + let request: Scopes = request.parse().unwrap(); + let client: Scopes = client.parse().unwrap(); + + assert!(!client.can_perform(&request)); +} + +#[rstest] +#[case("read", "read write")] +#[case("read", "read")] +#[case("follow", "read follow")] +#[case("write follow", "follow write")] +fn can_access(#[case] endpoint: &str, #[case] client: &str) { + let endpoint: Scopes = endpoint.parse().unwrap(); + let client: Scopes = client.parse().unwrap(); + + assert!(endpoint.can_be_accessed_by(&client)); +} + +#[rstest] +#[case("read write", "write")] +#[case("follow", "read write")] +#[case("write follow", "read follow")] +fn cant_access(#[case] endpoint: &str, #[case] client: &str) { + let endpoint: Scopes = endpoint.parse().unwrap(); + let client: Scopes = client.parse().unwrap(); + + assert!(!endpoint.can_be_accessed_by(&client)); +}