Skip to content

Commit e728248

Browse files
authored
[Torch Dialect] support aten.glu (#2531)
1 parent b0f39ac commit e728248

File tree

8 files changed

+136
-0
lines changed

8 files changed

+136
-0
lines changed

e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,7 @@
963963
"ElementwiseMaximumIntModule_basic",
964964
"ElementwiseMaxOtherIntModule_basic",
965965
"ElementwiseMaxOtherModule_basic",
966+
"GluStaticModule_basic",
966967
"ViewDoubleMergeStaticModule_basic",
967968
"ViewCollapseOnesMiddleModule_basic",
968969
"ViewFiveTestStaticModule_basic",

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4189,6 +4189,30 @@ def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [
41894189
}];
41904190
}
41914191

4192+
def Torch_AtenGluOp : Torch_Op<"aten.glu", [
4193+
AllowsTypeRefinement,
4194+
HasValueSemantics,
4195+
ReadOnly
4196+
]> {
4197+
let summary = "Generated op for `aten::glu : (Tensor, int) -> (Tensor)`";
4198+
let arguments = (ins
4199+
AnyTorchTensorType:$self,
4200+
Torch_IntType:$dim
4201+
);
4202+
let results = (outs
4203+
AnyTorchTensorType:$result
4204+
);
4205+
let hasCustomAssemblyFormat = 1;
4206+
let extraClassDefinition = [{
4207+
ParseResult AtenGluOp::parse(OpAsmParser &parser, OperationState &result) {
4208+
return parseDefaultTorchOp(parser, result, 2, 1);
4209+
}
4210+
void AtenGluOp::print(OpAsmPrinter &printer) {
4211+
printDefaultTorchOp(printer, *this, 2, 1);
4212+
}
4213+
}];
4214+
}
4215+
41924216
def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
41934217
AllowsTypeRefinement,
41944218
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6382,6 +6382,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
63826382
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
63836383
" return %0 : !torch.list<int>\n"
63846384
" }\n"
6385+
" func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
6386+
" %none = torch.constant.none\n"
6387+
" %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n"
6388+
" %int0 = torch.constant.int 0\n"
6389+
" %int2 = torch.constant.int 2\n"
6390+
" %int1 = torch.constant.int 1\n"
6391+
" %0 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
6392+
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
6393+
" %13 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
6394+
" %14 = torch.aten.add.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n"
6395+
" torch.prim.If.yield %14 : !torch.int\n"
6396+
" } else {\n"
6397+
" torch.prim.If.yield %arg1 : !torch.int\n"
6398+
" }\n"
6399+
" %2 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
6400+
" %3 = torch.aten.remainder.int %2, %int2 : !torch.int, !torch.int -> !torch.int\n"
6401+
" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
6402+
" torch.prim.If %4 -> () {\n"
6403+
" torch.prim.If.yield\n"
6404+
" } else {\n"
6405+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
6406+
" torch.prim.If.yield\n"
6407+
" }\n"
6408+
" %5 = torch.aten.slice.t %arg0, %none, %1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
6409+
" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
6410+
" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n"
6411+
" %8 = torch.prim.ListConstruct %7 : (!torch.int) -> !torch.list<int>\n"
6412+
" %9 = torch.aten.add.t %5, %8 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
6413+
" %10 = torch.aten.add.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n"
6414+
" %11 = torch.aten.slice.t %arg0, %10, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
6415+
" %12 = torch.aten.add.t %9, %11 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
6416+
" return %12 : !torch.list<int>\n"
6417+
" }\n"
63856418
" func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
63866419
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
63876420
" return %0 : !torch.list<int>\n"
@@ -8863,6 +8896,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
88638896
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
88648897
" return %0#1 : !torch.int\n"
88658898
" }\n"
8899+
" func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
8900+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
8901+
" return %0#1 : !torch.int\n"
8902+
" }\n"
88668903
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n"
88678904
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
88688905
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,48 @@ class DecomposeAtenNarrowTensorOp
361361
};
362362
} // namespace
363363

