Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions wincode/src/schema/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use {
TypeMeta,
},
core::{
cell::{Cell, RefCell},
marker::PhantomData,
mem::{self, transmute, MaybeUninit},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
Expand Down Expand Up @@ -1823,3 +1824,114 @@ impl_nonzero!(
NonZeroI128 => i128,
NonZeroIsize => isize,
);

unsafe impl<C: ConfigCore, T> SchemaWrite<C> for Cell<T>
where
T: SchemaWrite<C>,
T::Src: Copy,
{
type Src = Cell<T::Src>;

const TYPE_META: TypeMeta = const {
match T::TYPE_META {
TypeMeta::Static { size, .. } => TypeMeta::Static {
size,
zero_copy: false,
},
TypeMeta::Dynamic => TypeMeta::Dynamic,
}
};

#[inline]
fn size_of(src: &Self::Src) -> WriteResult<usize> {
T::size_of(&src.get())
}

#[inline]
fn write(writer: impl Writer, value: &Self::Src) -> WriteResult<()> {
let val = &value.get();
T::write(writer, val)?;
Ok(())
}
}

unsafe impl<'de, C: ConfigCore, T> SchemaRead<'de, C> for Cell<T>
where
T: SchemaRead<'de, C>,
{
type Dst = Cell<T::Dst>;

const TYPE_META: TypeMeta = const {
match T::TYPE_META {
TypeMeta::Static { size, .. } => TypeMeta::Static {
size,
zero_copy: false,
},
TypeMeta::Dynamic => TypeMeta::Dynamic,
}
};

#[inline(always)]
fn read(reader: impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
let value = T::get(reader)?;
dst.write(Cell::new(value));
Ok(())
}
}

unsafe impl<T, C: ConfigCore> SchemaWrite<C> for RefCell<T>
where
T: SchemaWrite<C>,
{
type Src = RefCell<T::Src>;

const TYPE_META: TypeMeta = const {
match T::TYPE_META {
TypeMeta::Static { size, .. } => TypeMeta::Static {
size,
zero_copy: false,
},
TypeMeta::Dynamic => TypeMeta::Dynamic,
}
};

#[inline]
fn size_of(src: &Self::Src) -> WriteResult<usize> {
let borrowed = src
.try_borrow()
.map_err(|_| crate::error::WriteError::Custom("RefCell already borrowed mutably"))?;
T::size_of(&*borrowed)
}

#[inline]
fn write(writer: impl Writer, src: &Self::Src) -> WriteResult<()> {
let borrowed = src
.try_borrow()
.map_err(|_| crate::error::WriteError::Custom("RefCell already borrowed mutably"))?;
T::write(writer, &*borrowed)
}
}

unsafe impl<'de, T, C: ConfigCore> SchemaRead<'de, C> for RefCell<T>
where
T: SchemaRead<'de, C>,
{
type Dst = RefCell<T::Dst>;

const TYPE_META: TypeMeta = const {
match T::TYPE_META {
TypeMeta::Static { size, .. } => TypeMeta::Static {
size,
zero_copy: false,
},
TypeMeta::Dynamic => TypeMeta::Dynamic,
}
};

#[inline]
fn read(reader: impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
let val = T::get(reader)?;
dst.write(RefCell::new(val));
Ok(())
}
}
162 changes: 161 additions & 1 deletion wincode/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ mod tests {
proptest::prelude::*,
std::{
alloc::Layout,
cell::Cell,
cell::{Cell, RefCell},
collections::{BinaryHeap, VecDeque},
mem::MaybeUninit,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
Expand Down Expand Up @@ -3534,6 +3534,166 @@ mod tests {
);
}

#[test]
fn test_cell_basic() {
proptest!(proptest_cfg(), |(value: (u32, f32, i64, f64, bool, u8))| {
let value = (
Cell::new(value.0),
Cell::new(value.1),
Cell::new(value.2),
Cell::new(value.3),
Cell::new(value.4),
Cell::new(value.5),
);

let serialized = serialize(&value).unwrap();
let bincode_serialized = bincode::serialize(&value).unwrap();
prop_assert_eq!(&serialized, &bincode_serialized);

type T = (
Cell<u32>,
Cell<f32>,
Cell<i64>,
Cell<f64>,
Cell<bool>,
Cell<u8>,
);

let deserialized: T = deserialize(&serialized).unwrap();
let bincode_deserialized: T = bincode::deserialize(&bincode_serialized).unwrap();

prop_assert_eq!(deserialized, bincode_deserialized);
});
}

#[test]
fn test_cell_char_bincode_equivalence() {
proptest!(proptest_cfg(), |(cell: Cell<char>)| {
let bincode_serialized = bincode::serialize(&cell).unwrap();
let schema_serialized = serialize(&cell).unwrap();
prop_assert_eq!(&bincode_serialized, &schema_serialized);

let bincode_deserialized: Cell<char> = bincode::deserialize(&bincode_serialized).unwrap();
let schema_deserialized: Cell<char> = deserialize(&schema_serialized).unwrap();

prop_assert_eq!(bincode_deserialized, schema_deserialized);
});
}

#[test]
fn test_cell_arrays_bincode_equivalence() {
proptest!(proptest_cfg(), |(cell: Cell<[u32; 4]>)| {

let bincode_serialized = bincode::serialize(&cell).unwrap();
let schema_serialized = serialize(&cell).unwrap();
prop_assert_eq!(&bincode_serialized, &schema_serialized);

let bincode_deserialized: Cell<[u32; 4]> = bincode::deserialize(&bincode_serialized).unwrap();
let schema_deserialized: Cell<[u32; 4]> = deserialize(&schema_serialized).unwrap();
prop_assert_eq!(schema_deserialized, bincode_deserialized);
});
}

#[test]
fn test_refcell_basic() {
proptest!(proptest_cfg(), |(value: (RefCell<u32>, RefCell<f32>, RefCell<i64>, RefCell<f64>, RefCell<bool>, RefCell<u8>))| {
let serialized = serialize(&value).unwrap();
let bincode_serialized = bincode::serialize(&value).unwrap();
prop_assert_eq!(&serialized, &bincode_serialized);

type T = (RefCell<u32>, RefCell<f32>, RefCell<i64>, RefCell<f64>, RefCell<bool>, RefCell<u8>);
let deserialized: T = deserialize(&serialized).unwrap();
let bincode_deserialized: T = bincode::deserialize(&bincode_serialized).unwrap();
prop_assert_eq!(deserialized, bincode_deserialized);
});
}

#[test]
fn test_refcell_nested() {
use std::cell::RefCell;

type Nested = RefCell<Vec<u64>>;

proptest!(proptest_cfg(), |(vec in proptest::collection::vec(any::<u64>(), 0..=5))| {
let value = RefCell::new(vec);
let serialized = serialize(&value).unwrap();
let bincode_serialized = bincode::serialize(&value).unwrap();
prop_assert_eq!(&serialized, &bincode_serialized);

let deserialized: Nested = deserialize(&serialized).unwrap();
let bincode_deserialized: Nested = bincode::deserialize(&bincode_serialized).unwrap();
prop_assert_eq!(&*deserialized.borrow(), &*bincode_deserialized.borrow());
});
}

#[test]
fn test_refcell_with_struct() {
use std::cell::RefCell;

#[derive(
SchemaWrite,
SchemaRead,
Debug,
PartialEq,
Eq,
serde::Serialize,
serde::Deserialize,
Clone,
proptest_derive::Arbitrary,
)]
#[wincode(internal)]
struct SimpleData {
id: u32,
count: u64,
}

type SimpleRefCell = RefCell<SimpleData>;

proptest!(proptest_cfg(), |(data: SimpleData)| {
let value = RefCell::new(data);
let serialized = serialize(&value).unwrap();
let bincode_serialized = bincode::serialize(&value).unwrap();
prop_assert_eq!(&serialized, &bincode_serialized);

let deserialized: SimpleRefCell = deserialize(&serialized).unwrap();
let bincode_deserialized: SimpleRefCell = bincode::deserialize(&bincode_serialized).unwrap();
prop_assert_eq!(&*deserialized.borrow(), &*bincode_deserialized.borrow());
});
}

#[test]
fn test_refcell_type_meta_dynamic() {
assert!(matches!(
<RefCell<String> as SchemaRead<DefaultConfig>>::TYPE_META,
TypeMeta::Dynamic
));
proptest!(proptest_cfg(), |(value: u64)| {
let value = RefCell::new(value);

let serialized = serialize(&value).unwrap();
let bincode_serialized = bincode::serialize(&value).unwrap();
prop_assert_eq!(&serialized, &serialized);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prop_assert_eq!(&serialized, &serialized);
prop_assert_eq!(&serialized, &bincode_serialized);


let deserialized: RefCell<u64> = deserialize(&serialized).unwrap();
let bincode_deserialized: RefCell<u64> = bincode::deserialize(&bincode_serialized).unwrap();

let deser = deserialized.borrow();
let bincode = bincode_deserialized.borrow();
prop_assert_eq!(&*deser, &*bincode);
});
}

#[test]
fn test_refcell_borrow_error() {
let refcell = RefCell::new(42u32);

// Borrow mutably to cause serialization to fail
let _mut_borrow = refcell.borrow_mut();

let result = serialize(&refcell);
assert!(result.is_err());
}

#[test]
fn test_byte_order_configuration() {
let c = Configuration::default().with_big_endian();
Expand Down