Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7aa505a
Combine parallel dense optimization pass
Arkar-Hema Apr 16, 2025
997d9e5
Clang format modified
Arkar-Hema Apr 16, 2025
15343bf
Clang format modified
Arkar-Hema Apr 16, 2025
4fad7ea
Added the unit test for the pass
Arkar-Hema Apr 16, 2025
d4d2fda
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 1, 2025
e6d8d6c
Updated test case, added compiler flag, and builder for gemm
Arkar-Hema May 2, 2025
7b357e2
Clang format fix
Arkar-Hema May 2, 2025
ab7a2aa
Clang fix
Arkar-Hema May 2, 2025
2b466ca
Added compiler option
Arkar-Hema May 2, 2025
6aff5ab
Added compiler option in test case
Arkar-Hema May 2, 2025
d8611f5
Test case updation
Arkar-Hema May 2, 2025
20cab0c
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 2, 2025
3bf58c0
Merge branch 'main' into combine_parallel_dense
Arkar-Hema May 5, 2025
d07e896
Added lit test for dynamic shapes
Arkar-Hema May 8, 2025
2266d90
Clang format fix
Arkar-Hema May 8, 2025
8882476
Added unrankedtype for outputtype
Arkar-Hema May 8, 2025
c8d2946
Added ranked type for output type
Arkar-Hema May 8, 2025
dd42652
Clang format fix
Arkar-Hema May 8, 2025
ca09e94
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 8, 2025
3f66539
Updated output type
Arkar-Hema May 9, 2025
2f0f113
Updated Compatible function
Arkar-Hema May 13, 2025
c2b3728
clang fix
Arkar-Hema May 13, 2025
81df6d9
Resolved conflicts
Arkar-Hema May 15, 2025
5f132a7
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 16, 2025
92ff3d4
Merge branch 'main' into combine_parallel_dense
chentong319 Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ OptReport optReport; // onnx-mlir only
bool useOldBufferization; // onnx-mlir only
bool enableTiming; // onnx-mlir only
bool enableBoundCheck; // onnx-mlir only
bool fuseParallelOnnxGemm; // onnx-mlir only
bool split_input_file; // onnx-mlir-opt only
bool verify_diagnostics; // onnx-mlir-opt only
bool verify_passes; // onnx-mlir-opt only
Expand Down Expand Up @@ -721,6 +722,13 @@ static llvm::cl::opt<bool, true> enable_bound_check("enable-bound-check",
llvm::cl::location(enableBoundCheck), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> fuse_parallel_onnx_gemm(
"fuse-parallel-onnx-gemm",
llvm::cl::desc("Enable Combine parallel dense layers (default=false)."),
llvm::cl::location(
fuseParallelOnnxGemm), // Link directly to the existing variable
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));

