From 8e389ff2ffac781648721696c716f141048c38c9 Mon Sep 17 00:00:00 2001 From: Sungsoon Cho <1025787+godot73@users.noreply.github.com> Date: Wed, 27 Dec 2023 20:33:18 -0800 Subject: [PATCH] Implement lowering of torch.aten.exponential (#2680) https://github.com/llvm/torch-mlir/issues/2646 Decompose aten.exponential() into: -exp(1-x)/lambda --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 7 +++ .../Torch/Transforms/DecomposeComplexOps.cpp | 46 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 7 +++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/rng.py | 23 ++++++++++ 8 files changed, 111 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6013f6da3cfc..16eb5565bedd 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4739,6 +4739,31 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [ }]; } +def Torch_AtenExponentialOp : Torch_Op<"aten.exponential", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::exponential : (Tensor, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$lambd, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExponentialOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenExponentialOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 1031f4aa7e53..25e83899bc1d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7580,6 +7580,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.exponential\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -9382,6 +9385,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exponential\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d8b8639e0a75..63fa66ccc31e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3562,6 +3562,51 @@ class DecomposeAtenBernoulliTensorOp }; } // namespace +namespace { +// Decompose exponential() to do inverse transform sampling. +// - https://en.wikipedia.org/wiki/Inverse_transform_sampling +// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus, +// exponential() = - ln(1 - uniform(0, 1)) / lambda. +class DecomposeAtenExponentialOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenExponentialOp op, + PatternRewriter &rewriter) const override { + if (!op.getGenerator().getType().isa()) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + Location loc = op.getLoc(); + Type resultType = op.getType(); + + // Create a uniform random op with low and high set to 0.0 and 1.0, + // respectively. + Value none = rewriter.create(loc); + Value zero = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value one = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value emptyTensor = rewriter.create( + loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + Value x = rewriter.create(loc, resultType, emptyTensor, + /*from=*/zero, /*to=*/one, + /*generator=*/none); + + Value negX = rewriter.create(loc, resultType, x); + Value oneMinusX = + rewriter.create(loc, resultType, negX, one, + /*alpha=*/one); + Value lnOneMinusX = rewriter.create(loc, resultType, oneMinusX); + Value negLambda = rewriter.create(loc, op.getLambd()); + rewriter.replaceOpWithNewOp(op, resultType, lnOneMinusX, + negLambda); + return success(); + } +}; +} // namespace + namespace { template class DecomposeAtenAddCLikeOp : public OpRewritePattern { @@ -6410,6 +6455,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeAtenBernoulliLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 79f64ef32fbf..933140d3013d 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c70b01b47819..6f683a43c34f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1397,6 +1397,7 @@ "CeilFloatModule_basic", "DivFloatModule_basic", "EqIntModule_basic", + "ExponentialModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 338f5e97e100..2e6094a6fa20 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -831,6 +831,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]: return self +def aten〇exponential〡shape(self: List[int], lambd: float = 1., generator: Any = None) -> List[int]: + return self + def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size @@ -2267,6 +2270,10 @@ def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇exponential〡dtype(self_rank_dtype: Tuple[int, int], lambd: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function([Invocation([1]), Invocation([1], dtype=torch.float16), Invocation([1], dtype=torch.complex64)]) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index efee6c852eb4..fb458f6a5d91 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -378,6 +378,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") + emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)") emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)") emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py index 1baa462462f1..dedd2b398bd4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py @@ -157,6 +157,29 @@ def UniformNoCorrelationModule_basic(module, tu: TestUtils): # ============================================================================== +class ExponentialModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, x): + a = torch.ops.aten.exponential(x, 3.0) + mean = torch.mean(a) + std = torch.std(a) + return mean, std + + +@register_test_case(module_factory=lambda: ExponentialModule()) +def ExponentialModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(512, 512, 16).double()) + +# ============================================================================== + class BernoulliModule(torch.nn.Module): def __init__(self): super().__init__()