Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7aa505a
Combine parallel dense optimization pass
Arkar-Hema Apr 16, 2025
997d9e5
Clang format modified
Arkar-Hema Apr 16, 2025
15343bf
Clang format modified
Arkar-Hema Apr 16, 2025
4fad7ea
Added the unit test for the pass
Arkar-Hema Apr 16, 2025
d4d2fda
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 1, 2025
e6d8d6c
Updated test case, added compiler flag, and builder for gemm
Arkar-Hema May 2, 2025
7b357e2
Clang format fix
Arkar-Hema May 2, 2025
ab7a2aa
Clang fix
Arkar-Hema May 2, 2025
2b466ca
Added compiler option
Arkar-Hema May 2, 2025
6aff5ab
Added compiler option in test case
Arkar-Hema May 2, 2025
d8611f5
Test case updation
Arkar-Hema May 2, 2025
20cab0c
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 2, 2025
3bf58c0
Merge branch 'main' into combine_parallel_dense
Arkar-Hema May 5, 2025
d07e896
Added lit test for dynamic shapes
Arkar-Hema May 8, 2025
2266d90
Clang format fix
Arkar-Hema May 8, 2025
8882476
Added unrankedtype for outputtype
Arkar-Hema May 8, 2025
c8d2946
Added ranked type for output type
Arkar-Hema May 8, 2025
dd42652
Clang format fix
Arkar-Hema May 8, 2025
ca09e94
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 8, 2025
3f66539
Updated output type
Arkar-Hema May 9, 2025
2f0f113
Updated Compatible function
Arkar-Hema May 13, 2025
c2b3728
clang fix
Arkar-Hema May 13, 2025
81df6d9
Resolved conflicts
Arkar-Hema May 15, 2025
5f132a7
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 16, 2025
92ff3d4
Merge branch 'main' into combine_parallel_dense
chentong319 Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions src/Dialect/ONNX/Transforms/Recompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
//
//===----------------------------------------------------------------------===//

#include <numeric>

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -38,6 +40,52 @@

using namespace mlir;

namespace onnx_mlir {
// splits a tensor along a static axis into multiple outputs based on specified
// channel sizes using the ONNX Split operation
ValueRange emitSplitByChannels(PatternRewriter &rewriter, Location loc,
Value input, ArrayRef<int64_t> splitSizes, int64_t axis) {

onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(rewriter, loc);
ShapedType inputType = mlir::cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
ArrayRef<int64_t> inputShape = inputType.getShape();

// Ensure the axis is within bounds and is a static dimension
assert(axis < static_cast<int64_t>(inputShape.size()) && axis >= 0 &&
"Axis out of bounds for input shape.");

assert(!inputType.isDynamicDim(axis) &&
"Channel dimension for input tensor must be static.");
// Validate split sizes
int64_t totalChannels = inputShape[axis];
int64_t sumSplitSizes =
std::accumulate(splitSizes.begin(), splitSizes.end(), 0);

assert(totalChannels == sumSplitSizes &&
"Split sizes must sum up to the total number of elements along the "
"axis.");

// Create Split Constant
Value splitConstant = create.onnx.constantInt64(splitSizes);

// Create output types for each split part
SmallVector<Type, 4> resultTypes;
for (int64_t size : splitSizes) {
SmallVector<int64_t> splitShape(inputShape.begin(), inputShape.end());
splitShape[axis] = size;
resultTypes.push_back(RankedTensorType::get(splitShape, elementType));
}
rewriter.setInsertionPointAfter(input.getDefiningOp());
// Perform Split Operation
ValueRange results =
create.onnx.split(ArrayRef(resultTypes), input, splitConstant, axis);

return results;
}

} // namespace onnx_mlir

namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
// #include "src/Dialect/ONNX/Transforms/ONNXRecompose.inc"
Expand Down Expand Up @@ -602,6 +650,145 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern
}
};

