Skip to content

Commit

Permalink
Implement Statement.matches and Expression.matches. (#236)
Browse files Browse the repository at this point in the history
* Implement Statement.matches and Expression.matches.

* Remove VirtualVariable.__eq__.

* Lint code.
  • Loading branch information
ltfish authored Sep 13, 2024
1 parent 52cf227 commit f15007b
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 13 deletions.
145 changes: 135 additions & 10 deletions ailment/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
claripy = None

from .tagged_object import TaggedObject
from .utils import get_bits, stable_hash, is_none_or_likeable
from .utils import get_bits, stable_hash, is_none_or_likeable, is_none_or_matchable

if TYPE_CHECKING:
from .statement import Statement
Expand Down Expand Up @@ -44,6 +44,9 @@ def __eq__(self, other):
def likes(self, atom): # pylint:disable=unused-argument,no-self-use
raise NotImplementedError()

def matches(self, atom): # pylint:disable=unused-argument,no-self-use
return NotImplementedError()

def replace(self, old_expr, new_expr):
if self is old_expr:
r = True
Expand Down Expand Up @@ -115,6 +118,7 @@ def likes(self, other):
and self.bits == other.bits
)

matches = likes
__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -153,6 +157,7 @@ def __str__(self):
def likes(self, other):
return type(self) is type(other) and self.tmp_idx == other.tmp_idx and self.bits == other.bits

matches = likes
__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -181,6 +186,8 @@ def size(self):
def likes(self, atom):
return type(self) is type(atom) and self.reg_offset == atom.reg_offset and self.bits == atom.bits

matches = likes

def __repr__(self):
return str(self)

Expand Down Expand Up @@ -263,6 +270,15 @@ def stack_offset(self) -> int | None:
return None

def likes(self, atom):
return (
isinstance(atom, VirtualVariable)
and self.varid == atom.varid
and self.bits == atom.bits
and self.category == atom.category
and self.oident == atom.oident
)

def matches(self, atom):
return (
isinstance(atom, VirtualVariable)
and self.bits == atom.bits
Expand Down Expand Up @@ -328,6 +344,15 @@ def verbose_op(self) -> str:
return "Phi"

def likes(self, atom) -> bool:
if isinstance(atom, Phi) and self.bits == atom.bits:
self_src_and_vvarids = {(src, vvar.varid if vvar is not None else None) for src, vvar in self.src_and_vvars}
other_src_and_vvarids = {
(src, vvar.varid if vvar is not None else None) for src, vvar in atom.src_and_vvars
}
return self_src_and_vvarids == other_src_and_vvarids
return False

def matches(self, atom) -> bool:
if isinstance(atom, Phi) and self.bits == atom.bits:
if len(self.src_and_vvars) != len(atom.src_and_vvars):
return False
Expand All @@ -342,7 +367,7 @@ def likes(self, atom) -> bool:
and other_vvar is not None
or self_vvar is not None
and other_vvar is None
or not self_vvar.likes(other_vvar)
or not self_vvar.matches(other_vvar)
):
return False
return True
Expand Down Expand Up @@ -424,7 +449,18 @@ def __repr__(self):

def likes(self, other):
return (
type(other) is UnaryOp and self.op == other.op and self.bits == other.bits and self.operand == other.operand
type(other) is UnaryOp
and self.op == other.op
and self.bits == other.bits
and self.operand.likes(other.operand)
)

def matches(self, atom):
return (
type(atom) is UnaryOp
and self.op == atom.op
and self.bits == atom.bits
and self.operand.matches(atom.operand)
)

__hash__ = TaggedObject.__hash__
Expand Down Expand Up @@ -518,6 +554,19 @@ def likes(self, other):
and self.rounding_mode == other.rounding_mode
)

