diff --git a/src/derive.rs b/src/derive.rs new file mode 100644 index 0000000..52ee6b1 --- /dev/null +++ b/src/derive.rs @@ -0,0 +1,175 @@ +use syn::parse_quote; +use syn::spanned::Spanned; + +use crate::utils::deref_expr; +use crate::utils::generics_declaration_to_generics; +use crate::utils::signature_to_associated_function_call; +use crate::utils::signature_to_method_call; +use crate::utils::trait_to_generic_ident; + +/// The different receivers supported on a method. +#[derive(Debug, PartialEq)] +pub enum Receiver { + Arbitrary, + Ref, + Mut, + Owned, +} + +/// A marker trait for types wrapping a single other type. +pub trait WrapperType { + /// A short name for the type being wrapper. + const NAME: &'static str; + + /// The receivers allowed for this wrapper type. + const RECEIVERS: &'static [Receiver]; + + /// Wrap the given identifier into the wrapper type. + fn wrap(ty: &syn::Ident) -> syn::Type; + + /// Check that the given receiver is supported for the wrapper type. + fn check_receiver(r: &syn::Receiver) -> syn::Result<()> { + let receivers = Self::RECEIVERS; + let err = if r.colon_token.is_some() && !receivers.contains(&Receiver::Arbitrary) { + Some(format!( + "cannot derive `{}` for a trait declaring methods with arbitrary receiver types", + Self::NAME + )) + } else if r.mutability.is_some() && !receivers.contains(&Receiver::Mut) { + Some(format!( + "cannot derive `{}` for a trait declaring `&mut self` methods", + Self::NAME + )) + } else if r.reference.is_none() && !receivers.contains(&Receiver::Owned) { + Some(format!( + "cannot derive `{}` for a trait declaring `self` methods", + Self::NAME + )) + } else { + None + }; + if let Some(msg) = err { + Err(syn::Error::new(r.span(), msg)) + } else { + Ok(()) + } + } + + /// Generate the derived implementation for the given trait. + fn derive(trait_: &syn::ItemTrait) -> syn::Result { + // build an identifier for the generic type used for the implementation + let trait_ident = &trait_.ident; + let generic_type = trait_to_generic_ident(&trait_); + let wrapper_type = Self::wrap(&generic_type); + + // build the generics for the impl block: + // we use the same generics as the trait itself, plus + // a generic type that implements the trait for which we provide the + // blanket implementation + let trait_generics = &trait_.generics; + let where_clause = &trait_.generics.where_clause; + let mut impl_generics = trait_generics.clone(); + + // we must however remove the generic type bounds, to avoid repeating them + let mut trait_generic_names = trait_generics.clone(); + trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?; + + // build the methods + let mut methods: Vec = Vec::new(); + let mut assoc_types: Vec = Vec::new(); + for item in trait_.items.iter() { + if let syn::TraitItem::Fn(ref m) = item { + methods.push(Self::derive_method( + m, + &trait_ident, + &generic_type, + &trait_generic_names, + )?) + } + + if let syn::TraitItem::Type(t) = item { + let t_ident = &t.ident; + let attrs = &t.attrs; + + let t_generics = &t.generics; + let where_clause = &t.generics.where_clause; + let mut t_generic_names = t_generics.clone(); + t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?; + + let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; ); + assoc_types.push(item); + } + } + + // check if any method has a `Self` receiver, which would mean we cannot + // relax the `Sized` trait requirement + let mut sized = false; + for item in trait_.items.iter() { + if let syn::TraitItem::Fn(ref m) = item { + if let Some(r) = m.sig.receiver() { + sized |= r.reference.is_none(); + } + } + } + + // Add generic type for the type we are creating ourselves + if sized { + impl_generics.params.push(syn::GenericParam::Type( + parse_quote!(#generic_type: #trait_ident #trait_generic_names), + )); + } else { + impl_generics.params.push(syn::GenericParam::Type( + parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized), + )); + } + + Ok(parse_quote!( + #[automatically_derived] + impl #impl_generics #trait_ident #trait_generic_names for #wrapper_type #where_clause { + #(#assoc_types)* + #(#methods)* + } + )) + } + + /// Generate the derived implementation for a single method of a trait. + fn derive_method( + m: &syn::TraitItemFn, + trait_ident: &syn::Ident, + generic_type: &syn::Ident, + trait_generic_names: &syn::Generics, + ) -> syn::Result { + let mut call: syn::Expr = if let Some(r) = m.sig.receiver() { + Self::check_receiver(r)?; + let mut call = signature_to_method_call(&m.sig)?; + if r.reference.is_some() { + call.receiver = Box::new(deref_expr(deref_expr(*call.receiver))); + } else { + call.receiver = Box::new(deref_expr(*call.receiver)); + } + call.into() + } else { + let call = signature_to_associated_function_call( + &m.sig, + &trait_ident, + &generic_type, + &trait_generic_names, + )?; + call.into() + }; + + if let Some(async_) = m.sig.asyncness { + let span = async_.span(); + call = syn::ExprAwait { + attrs: Vec::new(), + base: Box::new(call), + dot_token: syn::Token![.](span), + await_token: syn::Token![await](span), + } + .into(); + } + + let signature = &m.sig; + Ok(syn::parse_quote!(#[inline] #signature { #call })) + } +} diff --git a/src/derive/mod.rs b/src/derive/mod.rs deleted file mode 100644 index 884960d..0000000 --- a/src/derive/mod.rs +++ /dev/null @@ -1,52 +0,0 @@ -mod arc; -mod r#box; -mod r#mut; -mod rc; -mod r#ref; - -// --------------------------------------------------------------------------- - -#[derive(Debug, PartialEq, Eq, Hash)] -pub enum Derive { - Box, - Ref, - Mut, - Rc, - Arc, -} - -impl Derive { - pub fn from_str(s: &str) -> Option { - match s { - "Box" => Some(Derive::Box), - "Ref" => Some(Derive::Ref), - "Mut" => Some(Derive::Mut), - "Rc" => Some(Derive::Rc), - "Arc" => Some(Derive::Arc), - _ => None, - } - } - - pub fn from_path(p: &syn::Path) -> Option { - p.segments - .first() - .and_then(|s| Self::from_str(&s.ident.to_string())) - } - - pub fn defer_trait_methods(&self, trait_: &syn::ItemTrait) -> syn::Result { - match self { - Derive::Box => self::r#box::derive(trait_), - Derive::Ref => self::r#ref::derive(trait_), - Derive::Mut => self::r#mut::derive(trait_), - Derive::Rc => self::rc::derive(trait_), - Derive::Arc => self::arc::derive(trait_), - } - } -} - -// --------------------------------------------------------------------------- - -/// A marker trait for types wrapping a single other type. -trait WrapperType { - fn wrap(ty: &syn::Ident) -> syn::Type; -} diff --git a/src/items.rs b/src/items.rs deleted file mode 100644 index 17c71fc..0000000 --- a/src/items.rs +++ /dev/null @@ -1,139 +0,0 @@ -use syn::parse_quote; -use syn::spanned::Spanned; - -use crate::utils::deref_expr; -use crate::utils::generics_declaration_to_generics; -use crate::utils::signature_to_associated_function_call; -use crate::utils::signature_to_method_call; -use crate::utils::trait_to_generic_ident; - -/// Derive the delegate function for an `impl` block. -pub fn derive_impl_item_fn( - m: &syn::TraitItemFn, - trait_ident: &syn::Ident, - generic_type: &syn::Ident, - trait_generic_names: &syn::Generics, - check_receiver: F, -) -> syn::Result -where - F: Fn(&syn::Receiver) -> syn::Result<()>, -{ - let mut call: syn::Expr = if let Some(r) = m.sig.receiver() { - check_receiver(r)?; - let mut call = signature_to_method_call(&m.sig)?; - if r.reference.is_some() { - call.receiver = Box::new(deref_expr(deref_expr(*call.receiver))); - } else { - call.receiver = Box::new(deref_expr(*call.receiver)); - } - call.into() - } else { - let call = signature_to_associated_function_call( - &m.sig, - &trait_ident, - &generic_type, - &trait_generic_names, - )?; - call.into() - }; - - if let Some(async_) = m.sig.asyncness { - let span = async_.span(); - call = syn::ExprAwait { - attrs: Vec::new(), - base: Box::new(call), - dot_token: syn::Token![.](span), - await_token: syn::Token![await](span), - } - .into(); - } - - let signature = &m.sig; - Ok(syn::parse_quote!(#[inline] #signature { #call })) -} - -/// Derive the implementation for -pub fn derive_impl( - trait_: &syn::ItemTrait, - check_receiver: F, - generate_wrapper_type: G, -) -> syn::Result -where - F: Fn(&syn::Receiver) -> syn::Result<()>, - G: Fn(&syn::Ident) -> syn::Type, -{ - // build an identifier for the generic type used for the implementation - let trait_ident = &trait_.ident; - let generic_type = trait_to_generic_ident(&trait_); - let wrapper_type = generate_wrapper_type(&generic_type); - - // build the generics for the impl block: - // we use the same generics as the trait itself, plus - // a generic type that implements the trait for which we provide the - // blanket implementation - let trait_generics = &trait_.generics; - let where_clause = &trait_.generics.where_clause; - let mut impl_generics = trait_generics.clone(); - - // we must however remove the generic type bounds, to avoid repeating them - let mut trait_generic_names = trait_generics.clone(); - trait_generic_names.params = generics_declaration_to_generics(&trait_generics.params)?; - - // build the methods - let mut methods: Vec = Vec::new(); - let mut assoc_types: Vec = Vec::new(); - for item in trait_.items.iter() { - if let syn::TraitItem::Fn(ref m) = item { - methods.push(derive_impl_item_fn( - m, - &trait_ident, - &generic_type, - &trait_generic_names, - &check_receiver, - )?) - } - - if let syn::TraitItem::Type(t) = item { - let t_ident = &t.ident; - let attrs = &t.attrs; - - let t_generics = &t.generics; - let where_clause = &t.generics.where_clause; - let mut t_generic_names = t_generics.clone(); - t_generic_names.params = generics_declaration_to_generics(&t_generics.params)?; - - let item = parse_quote!( #(#attrs)* type #t_ident #t_generics = <#generic_type as #trait_ident #trait_generic_names>::#t_ident #t_generic_names #where_clause ; ); - assoc_types.push(item); - } - } - - // check if any method has a `Self` receiver, which would mean we cannot - // relax the `Sized` trait requirement - let mut sized = false; - for item in trait_.items.iter() { - if let syn::TraitItem::Fn(ref m) = item { - if let Some(r) = m.sig.receiver() { - sized |= r.reference.is_none(); - } - } - } - - // Add generic type for the type we are creating ourselves - if sized { - impl_generics.params.push(syn::GenericParam::Type( - parse_quote!(#generic_type: #trait_ident #trait_generic_names), - )); - } else { - impl_generics.params.push(syn::GenericParam::Type( - parse_quote!(#generic_type: #trait_ident #trait_generic_names + ?Sized), - )); - } - - Ok(parse_quote!( - #[automatically_derived] - impl #impl_generics #trait_ident #trait_generic_names for #wrapper_type #where_clause { - #(#assoc_types)* - #(#methods)* - } - )) -} diff --git a/src/lib.rs b/src/lib.rs index bfec6c3..046efd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,14 +14,14 @@ use syn::{parse_macro_input, punctuated::Punctuated, spanned::Spanned}; mod default; mod derive; -mod items; +mod types; mod utils; // --------------------------------------------------------------------------- struct Args { default: Option, - derives: HashSet, + derives: HashSet, } impl Args { @@ -35,7 +35,7 @@ impl Args { Punctuated::::parse_separated_nonempty, )?; for pair in types.into_pairs() { - if let Some(d) = derive::Derive::from_path(pair.value()) { + if let Some(d) = types::Type::from_path(pair.value()) { derives.insert(d); } else { return Err(syn::Error::new( diff --git a/src/derive/arc.rs b/src/types/arc.rs similarity index 91% rename from src/derive/arc.rs rename to src/types/arc.rs index fef0469..8c572a7 100644 --- a/src/derive/arc.rs +++ b/src/types/arc.rs @@ -1,38 +1,20 @@ use syn::parse_quote; -use syn::spanned::Spanned; -use super::WrapperType; -use crate::items::derive_impl; +use crate::derive::Receiver; +use crate::derive::WrapperType; -pub struct ArcType; +struct ArcType; impl WrapperType for ArcType { + const NAME: &'static str = "Arc"; + const RECEIVERS: &'static [Receiver] = &[Receiver::Ref]; fn wrap(ty: &syn::Ident) -> syn::Type { parse_quote!(std::sync::Arc<#ty>) } } pub fn derive(trait_: &syn::ItemTrait) -> syn::Result { - derive_impl( - trait_, - |r| { - let err = if r.colon_token.is_some() { - Some("cannot derive `Arc` for a trait declaring methods with arbitrary receiver types") - } else if r.mutability.is_some() { - Some("cannot derive `Arc` for a trait declaring `&mut self` methods") - } else if r.reference.is_none() { - Some("cannot derive `Arc` for a trait declaring `self` methods") - } else { - None - }; - if let Some(msg) = err { - Err(syn::Error::new(r.span(), msg)) - } else { - Ok(()) - } - }, - ArcType::wrap, - ) + ArcType::derive(trait_) } #[cfg(test)] diff --git a/src/derive/box.rs b/src/types/box.rs similarity index 94% rename from src/derive/box.rs rename to src/types/box.rs index 3768116..d34a987 100644 --- a/src/derive/box.rs +++ b/src/types/box.rs @@ -1,34 +1,20 @@ -use syn::{parse_quote, spanned::Spanned}; +use syn::parse_quote; -use crate::items::derive_impl; - -use super::WrapperType; +use crate::derive::Receiver; +use crate::derive::WrapperType; pub struct BoxType; impl WrapperType for BoxType { + const NAME: &'static str = "Box"; + const RECEIVERS: &'static [Receiver] = &[Receiver::Ref, Receiver::Mut, Receiver::Owned]; fn wrap(ty: &syn::Ident) -> syn::Type { parse_quote!(std::boxed::Box<#ty>) } } pub fn derive(trait_: &syn::ItemTrait) -> syn::Result { - derive_impl( - trait_, - |r| { - let err = if r.colon_token.is_some() { - Some("cannot derive `Box` for a trait declaring methods with arbitrary receiver types") - } else { - None - }; - if let Some(msg) = err { - Err(syn::Error::new(r.span(), msg)) - } else { - Ok(()) - } - }, - BoxType::wrap, - ) + BoxType::derive(trait_) } #[cfg(test)] diff --git a/src/types/mod.rs b/src/types/mod.rs new file mode 100644 index 0000000..d5c0bd5 --- /dev/null +++ b/src/types/mod.rs @@ -0,0 +1,45 @@ +mod arc; +mod r#box; +mod r#mut; +mod rc; +mod r#ref; + +// --------------------------------------------------------------------------- + +#[derive(Debug, PartialEq, Eq, Hash)] +pub enum Type { + Box, + Ref, + Mut, + Rc, + Arc, +} + +impl Type { + pub fn from_str(s: &str) -> Option { + match s { + "Box" => Some(Type::Box), + "Ref" => Some(Type::Ref), + "Mut" => Some(Type::Mut), + "Rc" => Some(Type::Rc), + "Arc" => Some(Type::Arc), + _ => None, + } + } + + pub fn from_path(p: &syn::Path) -> Option { + p.segments + .first() + .and_then(|s| Self::from_str(&s.ident.to_string())) + } + + pub fn defer_trait_methods(&self, trait_: &syn::ItemTrait) -> syn::Result { + match self { + Type::Box => self::r#box::derive(trait_), + Type::Ref => self::r#ref::derive(trait_), + Type::Mut => self::r#mut::derive(trait_), + Type::Rc => self::rc::derive(trait_), + Type::Arc => self::arc::derive(trait_), + } + } +} diff --git a/src/derive/mut.rs b/src/types/mut.rs similarity index 92% rename from src/derive/mut.rs rename to src/types/mut.rs index 5eceff0..49cadb4 100644 --- a/src/derive/mut.rs +++ b/src/types/mut.rs @@ -1,26 +1,20 @@ -use syn::{parse_quote, spanned::Spanned}; +use syn::parse_quote; -use crate::items::derive_impl; +use crate::derive::Receiver; +use crate::derive::WrapperType; + +struct MutType; + +impl WrapperType for MutType { + const NAME: &'static str = "Mut"; + const RECEIVERS: &'static [Receiver] = &[Receiver::Ref, Receiver::Mut]; + fn wrap(ty: &syn::Ident) -> syn::Type { + parse_quote!(&mut #ty) + } +} pub fn derive(trait_: &syn::ItemTrait) -> syn::Result { - derive_impl( - trait_, - |r| { - let err = if r.colon_token.is_some() { - Some("cannot derive `Mut` for a trait declaring methods with arbitrary receiver types") - } else if r.reference.is_none() { - Some("cannot derive `Mut` for a trait declaring `self` methods") - } else { - None - }; - if let Some(msg) = err { - Err(syn::Error::new(r.span(), msg)) - } else { - Ok(()) - } - }, - |generic_type| parse_quote!(&mut #generic_type), - ) + MutType::derive(trait_) } #[cfg(test)] diff --git a/src/derive/rc.rs b/src/types/rc.rs similarity index 91% rename from src/derive/rc.rs rename to src/types/rc.rs index bd11f59..b448db1 100644 --- a/src/derive/rc.rs +++ b/src/types/rc.rs @@ -1,28 +1,20 @@ -use syn::{parse_quote, spanned::Spanned}; +use syn::parse_quote; -use crate::items::derive_impl; +use crate::derive::Receiver; +use crate::derive::WrapperType; + +struct RcType; + +impl WrapperType for RcType { + const NAME: &'static str = "Rc"; + const RECEIVERS: &'static [Receiver] = &[Receiver::Ref]; + fn wrap(ty: &syn::Ident) -> syn::Type { + parse_quote!(std::rc::Rc<#ty>) + } +} pub fn derive(trait_: &syn::ItemTrait) -> syn::Result { - derive_impl( - trait_, - |r| { - let err = if r.colon_token.is_some() { - Some("cannot derive `Rc` for a trait declaring methods with arbitrary receiver types") - } else if r.mutability.is_some() { - Some("cannot derive `Rc` for a trait declaring `&mut self` methods") - } else if r.reference.is_none() { - Some("cannot derive `Rc` for a trait declaring `self` methods") - } else { - None - }; - if let Some(msg) = err { - Err(syn::Error::new(r.span(), msg)) - } else { - Ok(()) - } - }, - |generic_type| parse_quote!(std::rc::Rc<#generic_type>), - ) + RcType::derive(trait_) } #[cfg(test)] diff --git a/src/derive/ref.rs b/src/types/ref.rs similarity index 92% rename from src/derive/ref.rs rename to src/types/ref.rs index 1888e9b..3a73bea 100644 --- a/src/derive/ref.rs +++ b/src/types/ref.rs @@ -1,28 +1,20 @@ -use syn::{parse_quote, spanned::Spanned}; +use syn::parse_quote; -use crate::items::derive_impl; +use crate::derive::Receiver; +use crate::derive::WrapperType; + +struct RefType; + +impl WrapperType for RefType { + const NAME: &'static str = "Ref"; + const RECEIVERS: &'static [Receiver] = &[Receiver::Ref]; + fn wrap(ty: &syn::Ident) -> syn::Type { + parse_quote!(&#ty) + } +} pub fn derive(trait_: &syn::ItemTrait) -> syn::Result { - derive_impl( - trait_, - |r| { - let err = if r.colon_token.is_some() { - Some("cannot derive `Ref` for a trait declaring methods with arbitrary receiver types") - } else if r.mutability.is_some() { - Some("cannot derive `Ref` for a trait declaring `&mut self` methods") - } else if r.reference.is_none() { - Some("cannot derive `Ref` for a trait declaring `self` methods") - } else { - None - }; - if let Some(msg) = err { - Err(syn::Error::new(r.span(), msg)) - } else { - Ok(()) - } - }, - |generic_type| parse_quote!(&#generic_type), - ) + RefType::derive(trait_) } #[cfg(test)]