Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6623,6 +6623,36 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [
}];
}

def Torch_Aten_ScaledMmOp : Torch_Op<"aten._scaled_mm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_scaled_mm : (Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mat2,
AnyTorchTensorType:$scale_a,
AnyTorchTensorType:$scale_b,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$scale_result,
AnyTorchOptionalIntType:$out_dtype,
Torch_BoolType:$use_fast_accum
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ScaledMmOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void Aten_ScaledMmOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_Aten_IntMmOp : Torch_Op<"aten._int_mm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
16 changes: 8 additions & 8 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,27 +297,27 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && isa<Float8E4M3Type>(dest)) ||
(src.isF32() && isa<Float8E4M3FNType>(dest)) ||
(src.isF32() && isa<Float8E5M2Type>(dest)) ||
// f16 -> *
(src.isF16() && dest.isInteger(32)) ||
(src.isF16() && dest.isInteger(16)) ||
(src.isF16() && dest.isInteger(8)) ||
(src.isF16() && dest.isBF16()) ||
(src.isF16() && dest.isF32()) ||
(src.isF16() && isa<Float8E4M3Type>(dest)) ||
(src.isF16() && isa<Float8E4M3FNType>(dest)) ||
(src.isF16() && isa<Float8E5M2Type>(dest)) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isF32()) ||
(src.isBF16() && isa<Float8E4M3Type>(dest)) ||
(src.isBF16() && isa<Float8E4M3FNType>(dest)) ||
(src.isBF16() && isa<Float8E5M2Type>(dest)) ||
// fp8e4m3 -> *
(isa<Float8E4M3Type>(src) && dest.isBF16()) ||
(isa<Float8E4M3Type>(src) && dest.isF32()) ||
(isa<Float8E4M3Type>(src) && dest.isF16()) ||
(isa<Float8E4M3FNType>(src) && dest.isBF16()) ||
(isa<Float8E4M3FNType>(src) && dest.isF32()) ||
(isa<Float8E4M3FNType>(src) && dest.isF16()) ||
// fp8e5m2 -> *
(isa<Float8E5M2Type>(src) && dest.isBF16()) ||
(isa<Float8E5M2Type>(src) && dest.isF32()) ||
Expand Down Expand Up @@ -514,8 +514,8 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
outputElemTy.isInteger(48)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
} else if ((isa<Float8E4M3FNType>(inputElemTy) &&
isa<Float8E4M3FNType>(weightElemTy) && outputElemTy.isF16()) ||
(isa<Float8E5M2Type>(inputElemTy) &&
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
accType = mlir::TypeAttr::get(rewriter.getF16Type());
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ static bool isValidTorchDtype(Type dtype) {
// Builtin floating point types.
if (isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(dtype))
return true;
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E8M0FNUType>(dtype))
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E8M0FNUType,
Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType>(dtype))
return true;

if (isa<Torch::StringType>(dtype))
Expand Down
206 changes: 206 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8407,6 +8407,138 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._scaled_mm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.optional<list<int>>, %arg5: !torch.optional<list<int>>, %arg6: !torch.optional<int>, %arg7: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__._check_scaled_mm_scale_shapes(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.none\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @__torch__._check_scaled_mm_scale_shapes(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.none {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int1 = torch.constant.int 1\n"
" %int16 = torch.constant.int 16\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %0 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.remainder.int %0, %int16 : !torch.int, !torch.int -> !torch.int\n"
" %2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.aten.remainder.int %3, %int16 : !torch.int, !torch.int -> !torch.int\n"
" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.remainder.int %6, %int16 : !torch.int, !torch.int -> !torch.int\n"
" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = call @__torch__._numel(%arg2) : (!torch.list<int>) -> !torch.int\n"
" %10 = call @__torch__._numel(%arg3) : (!torch.list<int>) -> !torch.int\n"
" %11 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.bool) {\n"
" %13 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %12 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %13 = torch.aten.ne.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %14 = torch.prim.If %13 -> (!torch.bool) {\n"
" %18 = torch.aten.ne.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %18 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %14 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %15 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %16 = torch.aten.eq.int %15, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
" %18 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
" %19 = torch.aten.eq.int %18, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %19 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %17 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %18 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %19 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = func.call @__torch__._scaled_mm_block_scale_numel(%18, %19) : (!torch.int, !torch.int) -> !torch.int\n"
" %21 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %22 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %23 = func.call @__torch__._scaled_mm_block_scale_numel(%21, %22) : (!torch.int, !torch.int) -> !torch.int\n"
" %24 = torch.aten.eq.int %9, %20 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %24 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %25 = torch.aten.eq.int %10, %23 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %25 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" }\n"
" return %none : !torch.none\n"
" }\n"
" func.func @__torch__._numel(%arg0: !torch.list<int>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.prim.Loop %0, %true, init(%int1) {\n"
" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n"
" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop.condition %true, iter(%3 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @__torch__._scaled_mm_block_scale_numel(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int128 = torch.constant.int 128\n"
" %int32 = torch.constant.int 32\n"
" %0 = call @__torch__._round_up_to_multiple(%arg0, %int128) : (!torch.int, !torch.int) -> !torch.int\n"
" %1 = call @__torch__._round_up_to_multiple(%arg1, %int128) : (!torch.int, !torch.int) -> !torch.int\n"
" %2 = torch.aten.mul.int %0, %1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.floordiv.int %2, %int32 : !torch.int, !torch.int -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @__torch__._round_up_to_multiple(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %1 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %2 = torch.aten.floordiv.int %1, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.mul.int %2, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._int_mm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -15544,6 +15676,80 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._scaled_mm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.optional<tuple<int, int>>, %arg6: !torch.optional<int>, %arg7: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3:2 = torch.prim.TupleUnpack %arg3 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %4 = call @__torch__._check_scaled_mm_dtypes(%0#1, %1#1, %2#1, %3#1) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none\n"
" %5 = torch.aten.__isnot__ %arg6, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.int) {\n"
" %7 = torch.prim.unchecked_cast %arg6 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %7 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @__torch__._check_scaled_mm_dtypes(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.none {\n"
" %int44 = torch.constant.int 44\n"
" %true = torch.constant.bool true\n"
" %false = torch.constant.bool false\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0 = call @__torch__._is_scaled_mm_data_dtype(%arg0) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %0 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %1 = call @__torch__._is_scaled_mm_data_dtype(%arg1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.eq.int %arg2, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" %5 = torch.aten.eq.int %arg3, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %5 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %5 = torch.aten.eq.int %arg2, %int44 : !torch.int, !torch.int -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
" %7 = torch.aten.eq.int %arg3, %int44 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %6 : !torch.bool\n"
" }\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %none : !torch.none\n"
" }\n"
" func.func @__torch__._is_scaled_mm_data_dtype(%arg0: !torch.int) -> !torch.bool {\n"
" %int25 = torch.constant.int 25\n"
" %int23 = torch.constant.int 23\n"
" %int26 = torch.constant.int 26\n"
" %int24 = torch.constant.int 24\n"
" %int45 = torch.constant.int 45\n"
" %0 = torch.prim.ListConstruct %int45, %int24, %int26, %int23, %int25 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._int_mm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int3 = torch.constant.int 3\n"
" %none = torch.constant.none\n"
Expand Down
Loading
Loading