364+
namespace {
365+
class DecomposeAtenGluOp : public OpRewritePattern<AtenGluOp> {
366+
public:
367+
using OpRewritePattern::OpRewritePattern;
368+
LogicalResult matchAndRewrite(AtenGluOp op,
369+
PatternRewriter &rewriter) const override {
370+
Location loc = op.getLoc();
371+
Value self = op.getSelf();
372+
Value dim = op.getDim();
373+
374+
auto outputTy = op.getType().dyn_cast<Torch::ValueTensorType>();
375+
if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) {
376+
return rewriter.notifyMatchFailure(
377+
op, "Expected output type having sizes and dtype");
378+
}
379+
380+
Value zero =
381+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
382+
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
383+
Value two =
384+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
385+
386+
Value remainder = rewriter.create<AtenRemainderIntOp>(loc, dimSize, two);
387+
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, zero);
388+
rewriter.create<RuntimeAssertOp>(
389+
loc, eqOrNot,
390+
rewriter.getStringAttr("AtenGluOp's dim size must be multiply of 2"));
391+
392+
Value splitLength = rewriter.create<AtenFloordivIntOp>(loc, dimSize, two);
393+
Value a = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim, zero,
394+
splitLength);
395+
Value b = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim,
396+
splitLength, splitLength);
397+
// a⊗σ(b)
398+
Value sigmoidB = rewriter.create<AtenSigmoidOp>(loc, outputTy, b);
399+
Value result = rewriter.create<AtenMulTensorOp>(loc, outputTy, a, sigmoidB);
400+
rewriter.replaceOp(op, result);
401+
return success();
402+
}
403+
};
404+
} // namespace
405+
364406
namespace {
365407
class DecomposeAtenZeroOp
366408
: public OpRewritePattern<AtenZeroOp> {
@@ -5289,6 +5331,7 @@ class DecomposeComplexOpsPass
52895331
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
52905332
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
52915333
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
5334+
addPatternIfTargetOpIsIllegal<DecomposeAtenGluOp>(patterns);
52925335
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
52935336
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
52945337
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
427427
target.addIllegalOp<AtenHardsigmoidOp>();
428428
target.addIllegalOp<AtenRelu6Op>();
429429
target.addIllegalOp<AtenEluOp>();
430+
target.addIllegalOp<AtenGluOp>();
430431
target.addIllegalOp<AtenHardswishOp>();
431432
target.addIllegalOp<AtenSoftplusOp>();
432433
target.addIllegalOp<AtenSiluOp>();

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,12 @@ def aten〇relu6〡shape(self: List[int]) -> List[int]:
167167
def aten〇round〡shape(self: List[int]) -> List[int]:
168168
return upstream_shape_functions.unary(self)
169169

170+
def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
171+
if dim < 0:
172+
dim += len(self)
173+
assert self[dim] % 2 == 0, "glu's dim size must be multiply of 2"
174+
return self[:dim] + [self[dim] // 2] + self[dim+1:]
175+
170176
def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]:
171177
return upstream_shape_functions.unary(self)
172178

@@ -1932,6 +1938,11 @@ def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
19321938
self_rank, self_dtype = self_rank_dtype
19331939
return self_dtype
19341940

1941+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(100,)], dim=0))
1942+
def aten〇glu〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1) -> int:
1943+
self_rank, self_dtype = self_rank_dtype
1944+
return self_dtype
1945+
19351946
@check_dtype_function(
19361947
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype), "sum") for dtype in _SORTED_TORCH_TYPES])
19371948
def aten〇scatter_reduce〇two〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], reduce: str, include_self: bool = True) -> int:

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def emit_with_mutating_variants(key, **kwargs):
354354
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
355355
emit("aten::view_as_real : (Tensor) -> (Tensor)")
356356
emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)")
357+
emit("aten::glu : (Tensor, int) -> (Tensor)")
357358

358359
# Ops with dynamic number of outputs
359360
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")

python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3685,3 +3685,21 @@ def forward(self, x):
36853685
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt8Module())
36863686
def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils):
36873687
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8))
3688+
3689+
# ==============================================================================
3690+
3691+
class GluStaticModule(torch.nn.Module):
3692+
def __init__(self):
3693+
super().__init__()
3694+
3695+
@export
3696+
@annotate_args([
3697+
None,
3698+
([3, 24, 5], torch.float32, True)
3699+
])
3700+
def forward(self, x):
3701+
return torch.ops.aten.glu(x, dim=1)
3702+
3703+
@register_test_case(module_factory=lambda: GluStaticModule())
3704+
def GluStaticModule_basic(module, tu: TestUtils):
3705+
module.forward(tu.rand(3, 24, 5))

0 commit comments

Comments
 (0)