Skip to content

Conversation

@Rohanjames1997
Copy link
Contributor

@Rohanjames1997 Rohanjames1997 commented Dec 19, 2025

Description

This PR adds a BF16 (bfloat16) pointwise convolution kernel for ARM64 NCHWc format, leveraging the existing SBGEMM infrastructure. When the mlas.enable_gemm_fastmath_arm64_bfloat16 session option is enabled on supported ARM64 Linux hardware, Pointwise Conv is rerouted to use this BF16 implementation. This is an opt-in feature, similar to how BF16 matmul is opt-in.

Added a bool ZeroMode field to MLAS_SBGEMM_DATA_PARAMS (default true for backward compatibility) to enable per-batch control over output accumulation. This mirrors the beta parameter in FP32's MlasGemmBatch and is required for Pointwise convolutions with >128 input channels, where multiple GEMM calls must accumulate into the same output buffer.

Motivation and Context

The existing mlas.enable_gemm_fastmath_arm64_bfloat16 session option accelerates MatMul operations on ARM64 processors with BF16 support, but convolution operations did not benefit from this optimization. Pointwise convolutions (1x1 kernels) are essentially batched matrix multiplications.

This change extends the BF16 fastmath optimization to pointwise NCHWc convolutions, reusing the same session option. The implementation mirrors the FP32 pointwise kernel structure while delegating the actual computation to SBGEMM, ensuring correctness and maintainability.

Performance improvement

Measured a 15-20% gain on Mobilenet inference on an AWS Graviton4 instance.

Before (FP32)

/build/Linux/Release/onnxruntime_perf_test -C "mlas.enable_gemm_fastmath_arm64_bfloat16|0" -x 32 -I -m times -r 2000 ~/scripts/mobilenet.onnx

Number of inferences per second: 559.154

After (BF16)

./build/Linux/Release/onnxruntime_perf_test -C "mlas.enable_gemm_fastmath_arm64_bfloat16|1" -x 32 -I -m times -r 2000 ~/scripts/mobilenet.onnx

Number of inferences per second: 651.221

@Rohanjames1997
Copy link
Contributor Author

@hariharans29 another PR that's up your alley.

Can you request a preliminary review from Copilot & run CI?

Thanks!

@hariharans29 hariharans29 requested a review from Copilot December 21, 2025 05:23
@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends BF16 (bfloat16) precision optimization support to pointwise (1x1) NCHWc convolutions on ARM64 Linux platforms. The implementation leverages the existing SBGEMM infrastructure and the mlas.enable_gemm_fastmath_arm64_bfloat16 session option, delivering a reported 15-20% performance improvement on Mobilenet inference.

Key changes:

  • Adds BF16 pointwise convolution kernel (MlasConvPointwiseBf16KernelNeon) that delegates computation to SBGEMM
  • Introduces ZeroMode field to MLAS_SBGEMM_DATA_PARAMS to enable accumulation control across multiple GEMM calls
  • Routes pointwise convolutions to BF16 implementation when fastmath mode is enabled on supported hardware

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.

Show a summary per file
File Description
onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp New BF16 pointwise convolution kernel implementation using SBGEMM batch operations
onnxruntime/core/mlas/inc/mlas.h Adds UseBf16 parameter to MlasNchwcConv API and ZeroMode field to MLAS_SBGEMM_DATA_PARAMS
onnxruntime/core/mlas/lib/sbgemm.h Propagates ZeroMode parameter through SBGEMM packed/non-packed operations
onnxruntime/core/mlas/lib/snchwc.cpp Adds UseBf16 parameter and conditional BF16 kernel selection logic
onnxruntime/core/mlas/lib/mlasi.h Declares MlasConvPointwiseBf16KernelNeon and adds ConvPointwiseBf16Kernel to platform struct
onnxruntime/core/mlas/lib/platform.cpp Initializes BF16 kernel pointer in ARM64 NEON platform initialization
onnxruntime/contrib_ops/cpu/nchwc_ops.h Adds fastmath mode detection in constructor and member variable
onnxruntime/contrib_ops/cpu/nchwc_ops.cc Passes BF16 flag to MlasNchwcConv based on session options
cmake/onnxruntime_mlas.cmake Adds new source file with ARM BF16 compilation flags

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Rohanjames1997
Copy link
Contributor Author

Rohanjames1997 commented Dec 22, 2025

Thanks @hariharans29 !
Looks like the failures are due to inconsistent ifdefs(?). I'm looking into it.
Do let me know if you have ideas too, but I may need you to rerun CI a few times more after I push fixes.

@Rohanjames1997
Copy link
Contributor Author

Since SBGemm is compiled only on linux, I have enabled this BF16 Pointwise Conv kernel only on linux as well.

Can the CI be run again? 🤞

@aviralagrawal
Copy link

Exciting stuff. Looking forward to seeing this merged.

@Rohanjames1997
Copy link
Contributor Author

Thanks @aviralagrawal!

@hariharans29 gentle reminder 🙂

@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@hariharans29
Copy link
Member

Thanks @aviralagrawal!

@hariharans29 gentle reminder 🙂

Hi - Sorry , I was OOF. Kicked off CI and will review this PR soon. Thanks!

Copy link
Contributor Author

@Rohanjames1997 Rohanjames1997 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!
I've replied to the comments. I will wait for your reply before adding any more commits.

Meanwhile I am looking at the current CI failure.

NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
#if defined(__aarch64__) && defined(__linux__)
auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16);

Does this need a default value?

Copy link
Contributor Author

@Rohanjames1997 Rohanjames1997 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not.

It is assigned a value inside onnxruntime_session_options_config_keys.h

static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

In turn,enable_gemm_fastmath_arm64_bfloat16 is controlled via a runtime flag (and defaults to false).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I am missing something GetConfigEntry() returns std::optional which may be not there, and that needs to be accounted for. GetConfigOrDefault may be a better way of handling that.

https://github.com/microsoft/onnxruntime/blob/9659a858808654ffb6a34a77016fb735fdf5d44f/onnxruntime/core/framework/config_options.h?plain=1#L27C1-L27C91

Copy link
Contributor Author

@Rohanjames1997 Rohanjames1997 Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.. I reused the logic from line 34 in matmul.h.

My doubt is - wouldn't this line run without fail? (And hence, always have a value)

@Rohanjames1997
Copy link
Contributor Author

Thanks for the reviews @hariharans29 & @yuslepukhin!

Could you run CI again please?

Meanwhile I will try to figure out how to add tests for SBConv specifically.

@hariharans29
Copy link
Member

Thanks for the reviews @hariharans29 & @yuslepukhin!

Could you run CI again please?

Meanwhile I will try to figure out how to add tests for SBConv specifically.

Thanks @Rohanjames1997 ! Sorry - I will respond to your comments later today. Running CI now.

@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@hariharans29
Copy link
Member

You can ignore the failing CUDA checks. They just need #27020 to go in.

@Rohanjames1997
Copy link
Contributor Author

Phew... Thanks 😅

@hariharans29
Copy link
Member

Can you rebase as well ? You may need some changes in main for some checks to pass

@Rohanjames1997
Copy link
Contributor Author

Done. Ready for CI when you are. Feel free to merge if the CI passes 🤞

@hariharans29
Copy link
Member

Done. Ready for CI when you are. Feel free to merge if the CI passes 🤞

Thank you ! I will take one final look this evening and merge should everything look fine. Thanks again for the contribution !

@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@hariharans29
Copy link
Member

hariharans29 commented Jan 16, 2026

Left behind a couple of comments - can you please address them when you get a chance ? The rest of the code looks fine ! Thank you again !

@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants