Skip to content

Commit 03b6edc

Browse files
harsh-nodsilvasean
authored andcommitted
Add where, gt, bucketize and reshape ops to Torch dialect
This patch adds the where, gt, bucketize and reshape ops to the Torch dialect. These ops are present in the histogram calibration module. TEST: Successfully lowers ops to Torch dialect in histogram module.
1 parent cfc8de3 commit 03b6edc

File tree

2 files changed

+101
-17
lines changed

2 files changed

+101
-17
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,36 @@ def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [
512512
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
513513
}
514514

515+
def Torch_AtenGtTensorOp : Torch_Op<"aten.gt.Tensor", [
516+
AllowsTypeRefinement,
517+
HasValueSemantics
518+
]> {
519+
let summary = "Generated op for `aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)`";
520+
let arguments = (ins
521+
AnyTorchTensorType:$self,
522+
AnyTorchTensorType:$other
523+
);
524+
let results = (outs
525+
AnyTorchTensorType:$result
526+
);
527+
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
528+
}
529+
530+
def Torch_AtenGt_TensorOp : Torch_Op<"aten.gt_.Tensor", [
531+
IsTrailingUnderscoreInplaceVariant,
532+
AllowsTypeRefinement
533+
]> {
534+
let summary = "Generated op for `aten::gt_.Tensor : (Tensor, Tensor) -> (Tensor)`";
535+
let arguments = (ins
536+
AnyTorchTensorType:$self,
537+
AnyTorchTensorType:$other
538+
);
539+
let results = (outs
540+
AnyTorchTensorType:$result
541+
);
542+
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
543+
}
544+
515545
def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [
516546
AllowsTypeRefinement,
517547
HasValueSemantics
@@ -1071,22 +1101,6 @@ def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
10711101
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
10721102
}
10731103

1074-
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
1075-
AllowsTypeRefinement,
1076-
HasValueSemantics
1077-
]> {
1078-
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
1079-
let arguments = (ins
1080-
AnyTorchTensorType:$condition,
1081-
AnyTorchTensorType:$self,
1082-
AnyTorchTensorType:$other
1083-
);
1084-
let results = (outs
1085-
AnyTorchTensorType:$result
1086-
);
1087-
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
1088-
}
1089-
10901104
def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
10911105
AllowsTypeRefinement,
10921106
HasValueSemantics
@@ -1942,6 +1956,23 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
19421956
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
19431957
}
19441958

1959+
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
1960+
AllowsTypeRefinement,
1961+
HasValueSemantics
1962+
]> {
1963+
let summary = "Generated op for `aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)`";
1964+
let arguments = (ins
1965+
AnyTorchTensorType:$self,
1966+
AnyTorchTensorType:$boundaries,
1967+
Torch_BoolType:$out_int32,
1968+
Torch_BoolType:$right
1969+
);
1970+
let results = (outs
1971+
AnyTorchTensorType:$result
1972+
);
1973+
let assemblyFormat = "$self `,` $boundaries `,` $out_int32 `,` $right attr-dict `:` type($self) `,` type($boundaries) `,` type($out_int32) `,` type($right) `->` type($result)";
1974+
}
1975+
19451976
def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
19461977
AllowsTypeRefinement
19471978
]> {
@@ -2002,6 +2033,25 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [
20022033
let assemblyFormat = "$weight `,` $indices `,` $padding_idx `,` $scale_grad_by_freq `,` $sparse attr-dict `:` type($weight) `,` type($indices) `,` type($padding_idx) `,` type($scale_grad_by_freq) `,` type($sparse) `->` type($result)";
20032034
}
20042035

2036+
def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
2037+
AllowsTypeRefinement,
2038+
HasValueSemantics
2039+
]> {
2040+
let summary = "Generated op for `aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
2041+
let arguments = (ins
2042+
AnyTorchTensorType:$self,
2043+
TorchOptionalIntType:$dtype,
2044+
TorchOptionalIntType:$layout,
2045+
TorchOptionalDeviceType:$device,
2046+
TorchOptionalBoolType:$pin_memory,
2047+
TorchOptionalIntType:$memory_format
2048+
);
2049+
let results = (outs
2050+
AnyTorchTensorType:$result
2051+
);
2052+
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
2053+
}
2054+
20052055
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
20062056
AllowsTypeRefinement,
20072057
HasValueSemantics
@@ -2139,6 +2189,20 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
21392189
let assemblyFormat = "$self `,` $repeats attr-dict `:` type($self) `,` type($repeats) `->` type($result)";
21402190
}
21412191

2192+
def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
2193+
AllowsTypeRefinement
2194+
]> {
2195+
let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`";
2196+
let arguments = (ins
2197+
AnyTorchTensorType:$self,
2198+
TorchIntListType:$shape
2199+
);
2200+
let results = (outs
2201+
AnyTorchTensorType:$result
2202+
);
2203+
let assemblyFormat = "$self `,` $shape attr-dict `:` type($self) `,` type($shape) `->` type($result)";
2204+
}
2205+
21422206
def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
21432207
AllowsTypeRefinement
21442208
]> {
@@ -2312,6 +2376,22 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
23122376
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
23132377
}
23142378

2379+
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
2380+
AllowsTypeRefinement,
2381+
HasValueSemantics
2382+
]> {
2383+
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
2384+
let arguments = (ins
2385+
AnyTorchTensorType:$condition,
2386+
AnyTorchTensorType:$self,
2387+
AnyTorchTensorType:$other
2388+
);
2389+
let results = (outs
2390+
AnyTorchTensorType:$result
2391+
);
2392+
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
2393+
}
2394+
23152395
def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
23162396
AllowsTypeRefinement
23172397
]> {

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def emit_with_mutating_variants(key, **kwargs):
454454
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
455455
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
456456
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
457+
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
457458
"aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)",
458459
"aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
459460
"aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
@@ -479,7 +480,6 @@ def emit_with_mutating_variants(key, **kwargs):
479480
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
480481
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
481482
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
482-
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
483483
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
484484
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
485485
emit("aten::gelu : (Tensor) -> (Tensor)")
@@ -550,10 +550,12 @@ def emit_with_mutating_variants(key, **kwargs):
550550
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
551551
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
552552
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
553+
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
553554
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
554555
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
555556
emit("aten::detach : (Tensor) -> (Tensor)")
556557
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
558+
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
557559
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
558560
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
559561
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
@@ -563,6 +565,7 @@ def emit_with_mutating_variants(key, **kwargs):
563565
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
564566
emit("aten::numel : (Tensor) -> (int)")
565567
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
568+
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
566569
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
567570
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
568571
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
@@ -574,6 +577,7 @@ def emit_with_mutating_variants(key, **kwargs):
574577
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
575578
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
576579
emit("aten::view : (Tensor, int[]) -> (Tensor)")
580+
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
577581
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
578582
emit("aten::len.Tensor : (Tensor) -> (int)")
579583
emit("aten::cpu : (Tensor) -> (Tensor)")

0 commit comments

Comments
 (0)