diff --git a/wincode/src/schema/impls.rs b/wincode/src/schema/impls.rs index fd432284..c993ceaa 100644 --- a/wincode/src/schema/impls.rs +++ b/wincode/src/schema/impls.rs @@ -44,6 +44,7 @@ use { schema::{size_of_elem_iter, write_elem_iter_prealloc_check}, }, alloc::{ + borrow::Cow, boxed::Box, collections::{BTreeMap, BTreeSet, BinaryHeap, LinkedList, VecDeque}, rc::Rc, @@ -1013,6 +1014,22 @@ unsafe impl<'de, C: Config> SchemaRead<'de, C> for &'de str { } } +#[cfg(feature = "alloc")] +unsafe impl<'a, C: Config> SchemaWrite for Cow<'a, str> { + type Src = Self; + + #[inline] + fn size_of(src: &Self::Src) -> WriteResult { + >::size_of(src.as_ref()) + } + + #[inline] + fn write(writer: impl Writer, src: &Self::Src) -> WriteResult<()> { + C::LengthEncoding::prealloc_check::(src.len())?; + >::write(writer, src.as_ref()) + } +} + #[cfg(feature = "alloc")] unsafe impl<'de, C: Config> SchemaRead<'de, C> for String { type Dst = String; @@ -1034,6 +1051,34 @@ unsafe impl<'de, C: Config> SchemaRead<'de, C> for String { } } +#[cfg(feature = "alloc")] +unsafe impl<'de, C: Config> SchemaRead<'de, C> for Cow<'de, str> { + type Dst = Cow<'de, str>; + + #[inline] + fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit) -> ReadResult<()> { + let len = C::LengthEncoding::read_prealloc_check::(reader.by_ref())?; + match reader.take_borrowed(len) { + Ok(bytes) => { + let string = core::str::from_utf8(bytes).map_err(invalid_utf8_encoding)?; + dst.write(Cow::Borrowed(string)); + Ok(()) + } + Err(crate::io::ReadError::UnsupportedBorrow(_)) => { + let mut bytes = Vec::with_capacity(len); + reader.copy_into_slice(bytes.spare_capacity_mut())?; + // SAFETY: `copy_into_slice` ensures we fill the entire `bytes.spare_capacity_mut()` slice. + unsafe { bytes.set_len(len) }; + let string = String::from_utf8(bytes) + .map_err(|err| invalid_utf8_encoding(err.utf8_error()))?; + dst.write(Cow::Owned(string)); + Ok(()) + } + Err(err) => Err(err.into()), + } + } +} + /// Implement `SchemaWrite` and `SchemaRead` for types that may be iterated over sequentially. /// /// Generally this should only be used on types for which we cannot provide an optimized implementation, diff --git a/wincode/src/schema/mod.rs b/wincode/src/schema/mod.rs index 75539483..e6737509 100644 --- a/wincode/src/schema/mod.rs +++ b/wincode/src/schema/mod.rs @@ -511,7 +511,7 @@ mod tests { config::{self, Config, Configuration, DefaultConfig}, containers, deserialize, deserialize_mut, error::{self, invalid_tag_encoding}, - io::{Reader, Writer}, + io::{BorrowKind, Reader, Writer}, len::{BincodeLen, FixIntLen}, pod_wrapper, proptest_config::proptest_cfg, @@ -522,6 +522,7 @@ mod tests { proptest::prelude::*, std::{ alloc::Layout, + borrow::Cow, cell::Cell, collections::{BinaryHeap, HashMap, HashSet, VecDeque}, hash::{BuildHasher, Hasher}, @@ -3041,6 +3042,96 @@ mod tests { prop_assert_eq!(&str, &schema_deserialized); } + #[test] + fn test_cow_str(str in any::()) { + let cow = Cow::::Owned(str.clone()); + let bincode_serialized = bincode::serialize(&cow).unwrap(); + let schema_serialized = serialize(&cow).unwrap(); + prop_assert_eq!(&bincode_serialized, &schema_serialized); + + let bincode_deserialized: Cow<'_, str> = bincode::deserialize(&bincode_serialized).unwrap(); + let schema_deserialized: Cow<'_, str> = deserialize(&schema_serialized).unwrap(); + prop_assert_eq!(&cow, &bincode_deserialized); + prop_assert_eq!(&cow, &schema_deserialized); + } + + #[test] + fn test_cow_str_prefers_borrowed_when_supported(str in any::()) { + let cow = Cow::::Owned(str.clone()); + let serialized = serialize(&cow).unwrap(); + let deserialized: Cow<'_, str> = deserialize(&serialized).unwrap(); + + prop_assert!(matches!(deserialized, Cow::Borrowed(_))); + prop_assert_eq!(deserialized.as_ref(), str); + } + + #[test] + fn test_cow_str_falls_back_to_owned_when_borrow_is_unsupported( + str in any::(), + ) { + struct NoBorrowReader<'a> { + inner: &'a [u8], + } + + impl<'a> Reader<'a> for NoBorrowReader<'a> { + fn peek_array(&mut self) -> crate::io::ReadResult<&[u8; N]> { + let Some(src) = self.inner.get(..N) else { + return Err(crate::io::read_size_limit(N)); + }; + // SAFETY: `src` is exactly `N` bytes. + Ok(unsafe { &*src.as_ptr().cast::<[u8; N]>() }) + } + + fn copy_into_slice( + &mut self, + dst: &mut [MaybeUninit], + ) -> crate::io::ReadResult<()> { + let len = dst.len(); + let Some(src) = self.inner.get(..len) else { + return Err(crate::io::read_size_limit(len)); + }; + // SAFETY: `dst` points to `len` writable bytes and does not overlap `src`. + unsafe { + ptr::copy_nonoverlapping( + src.as_ptr(), + dst.as_mut_ptr().cast::(), + len, + ) + }; + self.inner = &self.inner[len..]; + Ok(()) + } + + fn take_array(&mut self) -> crate::io::ReadResult<[u8; N]> { + let Some((src, rest)) = self.inner.split_first_chunk() else { + return Err(crate::io::read_size_limit(N)); + }; + self.inner = rest; + Ok(*src) + } + + fn take_borrowed(&mut self, _len: usize) -> crate::io::ReadResult<&'a [u8]> { + Err(crate::io::ReadError::UnsupportedBorrow(BorrowKind::Backing)) + } + + unsafe fn consume_unchecked(&mut self, amt: usize) { + self.inner = unsafe { self.inner.get_unchecked(amt..) }; + } + + fn consume(&mut self, amt: usize) { + self.inner = self.inner.get(amt..).unwrap_or_default(); + } + } + + let cow = Cow::::Owned(str.clone()); + let serialized = serialize(&cow).unwrap(); + let reader = NoBorrowReader { inner: &serialized }; + let deserialized = as SchemaRead>::get(reader).unwrap(); + + prop_assert!(matches!(deserialized, Cow::Owned(_))); + prop_assert_eq!(deserialized.as_ref(), str); + } + #[test] fn test_struct_zero_copy(val in any::()) { let bincode_serialized = bincode::serialize(&val).unwrap();