Skip to content

Commit

Permalink
feat: remove Index AST node (vyperlang#3757)
Browse files Browse the repository at this point in the history
remove Index AST node type, it has been deprecated from the python AST
since python3.9.
  • Loading branch information
tserg authored Feb 5, 2024
1 parent f7f67d0 commit 01ec9a1
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 77 deletions.
8 changes: 3 additions & 5 deletions vyper/ast/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None:
raise CompilerPanic("Mismatch between node and input type while building getter")
if annotation.value.get("id") == "HashMap": # type: ignore
# for a HashMap, split the key/value types and use the key type as the next arg
arg, annotation = annotation.slice.value.elements # type: ignore
arg, annotation = annotation.slice.elements # type: ignore
elif annotation.value.get("id") == "DynArray":
arg = vy_ast.Name(id=type_._id)
annotation = annotation.slice.value.elements[0] # type: ignore
annotation = annotation.slice.elements[0] # type: ignore
else:
# for other types, build an input arg node from the expected type
# and remove the outer `Subscript` from the annotation
Expand All @@ -55,9 +55,7 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None:
input_nodes.append(vy_ast.arg(arg=f"arg{i}", annotation=arg))

# wrap the return statement in a `Subscript`
return_stmt = vy_ast.Subscript(
value=return_stmt, slice=vy_ast.Index(value=vy_ast.Name(id=f"arg{i}"))
)
return_stmt = vy_ast.Subscript(value=return_stmt, slice=vy_ast.Name(id=f"arg{i}"))

# after iterating the input types, the remaining annotation node is our return type
return_node = copy.copy(annotation)
Expand Down
4 changes: 0 additions & 4 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,10 +1227,6 @@ class Subscript(ExprNode):
__slots__ = ("slice", "value")


class Index(VyperNode):
__slots__ = ("value",)


class Assign(Stmt):
"""
An assignment.
Expand Down
5 changes: 1 addition & 4 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,9 @@ class Attribute(VyperNode):
value: VyperNode = ...

class Subscript(VyperNode):
slice: Index = ...
slice: VyperNode = ...
value: VyperNode = ...

class Index(VyperNode):
value: Constant = ...

class Assign(VyperNode): ...

class AnnAssign(VyperNode):
Expand Down
18 changes: 0 additions & 18 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,24 +341,6 @@ def visit_Expr(self, node):

return node

def visit_Subscript(self, node):
"""
Maintain consistency of `Subscript.slice` across python versions.
Starting from python 3.9, the `Index` node type has been deprecated,
and made impossible to instantiate via regular means. Here we do awful
hacky black magic to create an `Index` node. We need our own parser.
"""
self.generic_visit(node)

if not isinstance(node.slice, python_ast.Index):
index = python_ast.Constant(value=node.slice, ast_type="Index")
index.__class__ = python_ast.Index
self.generic_visit(index)
node.slice = index

return node

def visit_Constant(self, node):
"""
Handle `Constant` when using Python >=3.8
Expand Down
6 changes: 3 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,16 @@ def parse_Subscript(self):

if isinstance(sub.typ, HashMapT):
# TODO sanity check we are in a self.my_map[i] situation
index = Expr(self.expr.slice.value, self.context).ir_node
index = Expr(self.expr.slice, self.context).ir_node
if isinstance(index.typ, _BytestringT):
# we have to hash the key to get a storage location
index = keccak256_helper(index, self.context)

elif is_array_like(sub.typ):
index = Expr.parse_value_expr(self.expr.slice.value, self.context)
index = Expr.parse_value_expr(self.expr.slice, self.context)

elif is_tuple_like(sub.typ):
index = self.expr.slice.value.n
index = self.expr.slice.n
# note: this check should also happen in get_element_ptr
if not 0 <= index < len(sub.typ.member_types):
raise TypeCheckFailure("unreachable")
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def visit_Call(self, node) -> vy_ast.ExprNode:
return typ._try_fold(node) # type: ignore

def visit_Subscript(self, node) -> vy_ast.ExprNode:
slice_ = node.slice.value.get_folded_value()
slice_ = node.slice.get_folded_value()
value = node.value.get_folded_value()

if not isinstance(value, vy_ast.List):
Expand Down
6 changes: 1 addition & 5 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,10 +647,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None:
def visit_Constant(self, node: vy_ast.Constant, typ: VyperType) -> None:
validate_expected_type(node, typ)

def visit_Index(self, node: vy_ast.Index, typ: VyperType) -> None:
validate_expected_type(node.value, typ)
self.visit(node.value, typ)

def visit_List(self, node: vy_ast.List, typ: VyperType) -> None:
assert isinstance(typ, (SArrayT, DArrayT))
for element in node.elements:
Expand Down Expand Up @@ -687,7 +683,7 @@ def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None:
# get the correct type for the index, it might
# not be exactly base_type.key_type
# note: index_type is validated in types_from_Subscript
index_types = get_possible_types_from_node(node.slice.value)
index_types = get_possible_types_from_node(node.slice)
index_type = index_types.pop()

self.visit(node.slice, index_type)
Expand Down
8 changes: 4 additions & 4 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,13 @@ def types_from_Subscript(self, node):
types_list = self.get_possible_types_from_node(node.value)
ret = []
for t in types_list:
t.validate_index_type(node.slice.value)
ret.append(t.get_subscripted_type(node.slice.value))
t.validate_index_type(node.slice)
ret.append(t.get_subscripted_type(node.slice))
return ret

t = self.get_exact_type_from_node(node.value)
t.validate_index_type(node.slice.value)
return [t.get_subscripted_type(node.slice.value)]
t.validate_index_type(node.slice)
return [t.get_subscripted_type(node.slice)]

def types_from_Tuple(self, node):
types_list = [self.get_exact_type_from_node(i) for i in node.elements]
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional["VyperType"]:
raise StructureException(f"{self} is not callable", node)

@classmethod
def get_subscripted_type(self, node: vy_ast.Index) -> None:
def get_subscripted_type(self, node: vy_ast.VyperNode) -> None:
"""
Return the type of a subscript expression, e.g. x[1]
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/bytestrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def compare_type(self, other):

@classmethod
def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT":
if not isinstance(node, vy_ast.Subscript) or not isinstance(node.slice, vy_ast.Index):
if not isinstance(node, vy_ast.Subscript):
raise StructureException(
f"Cannot declare {cls._id} type without a maximum length, e.g. {cls._id}[5]", node
)
Expand Down
19 changes: 7 additions & 12 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ def get_subscripted_type(self, node):
def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT":
if (
not isinstance(node, vy_ast.Subscript)
or not isinstance(node.slice, vy_ast.Index)
or not isinstance(node.slice.value, vy_ast.Tuple)
or len(node.slice.value.elements) != 2
or not isinstance(node.slice, vy_ast.Tuple)
or len(node.slice.elements) != 2
):
raise StructureException(
(
Expand All @@ -83,7 +82,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT":
node,
)

k_ast, v_ast = node.slice.value.elements
k_ast, v_ast = node.slice.elements
key_type = type_from_annotation(k_ast, DataLocation.STORAGE)
if not key_type._as_hashmap_key:
raise InvalidType("can only use primitive types as HashMap key!", k_ast)
Expand Down Expand Up @@ -198,7 +197,7 @@ def compare_type(self, other):

@classmethod
def from_annotation(cls, node: vy_ast.Subscript) -> "SArrayT":
if not isinstance(node, vy_ast.Subscript) or not isinstance(node.slice, vy_ast.Index):
if not isinstance(node, vy_ast.Subscript):
raise StructureException(
"Arrays must be defined with base type and length, e.g. bool[5]", node
)
Expand Down Expand Up @@ -280,14 +279,10 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT":
if not isinstance(node, vy_ast.Subscript):
raise StructureException(err_msg, node)

if (
not isinstance(node.slice, vy_ast.Index)
or not isinstance(node.slice.value, vy_ast.Tuple)
or len(node.slice.value.elements) != 2
):
if not isinstance(node.slice, vy_ast.Tuple) or len(node.slice.elements) != 2:
raise StructureException(err_msg, node.slice)

length_node = node.slice.value.elements[1]
length_node = node.slice.elements[1]
if length_node.has_folded_value:
length_node = length_node.get_folded_value()

Expand All @@ -296,7 +291,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT":

length = length_node.value

value_node = node.slice.value.elements[0]
value_node = node.slice.elements[0]
value_type = type_from_annotation(value_node)
if not value_type._as_darray:
raise StructureException(f"Arrays of {value_type} are not allowed", value_node)
Expand Down
34 changes: 15 additions & 19 deletions vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,14 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType:
return typ_


def get_index_value(node: vy_ast.Index) -> int:
def get_index_value(node: vy_ast.VyperNode) -> int:
"""
Return the literal value for a `Subscript` index.
Arguments
---------
node: vy_ast.Index
Vyper ast node from the `slice` member of a Subscript node. Must be an
`Index` object (Vyper does not support `Slice` or `ExtSlice`).
node: vy_ast.VyperNode
Vyper ast node from the `slice` member of a Subscript node.
Returns
-------
Expand All @@ -181,23 +180,20 @@ def get_index_value(node: vy_ast.Index) -> int:
# TODO: revisit this!
from vyper.semantics.analysis.utils import get_possible_types_from_node

value = node.get("value")
if value.has_folded_value:
value = value.get_folded_value()

if not isinstance(value, vy_ast.Int):
if hasattr(node, "value"):
# even though the subscript is an invalid type, first check if it's a valid _something_
# this gives a more accurate error in case of e.g. a typo in a constant variable name
try:
get_possible_types_from_node(node.value)
except StructureException:
# StructureException is a very broad error, better to raise InvalidType in this case
pass
if node.has_folded_value:
node = node.get_folded_value()

if not isinstance(node, vy_ast.Int):
# even though the subscript is an invalid type, first check if it's a valid _something_
# this gives a more accurate error in case of e.g. a typo in a constant variable name
try:
get_possible_types_from_node(node)
except StructureException:
# StructureException is a very broad error, better to raise InvalidType in this case
pass
raise InvalidType("Subscript must be a literal integer", node)

if value.value <= 0:
if node.value <= 0:
raise ArrayIndexException("Subscript must be greater than 0", node)

return value.value
return node.value

0 comments on commit 01ec9a1

Please sign in to comment.