Skip to content

Commit

Permalink
Merge pull request #23 from OneLiteFeatherNET/feat/prompt-fetch
Browse files Browse the repository at this point in the history
Feat/prompt fetch
  • Loading branch information
Randoooom committed Jan 1, 2024
2 parents 863f4c5 + 5b9bcae commit 926f859
Show file tree
Hide file tree
Showing 15 changed files with 322 additions and 110 deletions.
6 changes: 6 additions & 0 deletions .redocly.lint-ignore.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
# See https://redoc.ly/docs/cli/ for more information.
target/openapi.yaml:
spec:
- >-
#/components/schemas/PutFeedbackPromptFieldRequest/properties/options/nullable
8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@ license = "MIT"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
aliri = "0.6.2"
aliri_axum = "0.3.0"
aliri_clock = "0.1.4"
aliri_oauth2 = "0.10.0"
aliri_tower = "0.5.0"
async-trait = "0.1.74"
axum = "0.6.20"
chrono = { version = "0.4.31", features = ["serde"] }
derivative = "2.2.0"
envy = "0.4.2"
getset = "0.1.2"
jwt-authorizer = "0.13.0"
kanal = "0.1.0-pre8"
lazy_static = "1.4.0"
nanoid = "0.4.0"
openidconnect = "3.4.0"
paste = "1.0.14"
rbatis = "4.4.20"
rbdc-pg = { version = "4.4.19", optional = true }
Expand All @@ -39,7 +44,6 @@ version-compare = "0.1.1"
rand = "0.8.5"
reqwest = { version = "0.11.23", features = ["json"] }
test-log = "0.2.14"
openidconnect = "3.4.0"

[features]
default = ["all-databases"]
Expand Down
5 changes: 3 additions & 2 deletions Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ script = "docker compose -f testing/oidc-mock/docker-compose.yaml up -d"
script = "docker run --name postgres -e POSTGRES_PASSWORD=password -e POSTGRES_USERNAME=postgres -p 5150:5432 -d postgres && sleep 1"

[tasks.postgres_tests]
env = { DATABASE = "POSTGRES", POSTGRES_USERNAME = "postgres", POSTGRES_PASSWORD = "password", POSTGRES_ENDPOINT = "localhost:5150", POSTGRES_DATABASE = "postgres", "OIDC_DISCOVERY_URL" = "http://localhost:5151", OIDC_CLIENT_ID = "client", OIDC_CLIENT_SECRET = "secret", RUST_LOG = "debug" }
env = { DATABASE = "POSTGRES", POSTGRES_USERNAME = "postgres", POSTGRES_PASSWORD = "password", POSTGRES_ENDPOINT = "localhost:5150", POSTGRES_DATABASE = "postgres", "OIDC_DISCOVERY_URL" = "http://localhost:5151", OIDC_CLIENT_ID = "client", OIDC_CLIENT_SECRET = "secret", RUST_LOG = "DEBUG", OIDC_SCOPE = "api:feedback-fusion" }
command = "cargo"
args = [
"test",
Expand All @@ -46,7 +46,7 @@ args = [
"http_tests",
"--",
"--nocapture",
"--test-threads=1"
"--test-threads=1",
]

[tasks.postgres]
Expand All @@ -71,6 +71,7 @@ args = [
"redocly",
"lint",
"--skip-rule=no-empty-servers",
"--skip-rule=info-license-url",
"target/openapi.yaml",
]

Expand Down
20 changes: 20 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,29 @@ pub struct Config {
#[serde(default = "default_global_rate_limit")]
global_rate_limit: u64,
oidc_discovery_url: String,
#[serde(default = "default_oidc_scope_admin")]
oidc_scope_admin: String,
#[serde(default = "default_oidc_scope_write")]
oidc_scope_write: String,
#[serde(default = "default_oidc_scope_read")]
oidc_scope_read: String,
#[serde(default = "default_oidc_audience")]
oidc_audience: String
}

#[inline]
fn default_global_rate_limit() -> u64 {
10
}
#[inline]
fn default_oidc_scope_admin() -> String { "api:feedback-fusion".to_owned() }

#[inline]
fn default_oidc_scope_write() -> String { "feedback-fusion:write".to_owned() }

#[inline]
fn default_oidc_scope_read() -> String { "feedback-fusion:read".to_owned() }

#[inline]
fn default_oidc_audience() -> String { "feedback-fusion".to_owned() }

35 changes: 25 additions & 10 deletions src/docs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
routes::v1::{prompt::*, response::*, *},
};
use std::{fs, path::Path};
use utoipa::{OpenApi, ToSchema};
use utoipa::{OpenApi, ToSchema, Modify, openapi::security::{SecurityScheme, OpenIdConnect}};

