Skip to content

Commit 1d41f7b

Browse files
authored
Rework AtenEmptyStridedOp checks (#2537)
Now using Value instead of Ints. Trades compile failure for a runtime assert
1 parent 4199fef commit 1d41f7b

File tree

5 files changed

+195
-129
lines changed

5 files changed

+195
-129
lines changed

e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@
764764
"NewEmptyModuleNonDefaultIntDtype_basic",
765765
"NewEmptyStridedModuleDefaultDtype_basic",
766766
"EmptyStridedModule_basic",
767+
"EmptyStridedSizeIntStrideModule_basic",
767768
"PermuteModule_basic",
768769
"PermuteNegativeIndexModule_basic",
769770
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
@@ -1440,6 +1441,7 @@
14401441
"UniformStaticShapeModule_basic",
14411442
"AtenEmbeddingBagStaticModule_basic",
14421443
"EmptyStridedModule_basic",
1444+
"EmptyStridedSizeIntStrideModule_basic",
14431445
"ElementwiseBitwiseAndScalarInt64Module_basic",
14441446
"ElementwiseBitwiseAndScalarInt32Module_basic",
14451447
"ElementwiseBitwiseAndScalarInt8Module_basic",

include/torch-mlir/Dialect/Torch/Utils/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) {
104104
return isAssumingStrictSymbolicShapes(builder.getBlock());
105105
}
106106

107+
// Helper function for AtenEmptyStrided and friends that checks if the stride
108+
// values are default or not. Throws a runtime assert if not.
109+
LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter,
110+
Value opSize, Value opStride,
111+
Location loc);
112+
107113
} // namespace Torch
108114
} // namespace torch
109115
} // namespace mlir

0 commit comments

Comments
 (0)