diff --git a/solidity_parser/parser.py b/solidity_parser/parser.py index 10fb82f..e19652c 100644 --- a/solidity_parser/parser.py +++ b/solidity_parser/parser.py @@ -128,6 +128,7 @@ def visitTypeDefinition(self, ctx): return Node(ctx=ctx, type="TypeDefinition", typeKeyword=ctx.TypeKeyword().getText(), + name=ctx.identifier().getText(), elementaryTypeName=self.visit(ctx.elementaryTypeName())) @@ -467,7 +468,9 @@ def visitModifierDefinition(self, ctx): type='ModifierDefinition', name=ctx.identifier().getText(), parameters=parameters, - body=self.visit(ctx.block())) + body=self.visit(ctx.block()), + isVirtual=ctx.VirtualKeyword() is not None, + isOverride=ctx.overrideSpecifier() is not None) def visitStatement(self, ctx): return self.visit(ctx.getChild(0)) @@ -485,6 +488,13 @@ def visitRevertStatement(self, ctx): type='RevertStatement', functionCall=self.visit(ctx.functionCall())) + def _index_of_child(self, ctx, child): + for i in range(0, len(ctx.children)): + if ctx.getChild(i).getText() == child: + return i + + return None + def visitExpression(self, ctx): children_length = len(ctx.children) @@ -599,12 +609,28 @@ def visitExpression(self, ctx): arguments=args, names=names) - if ctx.getChild(1).getText() == '[' and ctx.getChild(3).getText() == ']': + if (ctx.getChild(1).getText() == '[' and + ctx.getChild(2).getText() != ':' and + ctx.getChild(3).getText() == ']'): return Node(ctx=ctx, type='IndexAccess', base=self.visit(ctx.getChild(0)), index=self.visit(ctx.getChild(2))) + if ctx.getChild(1).getText() == '{' and ctx.getChild(3).getText() == '}': + args = [] + names = [] + + for nameValue in ctx.nameValueList().nameValue(): + args.append(self.visit(nameValue.expression())) + names.append(nameValue.identifier().getText()) + + return Node(ctx=ctx, + type='FunctionCallOptions', + expression=self.visit(ctx.getChild(0)), + arguments=args, + names=names) + elif children_length == 5: # ternary if ctx.getChild(1).getText() == '?' and ctx.getChild(3).getText() == ':': @@ -614,6 +640,30 @@ def visitExpression(self, ctx): TrueExpression=self.visit(ctx.getChild(2)), FalseExpression=self.visit(ctx.getChild(4))) + if 4 <= children_length <= 6 and ctx.getChild(1).getText() == '[': + left_bracket_index = self._index_of_child(ctx, '[') + colon_index = self._index_of_child(ctx, ':') + right_bracket_index = self._index_of_child(ctx, ']') + + if (left_bracket_index == 1 and + left_bracket_index < colon_index <= left_bracket_index + 2 and + colon_index < right_bracket_index <= colon_index + 2 and + right_bracket_index == children_length - 1): + indexLower = None + indexUpper = None + + if colon_index == left_bracket_index + 2: + indexLower = self.visit(ctx.getChild(left_bracket_index + 1)) + + if right_bracket_index == colon_index + 2: + indexUpper = self.visit(ctx.getChild(colon_index + 1)) + + return Node(ctx=ctx, + type='IndexRangeAccess', + base=self.visit(ctx.getChild(0)), + indexLower=indexLower, + indexUpper=indexUpper) + return self.visit(list(ctx.getChildren())) @@ -640,6 +690,10 @@ def visitStateVariableDeclaration(self, ctx): if ctx.ConstantKeyword(0): isDeclaredConst = True + isDeclaredImmutable = False + if ctx.ImmutableKeyword(0): + isDeclaredImmutable = True + decl = self._createNode( ctx=ctx, type='VariableDeclaration', @@ -649,6 +703,7 @@ def visitStateVariableDeclaration(self, ctx): visibility=visibility, isStateVar=True, isDeclaredConst=isDeclaredConst, + isDeclaredImmutable=isDeclaredImmutable, isIndexed=False) return Node(ctx=ctx, @@ -662,13 +717,15 @@ def visitForStatement(self, ctx): if conditionExpression: conditionExpression = conditionExpression.expression + loopExpression = Node(ctx=ctx, + type='ExpressionStatement', + expression=self.visit(ctx.expression())) if ctx.expression() else None + return Node(ctx=ctx, type='ForStatement', initExpression=self.visit(ctx.simpleStatement()), conditionExpression=conditionExpression, - loopExpression=Node(ctx=ctx, - type='ExpressionStatement', - expression=self.visit(ctx.expression())), + loopExpression=loopExpression, body=self.visit(ctx.statement()) ) @@ -741,16 +798,22 @@ def visitIdentifierList(self, ctx: SolidityParser.IdentifierListContext): def visitVariableDeclarationList(self, ctx: SolidityParser.VariableDeclarationListContext): result = [] for decl in self._mapCommasToNulls(ctx.children): - if decl == None: - return None + if decl is None: + result.append(None) + else: + storageLocation = None - result.append(self._createNode(ctx=ctx, - type='VariableDeclaration', - name=decl.identifier().getText(), - typeName=self.visit(decl.typeName()), - isStateVar=False, - isIndexed=False, - decl=decl)) + if decl.storageLocation(): + storageLocation = decl.storageLocation().getText() + + result.append(self._createNode(ctx=ctx, + type='VariableDeclaration', + name=decl.identifier().getText(), + typeName=self.visit(decl.typeName()), + storageLocation=storageLocation, + isStateVar=False, + isIndexed=False, + decl=decl)) return result @@ -845,9 +908,16 @@ def visitAssemblyExpression(self, ctx): return self.visit(ctx.getChild(0)) def visitAssemblyMember(self, ctx): + identifier = ctx.identifier() + + if isinstance(identifier, list): + name = [n.getText() for n in identifier] + else: + name = identifier.getText() + return Node(ctx=ctx, type='AssemblyMember', - name=ctx.identifier().getText()) + name=name) def visitAssemblyCall(self, ctx): functionName = ctx.getChild(0).getText() @@ -935,8 +1005,10 @@ def visitAssemblyAssignment(self, ctx): if names.identifier(): names = [self.visit(names.identifier())] - else: + elif names.assemblyIdentifierList(): names = self.visit(names.assemblyIdentifierList().identifier()) + else: + names = self.visit(names.assemblyMember()) return Node(ctx=ctx, type='AssemblyAssignment', @@ -1016,7 +1088,17 @@ def visitUserDefinedTypename(self, ctx): name=ctx.getText()) def visitReturnStatement(self, ctx): - return self.visit(ctx.expression()) + return Node(ctx=ctx, + type="ReturnStatement", + expression=self.visit(ctx.expression())) + + def visitBreakStatement(self, ctx): + return Node(ctx=ctx, + type="BreakStatement") + + def visitContinueStatement(self, ctx): + return Node(ctx=ctx, + type="ContinueStatement") def visitTerminal(self, ctx): return ctx.getText()