Skip to content

Commit

Permalink
feat: Support generic interface implemented on contract
Browse files Browse the repository at this point in the history
  • Loading branch information
jawoznia committed Oct 18, 2023
1 parent eb685d4 commit f046d9e
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 201 deletions.
68 changes: 50 additions & 18 deletions sylvia-derive/src/check_generics.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,57 @@
use syn::visit::Visit;
use syn::GenericParam;
use syn::{parse_quote, GenericArgument, GenericParam, Type};

pub trait GetPath {
fn get_path(&self) -> Option<syn::Path>;
}

impl GetPath for GenericParam {
fn get_path(&self) -> Option<syn::Path> {
match self {
GenericParam::Type(ty) => {
let ident = &ty.ident;
Some(parse_quote! { #ident })
}
_ => None,
}
}
}

impl GetPath for GenericArgument {
fn get_path(&self) -> Option<syn::Path> {
match self {
GenericArgument::Type(Type::Path(path)) => {
let path = &path.path;
Some(parse_quote! { #path })
}
_ => None,
}
}
}

#[derive(Debug)]
pub struct CheckGenerics<'g> {
generics: &'g [&'g GenericParam],
used: Vec<&'g GenericParam>,
pub struct CheckGenerics<'g, Generic> {
generics: &'g [&'g Generic],
used: Vec<&'g Generic>,
}

impl<'g> CheckGenerics<'g> {
pub fn new(generics: &'g [&'g GenericParam]) -> Self {
impl<'g, Generic> CheckGenerics<'g, Generic>
where
Generic: GetPath + PartialEq,
{
pub fn new(generics: &'g [&'g Generic]) -> Self {
Self {
generics,
used: vec![],
}
}

pub fn used(self) -> Vec<&'g GenericParam> {
pub fn used(self) -> Vec<&'g Generic> {
self.used
}

/// Returns split between used and unused generics
pub fn used_unused(self) -> (Vec<&'g GenericParam>, Vec<&'g GenericParam>) {
pub fn used_unused(self) -> (Vec<&'g Generic>, Vec<&'g Generic>) {
let unused = self
.generics
.iter()
Expand All @@ -32,17 +63,18 @@ impl<'g> CheckGenerics<'g> {
}
}

impl<'ast, 'g> Visit<'ast> for CheckGenerics<'g> {
impl<'ast, 'g, Generic> Visit<'ast> for CheckGenerics<'g, Generic>
where
Generic: GetPath + PartialEq,
{
fn visit_path(&mut self, p: &'ast syn::Path) {
if let Some(p) = p.get_ident() {
if let Some(gen) = self
.generics
.iter()
.find(|gen| matches!(gen, GenericParam::Type(ty) if ty.ident == *p))
{
if !self.used.contains(gen) {
self.used.push(gen);
}
if let Some(gen) = self
.generics
.iter()
.find(|gen| gen.get_path().as_ref() == Some(p))
{
if !self.used.contains(gen) {
self.used.push(gen);
}
}

Expand Down
71 changes: 49 additions & 22 deletions sylvia-derive/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use proc_macro_error::emit_error;
use quote::quote;
use syn::parse::{Parse, Parser};
use syn::spanned::Spanned;
use syn::{parse_quote, GenericParam, Ident, ItemImpl, ItemTrait, TraitItem, Type};
use syn::{
parse_quote, GenericArgument, GenericParam, Ident, ItemImpl, ItemTrait, PathArguments,
TraitItem, Type,
};

use crate::crate_module;
use crate::interfaces::Interfaces;
Expand Down Expand Up @@ -156,42 +159,48 @@ impl<'a> ImplInput<'a> {
}

pub fn process(&self) -> TokenStream {
let is_trait = self.item.trait_.is_some();
let Self {
item,
generics,
error,
custom,
override_entry_points,
interfaces,
..
} = self;
let is_trait = item.trait_.is_some();
let multitest_helpers = if cfg!(feature = "mt") {
let interface_generics = self.extract_generic_argument();
MultitestHelpers::new(
self.item,
item,
is_trait,
&self.error,
&self.generics,
&self.custom,
&self.override_entry_points,
&self.interfaces,
error,
&interface_generics,
custom,
override_entry_points,
interfaces,
)
.emit()
} else {
quote! {}
};

let unbonded_generics = &vec![];
let where_clause = &item.generics.where_clause;
let variants = MsgVariants::new(
self.item.as_variants(),
MsgType::Query,
unbonded_generics,
&None,
generics,
where_clause,
);

match is_trait {
true => self.process_interface(variants, multitest_helpers),
true => self.process_interface(multitest_helpers),
false => self.process_contract(variants, multitest_helpers),
}
}

fn process_interface(
&self,
variants: MsgVariants<'a>,
multitest_helpers: TokenStream,
) -> TokenStream {
let querier_bound_for_impl = self.emit_querier_for_bound_impl(variants);
fn process_interface(&self, multitest_helpers: TokenStream) -> TokenStream {
let querier_bound_for_impl = self.emit_querier_for_bound_impl();

#[cfg(not(tarpaulin_include))]
quote! {
Expand All @@ -203,7 +212,7 @@ impl<'a> ImplInput<'a> {

fn process_contract(
&self,
variants: MsgVariants<'a>,
variants: MsgVariants<'a, GenericParam>,
multitest_helpers: TokenStream,
) -> TokenStream {
let messages = self.emit_messages();
Expand Down Expand Up @@ -285,13 +294,31 @@ impl<'a> ImplInput<'a> {
.emit()
}

fn emit_querier_for_bound_impl(&self, variants: MsgVariants<'a>) -> TokenStream {
/// This method should only be called for trait impl block
fn extract_generic_argument(&self) -> Vec<&GenericArgument> {
let interface_generics = &self.item.trait_.as_ref();
let args = match interface_generics {
Some((_, path, _)) => path.segments.last().map(|segment| &segment.arguments),
None => None,
};

match args {
Some(PathArguments::AngleBracketed(args)) => {
args.args.pairs().map(|pair| *pair.value()).collect()
}
_ => vec![],
}
}

fn emit_querier_for_bound_impl(&self) -> TokenStream {
let trait_module = self
.interfaces
.interfaces()
.first()
.get_only_interface()
.map(|interface| &interface.module);
let contract_module = self.attributes.module.as_ref();
let generics = self.extract_generic_argument();

let variants = MsgVariants::new(self.item.as_variants(), MsgType::Query, &generics, &None);

variants.emit_querier_for_bound_impl(trait_module, contract_module)
}
Expand Down
48 changes: 30 additions & 18 deletions sylvia-derive/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,30 +83,23 @@ impl Interfaces {
.collect()
}

pub fn emit_glue_message_variants(
&self,
msg_ty: &MsgType,
msg_name: &Ident,
) -> Vec<TokenStream> {
pub fn emit_glue_message_variants(&self, msg_ty: &MsgType) -> Vec<TokenStream> {
let sylvia = crate_module();

self.interfaces
.iter()
.map(|interface| {
let ContractMessageAttr {
module,
exec_generic_params,
query_generic_params,
variant,
..
module, variant, ..
} = interface;

let generics = match msg_ty {
MsgType::Exec => exec_generic_params.as_slice(),
MsgType::Query => query_generic_params.as_slice(),
_ => &[],
};

let enum_name = Self::merge_module_with_name(interface, msg_name);
quote! { #variant(#module :: #enum_name<#(#generics,)*>) }
let interface_enum =
quote! { <#module ::InterfaceTypes as #sylvia ::types::InterfaceMessages> };
if msg_ty == &MsgType::Query {
quote! { #variant ( #interface_enum :: Query) }
} else {
quote! { #variant ( #interface_enum :: Exec)}
}
})
.collect()
}
Expand Down Expand Up @@ -158,4 +151,23 @@ impl Interfaces {
pub fn as_modules(&self) -> impl Iterator<Item = &Path> {
self.interfaces.iter().map(|interface| &interface.module)
}

pub fn get_only_interface(&self) -> Option<&ContractMessageAttr> {
let interfaces = &self.interfaces;
match interfaces.len() {
0 => None,
1 => Some(&interfaces[0]),
_ => {
let first = &interfaces[0];
for redefined in &interfaces[1..] {
emit_error!(
redefined.module, "The attribute `messages` is redefined";
note = first.module.span() => "Previous definition of the attribute `messsages`";
note = "Only one `messages` attribute can exist on an interface implementation on contract"
);
}
None
}
}
}
}
Loading

0 comments on commit f046d9e

Please sign in to comment.