Skip to content

Commit 60effce

Browse files
authored
[Dtype Function] fix aten.div.Tensor_mode's dtype function (#2555)
1 parent ad18219 commit 60effce

File tree

3 files changed

+104
-69
lines changed

3 files changed

+104
-69
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 83 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9524,36 +9524,43 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
95249524
" }\n"
95259525
" return %2 : !torch.int\n"
95269526
" }\n"
9527-
" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
9528-
" %false = torch.constant.bool false\n"
9529-
" %int6 = torch.constant.int 6\n"
9530-
" %true = torch.constant.bool true\n"
9527+
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
9528+
" %str = torch.constant.str \"AssertionError: Result dtype for aten.floor_divide bool\"\n"
9529+
" %int11 = torch.constant.int 11\n"
9530+
" %str_0 = torch.constant.str \"AssertionError: `other` cannot be complex\"\n"
9531+
" %none = torch.constant.none\n"
9532+
" %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
95319533
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
95329534
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9533-
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
9534-
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9535-
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
9536-
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n"
9537-
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
9538-
" torch.prim.If.yield %true : !torch.bool\n"
9535+
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
9536+
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
9537+
" torch.prim.If %3 -> () {\n"
9538+
" torch.prim.If.yield\n"
95399539
" } else {\n"
9540-
" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n"
9541-
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
9542-
" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n"
9543-
" torch.prim.If.yield %10 : !torch.bool\n"
9544-
" } else {\n"
9545-
" torch.prim.If.yield %false : !torch.bool\n"
9546-
" }\n"
9547-
" torch.prim.If.yield %9 : !torch.bool\n"
9540+
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
9541+
" torch.prim.If.yield\n"
95489542
" }\n"
9549-
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
9550-
" torch.prim.If.yield %4 : !torch.int\n"
9543+
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
9544+
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
9545+
" torch.prim.If %5 -> () {\n"
9546+
" torch.prim.If.yield\n"
95519547
" } else {\n"
9552-
" torch.prim.If.yield %int6 : !torch.int\n"
9548+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
9549+
" torch.prim.If.yield\n"
95539550
" }\n"
9554-
" return %7 : !torch.int\n"
9551+
" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
9552+
" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9553+
" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
9554+
" %9 = torch.aten.ne.int %8, %int11 : !torch.int, !torch.int -> !torch.bool\n"
9555+
" torch.prim.If %9 -> () {\n"
9556+
" torch.prim.If.yield\n"
9557+
" } else {\n"
9558+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9559+
" torch.prim.If.yield\n"
9560+
" }\n"
9561+
" return %8 : !torch.int\n"
95559562
" }\n"
9556-
" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<str>) -> !torch.int {\n"
9563+
" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
95579564
" %false = torch.constant.bool false\n"
95589565
" %int6 = torch.constant.int 6\n"
95599566
" %true = torch.constant.bool true\n"
@@ -9582,41 +9589,64 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
95829589
" }\n"
95839590
" return %7 : !torch.int\n"
95849591
" }\n"
9585-
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
9586-
" %str = torch.constant.str \"AssertionError: Result dtype for aten.floor_divide bool\"\n"
9587-
" %int11 = torch.constant.int 11\n"
9588-
" %str_0 = torch.constant.str \"AssertionError: `other` cannot be complex\"\n"
9592+
" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<str>) -> !torch.int {\n"
9593+
" %str = torch.constant.str \"trunc\"\n"
9594+
" %int6 = torch.constant.int 6\n"
9595+
" %true = torch.constant.bool true\n"
9596+
" %false = torch.constant.bool false\n"
9597+
" %str_0 = torch.constant.str \"floor\"\n"
95899598
" %none = torch.constant.none\n"
9590-
" %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
9591-
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9592-
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9593-
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
9594-
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
9595-
" torch.prim.If %3 -> () {\n"
9596-
" torch.prim.If.yield\n"
9597-
" } else {\n"
9598-
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
9599-
" torch.prim.If.yield\n"
9600-
" }\n"
9601-
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
9602-
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
9603-
" torch.prim.If %5 -> () {\n"
9604-
" torch.prim.If.yield\n"
9599+
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
9600+
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
9601+
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
9602+
" %4 = torch.aten.eq.str %3, %str_0 : !torch.str, !torch.str -> !torch.bool\n"
9603+
" torch.prim.If.yield %4 : !torch.bool\n"
96059604
" } else {\n"
9606-
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
9607-
" torch.prim.If.yield\n"
9605+
" torch.prim.If.yield %false : !torch.bool\n"
96089606
" }\n"
9609-
" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
9610-
" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9611-
" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
9612-
" %9 = torch.aten.ne.int %8, %int11 : !torch.int, !torch.int -> !torch.bool\n"
9613-
" torch.prim.If %9 -> () {\n"
9614-
" torch.prim.If.yield\n"
9607+
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
9608+
" %3 = func.call @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0, %arg1) : (!torch.tuple<int, int>, !torch.tuple<int, int>) -> !torch.int\n"
9609+
" torch.prim.If.yield %3 : !torch.int\n"
96159610
" } else {\n"
9616-
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9617-
" torch.prim.If.yield\n"
9611+
" %3:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9612+
" %4:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9613+
" %5 = torch.prim.ListConstruct %4#0, %3#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
9614+
" %6 = torch.prim.ListConstruct %4#1, %3#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9615+
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
9616+
" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%7) : (!torch.int) -> !torch.bool\n"
9617+
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
9618+
" torch.prim.If.yield %true : !torch.bool\n"
9619+
" } else {\n"
9620+
" %12 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
9621+
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
9622+
" %14 = torch.aten.ne.int %7, %int6 : !torch.int, !torch.int -> !torch.bool\n"
9623+
" torch.prim.If.yield %14 : !torch.bool\n"
9624+
" } else {\n"
9625+
" torch.prim.If.yield %false : !torch.bool\n"
9626+
" }\n"
9627+
" torch.prim.If.yield %13 : !torch.bool\n"
9628+
" }\n"
9629+
" %10 = torch.prim.If %9 -> (!torch.bool) {\n"
9630+
" torch.prim.If.yield %true : !torch.bool\n"
9631+
" } else {\n"
9632+
" %12 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
9633+
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
9634+
" %14 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
9635+
" %15 = torch.aten.eq.str %14, %str : !torch.str, !torch.str -> !torch.bool\n"
9636+
" torch.prim.If.yield %15 : !torch.bool\n"
9637+
" } else {\n"
9638+
" torch.prim.If.yield %false : !torch.bool\n"
9639+
" }\n"
9640+
" torch.prim.If.yield %13 : !torch.bool\n"
9641+
" }\n"
9642+
" %11 = torch.prim.If %10 -> (!torch.int) {\n"
9643+
" torch.prim.If.yield %7 : !torch.int\n"
9644+
" } else {\n"
9645+
" torch.prim.If.yield %int6 : !torch.int\n"
9646+
" }\n"
9647+
" torch.prim.If.yield %11 : !torch.int\n"
96189648
" }\n"
9619-
" return %8 : !torch.int\n"
9649+
" return %2 : !torch.int\n"
96209650
" }\n"
96219651
" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
96229652
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2412,6 +2412,18 @@ def aten〇bmm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[
24122412
self_priority = get_priority_of_dtype(self_dtype)
24132413
return mat2_dtype if mat2_priority < self_priority else self_dtype
24142414

2415+
@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool}))
2416+
def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
2417+
other_rank, other_dtype = other_rank_dtype
2418+
self_rank, self_dtype = self_rank_dtype
2419+
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
2420+
assert not is_complex_dtype(other_dtype), "`other` cannot be complex"
2421+
ranks: List[Optional[int]] = [self_rank, other_rank]
2422+
dtypes = [self_dtype, other_dtype]
2423+
promoted_dtype = promote_dtypes(ranks, dtypes)
2424+
assert promoted_dtype != torch.bool, "Result dtype for aten.floor_divide bool"
2425+
return promoted_dtype
2426+
24152427
@check_dtype_function(_check_two_tensor_op())
24162428
def aten〇div〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
24172429
other_rank, other_dtype = other_rank_dtype
@@ -2425,31 +2437,24 @@ def aten〇div〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty
24252437
else:
24262438
return torch.float32
24272439

