Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#default
}
}

impl #impl_generics #prost_path::encoding::IsDefault for #ident #ty_generics #where_clause {}
};
let expanded = if skip_debug {
expanded
Expand Down
156 changes: 132 additions & 24 deletions prost/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,42 @@ impl DecodeContext {
}
}

/// Trait for checking if a value equals its default.
///
/// This trait exists because IEEE 754 floats have `-0.0 == 0.0`, but they are
/// distinct bit patterns. For map encoding, we need to preserve `-0.0` values
/// rather than treating them as the default `0.0`.
pub trait IsDefault: PartialEq {
#[inline]
fn is_default(&self, default: &Self) -> bool {
self == default
}
}

impl IsDefault for f32 {
#[inline]
fn is_default(&self, default: &Self) -> bool {
self.to_bits() == default.to_bits()
}
}

impl IsDefault for f64 {
#[inline]
fn is_default(&self, default: &Self) -> bool {
self.to_bits() == default.to_bits()
}
}

impl<T: PartialEq> IsDefault for Option<T> {}
impl<T: PartialEq> IsDefault for Vec<T> {}
impl IsDefault for bool {}
impl IsDefault for i32 {}
impl IsDefault for i64 {}
impl IsDefault for u32 {}
impl IsDefault for u64 {}
impl IsDefault for String {}
impl IsDefault for Bytes {}

pub const MIN_TAG: u32 = 1;
pub const MAX_TAG: u32 = (1 << 29) - 1;

