Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
Expand All @@ -511,6 +512,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/contrib_ops/cpu/nchwc_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
}
}

#if defined(__aarch64__) && defined(__linux__)
bool use_bf16 = use_fastmath_mode_;
#else
bool use_bf16 = false;
#endif

MlasNchwcConv(
X_shape.GetDims().data(),
kernel_shape.data(),
Expand All @@ -216,7 +222,8 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
y_data.data(),
&activation_,
Sum == nullptr,
context->GetOperatorThreadPool());
context->GetOperatorThreadPool(),
use_bf16);

return Status::OK();
}
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cpu/nchwc_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/nn/conv_attributes.h"
#include "core/providers/cpu/nn/pool.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "contrib_ops/cpu/fused_activation.h"

namespace onnxruntime {
Expand Down Expand Up @@ -43,6 +44,10 @@ class NchwcConv final : public OpKernel {
public:
NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
#if defined(__aarch64__) && defined(__linux__)
auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16);
use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported();
#endif
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -51,6 +56,9 @@ class NchwcConv final : public OpKernel {
ConvAttributes conv_attrs_;

MLAS_ACTIVATION activation_;
#if defined(__aarch64__) && defined(__linux__)
bool use_fastmath_mode_{false};
#endif
};

class NchwcPoolBase : public PoolBase {
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,8 @@ MlasNchwcConv(
float* Output,
const MLAS_ACTIVATION* Activation,
bool ZeroMode,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
bool UseBf16 = false
);

void
Expand Down Expand Up @@ -1955,6 +1956,7 @@ struct MLAS_SBGEMM_DATA_PARAMS {
const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */
};

/**
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,9 @@ extern "C" {
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon;
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon;
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon;
#if defined(__linux__)
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon;
#endif
MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon;
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon;
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon;
Expand Down Expand Up @@ -1368,6 +1371,9 @@ struct MLAS_PLATFORM {
MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel;
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel;
MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel;
#if defined(__linux__)
MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseBf16Kernel;
#endif
MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount];
uint32_t NchwcBlockSize;
#endif
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ Return Value:
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;
#if defined(__linux__)
this->ConvPointwiseBf16Kernel = MlasConvPointwiseBf16KernelNeon;
#endif
this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon;
this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon;
this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon;
Expand Down
110 changes: 110 additions & 0 deletions onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

sbconv_kernel_neon.cpp

Abstract:

This module implements bfloat16 precision convolution kernels for ARM NEON.

--*/

#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__)

#include "mlasi.h"
#include "sconv.h"

constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE;

//
// BF16 Pointwise (1x1) Convolution Kernel using SBGEMM.
//
void MLASCALL
MlasConvPointwiseBf16KernelNeon(
const float* Input,
const float* Filter,
float* Output,
size_t StrideWidth,
size_t InputChannels,
size_t FilterCount,
size_t InputStride,
size_t FilterStride,
size_t OutputStride,
size_t OutputCount,
const float* Bias,
unsigned KernelFlags
)
{
const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0;
const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0;
const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0;

const size_t StrideWidthElements = StrideWidth / sizeof(float);
const size_t InputStrideElements = InputStride / sizeof(float);
const size_t FilterStrideElements = FilterStride / sizeof(float);
const size_t OutputStrideElements = OutputStride / sizeof(float);

// SBGEMM only adds bias when ZeroMode=true. When accumulating (ZeroMode=false),
// pre-add bias to existing output before the GEMM operations.
if (BiasAddition && AccumulateOutput) {
for (size_t f = 0; f < FilterCount; f++) {
float* output = Output + f * OutputStrideElements;
const float32x4_t b0 = MlasLoadFloat32x4(&Bias[f * BlockSize]);
const float32x4_t b1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]);
const float32x4_t b2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]);
const float32x4_t b3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]);
for (size_t i = 0; i < OutputCount; i++) {
MlasStoreFloat32x4(&output[i * BlockSize], MlasAddFloat32x4(b0, MlasLoadFloat32x4(&output[i * BlockSize])));
MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasAddFloat32x4(b1, MlasLoadFloat32x4(&output[i * BlockSize + 4])));
MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasAddFloat32x4(b2, MlasLoadFloat32x4(&output[i * BlockSize + 8])));
MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasAddFloat32x4(b3, MlasLoadFloat32x4(&output[i * BlockSize + 12])));
}
}
}

// Build SBGEMM params for all (filter, input_channel) combinations.
// FilterCount <= 4, InputChannels <= 8, so max 32 elements.
// Bias is set on all elements but SBGEMM only uses it when ZeroMode=true.
MLAS_SBGEMM_DATA_PARAMS gemm_params[32];

