From 98e838a890191b9250ad33741a1c121a9591caa3 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 18 Oct 2024 13:02:03 -0700 Subject: [PATCH] [mlir] Do not bufferize parallel_insert_slice dest to read for full slices (#112761) In the insert_slice bufferization interface implementation, the destination tensor is not considered read if the full tensor is overwritten by the slice. This PR adds the same check for tensor.parallel_insert_slice. Adds two new StaticValueUtils: - `isAllConstantIntValue` checks if an array of `OpFoldResult` are all equal to a passed `int64_t` value. - `areConstantIntValues` checks if an array of `OpFoldResult` are all equal to a passed array of `int64_t` values. fixes https://github.com/llvm/llvm-project/issues/112435 --------- Signed-off-by: Max Dawkins --- .../mlir/Dialect/Utils/StaticValueUtils.h | 6 +++ .../BufferizableOpInterfaceImpl.cpp | 54 +++++++++---------- .../Transforms/PackAndUnpackPatterns.cpp | 5 -- mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 15 +++++- .../Dialect/Tensor/one-shot-bufferize.mlir | 15 ++++++ 5 files changed, 62 insertions(+), 33 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index ba4f084d3efd1..4d7aa1ae17fdb 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -92,6 +92,12 @@ getConstantIntValues(ArrayRef ofrs); /// Return true if `ofr` is constant integer equal to `value`. bool isConstantIntValue(OpFoldResult ofr, int64_t value); +/// Return true if all of `ofrs` are constant integers equal to `value`. +bool areAllConstantIntValue(ArrayRef ofrs, int64_t value); +/// Return true if all of `ofrs` are constant integers equal to the +/// corresponding value in `values`. +bool areConstantIntValues(ArrayRef ofrs, + ArrayRef values); /// Return true if ofr1 and ofr2 are the same integer constant attribute /// values or the same SSA value. Ignore integer bitwitdh and type mismatch diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 87464ccb71720..c2b8614148bf2 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" @@ -636,6 +637,28 @@ struct InsertOpInterface } }; +template +static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp, + OpOperand &opOperand) { + // The source is always read. + if (opOperand == insertSliceOp.getSourceMutable()) + return true; + + // For the destination, it depends... + assert(opOperand == insertSliceOp.getDestMutable() && "expected dest"); + + // Dest is not read if it is entirely overwritten. E.g.: + // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> + bool allOffsetsZero = + llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex); + RankedTensorType destType = insertSliceOp.getDestType(); + bool sizesMatchDestSizes = + areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape()); + bool allStridesOne = + areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1); + return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); +} + /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under /// certain circumstances, this op can also be a no-op. /// @@ -646,32 +669,8 @@ struct InsertSliceOpInterface tensor::InsertSliceOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - auto insertSliceOp = cast(op); - RankedTensorType destType = insertSliceOp.getDestType(); - - // The source is always read. - if (opOperand == insertSliceOp.getSourceMutable()) - return true; - - // For the destination, it depends... - assert(opOperand == insertSliceOp.getDestMutable() && "expected dest"); - - // Dest is not read if it is entirely overwritten. E.g.: - // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> - bool allOffsetsZero = - llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) { - return isConstantIntValue(ofr, 0); - }); - bool sizesMatchDestSizes = llvm::all_of( - llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) { - return getConstantIntValue(it.value()) == - destType.getDimSize(it.index()); - }); - bool allStridesOne = - llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return isConstantIntValue(ofr, 1); - }); - return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); + return insertSliceOpRequiresRead(cast(op), + opOperand); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -931,7 +930,8 @@ struct ParallelInsertSliceOpInterface bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return true; + return insertSliceOpRequiresRead(cast(op), + opOperand); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp index 995486c87771a..3566714c6529e 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp @@ -16,11 +16,6 @@ namespace mlir { namespace tensor { namespace { -static bool areAllConstantIntValue(ArrayRef ofrs, int64_t value) { - return llvm::all_of( - ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); -} - /// Returns the number of shape sizes that is either dynamic or greater than 1. static int64_t getNumGtOneDims(ArrayRef shape) { return llvm::count_if( diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 547d120404aba..3eb6215a7a0b9 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" namespace mlir { @@ -131,12 +132,24 @@ getConstantIntValues(ArrayRef ofrs) { return res; } -/// Return true if `ofr` is constant integer equal to `value`. bool isConstantIntValue(OpFoldResult ofr, int64_t value) { auto val = getConstantIntValue(ofr); return val && *val == value; } +bool areAllConstantIntValue(ArrayRef ofrs, int64_t value) { + return llvm::all_of( + ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); +} + +bool areConstantIntValues(ArrayRef ofrs, + ArrayRef values) { + if (ofrs.size() != values.size()) + return false; + std::optional> constOfrs = getConstantIntValues(ofrs); + return constOfrs && llvm::equal(constOfrs.value(), values); +} + /// Return true if ofr1 and ofr2 are the same integer constant attribute values /// or the same SSA value. /// Ignore integer bitwidth and type mismatch that come from the fact there is diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index e2169fe1404c8..dc4306b8316ab 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso // ----- +// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place +// CHECK-NOT: memref.alloc() +func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) { + %fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32> + } + } {mapping = [#gpu.thread]} + return %3 : tensor<2xf32> +} + +// ----- + // This test case could bufferize in-place with a better analysis. However, it // is simpler to let the canonicalizer fold away the tensor.insert_slice.