From 04e0edd8998fa897811eb169280c040338fecaa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Fri, 6 Oct 2023 17:42:31 +0200 Subject: [PATCH] feat: Support generics on `messages` attribute in main `contract` macro call --- sylvia-derive/src/interfaces.rs | 74 +++++++++++++++++++++++---------- sylvia-derive/src/message.rs | 6 +-- sylvia-derive/src/parser.rs | 34 +++++++++++++-- sylvia/tests/generics.rs | 20 ++++++--- 4 files changed, 101 insertions(+), 33 deletions(-) diff --git a/sylvia-derive/src/interfaces.rs b/sylvia-derive/src/interfaces.rs index f1d2a909..09559c70 100644 --- a/sylvia-derive/src/interfaces.rs +++ b/sylvia-derive/src/interfaces.rs @@ -15,13 +15,6 @@ pub struct Interfaces { } impl Interfaces { - fn merge_module_with_name(message_attr: &ContractMessageAttr, name: &syn::Ident) -> syn::Ident { - // ContractMessageAttr will fail to parse empty `#[messsages()]` attribute so we can safely unwrap here - let syn::PathSegment { ident, .. } = &message_attr.module.segments.last().unwrap(); - let module_name = ident.to_string().to_case(Case::UpperCamel); - syn::Ident::new(&format!("{}{}", module_name, name), name.span()) - } - pub fn new(source: &ItemImpl) -> Self { let interfaces: Vec<_> = source .attrs @@ -90,11 +83,19 @@ impl Interfaces { .iter() .map(|interface| { let ContractMessageAttr { - module, variant, .. + module, + variant, + generics, + .. } = interface; + let generics = if !generics.is_empty() { + quote! { < #generics > } + } else { + quote! {} + }; let interface_enum = - quote! { <#module ::InterfaceTypes as #sylvia ::types::InterfaceMessages> }; + quote! { <#module ::InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> }; if msg_ty == &MsgType::Query { quote! { #variant ( #interface_enum :: Query) } } else { @@ -104,28 +105,46 @@ impl Interfaces { .collect() } - pub fn emit_messages_call(&self, msg_name: &Ident) -> Vec { + pub fn emit_messages_call(&self, msg_ty: &MsgType) -> Vec { + let sylvia = crate_module(); + self.interfaces .iter() .map(|interface| { - let enum_name = Self::merge_module_with_name(interface, msg_name); - let module = &interface.module; - quote! { &#module :: #enum_name :: messages()} + let ContractMessageAttr { + module, generics, .. + } = interface; + let generics = if !generics.is_empty() { + quote! { < #generics > } + } else { + quote! {} + }; + let type_name = msg_ty.as_accessor_name(); + quote! { + &<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: messages() + } }) .collect() } - pub fn emit_deserialization_attempts(&self, msg_name: &Ident) -> Vec { + pub fn emit_deserialization_attempts(&self, msg_ty: &MsgType) -> Vec { + let sylvia = crate_module(); + self.interfaces .iter() .map(|interface| { let ContractMessageAttr { - module, variant, .. + module, variant, generics, .. } = interface; - let enum_name = Self::merge_module_with_name(interface, msg_name); + let generics = if !generics.is_empty() { + quote! { < #generics > } + } else { + quote! {} + }; + let type_name = msg_ty.as_accessor_name(); quote! { - let msgs = &#module :: #enum_name ::messages(); + let msgs = &<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: messages(); if msgs.into_iter().any(|msg| msg == &recv_msg_name) { match val.deserialize_into() { Ok(msg) => return Ok(Self:: #variant (msg)), @@ -137,13 +156,26 @@ impl Interfaces { .collect() } - pub fn emit_response_schemas_calls(&self, msg_name: &Ident) -> Vec { + pub fn emit_response_schemas_calls(&self, msg_ty: &MsgType) -> Vec { + let sylvia = crate_module(); + self.interfaces .iter() .map(|interface| { - let enum_name = Self::merge_module_with_name(interface, msg_name); - let module = &interface.module; - quote! { #module :: #enum_name :: response_schemas_impl()} + let ContractMessageAttr { + module, generics, .. + } = interface; + + let generics = if !generics.is_empty() { + quote! { < #generics > } + } else { + quote! {} + }; + + let type_name = msg_ty.as_accessor_name(); + quote! { + <#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: response_schemas_impl() + } }) .collect() } diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 4398f6e8..926e9b74 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -968,7 +968,7 @@ impl<'a> GlueMessage<'a> { let msg_name = quote! {#contract ( #name)}; let mut messages_call_on_all_variants: Vec = - interfaces.emit_messages_call(name); + interfaces.emit_messages_call(msg_ty); messages_call_on_all_variants.push(quote! {&#name :: messages()}); let variants_cnt = messages_call_on_all_variants.len(); @@ -1002,7 +1002,7 @@ impl<'a> GlueMessage<'a> { let dispatch_arm = quote! {#enum_name :: #contract (msg) => msg.dispatch(contract, ctx)}; - let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(name); + let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(msg_ty); #[cfg(not(tarpaulin_include))] let contract_deserialization_attempt = quote! { @@ -1018,7 +1018,7 @@ impl<'a> GlueMessage<'a> { let ctx_type = msg_ty.emit_ctx_type(&custom.query_or_default()); let ret_type = msg_ty.emit_result_type(&custom.msg_or_default(), error); - let mut response_schemas_calls = interfaces.emit_response_schemas_calls(name); + let mut response_schemas_calls = interfaces.emit_response_schemas_calls(msg_ty); response_schemas_calls.push(quote! {#name :: response_schemas_impl()}); let response_schemas = match name.to_string().as_str() { diff --git a/sylvia-derive/src/parser.rs b/sylvia-derive/src/parser.rs index 0993ca5b..fbbf67dc 100644 --- a/sylvia-derive/src/parser.rs +++ b/sylvia-derive/src/parser.rs @@ -1,14 +1,17 @@ use proc_macro2::{Punct, TokenStream}; use proc_macro_error::emit_error; use quote::quote; +use syn::fold::Fold; use syn::parse::{Error, Nothing, Parse, ParseBuffer, ParseStream, Parser}; +use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ - parenthesized, parse_quote, Attribute, Ident, ImplItem, ImplItemMethod, ItemImpl, ItemTrait, - Path, Result, Token, TraitItem, Type, + parenthesized, parse_quote, Attribute, GenericArgument, Ident, ImplItem, ImplItemMethod, + ItemImpl, ItemTrait, Path, PathArguments, Result, Token, TraitItem, Type, }; use crate::crate_module; +use crate::strip_generics::StripGenerics; /// Parser arguments for `contract` macro pub struct ContractArgs { @@ -248,6 +251,7 @@ pub struct ContractMessageAttr { pub module: Path, pub variant: Ident, pub customs: Customs, + pub generics: Punctuated, } fn interface_has_custom(content: ParseStream) -> Result { @@ -285,6 +289,26 @@ fn interface_has_custom(content: ParseStream) -> Result { Ok(customs) } +fn extract_generics_from_path(module: &mut Path) -> Punctuated { + let generics = module.segments.last().map(|segment| { + match segment.arguments.clone(){ + PathArguments::AngleBracketed(generics) => { + generics.args + }, + PathArguments::None => Default::default(), + PathArguments::Parenthesized(generics) => { + emit_error!( + generics.span(), "Found paranthesis wrapping generics in `messages` attribute."; + note = "Expected `messages` attribute to be in form `#[messages(Path as Type)]`" + ); + Default::default() + } + } + }).unwrap_or_default(); + + generics +} + #[cfg(not(tarpaulin_include))] // False negative. It is being called in closure impl Parse for ContractMessageAttr { @@ -292,7 +316,9 @@ impl Parse for ContractMessageAttr { let content; parenthesized!(content in input); - let module = content.parse()?; + let mut module = content.parse()?; + let generics = extract_generics_from_path(&mut module); + let module = StripGenerics.fold_path(module); let _: Token![as] = content.parse()?; let variant = content.parse()?; @@ -310,6 +336,7 @@ impl Parse for ContractMessageAttr { module, variant, customs, + generics, }) } } @@ -474,6 +501,7 @@ impl OverrideEntryPoint { entry_point, msg_name, msg_type, + .. } = self; let sylvia = crate_module(); diff --git a/sylvia/tests/generics.rs b/sylvia/tests/generics.rs index ef73535b..3d63fac0 100644 --- a/sylvia/tests/generics.rs +++ b/sylvia/tests/generics.rs @@ -34,16 +34,21 @@ pub mod cw1_contract { use sylvia::types::InstantiateCtx; use sylvia_derive::contract; + use crate::{ExternalMsg, ExternalQuery}; + pub struct Cw1Contract; #[contract] + #[messages(crate::cw1 as Cw1)] + /// Required if interface returns generic `Response` + #[sv::custom(msg=ExternalMsg)] impl Cw1Contract { pub const fn new() -> Self { Self } #[msg(instantiate)] - pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult { + pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult> { Ok(Response::new()) } } @@ -91,12 +96,11 @@ impl cosmwasm_std::CustomQuery for ExternalQuery {} #[cfg(all(test, feature = "mt"))] mod tests { + use crate::cw1::{InterfaceTypes, Querier as Cw1Querier}; + use crate::{ExternalMsg, ExternalQuery}; use cosmwasm_std::{testing::mock_dependencies, Addr, CosmosMsg, Empty, QuerierWrapper}; - - use crate::{cw1::Querier, ExternalMsg, ExternalQuery}; - - use crate::cw1::InterfaceTypes; use sylvia::types::InterfaceMessages; + #[test] fn construct_messages() { let contract = Addr::unchecked("contract"); @@ -110,9 +114,13 @@ mod tests { let querier: QuerierWrapper = QuerierWrapper::new(&deps.querier); let cw1_querier = crate::cw1::BoundQuerier::borrowed(&contract, &querier); - let _: Result = Querier::some_query(&cw1_querier, ExternalMsg {}); + let _: Result = + crate::cw1::Querier::some_query(&cw1_querier, ExternalMsg {}); let _: Result = cw1_querier.some_query(ExternalMsg {}); + let contract_querier = crate::cw1_contract::BoundQuerier::borrowed(&contract, &querier); + let _: Result = contract_querier.some_query(ExternalMsg {}); + // Construct messages with Interface extension let _ = as InterfaceMessages>::Query::some_query(