size_t idx = 0;
for (size_t f = 0; f < FilterCount; f++) {
const float* filter = Filter + f * FilterStrideElements;
float* output = Output + f * OutputStrideElements;
for (size_t ic = 0; ic < InputChannels; ic++, idx++) {
gemm_params[idx].A = Input + ic * InputStrideElements;
gemm_params[idx].B = filter + ic * BlockSize * BlockSize;
gemm_params[idx].C = output;
gemm_params[idx].lda = StrideWidthElements;
gemm_params[idx].ldb = BlockSize;
gemm_params[idx].ldc = BlockSize;
gemm_params[idx].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr;
gemm_params[idx].AIsfp32 = true;
gemm_params[idx].BIsfp32 = true;
gemm_params[idx].ZeroMode = (ic == 0) && !AccumulateOutput;
gemm_params[idx].OutputProcessor = nullptr;
}
}

MlasSBGemmBatch(OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr);

if (ReluActivation) {
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);
for (size_t f = 0; f < FilterCount; f++) {
float* output = Output + f * OutputStrideElements;
for (size_t i = 0; i < OutputCount; i++) {
MlasStoreFloat32x4(&output[i * BlockSize], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize]), ZeroVector));
MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 4]), ZeroVector));
MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 8]), ZeroVector));
MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 12]), ZeroVector));
}
}
}
}

#endif
13 changes: 7 additions & 6 deletions onnxruntime/core/mlas/lib/sbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK,

template <typename KernelType>
MLAS_FORCEINLINE void
MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor)
MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode)
{
constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides;
size_t PackedStrideN = Strides.N;
Expand All @@ -131,7 +131,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size
//
size_t CountK;
for (size_t k = 0; k < K; k += CountK) {
bool ZeroMode = (k == 0);
bool ZeroMode = (k == 0) && InitialZeroMode;
CountK = std::min(K - k, PackedStrideK);

const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN;
Expand All @@ -148,7 +148,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size

template <typename KernelType>
void
MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor)
MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode)
{
//
// Compute the strides to step through slices of the input matrices.
Expand Down Expand Up @@ -201,7 +201,7 @@ MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_
const float* pbias =
((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart

bool ZeroMode = (k == 0);
bool ZeroMode = (k == 0) && InitialZeroMode;
MlasSBGemmKernel<KernelType>(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode);
}
if (PostProcessor != nullptr) {
Expand Down Expand Up @@ -249,16 +249,17 @@ MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN,
const float* A = (const float*)DataParams->A + RangeStartM * lda;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const float* bias = DataParams->Bias;
const bool zeroMode = DataParams->ZeroMode;

if (!DataParams->BIsfp32) {
MlasSBGemmPackedOperation<KernelType>(
RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A,
lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor
lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode
);
} else {
const size_t ldb = DataParams->ldb;
const float* B = (const float*)DataParams->B + RangeStartN;
MlasSBGemmNonPackedOperation<KernelType>(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor);
MlasSBGemmNonPackedOperation<KernelType>(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode);
}
}

Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/mlas/lib/snchwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct MLAS_NCHWC_CONV_WORK_BLOCK : MLAS_NCHWC_WORK_BLOCK
float* Output;
size_t GroupCount;
bool ZeroMode;
bool UseBf16;
};

//
Expand Down Expand Up @@ -881,6 +882,11 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM

#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC))
MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel;
#if defined(MLAS_TARGET_ARM64) && defined(__linux__)
if (WorkBlock->UseBf16) {
Kernel = GetMlasPlatform().ConvPointwiseBf16Kernel;
}
#endif
#else
MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel;
#endif
Expand Down Expand Up @@ -1224,7 +1230,8 @@ MlasNchwcConv(
float* Output,
const MLAS_ACTIVATION* Activation,
bool ZeroMode,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
bool UseBf16
)
/*++

Expand Down Expand Up @@ -1269,6 +1276,8 @@ Routine Description:
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.

UseBf16 - Supplies true to use BF16 for convolutions on supported platforms.

Return Value:

None.
Expand All @@ -1288,6 +1297,7 @@ Return Value:
WorkBlock.Bias = Bias;
WorkBlock.Activation = Activation;
WorkBlock.ZeroMode = ZeroMode;
WorkBlock.UseBf16 = UseBf16;

//
// Capture the generic shape parameters to the work block.
Expand Down
Loading