diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 08f1152c5..f595648b4 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -207,6 +207,8 @@ fn try_message(input: TokenStream) -> Result { #default } } + + impl #impl_generics #prost_path::encoding::IsDefault for #ident #ty_generics #where_clause {} }; let expanded = if skip_debug { expanded diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index 1794b0bfb..92564d8e6 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -97,6 +97,42 @@ impl DecodeContext { } } +/// Trait for checking if a value equals its default. +/// +/// This trait exists because IEEE 754 floats have `-0.0 == 0.0`, but they are +/// distinct bit patterns. For map encoding, we need to preserve `-0.0` values +/// rather than treating them as the default `0.0`. +pub trait IsDefault: PartialEq { + #[inline] + fn is_default(&self, default: &Self) -> bool { + self == default + } +} + +impl IsDefault for f32 { + #[inline] + fn is_default(&self, default: &Self) -> bool { + self.to_bits() == default.to_bits() + } +} + +impl IsDefault for f64 { + #[inline] + fn is_default(&self, default: &Self) -> bool { + self.to_bits() == default.to_bits() + } +} + +impl IsDefault for Option {} +impl IsDefault for Vec {} +impl IsDefault for bool {} +impl IsDefault for i32 {} +impl IsDefault for i64 {} +impl IsDefault for u32 {} +impl IsDefault for u64 {} +impl IsDefault for String {} +impl IsDefault for Bytes {} + pub const MIN_TAG: u32 = 1; pub const MAX_TAG: u32 = (1 << 29) - 1; @@ -966,7 +1002,7 @@ macro_rules! map { buf: &mut B, ) where K: Default + Eq + Hash + Ord, - V: Default + PartialEq, + V: Default + IsDefault, B: BufMut, KE: Fn(u32, &K, &mut B), KL: Fn(u32, &K) -> usize, @@ -1012,7 +1048,7 @@ macro_rules! map { ) -> usize where K: Default + Eq + Hash + Ord, - V: Default + PartialEq, + V: Default + IsDefault, KL: Fn(u32, &K) -> usize, VL: Fn(u32, &V) -> usize, { @@ -1034,7 +1070,7 @@ macro_rules! map { buf: &mut B, ) where K: Default + Eq + Hash + Ord, - V: PartialEq, + V: IsDefault, B: BufMut, KE: Fn(u32, &K, &mut B), KL: Fn(u32, &K) -> usize, @@ -1043,7 +1079,7 @@ macro_rules! map { { for (key, val) in values.iter() { let skip_key = key == &K::default(); - let skip_val = val == val_default; + let skip_val = val.is_default(val_default); let len = (if skip_key { 0 } else { key_encoded_len(1, key) }) + (if skip_val { 0 } else { val_encoded_len(2, val) }); @@ -1111,7 +1147,7 @@ macro_rules! map { ) -> usize where K: Default + Eq + Hash + Ord, - V: PartialEq, + V: IsDefault, KL: Fn(u32, &K) -> usize, VL: Fn(u32, &V) -> usize, { @@ -1123,7 +1159,7 @@ macro_rules! map { 0 } else { key_encoded_len(1, key) - }) + (if val == val_default { + }) + (if val.is_default(val_default) { 0 } else { val_encoded_len(2, val) @@ -1243,6 +1279,45 @@ mod test { Ok(()) } + // Generic function used for testing that a value round trips when put into a map + pub fn check_entry_roundtrip( + t: T, + insert: Insert, + get: Get, + encode: E, + merge: Merge, + equal: EQ, + len: L, + ) -> TestCaseResult + where + T: Debug + Default + PartialEq, + M: Sized + Default, + Insert: Fn(&mut M, i32, T) -> Option, + Get: for<'a> Fn(&'a M, &i32) -> Option<&'a T>, + E: FnOnce(u32, &M, &mut BytesMut), + Merge: Fn(&mut M, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, + EQ: Fn(&T, &T) -> bool, + L: Fn(u32, &M) -> usize, + { + let tag = 1u32; + let mut map = M::default(); + insert(&mut map, 1, t); + let mut buf = BytesMut::with_capacity(len(tag, &map)); + encode(tag, &map, &mut buf); + + let mut decoded = M::default(); + let mut bytes = buf.freeze(); + + while bytes.has_remaining() { + let _ = decode_key(&mut bytes).unwrap(); + merge(&mut decoded, &mut bytes, DecodeContext::default()).unwrap(); + } + let original = get(&map, &1).unwrap(); + let modified = get(&decoded, &1).unwrap(); + prop_assert!(equal(original, modified)); + Ok(()) + } + pub fn check_collection_type( value: T, tag: u32, @@ -1331,12 +1406,16 @@ mod test { #[cfg(feature = "std")] macro_rules! map_tests { (keys: $keys:tt, - vals: $vals:tt) => { + vals: $vals:tt, + floats: $floats:tt, + ) => { mod hash_map { map_tests!(@private HashMap, hash_map, $keys, $vals); + map_tests!(@private HashMap, hash_map, $floats); } mod btree_map { map_tests!(@private BTreeMap, btree_map, $keys, $vals); + map_tests!(@private BTreeMap, btree_map, $floats); } }; @@ -1358,12 +1437,40 @@ mod test { )* }; + (@private $map_type:ident, + $mod_name:ident, + [$(($val_ty:ident, $val_proto:ident, $func:ident)),*]) => { + $( + mod $func { + use std::collections::$map_type; + use proptest::prelude::*; + use crate::encoding::*; + use crate::encoding::test::check_entry_roundtrip; + proptest! { + #[test] + fn $func(v in any::<$val_ty>()) { + check_entry_roundtrip( + v, + $map_type::insert, + $map_type::get, + |tag, m, buf| $mod_name::encode(crate::encoding::int32::encode, crate::encoding::int32::encoded_len, $val_proto::encode, $val_proto::encoded_len, tag, m, buf), + |map, buf, context| $mod_name::merge(crate::encoding::int32::merge, $val_proto::merge, map, buf, context), + |lhs, rhs| $val_ty::to_bits(*lhs) == $val_ty::to_bits(*rhs), + |tag, map| $mod_name::encoded_len(crate::encoding::int32::encoded_len, $val_proto::encoded_len, tag, map) + )?; + } + } + } + )* + }; + (@private $map_type:ident, $mod_name:ident, ($key_ty:ty, $key_proto:ident), - [$(($val_ty:ty, $val_proto:ident)),*]) => { + [$(($val_ty:ty, $val_proto:ident, $another_name:ident)),*]) => { $( proptest! { + #[test] fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) { check_collection_type(values, tag, WireType::LengthDelimited, @@ -1412,22 +1519,23 @@ mod test { (String, string) ], vals: [ - (f32, float), - (f64, double), - (i32, int32), - (i64, int64), - (u32, uint32), - (u64, uint64), - (i32, sint32), - (i64, sint64), - (u32, fixed32), - (u64, fixed64), - (i32, sfixed32), - (i64, sfixed64), - (bool, bool), - (String, string), - (Vec, bytes) - ]); + (f32, float, float2), + (f64, double, double2), + (i32, int32, int322), + (i64, int64, int642), + (u32, uint32, uint322), + (u64, uint64, uint642), + (i32, sint32, sint322), + (i64, sint64, sint642), + (u32, fixed32, fixed322), + (u64, fixed64, fixed642), + (i32, sfixed32, sfixed322), + (i64, sfixed64, sfixed642), + (bool, bool, bool2), + (String, string, string2), + (Vec, bytes, bytes2) + ], + floats: [(f32, float, test_float_roundtrip), (f64, double, test_double_roundtrip)],); #[test] /// `decode_varint` accepts a `Buf`, which can be multiple concatenated buffers. diff --git a/tests/src/generic_derive.rs b/tests/src/generic_derive.rs index 8b283a86a..d43978587 100644 --- a/tests/src/generic_derive.rs +++ b/tests/src/generic_derive.rs @@ -1,4 +1,4 @@ -pub trait CustomType: prost::Message + Default + core::fmt::Debug {} +pub trait CustomType: prost::Message + Default + core::fmt::Debug + PartialEq {} impl CustomType for u64 {} @@ -11,7 +11,7 @@ enum GenericEnum { Number(u64), } -#[derive(Clone, prost::Message)] +#[derive(Clone, prost::Message, PartialEq)] struct GenericMessage { #[prost(message, tag = "1")] data: Option,