Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 99 additions & 17 deletions solidity_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def visitTypeDefinition(self, ctx):
return Node(ctx=ctx,
type="TypeDefinition",
typeKeyword=ctx.TypeKeyword().getText(),
name=ctx.identifier().getText(),
elementaryTypeName=self.visit(ctx.elementaryTypeName()))


Expand Down Expand Up @@ -467,7 +468,9 @@ def visitModifierDefinition(self, ctx):
type='ModifierDefinition',
name=ctx.identifier().getText(),
parameters=parameters,
body=self.visit(ctx.block()))
body=self.visit(ctx.block()),
isVirtual=ctx.VirtualKeyword() is not None,
isOverride=ctx.overrideSpecifier() is not None)

def visitStatement(self, ctx):
return self.visit(ctx.getChild(0))
Expand All @@ -485,6 +488,13 @@ def visitRevertStatement(self, ctx):
type='RevertStatement',
functionCall=self.visit(ctx.functionCall()))

def _index_of_child(self, ctx, child):
for i in range(0, len(ctx.children)):
if ctx.getChild(i).getText() == child:
return i

return None

def visitExpression(self, ctx):

children_length = len(ctx.children)
Expand Down Expand Up @@ -599,12 +609,28 @@ def visitExpression(self, ctx):
arguments=args,
names=names)

if ctx.getChild(1).getText() == '[' and ctx.getChild(3).getText() == ']':
if (ctx.getChild(1).getText() == '[' and
ctx.getChild(2).getText() != ':' and
ctx.getChild(3).getText() == ']'):
return Node(ctx=ctx,
type='IndexAccess',
base=self.visit(ctx.getChild(0)),
index=self.visit(ctx.getChild(2)))

if ctx.getChild(1).getText() == '{' and ctx.getChild(3).getText() == '}':
args = []
names = []

for nameValue in ctx.nameValueList().nameValue():
args.append(self.visit(nameValue.expression()))
names.append(nameValue.identifier().getText())

return Node(ctx=ctx,
type='FunctionCallOptions',
expression=self.visit(ctx.getChild(0)),
arguments=args,
names=names)

elif children_length == 5:
# ternary
if ctx.getChild(1).getText() == '?' and ctx.getChild(3).getText() == ':':
Expand All @@ -614,6 +640,30 @@ def visitExpression(self, ctx):
TrueExpression=self.visit(ctx.getChild(2)),
FalseExpression=self.visit(ctx.getChild(4)))

if 4 <= children_length <= 6 and ctx.getChild(1).getText() == '[':
left_bracket_index = self._index_of_child(ctx, '[')
colon_index = self._index_of_child(ctx, ':')
right_bracket_index = self._index_of_child(ctx, ']')

if (left_bracket_index == 1 and
left_bracket_index < colon_index <= left_bracket_index + 2 and
colon_index < right_bracket_index <= colon_index + 2 and
right_bracket_index == children_length - 1):
indexLower = None
indexUpper = None

if colon_index == left_bracket_index + 2:
indexLower = self.visit(ctx.getChild(left_bracket_index + 1))

if right_bracket_index == colon_index + 2:
indexUpper = self.visit(ctx.getChild(colon_index + 1))

return Node(ctx=ctx,
type='IndexRangeAccess',
base=self.visit(ctx.getChild(0)),
indexLower=indexLower,
indexUpper=indexUpper)

return self.visit(list(ctx.getChildren()))


Expand All @@ -640,6 +690,10 @@ def visitStateVariableDeclaration(self, ctx):
if ctx.ConstantKeyword(0):
isDeclaredConst = True

isDeclaredImmutable = False
if ctx.ImmutableKeyword(0):
isDeclaredImmutable = True

decl = self._createNode(
ctx=ctx,
type='VariableDeclaration',
Expand All @@ -649,6 +703,7 @@ def visitStateVariableDeclaration(self, ctx):
visibility=visibility,
isStateVar=True,
isDeclaredConst=isDeclaredConst,
isDeclaredImmutable=isDeclaredImmutable,
isIndexed=False)

return Node(ctx=ctx,
Expand All @@ -662,13 +717,15 @@ def visitForStatement(self, ctx):
if conditionExpression:
conditionExpression = conditionExpression.expression

loopExpression = Node(ctx=ctx,
type='ExpressionStatement',
expression=self.visit(ctx.expression())) if ctx.expression() else None

return Node(ctx=ctx,
type='ForStatement',
initExpression=self.visit(ctx.simpleStatement()),
conditionExpression=conditionExpression,
loopExpression=Node(ctx=ctx,
type='ExpressionStatement',
expression=self.visit(ctx.expression())),
loopExpression=loopExpression,
body=self.visit(ctx.statement())
)

Expand Down Expand Up @@ -741,16 +798,22 @@ def visitIdentifierList(self, ctx: SolidityParser.IdentifierListContext):
def visitVariableDeclarationList(self, ctx: SolidityParser.VariableDeclarationListContext):
result = []
for decl in self._mapCommasToNulls(ctx.children):
if decl == None:
return None
if decl is None:
result.append(None)
else:
storageLocation = None

result.append(self._createNode(ctx=ctx,
type='VariableDeclaration',
name=decl.identifier().getText(),
typeName=self.visit(decl.typeName()),
isStateVar=False,
isIndexed=False,
decl=decl))
if decl.storageLocation():
storageLocation = decl.storageLocation().getText()

result.append(self._createNode(ctx=ctx,
type='VariableDeclaration',
name=decl.identifier().getText(),
typeName=self.visit(decl.typeName()),
storageLocation=storageLocation,
isStateVar=False,
isIndexed=False,
decl=decl))

return result

Expand Down Expand Up @@ -845,9 +908,16 @@ def visitAssemblyExpression(self, ctx):
return self.visit(ctx.getChild(0))

def visitAssemblyMember(self, ctx):
identifier = ctx.identifier()

if isinstance(identifier, list):
name = [n.getText() for n in identifier]
else:
name = identifier.getText()

return Node(ctx=ctx,
type='AssemblyMember',
name=ctx.identifier().getText())
name=name)

def visitAssemblyCall(self, ctx):
functionName = ctx.getChild(0).getText()
Expand Down Expand Up @@ -935,8 +1005,10 @@ def visitAssemblyAssignment(self, ctx):

if names.identifier():
names = [self.visit(names.identifier())]
else:
elif names.assemblyIdentifierList():
names = self.visit(names.assemblyIdentifierList().identifier())
else:
names = self.visit(names.assemblyMember())

return Node(ctx=ctx,
type='AssemblyAssignment',
Expand Down Expand Up @@ -1016,7 +1088,17 @@ def visitUserDefinedTypename(self, ctx):
name=ctx.getText())

def visitReturnStatement(self, ctx):
return self.visit(ctx.expression())
return Node(ctx=ctx,
type="ReturnStatement",
expression=self.visit(ctx.expression()))

def visitBreakStatement(self, ctx):
return Node(ctx=ctx,
type="BreakStatement")

def visitContinueStatement(self, ctx):
return Node(ctx=ctx,
type="ContinueStatement")

def visitTerminal(self, ctx):
return ctx.getText()
Expand Down