Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tcp] Handle int inputs in sqrt #2467

Merged
merged 1 commit into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class Tcp_UnaryElementwiseOp<string mnemonic, list<Trait> traits = []> :
Tcp_Op<mnemonic, !listconcat(traits, [
Pure,
Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultElementType])> {
SameOperandsAndResultShape])> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since all elementwise ops except tcp.sqrt adhere to the SameOperandsAndResultElementType trait, is it better to leave this as-is, and let Tcp_SqrtOp just be a Tcp_Op that uses a relaxed set of traits?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still useful to be able to identify all the unary elementwise ops using the common base class type.

If there are more such ops in future, we could add another base type that can be common to all of them.

}

class Tcp_BinaryElementwiseOp<string mnemonic, list<Trait> traits = []> :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include "torch-mlir-dialects/Dialect/Tcp/IR/TcpEnums.td"

include "mlir/IR/OpBase.td"

def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh"> {
def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh", [SameOperandsAndResultElementType]> {
let summary = "Computes tanh of input, elementwise";

let description = [{
Expand All @@ -33,7 +33,7 @@ def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp"> {
def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp", [SameOperandsAndResultElementType]> {
let summary = "Clamps input tensor to the given min and/or max";

let description = [{
Expand Down Expand Up @@ -65,7 +65,7 @@ def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp"> {
let hasVerifier = 1;
}

def Tcp_SigmoidOp : Tcp_UnaryElementwiseOp<"sigmoid"> {
def Tcp_SigmoidOp : Tcp_UnaryElementwiseOp<"sigmoid", [SameOperandsAndResultElementType]> {
let summary = "Computes sigmoid of input, elementwise";

let description = [{
Expand Down Expand Up @@ -312,19 +312,19 @@ def Tcp_IsolatedGroupOp : Tcp_Op<"isolated_group", [
let hasVerifier = 1;
}

def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt", [SameOperandsAndResultElementType]> {
def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt"> {
let summary = "Computes square root of input, elementwise";

let description = [{
Computes elementwise square root of the input tensor.
}];

let arguments = (ins
Tcp_FloatOrComplexTensor:$in
Tcp_FloatOrIntTensor:$in
);

let results = (outs
Tcp_FloatOrComplexTensor:$out
Tcp_FloatTensor:$out
);

let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
Expand All @@ -351,7 +351,7 @@ def Tcp_ConcatOp : Tcp_Op<"concat", [SameOperandsAndResultElementType]> {
let hasVerifier = 1;
}

def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil"> {
def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil", [SameOperandsAndResultElementType]> {
let summary = "Computes ceil of input, elementwise";

let description = [{
Expand All @@ -369,7 +369,7 @@ def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor"> {
def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor", [SameOperandsAndResultElementType]> {
let summary = "Computes floor of input, elementwise";

let description = [{
Expand All @@ -387,7 +387,7 @@ def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos"> {
def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos", [SameOperandsAndResultElementType]> {
let summary = "Computes cosine of input, elementwise";

let description = [{
Expand All @@ -405,7 +405,7 @@ def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin"> {
def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin", [SameOperandsAndResultElementType]> {
let summary = "Computes sine of input, elementwise";

let description = [{
Expand All @@ -423,7 +423,7 @@ def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs"> {
def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs", [SameOperandsAndResultElementType]> {
let summary = "Computes absolute of input, elementwise";

let description = [{
Expand All @@ -441,7 +441,7 @@ def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log"> {
def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log", [SameOperandsAndResultElementType]> {
let summary = "Computes natural logarithm of input, elementwise";

let description = [{
Expand All @@ -459,7 +459,7 @@ def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg"> {
def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg", [SameOperandsAndResultElementType]> {
let summary = "Computes the negation of input, elementwise";

let description = [{
Expand All @@ -477,7 +477,7 @@ def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_AtanOp : Tcp_UnaryElementwiseOp<"atan"> {
def Tcp_AtanOp : Tcp_UnaryElementwiseOp<"atan", [SameOperandsAndResultElementType]> {
let summary = "Computes the arcus tangent value of input, elementwise";

let description = [{
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
return rewriter.notifyMatchFailure(
catOp, "aten.cat operands must be a list of tensors");

SmallVector tensorInputs = getTypeConvertedValues(
rewriter, catOp->getLoc(), getTypeConverter(), inputs);
auto tensorInputs = getTypeConvertedValues(rewriter, catOp->getLoc(),
getTypeConverter(), inputs);

int64_t dim;
if (!matchPattern(catOp.getDim(), m_TorchConstantInt(&dim)))
Expand Down
60 changes: 34 additions & 26 deletions lib/Conversion/TorchToTcp/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,14 @@ class ConvertAtenReluOp : public OpConversionPattern<AtenReluOp> {
}
};

class ConvertAtenAbsOp : public OpConversionPattern<AtenAbsOp> {
template <typename AtenOpT, typename TcpOpT>
class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenAbsOp>::OpConversionPattern;
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult
matchAndRewrite(AtenAbsOp op, OpAdaptor adaptor,
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
RankedTensorType inputType = input.getType().dyn_cast<RankedTensorType>();
Expand All @@ -464,13 +466,18 @@ class ConvertAtenAbsOp : public OpConversionPattern<AtenAbsOp> {
return rewriter.notifyMatchFailure(
op, "Abs input tensor must have integer or floating-point datatype");

rewriter.replaceOpWithNewOp<tcp::AbsOp>(op, inputType, input);
RankedTensorType resultType =
OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();

rewriter.replaceOpWithNewOp<TcpOpT>(op, resultType, input);
return success();
}
};

template <typename AtenOpT, typename TcpOpT>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenUnaryFpOnlyOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
Expand Down Expand Up @@ -680,37 +687,38 @@ void torch_to_tcp::populateElementwisePatternsAndLegality(

target.addIllegalOp<AtenCeilOp>();
target.addIllegalOp<AtenFloorOp>();
target.addIllegalOp<AtenSqrtOp>();
target.addIllegalOp<AtenSigmoidOp>();
target.addIllegalOp<AtenTanhOp>();
target.addIllegalOp<AtenSinOp>();
target.addIllegalOp<AtenCosOp>();
target.addIllegalOp<AtenLogOp>();
target.addIllegalOp<AtenNegOp>();
target.addIllegalOp<AtenAtanOp>();
patterns.add<ConvertAtenUnaryOp<AtenFloorOp, tcp::FloorOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenCeilOp, tcp::CeilOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenSqrtOp, tcp::SqrtOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenSigmoidOp, tcp::SigmoidOp>>(typeConverter,
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenFloorOp, tcp::FloorOp>>(
typeConverter, context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenCeilOp, tcp::CeilOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenSigmoidOp, tcp::SigmoidOp>>(
typeConverter, context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenTanhOp, tcp::TanhOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenSinOp, tcp::SinOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenCosOp, tcp::CosOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenLogOp, tcp::LogOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenNegOp, tcp::NegOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenAtanOp, tcp::AtanOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenTanhOp, tcp::TanhOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenSinOp, tcp::SinOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenCosOp, tcp::CosOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenLogOp, tcp::LogOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenNegOp, tcp::NegOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenAtanOp, tcp::AtanOp>>(typeConverter,
context);

target.addIllegalOp<AtenAbsOp>();
patterns.add<ConvertAtenAbsOp>(typeConverter, context);
target.addIllegalOp<AtenSqrtOp>();
patterns.add<ConvertAtenUnaryIntOrFpOp<AtenAbsOp, tcp::AbsOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryIntOrFpOp<AtenSqrtOp, tcp::SqrtOp>>(
typeConverter, context);

target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/TorchToTcp/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,19 @@ func.func @torch.aten.sqrt(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[

// -----

// CHECK-LABEL: func.func @torch.aten.sqrt_int(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T1:.*]] = tcp.sqrt %[[T0]] : tensor<?x?xi32> -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.sqrt_int(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.sqrt %arg0 : !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.ceil(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
Expand Down