struct CombineParallelDensePattern : public OpRewritePattern<ONNXGemmOp> {
using OpRewritePattern<ONNXGemmOp>::OpRewritePattern;

// Helper function to check if an gemm is mergeable
static bool areCompatible(ONNXGemmOp a, ONNXGemmOp b) {
return a.getAlpha() == b.getAlpha() && a.getBeta() == b.getBeta() &&
a.getTransA() == b.getTransA() && a.getTransB() == b.getTransB() &&
mlir::cast<ShapedType>(a.getB().getType())
.getShape()[a.getTransB() ? 1 : 0] ==
mlir::cast<ShapedType>(b.getB().getType())
.getShape()[b.getTransB() ? 1 : 0];
}

LogicalResult matchAndRewrite(
ONNXGemmOp gemmOp1, PatternRewriter &rewriter) const final {
Value input = gemmOp1.getA();
if (!onnx_mlir::isRankedShapedType(input.getType()) ||
mlir::cast<ShapedType>(input.getType()).hasStaticShape() == false)
return failure();

SmallVector<ONNXGemmOp> parallelGemms = {gemmOp1};

for (auto user : input.getUsers()) {
ONNXGemmOp currentGemm = dyn_cast<ONNXGemmOp>(user);
if (currentGemm && currentGemm != gemmOp1 &&
areCompatible(gemmOp1, currentGemm)) {
parallelGemms.push_back(currentGemm);
}
}
if (parallelGemms.size() < 2)
return failure();

Location loc = gemmOp1.getLoc();
ShapedType inputType = mlir::cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
rewriter, loc);

// Identify axis dynamically based on Gemm shape consistency
auto firstWeightType =
mlir::cast<ShapedType>(parallelGemms[0].getB().getType());
int64_t concatAxis =
gemmOp1.getTransB() ? 0 : firstWeightType.getRank() - 1;
int64_t Axis = firstWeightType.getRank() - 1;

// Concatenate weights
SmallVector<Value> weightValues;
SmallVector<int64_t> weightDims(firstWeightType.getShape());
int64_t totalOutputFeatures = 0;

for (auto gemm : parallelGemms) {
ShapedType weightType = mlir::cast<ShapedType>(gemm.getB().getType());
weightValues.push_back(gemm.getB());
totalOutputFeatures += weightType.getShape()[concatAxis];
}

weightDims[concatAxis] = totalOutputFeatures;
Type newWeightType = RankedTensorType::get(weightDims, elementType);
Value newWeight =
create.onnx.concat(newWeightType, weightValues, concatAxis);

// Concatenate biases (create zero constants for missing biases)
SmallVector<Value> biasValues;
for (auto gemm : parallelGemms) {
if (Value bias = gemm.getC()) {
biasValues.push_back(bias);
} else {
auto biasType =
RankedTensorType::get({totalOutputFeatures}, elementType);
Value zeroBias =
create.onnx.constant(DenseElementsAttr::get(biasType, 0.0));
biasValues.push_back(zeroBias);
}
}

SmallVector<int64_t> newBiasShape = {totalOutputFeatures};
Type newBiasType = RankedTensorType::get(newBiasShape, elementType);
Value newBias = create.onnx.concat(newBiasType, biasValues, 0);

// Create combined Gemm operation
auto outputShape =
mlir::cast<ShapedType>(gemmOp1.getResult().getType()).getShape().vec();
outputShape[Axis] = totalOutputFeatures;
auto newOutputType = RankedTensorType::get(outputShape, elementType);

auto newGemm = rewriter.create<ONNXGemmOp>(loc, newOutputType, input,
newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr());

// Check for common ConcatOp
ONNXConcatOp commonConcatOp = nullptr;
for (auto gemm : parallelGemms) {
for (auto user : gemm.getResult().getUsers()) {
if (auto concatOp = dyn_cast<ONNXConcatOp>(user)) {
if (!commonConcatOp) {
commonConcatOp = concatOp;
}
if (concatOp != commonConcatOp) {
commonConcatOp = nullptr;
break;
}
} else {
commonConcatOp = nullptr;
break;
}
}
if (!commonConcatOp) {
break;
}
}

if (commonConcatOp) {
commonConcatOp.getResult().replaceAllUsesWith(newGemm.getResult());
rewriter.eraseOp(commonConcatOp);
} else {
SmallVector<int64_t, 4> splitSizesVec;
for (auto gemm : parallelGemms) {
int64_t outputChannels =
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape()[Axis];
splitSizesVec.push_back(outputChannels);
}

ArrayRef<int64_t> splitSizes(splitSizesVec);
ValueRange splitResults = onnx_mlir::emitSplitByChannels(
rewriter, loc, newGemm.getResult(), splitSizes, Axis);

for (size_t i = 0; i < parallelGemms.size(); ++i) {
parallelGemms[i].replaceAllUsesWith(splitResults[i]);
}
}

for (auto gemm : parallelGemms) {
rewriter.eraseOp(gemm);
}

return success();
}
};

struct RecomposeONNXToONNXPass
: public PassWrapper<RecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RecomposeONNXToONNXPass)
Expand Down Expand Up @@ -682,6 +869,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
patterns.insert<RecomposeGeluFromMulPattern>(context);
patterns.insert<RecomposeLayerNormFromMulPattern>(context);
patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
patterns.insert<CombineParallelDensePattern>(context);
}

