Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: fix importing of flag types #3871

Merged
merged 6 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/functional/codegen/modules/test_flag_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
def test_import_flag_types(make_input_bundle, get_contract):
lib1 = """
import lib2

flag Roles:
ADMIN
USER

enum Roles2:
ADMIN
USER

role: Roles
role2: Roles2
role3: lib2.Roles3
"""
lib2 = """
flag Roles3:
ADMIN
USER
NOBODY
"""
contract = """
import lib1

initializes: lib1

@external
def bar(r: lib1.Roles, r2: lib1.Roles2, r3: lib1.lib2.Roles3) -> bool:
lib1.role = r
lib1.role2 = r2
lib1.role3 = r3
assert lib1.role == lib1.Roles.ADMIN
assert lib1.role2 == lib1.Roles2.USER
assert lib1.role3 == lib1.lib2.Roles3.NOBODY
return True
"""

input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})
c = get_contract(contract, input_bundle=input_bundle)
assert c.bar(1, 2, 4) is True
2 changes: 1 addition & 1 deletion vyper/ast/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ tuple_def: "(" ( NAME | array_def | dyn_array_def | tuple_def ) ( "," ( NAME | a
// NOTE: Map takes a basic type and maps to another type (can be non-basic, including maps)
_MAP: "HashMap"
map_def: _MAP "[" ( NAME | array_def ) "," type "]"
imported_type: NAME "." NAME
imported_type: NAME ("." NAME)+
type: ( NAME | imported_type | array_def | tuple_def | map_def | dyn_array_def )

// Structs can be composed of 1+ basic types or other custom_types
Expand Down
9 changes: 3 additions & 6 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,9 @@ def parse_Name(self):
def parse_Attribute(self):
typ = self.expr._metadata["type"]

# MyFlag.foo
if (
isinstance(typ, FlagT)
and isinstance(self.expr.value, vy_ast.Name)
and typ.name == self.expr.value.id
):
# check if we have a flag constant, e.g.
# [lib1].MyFlag.FOO
if isinstance(typ, FlagT) and is_type_t(self.expr.value._metadata["type"], FlagT):
# 0, 1, 2, .. 255
flag_id = typ._flag_members[self.expr.attr]
value = 2**flag_id # 0 => 0001, 1 => 0010, 2 => 0100, etc.
Expand Down
1 change: 1 addition & 0 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def _validate_self_namespace():

def visit_FlagDef(self, node):
obj = FlagT.from_FlagDef(node)
node._metadata["flag_type"] = obj
self.namespace[node.name] = obj

def visit_EventDef(self, node):
Expand Down
8 changes: 8 additions & 0 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None):
# add the type of the event so it can be used in call position
self.add_member(e.name, TYPE_T(e._metadata["event_type"])) # type: ignore

for f in self.flag_defs:
self.add_member(f.name, TYPE_T(f._metadata["flag_type"]))
self._helper.add_member(f.name, TYPE_T(f._metadata["flag_type"]))

for s in self.struct_defs:
# add the type of the struct so it can be used in call position
self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore
Expand Down Expand Up @@ -347,6 +351,10 @@ def function_defs(self):
def event_defs(self):
return self._module.get_children(vy_ast.EventDef)

@cached_property
def flag_defs(self):
return self._module.get_children(vy_ast.FlagDef)

@property
def struct_defs(self):
return self._module.get_children(vy_ast.StructDef)
Expand Down
Loading