@@ -2525,6 +2525,60 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
25252525 return success ();
25262526}
25272527
2528+ template <>
2529+ LogicalResult ConvertAtenOp<AtenUnflattenIntOp>::matchAndRewrite(
2530+ AtenUnflattenIntOp op, OpAdaptor adaptor,
2531+ ConversionPatternRewriter &rewriter) const {
2532+
2533+ // Not a ranked tensor type
2534+ auto selfType = adaptor.getSelf ().getType ().dyn_cast <RankedTensorType>();
2535+ if (!selfType || !selfType.hasStaticShape ())
2536+ return rewriter.notifyMatchFailure (
2537+ op,
2538+ " Only ranked tensor types with static shapes are currently supported" );
2539+
2540+ int64_t selfRank = selfType.getRank ();
2541+ int64_t dim;
2542+
2543+ if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
2544+ return rewriter.notifyMatchFailure (op, " dim must be a Scalar constant" );
2545+
2546+ SmallVector<int64_t > sizes;
2547+ if (!matchPattern (op.getSizes (), m_TorchListOfConstantInts (sizes)))
2548+ return rewriter.notifyMatchFailure (
2549+ op, " Only constant sizes are currently supported" );
2550+
2551+ if (selfRank > 0 && !isValidDim (dim, selfRank))
2552+ return rewriter.notifyMatchFailure (op, " dim is statically invalid" );
2553+
2554+ SmallVector<int64_t > newShape;
2555+ for (auto s :
2556+ llvm::enumerate (makeShapeTorchCompatible (selfType.getShape ()))) {
2557+ int64_t idx = s.index ();
2558+ if (idx < dim || idx > dim) {
2559+ newShape.push_back (s.value ());
2560+ } else {
2561+ auto sum = 1 ;
2562+ for (auto newDims : sizes) {
2563+ newShape.push_back (newDims);
2564+ sum *= newDims;
2565+ }
2566+ if (sum != s.value ())
2567+ return rewriter.notifyMatchFailure (op,
2568+ " sizes mismatch with original dim" );
2569+ }
2570+ }
2571+
2572+ auto newType = RankedTensorType::get (makeShapeLLVMCompatible (newShape),
2573+ selfType.getElementType ());
2574+
2575+ rewriter.replaceOpWithNewOp <tosa::ReshapeOp>(
2576+ op, getTypeConverter ()->convertType (newType), adaptor.getSelf (),
2577+ rewriter.getDenseI64ArrayAttr (newShape));
2578+
2579+ return success ();
2580+ }
2581+
25282582template <>
25292583LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
25302584 AtenPermuteOp op, OpAdaptor adaptor,
@@ -5050,6 +5104,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
50505104 INSERT_ATENOP_PATTERN (AtenBatchNormOp);
50515105 INSERT_ATENOP_PATTERN (AtenNativeLayerNormOp);
50525106 INSERT_ATENOP_PATTERN (AtenFlattenUsingIntsOp);
5107+ INSERT_ATENOP_PATTERN (AtenUnflattenIntOp);
50535108 INSERT_ATENOP_PATTERN (AtenPermuteOp);
50545109 INSERT_ATENOP_PATTERN (AtenLog2Op);
50555110 INSERT_ATENOP_PATTERN (AtenThresholdOp);
0 commit comments