Skip to content

Commit

Permalink
Cil visitor (#220)
Browse files Browse the repository at this point in the history
ir-modifying CIL iterator implementation
  • Loading branch information
ailrst committed Jul 3, 2024
1 parent d2e1624 commit d7213df
Show file tree
Hide file tree
Showing 4 changed files with 437 additions and 0 deletions.
169 changes: 169 additions & 0 deletions src/main/scala/ir/cilvisitor/CILVisitor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package ir.cilvisitor

import ir.*
import scala.collection.mutable.ArrayBuffer

/** A new visitor based off CIL.
*
* Copied from ASLi https://github.com/UQ-PAC/aslp/blob/partial_eval/libASL/asl_visitor.ml copying from George Necula's
* CIL project (https://people.eecs.berkeley.edu/~necula/cil/)
*/

sealed trait VisitAction[T]
case class SkipChildren[T]() extends VisitAction[T]
case class DoChildren[T]() extends VisitAction[T]
case class ChangeTo[T](e: T) extends VisitAction[T]
// changes to e, then visits children of e, then applies f to the result
case class ChangeDoChildrenPost[T](e: T, f: T => T) extends VisitAction[T]


trait CILVisitor:
def vprog(e: Program): VisitAction[Program] = DoChildren()
def vproc(e: Procedure): VisitAction[List[Procedure]] = DoChildren()
def vparams(e: ArrayBuffer[Parameter]): VisitAction[ArrayBuffer[Parameter]] = DoChildren()
def vblock(e: Block): VisitAction[Block] = DoChildren()

def vstmt(e: Statement): VisitAction[List[Statement]] = DoChildren()
def vjump(j: Jump): VisitAction[Jump] = DoChildren()
def vfallthrough(j: Option[GoTo]): VisitAction[Option[GoTo]] = DoChildren()

def vexpr(e: Expr): VisitAction[Expr] = DoChildren()
def vvar(e: Variable): VisitAction[Variable] = DoChildren()
def vmem(e: Memory): VisitAction[Memory] = DoChildren()

def enter_scope(params: ArrayBuffer[Parameter]): Unit = ()
def leave_scope(outparam: ArrayBuffer[Parameter]): Unit = ()


def doVisitList[T](v: CILVisitor, a: VisitAction[List[T]], n: T, continue: (T) => T): List[T] = {
a match {
case SkipChildren() => List(n)
case ChangeTo(z) => z
case DoChildren() => List(continue(n))
case ChangeDoChildrenPost(x, f) => f(x.map(continue(_)))
}
}

def doVisit[T](v: CILVisitor, a: VisitAction[T], n: T, continue: (T) => T): T = {
a match {
case SkipChildren() => n
case DoChildren() => continue(n)
case ChangeTo(z) => z
case ChangeDoChildrenPost(x, f) => f(continue(x))
}
}

class CILVisitorImpl(val v: CILVisitor) {

def visit_parameters(p: ArrayBuffer[Parameter]): ArrayBuffer[Parameter] = {
doVisit(v, v.vparams(p), p, (n) => n)
}

def visit_var(n: Variable): Variable = {
doVisit(v, v.vvar(n), n, (n) => n)
}


def visit_mem(n: Memory): Memory = {
doVisit(v, v.vmem(n), n, (n) => n)
}


def visit_jump(j: Jump): Jump = {
doVisit(v, v.vjump(j), j, (j) => j)
}

def visit_fallthrough(j: Option[GoTo]): Option[GoTo] = {
doVisit(v, v.vfallthrough(j), j, (j) => j)
}

def visit_expr(n: Expr): Expr = {
def continue(n: Expr): Expr = n match {
case n: Literal => n
case MemoryLoad(mem, index, endian, size) => MemoryLoad(visit_mem(mem), visit_expr(index), endian, size)
case Extract(end, start, arg) => Extract(end, start, visit_expr(arg))
case Repeat(repeats, arg) => Repeat(repeats, visit_expr(arg))
case ZeroExtend(bits, arg) => ZeroExtend(bits, visit_expr(arg))
case SignExtend(bits, arg) => SignExtend(bits, visit_expr(arg))
case BinaryExpr(op, arg, arg2) => BinaryExpr(op, visit_expr(arg), visit_expr(arg2))
case UnaryExpr(op, arg) => UnaryExpr(op, visit_expr(arg))
case v: Variable => visit_var(v)
case UninterpretedFunction(n, params, rt) => UninterpretedFunction(n, params.map(visit_expr), rt)
}
doVisit(v, v.vexpr(n), n, continue)
}

def visit_stmt(s: Statement): List[Statement] = {
def continue(n: Statement) = n match {
case m: MemoryAssign => {
m.mem = visit_mem(m.mem)
m.index = visit_expr(m.index)
m.value = visit_expr(m.value)
m
}
case m: Assign => {
m.rhs = visit_expr(m.rhs)
m.lhs = visit_var(m.lhs)
m
}
case s: Assert => {
s.body = visit_expr(s.body)
s
}
case s: Assume => {
s.body = visit_expr(s.body)
s
}
case n: NOP => n
}
doVisitList(v, v.vstmt(s), s, continue)
}

def visit_block(b: Block): Block = {
def continue(b: Block) = {
b.statements.foreach(s => {
val r = visit_stmt(s)
r match {
case Nil => b.statements.remove(s)
case n :: tl =>
b.statements.replace(s, n)
b.statements.insertAllAfter(Some(n), tl)
}
})
b.replaceJump(visit_jump(b.jump))
b.fallthrough = visit_fallthrough(b.fallthrough)
b
}

doVisit(v, v.vblock(b), b, continue)
}

def visit_proc(p: Procedure): List[Procedure] = {
def continue(p: Procedure) = {
p.in = visit_parameters(p.in)
v.enter_scope(p.in)
for (b <- p.blocks) {
p.replaceBlock(b, visit_block(b))
}
p.out = visit_parameters(p.out)
v.leave_scope(p.out)
p
}

doVisitList(v, v.vproc(p), p, continue)
}

def visit_proc(p: Program): Program = {
def continue(p: Program) = {
p.procedures = p.procedures.flatMap(visit_proc)
p
}
doVisit(v, v.vprog(p), p, continue)
}
}

def visit_block(v: CILVisitor, b: Block): Block = CILVisitorImpl(v).visit_block(b)
def visit_proc(v: CILVisitor, b: Procedure): List[Procedure] = CILVisitorImpl(v).visit_proc(b)
def visit_stmt(v: CILVisitor, e: Statement): List[Statement] = CILVisitorImpl(v).visit_stmt(e)
def visit_jump(v: CILVisitor, e: Jump): Jump = CILVisitorImpl(v).visit_jump(e)
def visit_expr(v: CILVisitor, e: Expr): Expr = CILVisitorImpl(v).visit_expr(e)
41 changes: 41 additions & 0 deletions src/main/scala/util/intrusive_list/IntrusiveList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,47 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private (
intrusiveListElement.insertAfter(newElem)
}


/**
* Insert an element after another element in the list.
* @param intrusiveListElement The element in the list to insert after, or None to indicate the beginning.
* @param newElems The elements to insert. Must not be members of any other intrusive list(s).
* @return the last inserted element, or the reference element
*/
def insertAllAfter(intrusiveListElement: Option[T], newElems: Iterable[T]): Option[T] = {
intrusiveListElement match {
case None =>
newElems.toList.reverse.map(prepend).headOption.orElse(intrusiveListElement)
case Some(n) =>
var p = n
for (i <- newElems) {
p = insertAfter(p, i)
}
Some(p)
}
}


/**
* Insert an element before another element in the list.
* @param intrusiveListElement The element in the list to insert before, or None to indicate the end of the list.
* @param newElems The elements to insert. Must not be members of any other intrusive list(s).
* @return the last inserted element, or the reference element
*/
def insertAllBefore(intrusiveListElement: Option[T], newElems: Iterable[T]): Option[T] = {
intrusiveListElement match {
case None =>
newElems.map(append).lastOption.orElse(intrusiveListElement)
case Some(n) =>
var p = n
for (i <- newElems.toList.reverse) {
p = insertBefore(p, i)
}
Some(p)
}
}


/**
* Insert an element before another element in the list.
*
Expand Down
181 changes: 181 additions & 0 deletions src/test/scala/ir/CILVisitorTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package ir

import scala.collection.mutable
import scala.collection.immutable.*
import org.scalatest.funsuite.AnyFunSuite
import util.intrusive_list.*
import ir.dsl.*
import ir.cilvisitor.*

class FindVars extends CILVisitor {
val vars = mutable.ArrayBuffer[Variable]()

override def vvar(v: Variable) = {
vars.append(v)
SkipChildren()
}

def globals = vars.collect { case g: Global =>
g
}
}

def globals(e: Expr): List[Variable] = {
val v = FindVars()
visit_expr(v, e)
v.globals.toList
}

def gamma_v(l: Variable) = LocalVar("Gamma_" + l.name, BoolType)

def gamma_e(e: Expr): Expr = {
globals(e) match {
case Nil => TrueLiteral
case hd :: Nil => hd
case hd :: tl => tl.foldLeft(hd: Expr)((l, r) => BinaryExpr(BoolAND, l, gamma_v(r)))
}
}

class AddGammas extends CILVisitor {

override def vstmt(s: Statement) = {
s match {
case a: Assign => ChangeTo(List(a, Assign(gamma_v(a.lhs), gamma_e(a.rhs))))
case _ => SkipChildren()
}

}
}

class CILVisTest extends AnyFunSuite {

def getRegister(name: String) = Register(name, 64)
test("trace prog") {
val p = prog(
proc("main", block("lmain", goto("lmain1")), block("lmain1", goto("lmain2")), block("lmain2", ret))
)

class BlockTrace extends CILVisitor {
val res = mutable.ArrayBuffer[String]()

override def vblock(b: Block) = {
res.append(b.label)
DoChildren()
}

override def vjump(b: Jump) = {
b match {
case g: GoTo => res.addAll(g.targets.map(t => s"gt_${t.label}").toList)
case r: IndirectCall => res.append("indirect")
case r: DirectCall => res.append("direct")
}
DoChildren()
}
}

val v = BlockTrace()
visit_proc(v, p.procedures.head)
assert(v.res.toList == List("lmain", "gt_lmain1", "lmain1", "gt_lmain2", "lmain2", "indirect"))
}

test("visit exprs") {
val program: Program = prog(
proc(
"main",
block("0x0", Assign(getRegister("R6"), getRegister("R31")), goto("0x1")),
block(
"0x1",
MemoryAssign(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64),
goto("returntarget")
),
block("returntarget", ret)
)
)

class ExprTrace extends CILVisitor {
val res = mutable.ArrayBuffer[String]()

override def vvar(e: Variable) = {
e match {
case Register(n, _) => res.append(n);
case _ => ??? // only reg in source program
}
DoChildren()
}

override def vexpr(e: Expr) = {
e match {
case BinaryExpr(op, l, r) => res.append(op.toString)
case n: Literal => res.append(n.toString)
case _ => ()
}
DoChildren()
}
}

val v = ExprTrace()
visit_proc(v, program.procedures.head)
assert(v.res.toList == List("R31", "R6", "add", "R6", "4bv64", "10bv64"))
}

test("rewrite exprs") {

val program: Program = prog(
proc(
"main",
block("0x0", Assign(getRegister("R6"), getRegister("R31")), goto("0x1")),
block(
"0x1",
MemoryAssign(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64),
goto("returntarget")
),
block("returntarget", ret)
)
)
class VarTrace extends CILVisitor {
val res = mutable.ArrayBuffer[String]()

override def vvar(e: Variable) = { res.append(e.name); SkipChildren() }

}

class RegReplace extends CILVisitor {
val res = mutable.ArrayBuffer[String]()

override def vvar(e: Variable) = {
e match {
case Register(n, sz) => ChangeTo(LocalVar("l" + n, e.getType));
case _ => DoChildren()
}
}

}

class RegReplacePost extends CILVisitor {
val res = mutable.ArrayBuffer[String]()

override def vvar(e: Variable) = {
e match {
case LocalVar(n, _) =>
ChangeDoChildrenPost(LocalVar("e" + n, e.getType), e => { res.append(e.name); e });
case _ => DoChildren()
}
}

}

val v = VarTrace()
visit_proc(v, program.procedures.head)
assert(v.res.toList == List("R31", "R6", "R6"))
visit_proc(RegReplace(), program.procedures.head)
val v2 = VarTrace()
visit_proc(v2, program.procedures.head)
assert(v2.res.toList == List("lR31", "lR6", "lR6"))

val v3 = RegReplacePost()
visit_proc(v3, program.procedures.head)
assert(v3.res.toList == List("elR31", "elR6", "elR6"))

}

}
Loading

0 comments on commit d7213df

Please sign in to comment.