Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InSpectre gadget v0.1 #17

Merged
merged 10 commits into from
Jan 29, 2024
Merged
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion analyzer/analysis/rangeAnalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_ast_ranges(constraints, ast : claripy.BV):

# We calculate the min and max once
s = claripy.Solver(timeout=global_config["Z3Timeout"])
s.constraints = constraints
s.constraints = constraints.copy()
ast_min = s.min(ast)
ast_max = s.max(ast)

Expand Down
34 changes: 25 additions & 9 deletions analyzer/analysis/range_strategies/find_constraints_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
# autopep8: off
from ...shared.config import *
from ...shared.utils import *
from ...shared.logger import *
# autopep8: on

l = get_logger("FindConstraintsBounds")

class RangeStrategyFindConstraintsBounds(RangeStrategy):
infer_isolated_strategy : RangeStrategyInferIsolated
Expand Down Expand Up @@ -48,30 +50,45 @@ def find_range(self, constraints, ast : claripy.ast.bv.BVS,
if ast_min == ast_max:
return self.infer_isolated_strategy.find_range([], ast, ast_min, ast_max)

# print(f"min:{ast_min} max:{ast_max}")

try:
sat_ranges = _find_sat_distribution(constraints, ast, ast_min, ast_max, stride=1)
except claripy.ClaripyZ3Error as e:
# timeout
return None

# --------- One non-satisfiable range

if sat_ranges != None and len(sat_ranges) == 1:
# It is a full range, we can treat it as isolated
return self.infer_isolated_strategy.find_range([], ast, sat_ranges[0][0], sat_ranges[0][1])
# We have one non-satisfiable range, so we can try to treat it as
# isolated

r = self.infer_isolated_strategy.find_range([], ast, sat_ranges[0][0], sat_ranges[0][1])

if r != None and sat_ranges[0][0] > sat_ranges[0][1]:
# The range wraps around, we have to be sure that the AST range is
# a simple strided range, otherwise we get two separate disjoint ranges
# which we cannot describe in our range (e.g., [ast != 0xf, ast <= 0xffff])
if r.and_mask != None or r.or_mask != None or \
ast_max != ((1 << ast.size()) - 1 - (r.stride - 1)) or ast_min != 0:

# We have a complex range thus fail (e.g., masking is performed)
return None

return r

# --------- Signed range
if ast_min == 0 and ast_max == (2**ast.size()) - 1:
s = claripy.Solver(timeout=global_config["Z3Timeout"])
new_min = (1 << (ast.size() - 1))
s.constraints = constraints + [(ast > new_min)]
s.constraints = constraints + [(ast >= new_min)]
upper_ast_min = s.min(ast)
upper_ast_max = s.max(ast)
s = claripy.Solver(timeout=global_config["Z3Timeout"])
s.constraints = constraints + [ast <= new_min]
s.constraints = constraints + [ast < new_min]
lower_ast_min = s.min(ast)
lower_ast_max = s.max(ast)

# print(f" new_min:{hex(new_min)} upper_min: {hex(upper_ast_min)} upper_max: {hex(upper_ast_max)} lower_min: {hex(lower_ast_min)} lower_max: {hex(lower_ast_max)}")

if lower_ast_min == 0 and upper_ast_max == (2**ast.size()) - 1:
# treat this as a single range that wraps around ( min > max )
Expand All @@ -87,7 +104,7 @@ def find_range(self, constraints, ast : claripy.ast.bv.BVS,
return self.infer_isolated_strategy.find_range([], ast, upper_ast_min, lower_ast_max)

# --------- Can't solve this
print(f"Cant' solve range: {ast} ({constraints})")
l.warning(f"Cant' solve range: {ast} ({constraints})")

# TODO: If there is only one SAT range, we may still be able to treat
# it as isolated and adjust the min and max.
Expand Down Expand Up @@ -122,8 +139,7 @@ def _find_sat_distribution(constraints, ast, start, end, stride = 1):
return [(start, end)]

# Range with a "hole"
# print(f"Range: {[(value+1, value-1)]}")
return [value+1, value-1]
return [(value+1, value-1)]

# TODO: Double check the validity of the code below
return None
Expand Down
2 changes: 1 addition & 1 deletion analyzer/analysis/range_strategies/find_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def find_range(self, constraints, ast : claripy.ast.bv.BVS,

entropy, and_mask, or_mask = _find_entropy(s, ast, ast_max)

return range_complex(ast_min, ast_max, False, entropy, and_mask, or_mask)
return range_complex(ast_min, ast_max, ast.size(), False, entropy, and_mask, or_mask)

def _find_entropy(s : claripy.Solver, ast : claripy.BV, ast_max : int):

Expand Down
68 changes: 55 additions & 13 deletions analyzer/analysis/range_strategies/infer_isolated.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def find_range(self, constraints, ast : claripy.ast.bv.BVS,
ast_min : int = None, ast_max : int = None):

# We only support isolated ranges

if constraints:
return None

Expand All @@ -34,34 +33,73 @@ def find_range(self, constraints, ast : claripy.ast.bv.BVS,
ast_max = s.max(ast)

if ast.depth == 1:
return range_simple(ast_min, ast_max, 1, True)
return range_simple(ast_min, ast_max, ast.size(), 1, True)

range_map = get_range_map_from_ast(ast)

if range_map.unknown:
return None

# We try an extra optimization: separating the concrete value
# from addition. This covers cases like:
# 0xffffffff81000000 + <BV32 AST >
# get_range_map_from_ast() cannot handle 'overflows', so we
# split the AST and add the concrete value manually to the range.
if ast.op == '__add__' and any(not arg.symbolic for arg in ast.args):

concrete_value = next(arg for arg in ast.args if not arg.symbolic).args[0]
sub_ast = sum([arg for arg in ast.args if arg.symbolic])

range_map = get_range_map_from_ast(sub_ast)

if range_map.unknown:
return None

range_map = range_map.switch_to_stride_mode()

if range_map.unknown:
return None

s = claripy.Solver(timeout=global_config["Z3Timeout"])
sub_ast_min = s.min(sub_ast)
sub_ast_max = s.max(sub_ast)

isolated_ast_min = sub_ast_min + concrete_value
isolated_ast_max = sub_ast_max + concrete_value

# handle overflows
isolated_ast_min &= (1 << ast.size()) - 1
isolated_ast_max &= (1 << ast.size()) - 1

if isolated_ast_min - isolated_ast_max == range_map.stride:
isolated_ast_min = s.min(ast)
isolated_ast_max = s.max(ast)

# incorporate non-isolated min and max, only adjust if they are
# tighter
# Note: Conditions hold for both normal and disjoint ranges
ast_min = ast_min if isolated_ast_min < ast_min else isolated_ast_min
ast_max = ast_max if isolated_ast_max > ast_max else isolated_ast_max

else:
return None


if range_map.stride_mode:
return range_simple(ast_min, ast_max, range_map.stride, isolated=True)
return range_simple(ast_min, ast_max, ast.size(), range_map.stride, isolated=True)

else:
return range_complex(ast_min, ast_max, True, None, range_map.and_mask, range_map.or_mask, True)
return range_complex(ast_min, ast_max, ast.size(), True, None, range_map.and_mask, range_map.or_mask, True)


def is_linear_mask(and_mask, or_mask):

mask = and_mask & ~or_mask

# print(bin(mask))

highest_bit = mask.bit_length()
lowest_bit = (mask & -mask).bit_length() - 1

stride_mask = (2 ** highest_bit - 1) & ~(2 ** lowest_bit - 1)

# print(bin(stride_mask))

return mask == stride_mask

@dataclass
Expand Down Expand Up @@ -374,8 +412,10 @@ def op_add(ast, range_maps):

for idx, map in enumerate(range_maps):
if not map:
assert(not concrete_ast)
concrete_ast = ast.args[idx]
if concrete_ast != None:
concrete_ast += ast.args[idx]
else:
concrete_ast = ast.args[idx]

elif map.is_full_range(ast.length):
return map
Expand Down Expand Up @@ -415,8 +455,10 @@ def op_mul(ast, range_maps):

for idx, map in enumerate(range_maps):
if not map:
assert(not concrete_ast)
concrete_ast = ast.args[idx]
if concrete_ast != None:
concrete_ast = concrete_ast * ast.args[idx]
else:
concrete_ast = ast.args[idx]

elif map.is_full_range(ast.length):
return map
Expand Down
2 changes: 1 addition & 1 deletion analyzer/analysis/range_strategies/small_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _list_to_stride_range(numbers : list):
if numbers[i] + stride != numbers[i+1]:
return None

return AstRange(min=numbers[0] , max=numbers[-1],
return AstRange(min=numbers[0] , max=numbers[-1], ast_size=0,
exact=True, isolated=True,
intervals=[Interval(numbers[0], numbers[-1], stride)])

Expand Down
5 changes: 5 additions & 0 deletions analyzer/analysis/tfpAnalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def analyse(t: TaintedFunctionPointer):
if s[0] == t.reg:
new_t.expr = s[1].expr

s = claripy.Solver(timeout=global_config["Z3Timeout"])
if not s.satisfiable(extra_constraints=[x[1] for x in new_t.constraints]):
# Skipping.. this combination of constraints is not satisfiable
continue

tfps.append(new_t)

# Analyse tfps
Expand Down
18 changes: 11 additions & 7 deletions analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def analyse_gadget(proj, gadget_address, name, config, csv_filename, tfp_csv_fil
"""

# Step 1. Analyze the code snippet with angr.
l.info(f"Analyzing gadget at address {hex(gadget_address)} {name} ...")
l.info(f"Analyzing gadget at address {hex(gadget_address)}...")
s = Scanner()
s.run(proj, gadget_address, config)

Expand Down Expand Up @@ -211,6 +211,9 @@ def analyse_gadget(proj, gadget_address, name, config, csv_filename, tfp_csv_fil
all_tfps = []
for c in s.calls:
all_tfps.extend(tfpAnalysis.analyse(c))

if all_tfps: l.info(f"Extracted {len(all_tfps)} tfps.")

for tfp in all_tfps:
tfp.uuid = str(uuid.uuid4())
tfp.name = name
Expand Down Expand Up @@ -248,19 +251,20 @@ def run(binary, config_file, base_address, gadgets, cache_project, csv_filename=
# Simplify how symbols get printed.
claripy.ast.base._unique_names = False

# Prepare angr project.
l.info("Loading angr project...")
config = load_config(config_file)
proj = load_angr_project(binary, base_address, cache_project)

l.info("Removing non-writable memory...")
remove_memory_sections(proj)

if global_config["LogLevel"] == 0:
disable_logging()
elif global_config["LogLevel"] == 1:
disable_logging(keep_main=True)

# Prepare angr project.
l.info("Loading angr project...")
proj = load_angr_project(binary, base_address, cache_project)

l.info("Removing non-writable memory...")
remove_memory_sections(proj)

# Run the Analyzer.
# TODO: Parallelize.
for g in gadgets:
Expand Down
Loading