Skip to content

Commit

Permalink
[mypyc] Refactor: move tagged int related code to mypyc.lower.int_ops (
Browse files Browse the repository at this point in the history
  • Loading branch information
JukkaL authored Mar 20, 2024
1 parent afdd9d5 commit 952c616
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 83 deletions.
5 changes: 2 additions & 3 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@
from mypyc.primitives.bytes_ops import bytes_slice_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op
from mypyc.primitives.generic_ops import iter_op
from mypyc.primitives.int_ops import int_comparison_op_mapping
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
from mypyc.primitives.registry import CFunctionDescription, builtin_names
Expand Down Expand Up @@ -814,7 +813,7 @@ def translate_is_none(builder: IRBuilder, expr: Expression, negated: bool) -> Va
def transform_basic_comparison(
builder: IRBuilder, op: str, left: Value, right: Value, line: int
) -> Value:
if is_fixed_width_rtype(left.type) and op in int_comparison_op_mapping:
if is_fixed_width_rtype(left.type) and op in ComparisonOp.signed_ops:
if right.type == left.type:
if left.type.is_signed:
op_id = ComparisonOp.signed_ops[op]
Expand All @@ -831,7 +830,7 @@ def transform_basic_comparison(
)
elif (
is_fixed_width_rtype(right.type)
and op in int_comparison_op_mapping
and op in ComparisonOp.signed_ops
and isinstance(left, Integer)
):
if right.type.is_signed:
Expand Down
45 changes: 0 additions & 45 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@
int64_divide_op,
int64_mod_op,
int64_to_int_op,
int_comparison_op_mapping,
int_to_int32_op,
int_to_int64_op,
ssize_t_to_int_op,
Expand Down Expand Up @@ -1413,50 +1412,6 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -
check = self.comparison_op(bitwise_and, zero, op, line)
return check

def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two tagged integers using given operator (value context)."""
# generate fast binary logic ops on short ints
if (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type)) and op in (
"==",
"!=",
):
quick = True
else:
quick = is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)
if quick:
return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
result = Register(bool_rprimitive)
short_int_block, int_block, out = BasicBlock(), BasicBlock(), BasicBlock()
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
if op in ("==", "!="):
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
else:
# for non-equality logical ops (less/greater than, etc.), need to check both sides
short_lhs = BasicBlock()
self.add(Branch(check_lhs, int_block, short_lhs, Branch.BOOL))
self.activate_block(short_lhs)
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
self.activate_block(int_block)
if swap_op:
args = [rhs, lhs]
else:
args = [lhs, rhs]
call = self.call_c(c_func_desc, args, line)
if negate_result:
# TODO: introduce UnaryIntOp?
call_result = self.unary_op(call, "not", line)
else:
call_result = call
self.add(Assign(result, call_result, line))
self.goto(out)
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Assign(result, eq, line))
self.goto_and_activate(out)
return result

def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two strings"""
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
Expand Down
92 changes: 85 additions & 7 deletions mypyc/lower/int_ops.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,113 @@
"""Convert tagged int primitive ops to lower-level ops."""

from __future__ import annotations

from mypyc.ir.ops import Value
from typing import NamedTuple

from mypyc.ir.ops import Assign, BasicBlock, Branch, ComparisonOp, Register, Value
from mypyc.ir.rtypes import bool_rprimitive, is_short_int_rprimitive
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
from mypyc.lower.registry import lower_binary_op
from mypyc.primitives.int_ops import int_equal_, int_less_than_
from mypyc.primitives.registry import CFunctionDescription


# Description for building int comparison ops
#
# Fields:
# binary_op_variant: identify which IntOp to use when operands are short integers
# c_func_description: the C function to call when operands are tagged integers
# c_func_negated: whether to negate the C function call's result
# c_func_swap_operands: whether to swap lhs and rhs when call the function
class IntComparisonOpDescription(NamedTuple):
binary_op_variant: int
c_func_description: CFunctionDescription
c_func_negated: bool
c_func_swap_operands: bool


# Provide mapping from textual op to short int's op variant and boxed int's description.
# Note that these are not complete implementations and require extra IR.
int_comparison_op_mapping: dict[str, IntComparisonOpDescription] = {
"==": IntComparisonOpDescription(ComparisonOp.EQ, int_equal_, False, False),
"!=": IntComparisonOpDescription(ComparisonOp.NEQ, int_equal_, True, False),
"<": IntComparisonOpDescription(ComparisonOp.SLT, int_less_than_, False, False),
"<=": IntComparisonOpDescription(ComparisonOp.SLE, int_less_than_, True, True),
">": IntComparisonOpDescription(ComparisonOp.SGT, int_less_than_, False, True),
">=": IntComparisonOpDescription(ComparisonOp.SGE, int_less_than_, True, False),
}


def compare_tagged(self: LowLevelIRBuilder, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two tagged integers using given operator (value context)."""
# generate fast binary logic ops on short ints
if (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type)) and op in (
"==",
"!=",
):
quick = True
else:
quick = is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)
if quick:
return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
result = Register(bool_rprimitive)
short_int_block, int_block, out = BasicBlock(), BasicBlock(), BasicBlock()
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
if op in ("==", "!="):
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
else:
# for non-equality logical ops (less/greater than, etc.), need to check both sides
short_lhs = BasicBlock()
self.add(Branch(check_lhs, int_block, short_lhs, Branch.BOOL))
self.activate_block(short_lhs)
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
self.activate_block(int_block)
if swap_op:
args = [rhs, lhs]
else:
args = [lhs, rhs]
call = self.call_c(c_func_desc, args, line)
if negate_result:
# TODO: introduce UnaryIntOp?
call_result = self.unary_op(call, "not", line)
else:
call_result = call
self.add(Assign(result, call_result, line))
self.goto(out)
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Assign(result, eq, line))
self.goto_and_activate(out)
return result


@lower_binary_op("int_eq")
def lower_int_eq(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "==", line)
return compare_tagged(builder, args[0], args[1], "==", line)


@lower_binary_op("int_ne")
def lower_int_ne(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "!=", line)
return compare_tagged(builder, args[0], args[1], "!=", line)


@lower_binary_op("int_lt")
def lower_int_lt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "<", line)
return compare_tagged(builder, args[0], args[1], "<", line)


@lower_binary_op("int_le")
def lower_int_le(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], "<=", line)
return compare_tagged(builder, args[0], args[1], "<=", line)


@lower_binary_op("int_gt")
def lower_int_gt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], ">", line)
return compare_tagged(builder, args[0], args[1], ">", line)


@lower_binary_op("int_ge")
def lower_int_ge(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
return builder.compare_tagged(args[0], args[1], ">=", line)
return compare_tagged(builder, args[0], args[1], ">=", line)
28 changes: 0 additions & 28 deletions mypyc/primitives/int_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@

from __future__ import annotations

from typing import NamedTuple

from mypyc.ir.ops import (
ERR_ALWAYS,
ERR_MAGIC,
ERR_MAGIC_OVERLAPPING,
ERR_NEVER,
ComparisonOp,
PrimitiveDescription,
)
from mypyc.ir.rtypes import (
Expand Down Expand Up @@ -196,20 +193,6 @@ def int_unary_op(name: str, c_function_name: str) -> CFunctionDescription:
# Primitives related to integer comparison operations:


# Description for building int comparison ops
#
# Fields:
# binary_op_variant: identify which IntOp to use when operands are short integers
# c_func_description: the C function to call when operands are tagged integers
# c_func_negated: whether to negate the C function call's result
# c_func_swap_operands: whether to swap lhs and rhs when call the function
class IntComparisonOpDescription(NamedTuple):
binary_op_variant: int
c_func_description: CFunctionDescription
c_func_negated: bool
c_func_swap_operands: bool


# Equals operation on two boxed tagged integers
int_equal_ = custom_op(
arg_types=[int_rprimitive, int_rprimitive],
Expand All @@ -226,17 +209,6 @@ class IntComparisonOpDescription(NamedTuple):
error_kind=ERR_NEVER,
)

# Provide mapping from textual op to short int's op variant and boxed int's description.
# Note that these are not complete implementations and require extra IR.
int_comparison_op_mapping: dict[str, IntComparisonOpDescription] = {
"==": IntComparisonOpDescription(ComparisonOp.EQ, int_equal_, False, False),
"!=": IntComparisonOpDescription(ComparisonOp.NEQ, int_equal_, True, False),
"<": IntComparisonOpDescription(ComparisonOp.SLT, int_less_than_, False, False),
"<=": IntComparisonOpDescription(ComparisonOp.SLE, int_less_than_, True, True),
">": IntComparisonOpDescription(ComparisonOp.SGT, int_less_than_, False, True),
">=": IntComparisonOpDescription(ComparisonOp.SGE, int_less_than_, True, False),
}

int64_divide_op = custom_op(
arg_types=[int64_rprimitive, int64_rprimitive],
return_type=int64_rprimitive,
Expand Down

0 comments on commit 952c616

Please sign in to comment.