diff --git a/Cargo.lock b/Cargo.lock index 56369f17..bb89e55d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -868,6 +868,8 @@ dependencies = [ "clap", "config_derive", "convert_case", + "serde", + "serde-value", "serde_json", "serde_yaml", "thiserror", @@ -2482,6 +2484,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "ordered-float" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7940cf2ca942593318d07fcf2596cdca60a85c9e7fab408a5e21a4f9dcd40d87" +dependencies = [ + "num-traits", +] + [[package]] name = "overload" version = "0.1.1" @@ -3301,6 +3312,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-value" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" +dependencies = [ + "ordered-float", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.166" diff --git a/backend/api/src/config.rs b/backend/api/src/config.rs index 7a9928c3..25e44e95 100644 --- a/backend/api/src/config.rs +++ b/backend/api/src/config.rs @@ -2,9 +2,10 @@ use std::net::SocketAddr; use anyhow::Result; use common::config::{LoggingConfig, RedisConfig, RmqConfig, TlsConfig}; +use config::KeyTree; -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] /// The API is the backend for the Scuffle service pub struct AppConfig { /// The path to the config file @@ -38,11 +39,11 @@ pub struct AppConfig { pub redis: RedisConfig, } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct ApiConfig { /// Bind address for the API - #[config(from_str)] + #[config(tree = "KeyTree::String")] pub bind_address: SocketAddr, /// If we should use TLS for the API server @@ -58,8 +59,8 @@ impl Default for ApiConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct DatabaseConfig { /// The database URL to use pub uri: String, @@ -73,8 +74,8 @@ impl Default for DatabaseConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct TurnstileConfig { /// The Cloudflare Turnstile site key to use pub secret_key: String, @@ -92,8 +93,8 @@ impl Default for TurnstileConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct JwtConfig { /// JWT secret pub secret: String, @@ -111,11 +112,11 @@ impl Default for JwtConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct GrpcConfig { /// Bind address for the GRPC server - #[config(from_str)] + #[config(tree = "KeyTree::String")] pub bind_address: SocketAddr, /// If we should use TLS for the gRPC server diff --git a/common/Cargo.toml b/common/Cargo.toml index 4b34ab1f..46901f14 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -6,14 +6,14 @@ authors = ["Scuffle "] description = "Scuffle Common Library" [features] -logging = ["dep:log", "dep:tracing", "dep:tracing-log", "dep:tracing-subscriber", "dep:arc-swap", "dep:anyhow", "dep:once_cell", "dep:thiserror"] +logging = ["dep:log", "dep:tracing", "dep:tracing-log", "dep:tracing-subscriber", "dep:arc-swap", "dep:anyhow", "dep:once_cell", "dep:thiserror", "dep:serde"] rmq = ["dep:lapin", "dep:arc-swap", "dep:anyhow", "dep:futures", "dep:tracing", "dep:tokio", "dep:async-stream", "prelude"] grpc = ["dep:tonic", "dep:anyhow", "dep:async-trait", "dep:futures", "dep:http", "dep:tower", "dep:trust-dns-resolver", "dep:tracing"] context = ["dep:tokio", "dep:tokio-util"] prelude = ["dep:tokio"] signal = [] macros = [] -config = ["dep:config", "logging"] +config = ["dep:config", "dep:serde", "logging"] default = ["logging", "rmq", "grpc", "context", "prelude", "signal", "macros", "config"] diff --git a/common/src/config.rs b/common/src/config.rs index df863a17..52cbd544 100644 --- a/common/src/config.rs +++ b/common/src/config.rs @@ -1,12 +1,15 @@ use anyhow::Result; use crate::logging; +use config::KeyTree; -#[derive(Debug, Clone, Default, PartialEq, config::Config)] +#[derive( + Debug, Clone, Default, PartialEq, config::Config, serde::Deserialize, serde::Serialize, +)] +#[serde(default)] pub struct TlsConfig { /// Domain name to use for TLS /// Only used for gRPC TLS connections - #[config(default)] pub domain: Option, /// The path to the TLS certificate @@ -19,14 +22,14 @@ pub struct TlsConfig { pub ca_cert: String, } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct LoggingConfig { /// The log level to use, this is a tracing env filter pub level: String, /// What logging mode we should use - #[config(from_str)] + #[config(tree = "KeyTree::String")] pub mode: logging::Mode, } @@ -39,8 +42,8 @@ impl Default for LoggingConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct RedisConfig { /// The address of the Redis server pub addresses: Vec, @@ -64,15 +67,15 @@ pub struct RedisConfig { pub sentinel: Option, } -#[derive(Debug, Clone, PartialEq, config::Config)] -// #[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct RedisSentinelConfig { /// The master group name pub service_name: String, } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct RmqConfig { /// The URI to use for connecting to RabbitMQ pub uri: String, @@ -115,12 +118,12 @@ pub fn parse( let mut builder = config::ConfigBuilder::new(); if enable_cli { - builder.add_source_with_priority(config::sources::CliSource::new(), 3); + builder.add_source_with_priority(config::sources::CliSource::new()?, 3); } - builder.add_source_with_priority(config::sources::EnvSource::new("SCUF", "_"), 2); + builder.add_source_with_priority(config::sources::EnvSource::with_prefix("SCUF")?, 2); - let key = builder.try_parse_key::("config_file")?; + let key = builder.parse_key::>("config_file")?; let key_provided = key.is_some(); diff --git a/common/src/logging.rs b/common/src/logging.rs index b6560391..f146da55 100644 --- a/common/src/logging.rs +++ b/common/src/logging.rs @@ -5,7 +5,7 @@ use tracing_subscriber::{prelude::*, reload::Handle, EnvFilter}; static RELOAD_HANDLE: OnceCell> = OnceCell::new(); -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum Mode { Default, Json, diff --git a/config/config/Cargo.toml b/config/config/Cargo.toml index b2829288..ab59e5a4 100644 --- a/config/config/Cargo.toml +++ b/config/config/Cargo.toml @@ -14,7 +14,8 @@ serde_yaml = "0" toml = "0" clap = { version = "4", features = ["cargo", "string"] } convert_case = "0" - +serde = { version = "1", features = ["derive"] } +serde-value = "0" tracing = { version = "0" } # Derive macro diff --git a/config/config/example/derive.rs b/config/config/example/derive.rs index a02aefcf..98862cc6 100644 --- a/config/config/example/derive.rs +++ b/config/config/example/derive.rs @@ -1,36 +1,20 @@ //! Run with: `cargo run --example derive` //! Look at the generated code with: `cargo expand --example derive` -use std::net::SocketAddr; -use std::str::FromStr; - use config::{sources, ConfigBuilder}; type TypeAlias = bool; -#[derive(config::Config, Debug, PartialEq)] -#[config(default)] +#[derive(config::Config, Debug, PartialEq, serde::Deserialize, serde::Serialize, Default)] +#[serde(default)] struct AppConfig { enabled: TypeAlias, logging: LoggingConfig, - optional: Option<()>, - #[config(from_str = "SocketAddr::from_str")] - bind_address: Option, -} - -impl Default for AppConfig { - fn default() -> Self { - Self { - enabled: false, - logging: LoggingConfig::default(), - optional: None, - bind_address: Some("127.0.0.1:5000".parse().unwrap()), - } - } + optional: Option, } -#[derive(config::Config, Debug, PartialEq)] -#[config(default)] +#[derive(config::Config, Debug, PartialEq, serde::Deserialize, serde::Serialize)] +#[serde(default)] struct LoggingConfig { level: String, json: bool, @@ -47,7 +31,7 @@ impl Default for LoggingConfig { fn main() { let mut builder: ConfigBuilder = ConfigBuilder::new(); - builder.add_source(sources::CliSource::new()); + builder.add_source(sources::CliSource::new().unwrap()); match builder.build() { Ok(config) => println!("{:?}", config), diff --git a/config/config/src/key.rs b/config/config/src/key.rs index d94a0313..44ecba42 100644 --- a/config/config/src/key.rs +++ b/config/config/src/key.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::{collections::BTreeMap, fmt::Display}; #[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] pub struct KeyPath(Vec); @@ -8,7 +8,11 @@ impl KeyPath { Self::default() } - pub fn child(&self, name: &str) -> Self { + pub fn child(&self) -> Option<&str> { + self.0.last().map(|s| s.as_str()) + } + + pub fn push_child(&self, name: &str) -> Self { let mut path = self.clone(); path.0.push(name.to_string()); path @@ -23,9 +27,16 @@ impl KeyPath { self } - pub fn push_root(mut self, root: &str) -> Self { + pub fn take_child(&mut self) -> Option { + self.0.pop() + } + + pub fn push_root(&mut self, root: &str) { self.0.insert(0, root.to_string()); - self + } + + pub fn get_inner(&self) -> &[String] { + &self.0 } } @@ -62,11 +73,74 @@ impl From<&str> for KeyPath { #[derive(Debug, Clone, PartialEq, Eq)] pub enum KeyType { String, - Integer, - Float, - Boolean, + I8, + I16, + I32, + I64, + U8, + U16, + U32, + U64, + F32, + F64, + Bool, + Unit, + Option(Box), + Map(BTreeMap), + Seq(Box), +} + +#[derive(Debug, Clone)] +pub enum KeyTree { + String, + I8, + I16, + I32, + I64, + U8, + U16, + U32, + U64, + F32, + F64, + Bool, Unit, - Array(Box), + Option(Box), + Map(BTreeMap), + Seq(Box), +} + +impl KeyTree { + pub fn key_type(&self) -> KeyType { + match self { + Self::String => KeyType::String, + Self::I8 => KeyType::I8, + Self::I16 => KeyType::I16, + Self::I32 => KeyType::I32, + Self::I64 => KeyType::I64, + Self::U8 => KeyType::U8, + Self::U16 => KeyType::U16, + Self::U32 => KeyType::U32, + Self::U64 => KeyType::U64, + Self::F32 => KeyType::F32, + Self::F64 => KeyType::F64, + Self::Bool => KeyType::Bool, + Self::Unit => KeyType::Unit, + Self::Option(key) => KeyType::Option(Box::new(key.key_type())), + Self::Map(map) => { + KeyType::Map(map.iter().map(|(k, v)| (k.clone(), v.key_type())).collect()) + } + Self::Seq(key) => KeyType::Seq(Box::new(key.key_type())), + } + } + + pub fn into_option(self) -> Self { + Self::Option(Box::new(self)) + } + + pub fn into_seq(self) -> Self { + Self::Seq(Box::new(self)) + } } /// A key has @@ -74,18 +148,16 @@ pub enum KeyType { /// - type #[derive(Debug, Clone)] pub struct Key { - path: KeyPath, - key_type: KeyType, + tree: KeyTree, skip_cli: bool, skip_env: bool, comment: Option<&'static str>, } impl Key { - pub fn new(path: KeyPath, key_type: KeyType) -> Self { + pub fn new(tree: KeyTree) -> Self { Self { - path, - key_type, + tree, skip_cli: false, skip_env: false, comment: None, @@ -119,11 +191,11 @@ impl Key { self.comment } - pub fn path(&self) -> &KeyPath { - &self.path + pub fn tree(&self) -> &KeyTree { + &self.tree } - pub fn key_type(&self) -> &KeyType { - &self.key_type + pub fn key_type(&self) -> KeyType { + self.tree.key_type() } } diff --git a/config/config/src/lib.rs b/config/config/src/lib.rs index 11101a8b..1a2390e0 100644 --- a/config/config/src/lib.rs +++ b/config/config/src/lib.rs @@ -1,5 +1,6 @@ //! TODO: write docs +use std::collections::{BTreeMap, HashSet}; use std::io; use std::num::{ParseFloatError, ParseIntError}; use std::ops::Deref; @@ -11,12 +12,9 @@ pub use config_derive::Config; mod tests; mod key; -mod primitives; pub mod sources; -mod value; - pub use key::*; -pub use value::*; +use serde_value::Value; pub type Result = std::result::Result; @@ -43,6 +41,8 @@ pub enum ConfigError { NotAPrimitive, #[error("can't call on a primitive")] IsAPrimitive, + #[error("cannot deserialize into a struct: {0}")] + Struct(#[from] serde_value::DeserializerError), } #[derive(Debug, thiserror::Error)] @@ -65,15 +65,17 @@ pub enum ParseError { Float(#[from] ParseFloatError), #[error("{0}")] Bool(#[from] ParseBoolError), + #[error("unsupported type: {1:?} for key: {0}")] + UnsupportedType(KeyPath, KeyType), } pub trait Source { - fn get_key(&self, key: &Key) -> Result>; + fn get_key(&self, path: KeyPath) -> Result>; } /// You don't need to implemented this trait manually. /// Please use the provided derive macro to generate the implementation. -pub trait Config +pub trait Config: Sized + 'static + serde::de::DeserializeOwned + serde::ser::Serialize where Self: Sized, { @@ -82,21 +84,7 @@ where const VERSION: Option<&'static str> = None; const AUTHOR: Option<&'static str> = None; - // Functions for config structs - - /// Builds `Self` from given values. - fn build(values: Option) -> Result; - - /// Returns a `Vec` of all keys. - fn keys(root: KeyPath) -> Result>; - - // Functions for primitive types - - /// Returns the [`KeyType`](KeyType) if this is a primitive type. - fn primitive() -> Option; - - /// Tries to build `Self` from an optional [`Value`](Value). - fn from_value(value: Option) -> Result; + fn tree() -> KeyTree; } struct SourceHolder { @@ -132,6 +120,39 @@ impl Default for ConfigBuilder { } } +fn merge(first: Value, second: Value) -> Value { + let Value::Map(mut first) = first else { + return first; + }; + + let Value::Map(mut second) = second else { + return Value::Map(first); + }; + + let mut merged = BTreeMap::new(); + // Get all unique keys from both maps + let keys = first + .keys() + .chain(second.keys()) + .cloned() + .collect::>(); + for key in keys { + let first = first.remove(&key); + let second = second.remove(&key); + + let value = match (first, second) { + (Some(first), Some(second)) => merge(first, second), + (Some(first), None) => first, + (None, Some(second)) => second, + (None, None) => unreachable!(), + }; + + merged.insert(key, value); + } + + Value::Map(merged) +} + impl ConfigBuilder { pub fn new() -> Self { Self { @@ -159,39 +180,24 @@ impl ConfigBuilder { } /// Get a single key. - pub fn get_key(&self, key: &Key) -> Result> { - for source in &self.sources { - if let Some(value) = source.get_key(key)? { - return Ok(Some(value)); - } - } - - Ok(None) - } - - pub fn try_parse_key(&self, path: impl Into) -> Result> { - match self.parse_key(path) { - Ok(value) => Ok(Some(value)), - Err(ConfigError::MissingKey) => Ok(None), - r => r.map(Some), - } + pub fn get_key(&self, path: impl Into) -> Result> { + let key_path = path.into(); + + Ok(self + .sources + .iter() + .map(|s| s.get_key(key_path.clone())) + .collect::>>()? + .into_iter() + .flatten() + .reduce(merge)) } /// Parse a single key. pub fn parse_key(&self, path: impl Into) -> Result { - if let Some(kt) = T::primitive() { - let value = self.get_key(&Key::new(path.into(), kt))?; - T::from_value(value) - } else { - let mut values: ValueMap = ValueMap::new(); - let keys = T::keys(KeyPath::new())?; - for key in &keys { - if let Some(value) = self.get_key(key)? { - values.insert(key.path().clone(), value); - } - } - T::build(Some(values)) - } + Ok(T::deserialize( + self.get_key(path)?.unwrap_or(Value::Option(None)), + )?) } /// This function iterates all added sources and builds a config for all `C::keys`. @@ -201,34 +207,56 @@ impl ConfigBuilder { } } -pub mod __internal { - use crate::ConfigError; - - pub fn parse( - input: Option<&str>, - func: impl Fn(&str) -> Result, - ) -> Result - where - F: Into, - { - match input { - Some(s) => match func(s) { - Ok(r) => Ok(r.into()), - Err(err) => Err(ConfigError::FromStr(Box::new(err))), - }, - None => { - // None values are tricky - let boxed: Box = Box::new(Option::::None); - - let boxed_type_id = boxed.type_id(); - let o_type_id = std::any::TypeId::of::(); - - if boxed_type_id != o_type_id { - return Err(ConfigError::MissingKey); - } - - Ok(*boxed.downcast::().unwrap()) +macro_rules! impl_config { + ($ty:ty, $kt:expr) => { + impl Config for $ty { + fn tree() -> KeyTree { + $kt } } + }; +} + +impl_config!(bool, KeyTree::Bool); +impl_config!(f32, KeyTree::F32); +impl_config!(f64, KeyTree::F64); +impl_config!(i8, KeyTree::I8); +impl_config!(i16, KeyTree::I16); +impl_config!(i32, KeyTree::I32); +impl_config!(i64, KeyTree::I64); +impl_config!(u8, KeyTree::U8); +impl_config!(u16, KeyTree::U16); +impl_config!(u32, KeyTree::U32); +impl_config!(u64, KeyTree::U64); +impl_config!(String, KeyTree::String); +impl_config!((), KeyTree::Unit); + +impl Config for Option { + fn tree() -> KeyTree { + KeyTree::Option(Box::new(C::tree())) + } +} + +impl Config for Vec { + fn tree() -> KeyTree { + KeyTree::Seq(Box::new(C::tree())) + } +} + +impl Config for isize { + fn tree() -> KeyTree { + #[cfg(target_pointer_width = "32")] + return KeyTree::I32; + #[cfg(target_pointer_width = "64")] + return KeyTree::I64; + } +} + +impl Config for usize { + fn tree() -> KeyTree { + #[cfg(target_pointer_width = "32")] + return KeyTree::U32; + #[cfg(target_pointer_width = "64")] + return KeyTree::U64; } } diff --git a/config/config/src/primitives.rs b/config/config/src/primitives.rs deleted file mode 100644 index ac6ed0af..00000000 --- a/config/config/src/primitives.rs +++ /dev/null @@ -1,129 +0,0 @@ -use crate::{Config, ConfigError, Key, KeyPath, KeyType, Result, Value, ValueMap}; - -macro_rules! impl_primitive { - ($obj:ident, $kt:expr, $value:ident => $fv:block) => { - impl Config for $obj { - fn build(_values: Option) -> Result { - Err(ConfigError::IsAPrimitive) - } - - fn keys(_: KeyPath) -> Result> { - Err(ConfigError::IsAPrimitive) - } - - fn primitive() -> Option { - Some($kt) - } - - fn from_value($value: Option) -> Result { - $fv - } - } - }; - ($($obj:ident),+; $kt:expr, $value:ident => $fv:block) => { - $(impl_primitive!($obj, $kt, $value => $fv);)+ - }; -} - -impl_primitive!(String, KeyType::String, value => { - let value = value.ok_or(ConfigError::MissingKey)?; - match value { - Value::String(s) => Ok(s), - _ => Err(ConfigError::TypeMismatch { path: None, expected: KeyType::String, got: value }), - } -}); - -// TODO: This is probably not what we want -// Since all integers are parsed as i64 it won't always work to just cast them -impl_primitive!(i128, i64, i32, i16, i8, u128, u64, u32, u16, u8, usize; KeyType::Integer, value => { - let value = value.ok_or(ConfigError::MissingKey)?; - match value { - Value::Integer(i) => Ok(i as Self), - _ => Err(ConfigError::TypeMismatch { path: None, expected: KeyType::Integer, got: value }), - } -}); - -impl_primitive!(f64, f32; KeyType::Float, value => { - let value = value.ok_or(ConfigError::MissingKey)?; - match value { - Value::Float(f) => Ok(f as Self), - _ => Err(ConfigError::TypeMismatch { path: None, expected: KeyType::Float, got: value }), - } -}); - -impl_primitive!(bool, KeyType::Boolean, value => { - let value = value.ok_or(ConfigError::MissingKey)?; - match value { - Value::Boolean(b) => Ok(b), - _ => Err(ConfigError::TypeMismatch { path: None, expected: KeyType::Boolean, got: value }), - } -}); - -type Unit = (); - -impl_primitive!(Unit, KeyType::Unit, value => { - let value = value.ok_or(ConfigError::MissingKey)?; - match value { - Value::Unit => Ok(()), - _ => Err(ConfigError::TypeMismatch { path: None, expected: KeyType::Unit, got: value }), - } -}); - -// Currently only supports arrays of primitives -impl Config for Vec { - fn build(_values: Option) -> Result { - Err(ConfigError::IsAPrimitive) - } - - fn keys(root: KeyPath) -> Result> { - T::keys(root) - } - - fn primitive() -> Option { - Some(KeyType::Array(Box::new( - T::primitive().unwrap_or(KeyType::Unit), - ))) - } - - fn from_value(value: Option) -> Result { - let value = value.ok_or(ConfigError::MissingKey)?; - if let Some(kt) = T::primitive() { - // When the contained type is a primitive - match value { - Value::Array(a) => a.into_iter().map(|v| T::from_value(Some(v))).collect(), - _ => Err(ConfigError::TypeMismatch { - path: None, - expected: KeyType::Array(Box::new(kt)), - got: value, - }), - } - } else { - // When the contained type isn't a primitive we can't build the vector - Err(ConfigError::NotAPrimitive) - } - } -} - -impl Config for Option { - fn build(values: Option) -> Result { - if values.is_none() { - return Ok(None); - } - Ok(Some(T::build(values)?)) - } - - fn keys(root: KeyPath) -> Result> { - T::keys(root) - } - - fn primitive() -> Option { - T::primitive() - } - - fn from_value(value: Option) -> Result { - match value { - Some(Value::Null) | None => Ok(None), - Some(value) => Ok(Some(T::from_value(Some(value))?)), - } - } -} diff --git a/config/config/src/sources/cli.rs b/config/config/src/sources/cli.rs index c25d3d95..b3381632 100644 --- a/config/config/src/sources/cli.rs +++ b/config/config/src/sources/cli.rs @@ -1,12 +1,162 @@ use std::marker::PhantomData; use clap::{command, Arg, ArgAction, ArgMatches, Command}; +use convert_case::{Case, Casing}; -use crate::{Config, Key, KeyPath, KeyType, Result, Source, Value}; +use crate::{Config, KeyPath, KeyTree, ParseError, Result, Source, Value}; -use convert_case::{Case, Casing}; +use super::utils; + +fn extend_cmd( + cmd: Command, + tree: &KeyTree, + arg: Option, + path: &KeyPath, + sequenced: bool, +) -> Result<(Option, Command)> { + Ok(match tree { + KeyTree::String => ( + Some(arg.unwrap().value_parser(clap::value_parser!(String))), + cmd, + ), + KeyTree::I8 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(i8))), + cmd, + ), + KeyTree::I16 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(i16))), + cmd, + ), + KeyTree::I32 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(i32))), + cmd, + ), + KeyTree::I64 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(i64))), + cmd, + ), + KeyTree::U8 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(u8))), + cmd, + ), + KeyTree::U16 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(u16))), + cmd, + ), + KeyTree::U32 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(u32))), + cmd, + ), + KeyTree::U64 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(u64))), + cmd, + ), + KeyTree::F32 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(f32))), + cmd, + ), + KeyTree::F64 => ( + Some(arg.unwrap().value_parser(clap::value_parser!(f64))), + cmd, + ), + KeyTree::Bool => ( + Some( + arg.unwrap() + .value_parser(clap::value_parser!(bool)) + .default_missing_value("true") + .num_args(0..=1), + ), + cmd, + ), + KeyTree::Unit => { + if sequenced { + ( + Some( + arg.unwrap() + .value_parser(clap::value_parser!(bool)) + .default_missing_value("true") + .num_args(0), + ), + cmd, + ) + } else { + ( + Some( + arg.unwrap() + .action(ArgAction::SetTrue) + .num_args(0) + .require_equals(false), + ), + cmd, + ) + } + } + KeyTree::Seq(t) => { + if sequenced { + return Err(ParseError::UnsupportedType(path.clone(), t.key_type()).into()); + } else { + let (arg, cmd) = extend_cmd(cmd, t, arg, path, true)?; + let Some(arg) = arg else { + return Err(ParseError::UnsupportedType(path.clone(), t.key_type()).into()) + }; + + let num_args = arg.get_num_args().unwrap(); + (Some(arg.num_args(num_args.min_values()..)), cmd) + } + } + KeyTree::Option(t) => { + if sequenced { + return Err(ParseError::UnsupportedType(path.clone(), t.key_type()).into()); + } + + let (arg, cmd) = extend_cmd(cmd, t, arg, path, false)?; + if let Some(arg) = arg { + let num_args = arg.get_num_args().unwrap(); + (Some(arg.num_args(0..=num_args.max_values())), cmd) + } else { + (None, cmd) + } + } + KeyTree::Map(map) => { + if sequenced { + return Err(ParseError::UnsupportedType(path.clone(), tree.key_type()).into()); + } + + let mut cmd = cmd; + + for (child_path, key) in map { + if key.skip_cli() { + continue; + } + + let path = path.push_child(child_path); + + let arg = Arg::new(path.to_string()) + .long( + path.get_inner() + .iter() + .map(|v| v.to_case(Case::Kebab)) + .collect::>() + .join("."), + ) + .num_args(1) + .required(false); + + let (arg, mut new_cmd) = extend_cmd(cmd, key.tree(), Some(arg), &path, false)?; + + if let Some(arg) = arg { + new_cmd = new_cmd.arg(arg); + } + + cmd = new_cmd; + } + + (None, cmd) + } + }) +} -pub fn generate_command() -> Command { +pub fn generate_command() -> Result { // Generate clap Command let mut command = command!(); @@ -46,50 +196,18 @@ pub fn generate_command() -> Command { command = command.help_template(template); - for key in C::keys(KeyPath::new()).unwrap() { - if key.skip_cli() { - continue; - } + let map = match C::tree() { + KeyTree::Map(map) => map, + r => return Err(ParseError::UnsupportedType(KeyPath::new(), r.key_type()).into()), + }; - let mut arg = Arg::new(key.path()) - .long(key.path().to_string().to_case(Case::Kebab)) - .num_args(1) - .required(false); - - if let Some(comment) = key.comment() { - arg = arg.long_help(comment); - } - - match key.key_type() { - KeyType::String => arg = arg.value_parser(clap::value_parser!(String)), - KeyType::Integer => arg = arg.value_parser(clap::value_parser!(i64)), - KeyType::Float => arg = arg.value_parser(clap::value_parser!(f64)), - KeyType::Boolean => { - arg = arg - .value_parser(clap::value_parser!(bool)) - .default_missing_value("true") - .num_args(0..=1) - } - KeyType::Unit => { - arg = arg - .action(ArgAction::SetTrue) - .num_args(0) - .require_equals(false) - } - KeyType::Array(t) => match t.as_ref() { - KeyType::String => { - arg = arg.value_parser(clap::value_parser!(String)).num_args(1..) - } - KeyType::Integer => arg = arg.value_parser(clap::value_parser!(i64)).num_args(1..), - KeyType::Float => arg = arg.value_parser(clap::value_parser!(f64)).num_args(1..), - KeyType::Boolean => arg = arg.value_parser(clap::value_parser!(bool)).num_args(1..), - r => unimplemented!("Arrays of {:?} are currently not supported in CLI", r), - }, - } + let (arg, mut command) = extend_cmd(command, &KeyTree::Map(map), None, &KeyPath::new(), false)?; + if let Some(arg) = arg { command = command.arg(arg); } - command + + Ok(command) } impl From<&KeyPath> for clap::Id { @@ -99,86 +217,229 @@ impl From<&KeyPath> for clap::Id { } pub struct CliSource { - matches: ArgMatches, + value: Value, _phantom: PhantomData, } -impl Default for CliSource { - fn default() -> Self { - Self::new() - } -} +fn matches_to_value( + path: KeyPath, + tree: &KeyTree, + matches: &ArgMatches, + sequenced: bool, +) -> Result> { + let id = path.to_string(); -impl CliSource { - pub fn new() -> Self { - Self { - matches: generate_command::().get_matches(), - _phantom: PhantomData, + match tree { + KeyTree::Bool => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::Bool(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::Bool(*s))) + } } - } + KeyTree::String => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| { + s.into_iter() + .map(|s| Value::String(s.to_string())) + .collect() + }) + .map(Value::Seq)) + } else { + Ok(matches + .get_one::(&id) + .map(|s| Value::String(s.to_string()))) + } + } + KeyTree::I8 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::I8(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::I8(*s))) + } + } + KeyTree::I16 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::I16(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::I16(*s))) + } + } + KeyTree::I32 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::I32(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::I32(*s))) + } + } + KeyTree::I64 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::I64(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::I64(*s))) + } + } + KeyTree::U8 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::U8(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::U8(*s))) + } + } + KeyTree::U16 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::U16(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::U16(*s))) + } + } + KeyTree::U32 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::U32(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::U32(*s))) + } + } + KeyTree::U64 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::U64(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::U64(*s))) + } + } + KeyTree::F32 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::F32(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::F32(*s))) + } + } + KeyTree::F64 => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|s| Value::F64(*s)).collect()) + .map(Value::Seq)) + } else { + Ok(matches.get_one::(&id).map(|s| Value::F64(*s))) + } + } + KeyTree::Unit => { + if sequenced { + Ok(matches + .get_many::(&id) + .map(|s| s.into_iter().map(|_| Value::Unit).collect()) + .map(Value::Seq)) + } else { + Ok(if matches.get_flag(&id) { + Some(Value::Unit) + } else { + None + }) + } + } + KeyTree::Seq(t) => { + if sequenced { + return Err(ParseError::UnsupportedType(path, t.key_type()).into()); + } - pub fn with_matches(matches: ArgMatches) -> Self { - Self { - matches, - _phantom: PhantomData, + matches_to_value(path, t, matches, true) } - } + KeyTree::Option(t) => { + if sequenced { + return Err(ParseError::UnsupportedType(path, t.key_type()).into()); + } - fn get_arg(&self, key: &Key) -> Option { - if key.skip_cli() { - return None; + let value = matches_to_value(path, t, matches, false)?; + if value.is_none() + && matches + .try_get_raw(&id) + .map(|v| v.is_some()) + .unwrap_or_default() + { + Ok(Some(Value::Option(None))) + } else { + Ok(value) + } } + KeyTree::Map(map) => { + if sequenced { + return Err(ParseError::UnsupportedType(path, tree.key_type()).into()); + } - let id = key.path().to_string(); + let mut hashmap = std::collections::BTreeMap::new(); - if !self.matches.contains_id(&id) { - tracing::debug!(key = key.path().to_string(), "Key not present in CLI"); - return None; - } + for (child_path, key) in map { + if key.skip_cli() { + continue; + } - match key.key_type() { - KeyType::String => self - .matches - .get_one::(&id) - .map(|s| Value::String(s.clone())), - KeyType::Integer => self.matches.get_one::(&id).map(|i| Value::Integer(*i)), - KeyType::Float => self.matches.get_one::(&id).map(|f| Value::Float(*f)), - KeyType::Boolean => self - .matches - .get_one::(&id) - .map(|b| Value::Boolean(*b)), - KeyType::Unit => { - if self.matches.get_flag(&id) { - Some(Value::Unit) - } else { - None + let path = path.push_child(child_path); + + let value = matches_to_value(path, key.tree(), matches, false)?; + + if let Some(value) = value { + hashmap.insert(Value::String(child_path.to_string()), value); } } - KeyType::Array(t) => match t.as_ref() { - KeyType::String => self - .matches - .get_many::(&id) - .map(|v| Value::Array(v.map(|s| Value::String(s.to_string())).collect())), - KeyType::Integer => self - .matches - .get_many::(&id) - .map(|v| Value::Array(v.map(|s| Value::Integer(*s)).collect())), - KeyType::Float => self - .matches - .get_many::(&id) - .map(|v| Value::Array(v.map(|s| Value::Float(*s)).collect())), - KeyType::Boolean => self - .matches - .get_many::(&id) - .map(|v| Value::Array(v.map(|s| Value::Boolean(*s)).collect())), - _ => unimplemented!("Arrays of arrays are currently not supported in CLI"), - }, + + if hashmap.is_empty() && path.root().is_some() { + Ok(None) + } else { + Ok(Some(Value::Map(hashmap))) + } } } } +impl CliSource { + pub fn new() -> Result { + Self::with_matches(generate_command::()?.get_matches()) + } + + pub fn with_matches(matches: ArgMatches) -> Result { + Ok(Self { + value: matches_to_value(KeyPath::new(), &C::tree(), &matches, false)? + .unwrap_or(Value::Option(None)), + _phantom: PhantomData, + }) + } +} + impl Source for CliSource { - fn get_key(&self, key: &Key) -> Result> { - Ok(self.get_arg(key)) + fn get_key(&self, path: KeyPath) -> Result> { + utils::get_key(&self.value, path) } } diff --git a/config/config/src/sources/env.rs b/config/config/src/sources/env.rs index b3b988db..59567eb9 100644 --- a/config/config/src/sources/env.rs +++ b/config/config/src/sources/env.rs @@ -1,87 +1,148 @@ -use std::marker::PhantomData; +use std::{collections::BTreeMap, marker::PhantomData}; -use crate::{Config, Key, KeyType, ParseError, Result, Source, Value}; +use crate::{Config, KeyPath, KeyTree, KeyType, ParseError, Result, Source, Value}; + +use super::utils; pub struct EnvSource { - prefix: Option, - joiner: String, + value: Value, _phantom: PhantomData, } -impl Default for EnvSource { - fn default() -> Self { - Self { - prefix: None, - joiner: "_".to_string(), +impl EnvSource { + pub fn new() -> Result { + Self::with_joiner(None, "_") + } + + pub fn with_prefix(prefix: &str) -> Result { + Self::with_joiner(Some(prefix), "_") + } + + pub fn with_joiner(prefix: Option<&str>, joiner: &str) -> Result { + Ok(Self { _phantom: PhantomData, - } + value: extract_keys( + &C::tree(), + prefix, + prefix + .map(|p| KeyPath::new().push_child(p)) + .unwrap_or_default(), + joiner, + false, + false, + )? + .unwrap_or(Value::Option(None)), + }) } } -impl EnvSource { - pub fn new(prefix: &str, joiner: &str) -> Self { - Self { - prefix: Some(prefix.to_string()), - joiner: joiner.to_string(), - _phantom: PhantomData, +fn extract_keys( + tree: &KeyTree, + prefix: Option<&str>, + path: KeyPath, + joiner: &str, + seq: bool, + optional: bool, +) -> Result> { + match tree { + KeyTree::Bool + | KeyTree::F32 + | KeyTree::F64 + | KeyTree::I8 + | KeyTree::I16 + | KeyTree::I32 + | KeyTree::I64 + | KeyTree::String + | KeyTree::U8 + | KeyTree::U16 + | KeyTree::U32 + | KeyTree::U64 + | KeyTree::Unit => { + let name = path + .get_inner() + .iter() + .map(|s| s.to_uppercase()) + .collect::>() + .join(joiner); + + // Parse to requested type + let Ok(value) = std::env::var(name) else { + return Ok(None); + }; + + if optional && value.is_empty() { + return Ok(Some(Value::Option(None))); + } + + Ok(Some(parse_to_value(tree.key_type(), &value)?)) } - } + KeyTree::Map(map) => { + if seq { + return Err(ParseError::UnsupportedType(path, tree.key_type()).into()); + } + + let result = map + .iter() + .map(|(child_path, key)| { + extract_keys( + key.tree(), + prefix, + path.push_child(child_path), + joiner, + false, + false, + ) + .map(|value| value.map(|value| (Value::String(child_path.clone()), value))) + }) + .collect::>>()? + .into_iter() + .flatten() + .collect::>(); - pub fn with_prefix(prefix: &str) -> Self { - Self { - prefix: Some(prefix.to_string()), - ..Default::default() + if result.is_empty() && path.get_inner().len() != prefix.is_some() as usize { + Ok(None) + } else { + Ok(Some(Value::Map(result))) + } } - } + KeyTree::Option(tree) => { + if seq { + return Err(ParseError::UnsupportedType(path, tree.key_type()).into()); + } + + extract_keys(tree, prefix, path, joiner, seq, true) + } + KeyTree::Seq(tree) => { + if seq { + return Err(ParseError::UnsupportedType(path, tree.key_type()).into()); + } - pub fn with_joiner(joiner: &str) -> Self { - Self { - joiner: joiner.to_string(), - ..Default::default() + extract_keys(tree, prefix, path, joiner, true, false) } } } -fn parse_to_value(key_type: &KeyType, s: &str) -> Result { +fn parse_to_value(key_type: KeyType, s: &str) -> Result { match key_type { KeyType::String => Ok(Value::String(s.to_string())), - KeyType::Integer => Ok(Value::Integer( - s.parse::().map_err(ParseError::Integer)?, - )), - KeyType::Float => Ok(Value::Float(s.parse::().map_err(ParseError::Float)?)), - KeyType::Boolean => Ok(Value::Boolean(s.parse::().map_err(ParseError::Bool)?)), - KeyType::Array(key_type) => { - let mut vec = vec![]; - // We split on "," here which means elements can't contain "," - // TODO: Determine if this is a problem - for element in s.split(',') { - vec.push(parse_to_value(key_type, element)?); - } - Ok(Value::Array(vec)) - } + KeyType::I64 => Ok(Value::I64(s.parse::().map_err(ParseError::Integer)?)), + KeyType::U64 => Ok(Value::U64(s.parse::().map_err(ParseError::Integer)?)), + KeyType::I32 => Ok(Value::I32(s.parse::().map_err(ParseError::Integer)?)), + KeyType::U32 => Ok(Value::U32(s.parse::().map_err(ParseError::Integer)?)), + KeyType::I16 => Ok(Value::I16(s.parse::().map_err(ParseError::Integer)?)), + KeyType::U16 => Ok(Value::U16(s.parse::().map_err(ParseError::Integer)?)), + KeyType::I8 => Ok(Value::I8(s.parse::().map_err(ParseError::Integer)?)), + KeyType::U8 => Ok(Value::U8(s.parse::().map_err(ParseError::Integer)?)), + KeyType::F32 => Ok(Value::F32(s.parse::().map_err(ParseError::Float)?)), + KeyType::F64 => Ok(Value::F64(s.parse::().map_err(ParseError::Float)?)), + KeyType::Bool => Ok(Value::Bool(s.parse::().map_err(ParseError::Bool)?)), KeyType::Unit => Ok(Value::Unit), + _ => unreachable!(), } } impl Source for EnvSource { - fn get_key(&self, key: &Key) -> Result> { - if key.skip_env() { - return Ok(None); - } - - let mut path = key.path().clone(); - if let Some(prefix) = &self.prefix { - path = path.push_root(prefix); - } - let joined = path - .into_iter() - .reduce(|a, b| format!("{a}{}{b}", self.joiner)) - .unwrap_or_default() - .to_uppercase(); - // Parse to requested type - match std::env::var(joined) { - Ok(var) => parse_to_value(key.key_type(), &var).map(Some), - Err(_) => Ok(None), - } + fn get_key(&self, path: KeyPath) -> Result> { + utils::get_key(&self.value, path) } } diff --git a/config/config/src/sources/file/json.rs b/config/config/src/sources/file/json.rs index f8c410b5..1a661c6e 100644 --- a/config/config/src/sources/file/json.rs +++ b/config/config/src/sources/file/json.rs @@ -1,26 +1,25 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; -use crate::Value; +use serde_value::Value; -impl From for Value { - fn from(value: serde_json::Value) -> Self { - match value { - serde_json::Value::String(s) => Value::String(s), - serde_json::Value::Number(n) => n - .as_i64() - .map(Value::Integer) - .unwrap_or(Value::Float(n.as_f64().expect("unsupported value type"))), - serde_json::Value::Bool(b) => Value::Boolean(b), - serde_json::Value::Array(a) => Value::Array(a.into_iter().map(|i| i.into()).collect()), - serde_json::Value::Object(map) => { - // Is there a better way than iterating over each entry? Probably not - let mut hashmap = HashMap::with_capacity(map.len()); - for (k, v) in map { - hashmap.insert(k, v.into()); - } - Value::Map(hashmap) +pub fn convert_value(value: serde_json::Value) -> Value { + match value { + serde_json::Value::String(s) => Value::String(s), + serde_json::Value::Number(n) => n + .as_i64() + .map(Value::I64) + .or_else(|| n.as_u64().map(Value::U64)) + .unwrap_or_else(|| Value::F64(n.as_f64().unwrap())), + serde_json::Value::Bool(b) => Value::Bool(b), + serde_json::Value::Array(a) => Value::Seq(a.into_iter().map(convert_value).collect()), + serde_json::Value::Object(map) => { + // Is there a better way than iterating over each entry? Probably not + let mut hashmap = BTreeMap::new(); + for (k, v) in map { + hashmap.insert(Value::String(k), convert_value(v)); } - serde_json::Value::Null => Value::Null, + Value::Map(hashmap) } + serde_json::Value::Null => Value::Option(None), } } diff --git a/config/config/src/sources/file/mod.rs b/config/config/src/sources/file/mod.rs index 4b7f34e6..53c322c4 100644 --- a/config/config/src/sources/file/mod.rs +++ b/config/config/src/sources/file/mod.rs @@ -5,7 +5,9 @@ use std::{ path::{Path, PathBuf}, }; -use crate::{Config, ConfigError, FileError, Key, KeyType, Result, Source, Value}; +use crate::{Config, ConfigError, FileError, KeyPath, Result, Source, Value}; + +use super::utils; mod json; mod toml; @@ -53,7 +55,7 @@ impl FileSource { let content: serde_json::Value = serde_json::from_reader(reader).map_err(FileError::Json)?; Ok(Self { - content: content.into(), + content: json::convert_value(content), _phantom: PhantomData, path: None, }) @@ -63,7 +65,7 @@ impl FileSource { let content = io::read_to_string(reader).map_err(FileError::Io)?; let value: ::toml::Value = ::toml::from_str(&content).map_err(FileError::Toml)?; Ok(Self { - content: value.into(), + content: toml::convert_value(value), _phantom: PhantomData, path: None, }) @@ -73,7 +75,7 @@ impl FileSource { let content: serde_yaml::Value = serde_yaml::from_reader(reader).map_err(FileError::Yaml)?; Ok(Self { - content: content.into(), + content: yaml::convert_value(content).unwrap_or_else(|| Value::Map(Default::default())), _phantom: PhantomData, path: None, }) @@ -81,39 +83,7 @@ impl FileSource { } impl Source for FileSource { - fn get_key(&self, key: &Key) -> Result> { - let mut current = &self.content; - for segment in key.path() { - let Value::Map(map) = current else { - // Trying to access a field on a non-map type - // I'm not sure if we should return an error here - return Ok(None); - }; - let Some(value) = map.get(segment) else { - return Ok(None); - }; - current = value; - } - // Check if value has right type - // Feels ugly - let type_match = match current { - Value::String(_) => *key.key_type() == KeyType::String, - Value::Integer(_) => *key.key_type() == KeyType::Integer, - Value::Float(_) => *key.key_type() == KeyType::Float, - Value::Boolean(_) => *key.key_type() == KeyType::Boolean, - Value::Array(_) => matches!(key.key_type(), KeyType::Array(_)), - Value::Map(_) => true, - Value::Null => true, - Value::Unit => *key.key_type() == KeyType::Unit, - }; - if type_match { - Ok(Some(current.clone())) - } else { - Err(ConfigError::TypeMismatch { - path: Some(key.path().clone()), - expected: key.key_type().clone(), - got: current.clone(), - }) - } + fn get_key(&self, path: KeyPath) -> Result> { + utils::get_key(&self.content, path) } } diff --git a/config/config/src/sources/file/toml.rs b/config/config/src/sources/file/toml.rs index 1fc844d4..ab41c668 100644 --- a/config/config/src/sources/file/toml.rs +++ b/config/config/src/sources/file/toml.rs @@ -1,24 +1,22 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; -use crate::Value; +use serde_value::Value; -impl From for Value { - fn from(value: toml::Value) -> Self { - match value { - toml::Value::String(s) => Value::String(s), - toml::Value::Integer(i) => Value::Integer(i), - toml::Value::Float(f) => Value::Float(f), - toml::Value::Boolean(b) => Value::Boolean(b), - toml::Value::Array(a) => Value::Array(a.into_iter().map(|i| i.into()).collect()), - toml::Value::Table(map) => { - // Is there a better way than iterating over each entry? Probably not - let mut hashmap = HashMap::with_capacity(map.len()); - for (k, v) in map { - hashmap.insert(k, v.into()); - } - Value::Map(hashmap) +pub fn convert_value(value: toml::Value) -> Value { + match value { + toml::Value::String(s) => Value::String(s), + toml::Value::Integer(i) => Value::I64(i), + toml::Value::Float(f) => Value::F64(f), + toml::Value::Boolean(b) => Value::Bool(b), + toml::Value::Array(a) => Value::Seq(a.into_iter().map(convert_value).collect()), + toml::Value::Table(map) => { + // Is there a better way than iterating over each entry? Probably not + let mut hashmap = BTreeMap::new(); + for (k, v) in map { + hashmap.insert(Value::String(k), convert_value(v)); } - _ => panic!("unsupported value type"), + Value::Map(hashmap) } + _ => panic!("unsupported value type"), } } diff --git a/config/config/src/sources/file/yaml.rs b/config/config/src/sources/file/yaml.rs index f58a954c..37aba83c 100644 --- a/config/config/src/sources/file/yaml.rs +++ b/config/config/src/sources/file/yaml.rs @@ -1,32 +1,35 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; use crate::Value; -impl From for Value { - fn from(value: serde_yaml::Value) -> Self { - match value { - serde_yaml::Value::String(s) => Value::String(s), - serde_yaml::Value::Number(n) => n - .as_i64() - .map(Value::Integer) - .unwrap_or(Value::Float(n.as_f64().expect("unsupported value type"))), - serde_yaml::Value::Bool(b) => Value::Boolean(b), - serde_yaml::Value::Sequence(s) => { - Value::Array(s.into_iter().map(|i| i.into()).collect()) +pub fn convert_value(value: serde_yaml::Value) -> Option { + match value { + serde_yaml::Value::String(s) => Some(Value::String(s)), + serde_yaml::Value::Number(n) => Some( + n.as_i64() + .map(Value::I64) + .or_else(|| n.as_u64().map(Value::U64)) + .unwrap_or_else(|| Value::F64(n.as_f64().unwrap())), + ), + serde_yaml::Value::Bool(b) => Some(Value::Bool(b)), + serde_yaml::Value::Sequence(a) => Some(Value::Seq( + a.into_iter().filter_map(convert_value).collect(), + )), + serde_yaml::Value::Mapping(map) => { + // Is there a better way than iterating over each entry? Probably not + let mut hashmap = BTreeMap::new(); + for (k, v) in map { + let Some(key) = convert_value(k) else { + continue; + }; + let Some(value) = convert_value(v) else { + continue; + }; + hashmap.insert(key, value); } - serde_yaml::Value::Mapping(map) => { - // Is there a better way than iterating over each entry? Probably not - let mut hashmap = HashMap::with_capacity(map.len()); - for (k, v) in map { - // HUH How can keys in a yaml map be not strings? - if let serde_yaml::Value::String(k) = k { - hashmap.insert(k, v.into()); - } - } - Value::Map(hashmap) - } - serde_yaml::Value::Null => Value::Null, - _ => panic!("unsupported value type"), + Some(Value::Map(hashmap)) } + serde_yaml::Value::Null => Some(Value::Option(None)), + serde_yaml::Value::Tagged(_) => None, } } diff --git a/config/config/src/sources/mod.rs b/config/config/src/sources/mod.rs index 1d3034d6..0636eb1f 100644 --- a/config/config/src/sources/mod.rs +++ b/config/config/src/sources/mod.rs @@ -2,6 +2,8 @@ pub mod cli; pub mod env; pub mod file; +mod utils; + pub use cli::CliSource; pub use env::EnvSource; pub use file::FileSource; diff --git a/config/config/src/sources/utils.rs b/config/config/src/sources/utils.rs new file mode 100644 index 00000000..e98a771e --- /dev/null +++ b/config/config/src/sources/utils.rs @@ -0,0 +1,20 @@ +use serde_value::Value; + +use crate::{KeyPath, Result}; + +pub fn get_key(mut current: &Value, path: KeyPath) -> Result> { + for segment in path.clone() { + let Value::Map(map) = current else { + // Trying to access a field on a non-map type + // I'm not sure if we should return an error here + panic!("Trying to access a field on a non-map type: {}, {:#?}", path, current); + }; + let Some(value) = map.get(&Value::String(segment)) else { + return Ok(None); + }; + + current = value; + } + + Ok(Some(current.clone())) +} diff --git a/config/config/src/tests.rs b/config/config/src/tests.rs index 406431c6..68e9bd51 100644 --- a/config/config/src/tests.rs +++ b/config/config/src/tests.rs @@ -2,10 +2,9 @@ //! This has to do with the way the macro generates code. It refers to items in the config crate with `::config` //! which is not available in the config crate itself. -use crate::{ - sources, Config, ConfigBuilder, ConfigError, Key, KeyPath, KeyType, Result, Source, Value, - ValueMap, -}; +use std::collections::BTreeMap; + +use crate::{sources, Config, ConfigBuilder, Key, KeyTree, Source, Value}; fn clear_env() { for (key, _) in std::env::vars() { @@ -15,134 +14,110 @@ fn clear_env() { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, serde::Deserialize, serde::Serialize)] struct DummyConfig { enabled: bool, logging: LoggingConfig, } -#[derive(Debug, PartialEq)] +// Can be generated with Config derive macro +impl Config for DummyConfig { + fn tree() -> KeyTree { + let mut keys = BTreeMap::new(); + + keys.insert("enabled".to_string(), Key::new(bool::tree())); + keys.insert("logging".to_string(), Key::new(LoggingConfig::tree())); + + KeyTree::Map(keys) + } +} + +#[derive(Debug, PartialEq, serde::Deserialize, serde::Serialize)] struct LoggingConfig { level: String, json: bool, } -// Can be generated with Config derive macro -impl Config for DummyConfig { - fn build(values: Option) -> Result { - let mut values = values.ok_or(ConfigError::MissingKey)?; - if let (Value::Boolean(enabled), Value::String(level), Value::Boolean(json)) = ( - values.remove(&KeyPath::from("enabled")).unwrap(), - values.remove(&KeyPath::from("logging.level")).unwrap(), - values.remove(&KeyPath::from("logging.json")).unwrap(), - ) { - Ok(Self { - enabled, - logging: LoggingConfig { level, json }, - }) - } else { - unimplemented!("wrong keys") - } - } - - fn keys(_path: KeyPath) -> Result> { - Ok(vec![ - Key::new("enabled".into(), KeyType::Boolean), - Key::new("logging.level".into(), KeyType::String), - Key::new("logging.json".into(), KeyType::Boolean), - ]) - } +impl Config for LoggingConfig { + fn tree() -> KeyTree { + let mut keys = BTreeMap::new(); - fn primitive() -> Option { - None - } + keys.insert("level".to_string(), Key::new(String::tree())); + keys.insert("json".to_string(), Key::new(bool::tree())); - fn from_value(_: Option) -> Result { - Err(ConfigError::NotAPrimitive) + KeyTree::Map(keys) } } #[test] fn env() { - let key = Key::new("test.key".into(), KeyType::String); - // With custom prefix and default joiner clear_env(); - std::env::set_var("SCUF_TEST_KEY", "TEST_VALUE"); - let config = sources::EnvSource::::with_prefix("SCUF"); + std::env::set_var("SCUF_ENABLED", "true"); + let config = sources::EnvSource::::with_prefix("SCUF").unwrap(); assert_eq!( - config.get_key(&key).unwrap().unwrap(), - Value::String("TEST_VALUE".to_string()) + config.get_key("enabled".into()).unwrap().unwrap(), + Value::Bool(true), ); // With no prefix and custom joiner clear_env(); - std::env::set_var("TEST__KEY", "TEST_VALUE"); - let config = sources::EnvSource::::with_joiner("__"); + std::env::set_var("LOGGING__JSON", "false"); + let config = sources::EnvSource::::with_joiner(None, "__").unwrap(); assert_eq!( - config.get_key(&key).unwrap().unwrap(), - Value::String("TEST_VALUE".to_string()) + config.get_key("logging.json".into()).unwrap().unwrap(), + Value::Bool(false), ); // With custom prefix and custom joiner clear_env(); - std::env::set_var("SCUF-TEST-KEY", "TEST_VALUE"); - let config = sources::EnvSource::::new("SCUF", "-"); + std::env::set_var("LOGGING_JSON", "true"); + let config = sources::EnvSource::::new().unwrap(); assert_eq!( - config.get_key(&key).unwrap().unwrap(), - Value::String("TEST_VALUE".to_string()) + config.get_key("logging.json".into()).unwrap().unwrap(), + Value::Bool(true), ); } #[test] fn file() { - let key = Key::new("test.key".into(), KeyType::String); let data: &[u8] = br#" [test] key = "test_value" "#; let config = sources::FileSource::::toml(data).unwrap(); assert_eq!( - config.get_key(&key).unwrap().unwrap(), + config.get_key("test.key".into()).unwrap().unwrap(), Value::String("test_value".to_string()) ); - assert_eq!( - config - .get_key(&Key::new("test.not_defined".into(), KeyType::String)) - .unwrap(), - None - ); + assert_eq!(config.get_key("test.not_defined".into()).unwrap(), None); } #[test] fn cli() { - let matches = sources::cli::generate_command::().get_matches_from(vec![ - "cli_test", - "--enabled", - "true", - "--logging.level", - "INFO", - "--logging.json", - "false", - ]); - let cli = sources::CliSource::::with_matches(matches); + let matches = sources::cli::generate_command::() + .unwrap() + .get_matches_from(vec![ + "cli_test", + "--enabled", + "true", + "--logging.level", + "INFO", + "--logging.json", + "false", + ]); + let cli = sources::CliSource::::with_matches(matches).unwrap(); assert_eq!( - cli.get_key(&Key::new("enabled".into(), KeyType::Boolean)) - .unwrap() - .unwrap(), - Value::Boolean(true), + cli.get_key("enabled".into()).unwrap().unwrap(), + Value::Bool(true), ); assert_eq!( - cli.get_key(&Key::new("logging.level".into(), KeyType::String)) - .unwrap() - .unwrap(), + cli.get_key("logging.level".into()).unwrap().unwrap(), Value::String("INFO".to_string()), ); assert_eq!( - cli.get_key(&Key::new("logging.json".into(), KeyType::Boolean)) - .unwrap() - .unwrap(), - Value::Boolean(false), + cli.get_key("logging.json".into()).unwrap().unwrap(), + Value::Bool(false), ); } diff --git a/config/config/src/value.rs b/config/config/src/value.rs deleted file mode 100644 index 2987c5ab..00000000 --- a/config/config/src/value.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::collections::HashMap; - -use crate::KeyPath; - -pub type ValueMap = HashMap; - -#[derive(Debug, Clone, PartialEq)] -pub enum Value { - String(String), - Integer(i64), - Float(f64), - Boolean(bool), - Array(Vec), - Map(HashMap), - Null, - Unit, -} - -pub fn extract_sub_values(values: &ValueMap, path: &KeyPath) -> Option { - let mut sub_values = HashMap::new(); - for (key_path, value) in values { - if key_path.root() == Some(&path.to_string()) { - let sub_path = key_path.clone(); - sub_values.insert(sub_path.drop_root(), value.clone()); - } - } - if sub_values.is_empty() { - None - } else { - Some(sub_values) - } -} diff --git a/config/config_derive/src/lib.rs b/config/config_derive/src/lib.rs index 844ac99e..0626468f 100644 --- a/config/config_derive/src/lib.rs +++ b/config/config_derive/src/lib.rs @@ -1,6 +1,6 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{DataStruct, Expr, Type, TypePath}; +use syn::{DataStruct, Type, TypePath}; #[proc_macro_derive(Config, attributes(config))] pub fn derive_answer_fn(tokens: TokenStream) -> TokenStream { @@ -21,30 +21,6 @@ fn impl_config(ast: &syn::DeriveInput) -> syn::Result { let attributes = get_attributes(&ast.attrs)?; - let struct_default_attr = match attributes.iter().find_map(|a| { - if let Attr::Default(e) = a { - Some(e) - } else { - None - } - }) { - Some(DefaultAttr::Default) => Some(DefaultAttr::Default), - Some(DefaultAttr::Expr(e)) => { - return Err(syn::Error::new_spanned( - e, - "Struct default must be non expression", - )) - } - None => None, - }; - - if attributes.iter().any(|a| matches!(a, Attr::FromStr(_))) { - return Err(syn::Error::new_spanned( - ast, - "Structs from_str is not supported", - )); - } - let struct_env_attr = attributes .iter() @@ -56,8 +32,7 @@ fn impl_config(ast: &syn::DeriveInput) -> syn::Result { .find_map(|a| if let Attr::Cli(e) = a { Some(e) } else { None }); let mut keys_init = vec![]; - let mut builder_init = vec![]; - let mut builder_keys = vec![]; + for field in fields.named.iter() { let Some(ident) = &field.ident else { return Err(syn::Error::new_spanned(field, "Only named fields are supported")); @@ -72,25 +47,6 @@ fn impl_config(ast: &syn::DeriveInput) -> syn::Result { let attributes = get_attributes(&field.attrs)?; - let field_default_attr = attributes - .iter() - .find_map(|a| { - if let Attr::Default(e) = a { - Some(e) - } else { - None - } - }) - .or(struct_default_attr.as_ref()); - - let field_from_str_attr = attributes.iter().find_map(|a| { - if let Attr::FromStr(e) = a { - Some(e) - } else { - None - } - }); - let field_env_attr = attributes .iter() .find_map(|a| if let Attr::Env(e) = a { Some(e) } else { None }) @@ -101,12 +57,24 @@ fn impl_config(ast: &syn::DeriveInput) -> syn::Result { .find_map(|a| if let Attr::Cli(e) = a { Some(e) } else { None }) .or(struct_cli_attr); + let type_attr = attributes + .iter() + .find_map(|a| { + if let Attr::KeyType(e) = a { + Some(e) + } else { + None + } + }) + .map(|t| quote! { #t }) + .unwrap_or_else(|| quote! { <#path as ::config::Config>::tree() }); + let add_attrs = { let env_attr = if let Some(env_attr) = field_env_attr { if env_attr.skip { - quote! { let key = key.with_skip_env(true); } + quote! { let key = key.with_skip_env(); } } else { - quote! { let key = key.with_skip_env(false); } + quote! {} } } else { quote! {} @@ -116,108 +84,25 @@ fn impl_config(ast: &syn::DeriveInput) -> syn::Result { if cli_attr.skip { quote! { let key = key.with_skip_cli(true); } } else { - quote! { let key = key.with_skip_cli(false); } + quote! {} } } else { quote! {} }; quote! { + let key = ::config::Key::new(#type_attr); #env_attr #cli_attr - let key = key.with_comment(#comment); + key.with_comment(#comment) } }; - let default_value = field_default_attr.map(|d| match d { - DefaultAttr::Default => quote! { ::default().#ident }, - DefaultAttr::Expr(expr) => quote! { #expr }, - }); - - if let Some(from_str) = field_from_str_attr { - keys_init.push(quote! { - { - let key = ::config::Key::new(root.child(stringify!(#ident)), ::config::KeyType::String); - #add_attrs - keys.push(key); - } + keys_init.push(quote! { + keys.insert(stringify!(#ident).to_string(), { + #add_attrs }); - - let from_str_expr = match from_str { - FromStrAttr::Default => quote! { - if let Some(value) = value { - <#path as ::std::str::FromStr>::from_str(value).map_err(|e| ::config::ConfigError::FromStr(Box::new(e))) - } else { - Err(::config::ConfigError::MissingKey) - } - }, - FromStrAttr::Expr(expr) => quote! { - ::config::__internal::parse(value, #expr) - }, - }; - - let missing_key = match default_value { - Some(expr) => quote! { Ok(#expr) }, - None => quote! { - let value = None; - #from_str_expr - }, - }; - - builder_init.push(quote! { - let #ident = match values.remove(&::std::convert::Into::<::config::KeyPath>::into(stringify!(#ident))) { - Some(::config::Value::String(value)) => { - let value = Some(value.as_str()); - #from_str_expr - }, - Some(v) => Err(::config::ConfigError::TypeMismatch { path: None, expected: ::config::KeyType::String, got: v }), - None => { #missing_key }, - }?; - }); - } else { - keys_init.push(quote! { - { - if let Some(kt) = <#path as ::config::Config>::primitive() { - // The field is a primitive type - let key = ::config::Key::new(root.child(stringify!(#ident)), kt); - #add_attrs - keys.push(key); - } else { - // The field is a sub struct that can be built itself - let mut sub_keys = <#path as ::config::Config>::keys(root.child(stringify!(#ident)))?; - keys.append(&mut sub_keys); - } - } - }); - - let mut build_value_primitive = - quote! { <#path as ::config::Config>::from_value(value)? }; - let mut build_value_struct = quote! { <#path as ::config::Config>::build(value)? }; - - if let Some(expr) = &default_value { - build_value_primitive = - quote! { if value.is_some() { #build_value_primitive } else { #expr } }; - build_value_struct = - quote! { if value.is_some() { #build_value_struct } else { #expr } }; - } - - builder_init.push(quote! { - let #ident = { - if <#path as ::config::Config>::primitive().is_some() { - let value = values.remove(&::std::convert::Into::<::config::KeyPath>::into(stringify!(#ident))); - #build_value_primitive - } else { - let value = ::config::extract_sub_values( - &values, - &::std::convert::Into::<::config::KeyPath>::into(stringify!(#ident)) - ); - #build_value_struct - } - }; - }); - } - - builder_keys.push(quote! { #ident, }); + }); } let name = &ast.ident; @@ -228,43 +113,16 @@ fn impl_config(ast: &syn::DeriveInput) -> syn::Result { const VERSION: Option<&'static str> = option_env!("CARGO_PKG_VERSION"); const AUTHOR : Option<&'static str> = option_env!("CARGO_PKG_AUTHORS"); - fn keys(mut root: ::config::KeyPath) -> ::config::Result<::std::vec::Vec<::config::Key>> { - let mut keys = ::std::vec![]; - #(#keys_init)* - Ok(keys) - } - - fn build(values: ::std::option::Option<::config::ValueMap>) -> ::config::Result { - let mut values = values.ok_or(::config::ConfigError::MissingKey)?; - #(#builder_init)* - Ok(Self { - #(#builder_keys)* - }) - } + fn tree() -> ::config::KeyTree { + let mut keys = ::std::collections::BTreeMap::new(); - fn primitive() -> ::std::option::Option<::config::KeyType> { - None - } + #(#keys_init)* - fn from_value(_: ::std::option::Option<::config::Value>) -> ::config::Result { - Err(::config::ConfigError::NotAPrimitive) + ::config::KeyTree::Map(keys) } } - }.into()) -} - -enum DefaultAttr { - // Use the default value from the struct - Default, - // Use the given expression - Expr(Expr), -} - -enum FromStrAttr { - // Use the std::str::FromStr implementation - Default, - // Use the given expression - Expr(Expr), + } + .into()) } struct EnvAttr { @@ -278,8 +136,7 @@ struct CliAttr { } enum Attr { - Default(DefaultAttr), - FromStr(FromStrAttr), + KeyType(syn::Expr), Env(EnvAttr), Cli(CliAttr), } @@ -300,36 +157,19 @@ fn get_attributes(attrs: &[syn::Attribute]) -> syn::Result> { .map(|meta| { match meta { syn::Meta::Path(path) => { - if path.is_ident("default") { - Ok(Attr::Default(DefaultAttr::Default)) - } else if path.is_ident("from_str") { - Ok(Attr::FromStr(FromStrAttr::Default)) - } else { - Err(syn::Error::new_spanned(path, "Unknown attribute")) - } + Err(syn::Error::new_spanned(path, "Unknown attribute")) } syn::Meta::NameValue(syn::MetaNameValue { path, value, .. }) => { - if path.is_ident("default") { - // Try see if the value is a string literal - match value { - syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit), - .. - }) => Ok(Attr::Default(DefaultAttr::Expr(syn::parse_str( - &lit.value(), - )?))), - expr => Ok(Attr::Default(DefaultAttr::Expr(expr))), - } - } else if path.is_ident("from_str") { + if path.is_ident("tree") { // Try see if the value is a string literal match value { syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. - }) => Ok(Attr::FromStr(FromStrAttr::Expr(syn::parse_str( + }) => Ok(Attr::KeyType(syn::parse_str( &lit.value(), - )?))), - expr => Ok(Attr::FromStr(FromStrAttr::Expr(expr))), + )?)), + expr => Ok(Attr::KeyType(expr)), } } else { Err(syn::Error::new_spanned(path, "Unknown attribute")) diff --git a/video/edge/src/config.rs b/video/edge/src/config.rs index f1c8f670..143a9a16 100644 --- a/video/edge/src/config.rs +++ b/video/edge/src/config.rs @@ -2,12 +2,13 @@ use std::net::SocketAddr; use anyhow::Result; use common::config::{LoggingConfig, RedisConfig, TlsConfig}; +use config::KeyTree; -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct EdgeConfig { /// Bind Address - #[config(from_str)] + #[config(tree = "KeyTree::String")] pub bind_address: SocketAddr, /// If we should use TLS @@ -23,11 +24,11 @@ impl Default for EdgeConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct GrpcConfig { /// The bind address for the gRPC server - #[config(from_str)] + #[config(tree = "KeyTree::String")] pub bind_address: SocketAddr, /// If we should use TLS for the gRPC server @@ -43,8 +44,8 @@ impl Default for GrpcConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct AppConfig { /// Name of this instance pub name: String, diff --git a/video/ingest/src/config.rs b/video/ingest/src/config.rs index 3c1c7f9a..578623cb 100644 --- a/video/ingest/src/config.rs +++ b/video/ingest/src/config.rs @@ -2,12 +2,13 @@ use std::net::SocketAddr; use anyhow::Result; use common::config::{LoggingConfig, RmqConfig, TlsConfig}; +use config::KeyTree; -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct RtmpConfig { /// The bind address for the RTMP server - #[config(from_str)] + #[config(tree = "KeyTree::String")] pub bind_address: SocketAddr, /// If we should use TLS for the RTMP server @@ -23,11 +24,11 @@ impl Default for RtmpConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct GrpcConfig { /// The bind address for the gRPC server - #[config(from_str)] + #[config(tree = "KeyTree::String")] pub bind_address: SocketAddr, /// Advertising address for the gRPC server @@ -47,8 +48,8 @@ impl Default for GrpcConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct ApiConfig { /// The bind address for the API server pub addresses: Vec, @@ -70,8 +71,8 @@ impl Default for ApiConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct TranscoderConfig { pub events_subject: String, } @@ -84,8 +85,8 @@ impl Default for TranscoderConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct AppConfig { /// Name of this instance pub name: String, diff --git a/video/transcoder/src/config.rs b/video/transcoder/src/config.rs index 592113c6..a90c9f39 100644 --- a/video/transcoder/src/config.rs +++ b/video/transcoder/src/config.rs @@ -2,35 +2,13 @@ use std::net::SocketAddr; use anyhow::Result; use common::config::{LoggingConfig, RedisConfig, RmqConfig, TlsConfig}; +use config::KeyTree; -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] -pub struct ApiConfig { - /// The bind address for the API server - pub addresses: Vec, - - /// Resolve interval in seconds (0 to disable) - pub resolve_interval: u64, - - /// If we should use TLS for the API server - pub tls: Option, -} - -impl Default for ApiConfig { - fn default() -> Self { - Self { - addresses: vec!["localhost:50051".to_string()], - resolve_interval: 30, - tls: None, - } - } -} - -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct GrpcConfig { /// The bind address for the gRPC server - #[config(from_str)] + #[config(tree = "KeyTree::String")] pub bind_address: SocketAddr, /// If we should use TLS for the gRPC server @@ -46,15 +24,17 @@ impl Default for GrpcConfig { } } -#[derive(Debug, Clone, Default, PartialEq, config::Config)] -#[config(default)] +#[derive( + Debug, Clone, Default, PartialEq, config::Config, serde::Deserialize, serde::Serialize, +)] +#[serde(default)] pub struct IngestConfig { /// If we should use TLS for the API server pub tls: Option, } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct TranscoderConfig { /// The direcory to create unix sockets in pub socket_dir: String, @@ -79,8 +59,8 @@ impl Default for TranscoderConfig { } } -#[derive(Debug, Clone, PartialEq, config::Config)] -#[config(default)] +#[derive(Debug, Clone, PartialEq, config::Config, serde::Deserialize, serde::Serialize)] +#[serde(default)] pub struct AppConfig { /// Name of this instance pub name: String, @@ -91,9 +71,6 @@ pub struct AppConfig { /// The log level to use, this is a tracing env filter pub logging: LoggingConfig, - /// API client configuration - pub api: ApiConfig, - /// gRPC server configuration pub grpc: GrpcConfig, @@ -112,7 +89,6 @@ impl Default for AppConfig { Self { name: "scuffle-transcoder".to_string(), config_file: Some("config".to_string()), - api: ApiConfig::default(), grpc: GrpcConfig::default(), logging: LoggingConfig::default(), rmq: RmqConfig::default(), diff --git a/video/transcoder/src/tests/config.rs b/video/transcoder/src/tests/config.rs index 4b9f6502..43c2526a 100644 --- a/video/transcoder/src/tests/config.rs +++ b/video/transcoder/src/tests/config.rs @@ -53,12 +53,6 @@ fn test_parse_file() { r#" [logging] level = "ingest=debug" - -[api] -addresses = [ - "test", - "test2" -] "#, ) .expect("Failed to write config file"); @@ -71,7 +65,6 @@ addresses = [ let config = AppConfig::parse().expect("Failed to parse config"); assert_eq!(config.logging.level, "ingest=debug"); - assert_eq!(config.api.addresses, vec!["test", "test2"]); assert_eq!( config.config_file, Some( @@ -99,12 +92,6 @@ level = "ingest=debug" [transcoder] socket_dir = "/tmp" - -[api] -addresses = [ - "test", - "test2" -] "#, ) .expect("Failed to write config file"); @@ -119,7 +106,6 @@ addresses = [ assert_eq!(config.logging.level, "ingest=info"); assert_eq!(config.transcoder.socket_dir, "/tmp".to_string()); - assert_eq!(config.api.addresses, vec!["test", "test2"]); assert_eq!( config.config_file, Some(