def matches(self, other):
return (
type(other) is Convert
and self.from_bits == other.from_bits
and self.to_bits == other.to_bits
and self.bits == other.bits
and self.is_signed == other.is_signed
and self.operand.matches(other.operand)
and self.from_type == other.from_type
and self.to_type == other.to_type
and self.rounding_mode == other.rounding_mode
)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -613,7 +662,17 @@ def likes(self, other):
and self.from_type == other.from_type
and self.to_bits == other.to_bits
and self.to_type == other.to_type
and self.operand == other.operand
and self.operand.likes(other.operand)
)

def matches(self, other):
return (
type(other) is Reinterpret
and self.from_bits == other.from_bits
and self.from_type == other.from_type
and self.to_bits == other.to_bits
and self.to_type == other.to_type
and self.operand.matches(other.operand)
)

__hash__ = TaggedObject.__hash__
Expand Down Expand Up @@ -800,6 +859,17 @@ def likes(self, other):
and self.rounding_mode == other.rounding_mode
)

def matches(self, other):
return (
type(other) is BinaryOp
and self.op == other.op
and self.bits == other.bits
and self.signed == other.signed
and is_none_or_matchable(self.operands, other.operands, is_list=True)
and self.floating_point == other.floating_point
and self.rounding_mode == other.rounding_mode
)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -936,7 +1006,15 @@ def likes(self, other):
type(other) is TernaryOp
and self.op == other.op
and self.bits == other.bits
and self.operands == other.operands
and is_none_or_likeable(self.operands, other.operands, is_list=True)
)

def matches(self, other):
return (
type(other) is TernaryOp
and self.op == other.op
and self.bits == other.bits
and is_none_or_matchable(self.operands, other.operands, is_list=True)
)

__hash__ = TaggedObject.__hash__
Expand Down Expand Up @@ -1061,7 +1139,6 @@ def replace(self, old_expr, new_expr):
def _likes_addr(self, other_addr):
if hasattr(self.addr, "likes") and hasattr(other_addr, "likes"):
return self.addr.likes(other_addr)

return self.addr == other_addr

def likes(self, other):
Expand All @@ -1074,6 +1151,21 @@ def likes(self, other):
and self.alt == other.alt
)

def _matches_addr(self, other_addr):
if hasattr(self.addr, "matches") and hasattr(other_addr, "matches"):
return self.addr.matches(other_addr)
return self.addr == other_addr

def matches(self, other):
return (
type(other) is Load
and self._matches_addr(other.addr)
and self.size == other.size
and self.endness == other.endness
and self.guard == other.guard
and self.alt == other.alt
)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -1130,9 +1222,18 @@ def __str__(self):
def likes(self, atom):
return (
type(atom) is ITE
and self.cond == atom.cond
and self.iffalse == atom.iffalse
and self.iftrue == atom.iftrue
and self.cond.likes(atom.cond)
and self.iffalse.likes(atom.iffalse)
and self.iftrue.likes(atom.iftrue)
and self.bits == atom.bits
)

def matches(self, atom):
return (
type(atom) is ITE
and self.cond.matches(atom.cond)
and self.iffalse.matches(atom.iffalse)
and self.iftrue.matches(atom.iftrue)
and self.bits == atom.bits
)

Expand Down Expand Up @@ -1199,6 +1300,7 @@ def __init__(self, idx, dirty_expr, bits=None, **kwargs):
def likes(self, other):
return type(other) is DirtyExpression and other.dirty_expr == self.dirty_expr

matches = likes
__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -1254,6 +1356,15 @@ def likes(self, other):
and all(op1.likes(op2) for op1, op2 in zip(other.operands, self.operands))
)

def matches(self, other):
return (
type(other) is VEXCCallExpression
and other.cee_name == self.cee_name
and len(self.operands) == len(other.operands)
and self.bits == other.bits
and all(op1.matches(op2) for op1, op2 in zip(other.operands, self.operands))
)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -1315,7 +1426,20 @@ def _hash_core(self):
return stable_hash((MultiStatementExpression,) + tuple(self.stmts) + (self.expr,))

