diff --git a/tests/ast_utils.py b/tests/ast_utils.py new file mode 100644 index 0000000000..e4be35adb2 --- /dev/null +++ b/tests/ast_utils.py @@ -0,0 +1,25 @@ +from vyper.ast.nodes import VyperNode + + +def deepequals(node: VyperNode, other: VyperNode): + # checks two nodes are recursively equal, ignoring metadata + # like line info. + if not isinstance(other, type(node)): + return False + + if isinstance(node, list): + if len(node) != len(other): + return False + return all(deepequals(a, b) for a, b in zip(node, other)) + + if not isinstance(node, VyperNode): + return node == other + + if getattr(node, "node_id", None) != getattr(other, "node_id", None): + return False + for field_name in (i for i in node.get_fields() if i not in VyperNode.__slots__): + lhs = getattr(node, field_name, None) + rhs = getattr(other, field_name, None) + if not deepequals(lhs, rhs): + return False + return True diff --git a/tests/unit/ast/nodes/test_binary.py b/tests/unit/ast/nodes/test_binary.py index d7662bc4bb..4bebe0abc2 100644 --- a/tests/unit/ast/nodes/test_binary.py +++ b/tests/unit/ast/nodes/test_binary.py @@ -1,5 +1,6 @@ import pytest +from tests.ast_utils import deepequals from vyper import ast as vy_ast from vyper.exceptions import SyntaxException @@ -18,7 +19,7 @@ def x(): """ ) - assert expected == mutated + assert deepequals(expected, mutated) def test_binary_length(): diff --git a/tests/unit/ast/nodes/test_compare_nodes.py b/tests/unit/ast/nodes/test_compare_nodes.py index 164cd3d371..d228e40bd1 100644 --- a/tests/unit/ast/nodes/test_compare_nodes.py +++ b/tests/unit/ast/nodes/test_compare_nodes.py @@ -1,3 +1,4 @@ +from tests.ast_utils import deepequals from vyper import ast as vy_ast @@ -6,21 +7,21 @@ def test_compare_different_node_clases(): left = vyper_ast.body[0].target right = vyper_ast.body[0].value - assert left != right + assert not deepequals(left, right) def test_compare_different_nodes_same_class(): vyper_ast = vy_ast.parse_to_ast("[1, 2]") left, right = vyper_ast.body[0].value.elements - assert left != right + assert not deepequals(left, right) def test_compare_different_nodes_same_value(): vyper_ast = vy_ast.parse_to_ast("[1, 1]") left, right = vyper_ast.body[0].value.elements - assert left != right + assert not deepequals(left, right) def test_compare_similar_node(): @@ -28,11 +29,11 @@ def test_compare_similar_node(): left = vy_ast.Int(value=1) right = vy_ast.Int(value=1) - assert left == right + assert deepequals(left, right) def test_compare_same_node(): vyper_ast = vy_ast.parse_to_ast("42") node = vyper_ast.body[0].value - assert node == node + assert deepequals(node, node) diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 196b1e24e6..cfad0795bc 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -1,6 +1,7 @@ import copy import json +from tests.ast_utils import deepequals from vyper import compiler from vyper.ast.nodes import NODE_SRC_ATTRIBUTES from vyper.ast.parse import parse_to_ast @@ -138,7 +139,7 @@ def test() -> int128: new_dict = json.loads(out_json) new_ast = dict_to_ast(new_dict) - assert new_ast == original_ast + assert deepequals(new_ast, original_ast) # strip source annotations like lineno, we don't care for inspecting diff --git a/tests/unit/ast/test_parser.py b/tests/unit/ast/test_parser.py index e0bfcbc2ef..96df6cf245 100644 --- a/tests/unit/ast/test_parser.py +++ b/tests/unit/ast/test_parser.py @@ -1,3 +1,4 @@ +from tests.ast_utils import deepequals from vyper.ast.parse import parse_to_ast @@ -12,7 +13,7 @@ def test() -> int128: ast1 = parse_to_ast(code) ast2 = parse_to_ast("\n \n" + code + "\n\n") - assert ast1 == ast2 + assert deepequals(ast1, ast2) def test_ast_unequal(): @@ -32,4 +33,4 @@ def test() -> int128: ast1 = parse_to_ast(code1) ast2 = parse_to_ast(code2) - assert ast1 != ast2 + assert not deepequals(ast1, ast2) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index ccc80947e4..3c8feec786 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -331,26 +331,10 @@ def get_fields(cls) -> set: slot_fields = [x for i in cls.__mro__ for x in getattr(i, "__slots__", [])] return set(i for i in slot_fields if not i.startswith("_")) - def __hash__(self): - values = [getattr(self, i, None) for i in VyperNode._public_slots] - return hash(tuple(values)) - def __deepcopy__(self, memo): # default implementation of deepcopy is a hotspot return pickle.loads(pickle.dumps(self)) - def __eq__(self, other): - # CMC 2024-03-03 I'm not sure it makes much sense to compare AST - # nodes, especially if they come from other modules - if not isinstance(other, type(self)): - return False - if getattr(other, "node_id", None) != getattr(self, "node_id", None): - return False - for field_name in (i for i in self.get_fields() if i not in VyperNode.__slots__): - if getattr(self, field_name, None) != getattr(other, field_name, None): - return False - return True - def __repr__(self): cls = type(self) class_repr = f"{cls.__module__}.{cls.__qualname__}" diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py index 148205f5f8..7f02bce79d 100644 --- a/vyper/semantics/analysis/imports.py +++ b/vyper/semantics/analysis/imports.py @@ -59,8 +59,7 @@ def push_path(self, module_ast: vy_ast.Module) -> None: def pop_path(self, expected: vy_ast.Module) -> None: popped = self._path.pop() - if expected != popped: - raise CompilerPanic("unreachable") + assert expected is popped, "unreachable" self._imports.pop() @contextlib.contextmanager @@ -78,7 +77,7 @@ def __init__(self, input_bundle: InputBundle, graph: _ImportGraph): self.graph = graph self._ast_of: dict[int, vy_ast.Module] = {} - self.seen: set[int] = set() + self.seen: set[vy_ast.Module] = set() self._integrity_sum = None @@ -103,7 +102,7 @@ def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module): return sha256sum("".join(acc)) def _resolve_imports_r(self, module_ast: vy_ast.Module): - if id(module_ast) in self.seen: + if module_ast in self.seen: return with self.graph.enter_path(module_ast): for node in module_ast.body: @@ -112,7 +111,8 @@ def _resolve_imports_r(self, module_ast: vy_ast.Module): self._handle_Import(node) elif isinstance(node, vy_ast.ImportFrom): self._handle_ImportFrom(node) - self.seen.add(id(module_ast)) + + self.seen.add(module_ast) def _handle_Import(self, node: vy_ast.Import): # import x.y[name] as y[alias]