diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index ed6e135c2ed4..c5f735c42679 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -27,6 +27,64 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { +// Runs an inclusive prefix sum along the middle dimension of a rank-3 tensor. +// The input shape is [outer, scan, inner]; the binary lifting loop keeps the +// implementation rank-independent after aten.cumsum is reshaped into this form. +static Value emitInclusiveScanByPowersOfTwo(Value running, + ConversionPatternRewriter &rewriter, + Location loc) { + RankedTensorType runningType = cast(running.getType()); + SmallVector runningShape = + makeShapeTorchCompatible(runningType.getShape()); + int64_t outer = runningShape[0]; + int64_t scanDimSize = runningShape[1]; + int64_t inner = runningShape[2]; + Type elementType = runningType.getElementType(); + + Value zero = arith::ConstantOp::create(rewriter, loc, + rewriter.getZeroAttr(elementType)); + + SmallVector sliceOffsets(3, rewriter.getIndexAttr(0)); + SmallVector sliceSizes = {rewriter.getIndexAttr(outer), + rewriter.getIndexAttr(scanDimSize), + rewriter.getIndexAttr(inner)}; + SmallVector sliceStrides(3, rewriter.getIndexAttr(1)); + + for (int64_t offset = 1; offset < scanDimSize; offset <<= 1) { + SmallVector lowPad = {0, offset, 0}; + SmallVector highPad = {0, 0, 0}; + Type paddedType = + tensor::PadOp::inferResultType(runningType, lowPad, highPad); + SmallVector lowPadValues = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(offset), + rewriter.getIndexAttr(0)}; + SmallVector highPadValues(3, rewriter.getIndexAttr(0)); + Value padded = tensor::PadOp::create(rewriter, loc, paddedType, running, + lowPadValues, highPadValues, zero); + + Value shifted = + tensor::ExtractSliceOp::create(rewriter, loc, runningType, padded, + sliceOffsets, sliceSizes, sliceStrides); + + running = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, {running, shifted}, elementType, + [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) { + Value result; + if (isa(elementType)) + result = arith::AddFOp::create(builder, loc, payloadArgs[0], + payloadArgs[1]); + else if (isa(elementType)) + result = arith::AddIOp::create(builder, loc, payloadArgs[0], + payloadArgs[1]); + else + llvm_unreachable("unsupported cumsum element type"); + linalg::YieldOp::create(builder, loc, result); + }); + } + + return running; +} + // Aten max.dim (min.dim) lowering represents the MaxDimOp (MinDimOp) as an // linalg.indexed_generic op, producing two output buffers. // @@ -812,6 +870,89 @@ class ConvertReductionOp : public ConversionPattern { }; } // namespace +namespace { +class ConvertAtenCumsumOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenCumsumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure(op, + "only static tensor shapes supported"); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, "dim must be constant"); + dim = toPositiveDim(dim, selfType.getRank()); + if (!isValidDim(dim, selfType.getRank())) + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + + auto resultType = dyn_cast( + getTypeConverter()->convertType(op.getType())); + if (!resultType || !resultType.hasStaticShape()) + return rewriter.notifyMatchFailure(op, "expected static ranked result"); + + Type resultElementType = resultType.getElementType(); + if (!isa(resultElementType)) + return rewriter.notifyMatchFailure( + op, "only floating point and integer element types supported"); + + if (selfType.getElementType() != resultElementType) + self = torch_to_linalg::convertTensorToElementType(rewriter, loc, self, + resultElementType); + + SmallVector inputShape = + makeShapeTorchCompatible(selfType.getShape()); + int64_t scanDimSize = inputShape[dim]; + + int64_t outer = 1; + for (int64_t i = 0; i < dim; ++i) + outer *= inputShape[i]; + int64_t inner = 1; + for (int64_t i = dim + 1, e = inputShape.size(); i < e; ++i) + inner *= inputShape[i]; + + SmallVector scanShape = {outer, scanDimSize, inner}; + auto scanType = RankedTensorType::get(makeShapeLLVMCompatible(scanShape), + resultElementType); + auto shapeType = RankedTensorType::get({3}, rewriter.getIntegerType(64)); + SmallVector scanShapeValues; + for (int64_t size : scanShape) { + scanShapeValues.push_back(arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(size))); + } + Value scanShapeTensor = tensor::FromElementsOp::create( + rewriter, loc, shapeType, scanShapeValues); + Value running = tensor::ReshapeOp::create(rewriter, loc, scanType, self, + scanShapeTensor) + .getResult(); + + running = emitInclusiveScanByPowersOfTwo(running, rewriter, loc); + + auto resultShapeType = + RankedTensorType::get({resultType.getRank()}, rewriter.getI64Type()); + SmallVector resultShapeValues; + for (int64_t size : resultType.getShape()) { + resultShapeValues.push_back(arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(size))); + } + Value resultShapeTensor = tensor::FromElementsOp::create( + rewriter, loc, resultShapeType, resultShapeValues); + Value result = tensor::ReshapeOp::create(rewriter, loc, resultType, running, + resultShapeTensor) + .getResult(); + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, bool allowNonFinites) { @@ -837,5 +978,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); patterns.add(typeConverter, context, allowNonFinites); + patterns.add(typeConverter, context); } diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 25a8cdc9f056..454b1e5050f8 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -1107,6 +1107,25 @@ func.func @torch.ops.aten.anydim$basic(%arg0: tensor<1x16x26x26xi1>) -> !torch.v // ----- +// CHECK-LABEL: func.func @torch.aten.cumsum$to_builtin_user +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.reshape +// CHECK: tensor.pad +// CHECK: tensor.extract_slice +// CHECK: linalg.generic +// CHECK: arith.addf +// CHECK: tensor.reshape +// CHECK-NOT: torch.aten.cumsum +func.func @torch.aten.cumsum$to_builtin_user(%arg0: !torch.vtensor<[2,3],f32>) -> tensor<2x3xf32> { + %dim = torch.constant.int 1 + %none = torch.constant.none + %0 = torch.aten.cumsum %arg0, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32> + %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> +} + +// ----- + // Per PyTorch docs, torch.cat allows "a 1-D empty tensor with size (0,)" // alongside operands of any rank. The linalg lowering must skip these. // CHECK-LABEL: func.func @torch.aten.cat$rank1_empty