Skip to content

Interface extension trait #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sylvia-derive/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use syn::{parse_quote, GenericParam, Ident, ItemImpl, ItemTrait, TraitItem, Type

use crate::crate_module;
use crate::interfaces::Interfaces;
use crate::message::{ContractEnumMessage, EnumMessage, GlueMessage, MsgVariants, StructMessage};
use crate::message::{
ContractEnumMessage, EnumMessage, GlueMessage, InterfaceMessages, MsgVariants, StructMessage,
};
use crate::multitest::{MultitestHelpers, TraitMultitestHelpers};
use crate::parser::{ContractArgs, ContractErrorAttr, Custom, MsgType, OverrideEntryPoints};
use crate::remote::Remote;
Expand Down Expand Up @@ -71,6 +73,8 @@ impl<'a> TraitInput<'a> {
)
.emit_querier();

let interface_messages = InterfaceMessages::new(self.item, &self.generics).emit();

#[cfg(not(tarpaulin_include))]
{
quote! {
Expand All @@ -81,6 +85,8 @@ impl<'a> TraitInput<'a> {
#remote

#querier

#interface_messages
}
}
}
Expand Down
111 changes: 88 additions & 23 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::parser::{
};
use crate::strip_generics::StripGenerics;
use crate::utils::{
as_where_clause, brace_generics, extract_return_type, filter_wheres, process_fields,
as_where_clause, emit_bracketed_generics, extract_return_type, filter_wheres, process_fields,
};
use crate::variant_descs::{AsVariantDescs, VariantDescs};
use convert_case::{Case, Casing};
Expand Down Expand Up @@ -114,8 +114,8 @@ impl<'a> StructMessage<'a> {
let fields = fields.iter().map(MsgField::emit);

let where_clause = as_where_clause(wheres);
let generics = brace_generics(generics);
let unused_generics = brace_generics(unused_generics);
let generics = emit_bracketed_generics(generics);
let unused_generics = emit_bracketed_generics(unused_generics);

#[cfg(not(tarpaulin_include))]
{
Expand Down Expand Up @@ -264,7 +264,7 @@ impl<'a> EnumMessage<'a> {
let ctx_type = msg_ty.emit_ctx_type(query_type);
let dispatch_type = msg_ty.emit_result_type(resp_type, &parse_quote!(C::Error));

let all_generics = brace_generics(all_generics);
let all_generics = emit_bracketed_generics(all_generics);
let phantom = if generics.is_empty() {
quote! {}
} else if MsgType::Query == *msg_ty {
Expand All @@ -289,7 +289,7 @@ impl<'a> EnumMessage<'a> {
}
};

let generics = brace_generics(generics);
let generics = emit_bracketed_generics(generics);

let unique_enum_name = Ident::new(&format!("{}{}", trait_name, name), name.span());

Expand Down Expand Up @@ -683,14 +683,14 @@ impl<'a> MsgVariant<'a> {
pub struct MsgVariants<'a> {
variants: Vec<MsgVariant<'a>>,
unbonded_generics: Vec<&'a GenericParam>,
where_clause: Option<WhereClause>,
where_predicates: Vec<&'a WherePredicate>,
}

impl<'a> MsgVariants<'a> {
pub fn new(
source: VariantDescs<'a>,
msg_type: MsgType,
generics: &'a Vec<&'a GenericParam>,
generics: &'a [&'a GenericParam],
unfiltered_where_clause: &'a Option<WhereClause>,
) -> Self {
let mut generics_checker = CheckGenerics::new(generics);
Expand Down Expand Up @@ -718,21 +718,21 @@ impl<'a> MsgVariants<'a> {
.collect();

let (unbonded_generics, _) = generics_checker.used_unused();
let wheres = filter_wheres(
unfiltered_where_clause,
generics.as_slice(),
&unbonded_generics,
);
let where_clause = if !wheres.is_empty() {
Some(parse_quote! { where #(#wheres),* })
} else {
None
};
let where_predicates = filter_wheres(unfiltered_where_clause, generics, &unbonded_generics);

Self {
variants,
unbonded_generics,
where_clause,
where_predicates,
}
}

pub fn where_clause(&self) -> Option<WhereClause> {
let where_predicates = &self.where_predicates;
if !where_predicates.is_empty() {
Some(parse_quote! { where #(#where_predicates),* })
} else {
None
}
}

Expand All @@ -745,9 +745,9 @@ impl<'a> MsgVariants<'a> {
let Self {
variants,
unbonded_generics,
where_clause,
..
} = self;
let where_clause = self.where_clause();

let methods_impl = variants
.iter()
Expand All @@ -759,7 +759,7 @@ impl<'a> MsgVariants<'a> {
.filter(|variant| variant.msg_type == MsgType::Query)
.map(MsgVariant::emit_querier_declaration);

let braced_generics = brace_generics(unbonded_generics);
let braced_generics = emit_bracketed_generics(unbonded_generics);
let querier = quote! { Querier #braced_generics };

#[cfg(not(tarpaulin_include))]
Expand Down Expand Up @@ -804,9 +804,9 @@ impl<'a> MsgVariants<'a> {
let Self {
variants,
unbonded_generics,
where_clause,
..
} = self;
let where_clause = self.where_clause();

let methods_impl = variants
.iter()
Expand Down Expand Up @@ -1100,6 +1100,71 @@ impl<'a> GlueMessage<'a> {
}
}

pub struct InterfaceMessages<'a> {
exec_variants: MsgVariants<'a>,
query_variants: MsgVariants<'a>,
generics: &'a [&'a GenericParam],
}

impl<'a> InterfaceMessages<'a> {
pub fn new(source: &'a ItemTrait, generics: &'a [&'a GenericParam]) -> Self {
let exec_variants = MsgVariants::new(
source.as_variants(),
MsgType::Exec,
generics,
&source.generics.where_clause,
);

let query_variants = MsgVariants::new(
source.as_variants(),
MsgType::Query,
generics,
&source.generics.where_clause,
);

Self {
exec_variants,
query_variants,
generics,
}
}

pub fn emit(&self) -> TokenStream {
let sylvia = crate_module();
let Self {
exec_variants,
query_variants,
generics,
} = self;

let exec_generics = &exec_variants.unbonded_generics;
let query_generics = &query_variants.unbonded_generics;

let bracket_generics = emit_bracketed_generics(generics);
let exec_bracketed_generics = emit_bracketed_generics(exec_generics);
let query_bracketed_generics = emit_bracketed_generics(query_generics);

let phantom = if !generics.is_empty() {
quote! {
_phantom: std::marker::PhantomData<( #(#generics,)* )>,
}
} else {
quote! {}
};

quote! {
pub struct InterfaceTypes #bracket_generics {
#phantom
}

impl #bracket_generics #sylvia ::types::InterfaceMessages for InterfaceTypes #bracket_generics {
type Exec = ExecMsg #exec_bracketed_generics;
type Query = QueryMsg #query_bracketed_generics;
}
}
}
}

pub struct EntryPoints<'a> {
name: Type,
error: Type,
Expand Down Expand Up @@ -1130,11 +1195,11 @@ impl<'a> EntryPoints<'a> {
)
.unwrap_or_else(|| parse_quote! { #sylvia ::cw_std::StdError });

let has_migrate = !MsgVariants::new(source.as_variants(), MsgType::Migrate, &vec![], &None)
let has_migrate = !MsgVariants::new(source.as_variants(), MsgType::Migrate, &[], &None)
.variants()
.is_empty();

let reply = MsgVariants::new(source.as_variants(), MsgType::Reply, &vec![], &None)
let reply = MsgVariants::new(source.as_variants(), MsgType::Reply, &[], &None)
.variants()
.iter()
.map(|variant| variant.function_name.clone())
Expand Down
2 changes: 1 addition & 1 deletion sylvia-derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub fn as_where_clause(where_predicates: &[&WherePredicate]) -> Option<WhereClau
}
}

pub fn brace_generics(unbonded_generics: &[&GenericParam]) -> TokenStream {
pub fn emit_bracketed_generics(unbonded_generics: &[&GenericParam]) -> TokenStream {
match unbonded_generics.is_empty() {
true => quote! {},
false => quote! { < #(#unbonded_generics,)* > },
Expand Down
7 changes: 7 additions & 0 deletions sylvia/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,10 @@ impl<'a, C: CustomQuery> From<(Deps<'a, C>, Env)> for QueryCtx<'a, C> {
}

pub trait CustomMsg: cosmwasm_std::CustomMsg + DeserializeOwned {}

impl<T> CustomMsg for T where T: cosmwasm_std::CustomMsg + DeserializeOwned {}

pub trait InterfaceMessages {
type Exec;
type Query;
}
13 changes: 12 additions & 1 deletion sylvia/tests/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ pub mod cw1 {
#[cw_serde]
pub struct ExternalMsg;
impl cosmwasm_std::CustomMsg for ExternalMsg {}
impl sylvia::types::CustomMsg for ExternalMsg {}

#[cw_serde]
pub struct ExternalQuery;
Expand All @@ -44,6 +43,8 @@ mod tests {

use crate::{cw1::Querier, ExternalMsg, ExternalQuery};

use crate::cw1::InterfaceTypes;
use sylvia::types::InterfaceMessages;
#[test]
fn construct_messages() {
let contract = Addr::unchecked("contract");
Expand All @@ -59,5 +60,15 @@ mod tests {
let cw1_querier = crate::cw1::BoundQuerier::borrowed(&contract, &querier);
let _: Result<ExternalQuery, _> = Querier::some_query(&cw1_querier, ExternalMsg {});
let _: Result<ExternalQuery, _> = cw1_querier.some_query(ExternalMsg {});

// Construct messages with Interface extension
let _ =
<InterfaceTypes<ExternalMsg, _, ExternalQuery> as InterfaceMessages>::Query::some_query(
ExternalMsg {},
);
let _=
<InterfaceTypes<_, ExternalMsg, cosmwasm_std::Empty> as InterfaceMessages>::Exec::execute(vec![
CosmosMsg::Custom(ExternalMsg {}),
]);
}
}