From 205dce6029bed302f354c0bde5d8c5804f214051 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Sat, 2 Mar 2024 17:47:16 -0500 Subject: [PATCH] [mlir][linalg] Add a folder for transpose(fill) -> fill (#83623) This is similar to the existing folder for a linalg.copy. Transposing a filled tensor is the same as filling the destination of the transpose. --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 +++++++++++++++++- mlir/test/Dialect/Linalg/canonicalize.mlir | 14 ++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 919f5130e1760..6954eee93efd1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -815,6 +815,22 @@ struct FoldFillWithCopy : OpRewritePattern { } }; +/// Fold fill with transpose. +struct FoldFillWithTranspose : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + if (auto fillOp = transposeOp.getInput().getDefiningOp()) { + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(), + transposeOp.getDpsInitOperand(0)->get()); + return success(); + } + return failure(); + } +}; + } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -823,7 +839,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, .add, FoldFillWithTensorReshape, - FoldInsertPadIntoFill>(context); + FoldInsertPadIntoFill, FoldFillWithTranspose>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 206d7e9f1ce8d..19cea6c2066c9 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -993,6 +993,20 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor, %arg1 : tenso // ----- +// CHECK-LABEL: func @canonicalize_fill_to_transpose_input( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[ZERO:.+]] = arith.constant 0.0 +// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor) +func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor + %transpose = linalg.transpose ins(%fill : tensor) outs(%arg1 : tensor) permutation = [1, 0] + return %transpose : tensor +} + +// ----- + // CHECK-LABEL: func @broadcast_same_shape( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)