2428-
@check_dtype_function(_check_two_tensor_op(rounding_mode=None))
2440+
@check_dtype_function(_check_two_tensor_op(rounding_mode=None) +
2441+
_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool}, rounding_mode="floor") +
2442+
_check_two_tensor_op(rounding_mode="trunc"))
24292443
def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rounding_mode: Optional[str]) -> int:
2444+
if rounding_mode is not None and rounding_mode == "floor":
2445+
return aten〇floor_divide〡dtype(self_rank_dtype, other_rank_dtype)
24302446
other_rank, other_dtype = other_rank_dtype
24312447
self_rank, self_dtype = self_rank_dtype
24322448
ranks: List[Optional[int]] = [self_rank, other_rank]
24332449
dtypes = [self_dtype, other_dtype]
24342450
promoted_dtype = promote_dtypes(ranks, dtypes)
24352451
if is_complex_dtype(promoted_dtype) or \
2436-
(is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32):
2452+
(is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32) or \
2453+
(rounding_mode is not None and rounding_mode == "trunc"):
24372454
return promoted_dtype
24382455
else:
24392456
return torch.float32
24402457

2441-
@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool}))
2442-
def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
2443-
other_rank, other_dtype = other_rank_dtype
2444-
self_rank, self_dtype = self_rank_dtype
2445-
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
2446-
assert not is_complex_dtype(other_dtype), "`other` cannot be complex"
2447-
ranks: List[Optional[int]] = [self_rank, other_rank]
2448-
dtypes = [self_dtype, other_dtype]
2449-
promoted_dtype = promote_dtypes(ranks, dtypes)
2450-
assert promoted_dtype != torch.bool, "Result dtype for aten.floor_divide bool"
2451-
return promoted_dtype
2452-
24532458
@check_dtype_function(
24542459
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
24552460
# Different width

projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def _get_fn_and_golden_results(f, invocation: List[Invocation]):
216216
# Check for error behavior.
217217
if invocation.is_expected_to_raise_exception():
218218
if fn_error is None and op_error is None:
219-
_report(f, invocation, f"Expected to raise an exception, but neither {fn_type} function n or op raised an exception")
219+
_report(f, invocation, f"Expected to raise an exception, but neither {fn_type} function or op raised an exception")
220220
if fn_error is None:
221-
_report(f, invocation, f"Op raised error {op_error!r}, but shape function did not.")
221+
_report(f, invocation, f"Op raised error {op_error!r}, but shape/dtype function did not.")
222222
if op_error is None:
223223
_report(f, invocation, f"{fn_type} function raised error {fn_error!r}, but op did not.")
224224
else:

0 commit comments

Comments
 (0)