Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions wincode/src/schema/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use {
schema::{size_of_elem_iter, write_elem_iter_prealloc_check},
},
alloc::{
borrow::Cow,
boxed::Box,
collections::{BTreeMap, BTreeSet, BinaryHeap, LinkedList, VecDeque},
rc::Rc,
Expand Down Expand Up @@ -1013,6 +1014,22 @@ unsafe impl<'de, C: Config> SchemaRead<'de, C> for &'de str {
}
}

#[cfg(feature = "alloc")]
unsafe impl<'a, C: Config> SchemaWrite<C> for Cow<'a, str> {
type Src = Self;

#[inline]
fn size_of(src: &Self::Src) -> WriteResult<usize> {
<str as SchemaWrite<C>>::size_of(src.as_ref())
}

#[inline]
fn write(writer: impl Writer, src: &Self::Src) -> WriteResult<()> {
C::LengthEncoding::prealloc_check::<u8>(src.len())?;
<str as SchemaWrite<C>>::write(writer, src.as_ref())
}
}

#[cfg(feature = "alloc")]
unsafe impl<'de, C: Config> SchemaRead<'de, C> for String {
type Dst = String;
Expand All @@ -1034,6 +1051,34 @@ unsafe impl<'de, C: Config> SchemaRead<'de, C> for String {
}
}

#[cfg(feature = "alloc")]
unsafe impl<'de, C: Config> SchemaRead<'de, C> for Cow<'de, str> {
type Dst = Cow<'de, str>;

#[inline]
fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
let len = C::LengthEncoding::read_prealloc_check::<u8>(reader.by_ref())?;
match reader.take_borrowed(len) {
Ok(bytes) => {
let string = core::str::from_utf8(bytes).map_err(invalid_utf8_encoding)?;
dst.write(Cow::Borrowed(string));
Ok(())
}
Err(crate::io::ReadError::UnsupportedBorrow(_)) => {
let mut bytes = Vec::with_capacity(len);
reader.copy_into_slice(bytes.spare_capacity_mut())?;
// SAFETY: `copy_into_slice` ensures we fill the entire `bytes.spare_capacity_mut()` slice.
unsafe { bytes.set_len(len) };
let string = String::from_utf8(bytes)
.map_err(|err| invalid_utf8_encoding(err.utf8_error()))?;
dst.write(Cow::Owned(string));
Ok(())
}
Err(err) => Err(err.into()),
}
}
}

/// Implement `SchemaWrite` and `SchemaRead` for types that may be iterated over sequentially.
///
/// Generally this should only be used on types for which we cannot provide an optimized implementation,
Expand Down
93 changes: 92 additions & 1 deletion wincode/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ mod tests {
config::{self, Config, Configuration, DefaultConfig},
containers, deserialize, deserialize_mut,
error::{self, invalid_tag_encoding},
io::{Reader, Writer},
io::{BorrowKind, Reader, Writer},
len::{BincodeLen, FixIntLen},
pod_wrapper,
proptest_config::proptest_cfg,
Expand All @@ -522,6 +522,7 @@ mod tests {
proptest::prelude::*,
std::{
alloc::Layout,
borrow::Cow,
cell::Cell,
collections::{BinaryHeap, HashMap, HashSet, VecDeque},
hash::{BuildHasher, Hasher},
Expand Down Expand Up @@ -3041,6 +3042,96 @@ mod tests {
prop_assert_eq!(&str, &schema_deserialized);
}

#[test]
fn test_cow_str(str in any::<String>()) {
let cow = Cow::<str>::Owned(str.clone());
let bincode_serialized = bincode::serialize(&cow).unwrap();
let schema_serialized = serialize(&cow).unwrap();
prop_assert_eq!(&bincode_serialized, &schema_serialized);

let bincode_deserialized: Cow<'_, str> = bincode::deserialize(&bincode_serialized).unwrap();
let schema_deserialized: Cow<'_, str> = deserialize(&schema_serialized).unwrap();
prop_assert_eq!(&cow, &bincode_deserialized);
prop_assert_eq!(&cow, &schema_deserialized);
}

#[test]
fn test_cow_str_prefers_borrowed_when_supported(str in any::<String>()) {
let cow = Cow::<str>::Owned(str.clone());
let serialized = serialize(&cow).unwrap();
let deserialized: Cow<'_, str> = deserialize(&serialized).unwrap();

prop_assert!(matches!(deserialized, Cow::Borrowed(_)));
prop_assert_eq!(deserialized.as_ref(), str);
}

#[test]
fn test_cow_str_falls_back_to_owned_when_borrow_is_unsupported(
str in any::<String>(),
) {
struct NoBorrowReader<'a> {
inner: &'a [u8],
}

impl<'a> Reader<'a> for NoBorrowReader<'a> {
fn peek_array<const N: usize>(&mut self) -> crate::io::ReadResult<&[u8; N]> {
let Some(src) = self.inner.get(..N) else {
return Err(crate::io::read_size_limit(N));
};
// SAFETY: `src` is exactly `N` bytes.
Ok(unsafe { &*src.as_ptr().cast::<[u8; N]>() })
}

fn copy_into_slice(
&mut self,
dst: &mut [MaybeUninit<u8>],
) -> crate::io::ReadResult<()> {
let len = dst.len();
let Some(src) = self.inner.get(..len) else {
return Err(crate::io::read_size_limit(len));
};
// SAFETY: `dst` points to `len` writable bytes and does not overlap `src`.
unsafe {
ptr::copy_nonoverlapping(
src.as_ptr(),
dst.as_mut_ptr().cast::<u8>(),
len,
)
};
self.inner = &self.inner[len..];
Ok(())
}

fn take_array<const N: usize>(&mut self) -> crate::io::ReadResult<[u8; N]> {
let Some((src, rest)) = self.inner.split_first_chunk() else {
return Err(crate::io::read_size_limit(N));
};
self.inner = rest;
Ok(*src)
}

fn take_borrowed(&mut self, _len: usize) -> crate::io::ReadResult<&'a [u8]> {
Err(crate::io::ReadError::UnsupportedBorrow(BorrowKind::Backing))
}

unsafe fn consume_unchecked(&mut self, amt: usize) {
self.inner = unsafe { self.inner.get_unchecked(amt..) };
}

fn consume(&mut self, amt: usize) {
self.inner = self.inner.get(amt..).unwrap_or_default();
}
}

let cow = Cow::<str>::Owned(str.clone());
let serialized = serialize(&cow).unwrap();
let reader = NoBorrowReader { inner: &serialized };
let deserialized = <Cow<'_, str> as SchemaRead<DefaultConfig>>::get(reader).unwrap();

prop_assert!(matches!(deserialized, Cow::Owned(_)));
prop_assert_eq!(deserialized.as_ref(), str);
}

#[test]
fn test_struct_zero_copy(val in any::<StructZeroCopy>()) {
let bincode_serialized = bincode::serialize(&val).unwrap();
Expand Down