From a19e3f21ca0c0360dfdd2438f4031b76df583351 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Thu, 31 Jul 2025 08:50:41 -0400 Subject: [PATCH 1/6] added test and normalization Signed-off-by: Alexandre Eichenberger --- .../NNPA/Compiler/NNPACompilerUtils.cpp | 7 +- .../Conversion/ZHighToZLow/ZHighToZLow.cpp | 5 +- src/Accelerators/NNPA/NNPAAccelerator.cpp | 2 +- src/Accelerators/NNPA/Pass/NNPAPasses.hpp | 3 +- .../Transform/ZLow/ZLowStickExpansion.cpp | 125 ++++++++++++++---- 5 files changed, 110 insertions(+), 32 deletions(-) diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 07e38a57da..5e97478e0c 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -276,8 +276,11 @@ void addPassesNNPA(mlir::OwningOpRef &module, pm.addPass(zlow::createZLowRewritePass()); // Late generation of code for stick/unstick, needed to be after a // ZLowRewrite pass. - if (!nnpaDisableCompilerStickUnstick) - pm.addPass(zlow::createZLowStickExpansionPass(enableParallel)); + bool expansion = !nnpaDisableCompilerStickUnstick; + bool allocNormalization = isCompatibleWithNNPALevel(NNPALevel::M15); + if (expansion || allocNormalization) + pm.addPass(zlow::createZLowStickOptimizationPass( + expansion, allocNormalization, enableParallel)); pm.addPass(mlir::createCanonicalizerPass()); // Normalize MemRefs. normalizeMemRefsPasses(pm); diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp index ed533d8682..f75141b51b 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp @@ -812,6 +812,9 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern { convertZTensorToMemRefType(*op->result_type_begin()); // Allocate a buffer for the result MemRef. + // Since march=arch15, HW Stick/Unstick may be more efficient with 4k + // allocated pages for CPU data as well, or ensure that allocated data here + // is 4K aligned. Value alloc = nullptr; if (isNHWCLayout(layout)) { if (!nnpaDisableCompilerStickUnstick) { @@ -828,7 +831,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern { alloc = create.mem.alignedAlloc( MemRefType::get({shape[0], shape[2], shape[3], shape[1]}, resType.getElementType()), - dimList); + dimList, gAlignment); } else { // Otherwise, we can directly stickify from NCHW. // Set pre-transformed layout to NCHW. diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index 46cc0e12b3..2bca57c3e6 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -126,7 +126,7 @@ void NNPAAccelerator::registerPasses(int optLevel) const { }); mlir::registerPass([]() -> std::unique_ptr { - return onnx_mlir::zlow::createZLowStickExpansionPass(); + return onnx_mlir::zlow::createZLowStickOptimizationPass(); }); mlir::registerPass([]() -> std::unique_ptr { diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index 1c8d6b7012..91cdb6f756 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -68,7 +68,8 @@ namespace zlow { std::unique_ptr createZLowRewritePass(); /// Add pass for rewriting ZLow ops. -std::unique_ptr createZLowStickExpansionPass( +std::unique_ptr createZLowStickOptimizationPass( + bool enableStickExpansion = true, bool enableAllocNormalization = false, bool enableParallel = false); /// Add pass for rewriting ZLow ops. diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index d2198d7d3d..d6b6231b90 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -4,7 +4,7 @@ //===--- ZLowStickExpansion.cpp - ZLow Stick/Unstick Expansion Patterns ---===// // -// Copyright 2024 The IBM Research Authors. +// Copyright 2024-2025 The IBM Research Authors. // // ============================================================================= // @@ -12,6 +12,9 @@ // to stick / unstick with explict code to perform the transformation, when // applicable. // +// This pass also boost the alignment of zlow.stick inputs to 4k when stick are +// sent to NNPA for march=M15. +// //===----------------------------------------------------------------------===// #include "llvm/Support/Debug.h" @@ -26,6 +29,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp" +#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" #include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" #include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" @@ -125,14 +129,57 @@ class UnstickExpansionPattern : public OpRewritePattern { } }; -/// Expand stick operation to compiler generated code for suitable patterns, aka -/// all but the 1D and 2DS data layouts at this time. +// Resize the alignment of the allocate that is the input to Stick Op, which is +// beneficial for stick operations that executes on the NNPA accelerator of +// march = arch15. +class StickInputAllocPattern : public OpRewritePattern { +public: + StickInputAllocPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite( + ZLowStickOp stickOp, PatternRewriter &rewriter) const override { + return resizeAllocOfStickInput(stickOp); + } + + static LogicalResult resizeAllocOfStickInput(ZLowStickOp stickOp) { + fprintf(stderr, "hi alex, looking at defining op of stick\n"); + stickOp.dump(); + memref::AllocOp allocOfXOp = + stickOp.getX().getDefiningOp(); + if (!allocOfXOp) { + fprintf(stderr, "hi alex, do not have an alloc\n"); + return failure(); + } + fprintf(stderr, "hi alex, has alloc\n"); + allocOfXOp.dump(); + auto alignmentAttr = allocOfXOp.getAlignment(); + int64_t intAlign = alignmentAttr ? alignmentAttr.value() : 1; + fprintf(stderr, "alignment attribute is %d\n", (int)intAlign); + if (intAlign >= gAlignment) { + fprintf(stderr, "hi alex, alloc already good as is\n"); + return failure(); + } + fprintf(stderr, "hi alex, increase alignment to 4k\n"); + ::std::optional attrValue(gAlignment); + allocOfXOp.setAlignment(attrValue); + + return success(); + } +}; + +/// Expand stick operation to compiler generated code for suitable patterns, +/// aka all but the 1D and 2DS data layouts at this time. class StickExpansionPattern : public OpRewritePattern { public: - StickExpansionPattern(MLIRContext *context, bool enableParallelism = false) + StickExpansionPattern(MLIRContext *context, bool enableAllocNormalization, + bool enableParallelism) : OpRewritePattern(context, 1), + enableAllocNormalization(enableAllocNormalization), enableParallel(enableParallelism) {} + bool enableAllocNormalization = true; bool enableParallel = true; using OpRewritePattern::OpRewritePattern; @@ -143,8 +190,8 @@ class StickExpansionPattern : public OpRewritePattern { // Generic way to handle all formats listed below. // Did not add the HWCK as this is typically for constants and want to - // preserve the high level constant propagation of constant values into the - // Convolution filters. + // preserve the high level constant propagation of constant values into + // the Convolution filters. if (layout.getValue().equals_insensitive("4D") || layout.getValue().equals_insensitive("3D") || layout.getValue().equals_insensitive("2D") || @@ -152,13 +199,16 @@ class StickExpansionPattern : public OpRewritePattern { layout.getValue().equals_insensitive("NHWC")) { return generateStickCodeNoBuffer(rewriter, stickOp); } - // Otherwise, we don't replace and keep the zdnn call. + // Otherwise, we don't replace and keep the zdnn call. Normalize alloc if + // necessary. + if (enableAllocNormalization) + return StickInputAllocPattern::resizeAllocOfStickInput(stickOp); return failure(); } // Version without buffer, more like zdnn. - // The only requirement for this code to generate the proper code is that E1 - // is been sticked by 64. + // The only requirement for this code to generate the proper code is that + // E1 is been sticked by 64. LogicalResult generateStickCodeNoBuffer( PatternRewriter &rewriter, ZLowStickOp stickOp) const { Operation *op = stickOp.getOperation(); @@ -212,8 +262,8 @@ class StickExpansionPattern : public OpRewritePattern { // If outputDims[E1] is constant and < 64, then T1 is 1 (ok), and we can // iterate over fewer values in the SIMD loop. IndexExpr simdLoopUB = lit64; - // Unrolling of SIMD loop: tried 2 and 8, 4 was best. Max is a const as we - // allocate array of that max size. + // Unrolling of SIMD loop: tried 2 and 8, 4 was best. Max is a const as + // we allocate array of that max size. const int64_t maxUnrollVL = 4; int64_t unrollVL = maxUnrollVL; if (outputDims[E1].isLiteral()) { @@ -248,10 +298,10 @@ class StickExpansionPattern : public OpRewritePattern { } } - // Compute max tiles. It is actually not easy to compute the max number of - // tiles. Since we don't allocate, it is just a "view", we only need to - // index by the "tile size", it is sufficient to assume 2 or more. Tiles are - // 64 elements. + // Compute max tiles. It is actually not easy to compute the max number + // of tiles. Since we don't allocate, it is just a "view", we only need + // to index by the "tile size", it is sufficient to assume 2 or more. + // Tiles are 64 elements. IndexExpr T = LitIE(2); DimsExpr reallocTileDims = {T, lit64}; Value allocAsTx64 = create.mem.reinterpretCast(alloc, reallocTileDims); @@ -337,20 +387,32 @@ class StickExpansionPattern : public OpRewritePattern { /*! * Function pass that optimizes ZLowIR. */ -class ZLowStickExpansionPass - : public PassWrapper> { +class ZLowStickOptimizationPass : public PassWrapper> { public: - ZLowStickExpansionPass(bool enableParallel) - : PassWrapper>(), - enableParallel(enableParallel) {} + ZLowStickOptimizationPass() = default; + ZLowStickOptimizationPass(const ZLowStickOptimizationPass &pass) + : PassWrapper>() {} + ZLowStickOptimizationPass(bool enableStickExpansion, + bool enableAllocNormalization, bool enableParallel) { + this->enableStickExpansion = enableStickExpansion; + this->enableAllocNormalization = enableAllocNormalization; + this->enableParallel = enableParallel; + } - bool enableParallel; + Option enableStickExpansion{*this, "enable-stick-expansion", + llvm::cl::desc("Enable stick expansion"), llvm::cl::init(true)}; + Option enableAllocNormalization{*this, "enable-alloc-normalization", + llvm::cl::desc("enable allocation alignment normalization"), + llvm::cl::init(false)}; + Option enableParallel{*this, "enable-parallel", + llvm::cl::desc("Enable parallelization"), llvm::cl::init(false)}; StringRef getArgument() const override { return "zlow-stick-expansion"; } StringRef getDescription() const override { - return "ZLow Stick/Unstick Ops expansion pass."; + return "ZLow Stick/Unstick Ops optimization pass."; } void runOnOperation() override { @@ -358,16 +420,25 @@ class ZLowStickExpansionPass ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext(), enableParallel); - patterns.insert(&getContext(), enableParallel); - + if (enableStickExpansion) { + // Stick expansion also performs normalization to 4k of allocs. + patterns.insert( + &getContext(), enableAllocNormalization, enableParallel); + patterns.insert(&getContext(), enableParallel); + } else { + // No need to normalize unstick alloc output pattern as its already 4k + // aligned. + patterns.insert(&getContext()); + } if (failed(applyPatternsGreedily(function, std::move(patterns)))) return signalPassFailure(); } }; -std::unique_ptr createZLowStickExpansionPass(bool enableParallel) { - return std::make_unique(enableParallel); +std::unique_ptr createZLowStickOptimizationPass(bool enableStickExpansion, + bool enableAllocNormalization, bool enableParallel) { + return std::make_unique( + enableStickExpansion, enableAllocNormalization, enableParallel); } } // namespace zlow From 005cc62b22fbf6cf1adeb070ac6cb987d68cb9eb Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Thu, 31 Jul 2025 09:03:21 -0400 Subject: [PATCH 2/6] fix Signed-off-by: Alexandre Eichenberger --- .../zlow-stick-alloc-normalization.mlir | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 test/mlir/accelerators/nnpa/transform/zlow-stick-alloc-normalization.mlir diff --git a/test/mlir/accelerators/nnpa/transform/zlow-stick-alloc-normalization.mlir b/test/mlir/accelerators/nnpa/transform/zlow-stick-alloc-normalization.mlir new file mode 100644 index 0000000000..5520481ae5 --- /dev/null +++ b/test/mlir/accelerators/nnpa/transform/zlow-stick-alloc-normalization.mlir @@ -0,0 +1,45 @@ +// RUN: onnx-mlir-opt --march=z17 --maccel=NNPA --zlow-stick-expansion="enable-stick-expansion=false enable-alloc-normalization=true" %s -split-input-file | FileCheck %s + +// ----- + +// No alloc normalization possible for input arguments +#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +func.func @test_no_normalization(%arg0: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map> + "zlow.stick"(%arg0, %alloc) {layout = "3DS"} : (memref<16x8x128xf32>, memref<16x8x128xf16, #map>) -> () + return %alloc : memref<16x8x128xf16, #map> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @test_no_normalization +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { +// CHECK: [[RES_:%.+]] = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map> +// CHECK: } +} +// ----- + +// Here the value being stickified is from an alloc memref in the model, so normalize to 4k. + +#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +func.func @test_normalization(%arg0: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + affine.for %arg1 = 0 to 16 { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 128 { + %0 = affine.load %arg0[%arg1, %arg2, %arg3] : memref<16x8x128xf32> + %1 = math.sin %0 : f32 + affine.store %1, %alloc[%arg1, %arg2, %arg3] : memref<16x8x128xf32> + } + } + } + %alloc1 = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map> + "zlow.stick"(%alloc, %alloc1) {layout = "3DS", no_saturation = -1 : si64} : (memref<16x8x128xf32>, memref<16x8x128xf16, #map>) -> () + return %alloc1 : memref<16x8x128xf16, #map> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @test_normalization +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { +// CHECK: [[RES_:%.+]] = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf32> +// CHECK: [[RES_1_:%.+]] = memref.alloc() {alignment = 4096 : i64} : memref<16x8x128xf16, #map> +} From ec5cf39839795241f63e7c872fcd3003ff5df037 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Thu, 31 Jul 2025 10:55:05 -0400 Subject: [PATCH 3/6] add debug info Signed-off-by: Alexandre Eichenberger --- .../NNPA/Transform/ZLow/ZLowStickExpansion.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index d6b6231b90..97084de4a3 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -143,28 +143,25 @@ class StickInputAllocPattern : public OpRewritePattern { return resizeAllocOfStickInput(stickOp); } + // Shared functions with other classes in this file. static LogicalResult resizeAllocOfStickInput(ZLowStickOp stickOp) { - fprintf(stderr, "hi alex, looking at defining op of stick\n"); - stickOp.dump(); memref::AllocOp allocOfXOp = stickOp.getX().getDefiningOp(); if (!allocOfXOp) { - fprintf(stderr, "hi alex, do not have an alloc\n"); + LLVM_DEBUG(llvm::dbgs() << " stick input had no alloc (parameter)\n"); return failure(); } - fprintf(stderr, "hi alex, has alloc\n"); - allocOfXOp.dump(); auto alignmentAttr = allocOfXOp.getAlignment(); int64_t intAlign = alignmentAttr ? alignmentAttr.value() : 1; fprintf(stderr, "alignment attribute is %d\n", (int)intAlign); if (intAlign >= gAlignment) { - fprintf(stderr, "hi alex, alloc already good as is\n"); + LLVM_DEBUG(llvm::dbgs() << " stick input is properly aligned\n"); return failure(); } - fprintf(stderr, "hi alex, increase alignment to 4k\n"); + LLVM_DEBUG( + llvm::dbgs() << " stick input alignment is too small; fix it\n"); ::std::optional attrValue(gAlignment); allocOfXOp.setAlignment(attrValue); - return success(); } }; From bbfb808a23a905d8bb24b716c938e014c02da846 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Thu, 31 Jul 2025 11:33:19 -0400 Subject: [PATCH 4/6] do the right thing for unstick, alloc to 4k only when using zDNN unstick Signed-off-by: Alexandre Eichenberger --- .../Conversion/ZHighToZLow/ZHighToZLow.cpp | 8 +-- .../Transform/ZLow/ZLowStickExpansion.cpp | 50 ++++++++++++++++--- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp index f75141b51b..e88ed04560 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp @@ -831,7 +831,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern { alloc = create.mem.alignedAlloc( MemRefType::get({shape[0], shape[2], shape[3], shape[1]}, resType.getElementType()), - dimList, gAlignment); + dimList); } else { // Otherwise, we can directly stickify from NCHW. // Set pre-transformed layout to NCHW. @@ -839,8 +839,10 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern { } } if (alloc == nullptr) - alloc = insertAllocForZMemRef( - zMemRefType, shapeHelper.getOutputDims(), op, rewriter); + // Memory for output (which is not in stick format) so use normal + // alignment. + alloc = insertAllocForZMemRef(zMemRefType, shapeHelper.getOutputDims(), + op, rewriter, MemRefBuilder::defaultAlign); // Emit a ZLow operation. rewriter.create(loc, input, alloc, layout); diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index 97084de4a3..d8f23f28f2 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -58,14 +58,49 @@ namespace zlow { using MDBuilder = MultiDialectBuilder; +// Resize the alignment of the allocate that is the output to Unstick Op, which +// is beneficial for unstick operations that executes on the NNPA accelerator of +// march = arch15. +class UnstickOutputAllocPattern : public OpRewritePattern { +public: + UnstickOutputAllocPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite( + ZLowUnstickOp unstickOp, PatternRewriter &rewriter) const override { + return resizeAllocOfUnstickOutput(unstickOp); + } + + // Shared functions with other classes in this file. + static LogicalResult resizeAllocOfUnstickOutput(ZLowUnstickOp unstickOp) { + memref::AllocOp allocOfXOp = + unstickOp.getOut().getDefiningOp(); + assert(allocOfXOp && "unstick output should always be allocated"); + auto alignmentAttr = allocOfXOp.getAlignment(); + int64_t intAlign = alignmentAttr ? alignmentAttr.value() : 1; + if (intAlign >= gAlignment) { + LLVM_DEBUG(llvm::dbgs() << " stick input is properly aligned\n"); + return failure(); + } + LLVM_DEBUG( + llvm::dbgs() << " stick input alignment is too small; fix it\n"); + ::std::optional attrValue(gAlignment); + allocOfXOp.setAlignment(attrValue); + return success(); + } +}; + /// Expand unstick operation to compiler generated code for suitable patterns, /// aka all but the 1D and 2DS data layouts at this time. class UnstickExpansionPattern : public OpRewritePattern { public: - UnstickExpansionPattern(MLIRContext *context, bool enableParallelism = false) + UnstickExpansionPattern(MLIRContext *context, bool enableAllocNormalization, + bool enableParallelism) : OpRewritePattern(context, 1), + enableAllocNormalization(enableAllocNormalization), enableParallel(enableParallelism) {} - + bool enableAllocNormalization = true; bool enableParallel = true; using OpRewritePattern::OpRewritePattern; @@ -84,7 +119,10 @@ class UnstickExpansionPattern : public OpRewritePattern { layout.getValue().equals_insensitive("NHWC")) { return generateUnstickCodeNoBuffer(rewriter, unstickOp); } - // Otherwise, we don't replace and keep the zdnn call. + // Otherwise, we don't replace and keep the zdnn call. Normalize alloc if + // necessary. + if (enableAllocNormalization) + return UnstickOutputAllocPattern::resizeAllocOfUnstickOutput(unstickOp); return failure(); } @@ -421,10 +459,10 @@ class ZLowStickOptimizationPass : public PassWrapper( &getContext(), enableAllocNormalization, enableParallel); - patterns.insert(&getContext(), enableParallel); + patterns.insert( + &getContext(), enableAllocNormalization, enableParallel); } else { - // No need to normalize unstick alloc output pattern as its already 4k - // aligned. + patterns.insert(&getContext()); patterns.insert(&getContext()); } if (failed(applyPatternsGreedily(function, std::move(patterns)))) From ad4a2967b263f634d2aa38a948d6009c8aecf1bd Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Mon, 4 Aug 2025 10:18:30 -0400 Subject: [PATCH 5/6] response to comments Signed-off-by: Alexandre Eichenberger --- .../Transform/ZLow/ZLowStickExpansion.cpp | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index d8f23f28f2..ca2c50ba0f 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -52,6 +52,21 @@ using namespace mlir; +// Ensure that alignment is a multiple of gAlignment. +static LogicalResult setAlignOfAlloc(memref::AllocOp allocOp) { + auto alignmentAttr = allocOp.getAlignment(); + int64_t align = alignmentAttr ? alignmentAttr.value() : 1; + if (align != 0 && align % gAlignment == 0) { + LLVM_DEBUG(llvm::dbgs() << " alloc is properly aligned\n"); + return failure(); + } + align = ((align + gAlignment - 1) / gAlignment) * gAlignment; + LLVM_DEBUG(llvm::dbgs() << " alloc not aligned ->" << align << "\n"); + ::std::optional attrValue(align); + allocOp.setAlignment(attrValue); + return success(); +} + namespace onnx_mlir { namespace zlow { @@ -74,20 +89,10 @@ class UnstickOutputAllocPattern : public OpRewritePattern { // Shared functions with other classes in this file. static LogicalResult resizeAllocOfUnstickOutput(ZLowUnstickOp unstickOp) { - memref::AllocOp allocOfXOp = + memref::AllocOp allocOfOutputOp = unstickOp.getOut().getDefiningOp(); - assert(allocOfXOp && "unstick output should always be allocated"); - auto alignmentAttr = allocOfXOp.getAlignment(); - int64_t intAlign = alignmentAttr ? alignmentAttr.value() : 1; - if (intAlign >= gAlignment) { - LLVM_DEBUG(llvm::dbgs() << " stick input is properly aligned\n"); - return failure(); - } - LLVM_DEBUG( - llvm::dbgs() << " stick input alignment is too small; fix it\n"); - ::std::optional attrValue(gAlignment); - allocOfXOp.setAlignment(attrValue); - return success(); + assert(allocOfOutputOp && "unstick output should always be allocated"); + return setAlignOfAlloc(allocOfOutputOp); } }; @@ -189,18 +194,7 @@ class StickInputAllocPattern : public OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << " stick input had no alloc (parameter)\n"); return failure(); } - auto alignmentAttr = allocOfXOp.getAlignment(); - int64_t intAlign = alignmentAttr ? alignmentAttr.value() : 1; - fprintf(stderr, "alignment attribute is %d\n", (int)intAlign); - if (intAlign >= gAlignment) { - LLVM_DEBUG(llvm::dbgs() << " stick input is properly aligned\n"); - return failure(); - } - LLVM_DEBUG( - llvm::dbgs() << " stick input alignment is too small; fix it\n"); - ::std::optional attrValue(gAlignment); - allocOfXOp.setAlignment(attrValue); - return success(); + return setAlignOfAlloc(allocOfXOp); } }; From 3c34027db2549f6d930235d61b8d2dbc6564a69a Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Mon, 4 Aug 2025 10:23:23 -0400 Subject: [PATCH 6/6] abundance of caution Signed-off-by: Alexandre Eichenberger --- src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index ca2c50ba0f..5042f128aa 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -60,6 +60,7 @@ static LogicalResult setAlignOfAlloc(memref::AllocOp allocOp) { LLVM_DEBUG(llvm::dbgs() << " alloc is properly aligned\n"); return failure(); } + align = align < 1 ? 1 : align; // Avoid negative number. align = ((align + gAlignment - 1) / gAlignment) * gAlignment; LLVM_DEBUG(llvm::dbgs() << " alloc not aligned ->" << align << "\n"); ::std::optional attrValue(align);