diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 1f22acbe2..b5008ba16 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -422,7 +422,14 @@ impl<'b> CodeGenerator<'_, 'b> { let boxed = self .context .should_box_message_field(fq_message_name, &field.descriptor); - let ty = self.resolve_type(&field.descriptor, fq_message_name); + let custom_module_path = self + .context + .get_custom_scalar_module_path(&field.descriptor, fq_message_name); + let ty = self.resolve_type( + &field.descriptor, + fq_message_name, + custom_module_path.as_deref(), + ); debug!( " field: {:?}, type: {:?}, boxed: {}", @@ -440,10 +447,10 @@ impl<'b> CodeGenerator<'_, 'b> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field.descriptor); + let type_tag = self.field_type_tag(&field.descriptor, custom_module_path.as_deref()); self.buf.push_str(&type_tag); - if type_ == Type::Bytes { + if type_ == Type::Bytes && custom_module_path.is_none() { let bytes_type = self .context .bytes_type(fq_message_name, field.descriptor.name()); @@ -545,8 +552,14 @@ impl<'b> CodeGenerator<'_, 'b> { key: &FieldDescriptorProto, value: &FieldDescriptorProto, ) { - let key_ty = self.resolve_type(key, fq_message_name); - let value_ty = self.resolve_type(value, fq_message_name); + let field_name = format!("{}.{}", fq_message_name, field.descriptor.name()); + let custom_module_path_key = self.context.get_custom_scalar_module_path(key, &field_name); + let custom_module_path_value = self + .context + .get_custom_scalar_module_path(value, &field_name); + let key_ty = self.resolve_type(key, fq_message_name, custom_module_path_key.as_deref()); + let value_ty = + self.resolve_type(value, fq_message_name, custom_module_path_value.as_deref()); debug!( " map field: {:?}, key type: {:?}, value type: {:?}", @@ -561,8 +574,8 @@ impl<'b> CodeGenerator<'_, 'b> { let map_type = self .context .map_type(fq_message_name, field.descriptor.name()); - let key_tag = self.field_type_tag(key); - let value_tag = self.map_value_type_tag(value); + let key_tag = self.field_type_tag(key, custom_module_path_key.as_deref()); + let value_tag = self.map_value_type_tag(value, custom_module_path_value.as_deref()); self.buf.push_str(&format!( "#[prost({} = \"{}, {}\", tag = \"{}\")]\n", @@ -659,7 +672,10 @@ impl<'b> CodeGenerator<'_, 'b> { } self.push_indent(); - let ty_tag = self.field_type_tag(&field.descriptor); + let custom_module_path = self + .context + .get_custom_scalar_module_path(&field.descriptor, fq_message_name); + let ty_tag = self.field_type_tag(&field.descriptor, custom_module_path.as_deref()); self.buf.push_str(&format!( "#[prost({}, tag = \"{}\")]\n", ty_tag, @@ -668,7 +684,11 @@ impl<'b> CodeGenerator<'_, 'b> { self.append_field_attributes(&oneof_name, field.descriptor.name()); self.push_indent(); - let ty = self.resolve_type(&field.descriptor, fq_message_name); + let ty = self.resolve_type( + &field.descriptor, + fq_message_name, + custom_module_path.as_deref(), + ); let boxed = self.context.should_box_oneof_field( fq_message_name, @@ -973,7 +993,20 @@ impl<'b> CodeGenerator<'_, 'b> { self.buf.push_str("}\n"); } - fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { + fn resolve_type( + &self, + field: &FieldDescriptorProto, + fq_message_name: &str, + custom_scalar_module_path: Option<&str>, + ) -> String { + if let Some(module_path) = custom_scalar_module_path { + return format!( + "<{} as {}::encoding::CustomScalarInterface>::Type", + module_path, + self.context.prost_path() + ); + } + match field.r#type() { Type::Float => String::from("f32"), Type::Double => String::from("f64"), @@ -1030,7 +1063,14 @@ impl<'b> CodeGenerator<'_, 'b> { .join("::") } - fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + fn field_type_tag( + &self, + field: &FieldDescriptorProto, + custom_scalar_module_path: Option<&str>, + ) -> Cow<'static, str> { + if let Some(module_path) = custom_scalar_module_path { + return Cow::Owned(format!("custom_scalar({})", module_path)); + } match field.r#type() { Type::Float => Cow::Borrowed("float"), Type::Double => Cow::Borrowed("double"), @@ -1056,13 +1096,21 @@ impl<'b> CodeGenerator<'_, 'b> { } } - fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + fn map_value_type_tag( + &self, + field: &FieldDescriptorProto, + custom_scalar_module_path: Option<&str>, + ) -> Cow<'static, str> { + if let Some(module_path) = custom_scalar_module_path { + return Cow::Owned(format!("custom_scalar({})", module_path)); + } match field.r#type() { Type::Enum => Cow::Owned(format!( "enumeration({})", self.resolve_ident(field.type_name()) )), - _ => self.field_type_tag(field), + + _ => self.field_type_tag(field, custom_scalar_module_path), } } diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 8deba5c7a..288941015 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -12,6 +12,7 @@ use log::debug; use log::trace; use prost::Message; +use prost_types::field_descriptor_proto::Type; use prost_types::{FileDescriptorProto, FileDescriptorSet}; use crate::code_generator::CodeGenerator; @@ -36,6 +37,7 @@ pub struct Config { pub(crate) message_attributes: PathMap, pub(crate) enum_attributes: PathMap, pub(crate) field_attributes: PathMap, + pub(crate) custom_scalar: PathMap<(Type, String)>, pub(crate) boxed: PathMap<()>, pub(crate) prost_types: bool, pub(crate) strip_enum_prefix: bool, @@ -375,6 +377,26 @@ impl Config { self } + pub fn custom_scalar( + &mut self, + proto_type: Type, + module_path: M, + paths: I, + ) -> &mut Self + where + I: IntoIterator, + S: AsRef, + M: AsRef, + { + for matcher in paths { + self.custom_scalar.insert( + matcher.as_ref().to_string(), + (proto_type, module_path.as_ref().to_string()), + ); + } + self + } + /// Configures the code generator to use the provided service generator. pub fn service_generator(&mut self, service_generator: Box) -> &mut Self { self.service_generator = Some(service_generator); @@ -1202,6 +1224,7 @@ impl default::Default for Config { message_attributes: PathMap::default(), enum_attributes: PathMap::default(), field_attributes: PathMap::default(), + custom_scalar: PathMap::default(), boxed: PathMap::default(), prost_types: true, strip_enum_prefix: true, diff --git a/prost-build/src/context.rs b/prost-build/src/context.rs index 7fecc8dfc..e169c99a1 100644 --- a/prost-build/src/context.rs +++ b/prost-build/src/context.rs @@ -179,6 +179,26 @@ impl<'a> Context<'a> { false } + pub fn get_custom_scalar_module_path( + &self, + field: &FieldDescriptorProto, + fq_message_name: &str, + ) -> Option { + if matches!(field.r#type(), Type::Message | Type::Group) { + return None; + } + self.config + .custom_scalar + .get_first_field(fq_message_name, field.name()) + .and_then(|(ty, interface)| { + if field.r#type() == *ty { + Some(interface.clone()) + } else { + None + } + }) + } + /// Returns `true` if this message can automatically derive Copy trait. pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool { assert_eq!(".", &fq_message_name[..1]); diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index c5f36e23c..9865e45b2 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -126,9 +126,9 @@ impl Field { /// Returns a statement which encodes the map field. pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { let tag = self.tag; - let key_mod = self.key_ty.module(); - let ke = quote!(#prost_path::encoding::#key_mod::encode); - let kl = quote!(#prost_path::encoding::#key_mod::encoded_len); + let key_mod = self.key_ty.encoding_module(prost_path); + let ke = quote!(#key_mod::encode); + let kl = quote!(#key_mod::encoded_len); let module = self.map_ty.module(); match &self.value_ty { ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { @@ -147,9 +147,9 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); - let ve = quote!(#prost_path::encoding::#val_mod::encode); - let vl = quote!(#prost_path::encoding::#val_mod::encoded_len); + let val_mod = value_ty.encoding_module(prost_path); + let ve = quote!(#val_mod::encode); + let vl = quote!(#val_mod::encoded_len); quote! { #prost_path::encoding::#module::encode( #ke, @@ -179,8 +179,8 @@ impl Field { /// Returns an expression which evaluates to the result of merging a decoded key value pair /// into the map. pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { - let key_mod = self.key_ty.module(); - let km = quote!(#prost_path::encoding::#key_mod::merge); + let key_mod = self.key_ty.encoding_module(prost_path); + let km = quote!(#key_mod::merge); let module = self.map_ty.module(); match &self.value_ty { ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { @@ -197,9 +197,9 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); - let vm = quote!(#prost_path::encoding::#val_mod::merge); - quote!(#prost_path::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx)) + let val_mod = value_ty.encoding_module(prost_path); + let vm = quote!(#val_mod::merge); + quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx)) } ValueTy::Message => quote! { #prost_path::encoding::#module::merge( @@ -216,8 +216,8 @@ impl Field { /// Returns an expression which evaluates to the encoded length of the map. pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { let tag = self.tag; - let key_mod = self.key_ty.module(); - let kl = quote!(#prost_path::encoding::#key_mod::encoded_len); + let key_mod = self.key_ty.encoding_module(prost_path); + let kl = quote!(#key_mod::encoded_len); let module = self.map_ty.module(); match &self.value_ty { ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { @@ -233,8 +233,8 @@ impl Field { } } ValueTy::Scalar(value_ty) => { - let val_mod = value_ty.module(); - let vl = quote!(#prost_path::encoding::#val_mod::encoded_len); + let val_mod = value_ty.encoding_module(prost_path); + let vl = quote!(#val_mod::encoded_len); quote!(#prost_path::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident)) } ValueTy::Message => quote! { @@ -256,7 +256,7 @@ impl Field { pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option { if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty { let key_ty = self.key_ty.rust_type(prost_path); - let key_ref_ty = self.key_ty.rust_ref_type(); + let key_ref_ty = self.key_ty.rust_ref_type(prost_path); let get = Ident::new(&format!("get_{ident}"), Span::call_site()); let insert = Ident::new(&format!("insert_{ident}"), Span::call_site()); @@ -366,7 +366,8 @@ fn key_ty_from_str(s: &str) -> Result { | scalar::Ty::Sfixed32 | scalar::Ty::Sfixed64 | scalar::Ty::Bool - | scalar::Ty::String => Ok(ty), + | scalar::Ty::String + | scalar::Ty::CustomScalar(_) => Ok(ty), _ => bail!("invalid map key type: {s}"), } } diff --git a/prost-derive/src/field/mod.rs b/prost-derive/src/field/mod.rs index 69af0e014..5f6c0dfc8 100644 --- a/prost-derive/src/field/mod.rs +++ b/prost-derive/src/field/mod.rs @@ -168,7 +168,7 @@ impl Field { pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option { match *self { - Field::Scalar(ref scalar) => scalar.methods(ident), + Field::Scalar(ref scalar) => scalar.methods(prost_path, ident), Field::Map(ref map) => map.methods(prost_path, ident), _ => None, } diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 25596cbc0..bcb637ed8 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -84,6 +84,7 @@ impl Field { Kind::Packed } (Some(Label::Repeated), _, false) => Kind::Repeated, + // TODO support packable custom scalar ? }; Ok(Some(Field { ty, kind, tag })) @@ -106,20 +107,28 @@ impl Field { } pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.encoding_module(prost_path); let encode_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode), Kind::Repeated => quote!(encode_repeated), Kind::Packed => quote!(encode_packed), }; - let encode_fn = quote!(#prost_path::encoding::#module::#encode_fn); + let encode_fn = quote!(#module::#encode_fn); let tag = self.tag; match self.kind { Kind::Plain(ref default) => { - let default = default.typed(); + let default_check = match self.ty { + Ty::CustomScalar(ref path) => { + quote!(!<#path as #prost_path::encoding::CustomScalarInterface>::is_default(&#ident)) + } + _ => { + let default = default.typed(); + quote!(#ident != #default) + } + }; quote! { - if #ident != #default { + if #default_check { #encode_fn(#tag, &#ident, buf); } } @@ -138,12 +147,12 @@ impl Field { /// Returns an expression which evaluates to the result of merging a decoded /// scalar value into the field. pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.encoding_module(prost_path); let merge_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge), Kind::Repeated | Kind::Packed => quote!(merge_repeated), }; - let merge_fn = quote!(#prost_path::encoding::#module::#merge_fn); + let merge_fn = quote!(#module::#merge_fn); match self.kind { Kind::Plain(..) | Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { @@ -160,20 +169,28 @@ impl Field { /// Returns an expression which evaluates to the encoded length of the field. pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream { - let module = self.ty.module(); + let module = self.ty.encoding_module(prost_path); let encoded_len_fn = match self.kind { Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len), Kind::Repeated => quote!(encoded_len_repeated), Kind::Packed => quote!(encoded_len_packed), }; - let encoded_len_fn = quote!(#prost_path::encoding::#module::#encoded_len_fn); + let encoded_len_fn = quote!(#module::#encoded_len_fn); let tag = self.tag; match self.kind { Kind::Plain(ref default) => { - let default = default.typed(); + let default_check = match self.ty { + Ty::CustomScalar(ref path) => { + quote!(!<#path as #prost_path::encoding::CustomScalarInterface>::is_default(&#ident)) + } + _ => { + let default = default.typed(); + quote!(#ident != #default) + } + }; quote! { - if #ident != #default { + if #default_check { #encoded_len_fn(#tag, &#ident) } else { 0 @@ -269,7 +286,7 @@ impl Field { } /// Returns methods to embed in the message. - pub fn methods(&self, ident: &TokenStream) -> Option { + pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option { let mut ident_str = ident.to_string(); if ident_str.starts_with("r#") { ident_str = ident_str.split_off(2); @@ -350,27 +367,41 @@ impl Field { } }) } else if let Kind::Optional(ref default) = self.kind { - let ty = self.ty.rust_ref_type(); - - let match_some = if self.ty.is_numeric() { - quote!(::core::option::Option::Some(val) => val,) + if let Ty::CustomScalar(ref path) = self.ty { + let get_doc = format!( + "Returns the value of `{0}`, or the default value if `{0}` is unset.", + ident_str, + ); + + Some(quote! { + #[doc=#get_doc] + pub fn #get<'x>(&'x self) -> <#path as #prost_path::encoding::CustomScalarInterface>::RefType<'x> { + <#path as #prost_path::encoding::CustomScalarInterface>::get(&self.#ident) + } + }) } else { - quote!(::core::option::Option::Some(ref val) => &val[..],) - }; - - let get_doc = format!( - "Returns the value of `{ident_str}`, or the default value if `{ident_str}` is unset." - ); - - Some(quote! { - #[doc=#get_doc] - pub fn #get(&self) -> #ty { - match self.#ident { - #match_some - ::core::option::Option::None => #default, + let ty = self.ty.rust_ref_type(prost_path); + + let match_some = if self.ty.is_numeric() { + quote!(::core::option::Option::Some(val) => val,) + } else { + quote!(::core::option::Option::Some(ref val) => &val[..],) + }; + + let get_doc = format!( + "Returns the value of `{ident_str}`, or the default value if `{ident_str}` is unset." + ); + + Some(quote! { + #[doc=#get_doc] + pub fn #get(&self) -> #ty { + match self.#ident { + #match_some + ::core::option::Option::None => #default, + } } - } - }) + }) + } } else { None } @@ -396,6 +427,7 @@ pub enum Ty { String, Bytes(BytesTy), Enumeration(Path), + CustomScalar(Path), } #[derive(Clone, Debug, PartialEq, Eq)] @@ -460,6 +492,9 @@ impl Ty { Meta::List(ref meta_list) if meta_list.path.is_ident("enumeration") => { Ty::Enumeration(meta_list.parse_args::()?) } + Meta::List(ref meta_list) if meta_list.path.is_ident("custom_scalar") => { + Ty::CustomScalar(meta_list.parse_args::()?) + } _ => return Ok(None), }; Ok(Some(ty)) @@ -467,6 +502,7 @@ impl Ty { pub fn from_str(s: &str) -> Result { let enumeration_len = "enumeration".len(); + let custom_scalar_len = "custom_scalar".len(); let error = Err(anyhow!("invalid type: {s}")); let ty = match s.trim() { "float" => Ty::Float, @@ -497,6 +533,19 @@ impl Ty { Ty::Enumeration(parse_str::(s[1..s.len() - 1].trim())?) } + s if s.len() > custom_scalar_len && &s[..custom_scalar_len] == "custom_scalar" => { + let s = &s[custom_scalar_len..].trim(); + match s.chars().next() { + Some('(') => (), + _ => return error, + } + match s.chars().next_back() { + Some(')') => (), + _ => return error, + } + + Ty::CustomScalar(parse_str::(s[1..s.len() - 1].trim())?) + } _ => return error, }; Ok(ty) @@ -521,6 +570,8 @@ impl Ty { Ty::String => "string", Ty::Bytes(..) => "bytes", Ty::Enumeration(..) => "enum", + // should not be used + Ty::CustomScalar(..) => "custom_scalar", } } @@ -529,12 +580,15 @@ impl Ty { match self { Ty::String => quote!(#prost_path::alloc::string::String), Ty::Bytes(ty) => ty.rust_type(prost_path), - _ => self.rust_ref_type(), + Ty::CustomScalar(ref path) => { + quote!(<#path as #prost_path::CustomScalarInterface>::Type) + } + _ => self.rust_ref_type(prost_path), } } // TODO: rename to 'ref_type' - pub fn rust_ref_type(&self) -> TokenStream { + pub fn rust_ref_type(&self, prost_path: &Path) -> TokenStream { match *self { Ty::Double => quote!(f64), Ty::Float => quote!(f32), @@ -552,19 +606,31 @@ impl Ty { Ty::String => quote!(&str), Ty::Bytes(..) => quote!(&[u8]), Ty::Enumeration(..) => quote!(i32), + Ty::CustomScalar(ref path) => { + quote!(&<#path as #prost_path::CustomScalarInterface>::Type) + } } } - pub fn module(&self) -> Ident { + pub fn encoding_module(&self, prost_path: &Path) -> TokenStream { match *self { - Ty::Enumeration(..) => Ident::new("int32", Span::call_site()), - _ => Ident::new(self.as_str(), Span::call_site()), + Ty::Enumeration(..) => { + let module = Ident::new("int32", Span::call_site()); + quote!(#prost_path::encoding::#module) + } + Ty::CustomScalar(ref path) => { + quote!(<#path as #prost_path::CustomScalarInterface>) + } + _ => { + let module = Ident::new(self.as_str(), Span::call_site()); + quote!(#prost_path::encoding::#module) + } } } /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). pub fn is_numeric(&self) -> bool { - !matches!(self, Ty::String | Ty::Bytes(..)) + !matches!(self, Ty::String | Ty::Bytes(..) | Ty::CustomScalar(..)) } } @@ -609,6 +675,7 @@ pub enum DefaultValue { Bytes(Vec), Enumeration(TokenStream), Path(Path), + CustomScalar, } impl DefaultValue { @@ -772,6 +839,7 @@ impl DefaultValue { Ty::String => DefaultValue::String(String::new()), Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), + Ty::CustomScalar(_) => DefaultValue::CustomScalar, } } @@ -819,6 +887,9 @@ impl ToTokens for DefaultValue { } DefaultValue::Enumeration(ref value) => value.to_tokens(tokens), DefaultValue::Path(ref value) => value.to_tokens(tokens), + DefaultValue::CustomScalar => { + tokens.append_all(quote!(::core::default::Default::default())) + } } } } diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index 1794b0bfb..40d9fe951 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -1145,6 +1145,45 @@ pub mod btree_map { map!(BTreeMap); } +pub trait CustomScalarInterface { + type Type: Default + Clone + PartialEq + Eq + core::hash::Hash; + type RefType<'x>; + + fn encoded_len(tag: u32, value: &Self::Type) -> usize; + fn encode(tag: u32, value: &Self::Type, buf: &mut impl BufMut); + + fn encoded_len_repeated(tag: u32, values: &[Self::Type]) -> usize { + values.iter().map(|v| Self::encoded_len(tag, v)).sum() + } + fn encode_repeated(tag: u32, values: &[Self::Type], buf: &mut impl BufMut) { + for value in values { + Self::encode(tag, value, buf); + } + } + + fn merge( + wire_type: WireType, + value: &mut Self::Type, + buf: &mut impl Buf, + _ctx: DecodeContext, + ) -> Result<(), DecodeError>; + fn merge_repeated( + wire_type: WireType, + values: &mut Vec, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let mut value = Self::Type::default(); + Self::merge(wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + } + + fn is_default(value: &Self::Type) -> bool; + fn get<'x>(value: &'x Option) -> Self::RefType<'x>; +} + #[cfg(test)] mod test { #[cfg(not(feature = "std"))] diff --git a/prost/src/lib.rs b/prost/src/lib.rs index 1d6be53f6..394e1543d 100644 --- a/prost/src/lib.rs +++ b/prost/src/lib.rs @@ -17,6 +17,8 @@ mod types; #[doc(hidden)] pub mod encoding; +pub use encoding::CustomScalarInterface; + pub use crate::encoding::length_delimiter::{ decode_length_delimiter, encode_length_delimiter, length_delimiter_len, }; diff --git a/tests-2015/Cargo.toml b/tests-2015/Cargo.toml index 8a622dd27..c02c52aa7 100644 --- a/tests-2015/Cargo.toml +++ b/tests-2015/Cargo.toml @@ -31,3 +31,4 @@ tempfile = "3" cfg-if = "1" env_logger = { version = "0.11", default-features = false } prost-build = { path = "../prost-build" } +prost-types = { path = "../prost-types" } diff --git a/tests-2018/Cargo.toml b/tests-2018/Cargo.toml index 360aec4e8..60bd1b7c7 100644 --- a/tests-2018/Cargo.toml +++ b/tests-2018/Cargo.toml @@ -33,3 +33,4 @@ protobuf = { path = "../protobuf" } cfg-if = "1" env_logger = { version = "0.11", default-features = false } prost-build = { path = "../prost-build" } +prost-types = { path = "../prost-types" } diff --git a/tests-2024/Cargo.toml b/tests-2024/Cargo.toml index 15839ac1d..daa74e269 100644 --- a/tests-2024/Cargo.toml +++ b/tests-2024/Cargo.toml @@ -33,3 +33,4 @@ protobuf = { path = "../protobuf" } cfg-if = "1" env_logger = { version = "0.11", default-features = false } prost-build = { path = "../prost-build" } +prost-types = { path = "../prost-types" } diff --git a/tests-no-std/Cargo.toml b/tests-no-std/Cargo.toml index 1fbe1d468..0a8fcb6a3 100644 --- a/tests-no-std/Cargo.toml +++ b/tests-no-std/Cargo.toml @@ -35,3 +35,4 @@ protobuf = { path = "../protobuf" } cfg-if = "1" env_logger = { version = "0.11", default-features = false } prost-build = { path = "../prost-build" } +prost-types = { path = "../prost-types" } diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 066162eaf..1d0fbc405 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -27,3 +27,4 @@ protobuf = { path = "../protobuf" } cfg-if = "1" env_logger = { version = "0.11", default-features = false } prost-build = { path = "../prost-build" } +prost-types = { path = "../prost-types" } diff --git a/tests/build.rs b/tests/build.rs index c4bdb68bc..7a591676c 100644 --- a/tests/build.rs +++ b/tests/build.rs @@ -177,6 +177,33 @@ fn main() { .compile_protos(&[src.join("boxed_field.proto")], includes) .unwrap(); + prost_build::Config::new() + .btree_map([ + ".custom_scalar.Msg.e", + ".custom_scalar.Msg.f", + ".custom_scalar.Msg.g", + ]) + .custom_scalar( + prost_types::field_descriptor_proto::Type::String, + "crate::custom_scalar::MyStringInterface", + [ + ".custom_scalar.Msg.a", + ".custom_scalar.Msg.b", + ".custom_scalar.Msg.c", + ".custom_scalar.Msg.d", + ".custom_scalar.Msg.e", + ".custom_scalar.Msg.f.value", + ".custom_scalar.Msg.g.key", + ], + ) + .custom_scalar( + prost_types::field_descriptor_proto::Type::Bytes, + "crate::custom_scalar::MyVecInterface", + [".custom_scalar.Msg.h"], + ) + .compile_protos(&[src.join("custom_scalar.proto")], includes) + .unwrap(); + prost_build::Config::new() .compile_protos(&[src.join("oneof_name_conflict.proto")], includes) .unwrap(); diff --git a/tests/src/custom_scalar.proto b/tests/src/custom_scalar.proto new file mode 100644 index 000000000..89aac6f21 --- /dev/null +++ b/tests/src/custom_scalar.proto @@ -0,0 +1,17 @@ + +syntax = "proto3"; + +package custom_scalar; + +message Msg { + string a = 1; + repeated string b = 2; + optional string c = 3; + oneof my_enum { + string d = 4; + }; + map e = 5; + map f = 6; + map g = 7; + bytes h = 8; +} diff --git a/tests/src/custom_scalar.rs b/tests/src/custom_scalar.rs new file mode 100644 index 000000000..ab3dba842 --- /dev/null +++ b/tests/src/custom_scalar.rs @@ -0,0 +1,110 @@ +include!(concat!(env!("OUT_DIR"), "/custom_scalar.rs")); + +use alloc::string::String; +use alloc::vec; +use alloc::vec::Vec; +use prost::Message; + +#[test] +fn test_custom_scalar() { + let msg = Msg { + a: MyString("a".into()), + b: vec![MyString("b".into())], + c: Some(MyString("c".into())), + my_enum: Some(msg::MyEnum::D(MyString("e".into()))), + e: [(MyString("f".into()), MyString("f".into()))] + .iter() + .cloned() + .collect(), + f: [("f".into(), MyString("f".into()))] + .iter() + .cloned() + .collect(), + g: [(MyString("f".into()), "f".into())] + .iter() + .cloned() + .collect(), + h: MyVec(vec![1, 2]), + }; + + let data = msg.encode_to_vec(); + let decoded_msg = Msg::decode(data.as_slice()).unwrap(); + + assert_eq!(msg, decoded_msg); +} + +#[derive(Clone, Default, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] +pub struct MyString(pub String); + +pub struct MyStringInterface; + +impl prost::CustomScalarInterface for MyStringInterface { + type Type = MyString; + type RefType<'x> = &'x str; + + fn encoded_len(tag: u32, value: &Self::Type) -> usize { + ::prost::encoding::string::encoded_len(tag, &value.0) + } + + fn encode(tag: u32, value: &Self::Type, buf: &mut impl prost::bytes::BufMut) { + ::prost::encoding::string::encode(tag, &value.0, buf); + } + + fn merge( + wire_type: prost::encoding::WireType, + value: &mut Self::Type, + buf: &mut impl prost::bytes::Buf, + ctx: prost::encoding::DecodeContext, + ) -> Result<(), prost::DecodeError> { + ::prost::encoding::string::merge(wire_type, &mut value.0, buf, ctx) + } + + fn is_default(value: &Self::Type) -> bool { + value.0.is_empty() + } + + fn get<'x>(value: &'x Option) -> Self::RefType<'x> { + match value { + Some(value) => value.0.as_str(), + None => "", + } + } +} + +#[derive(Clone, Default, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] +pub struct MyVec(pub Vec); + +struct MyVecInterface; + +impl prost::CustomScalarInterface for MyVecInterface { + type Type = MyVec; + type RefType<'x> = &'x [u8]; + + fn encoded_len(tag: u32, value: &Self::Type) -> usize { + ::prost::encoding::bytes::encoded_len(tag, &value.0) + } + + fn encode(tag: u32, value: &Self::Type, buf: &mut impl prost::bytes::BufMut) { + ::prost::encoding::bytes::encode(tag, &value.0, buf); + } + + fn merge( + wire_type: prost::encoding::WireType, + value: &mut Self::Type, + buf: &mut impl prost::bytes::Buf, + ctx: prost::encoding::DecodeContext, + ) -> Result<(), prost::DecodeError> { + ::prost::encoding::bytes::merge(wire_type, &mut value.0, buf, ctx) + } + + fn is_default(value: &Self::Type) -> bool { + value.0.is_empty() + } + + fn get<'x>(value: &'x Option) -> Self::RefType<'x> { + match value { + Some(value) => value.0.as_slice(), + None => &[], + } + } +} diff --git a/tests/src/lib.rs b/tests/src/lib.rs index ab2b1057b..8155625c2 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -52,6 +52,9 @@ mod type_names; #[cfg(test)] mod boxed_field; +#[cfg(test)] +mod custom_scalar; + #[cfg(test)] mod custom_debug;