diff --git a/jomini_derive/src/lib.rs b/jomini_derive/src/lib.rs index bd356a2..aeb1a75 100644 --- a/jomini_derive/src/lib.rs +++ b/jomini_derive/src/lib.rs @@ -78,6 +78,32 @@ fn ungroup(mut ty: &Type) -> &Type { /// Another attribute unique to jomini is `#[jomini(take_last)]` which will take the last occurence /// of a field. Helpful when a field is duplicated accidentally. /// +/// ## Non-serde attributes +/// The following are the extra attributes implemented by `JominiDeserialize` +/// that do not map directly to serde attributes. +/// +/// - `#[jomini(duplicated)]` +/// +/// This attribute requires that the field be a `Vec` field. Each appearance +/// of this attribute will push a new entry onto the vector. +/// +/// - `#[jomini(take_last)]` +/// +/// This attribute will discard any prior value of the field when it is +/// encountered. It is mutually exclusive with `#[jomini(duplicated)]`. +/// +/// - `#[jomini(token=value)]` +/// +/// Annoytating all fields with this enables the binary deserializer to map `u16` tokens directly to +/// struct fields, bypassing intermediate string lookups and comparisons. +/// +/// - `#[jomini(collect_with="path")]` +/// +/// This method will call the given function for every unknown entry in the +/// structure. The function should have signature +/// `fn(&mut T, &str, &mut A) -> Result<(), A::Error> where A: MapAccess`, +/// where `T` is the type of the field. +/// /// ## The Why /// /// Serde's `Deserialize` implementation will raise an error if a field occurs more than once in @@ -145,10 +171,13 @@ fn derive_impl(dinput: DeriveInput) -> Result { take_last: bool, default: DefaultFallback, deserialize_with: Option, + collect_with: Option, token: Option, borrow: bool, } + let mut unknown_field = None; + let mut field_attrs = Vec::new(); for f in &named_fields.named { let field_span = f.ident.as_ref().map_or(span, |id| id.span()); @@ -162,6 +191,7 @@ fn derive_impl(dinput: DeriveInput) -> Result { let mut take_last = false; let mut default = DefaultFallback::No; let mut deserialize_with = None; + let mut collect_with = None; let mut alias = None; let mut token = None; let mut borrow = false; @@ -197,6 +227,9 @@ fn derive_impl(dinput: DeriveInput) -> Result { } else if meta.path.is_ident("deserialize_with") { let lit: LitStr = meta.value()?.parse()?; deserialize_with = Some(lit.parse()?); + } else if meta.path.is_ident("collect_with") { + let lit: LitStr = meta.value()?.parse()?; + collect_with = Some(lit.parse()?); } else if meta.path.is_ident("alias") { let lit: LitStr = meta.value()?.parse()?; alias = Some(lit.value()); @@ -226,6 +259,7 @@ fn derive_impl(dinput: DeriveInput) -> Result { take_last, default, deserialize_with, + collect_with, token, borrow, }; @@ -237,6 +271,13 @@ fn derive_impl(dinput: DeriveInput) -> Result { )); } + if attr.deserialize_with.is_some() && attr.collect_with.is_some() { + return Err(Error::new( + field_span, + "Cannot have both deserialize_with and collect_with attributes on a field", + )); + } + // Validate borrow attribute usage if attr.borrow { // Check if the type is suitable for borrowing @@ -255,6 +296,15 @@ fn derive_impl(dinput: DeriveInput) -> Result { } } + if attr.collect_with.is_some() { + if unknown_field.is_some() { + return Err(Error::new( + field_span, + "Only one collect_with field allowed per struct", + )); + } + unknown_field = Some(field_attrs.len()); + } field_attrs.push(attr); } @@ -367,7 +417,7 @@ fn derive_impl(dinput: DeriveInput) -> Result { let builder_init = field_attrs.iter().map(|f| { let name = &f.ident; let x = &f.typ; - if !f.duplicated { + if !f.duplicated && f.collect_with.is_none() { let field_name_opt = format_ident!("{}_opt", name); quote! { let mut #field_name_opt : ::std::option::Option<#x> = None } } else { @@ -375,7 +425,7 @@ fn derive_impl(dinput: DeriveInput) -> Result { } }); - let builder_fields = field_attrs.iter().map(|f| { + let builder_fields = field_attrs.iter().filter(|f| f.collect_with.is_none()).map(|f| { let name = &f.ident; let x = &f.typ; let name_str = &f.display; @@ -455,7 +505,25 @@ fn derive_impl(dinput: DeriveInput) -> Result { } }).collect::>>()?; - let field_extract = field_attrs.iter().filter(|x| !x.duplicated).map(|f| { + let default_arm = if let Some(default_idx) = unknown_field { + let f = &field_attrs[default_idx]; + let collector = f.collect_with.as_ref().unwrap(); + let name = &f.ident; + quote! { + __Field::__unknown_str(s) => { #collector(&mut #name, s.as_ref(), &mut __map)?; }, + _ => { ::serde::de::MapAccess::next_value::<::serde::de::IgnoredAny>(&mut __map)?; } + } + } else { + quote! { _ => { ::serde::de::MapAccess::next_value::<::serde::de::IgnoredAny>(&mut __map)?; } } + }; + + let unknown_owned_str = if unknown_field.is_some() { + quote! { Ok(__Field::__unknown_str(::std::borrow::Cow::Owned(__value.into()))) } + } else { + quote! { Ok(__Field::__ignore) } + }; + + let field_extract = field_attrs.iter().filter(|x| !x.duplicated && x.collect_with.is_none()).map(|f| { let name = &f.ident; let field_name_opt = format_ident!("{}_opt", name); let name_str = &f.display; @@ -491,7 +559,7 @@ fn derive_impl(dinput: DeriveInput) -> Result { quote! { #match_arm => Ok(#field_ident) } - }); + }).collect::>(); let field_enum_token_match = field_attrs.iter().filter_map(|f| { f.token.map(|token| { @@ -538,14 +606,15 @@ fn derive_impl(dinput: DeriveInput) -> Result { fn deserialize<__D>(__deserializer: __D) -> ::std::result::Result where __D: ::serde::Deserializer<'de> { #[allow(non_camel_case_types)] - enum __Field { + enum __Field<'de> { #(#field_enums),* , + __unknown_str(::std::borrow::Cow<'de, str>), __ignore, }; struct __FieldVisitor; impl<'de> ::serde::de::Visitor<'de> for __FieldVisitor { - type Value = __Field; + type Value = __Field<'de>; fn expecting( &self, __formatter: &mut ::std::fmt::Formatter, @@ -561,7 +630,19 @@ fn derive_impl(dinput: DeriveInput) -> Result { { match __value { #(#field_enum_match),* , - _ => Ok(__Field::__ignore), + _ => #unknown_owned_str, + } + } + fn visit_borrowed_str<__E>( + self, + __value: &'de str, + ) -> ::std::result::Result + where + __E: ::serde::de::Error, + { + match __value { + #(#field_enum_match),* , + _ => Ok(__Field::__unknown_str(__value.into())), } } fn visit_u16<__E>( @@ -578,7 +659,7 @@ fn derive_impl(dinput: DeriveInput) -> Result { } } - impl<'de> serde::Deserialize<'de> for __Field { + impl<'de> serde::Deserialize<'de> for __Field<'de> { #[inline] fn deserialize<__D>( __deserializer: __D, @@ -614,7 +695,7 @@ fn derive_impl(dinput: DeriveInput) -> Result { while let Some(__key) = ::serde::de::MapAccess::next_key::<__Field>(&mut __map)? { match __key { #(#builder_fields),* , - _ => { ::serde::de::MapAccess::next_value::<::serde::de::IgnoredAny>(&mut __map)?; } + #default_arm } } diff --git a/jomini_derive/tests/18-collect-with.rs b/jomini_derive/tests/18-collect-with.rs new file mode 100644 index 0000000..162b5d6 --- /dev/null +++ b/jomini_derive/tests/18-collect-with.rs @@ -0,0 +1,61 @@ +use jomini_derive::JominiDeserialize; +use serde::{de}; + +#[derive(JominiDeserialize)] +pub struct Model { + #[jomini(collect_with = "add_country_node")] + countries: Vec<(String, u16)>, + first: u16, + fourth: u16, + #[jomini(duplicated)] + core: Vec, + names: Vec, +} + +fn add_country_node<'de, A: de::MapAccess<'de>>( + countries: &mut Vec<(String, u16)>, + key: &str, + map: &mut A, +) -> Result<(), A::Error> { + if key.len() <= 3 { + countries.push((key.into(), map.next_value()?)); + } else { + map.next_value::()?; + } + Ok(()) +} + +#[test] +fn test_deserialize_with() { + let data = r#" + { + "first": 1, + "core": 10, + "fourth": 2, + "core": 20, + "names": [ "CCC", "DDD" ], + "TAG": 10, + "NOTATAG": 40, + "MEE": 5 + }"#; + + let m: Model = serde_json::from_str(data).unwrap(); + assert_eq!(m.first, 1); + assert_eq!(m.fourth, 2); + assert_eq!(m.core, vec![10, 20]); + assert_eq!(m.names, vec!["CCC".to_string(), "DDD".to_string()]); + assert_eq!( + m.countries, + vec![("TAG".to_string(), 10), ("MEE".to_string(), 5)] + ); + + let m2: Model = serde_json::from_reader(data.as_bytes()).unwrap(); + assert_eq!(m2.first, 1); + assert_eq!(m2.fourth, 2); + assert_eq!(m2.core, vec![10, 20]); + assert_eq!(m2.names, vec!["CCC".to_string(), "DDD".to_string()]); + assert_eq!( + m2.countries, + vec![("TAG".to_string(), 10), ("MEE".to_string(), 5)] + ); +}