diff --git a/vyper/venom/passes/load_elimination.py b/vyper/venom/passes/load_elimination.py index 6701b588fe..09e498e536 100644 --- a/vyper/venom/passes/load_elimination.py +++ b/vyper/venom/passes/load_elimination.py @@ -1,7 +1,16 @@ from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis, VarEquivalenceAnalysis from vyper.venom.effects import Effects from vyper.venom.passes.base_pass import IRPass +from vyper.venom.basicblock import IRLiteral +def _conflict(store_opcode: str, k1: IRLiteral, k2: IRLiteral): + ptr1, ptr2 = k1.value, k2.value + # hardcode the size of store opcodes for now. maybe refactor to use + # vyper.evm.address_space + if store_opcode == "mstore": + return abs(ptr1 - ptr2) < 32 + assert store_opcode in ("sstore", "tstore"), "unhandled store opcode" + return abs(ptr1 - ptr2) < 1 class LoadElimination(IRPass): """ @@ -20,31 +29,69 @@ def run_pass(self): self.analyses_cache.invalidate_analysis(LivenessAnalysis) self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(VarEquivalenceAnalysis) def equivalent(self, op1, op2): return op1 == op2 or self.equivalence.equivalent(op1, op2) + def get_literal(self, op): + if isinstance(op, IRLiteral): + return op + return self.equivalence.get_literal(op) + def _process_bb(self, bb, eff, load_opcode, store_opcode): # not really a lattice even though it is not really inter-basic block; # we may generalize in the future - lattice = () + self._lattice = {} for inst in bb.instructions: - if eff in inst.get_write_effects(): - lattice = () if inst.opcode == store_opcode: # mstore [val, ptr] val, ptr = inst.operands - lattice = (ptr, val) - if inst.opcode == load_opcode: - prev_lattice = lattice + known_ptr: Optional[IRLiteral] = self.get_literal(ptr) + if known_ptr is None: + # flush the lattice + self._lattice = {ptr: val} + else: + # we found a redundant store, eliminate it + existing_val = self._lattice.get(known_ptr) + if self.equivalent(val, existing_val): + inst.opcode = "nop" + inst.output = None + inst.operands = [] + continue + + self._lattice[known_ptr] = val + + # kick out any conflicts + for existing_key in self._lattice.copy().keys(): + if not isinstance(existing_key, IRLiteral): + # flush the whole thing + self._lattice = {known_ptr: val} + break + + if _conflict(store_opcode, known_ptr, existing_key): + del self._lattice[existing_key] + self._lattice[known_ptr] = val + + elif eff in inst.get_write_effects(): + self._lattice = {} + + elif inst.opcode == load_opcode: (ptr,) = inst.operands - lattice = (ptr, inst.output) - if not prev_lattice: - continue - if not self.equivalent(ptr, prev_lattice[0]): - continue - inst.opcode = "store" - inst.operands = [prev_lattice[1]] + known_ptr = self.get_literal(ptr) + if known_ptr is not None: + ptr = known_ptr + + existing_value = self._lattice.get(ptr) + + assert inst.output is not None # help mypy + + # "cache" the value for future load instructions + self._lattice[ptr] = inst.output + + if existing_value is not None: + inst.opcode = "store" + inst.operands = [existing_value]