From d94f54819efa7e28b07f4ed4da73d38ff51d6a43 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 21 May 2024 16:52:48 -0400 Subject: [PATCH] wip: implement methods for structs test contract: ``` struct Foo: user: address balance: uint256 def decrement_balance(self, amount: uint256): self.balance -= amount @external def foo(f: Foo): s: Foo = f s.decrement_balance(1) ``` --- vyper/codegen/expr.py | 17 ++++++------- vyper/codegen/module.py | 2 ++ vyper/codegen/self_call.py | 13 ++++++++-- vyper/semantics/analysis/module.py | 7 +++--- vyper/semantics/types/function.py | 21 ++++++++++++++++ vyper/semantics/types/subscriptable.py | 6 +++-- vyper/semantics/types/user.py | 33 +++++++++++++++----------- 7 files changed, 70 insertions(+), 29 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 49c0714110..1ce1c645aa 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -677,21 +677,22 @@ def parse_Call(self): return arg_ir if isinstance(func_t, MemberFunctionT): - darray = Expr(self.expr.func.value, self.context).ir_node - assert isinstance(darray.typ, DArrayT) + ptr = Expr(self.expr.func.value, self.context).ir_node + + if isinstance(ptr.typ, StructT): + return self_call.ir_for_self_call(self.expr, self.context, ptr=ptr) + + assert isinstance(ptr.typ, DArrayT) args = [Expr(x, self.context).ir_node for x in self.expr.args] if self.expr.func.attr == "pop": # TODO consider moving this to builtins - darray = Expr(self.expr.func.value, self.context).ir_node assert len(self.expr.args) == 0 return_item = not self.is_stmt - return pop_dyn_array(darray, return_popped_item=return_item) + return pop_dyn_array(ptr, return_popped_item=return_item) elif self.expr.func.attr == "append": (arg,) = args - check_assign( - dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) - ) - return append_dyn_array(darray, arg) + check_assign(dummy_node_for_type(ptr.typ.value_type), dummy_node_for_type(arg.typ)) + return append_dyn_array(ptr, arg) assert isinstance(func_t, ContractFunctionT) assert func_t.is_internal or func_t.is_constructor diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 1844569138..da35f25709 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -31,6 +31,8 @@ def _runtime_reachable_functions(module_t, id_generator): for fn_t in ret: id_generator.ensure_id(fn_t) + print("ENTER", ret) + return ret diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index 1114dd46cc..dc15a1b5a2 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -1,7 +1,10 @@ +from typing import Optional + from vyper.codegen.core import _freshname, eval_once_check, make_setter from vyper.codegen.ir_node import IRnode from vyper.evm.address_space import MEMORY from vyper.exceptions import StateAccessViolation +from vyper.semantics.types.function import MemberFunctionT from vyper.semantics.types.subscriptable import TupleT @@ -20,7 +23,7 @@ def _align_kwargs(func_t, args_ir): return [i.default_value for i in unprovided_kwargs] -def ir_for_self_call(stmt_expr, context): +def ir_for_self_call(stmt_expr, context, ptr: Optional[IRnode] = None): from vyper.codegen.expr import Expr # TODO rethink this circular import # ** Internal Call ** @@ -39,7 +42,10 @@ def ir_for_self_call(stmt_expr, context): default_vals_ir = [Expr(x, context).ir_node for x in default_vals] args_ir = pos_args_ir + default_vals_ir - assert len(args_ir) == len(func_t.arguments) + if isinstance(func_t, MemberFunctionT): + assert len(args_ir) == len(func_t.arg_types) + else: + assert len(args_ir) == len(func_t.arguments) args_tuple_t = TupleT([x.typ for x in args_ir]) args_as_tuple = IRnode.from_list(["multi"] + [x for x in args_ir], typ=args_tuple_t) @@ -89,6 +95,9 @@ def ir_for_self_call(stmt_expr, context): copy_args = make_setter(args_dst, args_as_tuple) goto_op = ["goto", func_t._ir_info.internal_function_label(context.is_ctor_context)] + if ptr is not None: + goto_op += [ptr] + # pass return buffer to subroutine if return_buffer is not None: goto_op += [return_buffer] diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 63eafdbaf4..c7cf4e3b42 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -127,15 +127,16 @@ def _analyze_call_graph(module_ast: vy_ast.Module): for call in function_calls: try: call_t = get_exact_type_from_node(call.func) - except VyperException: + except VyperException as e: # there is a problem getting the call type. this might be # an issue, but it will be handled properly later. right now # we just want to be able to construct the call graph. + print("ENTER", e, call) continue - if isinstance(call_t, ContractFunctionT) and ( + if isinstance(call_t, MemberFunctionT) or (isinstance(call_t, ContractFunctionT) and ( call_t.is_internal or call_t.is_constructor - ): + )): fn_t.called_functions.add(call_t) for func in function_defs: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 7eab0958a6..85489d1027 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -848,6 +848,15 @@ def __init__( self.return_type = return_type self.is_modifying = is_modifying + self._ir_info = None + + @classmethod + def from_FunctionDef(cls, structname: str, funcdef: vy_ast.FunctionDef): + args = funcdef.args.args[1:] + argtypes = [type_from_annotation(arg.annotation) for arg in args] + return_type = _parse_return_type(funcdef) + return cls(structname, funcdef.name, argtypes, return_type, True) + @property def modifiability(self): return Modifiability.MODIFIABLE if self.is_modifying else Modifiability.RUNTIME_CONSTANT @@ -856,6 +865,18 @@ def modifiability(self): def _id(self): return self.name + @property + def n_positional_args(self): + return len(self.arg_types) + + @property + def n_total_args(self): + return self.n_positional_args + + @property + def keyword_args(self): + return [] + def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 5144952be8..077a630af4 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -236,8 +236,10 @@ def __init__(self, value_type: VyperType, length: int) -> None: from vyper.semantics.types.function import MemberFunctionT - self.add_member("append", MemberFunctionT(self, "append", [self.value_type], None, True)) - self.add_member("pop", MemberFunctionT(self, "pop", [], self.value_type, True)) + self.add_member( + "append", MemberFunctionT(self._id, "append", [self.value_type], None, True) + ) + self.add_member("pop", MemberFunctionT(self._id, "pop", [], self.value_type, True)) def __repr__(self): return f"DynArray[{self.value_type}, {self.length}]" diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a6ee646e62..25d10c2d94 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -19,6 +19,7 @@ from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType +from vyper.semantics.types.function import MemberFunctionT from vyper.semantics.types.subscriptable import HashMapT from vyper.semantics.types.utils import type_from_abi, type_from_annotation from vyper.utils import keccak256 @@ -324,14 +325,11 @@ def tuple_keys(self): return [k for (k, _v) in self.tuple_items()] def tuple_items(self): - return list(self.members.items()) + return list(self.member_types.items()) @cached_property def member_types(self): - """ - Alias to match TupleT API without shadowing `members` on TupleT - """ - return self.members + return {k: v for k, v in self.members.items() if not isinstance(v, MemberFunctionT)} @classmethod def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": @@ -351,23 +349,30 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": struct_name = base_node.name members: dict[str, VyperType] = {} for node in base_node.body: - if not isinstance(node, vy_ast.AnnAssign): + if not isinstance(node, (vy_ast.AnnAssign, vy_ast.FunctionDef)): raise StructureException( - "Struct declarations can only contain variable definitions", node + "Struct declarations can only contain variable or function definitions", node ) - if node.value is not None: - raise StructureException("Cannot assign a value during struct declaration", node) - if not isinstance(node.target, vy_ast.Name): - raise StructureException("Invalid syntax for struct member name", node.target) - member_name = node.target.id + if isinstance(node, vy_ast.AnnAssign): + if node.value is not None: + raise StructureException( + "Cannot assign a value during struct declaration", node + ) + if not isinstance(node.target, vy_ast.Name): + raise StructureException("Invalid syntax for struct member name", node.target) + member_name = node.target.id + typ = type_from_annotation(node.annotation) + else: + member_name = node.name + typ = MemberFunctionT.from_FunctionDef(struct_name, node) if member_name in members: # TODO: add prev_decl raise NamespaceCollision( - f"struct member '{member_name}' has already been declared", node.value + f"struct member '{member_name}' has already been declared", node ) - members[member_name] = type_from_annotation(node.annotation) + members[member_name] = typ return cls(struct_name, members, ast_def=base_node)