Expand Down Expand Up @@ -966,7 +1002,7 @@ macro_rules! map {
buf: &mut B,
) where
K: Default + Eq + Hash + Ord,
V: Default + PartialEq,
V: Default + IsDefault,
B: BufMut,
KE: Fn(u32, &K, &mut B),
KL: Fn(u32, &K) -> usize,
Expand Down Expand Up @@ -1012,7 +1048,7 @@ macro_rules! map {
) -> usize
where
K: Default + Eq + Hash + Ord,
V: Default + PartialEq,
V: Default + IsDefault,
KL: Fn(u32, &K) -> usize,
VL: Fn(u32, &V) -> usize,
{
Expand All @@ -1034,7 +1070,7 @@ macro_rules! map {
buf: &mut B,
) where
K: Default + Eq + Hash + Ord,
V: PartialEq,
V: IsDefault,
B: BufMut,
KE: Fn(u32, &K, &mut B),
KL: Fn(u32, &K) -> usize,
Expand All @@ -1043,7 +1079,7 @@ macro_rules! map {
{
for (key, val) in values.iter() {
let skip_key = key == &K::default();
let skip_val = val == val_default;
let skip_val = val.is_default(val_default);

let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
+ (if skip_val { 0 } else { val_encoded_len(2, val) });
Expand Down Expand Up @@ -1111,7 +1147,7 @@ macro_rules! map {
) -> usize
where
K: Default + Eq + Hash + Ord,
V: PartialEq,
V: IsDefault,
KL: Fn(u32, &K) -> usize,
VL: Fn(u32, &V) -> usize,
{
Expand All @@ -1123,7 +1159,7 @@ macro_rules! map {
0
} else {
key_encoded_len(1, key)
}) + (if val == val_default {
}) + (if val.is_default(val_default) {
0
} else {
val_encoded_len(2, val)
Expand Down Expand Up @@ -1243,6 +1279,45 @@ mod test {
Ok(())
}

// Generic function used for testing that a value round trips when put into a map
pub fn check_entry_roundtrip<T, M, Insert, Get, E, Merge, EQ, L>(
t: T,
insert: Insert,
get: Get,
encode: E,
merge: Merge,
equal: EQ,
len: L,
) -> TestCaseResult
where
T: Debug + Default + PartialEq,
M: Sized + Default,
Insert: Fn(&mut M, i32, T) -> Option<T>,
Get: for<'a> Fn(&'a M, &i32) -> Option<&'a T>,
E: FnOnce(u32, &M, &mut BytesMut),
Merge: Fn(&mut M, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
EQ: Fn(&T, &T) -> bool,
L: Fn(u32, &M) -> usize,
{
let tag = 1u32;
let mut map = M::default();
insert(&mut map, 1, t);
let mut buf = BytesMut::with_capacity(len(tag, &map));
encode(tag, &map, &mut buf);

let mut decoded = M::default();
let mut bytes = buf.freeze();

while bytes.has_remaining() {
let _ = decode_key(&mut bytes).unwrap();
merge(&mut decoded, &mut bytes, DecodeContext::default()).unwrap();
}
let original = get(&map, &1).unwrap();
let modified = get(&decoded, &1).unwrap();
prop_assert!(equal(original, modified));
Ok(())
}

pub fn check_collection_type<T, B, E, M, L>(
value: T,
tag: u32,
Expand Down Expand Up @@ -1331,12 +1406,16 @@ mod test {
#[cfg(feature = "std")]
macro_rules! map_tests {
(keys: $keys:tt,
vals: $vals:tt) => {
vals: $vals:tt,
floats: $floats:tt,
) => {
mod hash_map {
map_tests!(@private HashMap, hash_map, $keys, $vals);
map_tests!(@private HashMap, hash_map, $floats);
}
mod btree_map {
map_tests!(@private BTreeMap, btree_map, $keys, $vals);
map_tests!(@private BTreeMap, btree_map, $floats);
}
};

Expand All @@ -1358,12 +1437,40 @@ mod test {
)*
};

(@private $map_type:ident,
$mod_name:ident,
[$(($val_ty:ident, $val_proto:ident, $func:ident)),*]) => {
$(
mod $func {
use std::collections::$map_type;
use proptest::prelude::*;
use crate::encoding::*;
use crate::encoding::test::check_entry_roundtrip;
proptest! {
#[test]
fn $func(v in any::<$val_ty>()) {
check_entry_roundtrip(
v,
$map_type::insert,
$map_type::get,
|tag, m, buf| $mod_name::encode(crate::encoding::int32::encode, crate::encoding::int32::encoded_len, $val_proto::encode, $val_proto::encoded_len, tag, m, buf),
|map, buf, context| $mod_name::merge(crate::encoding::int32::merge, $val_proto::merge, map, buf, context),
|lhs, rhs| $val_ty::to_bits(*lhs) == $val_ty::to_bits(*rhs),
|tag, map| $mod_name::encoded_len(crate::encoding::int32::encoded_len, $val_proto::encoded_len, tag, map)
)?;
}
}
}
)*
};

(@private $map_type:ident,
$mod_name:ident,
($key_ty:ty, $key_proto:ident),
[$(($val_ty:ty, $val_proto:ident)),*]) => {
[$(($val_ty:ty, $val_proto:ident, $another_name:ident)),*]) => {
$(
proptest! {

#[test]
fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
check_collection_type(values, tag, WireType::LengthDelimited,
Expand Down Expand Up @@ -1412,22 +1519,23 @@ mod test {
(String, string)
],
vals: [
(f32, float),
(f64, double),
(i32, int32),
(i64, int64),
(u32, uint32),
(u64, uint64),
(i32, sint32),
(i64, sint64),
(u32, fixed32),
(u64, fixed64),
(i32, sfixed32),
(i64, sfixed64),
(bool, bool),
(String, string),
(Vec<u8>, bytes)
]);
(f32, float, float2),
(f64, double, double2),
(i32, int32, int322),
(i64, int64, int642),
(u32, uint32, uint322),
(u64, uint64, uint642),
(i32, sint32, sint322),
(i64, sint64, sint642),
(u32, fixed32, fixed322),
(u64, fixed64, fixed642),
(i32, sfixed32, sfixed322),
(i64, sfixed64, sfixed642),
(bool, bool, bool2),
(String, string, string2),
(Vec<u8>, bytes, bytes2)
],
floats: [(f32, float, test_float_roundtrip), (f64, double, test_double_roundtrip)],);

#[test]
/// `decode_varint` accepts a `Buf`, which can be multiple concatenated buffers.
Expand Down
4 changes: 2 additions & 2 deletions tests/src/generic_derive.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub trait CustomType: prost::Message + Default + core::fmt::Debug {}
pub trait CustomType: prost::Message + Default + core::fmt::Debug + PartialEq {}

impl CustomType for u64 {}

Expand All @@ -11,7 +11,7 @@ enum GenericEnum<A: CustomType> {
Number(u64),
}

#[derive(Clone, prost::Message)]
#[derive(Clone, prost::Message, PartialEq)]
struct GenericMessage<A: CustomType> {
#[prost(message, tag = "1")]
data: Option<A>,
Expand Down
Loading