Skip to content

Commit

Permalink
#18 Preprocessing methods to compute auxiliary data structures
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcere committed Nov 25, 2024
1 parent 543b3fe commit af82523
Showing 1 changed file with 192 additions and 0 deletions.
192 changes: 192 additions & 0 deletions src/greedy/greedy_new_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,30 @@
from global_params.types import var_id_T, instr_id_T, instr_JSON_T


def _simplify_graph_to_selected_nodes(graph: nx.DiGraph, selected_nodes: List) -> nx.DiGraph:
"""
Auxiliary method that returns the transitive reduction of the graph that is generated by
preserving the initial paths in the original path when restricted to the selected nodes
"""
subgraph = nx.DiGraph()
subgraph.add_nodes_from(selected_nodes)

# Add edges based on reachability using BFS/DFS
for u in selected_nodes:
for v in selected_nodes:
# Avoid self-loops and skips relations that have already been considered
if u != v and not subgraph.has_edge(v, u):

# Check if v is reachable from u in G using DFS
reachable = nx.algorithms.dfs_predecessors(graph, source=u)
if v in reachable:
subgraph.add_edge(u, v)

# Step 4: Apply transitive reduction on the dynamically created subgraph
subgraph_reduced = nx.transitive_reduction(subgraph)
return subgraph_reduced


