From d7213dfeb136c35659584995cc6a7147b7f79a71 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 3 Jul 2024 11:01:54 +1000 Subject: [PATCH] Cil visitor (#220) ir-modifying CIL iterator implementation --- src/main/scala/ir/cilvisitor/CILVisitor.scala | 169 ++++++++++++++++ .../util/intrusive_list/IntrusiveList.scala | 41 ++++ src/test/scala/ir/CILVisitorTest.scala | 181 ++++++++++++++++++ .../IntrusiveListPublicInterfaceTest.scala | 46 +++++ 4 files changed, 437 insertions(+) create mode 100644 src/main/scala/ir/cilvisitor/CILVisitor.scala create mode 100644 src/test/scala/ir/CILVisitorTest.scala diff --git a/src/main/scala/ir/cilvisitor/CILVisitor.scala b/src/main/scala/ir/cilvisitor/CILVisitor.scala new file mode 100644 index 000000000..a405d4fa1 --- /dev/null +++ b/src/main/scala/ir/cilvisitor/CILVisitor.scala @@ -0,0 +1,169 @@ +package ir.cilvisitor + +import ir.* +import scala.collection.mutable.ArrayBuffer + +/** A new visitor based off CIL. + * + * Copied from ASLi https://github.com/UQ-PAC/aslp/blob/partial_eval/libASL/asl_visitor.ml copying from George Necula's + * CIL project (https://people.eecs.berkeley.edu/~necula/cil/) + */ + +sealed trait VisitAction[T] +case class SkipChildren[T]() extends VisitAction[T] +case class DoChildren[T]() extends VisitAction[T] +case class ChangeTo[T](e: T) extends VisitAction[T] +// changes to e, then visits children of e, then applies f to the result +case class ChangeDoChildrenPost[T](e: T, f: T => T) extends VisitAction[T] + + +trait CILVisitor: + def vprog(e: Program): VisitAction[Program] = DoChildren() + def vproc(e: Procedure): VisitAction[List[Procedure]] = DoChildren() + def vparams(e: ArrayBuffer[Parameter]): VisitAction[ArrayBuffer[Parameter]] = DoChildren() + def vblock(e: Block): VisitAction[Block] = DoChildren() + + def vstmt(e: Statement): VisitAction[List[Statement]] = DoChildren() + def vjump(j: Jump): VisitAction[Jump] = DoChildren() + def vfallthrough(j: Option[GoTo]): VisitAction[Option[GoTo]] = DoChildren() + + def vexpr(e: Expr): VisitAction[Expr] = DoChildren() + def vvar(e: Variable): VisitAction[Variable] = DoChildren() + def vmem(e: Memory): VisitAction[Memory] = DoChildren() + + def enter_scope(params: ArrayBuffer[Parameter]): Unit = () + def leave_scope(outparam: ArrayBuffer[Parameter]): Unit = () + + +def doVisitList[T](v: CILVisitor, a: VisitAction[List[T]], n: T, continue: (T) => T): List[T] = { + a match { + case SkipChildren() => List(n) + case ChangeTo(z) => z + case DoChildren() => List(continue(n)) + case ChangeDoChildrenPost(x, f) => f(x.map(continue(_))) + } +} + +def doVisit[T](v: CILVisitor, a: VisitAction[T], n: T, continue: (T) => T): T = { + a match { + case SkipChildren() => n + case DoChildren() => continue(n) + case ChangeTo(z) => z + case ChangeDoChildrenPost(x, f) => f(continue(x)) + } +} + +class CILVisitorImpl(val v: CILVisitor) { + + def visit_parameters(p: ArrayBuffer[Parameter]): ArrayBuffer[Parameter] = { + doVisit(v, v.vparams(p), p, (n) => n) + } + + def visit_var(n: Variable): Variable = { + doVisit(v, v.vvar(n), n, (n) => n) + } + + + def visit_mem(n: Memory): Memory = { + doVisit(v, v.vmem(n), n, (n) => n) + } + + + def visit_jump(j: Jump): Jump = { + doVisit(v, v.vjump(j), j, (j) => j) + } + + def visit_fallthrough(j: Option[GoTo]): Option[GoTo] = { + doVisit(v, v.vfallthrough(j), j, (j) => j) + } + + def visit_expr(n: Expr): Expr = { + def continue(n: Expr): Expr = n match { + case n: Literal => n + case MemoryLoad(mem, index, endian, size) => MemoryLoad(visit_mem(mem), visit_expr(index), endian, size) + case Extract(end, start, arg) => Extract(end, start, visit_expr(arg)) + case Repeat(repeats, arg) => Repeat(repeats, visit_expr(arg)) + case ZeroExtend(bits, arg) => ZeroExtend(bits, visit_expr(arg)) + case SignExtend(bits, arg) => SignExtend(bits, visit_expr(arg)) + case BinaryExpr(op, arg, arg2) => BinaryExpr(op, visit_expr(arg), visit_expr(arg2)) + case UnaryExpr(op, arg) => UnaryExpr(op, visit_expr(arg)) + case v: Variable => visit_var(v) + case UninterpretedFunction(n, params, rt) => UninterpretedFunction(n, params.map(visit_expr), rt) + } + doVisit(v, v.vexpr(n), n, continue) + } + + def visit_stmt(s: Statement): List[Statement] = { + def continue(n: Statement) = n match { + case m: MemoryAssign => { + m.mem = visit_mem(m.mem) + m.index = visit_expr(m.index) + m.value = visit_expr(m.value) + m + } + case m: Assign => { + m.rhs = visit_expr(m.rhs) + m.lhs = visit_var(m.lhs) + m + } + case s: Assert => { + s.body = visit_expr(s.body) + s + } + case s: Assume => { + s.body = visit_expr(s.body) + s + } + case n: NOP => n + } + doVisitList(v, v.vstmt(s), s, continue) + } + + def visit_block(b: Block): Block = { + def continue(b: Block) = { + b.statements.foreach(s => { + val r = visit_stmt(s) + r match { + case Nil => b.statements.remove(s) + case n :: tl => + b.statements.replace(s, n) + b.statements.insertAllAfter(Some(n), tl) + } + }) + b.replaceJump(visit_jump(b.jump)) + b.fallthrough = visit_fallthrough(b.fallthrough) + b + } + + doVisit(v, v.vblock(b), b, continue) + } + + def visit_proc(p: Procedure): List[Procedure] = { + def continue(p: Procedure) = { + p.in = visit_parameters(p.in) + v.enter_scope(p.in) + for (b <- p.blocks) { + p.replaceBlock(b, visit_block(b)) + } + p.out = visit_parameters(p.out) + v.leave_scope(p.out) + p + } + + doVisitList(v, v.vproc(p), p, continue) + } + + def visit_proc(p: Program): Program = { + def continue(p: Program) = { + p.procedures = p.procedures.flatMap(visit_proc) + p + } + doVisit(v, v.vprog(p), p, continue) + } +} + +def visit_block(v: CILVisitor, b: Block): Block = CILVisitorImpl(v).visit_block(b) +def visit_proc(v: CILVisitor, b: Procedure): List[Procedure] = CILVisitorImpl(v).visit_proc(b) +def visit_stmt(v: CILVisitor, e: Statement): List[Statement] = CILVisitorImpl(v).visit_stmt(e) +def visit_jump(v: CILVisitor, e: Jump): Jump = CILVisitorImpl(v).visit_jump(e) +def visit_expr(v: CILVisitor, e: Expr): Expr = CILVisitorImpl(v).visit_expr(e) diff --git a/src/main/scala/util/intrusive_list/IntrusiveList.scala b/src/main/scala/util/intrusive_list/IntrusiveList.scala index a6c5f0b93..9ba9e14e5 100644 --- a/src/main/scala/util/intrusive_list/IntrusiveList.scala +++ b/src/main/scala/util/intrusive_list/IntrusiveList.scala @@ -298,6 +298,47 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private ( intrusiveListElement.insertAfter(newElem) } + + /** + * Insert an element after another element in the list. + * @param intrusiveListElement The element in the list to insert after, or None to indicate the beginning. + * @param newElems The elements to insert. Must not be members of any other intrusive list(s). + * @return the last inserted element, or the reference element + */ + def insertAllAfter(intrusiveListElement: Option[T], newElems: Iterable[T]): Option[T] = { + intrusiveListElement match { + case None => + newElems.toList.reverse.map(prepend).headOption.orElse(intrusiveListElement) + case Some(n) => + var p = n + for (i <- newElems) { + p = insertAfter(p, i) + } + Some(p) + } + } + + + /** + * Insert an element before another element in the list. + * @param intrusiveListElement The element in the list to insert before, or None to indicate the end of the list. + * @param newElems The elements to insert. Must not be members of any other intrusive list(s). + * @return the last inserted element, or the reference element + */ + def insertAllBefore(intrusiveListElement: Option[T], newElems: Iterable[T]): Option[T] = { + intrusiveListElement match { + case None => + newElems.map(append).lastOption.orElse(intrusiveListElement) + case Some(n) => + var p = n + for (i <- newElems.toList.reverse) { + p = insertBefore(p, i) + } + Some(p) + } + } + + /** * Insert an element before another element in the list. * diff --git a/src/test/scala/ir/CILVisitorTest.scala b/src/test/scala/ir/CILVisitorTest.scala new file mode 100644 index 000000000..9574f65c9 --- /dev/null +++ b/src/test/scala/ir/CILVisitorTest.scala @@ -0,0 +1,181 @@ +package ir + +import scala.collection.mutable +import scala.collection.immutable.* +import org.scalatest.funsuite.AnyFunSuite +import util.intrusive_list.* +import ir.dsl.* +import ir.cilvisitor.* + +class FindVars extends CILVisitor { + val vars = mutable.ArrayBuffer[Variable]() + + override def vvar(v: Variable) = { + vars.append(v) + SkipChildren() + } + + def globals = vars.collect { case g: Global => + g + } +} + +def globals(e: Expr): List[Variable] = { + val v = FindVars() + visit_expr(v, e) + v.globals.toList +} + +def gamma_v(l: Variable) = LocalVar("Gamma_" + l.name, BoolType) + +def gamma_e(e: Expr): Expr = { + globals(e) match { + case Nil => TrueLiteral + case hd :: Nil => hd + case hd :: tl => tl.foldLeft(hd: Expr)((l, r) => BinaryExpr(BoolAND, l, gamma_v(r))) + } +} + +class AddGammas extends CILVisitor { + + override def vstmt(s: Statement) = { + s match { + case a: Assign => ChangeTo(List(a, Assign(gamma_v(a.lhs), gamma_e(a.rhs)))) + case _ => SkipChildren() + } + + } +} + +class CILVisTest extends AnyFunSuite { + + def getRegister(name: String) = Register(name, 64) + test("trace prog") { + val p = prog( + proc("main", block("lmain", goto("lmain1")), block("lmain1", goto("lmain2")), block("lmain2", ret)) + ) + + class BlockTrace extends CILVisitor { + val res = mutable.ArrayBuffer[String]() + + override def vblock(b: Block) = { + res.append(b.label) + DoChildren() + } + + override def vjump(b: Jump) = { + b match { + case g: GoTo => res.addAll(g.targets.map(t => s"gt_${t.label}").toList) + case r: IndirectCall => res.append("indirect") + case r: DirectCall => res.append("direct") + } + DoChildren() + } + } + + val v = BlockTrace() + visit_proc(v, p.procedures.head) + assert(v.res.toList == List("lmain", "gt_lmain1", "lmain1", "gt_lmain2", "lmain2", "indirect")) + } + + test("visit exprs") { + val program: Program = prog( + proc( + "main", + block("0x0", Assign(getRegister("R6"), getRegister("R31")), goto("0x1")), + block( + "0x1", + MemoryAssign(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64), + goto("returntarget") + ), + block("returntarget", ret) + ) + ) + + class ExprTrace extends CILVisitor { + val res = mutable.ArrayBuffer[String]() + + override def vvar(e: Variable) = { + e match { + case Register(n, _) => res.append(n); + case _ => ??? // only reg in source program + } + DoChildren() + } + + override def vexpr(e: Expr) = { + e match { + case BinaryExpr(op, l, r) => res.append(op.toString) + case n: Literal => res.append(n.toString) + case _ => () + } + DoChildren() + } + } + + val v = ExprTrace() + visit_proc(v, program.procedures.head) + assert(v.res.toList == List("R31", "R6", "add", "R6", "4bv64", "10bv64")) + } + + test("rewrite exprs") { + + val program: Program = prog( + proc( + "main", + block("0x0", Assign(getRegister("R6"), getRegister("R31")), goto("0x1")), + block( + "0x1", + MemoryAssign(mem, BinaryExpr(BVADD, getRegister("R6"), bv64(4)), bv64(10), Endian.LittleEndian, 64), + goto("returntarget") + ), + block("returntarget", ret) + ) + ) + class VarTrace extends CILVisitor { + val res = mutable.ArrayBuffer[String]() + + override def vvar(e: Variable) = { res.append(e.name); SkipChildren() } + + } + + class RegReplace extends CILVisitor { + val res = mutable.ArrayBuffer[String]() + + override def vvar(e: Variable) = { + e match { + case Register(n, sz) => ChangeTo(LocalVar("l" + n, e.getType)); + case _ => DoChildren() + } + } + + } + + class RegReplacePost extends CILVisitor { + val res = mutable.ArrayBuffer[String]() + + override def vvar(e: Variable) = { + e match { + case LocalVar(n, _) => + ChangeDoChildrenPost(LocalVar("e" + n, e.getType), e => { res.append(e.name); e }); + case _ => DoChildren() + } + } + + } + + val v = VarTrace() + visit_proc(v, program.procedures.head) + assert(v.res.toList == List("R31", "R6", "R6")) + visit_proc(RegReplace(), program.procedures.head) + val v2 = VarTrace() + visit_proc(v2, program.procedures.head) + assert(v2.res.toList == List("lR31", "lR6", "lR6")) + + val v3 = RegReplacePost() + visit_proc(v3, program.procedures.head) + assert(v3.res.toList == List("elR31", "elR6", "elR6")) + + } + +} diff --git a/src/test/scala/util/intrusive_list/IntrusiveListPublicInterfaceTest.scala b/src/test/scala/util/intrusive_list/IntrusiveListPublicInterfaceTest.scala index f9f89867f..7e13908f6 100644 --- a/src/test/scala/util/intrusive_list/IntrusiveListPublicInterfaceTest.scala +++ b/src/test/scala/util/intrusive_list/IntrusiveListPublicInterfaceTest.scala @@ -194,4 +194,50 @@ class IntrusiveListPublicInterfaceTest extends AnyFunSuite { } + test("insertAllAfter") { + val x = IntrusiveList[Elem]() + + x.append(Elem(9)) + val first = Elem(10) + val f = x.append(first) + x.append(Elem(13)) + // 9 10 13 + assert(x.toList.map(_.t) == List(9, 10, 13)) + + val n = Elem(225) + val toInsert = List(Elem(11), Elem(12), n) + + val r = x.insertAllAfter(Some(first), toInsert) + assert(r.get eq n) + assert(x.toList.map(_.t) == List(9, 10, 11, 12, 225, 13)) + + val l = Range(1, 4).map(x => Elem(x)) + val rr = x.insertAllAfter(None, l) + assert(x.toList.map(_.t) == List(1, 2, 3, 9, 10, 11, 12, 225, 13)) + + } + + + test("insertAllBefore") { + val x = IntrusiveList[Elem]() + + x.append(Elem(9)) + val first = Elem(10) + val f = x.append(first) + x.append(Elem(13)) + // 9 10 13 + assert(x.toList.map(_.t) == List(9, 10, 13)) + + val n = Elem(11) + val toInsert = List(n, Elem(12), Elem(255)) + + val r = x.insertAllBefore(Some(first), toInsert) + assert(r.get eq n) + assert(x.toList.map(_.t) == List(9,11, 12, 255, 10, 13)) + + val l = Range(1, 4).map(x => Elem(x)) + val rr = x.insertAllBefore(None, l) + assert(x.toList.map(_.t) == List(9, 11, 12, 255,10, 13, 1, 2, 3)) + } + }