Skip to content

Commit

Permalink
[ruby_ast_gen] Handling for Singleton & Anon Classes (#5006)
Browse files Browse the repository at this point in the history
* Improved runner stability and handling class fields

* [ruby_ast_gen] Handling for Singleton & Anon Classes
Additionally, added handling for range operators and lowered hash arguments in calls to named arguments to be consistent with current ANTLR interpretation of these args.

* Roll back AstGenRunner changes

* Remove line for diff
  • Loading branch information
DavidBakerEffendi authored Oct 17, 2024
1 parent 3a432d8 commit 009ac15
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ object RubyIntermediateAst {
extends RubyExpression(span)
with RubyStatement {

def toStatementList: StatementList = StatementList(body :: Nil)(span)

def toMethodDeclaration(name: String, parameters: Option[List[RubyExpression]]): MethodDeclaration =
parameters match {
case Some(givenParameters) => MethodDeclaration(name, givenParameters, body)(span)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ import better.files.File
import io.joern.rubysrc2cpg.Config
import io.joern.x2cpg.astgen.AstGenRunner.{AstGenProgramMetaData, executableDir}
import io.joern.x2cpg.astgen.AstGenRunnerBase
import io.joern.x2cpg.utils.ExternalCommand
import org.jruby.{Ruby, RubyHash, RubyInstanceConfig, RubyRuntimeAdapter}
import org.jruby.javasupport.JavaEmbedUtils
import org.jruby.RubyInstanceConfig
import org.slf4j.LoggerFactory

import java.io.{ByteArrayOutputStream, PrintStream}
import java.io.File.separator
import java.nio.file.{Files, Paths}
import java.io.{ByteArrayOutputStream, PrintStream}
import scala.collection.mutable
import scala.jdk.CollectionConverters.*
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -78,6 +75,7 @@ class RubyAstGenRunner(config: Config) extends AstGenRunnerBase(config) {
config.setEnvironment(Map("GEM_PATH" -> gemPath, "GEM_FILE" -> gemPath).asJava)
config.setHasShebangLine(true)
config.setScriptFileName(mainScript)
config.setHardExit(false)

try {
org.jruby.Main(config).run(Array.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ object ParserKeys {
val Condition = "condition"
val ElseClause = "else_clause"
val ElseBranch = "else_branch"
val End = "end"
val ExecList = "exec_list"
val ExecVar = "exec_var"
val FilePath = "file_path"
Expand All @@ -37,6 +38,7 @@ object ParserKeys {
val Right = "right"
val Rhs = "rhs"
val Statement = "statement"
val Start = "start"
val SuperClass = "superclass"
val ThenBranch = "then_branch"
val Type = "type"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
package io.joern.rubysrc2cpg.parser

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
AllowedTypeDeclarationChild,
ClassFieldIdentifier,
MemberAccess,
MethodDeclaration,
RubyExpression,
RubyFieldIdentifier,
SelfIdentifier,
SimpleIdentifier,
SingleAssignment,
StatementList,
StaticLiteral,
TextSpan,
TypeDeclBodyCall
}
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import upickle.core.*
import upickle.default.*

Expand Down Expand Up @@ -44,26 +50,76 @@ object RubyJsonHelpers {

}

protected def nilLiteral(span: TextSpan): StaticLiteral = StaticLiteral(getBuiltInType(Defines.NilClass))(span)

def createClassBodyAndFields(
obj: ujson.Obj
)(implicit visit: ujson.Value => RubyExpression): (StatementList, List[RubyExpression & RubyFieldIdentifier]) = {

def createBodyMethod(fieldStatements: List[ujson.Obj]): MethodDeclaration = {
MethodDeclaration(
Defines.TypeDeclBody,
Nil,
StatementList(fieldStatements.map(visit))(obj.toTextSpan.spanStart(s"(...)"))
)(obj.toTextSpan.spanStart(s"def <body>; (...); end"))
def bodyMethod(fieldStatements: List[RubyExpression]): MethodDeclaration = {

val body = fieldStatements.map {
case field: SimpleIdentifier =>
val assignmentSpan = field.span.spanStart(s"${field.span} = nil")
SingleAssignment(ClassFieldIdentifier()(field.span), "=", nilLiteral(field.span))(assignmentSpan)
case field: RubyFieldIdentifier =>
val assignmentSpan = field.span.spanStart(s"${field.span} = nil")
SingleAssignment(field, "=", nilLiteral(field.span))(assignmentSpan)
case assignment @ SingleAssignment(_: RubyFieldIdentifier, _, _) => assignment
case assignment @ SingleAssignment(lhs: SimpleIdentifier, op, _) =>
assignment.copy(lhs = ClassFieldIdentifier()(lhs.span))(assignment.span)
case otherExpr => otherExpr
}

MethodDeclaration(Defines.TypeDeclBody, Nil, StatementList(body)(obj.toTextSpan.spanStart(s"(...)")))(
obj.toTextSpan.spanStart(s"def <body>; (...); end")
)
}

val bodyMethod = createBodyMethod(Nil)
/** @param expr
* An expression that is a direct child to a class or module.
* @return
* true if the expression constitutes field-related behaviour, false if otherwise.
*/
def isFieldStmt(expr: RubyExpression): Boolean = {
expr match {
case _: SingleAssignment => true
case _: SimpleIdentifier => true
case _: RubyFieldIdentifier => true
case _ => false
}
}

/** Extracts a field from the expression.
* @param expr
* An expression that is a direct child to a class or module.
*/
def getField(expr: RubyExpression): Option[RubyExpression & RubyFieldIdentifier] = {
expr match {
case field: SimpleIdentifier => Option(ClassFieldIdentifier()(field.span))
case field: RubyFieldIdentifier => Option(field)
case _ @SingleAssignment(lhs: RubyFieldIdentifier, _, _) => Option(lhs)
case _ @SingleAssignment(lhs: SimpleIdentifier, _, _) => Option(ClassFieldIdentifier()(lhs.span))
case _ => None
}
}

obj.visitOption(ParserKeys.Body) match {
case Some(stmtList @ StatementList(expression :: Nil)) if expression.isInstanceOf[AllowedTypeDeclarationChild] =>
(stmtList, Nil)
case Some(stmtList @ StatementList(expression :: Nil)) if isFieldStmt(expression) =>
(StatementList(bodyMethod(expression :: Nil) :: Nil)(stmtList.span), getField(expression).toList)
case Some(stmtList: StatementList) =>
val body = stmtList.copy(statements = bodyMethod +: stmtList.statements)(stmtList.span)
(body, Nil)
case Some(expression) => (StatementList(bodyMethod :: expression :: Nil)(obj.toTextSpan), Nil)
case None => (StatementList(bodyMethod :: Nil)(obj.toTextSpan.spanStart("<empty>")), Nil)
val (fieldStmts, otherStmts) = stmtList.statements.partition(isFieldStmt)
val (typeDeclStmts, bodyStmts) = otherStmts.partition(_.isInstanceOf[AllowedTypeDeclarationChild])
val body = stmtList.copy(statements = bodyMethod(fieldStmts ++ bodyStmts) +: typeDeclStmts)(stmtList.span)
val fields = fieldStmts.flatMap(getField)
(body, fields)
case Some(expression) if isFieldStmt(expression) || !expression.isInstanceOf[AllowedTypeDeclarationChild] =>
(StatementList(bodyMethod(expression :: Nil) :: Nil)(obj.toTextSpan), getField(expression).toList)
case Some(expression) =>
(StatementList(bodyMethod(Nil) :: expression :: Nil)(obj.toTextSpan), Nil)
case None => (StatementList(bodyMethod(Nil) :: Nil)(obj.toTextSpan.spanStart("<empty>")), Nil)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ class RubyJsonToNodeCreator(
val body = visit(obj(ParserKeys.Body))
val block = Block(parameters, body)(obj.toTextSpan)
visit(obj(ParserKeys.CallName)) match {
case simpleCall: RubyCall => simpleCall.withBlock(block)
case classNew: ObjectInstantiation if classNew.target.text == "Class.new" =>
AnonymousClassDeclaration(freshClassName(obj.toTextSpan), None, block.toStatementList)(obj.toTextSpan)
case simpleCall: RubyCall =>
simpleCall.withBlock(block)
case x =>
logger.warn(s"Unexpected call type used for block ${x.getClass}, ignoring block")
x
Expand Down Expand Up @@ -263,7 +266,12 @@ class RubyJsonToNodeCreator(

private def visitExclusiveFlipFlop(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitExclusiveRange(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitExclusiveRange(obj: Obj): RubyExpression = {
val start = visit(obj(ParserKeys.Start))
val end = visit(obj(ParserKeys.End))
val op = RangeOperator(true)(obj.toTextSpan.spanStart("..."))
RangeExpression(start, end, op)(obj.toTextSpan)
}

private def visitExecutableString(obj: Obj): RubyExpression = {
val callName =
Expand All @@ -274,6 +282,11 @@ class RubyJsonToNodeCreator(

private def visitFalse(obj: Obj): RubyExpression = StaticLiteral(getBuiltInType(Defines.FalseClass))(obj.toTextSpan)

private def visitFieldDeclaration(obj: Obj): RubyExpression = {
val arguments = obj.visitArray(ParserKeys.Arguments)
FieldsDeclaration(arguments)(obj.toTextSpan)
}

private def visitFindPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitFloat(obj: Obj): RubyExpression = StaticLiteral(getBuiltInType(Defines.Float))(obj.toTextSpan)
Expand Down Expand Up @@ -326,7 +339,12 @@ class RubyJsonToNodeCreator(

private def visitInclusiveFlipFlop(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitInclusiveRange(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitInclusiveRange(obj: Obj): RubyExpression = {
val start = visit(obj(ParserKeys.Start))
val end = visit(obj(ParserKeys.End))
val op = RangeOperator(false)(obj.toTextSpan.spanStart(".."))
RangeExpression(start, end, op)(obj.toTextSpan)
}

private def visitInPattern(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

Expand Down Expand Up @@ -425,7 +443,11 @@ class RubyJsonToNodeCreator(

private def visitOrAssign(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitPair(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitPair(obj: Obj): RubyExpression = {
val key = visit(obj(ParserKeys.Key))
val value = visit(obj(ParserKeys.Value))
Association(key, value)(obj.toTextSpan)
}

private def visitPostExpression(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

Expand Down Expand Up @@ -484,14 +506,20 @@ class RubyJsonToNodeCreator(
case "new" => visitObjectInstantiation(obj)
case "raise" => visitRaise(obj)
case "include" => visitInclude(obj)
case "attr_reader" | "attr_writer" | "attr_accessor" => visitFieldDeclaration(obj)
case requireLike if ImportCallNames.contains(requireLike) => visitRequireLike(obj)
case _ if BinaryOperators.isBinaryOperatorName(callName) =>
val lhs = visit(obj(ParserKeys.Receiver))
val rhs = obj.visitArray(ParserKeys.Arguments).head
BinaryExpression(lhs, callName, rhs)(obj.toTextSpan)
case _ =>
val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName))
val arguments = obj.visitArray(ParserKeys.Arguments)
val target = SimpleIdentifier()(obj.toTextSpan.spanStart(callName))
val argumentArr = obj.visitArray(ParserKeys.Arguments)
val arguments = argumentArr.zipWithIndex.flatMap {
case (hashLiteral: HashLiteral, idx) =>
hashLiteral.elements // a hash is likely named arguments
case (x, _) => x :: Nil
}
if (obj.contains(ParserKeys.Receiver)) {
val base = visit(obj(ParserKeys.Receiver))
MemberCall(base, ".", callName, arguments)(obj.toTextSpan)
Expand All @@ -503,12 +531,26 @@ class RubyJsonToNodeCreator(

private def visitShadowArg(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))

private def visitSingletonMethodDefinition(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitSingletonMethodDefinition(obj: Obj): RubyExpression = {
val base = visit(obj(ParserKeys.Base))
val name = obj(ParserKeys.Name).str
val parameters = obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj].visitArray(ParserKeys.Children)
val body = obj.visitOption(ParserKeys.Body).getOrElse(StatementList(Nil)(obj.toTextSpan.spanStart("<empty>")))
SingletonMethodDeclaration(base, name, parameters, body)(obj.toTextSpan)
}

private def visitSingletonClassDefinition(obj: Obj): RubyExpression = defaultResult(Option(obj.toTextSpan))
private def visitSingletonClassDefinition(obj: Obj): RubyExpression = {
val name = visit(obj(ParserKeys.Name))
val baseClass = obj.visitOption(ParserKeys.SuperClass)
val body = obj.visitOption(ParserKeys.Body).getOrElse(StatementList(Nil)(obj.toTextSpan.spanStart("<empty>")))
val bodyMemberCall = createBodyMemberCall(name.text, obj.toTextSpan)
SingletonClassDeclaration(name = name, baseClass = baseClass, body = body, bodyMemberCall = Option(bodyMemberCall))(
obj.toTextSpan
)
}

private def visitSingleAssignment(obj: Obj): RubyExpression = {
val lhs = visit(obj(ParserKeys.Lhs))
val lhs = SimpleIdentifier()(obj.toTextSpan.spanStart(obj(ParserKeys.Lhs).str))
val rhs = visit(obj(ParserKeys.Rhs))
SingleAssignment(lhs, "=", rhs)(obj.toTextSpan)
}
Expand Down
Loading

0 comments on commit 009ac15

Please sign in to comment.