Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,11 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
pm.addPass(zlow::createZLowRewritePass());
// Late generation of code for stick/unstick, needed to be after a
// ZLowRewrite pass.
if (!nnpaDisableCompilerStickUnstick)
pm.addPass(zlow::createZLowStickExpansionPass(enableParallel));
bool expansion = !nnpaDisableCompilerStickUnstick;
bool allocNormalization = isCompatibleWithNNPALevel(NNPALevel::M15);
if (expansion || allocNormalization)
pm.addPass(zlow::createZLowStickOptimizationPass(
expansion, allocNormalization, enableParallel));
pm.addPass(mlir::createCanonicalizerPass());
// Normalize MemRefs.
normalizeMemRefsPasses(pm);
Expand Down
9 changes: 7 additions & 2 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,9 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
convertZTensorToMemRefType(*op->result_type_begin());

// Allocate a buffer for the result MemRef.
// Since march=arch15, HW Stick/Unstick may be more efficient with 4k
// allocated pages for CPU data as well, or ensure that allocated data here
// is 4K aligned.
Value alloc = nullptr;
if (isNHWCLayout(layout)) {
if (!nnpaDisableCompilerStickUnstick) {
Expand All @@ -836,8 +839,10 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
}
}
if (alloc == nullptr)
alloc = insertAllocForZMemRef(
zMemRefType, shapeHelper.getOutputDims(), op, rewriter);
// Memory for output (which is not in stick format) so use normal
// alignment.
alloc = insertAllocForZMemRef(zMemRefType, shapeHelper.getOutputDims(),
op, rewriter, MemRefBuilder::defaultAlign);

// Emit a ZLow operation.
rewriter.create<ZLowUnstickOp>(loc, input, alloc, layout);
Expand Down
2 changes: 1 addition & 1 deletion src/Accelerators/NNPA/NNPAAccelerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void NNPAAccelerator::registerPasses(int optLevel) const {
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::zlow::createZLowStickExpansionPass();
return onnx_mlir::zlow::createZLowStickOptimizationPass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
Expand Down
3 changes: 2 additions & 1 deletion src/Accelerators/NNPA/Pass/NNPAPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ namespace zlow {
std::unique_ptr<mlir::Pass> createZLowRewritePass();

/// Add pass for rewriting ZLow ops.
std::unique_ptr<mlir::Pass> createZLowStickExpansionPass(
std::unique_ptr<mlir::Pass> createZLowStickOptimizationPass(
bool enableStickExpansion = true, bool enableAllocNormalization = false,
bool enableParallel = false);

/// Add pass for rewriting ZLow ops.
Expand Down
161 changes: 131 additions & 30 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@

//===--- ZLowStickExpansion.cpp - ZLow Stick/Unstick Expansion Patterns ---===//
//
// Copyright 2024 The IBM Research Authors.
// Copyright 2024-2025 The IBM Research Authors.
//
// =============================================================================
//
// This pass implements optimizations for ZLow operations, by substituting calls
// to stick / unstick with explict code to perform the transformation, when
// applicable.
//
// This pass also boost the alignment of zlow.stick inputs to 4k when stick are
// sent to NNPA for march=M15.
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/Debug.h"
Expand All @@ -26,6 +29,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp"
#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp"
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
Expand All @@ -50,20 +54,61 @@

using namespace mlir;

// Ensure that alignment is a multiple of gAlignment.
static LogicalResult setAlignOfAlloc(memref::AllocOp allocOp) {
auto alignmentAttr = allocOp.getAlignment();
int64_t align = alignmentAttr ? alignmentAttr.value() : 1;
if (align != 0 && align % gAlignment == 0) {
LLVM_DEBUG(llvm::dbgs() << " alloc is properly aligned\n");
return failure();
}
align = align < 1 ? 1 : align; // Avoid negative number.
align = ((align + gAlignment - 1) / gAlignment) * gAlignment;
LLVM_DEBUG(llvm::dbgs() << " alloc not aligned ->" << align << "\n");
::std::optional<uint64_t> attrValue(align);
allocOp.setAlignment(attrValue);
return success();
}

namespace onnx_mlir {
namespace zlow {

using MDBuilder = MultiDialectBuilder<IndexExprBuilderForKrnl, KrnlBuilder,
MathBuilder, MemRefBuilder, VectorBuilder, AffineBuilder, SCFBuilder>;

// Resize the alignment of the allocate that is the output to Unstick Op, which
// is beneficial for unstick operations that executes on the NNPA accelerator of
// march = arch15.
class UnstickOutputAllocPattern : public OpRewritePattern<ZLowUnstickOp> {
public:
UnstickOutputAllocPattern(MLIRContext *context)
: OpRewritePattern<ZLowUnstickOp>(context, 1) {}

using OpRewritePattern<ZLowUnstickOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
ZLowUnstickOp unstickOp, PatternRewriter &rewriter) const override {
return resizeAllocOfUnstickOutput(unstickOp);
}

// Shared functions with other classes in this file.
static LogicalResult resizeAllocOfUnstickOutput(ZLowUnstickOp unstickOp) {
memref::AllocOp allocOfOutputOp =
unstickOp.getOut().getDefiningOp<memref::AllocOp>();
assert(allocOfOutputOp && "unstick output should always be allocated");
return setAlignOfAlloc(allocOfOutputOp);
}
};

/// Expand unstick operation to compiler generated code for suitable patterns,
/// aka all but the 1D and 2DS data layouts at this time.
class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
public:
UnstickExpansionPattern(MLIRContext *context, bool enableParallelism = false)
UnstickExpansionPattern(MLIRContext *context, bool enableAllocNormalization,
bool enableParallelism)
: OpRewritePattern<ZLowUnstickOp>(context, 1),
enableAllocNormalization(enableAllocNormalization),
enableParallel(enableParallelism) {}

bool enableAllocNormalization = true;
bool enableParallel = true;

using OpRewritePattern<ZLowUnstickOp>::OpRewritePattern;
Expand All @@ -82,7 +127,10 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
layout.getValue().equals_insensitive("NHWC")) {
return generateUnstickCodeNoBuffer(rewriter, unstickOp);
}
// Otherwise, we don't replace and keep the zdnn call.
// Otherwise, we don't replace and keep the zdnn call. Normalize alloc if
// necessary.
if (enableAllocNormalization)
return UnstickOutputAllocPattern::resizeAllocOfUnstickOutput(unstickOp);
return failure();
}

Expand Down Expand Up @@ -127,14 +175,43 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
}
};

/// Expand stick operation to compiler generated code for suitable patterns, aka
/// all but the 1D and 2DS data layouts at this time.
// Resize the alignment of the allocate that is the input to Stick Op, which is
// beneficial for stick operations that executes on the NNPA accelerator of
// march = arch15.
class StickInputAllocPattern : public OpRewritePattern<ZLowStickOp> {
public:
StickInputAllocPattern(MLIRContext *context)
: OpRewritePattern<ZLowStickOp>(context, 1) {}

using OpRewritePattern<ZLowStickOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
ZLowStickOp stickOp, PatternRewriter &rewriter) const override {
return resizeAllocOfStickInput(stickOp);
}

// Shared functions with other classes in this file.
static LogicalResult resizeAllocOfStickInput(ZLowStickOp stickOp) {
memref::AllocOp allocOfXOp =
stickOp.getX().getDefiningOp<memref::AllocOp>();
if (!allocOfXOp) {
LLVM_DEBUG(llvm::dbgs() << " stick input had no alloc (parameter)\n");
return failure();
}
return setAlignOfAlloc(allocOfXOp);
}
};

/// Expand stick operation to compiler generated code for suitable patterns,
/// aka all but the 1D and 2DS data layouts at this time.
class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
public:
StickExpansionPattern(MLIRContext *context, bool enableParallelism = false)
StickExpansionPattern(MLIRContext *context, bool enableAllocNormalization,
bool enableParallelism)
: OpRewritePattern<ZLowStickOp>(context, 1),
enableAllocNormalization(enableAllocNormalization),
enableParallel(enableParallelism) {}

bool enableAllocNormalization = true;
bool enableParallel = true;

using OpRewritePattern<ZLowStickOp>::OpRewritePattern;
Expand All @@ -146,22 +223,25 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {

// Generic way to handle all formats listed below.
// Did not add the HWCK as this is typically for constants and want to
// preserve the high level constant propagation of constant values into the
// Convolution filters.
// preserve the high level constant propagation of constant values into
// the Convolution filters.
if (layout.getValue().equals_insensitive("4D") ||
layout.getValue().equals_insensitive("3D") ||
layout.getValue().equals_insensitive("2D") ||
layout.getValue().equals_insensitive("3DS") ||
layout.getValue().equals_insensitive("NHWC")) {
return generateStickCodeNoBuffer(rewriter, stickOp);
}
// Otherwise, we don't replace and keep the zdnn call.
// Otherwise, we don't replace and keep the zdnn call. Normalize alloc if
// necessary.
if (enableAllocNormalization)
return StickInputAllocPattern::resizeAllocOfStickInput(stickOp);
return failure();
}

// Version without buffer, more like zdnn.
// The only requirement for this code to generate the proper code is that E1
// is been sticked by 64.
// The only requirement for this code to generate the proper code is that
// E1 is been sticked by 64.
LogicalResult generateStickCodeNoBuffer(
PatternRewriter &rewriter, ZLowStickOp stickOp) const {
Operation *op = stickOp.getOperation();
Expand Down Expand Up @@ -215,8 +295,8 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
// If outputDims[E1] is constant and < 64, then T1 is 1 (ok), and we can
// iterate over fewer values in the SIMD loop.
IndexExpr simdLoopUB = lit64;
// Unrolling of SIMD loop: tried 2 and 8, 4 was best. Max is a const as we
// allocate array of that max size.
// Unrolling of SIMD loop: tried 2 and 8, 4 was best. Max is a const as
// we allocate array of that max size.
const int64_t maxUnrollVL = 4;
int64_t unrollVL = maxUnrollVL;
if (outputDims[E1].isLiteral()) {
Expand Down Expand Up @@ -244,10 +324,10 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
loopDefs, lbs, ubs, 0, rank, {}, /*min iter for going parallel*/ 8);
}

