-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Introducing BF16 Pointwise NCHWc Convolution for Arm64 #26838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@hariharans29 another PR that's up your alley. Can you request a preliminary review from Copilot & run CI? Thanks! |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR extends BF16 (bfloat16) precision optimization support to pointwise (1x1) NCHWc convolutions on ARM64 Linux platforms. The implementation leverages the existing SBGEMM infrastructure and the mlas.enable_gemm_fastmath_arm64_bfloat16 session option, delivering a reported 15-20% performance improvement on Mobilenet inference.
Key changes:
- Adds BF16 pointwise convolution kernel (
MlasConvPointwiseBf16KernelNeon) that delegates computation to SBGEMM - Introduces
ZeroModefield toMLAS_SBGEMM_DATA_PARAMSto enable accumulation control across multiple GEMM calls - Routes pointwise convolutions to BF16 implementation when fastmath mode is enabled on supported hardware
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp |
New BF16 pointwise convolution kernel implementation using SBGEMM batch operations |
onnxruntime/core/mlas/inc/mlas.h |
Adds UseBf16 parameter to MlasNchwcConv API and ZeroMode field to MLAS_SBGEMM_DATA_PARAMS |
onnxruntime/core/mlas/lib/sbgemm.h |
Propagates ZeroMode parameter through SBGEMM packed/non-packed operations |
onnxruntime/core/mlas/lib/snchwc.cpp |
Adds UseBf16 parameter and conditional BF16 kernel selection logic |
onnxruntime/core/mlas/lib/mlasi.h |
Declares MlasConvPointwiseBf16KernelNeon and adds ConvPointwiseBf16Kernel to platform struct |
onnxruntime/core/mlas/lib/platform.cpp |
Initializes BF16 kernel pointer in ARM64 NEON platform initialization |
onnxruntime/contrib_ops/cpu/nchwc_ops.h |
Adds fastmath mode detection in constructor and member variable |
onnxruntime/contrib_ops/cpu/nchwc_ops.cc |
Passes BF16 flag to MlasNchwcConv based on session options |
cmake/onnxruntime_mlas.cmake |
Adds new source file with ARM BF16 compilation flags |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Thanks @hariharans29 ! |
|
Since SBGemm is compiled only on linux, I have enabled this BF16 Pointwise Conv kernel only on linux as well. Can the CI be run again? 🤞 |
|
Exciting stuff. Looking forward to seeing this merged. |
|
Thanks @aviralagrawal! @hariharans29 gentle reminder 🙂 |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Hi - Sorry , I was OOF. Kicked off CI and will review this PR soon. Thanks! |
Rohanjames1997
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review!
I've replied to the comments. I will wait for your reply before adding any more commits.
Meanwhile I am looking at the current CI failure.
| NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { | ||
| ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); | ||
| #if defined(__aarch64__) && defined(__linux__) | ||
| auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not.
It is assigned a value inside onnxruntime_session_options_config_keys.h
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
In turn,enable_gemm_fastmath_arm64_bfloat16 is controlled via a runtime flag (and defaults to false).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I am missing something GetConfigEntry() returns std::optional which may be not there, and that needs to be accounted for. GetConfigOrDefault may be a better way of handling that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Thanks for the reviews @hariharans29 & @yuslepukhin! Could you run CI again please? Meanwhile I will try to figure out how to add tests for SBConv specifically. |
Thanks @Rohanjames1997 ! Sorry - I will respond to your comments later today. Running CI now. |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
You can ignore the failing CUDA checks. They just need #27020 to go in. |
|
Phew... Thanks 😅 |
|
Can you rebase as well ? You may need some changes in main for some checks to pass |
|
Done. Ready for CI when you are. Feel free to merge if the CI passes 🤞 |
Thank you ! I will take one final look this evening and merge should everything look fine. Thanks again for the contribution ! |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
Left behind a couple of comments - can you please address them when you get a chance ? The rest of the code looks fine ! Thank you again ! |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Description
This PR adds a BF16 (bfloat16) pointwise convolution kernel for ARM64 NCHWc format, leveraging the existing SBGEMM infrastructure. When the
mlas.enable_gemm_fastmath_arm64_bfloat16session option is enabled on supported ARM64 Linux hardware, Pointwise Conv is rerouted to use this BF16 implementation. This is an opt-in feature, similar to how BF16 matmul is opt-in.Added a bool ZeroMode field to
MLAS_SBGEMM_DATA_PARAMS(defaulttruefor backward compatibility) to enable per-batch control over output accumulation. This mirrors the beta parameter in FP32'sMlasGemmBatchand is required for Pointwise convolutions with >128 input channels, where multiple GEMM calls must accumulate into the same output buffer.Motivation and Context
The existing
mlas.enable_gemm_fastmath_arm64_bfloat16session option accelerates MatMul operations on ARM64 processors with BF16 support, but convolution operations did not benefit from this optimization. Pointwise convolutions (1x1 kernels) are essentially batched matrix multiplications.This change extends the BF16 fastmath optimization to pointwise NCHWc convolutions, reusing the same session option. The implementation mirrors the FP32 pointwise kernel structure while delegating the actual computation to SBGEMM, ensuring correctness and maintainability.
Performance improvement
Measured a 15-20% gain on Mobilenet inference on an AWS Graviton4 instance.
Before (FP32)
After (BF16)