Skip to content

Commit

Permalink
lower torch.aten.isinf to linalg (#2638)
Browse files Browse the repository at this point in the history
Co-authored-by: Rob Suderman <[email protected]>
  • Loading branch information
renxida and rsuderman committed Dec 29, 2023
1 parent 9fc212e commit 6660a26
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
11 changes: 9 additions & 2 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (isa<AtenAbsOp>(op))
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
if (isa<AtenIsinfOp>(op)){
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
Value infinity = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
return createEqual(b, loc, abs.getType(), abs, infinity);
}
if (isa<AtenSigmoidOp>(op)) {
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
b, converter, payloadArgs[0], op);
Expand Down Expand Up @@ -1343,7 +1349,7 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
Expand Down Expand Up @@ -1992,7 +1998,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
Expand Down
4 changes: 4 additions & 0 deletions projects/ltc/csrc/base_lazy_backend/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
return {Shape(at::kBool, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
Expand Down
4 changes: 3 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@
"ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseAtenIsinfOpModule_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseBinaryModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
Expand Down Expand Up @@ -1328,6 +1329,8 @@
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
}) - {
### Test failing in make_fx_tosa but not in tosa

Expand Down Expand Up @@ -1489,5 +1492,4 @@
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseIsinfModule_basic",
}
25 changes: 25 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3385,6 +3385,31 @@ def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=2).bool())


# ==============================================================================

class ElementwiseAtenIsinfOpModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([2, 5], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.isinf(x)

@register_test_case(module_factory=lambda: ElementwiseAtenIsinfOpModule())
def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils):
test_input = torch.tensor(
[
[1, float('inf'), 2, float('-inf'), float('nan')],
[1, float('inf'), float('-inf'), float('nan'), 3],
]
)
module.forward(test_input)


# ==============================================================================


Expand Down

0 comments on commit 6660a26

Please sign in to comment.