diff --git a/nativelink-config/examples/stores-config.json5 b/nativelink-config/examples/stores-config.json5 index b7c711260..4fe27c981 100644 --- a/nativelink-config/examples/stores-config.json5 +++ b/nativelink-config/examples/stores-config.json5 @@ -253,6 +253,8 @@ "endpoints": [ {"address": "grpc://${CAS_ENDPOINT:-127.0.0.1}:50051"} ], + "connections_per_endpoint": "5", + "rpc_timeout_s": "5m", "store_type": "ac" } }, diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index 0f52a9e8b..36b267c47 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -196,12 +196,12 @@ pub struct GrpcSpec { /// Limit the number of simultaneous upstream requests to this many. A /// value of zero is treated as unlimited. If the limit is reached the /// request is queued. - #[serde(default)] + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] pub max_concurrent_requests: usize, /// The number of connections to make to each specified endpoint to balance /// the load over multiple TCP connections. Default 1. - #[serde(default)] + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] pub connections_per_endpoint: usize, } diff --git a/nativelink-config/src/serde_utils.rs b/nativelink-config/src/serde_utils.rs index e9c6f81c9..16bd69644 100644 --- a/nativelink-config/src/serde_utils.rs +++ b/nativelink-config/src/serde_utils.rs @@ -152,6 +152,43 @@ pub fn convert_string_with_shellexpand<'de, D: Deserializer<'de>>( Ok((*(shellexpand::env(&value).map_err(de::Error::custom)?)).to_string()) } +pub fn convert_boolean_with_shellexpand<'de, D, T>(deserializer: D) -> Result +where + D: Deserializer<'de>, + T: TryFrom, + >::Error: fmt::Display, +{ + struct BooleanExpandVisitor>(PhantomData); + + impl Visitor<'_> for BooleanExpandVisitor + where + T: TryFrom, + >::Error: fmt::Display, + { + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a boolean or a shell-expandable string that is a boolean") + } + + fn visit_bool(self, v: bool) -> Result { + T::try_from(v).map_err(de::Error::custom) + } + + fn visit_str(self, v: &str) -> Result { + if v.is_empty() { + return Err(de::Error::custom("empty string is not a valid number")); + } + let expanded = shellexpand::env(v).map_err(de::Error::custom)?; + let s = expanded.as_ref().trim().to_lowercase(); + let parsed = s.parse::().map_err(de::Error::custom)?; + T::try_from(parsed).map_err(de::Error::custom) + } + } + + deserializer.deserialize_any(BooleanExpandVisitor::(PhantomData)) +} + /// Same as `convert_string_with_shellexpand`, but supports `Vec`. /// /// # Errors @@ -249,6 +286,86 @@ where deserializer.deserialize_any(DataSizeVisitor::(PhantomData)) } +/// # Errors +/// +/// Will return `Err` if deserialization fails. +pub fn convert_optional_data_size_with_shellexpand<'de, D, T>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, + T: TryFrom, + >::Error: fmt::Display, +{ + struct DataSizeVisitor>(PhantomData); + + impl<'de, T> Visitor<'de> for DataSizeVisitor + where + T: TryFrom, + >::Error: fmt::Display, + { + type Value = Option; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an optional number of bytes as an integer, or a string with a data size format (e.g., \"1GB\", \"500MB\", \"1.5TB\")") + } + + fn visit_none(self) -> Result { + Ok(None) + } + + fn visit_unit(self) -> Result { + Ok(None) + } + + fn visit_some>( + self, + deserializer: D2, + ) -> Result { + deserializer.deserialize_any(self) + } + + fn visit_u64(self, v: u64) -> Result { + T::try_from(u128::from(v)) + .map(Some) + .map_err(de::Error::custom) + } + + fn visit_i64(self, v: i64) -> Result { + if v < 0 { + return Err(de::Error::custom("Negative data size is not allowed")); + } + let v_u128 = u128::try_from(v).map_err(de::Error::custom)?; + T::try_from(v_u128).map(Some).map_err(de::Error::custom) + } + + fn visit_u128(self, v: u128) -> Result { + T::try_from(v).map(Some).map_err(de::Error::custom) + } + + fn visit_i128(self, v: i128) -> Result { + if v < 0 { + return Err(de::Error::custom("Negative data size is not allowed")); + } + let v_u128 = u128::try_from(v).map_err(de::Error::custom)?; + T::try_from(v_u128).map(Some).map_err(de::Error::custom) + } + + fn visit_str(self, v: &str) -> Result { + let expanded = shellexpand::env(v).map_err(de::Error::custom)?; + let s = expanded.as_ref().trim(); + if v.is_empty() { + return Err(de::Error::custom("Missing value in a size field")); + } + let byte_size = Byte::parse_str(s, true).map_err(de::Error::custom)?; + let bytes = byte_size.as_u128(); + T::try_from(bytes).map(Some).map_err(de::Error::custom) + } + } + + deserializer.deserialize_option(DataSizeVisitor::(PhantomData)) +} + /// # Errors /// /// Will return `Err` if deserialization fails. diff --git a/nativelink-config/src/stores.rs b/nativelink-config/src/stores.rs index 6c7a925e7..59ecb7afa 100644 --- a/nativelink-config/src/stores.rs +++ b/nativelink-config/src/stores.rs @@ -19,8 +19,9 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use crate::serde_utils::{ - convert_data_size_with_shellexpand, convert_duration_with_shellexpand, - convert_numeric_with_shellexpand, convert_optional_numeric_with_shellexpand, + convert_boolean_with_shellexpand, convert_data_size_with_shellexpand, + convert_duration_with_shellexpand, convert_numeric_with_shellexpand, + convert_optional_data_size_with_shellexpand, convert_optional_numeric_with_shellexpand, convert_optional_string_with_shellexpand, convert_string_with_shellexpand, convert_vec_string_with_shellexpand, }; @@ -472,6 +473,8 @@ pub enum StoreSpec { /// "endpoints": [ /// {"address": "grpc://${CAS_ENDPOINT:-127.0.0.1}:50051"} /// ], + /// "connections_per_endpoint": "5", + /// "rpc_timeout_s": "5m", /// "store_type": "ac" /// } /// ``` @@ -542,6 +545,7 @@ pub struct ShardConfig { /// all the store's weights divided by the individual store's weight. /// /// Default: 1 + #[serde(deserialize_with = "convert_optional_numeric_with_shellexpand")] pub weight: Option, } @@ -618,7 +622,7 @@ pub struct FilesystemSpec { /// runtime. /// A value of 0 means unlimited (no concurrency limit). /// Default: 0 - #[serde(default)] + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] pub max_concurrent_writes: usize, } @@ -632,7 +636,7 @@ pub struct ExperimentalOntapS3Spec { pub vserver_name: String, #[serde(deserialize_with = "convert_string_with_shellexpand")] pub bucket: String, - #[serde(default)] + #[serde(default, deserialize_with = "convert_optional_string_with_shellexpand")] pub root_certificates: Option, /// Common retry and upload configuration @@ -786,7 +790,7 @@ pub struct VerifySpec { /// an upload of data. /// /// This should be set to false for AC, but true for CAS stores. - #[serde(default)] + #[serde(default, deserialize_with = "convert_boolean_with_shellexpand")] pub verify_size: bool, /// If the data should be hashed and verify that the key matches the @@ -794,7 +798,7 @@ pub struct VerifySpec { /// request and if not set will use the global default. /// /// This should be set to false for AC, but true for CAS stores. - #[serde(default)] + #[serde(default, deserialize_with = "convert_boolean_with_shellexpand")] pub verify_hash: bool, } @@ -930,6 +934,10 @@ pub struct ExperimentalGcsSpec { /// Chunk size for resumable uploads. /// /// Default: 2MB + #[serde( + default, + deserialize_with = "convert_optional_data_size_with_shellexpand" + )] pub resumable_chunk_size: Option, /// Common retry and upload configuration @@ -937,17 +945,17 @@ pub struct ExperimentalGcsSpec { pub common: CommonObjectSpec, /// Error if authentication was not found. - #[serde(default)] + #[serde(default, deserialize_with = "convert_boolean_with_shellexpand")] pub authentication_required: bool, /// Connection timeout in milliseconds. /// Default: 3000 - #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] pub connection_timeout_s: u64, /// Read timeout in milliseconds. /// Default: 3000 - #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] pub read_timeout_s: u64, } @@ -981,17 +989,26 @@ pub struct CommonObjectSpec { /// upload will be aborted and the client will likely receive an error. /// /// Default: 5MB. + #[serde( + default, + deserialize_with = "convert_optional_data_size_with_shellexpand" + )] pub max_retry_buffer_per_request: Option, /// Maximum number of concurrent `UploadPart` requests per `MultipartUpload`. /// /// Default: 10. + /// + #[serde( + default, + deserialize_with = "convert_optional_numeric_with_shellexpand" + )] pub multipart_max_concurrent_uploads: Option, /// Allow unencrypted HTTP connections. Only use this for local testing. /// /// Default: false - #[serde(default)] + #[serde(default, deserialize_with = "convert_boolean_with_shellexpand")] pub insecure_allow_http: bool, /// Disable http/2 connections and only use http/1.1. Default client @@ -1001,7 +1018,7 @@ pub struct CommonObjectSpec { /// underlying network environment, S3, or GCS API servers specify otherwise. /// /// Default: false - #[serde(default)] + #[serde(default, deserialize_with = "convert_boolean_with_shellexpand")] pub disable_http2: bool, } @@ -1050,29 +1067,33 @@ pub struct GrpcEndpoint { /// The TLS configuration to use to connect to the endpoint (if grpcs). pub tls_config: Option, /// The maximum concurrency to allow on this endpoint. + #[serde( + default, + deserialize_with = "convert_optional_numeric_with_shellexpand" + )] pub concurrency_limit: Option, /// Timeout for establishing a TCP connection to the endpoint (seconds). /// If not set or 0, defaults to 30 seconds. - #[serde(default)] + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] pub connect_timeout_s: u64, /// TCP keepalive interval (seconds). Sends TCP keepalive probes at this /// interval to detect dead connections at the OS level. /// If not set or 0, defaults to 30 seconds. - #[serde(default)] + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] pub tcp_keepalive_s: u64, /// HTTP/2 keepalive interval (seconds). Sends HTTP/2 PING frames at this /// interval to detect dead connections at the application level. /// If not set or 0, defaults to 30 seconds. - #[serde(default)] + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] pub http2_keepalive_interval_s: u64, /// HTTP/2 keepalive timeout (seconds). If a PING response is not received /// within this duration, the connection is considered dead. /// If not set or 0, defaults to 20 seconds. - #[serde(default)] + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] pub http2_keepalive_timeout_s: u64, } @@ -1096,12 +1117,12 @@ pub struct GrpcSpec { /// Limit the number of simultaneous upstream requests to this many. A /// value of zero is treated as unlimited. If the limit is reached the /// request is queued. - #[serde(default)] + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] pub max_concurrent_requests: usize, /// The number of connections to make to each specified endpoint to balance /// the load over multiple TCP connections. Default 1. - #[serde(default)] + #[serde(default, deserialize_with = "convert_numeric_with_shellexpand")] pub connections_per_endpoint: usize, /// Maximum time (seconds) allowed for a single RPC request (e.g. a @@ -1109,7 +1130,7 @@ pub struct GrpcSpec { /// individual RPCs from hanging forever on dead connections. /// /// Default: 120 (seconds) - #[serde(default)] + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] pub rpc_timeout_s: u64, } @@ -1175,7 +1196,7 @@ pub struct RedisSpec { /// organize your data according to the shared prefix. /// /// Default: (Empty String / No Prefix) - #[serde(default)] + #[serde(default, deserialize_with = "convert_string_with_shellexpand")] pub key_prefix: String, /// Set the mode Redis is operating in. @@ -1396,7 +1417,7 @@ pub struct ExperimentalMongoSpec { /// Enable `MongoDB` change streams for real-time updates. /// Required for scheduler subscriptions. /// Default: false - #[serde(default)] + #[serde(default, deserialize_with = "convert_boolean_with_shellexpand")] pub enable_change_streams: bool, /// Write concern 'w' parameter. diff --git a/nativelink-config/tests/deserialization_test.rs b/nativelink-config/tests/deserialization_test.rs index e19e8f1b6..6ee384d33 100644 --- a/nativelink-config/tests/deserialization_test.rs +++ b/nativelink-config/tests/deserialization_test.rs @@ -13,7 +13,8 @@ // limitations under the License. use nativelink_config::serde_utils::{ - convert_data_size_with_shellexpand, convert_duration_with_shellexpand, + convert_boolean_with_shellexpand, convert_data_size_with_shellexpand, + convert_duration_with_shellexpand, convert_optional_data_size_with_shellexpand, convert_optional_numeric_with_shellexpand, convert_optional_string_with_shellexpand, }; use serde::Deserialize; @@ -30,6 +31,15 @@ struct DataSizeEntity { data_size: usize, } +#[derive(Deserialize, Debug)] +struct OptionalDataSizeEntity { + #[serde( + default, + deserialize_with = "convert_optional_data_size_with_shellexpand" + )] + data_size: Option, +} + #[derive(Deserialize, Debug)] struct OptionalNumericEntity { #[serde( @@ -45,6 +55,12 @@ struct OptionalStringEntity { value: Option, } +#[derive(Deserialize, Debug)] +struct BoolEntity { + #[serde(default, deserialize_with = "convert_boolean_with_shellexpand")] + value: bool, +} + mod duration_tests { use pretty_assertions::assert_eq; @@ -289,6 +305,21 @@ mod optional_values_tests { } } + #[test] + fn test_optional_datasize_values() { + let examples = [ + (r#"{"data_size": null}"#, None), + (r#"{"data_size": 42}"#, Some(42)), + (r"{}", None), // Missing field + (r#"{"data_size": "20K"}"#, Some(20000)), + ]; + + for (input, expected) in examples { + let deserialized: OptionalDataSizeEntity = serde_json5::from_str(input).unwrap(); + assert_eq!(deserialized.data_size, expected); + } + } + #[test] fn test_mixed_optional_values() { #[derive(Deserialize)] @@ -331,8 +362,34 @@ mod optional_values_tests { } } +mod boolean_tests { + use crate::BoolEntity; + + #[test] + fn test_bool_parsing() { + let examples = [ + // Standard value + (r#"{"value": true}"#, true), + (r#"{"value": false}"#, false), + // Stringy values + (r#"{"value": "true"}"#, true), + (r#"{"value": "false"}"#, false), + // Stringy values with odd cases + (r#"{"value": "TRue"}"#, true), + (r#"{"value": "faLSE"}"#, false), + ]; + + for (input, expected) in examples { + let deserialized: BoolEntity = + serde_json5::from_str(input).unwrap_or_else(|_| panic!("Failed on '{input}'")); + assert_eq!(deserialized.value, expected, "{input}"); + } + } +} + mod shellexpand_tests { use pretty_assertions::assert_eq; + use serde_json5::Location; use super::*; @@ -347,6 +404,8 @@ mod shellexpand_tests { std::env::set_var("TEST_NUMBER", "42"); std::env::set_var("TEST_VAR", "test_value"); std::env::set_var("EMPTY_VAR", ""); + std::env::set_var("TEST_GOOD_BOOL", "true"); + std::env::set_var("TEST_BAD_BOOL", "wibble"); }; // Test duration with environment variable @@ -359,6 +418,11 @@ mod shellexpand_tests { serde_json5::from_str::(r#"{"data_size": "${TEST_SIZE}"}"#).unwrap(); assert_eq!(size_result.data_size, 1_000_000_000); + let size_result = + serde_json5::from_str::(r#"{"data_size": "${TEST_SIZE}"}"#) + .unwrap(); + assert_eq!(size_result.data_size, Some(1_000_000_000)); + // Test optional numeric with environment variable let numeric_result = serde_json5::from_str::(r#"{"value": "${TEST_NUMBER}"}"#) @@ -384,5 +448,22 @@ mod shellexpand_tests { .to_string() .contains("environment variable not found") ); + + let good_bool_results = + serde_json5::from_str::(r#"{"value": "${TEST_GOOD_BOOL}"}"#).unwrap(); + assert!(good_bool_results.value); + + let bad_bool_results = + serde_json5::from_str::(r#"{"value": "${TEST_BAD_BOOL}"}"#).unwrap_err(); + assert_eq!( + bad_bool_results, + serde_json5::Error::Message { + msg: "provided string was not `true` or `false`".into(), + location: Some(Location { + line: 1, + column: 11 + }) + } + ); } }