Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: allow module intrinsic interface call #4090

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
23 changes: 23 additions & 0 deletions tests/functional/codegen/modules/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,26 @@ def __init__():
# call `c.__default__()`
env.message_call(c.address)
assert c.counter() == 6


def test_inline_interface_export(make_input_bundle, get_contract):
lib1 = """
interface IAsset:
def asset() -> address: view

implements: IAsset

@external
@view
def asset() -> address:
return self
"""
main = """
import lib1

exports: lib1.IAsset
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
c = get_contract(main, input_bundle=input_bundle)

assert c.asset() == c.address
31 changes: 30 additions & 1 deletion tests/functional/codegen/modules/test_interface_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,40 @@ def foo() -> bool:
# check that this typechecks both directions
a: lib1.IERC20 = IERC20(msg.sender)
b: lib2.IERC20 = IERC20(msg.sender)
c: IERC20 = lib1.IERC20(msg.sender) # allowed in call position

# return the equality so we can sanity check it
return a == b
return a == b and b == c
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})
c = get_contract(main, input_bundle=input_bundle)

assert c.foo() is True


def test_intrinsic_interface(get_contract, make_input_bundle):
lib = """
@external
@view
def foo() -> uint256:
# detect self call
if msg.sender == self:
return 4
else:
return 5
"""
main = """
import lib

exports: lib.__interface__

@external
@view
def bar() -> uint256:
return staticcall lib.__interface__(self).foo()
"""
input_bundle = make_input_bundle({"lib.vy": lib})
c = get_contract(main, input_bundle=input_bundle)

assert c.foo() == 5
assert c.bar() == 4
22 changes: 22 additions & 0 deletions tests/functional/syntax/modules/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,28 @@ def do_xyz():
assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!"


def test_no_export_unimplemented_inline_interface(make_input_bundle):
lib1 = """
interface ifoo:
def do_xyz(): nonpayable

# technically implements ifoo, but missing `implements: ifoo`

@external
def do_xyz():
pass
"""
main = """
import lib1

exports: lib1.ifoo
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(InterfaceViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!"


def test_export_selector_conflict(make_input_bundle):
ifoo = """
@external
Expand Down
10 changes: 6 additions & 4 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace
from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT
from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT, is_type_t
from vyper.semantics.types.function import ContractFunctionT
from vyper.semantics.types.module import ModuleT
from vyper.semantics.types.utils import type_from_annotation
Expand Down Expand Up @@ -547,7 +547,9 @@ def visit_ExportsDecl(self, node):
elif isinstance(info.typ, ContractFunctionT):
# regular function
funcs = [info.typ]
elif isinstance(info.typ, InterfaceT):
elif is_type_t(info.typ, InterfaceT):
interface_t = info.typ.typedef

if not isinstance(item, vy_ast.Attribute):
raise StructureException(
"invalid export",
Expand All @@ -558,7 +560,7 @@ def visit_ExportsDecl(self, node):
if module_info is None:
raise StructureException("not a valid module!", item.value)

if info.typ not in module_info.typ.implemented_interfaces:
if interface_t not in module_info.typ.implemented_interfaces:
iface_str = item.node_source_code
module_str = item.value.node_source_code
msg = f"requested `{iface_str}` but `{module_str}`"
Expand All @@ -569,7 +571,7 @@ def visit_ExportsDecl(self, node):
# find the specific implementation of the function in the module
funcs = [
module_exposed_fns[fn.name]
for fn in info.typ.functions.values()
for fn in interface_t.functions.values()
if fn.is_external
]
else:
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _raise_invalid_reference(name, node):
try:
s = t.get_member(name, node)

if isinstance(s, (VyperType, TYPE_T)):
if isinstance(s, VyperType):
# ex. foo.bar(). bar() is a ContractFunctionT
return [s]

Expand Down
14 changes: 11 additions & 3 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,16 +323,24 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None):

for i in self.import_stmts:
import_info = i._metadata["import_info"]
self.add_member(import_info.alias, import_info.typ)

if hasattr(import_info.typ, "module_t"):
self._helper.add_member(import_info.alias, TYPE_T(import_info.typ))
module_info = import_info.typ
# get_expr_info uses ModuleInfo
self.add_member(import_info.alias, module_info)
# type_from_annotation uses TYPE_T
self._helper.add_member(import_info.alias, TYPE_T(module_info.module_t))
else: # interfaces
assert isinstance(import_info, InterfaceT)
self.add_member(import_info.alias, TYPE_T(import_info.typ))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add an assert for it being an interface?
might be useful in the future should we allow eg selective imports

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


for name, interface_t in self.interfaces.items():
# can access interfaces in type position
self._helper.add_member(name, TYPE_T(interface_t))

self.add_member("__interface__", self.interface)
# can use module.__interface__ in call position
self.add_member("__interface__", TYPE_T(self.interface))
self._helper.add_member("__interface__", TYPE_T(self.interface))

# __eq__ is very strict on ModuleT - object equality! this is because we
# don't want to reason about where a module came from (i.e. input bundle,
Expand Down
Loading