Skip to content

Commit

Permalink
Implement lowering of torch.aten.exponential (#2680)
Browse files Browse the repository at this point in the history
#2646

Decompose aten.exponential() into: -exp(1-x)/lambda
  • Loading branch information
godot73 committed Dec 28, 2023
1 parent d560698 commit 8e389ff
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 0 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4739,6 +4739,31 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
}];
}

def Torch_AtenExponentialOp : Torch_Op<"aten.exponential", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::exponential : (Tensor, float, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$lambd,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenExponentialOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenExponentialOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7580,6 +7580,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.exponential\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
Expand Down Expand Up @@ -9382,6 +9385,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.exponential\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
Expand Down
46 changes: 46 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3562,6 +3562,51 @@ class DecomposeAtenBernoulliTensorOp
};
} // namespace

namespace {
// Decompose exponential() to do inverse transform sampling.
// - https://en.wikipedia.org/wiki/Inverse_transform_sampling
// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus,
// exponential() = - ln(1 - uniform(0, 1)) / lambda.
class DecomposeAtenExponentialOp : public OpRewritePattern<AtenExponentialOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExponentialOp op,
PatternRewriter &rewriter) const override {
if (!op.getGenerator().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");

Location loc = op.getLoc();
Type resultType = op.getType();

// Create a uniform random op with low and high set to 0.0 and 1.0,
// respectively.
Value none = rewriter.create<ConstantNoneOp>(loc);
Value zero =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
Value x = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensor,
/*from=*/zero, /*to=*/one,
/*generator=*/none);

Value negX = rewriter.create<AtenNegOp>(loc, resultType, x);
Value oneMinusX =
rewriter.create<AtenAddScalarOp>(loc, resultType, negX, one,
/*alpha=*/one);
Value lnOneMinusX = rewriter.create<AtenLogOp>(loc, resultType, oneMinusX);
Value negLambda = rewriter.create<AtenNegFloatOp>(loc, op.getLambd());
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, resultType, lnOneMinusX,
negLambda);
return success();
}
};
} // namespace

namespace {
template <typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
Expand Down Expand Up @@ -6410,6 +6455,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExponentialOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
target.addIllegalOp<AtenBernoulliPOp>();
target.addIllegalOp<AtenBernoulliTensorOp>();
target.addIllegalOp<AtenExponentialOp>();
target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenEyeOp>();
target.addIllegalOp<AtenEyeMOp>();
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,7 @@
"CeilFloatModule_basic",
"DivFloatModule_basic",
"EqIntModule_basic",
"ExponentialModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa
def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]:
return self

def aten〇exponential〡shape(self: List[int], lambd: float = 1., generator: Any = None) -> List[int]:
return self

def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size

Expand Down Expand Up @@ -2267,6 +2270,10 @@ def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0.,
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇exponential〡dtype(self_rank_dtype: Tuple[int, int], lambd: float = 1., generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function([Invocation([1]),
Invocation([1], dtype=torch.float16),
Invocation([1], dtype=torch.complex64)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,29 @@ def UniformNoCorrelationModule_basic(module, tu: TestUtils):

# ==============================================================================

class ExponentialModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, x):
a = torch.ops.aten.exponential(x, 3.0)
mean = torch.mean(a)
std = torch.std(a)
return mean, std


@register_test_case(module_factory=lambda: ExponentialModule())
def ExponentialModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(512, 512, 16).double())

# ==============================================================================

class BernoulliModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 8e389ff

Please sign in to comment.