diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 0035bd53188b..9466bc2cea79 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -58,6 +58,7 @@ from mypyc.options import CompilerOptions from mypyc.transform.copy_propagation import do_copy_propagation from mypyc.transform.exceptions import insert_exception_handling +from mypyc.transform.flag_elimination import do_flag_elimination from mypyc.transform.refcount import insert_ref_count_opcodes from mypyc.transform.uninit import insert_uninit_checks @@ -234,8 +235,9 @@ def compile_scc_to_ir( insert_exception_handling(fn) # Insert refcount handling. insert_ref_count_opcodes(fn) - # Perform copy propagation optimization. + # Perform optimizations. do_copy_propagation(fn, compiler_options) + do_flag_elimination(fn, compiler_options) return modules diff --git a/mypyc/test-data/opt-flag-elimination.test b/mypyc/test-data/opt-flag-elimination.test new file mode 100644 index 000000000000..f047a87dc3fa --- /dev/null +++ b/mypyc/test-data/opt-flag-elimination.test @@ -0,0 +1,300 @@ +-- Test cases for "flag elimination" optimization. Used to optimize away +-- registers that are always used immediately after assignment as branch conditions. + +[case testFlagEliminationSimple] +def c() -> bool: + return True +def d() -> bool: + return True + +def f(x: bool) -> int: + if x: + b = c() + else: + b = d() + if b: + return 1 + else: + return 2 +[out] +def c(): +L0: + return 1 +def d(): +L0: + return 1 +def f(x): + x, r0, r1 :: bool +L0: + if x goto L1 else goto L2 :: bool +L1: + r0 = c() + if r0 goto L4 else goto L5 :: bool +L2: + r1 = d() + if r1 goto L4 else goto L5 :: bool +L3: + unreachable +L4: + return 2 +L5: + return 4 + +[case testFlagEliminationOneAssignment] +def c() -> bool: + return True + +def f(x: bool) -> int: + # Not applied here + b = c() + if b: + return 1 + else: + return 2 +[out] +def c(): +L0: + return 1 +def f(x): + x, r0, b :: bool +L0: + r0 = c() + b = r0 + if b goto L1 else goto L2 :: bool +L1: + return 2 +L2: + return 4 + +[case testFlagEliminationThreeCases] +def c(x: int) -> bool: + return True + +def f(x: bool, y: bool) -> int: + if x: + b = c(1) + elif y: + b = c(2) + else: + b = c(3) + if b: + return 1 + else: + return 2 +[out] +def c(x): + x :: int +L0: + return 1 +def f(x, y): + x, y, r0, r1, r2 :: bool +L0: + if x goto L1 else goto L2 :: bool +L1: + r0 = c(2) + if r0 goto L6 else goto L7 :: bool +L2: + if y goto L3 else goto L4 :: bool +L3: + r1 = c(4) + if r1 goto L6 else goto L7 :: bool +L4: + r2 = c(6) + if r2 goto L6 else goto L7 :: bool +L5: + unreachable +L6: + return 2 +L7: + return 4 + +[case testFlagEliminationAssignmentNotLastOp] +def f(x: bool) -> int: + y = 0 + if x: + b = True + y = 1 + else: + b = False + if b: + return 1 + else: + return 2 +[out] +def f(x): + x :: bool + y :: int + b :: bool +L0: + y = 0 + if x goto L1 else goto L2 :: bool +L1: + b = 1 + y = 2 + goto L3 +L2: + b = 0 +L3: + if b goto L4 else goto L5 :: bool +L4: + return 2 +L5: + return 4 + +[case testFlagEliminationAssignmentNoDirectGoto] +def f(x: bool) -> int: + if x: + b = True + else: + b = False + if x: + if b: + return 1 + else: + return 2 + return 4 +[out] +def f(x): + x, b :: bool +L0: + if x goto L1 else goto L2 :: bool +L1: + b = 1 + goto L3 +L2: + b = 0 +L3: + if x goto L4 else goto L7 :: bool +L4: + if b goto L5 else goto L6 :: bool +L5: + return 2 +L6: + return 4 +L7: + return 8 + +[case testFlagEliminationBranchNotNextOpAfterGoto] +def f(x: bool) -> int: + if x: + b = True + else: + b = False + y = 1 # Prevents the optimization + if b: + return 1 + else: + return 2 +[out] +def f(x): + x, b :: bool + y :: int +L0: + if x goto L1 else goto L2 :: bool +L1: + b = 1 + goto L3 +L2: + b = 0 +L3: + y = 2 + if b goto L4 else goto L5 :: bool +L4: + return 2 +L5: + return 4 + +[case testFlagEliminationFlagReadTwice] +def f(x: bool) -> bool: + if x: + b = True + else: + b = False + if b: + return b # Prevents the optimization + else: + return False +[out] +def f(x): + x, b :: bool +L0: + if x goto L1 else goto L2 :: bool +L1: + b = 1 + goto L3 +L2: + b = 0 +L3: + if b goto L4 else goto L5 :: bool +L4: + return b +L5: + return 0 + +[case testFlagEliminationArgumentNotEligible] +def f(x: bool, b: bool) -> bool: + if x: + b = True + else: + b = False + if b: + return True + else: + return False +[out] +def f(x, b): + x, b :: bool +L0: + if x goto L1 else goto L2 :: bool +L1: + b = 1 + goto L3 +L2: + b = 0 +L3: + if b goto L4 else goto L5 :: bool +L4: + return 1 +L5: + return 0 + +[case testFlagEliminationFlagNotAlwaysDefined] +def f(x: bool, y: bool) -> bool: + if x: + b = True + elif y: + b = False + else: + bb = False # b not assigned here -> can't optimize + if b: + return True + else: + return False +[out] +def f(x, y): + x, y, r0, b, bb, r1 :: bool +L0: + r0 = :: bool + b = r0 + if x goto L1 else goto L2 :: bool +L1: + b = 1 + goto L5 +L2: + if y goto L3 else goto L4 :: bool +L3: + b = 0 + goto L5 +L4: + bb = 0 +L5: + if is_error(b) goto L6 else goto L7 +L6: + r1 = raise UnboundLocalError('local variable "b" referenced before assignment') + unreachable +L7: + if b goto L8 else goto L9 :: bool +L8: + return 1 +L9: + return 0 diff --git a/mypyc/test/test_copy_propagation.py b/mypyc/test/test_optimizations.py similarity index 62% rename from mypyc/test/test_copy_propagation.py rename to mypyc/test/test_optimizations.py index c729e3d186c3..3f1f46ac1dd7 100644 --- a/mypyc/test/test_copy_propagation.py +++ b/mypyc/test/test_optimizations.py @@ -1,4 +1,4 @@ -"""Runner for copy propagation optimization tests.""" +"""Runner for IR optimization tests.""" from __future__ import annotations @@ -8,6 +8,7 @@ from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase from mypyc.common import TOP_LEVEL_NAME +from mypyc.ir.func_ir import FuncIR from mypyc.ir.pprint import format_func from mypyc.options import CompilerOptions from mypyc.test.testutil import ( @@ -19,13 +20,16 @@ use_custom_builtins, ) from mypyc.transform.copy_propagation import do_copy_propagation +from mypyc.transform.flag_elimination import do_flag_elimination from mypyc.transform.uninit import insert_uninit_checks -files = ["opt-copy-propagation.test"] +class OptimizationSuite(MypycDataSuite): + """Base class for IR optimization test suites. + + To use this, add a base class and define "files" and "do_optimizations". + """ -class TestCopyPropagation(MypycDataSuite): - files = files base_path = test_temp_dir def run_case(self, testcase: DataDrivenTestCase) -> None: @@ -41,7 +45,24 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"): continue insert_uninit_checks(fn) - do_copy_propagation(fn, CompilerOptions()) + self.do_optimizations(fn) actual.extend(format_func(fn)) assert_test_output(testcase, actual, "Invalid source code output", expected_output) + + def do_optimizations(self, fn: FuncIR) -> None: + raise NotImplementedError + + +class TestCopyPropagation(OptimizationSuite): + files = ["opt-copy-propagation.test"] + + def do_optimizations(self, fn: FuncIR) -> None: + do_copy_propagation(fn, CompilerOptions()) + + +class TestFlagElimination(OptimizationSuite): + files = ["opt-flag-elimination.test"] + + def do_optimizations(self, fn: FuncIR) -> None: + do_flag_elimination(fn, CompilerOptions()) diff --git a/mypyc/transform/flag_elimination.py b/mypyc/transform/flag_elimination.py new file mode 100644 index 000000000000..605e5bc46ae4 --- /dev/null +++ b/mypyc/transform/flag_elimination.py @@ -0,0 +1,108 @@ +"""Bool register elimination optimization. + +Example input: + + L1: + r0 = f() + b = r0 + goto L3 + L2: + r1 = g() + b = r1 + goto L3 + L3: + if b goto L4 else goto L5 + +The register b is redundant and we replace the assignments with two copies of +the branch in L3: + + L1: + r0 = f() + if r0 goto L4 else goto L5 + L2: + r1 = g() + if r1 goto L4 else goto L5 + +This helps generate simpler IR for tagged integers comparisons, for example. +""" + +from __future__ import annotations + +from mypyc.ir.func_ir import FuncIR +from mypyc.ir.ops import Assign, BasicBlock, Branch, Goto, Register, Unreachable +from mypyc.irbuild.ll_builder import LowLevelIRBuilder +from mypyc.options import CompilerOptions +from mypyc.transform.ir_transform import IRTransform + + +def do_flag_elimination(fn: FuncIR, options: CompilerOptions) -> None: + # Find registers that are used exactly once as source, and in a branch. + counts: dict[Register, int] = {} + branches: dict[Register, Branch] = {} + labels: dict[Register, BasicBlock] = {} + for block in fn.blocks: + for i, op in enumerate(block.ops): + for src in op.sources(): + if isinstance(src, Register): + counts[src] = counts.get(src, 0) + 1 + if i == 0 and isinstance(op, Branch) and isinstance(op.value, Register): + branches[op.value] = op + labels[op.value] = block + + # Based on these we can find the candidate registers. + candidates: set[Register] = { + r for r in branches if counts.get(r, 0) == 1 and r not in fn.arg_regs + } + + # Remove candidates with invalid assignments. + for block in fn.blocks: + for i, op in enumerate(block.ops): + if isinstance(op, Assign) and op.dest in candidates: + next_op = block.ops[i + 1] + if not (isinstance(next_op, Goto) and next_op.label is labels[op.dest]): + # Not right + candidates.remove(op.dest) + + builder = LowLevelIRBuilder(None, options) + transform = FlagEliminationTransform( + builder, {x: y for x, y in branches.items() if x in candidates} + ) + transform.transform_blocks(fn.blocks) + fn.blocks = builder.blocks + + +class FlagEliminationTransform(IRTransform): + def __init__(self, builder: LowLevelIRBuilder, branch_map: dict[Register, Branch]) -> None: + super().__init__(builder) + self.branch_map = branch_map + self.branches = set(branch_map.values()) + + def visit_assign(self, op: Assign) -> None: + old_branch = self.branch_map.get(op.dest) + if old_branch: + # Replace assignment with a copy of the old branch, which is in a + # separate basic block. The old branch will be deletecd in visit_branch. + new_branch = Branch( + op.src, + old_branch.true, + old_branch.false, + old_branch.op, + old_branch.line, + rare=old_branch.rare, + ) + new_branch.negated = old_branch.negated + new_branch.traceback_entry = old_branch.traceback_entry + self.add(new_branch) + else: + self.add(op) + + def visit_goto(self, op: Goto) -> None: + # This is a no-op if basic block already terminated + self.builder.goto(op.label) + + def visit_branch(self, op: Branch) -> None: + if op in self.branches: + # This branch is optimized away + self.add(Unreachable()) + else: + self.add(op) diff --git a/mypyc/transform/ir_transform.py b/mypyc/transform/ir_transform.py index 1bcfc8fb5feb..254fe3f7771d 100644 --- a/mypyc/transform/ir_transform.py +++ b/mypyc/transform/ir_transform.py @@ -101,17 +101,17 @@ def transform_blocks(self, blocks: list[BasicBlock]) -> None: def add(self, op: Op) -> Value: return self.builder.add(op) - def visit_goto(self, op: Goto) -> Value: - return self.add(op) + def visit_goto(self, op: Goto) -> None: + self.add(op) - def visit_branch(self, op: Branch) -> Value: - return self.add(op) + def visit_branch(self, op: Branch) -> None: + self.add(op) - def visit_return(self, op: Return) -> Value: - return self.add(op) + def visit_return(self, op: Return) -> None: + self.add(op) - def visit_unreachable(self, op: Unreachable) -> Value: - return self.add(op) + def visit_unreachable(self, op: Unreachable) -> None: + self.add(op) def visit_assign(self, op: Assign) -> Value | None: return self.add(op)