Skip to content

Commit

Permalink
fix[lang]: fix certain varinfo comparisons (#4164)
Browse files Browse the repository at this point in the history
for `VarInfo`s which are declared in memory, the `VarInfo`
initialization is missing `decl_node`, and therefore different
variables with the same type get detected as overlapping in loop
iterator modification detection. this commit properly initializes
memory `VarInfo`s with the appropriate `decl_node`.

---------

Co-authored-by: cyberthirst <[email protected]>
  • Loading branch information
charles-cooper and cyberthirst authored Aug 7, 2024
1 parent 9b322d6 commit 07ddea6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 7 deletions.
33 changes: 33 additions & 0 deletions tests/functional/codegen/features/iteration/test_for_in_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,3 +897,36 @@ def foo():
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot modify loop variable `queue`"


def test_iterator_modification_memory(get_contract):
code = """
@external
def foo() -> DynArray[uint256, 10]:
# check VarInfos are distinguished by decl_node when they have same type
alreadyDone: DynArray[uint256, 10] = []
_assets: DynArray[uint256, 10] = [1, 2, 3, 4, 3, 2, 1]
for a: uint256 in _assets:
if a in alreadyDone:
continue
alreadyDone.append(a)
return alreadyDone
"""
c = get_contract(code)
assert c.foo() == [1, 2, 3, 4]


def test_iterator_modification_func_arg(get_contract):
code = """
@internal
def boo(a: DynArray[uint256, 12] = [], b: DynArray[uint256, 12] = []) -> DynArray[uint256, 12]:
for i: uint256 in a:
b.append(i)
return b
@external
def foo() -> DynArray[uint256, 12]:
return self.boo([1, 2, 3])
"""
c = get_contract(code)
assert c.foo() == [1, 2, 3]
24 changes: 21 additions & 3 deletions tests/unit/ast/test_ast_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,13 @@ def qux2():
{
"annotation": {"ast_type": "Name", "id": "uint256"},
"ast_type": "AnnAssign",
"target": {"ast_type": "Name", "id": "x"},
"target": {
"ast_type": "Name",
"id": "x",
"variable_reads": [
{"name": "x", "decl_node": {"node_id": 15, "source_id": 0}, "access_path": []}
],
},
"value": {
"ast_type": "Attribute",
"attr": "counter",
Expand Down Expand Up @@ -1300,7 +1306,13 @@ def qux2():
{
"annotation": {"ast_type": "Name", "id": "uint256"},
"ast_type": "AnnAssign",
"target": {"ast_type": "Name", "id": "x"},
"target": {
"ast_type": "Name",
"id": "x",
"variable_reads": [
{"name": "x", "decl_node": {"node_id": 35, "source_id": 0}, "access_path": []}
],
},
"value": {
"ast_type": "Attribute",
"attr": "counter",
Expand All @@ -1317,7 +1329,13 @@ def qux2():
{
"annotation": {"ast_type": "Name", "id": "uint256"},
"ast_type": "AnnAssign",
"target": {"ast_type": "Name", "id": "y"},
"target": {
"ast_type": "Name",
"id": "y",
"variable_reads": [
{"name": "y", "decl_node": {"node_id": 44, "source_id": 0}, "access_path": []}
],
},
"value": {
"ast_type": "Attribute",
"attr": "counter",
Expand Down
5 changes: 4 additions & 1 deletion vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,10 @@ def to_dict(self):
# map SUBSCRIPT_ACCESS to `"$subscript_access"` (which is an identifier
# which can't be constructed by the user)
path = ["$subscript_access" if s is self.SUBSCRIPT_ACCESS else s for s in self.path]
varname = var.decl_node.target.id
if isinstance(var.decl_node, vy_ast.arg):
varname = var.decl_node.arg
else:
varname = var.decl_node.target.id

decl_node = var.decl_node.get_id_dict()
ret = {"name": varname, "decl_node": decl_node, "access_path": path}
Expand Down
6 changes: 3 additions & 3 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def analyze(self):

for arg in self.func.arguments:
self.namespace[arg.name] = VarInfo(
arg.typ, location=location, modifiability=modifiability
arg.typ, location=location, modifiability=modifiability, decl_node=arg.ast_source
)

for node in self.fn_node.body:
Expand Down Expand Up @@ -363,7 +363,7 @@ def visit_AnnAssign(self, node):
# validate the value before adding it to the namespace
self.expr_visitor.visit(node.value, typ)

self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY)
self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY, decl_node=node)

self.expr_visitor.visit(node.target, typ)

Expand Down Expand Up @@ -575,7 +575,7 @@ def visit_For(self, node):
target_name = node.target.target.id
# maybe we should introduce a new Modifiability: LOOP_VARIABLE
self.namespace[target_name] = VarInfo(
target_type, modifiability=Modifiability.RUNTIME_CONSTANT
target_type, modifiability=Modifiability.RUNTIME_CONSTANT, decl_node=node.target
)

self.expr_visitor.visit(node.target.target, target_type)
Expand Down

0 comments on commit 07ddea6

Please sign in to comment.