#[derive(ToSchema)]
#[aliases(
Expand Down Expand Up @@ -57,7 +57,8 @@ pub fn generate() {
get_fields,
delete_field,
post_response,
get_responses
get_responses,
fetch
),
components(
schemas(
Expand All @@ -81,19 +82,33 @@ pub fn generate() {
FeedbackTargetPage,
FeedbackPromptPage,
FeedbackPromptFieldPage,
GetFeedbackPromptResponsesResponse,
SubmitFeedbackPromptResponseRequest
GetFeedbackPromptResponsesResponseWrapper,
SubmitFeedbackPromptResponseRequest,
TextResponse,
RatingResponse
)
),
tags(
(name = "FeedbackTarget"),
(name = "FeedbackTargetPrompt"),
(name = "FeedbackTargetPromptField"),
(name = "FeedbackTargetPromptResponse"),
(name = "FeedbackPromptResponse")
)
(name = "FeedbackTarget", description = "A Target contains multiple prompts and is therefore used in order to manage multiple projects with the same instance."),
(name = "FeedbackTargetPrompt", description = "A Prompt contains multiple fields and collects the feedback for your project."),
(name = "FeedbackTargetPromptField", description = "A Field is a input prompt for the clients visiting ur website and rating ur project or whatever."),
(name = "FeedbackPromptResponse", description = "Collect responses from clients and manage / view them"),
),
modifiers(&Security)
)]
struct OpenApiSpecification;
struct Security;

impl Modify for Security {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
if let Some(components) = openapi.components.as_mut() {
components.add_security_scheme(
"oidc",
SecurityScheme::OpenIdConnect(OpenIdConnect::new("https://your-oidc-provider.tld"))
)
}
}
}

