Skip to content

Commit

Permalink
feat: Support generic types in entry points
Browse files Browse the repository at this point in the history
  • Loading branch information
jawoznia committed Nov 13, 2023
1 parent ea75996 commit 972c97b
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 27 deletions.
4 changes: 4 additions & 0 deletions examples/contracts/generic_contract/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use sylvia::types::{
};
use sylvia::{contract, schemars};

#[cfg(not(feature = "library"))]
use sylvia::entry_points;

pub struct GenericContract<InstantiateParam, ExecParam, QueryParam, MigrateParam, RetType>(
std::marker::PhantomData<(
InstantiateParam,
Expand All @@ -16,6 +19,7 @@ pub struct GenericContract<InstantiateParam, ExecParam, QueryParam, MigrateParam
)>,
);

#[cfg_attr(not(feature = "library"), entry_points(generics<SvCustomMsg, SvCustomMsg, SvCustomMsg, sylvia::types::SvCustomMsg, SvCustomMsg>))]
#[contract]
#[messages(cw1 as Cw1: custom(msg))]
#[messages(generic<SvCustomMsg, SvCustomMsg, sylvia::types::SvCustomMsg> as Generic: custom(msg))]
Expand Down
4 changes: 4 additions & 0 deletions examples/contracts/generic_iface_on_contract/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ use cosmwasm_std::{Response, StdResult};
use sylvia::types::{InstantiateCtx, SvCustomMsg};
use sylvia::{contract, schemars};

#[cfg(not(feature = "library"))]
use sylvia::entry_points;

pub struct NonGenericContract;

#[cfg_attr(not(feature = "library"), entry_points)]
#[contract]
#[messages(generic<SvCustomMsg, sylvia::types::SvCustomMsg, SvCustomMsg> as Generic: custom(msg))]
#[messages(custom_and_generic<SvCustomMsg, SvCustomMsg, sylvia::types::SvCustomMsg> as CustomAndGeneric)]
Expand Down
2 changes: 1 addition & 1 deletion sylvia-derive/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl Interfaces {
quote! {}
};