// Compute max tiles. It is actually not easy to compute the max number of
// tiles. Since we don't allocate, it is just a "view", we only need to
// index by the "tile size", it is sufficient to assume 2 or more. Tiles are
// 64 elements.
// Compute max tiles. It is actually not easy to compute the max number
// of tiles. Since we don't allocate, it is just a "view", we only need
// to index by the "tile size", it is sufficient to assume 2 or more.
// Tiles are 64 elements.
IndexExpr T = LitIE(2);
DimsExpr reallocTileDims = {T, lit64};
Value allocAsTx64 = create.mem.reinterpretCast(alloc, reallocTileDims);
Expand Down Expand Up @@ -333,37 +413,58 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
/*!
* Function pass that optimizes ZLowIR.
*/
class ZLowStickExpansionPass
: public PassWrapper<ZLowStickExpansionPass, OperationPass<func::FuncOp>> {
class ZLowStickOptimizationPass : public PassWrapper<ZLowStickOptimizationPass,
OperationPass<func::FuncOp>> {

public:
ZLowStickExpansionPass(bool enableParallel)
: PassWrapper<ZLowStickExpansionPass, OperationPass<func::FuncOp>>(),
enableParallel(enableParallel) {}
ZLowStickOptimizationPass() = default;
ZLowStickOptimizationPass(const ZLowStickOptimizationPass &pass)
: PassWrapper<ZLowStickOptimizationPass, OperationPass<func::FuncOp>>() {}
ZLowStickOptimizationPass(bool enableStickExpansion,
bool enableAllocNormalization, bool enableParallel) {
this->enableStickExpansion = enableStickExpansion;
this->enableAllocNormalization = enableAllocNormalization;
this->enableParallel = enableParallel;
}

bool enableParallel;
Option<bool> enableStickExpansion{*this, "enable-stick-expansion",
llvm::cl::desc("Enable stick expansion"), llvm::cl::init(true)};
Option<bool> enableAllocNormalization{*this, "enable-alloc-normalization",
llvm::cl::desc("enable allocation alignment normalization"),
llvm::cl::init(false)};
Option<bool> enableParallel{*this, "enable-parallel",
llvm::cl::desc("Enable parallelization"), llvm::cl::init(false)};

StringRef getArgument() const override { return "zlow-stick-expansion"; }

StringRef getDescription() const override {
return "ZLow Stick/Unstick Ops expansion pass.";
return "ZLow Stick/Unstick Ops optimization pass.";
}

void runOnOperation() override {
Operation *function = getOperation();

ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
patterns.insert<StickExpansionPattern>(&getContext(), enableParallel);
patterns.insert<UnstickExpansionPattern>(&getContext(), enableParallel);

if (enableStickExpansion) {
// Stick expansion also performs normalization to 4k of allocs.
patterns.insert<StickExpansionPattern>(
&getContext(), enableAllocNormalization, enableParallel);
patterns.insert<UnstickExpansionPattern>(
&getContext(), enableAllocNormalization, enableParallel);
} else {
patterns.insert<UnstickOutputAllocPattern>(&getContext());
patterns.insert<StickInputAllocPattern>(&getContext());
}
if (failed(applyPatternsGreedily(function, std::move(patterns))))
return signalPassFailure();
}
};

std::unique_ptr<Pass> createZLowStickExpansionPass(bool enableParallel) {
return std::make_unique<ZLowStickExpansionPass>(enableParallel);
std::unique_ptr<Pass> createZLowStickOptimizationPass(bool enableStickExpansion,
bool enableAllocNormalization, bool enableParallel) {
return std::make_unique<ZLowStickOptimizationPass>(
enableStickExpansion, enableAllocNormalization, enableParallel);
}

} // namespace zlow
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: onnx-mlir-opt --march=z17 --maccel=NNPA --zlow-stick-expansion="enable-stick-expansion=false enable-alloc-normalization=true" %s -split-input-file | FileCheck %s

// -----

// No alloc normalization possible for input arguments
#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)>
func.func @test_no_normalization(%arg0: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> {
%alloc = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map>
"zlow.stick"(%arg0, %alloc) {layout = "3DS"} : (memref<16x8x128xf32>, memref<16x8x128xf16, #map>) -> ()
return %alloc : memref<16x8x128xf16, #map>

// mlir2FileCheck.py
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)>
// CHECK-LABEL: func.func @test_no_normalization
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> {
// CHECK: [[RES_:%.+]] = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map>
// CHECK: }
}
// -----

// Here the value being stickified is from an alloc memref in the model, so normalize to 4k.

#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)>
func.func @test_normalization(%arg0: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> {
%alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32>
affine.for %arg1 = 0 to 16 {
affine.for %arg2 = 0 to 8 {
affine.for %arg3 = 0 to 128 {
%0 = affine.load %arg0[%arg1, %arg2, %arg3] : memref<16x8x128xf32>
%1 = math.sin %0 : f32
affine.store %1, %alloc[%arg1, %arg2, %arg3] : memref<16x8x128xf32>
}
}
}
%alloc1 = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map>
"zlow.stick"(%alloc, %alloc1) {layout = "3DS", no_saturation = -1 : si64} : (memref<16x8x128xf32>, memref<16x8x128xf16, #map>) -> ()
return %alloc1 : memref<16x8x128xf16, #map>

// mlir2FileCheck.py
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)>
// CHECK-LABEL: func.func @test_normalization
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> {
// CHECK: [[RES_:%.+]] = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf32>
// CHECK: [[RES_1_:%.+]] = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map>
}
Loading