let destination = Path::new("./target").join("openapi.yaml");
// write the spec file
Expand Down
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub enum FeedbackFusionError {
DatabaseError(#[from] rbatis::Error),
#[error("unauthorized")]
Unauthorized,
#[error("{0}")]
Forbidden(String),
}

impl From<ValidationErrors> for FeedbackFusionError {
Expand Down Expand Up @@ -71,6 +73,10 @@ impl IntoResponse for FeedbackFusionError {
StatusCode::UNAUTHORIZED,
Json(FeedbackFusionErrorResponse::from("Unauthorized".to_owned())),
),
FeedbackFusionError::Forbidden(error) => (
StatusCode::FORBIDDEN,
Json(FeedbackFusionErrorResponse::from(error)),
),
_ => {
error!("Error occurred while processing request: {:?}", self);

Expand Down
15 changes: 10 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ async fn main() {
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.init();
debug!("{:?}", std::env::vars());

// init config
lazy_static::initialize(&CONFIG);
Expand Down Expand Up @@ -129,19 +128,25 @@ pub(crate) async fn router(connection: DatabaseConnection) -> Router {
*CONFIG.global_rate_limit(),
Duration::from_secs(1),
))
.layer(TraceLayer::new_for_http()),
.layer(TraceLayer::new_for_http())
)
}

pub mod prelude {
pub use crate::{
config::*, database::DatabaseConnection, database_request, error::*, oidc_layer, routes::*,
impl_select_page_wrapper, state::FeedbackFusionState, CONFIG, DATABASE_CONFIG,
config::*,
database::DatabaseConnection,
database_request,
error::*,
impl_select_page_wrapper,
routes::{oidc::*, *},
state::FeedbackFusionState,
CONFIG, DATABASE_CONFIG,
};
pub use axum::{
extract::{Json, Query, State},
routing::*,
Router,
};
pub use rbatis::{rbdc::JsonV, plugin::page::Page, IPageRequest};
pub use rbatis::{plugin::page::Page, rbdc::JsonV, IPageRequest};
}
15 changes: 13 additions & 2 deletions src/routes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,24 @@

use crate::prelude::*;

use aliri_tower::Oauth2Authorizer;
use rbatis::plugin::page::PageRequest;

pub mod oidc;
pub mod v1;
mod oidc;

pub async fn router(state: FeedbackFusionState) -> Router {
Router::new().nest("/v1", v1::router(state).await)
let (authorized, unauthorized) = v1::router(state).await;

// build the authority
let authority = oidc::authority().await.unwrap();
let authorizer = Oauth2Authorizer::new()
.with_claims::<OIDCClaims>()
.with_terse_error_handler();

Router::new()
.nest("/v1", authorized.layer(authorizer.jwt_layer(authority)))
.nest("/v1", unauthorized)
}

#[derive(Debug, Clone, Deserialize, IntoParams)]
Expand Down
119 changes: 106 additions & 13 deletions src/routes/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,110 @@
//DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
//OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

#[macro_export]
macro_rules! oidc_layer {
() => {{
use jwt_authorizer::{Authorizer, IntoLayer, JwtAuthorizer};

// init the oidc authorizer
let authorizer: Authorizer = JwtAuthorizer::from_oidc(CONFIG.oidc_discovery_url())
.build()
.await
.unwrap();

authorizer.into_layer()
}}; // we can add here support for scopes later :)
use std::str::FromStr;

use crate::prelude::*;

use aliri::{
jwa,
jwt::{self, CoreValidator},
};
use aliri_clock::UnixTime;
use aliri_oauth2::{Authority, HasScope, Scope};
use openidconnect::{
core::{CoreJwsSigningAlgorithm, CoreProviderMetadata},
IssuerUrl,
};

pub async fn authority() -> Result<Authority> {
// sadly aliri does not support oidc yet, so we have to do the config stuff manually :(((((
// discover the oidc endpoints
let issuer = IssuerUrl::new(CONFIG.oidc_discovery_url().clone())
.map_err(|_| FeedbackFusionError::ConfigurationError("invalid discovery url".to_owned()))?;
let metadata = CoreProviderMetadata::discover_async(
issuer.clone(),
openidconnect::reqwest::async_http_client,
)
.await
.map_err(|_| FeedbackFusionError::ConfigurationError("invalid oidc endpoint".to_owned()))?;
// extract the jwks
let jwks_url = metadata.jwks_uri().url();
// extract the algorithms
let algorithms = metadata
.id_token_signing_alg_values_supported()
.iter()
.filter_map(|key| match key {
CoreJwsSigningAlgorithm::HmacSha256 => Some(jwa::Algorithm::HS256),
CoreJwsSigningAlgorithm::HmacSha384 => Some(jwa::Algorithm::HS384),
CoreJwsSigningAlgorithm::HmacSha512 => Some(jwa::Algorithm::HS512),
CoreJwsSigningAlgorithm::EcdsaP256Sha256 => Some(jwa::Algorithm::PS256),
CoreJwsSigningAlgorithm::EcdsaP384Sha384 => Some(jwa::Algorithm::PS384),
CoreJwsSigningAlgorithm::EcdsaP521Sha512 => Some(jwa::Algorithm::PS512),
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256 => Some(jwa::Algorithm::RS256),
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha384 => Some(jwa::Algorithm::RS384),
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha512 => Some(jwa::Algorithm::RS512),
_ => None,
})
.collect::<Vec<jwa::Algorithm>>();

// build the validator
let mut validator = CoreValidator::default()
.add_allowed_audience(
jwt::Audience::from_str(CONFIG.oidc_audience().as_str())
.expect("Invalid oidc audience"),
)
.require_issuer(jwt::Issuer::from_str(issuer.as_str()).unwrap());
for algorithm in algorithms {
validator = validator.add_approved_algorithm(algorithm);
};

// build the authority
let authority = Authority::new_from_url(jwks_url.to_string(), validator)
.await
.unwrap();

Ok(authority)
}

#[derive(Debug, Clone, Deserialize)]
pub struct OIDCClaims {
iss: jwt::Issuer,
aud: jwt::Audiences,
nbf: UnixTime,
exp: UnixTime,
scope: Scope,
}

impl jwt::CoreClaims for OIDCClaims {
fn nbf(&self) -> Option<UnixTime> {
Some(self.nbf)
}
fn exp(&self) -> Option<UnixTime> {
Some(self.exp)
}
fn aud(&self) -> &jwt::Audiences {
&self.aud
}
fn iss(&self) -> Option<&jwt::IssuerRef> {
Some(&self.iss)
}
fn sub(&self) -> Option<&jwt::SubjectRef> {
None
}
}

impl HasScope for OIDCClaims {
fn scope(&self) -> &Scope {
&self.scope
}
}

pub mod scope {
aliri_axum::scope_guards! {
type Claims = super::OIDCClaims;

pub scope API = "api:feedback-fusion";
pub scope Read = ["api:feedback-fusion" || "feedback-fusion:read"];
pub scope Write = ["api:feedback-fusion" || "feedback-fusion:write"];
}
}
Loading

0 comments on commit 926f859

Please sign in to comment.