Skip to content

Commit

Permalink
[ruby] Fix structure for ForEach loops in Ruby (#4984)
Browse files Browse the repository at this point in the history
* [ruby] Changed handling of ForEach loops

* [ruby] fixed double _astIn reference to identifier

* [ruby] fix failing tests
  • Loading branch information
AndreiDreyer authored Oct 2, 2024
1 parent 0d75bd1 commit bc002cf
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
WhenClause,
WhileExpression
}
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.NewBlock
import io.shiftleft.codepropertygraph.generated.nodes.{
NewBlock,
NewFieldIdentifier,
NewIdentifier,
NewLiteral,
NewLocal
}

trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand Down Expand Up @@ -106,11 +113,101 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo
}

private def astForForExpression(node: ForExpression): Ast = {
val forEachNode = controlStructureNode(node, ControlStructureTypes.FOR, code(node))
val doBodyAst = astsForStatement(node.doBlock)
val iteratorNode = astForExpression(node.forVariable)
val iterableNode = astForExpression(node.iterableVariable)
Ast(forEachNode).withChild(iteratorNode).withChild(iterableNode).withChildren(doBodyAst)
val forEachNode = controlStructureNode(node, ControlStructureTypes.FOR, code(node))

def collectionAst = astForExpression(node.iterableVariable)
val collectionNode = node.iterableVariable

val iterIdentifier =
identifierNode(
node = node.forVariable,
name = node.forVariable.span.text,
code = node.forVariable.span.text,
typeFullName = Defines.Any
)
val iterVarLocal = NewLocal().name(node.forVariable.span.text).code(node.forVariable.span.text)
scope.addToScope(node.forVariable.span.text, iterVarLocal)

val idxName = "_idx_"
val idxLocal = NewLocal().name(idxName).code(idxName).typeFullName(Defines.getBuiltInType(Defines.Integer))
val idxIdenAtAssign = identifierNode(
node = collectionNode,
name = idxName,
code = idxName,
typeFullName = Defines.getBuiltInType(Defines.Integer)
)

val idxAssignment =
callNode(node, s"$idxName = 0", Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH)
val idxAssignmentArgs =
List(Ast(idxIdenAtAssign), Ast(NewLiteral().code("0").typeFullName(Defines.getBuiltInType(Defines.Integer))))
val idxAssignmentAst = callAst(idxAssignment, idxAssignmentArgs)

val idxIdAtCond = idxIdenAtAssign.copy
val collectionCountAccess = callNode(
node,
s"${node.iterableVariable.span.text}.length",
Operators.fieldAccess,
Operators.fieldAccess,
DispatchTypes.STATIC_DISPATCH
)
val fieldAccessAst = callAst(
collectionCountAccess,
collectionAst :: Ast(NewFieldIdentifier().canonicalName("length").code("length")) :: Nil
)

val idxLt = callNode(
node,
s"$idxName < ${node.iterableVariable.span.text}.length",
Operators.lessThan,
Operators.lessThan,
DispatchTypes.STATIC_DISPATCH
)
val idxLtArgs = List(Ast(idxIdAtCond), fieldAccessAst)
val ltCallCond = callAst(idxLt, idxLtArgs)

val idxIdAtCollAccess = idxIdenAtAssign.copy
val collectionIdxAccess = callNode(
node,
s"${node.iterableVariable.span.text}[$idxName++]",
Operators.indexAccess,
Operators.indexAccess,
DispatchTypes.STATIC_DISPATCH
)
val postIncrAst = callAst(
callNode(node, s"$idxName++", Operators.postIncrement, Operators.postIncrement, DispatchTypes.STATIC_DISPATCH),
Ast(idxIdAtCollAccess) :: Nil
)

val indexAccessAst = callAst(collectionIdxAccess, collectionAst :: postIncrAst :: Nil)
val iteratorAssignmentNode = callNode(
node,
s"${node.forVariable.span.text} = ${node.iterableVariable.span.text}[$idxName++]",
Operators.assignment,
Operators.assignment,
DispatchTypes.STATIC_DISPATCH
)
val iteratorAssignmentArgs = List(Ast(iterIdentifier), indexAccessAst)
val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs)
val doBodyAst = astsForStatement(node.doBlock)

val locals = Ast(idxLocal)
.withRefEdge(idxIdenAtAssign, idxLocal)
.withRefEdge(idxIdAtCond, idxLocal)
.withRefEdge(idxIdAtCollAccess, idxLocal) :: Ast(iterVarLocal).withRefEdge(iterIdentifier, iterVarLocal) :: Nil

val conditionAsts = ltCallCond :: Nil
val initAsts = idxAssignmentAst :: Nil
val updateAsts = iteratorAssignmentAst :: Nil

forAst(
forNode = forEachNode,
locals = locals,
initAsts = initAsts,
conditionAsts = conditionAsts,
updateAsts = updateAsts,
bodyAsts = doBodyAst
)
}