/*!
Expand Down
101 changes: 101 additions & 0 deletions test/mlir/onnx/onnx_recompose_combine_parallel_dense.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// RUN: onnx-mlir --useOnnxModelTypes=false --EmitONNXIR --printIR %s | FileCheck %s

func.func @test_gemm_concat_simple(%arg0: tensor<1x4xf32>) -> tensor<1x6xf32> {
%0 = onnx.Constant dense<[[6.033820e-01, 0.874853491, 0.840596497],
[0.0872995406, 0.490965605, 0.450427264],
[0.750424325, 0.274208099, 0.977319359],
[0.0853121132, 9.420610e-01, 0.892422915]]> : tensor<4x3xf32>
%1 = onnx.Constant dense<[0.626507699, 0.101028912, 0.774093985]> : tensor<3xf32>
%2 = onnx.Constant dense<[[0.845248579, 0.0606110133, 0.115944877],
[0.674885928, 0.550753951, 0.25179252],
[0.331635177, 0.910293042, 9.552980e-01],
[0.119107425, 7.870370e-01, 0.439898729]]> : tensor<4x3xf32>
%3 = onnx.Constant dense<[0.243570983, 0.976932287, 0.137448117]> : tensor<3xf32>
%4 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%5 = "onnx.Gemm"(%arg0, %2, %3) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%6 = "onnx.Concat"(%4, %5) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x6xf32>
return %6 : tensor<1x6xf32>

// CHECK-LABEL: func @test_gemm_concat_simple
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x6xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{\[\[6.033820e-01, 0.874853491, 0.840596497, 0.845248579, 0.0606110133, 0.115944877\], \[0.0872995406, 0.490965605, 0.450427264, 0.674885928, 0.550753951, 0.25179252\], \[0.750424325, 0.274208099, 0.977319359, 0.331635177, 0.910293042, 9.552980e-01\], \[0.0853121132, 9.420610e-01, 0.892422915, 0.119107425, 7.870370e-01, 0.439898729\]\]}}> : tensor<4x6xf32>

// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{\[0.626507699, 0.101028912, 0.774093985, 0.243570983, 0.976932287, 0.137448117\]}}> : tensor<6xf32>

// CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]])
// CHECK-SAME: : (tensor<1x4xf32>, tensor<4x6xf32>, tensor<6xf32>) -> tensor<1x6xf32>
// CHECK-NEXT: return [[VAR_2_]] : tensor<1x6xf32>

}

func.func @test_gemm_concat_complex(%arg0: tensor<1x4xf32>) -> tensor<1x18xf32> {
%0 = onnx.Constant dense<[[0.204779208, 0.695178091, 0.239361823], [0.996994256, 0.601588786, 0.190346241], [0.842002928, 0.739568233, 0.994108259], [0.905652821, 0.834119677, 0.303750187]]> : tensor<4x3xf32>
%1 = onnx.Constant dense<[[0.793336808, 0.967174768, 0.98079878], [0.761894762, 0.102106638, 0.039635919], [0.00603901641, 0.923491775, 0.357523948], [0.696550369, 0.308858335, 0.0805873647]]> : tensor<4x3xf32>
%2 = onnx.Constant dense<[[0.544325054, 0.151464358, 0.934764087], [0.478074521, 0.161221609, 0.71641761], [0.50913018, 0.756769299, 0.904207945], [0.0835523381, 0.918578445, 0.835795641]]> : tensor<4x3xf32>
%3 = onnx.Constant dense<[[0.41472131, 0.492292702, 0.088731639], [0.903954088, 0.128603399, 0.769681036], [0.953823149, 0.836306929, 0.9627828], [0.800210654, 0.308792889, 0.314317614]]> : tensor<4x3xf32>
%4 = onnx.Constant dense<[[0.12443202, 0.226671219, 0.148676723], [0.616570889, 0.962450921, 0.134999171], [0.184063375, 0.764316678, 0.414653629], [0.0643175319, 0.148418352, 0.596157073]]> : tensor<4x3xf32>
%5 = onnx.Constant dense<[[0.391361624, 0.664259791, 0.618797242], [0.672276973, 0.0329957306, 0.00447194278], [0.732442378, 0.597825587, 0.0171195511], [0.568968296, 0.778787076, 0.921517431]]> : tensor<4x3xf32>
%6 = onnx.Constant dense<[0.276767612, 0.952775657, 0.301255673]> : tensor<3xf32>
%7 = onnx.Constant dense<[0.889294981, 0.491430521, 0.142108783]> : tensor<3xf32>
%8 = onnx.Constant dense<[0.790298938, 0.401669294, 0.446535289]> : tensor<3xf32>
%9 = onnx.Constant dense<[0.3797189, 0.496988833, 0.511586726]> : tensor<3xf32>
%10 = onnx.Constant dense<[0.721806407, 0.0192602724, 0.322999328]> : tensor<3xf32>
%11 = onnx.Constant dense<[0.969116449, 4.448790e-01, 0.668284774]> : tensor<3xf32>
%12 = "onnx.Gemm"(%arg0, %0, %6) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%13 = "onnx.Gemm"(%arg0, %1, %7) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%14 = "onnx.Gemm"(%arg0, %2, %8) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_3", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%15 = "onnx.Gemm"(%arg0, %3, %9) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_4", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%16 = "onnx.Gemm"(%arg0, %4, %10) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_5", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%17 = "onnx.Gemm"(%arg0, %5, %11) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_6", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%18 = "onnx.Concat"(%12, %13, %14, %15, %16, %17) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x18xf32>
return %18 : tensor<1x18xf32>

// CHECK-LABEL: func @test_gemm_concat_complex
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x18xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<4x18xf32>

// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<18xf32>

// CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]])
// CHECK-SAME: : (tensor<1x4xf32>, tensor<4x18xf32>, tensor<18xf32>) -> tensor<1x18xf32>
// CHECK-NEXT: return [[VAR_2_]] : tensor<1x18xf32>

}

func.func @test_combine_gemm_split(%arg0: tensor<1x4xf32>) -> tensor<1x12xf32> {
%0 = onnx.Constant dense<[[0.199878812, 0.849797964, 0.269263595], [0.146060213, 0.146481737, 0.573383629], [5.496260e-01, 0.930284262, 0.296700984], [0.888540446, 0.329749823, 0.0487339608]]> : tensor<4x3xf32>
%1 = onnx.Constant dense<[[0.512602746, 0.841705561, 3.472580e-01], [0.985034883, 0.372110397, 0.676640093], [0.366143614, 0.211020753, 0.24549152], [0.7849949, 0.798389971, 0.759396135]]> : tensor<4x3xf32>
%2 = onnx.Constant dense<[[0.0379290208, 0.745854259, 0.249491423], [0.207114503, 0.768784403, 0.183352739], [0.546739817, 0.7326473, 0.610019266], [0.843589544, 0.0109933764, 0.56139493]]> : tensor<4x3xf32>
%3 = onnx.Constant dense<[[0.672199249, 0.756824672, 0.38623023], [0.668579399, 0.284004182, 0.229134396], [0.647052705, 0.809947431, 0.899343073], [0.0700130314, 0.520019472, 0.210815623]]> : tensor<4x3xf32>
%4 = onnx.Constant dense<[0.613018572, 0.517307281, 0.902812659]> : tensor<3xf32>
%5 = onnx.Constant dense<[0.352589607, 0.578843653, 0.101251811]> : tensor<3xf32>
%6 = onnx.Constant dense<[0.930565953, 0.390370637, 0.524582207]> : tensor<3xf32>
%7 = onnx.Constant dense<[0.812823832, 0.946865141, 0.834036648]> : tensor<3xf32>
%8 = "onnx.Gemm"(%arg0, %0, %4) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%9 = "onnx.Gemm"(%arg0, %1, %5) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%10 = "onnx.Gemm"(%arg0, %2, %6) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_3", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%11 = "onnx.Gemm"(%arg0, %3, %7) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_4", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%12 = "onnx.Relu"(%8) {onnx_node_name = "ReLU_1"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%13 = "onnx.Sigmoid"(%9) {onnx_node_name = "Sigmoid_2"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%14 = "onnx.Tanh"(%10) {onnx_node_name = "Tanh_3"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%15 = "onnx.LeakyRelu"(%11) {alpha = 0.00999999977 : f32, onnx_node_name = "LeakyReLU_4"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%16 = "onnx.Concat"(%12, %13, %14, %15) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x12xf32>
return %16 : tensor<1x12xf32>

// CHECK-LABEL: func @test_combine_gemm_split
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x12xf32> {
// CHECK: [[CONST_SPLIT_:%.+]] = onnx.Constant dense<3> : tensor<4xi64>
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<4x12xf32>
// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<12xf32>
// CHECK: [[GEMM_OUT_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]])
// CHECK-SAME: : (tensor<1x4xf32>, tensor<4x12xf32>, tensor<12xf32>) -> tensor<1x12xf32>
// CHECK: [[VAR_2_:[^ ]+]]:4 = "onnx.Split"([[GEMM_OUT_]], [[CONST_SPLIT_]]) {axis = 1 : si64, onnx_node_name = "onnx.Split_3"} : (tensor<1x12xf32>, tensor<4xi64>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
// CHECK: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_2_]]#0) {onnx_node_name = "ReLU_1"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_4_:%.+]] = "onnx.Sigmoid"([[VAR_2_]]#3) {onnx_node_name = "Sigmoid_2"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_5_:%.+]] = "onnx.Tanh"([[VAR_2_]]#2) {onnx_node_name = "Tanh_3"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_6_:%.+]] = "onnx.LeakyRelu"([[VAR_2_]]#1) {alpha = 0.00999999977 : f32, onnx_node_name = "LeakyReLU_4"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[FINAL_OUT:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_5_]], [[VAR_6_]]) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x12xf32>
// CHECK: return [[FINAL_OUT]] : tensor<1x12xf32>


}
Loading