diff --git a/Cargo.toml b/Cargo.toml index 17be5eb9..7a4b49bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,6 @@ license = "MIT" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -argon2 = "0.5.2" async-trait = "0.1.74" axum = "0.6.20" chrono = { version = "0.4.31", features = ["serde"] } @@ -27,7 +26,6 @@ serde = { version = "1.0.193", features = ["derive"] } serde_json = "1.0.108" thiserror = "1.0.50" tokio = { version = "1.33.0", features = ["full"] } -totp-rs = { version = "5.4.0", features = ["qr", "gen_secret"] } tower = { version = "0.4.13", features = ["limit", "buffer"] } tower-http = { version = "0.4.4", features = ["trace"] } tracing = "0.1.39" @@ -37,6 +35,12 @@ utoipa = { version = "4.1.0", features = ["yaml", "chrono"] } validator = { version = "0.16", features = ["derive"] } version-compare = "0.1.1" +[dev-dependencies] +rand = "0.8.5" +reqwest = { version = "0.11.23", features = ["json"] } +test-log = "0.2.14" +openidconnect = "3.4.0" + [features] default = ["all-databases"] @@ -45,3 +49,5 @@ all-databases = ["postgres", "mysql"] postgres = ["rbdc-pg"] mysql = ["rbdc-mysql"] +test = [] + diff --git a/Makefile.toml b/Makefile.toml index 1a0a4ea7..165d05d1 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -35,12 +35,26 @@ script = "docker compose -f testing/oidc-mock/docker-compose.yaml up -d" script = "docker run --name postgres -e POSTGRES_PASSWORD=password -e POSTGRES_USERNAME=postgres -p 5150:5432 -d postgres && sleep 1" [tasks.postgres_tests] -env = { postgres_username = "postgres", postgres_password = "password", postgres_endpoint = "localhost:5150", postgres_database = "postgres", "oidc_discovery_url" = "http://localhost:5151/" } +env = { DATABASE = "POSTGRES", POSTGRES_USERNAME = "postgres", POSTGRES_PASSWORD = "password", POSTGRES_ENDPOINT = "localhost:5150", POSTGRES_DATABASE = "postgres", "OIDC_DISCOVERY_URL" = "http://localhost:5151", OIDC_CLIENT_ID = "client", OIDC_CLIENT_SECRET = "secret", RUST_LOG = "debug" } command = "cargo" -args = ["test", "--features", "postgres"] +args = [ + "test", + "--no-default-features", + "--features", + "postgres,test", + "--test", + "http_tests", + "--", + "--nocapture", + "--test-threads=1" +] [tasks.postgres] -run_task = { name = ["oidc-server-mock", "postgres_database", "postgres_tests"], fork = true, cleanup_task = "postgres_cleanup"} +run_task = { name = [ + "oidc-server-mock", + "postgres_database", + "postgres_tests", +], fork = true, cleanup_task = "postgres_cleanup" } [tasks.postgres_cleanup] script = "docker kill postgres;docker rm postgres;docker kill oidc-server-mock;docker rm oidc-server-mock" @@ -53,7 +67,12 @@ run_task = { name = ["postgres"] } [tasks.docs_lint] dependencies = ["docs_generate"] command = "npx" -args = ["redocly", "lint", "--skip-rule=no-empty-servers", "target/openapi.yaml"] +args = [ + "redocly", + "lint", + "--skip-rule=no-empty-servers", + "target/openapi.yaml", +] [tasks.docs_build] dependencies = ["docs_lint"] diff --git a/build.rs b/build.rs new file mode 100644 index 00000000..2770ed24 --- /dev/null +++ b/build.rs @@ -0,0 +1,4 @@ +// just placeholder to get OUT_DIR + +fn main() {} + diff --git a/src/database/drop.sql b/src/database/drop.sql index c4a17593..41e3d151 100644 --- a/src/database/drop.sql +++ b/src/database/drop.sql @@ -1,4 +1,6 @@ -DROP TABLE IF EXISTS field; -DROP TABLE IF EXISTS prompt; -DROP TABLE IF EXISTS target; +DROP TABLE IF EXISTS feedback_prompt_field_response; +DROP TABLE IF EXISTS feedback_prompt_response; +DROP TABLE IF EXISTS feedback_prompt_field; +DROP TABLE IF EXISTS feedback_prompt; +DROP TABLE IF EXISTS feedback_target; diff --git a/src/database/mod.rs b/src/database/mod.rs index ac6846c8..979431d4 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -52,7 +52,7 @@ macro_rules! database_configuration { #[derive(Debug, Clone)] pub enum DatabaseConfiguration { $( - #[cfg(feature = "" $ident:lower)] + #[cfg(feature = $scheme)] $ident($config), )* } @@ -61,8 +61,8 @@ macro_rules! database_configuration { #[inline(always)] pub fn extract() -> crate::error::Result { $( - #[cfg(feature = "" $ident:lower)] - if let Ok(config) = envy::prefixed(stringify!([<$ident:lower _>]).trim()).from_env::<$config>() { + #[cfg(feature = $scheme)] + if let Ok(config) = envy::prefixed(stringify!([<$ident:upper _>])).from_env::<$config>() { return Ok(Self::$ident(config)); } )* @@ -77,12 +77,12 @@ macro_rules! database_configuration { match self { $( - #[cfg(feature = "" $ident:lower)] + #[cfg(feature = $scheme)] Self::$ident(config) => { let url = config.to_url($scheme); connection.init($driver {}, url.as_str())?; - #[cfg(test)] + #[cfg(feature = "test")] connection.exec(include_str!("drop.sql"), vec![]).await?; // perform migrations @@ -145,3 +145,33 @@ macro_rules! database_request { }}; } +/// rbatis doesnt convert the LIMIT statements for postgres and mssql therefore we need a wrapper +/// REF: https://rbatis.github.io/rbatis.io/#/v4/?id=macros-select-page +#[macro_export] +macro_rules! impl_select_page_wrapper { + ($table:path {}) => { + impl_select_page_wrapper!($table{select_page() => ""}); + }; + ($table:path {$ident:ident ($($arg:ident: $ty:ty $(,)?)*) => $expr:expr}) => { + paste!{ + impl_select_page!($table {$ident($($arg: $ty,)* limit_sql: &str) => $expr}); + + impl $table { + pub async fn [<$ident _wrapper>](executor: &dyn rbatis::executor::Executor, page_request: &dyn rbatis::IPageRequest, $($arg: $ty,)*) -> std::result::Result, rbatis::rbdc::Error> { + + use std::ops::Deref; + let limit = page_request.page_size(); + let offset = page_request.offset(); + + match $crate::DATABASE_CONFIG.deref() { + #[cfg(feature = "postgres")] + $crate::database::DatabaseConfiguration::Postgres(_) => Self::$ident(executor, page_request, $($arg,)* format!(" LIMIT {} OFFSET {} ", limit, offset).as_str()).await, + #[allow(unreachable_patterns)] + _ => Self::$ident(executor, page_request, $($arg,)* format!(" LIMIT {},{} ", limit, offset).as_str()).await + } + } + } + } + + } +} diff --git a/src/database/postgres.sql b/src/database/postgres.sql index 7920e78f..731ee50a 100644 --- a/src/database/postgres.sql +++ b/src/database/postgres.sql @@ -1,4 +1,4 @@ -CREATE TABLE IF NOT EXISTS target ( +CREATE TABLE IF NOT EXISTS feedback_target ( id VARCHAR(32) UNIQUE NOT NULL, name VARCHAR(32) NOT NULL, description VARCHAR(255), @@ -6,22 +6,36 @@ CREATE TABLE IF NOT EXISTS target ( created_at TIMESTAMP ); -CREATE TABLE IF NOT EXISTS prompt ( +CREATE TABLE IF NOT EXISTS feedback_prompt ( id VARCHAR(32) UNIQUE NOT NULL, title VARCHAR(32) NOT NULL, - target VARCHAR(32) REFERENCES target(id) NOT NULL, + target VARCHAR(32) REFERENCES feedback_target(id) NOT NULL, active BOOLEAN NOT NULL, updated_at TIMESTAMP, created_at TIMESTAMP ); -CREATE TABLE IF NOT EXISTS field ( +CREATE TABLE IF NOT EXISTS feedback_prompt_field ( id VARCHAR(32) UNIQUE NOT NULL, title VARCHAR(255) NOT NULL, - prompt VARCHAR(32) REFERENCES prompt(id) NOT NULL, + prompt VARCHAR(32) REFERENCES feedback_prompt(id) NOT NULL, type VARCHAR(32) NOT NULL, - options BPCHAR NOT NULL, + options JSON NOT NULL, updated_at TIMESTAMP, created_at TIMESTAMP ); +CREATE TABLE IF NOT EXISTS feedback_prompt_response ( + id VARCHAR(32) UNIQUE NOT NULL, + prompt VARCHAR(32) REFERENCES feedback_prompt(id) NOT NULL, + created_at TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS feedback_prompt_field_response ( + id VARCHAR(32) UNIQUE NOT NULL, + response VARCHAR(32) REFERENCES feedback_prompt_response(id) NOT NULL, + field VARCHAR(32) REFERENCES feedback_prompt_field(id) NOT NULL, + data JSON NOT NULL +); + + diff --git a/src/database/schema/feedback/input.rs b/src/database/schema/feedback/input.rs index 565cb263..0e661257 100644 --- a/src/database/schema/feedback/input.rs +++ b/src/database/schema/feedback/input.rs @@ -20,11 +20,27 @@ //DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, //OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +use crate::prelude::*; +use rbatis::rbdc::{DateTime, JsonV}; + +use super::FeedbackPromptInputType; + #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, ToSchema)] -#[serde(tag = "type")] +#[serde(untagged)] +#[serde(rename_all = "lowercase")] pub enum FeedbackPromptInputOptions { Text(TextOptions), - Rating(RatingOptions) + Rating(RatingOptions), +} + +// TODO: gen with macro +impl PartialEq for FeedbackPromptInputType { + fn eq(&self, other: &FeedbackPromptInputOptions) -> bool { + match self { + Self::Text => matches!(other, FeedbackPromptInputOptions::Text(_)), + Self::Rating => matches!(other, FeedbackPromptInputOptions::Rating(_)), + } + } } #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, TypedBuilder, ToSchema, Validate)] @@ -33,7 +49,7 @@ pub struct TextOptions { #[validate(length(max = 255))] description: String, #[validate(length(max = 255))] - placeholder: String + placeholder: String, } #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, TypedBuilder, ToSchema, Validate)] @@ -41,6 +57,68 @@ pub struct TextOptions { pub struct RatingOptions { #[validate(length(max = 255))] description: String, - max: u8 + max: u8, +} + +#[derive( + Deserialize, Serialize, Clone, Derivative, Debug, Getters, MutGetters, TypedBuilder, ToSchema, +)] +#[derivative(PartialEq)] +#[get = "pub"] +#[get_mut = "pub"] +#[builder(field_defaults(setter(into)))] +pub struct FeedbackPromptResponse { + #[builder(default_code = r#"nanoid::nanoid!()"#)] + id: String, + prompt: String, + #[derivative(PartialEq = "ignore")] + #[builder(default)] + created_at: DateTime, } +crud!(FeedbackPromptResponse {}); +impl_select_page_wrapper!(FeedbackPromptResponse {select_page_by_prompt(prompt: &str) => "WHERE prompt = #{prompt}"}); + +#[derive( + Deserialize, Serialize, Clone, PartialEq, Debug, Getters, MutGetters, TypedBuilder, ToSchema, +)] +#[get = "pub"] +#[get_mut = "pub"] +#[builder(field_defaults(setter(into)))] +pub struct FeedbackPromptFieldResponse { + #[builder(default_code = r#"nanoid::nanoid!()"#)] + id: String, + response: String, + field: String, + #[schema(value_type = FeedbackPromptFieldData)] + data: JsonV, +} + +crud!(FeedbackPromptFieldResponse {}); + +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, ToSchema)] +#[serde(untagged)] +pub enum FeedbackPromptFieldData { + Text(TextResponse), + Rating(RatingResponse), +} + +// TODO: use macro +impl PartialEq for FeedbackPromptInputType { + fn eq(&self, other: &FeedbackPromptFieldData) -> bool { + match self { + Self::Text => matches!(other, FeedbackPromptFieldData::Text(_)), + Self::Rating => matches!(other, FeedbackPromptFieldData::Rating(_)), + } + } +} + +#[derive(Deserialize, Serialize, Clone, Debug, ToSchema, PartialEq)] +pub struct TextResponse { + data: String, +} + +#[derive(Deserialize, Serialize, Clone, Debug, ToSchema, PartialEq)] +pub struct RatingResponse { + data: u8, +} diff --git a/src/database/schema/feedback/prompt.rs b/src/database/schema/feedback/prompt.rs index ecba4f23..91827506 100644 --- a/src/database/schema/feedback/prompt.rs +++ b/src/database/schema/feedback/prompt.rs @@ -20,7 +20,8 @@ //DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, //OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -use rbatis::rbdc::DateTime; +use crate::prelude::*; +use rbatis::rbdc::{DateTime, JsonV}; use super::input::FeedbackPromptInputOptions; @@ -31,14 +32,14 @@ use super::input::FeedbackPromptInputOptions; Derivative, Debug, Getters, - MutGetters, + Setters, TypedBuilder, ToSchema, Validate, )] #[derivative(PartialEq)] #[get = "pub"] -#[get_mut = "pub"] +#[set = "pub"] #[builder(field_defaults(setter(into)))] pub struct FeedbackPrompt { #[builder(default_code = r#"nanoid::nanoid!()"#)] @@ -57,9 +58,11 @@ pub struct FeedbackPrompt { } crud!(FeedbackPrompt {}); -impl_select_page!(FeedbackPrompt {select_page_by_target(target: &str) => "`WHERE target = #{target}`"}); +impl_select!(FeedbackPrompt {select_by_id(id: &str) -> Option => "`WHERE id = #{id} LIMIT 1`"}); +impl_select_page_wrapper!(FeedbackPrompt {select_page_by_target(target: &str) => "`WHERE target = #{target}`"}); #[derive(Deserialize, Serialize, Debug, Clone, PartialEq, ToSchema)] +#[serde(rename_all = "lowercase")] pub enum FeedbackPromptInputType { Text, Rating, @@ -72,14 +75,14 @@ pub enum FeedbackPromptInputType { Derivative, Debug, Getters, - MutGetters, + Setters, TypedBuilder, ToSchema, Validate, )] #[derivative(PartialEq)] #[get = "pub"] -#[get_mut = "pub"] +#[set = "pub"] #[builder(field_defaults(setter(into)))] pub struct FeedbackPromptField { #[builder(default_code = r#"nanoid::nanoid!()"#)] @@ -88,7 +91,8 @@ pub struct FeedbackPromptField { title: String, prompt: String, r#type: FeedbackPromptInputType, - options: FeedbackPromptInputOptions, + #[schema(value_type = FeedbackPromptInputOptions)] + options: JsonV, #[builder(default)] #[derivative(PartialEq = "ignore")] updated_at: DateTime, @@ -98,4 +102,5 @@ pub struct FeedbackPromptField { } crud!(FeedbackPromptField {}); -impl_select_page!(FeedbackPromptField {select_page_by_prompt(prompt: &str) => "`WHERE prompt = #{prompt}`"}); +impl_select!(FeedbackPromptField {select_by_id(id: &str) -> Option => "`WHERE id = #{id} LIMIT 1`"}); +impl_select_page_wrapper!(FeedbackPromptField {select_page_by_prompt(prompt: &str) => "`WHERE prompt = #{prompt}`"}); diff --git a/src/database/schema/feedback/target.rs b/src/database/schema/feedback/target.rs index ca371255..1d0aa85b 100644 --- a/src/database/schema/feedback/target.rs +++ b/src/database/schema/feedback/target.rs @@ -21,11 +21,12 @@ //OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. use rbatis::rbdc::DateTime; +use crate::prelude::*; -#[derive(Deserialize, Serialize, Clone, Derivative, Debug, Getters, MutGetters, TypedBuilder, ToSchema, Validate)] +#[derive(Deserialize, Serialize, Clone, Derivative, Debug, Getters, Setters, TypedBuilder, ToSchema, Validate)] #[derivative(PartialEq)] #[get = "pub"] -#[get_mut = "pub"] +#[set = "pub"] #[builder(field_defaults(setter(into)))] pub struct FeedbackTarget { #[builder(default_code = r#"nanoid::nanoid!()"#)] @@ -44,5 +45,4 @@ pub struct FeedbackTarget { crud!(FeedbackTarget {}); impl_select!(FeedbackTarget {select_by_id(id: &str) -> Option => "`WHERE id = #{id} LIMIT 1`"}); -impl_select_page!(FeedbackTarget {select_page(query: &str) => "`WHERE name LIKE '%#{query}%' ORDER BY created_at DESC`"}); - +impl_select_page_wrapper!(FeedbackTarget {select_page(query: &str) => "`WHERE name ILIKE COALESCE('%' || NULLIF(#{query}, '') || '%', '%%')`"}); diff --git a/src/docs.rs b/src/docs.rs index 1f010a1d..d278e8b7 100644 --- a/src/docs.rs +++ b/src/docs.rs @@ -20,12 +20,21 @@ //DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, //OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -use crate::{database::schema, routes}; +use crate::{ + database::schema::feedback::*, + routes::v1::{prompt::*, response::*, *}, +}; use std::{fs, path::Path}; use utoipa::{OpenApi, ToSchema}; #[derive(ToSchema)] -pub struct PageResult ToSchema<'a>> { +#[aliases( + FeedbackTargetPage = Page, + FeedbackPromptPage = Page, + FeedbackPromptFieldPage = Page + +)] +pub struct Page ToSchema<'a>> { records: Vec, total: u64, page_no: u64, @@ -35,25 +44,53 @@ pub fn generate() { #[derive(OpenApi)] #[openapi( paths( - routes::feedback::post_target, - routes::feedback::get_targets, - routes::feedback::prompt::post_prompt, - routes::feedback::prompt::get_prompts, - routes::feedback::prompt::put_prompt, - routes::feedback::prompt::delete_prompt, - routes::feedback::prompt::post_field, - routes::feedback::prompt::put_field, - routes::feedback::prompt::get_fields, - routes::feedback::prompt::delete_field, + post_target, + get_targets, + put_target, + delete_target, + post_prompt, + get_prompts, + put_prompt, + delete_prompt, + post_field, + put_field, + get_fields, + delete_field, + post_response, + get_responses ), components( schemas( - schema::feedback::FeedbackTarget, - routes::feedback::CreateFeedbackTargetRequest, + FeedbackTarget, + PutFeedbackTargetRequest, + FeedbackPrompt, + PutFeedbackPromptRequest, + FeedbackPromptField, + PutFeedbackPromptFieldRequest, + FeedbackPromptInputType, + FeedbackPromptField, + FeedbackPromptInputOptions, + TextOptions, + RatingOptions, + FeedbackPromptResponse, + FeedbackPromptFieldResponse, + FeedbackPromptFieldData, + CreateFeedbackTargetRequest, + CreateFeedbackPromptRequest, + CreateFeedbackPromptFieldRequest, + FeedbackTargetPage, + FeedbackPromptPage, + FeedbackPromptFieldPage, + GetFeedbackPromptResponsesResponse, + SubmitFeedbackPromptResponseRequest ) ), tags( - (name = "FeedbackTarget") + (name = "FeedbackTarget"), + (name = "FeedbackTargetPrompt"), + (name = "FeedbackTargetPromptField"), + (name = "FeedbackTargetPromptResponse"), + (name = "FeedbackPromptResponse") ) )] struct OpenApiSpecification; diff --git a/src/main.rs b/src/main.rs index 6b1b1587..f6fedef6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,15 +71,16 @@ async fn main() { #[cfg(not(feature = "docs"))] { - // init config - lazy_static::initialize(&CONFIG); - lazy_static::initialize(&DATABASE_CONFIG); - // init the tracing subscriber with the `RUST_LOG` env filter tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::from_default_env()) .with(tracing_subscriber::fmt::layer()) .init(); + debug!("{:?}", std::env::vars()); + + // init config + lazy_static::initialize(&CONFIG); + lazy_static::initialize(&DATABASE_CONFIG); let (sender, receiver) = kanal::oneshot_async::<()>(); let address = SocketAddr::from(([0, 0, 0, 0], 8000)); @@ -97,6 +98,7 @@ async fn main() { .await .unwrap(); }); + info!("Listening for incoming requests"); match tokio::signal::ctrl_c().await { Ok(()) => {} @@ -110,10 +112,10 @@ async fn main() { } } -async fn router(connection: DatabaseConnection) -> Router { +pub(crate) async fn router(connection: DatabaseConnection) -> Router { let state = FeedbackFusionState::new(connection); - Router::new().nest("/", routes::router(state).await).layer( + routes::router(state).await.layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|error: BoxError| async move { ( @@ -133,18 +135,13 @@ async fn router(connection: DatabaseConnection) -> Router { pub mod prelude { pub use crate::{ - config::*, - database::DatabaseConnection, - database_request, - error::*, - oidc_layer, - routes::*, - state::FeedbackFusionState, - CONFIG, DATABASE_CONFIG, + config::*, database::DatabaseConnection, database_request, error::*, oidc_layer, routes::*, + impl_select_page_wrapper, state::FeedbackFusionState, CONFIG, DATABASE_CONFIG, }; pub use axum::{ extract::{Json, Query, State}, routing::*, Router, }; + pub use rbatis::{rbdc::JsonV, plugin::page::Page, IPageRequest}; } diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 92abda7e..1ea72af5 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -21,25 +21,31 @@ //OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. use crate::prelude::*; -use rbatis::sql::PageRequest; -pub mod feedback; +use rbatis::plugin::page::PageRequest; + +pub mod v1; mod oidc; pub async fn router(state: FeedbackFusionState) -> Router { - Router::new().nest("/feedback", feedback::router(state).await) + Router::new().nest("/v1", v1::router(state).await) } #[derive(Debug, Clone, Deserialize, IntoParams)] +#[into_params(parameter_in = Query)] pub struct SearchQuery { #[serde(default)] + #[param(nullable)] query: String, } #[derive(Debug, Clone, Deserialize, IntoParams)] +#[into_params(parameter_in = Query)] pub struct Pagination { + #[param(default = 1)] #[serde(default = "page")] page: usize, + #[param(default = 20)] #[serde(default = "page_size")] page_size: usize, } @@ -56,4 +62,10 @@ impl Pagination { pub fn request(self) -> PageRequest { PageRequest::new(self.page as u64, self.page_size as u64) } + + pub fn eval(&self) -> (u64, u64) { + let request = self.clone().request(); + + (request.page_size(), request.offset()) + } } diff --git a/src/routes/feedback/mod.rs b/src/routes/v1/mod.rs similarity index 63% rename from src/routes/feedback/mod.rs rename to src/routes/v1/mod.rs index db47cd34..105ffa20 100644 --- a/src/routes/feedback/mod.rs +++ b/src/routes/v1/mod.rs @@ -21,28 +21,34 @@ //OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. use axum::{extract::Path, http::StatusCode}; -use rbatis::sql::Page; use validator::Validate; use crate::{database::schema::feedback::FeedbackTarget, prelude::*}; pub mod prompt; +pub mod response; pub async fn router(state: FeedbackFusionState) -> Router { Router::new() .route( "/target", - post(post_target) - .put(put_target) - .get(get_targets) - .get(get_target) + post(post_target).get(get_targets).layer(oidc_layer!()), + ) + .route( + "/target/:target", + get(get_target) .delete(delete_target) + .put(put_target) .layer(oidc_layer!()), ) .nest( "/target/:target/prompt", prompt::router(state.clone()).await, ) + .nest( + "/target/:target/prompt/:prompt/response", + response::router(state.clone()).await, + ) .with_state(state) } @@ -53,8 +59,8 @@ pub struct CreateFeedbackTargetRequest { description: Option, } -/// POST /feedback/target -#[utoipa::path(post, path = "/feedback/target", request_body = CreateFeedbackTargetRequest, tag = "FeedbackTarget", responses( +/// POST /v1/target +#[utoipa::path(post, path = "/v1/target", request_body = CreateFeedbackTargetRequest, tag = "FeedbackTarget", responses( (status = 201, description = "Target created", body = FeedbackTarget) ))] pub async fn post_target( @@ -77,9 +83,9 @@ pub async fn post_target( Ok((StatusCode::CREATED, Json(target))) } -/// GET /feedback/target -#[utoipa::path(get, path = "/feedback/target", params(SearchQuery, Pagination ), tag = "FeedbackTarget", responses( - (status = 200, description = "Page of Targets", body = PageResult) +/// GET /v1/target +#[utoipa::path(get, path = "/v1/target", params(SearchQuery, Pagination ), tag = "FeedbackTarget", responses( + (status = 200, description = "Page of Targets", body = FeedbackTargetPage) ))] pub async fn get_targets( State(state): State, @@ -90,24 +96,29 @@ pub async fn get_targets( // fetch the Page let page = database_request!( - FeedbackTarget::select_page(connection, &pagination.request(), search.query.as_str()) - .await? + FeedbackTarget::select_page_wrapper( + connection, + &pagination.request(), + search.query.as_str() + ) + .await? ); Ok(Json(page)) } -/// GET /feedback/target/:id -#[utoipa::path(get, path = "/feedback/target/:id", tag = "FeedbackTarget", responses( +/// GET /v1/target/:target +#[utoipa::path(get, path = "/v1/target/:id", tag = "FeedbackTarget", responses( (status = 200, description = "Target", body = FeedbackTarget), (status = 400, description = "Target not found") ))] pub async fn get_target( State(state): State, - Path(id): Path, + Path(target): Path, ) -> Result> { let connection = state.connection(); - let target = database_request!(FeedbackTarget::select_by_id(connection, id.as_str()).await?); + let target = + database_request!(FeedbackTarget::select_by_id(connection, target.as_str()).await?); match target { Some(target) => Ok(Json(target)), None => Err(FeedbackFusionError::BadRequest( @@ -116,32 +127,48 @@ pub async fn get_target( } } -/// PUT /feedback/target -#[utoipa::path(put, path = "/feedback/target", request_body = FeedbackTarget, tag = "FeedbackTarget", responses( +#[derive(Clone, Debug, Deserialize, ToSchema, Validate)] +pub struct PutFeedbackTargetRequest { + #[validate(length(max = 255))] + name: Option, + #[validate(length(max = 255))] + description: Option, +} + +/// PUT /v1/target/:target +#[utoipa::path(put, path = "/v1/target/:target", request_body = PutFeedbackTargetRequest, tag = "FeedbackTarget", responses( (status = 200, description = "Updated", body = FeedbackTarget) ))] pub async fn put_target( State(state): State, - Json(target): Json, + Path(target): Path, + Json(data): Json, ) -> Result> { - let connection = state.connection(); + data.validate()?; - target.validate()?; + let mut target = database_request!(FeedbackTarget::select_by_id( + state.connection(), + target.as_str() + ) + .await? + .ok_or(FeedbackFusionError::BadRequest("not found".to_owned()))?); + target.set_name(data.name.unwrap_or(target.name().clone())); + target.set_description(data.description.or(target.description().clone())); - database_request!(FeedbackTarget::update_by_column(connection, &target, "id").await?); + database_request!(FeedbackTarget::update_by_column(state.connection(), &target, "id").await?); Ok(Json(target)) } -/// DELETE /feedback/target/:id -#[utoipa::path(delete, path = "/feedback/target/:id", tag = "FeedbackTarget", responses( +/// DELETE /v1/target/:target +#[utoipa::path(delete, path = "/v1/target/:target", tag = "FeedbackTarget", responses( (status = 200, description = "Deleted") ))] pub async fn delete_target( State(state): State, - Path(id): Path, + Path(target): Path, ) -> Result { let connection = state.connection(); - database_request!(FeedbackTarget::delete_by_column(connection, "id", id.as_str()).await?); + database_request!(FeedbackTarget::delete_by_column(connection, "id", target.as_str()).await?); Ok(StatusCode::OK) } diff --git a/src/routes/feedback/prompt.rs b/src/routes/v1/prompt.rs similarity index 57% rename from src/routes/feedback/prompt.rs rename to src/routes/v1/prompt.rs index ae57349a..4e3f6e31 100644 --- a/src/routes/feedback/prompt.rs +++ b/src/routes/v1/prompt.rs @@ -1,4 +1,3 @@ -//SPDX-FileCopyrightText: 2023 OneLiteFeatherNet //SPDX-License-Identifier: MIT //MIT License @@ -21,7 +20,6 @@ //OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. use axum::{extract::Path, http::StatusCode}; -use rbatis::sql::Page; use validator::Validate; use crate::{ @@ -33,13 +31,10 @@ use crate::{ pub async fn router(state: FeedbackFusionState) -> Router { Router::new() - .route("/", post(post_prompt).get(get_prompts).put(put_prompt)) - .route("/:prompt", delete(delete_prompt)) - .route( - "/:prompt/field", - post(post_field).put(put_field).get(get_fields), - ) - .route("/:prompt/field/:field", delete(delete_field)) + .route("/", post(post_prompt).get(get_prompts)) + .route("/:prompt", delete(delete_prompt).put(put_prompt)) + .route("/:prompt/field", post(post_field).get(get_fields)) + .route("/:prompt/field/:field", delete(delete_field).put(put_field)) .layer(oidc_layer!()) .with_state(state) } @@ -52,10 +47,10 @@ pub struct CreateFeedbackPromptRequest { active: bool, } -/// POST /feedback/target/:target/prompt -#[utoipa::path(post, path = "/feedback/target/:target/prompt", request_body = CreateFeedbackPromptRequest, responses( +/// POST /v1/target/:target/prompt +#[utoipa::path(post, path = "/v1/target/:target/prompt", request_body = CreateFeedbackPromptRequest, responses( (status = 201, body = FeedbackPrompt) -))] +), tag = "FeedbackTargetPrompt")] pub async fn post_prompt( State(state): State, Path(target): Path, @@ -74,17 +69,17 @@ pub async fn post_prompt( Ok((StatusCode::CREATED, Json(prompt))) } -/// GET /feedback/target/:target/prompt -#[utoipa::path(get, path = "/feedback/target/:target/prompt", params(Pagination), responses( - (status = 200, body = Page) -))] +/// GET /v1/target/:target/prompt +#[utoipa::path(get, path = "/v1/target/:target/prompt", params(Pagination), responses( + (status = 200, body = FeedbackPromptPage) +), tag = "FeedbackTargetPrompt")] pub async fn get_prompts( State(state): State, Query(pagination): Query, Path(target): Path, ) -> Result>> { let prompts = database_request!( - FeedbackPrompt::select_page_by_target( + FeedbackPrompt::select_page_by_target_wrapper( state.connection(), &pagination.request(), target.as_str(), @@ -95,24 +90,41 @@ pub async fn get_prompts( Ok(Json(prompts)) } -/// PUT /feedback/target/:target/prompt -#[utoipa::path(put, path = "/feedback/target/:target/prompt", request_body = FeedbackPrompt, responses( +#[derive(Deserialize, Debug, Clone, ToSchema, Validate)] +pub struct PutFeedbackPromptRequest { + #[validate(length(max = 255))] + title: Option, + active: Option, +} + +/// PUT /v1/target/:target/prompt/:prompt +#[utoipa::path(put, path = "/v1/target/:target/prompt/:prompt", request_body = PutFeedbackPromptRequest, responses( (status = 200, body = FeedbackPrompt) -))] +), tag = "FeedbackTargetPrompt")] pub async fn put_prompt( State(state): State, - Json(prompt): Json, + Path((_, prompt)): Path<(String, String)>, + Json(data): Json, ) -> Result> { - prompt.validate()?; + data.validate()?; + let mut prompt = database_request!(FeedbackPrompt::select_by_id( + state.connection(), + prompt.as_str() + ) + .await? + .ok_or(FeedbackFusionError::BadRequest("not found".to_owned()))?); + + prompt.set_title(data.title.unwrap_or(prompt.title().clone())); + prompt.set_active(data.active.unwrap_or(*prompt.active())); database_request!(FeedbackPrompt::update_by_column(state.connection(), &prompt, "id").await?); Ok(Json(prompt)) } -/// DELETE /feedback/target/:target/prompt/:prompt -#[utoipa::path(delete, path = "/feedback/target/:target/prompt/:prompt", responses( +/// DELETE /v1/target/:target/prompt/:prompt +#[utoipa::path(delete, path = "/v1/target/:target/prompt/:prompt", responses( (status = 200, description = "Deleted") -))] +), tag = "FeedbackTargetPrompt")] pub async fn delete_prompt( State(state): State, Path((_, prompt)): Path<(String, String)>, @@ -125,27 +137,33 @@ pub async fn delete_prompt( #[derive(Debug, Clone, ToSchema, Deserialize, Validate)] pub struct CreateFeedbackPromptFieldRequest { + #[validate(length(max = 255))] title: String, r#type: FeedbackPromptInputType, options: FeedbackPromptInputOptions, } -/// POST /feedback/target/:target/prompt/:prompt/field -#[utoipa::path(post, path = "/feedback/target/:target/prompt/:prompt/field", request_body = CreateFeedbackPromptFieldRequest, responses( +/// POST /v1/target/:target/prompt/:prompt/field +#[utoipa::path(post, path = "/v1/target/:target/prompt/:prompt/field", request_body = CreateFeedbackPromptFieldRequest, responses( (status = 201, description = "Created", body = FeedbackPromptField) -))] +), tag = "FeedbackTargetPromptField")] pub async fn post_field( State(state): State, Path((_, prompt)): Path<(String, String)>, Json(data): Json, ) -> Result<(StatusCode, Json)> { data.validate()?; + // validate type and enum + if !data.r#type.eq(&data.options) { + return Err(FeedbackFusionError::BadRequest("type does not match".to_owned())); + }; + // build the field let field = FeedbackPromptField::builder() .title(data.title) .r#type(data.r#type) - .options(data.options) + .options(JsonV(data.options)) .prompt(prompt) .build(); database_request!(FeedbackPromptField::insert(state.connection(), &field).await?); @@ -153,17 +171,17 @@ pub async fn post_field( Ok((StatusCode::CREATED, Json(field))) } -/// GET /feedback/target/:target/prompt/:prompt/field -#[utoipa::path(get, path = "/feedback/target/:target/prompt/:prompt/field", params(Pagination), responses( - (status = 200, body = Page) -))] +/// GET /v1/target/:target/prompt/:prompt/field +#[utoipa::path(get, path = "/v1/target/:target/prompt/:prompt/field", params(Pagination), responses( + (status = 200, body = FeedbackPromptFieldPage) +), tag = "FeedbackTargetPromptField")] pub async fn get_fields( State(state): State, Query(pagination): Query, Path((_, prompt)): Path<(String, String)>, ) -> Result>> { let page = database_request!( - FeedbackPromptField::select_page_by_prompt( + FeedbackPromptField::select_page_by_prompt_wrapper( state.connection(), &pagination.request(), prompt.as_str() @@ -174,26 +192,53 @@ pub async fn get_fields( Ok(Json(page)) } -/// PUT /feedback/target/:target/prompt/:prompt/field -#[utoipa::path(put, path = "/feedback/target/:target/prompt/:prompt/field", request_body = FeedbackPromptField, responses( +#[derive(Debug, Clone, Deserialize, Validate, ToSchema)] +pub struct PutFeedbackPromptFieldRequest { + #[validate(length(max = 255))] + title: Option, + options: Option, +} + +/// PUT /v1/target/:target/prompt/:prompt/field/:field +#[utoipa::path(put, path = "/v1/target/:target/prompt/:prompt/field/:field", request_body = PutFeedbackPromptFieldRequest, responses( (status = 200, body = FeedbackPromptField, description = "updated") -))] +), tag = "FeedbackTargetPromptField")] pub async fn put_field( State(state): State, - Json(data): Json, + Path((_, _, field)): Path<(String, String, String)>, + Json(data): Json, ) -> Result> { data.validate()?; + + let mut field = database_request!(FeedbackPromptField::select_by_id( + state.connection(), + field.as_str() + ) + .await? + .ok_or(FeedbackFusionError::BadRequest("not found".to_owned()))?); + // validate type and enum + if data.options.as_ref().is_some_and(|options|!field.r#type().eq(options)) { + return Err(FeedbackFusionError::BadRequest("type does not match".to_owned())); + }; + + + field.set_title(data.title.unwrap_or(field.title().to_string())); + if let Some(options) = data.options { + if field.r#type().eq(&options) { + field.set_options(JsonV(options)); + } + } database_request!( - FeedbackPromptField::update_by_column(state.connection(), &data, "id").await? + FeedbackPromptField::update_by_column(state.connection(), &field, "id").await? ); - Ok(Json(data)) + Ok(Json(field)) } -/// DELETE /feedback/target/:target/prompt/:prompt/field/:field -#[utoipa::path(delete, path = "/feedback/target/:target/prompt/:prompt/field/:field", responses( +/// DELETE /v1/target/:target/prompt/:prompt/field/:field +#[utoipa::path(delete, path = "/v1/target/:target/prompt/:prompt/field/:field", responses( (status = 200, description = "Deleted") -))] +), tag = "FeedbackTargetPromptField")] pub async fn delete_field( State(state): State, Path((_, _, field)): Path<(String, String, String)>, diff --git a/src/routes/v1/response.rs b/src/routes/v1/response.rs new file mode 100644 index 00000000..c68de8e6 --- /dev/null +++ b/src/routes/v1/response.rs @@ -0,0 +1,167 @@ +//SPDX-FileCopyrightText: 2023 OneLiteFeatherNet +//SPDX-License-Identifier: MIT + +//MIT License + +// Copyright (c) 2023 OneLiteFeatherNet + +//Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +//associated documentation files (the "Software"), to deal in the Software without restriction, +//including without limitation the rights to use, copy, modify, merge, publish, distribute, +//sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +//furnished to do so, subject to the following conditions: + +//The above copyright notice and this permission notice (including the next paragraph) shall be +//included in all copies or substantial portions of the Software. + +//THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +//NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +//NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +//DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +//OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +use std::collections::HashMap; + +use crate::{ + database::schema::feedback::{ + FeedbackPromptField, FeedbackPromptFieldData, FeedbackPromptFieldResponse, + FeedbackPromptResponse, + }, + prelude::*, +}; + +use axum::{extract::Path, http::StatusCode}; +use rbatis::rbatis_codegen::IntoSql; + +pub async fn router(state: FeedbackFusionState) -> Router { + Router::new() + .route("/", post(post_response)) + .route("/", get(get_responses).layer(oidc_layer!())) + .with_state(state) +} + +#[derive(Deserialize, Clone, Debug, ToSchema)] +pub struct SubmitFeedbackPromptResponseRequest { + responses: HashMap, +} + +/// POST /v1/target/:target/prompt/:prompt/response +#[utoipa::path(post, path = "/v1/target/:target/prompt/:prompt/response", request_body = SubmitFeedbackPromptResponseRequest, responses( + (status = 200, description = "Created", body = FeedbackPromptResponse) +), tag = "FeedbackPromptResponse")] +pub async fn post_response( + State(state): State, + Path((_, prompt)): Path<(String, String)>, + Json(data): Json, +) -> Result<(StatusCode, Json)> { + // start transaction + let mut transaction = state.connection().acquire_begin().await?; + // fetch the fields of the prompt + let fields = database_request!( + FeedbackPromptField::select_by_column(&transaction, "prompt", prompt.as_str()).await? + ); + // as we can assume a prompt has to have at least 1 field we can throw the 400 here + if fields.is_empty() { + return Err(FeedbackFusionError::BadRequest("invalid prompt".to_owned())); + } + + // insert the response dataprompt + let response = FeedbackPromptResponse::builder().prompt(prompt).build(); + database_request!(FeedbackPromptResponse::insert(&transaction, &response).await?); + + // transform the hashmap into a field data vec + let data = data + .responses + .into_iter() + .filter_map(|(field, value)| { + // validate the type of field and response + if fields + .iter() + .any(|f| field.eq(f.id()) && f.r#type().eq(&value)) + { + Some( + FeedbackPromptFieldResponse::builder() + .response(response.id().as_str()) + .field(field) + .data(JsonV(value)) + .build(), + ) + } else { + None + } + }) + .collect::>(); + // insert them as batch + database_request!( + FeedbackPromptFieldResponse::insert_batch(&transaction, data.as_slice(), data.len() as u64) + .await? + ); + + // commit the transaction + transaction.commit().await?; + + Ok((StatusCode::CREATED, Json(response))) +} + +pub type GetFeedbackPromptResponsesResponse = HashMap>; +#[derive(Deserialize, Debug, Clone)] +struct DatabaseResult { + result: JsonV, +} + +#[py_sql( + "`SELECT jsonb_object_agg(response, rows) AS RESULT FROM (` + `SELECT response, ` + `jsonb_agg(jsonb_build_object('id', id, 'response', response, 'field', field, 'data', data)) AS ROWS ` + `FROM feedback_prompt_field_response ` + ` WHERE response IN ` + ${responses.sql()} + ` GROUP BY response) subquery`" +)] +async fn group_field_responses( + rb: &dyn rbatis::executor::Executor, + responses: &[String], +) -> rbatis::Result { + impled!() +} + +/// GET /v1/target/:target/prompt/:prompt/response +#[utoipa::path(get, path = "/v1/target/:target/prompt/:prompt/response", params(Pagination), responses( + (status = 200, body = GetFeedbackPromptResponsesResponse) +), tag = "FeedbackPromptResponse")] +pub async fn get_responses( + State(state): State, + Path((_, prompt)): Path<(String, String)>, + Query(pagination): Query, +) -> Result> { + // select a page of responses + let responses = database_request!( + FeedbackPromptResponse::select_page_by_prompt_wrapper( + state.connection(), + &pagination.request(), + prompt.as_str() + ) + .await? + ); + + let records = if responses.total > 0 { + database_request!( + group_field_responses( + state.connection(), + responses + .records + .iter() + .map(|response| response.id().clone()) + .collect::>() + .as_slice(), + ) + .await? + .result + .0 + ) + } else { + HashMap::new() + }; + + Ok(Json(records)) +} diff --git a/testing/oidc-mock/clients-config.json b/testing/oidc-mock/clients-config.json index 6258b29c..b242ba8e 100644 --- a/testing/oidc-mock/clients-config.json +++ b/testing/oidc-mock/clients-config.json @@ -4,7 +4,7 @@ "ClientSecrets": ["secret"], "Description": "mock client", "AllowedGrantTypes": ["client_credentials"], - "AllowedScopes": ["openid"], + "AllowedScopes": ["api"], "ClientClaimsPrefix": "" } ] diff --git a/testing/oidc-mock/docker-compose.yaml b/testing/oidc-mock/docker-compose.yaml index 9f4b5809..3d4aab0a 100644 --- a/testing/oidc-mock/docker-compose.yaml +++ b/testing/oidc-mock/docker-compose.yaml @@ -26,6 +26,8 @@ services: { "AutomaticRedirectAfterSignOut": true } + API_SCOPES_INLINE: | + - Name: api USERS_CONFIGURATION_INLINE: | [ { diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 00000000..51b6ef1d --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,119 @@ +//SPDX-FileCopyrightText: 2023 OneLiteFeatherNet +//SPDX-License-Identifier: MIT + +//MIT License + +// Copyright (c) 2023 OneLiteFeatherNet + +//Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +//associated documentation files (the "Software"), to deal in the Software without restriction, +//including without limitation the rights to use, copy, modify, merge, publish, distribute, +//sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +//furnished to do so, subject to the following conditions: + +//The above copyright notice and this permission notice (including the next paragraph) shall be +//included in all copies or substantial portions of the Software. + +//THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +//NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +//NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +//DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +//OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +use openidconnect::{ + core::{CoreClient, CoreProviderMetadata}, + reqwest::async_http_client, + ClientId, ClientSecret, IssuerUrl, OAuth2TokenResponse, Scope, +}; +use reqwest::{ + header::{HeaderMap, HeaderValue}, + Client, +}; +use std::{ + fs::File, + path::Path, + process::{Child, Command, Stdio}, +}; +use tracing::debug; + +pub const HTTP_ENDPOINT: &'static str = "http://localhost:8000"; + +pub struct BackendServer(Child); + +impl Drop for BackendServer { + fn drop(&mut self) { + let _ = self.0.kill(); + } +} + +pub fn run_server() -> BackendServer { + // construct the executable path + let mut path = std::env::current_exe().unwrap(); + assert!(path.pop()); + assert!(path.pop()); + path = path.join(env!("CARGO_PKG_NAME")); + + // prepare the command + let mut command = Command::new(path); + let seed = rand::random::(); + let stdout = Stdio::from(File::create(Path::new(env!("OUT_DIR")).join(format!("{}stdout", seed))).unwrap()); + let stderr = Stdio::from(File::create(Path::new(env!("OUT_DIR")).join(format!("{}stderr", seed))).unwrap()); + debug!("OUT={} SEED={}", env!("OUT_DIR"), seed); + + command.stdin(Stdio::piped()); + command.stdout(stdout); + command.stderr(stderr); + + command.env_clear(); + let database = env!("DATABASE"); + let mut env = vec!["_USERNAME", "_PASSWORD", "_ENDPOINT", "_DATABASE"] + .into_iter() + .map(|s| format!("{}{}", database, s)) + .collect::>(); + env.push("OIDC_DISCOVERY_URL".to_owned()); + + for key in env.iter() { + if let Ok(value) = std::env::var(key) { + debug!("{:?}: {:?}", key, value); + command.env(key, value); + } + } + command.env("RUST_LOG", "DEBUG"); + + let child = command.spawn().unwrap(); + std::thread::sleep(std::time::Duration::from_secs(1)); + + BackendServer(child) +} + +pub async fn authenticate() -> String { + let issuer = IssuerUrl::new(env!("OIDC_DISCOVERY_URL").to_owned()).unwrap(); + let metadata = CoreProviderMetadata::discover_async(issuer, async_http_client) + .await + .unwrap(); + let client = CoreClient::from_provider_metadata( + metadata, + ClientId::new(env!("OIDC_CLIENT_ID").to_owned()), + Some(ClientSecret::new(env!("OIDC_CLIENT_SECRET").to_owned())), + ); + + let token_response = client + .exchange_client_credentials() + .add_scope(Scope::new("api".to_owned())) + .request_async(async_http_client) + .await + .unwrap(); + + token_response.access_token().secret().clone() +} + +pub async fn client() -> Client { + let access_token = authenticate().await; + + let mut headers = HeaderMap::new(); + headers.insert( + "Authorization", + HeaderValue::from_str(format!("Bearer {}", access_token).as_str()).unwrap(), + ); + Client::builder().default_headers(headers).build().unwrap() +} diff --git a/tests/http_tests.rs b/tests/http_tests.rs new file mode 100644 index 00000000..f0dd162f --- /dev/null +++ b/tests/http_tests.rs @@ -0,0 +1,670 @@ +//SPDX-FileCopyrightText: 2023 OneLiteFeatherNet +//SPDX-License-Identifier: MIT + +//MIT License + +// Copyright (c) 2023 OneLiteFeatherNet + +//Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +//associated documentation files (the "Software"), to deal in the Software without restriction, +//including without limitation the rights to use, copy, modify, merge, publish, distribute, +//sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +//furnished to do so, subject to the following conditions: + +//The above copyright notice and this permission notice (including the next paragraph) shall be +//included in all copies or substantial portions of the Software. + +//THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +//NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +//NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +//DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +//OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +use common::*; +use rbatis::sql::Page; +use reqwest::StatusCode; +use serde::Deserialize; +use test_log::test; + +mod common; + +#[derive(Debug, Clone, Deserialize, PartialEq)] +struct TargetResponse { + id: String, + name: String, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +struct PromptResponse { + id: String, + target: String, + active: bool, + title: String, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +struct FieldResponse { + id: String, + prompt: String, + title: String, +} + +#[test(tokio::test)] +async fn test_target_endpoints() { + let _server = run_server(); + let client = client().await; + + // test auth + { + let response = reqwest::Client::default() + .post(format!("{}/v1/target", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .get(format!("{}/v1/target", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .put(format!("{}/v1/target", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .delete(format!("{}/v1/target/awdawd", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + } + + // test creation + let target = { + let response = client + .post(format!("{}/v1/target", HTTP_ENDPOINT)) + .json(&serde_json::json!({ + "name": "Name", + "description": "Description" + })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::CREATED, response.status()); + + let data = response.json::().await; + assert!(data.is_ok()); + + data.unwrap() + }; + + // test get by id endpoint + { + let response = client + .get(format!("{}/v1/target/{}", HTTP_ENDPOINT, &target.id)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let data = response.json::().await.unwrap(); + assert_eq!(&target, &data); + } + + // test get page endpoint + { + let response = client + .get(format!("{}/v1/target", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let data = response.json::>().await.unwrap(); + assert_eq!(1, data.records.len()); + let first = data.records.first().unwrap(); + assert_eq!(&target, first); + } + + // test put endpoint + { + let response = client + .put(format!("{}/v1/target/{}", HTTP_ENDPOINT, &target.id)) + .json(&serde_json::json!({ + "name": "updated" + })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let response = client + .get(format!("{}/v1/target/{}", HTTP_ENDPOINT, &target.id)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let data = response.json::().await.unwrap(); + assert_eq!("updated", data.name.as_str()); + } + + // test delete endpoint + { + let response = client + .delete(format!("{}/v1/target/{}", HTTP_ENDPOINT, &target.id)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let response = client + .get(format!("{}/v1/target/{}", HTTP_ENDPOINT, &target.id)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::BAD_REQUEST, response.status()); + } +} + +#[test(tokio::test)] +async fn test_prompt_endpoints() { + let _server = run_server(); + let client = client().await; + // prepare dependencies + let response = client + .post(format!("{}/v1/target", HTTP_ENDPOINT)) + .json(&serde_json::json!({ + "name": "Name" + })) + .send() + .await + .unwrap(); + let target = response.json::().await.unwrap(); + + // test auth + { + let response = reqwest::Client::default() + .post(format!("{}/v1/target/test/prompt", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .get(format!("{}/v1/target/test/prompt", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .put(format!("{}/v1/target/test/prompt/test", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .delete(format!("{}/v1/target/test/prompt/test", HTTP_ENDPOINT)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + } + + // test post + let prompt = { + let response = client + .post(format!("{}/v1/target/{}/prompt", HTTP_ENDPOINT, &target.id)) + .json(&serde_json::json!({ + "title": "title" + })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::CREATED, response.status()); + + let data = response.json::().await; + assert!(data.is_ok()); + + data.unwrap() + }; + + // test get + { + let response = client + .get(format!("{}/v1/target/{}/prompt", HTTP_ENDPOINT, &target.id)) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let data = response.json::>().await.unwrap(); + assert_eq!(1, data.records.len()); + assert_eq!(&prompt, data.records.first().unwrap()); + } + + // test put + { + let response = client + .put(format!( + "{}/v1/target/{}/prompt/{}", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .json(&serde_json::json!({ + "active": false + })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let response = client + .get(format!("{}/v1/target/{}/prompt", HTTP_ENDPOINT, &target.id)) + .send() + .await + .unwrap(); + assert_eq!( + false, + response + .json::>() + .await + .unwrap() + .records + .first() + .unwrap() + .active + ); + } + + // test delete + { + let response = client + .delete(format!( + "{}/v1/target/{}/prompt/{}", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let response = client + .get(format!("{}/v1/target/{}/prompt", HTTP_ENDPOINT, &target.id)) + .send() + .await + .unwrap(); + assert_eq!( + 0, + response + .json::>() + .await + .unwrap() + .records + .len() + ); + } +} + +#[test(tokio::test)] +async fn test_prompt_field_endpoints() { + let _server = run_server(); + let client = client().await; + + // prepare dependencies + let (target, prompt) = { + let target = { + let response = client + .post(format!("{}/v1/target", HTTP_ENDPOINT)) + .json(&serde_json::json!({ + "name": "Name" + })) + .send() + .await + .unwrap(); + response.json::().await.unwrap() + }; + + let prompt = { + let response = client + .post(format!("{}/v1/target/{}/prompt", HTTP_ENDPOINT, &target.id)) + .json(&serde_json::json!({ + "title": "title" + })) + .send() + .await + .unwrap(); + response.json::().await.unwrap() + }; + + (target, prompt) + }; + + // test auth + { + let response = reqwest::Client::default() + .post(format!( + "{}/v1/target/test/prompt/test/field", + HTTP_ENDPOINT + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .get(format!( + "{}/v1/target/test/prompt/test/field", + HTTP_ENDPOINT + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .put(format!( + "{}/v1/target/test/prompt/test/field/test", + HTTP_ENDPOINT + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + + let response = reqwest::Client::default() + .delete(format!( + "{}/v1/target/test/prompt/test/field/test", + HTTP_ENDPOINT + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + } + + // test post + let field = { + // test wrong type + let response = client + .post(format!( + "{}/v1/target/{}/prompt/{}/field", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .json(&serde_json::json!({ + "title": "test", + "type": "text", + "options": {"max": 5, "description": "hell yea"} + })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::BAD_REQUEST, response.status()); + + // test insert + let response = client + .post(format!( + "{}/v1/target/{}/prompt/{}/field", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .json(&serde_json::json!({ + "title": "Test", + "type": "text", + "options": { + "placeholder": "placeholder", + "description": "description", + } + })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::CREATED, response.status()); + + let field = response.json::().await.unwrap(); + assert_eq!(&prompt.id, &field.prompt); + field + }; + + // test get + { + let response = client + .get(format!( + "{}/v1/target/{}/prompt/{}/field", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let data = response.json::>().await.unwrap(); + assert_eq!(1, data.total); + assert_eq!(&&field, data.records.first().as_ref().unwrap()); + } + + // test put + { + // test put invalid options + let response = client + .put(format!( + "{}/v1/target/{}/prompt/{}/field/{}", + HTTP_ENDPOINT, &target.id, &prompt.id, &field.id + )) + .json(&serde_json::json!({ + "options": {"max": 5, "description": "test"} + })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::BAD_REQUEST, response.status()); + + // test put title + let response = client + .put(format!( + "{}/v1/target/{}/prompt/{}/field/{}", + HTTP_ENDPOINT, &target.id, &prompt.id, &field.id + )) + .json(&serde_json::json!({ "title": "Updated" })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let response = client + .get(format!( + "{}/v1/target/{}/prompt/{}/field", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .send() + .await + .unwrap(); + let data = response.json::>().await.unwrap(); + assert_eq!("Updated", data.records.first().unwrap().title.as_str()); + } + + // test delete + { + let response = client + .delete(format!( + "{}/v1/target/{}/prompt/{}/field/{}", + HTTP_ENDPOINT, &target.id, &prompt.id, &field.id + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + + let response = client + .get(format!( + "{}/v1/target/{}/prompt/{}/field", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .send() + .await + .unwrap(); + let page = response.json::>().await.unwrap(); + assert_eq!(0, page.total); + } +} + +#[test(tokio::test)] +async fn test_response_endpoints() { + let _server = run_server(); + let client = client().await; + + let (target, prompt) = { + let target = { + let response = client + .post(format!("{}/v1/target", HTTP_ENDPOINT)) + .json(&serde_json::json!({ + "name": "Name" + })) + .send() + .await + .unwrap(); + response.json::().await.unwrap() + }; + + let prompt = { + let response = client + .post(format!("{}/v1/target/{}/prompt", HTTP_ENDPOINT, &target.id)) + .json(&serde_json::json!({ + "title": "title" + })) + .send() + .await + .unwrap(); + response.json::().await.unwrap() + }; + + (target, prompt) + }; + + // test auth + { + let response = reqwest::Client::default() + .post(format!( + "{}/v1/target/{}/prompt/{}/response", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNSUPPORTED_MEDIA_TYPE, response.status()); + + let response = reqwest::Client::default() + .get(format!( + "{}/v1/target/{}/prompt/{}/response", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::UNAUTHORIZED, response.status()); + } + + // test post response on empty prompt and non existing prompt + { + let response = client + .post(format!( + "{}/v1/target/{}/prompt/test/response", + HTTP_ENDPOINT, &target.id + )) + .json(&serde_json::json!({ "responses": {} })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::BAD_REQUEST, response.status()); + + let response = client + .post(format!( + "{}/v1/target/{}/prompt/{}/response", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .json(&serde_json::json!({ "responses": {} })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::BAD_REQUEST, response.status()); + } + + // create testing fields + let (text_field, rating_field) = { + let response = client + .post(format!( + "{}/v1/target/{}/prompt/{}/field", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .json(&serde_json::json!({ + "title": "Test", + "type": "text", + "options": { + "placeholder": "placeholder", + "description": "description", + } + })) + .send() + .await + .unwrap(); + let text_field = response.json::().await.unwrap(); + + let response = client + .post(format!( + "{}/v1/target/{}/prompt/{}/field", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .json(&serde_json::json!({ + "title": "Test", + "type": "rating", + "options": { + "max": 10, + "description": "description", + } + })) + .send() + .await + .unwrap(); + let rating_field = response.json::().await.unwrap(); + + (text_field, rating_field) + }; + + // test post response + { + let response = client + .post(format!( + "{}/v1/target/{}/prompt/{}/response", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .json(&serde_json::json!({ + "responses": { + &text_field.id: {"data": "Yea"}, + &rating_field.id: {"data": 5} + } + })) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::CREATED, response.status()); + } + + // test get + { + let response = client + .get(format!( + "{}/v1/target/{}/prompt/{}/response", + HTTP_ENDPOINT, &target.id, &prompt.id + )) + .send() + .await + .unwrap(); + assert_eq!(StatusCode::OK, response.status()); + } +}