Skip to content

Commit

Permalink
Merge pull request #62 from OneLiteFeatherNET/feat/configfile
Browse files Browse the repository at this point in the history
feat: add config file format and watcher
  • Loading branch information
Randoooom committed Jun 2, 2024
2 parents 513c572 + 7e8b0ab commit 9e1ec9d
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 10 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ async-trait = "0.1.80"
chrono = { version = "0.4.38", features = ["serde"] }
derivative = "2.2.0"
envy = "0.4.2"
futures = "0.3.30"
getset = "0.1.2"
itertools = "0.12.1"
kanal = "0.1.0-pre8"
lazy_static = "1.4.0"
nanoid = "0.4.0"
notify = { version = "6.1.1", default-features = false, features = [
"macos_kqueue",
] }
openidconnect = "3.5.0"
paste = "1.0.14"
prost-types = "0.12.4"
Expand All @@ -34,6 +38,7 @@ serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.115"
serde-this-or-that = { git = "https://github.com/Randoooom/serde-this-or-that.git", branch = "feature/bool-signed" }
thiserror = "1.0.58"
toml = "0.8.13"
tonic = "0.11.0"
tonic-health = "0.11.0"
tonic-reflection = "0.11.0"
Expand Down
160 changes: 159 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
*
*/

use crate::prelude::*;
use rbatis::executor::Executor;

use crate::{
database::schema::feedback::{Field, FieldOptions, FieldType, Prompt, Target},
prelude::*,
};