protected def astsForCaseExpression(node: CaseExpression): Seq[Ast] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,25 @@ class ControlStructureTests extends RubyCode2CpgFixture {
forEachNode.controlStructureType shouldBe ControlStructureTypes.FOR

inside(forEachNode.astChildren.l) {
case (iteratorNode: Identifier) :: (iterableNode: Identifier) :: (doBody: Block) :: Nil =>
iteratorNode.code shouldBe "i"
iterableNode.code shouldBe "x"
// We use .ast as there will be an implicit return node here
doBody.ast.isCall.code.headOption shouldBe Option("puts x - i")
case _ => fail("No node for iterable found in `for-in` statement")
case (idxLocal: Local) :: (iVarLocal: Local) :: (initAssign: Call) :: (cond: Call) :: (update: Call) :: (forBlock: Block) :: Nil =>
idxLocal.name shouldBe "_idx_"
idxLocal.typeFullName shouldBe Defines.getBuiltInType(Defines.Integer)

iVarLocal.name shouldBe "i"

initAssign.code shouldBe "_idx_ = 0"
initAssign.name shouldBe Operators.assignment
initAssign.methodFullName shouldBe Operators.assignment

cond.code shouldBe "_idx_ < x.length"
cond.name shouldBe Operators.lessThan
cond.methodFullName shouldBe Operators.lessThan

update.code shouldBe "i = x[_idx_++]"
update.name shouldBe Operators.assignment
update.methodFullName shouldBe Operators.assignment

case xs => fail(s"Expected 6 children for `forEachNode`, got [${xs.code.mkString(",")}]")
}

inside(forEachNode.astChildren.isBlock.l) {
Expand All @@ -438,13 +451,25 @@ class ControlStructureTests extends RubyCode2CpgFixture {
forEachNode.controlStructureType shouldBe ControlStructureTypes.FOR

inside(forEachNode.astChildren.l) {
case (iteratorNode: Identifier) :: (iterableNode: Call) :: (doBody: Block) :: Nil =>
iteratorNode.code shouldBe "i"
iterableNode.code shouldBe "1..x"
iterableNode.name shouldBe Operators.range
// We use .ast as there will be an implicit return node here
doBody.ast.isCall.code.headOption shouldBe Option("puts x + i")
case _ => fail("Invalid `for-in` children nodes")
case (idxLocal: Local) :: (iVarLocal: Local) :: (initAssign: Call) :: (cond: Call) :: (update: Call) :: (forBlock: Block) :: Nil =>
idxLocal.name shouldBe "_idx_"
idxLocal.typeFullName shouldBe Defines.getBuiltInType(Defines.Integer)

iVarLocal.name shouldBe "i"

initAssign.code shouldBe "_idx_ = 0"
initAssign.name shouldBe Operators.assignment
initAssign.methodFullName shouldBe Operators.assignment

cond.code shouldBe "_idx_ < 1..x.length"
cond.name shouldBe Operators.lessThan
cond.methodFullName shouldBe Operators.lessThan

update.code shouldBe "i = 1..x[_idx_++]"
update.name shouldBe Operators.assignment
update.methodFullName shouldBe Operators.assignment

case xs => fail(s"Expected 6 children for `forEachNode`, got [${xs.code.mkString(",")}]")
}

case _ => fail("No control structure node found for `for-in`.")
Expand Down Expand Up @@ -650,4 +675,44 @@ class ControlStructureTests extends RubyCode2CpgFixture {
case xs => fail(s"Expected one IF structure, got [${xs.code.mkString(",")}]")
}
}

"ForEach loops" in {
val cpg = code("""
|fibNumbers = [0, 1, 1, 2, 3, 5, 8, 13]
|for num in fibNumbers
| puts num
|end
|""".stripMargin)

inside(cpg.method.isModule.controlStructure.l) {
case forEachNode :: Nil =>
forEachNode.controlStructureType shouldBe ControlStructureTypes.FOR

inside(forEachNode.astChildren.l) {
case (idxLocal: Local) :: (numLocal: Local) :: (initAssign: Call) :: (cond: Call) :: (update: Call) :: (forBlock: Block) :: Nil =>
idxLocal.name shouldBe "_idx_"
idxLocal.typeFullName shouldBe Defines.getBuiltInType(Defines.Integer)

numLocal.name shouldBe "num"

initAssign.code shouldBe "_idx_ = 0"
initAssign.name shouldBe Operators.assignment
initAssign.methodFullName shouldBe Operators.assignment

cond.code shouldBe "_idx_ < fibNumbers.length"
cond.name shouldBe Operators.lessThan
cond.methodFullName shouldBe Operators.lessThan

update.code shouldBe "num = fibNumbers[_idx_++]"
update.name shouldBe Operators.assignment
update.methodFullName shouldBe Operators.assignment

val List(putsCall) = cpg.call.nameExact("puts").l
putsCall.astParent shouldBe forBlock

case xs => fail(s"Expected 6 children for `forEachNode`, got [${xs.code.mkString(",")}]")
}
case xs => fail(s"Expected one node for `forEach` loop, got [${xs.code.mkString(",")}]")
}
}
}

0 comments on commit bc002cf

Please sign in to comment.