@@ -6382,6 +6382,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
63826382" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
63836383" return %0 : !torch.list<int>\n"
63846384" }\n"
6385+ " func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
6386+ " %none = torch.constant.none\n"
6387+ " %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n"
6388+ " %int0 = torch.constant.int 0\n"
6389+ " %int2 = torch.constant.int 2\n"
6390+ " %int1 = torch.constant.int 1\n"
6391+ " %0 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
6392+ " %1 = torch.prim.If %0 -> (!torch.int) {\n"
6393+ " %13 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
6394+ " %14 = torch.aten.add.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n"
6395+ " torch.prim.If.yield %14 : !torch.int\n"
6396+ " } else {\n"
6397+ " torch.prim.If.yield %arg1 : !torch.int\n"
6398+ " }\n"
6399+ " %2 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
6400+ " %3 = torch.aten.remainder.int %2, %int2 : !torch.int, !torch.int -> !torch.int\n"
6401+ " %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
6402+ " torch.prim.If %4 -> () {\n"
6403+ " torch.prim.If.yield\n"
6404+ " } else {\n"
6405+ " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
6406+ " torch.prim.If.yield\n"
6407+ " }\n"
6408+ " %5 = torch.aten.slice.t %arg0, %none, %1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
6409+ " %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
6410+ " %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n"
6411+ " %8 = torch.prim.ListConstruct %7 : (!torch.int) -> !torch.list<int>\n"
6412+ " %9 = torch.aten.add.t %5, %8 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
6413+ " %10 = torch.aten.add.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n"
6414+ " %11 = torch.aten.slice.t %arg0, %10, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
6415+ " %12 = torch.aten.add.t %9, %11 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
6416+ " return %12 : !torch.list<int>\n"
6417+ " }\n"
63856418" func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
63866419" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
63876420" return %0 : !torch.list<int>\n"
@@ -8863,6 +8896,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
88638896" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
88648897" return %0#1 : !torch.int\n"
88658898" }\n"
8899+ " func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
8900+ " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
8901+ " return %0#1 : !torch.int\n"
8902+ " }\n"
88668903" func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n"
88678904" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
88688905" return %0#1 : !torch.int\n"
0 commit comments