Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e1c28f3
Introducing a BF16 optimized Pointwise Conv kernel
Rohanjames1997 Nov 12, 2025
4a814a2
Simplify the Fp32 Pointwise Kernel to use GEMM
Rohanjames1997 Nov 12, 2025
e2197cd
Simplify the Fp32 Depthwise kernel
Rohanjames1997 Nov 17, 2025
57c1c73
Simplify the Fp32 Depthwise Conv kernel
Rohanjames1997 Nov 26, 2025
998267d
Make fp32 Depthwise Conv branchless
Rohanjames1997 Nov 26, 2025
a560ac8
Add a BF16 Depthwise kernel. Yet to optimize perf
Rohanjames1997 Nov 28, 2025
eb59c07
Wiring for the BF16 kernels
Rohanjames1997 Dec 1, 2025
9a55be9
Make fp32 Depthwise Conv branchless
Rohanjames1997 Nov 26, 2025
4113c91
Make MlasConvFloatKernelNeonImpl branchless
Rohanjames1997 Dec 1, 2025
06e05a7
Remove redundant code
Rohanjames1997 Dec 1, 2025
f94a66a
Merge two loops
Rohanjames1997 Dec 1, 2025
7cb2b44
Make ReLU branchless in Fp32 Pointwise Conv
Rohanjames1997 Dec 2, 2025
98af370
Hoist allocation outside the loop
Rohanjames1997 Dec 3, 2025
3b7dd2f
Refactor the data validation step for Depthwise
Rohanjames1997 Dec 4, 2025
741df51
Sequential memory access for depthwise
Rohanjames1997 Dec 5, 2025
8f0e39c
Inline the function call
Rohanjames1997 Dec 5, 2025
5b82b17
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Rohanjames1997 Dec 15, 2025
c5a769d
Fix the GEMM implementation and expand the test coverage.
Rohanjames1997 Dec 10, 2025
a0a5bd5
Fix segfault in ConvNoBiasAddFusion
Rohanjames1997 Dec 10, 2025
7ea4926
Eliminate potential segfault
Rohanjames1997 Dec 10, 2025
abe4e21
Copilot's suggestion for boundary checks
Rohanjames1997 Dec 11, 2025
a6b2e10
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Rohanjames1997 Dec 15, 2025
2e9bed9
Merge branch 'bf16_conv' of https://github.com/Rohanjames1997/onnxrun…
Rohanjames1997 Dec 15, 2025
8d76b80
disable depthwise temp
Rohanjames1997 Dec 16, 2025
6cbbd25
Extend MlasSBGemmBatch to accept ZeroMode
Rohanjames1997 Dec 19, 2025
8eb4f5a
Merge branch 'microsoft:main' into bf16_conv
Rohanjames1997 Dec 19, 2025
b63aac1
Hacks for NCHWC Conv
Rohanjames1997 Dec 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
Expand All @@ -511,6 +512,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1955,6 +1955,7 @@ struct MLAS_SBGEMM_DATA_PARAMS {
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */
};

/**
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,13 @@ extern "C" {
#endif
#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)
MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon;
MLAS_CONV_FLOAT_KERNEL MlasConvNchwBf16KernelNeon;
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon;
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcBf16KernelNeon;
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon;
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseBf16KernelNeon;
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon;
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon;
MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon;
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon;
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon;
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,9 +567,13 @@ Return Value:

#if defined(MLAS_USE_ARM_NEON_NCHWC)
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon;
this->ConvNchwFloatKernel = MlasConvNchwBf16KernelNeon;
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
this->ConvNchwcFloatKernel = MlasConvNchwcBf16KernelNeon;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
// this->ConvDepthwiseFloatKernel = MlasConvDepthwiseBf16KernelNeon;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;
Comment on lines 569 to 575
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate assignment to ConvNchwFloatKernel. The second assignment overwrites the first. If the intent is to replace the float kernel with the BF16 kernel, remove the first assignment. Otherwise, clarify the initialization logic.

Suggested change
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon;
this->ConvNchwFloatKernel = MlasConvNchwBf16KernelNeon;
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
this->ConvNchwcFloatKernel = MlasConvNchwcBf16KernelNeon;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
// this->ConvDepthwiseFloatKernel = MlasConvDepthwiseBf16KernelNeon;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;
this->ConvNchwFloatKernel = MlasConvNchwBf16KernelNeon;
this->ConvNchwcFloatKernel = MlasConvNchwcBf16KernelNeon;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
// this->ConvDepthwiseFloatKernel = MlasConvDepthwiseBf16KernelNeon;

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate assignment to ConvPointwiseFloatKernel. The second assignment overwrites the first. If the intent is to replace the float kernel with the BF16 kernel, remove the first assignment. Otherwise, clarify the initialization logic.

Suggested change
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;

Copilot uses AI. Check for mistakes.
this->ConvPointwiseFloatKernel = MlasConvPointwiseBf16KernelNeon;
Comment on lines +570 to +576
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate assignment to ConvNchwcFloatKernel. The second assignment overwrites the first. If the intent is to replace the float kernel with the BF16 kernel, remove the first assignment. Otherwise, clarify the initialization logic.

Suggested change
this->ConvNchwFloatKernel = MlasConvNchwBf16KernelNeon;
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
this->ConvNchwcFloatKernel = MlasConvNchwcBf16KernelNeon;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
// this->ConvDepthwiseFloatKernel = MlasConvDepthwiseBf16KernelNeon;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;
this->ConvPointwiseFloatKernel = MlasConvPointwiseBf16KernelNeon;
// this->ConvNchwFloatKernel = MlasConvNchwBf16KernelNeon;
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
// this->ConvNchwcFloatKernel = MlasConvNchwcBf16KernelNeon;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
// this->ConvDepthwiseFloatKernel = MlasConvDepthwiseBf16KernelNeon;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;
// this->ConvPointwiseFloatKernel = MlasConvPointwiseBf16KernelNeon;

Copilot uses AI. Check for mistakes.
this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon;
this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon;
this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon;
Expand Down
Loading
Loading