Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
Expand All @@ -517,6 +518,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/contrib_ops/cpu/nchwc_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
}
}

#if defined(__aarch64__) && defined(__linux__)
const bool use_bf16 = use_fastmath_mode_;
#else
const bool use_bf16 = false;
#endif

MlasNchwcConv(
X_shape.GetDims().data(),
kernel_shape.data(),
Expand All @@ -216,7 +222,8 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
y_data.data(),
&activation_,
Sum == nullptr,
context->GetOperatorThreadPool());
context->GetOperatorThreadPool(),
use_bf16);

return Status::OK();
}
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cpu/nchwc_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/providers/cpu/nn/pool.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "contrib_ops/cpu/fused_activation.h"

namespace onnxruntime {
Expand Down Expand Up @@ -43,6 +44,10 @@ class NchwcConv final : public OpKernel {
public:
NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
#if defined(__aarch64__) && defined(__linux__)
auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16);

Does this need a default value?

Copy link
Contributor Author

@Rohanjames1997 Rohanjames1997 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not.

It is assigned a value inside onnxruntime_session_options_config_keys.h

static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

In turn,enable_gemm_fastmath_arm64_bfloat16 is controlled via a runtime flag (and defaults to false).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I am missing something GetConfigEntry() returns std::optional which may be not there, and that needs to be accounted for. GetConfigOrDefault may be a better way of handling that.

https://github.com/microsoft/onnxruntime/blob/9659a858808654ffb6a34a77016fb735fdf5d44f/onnxruntime/core/framework/config_options.h?plain=1#L27C1-L27C91

Copy link
Contributor Author

@Rohanjames1997 Rohanjames1997 Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.. I reused the logic from line 34 in matmul.h.

My doubt is - wouldn't this line run without fail? (And hence, always have a value)

use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported();
#endif
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -51,6 +56,9 @@ class NchwcConv final : public OpKernel {
ConvAttributes conv_attrs_;

MLAS_ACTIVATION activation_;
#if defined(__aarch64__) && defined(__linux__)
bool use_fastmath_mode_{false};
#endif
};

class NchwcPoolBase : public PoolBase {
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,8 @@ MlasNchwcConv(
float* Output,
const MLAS_ACTIVATION* Activation,
bool ZeroMode,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
bool UseBf16 = false
);

void
Expand Down Expand Up @@ -1965,6 +1966,7 @@ struct MLAS_SBGEMM_DATA_PARAMS {
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */
};

/**
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,9 @@ extern "C" {
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon;
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon;
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon;
#if defined(__aarch64__) && defined(__linux__)
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon;
#endif
MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon;
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon;
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon;
Expand Down Expand Up @@ -1372,6 +1375,9 @@ struct MLAS_PLATFORM {
MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel;
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel;
MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel;
#if defined(__aarch64__) && defined(__linux__)
MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseBf16Kernel;
#endif
MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount];
uint32_t NchwcBlockSize;
#endif
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,9 @@ Return Value:
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;
#if defined(__aarch64__) && defined(__linux__)
this->ConvPointwiseBf16Kernel = MlasConvPointwiseBf16KernelNeon;
#endif
this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon;
this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon;
this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon;
Expand Down
110 changes: 110 additions & 0 deletions onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

sbconv_kernel_neon.cpp

Abstract:

This module implements bfloat16 precision convolution kernels for ARM NEON.

--*/

#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__)

#include "mlasi.h"
#include "sconv.h"

constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE;

//
// BF16 Pointwise (1x1) Convolution Kernel using SBGEMM.
//
void MLASCALL
MlasConvPointwiseBf16KernelNeon(
const float* Input,
const float* Filter,
float* Output,
size_t StrideWidth,
size_t InputChannels,
size_t FilterCount,
size_t InputStride,
size_t FilterStride,
size_t OutputStride,
size_t OutputCount,
const float* Bias,
unsigned KernelFlags
)
{
const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0;
const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0;
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;

const size_t StrideWidthElements = StrideWidth / sizeof(float);
const size_t InputStrideElements = InputStride / sizeof(float);
const size_t FilterStrideElements = FilterStride / sizeof(float);
const size_t OutputStrideElements = OutputStride / sizeof(float);

// SBGEMM only adds bias when ZeroMode=true. When accumulating (ZeroMode=false),
// pre-add bias to existing output before the GEMM operations.
if (BiasAddition && AccumulateOutput) {
for (size_t f = 0; f < FilterCount; f++) {
float* output = Output + f * OutputStrideElements;
const float32x4_t b0 = MlasLoadFloat32x4(&Bias[f * BlockSize]);
const float32x4_t b1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]);
const float32x4_t b2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]);
const float32x4_t b3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]);
for (size_t i = 0; i < OutputCount; i++) {
MlasStoreFloat32x4(&output[i * BlockSize], MlasAddFloat32x4(b0, MlasLoadFloat32x4(&output[i * BlockSize])));
MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasAddFloat32x4(b1, MlasLoadFloat32x4(&output[i * BlockSize + 4])));
MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasAddFloat32x4(b2, MlasLoadFloat32x4(&output[i * BlockSize + 8])));
MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasAddFloat32x4(b3, MlasLoadFloat32x4(&output[i * BlockSize + 12])));
}
}
}

// Build SBGEMM params for all (filter, input_channel) combinations.
// FilterCount <= 4, InputChannels <= 8, so max 32 elements.
// Bias is set on all elements but SBGEMM only uses it when ZeroMode=true.
MLAS_SBGEMM_DATA_PARAMS gemm_params[32];

size_t idx = 0;
for (size_t f = 0; f < FilterCount; f++) {
const float* filter = Filter + f * FilterStrideElements;
float* output = Output + f * OutputStrideElements;
for (size_t ic = 0; ic < InputChannels; ic++, idx++) {
gemm_params[idx].A = Input + ic * InputStrideElements;
gemm_params[idx].B = filter + ic * BlockSize * BlockSize;
gemm_params[idx].C = output;
gemm_params[idx].lda = StrideWidthElements;
gemm_params[idx].ldb = BlockSize;
gemm_params[idx].ldc = BlockSize;
gemm_params[idx].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr;
gemm_params[idx].AIsfp32 = true;
gemm_params[idx].BIsfp32 = true;
gemm_params[idx].ZeroMode = (ic == 0) && !AccumulateOutput;
gemm_params[idx].OutputProcessor = nullptr;
}
}

MlasSBGemmBatch(OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr);

if (ReluActivation) {
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
for (size_t f = 0; f < FilterCount; f++) {
float* output = Output + f * OutputStrideElements;
for (size_t i = 0; i < OutputCount; i++) {
MlasStoreFloat32x4(&output[i * BlockSize], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize]), ZeroVector));
MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 4]), ZeroVector));
MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 8]), ZeroVector));
MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 12]), ZeroVector));
}
}
}
}

#endif
13 changes: 7 additions & 6 deletions onnxruntime/core/mlas/lib/sbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK,

template <typename KernelType>
MLAS_FORCEINLINE void
MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor)
MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode)
{
constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides;
size_t PackedStrideN = Strides.N;
Expand All @@ -131,7 +131,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size
//
size_t CountK;
for (size_t k = 0; k < K; k += CountK) {
bool ZeroMode = (k == 0);
bool ZeroMode = (k == 0) && InitialZeroMode;
CountK = std::min(K - k, PackedStrideK);

const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN;
Expand All @@ -148,7 +148,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size

template <typename KernelType>
void
MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor)
MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode)
{
//
// Compute the strides to step through slices of the input matrices.
Expand Down Expand Up @@ -201,7 +201,7 @@ MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_
const float* pbias =
((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart

bool ZeroMode = (k == 0);
bool ZeroMode = (k == 0) && InitialZeroMode;
MlasSBGemmKernel<KernelType>(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode);
}
if (PostProcessor != nullptr) {
Expand Down Expand Up @@ -249,16 +249,17 @@ MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN,
const float* A = (const float*)DataParams->A + RangeStartM * lda;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const float* bias = DataParams->Bias;
const bool zeroMode = DataParams->ZeroMode;

if (!DataParams->BIsfp32) {
MlasSBGemmPackedOperation<KernelType>(
RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A,
lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor
lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode
);
} else {
const size_t ldb = DataParams->ldb;
const float* B = (const float*)DataParams->B + RangeStartN;
MlasSBGemmNonPackedOperation<KernelType>(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor);
MlasSBGemmNonPackedOperation<KernelType>(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode);
}
}

Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/mlas/lib/snchwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct MLAS_NCHWC_CONV_WORK_BLOCK : MLAS_NCHWC_WORK_BLOCK
float* Output;
size_t GroupCount;
bool ZeroMode;
bool UseBf16;
};

//
Expand Down Expand Up @@ -881,6 +882,11 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM

#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC))
MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel;
#if defined(__aarch64__) && defined(__linux__)
if (WorkBlock->UseBf16) {
Kernel = GetMlasPlatform().ConvPointwiseBf16Kernel;
}
#endif
#else
MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel;
#endif
Expand Down Expand Up @@ -1224,7 +1230,8 @@ MlasNchwcConv(
float* Output,
const MLAS_ACTIVATION* Activation,
bool ZeroMode,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
const bool UseBf16
)
/*++

Expand Down Expand Up @@ -1269,6 +1276,8 @@ Routine Description:
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.

UseBf16 - Supplies true to use BF16 for convolutions on supported platforms.

Return Value:

None.
Expand All @@ -1288,6 +1297,7 @@ Return Value:
WorkBlock.Bias = Bias;
WorkBlock.Activation = Activation;
WorkBlock.ZeroMode = ZeroMode;
WorkBlock.UseBf16 = UseBf16;

//
// Capture the generic shape parameters to the work block.
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ static size_t Conv2dNchwcRegistLongExecute() {
if (GetMlasThreadPool() != nullptr) {
count += MlasLongExecuteTests<MlasNchwcConv2DTest<true>>::RegisterLongExecute();
}
#if defined(__aarch64__) && defined(__linux__)
if (MlasBf16AccelerationSupported()) {
count += MlasLongExecuteTests<MlasNchwcConv2DBf16Test<false>>::RegisterLongExecute();
if (GetMlasThreadPool() != nullptr) {
count += MlasLongExecuteTests<MlasNchwcConv2DBf16Test<true>>::RegisterLongExecute();
}
}
#endif
}

return count;
Expand All @@ -25,6 +33,14 @@ static size_t Conv2dNchwcRegistShortExecute() {
if (GetMlasThreadPool() != nullptr) {
count += Conv2dShortExecuteTest<MlasNchwcConv2DTest<true>>::RegisterShortExecuteTests();
}
#if defined(__aarch64__) && defined(__linux__)
if (MlasBf16AccelerationSupported()) {
count += Conv2dShortExecuteTest<MlasNchwcConv2DBf16Test<false>>::RegisterShortExecuteTests();
if (GetMlasThreadPool() != nullptr) {
count += Conv2dShortExecuteTest<MlasNchwcConv2DBf16Test<true>>::RegisterShortExecuteTests();
}
}
#endif
}

return count;
Expand Down
Loading
Loading