Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 79 additions & 4 deletions lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,26 @@ LogicalResult getListFromTensor(Value value, SmallVector<OpFoldResult> &vals) {
return success();
}

// aten.cat of 1D tensors: recurse into each element.
if (auto catOp = value.getDefiningOp<Torch::AtenCatOp>()) {
int64_t catDim;
if (matchPattern(catOp.getDim(), m_TorchConstantInt(&catDim)) &&
catDim == 0) {
SmallVector<Value> tensors;
if (succeeded(getListOperands(catOp.getTensors(), tensors))) {
SmallVector<OpFoldResult> catElements;
if (llvm::all_of(tensors,
[&](Value t) {
return succeeded(getListFromTensor(t, catElements));
}) &&
(int64_t)catElements.size() <= kMaxFold) {
vals.append(catElements.begin(), catElements.end());
return success();
}
}
}
}

// Last supported case: ValueTensorLiteralOp
auto literalOp = value.getDefiningOp<Torch::ValueTensorLiteralOp>();
if (!literalOp)
Expand Down Expand Up @@ -357,6 +377,60 @@ class PropagateAtenIndexSelectPattern
};
} // namespace

namespace {
// Fold `aten.select.int(1d_tensor, 0, const_idx)` by extracting the i-th
// scalar element via getListFromTensor (which handles literals, unsqueeze,
// NumToTensor, cat, etc.).
class PropagateAtenSelectIntPattern : public OpRewritePattern<AtenSelectIntOp> {
public:
using OpRewritePattern<AtenSelectIntOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSelectIntOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "requires a constant dim");

int64_t idx;
if (!matchPattern(op.getIndex(), m_TorchConstantInt(&idx)))
return rewriter.notifyMatchFailure(op, "requires a constant index");

auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
if (!selfTy.hasSizes() || selfTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected 1D input");

int64_t selfRank = selfTy.getSizes().size();
dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "invalid dim");

int64_t dimLength = selfTy.getSizes()[dim];
if (dimLength == kUnknownSize)
return rewriter.notifyMatchFailure(op, "unknown dim length");

idx = toPositiveDim(idx, dimLength);
if (!isValidDim(idx, dimLength))
return rewriter.notifyMatchFailure(op, "invalid index");

SmallVector<OpFoldResult> elements;
if (failed(getListFromTensor(op.getSelf(), elements)) ||
idx >= (int64_t)elements.size())
return rewriter.notifyMatchFailure(op, "cannot decompose source tensor");

SmallVector<Value, 1> materialized;
SmallVector<OpFoldResult, 1> single = {elements[idx]};
if (failed(materializeFolds(b, single, materialized)))
return failure();

auto resultTy = cast<ValueTensorType>(op.getType());
rewriter.replaceOp(
op, PrimNumToTensorScalarOp::create(b, resultTy, materialized.front()));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the PrimNumToTensorScalarOp is expected to have a rank zero result in some places (e.g., some folding patterns). I think the abstract interp function is just return []. If we know the output is always shaped like [1] it might be better to use a full op just in case this causes issues anywhere.

return success();
}
};
} // namespace

