Skip to content

Conversation

Arkar-Hema
Copy link
Contributor

Combine Parallel Dense

CombineParallelDense is an optimization pass designed to merge multiple parallel ONNXGemmOp (Dense/Fully Connected) operations into a single, more efficient Dense layer. This optimization reduces redundant computations, improves memory efficiency, and enhances hardware utilization.

The pass identifies Dense (Gemm) operations that:

  • Share the same input tensor.
  • Have identical attributes such as alpha, beta, transA and transB (ensuring compatibility).
  • May have different output dimensions (number of neurons) but maintain compatible weight shapes for concatenation.

Lets assume a input case:

  • Input Shape: (1, 512)
  • Dense A: out_features = 256
  • Dense B: out_features = 128
  • Dense C: out_features = 64
  • Attributes: transB = 0, alpha = 1.0, beta = 1.0

Before Optimization (Three Parallel Gemms)

  • Each GEMM does one full matrix multiplication (1×512 × 512×N)
  • Three separate weight and bias tensors and produces three outputs
    -Memory Reads: 3 times full input (one for each gemm)
  • Post-processing: A Concat(axis=1) merges them into one output: Y (1×448)

After Optimization (Combined Dense)

  • Total Output Features: 256 + 128 + 64 = 448
  • All weights are concatenated along output channel axis → New weight shape: (512, 448)
  • Biases are also concatenated
  • A single ONNXGemmOp computes Y (1×448) directly

Improvement in performance metrics

Latency Improvement: 7-15%
Throughput Improvement: 8-14%
Memory Usage Improvement: 10-12%

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@tungld
Copy link
Collaborator

tungld commented Apr 17, 2025

@Arkar-Hema A general question: in what kind of models have you seen this kind of pattern: multiple Gemm ops followed by a Concat op? and also similar patterns you have recently created PRs for? Just curious on how practical it is. Thanks!

@Arkar-Hema
Copy link
Contributor Author

@Arkar-Hema A general question: in what kind of models have you seen this kind of pattern: multiple Gemm ops followed by a Concat op? and also similar patterns you have recently created PRs for? Just curious on how practical it is. Thanks!

  • Models with the CombineParallelDense pattern (Combine parallel dense Optimization pass in ONNX Dialect #3123):
    These contain multiple Gemm ops, though not always followed by a Concat. I added the Concat condition to the pass so it would still handle those cases gracefully if present. Some models with this pattern include:
  1. Bertsquad-8
  2. Bertsquad-10
  3. Bertsquad-12
  4. FasterRCNN-10
  1. ResNet101-DUC-12
  2. ResNet101-DUC-7
  3. emotion-ferplus models
  4. caffenet models
  5. Densenet models
  6. googlenet models
  7. inception models
  8. rcnn-ilsvrc13 models
  9. resnet models
  10. vgg models
  1. retinanet models
  2. version-RFB-320
  3. version-RFB-640
  4. googlenet models
  5. inception models
  6. resnet models
  7. squeezenet models

@Arkar-Hema
Copy link
Contributor Author

@tungld could you please verify this patch?

@tungld
Copy link
Collaborator

tungld commented Apr 22, 2025

@Arkar-Hema thank you for the information!

I have some general comments:

  • I think that when multiple GEMM ops are followed by a concat, the performance in theory would be better. But, could you run with multiple input sizes to see how the performance benefit in practice?
  • When multiple GEMM ops are NOT followed by a concat (this is the case for the models you listed), you need a split and I think the split axis is the innermost dimension. I am not sure how slow the split is and if we can get speedup or not. Could you do a performance comparison to see if you can achieve speedup in this case?
  • Are you targeting optimization for CPU or it is beneficial for AI accelerators as well given that AI accelerators may use special data layout which may be not convenient for concat or split.

Thanks.

@Arkar-Hema
Copy link
Contributor Author

I ran performance benchmarks across a range of input sizes for both the GEMM → Concat and the Combined GEMM → Split cases. Results show that:

  • In the Concat case, the optimization provides consistent Latency improvement of 2-7%, and throughput improvement of 1-5%
    image
    image
    image
    image

  • In the cases where it splits, the optimization provides consistent Latency improvement of 1-7%, and throughput improvement of 1-8%
    image
    image
    image
    image

  • I’ve currently targeted this pass primarily for CPU backends only.

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Arkar-Hema for the experiments! Did you compile your programs with -O3?

Since this parallel fusion may not work for accelerators, could you create a compile option to enable this if needed, for example -fuse-parallel-onnx-gemm?

I don't think you need to handle the case where there is a concat after multiple gemms. Just emit a split op, then later you can write a simple canonicalization rule for concat to fuse Split -> Concat.

Below are my first-round comments, most of them are for simplifying the code, making it easy to follow. However, the important thing is you need to check the input C carefully because it's broadcastable.

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@Arkar-Hema
Copy link
Contributor Author

Thanks @Arkar-Hema for the experiments! Did you compile your programs with -O3?

Since this parallel fusion may not work for accelerators, could you create a compile option to enable this if needed, for example -fuse-parallel-onnx-gemm?

I don't think you need to handle the case where there is a concat after multiple gemms. Just emit a split op, then later you can write a simple canonicalization rule for concat to fuse Split -> Concat.

Below are my first-round comments, most of them are for simplifying the code, making it easy to follow. However, the important thing is you need to check the input C carefully because it's broadcastable.

I have added it, Thanks

Arkar-Hema added 3 commits May 2, 2025 05:00
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

@tungld
Copy link
Collaborator

tungld commented May 12, 2025

Hi @Arkar-Hema When addressing a comment, could you please provide a brief explanation of how you did so? This will make the review process easier. Thanks!

Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@Arkar-Hema Arkar-Hema requested a review from tungld May 19, 2025 03:33
mlir::cast<ShapedType>(b.getResult().getType()).getShape();
// Output channels is the last dim
if (aOutputShape.back() != bOutputShape.back())
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both Biases as tensor<1xf32>:
If both biases are of shape tensor<1xf32>, I now check their corresponding Gemm output shapes and ensure their output channels (last dimension) match before considering them compatible. If they differ, the function returns false, as merging them without this check would be invalid.

It does not make sense to me how this can solve the problem. You must check there is no broadcasting here, say the last dim in the output must be 1 also, for example:

if (aOutputShape.back() != 1 || bOutputShape.back() != 1)
   return false;

Also, please do add a lit test for this case, to make sure gemm ops are not merged.

Type unrankedTensorType = mlir::UnrankedTensorType::get(elementType);
Type newWeightType = unrankedTensorType;
Value newWeight =
create.onnx.concat(newWeightType, weightValues, concatWeightAxis);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace newWeightType by unrankedTensorType . It's redundant to define newWeightType.

}

