-
Notifications
You must be signed in to change notification settings - Fork 3.7k
mlas/arm64: add NEON conv asm kernels and tune NCHWC kernel selection #27099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Milos Puzovic <[email protected]>
|
Interesting contribution - thank you! A few questions -
|
|
Hi @aviralagrawal, thank you vey much for your prompt feedback.
Compared to direct GEMM implementation of pointwise convolution asm kernel computes 1x1 conv directly:
As usual there are trade-offs so direct GEMM would be faster when output count is small because then asm kernel drops to single-output path which has less ILP and won't be able to reuse filter loads, non-unit stride and non-contigious output regions hence why we have heuristics checking for stride width and height and very large K/M when GEMM blocking can make better use of caches then a fixed 4-output tile. This is best illustrated if we extract pointwise convolutions from mobilnet that we ran and we can see that on average asm implementation is 1.07x faster, and significant speed ups are when number of channels is high and K/M are small (in the image those are H and W dimensions). In convolution heavy networks the convolutions that are dominant are ones with large number of channels and low height and width so we see visible performance improvements as optimisations from this PR are weighted in that direction.
For benchmarking we used the model from: https://github.com/onnx/models/blob/main/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx
Running |
|
Thanks @milpuz01 for the detailed description & comment! A couple questions from my side:
|
|
Hi @Rohanjames1997, thank you very much for your comments.
No, particular reason. Mostly because the focus for this PR was on MobileNet model and lack of bandwidth. Thank you for sharing the model where
Yes, I think that is great idea and would be interesting to hear from @hariharans29 too what other testing we should make to try to make these kernels default. As you can see above this change is not going to accelerate all possible pointwise convolutions for example but on average it will show the improvements so if we could agree on a set of performance targets we can use that to drive the decision. Also thank you for your code review I will address them in a separate commit. |
Unfortunately, I don't have a comprehensive list of performance targets to be met to make the feature default. Since, the performance testing may not include all possible Conv shapes, I would like to err on the side of caution and atleast provide one release timeline heads-up to the users before considering making the feature default. I would also encourage you to open a discussion to solicit feedback from other ORT users on ARM if they see speed-up for their models with this feature. It would provide greater confidence and a strong data point to turn it on by default. Thanks for this contribution, we will review it shortly ! |
Signed-off-by: Milos Puzovic <[email protected]>
Thanks @hariharans29. I agree with erring on the side of caution. If this PR goes through and it is in the main release is it possible to add a note that we would like to make |
Thanks @milpuz01. The PR should go through in main eventually but I don't think it will go in 1.24.0 unfortunately as the release branch is cut and the bar to take in new code at this point is critical bug fixes and urgent customer asks only. I will try to take this in for 1.24.1 when it happens and sure I will add a note about considering making it default in one of the future releases, but ultimately, as discussed in the comment #27099 (comment), I expect the NchwcFloatKernel needs optimizations before considering that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds new AArch64 NEON assembly micro-kernels for NCHW, depthwise NCHWc, and pointwise NCHWc convolution, integrates them into the MLAS build, and updates NCHWc kernel-selection heuristics to prefer the asm kernels in selected shapes.
Changes:
- Add new AArch64
.Sconvolution micro-kernels (NCHW, depthwise NCHWc, pointwise NCHWc) and wire them into the MLAS build. - Update ARM64 platform init and NCHWc execution heuristics to select asm kernels for pointwise (stride-1, larger tiles) and depthwise (wider outputs).
- Remove the old intrinsics wrapper for the NCHW float kernel in the NCHWc NEON source file.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| cmake/onnxruntime_mlas.cmake | Adds new AArch64 asm sources to the ARM NEON NCHWc MLAS build setup. |
| onnxruntime/core/mlas/lib/snchwc.cpp | Adds ARM64 heuristics to prefer asm depthwise/pointwise kernels in “safe” cases. |
| onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp | Removes the old NCHW float kernel wrapper implementation from the NCHWc NEON source file. |
| onnxruntime/core/mlas/lib/platform.cpp | Switches ARM64 NCHW conv kernel default to asm; updates commentary around kernel choices. |
| onnxruntime/core/mlas/lib/mlasi.h | Declares new asm kernel entry points for ARM64 NEON NCHWc. |
| onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S | Adds new NCHW convolution asm micro-kernel. |
| onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S | Adds new depthwise NCHWc asm micro-kernel (fast/slow path for padding). |
| onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S | Adds new pointwise NCHWc asm micro-kernel (multi-output reuse). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // prologue ------------------------------------------------------------- | ||
| stp x19,x20,[sp,#-.LFrame_SavedRegs]! | ||
| stp x21,x22,[sp,#.LFrame_x21_x22] | ||
| stp x23,x24,[sp,#.LFrame_x23_x24] | ||
| stp x25,x26,[sp,#.LFrame_x25_x26] |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MlasConvNchwFloatKernelNeonAsm uses v8–v15 (e.g., in CONV2_NOPAD/POSTPROCESS paths) but the prologue only saves x19–x28. Per this repo’s AArch64 assembly convention/ABI notes (e.g., aarch64/ConvSymU8KernelNeon.S saves d8–d15), this function must preserve callee-saved FP/SIMD regs (at least d8–d15) when it clobbers them. Please save/restore d8–d15 (or refactor to only use caller-saved v registers) to avoid corrupting caller state.
| this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon; | ||
| // Prefer the hand written micro-kernel for the NCHW convolution path. It | ||
| // offers a tighter schedule and a specialised two-output inner loop that | ||
| // reduces pressure on the memory system compared |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is incomplete (ends with “compared”) and doesn’t form a full sentence. Please finish the sentence (e.g., “compared to the intrinsics kernel”) so the rationale is clear.
| // reduces pressure on the memory system compared | |
| // reduces pressure on the memory system compared to the intrinsics-based implementation. |
| // reduces memory traffic. The AArch64 assembly kernel is picked up by | ||
| // heuristics in platform.cpp to avoid regressions on small convolutions. |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment says the AArch64 pointwise asm kernel is selected by “heuristics in platform.cpp”, but the selection logic is in snchwc.cpp. Please update the comment to point to the correct location (or move the heuristic into platform dispatch if that’s the intent) to avoid confusion for future maintainers.
| // reduces memory traffic. The AArch64 assembly kernel is picked up by | |
| // heuristics in platform.cpp to avoid regressions on small convolutions. | |
| // reduces memory traffic. The AArch64 assembly kernel is selected by | |
| // heuristics in snchwc.cpp to avoid regressions on small convolutions. |
| // Compute the base pointers for this filter block. | ||
| madd x16,x14,x8,x2 // output pointer for this filter | ||
| madd x17,x14,x7,x1 // filter pointer for this filter | ||
| add x18,x10,x14,lsl #6 // bias pointer (if used) |
Copilot
AI
Jan 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kernel uses x18 as a general-purpose register (e.g., bias pointer at line 139). On some AArch64 ABIs x18 is reserved as the platform register (the repo even notes this in aarch64/HalfGemmKernelNeon.S:22). Since setup_arm_neon_nchwc() doesn’t exclude these sources on Apple, enabling onnxruntime_USE_ARM_NEON_NCHWC on such platforms could clobber x18 and break TLS/platform state. Please avoid using x18 (pick another temp) or add platform-specific guarding in CMake so this file isn’t built where x18 is reserved.
| add x18,x10,x14,lsl #6 // bias pointer (if used) | |
| add x21,x10,x14,lsl #6 // bias pointer (if used) |

Overview
This PR adds ARM64 NEON assembly micro‑kernels for NCHW, depthwise, and pointwise convolution, wires them into the MLAS build, and adds shape‑based selection heuristics for NCHWC depthwise/pointwise to favor the asm kernels in safe cases (stride‑1 pointwise; wider depthwise outputs). The BF16 path is unchanged.
Key changes
Performance
Numbers below are expressed as multipliers vs the non‑NCHWC baseline (same model and perf_test settings):
Baseline (no
--enable_arm_neon_nchwc)With
--enable_arm_neon_nchwc(no asm additions/heuristics)With this PR (asm kernels + heuristics)
Testing
./build.sh --config Release --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync --skip_tests --enable_pybind --build_wheel --enable_arm_neon_nchwcOMP_NUM_THREADS=8 ./build/Linux/Release/onnxruntime_perf_test -I -m times -r 1000 --x 8 ~/mobilenetv2-7.onnx