Skip to content

Commit

Permalink
[Torch Dialect] support AtenArangeStartOutOp in ReduceOpVariants like… (
Browse files Browse the repository at this point in the history
#2563)

… AtenBernoulli_FloatOp

It fixing case like: `%2110 = torch.aten.arange.start_out %int1,
%int1517, %int1, %2109 : !torch.int, !torch.int, !torch.int,
!torch.tensor -> !torch.tensor`.
`aten.arange.start_out` doesn't have value semantics also, means`%2110`
is an alias for %2109.
So I decompose it to `aten.arange.start` + `torch.contents.overwrite`.  
The complex decomposition logic is target to handle cases like view and
dtype cast which I add in e2e tests.
  • Loading branch information
qingyunqu committed Nov 16, 2023
1 parent dad1f01 commit facbe5d
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 8 deletions.
62 changes: 54 additions & 8 deletions lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,26 +191,71 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
// Reduce Ops without value semantics but the corresponding without trailing
// underscore variant doesn't exist.
namespace {

// int(ceil((end - start) / step))
Value calculateArangeResultNumElements(PatternRewriter &rewriter, Location loc,
Value start, Value end, Value step) {
Value sub = rewriter.create<AtenSubOp>(
loc, Torch::NumberType::get(rewriter.getContext()), end, start);
Value div = rewriter.create<AtenDivOp>(loc, sub, step);
return rewriter.create<AtenCeilFloatOp>(loc, div);
}

class ReduceNonValueSemanticOps : public RewritePattern {
public:
ReduceNonValueSemanticOps(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Operation *newOp;
MLIRContext *ctx = op->getContext();
if (isa<AtenBernoulli_FloatOp>(op)) {
newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
Operation *newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
loc, op->getResultTypes(), op->getOperands());
auto tensor =
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
return success();
} else if (auto arangeOutOp = dyn_cast<AtenArangeStartOutOp>(op)) {
Value start = arangeOutOp.getStart();
Value end = arangeOutOp.getEnd();
Value step = arangeOutOp.getStep();
Value out = arangeOutOp.getOut();

// `overwrite.tensor.contents` cannot change the tensor shape,
// so `out` tensor should have same num_elements with result tensor.
// It means that we don't support code like:
// `x = torch.randn(12)`
// `y = torch.arange(13, out=x)`
Value resultNumElements =
calculateArangeResultNumElements(rewriter, loc, start, end, step);
Value outNumElements = rewriter.create<AtenNumelOp>(loc, out);
Value eqOrNot =
rewriter.create<AtenEqIntOp>(loc, resultNumElements, outNumElements);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("`out` tensor should have the same "
"num_elements with result tenosr"));

auto dtype = rewriter.create<PrimDtypeOp>(loc, out);
auto device = rewriter.create<PrimDeviceOp>(loc, out);
auto shape = rewriter.create<AtenSizeOp>(
loc, Torch::ListType::get(Torch::IntType::get(ctx)), out);
auto none = rewriter.create<ConstantNoneOp>(loc);
Value newArange = rewriter.create<AtenArangeStartStepOp>(
loc, arangeOutOp.getResult().getType(), start, end, step, dtype,
/*layout=*/none, device, /*pin_memory=*/none);
Value reshape = rewriter.create<AtenReshapeOp>(
loc, arangeOutOp.getResult().getType(), newArange, shape);

auto vtensor = rewriter.create<CopyToValueTensorOp>(loc, reshape);
createOverwriteTensorContents(rewriter, loc, vtensor, out);
rewriter.replaceOp(arangeOutOp, out);
return success();
} else {
return failure();
}

auto tensor =
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
};
} // namespace
Expand Down Expand Up @@ -309,6 +354,7 @@ struct ReduceOpVariantsPass
ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.addIllegalOp<AtenArangeStartOutOp>();
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
Expand Down
6 changes: 6 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@

# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
"ThresholdBackward2dMixedModule_basic",

# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
"ArangeStartOutViewModule_basic",
}

if torch_version_for_comparison() < version.parse("2.1.0.dev"):
Expand Down Expand Up @@ -1303,6 +1306,8 @@
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
"MeanModule_basic",
"ArangeStartOutModule_basic",
"ArangeStartOutViewModule_basic",
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
Expand Down Expand Up @@ -1372,6 +1377,7 @@
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"AddIntModule_basic",
"ArangeStartOutViewModule_basic",
"AtenIntBoolOpModule_basic",
"BernoulliTensorModule_basic",
"BincountMinlengthModule_basic",
Expand Down
50 changes: 50 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,53 @@ def forward(self):
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

class ArangeStartOutModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([12], torch.int64, True),
])
def forward(self, x):
return torch.arange(start=0, end=12, out=x)

@register_test_case(module_factory=lambda: ArangeStartOutModule())
def ArangeStartOutModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(12).to(torch.int64))

class ArangeStartOutViewModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4], torch.int64, True),
])
def forward(self, x):
return torch.arange(start=1, end=13, out=x)

@register_test_case(module_factory=lambda: ArangeStartOutViewModule())
def ArangeStartOutViewModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(3, 4).to(torch.int64))

class ArangeStartOutDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([12], torch.int64, True),
])
def forward(self, x):
return torch.arange(start=1.1, end=13.1, out=x)

@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(12).to(torch.int64))

0 comments on commit facbe5d

Please sign in to comment.