Skip to content

Commit

Permalink
[mypyc] Implement lowering pass and add primitives for int (in)equali…
Browse files Browse the repository at this point in the history
…ty (#17027)

Add a new `PrimitiveOp` op which can be transformed into lower-level ops
in a lowering pass after reference counting op insertion pass.

Higher-level ops in IR make it easier to implement various
optimizations, and the output of irbuild test cases will be more compact
and readable.

Implement the lowering pass. Currently it's pretty minimal, and I will
add additional primitives and the direct transformation of various
primitives to `CallC` ops in follow-up PRs. Currently primitives that
map to C calls generate `CallC` ops in the main irbuild pass, but the
long-term goal is to only/mostly generate `PrimitiveOp`s instead of
`CallC` ops during the main irbuild pass.

Also implement primitives for tagged integer equality and inequality as
examples.

Lowering of primitives is implemented using decorated handler functions
in `mypyc.lower` that are found based on the name of the primitive. The
name has no other significance, though it's also used in pretty-printed
IR output.

Work on mypyc/mypyc#854. The issue describes the motivation in more
detail.
  • Loading branch information
JukkaL committed Mar 16, 2024
1 parent 31dc503 commit c591c89
Show file tree
Hide file tree
Showing 32 changed files with 772 additions and 483 deletions.
4 changes: 4 additions & 0 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
RegisterOp,
Return,
Expand Down Expand Up @@ -234,6 +235,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]:
def visit_call_c(self, op: CallC) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_truncate(self, op: Truncate) -> GenAndKill[T]:
return self.visit_register_op(op)

Expand Down
4 changes: 4 additions & 0 deletions mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -381,6 +382,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
def visit_call_c(self, op: CallC) -> None:
pass

def visit_primitive_op(self, op: PrimitiveOp) -> None:
pass

def visit_truncate(self, op: Truncate) -> None:
pass

Expand Down
4 changes: 4 additions & 0 deletions mypyc/analysis/selfleaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LoadStatic,
MethodCall,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
RegisterOp,
Expand Down Expand Up @@ -149,6 +150,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill:
def visit_call_c(self, op: CallC) -> GenAndKill:
return self.check_register_op(op)

def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill:
return self.check_register_op(op)

def visit_truncate(self, op: Truncate) -> GenAndKill:
return CLEAN

Expand Down
6 changes: 6 additions & 0 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -629,6 +630,11 @@ def visit_call_c(self, op: CallC) -> None:
args = ", ".join(self.reg(arg) for arg in op.args)
self.emitter.emit_line(f"{dest}{op.function_name}({args});")

def visit_primitive_op(self, op: PrimitiveOp) -> None:
raise RuntimeError(
f"unexpected PrimitiveOp {op.desc.name}: they must be lowered before codegen"
)

def visit_truncate(self, op: Truncate) -> None:
dest = self.reg(op)
value = self.reg(op.src)
Expand Down
10 changes: 7 additions & 3 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from mypyc.transform.copy_propagation import do_copy_propagation
from mypyc.transform.exceptions import insert_exception_handling
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.lower import lower_ir
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.uninit import insert_uninit_checks

Expand Down Expand Up @@ -235,6 +236,8 @@ def compile_scc_to_ir(
insert_exception_handling(fn)
# Insert refcount handling.
insert_ref_count_opcodes(fn)
# Switch to lower abstraction level IR.
lower_ir(fn, compiler_options)
# Perform optimizations.
do_copy_propagation(fn, compiler_options)
do_flag_elimination(fn, compiler_options)
Expand Down Expand Up @@ -423,10 +426,11 @@ def compile_modules_to_c(
)

modules = compile_modules_to_ir(result, mapper, compiler_options, errors)
ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options)
if errors.num_errors > 0:
return {}, []

if errors.num_errors == 0:
write_cache(modules, result, group_map, ctext)
ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options)
write_cache(modules, result, group_map, ctext)

return modules, [ctext[name] for _, name in groups]

Expand Down
79 changes: 78 additions & 1 deletion mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,78 @@ def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_method_call(self)


class PrimitiveDescription:
"""Description of a primitive op.
Primitives get lowered into lower-level ops before code generation.
If c_function_name is provided, a primitive will be lowered into a CallC op.
Otherwise custom logic will need to be implemented to transform the
primitive into lower-level ops.
"""

