-
Notifications
You must be signed in to change notification settings - Fork 371
Combine parallel dense Optimization pass in ONNX Dialect #3123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
@Arkar-Hema A general question: in what kind of models have you seen this kind of pattern: multiple Gemm ops followed by a Concat op? and also similar patterns you have recently created PRs for? Just curious on how practical it is. Thanks! |
|
@tungld could you please verify this patch? |
@Arkar-Hema thank you for the information! I have some general comments:
Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Arkar-Hema for the experiments! Did you compile your programs with -O3?
Since this parallel fusion may not work for accelerators, could you create a compile option to enable this if needed, for example -fuse-parallel-onnx-gemm
?
I don't think you need to handle the case where there is a concat after multiple gemms. Just emit a split op, then later you can write a simple canonicalization rule for concat to fuse Split -> Concat
.
Below are my first-round comments, most of them are for simplifying the code, making it easy to follow. However, the important thing is you need to check the input C carefully because it's broadcastable.
Can one of the admins verify this patch? |
I have added it, Thanks |
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
@jenkins-droid test this please |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
@jenkins-droid test this please |
Hi @Arkar-Hema When addressing a comment, could you please provide a brief explanation of how you did so? This will make the review process easier. Thanks! |
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
mlir::cast<ShapedType>(b.getResult().getType()).getShape(); | ||
// Output channels is the last dim | ||
if (aOutputShape.back() != bOutputShape.back()) | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both Biases as tensor<1xf32>:
If both biases are of shape tensor<1xf32>, I now check their corresponding Gemm output shapes and ensure their output channels (last dimension) match before considering them compatible. If they differ, the function returns false, as merging them without this check would be invalid.
It does not make sense to me how this can solve the problem. You must check there is no broadcasting here, say the last dim in the output must be 1 also, for example:
if (aOutputShape.back() != 1 || bOutputShape.back() != 1)
return false;
Also, please do add a lit test for this case, to make sure gemm ops are not merged.
Type unrankedTensorType = mlir::UnrankedTensorType::get(elementType); | ||
Type newWeightType = unrankedTensorType; | ||
Value newWeight = | ||
create.onnx.concat(newWeightType, weightValues, concatWeightAxis); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace newWeightType
by unrankedTensorType
. It's redundant to define newWeightType
.
} | ||
|
||
Type newBiasType = unrankedTensorType; | ||
Value newBias = create.onnx.concat(newBiasType, biasValues, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace newBiasType
by unrankedTensorType
. It's redundant to define newBiasType
.
auto aOutputShape = | ||
mlir::cast<ShapedType>(a.getResult().getType()).getShape(); | ||
auto bOutputShape = | ||
mlir::cast<ShapedType>(b.getResult().getType()).getShape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace these by:
ArrayRef<int64_t> aOutputShape = getShape(a.getY().getType());
ArrayRef<int64_t> bOutputShape = getShape(b.getY().getType());
|
||
auto newGemm = rewriter.create<ONNXGemmOp>(loc, newOutputType, input, | ||
newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(), | ||
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace this by
Value newGemmOutput = create.onnx.gemm(unrankedTensorType, input,
newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr());
biasValues.push_back(gemm.getC()); | ||
} else { | ||
auto gemmShape = | ||
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use: ArrayRef<int64_t> gemmShape = getShape(gemm.getY().getType());
|
||
ArrayRef<int64_t> splitSizes(splitSizesVec); | ||
ValueRange splitResults = onnx_mlir::emitSplitByChannels( | ||
rewriter, loc, newGemm.getResult(), splitSizes, splitAxis); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace this by
SmallVector<Type, 4> splitTypes(splitSizes.size(), unrankedTensorType);
ValueRange splitResults = create.onnx.split(
splitTypes, newGemmOutput, create.onnx.constantInt64(splitSizes), splitAxis);
auto gemmShape = | ||
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape(); | ||
Value zeroBias = create.onnx.constant(DenseElementsAttr::get( | ||
RankedTensorType::get({gemmShape[splitAxis]}, elementType), 0.0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check if gemmShape[splitAxis]
is a static dimension in areCompatible()
function. Otherwise, it fails to create a constant tensor here.
else if (aCShape[0] != bCShape[0]) | ||
return false; | ||
} | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When aC is None, do check that the last dim of aOutput is static. Otherwise, it fails when you create a constant tensor of zeros in the later code that you use the last dim of aOutput.
Check the same thing for bC.
Please add a list test for the case where aC or bC is None.
newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(), | ||
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr()); | ||
|
||
// Check for common ConcatOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check this earlier just after you collect all parallelGemms
. The reason is you have a return failure()
here which may interrupt the whole rewriting while you created new weight, new bias, and new gemm. Moving this check earlier before creating any new ops would make the IR clean.
Combine Parallel Dense
CombineParallelDense is an optimization pass designed to merge multiple parallel ONNXGemmOp (Dense/Fully Connected) operations into a single, more efficient Dense layer. This optimization reduces redundant computations, improves memory efficiency, and enhances hardware utilization.
The pass identifies Dense (Gemm) operations that:
Lets assume a input case:
Before Optimization (Three Parallel Gemms)
-Memory Reads: 3 times full input (one for each gemm)
After Optimization (Combined Dense)
Improvement in performance metrics
Latency Improvement: 7-15%
Throughput Improvement: 8-14%
Memory Usage Improvement: 10-12%