From facbe5d96be48d323b29caee8374922617da767d Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 17 Nov 2023 00:51:55 +0800 Subject: [PATCH] =?UTF-8?q?[Torch=20Dialect]=20support=20AtenArangeStartOu?= =?UTF-8?q?tOp=20in=20ReduceOpVariants=20like=E2=80=A6=20(#2563)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … AtenBernoulli_FloatOp It fixing case like: `%2110 = torch.aten.arange.start_out %int1, %int1517, %int1, %2109 : !torch.int, !torch.int, !torch.int, !torch.tensor -> !torch.tensor`. `aten.arange.start_out` doesn't have value semantics also, means`%2110` is an alias for %2109. So I decompose it to `aten.arange.start` + `torch.contents.overwrite`. The complex decomposition logic is target to handle cases like view and dtype cast which I add in e2e tests. --- .../Torch/Transforms/ReduceOpVariants.cpp | 62 ++++++++++++++++--- projects/pt1/e2e_testing/xfail_sets.py | 6 ++ .../torch_mlir_e2e_test/test_suite/arange.py | 50 +++++++++++++++ 3 files changed, 110 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 53b1205107e3..a4b02cf9e17f 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -191,6 +191,16 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { // Reduce Ops without value semantics but the corresponding without trailing // underscore variant doesn't exist. namespace { + +// int(ceil((end - start) / step)) +Value calculateArangeResultNumElements(PatternRewriter &rewriter, Location loc, + Value start, Value end, Value step) { + Value sub = rewriter.create( + loc, Torch::NumberType::get(rewriter.getContext()), end, start); + Value div = rewriter.create(loc, sub, step); + return rewriter.create(loc, div); +} + class ReduceNonValueSemanticOps : public RewritePattern { public: ReduceNonValueSemanticOps(MLIRContext *context) @@ -198,19 +208,54 @@ class ReduceNonValueSemanticOps : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); - Operation *newOp; + MLIRContext *ctx = op->getContext(); if (isa(op)) { - newOp = rewriter.create( + Operation *newOp = rewriter.create( loc, op->getResultTypes(), op->getOperands()); + auto tensor = + rewriter.create(loc, newOp->getResult(0)); + createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0)); + rewriter.replaceOp(op, op->getOperand(0)); + return success(); + } else if (auto arangeOutOp = dyn_cast(op)) { + Value start = arangeOutOp.getStart(); + Value end = arangeOutOp.getEnd(); + Value step = arangeOutOp.getStep(); + Value out = arangeOutOp.getOut(); + + // `overwrite.tensor.contents` cannot change the tensor shape, + // so `out` tensor should have same num_elements with result tensor. + // It means that we don't support code like: + // `x = torch.randn(12)` + // `y = torch.arange(13, out=x)` + Value resultNumElements = + calculateArangeResultNumElements(rewriter, loc, start, end, step); + Value outNumElements = rewriter.create(loc, out); + Value eqOrNot = + rewriter.create(loc, resultNumElements, outNumElements); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("`out` tensor should have the same " + "num_elements with result tenosr")); + + auto dtype = rewriter.create(loc, out); + auto device = rewriter.create(loc, out); + auto shape = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(ctx)), out); + auto none = rewriter.create(loc); + Value newArange = rewriter.create( + loc, arangeOutOp.getResult().getType(), start, end, step, dtype, + /*layout=*/none, device, /*pin_memory=*/none); + Value reshape = rewriter.create( + loc, arangeOutOp.getResult().getType(), newArange, shape); + + auto vtensor = rewriter.create(loc, reshape); + createOverwriteTensorContents(rewriter, loc, vtensor, out); + rewriter.replaceOp(arangeOutOp, out); + return success(); } else { return failure(); } - - auto tensor = - rewriter.create(loc, newOp->getResult(0)); - createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0)); - rewriter.replaceOp(op, op->getOperand(0)); - return success(); } }; } // namespace @@ -309,6 +354,7 @@ struct ReduceOpVariantsPass ConversionTarget target(*context); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable]( Operation *op) { if (op->hasTrait() || diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2d26a3687a85..7d2a891d85c4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -302,6 +302,9 @@ # ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32) "ThresholdBackward2dMixedModule_basic", + + # ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4])) + "ArangeStartOutViewModule_basic", } if torch_version_for_comparison() < version.parse("2.1.0.dev"): @@ -1303,6 +1306,8 @@ "AtenEyeModuleFalsePinMemory_basic", "AtenEyeModuleFloat2D_basic", "MeanModule_basic", + "ArangeStartOutModule_basic", + "ArangeStartOutViewModule_basic", } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { @@ -1372,6 +1377,7 @@ "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", "AddIntModule_basic", + "ArangeStartOutViewModule_basic", "AtenIntBoolOpModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py index d7ca3b6e2bac..8237d2601711 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py @@ -248,3 +248,53 @@ def forward(self): @register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule()) def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils): module.forward() + +# ============================================================================== + +class ArangeStartOutModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([12], torch.int64, True), + ]) + def forward(self, x): + return torch.arange(start=0, end=12, out=x) + +@register_test_case(module_factory=lambda: ArangeStartOutModule()) +def ArangeStartOutModule_basic(module, tu: TestUtils): + module.forward(torch.zeros(12).to(torch.int64)) + +class ArangeStartOutViewModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int64, True), + ]) + def forward(self, x): + return torch.arange(start=1, end=13, out=x) + +@register_test_case(module_factory=lambda: ArangeStartOutViewModule()) +def ArangeStartOutViewModule_basic(module, tu: TestUtils): + module.forward(torch.zeros(3, 4).to(torch.int64)) + +class ArangeStartOutDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([12], torch.int64, True), + ]) + def forward(self, x): + return torch.arange(start=1.1, end=13.1, out=x) + +@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule()) +def ArangeStartOutDtypeModule_basic(module, tu: TestUtils): + module.forward(torch.zeros(12).to(torch.int64))