From c4b7aef1c0b971b48d863cf05130c25aa5b841cd Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Fri, 24 Apr 2026 09:35:55 +0200 Subject: [PATCH 1/6] [Torch] add aten._scaled_mm op support and FX import plumbing Add Torch op definitions, abstract interpretation support, and FX importer handling for aten._scaled_mm, including float8_e8m0fnu blocked-scale tensors. Keep this change independent of TOSA legalization. Include frontend export/import coverage for per-tensor and blocked-scale FP8 shapes, plus out_dtype=None dtype inference. Signed-off-by: Cathal Corbett Change-Id: I102732ad725f89477b7f8fb2339d4fe920fa647b --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 30 +++ .../TorchToTosa/TosaLegalizeUtils.cpp | 16 +- lib/Dialect/Torch/IR/TorchTypes.cpp | 4 +- .../Transforms/AbstractInterpLibrary.cpp | 16 ++ .../build_tools/abstract_interp_lib_gen.py | 11 +- .../build_tools/torch_ods_gen.py | 3 + test/python/fx_importer/basic_test.py | 251 ++++++++++++++++++ 7 files changed, 320 insertions(+), 11 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 2a35d0f9ba9e..a719a82e5244 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6623,6 +6623,36 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [ }]; } +def Torch_Aten_ScaledMmOp : Torch_Op<"aten._scaled_mm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_scaled_mm : (Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, int?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mat2, + AnyTorchTensorType:$scale_a, + AnyTorchTensorType:$scale_b, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$scale_result, + AnyTorchOptionalIntType:$out_dtype, + Torch_BoolType:$use_fast_accum + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_ScaledMmOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void Aten_ScaledMmOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + def Torch_Aten_IntMmOp : Torch_Op<"aten._int_mm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 844b1d965b25..24521f51013d 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -297,7 +297,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isBF16()) || (src.isF32() && dest.isF16()) || - (src.isF32() && isa(dest)) || + (src.isF32() && isa(dest)) || (src.isF32() && isa(dest)) || // f16 -> * (src.isF16() && dest.isInteger(32)) || @@ -305,19 +305,19 @@ std::optional getConstTensor(PatternRewriter &rewriter, (src.isF16() && dest.isInteger(8)) || (src.isF16() && dest.isBF16()) || (src.isF16() && dest.isF32()) || - (src.isF16() && isa(dest)) || + (src.isF16() && isa(dest)) || (src.isF16() && isa(dest)) || // bf16 -> * (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isInteger(16)) || (src.isBF16() && dest.isInteger(8)) || (src.isBF16() && dest.isF32()) || - (src.isBF16() && isa(dest)) || + (src.isBF16() && isa(dest)) || (src.isBF16() && isa(dest)) || // fp8e4m3 -> * - (isa(src) && dest.isBF16()) || - (isa(src) && dest.isF32()) || - (isa(src) && dest.isF16()) || + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16()) || // fp8e5m2 -> * (isa(src) && dest.isBF16()) || (isa(src) && dest.isF32()) || @@ -514,8 +514,8 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && outputElemTy.isInteger(48)) { accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); - } else if ((isa(inputElemTy) && - isa(weightElemTy) && outputElemTy.isF16()) || + } else if ((isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16()) || (isa(inputElemTy) && isa(weightElemTy) && outputElemTy.isF16())) { accType = mlir::TypeAttr::get(rewriter.getF16Type()); diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index fd0b50cd3585..c4c88d6fd757 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -191,8 +191,8 @@ static bool isValidTorchDtype(Type dtype) { // Builtin floating point types. if (isa(dtype)) return true; - if (isa(dtype)) + if (isa(dtype)) return true; if (isa(dtype)) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 155e6ca8c106..e2be43c991e6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8407,6 +8407,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._scaled_mm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._int_mm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -15544,6 +15548,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._scaled_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" %3 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._int_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" " %none = torch.constant.none\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 724ad18044d5..c5a4a6bdb8c5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -908,6 +908,9 @@ def aten〇mv〡shape(self: List[int], vec: List[int]) -> List[int]: def aten〇mm〡shape(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) +def aten〇_scaled_mm〡shape(self: List[int], mat2: List[int], scale_a: List[int], scale_b: List[int], bias: Optional[List[int]] = None, scale_result: Optional[List[int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> List[int]: + return upstream_shape_functions.mm(self, mat2) + def aten〇_int_mm〡shape(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) @@ -4682,6 +4685,13 @@ def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[i dtypes = [self_dtype, mat2_dtype] return promote_dtypes(ranks, dtypes) +def aten〇_scaled_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], scale_a_rank_dtype: Tuple[int, int], scale_b_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, scale_result_rank_dtype: Optional[Tuple[int, int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + mat2_rank, mat2_dtype = mat2_rank_dtype + if out_dtype is not None: + return out_dtype + return self_dtype + def aten〇_int_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype mat2_rank, mat2_dtype = mat2_rank_dtype @@ -6369,4 +6379,3 @@ def _create_argparse() -> argparse.ArgumentParser: if __name__ == "__main__": main(_create_argparse().parse_args()) - diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 0fc4fd3e0de9..8eb7164191e5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -574,6 +574,9 @@ def emit_with_mutating_variants(key, **kwargs): # Non-elementwise tensor compute ops emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") emit("aten::mm : (Tensor, Tensor) -> (Tensor)") + emit( + "aten::_scaled_mm : (Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, int?, bool) -> (Tensor)" + ) emit("aten::_int_mm : (Tensor, Tensor) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index f00290593d1c..3e8abf0c79b0 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -28,6 +28,21 @@ def run(f): print() +def print_exported_graph_with_tensor_meta(gm): + print(gm.code) + for node in gm.graph.nodes: + tm = node.meta.get("tensor_meta") + val = node.meta.get("val") + if tm is not None: + print( + f"META {node.name} {node.op} {node.target} {tuple(tm.shape)} {tm.dtype}" + ) + elif isinstance(val, torch.Tensor): + print( + f"META {node.name} {node.op} {node.target} {tuple(val.shape)} {val.dtype}" + ) + + @run # CHECK-LABEL: test_import_frozen_exported_program # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> @@ -238,6 +253,242 @@ def forward(self): print(m) +@run +# CHECK-LABEL: test_export_scaled_mm_per_tensor_frontend +# CHECK: def forward(self, a, b, a_scale, b_scale): +# CHECK: torch.ops.aten._scaled_mm.default(a, b, a_scale, b_scale, None, None, torch.bfloat16) +# CHECK: META a placeholder a (128, 128) torch.float8_e4m3fn +# CHECK: META b placeholder b (128, 128) torch.float8_e4m3fn +# CHECK: META a_scale placeholder a_scale () torch.float32 +# CHECK: META b_scale placeholder b_scale () torch.float32 +# CHECK: META _scaled_mm call_function aten._scaled_mm.default (128, 128) torch.bfloat16 +def test_export_scaled_mm_per_tensor_frontend(): + class Basic(nn.Module): + def forward(self, a, b, a_scale, b_scale): + return torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=torch.bfloat16) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale = torch.tensor(1.0, dtype=torch.float32) + b_scale = torch.tensor(1.0, dtype=torch.float32) + + exported = torch.export.export(Basic(), (a, b, a_scale, b_scale)) + print_exported_graph_with_tensor_meta(exported.module()) + + +@run +# CHECK-LABEL: test_export_scaled_mm_out_dtype_none_frontend +# CHECK: def forward(self, a, b, a_scale, b_scale): +# CHECK: torch.ops.aten._scaled_mm.default(a, b, a_scale, b_scale) +# CHECK: META a placeholder a (128, 128) torch.float8_e4m3fn +# CHECK: META b placeholder b (128, 128) torch.float8_e5m2 +# CHECK: META a_scale placeholder a_scale () torch.float32 +# CHECK: META b_scale placeholder b_scale () torch.float32 +# CHECK: META _scaled_mm call_function aten._scaled_mm.default (128, 128) torch.float8_e4m3fn +def test_export_scaled_mm_out_dtype_none_frontend(): + class Basic(nn.Module): + def forward(self, a, b, a_scale, b_scale): + return torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=None) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + a_scale = torch.tensor(1.0, dtype=torch.float32) + b_scale = torch.tensor(1.0, dtype=torch.float32) + + exported = torch.export.export(Basic(), (a, b, a_scale, b_scale)) + print_exported_graph_with_tensor_meta(exported.module()) + + +@run +# CHECK-LABEL: test_export_scaled_mm_block_scaled_fp8_frontend +# CHECK: def forward(self, a, b, a_scale_block, b_scale_block): +# CHECK: torch.ops.aten._scaled_mm.default(a, b, a_scale_block, b_scale_block, None, None, torch.bfloat16) +# CHECK: META a placeholder a (256, 128) torch.float8_e4m3fn +# CHECK: META b placeholder b (128, 64) torch.float8_e4m3fn +# CHECK: META a_scale_block placeholder a_scale_block (1024,) torch.float8_e8m0fnu +# CHECK: META b_scale_block placeholder b_scale_block (512,) torch.float8_e8m0fnu +# CHECK: META _scaled_mm call_function aten._scaled_mm.default (256, 64) torch.bfloat16 +def test_export_scaled_mm_block_scaled_fp8_frontend(): + class Basic(nn.Module): + def forward(self, a, b, a_scale_block, b_scale_block): + return torch._scaled_mm( + a, b, a_scale_block, b_scale_block, out_dtype=torch.bfloat16 + ) + + a = torch.ones((256, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 64), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale_block = torch.zeros((1024,), dtype=torch.float8_e8m0fnu) + b_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + + exported = torch.export.export(Basic(), (a, b, a_scale_block, b_scale_block)) + print_exported_graph_with_tensor_meta(exported.module()) + + +@run +# CHECK-LABEL: test_export_scaled_mm_block_scaled_fp8_ragged_frontend +# CHECK: def forward(self, a, b, a_scale_block, b_scale_block): +# CHECK: torch.ops.aten._scaled_mm.default(a, b, a_scale_block, b_scale_block, None, None, torch.bfloat16) +# CHECK: META a placeholder a (130, 128) torch.float8_e4m3fn +# CHECK: META b placeholder b (128, 67) torch.float8_e4m3fn +# CHECK: META a_scale_block placeholder a_scale_block (1024,) torch.float8_e8m0fnu +# CHECK: META b_scale_block placeholder b_scale_block (512,) torch.float8_e8m0fnu +# CHECK: META _scaled_mm call_function aten._scaled_mm.default (130, 67) torch.bfloat16 +def test_export_scaled_mm_block_scaled_fp8_ragged_frontend(): + class Basic(nn.Module): + def forward(self, a, b, a_scale_block, b_scale_block): + return torch._scaled_mm( + a, b, a_scale_block, b_scale_block, out_dtype=torch.bfloat16 + ) + + a = torch.ones((130, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 67), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale_block = torch.zeros((1024,), dtype=torch.float8_e8m0fnu) + b_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + + exported = torch.export.export(Basic(), (a, b, a_scale_block, b_scale_block)) + print_exported_graph_with_tensor_meta(exported.module()) + + +@run +# CHECK-LABEL: test_import_scaled_mm_per_tensor +# CHECK: func.func @test_import_scaled_mm_per_tensor(%arg0: !torch.vtensor<[128,128],f8E4M3FN>, %arg1: !torch.vtensor<[128,128],f8E4M3FN>, %arg2: !torch.vtensor<[],f32>, %arg3: !torch.vtensor<[],f32>) -> !torch.vtensor<[128,128],bf16> +# CHECK: %[[NONE:.+]] = torch.constant.none +# CHECK: %[[NONE_0:.+]] = torch.constant.none +# CHECK: %[[INT15:.+]] = torch.constant.int 15 +# CHECK: %[[FALSE:.+]] = torch.constant.bool false +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3, %[[NONE]], %[[NONE_0]], %[[INT15]], %[[FALSE]] +# CHECK: return %[[MM]] +def test_import_scaled_mm_per_tensor(): + class Basic(nn.Module): + def forward(self, a, b, a_scale, b_scale): + return torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=torch.bfloat16) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale = torch.tensor(1.0, dtype=torch.float32) + b_scale = torch.tensor(1.0, dtype=torch.float32) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale, + b_scale, + func_name="test_import_scaled_mm_per_tensor", + ) + print(m) + + +@run +# CHECK-LABEL: test_import_scaled_mm_per_tensor_e5m2 +# CHECK: func.func @test_import_scaled_mm_per_tensor_e5m2(%arg0: !torch.vtensor<[128,128],f8E5M2>, %arg1: !torch.vtensor<[128,128],f8E5M2>, %arg2: !torch.vtensor<[],f32>, %arg3: !torch.vtensor<[],f32>) -> !torch.vtensor<[128,128],bf16> +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3 +def test_import_scaled_mm_per_tensor_e5m2(): + class Basic(nn.Module): + def forward(self, a, b, a_scale, b_scale): + return torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=torch.bfloat16) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + a_scale = torch.tensor(1.0, dtype=torch.float32) + b_scale = torch.tensor(1.0, dtype=torch.float32) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale, + b_scale, + func_name="test_import_scaled_mm_per_tensor_e5m2", + ) + print(m) + + +@run +# CHECK-LABEL: test_import_scaled_mm_out_dtype_none +# CHECK: func.func @test_import_scaled_mm_out_dtype_none(%arg0: !torch.vtensor<[128,128],f8E4M3FN>, %arg1: !torch.vtensor<[128,128],f8E5M2>, %arg2: !torch.vtensor<[],f32>, %arg3: !torch.vtensor<[],f32>) -> !torch.vtensor<[128,128],f8E4M3FN> +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3 +# CHECK: return %[[MM]] +def test_import_scaled_mm_out_dtype_none(): + class Basic(nn.Module): + def forward(self, a, b, a_scale, b_scale): + return torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=None) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + a_scale = torch.tensor(1.0, dtype=torch.float32) + b_scale = torch.tensor(1.0, dtype=torch.float32) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale, + b_scale, + func_name="test_import_scaled_mm_out_dtype_none", + ) + print(m) + + +@run +# CHECK-LABEL: test_import_scaled_mm_block_scaled_fp8 +# CHECK: func.func @test_import_scaled_mm_block_scaled_fp8(%arg0: !torch.vtensor<[128,128],f8E4M3FN>, %arg1: !torch.vtensor<[128,128],f8E4M3FN>, %arg2: !torch.vtensor<[512],f8E8M0FNU>, %arg3: !torch.vtensor<[512],f8E8M0FNU>) -> !torch.vtensor<[128,128],bf16> +# CHECK: %[[NONE:.+]] = torch.constant.none +# CHECK: %[[NONE_0:.+]] = torch.constant.none +# CHECK: %[[INT15:.+]] = torch.constant.int 15 +# CHECK: %[[FALSE:.+]] = torch.constant.bool false +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3, %[[NONE]], %[[NONE_0]], %[[INT15]], %[[FALSE]] +# CHECK: return %[[MM]] +def test_import_scaled_mm_block_scaled_fp8(): + class Basic(nn.Module): + def forward(self, a, b, a_scale_block, b_scale_block): + return torch._scaled_mm( + a, b, a_scale_block, b_scale_block, out_dtype=torch.bfloat16 + ) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + b_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale_block, + b_scale_block, + func_name="test_import_scaled_mm_block_scaled_fp8", + ) + print(m) + + +@run +# CHECK-LABEL: test_import_scaled_mm_block_scaled_fp8_e5m2 +# CHECK: func.func @test_import_scaled_mm_block_scaled_fp8_e5m2(%arg0: !torch.vtensor<[128,128],f8E5M2>, %arg1: !torch.vtensor<[128,128],f8E5M2>, %arg2: !torch.vtensor<[512],f8E8M0FNU>, %arg3: !torch.vtensor<[512],f8E8M0FNU>) -> !torch.vtensor<[128,128],bf16> +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3 +def test_import_scaled_mm_block_scaled_fp8_e5m2(): + class Basic(nn.Module): + def forward(self, a, b, a_scale_block, b_scale_block): + return torch._scaled_mm( + a, b, a_scale_block, b_scale_block, out_dtype=torch.bfloat16 + ) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + a_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + b_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale_block, + b_scale_block, + func_name="test_import_scaled_mm_block_scaled_fp8_e5m2", + ) + print(m) + + @run # CHECK-LABEL: test_while_loop_two_returns # Check that helper functions are emitted first From 8052ff9e634a9cdd112c7e695d6943143e4be0fe Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Wed, 6 May 2026 15:35:24 +0200 Subject: [PATCH 2/6] [Torch] Test _scaled_mm abstract interp Change-Id: I555367f4c8a3725fba75a1389e170bf402ad5eb9 --- .../build_tools/abstract_interp_lib_gen.py | 46 +++++++++++++++++++ .../build_tools/testing_framework.py | 15 ++++-- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index c5a4a6bdb8c5..55c7c6f1ccb6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -908,6 +908,29 @@ def aten〇mv〡shape(self: List[int], vec: List[int]) -> List[int]: def aten〇mm〡shape(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) +@check_shape_function([ + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=torch.bfloat16, + ), + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e5m2, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=None, + ), + Invocation( + TensorOfShape(256, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 64, dtype=torch.float8_e4m3fn, stride=(1, 128)), + TensorOfShape(1024, dtype=torch.float8_e8m0fnu), + TensorOfShape(512, dtype=torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ), +]) def aten〇_scaled_mm〡shape(self: List[int], mat2: List[int], scale_a: List[int], scale_b: List[int], bias: Optional[List[int]] = None, scale_result: Optional[List[int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> List[int]: return upstream_shape_functions.mm(self, mat2) @@ -4685,6 +4708,29 @@ def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[i dtypes = [self_dtype, mat2_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function([ + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=torch.bfloat16, + ), + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e5m2, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=None, + ), + Invocation( + TensorOfShape(256, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 64, dtype=torch.float8_e4m3fn, stride=(1, 128)), + TensorOfShape(1024, dtype=torch.float8_e8m0fnu), + TensorOfShape(512, dtype=torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ), +]) def aten〇_scaled_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], scale_a_rank_dtype: Tuple[int, int], scale_b_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, scale_result_rank_dtype: Optional[Tuple[int, int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> int: self_rank, self_dtype = self_rank_dtype mat2_rank, mat2_dtype = mat2_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py index cbeb38d66365..26150d152bcf 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Any, List, Iterable, Optional, Callable +from typing import Any, List, Iterable, Optional, Callable, Tuple import torch from torch import Tensor @@ -67,14 +67,19 @@ def __init__( *shape: int, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + stride: Optional[Tuple[int, ...]] = None, ): self.shape = list(shape) self.dtype = dtype self.device = "meta" if device is None else device + self.stride = stride def __repr__(self): args_str = ", ".join(repr(x) for x in self.shape) - return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})" + kwargs = f"dtype={self.dtype}, device={self.device}" + if self.stride is not None: + kwargs += f", stride={self.stride}" + return f"TensorOfShape({args_str}, {kwargs})" def LongTensorOfShape(*args, **kwargs): @@ -158,7 +163,11 @@ def to_dtype_function_args(self): def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" - tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype).to(o.device) + tensor_transformer = lambda o: ( + torch.empty_strided(o.shape, o.stride, dtype=o.dtype, device=o.device) + if o.stride is not None + else torch.ones(o.shape, dtype=o.dtype).to(o.device) + ) return _recursively_transform_tensor_args(self.args, tensor_transformer) def __repr__(self) -> str: From c99c3c3815ab3d6d4c93b9866d6be188bfc4062a Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Wed, 6 May 2026 15:47:54 +0200 Subject: [PATCH 3/6] [Torch] Tighten _scaled_mm abstract interp Change-Id: I3f62b8d60de4155295ec6f552f68c3d5f21f2c08 --- .../Transforms/AbstractInterpLibrary.cpp | 168 +++++++++++++++++- .../build_tools/abstract_interp_lib_gen.py | 78 +++++++- 2 files changed, 240 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index e2be43c991e6..01ad1d94607f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8408,7 +8408,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten._scaled_mm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int16 = torch.constant.int 16\n" +" %int32 = torch.constant.int 32\n" +" %int127 = torch.constant.int 127\n" +" %int128 = torch.constant.int 128\n" " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.multiply_integers(%arg2) : (!torch.list) -> !torch.int\n" +" %2 = call @__torch__.torch.jit._shape_functions.multiply_integers(%arg3) : (!torch.list) -> !torch.int\n" +" %3 = torch.aten.eq.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %29 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %29 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %29 = torch.aten.ne.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %29 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %30 = torch.aten.ne.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %30 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %100 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %101 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %102 = torch.aten.eq.int %100, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %103 = torch.prim.If %102 -> (!torch.bool) {\n" +" %104 = torch.aten.eq.int %101, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %104 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %103 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %31 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.aten.remainder.int %31, %int16 : !torch.int, !torch.int -> !torch.int\n" +" %33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %33 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %34 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.remainder.int %34, %int16 : !torch.int, !torch.int -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %36 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %37 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.add.int %37, %int127 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.floordiv.int %38, %int128 : !torch.int, !torch.int -> !torch.int\n" +" %40 = torch.aten.mul.int %39, %int128 : !torch.int, !torch.int -> !torch.int\n" +" %41 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.add.int %41, %int127 : !torch.int, !torch.int -> !torch.int\n" +" %43 = torch.aten.floordiv.int %42, %int128 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.mul.int %43, %int128 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.mul.int %40, %44 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.floordiv.int %45, %int32 : !torch.int, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.aten.add.int %47, %int127 : !torch.int, !torch.int -> !torch.int\n" +" %49 = torch.aten.floordiv.int %48, %int128 : !torch.int, !torch.int -> !torch.int\n" +" %50 = torch.aten.mul.int %49, %int128 : !torch.int, !torch.int -> !torch.int\n" +" %51 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %52 = torch.aten.add.int %51, %int127 : !torch.int, !torch.int -> !torch.int\n" +" %53 = torch.aten.floordiv.int %52, %int128 : !torch.int, !torch.int -> !torch.int\n" +" %54 = torch.aten.mul.int %53, %int128 : !torch.int, !torch.int -> !torch.int\n" +" %55 = torch.aten.mul.int %50, %54 : !torch.int, !torch.int -> !torch.int\n" +" %56 = torch.aten.floordiv.int %55, %int32 : !torch.int, !torch.int -> !torch.int\n" +" %57 = torch.aten.eq.int %1, %46 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %57 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %58 = torch.aten.eq.int %2, %56 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %58 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" }\n" " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten._int_mm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" @@ -15549,16 +15654,69 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %6 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._scaled_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" " %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int6 = torch.constant.int 6\n" +" %int23 = torch.constant.int 23\n" +" %int24 = torch.constant.int 24\n" +" %int25 = torch.constant.int 25\n" +" %int26 = torch.constant.int 26\n" +" %int44 = torch.constant.int 44\n" +" %int45 = torch.constant.int 45\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.int) {\n" -" %3 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.int\n" -" torch.prim.If.yield %3 : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3:2 = torch.prim.TupleUnpack %arg3 : !torch.tuple -> !torch.int, !torch.int\n" +" %4 = torch.prim.ListConstruct %int45, %int24, %int26, %int23, %int25 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = torch.aten.__contains__.int_list %4, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__contains__.int_list %4, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.eq.int %2#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %23 = torch.aten.eq.int %3#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %23 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.aten.eq.int %2#1, %int44 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" %23 = torch.aten.eq.int %3#1, %int44 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %23 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %11 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" %23 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %23 : !torch.int\n" " } else {\n" " torch.prim.If.yield %0#1 : !torch.int\n" " }\n" -" return %2 : !torch.int\n" +" return %13 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._int_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 55c7c6f1ccb6..b371e512e6c7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -908,6 +908,36 @@ def aten〇mv〡shape(self: List[int], vec: List[int]) -> List[int]: def aten〇mm〡shape(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) +def _numel(sizes: List[int]) -> int: + numel = 1 + for size in sizes: + numel *= size + return numel + +def _round_up_to_multiple(value: int, multiple: int) -> int: + return ((value + multiple - 1) // multiple) * multiple + +def _scaled_mm_block_scale_numel(rows: int, cols: int) -> int: + return _round_up_to_multiple(rows, 128) * _round_up_to_multiple(cols, 128) // 32 + +def _check_scaled_mm_scale_shapes(self: List[int], mat2: List[int], scale_a: List[int], scale_b: List[int]): + scale_a_numel = _numel(scale_a) + scale_b_numel = _numel(scale_b) + + if scale_a_numel == 1 and scale_b_numel == 1: + return + + assert scale_a_numel != 1 and scale_b_numel != 1 + if len(scale_a) == 2 and len(scale_b) == 2: + return + + assert mat2[0] % 16 == 0 + assert mat2[1] % 16 == 0 + expected_scale_a_numel = _scaled_mm_block_scale_numel(self[0], self[1]) + expected_scale_b_numel = _scaled_mm_block_scale_numel(mat2[1], mat2[0]) + assert scale_a_numel == expected_scale_a_numel + assert scale_b_numel == expected_scale_b_numel + @check_shape_function([ Invocation( TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), @@ -930,9 +960,18 @@ def aten〇mm〡shape(self: List[int], mat2: List[int]) -> List[int]: TensorOfShape(512, dtype=torch.float8_e8m0fnu), out_dtype=torch.bfloat16, ), + ErrorInvocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn, stride=(1, 128)), + TensorOfShape(500, dtype=torch.float8_e8m0fnu), + TensorOfShape(512, dtype=torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ), ]) def aten〇_scaled_mm〡shape(self: List[int], mat2: List[int], scale_a: List[int], scale_b: List[int], bias: Optional[List[int]] = None, scale_result: Optional[List[int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> List[int]: - return upstream_shape_functions.mm(self, mat2) + result = upstream_shape_functions.mm(self, mat2) + _check_scaled_mm_scale_shapes(self, mat2, scale_a, scale_b) + return result def aten〇_int_mm〡shape(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) @@ -4708,6 +4747,26 @@ def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[i dtypes = [self_dtype, mat2_dtype] return promote_dtypes(ranks, dtypes) +def _is_scaled_mm_data_dtype(dtype: int) -> bool: + return dtype in [ + torch.float4_e2m1fn_x2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ] + +def _check_scaled_mm_dtypes(self_dtype: int, mat2_dtype: int, scale_a_dtype: int, scale_b_dtype: int): + assert _is_scaled_mm_data_dtype(self_dtype) + assert _is_scaled_mm_data_dtype(mat2_dtype) + assert ( + scale_a_dtype == torch.float32 + and scale_b_dtype == torch.float32 + ) or ( + scale_a_dtype == torch.float8_e8m0fnu + and scale_b_dtype == torch.float8_e8m0fnu + ) + @check_dtype_function([ Invocation( TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), @@ -4730,10 +4789,27 @@ def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[i TensorOfShape(512, dtype=torch.float8_e8m0fnu), out_dtype=torch.bfloat16, ), + ErrorInvocation( + TensorOfShape(2, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.int32, stride=(1, 3)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=None, + ), + ErrorInvocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn, stride=(1, 128)), + TensorOfShape(512, dtype=torch.float8_e8m0fnu), + TensorOfShape(512, dtype=torch.float32), + out_dtype=torch.bfloat16, + ), ]) def aten〇_scaled_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], scale_a_rank_dtype: Tuple[int, int], scale_b_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, scale_result_rank_dtype: Optional[Tuple[int, int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> int: self_rank, self_dtype = self_rank_dtype mat2_rank, mat2_dtype = mat2_rank_dtype + scale_a_rank, scale_a_dtype = scale_a_rank_dtype + scale_b_rank, scale_b_dtype = scale_b_rank_dtype + _check_scaled_mm_dtypes(self_dtype, mat2_dtype, scale_a_dtype, scale_b_dtype) if out_dtype is not None: return out_dtype return self_dtype From 12efa5c5bd6775887a120b89c8938bdfb3658f84 Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Wed, 6 May 2026 15:51:18 +0200 Subject: [PATCH 4/6] [FX] Fix _scaled_mm ragged test shape Change-Id: I1c1a386a6395b6416c6ada5d06cc82f87271bd5d --- test/python/fx_importer/basic_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 3e8abf0c79b0..40d9a81c0b5c 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -329,10 +329,10 @@ def forward(self, a, b, a_scale_block, b_scale_block): # CHECK: def forward(self, a, b, a_scale_block, b_scale_block): # CHECK: torch.ops.aten._scaled_mm.default(a, b, a_scale_block, b_scale_block, None, None, torch.bfloat16) # CHECK: META a placeholder a (130, 128) torch.float8_e4m3fn -# CHECK: META b placeholder b (128, 67) torch.float8_e4m3fn +# CHECK: META b placeholder b (128, 80) torch.float8_e4m3fn # CHECK: META a_scale_block placeholder a_scale_block (1024,) torch.float8_e8m0fnu # CHECK: META b_scale_block placeholder b_scale_block (512,) torch.float8_e8m0fnu -# CHECK: META _scaled_mm call_function aten._scaled_mm.default (130, 67) torch.bfloat16 +# CHECK: META _scaled_mm call_function aten._scaled_mm.default (130, 80) torch.bfloat16 def test_export_scaled_mm_block_scaled_fp8_ragged_frontend(): class Basic(nn.Module): def forward(self, a, b, a_scale_block, b_scale_block): @@ -341,7 +341,7 @@ def forward(self, a, b, a_scale_block, b_scale_block): ) a = torch.ones((130, 128), dtype=torch.float32).to(torch.float8_e4m3fn) - b = torch.ones((128, 67), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 80), dtype=torch.float32).to(torch.float8_e4m3fn) a_scale_block = torch.zeros((1024,), dtype=torch.float8_e8m0fnu) b_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) From 98eef4911a4ec4cff495b62a15c3e0050d1b65cf Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Wed, 6 May 2026 15:53:20 +0200 Subject: [PATCH 5/6] [Torch] Check _scaled_mm matrix divisibility Change-Id: I7d4343af324fafe830ff514d79d8c6f03ac200e3 --- .../Transforms/AbstractInterpLibrary.cpp | 27 +++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 20 ++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 01ad1d94607f..fd0fafd90de5 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8422,6 +8422,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " %1 = call @__torch__.torch.jit._shape_functions.multiply_integers(%arg2) : (!torch.list) -> !torch.int\n" " %2 = call @__torch__.torch.jit._shape_functions.multiply_integers(%arg3) : (!torch.list) -> !torch.int\n" +" %200 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %201 = torch.aten.remainder.int %200, %int16 : !torch.int, !torch.int -> !torch.int\n" +" %202 = torch.aten.eq.int %201, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %202 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %203 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %204 = torch.aten.remainder.int %203, %int16 : !torch.int, !torch.int -> !torch.int\n" +" %205 = torch.aten.eq.int %204, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %205 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %206 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %207 = torch.aten.remainder.int %206, %int16 : !torch.int, !torch.int -> !torch.int\n" +" %208 = torch.aten.eq.int %207, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %208 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " %3 = torch.aten.eq.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " %29 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index b371e512e6c7..f6d81866caf8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -921,6 +921,10 @@ def _scaled_mm_block_scale_numel(rows: int, cols: int) -> int: return _round_up_to_multiple(rows, 128) * _round_up_to_multiple(cols, 128) // 32 def _check_scaled_mm_scale_shapes(self: List[int], mat2: List[int], scale_a: List[int], scale_b: List[int]): + assert self[1] % 16 == 0 + assert mat2[0] % 16 == 0 + assert mat2[1] % 16 == 0 + scale_a_numel = _numel(scale_a) scale_b_numel = _numel(scale_b) @@ -931,8 +935,6 @@ def _check_scaled_mm_scale_shapes(self: List[int], mat2: List[int], scale_a: Lis if len(scale_a) == 2 and len(scale_b) == 2: return - assert mat2[0] % 16 == 0 - assert mat2[1] % 16 == 0 expected_scale_a_numel = _scaled_mm_block_scale_numel(self[0], self[1]) expected_scale_b_numel = _scaled_mm_block_scale_numel(mat2[1], mat2[0]) assert scale_a_numel == expected_scale_a_numel @@ -967,6 +969,20 @@ def _check_scaled_mm_scale_shapes(self: List[int], mat2: List[int], scale_a: Lis TensorOfShape(512, dtype=torch.float8_e8m0fnu), out_dtype=torch.bfloat16, ), + ErrorInvocation( + TensorOfShape(128, 100, dtype=torch.float8_e4m3fn), + TensorOfShape(100, 128, dtype=torch.float8_e4m3fn, stride=(1, 100)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=torch.bfloat16, + ), + ErrorInvocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 67, dtype=torch.float8_e4m3fn, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=torch.bfloat16, + ), ]) def aten〇_scaled_mm〡shape(self: List[int], mat2: List[int], scale_a: List[int], scale_b: List[int], bias: Optional[List[int]] = None, scale_result: Optional[List[int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> List[int]: result = upstream_shape_functions.mm(self, mat2) From ada420061982a170f9fcc76a611e7c12ac14c99f Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Thu, 7 May 2026 08:56:39 +0200 Subject: [PATCH 6/6] [Torch] Regenerate _scaled_mm abstract interp library Change-Id: I52f0543dc1120e3a8e3d358486c2cafc8a63d1eb --- .../Transforms/AbstractInterpLibrary.cpp | 265 +++++++++--------- 1 file changed, 135 insertions(+), 130 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index fd0fafd90de5..d78bab129545 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8408,140 +8408,136 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten._scaled_mm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.list {\n" -" %true = torch.constant.bool true\n" +" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %1 = call @__torch__._check_scaled_mm_scale_shapes(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.list, !torch.list) -> !torch.none\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__._check_scaled_mm_scale_shapes(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list) -> !torch.none {\n" " %false = torch.constant.bool false\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" -" %int2 = torch.constant.int 2\n" " %int16 = torch.constant.int 16\n" -" %int32 = torch.constant.int 32\n" -" %int127 = torch.constant.int 127\n" -" %int128 = torch.constant.int 128\n" -" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" %1 = call @__torch__.torch.jit._shape_functions.multiply_integers(%arg2) : (!torch.list) -> !torch.int\n" -" %2 = call @__torch__.torch.jit._shape_functions.multiply_integers(%arg3) : (!torch.list) -> !torch.int\n" -" %200 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %201 = torch.aten.remainder.int %200, %int16 : !torch.int, !torch.int -> !torch.int\n" -" %202 = torch.aten.eq.int %201, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %202 -> () {\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.remainder.int %0, %int16 : !torch.int, !torch.int -> !torch.int\n" +" %2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %203 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %204 = torch.aten.remainder.int %203, %int16 : !torch.int, !torch.int -> !torch.int\n" -" %205 = torch.aten.eq.int %204, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %205 -> () {\n" +" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.remainder.int %3, %int16 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %206 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %207 = torch.aten.remainder.int %206, %int16 : !torch.int, !torch.int -> !torch.int\n" -" %208 = torch.aten.eq.int %207, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %208 -> () {\n" +" %6 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.remainder.int %6, %int16 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.aten.eq.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %29 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %29 : !torch.bool\n" +" %9 = call @__torch__._numel(%arg2) : (!torch.list) -> !torch.int\n" +" %10 = call @__torch__._numel(%arg3) : (!torch.list) -> !torch.int\n" +" %11 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" %13 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %4 -> () {\n" +" torch.prim.If %12 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" %29 = torch.aten.ne.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %29 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %30 = torch.aten.ne.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %30 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %100 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" -" %101 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" -" %102 = torch.aten.eq.int %100, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" %103 = torch.prim.If %102 -> (!torch.bool) {\n" -" %104 = torch.aten.eq.int %101, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %104 : !torch.bool\n" +" %13 = torch.aten.ne.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" %18 = torch.aten.ne.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %18 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %103 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %31 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %32 = torch.aten.remainder.int %31, %int16 : !torch.int, !torch.int -> !torch.int\n" -" %33 = torch.aten.eq.int %32, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %33 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %34 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %35 = torch.aten.remainder.int %34, %int16 : !torch.int, !torch.int -> !torch.int\n" -" %36 = torch.aten.eq.int %35, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %36 -> () {\n" +" torch.prim.If %14 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %37 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %38 = torch.aten.add.int %37, %int127 : !torch.int, !torch.int -> !torch.int\n" -" %39 = torch.aten.floordiv.int %38, %int128 : !torch.int, !torch.int -> !torch.int\n" -" %40 = torch.aten.mul.int %39, %int128 : !torch.int, !torch.int -> !torch.int\n" -" %41 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %42 = torch.aten.add.int %41, %int127 : !torch.int, !torch.int -> !torch.int\n" -" %43 = torch.aten.floordiv.int %42, %int128 : !torch.int, !torch.int -> !torch.int\n" -" %44 = torch.aten.mul.int %43, %int128 : !torch.int, !torch.int -> !torch.int\n" -" %45 = torch.aten.mul.int %40, %44 : !torch.int, !torch.int -> !torch.int\n" -" %46 = torch.aten.floordiv.int %45, %int32 : !torch.int, !torch.int -> !torch.int\n" -" %47 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %48 = torch.aten.add.int %47, %int127 : !torch.int, !torch.int -> !torch.int\n" -" %49 = torch.aten.floordiv.int %48, %int128 : !torch.int, !torch.int -> !torch.int\n" -" %50 = torch.aten.mul.int %49, %int128 : !torch.int, !torch.int -> !torch.int\n" -" %51 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %52 = torch.aten.add.int %51, %int127 : !torch.int, !torch.int -> !torch.int\n" -" %53 = torch.aten.floordiv.int %52, %int128 : !torch.int, !torch.int -> !torch.int\n" -" %54 = torch.aten.mul.int %53, %int128 : !torch.int, !torch.int -> !torch.int\n" -" %55 = torch.aten.mul.int %50, %54 : !torch.int, !torch.int -> !torch.int\n" -" %56 = torch.aten.floordiv.int %55, %int32 : !torch.int, !torch.int -> !torch.int\n" -" %57 = torch.aten.eq.int %1, %46 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %57 -> () {\n" -" torch.prim.If.yield\n" +" %15 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %16 = torch.aten.eq.int %15, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" %18 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %19 = torch.aten.eq.int %18, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %false : !torch.bool\n" " }\n" -" %58 = torch.aten.eq.int %2, %56 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %58 -> () {\n" +" torch.prim.If %17 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" +" %18 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %20 = func.call @__torch__._scaled_mm_block_scale_numel(%18, %19) : (!torch.int, !torch.int) -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %23 = func.call @__torch__._scaled_mm_block_scale_numel(%21, %22) : (!torch.int, !torch.int) -> !torch.int\n" +" %24 = torch.aten.eq.int %9, %20 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %24 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %25 = torch.aten.eq.int %10, %23 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %25 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " torch.prim.If.yield\n" " }\n" " torch.prim.If.yield\n" " }\n" -" return %0 : !torch.list\n" +" return %none : !torch.none\n" +" }\n" +" func.func @__torch__._numel(%arg0: !torch.list) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%int1) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%3 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @__torch__._scaled_mm_block_scale_numel(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %int128 = torch.constant.int 128\n" +" %int32 = torch.constant.int 32\n" +" %0 = call @__torch__._round_up_to_multiple(%arg0, %int128) : (!torch.int, !torch.int) -> !torch.int\n" +" %1 = call @__torch__._round_up_to_multiple(%arg1, %int128) : (!torch.int, !torch.int) -> !torch.int\n" +" %2 = torch.aten.mul.int %0, %1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.floordiv.int %2, %int32 : !torch.int, !torch.int -> !torch.int\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @__torch__._round_up_to_multiple(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %1 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %2 = torch.aten.floordiv.int %1, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.mul.int %2, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" return %3 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten._int_mm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" @@ -15681,69 +15677,78 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %6 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._scaled_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.int {\n" -" %true = torch.constant.bool true\n" -" %false = torch.constant.bool false\n" " %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int6 = torch.constant.int 6\n" -" %int23 = torch.constant.int 23\n" -" %int24 = torch.constant.int 24\n" -" %int25 = torch.constant.int 25\n" -" %int26 = torch.constant.int 26\n" -" %int44 = torch.constant.int 44\n" -" %int45 = torch.constant.int 45\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %3:2 = torch.prim.TupleUnpack %arg3 : !torch.tuple -> !torch.int, !torch.int\n" -" %4 = torch.prim.ListConstruct %int45, %int24, %int26, %int23, %int25 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %5 = torch.aten.__contains__.int_list %4, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %4 = call @__torch__._check_scaled_mm_dtypes(%0#1, %1#1, %2#1, %3#1) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none\n" +" %5 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" %7 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @__torch__._check_scaled_mm_dtypes(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.none {\n" +" %int44 = torch.constant.int 44\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0 = call @__torch__._is_scaled_mm_data_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.__contains__.int_list %4, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %1 = call @__torch__._is_scaled_mm_data_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %7 = torch.aten.eq.int %2#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.bool) {\n" -" %23 = torch.aten.eq.int %3#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %23 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %9 = torch.aten.eq.int %2#1, %int44 : !torch.int, !torch.int -> !torch.bool\n" -" %10 = torch.prim.If %9 -> (!torch.bool) {\n" -" %23 = torch.aten.eq.int %3#1, %int44 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %23 : !torch.bool\n" +" %2 = torch.aten.eq.int %arg2, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %5 = torch.aten.eq.int %arg3, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" %11 = torch.prim.If %8 -> (!torch.bool) {\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" torch.prim.If.yield %10 : !torch.bool\n" +" %5 = torch.aten.eq.int %arg2, %int44 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %7 = torch.aten.eq.int %arg3, %int44 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %6 : !torch.bool\n" " }\n" -" torch.prim.If %11 -> () {\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %12 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.int) {\n" -" %23 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.int\n" -" torch.prim.If.yield %23 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %0#1 : !torch.int\n" -" }\n" -" return %13 : !torch.int\n" +" return %none : !torch.none\n" +" }\n" +" func.func @__torch__._is_scaled_mm_data_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %int25 = torch.constant.int 25\n" +" %int23 = torch.constant.int 23\n" +" %int26 = torch.constant.int 26\n" +" %int24 = torch.constant.int 24\n" +" %int45 = torch.constant.int 45\n" +" %0 = torch.prim.ListConstruct %int45, %int24, %int26, %int23, %int25 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._int_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n"