Type newBiasType = unrankedTensorType;
Value newBias = create.onnx.concat(newBiasType, biasValues, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace newBiasType by unrankedTensorType . It's redundant to define newBiasType .

auto aOutputShape =
mlir::cast<ShapedType>(a.getResult().getType()).getShape();
auto bOutputShape =
mlir::cast<ShapedType>(b.getResult().getType()).getShape();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace these by:

ArrayRef<int64_t> aOutputShape = getShape(a.getY().getType());
ArrayRef<int64_t> bOutputShape = getShape(b.getY().getType());


auto newGemm = rewriter.create<ONNXGemmOp>(loc, newOutputType, input,
newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace this by

Value newGemmOutput = create.onnx.gemm(unrankedTensorType, input,
        newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
        gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr());

biasValues.push_back(gemm.getC());
} else {
auto gemmShape =
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use: ArrayRef<int64_t> gemmShape = getShape(gemm.getY().getType());


ArrayRef<int64_t> splitSizes(splitSizesVec);
ValueRange splitResults = onnx_mlir::emitSplitByChannels(
rewriter, loc, newGemm.getResult(), splitSizes, splitAxis);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace this by

    SmallVector<Type, 4> splitTypes(splitSizes.size(), unrankedTensorType);
    ValueRange splitResults = create.onnx.split(
        splitTypes, newGemmOutput, create.onnx.constantInt64(splitSizes), splitAxis);

auto gemmShape =
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape();
Value zeroBias = create.onnx.constant(DenseElementsAttr::get(
RankedTensorType::get({gemmShape[splitAxis]}, elementType), 0.0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if gemmShape[splitAxis] is a static dimension in areCompatible() function. Otherwise, it fails to create a constant tensor here.

else if (aCShape[0] != bCShape[0])
return false;
}
return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When aC is None, do check that the last dim of aOutput is static. Otherwise, it fails when you create a constant tensor of zeros in the later code that you use the last dim of aOutput.

Check the same thing for bC.

Please add a list test for the case where aC or bC is None.

newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr());

// Check for common ConcatOp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check this earlier just after you collect all parallelGemms. The reason is you have a return failure() here which may interrupt the whole rewriting while you created new weight, new bias, and new gemm. Moving this check earlier before creating any new ops would make the IR clean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants