diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index e35bec9dbc..55148d0137 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -13,6 +13,7 @@ CompilerPanic, ImmutableViolation, OverflowException, + StackTooDeep, StateAccessViolation, StaticAssertionException, TypeMismatch, @@ -290,6 +291,7 @@ def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: assert c.test_array(2, 7, 1, 8) == -5454 +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_four_d_array_accessor(get_contract): four_d_array_accessor = """ @external diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index 26cd16ed32..21a40182f0 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -7,7 +7,7 @@ from tests.utils import check_precompile_asserts, decimal_to_int from vyper.compiler.settings import OptimizationLevel from vyper.evm.opcodes import version_check -from vyper.exceptions import ArrayIndexException, OverflowException, TypeMismatch +from vyper.exceptions import ArrayIndexException, OverflowException, StackTooDeep, TypeMismatch def _map_nested(f, xs): @@ -193,6 +193,7 @@ def test_array(x: int128, y: int128, z: int128, w: int128) -> int128: assert c.test_array(2, 7, 1, 8) == -5454 +@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_four_d_array_accessor(get_contract): four_d_array_accessor = """ @external diff --git a/vyper/venom/passes/load_elimination.py b/vyper/venom/passes/load_elimination.py index 6701b588fe..c5e4a20a1e 100644 --- a/vyper/venom/passes/load_elimination.py +++ b/vyper/venom/passes/load_elimination.py @@ -1,8 +1,21 @@ +from typing import Optional + from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis, VarEquivalenceAnalysis +from vyper.venom.basicblock import IRLiteral from vyper.venom.effects import Effects from vyper.venom.passes.base_pass import IRPass +def _conflict(store_opcode: str, k1: IRLiteral, k2: IRLiteral): + ptr1, ptr2 = k1.value, k2.value + # hardcode the size of store opcodes for now. maybe refactor to use + # vyper.evm.address_space + if store_opcode == "mstore": + return abs(ptr1 - ptr2) < 32 + assert store_opcode in ("sstore", "tstore"), "unhandled store opcode" + return abs(ptr1 - ptr2) < 1 + + class LoadElimination(IRPass): """ Eliminate sloads, mloads and tloads @@ -11,40 +24,76 @@ class LoadElimination(IRPass): # should this be renamed to EffectsElimination? def run_pass(self): - self.equivalence = self.analyses_cache.request_analysis(VarEquivalenceAnalysis) - for bb in self.function.get_basic_blocks(): self._process_bb(bb, Effects.MEMORY, "mload", "mstore") self._process_bb(bb, Effects.TRANSIENT, "tload", "tstore") self._process_bb(bb, Effects.STORAGE, "sload", "sstore") + self._process_bb(bb, None, "dload", None) self.analyses_cache.invalidate_analysis(LivenessAnalysis) self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(VarEquivalenceAnalysis) def equivalent(self, op1, op2): - return op1 == op2 or self.equivalence.equivalent(op1, op2) + return op1 == op2 + + def get_literal(self, op): + if isinstance(op, IRLiteral): + return op + return None def _process_bb(self, bb, eff, load_opcode, store_opcode): # not really a lattice even though it is not really inter-basic block; # we may generalize in the future - lattice = () + self._lattice = {} for inst in bb.instructions: - if eff in inst.get_write_effects(): - lattice = () - if inst.opcode == store_opcode: # mstore [val, ptr] val, ptr = inst.operands - lattice = (ptr, val) - if inst.opcode == load_opcode: - prev_lattice = lattice + known_ptr: Optional[IRLiteral] = self.get_literal(ptr) + if known_ptr is None: + # flush the lattice + self._lattice = {ptr: val} + else: + # we found a redundant store, eliminate it + existing_val = self._lattice.get(known_ptr) + if self.equivalent(val, existing_val): + inst.opcode = "nop" + inst.output = None + inst.operands = [] + continue + + self._lattice[known_ptr] = val + + # kick out any conflicts + for existing_key in self._lattice.copy().keys(): + if not isinstance(existing_key, IRLiteral): + # flush the whole thing + self._lattice = {known_ptr: val} + break + + if _conflict(store_opcode, known_ptr, existing_key): + del self._lattice[existing_key] + self._lattice[known_ptr] = val + + elif eff is not None and eff in inst.get_write_effects(): + self._lattice = {} + + elif inst.opcode == load_opcode: (ptr,) = inst.operands - lattice = (ptr, inst.output) - if not prev_lattice: - continue - if not self.equivalent(ptr, prev_lattice[0]): - continue - inst.opcode = "store" - inst.operands = [prev_lattice[1]] + known_ptr = self.get_literal(ptr) + if known_ptr is not None: + ptr = known_ptr + + existing_value = self._lattice.get(ptr) + + assert inst.output is not None # help mypy + + # "cache" the value for future load instructions + self._lattice[ptr] = inst.output + + if existing_value is not None: + inst.opcode = "store" + inst.operands = [existing_value]