def __init__(
self,
name: str,
arg_types: list[RType],
return_type: RType, # TODO: What about generic?
var_arg_type: RType | None,
truncated_type: RType | None,
c_function_name: str | None,
error_kind: int,
steals: StealsDescription,
is_borrowed: bool,
ordering: list[int] | None,
extra_int_constants: list[tuple[int, RType]],
priority: int,
) -> None:
# Each primitive much have a distinct name, but otherwise they are arbitrary.
self.name: Final = name
self.arg_types: Final = arg_types
self.return_type: Final = return_type
self.var_arg_type: Final = var_arg_type
self.truncated_type: Final = truncated_type
# If non-None, this will map to a call of a C helper function; if None,
# there must be a custom handler function that gets invoked during the lowering
# pass to generate low-level IR for the primitive (in the mypyc.lower package)
self.c_function_name: Final = c_function_name
self.error_kind: Final = error_kind
self.steals: Final = steals
self.is_borrowed: Final = is_borrowed
self.ordering: Final = ordering
self.extra_int_constants: Final = extra_int_constants
self.priority: Final = priority

def __repr__(self) -> str:
return f"<PrimitiveDescription {self.name}>"


class PrimitiveOp(RegisterOp):
"""A higher-level primitive operation.
Some of these have special compiler support. These will be lowered
(transformed) into lower-level IR ops before code generation, and after
reference counting op insertion. Others will be transformed into CallC
ops.
Tagged integer equality is a typical primitive op with non-trivial
lowering. It gets transformed into a tag check, followed by different
code paths for short and long representations.
"""

def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None:
self.args = args
self.type = desc.return_type
self.error_kind = desc.error_kind
self.desc = desc

def sources(self) -> list[Value]:
return self.args

def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_primitive_op(self)


class LoadErrorValue(RegisterOp):
"""Load an error value.
Expand Down Expand Up @@ -1446,7 +1518,8 @@ class Unborrow(RegisterOp):

error_kind = ERR_NEVER

def __init__(self, src: Value) -> None:
def __init__(self, src: Value, line: int = -1) -> None:
super().__init__(line)
assert src.is_borrowed
self.src = src
self.type = src.type
Expand Down Expand Up @@ -1555,6 +1628,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> T:
def visit_call_c(self, op: CallC) -> T:
raise NotImplementedError

@abstractmethod
def visit_primitive_op(self, op: PrimitiveOp) -> T:
raise NotImplementedError

@abstractmethod
def visit_truncate(self, op: Truncate) -> T:
raise NotImplementedError
Expand Down
17 changes: 17 additions & 0 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -217,6 +218,22 @@ def visit_call_c(self, op: CallC) -> str:
else:
return self.format("%r = %s(%s)", op, op.function_name, args_str)

def visit_primitive_op(self, op: PrimitiveOp) -> str:
args = []
arg_index = 0
type_arg_index = 0
for arg_type in zip(op.desc.arg_types):
if arg_type:
args.append(self.format("%r", op.args[arg_index]))
arg_index += 1
else:
assert op.type_args
args.append(self.format("%r", op.type_args[type_arg_index]))
type_arg_index += 1

args_str = ", ".join(args)
return self.format("%r = %s %s ", op, op.desc.name, args_str)

def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)

Expand Down
7 changes: 6 additions & 1 deletion mypyc/irbuild/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def maybe_process_conditional_comparison(
self.add_bool_branch(reg, true, false)
else:
# "left op right" for two tagged integers
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
if op in ("==", "!="):
reg = self.builder.binary_op(left, right, op, e.line)
self.flush_keep_alives()
self.add_bool_branch(reg, true, false)
else:
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
return True


Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
set_literal = precompute_set_literal(builder, e.operands[1])
if set_literal is not None:
lhs = e.operands[0]
result = builder.builder.call_c(
result = builder.builder.primitive_op(
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
)
if first_op == "not in":
Expand All @@ -778,7 +778,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
borrow_left = is_borrow_friendly_expr(builder, right_expr)
left = builder.accept(left_expr, can_borrow=borrow_left)
right = builder.accept(right_expr, can_borrow=True)
return builder.compare_tagged(left, right, first_op, e.line)
return builder.binary_op(left, right, first_op, e.line)

# TODO: Don't produce an expression when used in conditional context
# All of the trickiness here is due to support for chained conditionals
Expand Down
Loading

0 comments on commit c591c89

Please sign in to comment.