namespace {
// Conversion attempts to handle some common propagatable slice cases, namely
// splatted values, no-op slices, known list of values, or any case where a
Expand Down Expand Up @@ -1507,10 +1581,11 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
// are positive so floor divide should be a sufficient scalar replacement.
patterns.insert<
PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
PropagateAtenSelectIntPattern, PropagateAtenItemPattern,
PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern,
PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern,
PropagateAtenBroadcastToPattern, PropagateAtenTransposeIntPattern,
PropagateAtenToDtypePattern,
PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
Expand Down
78 changes: 78 additions & 0 deletions test/Dialect/Torch/scalarize-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,81 @@ func.func @transpose$prop_3d_m1_0(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
%12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list<int>
return %7 : !torch.vtensor<[2,2,2],si64>
}

// -----

// select.int on cat of constants and dynamic — folds constant elements.
// CHECK-LABEL: @select_int_from_cat_fold
func.func @select_int_from_cat_fold(%arg0: !torch.vtensor<[1,?,2048],f16>, %arg1: !torch.int) -> !torch.vtensor<[?,?,?,?],f16> {
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16
// CHECK-DAG: %[[INT128:.*]] = torch.constant.int 128
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT1]], %arg1, %[[INT16]], %[[INT128]]
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %arg0, %[[LIST]]
// CHECK: return %[[RESULT]]
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%c1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%c16 = torch.vtensor.literal(dense<16> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%c128 = torch.vtensor.literal(dense<128> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%dyn = torch.prim.NumToTensor.Scalar %arg1 : !torch.int -> !torch.vtensor<[],si64>
%dyn_unsq = torch.aten.unsqueeze %dyn, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%list = torch.prim.ListConstruct %c1, %dyn_unsq, %c16, %c128 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%cat = torch.aten.cat %list, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],si64>
%s0 = torch.aten.select.int %cat, %int0, %int0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%d0 = torch.aten.item %s0 : !torch.vtensor<[1],si64> -> !torch.int
%s1 = torch.aten.select.int %cat, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%d1 = torch.aten.item %s1 : !torch.vtensor<[1],si64> -> !torch.int
%s2 = torch.aten.select.int %cat, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%d2 = torch.aten.item %s2 : !torch.vtensor<[1],si64> -> !torch.int
%s3 = torch.aten.select.int %cat, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%d3 = torch.aten.item %s3 : !torch.vtensor<[1],si64> -> !torch.int
%shape = torch.prim.ListConstruct %d0, %d1, %d2, %d3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%result = torch.aten.reshape %arg0, %shape : !torch.vtensor<[1,?,2048],f16>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f16>
return %result : !torch.vtensor<[?,?,?,?],f16>
}

// -----

// select.int with negative index — selects last element.
// CHECK-LABEL: @select_int_negative_index
func.func @select_int_negative_index(%arg0: !torch.int) -> !torch.list<int> {
// CHECK-DAG: %[[INT128:.*]] = torch.constant.int 128
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT128]]
// CHECK: return %[[LIST]]
%int0 = torch.constant.int 0
%int_neg1 = torch.constant.int -1
%c1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%c128 = torch.vtensor.literal(dense<128> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%dyn = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64>
%dyn_unsq = torch.aten.unsqueeze %dyn, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
%list = torch.prim.ListConstruct %c1, %dyn_unsq, %c128 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%cat = torch.aten.cat %list, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%sel = torch.aten.select.int %cat, %int0, %int_neg1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%result = torch.aten.item %sel : !torch.vtensor<[1],si64> -> !torch.int
%shape = torch.prim.ListConstruct %result : (!torch.int) -> !torch.list<int>
return %shape : !torch.list<int>
}

// -----

// select.int on cat with multi-element sub-tensor.
// cat([vtensor<[2]>, vtensor<[1]>]) produces [3], select at index 1.
// CHECK-LABEL: @select_int_multi_element_subtensor
func.func @select_int_multi_element_subtensor() -> !torch.list<int> {
// CHECK-DAG: %[[INT42:.*]] = torch.constant.int 42
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT42]]
// CHECK: return %[[LIST]]
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%c = torch.vtensor.literal(dense<[10, 42]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
%c2 = torch.vtensor.literal(dense<99> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%list = torch.prim.ListConstruct %c, %c2 : (!torch.vtensor<[2],si64>, !torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%cat = torch.aten.cat %list, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[3],si64>
%sel = torch.aten.select.int %cat, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%result = torch.aten.item %sel : !torch.vtensor<[1],si64> -> !torch.int
%shape = torch.prim.ListConstruct %result : (!torch.int) -> !torch.list<int>
return %shape : !torch.list<int>
}
Loading