Skip to content

Commit

Permalink
[onnx] Add support for onnx.sinh (#2643)
Browse files Browse the repository at this point in the history
Adds a lowering from `onnx.sinh` to `aten.sinh`. This includes adding
the `aten.sinh` operator.
  • Loading branch information
rsuderman committed Dec 16, 2023
1 parent b3e9420 commit 6188869
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 1 deletion.
45 changes: 45 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,51 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [
}];
}

def Torch_AtenSinhOp : Torch_Op<"aten.sinh", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sinh : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSinhOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSinhOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSinh_Op : Torch_Op<"aten.sinh_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::sinh_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSinh_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSinh_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
14 changes: 13 additions & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();

rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
binder.op, resultType, operand);
return success();
});

patterns.onOp("Sinh", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();

rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
binder.op, resultType, operand);
return success();
});

patterns.onOp(
"Transpose", 13,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::selu : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sign : (Tensor) -> (Tensor)",
"aten::sinh : (Tensor) -> (Tensor)",
"aten::sgn : (Tensor) -> (Tensor)",
"aten::hardsigmoid : (Tensor) -> (Tensor)",
"aten::hardswish : (Tensor) -> (Tensor)",
Expand Down
9 changes: 9 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,15 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,

// -----

// CHECK-LABEL: func.func @test_sinh
func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} {
// CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}

// -----

// CHECK-LABEL: func.func @test_transpose_default
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
Expand Down

0 comments on commit 6188869

Please sign in to comment.