diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 2a35d0f9ba9e..8b3fb891943e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6623,6 +6623,37 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [ }]; } +def Torch_Aten_ScaledMmOp : Torch_Op<"aten._scaled_mm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_scaled_mm : (Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, int?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mat2, + AnyTorchTensorType:$scale_a, + AnyTorchTensorType:$scale_b, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$scale_result, + AnyTorchOptionalIntType:$out_dtype, + Torch_BoolType:$use_fast_accum + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_ScaledMmOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void Aten_ScaledMmOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_Aten_IntMmOp : Torch_Op<"aten._int_mm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 67de6981edc6..abb616fb3200 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -6314,6 +6314,159 @@ LogicalResult AtenKthvalueOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Aten_ScaledMmOp +//===----------------------------------------------------------------------===// + +static bool isScaledMmDataDtype(Type dtype) { + return isa(dtype); +} + +static bool isScaledMmTensorwiseOrRowwiseScaleDtype(Type dtype) { + return dtype.isF32(); +} + +static bool isScaledMmBlockwiseScaleDtype(Type dtype) { + return isa(dtype); +} + +static int64_t ceilDivPositive(int64_t dividend, int64_t divisor) { + return (dividend + divisor - 1) / divisor; +} + +static bool getNumel(ArrayRef sizes, int64_t &numel) { + numel = 1; + for (int64_t size : sizes) { + if (size == kUnknownSize) + return false; + numel *= size; + } + return true; +} + +static bool hasShape(ArrayRef sizes, ArrayRef expected) { + return sizes.size() == expected.size() && llvm::equal(sizes, expected); +} + +LogicalResult Aten_ScaledMmOp::verify() { + auto selfType = cast(getSelf().getType()); + auto mat2Type = cast(getMat2().getType()); + auto scaleAType = cast(getScaleA().getType()); + auto scaleBType = cast(getScaleB().getType()); + + if (selfType.hasDtype() && !isScaledMmDataDtype(selfType.getDtype())) + return emitOpError("expected self to have an FP8 dtype, but got ") + << selfType.getDtype(); + if (mat2Type.hasDtype() && !isScaledMmDataDtype(mat2Type.getDtype())) + return emitOpError("expected mat2 to have an FP8 dtype, but got ") + << mat2Type.getDtype(); + + if (!selfType.hasSizes() || !mat2Type.hasSizes()) + return success(); + + ArrayRef selfShape = selfType.getSizes(); + ArrayRef mat2Shape = mat2Type.getSizes(); + if (selfShape.size() != 2 || mat2Shape.size() != 2) + return emitOpError("expected self and mat2 to be rank 2, but got ranks ") + << selfShape.size() << " and " << mat2Shape.size(); + + int64_t m = selfShape[0]; + int64_t k = selfShape[1]; + int64_t mat2K = mat2Shape[0]; + int64_t n = mat2Shape[1]; + + if (k != kUnknownSize && mat2K != kUnknownSize && k != mat2K) + return emitOpError("expected self and mat2 contracting dimensions to " + "match, but got ") + << k << " and " << mat2K; + if (k != kUnknownSize && k % 16 != 0) + return emitOpError("expected self contracting dimension to be divisible " + "by 16, but got ") + << k; + if (mat2K != kUnknownSize && mat2K % 16 != 0) + return emitOpError("expected mat2 contracting dimension to be divisible " + "by 16, but got ") + << mat2K; + if (n != kUnknownSize && n % 16 != 0) + return emitOpError("expected mat2 output dimension to be divisible by 16, " + "but got ") + << n; + + if (!scaleAType.hasDtype() || !scaleBType.hasDtype() || + !scaleAType.hasSizes() || !scaleBType.hasSizes() || + !selfType.areAllSizesKnown() || !mat2Type.areAllSizesKnown()) + return success(); + + Type scaleADtype = scaleAType.getDtype(); + Type scaleBDtype = scaleBType.getDtype(); + ArrayRef scaleAShape = scaleAType.getSizes(); + ArrayRef scaleBShape = scaleBType.getSizes(); + + int64_t scaleANumel; + int64_t scaleBNumel; + if (!getNumel(scaleAShape, scaleANumel) || + !getNumel(scaleBShape, scaleBNumel)) + return success(); + + if (scaleANumel == 1 || scaleBNumel == 1) { + if (scaleANumel != 1 || scaleBNumel != 1) + return emitOpError("expected scale_a and scale_b to both be scalar for " + "tensorwise scaling"); + if (!isScaledMmTensorwiseOrRowwiseScaleDtype(scaleADtype) || + !isScaledMmTensorwiseOrRowwiseScaleDtype(scaleBDtype)) + return emitOpError( + "expected tensorwise scale_a and scale_b to have f32 dtype"); + return success(); + } + + if (scaleADtype == scaleBDtype && + isScaledMmBlockwiseScaleDtype(scaleADtype)) { + int64_t blockSizeK = isa(scaleADtype) ? 16 : 32; + int64_t numKBlocks = ceilDivPositive(k, blockSizeK); + int64_t paddedNumKBlocks = ceilDivPositive(numKBlocks, 4) * 4; + int64_t expectedScaleANumel = + 128 * ceilDivPositive(m, 128) * paddedNumKBlocks; + int64_t expectedScaleBNumel = + 128 * ceilDivPositive(n, 128) * paddedNumKBlocks; + if (scaleANumel != expectedScaleANumel || + scaleBNumel != expectedScaleBNumel) + return emitOpError("invalid blockwise scaling configuration: expected " + "scale_a to have ") + << expectedScaleANumel << " elements and scale_b to have " + << expectedScaleBNumel << " elements, but got " << scaleANumel + << " and " << scaleBNumel; + return success(); + } + + if (!isScaledMmTensorwiseOrRowwiseScaleDtype(scaleADtype) || + !isScaledMmTensorwiseOrRowwiseScaleDtype(scaleBDtype)) + return emitOpError("expected non-tensorwise, non-blockwise scale_a and " + "scale_b to have f32 dtype"); + + if (scaleAShape.size() != 2 || scaleBShape.size() != 2) + return emitOpError("expected non-tensorwise scale_a and scale_b to be " + "rank 2, but got ranks ") + << scaleAShape.size() << " and " << scaleBShape.size(); + + int64_t kBlocks = ceilDivPositive(k, 128); + int64_t mBlocks = ceilDivPositive(m, 128); + int64_t nBlocks = ceilDivPositive(n, 128); + if (hasShape(scaleAShape, {m, 1}) && hasShape(scaleBShape, {1, n})) + return success(); + if (hasShape(scaleAShape, {m, kBlocks}) && + hasShape(scaleBShape, {kBlocks, nBlocks})) + return success(); + if (hasShape(scaleAShape, {m, kBlocks}) && + hasShape(scaleBShape, {kBlocks, n})) + return success(); + if (hasShape(scaleAShape, {mBlocks, kBlocks}) && + hasShape(scaleBShape, {kBlocks, n})) + return success(); + + return emitOpError("invalid scaling configuration for scale_a and scale_b"); +} + //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 155e6ca8c106..e2be43c991e6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8407,6 +8407,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._scaled_mm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._int_mm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -15544,6 +15548,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._scaled_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" %3 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._int_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" " %none = torch.constant.none\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 724ad18044d5..53d15c127188 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -908,6 +908,46 @@ def aten〇mv〡shape(self: List[int], vec: List[int]) -> List[int]: def aten〇mm〡shape(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) +@check_shape_function([ + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=torch.bfloat16, + ), + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e5m2, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=None, + ), + Invocation( + TensorOfShape(256, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 64, dtype=torch.float8_e4m3fn, stride=(1, 128)), + TensorOfShape(1024, dtype=torch.float8_e8m0fnu), + TensorOfShape(512, dtype=torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ), + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 64, dtype=torch.float8_e4m3fn, stride=(1, 128)), + TensorOfShape(128, 1), + TensorOfShape(1, 64), + out_dtype=torch.bfloat16, + ), + Invocation( + TensorOfShape(128, 128, dtype=torch.float4_e2m1fn_x2), + TensorOfShape(128, 128, dtype=torch.float4_e2m1fn_x2, stride=(1, 128)), + TensorOfShape(1024, dtype=torch.float8_e8m0fnu), + TensorOfShape(1024, dtype=torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ), +]) +def aten〇_scaled_mm〡shape(self: List[int], mat2: List[int], scale_a: List[int], scale_b: List[int], bias: Optional[List[int]] = None, scale_result: Optional[List[int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> List[int]: + return upstream_shape_functions.mm(self, mat2) + def aten〇_int_mm〡shape(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) @@ -4682,6 +4722,56 @@ def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[i dtypes = [self_dtype, mat2_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function([ + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=torch.bfloat16, + ), + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 128, dtype=torch.float8_e5m2, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=None, + ), + Invocation( + TensorOfShape(256, 128, dtype=torch.float8_e4m3fn), + TensorOfShape(128, 64, dtype=torch.float8_e4m3fn, stride=(1, 128)), + TensorOfShape(1024, dtype=torch.float8_e8m0fnu), + TensorOfShape(512, dtype=torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ), + Invocation( + TensorOfShape(128, 128, dtype=torch.float8_e4m3fnuz), + TensorOfShape(128, 128, dtype=torch.float8_e5m2fnuz, stride=(1, 128)), + ZeroDTensorWithDtype(torch.float32), + ZeroDTensorWithDtype(torch.float32), + out_dtype=torch.float16, + ), + Invocation( + TensorOfShape(128, 64, dtype=torch.float4_e2m1fn_x2), + TensorOfShape(64, 128, dtype=torch.float4_e2m1fn_x2, stride=(1, 64)), + TensorOfShape(512, dtype=torch.float8_e8m0fnu), + TensorOfShape(512, dtype=torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ), + Invocation( + TensorOfShape(128, 128, dtype=torch.float4_e2m1fn_x2), + TensorOfShape(128, 128, dtype=torch.float4_e2m1fn_x2, stride=(1, 128)), + TensorOfShape(1024, dtype=torch.float8_e8m0fnu), + TensorOfShape(1024, dtype=torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ), +]) +def aten〇_scaled_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], scale_a_rank_dtype: Tuple[int, int], scale_b_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, scale_result_rank_dtype: Optional[Tuple[int, int]] = None, out_dtype: Optional[int] = None, use_fast_accum: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + if out_dtype is not None: + return out_dtype + return self_dtype + def aten〇_int_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype mat2_rank, mat2_dtype = mat2_rank_dtype @@ -6369,4 +6459,3 @@ def _create_argparse() -> argparse.ArgumentParser: if __name__ == "__main__": main(_create_argparse().parse_args()) - diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py index cbeb38d66365..6080052df291 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/testing_framework.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Any, List, Iterable, Optional, Callable +from typing import Any, List, Iterable, Optional, Callable, Tuple import torch from torch import Tensor @@ -59,7 +59,9 @@ class TensorOfShape: this special treatment. This class also tracks a dtype of the tensor, since some ops require a - specific dtype. + specific dtype. When a stride is provided, it is only used to construct the + real tensor for testing the upstream op; torch-mlir does not import stride + metadata. """ def __init__( @@ -67,14 +69,19 @@ def __init__( *shape: int, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + stride: Optional[Tuple[int, ...]] = None, ): self.shape = list(shape) self.dtype = dtype self.device = "meta" if device is None else device + self.stride = stride def __repr__(self): args_str = ", ".join(repr(x) for x in self.shape) - return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})" + kwargs = f"dtype={self.dtype}, device={self.device}" + if self.stride is not None: + kwargs += f", stride={self.stride}" + return f"TensorOfShape({args_str}, {kwargs})" def LongTensorOfShape(*args, **kwargs): @@ -158,7 +165,22 @@ def to_dtype_function_args(self): def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" - tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype).to(o.device) + + def initialize_tensor(t): + try: + return t.fill_(1) + except RuntimeError as e: + if t.dtype == torch.float4_e2m1fn_x2 and "not implemented" in str(e): + return t + raise + + tensor_transformer = lambda o: ( + initialize_tensor( + torch.empty_strided(o.shape, o.stride, dtype=o.dtype, device=o.device) + ) + if o.stride is not None + else initialize_tensor(torch.empty(o.shape, dtype=o.dtype).to(o.device)) + ) return _recursively_transform_tensor_args(self.args, tensor_transformer) def __repr__(self) -> str: diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 0fc4fd3e0de9..63a0f50f21a4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -574,6 +574,10 @@ def emit_with_mutating_variants(key, **kwargs): # Non-elementwise tensor compute ops emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") emit("aten::mm : (Tensor, Tensor) -> (Tensor)") + emit( + "aten::_scaled_mm : (Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, int?, bool) -> (Tensor)", + has_verifier=True, + ) emit("aten::_int_mm : (Tensor, Tensor) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index c863e93fa5fa..d8f6dd53ee86 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -403,3 +403,33 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) - torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %arg0 : !torch.vtensor<[?],f32> } + +// ----- + +func.func @torch.aten._scaled_mm$invalid_blockwise_scale_numel( + %arg0: !torch.vtensor<[128,128],f8E4M3FN>, + %arg1: !torch.vtensor<[128,128],f8E4M3FN>, + %arg2: !torch.vtensor<[512],f8E4M3FN>, + %arg3: !torch.vtensor<[512],f8E4M3FN>) -> !torch.vtensor<[128,128],bf16> { + %false = torch.constant.bool false + %int15 = torch.constant.int 15 + %none = torch.constant.none + // expected-error @+1 {{'torch.aten._scaled_mm' op invalid blockwise scaling configuration: expected scale_a to have 1024 elements and scale_b to have 1024 elements, but got 512 and 512}} + %0 = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3, %none, %none, %int15, %false : !torch.vtensor<[128,128],f8E4M3FN>, !torch.vtensor<[128,128],f8E4M3FN>, !torch.vtensor<[512],f8E4M3FN>, !torch.vtensor<[512],f8E4M3FN>, !torch.none, !torch.none, !torch.int, !torch.bool -> !torch.vtensor<[128,128],bf16> + return %0 : !torch.vtensor<[128,128],bf16> +} + +// ----- + +func.func @torch.aten._scaled_mm$invalid_mixed_scale_dtype( + %arg0: !torch.vtensor<[128,128],f8E4M3FN>, + %arg1: !torch.vtensor<[128,128],f8E4M3FN>, + %arg2: !torch.vtensor<[512],f8E8M0FNU>, + %arg3: !torch.vtensor<[512],f32>) -> !torch.vtensor<[128,128],bf16> { + %false = torch.constant.bool false + %int15 = torch.constant.int 15 + %none = torch.constant.none + // expected-error @+1 {{'torch.aten._scaled_mm' op expected non-tensorwise, non-blockwise scale_a and scale_b to have f32 dtype}} + %0 = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3, %none, %none, %int15, %false : !torch.vtensor<[128,128],f8E4M3FN>, !torch.vtensor<[128,128],f8E4M3FN>, !torch.vtensor<[512],f8E8M0FNU>, !torch.vtensor<[512],f32>, !torch.none, !torch.none, !torch.int, !torch.bool -> !torch.vtensor<[128,128],bf16> + return %0 : !torch.vtensor<[128,128],bf16> +} diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index f00290593d1c..89d76d2b520b 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -238,6 +238,146 @@ def forward(self): print(m) +@run +# CHECK-LABEL: test_import_scaled_mm_per_tensor +# CHECK: func.func @test_import_scaled_mm_per_tensor(%arg0: !torch.vtensor<[128,128],f8E4M3FN>, %arg1: !torch.vtensor<[128,128],f8E4M3FN>, %arg2: !torch.vtensor<[],f32>, %arg3: !torch.vtensor<[],f32>) -> !torch.vtensor<[128,128],bf16> +# CHECK: %[[NONE:.+]] = torch.constant.none +# CHECK: %[[NONE_0:.+]] = torch.constant.none +# CHECK: %[[INT15:.+]] = torch.constant.int 15 +# CHECK: %[[FALSE:.+]] = torch.constant.bool false +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3, %[[NONE]], %[[NONE_0]], %[[INT15]], %[[FALSE]] +# CHECK: return %[[MM]] +def test_import_scaled_mm_per_tensor(): + class Basic(nn.Module): + def forward(self, a, b, a_scale, b_scale): + return torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=torch.bfloat16) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale = torch.tensor(1.0, dtype=torch.float32) + b_scale = torch.tensor(1.0, dtype=torch.float32) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale, + b_scale, + func_name="test_import_scaled_mm_per_tensor", + ) + print(m) + + +@run +# CHECK-LABEL: test_import_scaled_mm_per_tensor_e5m2 +# CHECK: func.func @test_import_scaled_mm_per_tensor_e5m2(%arg0: !torch.vtensor<[128,128],f8E5M2>, %arg1: !torch.vtensor<[128,128],f8E5M2>, %arg2: !torch.vtensor<[],f32>, %arg3: !torch.vtensor<[],f32>) -> !torch.vtensor<[128,128],bf16> +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3 +def test_import_scaled_mm_per_tensor_e5m2(): + class Basic(nn.Module): + def forward(self, a, b, a_scale, b_scale): + return torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=torch.bfloat16) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + a_scale = torch.tensor(1.0, dtype=torch.float32) + b_scale = torch.tensor(1.0, dtype=torch.float32) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale, + b_scale, + func_name="test_import_scaled_mm_per_tensor_e5m2", + ) + print(m) + + +@run +# CHECK-LABEL: test_import_scaled_mm_out_dtype_none +# CHECK: func.func @test_import_scaled_mm_out_dtype_none(%arg0: !torch.vtensor<[128,128],f8E4M3FN>, %arg1: !torch.vtensor<[128,128],f8E5M2>, %arg2: !torch.vtensor<[],f32>, %arg3: !torch.vtensor<[],f32>) -> !torch.vtensor<[128,128],f8E4M3FN> +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3 +# CHECK: return %[[MM]] +def test_import_scaled_mm_out_dtype_none(): + class Basic(nn.Module): + def forward(self, a, b, a_scale, b_scale): + return torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=None) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + a_scale = torch.tensor(1.0, dtype=torch.float32) + b_scale = torch.tensor(1.0, dtype=torch.float32) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale, + b_scale, + func_name="test_import_scaled_mm_out_dtype_none", + ) + print(m) + + +@run +# CHECK-LABEL: test_import_scaled_mm_block_scaled_fp8 +# CHECK: func.func @test_import_scaled_mm_block_scaled_fp8(%arg0: !torch.vtensor<[128,128],f8E4M3FN>, %arg1: !torch.vtensor<[128,128],f8E4M3FN>, %arg2: !torch.vtensor<[512],f8E8M0FNU>, %arg3: !torch.vtensor<[512],f8E8M0FNU>) -> !torch.vtensor<[128,128],bf16> +# CHECK: %[[NONE:.+]] = torch.constant.none +# CHECK: %[[NONE_0:.+]] = torch.constant.none +# CHECK: %[[INT15:.+]] = torch.constant.int 15 +# CHECK: %[[FALSE:.+]] = torch.constant.bool false +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3, %[[NONE]], %[[NONE_0]], %[[INT15]], %[[FALSE]] +# CHECK: return %[[MM]] +def test_import_scaled_mm_block_scaled_fp8(): + class Basic(nn.Module): + def forward(self, a, b, a_scale_block, b_scale_block): + return torch._scaled_mm( + a, b, a_scale_block, b_scale_block, out_dtype=torch.bfloat16 + ) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + b_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale_block, + b_scale_block, + func_name="test_import_scaled_mm_block_scaled_fp8", + ) + print(m) + + +@run +# CHECK-LABEL: test_import_scaled_mm_block_scaled_fp8_e5m2 +# CHECK: func.func @test_import_scaled_mm_block_scaled_fp8_e5m2(%arg0: !torch.vtensor<[128,128],f8E5M2>, %arg1: !torch.vtensor<[128,128],f8E5M2>, %arg2: !torch.vtensor<[512],f8E8M0FNU>, %arg3: !torch.vtensor<[512],f8E8M0FNU>) -> !torch.vtensor<[128,128],bf16> +# CHECK: %[[MM:.+]] = torch.aten._scaled_mm %arg0, %arg1, %arg2, %arg3 +def test_import_scaled_mm_block_scaled_fp8_e5m2(): + class Basic(nn.Module): + def forward(self, a, b, a_scale_block, b_scale_block): + return torch._scaled_mm( + a, b, a_scale_block, b_scale_block, out_dtype=torch.bfloat16 + ) + + a = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + b = torch.ones((128, 128), dtype=torch.float32).to(torch.float8_e5m2) + a_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + b_scale_block = torch.zeros((512,), dtype=torch.float8_e8m0fnu) + + m = fx.export_and_import( + Basic(), + a, + b, + a_scale_block, + b_scale_block, + func_name="test_import_scaled_mm_block_scaled_fp8_e5m2", + ) + print(m) + + @run # CHECK-LABEL: test_while_loop_two_returns # Check that helper functions are emitted first