Skip to content

Don't write to locals if they are never read #502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions library/_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,5 +1486,89 @@ async def foo():
# TODO(emacs): Test with (multiple context managers)


@pyro_only
class OptStoreFastTests(unittest.TestCase):
def test_store_with_load_not_replaced(self):
source = """
def foo():
_ = 123
return _
"""
func = compile_function(source, "foo")
self.assertEqual(
dis(func.__code__),
"""\
LOAD_CONST 123
STORE_FAST_REVERSE _
LOAD_FAST_REVERSE_UNCHECKED _
RETURN_VALUE
""",
)
self.assertEqual(func(), 123)

def test_store_with_no_load_replaced_with_pop_top(self):
source = """
def foo():
_ = 123
return 456
"""
func = compile_function(source, "foo")
self.assertEqual(
dis(func.__code__),
"""\
LOAD_CONST 123
POP_TOP
LOAD_CONST 456
RETURN_VALUE
""",
)
self.assertEqual(func(), 456)

def test_store_in_loop_replaced_with_pop_top(self):
source = """
def foo(x):
for _ in x:
pass
"""
func = compile_function(source, "foo")
self.assertEqual(
dis(func.__code__),
"""\
LOAD_FAST_REVERSE_UNCHECKED x
GET_ITER
FOR_ITER 4
POP_TOP
JUMP_ABSOLUTE 4
LOAD_CONST None
RETURN_VALUE
""",
)
self.assertEqual(func(()), None)

def test_multiple_store_no_read_replaced_with_pop_top(self):
source = """
def foo():
x = 123
y = 456
z = 789
return x
"""
func = compile_function(source, "foo")
self.assertEqual(
dis(func.__code__),
"""\
LOAD_CONST 123
STORE_FAST_REVERSE x
LOAD_CONST 456
POP_TOP
LOAD_CONST 789
POP_TOP
LOAD_FAST_REVERSE_UNCHECKED x
RETURN_VALUE
""",
)
self.assertEqual(func(), 123)


if __name__ == "__main__":
unittest.main()
20 changes: 20 additions & 0 deletions library/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,25 @@ def visitBinOp(self, node: ast.BinOp) -> ast.expr:
class PyroFlowGraph(PyFlowGraph38):
opcode = opcodepyro.opcode

def optimizeStoreFast(self):
if "locals" in self.varnames or "locals" in self.names:
# A bit of a hack: if someone is using locals(), we shouldn't mess
# with them.
return
used = set()
for block in self.getBlocksInOrder():
for instr in block.getInstructions():
if instr.opname == "LOAD_FAST" or instr.opname == "DELETE_FAST":
used.add(instr.oparg)
# We never read from or delete the local, so we can replace all stores
# to it with POP_TOP.
for block in self.getBlocksInOrder():
for instr in block.getInstructions():
if instr.opname == "STORE_FAST" and instr.oparg not in used:
instr.opname = "POP_TOP"
instr.oparg = 0
instr.ioparg = 0

def optimizeLoadFast(self):
blocks = self.getBlocksInOrder()
preds = tuple(set() for i in range(self.block_count))
Expand Down Expand Up @@ -297,6 +316,7 @@ def process_one_block(block, modify=False):
self.entry.insts = deletes + self.entry.insts

def getCode(self):
self.optimizeStoreFast()
self.optimizeLoadFast()
return super().getCode()

Expand Down