-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ir-modifying CIL iterator implementation
- Loading branch information
Showing
4 changed files
with
437 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package ir.cilvisitor | ||
|
||
import ir.* | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
/** A new visitor based off CIL. | ||
* | ||
* Copied from ASLi https://github.com/UQ-PAC/aslp/blob/partial_eval/libASL/asl_visitor.ml copying from George Necula's | ||
* CIL project (https://people.eecs.berkeley.edu/~necula/cil/) | ||
*/ | ||
|
||
sealed trait VisitAction[T] | ||
case class SkipChildren[T]() extends VisitAction[T] | ||
case class DoChildren[T]() extends VisitAction[T] | ||
case class ChangeTo[T](e: T) extends VisitAction[T] | ||
// changes to e, then visits children of e, then applies f to the result | ||
case class ChangeDoChildrenPost[T](e: T, f: T => T) extends VisitAction[T] | ||
|
||
|
||
trait CILVisitor: | ||
def vprog(e: Program): VisitAction[Program] = DoChildren() | ||
def vproc(e: Procedure): VisitAction[List[Procedure]] = DoChildren() | ||
def vparams(e: ArrayBuffer[Parameter]): VisitAction[ArrayBuffer[Parameter]] = DoChildren() | ||
def vblock(e: Block): VisitAction[Block] = DoChildren() | ||
|
||
def vstmt(e: Statement): VisitAction[List[Statement]] = DoChildren() | ||
def vjump(j: Jump): VisitAction[Jump] = DoChildren() | ||
def vfallthrough(j: Option[GoTo]): VisitAction[Option[GoTo]] = DoChildren() | ||
|
||
def vexpr(e: Expr): VisitAction[Expr] = DoChildren() | ||
def vvar(e: Variable): VisitAction[Variable] = DoChildren() | ||
def vmem(e: Memory): VisitAction[Memory] = DoChildren() | ||
|
||
def enter_scope(params: ArrayBuffer[Parameter]): Unit = () | ||
def leave_scope(outparam: ArrayBuffer[Parameter]): Unit = () | ||
|
||
|
||
def doVisitList[T](v: CILVisitor, a: VisitAction[List[T]], n: T, continue: (T) => T): List[T] = { | ||
a match { | ||
case SkipChildren() => List(n) | ||
case ChangeTo(z) => z | ||
case DoChildren() => List(continue(n)) | ||
case ChangeDoChildrenPost(x, f) => f(x.map(continue(_))) | ||
} | ||
} | ||
|
||
def doVisit[T](v: CILVisitor, a: VisitAction[T], n: T, continue: (T) => T): T = { | ||
a match { | ||
case SkipChildren() => n | ||
case DoChildren() => continue(n) | ||
case ChangeTo(z) => z | ||
case ChangeDoChildrenPost(x, f) => f(continue(x)) | ||
} | ||
} | ||
|
||
class CILVisitorImpl(val v: CILVisitor) { | ||
|
||
def visit_parameters(p: ArrayBuffer[Parameter]): ArrayBuffer[Parameter] = { | ||
doVisit(v, v.vparams(p), p, (n) => n) | ||
} | ||
|
||
def visit_var(n: Variable): Variable = { | ||
doVisit(v, v.vvar(n), n, (n) => n) | ||
} | ||
|
||
|
||
def visit_mem(n: Memory): Memory = { | ||
doVisit(v, v.vmem(n), n, (n) => n) | ||
} | ||
|
||
|
||
def visit_jump(j: Jump): Jump = { | ||
doVisit(v, v.vjump(j), j, (j) => j) | ||
} | ||
|
||
def visit_fallthrough(j: Option[GoTo]): Option[GoTo] = { | ||
doVisit(v, v.vfallthrough(j), j, (j) => j) | ||
} | ||
|
||
def visit_expr(n: Expr): Expr = { | ||
def continue(n: Expr): Expr = n match { | ||
case n: Literal => n | ||
case MemoryLoad(mem, index, endian, size) => MemoryLoad(visit_mem(mem), visit_expr(index), endian, size) | ||
case Extract(end, start, arg) => Extract(end, start, visit_expr(arg)) | ||
case Repeat(repeats, arg) => Repeat(repeats, visit_expr(arg)) | ||
case ZeroExtend(bits, arg) => ZeroExtend(bits, visit_expr(arg)) | ||
case SignExtend(bits, arg) => SignExtend(bits, visit_expr(arg)) | ||
case BinaryExpr(op, arg, arg2) => BinaryExpr(op, visit_expr(arg), visit_expr(arg2)) | ||
case UnaryExpr(op, arg) => UnaryExpr(op, visit_expr(arg)) | ||
case v: Variable => visit_var(v) | ||
case UninterpretedFunction(n, params, rt) => UninterpretedFunction(n, params.map(visit_expr), rt) | ||
} | ||
doVisit(v, v.vexpr(n), n, continue) | ||
} | ||
|
||
def visit_stmt(s: Statement): List[Statement] = { | ||
def continue(n: Statement) = n match { | ||
case m: MemoryAssign => { | ||
m.mem = visit_mem(m.mem) | ||
m.index = visit_expr(m.index) | ||
m.value = visit_expr(m.value) | ||
m | ||
} | ||
case m: Assign => { | ||
m.rhs = visit_expr(m.rhs) | ||
m.lhs = visit_var(m.lhs) | ||
m | ||
} | ||
case s: Assert => { | ||
s.body = visit_expr(s.body) | ||
s | ||
} | ||
case s: Assume => { | ||
s.body = visit_expr(s.body) | ||
s | ||
} | ||
case n: NOP => n | ||
} | ||
doVisitList(v, v.vstmt(s), s, continue) | ||
} | ||
|
||
def visit_block(b: Block): Block = { | ||
def continue(b: Block) = { | ||
b.statements.foreach(s => { | ||
val r = visit_stmt(s) | ||
r match { | ||
case Nil => b.statements.remove(s) | ||
case n :: tl => | ||
b.statements.replace(s, n) | ||
b.statements.insertAllAfter(Some(n), tl) | ||
} | ||
}) | ||
b.replaceJump(visit_jump(b.jump)) | ||
b.fallthrough = visit_fallthrough(b.fallthrough) | ||
b | ||
} | ||
|
||
doVisit(v, v.vblock(b), b, continue) | ||
} | ||
|
||
def visit_proc(p: Procedure): List[Procedure] = { | ||
def continue(p: Procedure) = { | ||
p.in = visit_parameters(p.in) | ||
v.enter_scope(p.in) | ||
for (b <- p.blocks) { | ||
p.replaceBlock(b, visit_block(b)) | ||
} | ||
p.out = visit_parameters(p.out) | ||
v.leave_scope(p.out) | ||
p | ||
} | ||
|
||
doVisitList(v, v.vproc(p), p, continue) | ||
} | ||
|
||
def visit_proc(p: Program): Program = { | ||
def continue(p: Program) = { | ||
p.procedures = p.procedures.flatMap(visit_proc) | ||
p | ||
} | ||
doVisit(v, v.vprog(p), p, continue) | ||
} | ||
} | ||
|
||
def visit_block(v: CILVisitor, b: Block): Block = CILVisitorImpl(v).visit_block(b) | ||
def visit_proc(v: CILVisitor, b: Procedure): List[Procedure] = CILVisitorImpl(v).visit_proc(b) | ||
def visit_stmt(v: CILVisitor, e: Statement): List[Statement] = CILVisitorImpl(v).visit_stmt(e) | ||
def visit_jump(v: CILVisitor, e: Jump): Jump = CILVisitorImpl(v).visit_jump(e) | ||
def visit_expr(v: CILVisitor, e: Expr): Expr = CILVisitorImpl(v).visit_expr(e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
package ir | ||
|
||
import scala.collection.mutable | ||
import scala.collection.immutable.* | ||
import org.scalatest.funsuite.AnyFunSuite | ||
import util.intrusive_list.* | ||
import ir.dsl.* | ||
import ir.cilvisitor.* | ||
|
||
class FindVars extends CILVisitor { | ||
val vars = mutable.ArrayBuffer[Variable]() | ||
|
||
override def vvar(v: Variable) = { | ||
vars.append(v) | ||
SkipChildren() | ||
} | ||
|
||
def globals = vars.collect { case g: Global => | ||
g | ||
} | ||
} | ||
|
||
def globals(e: Expr): List[Variable] = { | ||
val v = FindVars() | ||
visit_expr(v, e) | ||
v.globals.toList | ||
} | ||
|
||
def gamma_v(l: Variable) = LocalVar("Gamma_" + l.name, BoolType) | ||
|
||
def gamma_e(e: Expr): Expr = { | ||
globals(e) match { | ||
case Nil => TrueLiteral | ||
case hd :: Nil => hd | ||
case hd :: tl => tl.foldLeft(hd: Expr)((l, r) => BinaryExpr(BoolAND, l, gamma_v(r))) | ||
} | ||
} | ||
|
||
class AddGammas extends CILVisitor { | ||
|
||
override def vstmt(s: Statement) = { | ||
s match { | ||
case a: Assign => ChangeTo(List(a, Assign(gamma_v(a.lhs), gamma_e(a.rhs)))) | ||
case _ => SkipChildren() | ||
} | ||
|
||
} | ||
} | ||
|
||
class CILVisTest extends AnyFunSuite { | ||
|
||
def getRegister(name: String) = Register(name, 64) | ||
test("trace prog") { | ||
val p = prog( | ||
proc("main", block("lmain", goto("lmain1")), block("lmain1", goto("lmain2")), block("lmain2", ret)) | ||
) | ||
|
||
class BlockTrace extends CILVisitor { | ||
val res = mutable.ArrayBuffer[String]() | ||
|
||
override def vblock(b: Block) = { | ||
res.append(b.label) | ||
DoChildren() | ||
} | ||
|
||
override def vjump(b: Jump) = { | ||
b match { | ||
case g: GoTo => res.addAll(g.targets.map(t => s"gt_${t.label}").toList) | ||
case r: IndirectCall => res.append("indirect") | ||
case r: DirectCall => res.append("direct") | ||
} | ||
DoChildren() | ||
} | ||
} | ||
|
||
val v = BlockTrace() | ||
visit_proc(v, p.procedures.head) | ||
assert(v.res.toList == List("lmain", "gt_lmain1", "lmain1", "gt_lmain2", "lmain2", "indirect")) | ||
} | ||
|
||
test("visit exprs") { | ||
val program: Program = prog( | ||
proc( | ||
"main", | ||
block("0x0", Assign(getRegister("R6"), getRegister("R31")), goto("0x1")), | ||
block( | ||
"0x1", | ||
MemoryAssign(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64), | ||
goto("returntarget") | ||
), | ||
block("returntarget", ret) | ||
) | ||
) | ||
|
||
class ExprTrace extends CILVisitor { | ||
val res = mutable.ArrayBuffer[String]() | ||
|
||
override def vvar(e: Variable) = { | ||
e match { | ||
case Register(n, _) => res.append(n); | ||
case _ => ??? // only reg in source program | ||
} | ||
DoChildren() | ||
} | ||
|
||
override def vexpr(e: Expr) = { | ||
e match { | ||
case BinaryExpr(op, l, r) => res.append(op.toString) | ||
case n: Literal => res.append(n.toString) | ||
case _ => () | ||
} | ||
DoChildren() | ||
} | ||
} | ||
|
||
val v = ExprTrace() | ||
visit_proc(v, program.procedures.head) | ||
assert(v.res.toList == List("R31", "R6", "add", "R6", "4bv64", "10bv64")) | ||
} | ||
|
||
test("rewrite exprs") { | ||
|
||
val program: Program = prog( | ||
proc( | ||
"main", | ||
block("0x0", Assign(getRegister("R6"), getRegister("R31")), goto("0x1")), | ||
block( | ||
"0x1", | ||
MemoryAssign(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64), | ||
goto("returntarget") | ||
), | ||
block("returntarget", ret) | ||
) | ||
) | ||
class VarTrace extends CILVisitor { | ||
val res = mutable.ArrayBuffer[String]() | ||
|
||
override def vvar(e: Variable) = { res.append(e.name); SkipChildren() } | ||
|
||
} | ||
|
||
class RegReplace extends CILVisitor { | ||
val res = mutable.ArrayBuffer[String]() | ||
|
||
override def vvar(e: Variable) = { | ||
e match { | ||
case Register(n, sz) => ChangeTo(LocalVar("l" + n, e.getType)); | ||
case _ => DoChildren() | ||
} | ||
} | ||
|
||
} | ||
|
||
class RegReplacePost extends CILVisitor { | ||
val res = mutable.ArrayBuffer[String]() | ||
|
||
override def vvar(e: Variable) = { | ||
e match { | ||
case LocalVar(n, _) => | ||
ChangeDoChildrenPost(LocalVar("e" + n, e.getType), e => { res.append(e.name); e }); | ||
case _ => DoChildren() | ||
} | ||
} | ||
|
||
} | ||
|
||
val v = VarTrace() | ||
visit_proc(v, program.procedures.head) | ||
assert(v.res.toList == List("R31", "R6", "R6")) | ||
visit_proc(RegReplace(), program.procedures.head) | ||
val v2 = VarTrace() | ||
visit_proc(v2, program.procedures.head) | ||
assert(v2.res.toList == List("lR31", "lR6", "lR6")) | ||
|
||
val v3 = RegReplacePost() | ||
visit_proc(v3, program.procedures.head) | ||
assert(v3.res.toList == List("elR31", "elR6", "elR6")) | ||
|
||
} | ||
|
||
} |
Oops, something went wrong.