Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support generics on messages attribute in main contract call #238

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 53 additions & 21 deletions sylvia-derive/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ pub struct Interfaces {
}

impl Interfaces {
fn merge_module_with_name(message_attr: &ContractMessageAttr, name: &syn::Ident) -> syn::Ident {
// ContractMessageAttr will fail to parse empty `#[messsages()]` attribute so we can safely unwrap here
let syn::PathSegment { ident, .. } = &message_attr.module.segments.last().unwrap();
let module_name = ident.to_string().to_case(Case::UpperCamel);
syn::Ident::new(&format!("{}{}", module_name, name), name.span())
}

pub fn new(source: &ItemImpl) -> Self {
let interfaces: Vec<_> = source
.attrs
Expand Down Expand Up @@ -90,11 +83,19 @@ impl Interfaces {
.iter()
.map(|interface| {
let ContractMessageAttr {
module, variant, ..
module,
variant,
generics,
..
} = interface;
let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};

let interface_enum =
quote! { <#module ::InterfaceTypes as #sylvia ::types::InterfaceMessages> };
quote! { <#module ::InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> };
if msg_ty == &MsgType::Query {
quote! { #variant ( #interface_enum :: Query) }
} else {
Expand All @@ -104,28 +105,46 @@ impl Interfaces {
.collect()
}

pub fn emit_messages_call(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_messages_call(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let enum_name = Self::merge_module_with_name(interface, msg_name);
let module = &interface.module;
quote! { &#module :: #enum_name :: messages()}
let ContractMessageAttr {
module, generics, ..
} = interface;
let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};
let type_name = msg_ty.as_accessor_name();
quote! {
&<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: messages()
}
})
.collect()
}

pub fn emit_deserialization_attempts(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_deserialization_attempts(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let ContractMessageAttr {
module, variant, ..
module, variant, generics, ..
} = interface;
let enum_name = Self::merge_module_with_name(interface, msg_name);
let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};

let type_name = msg_ty.as_accessor_name();
quote! {
let msgs = &#module :: #enum_name ::messages();
let msgs = &<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: messages();
if msgs.into_iter().any(|msg| msg == &recv_msg_name) {
match val.deserialize_into() {
Ok(msg) => return Ok(Self:: #variant (msg)),
Expand All @@ -137,13 +156,26 @@ impl Interfaces {
.collect()
}

pub fn emit_response_schemas_calls(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_response_schemas_calls(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let enum_name = Self::merge_module_with_name(interface, msg_name);
let module = &interface.module;
quote! { #module :: #enum_name :: response_schemas_impl()}
let ContractMessageAttr {
module, generics, ..
} = interface;

let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};

let type_name = msg_ty.as_accessor_name();
quote! {
<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: response_schemas_impl()
}
})
.collect()
}
Expand Down
6 changes: 3 additions & 3 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ impl<'a> GlueMessage<'a> {

let msg_name = quote! {#contract ( #name)};
let mut messages_call_on_all_variants: Vec<TokenStream> =
interfaces.emit_messages_call(name);
interfaces.emit_messages_call(msg_ty);
messages_call_on_all_variants.push(quote! {&#name :: messages()});

let variants_cnt = messages_call_on_all_variants.len();
Expand Down Expand Up @@ -1004,7 +1004,7 @@ impl<'a> GlueMessage<'a> {

let dispatch_arm = quote! {#enum_name :: #contract (msg) => msg.dispatch(contract, ctx)};

let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(name);
let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(msg_ty);

#[cfg(not(tarpaulin_include))]
let contract_deserialization_attempt = quote! {
Expand All @@ -1020,7 +1020,7 @@ impl<'a> GlueMessage<'a> {
let ctx_type = msg_ty.emit_ctx_type(&custom.query_or_default());
let ret_type = msg_ty.emit_result_type(&custom.msg_or_default(), error);

let mut response_schemas_calls = interfaces.emit_response_schemas_calls(name);
let mut response_schemas_calls = interfaces.emit_response_schemas_calls(msg_ty);
response_schemas_calls.push(quote! {#name :: response_schemas_impl()});

let response_schemas = match name.to_string().as_str() {
Expand Down
34 changes: 31 additions & 3 deletions sylvia-derive/src/parser.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use proc_macro2::{Punct, TokenStream};
use proc_macro_error::emit_error;
use quote::quote;
use syn::fold::Fold;
use syn::parse::{Error, Nothing, Parse, ParseBuffer, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
parenthesized, parse_quote, Attribute, Ident, ImplItem, ImplItemMethod, ItemImpl, ItemTrait,
Path, Result, Token, TraitItem, Type,
parenthesized, parse_quote, Attribute, GenericArgument, Ident, ImplItem, ImplItemMethod,
ItemImpl, ItemTrait, Path, PathArguments, Result, Token, TraitItem, Type,
};

use crate::crate_module;
use crate::strip_generics::StripGenerics;

/// Parser arguments for `contract` macro
pub struct ContractArgs {
Expand Down Expand Up @@ -248,6 +251,7 @@ pub struct ContractMessageAttr {
pub module: Path,
pub variant: Ident,
pub customs: Customs,
pub generics: Punctuated<GenericArgument, Token![,]>,
}

fn interface_has_custom(content: ParseStream) -> Result<Customs> {
Expand Down Expand Up @@ -285,14 +289,36 @@ fn interface_has_custom(content: ParseStream) -> Result<Customs> {
Ok(customs)
}

fn extract_generics_from_path(module: &mut Path) -> Punctuated<GenericArgument, Token![,]> {
let generics = module.segments.last().map(|segment| {
match segment.arguments.clone(){
PathArguments::AngleBracketed(generics) => {
generics.args
},
PathArguments::None => Default::default(),
PathArguments::Parenthesized(generics) => {
emit_error!(
generics.span(), "Found paranthesis wrapping generics in `messages` attribute.";
note = "Expected `messages` attribute to be in form `#[messages(Path<generics> as Type)]`"
);
Default::default()
}
}
}).unwrap_or_default();

generics
}

#[cfg(not(tarpaulin_include))]
// False negative. It is being called in closure
impl Parse for ContractMessageAttr {
fn parse(input: ParseStream) -> Result<Self> {
let content;
parenthesized!(content in input);

let module = content.parse()?;
let mut module = content.parse()?;
let generics = extract_generics_from_path(&mut module);
let module = StripGenerics.fold_path(module);

let _: Token![as] = content.parse()?;
let variant = content.parse()?;
Expand All @@ -310,6 +336,7 @@ impl Parse for ContractMessageAttr {
module,
variant,
customs,
generics,
})
}
}
Expand Down Expand Up @@ -474,6 +501,7 @@ impl OverrideEntryPoint {
entry_point,
msg_name,
msg_type,
..
} = self;

let sylvia = crate_module();
Expand Down
20 changes: 14 additions & 6 deletions sylvia/tests/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,21 @@ pub mod cw1_contract {
use sylvia::types::InstantiateCtx;
use sylvia_derive::contract;

use crate::{ExternalMsg, ExternalQuery};

pub struct Cw1Contract;

#[contract]
#[messages(crate::cw1<ExternalMsg, ExternalMsg, ExternalQuery> as Cw1)]
/// Required if interface returns generic `Response`
#[sv::custom(msg=ExternalMsg)]
impl Cw1Contract {
pub const fn new() -> Self {
Self
}

#[msg(instantiate)]
pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult<Response> {
pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult<Response<ExternalMsg>> {
Ok(Response::new())
}
}
Expand Down Expand Up @@ -91,12 +96,11 @@ impl cosmwasm_std::CustomQuery for ExternalQuery {}

#[cfg(all(test, feature = "mt"))]
mod tests {
use crate::cw1::{InterfaceTypes, Querier as Cw1Querier};
use crate::{ExternalMsg, ExternalQuery};
use cosmwasm_std::{testing::mock_dependencies, Addr, CosmosMsg, Empty, QuerierWrapper};

use crate::{cw1::Querier, ExternalMsg, ExternalQuery};

use crate::cw1::InterfaceTypes;
use sylvia::types::InterfaceMessages;

#[test]
fn construct_messages() {
let contract = Addr::unchecked("contract");
Expand All @@ -110,9 +114,13 @@ mod tests {
let querier: QuerierWrapper<ExternalQuery> = QuerierWrapper::new(&deps.querier);

let cw1_querier = crate::cw1::BoundQuerier::borrowed(&contract, &querier);
let _: Result<ExternalQuery, _> = Querier::some_query(&cw1_querier, ExternalMsg {});
let _: Result<ExternalQuery, _> =
crate::cw1::Querier::some_query(&cw1_querier, ExternalMsg {});
let _: Result<ExternalQuery, _> = cw1_querier.some_query(ExternalMsg {});

let contract_querier = crate::cw1_contract::BoundQuerier::borrowed(&contract, &querier);
let _: Result<ExternalQuery, _> = contract_querier.some_query(ExternalMsg {});

// Construct messages with Interface extension
let _ =
<InterfaceTypes<ExternalMsg, _, ExternalQuery> as InterfaceMessages>::Query::some_query(
Expand Down