lazy_static! {
pub static ref CONFIG: Config = envy::from_env::<Config>().unwrap();
Expand All @@ -37,6 +42,7 @@ pub struct Config {
oidc_discovery_url: String,
#[serde(default = "default_oidc_audience")]
oidc_audience: String,
config_path: Option<String>,
}

#[inline]
Expand All @@ -48,3 +54,155 @@ fn default_global_rate_limit() -> u64 {
fn default_oidc_audience() -> String {
"feedback-fusion".to_owned()
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct InstanceConfig {
targets: Vec<TargetConfig>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct TargetConfig {
id: String,
name: String,
description: Option<String>,
prompts: Option<Vec<PromptConfig>>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct PromptConfig {
id: String,
title: String,
description: String,
#[serde(default)]
active: bool,
fields: Option<Vec<FieldConfig>>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct FieldConfig {
id: String,
title: String,
description: Option<String>,
field_type: FieldType,
options: FieldOptions,
}

#[instrument(skip_all)]
pub async fn sync_config(connection: &DatabaseConnection) -> Result<()> {
// as this function can only be called when the notify watch agent got created we can use
// unwrap here
let config_path = CONFIG.config_path.as_ref().unwrap();
let content = tokio::fs::read_to_string(config_path)
.await
.map_err(|error| {
FeedbackFusionError::ConfigurationError(format!(
"Error while reading config: {}",
error
))
})?;

// parse the config
let config: InstanceConfig = toml::from_str(content.as_str()).map_err(|error| {
FeedbackFusionError::ConfigurationError(format!("Error while reading config: {}", error))
})?;
info!("Sucessfully parsed config");

// start a database transaction
let transaction = connection.acquire_begin().await?;
let transaction = transaction.defer_async(|mut tx| async move {
if !tx.done {
let _ = tx.rollback().await;
}
});

for target in config.targets.into_iter() {
sync_target(target, &transaction).await?;
}

Ok(())
}

macro_rules! update_otherwise_create {
($transaction:ident, $path:path, $data:ident, $($field:ident $(,)?)* => $target:ident = $source:ident) => {
paste!{
let result = $path::select_by_id($transaction, $data.id.as_str()).await?;

if let Some(mut data) = result {
$(
data.[<set_ $field>]($data.$field);
)*

$path::update_by_column($transaction, &data, "id").await?;
} else {
$path::insert(
$transaction,
&$path::builder()
.id($data.id)
$(
.$field($data.$field)
)*
.$target($source)
.build()
).await?;
}
}
};
($transaction:ident, $path:path, $data:ident, $($field:ident $(,)?)*) => {
paste!{
let result = $path::select_by_id($transaction, $data.id.as_str()).await?;

if let Some(mut data) = result {
$(
data.[<set_ $field>]($data.$field);
)*

$path::update_by_column($transaction, &data, "id").await?;
} else {
$path::insert(
$transaction,
&$path::builder()
.id($data.id)
$(
.$field($data.$field)
)*
.build()
).await?;
}
}
};
}

#[instrument(skip_all)]
pub async fn sync_target(target: TargetConfig, transaction: &dyn Executor) -> Result<()> {
let id = target.id.clone();
update_otherwise_create!(transaction, Target, target, name, description);

if let Some(prompts) = target.prompts {
for prompt in prompts.into_iter() {
sync_prompt(prompt, transaction, id.clone()).await?;
}
}

Ok(())
}

#[instrument(skip_all)]
pub async fn sync_prompt(prompt: PromptConfig, transaction: &dyn Executor, target: String) -> Result<()> {
let id = prompt.id.clone();
update_otherwise_create!(transaction, Prompt, prompt, title, description, active => target = target);

if let Some(fields) = prompt.fields {
for field in fields.into_iter() {
sync_field(field, transaction, id.clone()).await?;
}
}

Ok(())
}

#[instrument(skip_all)]
pub async fn sync_field(field: FieldConfig, transaction: &dyn Executor, prompt: String) -> Result<()> {
update_otherwise_create!(transaction, Field, field, title, description, field_type, options => prompt = prompt);

Ok(())
}
2 changes: 1 addition & 1 deletion src/database/mssql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ BEGIN
title VARCHAR(32) NOT NULL,
description VARCHAR(255),
prompt VARCHAR(32) NOT NULL REFERENCES prompt(id),
type VARCHAR(32) NOT NULL,
field_type VARCHAR(32) NOT NULL,
options NVARCHAR(MAX) NOT NULL,
updated_at DATETIME,
created_at DATETIME
Expand Down
2 changes: 1 addition & 1 deletion src/database/mysql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ CREATE TABLE IF NOT EXISTS field (
title VARCHAR(32) NOT NULL,
description VARCHAR(255),
prompt VARCHAR(32) NOT NULL,
type VARCHAR(32) NOT NULL,
field_type VARCHAR(32) NOT NULL,
options TEXT NOT NULL,
updated_at TIMESTAMP(3),
created_at TIMESTAMP(3),
Expand Down
2 changes: 1 addition & 1 deletion src/database/postgres.sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ CREATE TABLE IF NOT EXISTS field (
title VARCHAR(32) NOT NULL,
description VARCHAR(255),
prompt VARCHAR(32) REFERENCES prompt(id) NOT NULL,
type VARCHAR(32) NOT NULL,
field_type VARCHAR(32) NOT NULL,
options TEXT NOT NULL,
updated_at TIMESTAMP,
created_at TIMESTAMP
Expand Down
6 changes: 3 additions & 3 deletions src/database/schema/feedback/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ pub struct Field {
#[validate(length(max = 255))]
description: Option<String>,
prompt: String,
r#type: FieldType,
field_type: FieldType,
#[serde(serialize_with = "serialize_options", deserialize_with = "deserialize_options")]
options: FieldOptions,
#[builder(default_code = r#"DateTime::utc()"#)]
Expand All @@ -184,7 +184,7 @@ impl From<Field> for feedback_fusion_common::proto::Field {
title: val.title,
description: val.description,
prompt: val.prompt,
field_type: val.r#type.into(),
field_type: val.field_type.into(),
options: Some(val.options.into()),
updated_at: Some(date_time_to_timestamp(val.updated_at)),
created_at: Some(date_time_to_timestamp(val.created_at)),
Expand All @@ -201,7 +201,7 @@ impl TryInto<Field> for feedback_fusion_common::proto::Field {
title: self.title,
description: self.description,
prompt: self.prompt,
r#type: self.field_type.try_into()?,
field_type: self.field_type.try_into()?,
options: self.options.unwrap().try_into()?,
updated_at: to_date_time!(self.updated_at),
created_at: to_date_time!(self.created_at),
Expand Down
58 changes: 56 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
*/
#![allow(clippy::too_many_arguments)]

use std::{path::PathBuf, str::FromStr};

use crate::{
prelude::*,
services::v1::{FeedbackFusionV1Context, PublicFeedbackFusionV1Context},
Expand All @@ -30,6 +32,8 @@ use feedback_fusion_common::proto::{
feedback_fusion_v1_server::FeedbackFusionV1Server,
public_feedback_fusion_v1_server::PublicFeedbackFusionV1Server,
};
use futures::stream::StreamExt;
use notify::{RecommendedWatcher, Watcher};
use tonic::transport::Server;
use tonic_web::GrpcWebLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
Expand All @@ -53,12 +57,62 @@ async fn main() {
lazy_static::initialize(&CONFIG);
lazy_static::initialize(&DATABASE_CONFIG);

let (sender, receiver) = kanal::oneshot_async::<()>();

// connect to the database
let connection = DATABASE_CONFIG.connect().await.unwrap();
let connection = DatabaseConnection::from(connection);

// start config file watcher
if let Some(config_path) = CONFIG.config_path().as_ref() {
let connection = connection.clone();
info!("CONFIG_PATH present, starting watcher");
// initial load
match config::sync_config(&connection).await {
Ok(_) => info!("Config reloaded"),
Err(error) => error!("Error occurred while syncinc config: {error}"),
};

tokio::spawn(async move {
let (sender, receiver) = kanal::bounded_async(1);

let mut watcher = RecommendedWatcher::new(
move |response| {
let sender = sender.clone();
tokio::spawn(async move { sender.send(response).await.unwrap() });
},
notify::Config::default(),
)
.unwrap();

watcher
.watch(
&PathBuf::from_str(config_path.as_str()).unwrap(),
notify::RecursiveMode::NonRecursive,
)
.unwrap();
info!("Watching for changes at {config_path}");

let mut stream = receiver.stream();
while let Some(response) = stream.next().await {
match response {
Ok(_) => {
let span = info_span!("ConfigReload");
let _ = span.enter();

match config::sync_config(&connection).await {
Ok(_) => info!("Config reloaded"),
Err(error) => error!("Error occurred while syncinc config: {error}"),
}
}
Err(error) => error!("Error occurred during watch: {error}"),
}
}

Ok::<(), FeedbackFusionError>(())
});
}

let (sender, receiver) = kanal::oneshot_async::<()>();

tokio::spawn(async move {
let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
health_reporter
Expand Down
2 changes: 1 addition & 1 deletion src/services/v1/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub async fn create_field(

// build the field
let field = Field::builder()
.r#type(Into::<FieldType>::into(data.field_type()))
.field_type(Into::<FieldType>::into(data.field_type()))
.title(data.title)
.description(data.description)
.options(TryInto::<FieldOptions>::try_into(data.options.unwrap())?)
Expand Down

0 comments on commit 9e1ec9d

Please sign in to comment.