def idx_wrt_cstack(idx: int, cstack: List, fstack: List) -> int:
"""
Given a position w.r.t fstack, returns the corresponding position w.r.t cstack
Expand Down Expand Up @@ -127,3 +151,171 @@ def top_stack(self) -> Optional[var_id_T]:

def __repr__(self):
return str(self.stack)


class SMSgreedy:

def __init__(self, json_format, debug_mode: bool = False):
self._user_instr: List[instr_JSON_T] = json_format['user_instrs']
self._initial_stack: List[var_id_T] = json_format['src_ws']
self._final_stack: List[var_id_T] = json_format['tgt_ws']
self._vars: List[var_id_T] = json_format['vars']
self._deps: List[Tuple[var_id_T, var_id_T]] = json_format['dependencies']
self.debug_mode = debug_mode

# Note: we assume function invocations might have several variables in 'outpt_sk'
self._var2instr = {var: ins for ins in self._user_instr for var in ins['outpt_sk']}
self._id2instr = {ins['id']: ins for ins in self._user_instr}
self._var2id = {var: ins['id'] for ins in self._user_instr for var in ins['outpt_sk']}
self._var2pos_stack = self._compute_var2pos(self._final_stack)

self._var_total_uses = self._compute_var_total_uses()
direct_g, indirect_g = self._compute_dependency_graph()

self._relevant_ops = self.select_ops(direct_g)
self._indirect_g = _simplify_graph_to_selected_nodes(indirect_g, self._relevant_ops)
self._direct_g = _simplify_graph_to_selected_nodes(direct_g, self._relevant_ops)

# We determine which elements must be computed in order to compute certain instruction
self._values_used = {}
for instr_id in self._relevant_ops:
self._compute_values_used(self._id2instr[instr_id], self._relevant_ops, self._values_used)

# Determine which topmost elements can be reused in the graph
self._top_can_be_used = {}
for instr in self._user_instr:
self._compute_top_can_used(instr, self._top_can_be_used)

# We need to compute the sub graph over the full dependency graph, as edges could be lost if we use the
# transitive reduction instead. Hence, we need to compute the transitive_closure of the graph
self._trans_sub_graph = nx.transitive_reduction(nx.DiGraph([*self._direct_g.edges, *self._indirect_g.edges]))
for node in self._relevant_ops:
self._trans_sub_graph.add_node(node)

def _compute_var_total_uses(self) -> Dict[var_id_T, int]:
"""
Computes how many times each var appears either in the final stack or as a subterm
for other terms.
"""
var_uses = defaultdict(lambda: 0)

# Count vars in the final stack
for var_stack in self._final_stack:
var_uses[var_stack] += 1

# Count vars as input of other instrs
for instr_id, instr in self._id2instr.items():
for subterm_var in instr['inpt_sk']:
var_uses[subterm_var] += 1

return var_uses

def _compute_var2pos(self, var_list: List[var_id_T]) -> Dict[var_id_T, List[int]]:
"""
Dict that links each stack variable that appears in a var list to the
list of positions it occupies
"""
var2pos = defaultdict(lambda: [])

for i, stack_var in enumerate(var_list):
var2pos[stack_var].append(i)

return var2pos

def _compute_dependency_graph(self) -> Tuple[nx.DiGraph, nx.DiGraph]:
"""
We generate two dependency graphs: one for direct relations (i.e. one term embedded into another)
and other with the dependencies due to memory/storage accesses
"""
direct_graph = nx.DiGraph()
indirect_graph = nx.DiGraph()

for instr in self._user_instr:
instr_id = instr['id']
direct_graph.add_node(instr_id)
indirect_graph.add_node(instr_id)

for stack_elem in instr['inpt_sk']:
# This means the stack element corresponds to another uninterpreted instruction
if stack_elem in self._var2instr:
direct_graph.add_edge(self._var2id[stack_elem], instr_id)

# We need to consider also the order given by the tuples
for id1, id2 in self._deps:
indirect_graph.add_edge(id1, id2)

return direct_graph, indirect_graph

def select_ops(self, direct_g: nx.DiGraph):
"""
Selects which operations are considered in the algorithm. We consider mem operations (excluding loads with no
dependencies) and computations that are not subterms
"""
dep_ids = set(elem for dep in self._deps for elem in dep)

# Relevant operations corresponds to memory operations (STORE in all cases, LOADs an KECCAKs if they have some
# some kind of dependency) and operations that are not used elsewhere. The idea here is that we want
# to consider the maximal elements to compute, as reusing computations is easier this way
relevant_operations = [instr["id"] for instr in self._user_instr if
any(instr_name in instr["disasm"] for instr_name in ["STORE"])
or (any(load_instr in instr["disasm"] for load_instr in ["LOAD", "KECCAK"])
and instr["id"] in dep_ids)
or direct_g.out_degree(instr["id"]) == 0]
return relevant_operations

def _compute_top_can_used(self, instr: instr_JSON_T, top_can_be_used: Dict[var_id_T, Set[var_id_T]]) -> Set[var_id_T]:
"""
Computes for each instruction if the topmost element of the stack can be reused directly
at some point. It considers commutative operations
"""
reused_elements = top_can_be_used.get(instr["id"], None)
if reused_elements is not None:
return reused_elements

current_uses = set()
comm = instr["commutative"]
first_element = True
for stack_var in reversed(instr["inpt_sk"]):
# We only consider the first element if the operation is not commutative, or both elements otherwise
if comm or first_element:
instr_bef = self._var2instr.get(stack_var, None)
if instr_bef is not None:
instr_bef_id = instr_bef["id"]
if instr_bef_id not in top_can_be_used:
current_uses.update(self._compute_top_can_used(instr_bef, top_can_be_used))
else:
current_uses.update(top_can_be_used[instr_bef_id])
# Add only instructions that are relevant to our context
current_uses.add(stack_var)
else:
break
first_element = False

top_can_be_used[instr["id"]] = current_uses
return current_uses

def _compute_values_used(self, instr: instr_JSON_T, relevant_ops: List[instr_id_T],
value_uses: Dict[var_id_T, Set[var_id_T]]) -> Set[var_id_T]:
"""
For a given instruction, determines which stack elements must be computed
"""
values_used = value_uses.get(instr["id"], None)
if values_used is not None:
return values_used

current_uses = set()
for stack_var in instr["inpt_sk"]:
instr_bef = self._var2instr.get(stack_var, None)
if instr_bef is not None:
instr_bef_id = instr_bef["id"]
if instr_bef_id not in value_uses:
current_uses.update(self._compute_values_used(instr_bef, relevant_ops, value_uses))
else:
current_uses.update(value_uses[instr_bef_id])
# Add only instructions that are relevant to our context
if instr_bef_id in relevant_ops:
current_uses.add(stack_var)
else:
current_uses.add(stack_var)
value_uses[instr["id"]] = current_uses
return current_uses

0 comments on commit af82523

Please sign in to comment.