Skip to content

Commit f01f822

Browse files
committed
feat: Add support for generics in interface
1 parent e6131e4 commit f01f822

File tree

10 files changed

+187
-93
lines changed

10 files changed

+187
-93
lines changed

sylvia-derive/src/check_generics.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use syn::visit::Visit;
22
use syn::GenericParam;
33

4+
#[derive(Debug)]
45
pub struct CheckGenerics<'g> {
56
generics: &'g [&'g GenericParam],
67
used: Vec<&'g GenericParam>,

sylvia-derive/src/input.rs

+24-16
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,14 @@ impl<'a> TraitInput<'a> {
6262
let messages = self.emit_messages();
6363
let multitest_helpers = self.emit_helpers();
6464
let remote = Remote::new(&Interfaces::default()).emit();
65-
let querier = MsgVariants::new(self.item.as_variants(), &self.generics).emit_querier();
65+
66+
let querier = MsgVariants::new(
67+
self.item.as_variants(),
68+
MsgType::Query,
69+
&self.generics,
70+
&self.item.generics.where_clause,
71+
)
72+
.emit_querier();
6673

6774
#[cfg(not(tarpaulin_include))]
6875
{
@@ -159,22 +166,26 @@ impl<'a> ImplInput<'a> {
159166
quote! {}
160167
};
161168

162-
let interfaces = Interfaces::new(self.item);
163-
let variants = MsgVariants::new(self.item.as_variants(), &self.generics);
169+
let unbonded_generics = &vec![];
170+
let variants = MsgVariants::new(
171+
self.item.as_variants(),
172+
MsgType::Query,
173+
unbonded_generics,
174+
&None,
175+
);
164176

165177
match is_trait {
166-
true => self.process_interface(&interfaces, variants, multitest_helpers),
167-
false => self.process_contract(&interfaces, variants, multitest_helpers),
178+
true => self.process_interface(variants, multitest_helpers),
179+
false => self.process_contract(variants, multitest_helpers),
168180
}
169181
}
170182

171183
fn process_interface(
172184
&self,
173-
interfaces: &Interfaces,
174185
variants: MsgVariants<'a>,
175186
multitest_helpers: TokenStream,
176187
) -> TokenStream {
177-
let querier_bound_for_impl = self.emit_querier_for_bound_impl(interfaces, variants);
188+
let querier_bound_for_impl = self.emit_querier_for_bound_impl(variants);
178189

179190
#[cfg(not(tarpaulin_include))]
180191
quote! {
@@ -186,14 +197,14 @@ impl<'a> ImplInput<'a> {
186197

187198
fn process_contract(
188199
&self,
189-
interfaces: &Interfaces,
190200
variants: MsgVariants<'a>,
191201
multitest_helpers: TokenStream,
192202
) -> TokenStream {
193203
let messages = self.emit_messages();
194-
let remote = Remote::new(interfaces).emit();
204+
let remote = Remote::new(&self.interfaces).emit();
205+
195206
let querier = variants.emit_querier();
196-
let querier_from_impl = interfaces.emit_querier_from_impl();
207+
let querier_from_impl = self.interfaces.emit_querier_from_impl();
197208

198209
#[cfg(not(tarpaulin_include))]
199210
{
@@ -268,12 +279,9 @@ impl<'a> ImplInput<'a> {
268279
.emit()
269280
}
270281

271-
fn emit_querier_for_bound_impl(
272-
&self,
273-
interfaces: &Interfaces,
274-
variants: MsgVariants<'a>,
275-
) -> TokenStream {
276-
let trait_module = interfaces
282+
fn emit_querier_for_bound_impl(&self, variants: MsgVariants<'a>) -> TokenStream {
283+
let trait_module = self
284+
.interfaces
277285
.interfaces()
278286
.first()
279287
.map(|interface| &interface.module);

sylvia-derive/src/message.rs

+91-37
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ impl<'a> EnumMessage<'a> {
303303
#[allow(clippy::derive_partial_eq_without_eq)]
304304
#[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, cosmwasm_schema::QueryResponses)]
305305
#[serde(rename_all="snake_case")]
306-
pub enum #unique_enum_name #generics #where_clause {
306+
pub enum #unique_enum_name #generics {
307307
#(#variants,)*
308308
}
309309
pub type #name #generics = #unique_enum_name #generics;
@@ -314,7 +314,7 @@ impl<'a> EnumMessage<'a> {
314314
#[allow(clippy::derive_partial_eq_without_eq)]
315315
#[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)]
316316
#[serde(rename_all="snake_case")]
317-
pub enum #unique_enum_name #generics #where_clause {
317+
pub enum #unique_enum_name #generics {
318318
#(#variants,)*
319319
}
320320
pub type #name #generics = #unique_enum_name #generics;
@@ -506,7 +506,9 @@ impl<'a> MsgVariant<'a> {
506506

507507
let return_type = if let MsgAttr::Query { resp_type } = msg_attr {
508508
match resp_type {
509-
Some(resp_type) => quote! {#resp_type},
509+
Some(resp_type) => {
510+
quote! {#resp_type}
511+
}
510512
None => {
511513
let return_type = extract_return_type(&sig.output);
512514
quote! {#return_type}
@@ -667,12 +669,20 @@ impl<'a> MsgVariant<'a> {
667669
}
668670
}
669671

670-
pub struct MsgVariants<'a>(Vec<MsgVariant<'a>>);
672+
pub struct MsgVariants<'a> {
673+
variants: Vec<MsgVariant<'a>>,
674+
unbonded_generics: Vec<&'a GenericParam>,
675+
where_clause: Option<WhereClause>,
676+
}
671677

672678
impl<'a> MsgVariants<'a> {
673-
pub fn new(source: VariantDescs<'a>, generics: &[&'a GenericParam]) -> Self {
679+
pub fn new(
680+
source: VariantDescs<'a>,
681+
msg_type: MsgType,
682+
generics: &'a Vec<&'a GenericParam>,
683+
unfiltered_where_clause: &'a Option<WhereClause>,
684+
) -> Self {
674685
let mut generics_checker = CheckGenerics::new(generics);
675-
676686
let variants: Vec<_> = source
677687
.filter_map(|variant_desc| {
678688
let msg_attr = variant_desc.attr_msg()?;
@@ -684,19 +694,49 @@ impl<'a> MsgVariants<'a> {
684694
}
685695
};
686696

697+
if attr.msg_type() != msg_type {
698+
return None;
699+
}
700+
687701
Some(MsgVariant::new(
688702
variant_desc.into_sig(),
689703
&mut generics_checker,
690704
attr,
691705
))
692706
})
693707
.collect();
694-
Self(variants)
708+
709+
let (unbonded_generics, _) = generics_checker.used_unused();
710+
let wheres = filter_wheres(
711+
unfiltered_where_clause,
712+
generics.as_slice(),
713+
&unbonded_generics,
714+
);
715+
let where_clause = if !wheres.is_empty() {
716+
Some(parse_quote! { where #(#wheres),* })
717+
} else {
718+
None
719+
};
720+
721+
Self {
722+
variants,
723+
unbonded_generics,
724+
where_clause,
725+
}
726+
}
727+
728+
pub fn variants(&self) -> &Vec<MsgVariant<'a>> {
729+
&self.variants
695730
}
696731

697732
pub fn emit_querier(&self) -> TokenStream {
698733
let sylvia = crate_module();
699-
let variants = &self.0;
734+
let Self {
735+
variants,
736+
unbonded_generics,
737+
where_clause,
738+
..
739+
} = self;
700740

701741
let methods_impl = variants
702742
.iter()
@@ -708,6 +748,12 @@ impl<'a> MsgVariants<'a> {
708748
.filter(|variant| variant.msg_type == MsgType::Query)
709749
.map(MsgVariant::emit_querier_declaration);
710750

751+
let querier = if !unbonded_generics.is_empty() {
752+
quote! { Querier < #(#unbonded_generics,)* > }
753+
} else {
754+
quote! { Querier }
755+
};
756+
711757
#[cfg(not(tarpaulin_include))]
712758
{
713759
quote! {
@@ -730,12 +776,11 @@ impl<'a> MsgVariants<'a> {
730776
}
731777
}
732778

733-
impl <'a, C: #sylvia ::cw_std::CustomQuery> Querier for BoundQuerier<'a, C> {
779+
impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#unbonded_generics,)*> #querier for BoundQuerier<'a, C> #where_clause {
734780
#(#methods_impl)*
735781
}
736782

737-
738-
pub trait Querier {
783+
pub trait #querier {
739784
#(#methods_declaration)*
740785
}
741786
}
@@ -748,24 +793,33 @@ impl<'a> MsgVariants<'a> {
748793
contract_module: Option<&Path>,
749794
) -> TokenStream {
750795
let sylvia = crate_module();
751-
let variants = &self.0;
796+
let Self {
797+
variants,
798+
unbonded_generics,
799+
where_clause,
800+
..
801+
} = self;
752802

753803
let methods_impl = variants
754804
.iter()
755805
.filter(|variant| variant.msg_type == MsgType::Query)
756806
.map(|variant| variant.emit_querier_impl(trait_module));
757807

758-
let querier = trait_module
808+
let mut querier = trait_module
759809
.map(|module| quote! { #module ::Querier })
760810
.unwrap_or_else(|| quote! { Querier });
761811
let bound_querier = contract_module
762812
.map(|module| quote! { #module ::BoundQuerier})
763813
.unwrap_or_else(|| quote! { BoundQuerier });
764814

815+
if !unbonded_generics.is_empty() {
816+
querier = quote! { #querier < #(#unbonded_generics,)* > };
817+
}
818+
765819
#[cfg(not(tarpaulin_include))]
766820
{
767821
quote! {
768-
impl <'a, C: #sylvia ::cw_std::CustomQuery> #querier for #bound_querier<'a, C> {
822+
impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#unbonded_generics,)*> #querier for #bound_querier<'a, C> #where_clause {
769823
#(#methods_impl)*
770824
}
771825
}
@@ -886,7 +940,7 @@ impl<'a> GlueMessage<'a> {
886940
interfaces,
887941
} = self;
888942
let contract = StripGenerics.fold_type((*contract).clone());
889-
let contract_name = Ident::new(&format!("Contract{}", name), name.span());
943+
let enum_name = Ident::new(&format!("Contract{}", name), name.span());
890944

891945
let variants = interfaces.emit_glue_message_variants(msg_ty, name);
892946

@@ -916,15 +970,15 @@ impl<'a> GlueMessage<'a> {
916970

917971
match (msg_ty, customs.has_msg) {
918972
(MsgType::Exec, true) => quote! {
919-
#contract_name :: #variant(msg) => #sylvia ::into_response::IntoResponse::into_response(msg.dispatch(contract, Into::into( #ctx ))?)
973+
#enum_name:: #variant(msg) => #sylvia ::into_response::IntoResponse::into_response(msg.dispatch(contract, Into::into( #ctx ))?)
920974
},
921975
_ => quote! {
922-
#contract_name :: #variant(msg) => msg.dispatch(contract, Into::into( #ctx ))
976+
#enum_name :: #variant(msg) => msg.dispatch(contract, Into::into( #ctx ))
923977
},
924978
}
925979
});
926980

927-
let dispatch_arm = quote! {#contract_name :: #contract (msg) =>msg.dispatch(contract, ctx)};
981+
let dispatch_arm = quote! {#enum_name :: #contract (msg) => msg.dispatch(contract, ctx)};
928982

929983
let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(name);
930984

@@ -951,7 +1005,7 @@ impl<'a> GlueMessage<'a> {
9511005
{
9521006
quote! {
9531007
#[cfg(not(target_arch = "wasm32"))]
954-
impl cosmwasm_schema::QueryResponses for #contract_name {
1008+
impl #sylvia ::cw_schema::QueryResponses for #enum_name {
9551009
fn response_schemas_impl() -> std::collections::BTreeMap<String, #sylvia ::schemars::schema::RootSchema> {
9561010
let responses = [#(#response_schemas_calls),*];
9571011
responses.into_iter().flatten().collect()
@@ -971,12 +1025,12 @@ impl<'a> GlueMessage<'a> {
9711025
#[allow(clippy::derive_partial_eq_without_eq)]
9721026
#[derive(#sylvia ::serde::Serialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)]
9731027
#[serde(rename_all="snake_case", untagged)]
974-
pub enum #contract_name {
1028+
pub enum #enum_name {
9751029
#(#variants,)*
9761030
#msg_name
9771031
}
9781032

979-
impl #contract_name {
1033+
impl #enum_name {
9801034
pub fn dispatch(
9811035
self,
9821036
contract: &#contract,
@@ -996,7 +1050,7 @@ impl<'a> GlueMessage<'a> {
9961050

9971051
#response_schemas
9981052

999-
impl<'de> serde::Deserialize<'de> for #contract_name {
1053+
impl<'de> serde::Deserialize<'de> for #enum_name {
10001054
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
10011055
where D: serde::Deserializer<'de>,
10021056
{
@@ -1043,7 +1097,8 @@ pub struct EntryPoints<'a> {
10431097
error: Type,
10441098
custom: Custom<'a>,
10451099
override_entry_points: OverrideEntryPoints,
1046-
variants: MsgVariants<'a>,
1100+
has_migrate: bool,
1101+
reply: Option<Ident>,
10471102
}
10481103

10491104
impl<'a> EntryPoints<'a> {
@@ -1067,17 +1122,24 @@ impl<'a> EntryPoints<'a> {
10671122
)
10681123
.unwrap_or_else(|| parse_quote! { #sylvia ::cw_std::StdError });
10691124

1070-
let generics: Vec<_> = source.generics.params.iter().collect();
1125+
let has_migrate = !MsgVariants::new(source.as_variants(), MsgType::Migrate, &vec![], &None)
1126+
.variants()
1127+
.is_empty();
10711128

1072-
let variants = MsgVariants::new(source.as_variants(), &generics);
1129+
let reply = MsgVariants::new(source.as_variants(), MsgType::Reply, &vec![], &None)
1130+
.variants()
1131+
.iter()
1132+
.map(|variant| variant.function_name.clone())
1133+
.next();
10731134
let custom = Custom::new(&source.attrs);
10741135

10751136
Self {
10761137
name,
10771138
error,
10781139
custom,
10791140
override_entry_points,
1080-
variants,
1141+
has_migrate,
1142+
reply,
10811143
}
10821144
}
10831145

@@ -1087,17 +1149,13 @@ impl<'a> EntryPoints<'a> {
10871149
error,
10881150
custom,
10891151
override_entry_points,
1090-
variants,
1152+
has_migrate,
1153+
reply,
10911154
} = self;
10921155
let sylvia = crate_module();
10931156

10941157
let custom_msg = custom.msg_or_default();
10951158
let custom_query = custom.query_or_default();
1096-
let reply = variants
1097-
.0
1098-
.iter()
1099-
.find(|variant| variant.msg_type == MsgType::Reply)
1100-
.map(|variant| variant.function_name.clone());
11011159

11021160
#[cfg(not(tarpaulin_include))]
11031161
{
@@ -1119,12 +1177,8 @@ impl<'a> EntryPoints<'a> {
11191177
let migrate_not_overridden = override_entry_points
11201178
.get_entry_point(MsgType::Migrate)
11211179
.is_none();
1122-
let migrate_msg_defined = variants
1123-
.0
1124-
.iter()
1125-
.any(|variant| variant.msg_type == MsgType::Migrate);
11261180

1127-
let migrate = if migrate_not_overridden && migrate_msg_defined {
1181+
let migrate = if migrate_not_overridden && *has_migrate {
11281182
OverrideEntryPoint::emit_default_entry_point(
11291183
&custom_msg,
11301184
&custom_query,

0 commit comments

Comments
 (0)