Skip to content

Commit

Permalink
[mypyc] Refactor: use primitive op for initializing list item (#17056)
Browse files Browse the repository at this point in the history
Add a new primitive op for initializing list items. Also add support for
primitive ops that steal operands (reference counting wise).

This will also remove most instances of `WORD_SIZE` in irbuild tests,
which were a bit painful, since running tests with `--update-data`
removed these and they had to be manually added back for 32-bit tests to
pass.
  • Loading branch information
JukkaL authored Mar 22, 2024
1 parent 394d17b commit a0a0ada
Show file tree
Hide file tree
Showing 22 changed files with 463 additions and 362 deletions.
8 changes: 8 additions & 0 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,14 @@ def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1
def sources(self) -> list[Value]:
return self.args

def stolen(self) -> list[Value]:
steals = self.desc.steals
if isinstance(steals, list):
assert len(steals) == len(self.args)
return [arg for arg, steal in zip(self.args, steals) if steal]
else:
return [] if not steals else self.sources()

def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_primitive_op(self)

Expand Down
5 changes: 4 additions & 1 deletion mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def visit_primitive_op(self, op: PrimitiveOp) -> str:
type_arg_index += 1

args_str = ", ".join(args)
return self.format("%r = %s %s", op, op.desc.name, args_str)
if op.is_void:
return self.format("%s %s", op.desc.name, args_str)
else:
return self.format("%r = %s %s", op, op.desc.name, args_str)

def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
Expand Down
14 changes: 4 additions & 10 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
PrimitiveOp,
RaiseStandardError,
Register,
SetMem,
Truncate,
TupleGet,
TupleSet,
Expand Down Expand Up @@ -165,7 +164,7 @@
uint8_overflow,
)
from mypyc.primitives.list_ops import list_build_op, list_extend_op, new_list_op
from mypyc.primitives.misc_ops import bool_op, fast_isinstance_op, none_object_op
from mypyc.primitives.misc_ops import bool_op, buf_init_item, fast_isinstance_op, none_object_op
from mypyc.primitives.registry import (
ERR_NEG_INT,
CFunctionDescription,
Expand Down Expand Up @@ -1627,14 +1626,9 @@ def new_list_op(self, values: list[Value], line: int) -> Value:
ob_item_ptr = self.add(GetElementPtr(result_list, PyListObject, "ob_item", line))
ob_item_base = self.add(LoadMem(pointer_rprimitive, ob_item_ptr, line))
for i in range(len(values)):
if i == 0:
item_address = ob_item_base
else:
offset = Integer(PLATFORM_SIZE * i, c_pyssize_t_rprimitive, line)
item_address = self.add(
IntOp(pointer_rprimitive, ob_item_base, offset, IntOp.ADD, line)
)
self.add(SetMem(object_rprimitive, item_address, args[i], line))
self.primitive_op(
buf_init_item, [ob_item_base, Integer(i, c_pyssize_t_rprimitive), args[i]], line
)
self.add(KeepAlive([result_list]))
return result_list

Expand Down
14 changes: 7 additions & 7 deletions mypyc/lower/int_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mypyc.ir.ops import Assign, BasicBlock, Branch, ComparisonOp, Register, Value
from mypyc.ir.rtypes import bool_rprimitive, is_short_int_rprimitive
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
from mypyc.lower.registry import lower_binary_op
from mypyc.lower.registry import lower_primitive_op
from mypyc.primitives.int_ops import int_equal_, int_less_than_
from mypyc.primitives.registry import CFunctionDescription

Expand Down Expand Up @@ -83,31 +83,31 @@ def compare_tagged(self: LowLevelIRBuilder, lhs: Value, rhs: Value, op: str, lin
return result


@lower_binary_op("int_eq")
@lower_primitive_op("int_eq")
def lower_int_eq(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return compare_tagged(builder, args[0], args[1], "==", line)


@lower_binary_op("int_ne")
@lower_primitive_op("int_ne")
def lower_int_ne(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return compare_tagged(builder, args[0], args[1], "!=", line)


@lower_binary_op("int_lt")
@lower_primitive_op("int_lt")
def lower_int_lt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return compare_tagged(builder, args[0], args[1], "<", line)


@lower_binary_op("int_le")
@lower_primitive_op("int_le")
def lower_int_le(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return compare_tagged(builder, args[0], args[1], "<=", line)


@lower_binary_op("int_gt")
@lower_primitive_op("int_gt")
def lower_int_gt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return compare_tagged(builder, args[0], args[1], ">", line)


@lower_binary_op("int_ge")
@lower_primitive_op("int_ge")
def lower_int_ge(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return compare_tagged(builder, args[0], args[1], ">=", line)
34 changes: 34 additions & 0 deletions mypyc/lower/list_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from mypyc.common import PLATFORM_SIZE
from mypyc.ir.ops import Integer, IntOp, SetMem, Value
from mypyc.ir.rtypes import c_pyssize_t_rprimitive, object_rprimitive, pointer_rprimitive
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
from mypyc.lower.registry import lower_primitive_op


@lower_primitive_op("buf_init_item")
def buf_init_item(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
"""Initialize an item in a buffer of "PyObject *" values at given index.
This can be used to initialize the data buffer of a freshly allocated list
object.
"""
base = args[0]
index_value = args[1]
value = args[2]
assert isinstance(index_value, Integer)
index = index_value.numeric_value()
if index == 0:
ptr = base
else:
ptr = builder.add(
IntOp(
pointer_rprimitive,
base,
Integer(index * PLATFORM_SIZE, c_pyssize_t_rprimitive),
IntOp.ADD,
line,
)
)
return builder.add(SetMem(object_rprimitive, ptr, value, line))
7 changes: 4 additions & 3 deletions mypyc/lower/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
lowering_registry: Final[dict[str, LowerFunc]] = {}


def lower_binary_op(name: str) -> Callable[[LowerFunc], LowerFunc]:
"""Register a handler that generates low-level IR for a primitive binary op."""
def lower_primitive_op(name: str) -> Callable[[LowerFunc], LowerFunc]:
"""Register a handler that generates low-level IR for a primitive op."""

def wrapper(f: LowerFunc) -> LowerFunc:
assert name not in lowering_registry
Expand All @@ -23,4 +23,5 @@ def wrapper(f: LowerFunc) -> LowerFunc:


# Import various modules that set up global state.
import mypyc.lower.int_ops # noqa: F401
import mypyc.lower.int_ops
import mypyc.lower.list_ops # noqa: F401
22 changes: 20 additions & 2 deletions mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@
int_rprimitive,
object_pointer_rprimitive,
object_rprimitive,
pointer_rprimitive,
str_rprimitive,
void_rtype,
)
from mypyc.primitives.registry import (
ERR_NEG_INT,
custom_op,
custom_primitive_op,
function_op,
load_address_op,
)
from mypyc.primitives.registry import ERR_NEG_INT, custom_op, function_op, load_address_op

# Get the 'bool' type object.
load_address_op(name="builtins.bool", type=object_rprimitive, src="PyBool_Type")
Expand Down Expand Up @@ -232,10 +240,20 @@
)


# register an implementation for a singledispatch function
# Register an implementation for a singledispatch function
register_function = custom_op(
arg_types=[object_rprimitive, object_rprimitive, object_rprimitive],
return_type=object_rprimitive,
c_function_name="CPySingledispatch_RegisterFunction",
error_kind=ERR_MAGIC,
)


# Initialize a PyObject * item in a memory buffer (steal the value)
buf_init_item = custom_primitive_op(
name="buf_init_item",
arg_types=[pointer_rprimitive, c_pyssize_t_rprimitive, object_rprimitive],
return_type=void_rtype,
error_kind=ERR_NEVER,
steals=[False, False, True],
)
35 changes: 35 additions & 0 deletions mypyc/primitives/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,41 @@ def custom_op(
)


def custom_primitive_op(
name: str,
arg_types: list[RType],
return_type: RType,
error_kind: int,
c_function_name: str | None = None,
var_arg_type: RType | None = None,
truncated_type: RType | None = None,
ordering: list[int] | None = None,
extra_int_constants: list[tuple[int, RType]] | None = None,
steals: StealsDescription = False,
is_borrowed: bool = False,
) -> PrimitiveDescription:
"""Define a primitive op that can't be automatically generated based on the AST.
Most arguments are similar to method_op().
"""
if extra_int_constants is None:
extra_int_constants = []
return PrimitiveDescription(
name=name,
arg_types=arg_types,
return_type=return_type,
var_arg_type=var_arg_type,
truncated_type=truncated_type,
c_function_name=c_function_name,
error_kind=error_kind,
steals=steals,
is_borrowed=is_borrowed,
ordering=ordering,
extra_int_constants=extra_int_constants,
priority=0,
)


def unary_op(
name: str,
arg_type: RType,
Expand Down
7 changes: 3 additions & 4 deletions mypyc/test-data/irbuild-any.test
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def f2(a, n, l):
r9, r10 :: bit
r11 :: list
r12 :: object
r13, r14, r15 :: ptr
r13, r14 :: ptr
L0:
r0 = box(int, n)
r1 = PyObject_GetItem(a, r0)
Expand All @@ -123,9 +123,8 @@ L0:
r12 = box(int, n)
r13 = get_element_ptr r11 ob_item :: PyListObject
r14 = load_mem r13 :: ptr*
set_mem r14, a :: builtins.object*
r15 = r14 + WORD_SIZE*1
set_mem r15, r12 :: builtins.object*
buf_init_item r14, 0, a
buf_init_item r14, 1, r12
keep_alive r11
return 1
def f3(a, n):
Expand Down
Loading

0 comments on commit a0a0ada

Please sign in to comment.