Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support generics on messages attribute in main contract call #238

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 53 additions & 21 deletions sylvia-derive/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ pub struct Interfaces {
}

impl Interfaces {
fn merge_module_with_name(message_attr: &ContractMessageAttr, name: &syn::Ident) -> syn::Ident {
// ContractMessageAttr will fail to parse empty `#[messsages()]` attribute so we can safely unwrap here
let syn::PathSegment { ident, .. } = &message_attr.module.segments.last().unwrap();
let module_name = ident.to_string().to_case(Case::UpperCamel);
syn::Ident::new(&format!("{}{}", module_name, name), name.span())
}

pub fn new(source: &ItemImpl) -> Self {
let interfaces: Vec<_> = source
.attrs
Expand Down Expand Up @@ -90,11 +83,19 @@ impl Interfaces {
.iter()
.map(|interface| {
let ContractMessageAttr {
module, variant, ..
module,
variant,
generics,
..
} = interface;
let generics = if !generics.is_empty() {
quote! { < #generics > }
jawoznia marked this conversation as resolved.
Show resolved Hide resolved
} else {
quote! {}
};

let interface_enum =
quote! { <#module ::InterfaceTypes as #sylvia ::types::InterfaceMessages> };
quote! { <#module ::InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> };
if msg_ty == &MsgType::Query {
quote! { #variant ( #interface_enum :: Query) }
} else {
Expand All @@ -104,28 +105,46 @@ impl Interfaces {
.collect()
}

pub fn emit_messages_call(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_messages_call(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let enum_name = Self::merge_module_with_name(interface, msg_name);
let module = &interface.module;
quote! { &#module :: #enum_name :: messages()}
let ContractMessageAttr {
module, generics, ..
} = interface;
let generics = if !generics.is_empty() {
quote! { < #generics > }
jawoznia marked this conversation as resolved.
Show resolved Hide resolved
} else {
quote! {}
};
let type_name = msg_ty.as_accessor_name();
quote! {
&<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: messages()
}
})
.collect()
}

pub fn emit_deserialization_attempts(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_deserialization_attempts(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let ContractMessageAttr {
module, variant, ..
module, variant, generics, ..
} = interface;
let enum_name = Self::merge_module_with_name(interface, msg_name);
let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};

let type_name = msg_ty.as_accessor_name();
quote! {
let msgs = &#module :: #enum_name ::messages();
let msgs = &<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: messages();
if msgs.into_iter().any(|msg| msg == &recv_msg_name) {
match val.deserialize_into() {
Ok(msg) => return Ok(Self:: #variant (msg)),
Expand All @@ -137,13 +156,26 @@ impl Interfaces {
.collect()
}

pub fn emit_response_schemas_calls(&self, msg_name: &Ident) -> Vec<TokenStream> {
pub fn emit_response_schemas_calls(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let enum_name = Self::merge_module_with_name(interface, msg_name);
let module = &interface.module;
quote! { #module :: #enum_name :: response_schemas_impl()}
let ContractMessageAttr {
module, generics, ..
} = interface;

let generics = if !generics.is_empty() {
quote! { < #generics > }
} else {
quote! {}
};

let type_name = msg_ty.as_accessor_name();
quote! {
<#module :: InterfaceTypes #generics as #sylvia ::types::InterfaceMessages> :: #type_name :: response_schemas_impl()
}
})
.collect()
}
Expand Down
6 changes: 3 additions & 3 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ impl<'a> GlueMessage<'a> {

let msg_name = quote! {#contract ( #name)};
let mut messages_call_on_all_variants: Vec<TokenStream> =
interfaces.emit_messages_call(name);
interfaces.emit_messages_call(msg_ty);
messages_call_on_all_variants.push(quote! {&#name :: messages()});

let variants_cnt = messages_call_on_all_variants.len();
Expand Down Expand Up @@ -1004,7 +1004,7 @@ impl<'a> GlueMessage<'a> {

let dispatch_arm = quote! {#enum_name :: #contract (msg) => msg.dispatch(contract, ctx)};

let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(name);
let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(msg_ty);

#[cfg(not(tarpaulin_include))]
let contract_deserialization_attempt = quote! {
Expand All @@ -1020,7 +1020,7 @@ impl<'a> GlueMessage<'a> {
let ctx_type = msg_ty.emit_ctx_type(&custom.query_or_default());
let ret_type = msg_ty.emit_result_type(&custom.msg_or_default(), error);

let mut response_schemas_calls = interfaces.emit_response_schemas_calls(name);
let mut response_schemas_calls = interfaces.emit_response_schemas_calls(msg_ty);
response_schemas_calls.push(quote! {#name :: response_schemas_impl()});

let response_schemas = match name.to_string().as_str() {
Expand Down
34 changes: 31 additions & 3 deletions sylvia-derive/src/parser.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use proc_macro2::{Punct, TokenStream};
use proc_macro_error::emit_error;
use quote::quote;
use syn::fold::Fold;
use syn::parse::{Error, Nothing, Parse, ParseBuffer, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
parenthesized, parse_quote, Attribute, Ident, ImplItem, ImplItemMethod, ItemImpl, ItemTrait,
Path, Result, Token, TraitItem, Type,
parenthesized, parse_quote, Attribute, GenericArgument, Ident, ImplItem, ImplItemMethod,
ItemImpl, ItemTrait, Path, PathArguments, Result, Token, TraitItem, Type,
};

use crate::crate_module;
use crate::strip_generics::StripGenerics;

/// Parser arguments for `contract` macro
pub struct ContractArgs {
Expand Down Expand Up @@ -248,6 +251,7 @@
pub module: Path,
pub variant: Ident,
pub customs: Customs,
pub generics: Punctuated<GenericArgument, Token![,]>,
}

fn interface_has_custom(content: ParseStream) -> Result<Customs> {
Expand Down Expand Up @@ -285,14 +289,36 @@
Ok(customs)
}

fn extract_generics_from_path(module: &mut Path) -> Punctuated<GenericArgument, Token![,]> {
let generics = module.segments.last().map(|segment| {
match segment.arguments.clone(){
PathArguments::AngleBracketed(generics) => {
generics.args
},
PathArguments::None => Default::default(),
PathArguments::Parenthesized(generics) => {
emit_error!(
generics.span(), "Found paranthesis wrapping generics in `messages` attribute.";
note = "Expected `messages` attribute to be in form `#[messages(Path<generics> as Type)]`"

Check warning on line 302 in sylvia-derive/src/parser.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/parser.rs#L299-L302

Added lines #L299 - L302 were not covered by tests
);
Default::default()

Check warning on line 304 in sylvia-derive/src/parser.rs

View check run for this annotation

Codecov / codecov/patch

sylvia-derive/src/parser.rs#L304

Added line #L304 was not covered by tests
}
}
}).unwrap_or_default();

generics
}

#[cfg(not(tarpaulin_include))]
// False negative. It is being called in closure
impl Parse for ContractMessageAttr {
fn parse(input: ParseStream) -> Result<Self> {
let content;
parenthesized!(content in input);

let module = content.parse()?;
let mut module = content.parse()?;
let generics = extract_generics_from_path(&mut module);
let module = StripGenerics.fold_path(module);

let _: Token![as] = content.parse()?;
let variant = content.parse()?;
Expand All @@ -310,6 +336,7 @@
module,
variant,
customs,
generics,
})
}
}
Expand Down Expand Up @@ -474,6 +501,7 @@
entry_point,
msg_name,
msg_type,
..
} = self;

let sylvia = crate_module();
Expand Down
20 changes: 14 additions & 6 deletions sylvia/tests/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,21 @@
use sylvia::types::InstantiateCtx;
use sylvia_derive::contract;

use crate::{ExternalMsg, ExternalQuery};

pub struct Cw1Contract;

#[contract]
#[messages(crate::cw1<ExternalMsg, ExternalMsg, ExternalQuery> as Cw1)]
/// Required if interface returns generic `Response`
#[sv::custom(msg=ExternalMsg)]
impl Cw1Contract {
pub const fn new() -> Self {
Self
}

#[msg(instantiate)]
pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult<Response> {
pub fn instantiate(&self, _ctx: InstantiateCtx) -> StdResult<Response<ExternalMsg>> {

Check warning on line 51 in sylvia/tests/generics.rs

View check run for this annotation

Codecov / codecov/patch

sylvia/tests/generics.rs#L51

Added line #L51 was not covered by tests
Ok(Response::new())
}
}
Expand Down Expand Up @@ -91,12 +96,11 @@

#[cfg(all(test, feature = "mt"))]
mod tests {
use crate::cw1::{InterfaceTypes, Querier as Cw1Querier};
use crate::{ExternalMsg, ExternalQuery};
use cosmwasm_std::{testing::mock_dependencies, Addr, CosmosMsg, Empty, QuerierWrapper};

use crate::{cw1::Querier, ExternalMsg, ExternalQuery};

use crate::cw1::InterfaceTypes;
use sylvia::types::InterfaceMessages;

#[test]
fn construct_messages() {
let contract = Addr::unchecked("contract");
Expand All @@ -110,9 +114,13 @@
let querier: QuerierWrapper<ExternalQuery> = QuerierWrapper::new(&deps.querier);

let cw1_querier = crate::cw1::BoundQuerier::borrowed(&contract, &querier);
let _: Result<ExternalQuery, _> = Querier::some_query(&cw1_querier, ExternalMsg {});
let _: Result<ExternalQuery, _> =
crate::cw1::Querier::some_query(&cw1_querier, ExternalMsg {});
let _: Result<ExternalQuery, _> = cw1_querier.some_query(ExternalMsg {});

let contract_querier = crate::cw1_contract::BoundQuerier::borrowed(&contract, &querier);
let _: Result<ExternalQuery, _> = contract_querier.some_query(ExternalMsg {});

// Construct messages with Interface extension
let _ =
<InterfaceTypes<ExternalMsg, _, ExternalQuery> as InterfaceMessages>::Query::some_query(
Expand Down
Loading