def likes(self, other):
return type(self) is type(other) and self.stmts == other.stmts and self.expr == other.expr
return (
type(self) is type(other)
and len(self.stmts) == len(other.stmts)
and all(s_stmt.likes(o_stmt) for s_stmt, o_stmt in zip(self.stmts, other.stmts))
and self.expr.likes(other.expr)
)

def matches(self, atom):
return (
type(self) is type(atom)
and len(self.stmts) == len(atom.stmts)
and all(s_stmt.matches(o_stmt) for s_stmt, o_stmt in zip(self.stmts, atom.stmts))
and self.expr.matches(atom.expr)
)

def __repr__(self):
return f"MultiStatementExpression({self.stmts}, {self.expr})"
Expand Down Expand Up @@ -1408,6 +1532,7 @@ def likes(self, other):
and self.offset == other.offset
)

matches = likes
__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down
48 changes: 47 additions & 1 deletion ailment/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
except ImportError:
claripy = None

from .utils import stable_hash, is_none_or_likeable
from .utils import stable_hash, is_none_or_likeable, is_none_or_matchable
from .tagged_object import TaggedObject
from .expression import Expression

Expand Down Expand Up @@ -35,6 +35,12 @@ def eq(self, expr0, expr1): # pylint:disable=no-self-use
return expr0 is expr1
return expr0 == expr1

def likes(self, atom): # pylint:disable=unused-argument,no-self-use
raise NotImplementedError()

def matches(self, atom): # pylint:disable=unused-argument,no-self-use
return NotImplementedError()


class Assignment(Statement):
"""
Expand All @@ -58,6 +64,9 @@ def __eq__(self, other):
def likes(self, other):
return type(other) is Assignment and self.dst.likes(other.dst) and self.src.likes(other.src)

def matches(self, other):
return type(other) is Assignment and self.dst.matches(other.dst) and self.src.matches(other.src)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -138,6 +147,16 @@ def likes(self, other):
and self.endness == other.endness
)

def matches(self, other):
return (
type(other) is Store
and self.addr.matches(other.addr)
and self.data.matches(other.data)
and self.size == other.size
and self.guard == other.guard
and self.endness == other.endness
)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -236,6 +255,9 @@ def __eq__(self, other):
def likes(self, other):
return type(other) is Jump and is_none_or_likeable(self.target, other.target)

def matches(self, other):
return type(other) is Jump and is_none_or_matchable(self.target, other.target)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -321,6 +343,14 @@ def likes(self, other):
and is_none_or_likeable(self.false_target, other.false_target)
)

def matches(self, other):
return (
type(other) is ConditionalJump
and self.condition.matches(other.condition)
and is_none_or_matchable(self.true_target, other.true_target)
and is_none_or_matchable(self.false_target, other.false_target)
)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -458,6 +488,17 @@ def likes(self, other):
and is_none_or_likeable(self.fp_ret_expr, other.fp_ret_expr)
)

def matches(self, other):
return (
type(other) is Call
and is_none_or_matchable(self.target, other.target)
and self.calling_convention == other.calling_convention
and self.prototype == other.prototype
and is_none_or_matchable(self.args, other.args, is_list=True)
and is_none_or_matchable(self.ret_expr, other.ret_expr)
and is_none_or_matchable(self.fp_ret_expr, other.fp_ret_expr)
)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -593,6 +634,9 @@ def __eq__(self, other):
def likes(self, other):
return type(other) is Return and is_none_or_likeable(self.ret_exprs, other.ret_exprs, is_list=True)

def matches(self, other):
return type(other) is Return and is_none_or_matchable(self.ret_exprs, other.ret_exprs, is_list=True)

__hash__ = TaggedObject.__hash__

def _hash_core(self):
Expand Down Expand Up @@ -685,6 +729,8 @@ def __init__(self, idx, name: str, ins_addr: int, block_idx: int | None = None,
def likes(self, other: "Label"):
return isinstance(other, Label)

matches = likes

def _hash_core(self):
return stable_hash(
(
Expand Down
Loading

0 comments on commit f15007b

Please sign in to comment.