/*
How to use the optional optimization for testing.

Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ extern bool useOldBufferization; // onnx-mlir only
extern bool enableTiming; // onnx-mlir only
extern bool enableBoundCheck; // onnx-mlir only
extern bool debugTestCompilerOpt; // onnx-mlir only
extern bool fuseParallelOnnxGemm; // onnx-mlir only

extern bool split_input_file; // onnx-mlir-opt only
extern bool verify_diagnostics; // onnx-mlir-opt only
Expand Down
7 changes: 7 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ Value OnnxBuilder::gelu(Value input, StringAttr approximateAttr) const {
toTensor(input.getType()), input, approximateAttr);
}

Value OnnxBuilder::gemm(Type Y, Value A, Value B, Value C, FloatAttr alpha,
FloatAttr beta, IntegerAttr transA, IntegerAttr transB) const {

return createOpAndInferShapes<ONNXGemmOp>(
toTensor(Y), A, B, C, alpha, beta, transA, transB);
}

// ONNXLayerNormalizationOp, version with one output only (Y).
Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
Value bias, int64_t axis, FloatAttr epsilon) const {
Expand Down
7 changes: 6 additions & 1 deletion src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ struct OnnxBuilder : DialectBuilder {
// ONNXGeluOp
mlir::Value gelu(mlir::Value input, mlir::StringAttr approximateAttr) const;

// ONNXGemmOp
mlir::Value gemm(mlir::Type Y, mlir::Value A, mlir::Value B, mlir::Value C,
mlir::FloatAttr alpha, mlir::FloatAttr beta, mlir::IntegerAttr transA,
mlir::IntegerAttr transB) const;

// ONNXLayerNormalizationOp, version with one output only (Y).
mlir::Value layerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
Expand All @@ -118,7 +123,7 @@ struct OnnxBuilder : DialectBuilder {
mlir::Value scale, mlir::Value bias, int64_t axis,
mlir::FloatAttr epsilon) const;

// ONNXMatMulOp or ONNXGemmOp
// ONNXMatMulOp
mlir::Value matmul(
mlir::Type Y, mlir::Value A, mlir::Value B, bool useGemm = false) const;

Expand Down
209 changes: 209 additions & 0 deletions src/Dialect/ONNX/Transforms/Recompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
//
//===----------------------------------------------------------------------===//

#include <numeric>

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand All @@ -38,6 +41,52 @@

using namespace mlir;

namespace onnx_mlir {
// splits a tensor along a static axis into multiple outputs based on specified
// channel sizes using the ONNX Split operation
ValueRange emitSplitByChannels(PatternRewriter &rewriter, Location loc,
Value input, ArrayRef<int64_t> splitSizes, int64_t axis) {

onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(rewriter, loc);
ShapedType inputType = mlir::cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
ArrayRef<int64_t> inputShape = inputType.getShape();

// Ensure the axis is within bounds and is a static dimension
assert(axis < static_cast<int64_t>(inputShape.size()) && axis >= 0 &&
"Axis out of bounds for input shape.");

assert(!inputType.isDynamicDim(axis) &&
"Channel dimension for input tensor must be static.");
// Validate split sizes
int64_t totalChannels = inputShape[axis];
int64_t sumSplitSizes =
std::accumulate(splitSizes.begin(), splitSizes.end(), 0);

assert(totalChannels == sumSplitSizes &&
"Split sizes must sum up to the total number of elements along the "
"axis.");

// Create Split Constant
Value splitConstant = create.onnx.constantInt64(splitSizes);

// Create output types for each split part
SmallVector<Type, 4> resultTypes;
for (int64_t size : splitSizes) {
SmallVector<int64_t> splitShape(inputShape.begin(), inputShape.end());
splitShape[axis] = size;
resultTypes.push_back(RankedTensorType::get(splitShape, elementType));
}
rewriter.setInsertionPointAfter(input.getDefiningOp());
// Perform Split Operation
ValueRange results =
create.onnx.split(ArrayRef(resultTypes), input, splitConstant, axis);

return results;
}

} // namespace onnx_mlir

namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
// #include "src/Dialect/ONNX/Transforms/ONNXRecompose.inc"
Expand Down Expand Up @@ -602,6 +651,163 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern
}
};

struct CombineParallelDensePattern : public OpRewritePattern<ONNXGemmOp> {
using OpRewritePattern<ONNXGemmOp>::OpRewritePattern;

// Helper function to check if an gemm is mergeable
static bool areCompatible(ONNXGemmOp a, ONNXGemmOp b) {
if (a.getAlpha() != b.getAlpha() || a.getBeta() != b.getBeta() ||
a.getTransA() != b.getTransA() || a.getTransB() != b.getTransB())
return false;

auto aBShape = mlir::cast<ShapedType>(a.getB().getType()).getShape();
auto bBShape = mlir::cast<ShapedType>(b.getB().getType()).getShape();
int64_t axis = a.getTransB() ? 1 : 0;
if (aBShape[axis] != bBShape[axis])
return false;

// Check C compatibility — only allow None or 1D
Value aC = a.getC();
Value bC = b.getC();
if (!onnx_mlir::isNoneValue(aC) && !onnx_mlir::isNoneValue(bC)) {
auto aCShape = mlir::cast<ShapedType>(aC.getType()).getShape();
auto bCShape = mlir::cast<ShapedType>(bC.getType()).getShape();
if (aCShape.size() != 1 || bCShape.size() != 1)
return false;
}
return true;
}

LogicalResult matchAndRewrite(
ONNXGemmOp gemmOp1, PatternRewriter &rewriter) const final {
Value input = gemmOp1.getA();
if (!onnx_mlir::isRankedShapedType(input.getType()) ||
!mlir::cast<ShapedType>(input.getType()).hasStaticShape())
return failure();

SmallVector<ONNXGemmOp> parallelGemms = {gemmOp1};

for (auto user : input.getUsers()) {
ONNXGemmOp currentGemm = dyn_cast<ONNXGemmOp>(user);
if (currentGemm && currentGemm != gemmOp1 &&
areCompatible(gemmOp1, currentGemm)) {
parallelGemms.push_back(currentGemm);
}
}
if (parallelGemms.size() < 2)
return failure();

Location loc = gemmOp1.getLoc();
ShapedType inputType = mlir::cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
rewriter, loc);

// Identify axis dynamically based on Gemm shape consistency
auto firstWeightType =
mlir::cast<ShapedType>(parallelGemms[0].getB().getType());
int64_t concatAxis = gemmOp1.getTransB() ? 0 : 1;
int64_t Axis = 1;

// Concatenate weights
SmallVector<Value> weightValues;

for (auto gemm : parallelGemms) {
weightValues.push_back(gemm.getB());
}

Type newWeightType = mlir::UnrankedTensorType::get(elementType);
Value newWeight =
create.onnx.concat(newWeightType, weightValues, concatAxis);

// Concatenate biases (create zero constants for missing biases)
SmallVector<Value> biasValues;
for (auto gemm : parallelGemms) {
if (!onnx_mlir::isNoneValue(gemm.getC())) {
biasValues.push_back(gemm.getC());
} else {
auto biasShape =
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape();
Value zeroBias = create.onnx.constant(DenseElementsAttr::get(
RankedTensorType::get({biasShape[Axis]}, elementType), 0.0));
biasValues.push_back(zeroBias);
}
}

Type newBiasType = mlir::UnrankedTensorType::get(elementType);
Value newBias = create.onnx.concat(newBiasType, biasValues, 0);

// Create combined Gemm operation
SmallVector<int64_t, 2> newOutputShape(
mlir::cast<ShapedType>(parallelGemms[0].getResult().getType())
.getShape());

// Sum output channels from parallel gemms
int64_t totalOutputChannels = 0;
for (auto gemm : parallelGemms) {
int64_t outCh =
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape()[Axis];
totalOutputChannels += outCh;
}
newOutputShape[Axis] = totalOutputChannels;
auto newOutputType = RankedTensorType::get(newOutputShape, elementType);

auto newGemm = rewriter.create<ONNXGemmOp>(loc, newOutputType, input,
newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr());

// Check for common ConcatOp
ONNXConcatOp commonConcatOp = nullptr;
for (auto gemm : parallelGemms) {
for (auto user : gemm.getResult().getUsers()) {
if (auto concatOp = dyn_cast<ONNXConcatOp>(user)) {
if (!commonConcatOp) {
commonConcatOp = concatOp;
}
if (concatOp != commonConcatOp) {
commonConcatOp = nullptr;
break;
}
} else {
commonConcatOp = nullptr;
break;
}
}
if (!commonConcatOp) {
break;
}
}

if (commonConcatOp) {
commonConcatOp.getResult().replaceAllUsesWith(newGemm.getResult());
rewriter.eraseOp(commonConcatOp);
} else {
SmallVector<int64_t, 4> splitSizesVec;
for (auto gemm : parallelGemms) {
int64_t outputChannels =
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape()[Axis];
splitSizesVec.push_back(outputChannels);
}

ArrayRef<int64_t> splitSizes(splitSizesVec);
ValueRange splitResults = onnx_mlir::emitSplitByChannels(
rewriter, loc, newGemm.getResult(), splitSizes, Axis);

for (size_t i = 0; i < parallelGemms.size(); ++i) {
parallelGemms[i].replaceAllUsesWith(splitResults[i]);
}
}

for (auto gemm : parallelGemms) {
if (gemm.getResult().use_empty()) {
rewriter.eraseOp(gemm);
}
}

return success();
}
};

struct RecomposeONNXToONNXPass
: public PassWrapper<RecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RecomposeONNXToONNXPass)
Expand Down Expand Up @@ -682,6 +888,9 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
patterns.insert<RecomposeGeluFromMulPattern>(context);
patterns.insert<RecomposeLayerNormFromMulPattern>(context);
patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
if (fuseParallelOnnxGemm) {
patterns.insert<CombineParallelDensePattern>(context);
}
}

/*!
Expand Down
61 changes: 61 additions & 0 deletions test/mlir/onnx/onnx_recompose_combine_parallel_dense.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: onnx-mlir --useOnnxModelTypes=false --fuse-parallel-onnx-gemm --EmitONNXIR --printIR %s | FileCheck %s

func.func @test_gemm_concat_simple(%arg0: tensor<1x4xf32>) -> tensor<1x6xf32> {
%0 = onnx.Constant dense<5.5>: tensor<4x3xf32>
%1 = onnx.Constant dense<0.2> : tensor<3xf32>
%2 = onnx.Constant dense<4.5>: tensor<4x3xf32>
%3 = onnx.Constant dense<0.5> : tensor<3xf32>
%4 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%5 = "onnx.Gemm"(%arg0, %2, %3) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%6 = "onnx.Concat"(%4, %5) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x6xf32>
return %6 : tensor<1x6xf32>

// CHECK-LABEL: func @test_gemm_concat_simple
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x6xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<4x6xf32>

// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<6xf32>

// CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]])
// CHECK-SAME: : (tensor<1x4xf32>, tensor<4x6xf32>, tensor<6xf32>) -> tensor<1x6xf32>
// CHECK-NEXT: return [[VAR_2_]] : tensor<1x6xf32>

}

func.func @test_combine_gemm_split(%arg0: tensor<1x4xf32>) -> tensor<1x12xf32> {
%0 = onnx.Constant dense<1.6> : tensor<4x3xf32>
%1 = onnx.Constant dense<2.7> : tensor<4x3xf32>
%2 = onnx.Constant dense<3.7> : tensor<4x3xf32>
%3 = onnx.Constant dense<4.6> : tensor<4x3xf32>
%4 = onnx.Constant dense<0.1> : tensor<3xf32>
%5 = onnx.Constant dense<0.9> : tensor<3xf32>
%6 = onnx.Constant dense<0.2> : tensor<3xf32>
%7 = onnx.Constant dense<0.8> : tensor<3xf32>
%8 = "onnx.Gemm"(%arg0, %0, %4) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%9 = "onnx.Gemm"(%arg0, %1, %5) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%10 = "onnx.Gemm"(%arg0, %2, %6) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_3", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%11 = "onnx.Gemm"(%arg0, %3, %7) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_4", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%12 = "onnx.Relu"(%8) {onnx_node_name = "ReLU_1"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%13 = "onnx.Sigmoid"(%9) {onnx_node_name = "Sigmoid_2"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%14 = "onnx.Tanh"(%10) {onnx_node_name = "Tanh_3"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%15 = "onnx.LeakyRelu"(%11) {alpha = 0.00999999977 : f32, onnx_node_name = "LeakyReLU_4"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%16 = "onnx.Concat"(%12, %13, %14, %15) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x12xf32>
return %16 : tensor<1x12xf32>

// CHECK-LABEL: func @test_combine_gemm_split
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x12xf32> {
// CHECK: [[CONST_SPLIT_:%.+]] = onnx.Constant dense<3> : tensor<4xi64>
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<4x12xf32>
// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<12xf32>
// CHECK: [[GEMM_OUT_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]])
// CHECK-SAME: : (tensor<1x4xf32>, tensor<4x12xf32>, tensor<12xf32>) -> tensor<1x12xf32>
// CHECK: [[VAR_2_:[^ ]+]]:4 = "onnx.Split"([[GEMM_OUT_]], [[CONST_SPLIT_]]) {axis = 1 : si64, onnx_node_name = "onnx.Split_2"} : (tensor<1x12xf32>, tensor<4xi64>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
// CHECK: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_2_]]#0) {onnx_node_name = "ReLU_1"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_4_:%.+]] = "onnx.Sigmoid"([[VAR_2_]]#3) {onnx_node_name = "Sigmoid_2"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_5_:%.+]] = "onnx.Tanh"([[VAR_2_]]#2) {onnx_node_name = "Tanh_3"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_6_:%.+]] = "onnx.LeakyRelu"([[VAR_2_]]#1) {alpha = 0.00999999977 : f32, onnx_node_name = "LeakyReLU_4"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[FINAL_OUT:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_5_]], [[VAR_6_]]) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x12xf32>
// CHECK: return [[FINAL_OUT]] : tensor<1x12xf32>


}
Loading