From 83701e5588e10fdf6968598540f934a9731c07bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Wed, 25 Oct 2023 19:03:22 +0200 Subject: [PATCH] feat: Support generic types in entry points --- .../generic_contract/src/contract.rs | 4 ++ sylvia-derive/src/interfaces.rs | 2 +- sylvia-derive/src/lib.rs | 5 +- sylvia-derive/src/message.rs | 49 ++++++++++++------- sylvia-derive/src/parser.rs | 45 ++++++++++++++--- 5 files changed, 78 insertions(+), 27 deletions(-) diff --git a/examples/contracts/generic_contract/src/contract.rs b/examples/contracts/generic_contract/src/contract.rs index 4fffc93e..eef976ea 100644 --- a/examples/contracts/generic_contract/src/contract.rs +++ b/examples/contracts/generic_contract/src/contract.rs @@ -6,6 +6,9 @@ use sylvia::types::{ }; use sylvia::{contract, schemars}; +#[cfg(not(feature = "library"))] +use sylvia::entry_points; + pub struct GenericContract( std::marker::PhantomData<( InstantiateParam, @@ -16,6 +19,7 @@ pub struct GenericContract, ); +#[cfg_attr(not(feature = "library"), entry_points(generics))] #[contract] #[messages(cw1 as Cw1: custom(msg))] #[messages(generic as Generic: custom(msg))] diff --git a/sylvia-derive/src/interfaces.rs b/sylvia-derive/src/interfaces.rs index 0e75edd1..fbfe9718 100644 --- a/sylvia-derive/src/interfaces.rs +++ b/sylvia-derive/src/interfaces.rs @@ -159,7 +159,7 @@ impl Interfaces { quote! {} }; - let type_name = msg_ty.as_accessor_name(); + let type_name = msg_ty.as_accessor_name(false); quote! { <#module ::sv::Api #generics as #sylvia ::types::InterfaceApi> :: #type_name :: response_schemas_impl() } diff --git a/sylvia-derive/src/lib.rs b/sylvia-derive/src/lib.rs index a002c044..a5c6ca8e 100644 --- a/sylvia-derive/src/lib.rs +++ b/sylvia-derive/src/lib.rs @@ -258,9 +258,10 @@ pub fn entry_points(attr: TokenStream, item: TokenStream) -> TokenStream { #[cfg(not(tarpaulin_include))] fn entry_points_impl(attr: TokenStream2, item: TokenStream2) -> TokenStream2 { - fn inner(_attr: TokenStream2, item: TokenStream2) -> syn::Result { + fn inner(attr: TokenStream2, item: TokenStream2) -> syn::Result { + let attrs: parser::EntryPointArgs = parse2(attr)?; let input: ItemImpl = parse2(item)?; - let expanded = EntryPoints::new(&input).emit(); + let expanded = EntryPoints::new(&input, attrs).emit(); Ok(quote! { #input diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index e3ae3834..fc161615 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -3,7 +3,7 @@ use crate::crate_module; use crate::interfaces::Interfaces; use crate::parser::{ parse_associated_custom_type, parse_struct_message, ContractErrorAttr, ContractMessageAttr, - Custom, MsgAttr, MsgType, OverrideEntryPoints, + Custom, EntryPointArgs, MsgAttr, MsgType, OverrideEntryPoints, }; use crate::strip_generics::StripGenerics; use crate::utils::{ @@ -16,11 +16,12 @@ use proc_macro_error::emit_error; use quote::{quote, ToTokens}; use syn::fold::Fold; use syn::parse::{Parse, Parser}; +use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::visit::Visit; use syn::{ - parse_quote, Attribute, GenericParam, Ident, ItemImpl, ItemTrait, Pat, PatType, Path, - ReturnType, Signature, TraitItem, Type, WhereClause, WherePredicate, + parse_quote, Attribute, GenericArgument, GenericParam, Ident, ItemImpl, ItemTrait, Pat, + PatType, Path, ReturnType, Signature, Token, TraitItem, Type, WhereClause, WherePredicate, }; /// Representation of single struct message @@ -747,7 +748,7 @@ impl<'a> MsgVariant<'a> { let bracketed_generics = emit_bracketed_generics(generics); let interface_enum = quote! { < #module sv::Api #bracketed_generics as #sylvia ::types::InterfaceApi> }; - let type_name = msg_ty.as_accessor_name(); + let type_name = msg_ty.as_accessor_name(false); let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); match msg_ty { @@ -790,7 +791,7 @@ impl<'a> MsgVariant<'a> { } = self; let params = fields.iter().map(|field| field.emit_method_field()); - let type_name = msg_ty.as_accessor_name(); + let type_name = msg_ty.as_accessor_name(false); let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); match msg_ty { @@ -1023,12 +1024,9 @@ where custom_query: &Type, name: &Type, error: &Type, + contract_generics: &Option>, ) -> TokenStream { - let Self { - used_generics, - msg_ty, - .. - } = self; + let Self { msg_ty, .. } = self; let sylvia = crate_module(); let resp_type = match msg_ty { @@ -1038,16 +1036,19 @@ where let params = msg_ty.emit_ctx_params(custom_query); let values = msg_ty.emit_ctx_values(); let ep_name = msg_ty.emit_ep_name(); - let msg_name = msg_ty.emit_msg_name(true); - let bracketed_generics = emit_bracketed_generics(used_generics); + let bracketed_generics = match &contract_generics { + Some(generics) => quote! { ::< #generics > }, + None => quote! {}, + }; + let associated_name = msg_ty.as_accessor_name(true); quote! { #[#sylvia ::cw_std::entry_point] pub fn #ep_name ( #params , - msg: sv:: #msg_name #bracketed_generics, + msg: < #name < #contract_generics > as #sylvia ::types::ContractApi> :: #associated_name, ) -> Result<#resp_type, #error> { - msg.dispatch(&#name ::new() , ( #values )).map_err(Into::into) + msg.dispatch(&#name #bracketed_generics ::new() , ( #values )).map_err(Into::into) } } } @@ -1608,10 +1609,11 @@ pub struct EntryPoints<'a> { override_entry_points: OverrideEntryPoints, generics: Vec<&'a GenericParam>, where_clause: &'a Option, + attrs: EntryPointArgs, } impl<'a> EntryPoints<'a> { - pub fn new(source: &'a ItemImpl) -> Self { + pub fn new(source: &'a ItemImpl, attrs: EntryPointArgs) -> Self { let sylvia = crate_module(); let name = StripGenerics.fold_type(*source.self_ty.clone()); let override_entry_points = OverrideEntryPoints::new(&source.attrs); @@ -1643,6 +1645,7 @@ impl<'a> EntryPoints<'a> { override_entry_points, generics, where_clause, + attrs, } } @@ -1655,6 +1658,7 @@ impl<'a> EntryPoints<'a> { override_entry_points, generics, where_clause, + attrs, } = self; let sylvia = crate_module(); @@ -1683,6 +1687,10 @@ impl<'a> EntryPoints<'a> { .iter() .map(|variant| variant.function_name.clone()) .next(); + let contract_generics = match &attrs.generics { + Some(generics) => quote! { ::< #generics > }, + None => quote! {}, + }; #[cfg(not(tarpaulin_include))] { @@ -1696,6 +1704,7 @@ impl<'a> EntryPoints<'a> { &custom_query, name, error, + &attrs.generics, ), }, ); @@ -1706,7 +1715,13 @@ impl<'a> EntryPoints<'a> { let migrate = if migrate_not_overridden && migrate_variants.get_only_variant().is_some() { - migrate_variants.emit_default_entry_point(&custom_msg, &custom_query, name, error) + migrate_variants.emit_default_entry_point( + &custom_msg, + &custom_query, + name, + error, + &attrs.generics, + ) } else { quote! {} }; @@ -1722,7 +1737,7 @@ impl<'a> EntryPoints<'a> { env: #sylvia ::cw_std::Env, msg: #sylvia ::cw_std::Reply, ) -> Result<#sylvia ::cw_std::Response < #custom_msg >, #error> { - #name ::new(). #reply((deps, env).into(), msg).map_err(Into::into) + #name #contract_generics ::new(). #reply((deps, env).into(), msg).map_err(Into::into) } }, _ => quote! {}, diff --git a/sylvia-derive/src/parser.rs b/sylvia-derive/src/parser.rs index 836ed6df..1e58d9f3 100644 --- a/sylvia-derive/src/parser.rs +++ b/sylvia-derive/src/parser.rs @@ -13,9 +13,10 @@ use syn::{ use crate::crate_module; use crate::strip_generics::StripGenerics; -/// Parser arguments for `contract` macro +/// Parsed arguments for `contract` macro pub struct ContractArgs { - /// Module name wrapping generated messages, by default no additional module is created + /// Module in which contract impl block is defined. + /// Used only while implementing `Interface` on `Contract`. pub module: Option, } @@ -46,6 +47,31 @@ impl Parse for ContractArgs { } } +/// Parsed arguments for `entry_points` macro +pub struct EntryPointArgs { + /// Types used in place of contracts generics. + pub generics: Option>, +} + +impl Parse for EntryPointArgs { + fn parse(input: ParseStream) -> Result { + if input.is_empty() { + return Ok(Self { generics: None }); + } + + let path: Path = input.parse()?; + + let generics = match path.segments.last() { + Some(segment) if segment.ident == "generics" => Some(extract_generics_from_path(&path)), + _ => return Err(Error::new(path.span(), "Expected `generics`")), + }; + + let _: Nothing = input.parse()?; + + Ok(Self { generics }) + } +} + /// Type of message to be generated #[derive(PartialEq, Eq, Debug, Clone, Copy)] pub enum MsgType { @@ -158,11 +184,16 @@ impl MsgType { } } - pub fn as_accessor_name(&self) -> Option { + pub fn as_accessor_name(&self, is_wrapper: bool) -> Option { match self { + MsgType::Exec if is_wrapper => Some(parse_quote! { ContractExec }), + MsgType::Query if is_wrapper => Some(parse_quote! { ContractQuery }), + MsgType::Instantiate => Some(parse_quote! { Instantiate }), MsgType::Exec => Some(parse_quote! { Exec }), MsgType::Query => Some(parse_quote! { Query }), - _ => None, + MsgType::Migrate => Some(parse_quote! { Migrate }), + MsgType::Sudo => Some(parse_quote! { Sudo }), + MsgType::Reply => Some(parse_quote! { Reply }), } } } @@ -291,7 +322,7 @@ fn interface_has_custom(content: ParseStream) -> Result { Ok(customs) } -fn extract_generics_from_path(module: &mut Path) -> Punctuated { +fn extract_generics_from_path(module: &Path) -> Punctuated { let generics = module.segments.last().map(|segment| { match segment.arguments.clone(){ PathArguments::AngleBracketed(generics) => { @@ -318,8 +349,8 @@ impl Parse for ContractMessageAttr { let content; parenthesized!(content in input); - let mut module = content.parse()?; - let generics = extract_generics_from_path(&mut module); + let module = content.parse()?; + let generics = extract_generics_from_path(&module); let module = StripGenerics.fold_path(module); let _: Token![as] = content.parse()?;