diff --git a/examples/Cargo.lock b/examples/Cargo.lock index b850700a..91bd63c0 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -601,57 +601,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "generic_contract" -version = "0.5.0" -dependencies = [ - "anyhow", - "cosmwasm-schema", - "cosmwasm-std", - "custom-and-generic", - "cw-multi-test", - "cw-storage-plus", - "cw-utils", - "cw1", - "generic", - "serde", - "sylvia", -] - -[[package]] -name = "generic_iface_on_contract" -version = "0.5.0" -dependencies = [ - "anyhow", - "cosmwasm-schema", - "cosmwasm-std", - "custom-and-generic", - "cw-multi-test", - "cw-storage-plus", - "cw-utils", - "cw1", - "generic", - "serde", - "sylvia", -] - -[[package]] -name = "generics_forwarded" -version = "0.5.0" -dependencies = [ - "anyhow", - "cosmwasm-schema", - "cosmwasm-std", - "custom-and-generic", - "cw-multi-test", - "cw-storage-plus", - "cw-utils", - "cw1", - "generic", - "serde", - "sylvia", -] - [[package]] name = "getrandom" version = "0.2.11" diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 8286636a..23dcee66 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -16,9 +16,9 @@ members = [ "contracts/cw20-base", "contracts/entry-points-overriding", "contracts/custom", - "contracts/generic_contract", - "contracts/generics_forwarded", - "contracts/generic_iface_on_contract", + # "contracts/generic_contract", + # "contracts/generics_forwarded", + # "contracts/generic_iface_on_contract", ] resolver = "2" diff --git a/examples/interfaces/custom-and-generic/src/lib.rs b/examples/interfaces/custom-and-generic/src/lib.rs index b871bd83..23fc1180 100644 --- a/examples/interfaces/custom-and-generic/src/lib.rs +++ b/examples/interfaces/custom-and-generic/src/lib.rs @@ -1,66 +1,53 @@ -use cosmwasm_std::{CosmosMsg, CustomMsg, Response, StdError}; +use cosmwasm_std::{CosmosMsg, Response, StdError}; -use serde::de::DeserializeOwned; -use serde::Deserialize; -use sylvia::types::{ExecCtx, QueryCtx}; +use sylvia::types::{CustomMsg, CustomQuery, ExecCtx, QueryCtx}; use sylvia::{interface, schemars}; #[interface] #[sv::custom(msg=CustomMsgT, query=CustomQueryT)] -pub trait CustomAndGeneric< - Exec1T, - Exec2T, - Exec3T, - Query1T, - Query2T, - Query3T, - RetT, - CustomMsgT, - CustomQueryT, -> where - for<'msg_de> Exec1T: CustomMsg + Deserialize<'msg_de>, - Exec2T: sylvia::types::CustomMsg, - Exec3T: sylvia::types::CustomMsg, - Query1T: sylvia::types::CustomMsg, - Query2T: sylvia::types::CustomMsg, - Query3T: sylvia::types::CustomMsg, - RetT: CustomMsg + DeserializeOwned, - CustomMsgT: CustomMsg + DeserializeOwned, - CustomQueryT: sylvia::types::CustomQuery + 'static, -{ +pub trait CustomAndGeneric { type Error: From; + type Exec1T: CustomMsg; + type Exec2T: CustomMsg; + type Exec3T: CustomMsg; + type Query1T: CustomMsg; + type Query2T: CustomMsg; + type Query3T: CustomMsg; + type RetT: CustomMsg; + type CustomMsgT: CustomMsg; + type CustomQueryT: CustomQuery + 'static; #[msg(exec)] fn custom_generic_execute_one( &self, - ctx: ExecCtx, - msgs1: Vec>, - msgs2: Vec>, - ) -> Result, Self::Error>; + ctx: ExecCtx, + msgs1: Vec>, + msgs2: Vec>, + ) -> Result, Self::Error>; #[msg(exec)] fn custom_generic_execute_two( &self, - ctx: ExecCtx, - msgs1: Vec>, - msgs2: Vec>, - ) -> Result, Self::Error>; + ctx: ExecCtx, + msgs1: Vec>, + msgs2: Vec>, + ) -> Result, Self::Error>; #[msg(query)] fn custom_generic_query_one( &self, - ctx: QueryCtx, - param1: Query1T, - param2: Query2T, - ) -> Result; + ctx: QueryCtx, + param1: Self::Query1T, + param2: Self::Query2T, + ) -> Result; #[msg(query)] fn custom_generic_query_two( &self, - ctx: QueryCtx, - param1: Query2T, - param2: Query3T, - ) -> Result; + ctx: QueryCtx, + param1: Self::Query2T, + param2: Self::Query3T, + ) -> Result; } #[cfg(test)] @@ -69,6 +56,8 @@ mod tests { use cosmwasm_std::{Addr, CosmosMsg, Empty, QuerierWrapper}; use sylvia::types::{InterfaceApi, SvCustomMsg, SvCustomQuery}; + use crate::sv::Querier; + #[test] fn construct_messages() { let contract = Addr::unchecked("contract"); @@ -96,24 +85,27 @@ mod tests { let deps = mock_dependencies(); let querier_wrapper: QuerierWrapper = QuerierWrapper::new(&deps.querier); - let querier = super::sv::BoundQuerier::borrowed(&contract, &querier_wrapper); + let querier = super::sv::BoundQuerier::< + _, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomQuery, + >::borrowed(&contract, &querier_wrapper); let _: Result = - super::sv::Querier::<_, _, _, SvCustomMsg>::custom_generic_query_one( - &querier, - SvCustomMsg {}, - SvCustomMsg {}, - ); - // let _: Result = - // querier.custom_generic_query_one(SvCustomMsg {}, SvCustomMsg {}); + super::sv::Querier::custom_generic_query_one(&querier, SvCustomMsg {}, SvCustomMsg {}); + let _: Result = + querier.custom_generic_query_one(SvCustomMsg {}, SvCustomMsg {}); + let _: Result = + super::sv::Querier::custom_generic_query_two(&querier, SvCustomMsg {}, SvCustomMsg {}); let _: Result = - super::sv::Querier::::custom_generic_query_two( - &querier, - SvCustomMsg {}, - SvCustomMsg {}, - ); - // let _: Result = - // querier.custom_generic_query_two(SvCustomMsg {}, SvCustomMsg {}); + querier.custom_generic_query_two(SvCustomMsg {}, SvCustomMsg {}); // Construct messages with Interface extension let _ = -where - for<'msg_de> Exec1T: CustomMsg + Deserialize<'msg_de>, - Exec2T: CustomMsg, - Exec3T: CustomMsg, - Query1T: CustomMsg, - Query2T: CustomMsg, - Query3T: CustomMsg, - RetT: CustomMsg, -{ +pub trait Generic { type Error: From; + type Exec1T: CustomMsg; + type Exec2T: CustomMsg; + type Exec3T: CustomMsg; + type Query1T: CustomMsg; + type Query2T: CustomMsg; + type Query3T: CustomMsg; + type RetT: CustomMsg; #[msg(exec)] fn generic_exec_one( &self, ctx: ExecCtx, - msgs1: Vec>, - msgs2: Vec>, + msgs1: Vec>, + msgs2: Vec>, ) -> Result; #[msg(exec)] fn generic_exec_two( &self, ctx: ExecCtx, - msgs1: Vec>, - msgs2: Vec>, + msgs1: Vec>, + msgs2: Vec>, ) -> Result; #[msg(query)] fn generic_query_one( &self, ctx: QueryCtx, - param1: Query1T, - param2: Query2T, - ) -> Result; + param1: Self::Query1T, + param2: Self::Query2T, + ) -> Result; #[msg(query)] fn generic_query_two( &self, ctx: QueryCtx, - param1: Query2T, - param2: Query3T, - ) -> Result; + param1: Self::Query2T, + param2: Self::Query3T, + ) -> Result; } #[cfg(test)] @@ -55,6 +52,8 @@ mod tests { use cosmwasm_std::{testing::mock_dependencies, Addr, CosmosMsg, Empty, QuerierWrapper}; use sylvia::types::{InterfaceApi, SvCustomMsg}; + use crate::sv::Querier; + #[test] fn construct_messages() { let contract = Addr::unchecked("contract"); @@ -81,21 +80,22 @@ mod tests { let deps = mock_dependencies(); let querier_wrapper: QuerierWrapper = QuerierWrapper::new(&deps.querier); - let querier = super::sv::BoundQuerier::borrowed(&contract, &querier_wrapper); + let querier = super::sv::BoundQuerier::< + Empty, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + SvCustomMsg, + >::borrowed(&contract, &querier_wrapper); let _: Result = - super::sv::Querier::<_, _, _, SvCustomMsg>::generic_query_one( - &querier, - SvCustomMsg {}, - SvCustomMsg {}, - ); + super::sv::Querier::generic_query_one(&querier, SvCustomMsg {}, SvCustomMsg {}); let _: Result = - super::sv::Querier::::generic_query_two( - &querier, - SvCustomMsg {}, - SvCustomMsg {}, - ); - // let _: Result = querier.generic_query_one(SvCustomMsg {}, SvCustomMsg {}); - // let _: Result = querier.generic_query_two(SvCustomMsg {}, SvCustomMsg {}); + super::sv::Querier::generic_query_two(&querier, SvCustomMsg {}, SvCustomMsg {}); + let _: Result = querier.generic_query_one(SvCustomMsg {}, SvCustomMsg {}); + let _: Result = querier.generic_query_two(SvCustomMsg {}, SvCustomMsg {}); // Construct messages with Interface extension let _ = (Vec<&'a TraitItemType>); + +impl<'a> AssociatedTypes<'a> { + pub fn new(source: &'a ItemTrait) -> Self { + let associated_types: Vec<_> = source + .items + .iter() + .filter_map(|item| match item { + TraitItem::Type(ty) if !RESERVED_TYPES.contains(&ty.ident.to_string().as_str()) => { + Some(ty) + } + _ => None, + }) + .collect(); + + Self(associated_types) + } + + pub fn as_where_predicates(&self) -> Vec { + self.0 + .iter() + .map(|associated| { + let name = &associated.ident; + let colon = &associated.colon_token; + let bound = &associated.bounds; + parse_quote! { #name #colon #bound } + }) + .collect() + } + + pub fn as_where_clause(&self) -> Option { + let predicates = self.as_where_predicates(); + if !predicates.is_empty() { + parse_quote! { where #(#predicates),* } + } else { + None + } + } + + pub fn as_names(&self) -> Vec<&Ident> { + self.0.iter().map(|associated| &associated.ident).collect() + } + + pub fn as_types_declaration(&self) -> &Vec<&TraitItemType> { + &self.0 + } + + pub fn emit_types_definition(&self) -> Vec { + self.as_names() + .iter() + .map(|name| quote! { type #name = #name; }) + .collect() + } + + pub fn emit_contract_predicate(&self, trait_name: &Ident) -> TokenStream { + let predicate = quote! { ContractT: #trait_name }; + if self.0.is_empty() { + return predicate; + } + + let bounds = self.0.iter().map(|associated| { + let name = &associated.ident; + quote! { #name = #name } + }); + + quote! { + #predicate < #(#bounds,)* > + } + } +} diff --git a/sylvia-derive/src/check_generics.rs b/sylvia-derive/src/check_generics.rs index 35680e94..4eaedd42 100644 --- a/sylvia-derive/src/check_generics.rs +++ b/sylvia-derive/src/check_generics.rs @@ -1,14 +1,15 @@ +use proc_macro2::Ident; use syn::visit::Visit; -use syn::{parse_quote, GenericArgument, GenericParam, Type}; +use syn::{parse_quote, GenericArgument, GenericParam, Path, TraitItemType, Type}; /// Provides method extracting `syn::Path`. /// Inteded to be used with `syn::GenericParam` and `syn::GenericArgument`. pub trait GetPath { - fn get_path(&self) -> Option; + fn get_path(&self) -> Option; } impl GetPath for GenericParam { - fn get_path(&self) -> Option { + fn get_path(&self) -> Option { match self { GenericParam::Type(ty) => { let ident = &ty.ident; @@ -20,7 +21,7 @@ impl GetPath for GenericParam { } impl GetPath for GenericArgument { - fn get_path(&self) -> Option { + fn get_path(&self) -> Option { match self { GenericArgument::Type(Type::Path(path)) => { let path = &path.path; @@ -31,6 +32,19 @@ impl GetPath for GenericArgument { } } +impl GetPath for TraitItemType { + fn get_path(&self) -> Option { + let ident = &self.ident; + Some(parse_quote!(#ident)) + } +} + +impl GetPath for Ident { + fn get_path(&self) -> Option { + Some(parse_quote! { #self }) + } +} + /// Traverses AST tree and checks if generics are used in method signatures. #[derive(Debug)] pub struct CheckGenerics<'g, Generic> { diff --git a/sylvia-derive/src/input.rs b/sylvia-derive/src/input.rs index 189f0a2b..24989503 100644 --- a/sylvia-derive/src/input.rs +++ b/sylvia-derive/src/input.rs @@ -1,8 +1,9 @@ -use proc_macro2::{Span, TokenStream}; +use proc_macro2::TokenStream; use proc_macro_error::emit_error; use quote::quote; use syn::{GenericArgument, GenericParam, Ident, ItemImpl, ItemTrait, PathArguments, TraitItem}; +use crate::associated_types::AssociatedTypes; use crate::interfaces::Interfaces; use crate::message::{ ContractApi, ContractEnumMessage, EnumMessage, GlueMessage, InterfaceApi, MsgVariants, @@ -10,6 +11,7 @@ use crate::message::{ }; use crate::multitest::{MultitestHelpers, TraitMultitestHelpers}; use crate::parser::{ContractArgs, ContractErrorAttr, Custom, MsgType, OverrideEntryPoints}; +use crate::querier::Querier; use crate::remote::Remote; use crate::utils::is_trait; use crate::variant_descs::AsVariantDescs; @@ -17,8 +19,8 @@ use crate::variant_descs::AsVariantDescs; /// Preprocessed `interface` macro input pub struct TraitInput<'a> { item: &'a ItemTrait, - generics: Vec<&'a GenericParam>, custom: Custom<'a>, + associated_types: AssociatedTypes<'a>, } /// Preprocessed `contract` macro input for non-trait impl block @@ -34,9 +36,13 @@ pub struct ImplInput<'a> { impl<'a> TraitInput<'a> { #[cfg(not(tarpaulin_include))] - // This requires invalid implementation which would fail at compile time and making it impossible to test pub fn new(item: &'a ItemTrait) -> Self { - let generics = item.generics.params.iter().collect(); + if !item.generics.params.is_empty() { + emit_error!( + item.ident.span(), "Generics on traits are not supported. Use associated types instead."; + note = "Sylvia interfaces can be implemented only a single time per contract."; + ); + } if !item .items @@ -51,28 +57,31 @@ impl<'a> TraitInput<'a> { } let custom = Custom::new(&item.attrs); + let associated_types = AssociatedTypes::new(item); Self { item, - generics, custom, + associated_types, } } pub fn process(&self) -> TokenStream { + let Self { + associated_types, + item, + custom, + } = self; let messages = self.emit_messages(); let multitest_helpers = self.emit_helpers(); - let remote = Remote::new(&Interfaces::default()).emit(); + let remote = Remote::new(&Interfaces::default(), associated_types).emit(); + let associated_names = associated_types.as_names(); - let querier = MsgVariants::new( - self.item.as_variants(), - MsgType::Query, - &self.generics, - &self.item.generics.where_clause, - ) - .emit_querier(); + let query_variants = + MsgVariants::new(item.as_variants(), MsgType::Query, &associated_names, &None); + let querier = Querier::new(&query_variants, associated_types).emit_trait_querier(); - let interface_messages = InterfaceApi::new(self.item, &self.generics, &self.custom).emit(); + let interface_messages = InterfaceApi::new(item, associated_types, custom).emit(); #[cfg(not(tarpaulin_include))] { @@ -103,8 +112,8 @@ impl<'a> TraitInput<'a> { } fn emit_messages(&self) -> TokenStream { - let exec = self.emit_msg(&Ident::new("ExecMsg", Span::mixed_site()), MsgType::Exec); - let query = self.emit_msg(&Ident::new("QueryMsg", Span::mixed_site()), MsgType::Query); + let exec = self.emit_msg(MsgType::Exec); + let query = self.emit_msg(MsgType::Query); #[cfg(not(tarpaulin_include))] { @@ -116,8 +125,24 @@ impl<'a> TraitInput<'a> { } } - fn emit_msg(&self, name: &Ident, msg_ty: MsgType) -> TokenStream { - EnumMessage::new(name, self.item, msg_ty, &self.generics, &self.custom).emit() + fn emit_msg(&self, msg_ty: MsgType) -> TokenStream { + let where_clause = &self.associated_types.as_where_clause(); + let associated_names = self.associated_types.as_names(); + let variants = MsgVariants::new( + self.item.as_variants(), + msg_ty, + &associated_names, + where_clause, + ); + + EnumMessage::new( + self.item, + msg_ty, + &self.custom, + variants, + &self.associated_types, + ) + .emit() } } @@ -182,7 +207,7 @@ impl<'a> ImplInput<'a> { ) .emit_querier(); let messages = self.emit_messages(); - let remote = Remote::new(&self.interfaces).emit(); + let remote = Remote::new(&self.interfaces, &Default::default()).emit(); let querier_from_impl = self.interfaces.emit_querier_from_impl(); let contract_api = ContractApi::new(item, generics, custom, interfaces).emit(); diff --git a/sylvia-derive/src/lib.rs b/sylvia-derive/src/lib.rs index a5c6ca8e..ebe0a7b0 100644 --- a/sylvia-derive/src/lib.rs +++ b/sylvia-derive/src/lib.rs @@ -6,15 +6,18 @@ use quote::quote; use syn::fold::Fold; use syn::{parse2, parse_quote, ItemImpl, ItemTrait, Path}; +mod associated_types; pub(crate) mod check_generics; mod input; mod interfaces; mod message; mod multitest; mod parser; +mod querier; mod remote; mod strip_generics; mod strip_input; +mod strip_self_path; mod utils; mod variant_descs; diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index a683edaa..c2d8ad9e 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -1,3 +1,4 @@ +use crate::associated_types::AssociatedTypes; use crate::check_generics::{CheckGenerics, GetPath}; use crate::crate_module; use crate::interfaces::Interfaces; @@ -6,6 +7,7 @@ use crate::parser::{ Custom, EntryPointArgs, MsgAttr, MsgType, OverrideEntryPoints, }; use crate::strip_generics::StripGenerics; +use crate::strip_self_path::StripSelfPath; use crate::utils::{ as_where_clause, emit_bracketed_generics, extract_return_type, filter_generics, filter_wheres, process_fields, @@ -23,7 +25,7 @@ use syn::spanned::Spanned; use syn::visit::Visit; use syn::{ parse_quote, Attribute, GenericArgument, GenericParam, Ident, ItemImpl, ItemTrait, Pat, - PatType, Path, ReturnType, Signature, Token, TraitItem, Type, WhereClause, WherePredicate, + PatType, Path, ReturnType, Signature, Token, Type, WhereClause, WherePredicate, }; /// Representation of single struct message @@ -111,7 +113,7 @@ impl<'a> StructMessage<'a> { let fields_names: Vec<_> = fields.iter().map(MsgField::name).collect(); let parameters = fields.iter().map(|field| { let name = field.name; - let ty = field.ty; + let ty = &field.ty; quote! { #name : #ty} }); let fields = fields.iter().map(MsgField::emit); @@ -149,14 +151,9 @@ impl<'a> StructMessage<'a> { /// Representation of single enum message pub struct EnumMessage<'a> { - name: &'a Ident, - trait_name: &'a Ident, - variants: Vec>, - generics: Vec<&'a GenericParam>, - unused_generics: Vec<&'a GenericParam>, - all_generics: &'a [&'a GenericParam], - wheres: Vec<&'a WherePredicate>, - full_where: Option<&'a WhereClause>, + source: &'a ItemTrait, + variants: MsgVariants<'a, Ident>, + associated_types: &'a AssociatedTypes<'a>, msg_ty: MsgType, resp_type: Type, query_type: Type, @@ -164,39 +161,12 @@ pub struct EnumMessage<'a> { impl<'a> EnumMessage<'a> { pub fn new( - name: &'a Ident, source: &'a ItemTrait, - ty: MsgType, - generics: &'a [&'a GenericParam], + msg_ty: MsgType, custom: &'a Custom, + variants: MsgVariants<'a, Ident>, + associated_types: &'a AssociatedTypes<'a>, ) -> Self { - let trait_name = &source.ident; - - let mut generics_checker = CheckGenerics::new(generics); - let variants: Vec<_> = source - .items - .iter() - .filter_map(|item| match item { - TraitItem::Method(method) => { - let msg_attr = method.attrs.iter().find(|attr| attr.path.is_ident("msg"))?; - let attr = match MsgAttr::parse.parse2(msg_attr.tokens.clone()) { - Ok(attr) => attr, - Err(err) => { - emit_error!(method.span(), err); - return None; - } - }; - - if attr == ty { - Some(MsgVariant::new(&method.sig, &mut generics_checker, attr)) - } else { - None - } - } - _ => None, - }) - .collect(); - let associated_exec = parse_associated_custom_type(source, "ExecC"); let associated_query = parse_associated_custom_type(source, "QueryC"); @@ -210,136 +180,79 @@ impl<'a> EnumMessage<'a> { .or(associated_query) .unwrap_or_else(Custom::default_type); - let (used_generics, unused_generics) = generics_checker.used_unused(); - let wheres = filter_wheres(&source.generics.where_clause, generics, &used_generics); - Self { - name, - trait_name, + source, variants, - generics: used_generics, - unused_generics, - all_generics: generics, - wheres, - full_where: source.generics.where_clause.as_ref(), - msg_ty: ty, + associated_types, + msg_ty, resp_type, query_type, } } pub fn emit(&self) -> TokenStream { - let sylvia = crate_module(); - let Self { - name, - trait_name, + source, variants, - generics, - unused_generics, - all_generics, - wheres, - full_where, + associated_types, msg_ty, resp_type, query_type, } = self; - let match_arms = variants.iter().map(|variant| variant.emit_dispatch_leg()); - let mut msgs: Vec = variants - .iter() - .map(|var| var.name.to_string().to_case(Case::Snake)) - .collect(); + let trait_name = &source.ident; + let enum_name = msg_ty.emit_msg_name(false); + let unique_enum_name = + Ident::new(&format!("{}{}", trait_name, enum_name), enum_name.span()); + + let match_arms = variants.emit_dispatch_legs(); + let mut msgs = variants.as_names_snake_cased(); msgs.sort(); let msgs_cnt = msgs.len(); - let variants_constructors = variants.iter().map(MsgVariant::emit_variants_constructors); - let variants = variants.iter().map(MsgVariant::emit); - let where_clause = if !wheres.is_empty() { - quote! { - where #(#wheres,)* - } - } else { - quote! {} - }; + let variants_constructors = variants.emit_constructors(); + let msg_variants = variants.emit(); let ctx_type = msg_ty.emit_ctx_type(query_type); - let dispatch_type = msg_ty.emit_result_type(resp_type, &parse_quote!(C::Error)); + let dispatch_type = msg_ty.emit_result_type(resp_type, &parse_quote!(ContractT::Error)); - let all_generics = emit_bracketed_generics(all_generics); - let phantom = if generics.is_empty() { - quote! {} - } else if MsgType::Query == *msg_ty { - quote! { - #[serde(skip)] - #[returns((#(#generics,)*))] - _Phantom(std::marker::PhantomData<( #(#generics,)* )>), - } - } else { - quote! { - #[serde(skip)] - _Phantom(std::marker::PhantomData<( #(#generics,)* )>), - } - }; - - let match_arms = if !generics.is_empty() { - quote! { - #(#match_arms,)* - _Phantom(_) => Err(#sylvia ::cw_std::StdError::generic_err("Phantom message should not be constructed.")).map_err(Into::into), - } - } else { - quote! { - #(#match_arms,)* - } - }; - - let generics = emit_bracketed_generics(generics); + let used_generics = variants.used_generics(); + let unused_generics = variants.unused_generics(); + let where_predicates = associated_types.as_where_predicates(); + let where_clause = variants.where_clause(); + let contract_predicate = associated_types.emit_contract_predicate(trait_name); - let unique_enum_name = Ident::new(&format!("{}{}", trait_name, name), name.span()); + let phantom_variant = variants.emit_phantom_variant(); + let phatom_match_arm = variants.emit_phantom_match_arm(); + let bracketed_used_generics = emit_bracketed_generics(used_generics); let ep_name = msg_ty.emit_ep_name(); - let messages_fn_name = Ident::new(&format!("{}_messages", ep_name), name.span()); - - #[cfg(not(tarpaulin_include))] - let enum_declaration = match name.to_string().as_str() { - "QueryMsg" => { - quote! { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, cosmwasm_schema::QueryResponses)] - #[serde(rename_all="snake_case")] - pub enum #unique_enum_name #generics { - #(#variants,)* - #phantom - } - pub type #name #generics = #unique_enum_name #generics; - } - } - _ => { - quote! { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] - #[serde(rename_all="snake_case")] - pub enum #unique_enum_name #generics { - #(#variants,)* - #phantom - } - pub type #name #generics = #unique_enum_name #generics; - } - } - }; + let messages_fn_name = Ident::new(&format!("{}_messages", ep_name), enum_name.span()); + let derive_call = msg_ty.emit_derive_call(); #[cfg(not(tarpaulin_include))] { quote! { - #enum_declaration - - impl #generics #unique_enum_name #generics #where_clause { - pub fn dispatch(self, contract: &C, ctx: #ctx_type) - -> #dispatch_type #full_where + #[allow(clippy::derive_partial_eq_without_eq)] + #derive_call + #[serde(rename_all="snake_case")] + pub enum #unique_enum_name #bracketed_used_generics { + #(#msg_variants,)* + #phantom_variant + } + pub type #enum_name #bracketed_used_generics = #unique_enum_name #bracketed_used_generics; + + impl #bracketed_used_generics #unique_enum_name #bracketed_used_generics #where_clause { + pub fn dispatch(self, contract: &ContractT, ctx: #ctx_type) + -> #dispatch_type + where + #(#where_predicates,)* + #contract_predicate { use #unique_enum_name::*; match self { - #match_arms + #(#match_arms,)* + #phatom_match_arm } } #(#variants_constructors)* @@ -470,6 +383,7 @@ pub struct MsgVariant<'a> { // `MsgField<'a>` fields: Vec>, return_type: TokenStream, + stripped_return_type: TokenStream, msg_type: MsgType, } @@ -492,20 +406,22 @@ impl<'a> MsgVariant<'a> { let fields = process_fields(sig, generics_checker); let msg_type = msg_attr.msg_type(); - let return_type = if let MsgAttr::Query { resp_type } = msg_attr { + let (return_type, stripped_return_type) = if let MsgAttr::Query { resp_type } = msg_attr { match resp_type { Some(resp_type) => { - generics_checker.visit_path(&parse_quote! { #resp_type }); - quote! {#resp_type} + let resp_type = parse_quote! { #resp_type }; + generics_checker.visit_path(&resp_type); + (quote! { #resp_type }, quote! { #resp_type }) } None => { let return_type = extract_return_type(&sig.output); - generics_checker.visit_path(return_type); - quote! {#return_type} + let stripped_return_type = StripSelfPath.fold_path(return_type.clone()); + generics_checker.visit_path(&stripped_return_type); + (quote! { #return_type }, quote! { #stripped_return_type }) } } } else { - quote! {} + (quote! {}, quote! {}) }; Self { @@ -513,33 +429,32 @@ impl<'a> MsgVariant<'a> { function_name, fields, return_type, + stripped_return_type, msg_type, } } /// Emits message variant pub fn emit(&self) -> TokenStream { - let Self { name, fields, .. } = self; + let Self { + name, + fields, + msg_type, + stripped_return_type, + .. + } = self; let fields = fields.iter().map(MsgField::emit); - let return_type = &self.return_type; + let returns_attribute = match msg_type { + MsgType::Query => quote! { #[returns(#stripped_return_type)] }, + _ => quote! {}, + }; - if self.msg_type == MsgType::Query { - #[cfg(not(tarpaulin_include))] - { - quote! { - #[returns(#return_type)] - #name { - #(#fields,)* - } - } - } - } else { - #[cfg(not(tarpaulin_include))] - { - quote! { - #name { - #(#fields,)* - } + #[cfg(not(tarpaulin_include))] + { + quote! { + #returns_attribute + #name { + #(#fields,)* } } } @@ -597,13 +512,8 @@ impl<'a> MsgVariant<'a> { let method_name = name.to_string().to_case(Case::Snake); let method_name = Ident::new(&method_name, name.span()); - - let parameters = fields.iter().map(|field| { - let name = field.name; - let ty = field.ty; - quote! { #name : #ty} - }); - let arguments = fields.iter().map(|field| field.name); + let parameters = fields.iter().map(MsgField::emit_method_field); + let arguments = fields.iter().map(MsgField::name); quote! { pub fn #method_name( #(#parameters),*) -> Self { @@ -612,10 +522,35 @@ impl<'a> MsgVariant<'a> { } } + pub fn emit_trait_querier_impl(&self, associated_name: &[&Ident]) -> TokenStream { + let sylvia = crate_module(); + let Self { + name, + fields, + return_type, + .. + } = self; + + let parameters = fields.iter().map(MsgField::emit_method_field); + let fields_names = fields.iter().map(MsgField::name); + let variant_name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); + let bracketed_generics = emit_bracketed_generics(associated_name); + + #[cfg(not(tarpaulin_include))] + { + quote! { + fn #variant_name(&self, #(#parameters),*) -> Result< #return_type, #sylvia:: cw_std::StdError> { + let query = ::Query:: #variant_name (#(#fields_names),*); + self.querier().query_wasm_smart(self.contract(), &query) + } + } + } + } + pub fn emit_querier_impl( &self, trait_module: Option<&Path>, - unbonded_generics: &Vec<&Generic>, + generics: &Vec<&Generic>, ) -> TokenStream { let sylvia = crate_module(); let Self { @@ -637,8 +572,8 @@ impl<'a> MsgVariant<'a> { .map(|module| quote! { #module ::sv::QueryMsg }) .unwrap_or_else(|| quote! { QueryMsg }); - let msg = if !unbonded_generics.is_empty() { - quote! { #msg ::< #(#unbonded_generics,)* > } + let msg = if !generics.is_empty() { + quote! { #msg ::< #(#generics,)* > } } else { quote! { #msg } }; @@ -663,7 +598,7 @@ impl<'a> MsgVariant<'a> { .. } = self; - let parameters = fields.iter().map(MsgField::emit_method_field); + let parameters = fields.iter().map(MsgField::emit_querier_method_field); let variant_name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); #[cfg(not(tarpaulin_include))] @@ -689,7 +624,7 @@ impl<'a> MsgVariant<'a> { let Self { name, fields, - return_type, + stripped_return_type, .. } = self; @@ -721,7 +656,7 @@ impl<'a> MsgVariant<'a> { } }, MsgType::Query => quote! { - pub fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { + pub fn #name (&self, #(#params,)* ) -> Result<#stripped_return_type, #error_type> { let msg = #enum_name :: #name ( #(#arguments),* ); (*self.app) @@ -751,7 +686,7 @@ impl<'a> MsgVariant<'a> { let Self { name, fields, - return_type, + stripped_return_type, .. } = self; @@ -773,7 +708,7 @@ impl<'a> MsgVariant<'a> { } }, MsgType::Query => quote! { - fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { + fn #name (&self, #(#params,)* ) -> Result<#stripped_return_type, #error_type> { let msg = #interface_enum :: #type_name :: #name ( #(#arguments),* ); (*self.app) @@ -787,6 +722,10 @@ impl<'a> MsgVariant<'a> { } } + pub fn is_of_type(&self, msg_type: MsgType) -> bool { + self.msg_type == msg_type + } + pub fn emit_proxy_methods_declarations( &self, msg_ty: &MsgType, @@ -798,7 +737,7 @@ impl<'a> MsgVariant<'a> { let Self { name, fields, - return_type, + stripped_return_type, .. } = self; @@ -811,7 +750,7 @@ impl<'a> MsgVariant<'a> { fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #interface_enum :: #type_name, MtApp, #custom_msg>; }, MsgType::Query => quote! { - fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type>; + fn #name (&self, #(#params,)* ) -> Result<#stripped_return_type, #error_type>; }, _ => quote! {}, } @@ -934,6 +873,7 @@ where pub struct BoundQuerier<'a, C: #sylvia ::cw_std::CustomQuery> { contract: &'a #sylvia ::cw_std::Addr, querier: &'a #sylvia ::cw_std::QuerierWrapper<'a, C>, + _phantom: std::marker::PhantomData<()>, } impl<'a, C: #sylvia ::cw_std::CustomQuery> BoundQuerier<'a, C> { @@ -946,7 +886,7 @@ where } pub fn borrowed(contract: &'a #sylvia ::cw_std::Addr, querier: &'a #sylvia ::cw_std::QuerierWrapper<'a, C>) -> Self { - Self {contract, querier} + Self {contract, querier, _phantom: std::marker::PhantomData} } } @@ -1102,6 +1042,17 @@ where .collect() } + pub fn emit_phantom_match_arm(&self) -> TokenStream { + let sylvia = crate_module(); + let Self { used_generics, .. } = self; + if used_generics.is_empty() { + return quote! {}; + } + quote! { + _Phantom(_) => Err(#sylvia ::cw_std::StdError::generic_err("Phantom message should not be constructed.")).map_err(Into::into), + } + } + pub fn emit_dispatch_legs(&self) -> impl Iterator + '_ { self.variants .iter() @@ -1128,6 +1079,29 @@ where pub fn get_only_variant(&self) -> Option<&MsgVariant> { self.variants.first() } + + pub fn emit_phantom_variant(&self) -> TokenStream { + let Self { + msg_ty, + used_generics, + .. + } = self; + + if used_generics.is_empty() { + return quote! {}; + } + + let return_attr = match msg_ty { + MsgType::Query => quote! { #[returns((#(#used_generics,)*))] }, + _ => quote! {}, + }; + + quote! { + #[serde(skip)] + #return_attr + _Phantom(std::marker::PhantomData<( #(#used_generics,)* )>), + } + } } /// Representation of single message variant field @@ -1135,6 +1109,7 @@ where pub struct MsgField<'a> { name: &'a Ident, ty: &'a Type, + stripped_ty: Type, attrs: &'a Vec, } @@ -1172,20 +1147,43 @@ impl<'a> MsgField<'a> { }?; let ty = &item.ty; + let stripped_ty = StripSelfPath.fold_type((*item.ty).clone()); let attrs = &item.attrs; - generics_checker.visit_type(ty); + generics_checker.visit_type(&stripped_ty); - Some(Self { name, ty, attrs }) + Some(Self { + name, + ty, + stripped_ty, + attrs, + }) } /// Emits message field pub fn emit(&self) -> TokenStream { - let Self { name, ty, attrs } = self; + let Self { + name, + stripped_ty, + attrs, + .. + } = self; #[cfg(not(tarpaulin_include))] { quote! { #(#attrs)* + #name: #stripped_ty + } + } + } + + /// Emits message field + pub fn emit_querier_method_field(&self) -> TokenStream { + let Self { name, ty, .. } = self; + + #[cfg(not(tarpaulin_include))] + { + quote! { #name: #ty } } @@ -1193,12 +1191,14 @@ impl<'a> MsgField<'a> { /// Emits method field pub fn emit_method_field(&self) -> TokenStream { - let Self { name, ty, .. } = self; + let Self { + name, stripped_ty, .. + } = self; #[cfg(not(tarpaulin_include))] { quote! { - #name: #ty + #name: #stripped_ty } } } @@ -1567,38 +1567,20 @@ impl<'a> ContractApi<'a> { pub struct InterfaceApi<'a> { source: &'a ItemTrait, - exec_variants: MsgVariants<'a, GenericParam>, - query_variants: MsgVariants<'a, GenericParam>, - generics: &'a [&'a GenericParam], custom: &'a Custom<'a>, + associated_types: &'a AssociatedTypes<'a>, } impl<'a> InterfaceApi<'a> { pub fn new( source: &'a ItemTrait, - generics: &'a [&'a GenericParam], + associated_types: &'a AssociatedTypes<'a>, custom: &'a Custom<'a>, ) -> Self { - let exec_variants = MsgVariants::new( - source.as_variants(), - MsgType::Exec, - generics, - &source.generics.where_clause, - ); - - let query_variants = MsgVariants::new( - source.as_variants(), - MsgType::Query, - generics, - &source.generics.where_clause, - ); - Self { source, - exec_variants, - query_variants, - generics, custom, + associated_types, } } @@ -1606,18 +1588,31 @@ impl<'a> InterfaceApi<'a> { let sylvia = crate_module(); let Self { source, - exec_variants, - query_variants, - generics, custom, + associated_types, } = self; - let where_clause = &source.generics.where_clause; + let generics = associated_types.as_names(); + let exec_variants = MsgVariants::new( + source.as_variants(), + MsgType::Exec, + &generics, + &source.generics.where_clause, + ); + + let query_variants = MsgVariants::new( + source.as_variants(), + MsgType::Query, + &generics, + &source.generics.where_clause, + ); + + let where_clause = &self.associated_types.as_where_clause(); let custom_query = custom.query_or_default(); let exec_generics = &exec_variants.used_generics; let query_generics = &query_variants.used_generics; - let bracket_generics = emit_bracketed_generics(generics); + let bracket_generics = emit_bracketed_generics(&generics); let exec_bracketed_generics = emit_bracketed_generics(exec_generics); let query_bracketed_generics = emit_bracketed_generics(query_generics); @@ -1637,7 +1632,7 @@ impl<'a> InterfaceApi<'a> { impl #bracket_generics #sylvia ::types::InterfaceApi for Api #bracket_generics #where_clause { type Exec = ExecMsg #exec_bracketed_generics; type Query = QueryMsg #query_bracketed_generics; - type Querier<'querier> = BoundQuerier<'querier, #custom_query >; + type Querier<'querier> = BoundQuerier<'querier, #custom_query, #(#generics,)* >; } } } diff --git a/sylvia-derive/src/parser.rs b/sylvia-derive/src/parser.rs index 90128572..f2dd929f 100644 --- a/sylvia-derive/src/parser.rs +++ b/sylvia-derive/src/parser.rs @@ -193,7 +193,7 @@ impl MsgType { } } - pub fn emit_msg_name(&self, is_wrapper: bool) -> Type { + pub fn emit_msg_name(&self, is_wrapper: bool) -> Ident { match self { MsgType::Exec if is_wrapper => parse_quote! { ContractExecMsg }, MsgType::Query if is_wrapper => parse_quote! { ContractQueryMsg }, @@ -233,6 +233,18 @@ impl MsgType { }, } } + + pub fn emit_derive_call(&self) -> TokenStream { + let sylvia = crate_module(); + match self { + MsgType::Query => quote! { + #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, #sylvia:: cw_schema::QueryResponses)] + }, + _ => quote! { + #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] + }, + } + } } impl PartialEq for MsgAttr { @@ -482,7 +494,7 @@ pub fn parse_associated_custom_type(source: &ItemTrait, type_name: &str) -> Opti source.items.iter().find_map(|item| match item { TraitItem::Type(ty) if ty.ident == type_name => { let type_name = Ident::new(type_name, ty.span()); - Some(parse_quote! { :: #type_name}) + Some(parse_quote! { :: #type_name}) } _ => None, }) diff --git a/sylvia-derive/src/querier.rs b/sylvia-derive/src/querier.rs new file mode 100644 index 00000000..a4f3cc68 --- /dev/null +++ b/sylvia-derive/src/querier.rs @@ -0,0 +1,92 @@ +use proc_macro2::TokenStream; +use quote::quote; +use quote::ToTokens; + +use crate::associated_types::AssociatedTypes; +use crate::check_generics::GetPath; +use crate::crate_module; +use crate::message::{MsgVariant, MsgVariants}; +use crate::parser::MsgType; + +pub struct Querier<'a, Generic> { + variants: &'a MsgVariants<'a, Generic>, + associated_types: &'a AssociatedTypes<'a>, +} + +impl<'a, Generic> Querier<'a, Generic> +where + Generic: GetPath + PartialEq + ToTokens, +{ + pub fn new( + variants: &'a MsgVariants<'a, Generic>, + associated_types: &'a AssociatedTypes, + ) -> Self { + Self { + variants, + associated_types, + } + } + + pub fn emit_trait_querier(&self) -> TokenStream { + let sylvia = crate_module(); + let Self { + variants, + associated_types, + .. + } = self; + + let methods_impl = variants + .variants() + .iter() + .filter(|variant| variant.is_of_type(MsgType::Query)) + .map(|variant| variant.emit_trait_querier_impl(&associated_types.as_names())); + + let methods_declaration = variants + .variants() + .iter() + .filter(|variant| variant.is_of_type(MsgType::Query)) + .map(MsgVariant::emit_querier_declaration); + + let generics = associated_types.as_names(); + let types_declaration = associated_types.as_types_declaration(); + let types_definition = associated_types.emit_types_definition(); + let where_clause = associated_types.as_where_clause(); + + #[cfg(not(tarpaulin_include))] + { + quote! { + pub struct BoundQuerier<'a, C: #sylvia ::cw_std::CustomQuery, #(#generics,)* > { + contract: &'a #sylvia ::cw_std::Addr, + querier: &'a #sylvia ::cw_std::QuerierWrapper<'a, C>, + _phantom: std::marker::PhantomData<( #(#generics,)* )>, + } + + impl<'a, C: #sylvia ::cw_std::CustomQuery, #(#generics,)* > BoundQuerier<'a, C, #(#generics,)* > #where_clause { + pub fn querier(&self) -> &'a #sylvia ::cw_std::QuerierWrapper<'a, C> { + self.querier + } + + pub fn contract(&self) -> &'a #sylvia ::cw_std::Addr { + self.contract + } + + pub fn borrowed(contract: &'a #sylvia ::cw_std::Addr, querier: &'a #sylvia ::cw_std::QuerierWrapper<'a, C>) -> Self { + Self { contract, querier, _phantom: std::marker::PhantomData } + } + } + + impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#generics,)* > Querier for BoundQuerier<'a, C, #(#generics,)* > #where_clause { + #(#types_definition)* + + #(#methods_impl)* + } + + pub trait Querier { + #(#types_declaration)* + + #(#methods_declaration)* + } + } + } + } +} diff --git a/sylvia-derive/src/remote.rs b/sylvia-derive/src/remote.rs index 76fcdf1a..4097b10f 100644 --- a/sylvia-derive/src/remote.rs +++ b/sylvia-derive/src/remote.rs @@ -1,21 +1,27 @@ use proc_macro2::TokenStream; use quote::quote; +use crate::associated_types::AssociatedTypes; use crate::crate_module; use crate::interfaces::Interfaces; use crate::parser::ContractMessageAttr; pub struct Remote<'a> { interfaces: &'a Interfaces, + associated_types: &'a AssociatedTypes<'a>, } impl<'a> Remote<'a> { - pub fn new(interfaces: &'a Interfaces) -> Self { - Self { interfaces } + pub fn new(interfaces: &'a Interfaces, associated_types: &'a AssociatedTypes<'a>) -> Self { + Self { + interfaces, + associated_types, + } } pub fn emit(&self) -> TokenStream { let sylvia = crate_module(); + let generics = self.associated_types.as_names(); let from_implementations = self.interfaces.interfaces().iter().map(|interface| { let ContractMessageAttr { module, .. } = interface; @@ -31,33 +37,36 @@ impl<'a> Remote<'a> { quote! { #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] - pub struct Remote<'a>(std::borrow::Cow<'a, #sylvia ::cw_std::Addr>); + pub struct Remote<'a, #(#generics,)* > { + addr: std::borrow::Cow<'a, #sylvia ::cw_std::Addr>, + #[serde(skip)] + _phantom: std::marker::PhantomData<( #(#generics,)* )>, + } - impl<'a> Remote<'a> { + impl<'a, #(#generics,)* > Remote<'a, #(#generics,)* > { pub fn new(addr: #sylvia ::cw_std::Addr) -> Self { - Self(std::borrow::Cow::Owned(addr)) + Self{addr: std::borrow::Cow::Owned(addr), _phantom: std::marker::PhantomData} } pub fn borrowed(addr: &'a #sylvia ::cw_std::Addr) -> Self { - Self(std::borrow::Cow::Borrowed(addr)) + Self{addr: std::borrow::Cow::Borrowed(addr), _phantom: std::marker::PhantomData} } - } - impl<'a> AsRef<#sylvia ::cw_std::Addr> for Remote<'a> { - fn as_ref(&self) -> &#sylvia ::cw_std::Addr { - &self.0 - } - } - - impl Remote<'_> { - pub fn querier<'a, C: #sylvia ::cw_std::CustomQuery>(&'a self, querier: &'a #sylvia ::cw_std::QuerierWrapper<'a, C>) -> BoundQuerier<'a, C> { + pub fn querier(&'a self, querier: &'a #sylvia ::cw_std::QuerierWrapper<'a, C>) -> BoundQuerier<'a, C, #(#generics,)* > { BoundQuerier { - contract: &self.0, + contract: &self.addr, querier, + _phantom: std::marker::PhantomData, } } } + impl<'a, #(#generics,)* > AsRef<#sylvia ::cw_std::Addr> for Remote<'a, #(#generics,)* > { + fn as_ref(&self) -> &#sylvia ::cw_std::Addr { + &self.addr + } + } + #(#from_implementations)* } } diff --git a/sylvia-derive/src/strip_self_path.rs b/sylvia-derive/src/strip_self_path.rs new file mode 100644 index 00000000..b31ba5e5 --- /dev/null +++ b/sylvia-derive/src/strip_self_path.rs @@ -0,0 +1,15 @@ +use syn::fold::Fold; +use syn::Path; + +pub struct StripSelfPath; + +impl Fold for StripSelfPath { + fn fold_path(&mut self, path: Path) -> Path { + let segments = path + .segments + .into_iter() + .filter(|segment| segment.ident != "Self") + .collect(); + syn::fold::fold_path(self, Path { segments, ..path }) + } +} diff --git a/sylvia-derive/src/utils.rs b/sylvia-derive/src/utils.rs index 57bfe295..4b6a1542 100644 --- a/sylvia-derive/src/utils.rs +++ b/sylvia-derive/src/utils.rs @@ -120,7 +120,7 @@ pub fn as_where_clause(where_predicates: &[&WherePredicate]) -> Option(unbonded_generics: &[&Generic]) -> TokenStream { +pub fn emit_bracketed_generics(unbonded_generics: &[impl ToTokens]) -> TokenStream { match unbonded_generics.is_empty() { true => quote! {}, false => quote! { < #(#unbonded_generics,)* > },