diff --git a/Cargo.toml b/Cargo.toml index 93561fd1..a3045a5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,11 +17,15 @@ async-trait = "0.1.80" chrono = { version = "0.4.38", features = ["serde"] } derivative = "2.2.0" envy = "0.4.2" +futures = "0.3.30" getset = "0.1.2" itertools = "0.12.1" kanal = "0.1.0-pre8" lazy_static = "1.4.0" nanoid = "0.4.0" +notify = { version = "6.1.1", default-features = false, features = [ + "macos_kqueue", +] } openidconnect = "3.5.0" paste = "1.0.14" prost-types = "0.12.4" @@ -34,6 +38,7 @@ serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.115" serde-this-or-that = { git = "https://github.com/Randoooom/serde-this-or-that.git", branch = "feature/bool-signed" } thiserror = "1.0.58" +toml = "0.8.13" tonic = "0.11.0" tonic-health = "0.11.0" tonic-reflection = "0.11.0" diff --git a/src/config.rs b/src/config.rs index fbb356d8..56649a40 100644 --- a/src/config.rs +++ b/src/config.rs @@ -21,7 +21,12 @@ * */ -use crate::prelude::*; +use rbatis::executor::Executor; + +use crate::{ + database::schema::feedback::{Field, FieldOptions, FieldType, Prompt, Target}, + prelude::*, +}; lazy_static! { pub static ref CONFIG: Config = envy::from_env::().unwrap(); @@ -37,6 +42,7 @@ pub struct Config { oidc_discovery_url: String, #[serde(default = "default_oidc_audience")] oidc_audience: String, + config_path: Option, } #[inline] @@ -48,3 +54,155 @@ fn default_global_rate_limit() -> u64 { fn default_oidc_audience() -> String { "feedback-fusion".to_owned() } + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct InstanceConfig { + targets: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct TargetConfig { + id: String, + name: String, + description: Option, + prompts: Option>, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct PromptConfig { + id: String, + title: String, + description: String, + #[serde(default)] + active: bool, + fields: Option>, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct FieldConfig { + id: String, + title: String, + description: Option, + field_type: FieldType, + options: FieldOptions, +} + +#[instrument(skip_all)] +pub async fn sync_config(connection: &DatabaseConnection) -> Result<()> { + // as this function can only be called when the notify watch agent got created we can use + // unwrap here + let config_path = CONFIG.config_path.as_ref().unwrap(); + let content = tokio::fs::read_to_string(config_path) + .await + .map_err(|error| { + FeedbackFusionError::ConfigurationError(format!( + "Error while reading config: {}", + error + )) + })?; + + // parse the config + let config: InstanceConfig = toml::from_str(content.as_str()).map_err(|error| { + FeedbackFusionError::ConfigurationError(format!("Error while reading config: {}", error)) + })?; + info!("Sucessfully parsed config"); + + // start a database transaction + let transaction = connection.acquire_begin().await?; + let transaction = transaction.defer_async(|mut tx| async move { + if !tx.done { + let _ = tx.rollback().await; + } + }); + + for target in config.targets.into_iter() { + sync_target(target, &transaction).await?; + } + + Ok(()) +} + +macro_rules! update_otherwise_create { + ($transaction:ident, $path:path, $data:ident, $($field:ident $(,)?)* => $target:ident = $source:ident) => { + paste!{ + let result = $path::select_by_id($transaction, $data.id.as_str()).await?; + + if let Some(mut data) = result { + $( + data.[]($data.$field); + )* + + $path::update_by_column($transaction, &data, "id").await?; + } else { + $path::insert( + $transaction, + &$path::builder() + .id($data.id) + $( + .$field($data.$field) + )* + .$target($source) + .build() + ).await?; + } + } + }; + ($transaction:ident, $path:path, $data:ident, $($field:ident $(,)?)*) => { + paste!{ + let result = $path::select_by_id($transaction, $data.id.as_str()).await?; + + if let Some(mut data) = result { + $( + data.[]($data.$field); + )* + + $path::update_by_column($transaction, &data, "id").await?; + } else { + $path::insert( + $transaction, + &$path::builder() + .id($data.id) + $( + .$field($data.$field) + )* + .build() + ).await?; + } + } + }; +} + +#[instrument(skip_all)] +pub async fn sync_target(target: TargetConfig, transaction: &dyn Executor) -> Result<()> { + let id = target.id.clone(); + update_otherwise_create!(transaction, Target, target, name, description); + + if let Some(prompts) = target.prompts { + for prompt in prompts.into_iter() { + sync_prompt(prompt, transaction, id.clone()).await?; + } + } + + Ok(()) +} + +#[instrument(skip_all)] +pub async fn sync_prompt(prompt: PromptConfig, transaction: &dyn Executor, target: String) -> Result<()> { + let id = prompt.id.clone(); + update_otherwise_create!(transaction, Prompt, prompt, title, description, active => target = target); + + if let Some(fields) = prompt.fields { + for field in fields.into_iter() { + sync_field(field, transaction, id.clone()).await?; + } + } + + Ok(()) +} + +#[instrument(skip_all)] +pub async fn sync_field(field: FieldConfig, transaction: &dyn Executor, prompt: String) -> Result<()> { + update_otherwise_create!(transaction, Field, field, title, description, field_type, options => prompt = prompt); + + Ok(()) +} diff --git a/src/database/mssql.sql b/src/database/mssql.sql index 700bae8e..27ed9c50 100644 --- a/src/database/mssql.sql +++ b/src/database/mssql.sql @@ -30,7 +30,7 @@ BEGIN title VARCHAR(32) NOT NULL, description VARCHAR(255), prompt VARCHAR(32) NOT NULL REFERENCES prompt(id), - type VARCHAR(32) NOT NULL, + field_type VARCHAR(32) NOT NULL, options NVARCHAR(MAX) NOT NULL, updated_at DATETIME, created_at DATETIME diff --git a/src/database/mysql.sql b/src/database/mysql.sql index c29e05fd..0815f6e8 100644 --- a/src/database/mysql.sql +++ b/src/database/mysql.sql @@ -22,7 +22,7 @@ CREATE TABLE IF NOT EXISTS field ( title VARCHAR(32) NOT NULL, description VARCHAR(255), prompt VARCHAR(32) NOT NULL, - type VARCHAR(32) NOT NULL, + field_type VARCHAR(32) NOT NULL, options TEXT NOT NULL, updated_at TIMESTAMP(3), created_at TIMESTAMP(3), diff --git a/src/database/postgres.sql b/src/database/postgres.sql index 8ac5c762..92e1803e 100644 --- a/src/database/postgres.sql +++ b/src/database/postgres.sql @@ -21,7 +21,7 @@ CREATE TABLE IF NOT EXISTS field ( title VARCHAR(32) NOT NULL, description VARCHAR(255), prompt VARCHAR(32) REFERENCES prompt(id) NOT NULL, - type VARCHAR(32) NOT NULL, + field_type VARCHAR(32) NOT NULL, options TEXT NOT NULL, updated_at TIMESTAMP, created_at TIMESTAMP diff --git a/src/database/schema/feedback/prompt.rs b/src/database/schema/feedback/prompt.rs index 44291523..0b5421f6 100644 --- a/src/database/schema/feedback/prompt.rs +++ b/src/database/schema/feedback/prompt.rs @@ -166,7 +166,7 @@ pub struct Field { #[validate(length(max = 255))] description: Option, prompt: String, - r#type: FieldType, + field_type: FieldType, #[serde(serialize_with = "serialize_options", deserialize_with = "deserialize_options")] options: FieldOptions, #[builder(default_code = r#"DateTime::utc()"#)] @@ -184,7 +184,7 @@ impl From for feedback_fusion_common::proto::Field { title: val.title, description: val.description, prompt: val.prompt, - field_type: val.r#type.into(), + field_type: val.field_type.into(), options: Some(val.options.into()), updated_at: Some(date_time_to_timestamp(val.updated_at)), created_at: Some(date_time_to_timestamp(val.created_at)), @@ -201,7 +201,7 @@ impl TryInto for feedback_fusion_common::proto::Field { title: self.title, description: self.description, prompt: self.prompt, - r#type: self.field_type.try_into()?, + field_type: self.field_type.try_into()?, options: self.options.unwrap().try_into()?, updated_at: to_date_time!(self.updated_at), created_at: to_date_time!(self.created_at), diff --git a/src/main.rs b/src/main.rs index efaa6ffd..5125bba6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,8 @@ */ #![allow(clippy::too_many_arguments)] +use std::{path::PathBuf, str::FromStr}; + use crate::{ prelude::*, services::v1::{FeedbackFusionV1Context, PublicFeedbackFusionV1Context}, @@ -30,6 +32,8 @@ use feedback_fusion_common::proto::{ feedback_fusion_v1_server::FeedbackFusionV1Server, public_feedback_fusion_v1_server::PublicFeedbackFusionV1Server, }; +use futures::stream::StreamExt; +use notify::{RecommendedWatcher, Watcher}; use tonic::transport::Server; use tonic_web::GrpcWebLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -53,12 +57,62 @@ async fn main() { lazy_static::initialize(&CONFIG); lazy_static::initialize(&DATABASE_CONFIG); - let (sender, receiver) = kanal::oneshot_async::<()>(); - // connect to the database let connection = DATABASE_CONFIG.connect().await.unwrap(); let connection = DatabaseConnection::from(connection); + // start config file watcher + if let Some(config_path) = CONFIG.config_path().as_ref() { + let connection = connection.clone(); + info!("CONFIG_PATH present, starting watcher"); + // initial load + match config::sync_config(&connection).await { + Ok(_) => info!("Config reloaded"), + Err(error) => error!("Error occurred while syncinc config: {error}"), + }; + + tokio::spawn(async move { + let (sender, receiver) = kanal::bounded_async(1); + + let mut watcher = RecommendedWatcher::new( + move |response| { + let sender = sender.clone(); + tokio::spawn(async move { sender.send(response).await.unwrap() }); + }, + notify::Config::default(), + ) + .unwrap(); + + watcher + .watch( + &PathBuf::from_str(config_path.as_str()).unwrap(), + notify::RecursiveMode::NonRecursive, + ) + .unwrap(); + info!("Watching for changes at {config_path}"); + + let mut stream = receiver.stream(); + while let Some(response) = stream.next().await { + match response { + Ok(_) => { + let span = info_span!("ConfigReload"); + let _ = span.enter(); + + match config::sync_config(&connection).await { + Ok(_) => info!("Config reloaded"), + Err(error) => error!("Error occurred while syncinc config: {error}"), + } + } + Err(error) => error!("Error occurred during watch: {error}"), + } + } + + Ok::<(), FeedbackFusionError>(()) + }); + } + + let (sender, receiver) = kanal::oneshot_async::<()>(); + tokio::spawn(async move { let (mut health_reporter, health_service) = tonic_health::server::health_reporter(); health_reporter diff --git a/src/services/v1/field.rs b/src/services/v1/field.rs index 5ee889a4..170d3dd1 100644 --- a/src/services/v1/field.rs +++ b/src/services/v1/field.rs @@ -40,7 +40,7 @@ pub async fn create_field( // build the field let field = Field::builder() - .r#type(Into::::into(data.field_type())) + .field_type(Into::::into(data.field_type())) .title(data.title) .description(data.description) .options(TryInto::::try_into(data.options.unwrap())?)