From 458025976654c7eb78bc33ce7a18bae06a05a821 Mon Sep 17 00:00:00 2001 From: Martin Larralde Date: Sun, 3 Mar 2024 22:23:22 +0100 Subject: [PATCH] Implement blanket implementation for `Cow` wrapper --- Cargo.toml | 4 + README.md | 16 +-- src/derive.rs | 32 +++-- src/types/cow.rs | 109 ++++++++++++++++++ src/types/mod.rs | 4 + tests/derive_cow/mod.rs | 7 ++ tests/derive_cow/successes/assoc_function.rs | 21 ++++ .../successes/assoc_function_rettype.rs | 21 ++++ tests/derive_cow/successes/assoc_type.rs | 33 ++++++ tests/derive_cow/successes/receiver_ref.rs | 24 ++++ tests/derive_cow/successes/trait_generics.rs | 30 +++++ .../successes/where_clause_assoc_fn.rs | 36 ++++++ 12 files changed, 322 insertions(+), 15 deletions(-) create mode 100644 src/types/cow.rs create mode 100644 tests/derive_cow/mod.rs create mode 100644 tests/derive_cow/successes/assoc_function.rs create mode 100644 tests/derive_cow/successes/assoc_function_rettype.rs create mode 100644 tests/derive_cow/successes/assoc_type.rs create mode 100644 tests/derive_cow/successes/receiver_ref.rs create mode 100644 tests/derive_cow/successes/trait_generics.rs create mode 100644 tests/derive_cow/successes/where_clause_assoc_fn.rs diff --git a/Cargo.toml b/Cargo.toml index 1d474fe..bc68ba8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,3 +68,7 @@ harness = false name = "derive_ref" path = "tests/derive_ref/mod.rs" harness = false +[[test]] +name = "derive_cow" +path = "tests/derive_cow/mod.rs" +harness = false diff --git a/README.md b/README.md index 83a59a4..6c98cf1 100644 --- a/README.md +++ b/README.md @@ -48,13 +48,15 @@ provided the trait methods fit the constraints for that derive, such as only declaring methods with `&self` of `&mut self` as their receiver. The following derives are available: -| Derive | Impl block | `fn (&self)` | `fn (&mut self)` | `fn (self)` | -|--------|--------------------------------------------|--------------|------------------|-------------| -| Ref | `impl Trait for &T` | ✔️ | | | -| Rc | `impl Trait for Rc` | ✔️ | | | -| Arc | `impl Trait for Arc` | ✔️ | | | -| Mut | `impl Trait for &mut T` | ✔️ | ✔️ | | -| Box | `impl Trait for Box` | ✔️ | ✔️ | ✔️ | +| Derive | Impl block | `fn (&self)` | `fn (&mut self)` | `fn (self)` | +|--------|---------------------------------------------------------|--------------|------------------|-------------| +| Ref | `impl Trait for &T` | ✔️ | | | +| Rc | `impl Trait for Rc` | ✔️ | | | +| Arc | `impl Trait for Arc` | ✔️ | | | +| Mut | `impl Trait for &mut T` | ✔️ | ✔️ | | +| Box¹ | `impl Trait for Box` | ✔️ | ✔️ | | +| Box² | `impl Trait for Box` | ✔️ | ✔️ | ✔️ | +| Cow | `impl Trait for Cow<_, T>` | ✔️ | | | For instance, with our own version of `std::fmt::Write`, we can provide an implementation for `Box` and `&mut impl Write`: diff --git a/src/derive.rs b/src/derive.rs index 52ee6b1..f8630d6 100644 --- a/src/derive.rs +++ b/src/derive.rs @@ -1,4 +1,5 @@ use syn::parse_quote; +use syn::punctuated::Punctuated; use syn::spanned::Spanned; use crate::utils::deref_expr; @@ -24,6 +25,9 @@ pub trait WrapperType { /// The receivers allowed for this wrapper type. const RECEIVERS: &'static [Receiver]; + /// Additional types to add to the generic type bound. + const BOUNDS: &'static [&'static str] = &[]; + /// Wrap the given identifier into the wrapper type. fn wrap(ty: &syn::Ident) -> syn::Type; @@ -113,16 +117,28 @@ pub trait WrapperType { } // Add generic type for the type we are creating ourselves - if sized { - impl_generics.params.push(syn::GenericParam::Type( - parse_quote!(#generic_type: #trait_ident #trait_generic_names), - )); - } else { - impl_generics.params.push(syn::GenericParam::Type( - parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized), - )); + let span = generic_type.span(); + let mut bounds: Punctuated<_, _> = parse_quote!(#trait_ident #trait_generic_names); + if !sized { + bounds.push(parse_quote!(?Sized)); + } + for bound in Self::BOUNDS { + let bound_ident = syn::Ident::new(bound, span); + bounds.push(parse_quote!(#bound_ident)) } + // Add the type wrapper in the wrapper type for the generic. + impl_generics + .params + .push(syn::GenericParam::Type(syn::TypeParam { + attrs: Vec::new(), + ident: generic_type.clone(), + colon_token: Some(syn::Token![:](span)), + bounds, + eq_token: None, + default: None, + })); + Ok(parse_quote!( #[automatically_derived] impl #impl_generics #trait_ident #trait_generic_names for #wrapper_type #where_clause { diff --git a/src/types/cow.rs b/src/types/cow.rs new file mode 100644 index 0000000..c1173c3 --- /dev/null +++ b/src/types/cow.rs @@ -0,0 +1,109 @@ +use syn::parse_quote; + +use crate::derive::Receiver; +use crate::derive::WrapperType; + +struct CowType; + +impl WrapperType for CowType { + const NAME: &'static str = "Cow"; + const RECEIVERS: &'static [Receiver] = &[Receiver::Ref]; + const BOUNDS: &'static [&'static str] = &["ToOwned"]; + fn wrap(ty: &syn::Ident) -> syn::Type { + parse_quote!(std::borrow::Cow<'_, #ty>) + } +} + +pub fn derive(trait_: &syn::ItemTrait) -> syn::Result { + CowType::derive(trait_) +} + +#[cfg(test)] +mod tests { + mod derive { + + use syn::parse_quote; + + #[test] + fn empty() { + let trait_ = parse_quote!( + trait MyTrait {} + ); + let derived = super::super::derive(&trait_).unwrap(); + assert_eq!( + derived, + parse_quote!( + #[automatically_derived] + impl MyTrait for std::borrow::Cow<'_, MT> {} + ) + ); + } + + #[test] + fn receiver_ref() { + let trait_ = parse_quote!( + trait Trait { + fn my_method(&self); + } + ); + assert_eq!( + super::super::derive(&trait_).unwrap(), + parse_quote!( + #[automatically_derived] + impl Trait for std::borrow::Cow<'_, T> { + #[inline] + fn my_method(&self) { + (*(*self)).my_method() + } + } + ) + ); + } + + #[test] + fn receiver_mut() { + let trait_ = parse_quote!( + trait Trait { + fn my_method(&mut self); + } + ); + assert!(super::super::derive(&trait_).is_err()); + } + + #[test] + fn receiver_self() { + let trait_ = parse_quote!( + trait MyTrait { + fn my_method(self); + } + ); + assert!(super::super::derive(&trait_).is_err()); + } + + #[test] + fn receiver_arbitrary() { + let trait_ = parse_quote!( + trait Trait { + fn my_method(self: Box); + } + ); + assert!(super::super::derive(&trait_).is_err()); + } + + #[test] + fn generics() { + let trait_ = parse_quote!( + trait MyTrait {} + ); + let derived = super::super::derive(&trait_).unwrap(); + + assert_eq!( + derived, + parse_quote!( + #[automatically_derived] + impl + ?Sized + ToOwned> MyTrait for std::borrow::Cow<'_, MT> {} + ) + ); + } + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index d5c0bd5..cda44e9 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,5 +1,6 @@ mod arc; mod r#box; +mod cow; mod r#mut; mod rc; mod r#ref; @@ -9,6 +10,7 @@ mod r#ref; #[derive(Debug, PartialEq, Eq, Hash)] pub enum Type { Box, + Cow, Ref, Mut, Rc, @@ -19,6 +21,7 @@ impl Type { pub fn from_str(s: &str) -> Option { match s { "Box" => Some(Type::Box), + "Cow" => Some(Type::Cow), "Ref" => Some(Type::Ref), "Mut" => Some(Type::Mut), "Rc" => Some(Type::Rc), @@ -36,6 +39,7 @@ impl Type { pub fn defer_trait_methods(&self, trait_: &syn::ItemTrait) -> syn::Result { match self { Type::Box => self::r#box::derive(trait_), + Type::Cow => self::cow::derive(trait_), Type::Ref => self::r#ref::derive(trait_), Type::Mut => self::r#mut::derive(trait_), Type::Rc => self::rc::derive(trait_), diff --git a/tests/derive_cow/mod.rs b/tests/derive_cow/mod.rs new file mode 100644 index 0000000..985f5f4 --- /dev/null +++ b/tests/derive_cow/mod.rs @@ -0,0 +1,7 @@ +extern crate trybuild; + +fn main() { + let t = trybuild::TestCases::new(); + t.compile_fail(file!().replace("mod.rs", "fails/*.rs")); + t.pass(file!().replace("mod.rs", "successes/*.rs")); +} diff --git a/tests/derive_cow/successes/assoc_function.rs b/tests/derive_cow/successes/assoc_function.rs new file mode 100644 index 0000000..c8e3ef1 --- /dev/null +++ b/tests/derive_cow/successes/assoc_function.rs @@ -0,0 +1,21 @@ +use blanket::blanket; +use impls::impls; + +use std::borrow::Cow; + +#[blanket(derive(Cow))] +pub trait StaticChecker { + fn check(); +} + +#[derive(Default, Clone)] +struct NoOpChecker; + +impl StaticChecker for NoOpChecker { + fn check() {} +} + +fn main() { + assert!(impls!( NoOpChecker: StaticChecker)); + assert!(impls!(Cow: StaticChecker)); +} diff --git a/tests/derive_cow/successes/assoc_function_rettype.rs b/tests/derive_cow/successes/assoc_function_rettype.rs new file mode 100644 index 0000000..668cf85 --- /dev/null +++ b/tests/derive_cow/successes/assoc_function_rettype.rs @@ -0,0 +1,21 @@ +use blanket::blanket; +use impls::impls; + +use std::borrow::Cow; + +#[blanket(derive(Cow))] +pub trait StaticChecker { + fn check() -> Result<(), String>; +} + +#[derive(Default, Clone)] +struct NoOpChecker; + +impl StaticChecker for NoOpChecker { + fn check() -> Result<(), String> { Ok(()) } +} + +fn main() { + assert!(impls!( NoOpChecker: StaticChecker)); + assert!(impls!(Cow: StaticChecker)); +} diff --git a/tests/derive_cow/successes/assoc_type.rs b/tests/derive_cow/successes/assoc_type.rs new file mode 100644 index 0000000..cd72e7d --- /dev/null +++ b/tests/derive_cow/successes/assoc_type.rs @@ -0,0 +1,33 @@ +use std::borrow::Cow; +use std::sync::Arc; +use std::sync::RwLock; + +use blanket::blanket; +use impls::impls; + +#[blanket(derive(Cow))] +pub trait Counter { + type Return: Clone; // <- verify this + fn increment(&self) -> Self::Return; +} + +#[derive(Default, Clone)] +struct AtomicCounter { + count: Arc>, +} + +impl Counter for AtomicCounter { + // Generate something like `type Return = ::Return;`. + type Return = u8; + fn increment(&self) -> u8 { + let mut guard = self.count.try_write().unwrap(); + let out = *guard; + *guard += 1; + out + } +} + +fn main() { + assert!(impls!(AtomicCounter: Counter)); + assert!(impls!(Cow: Counter)); +} diff --git a/tests/derive_cow/successes/receiver_ref.rs b/tests/derive_cow/successes/receiver_ref.rs new file mode 100644 index 0000000..25df1cf --- /dev/null +++ b/tests/derive_cow/successes/receiver_ref.rs @@ -0,0 +1,24 @@ +extern crate blanket; +extern crate impls; + +use std::borrow::Cow; + +use blanket::blanket; +use impls::impls; + +#[blanket(derive(Cow))] +pub trait Counter { + fn count(&self); +} + +#[derive(Default, Clone)] +struct AtomicCounter {} + +impl Counter for AtomicCounter { + fn count(&self) {} +} + +fn main() { + assert!(impls!(AtomicCounter: Counter)); + assert!(impls!(Cow: Counter)); +} diff --git a/tests/derive_cow/successes/trait_generics.rs b/tests/derive_cow/successes/trait_generics.rs new file mode 100644 index 0000000..19e22c3 --- /dev/null +++ b/tests/derive_cow/successes/trait_generics.rs @@ -0,0 +1,30 @@ +extern crate blanket; +extern crate impls; + +use std::borrow::Cow; + +use blanket::blanket; +use impls::impls; + +#[blanket(derive(Cow))] +pub trait AsRef2 { + fn as_ref2(&self) -> &T; +} + +#[derive(Default, Clone)] +struct Owner { + owned: T, +} + +impl AsRef2 for Owner { + fn as_ref2(&self) -> &T { + &self.owned + } +} + +fn main() { + assert!(impls!(Owner: AsRef2)); + assert!(impls!(Cow>: AsRef2)); + assert!(impls!(Owner: AsRef2)); + assert!(impls!(Cow>: AsRef2)); +} diff --git a/tests/derive_cow/successes/where_clause_assoc_fn.rs b/tests/derive_cow/successes/where_clause_assoc_fn.rs new file mode 100644 index 0000000..f5c5ffc --- /dev/null +++ b/tests/derive_cow/successes/where_clause_assoc_fn.rs @@ -0,0 +1,36 @@ +use std::borrow::Cow; +use std::sync::Arc; +use std::sync::RwLock; + +use blanket::blanket; +use impls::impls; + +#[blanket(derive(Cow))] +pub trait Counter +where + T: Clone, +{ + fn increment(&self, t: T); + + fn super_helpful_helper(&self, t: T) + { + self.increment(t.clone()) + } +} + +#[derive(Default, Clone)] +struct AtomicCounter { + count: Arc>, +} + +impl Counter for AtomicCounter { + fn increment(&self, value: u8) { + let mut guard = self.count.try_write().unwrap(); + *guard += value; + } +} + +fn main() { + assert!(impls!(AtomicCounter: Counter)); + assert!(impls!(Cow: Counter)); +}