From 9279be00a717a79ef9f69f50ae9f2be9d2b3ce46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Wed, 18 Oct 2023 12:29:45 +0200 Subject: [PATCH] feat: Support generics on every message type --- sylvia-derive/src/input.rs | 54 ++++-------- sylvia-derive/src/message.rs | 155 +++++++++++++++++++-------------- sylvia-derive/src/multitest.rs | 69 +++++++++------ sylvia-derive/src/parser.rs | 24 ++--- sylvia/tests/generics.rs | 72 ++++++++++++--- 5 files changed, 216 insertions(+), 158 deletions(-) diff --git a/sylvia-derive/src/input.rs b/sylvia-derive/src/input.rs index 8f9f2400..74061b64 100644 --- a/sylvia-derive/src/input.rs +++ b/sylvia-derive/src/input.rs @@ -185,16 +185,16 @@ impl<'a> ImplInput<'a> { let Self { item, generics, .. } = self; let multitest_helpers = self.emit_multitest_helpers(generics); let where_clause = &item.generics.where_clause; - let variants = MsgVariants::new( + + let querier = MsgVariants::new( self.item.as_variants(), MsgType::Query, generics, where_clause, - ); - - let messages = self.emit_messages(&variants); + ) + .emit_querier(); + let messages = self.emit_messages(); let remote = Remote::new(&self.interfaces).emit(); - let querier = variants.emit_querier(); let querier_from_impl = self.interfaces.emit_querier_from_impl(); #[cfg(not(tarpaulin_include))] @@ -213,23 +213,13 @@ impl<'a> ImplInput<'a> { } } - fn emit_messages(&self, variants: &MsgVariants) -> TokenStream { + fn emit_messages(&self) -> TokenStream { let instantiate = self.emit_struct_msg(MsgType::Instantiate); let migrate = self.emit_struct_msg(MsgType::Migrate); - let exec_impl = - self.emit_enum_msg(&Ident::new("ExecMsg", Span::mixed_site()), MsgType::Exec); - let query_impl = - self.emit_enum_msg(&Ident::new("QueryMsg", Span::mixed_site()), MsgType::Query); - let exec = self.emit_glue_msg( - &Ident::new("ExecMsg", Span::mixed_site()), - MsgType::Exec, - variants, - ); - let query = self.emit_glue_msg( - &Ident::new("QueryMsg", Span::mixed_site()), - MsgType::Query, - variants, - ); + let exec_impl = self.emit_enum_msg(MsgType::Exec); + let query_impl = self.emit_enum_msg(MsgType::Query); + let exec = self.emit_glue_msg(MsgType::Exec); + let query = self.emit_glue_msg(MsgType::Query); #[cfg(not(tarpaulin_include))] { @@ -254,26 +244,16 @@ impl<'a> ImplInput<'a> { .map_or(quote! {}, |msg| msg.emit()) } - fn emit_enum_msg(&self, name: &Ident, msg_ty: MsgType) -> TokenStream { - ContractEnumMessage::new( - name, - self.item, - msg_ty, - &self.generics, - &self.error, - &self.custom, - ) - .emit() + fn emit_enum_msg(&self, msg_ty: MsgType) -> TokenStream { + ContractEnumMessage::new(self.item, msg_ty, &self.generics, &self.error, &self.custom) + .emit() } - fn emit_glue_msg( - &self, - name: &Ident, - msg_ty: MsgType, - variants: &MsgVariants, - ) -> TokenStream { + fn emit_glue_msg(&self, msg_ty: MsgType) -> TokenStream { + let Self { generics, item, .. } = self; + let where_clause = &item.generics.where_clause; + let variants = MsgVariants::new(item.as_variants(), msg_ty, generics, where_clause); GlueMessage::new( - name, self.item, msg_ty, &self.error, diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 443b11c2..a067a7c3 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -348,37 +348,32 @@ impl<'a> EnumMessage<'a> { /// Representation of single enum message pub struct ContractEnumMessage<'a> { - name: &'a Ident, variants: MsgVariants<'a, GenericParam>, msg_ty: MsgType, contract: &'a Type, error: &'a Type, custom: &'a Custom<'a>, + where_clause: &'a Option, } impl<'a> ContractEnumMessage<'a> { pub fn new( - name: &'a Ident, source: &'a ItemImpl, msg_ty: MsgType, generics: &'a [&'a GenericParam], error: &'a Type, custom: &'a Custom, ) -> Self { - let variants = MsgVariants::new( - source.as_variants(), - msg_ty, - generics, - &source.generics.where_clause, - ); + let where_clause = &source.generics.where_clause; + let variants = MsgVariants::new(source.as_variants(), msg_ty, generics, where_clause); Self { - name, variants, msg_ty, contract: &source.self_ty, error, custom, + where_clause, } } @@ -386,18 +381,21 @@ impl<'a> ContractEnumMessage<'a> { let sylvia = crate_module(); let Self { - name, variants, msg_ty, contract, error, custom, + where_clause, + .. } = self; + let enum_name = msg_ty.emit_msg_name(false); let match_arms = variants.emit_dispatch_legs(); - let generic_name = variants.emit_generic_name(name); let unused_generics = variants.unused_generics(); let unused_generics = emit_bracketed_generics(unused_generics); + let used_generics = variants.used_generics(); + let used_generics = emit_bracketed_generics(used_generics); let mut variant_names = variants.as_names_snake_cased(); variant_names.sort(); @@ -419,13 +417,13 @@ impl<'a> ContractEnumMessage<'a> { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, #derive_query )] #[serde(rename_all="snake_case")] - pub enum #generic_name { + pub enum #enum_name #used_generics { #(#variants,)* } - impl #generic_name { - pub fn dispatch #unused_generics (self, contract: &#contract, ctx: #ctx_type) -> #ret_type { - use #name::*; + impl #used_generics #enum_name #used_generics { + pub fn dispatch #unused_generics (self, contract: &#contract, ctx: #ctx_type) -> #ret_type #where_clause { + use #enum_name::*; match self { #(#match_arms,)* @@ -650,13 +648,17 @@ impl<'a> MsgVariant<'a> { } } - pub fn emit_multitest_proxy_methods( + pub fn emit_multitest_proxy_methods( &self, msg_ty: &MsgType, custom_msg: &Type, mt_app: &Type, error_type: &Type, - ) -> TokenStream { + generics: &[&Generic], + ) -> TokenStream + where + Generic: ToTokens, + { let sylvia = crate_module(); let Self { name, @@ -668,27 +670,33 @@ impl<'a> MsgVariant<'a> { let params = fields.iter().map(|field| field.emit_method_field()); let arguments = fields.iter().map(MsgField::name); let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span()); + let enum_name = msg_ty.emit_msg_name(false); + let enum_name: Type = if !generics.is_empty() { + parse_quote! { #enum_name ::< #(#generics,)* > } + } else { + parse_quote! { #enum_name } + }; match msg_ty { MsgType::Exec => quote! { #[track_caller] - pub fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, ExecMsg, #mt_app, #custom_msg> { - let msg = ExecMsg:: #name ( #(#arguments),* ); + pub fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::<#error_type, #enum_name, #mt_app, #custom_msg> { + let msg = #enum_name :: #name ( #(#arguments),* ); #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app) } }, MsgType::Migrate => quote! { #[track_caller] - pub fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::<#error_type, MigrateMsg, #mt_app, #custom_msg> { - let msg = MigrateMsg::new( #(#arguments),* ); + pub fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::<#error_type, #enum_name, #mt_app, #custom_msg> { + let msg = #enum_name ::new( #(#arguments),* ); #sylvia ::multitest::MigrateProxy::new(&self.contract_addr, msg, &self.app) } }, MsgType::Query => quote! { pub fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> { - let msg = QueryMsg:: #name ( #(#arguments),* ); + let msg = #enum_name :: #name ( #(#arguments),* ); (*self.app) .app() @@ -871,8 +879,13 @@ where &self.unused_generics } - pub fn where_predicates(&'a self) -> &'a [&'a WherePredicate] { - &self.where_predicates + pub fn as_where_clause(&'a self) -> Option { + let where_predicates = &self.where_predicates; + if !where_predicates.is_empty() { + Some(parse_quote!( where #(#where_predicates,)* )) + } else { + None + } } pub fn emit_querier(&self) -> TokenStream { @@ -980,10 +993,11 @@ where } = self; let values = msg_ty.emit_ctx_values(); - let msg_name = msg_ty.emit_msg_name(used_generics.as_slice()); + let msg_name = msg_ty.emit_msg_name(true); + let bracketed_generics = emit_bracketed_generics(used_generics); quote! { - #sylvia ::cw_std::from_slice::< #msg_name >(&msg)? + #sylvia ::cw_std::from_slice::< #msg_name #bracketed_generics >(&msg)? .dispatch(self, ( #values )) .map_err(Into::into) } @@ -1010,13 +1024,14 @@ where let params = msg_ty.emit_ctx_params(custom_query); let values = msg_ty.emit_ctx_values(); let ep_name = msg_ty.emit_ep_name(); - let msg_name = msg_ty.emit_msg_name(used_generics); + let msg_name = msg_ty.emit_msg_name(true); + let bracketed_generics = emit_bracketed_generics(used_generics); quote! { #[#sylvia ::cw_std::entry_point] pub fn #ep_name ( #params , - msg: #msg_name, + msg: #msg_name #bracketed_generics, ) -> Result<#resp_type, #error> { msg.dispatch(&#name ::new() , ( #values )).map_err(Into::into) } @@ -1031,7 +1046,13 @@ where self.variants .iter() .map(|variant| { - variant.emit_multitest_proxy_methods(&self.msg_ty, custom_msg, mt_app, error_type) + variant.emit_multitest_proxy_methods( + &self.msg_ty, + custom_msg, + mt_app, + error_type, + &self.used_generics, + ) }) .collect() } @@ -1100,17 +1121,6 @@ where .map(MsgVariant::emit_variants_constructors) } - pub fn emit_generic_name(&self, name: &Ident) -> TokenStream { - let generics = emit_bracketed_generics(&self.used_generics); - - #[cfg(not(tarpaulin_include))] - { - quote! { - #name #generics - } - } - } - pub fn emit(&self) -> impl Iterator + '_ { self.variants.iter().map(MsgVariant::emit) } @@ -1201,27 +1211,26 @@ impl<'a> MsgField<'a> { /// Glue message is the message composing Exec/Query messages from several traits #[derive(Debug)] pub struct GlueMessage<'a> { - name: &'a Ident, + source: &'a ItemImpl, contract: &'a Type, msg_ty: MsgType, error: &'a Type, custom: &'a Custom<'a>, interfaces: &'a Interfaces, - variants: &'a MsgVariants<'a, GenericParam>, + variants: MsgVariants<'a, GenericParam>, } impl<'a> GlueMessage<'a> { pub fn new( - name: &'a Ident, source: &'a ItemImpl, msg_ty: MsgType, error: &'a Type, custom: &'a Custom, interfaces: &'a Interfaces, - variants: &'a MsgVariants<'a, GenericParam>, + variants: MsgVariants<'a, GenericParam>, ) -> Self { GlueMessage { - name, + source, contract: &source.self_ty, msg_ty, error, @@ -1234,7 +1243,7 @@ impl<'a> GlueMessage<'a> { pub fn emit(&self) -> TokenStream { let sylvia = crate_module(); let Self { - name, + source, contract, msg_ty, error, @@ -1242,19 +1251,28 @@ impl<'a> GlueMessage<'a> { interfaces, variants, } = self; - let contract_name = StripGenerics.fold_type((*contract).clone()); - let enum_name = Ident::new(&format!("Contract{}", name), name.span()); + let used_generics = variants.used_generics(); - let used_generics = emit_bracketed_generics(used_generics); let unused_generics = variants.unused_generics(); + let where_clause = variants.as_where_clause(); + let full_where_clause = &source.generics.where_clause; + + let contract_enum_name = msg_ty.emit_msg_name(true); + let enum_name = msg_ty.emit_msg_name(false); + let contract_name = StripGenerics.fold_type((*contract).clone()); let unused_generics = emit_bracketed_generics(unused_generics); - let where_clause = variants.where_clause(); + let bracketed_used_generics = emit_bracketed_generics(used_generics); let variants = interfaces.emit_glue_message_variants(msg_ty); - let contract_variant = quote! { #contract_name ( #name ) }; + let contract_variant = quote! { #contract_name ( #enum_name #bracketed_used_generics ) }; let mut messages_call = interfaces.emit_messages_call(msg_ty); - messages_call.push(quote! { &#name :: messages() }); + let prefixed_used_generics = if !used_generics.is_empty() { + quote! { :: #bracketed_used_generics } + } else { + quote! {} + }; + messages_call.push(quote! { &#enum_name #prefixed_used_generics :: messages() }); let variants_cnt = messages_call.len(); @@ -1277,22 +1295,22 @@ impl<'a> GlueMessage<'a> { match (msg_ty, customs.has_msg) { (MsgType::Exec, true) => quote! { - #enum_name:: #variant(msg) => #sylvia ::into_response::IntoResponse::into_response(msg.dispatch(contract, Into::into( #ctx ))?) + #contract_enum_name:: #variant(msg) => #sylvia ::into_response::IntoResponse::into_response(msg.dispatch(contract, Into::into( #ctx ))?) }, _ => quote! { - #enum_name :: #variant(msg) => msg.dispatch(contract, Into::into( #ctx )) + #contract_enum_name :: #variant(msg) => msg.dispatch(contract, Into::into( #ctx )) }, } }); let dispatch_arm = - quote! {#enum_name :: #contract_name (msg) => msg.dispatch(contract, ctx)}; + quote! {#contract_enum_name :: #contract_name (msg) => msg.dispatch(contract, ctx)}; let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(msg_ty); #[cfg(not(tarpaulin_include))] let contract_deserialization_attempt = quote! { - let msgs = &#name :: messages(); + let msgs = &#enum_name #prefixed_used_generics :: messages(); if msgs.into_iter().any(|msg| msg == &recv_msg_name) { match val.deserialize_into() { Ok(msg) => return Ok(Self:: #contract_name (msg)), @@ -1305,15 +1323,16 @@ impl<'a> GlueMessage<'a> { let ret_type = msg_ty.emit_result_type(&custom.msg_or_default(), error); let mut response_schemas_calls = interfaces.emit_response_schemas_calls(msg_ty); - response_schemas_calls.push(quote! {#name :: response_schemas_impl()}); + response_schemas_calls + .push(quote! {#enum_name #prefixed_used_generics :: response_schemas_impl()}); - let response_schemas = match name.to_string().as_str() { - "QueryMsg" => { + let response_schemas = match msg_ty { + MsgType::Query => { #[cfg(not(tarpaulin_include))] { quote! { #[cfg(not(target_arch = "wasm32"))] - impl #sylvia ::cw_schema::QueryResponses for #enum_name { + impl #bracketed_used_generics #sylvia ::cw_schema::QueryResponses for #contract_enum_name #bracketed_used_generics #where_clause { fn response_schemas_impl() -> std::collections::BTreeMap { let responses = [#(#response_schemas_calls),*]; responses.into_iter().flatten().collect() @@ -1333,21 +1352,23 @@ impl<'a> GlueMessage<'a> { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(#sylvia ::serde::Serialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] #[serde(rename_all="snake_case", untagged)] - pub enum #enum_name #used_generics { + pub enum #contract_enum_name #bracketed_used_generics { #(#variants,)* #contract_variant } - impl #used_generics #enum_name #used_generics { - pub fn dispatch #unused_generics #where_clause ( + impl #bracketed_used_generics #contract_enum_name #bracketed_used_generics { + pub fn dispatch #unused_generics ( self, contract: &#contract, ctx: #ctx_type, - ) -> #ret_type { - const _: () = { + ) -> #ret_type #full_where_clause { + const fn assert_no_intersection #bracketed_used_generics () #where_clause { let msgs: [&[&str]; #variants_cnt] = [#(#messages_call),*]; #sylvia ::utils::assert_no_intersection(msgs); - }; + } + + assert_no_intersection #prefixed_used_generics (); match self { #(#dispatch_arms,)* @@ -1358,7 +1379,7 @@ impl<'a> GlueMessage<'a> { #response_schemas - impl<'de> serde::Deserialize<'de> for #enum_name { + impl<'de, #(#used_generics,)* > serde::Deserialize<'de> for #contract_enum_name #bracketed_used_generics #where_clause { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { diff --git a/sylvia-derive/src/multitest.rs b/sylvia-derive/src/multitest.rs index dbec15a1..f958cd5c 100644 --- a/sylvia-derive/src/multitest.rs +++ b/sylvia-derive/src/multitest.rs @@ -38,6 +38,7 @@ pub struct MultitestHelpers<'a, Generics> { is_trait: bool, source: &'a ItemImpl, generics: &'a [&'a Generics], + where_clause: &'a Option, contract_name: &'a Ident, proxy_name: Ident, custom: &'a Custom<'a>, @@ -127,6 +128,7 @@ where is_trait, source, generics, + where_clause, contract_name, proxy_name, custom, @@ -150,6 +152,8 @@ where exec_variants, query_variants, migrate_variants, + generics, + where_clause, .. } = self; let sylvia = crate_module(); @@ -180,6 +184,9 @@ where query_variants.emit_multitest_proxy_methods(&custom_msg, &mt_app, error_type); let migrate_methods = migrate_variants.emit_multitest_proxy_methods(&custom_msg, &mt_app, error_type); + let where_predicates = where_clause + .as_ref() + .map(|where_clause| &where_clause.predicates); let contract_block = self.generate_contract_helpers(); @@ -195,13 +202,14 @@ where #[derive(Derivative)] #[derivative(Debug)] - pub struct #proxy_name <'app, MtApp> { + pub struct #proxy_name <'app, MtApp, #(#generics,)* > { pub contract_addr: #sylvia ::cw_std::Addr, #[derivative(Debug="ignore")] pub app: &'app #sylvia ::multitest::App, + _phantom: std::marker::PhantomData<( #(#generics,)* )>, } - impl<'app, BankT, ApiT, StorageT, CustomT, WasmT, StakingT, DistrT, IbcT, GovT> #proxy_name <'app, #mt_app > + impl<'app, BankT, ApiT, StorageT, CustomT, WasmT, StakingT, DistrT, IbcT, GovT, #(#generics,)* > #proxy_name <'app, #mt_app, #(#generics,)* > where CustomT: #sylvia ::cw_multi_test::Module, CustomT::ExecT: std::fmt::Debug @@ -219,10 +227,11 @@ where DistrT: #sylvia ::cw_multi_test::Distribution, IbcT: #sylvia ::cw_multi_test::Ibc, GovT: #sylvia ::cw_multi_test::Gov, - #mt_app : Executor< #custom_msg > + #mt_app : Executor< #custom_msg >, + #where_predicates { pub fn new(contract_addr: #sylvia ::cw_std::Addr, app: &'app #sylvia ::multitest::App< #mt_app >) -> Self { - #proxy_name{ contract_addr, app } + #proxy_name { contract_addr, app, _phantom: std::marker::PhantomData::default() } } #( #exec_methods )* @@ -231,12 +240,12 @@ where #( #proxy_accessors )* } - impl<'app, BankT, ApiT, StorageT, CustomT, WasmT, StakingT, DistrT, IbcT, GovT> + impl<'app, BankT, ApiT, StorageT, CustomT, WasmT, StakingT, DistrT, IbcT, GovT, #(#generics,)* > From<( #sylvia ::cw_std::Addr, &'app #sylvia ::multitest::App<#mt_app>, )> - for #proxy_name <'app, #mt_app > + for #proxy_name <'app, #mt_app, #(#generics,)* > where CustomT: #sylvia ::cw_multi_test::Module, CustomT::ExecT: std::fmt::Debug @@ -255,9 +264,10 @@ where IbcT: #sylvia ::cw_multi_test::Ibc, GovT: #sylvia ::cw_multi_test::Gov, #mt_app : Executor< #custom_msg >, + #where_predicates { fn from(input: (#sylvia ::cw_std::Addr, &'app #sylvia ::multitest::App< #mt_app >)) - -> #proxy_name<'app, #mt_app > { + -> #proxy_name<'app, #mt_app, #(#generics,)* > { #proxy_name::new(input.0, input.1) } } @@ -378,10 +388,10 @@ where fn generate_contract_helpers(&self) -> TokenStream { let sylvia = crate_module(); let Self { - source, error_type, is_trait, generics, + where_clause, contract_name, proxy_name, instantiate_variants, @@ -404,11 +414,10 @@ where let used_generics = instantiate_variants.used_generics(); let bracketed_used_generics = emit_bracketed_generics(used_generics); - let bracketed_generics = emit_bracketed_generics(generics); - let full_where_clause = &source.generics.where_clause; - let where_predicates = instantiate_variants.where_predicates(); - let where_clause = instantiate_variants.where_clause(); + let where_predicates = where_clause + .as_ref() + .map(|where_clause| &where_clause.predicates); let contract = if !generics.is_empty() { quote! { #contract_name ::< #(#generics,)* > } } else { @@ -446,12 +455,14 @@ where quote! { #impl_contract - pub struct CodeId<'app, MtApp> { + pub struct CodeId<'app, MtApp, #(#generics,)* > { code_id: u64, app: &'app #sylvia ::multitest::App, + _phantom: std::marker::PhantomData<( #(#generics,)* )>, + } - impl<'app, BankT, ApiT, StorageT, CustomT, StakingT, DistrT, IbcT, GovT> CodeId<'app, #mt_app> + impl<'app, BankT, ApiT, StorageT, CustomT, StakingT, DistrT, IbcT, GovT, #(#generics,)* > CodeId<'app, #mt_app, #(#generics,)* > where BankT: #sylvia ::cw_multi_test::Bank, ApiT: #sylvia ::cw_std::Api, @@ -461,23 +472,24 @@ where DistrT: #sylvia ::cw_multi_test::Distribution, IbcT: #sylvia ::cw_multi_test::Ibc, GovT: #sylvia ::cw_multi_test::Gov, + #where_predicates { - pub fn store_code #bracketed_generics (app: &'app #sylvia ::multitest::App< #mt_app >) -> Self #full_where_clause { + pub fn store_code(app: &'app #sylvia ::multitest::App< #mt_app >) -> Self { let code_id = app .app_mut() .store_code(Box::new(#contract ::new())); - Self { code_id, app } + Self { code_id, app, _phantom: std::marker::PhantomData::default() } } pub fn code_id(&self) -> u64 { self.code_id } - pub fn instantiate #bracketed_used_generics ( + pub fn instantiate( &self,#(#fields,)* - ) -> InstantiateProxy<'_, 'app, #mt_app, #(#used_generics,)* > #where_clause { + ) -> InstantiateProxy<'_, 'app, #mt_app, #(#generics,)* > { let msg = #instantiate_msg {#(#fields_names,)*}; - InstantiateProxy::<_, #(#used_generics,)* > { + InstantiateProxy::<_, #(#generics,)* > { code_id: self, funds: &[], label: "Contract", @@ -487,24 +499,24 @@ where } } - pub struct InstantiateProxy<'a, 'app, MtApp, #(#used_generics,)* > { - code_id: &'a CodeId <'app, MtApp>, - funds: &'a [#sylvia ::cw_std::Coin], - label: &'a str, + pub struct InstantiateProxy<'proxy, 'app, MtApp, #(#generics,)* > { + code_id: &'proxy CodeId <'app, MtApp, #(#generics,)* >, + funds: &'proxy [#sylvia ::cw_std::Coin], + label: &'proxy str, admin: Option, msg: InstantiateMsg #bracketed_used_generics, } - impl<'a, 'app, MtApp, #(#used_generics,)* > InstantiateProxy<'a, 'app, MtApp, #(#used_generics,)* > + impl<'proxy, 'app, MtApp, #(#generics,)* > InstantiateProxy<'proxy, 'app, MtApp, #(#generics,)* > where MtApp: Executor< #custom_msg >, - #(#where_predicates,)* + #where_predicates { - pub fn with_funds(self, funds: &'a [#sylvia ::cw_std::Coin]) -> Self { + pub fn with_funds(self, funds: &'proxy [#sylvia ::cw_std::Coin]) -> Self { Self { funds, ..self } } - pub fn with_label(self, label: &'a str) -> Self { + pub fn with_label(self, label: &'proxy str) -> Self { Self { label, ..self } } @@ -514,7 +526,7 @@ where } #[track_caller] - pub fn call(self, sender: &str) -> Result<#proxy_name<'app, MtApp>, #error_type> { + pub fn call(self, sender: &str) -> Result<#proxy_name<'app, MtApp, #(#generics,)* >, #error_type> { (*self.code_id.app) .app_mut() .instantiate_contract( @@ -529,6 +541,7 @@ where .map(|addr| #proxy_name { contract_addr: addr, app: self.code_id.app, + _phantom: std::marker::PhantomData::default(), }) } } diff --git a/sylvia-derive/src/parser.rs b/sylvia-derive/src/parser.rs index 51ca393c..864eb183 100644 --- a/sylvia-derive/src/parser.rs +++ b/sylvia-derive/src/parser.rs @@ -1,6 +1,6 @@ use proc_macro2::{Punct, TokenStream}; use proc_macro_error::emit_error; -use quote::{quote, ToTokens}; +use quote::quote; use syn::fold::Fold; use syn::parse::{Error, Nothing, Parse, ParseBuffer, ParseStream, Parser}; use syn::punctuated::Punctuated; @@ -145,21 +145,15 @@ impl MsgType { } } - pub fn emit_msg_name(&self, generics: &[&Generic]) -> Type - where - Generic: ToTokens, - { - let generics = if !generics.is_empty() { - quote! { ::< #(#generics,)* > } - } else { - quote! {} - }; + pub fn emit_msg_name(&self, is_wrapper: bool) -> Type { match self { - MsgType::Exec => parse_quote! { ContractExecMsg #generics }, - MsgType::Query => parse_quote! { ContractQueryMsg #generics }, - MsgType::Instantiate => parse_quote! { InstantiateMsg #generics }, - MsgType::Migrate => parse_quote! { MigrateMsg #generics }, - MsgType::Reply => parse_quote! { ReplyMsg #generics }, + MsgType::Exec if is_wrapper => parse_quote! { ContractExecMsg }, + MsgType::Query if is_wrapper => parse_quote! { ContractQueryMsg }, + MsgType::Exec => parse_quote! { ExecMsg }, + MsgType::Query => parse_quote! { QueryMsg }, + MsgType::Instantiate => parse_quote! { InstantiateMsg }, + MsgType::Migrate => parse_quote! { MigrateMsg }, + MsgType::Reply => parse_quote! { ReplyMsg }, MsgType::Sudo => todo!(), } } diff --git a/sylvia/tests/generics.rs b/sylvia/tests/generics.rs index a386124e..5b153111 100644 --- a/sylvia/tests/generics.rs +++ b/sylvia/tests/generics.rs @@ -80,26 +80,63 @@ pub mod non_generic { } pub mod generic_contract { - use cosmwasm_std::{CustomQuery, Response, StdResult}; + use cosmwasm_std::{Reply, Response, StdResult}; use serde::de::DeserializeOwned; use serde::Deserialize; - use sylvia::types::{CustomMsg, InstantiateCtx}; + use sylvia::types::{CustomMsg, ExecCtx, InstantiateCtx, MigrateCtx, QueryCtx, ReplyCtx}; use sylvia_derive::contract; - pub struct GenericContract(std::marker::PhantomData<(Msg, QueryRet)>); + pub struct GenericContract( + std::marker::PhantomData<( + InstantiateParam, + ExecParam, + QueryParam, + MigrateParam, + RetType, + )>, + ); #[contract] - impl GenericContract + impl + GenericContract where - for<'msg_de> Msg: CustomMsg + Deserialize<'msg_de> + 'msg_de, - for<'a> QueryRet: CustomQuery + DeserializeOwned + 'a, + for<'msg_de> InstantiateParam: CustomMsg + Deserialize<'msg_de> + 'msg_de, + for<'exec> ExecParam: CustomMsg + DeserializeOwned + 'exec, + for<'exec> QueryParam: CustomMsg + DeserializeOwned + 'exec, + for<'exec> MigrateParam: CustomMsg + DeserializeOwned + 'exec, + for<'ret> RetType: CustomMsg + DeserializeOwned + 'ret, { pub const fn new() -> Self { Self(std::marker::PhantomData) } #[msg(instantiate)] - pub fn instantiate(&self, _ctx: InstantiateCtx, _msg: Msg) -> StdResult { + pub fn instantiate( + &self, + _ctx: InstantiateCtx, + _msg: InstantiateParam, + ) -> StdResult { + Ok(Response::new()) + } + + #[msg(exec)] + pub fn execute(&self, _ctx: ExecCtx, _msg: ExecParam) -> StdResult { + Ok(Response::new()) + } + + #[msg(query)] + pub fn query(&self, _ctx: QueryCtx, _msg: QueryParam) -> StdResult { + Ok(Response::new()) + } + + #[msg(migrate)] + pub fn migrate(&self, _ctx: MigrateCtx, _msg: MigrateParam) -> StdResult { + Ok(Response::new()) + } + + #[allow(dead_code)] + #[msg(reply)] + fn reply(&self, _ctx: ReplyCtx, _reply: Reply) -> StdResult { Ok(Response::new()) } } @@ -327,18 +364,31 @@ mod tests { #[test] fn generic_contract() { + use crate::generic_contract::multitest_utils::CodeId; let app = App::default(); - let code_id = crate::generic_contract::multitest_utils::CodeId::store_code::< + let code_id: CodeId< + cw_multi_test::BasicApp, + ExternalMsg, ExternalMsg, - ExternalQuery, - >(&app); + ExternalMsg, + crate::ExternalMsg, + crate::ExternalMsg, + > = CodeId::store_code(&app); let owner = "owner"; - code_id + let contract = code_id .instantiate(ExternalMsg {}) .with_label("GenericContract") + .with_admin(owner) .call(owner) .unwrap(); + + contract.execute(ExternalMsg).call(owner).unwrap(); + contract.query(ExternalMsg).unwrap(); + contract + .migrate(ExternalMsg) + .call(owner, code_id.code_id()) + .unwrap(); } }