diff --git a/crates/cli/src/lib.rs b/crates/cli/src/lib.rs index e9d8c9af..2ca894cf 100644 --- a/crates/cli/src/lib.rs +++ b/crates/cli/src/lib.rs @@ -9,9 +9,9 @@ pub struct Args { } pub fn get_config(args: Args) -> eyre::Result { - Ok(if let Some(path) = args.config.as_ref() { - Config::from_file(path)? + if let Some(path) = args.config.as_ref() { + Config::from_file(path) } else { Config::from_env() - }) + } } diff --git a/crates/core/src/config.rs b/crates/core/src/config.rs index 956d66b6..79442e0e 100644 --- a/crates/core/src/config.rs +++ b/crates/core/src/config.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use std::str::FromStr; use ethers::types::Address; -use eyre::{eyre, Result}; +use eyre::{eyre, Context, Result}; use helios::client::{Client, ClientBuilder}; use helios::config::checkpoints; use helios::config::networks::Network; @@ -84,25 +84,37 @@ impl Default for Config { } impl Config { - pub fn from_env() -> Self { - Self { - network: Network::from_str( - &std::env::var("NETWORK").unwrap_or_default(), - ) - .unwrap_or(Network::MAINNET), - eth_execution_rpc: std::env::var("ETH_EXECUTION_RPC") - .unwrap_or_default(), - starknet_rpc: std::env::var("STARKNET_RPC").unwrap_or_default(), - data_dir: PathBuf::from( - std::env::var("DATA_DIR").unwrap_or_default(), - ), - poll_secs: u64::from_str( - &std::env::var("POLL_SECS").unwrap_or_default(), - ) - .unwrap_or(DEFAULT_POLL_SECS), - rpc_addr: rpc_addr(), - fee_token_addr: fee_token_addr(), - } + pub fn from_env() -> Result { + Self::from_vars(|key| std::env::var(key).ok()) + } + + fn from_vars(get: F) -> Result + where + F: Fn(&'static str) -> Option + 'static, + { + let require = |var_key: &'static str| { + get(var_key).ok_or_else(|| eyre!("The \"{}\" env var must be set or a configuration file must be specified", var_key)) + }; + + Ok(Self { + network: Network::from_str(&get("NETWORK").unwrap_or_default()) + .unwrap_or(Network::MAINNET), + eth_execution_rpc: require("ETH_EXECUTION_RPC")?, + starknet_rpc: require("STARKNET_RPC")?, + data_dir: PathBuf::from(get("DATA_DIR").unwrap_or_default()), + poll_secs: u64::from_str(&get("POLL_SECS").unwrap_or_default()) + .unwrap_or(DEFAULT_POLL_SECS), + rpc_addr: match get("RPC_ADDR") { + Some(addr) => SocketAddr::from_str(&addr) + .context("Invalid value for `RPC_ADDR`")?, + None => rpc_addr(), + }, + fee_token_addr: match get("FEE_TOKEN_ADDR") { + Some(addr) => FieldElement::from_hex_be(&addr) + .context("Invalid value for `FEE_TOKEN_ADDR`")?, + None => fee_token_addr(), + }, + }) } pub fn from_file(path: &str) -> Result { @@ -201,3 +213,79 @@ impl Config { .expect("incorrect helios client config") } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::*; + + fn case(vars: &[(&'static str, &'static str)]) -> Result { + let vars: HashMap<&str, &str> = HashMap::from_iter(vars.to_vec()); + Config::from_vars(move |s| vars.get(s).map(|s| s.to_string())) + } + + static MIN_CONFIG_VARS: &[(&'static str, &'static str)] = + &[("ETH_EXECUTION_RPC", "url"), ("STARKNET_RPC", "url")]; + + #[test] + fn test_min_config_requirements() { + assert!(case(&[("ETH_EXECUTION_RPC", "url"),]).is_err()); + assert!(case(&[("STARKNET_RPC", "url"),]).is_err()); + + assert!(case(&MIN_CONFIG_VARS).is_ok()); + } + + #[test] + fn test_unspecified_network_is_mainnet() { + let config = case(&MIN_CONFIG_VARS).unwrap(); + assert_eq!(config.network, Network::MAINNET); + } + + #[test] + fn test_rpc_address_is_validated() { + let result = case(&[ + ("ETH_EXECUTION_RPC", "url"), + ("STARKNET_RPC", "url"), + ("RPC_ADDR", "invalid_value"), + ]); + assert!(result.is_err()); + + let result = case(&[ + ("ETH_EXECUTION_RPC", "url"), + ("STARKNET_RPC", "url"), + ("RPC_ADDR", "127.0.0.1:3333"), + ]); + assert!(result.is_ok()); + assert_eq!( + SocketAddr::from_str("127.0.0.1:3333").unwrap(), + result.unwrap().rpc_addr + ); + + // Default test case + let config = case(&MIN_CONFIG_VARS).unwrap(); + assert_eq!(config.rpc_addr, rpc_addr()); + } + + #[test] + fn test_fee_token_addr_is_validated() { + let result = case(&[ + ("ETH_EXECUTION_RPC", "url"), + ("STARKNET_RPC", "url"), + ("FEE_TOKEN_ADDR", "invalid_value"), + ]); + assert!(result.is_err()); + + let result = case(&[ + ("ETH_EXECUTION_RPC", "url"), + ("STARKNET_RPC", "url"), + ("FEE_TOKEN_ADDR", "1"), + ]); + assert!(result.is_ok()); + assert_eq!(FieldElement::ONE, result.unwrap().fee_token_addr); + + // Default test case + let config = case(&MIN_CONFIG_VARS).unwrap(); + assert_eq!(config.fee_token_addr, fee_token_addr()); + } +}