let type_name = msg_ty.as_accessor_name();
let type_name = msg_ty.as_accessor_name(false);
quote! {
<#module ::sv::Api #generics as #sylvia ::types::InterfaceApi> :: #type_name :: response_schemas_impl()
}
Expand Down
5 changes: 3 additions & 2 deletions sylvia-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,10 @@ pub fn entry_points(attr: TokenStream, item: TokenStream) -> TokenStream {

#[cfg(not(tarpaulin_include))]
fn entry_points_impl(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
fn inner(_attr: TokenStream2, item: TokenStream2) -> syn::Result<TokenStream2> {
fn inner(attr: TokenStream2, item: TokenStream2) -> syn::Result<TokenStream2> {
let attrs: parser::EntryPointArgs = parse2(attr)?;
let input: ItemImpl = parse2(item)?;
let expanded = EntryPoints::new(&input).emit();
let expanded = EntryPoints::new(&input, attrs).emit();

Ok(quote! {
#input
Expand Down
49 changes: 32 additions & 17 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::crate_module;
use crate::interfaces::Interfaces;
use crate::parser::{
parse_associated_custom_type, parse_struct_message, ContractErrorAttr, ContractMessageAttr,
Custom, MsgAttr, MsgType, OverrideEntryPoints,
Custom, EntryPointArgs, MsgAttr, MsgType, OverrideEntryPoints,
};
use crate::strip_generics::StripGenerics;
use crate::utils::{
Expand All @@ -16,11 +16,12 @@ use proc_macro_error::emit_error;
use quote::{quote, ToTokens};
use syn::fold::Fold;
use syn::parse::{Parse, Parser};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::visit::Visit;
use syn::{
parse_quote, Attribute, GenericParam, Ident, ItemImpl, ItemTrait, Pat, PatType, Path,
ReturnType, Signature, TraitItem, Type, WhereClause, WherePredicate,
parse_quote, Attribute, GenericArgument, GenericParam, Ident, ItemImpl, ItemTrait, Pat,
PatType, Path, ReturnType, Signature, Token, TraitItem, Type, WhereClause, WherePredicate,
};

/// Representation of single struct message
Expand Down Expand Up @@ -747,7 +748,7 @@ impl<'a> MsgVariant<'a> {
let bracketed_generics = emit_bracketed_generics(generics);
let interface_enum =
quote! { < #module sv::Api #bracketed_generics as #sylvia ::types::InterfaceApi> };
let type_name = msg_ty.as_accessor_name();
let type_name = msg_ty.as_accessor_name(false);
let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span());

match msg_ty {
Expand Down Expand Up @@ -790,7 +791,7 @@ impl<'a> MsgVariant<'a> {
} = self;

let params = fields.iter().map(|field| field.emit_method_field());
let type_name = msg_ty.as_accessor_name();
let type_name = msg_ty.as_accessor_name(false);
let name = Ident::new(&name.to_string().to_case(Case::Snake), name.span());

match msg_ty {
Expand Down Expand Up @@ -1023,12 +1024,9 @@ where
custom_query: &Type,
name: &Type,
error: &Type,
contract_generics: &Option<Punctuated<GenericArgument, Token![,]>>,
) -> TokenStream {
let Self {
used_generics,
msg_ty,
..
} = self;
let Self { msg_ty, .. } = self;
let sylvia = crate_module();

let resp_type = match msg_ty {
Expand All @@ -1038,16 +1036,19 @@ where
let params = msg_ty.emit_ctx_params(custom_query);
let values = msg_ty.emit_ctx_values();
let ep_name = msg_ty.emit_ep_name();
let msg_name = msg_ty.emit_msg_name(true);
let bracketed_generics = emit_bracketed_generics(used_generics);
let bracketed_generics = match &contract_generics {
Some(generics) => quote! { ::< #generics > },
None => quote! {},
};
let associated_name = msg_ty.as_accessor_name(true);

quote! {
#[#sylvia ::cw_std::entry_point]
pub fn #ep_name (
#params ,
msg: sv:: #msg_name #bracketed_generics,
msg: < #name < #contract_generics > as #sylvia ::types::ContractApi> :: #associated_name,
) -> Result<#resp_type, #error> {
msg.dispatch(&#name ::new() , ( #values )).map_err(Into::into)
msg.dispatch(&#name #bracketed_generics ::new() , ( #values )).map_err(Into::into)
}
}
}
Expand Down Expand Up @@ -1608,10 +1609,11 @@ pub struct EntryPoints<'a> {
override_entry_points: OverrideEntryPoints,
generics: Vec<&'a GenericParam>,
where_clause: &'a Option<WhereClause>,
attrs: EntryPointArgs,
}

impl<'a> EntryPoints<'a> {
pub fn new(source: &'a ItemImpl) -> Self {
pub fn new(source: &'a ItemImpl, attrs: EntryPointArgs) -> Self {
let sylvia = crate_module();
let name = StripGenerics.fold_type(*source.self_ty.clone());
let override_entry_points = OverrideEntryPoints::new(&source.attrs);
Expand Down Expand Up @@ -1643,6 +1645,7 @@ impl<'a> EntryPoints<'a> {
override_entry_points,
generics,
where_clause,
attrs,
}
}

Expand All @@ -1655,6 +1658,7 @@ impl<'a> EntryPoints<'a> {
override_entry_points,
generics,
where_clause,
attrs,
} = self;
let sylvia = crate_module();

Expand Down Expand Up @@ -1683,6 +1687,10 @@ impl<'a> EntryPoints<'a> {
.iter()
.map(|variant| variant.function_name.clone())
.next();
let contract_generics = match &attrs.generics {
Some(generics) => quote! { ::< #generics > },
None => quote! {},
};

#[cfg(not(tarpaulin_include))]
{
Expand All @@ -1696,6 +1704,7 @@ impl<'a> EntryPoints<'a> {
&custom_query,
name,
error,
&attrs.generics,
),
},
);
Expand All @@ -1706,7 +1715,13 @@ impl<'a> EntryPoints<'a> {

let migrate = if migrate_not_overridden && migrate_variants.get_only_variant().is_some()
{
migrate_variants.emit_default_entry_point(&custom_msg, &custom_query, name, error)
migrate_variants.emit_default_entry_point(
&custom_msg,
&custom_query,
name,
error,
&attrs.generics,
)
} else {
quote! {}
};
Expand All @@ -1722,7 +1737,7 @@ impl<'a> EntryPoints<'a> {
env: #sylvia ::cw_std::Env,
msg: #sylvia ::cw_std::Reply,
) -> Result<#sylvia ::cw_std::Response < #custom_msg >, #error> {
#name ::new(). #reply((deps, env).into(), msg).map_err(Into::into)
#name #contract_generics ::new(). #reply((deps, env).into(), msg).map_err(Into::into)
}
},
_ => quote! {},
Expand Down
45 changes: 38 additions & 7 deletions sylvia-derive/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ use syn::{
use crate::crate_module;
use crate::strip_generics::StripGenerics;

/// Parser arguments for `contract` macro
/// Parsed arguments for `contract` macro
pub struct ContractArgs {
/// Module name wrapping generated messages, by default no additional module is created
/// Module in which contract impl block is defined.
/// Used only while implementing `Interface` on `Contract`.
pub module: Option<Path>,
}

Expand Down Expand Up @@ -46,6 +47,31 @@ impl Parse for ContractArgs {
}
}

/// Parsed arguments for `entry_points` macro
pub struct EntryPointArgs {
/// Types used in place of contracts generics.
pub generics: Option<Punctuated<GenericArgument, Token![,]>>,
}

impl Parse for EntryPointArgs {
fn parse(input: ParseStream) -> Result<Self> {
if input.is_empty() {
return Ok(Self { generics: None });
}

let path: Path = input.parse()?;

let generics = match path.segments.last() {
Some(segment) if segment.ident == "generics" => Some(extract_generics_from_path(&path)),
_ => return Err(Error::new(path.span(), "Expected `generics`")),
};

let _: Nothing = input.parse()?;

Ok(Self { generics })
}
}

/// Type of message to be generated
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum MsgType {
Expand Down Expand Up @@ -158,11 +184,16 @@ impl MsgType {
}
}

pub fn as_accessor_name(&self) -> Option<Type> {
pub fn as_accessor_name(&self, is_wrapper: bool) -> Option<Type> {
match self {
MsgType::Exec if is_wrapper => Some(parse_quote! { ContractExec }),
MsgType::Query if is_wrapper => Some(parse_quote! { ContractQuery }),
MsgType::Instantiate => Some(parse_quote! { Instantiate }),
MsgType::Exec => Some(parse_quote! { Exec }),
MsgType::Query => Some(parse_quote! { Query }),
_ => None,
MsgType::Migrate => Some(parse_quote! { Migrate }),
MsgType::Sudo => Some(parse_quote! { Sudo }),
MsgType::Reply => Some(parse_quote! { Reply }),
}
}
}
Expand Down Expand Up @@ -291,7 +322,7 @@ fn interface_has_custom(content: ParseStream) -> Result<Customs> {
Ok(customs)
}

fn extract_generics_from_path(module: &mut Path) -> Punctuated<GenericArgument, Token![,]> {
fn extract_generics_from_path(module: &Path) -> Punctuated<GenericArgument, Token![,]> {
let generics = module.segments.last().map(|segment| {
match segment.arguments.clone(){
PathArguments::AngleBracketed(generics) => {
Expand All @@ -318,8 +349,8 @@ impl Parse for ContractMessageAttr {
let content;
parenthesized!(content in input);

let mut module = content.parse()?;
let generics = extract_generics_from_path(&mut module);
let module = content.parse()?;
let generics = extract_generics_from_path(&module);
let module = StripGenerics.fold_path(module);

let _: Token![as] = content.parse()?;
Expand Down

0 comments on commit 972c97b

Please sign in to comment.