Skip to content

Commit

Permalink
Add option to force generic algorithms on x86
Browse files Browse the repository at this point in the history
Option is named onnxruntime_FORCE_GENERIC_ALGORITHMS
  • Loading branch information
AlekseiNikiforovIBM committed Nov 21, 2024
1 parent 369d7bf commit f4ffe62
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 3 deletions.
5 changes: 5 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ cmake_dependent_option(MSVC_Z7_OVERRIDE "replacing /Zi and /ZI with /Z7 when usi

option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF)
option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF)
option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF)

# ENABLE_TRAINING includes all training functionality
# The following 2 entry points
Expand Down Expand Up @@ -971,6 +972,10 @@ if (onnxruntime_USE_LOCK_FREE_QUEUE)
add_compile_definitions(USE_LOCK_FREE_QUEUE)
endif()

if (onnxruntime_FORCE_GENERIC_ALGORITHMS)
add_compile_definitions(FORCE_GENERIC_ALGORITHMS)
endif()

if (onnxruntime_ENABLE_LAZY_TENSOR)
# To support LazyTensor, ORT needs to call Python function from C/C++.
# so onnxruntime_ENABLE_PYTHON is required.
Expand Down
7 changes: 7 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,13 @@ endif()
if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET)
file(GLOB_RECURSE mlas_platform_srcs
"${MLAS_SRC_DIR}/scalar/*.cpp")
elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS)
file(GLOB_RECURSE mlas_platform_srcs_generic
"${MLAS_SRC_DIR}/scalar/*.cpp")
set(mlas_platform_srcs
${mlas_platform_srcs}
${mlas_platform_srcs_generic}
)
endif()
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
endif()
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,22 @@ size_t
bool ZeroMode
);

#ifdef FORCE_GENERIC_ALGORITHMS
typedef
size_t
(MLASCALL MLAS_GEMM_FLOAT_KERNEL_GENERIC)(
const float* A,
const float* B,
float* C,
size_t CountK,
size_t CountM,
size_t CountN,
size_t lda,
size_t ldc,
float alpha
);
#endif

#else

#if defined(__aarch64__) && defined(__linux__)
Expand Down Expand Up @@ -733,6 +749,10 @@ extern "C" {
#if defined(MLAS_TARGET_AMD64_IX86)
MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelSse;
MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx;
#ifdef FORCE_GENERIC_ALGORITHMS
MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelZero;
MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelAdd;
#endif
#if defined(MLAS_TARGET_AMD64)
MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelFma3;
MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx512F;
Expand Down
13 changes: 12 additions & 1 deletion onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,11 @@ Return Value:
this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel;
this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel;
#ifndef __APPLE__
#ifndef FORCE_GENERIC_ALGORITHMS
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse;
#else // FORCE_GENERIC_ALGORITHMS
this->CastF16ToF32Kernel = nullptr;
#endif // FORCE_GENERIC_ALGORITHMS
#endif // __APPLE__

this->NchwcBlockSize = 8;
Expand All @@ -308,8 +312,11 @@ Return Value:
//
// Check if the processor supports SSE 4.1 instructions.
//

#ifndef FORCE_GENERIC_ALGORITHMS
if ((Cpuid1[2] & 0x80000) != 0) {
#else // FORCE_GENERIC_ALGORITHMS
if (false) {
#endif // FORCE_GENERIC_ALGORITHMS
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchSse41;
}

Expand All @@ -319,7 +326,11 @@ Return Value:
// Check if the processor supports the AVX and OSXSAVE features.
//

#ifndef FORCE_GENERIC_ALGORITHMS
if ((Cpuid1[2] & 0x18000000) == 0x18000000) {
#else // FORCE_GENERIC_ALGORITHMS
if (false) {
#endif // FORCE_GENERIC_ALGORITHMS

//
// Check if the operating system supports saving SSE and AVX states.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/mlas/lib/qgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ MlasGemmQuantGetDispatch(
{
const MLAS_GEMM_QUANT_DISPATCH* GemmQuantDispatch = &MlasGemmQuantDispatchDefault;

#if !defined(FORCE_GENERIC_ALGORITHMS)
#if defined(MLAS_TARGET_AMD64_IX86)
if (AIsSigned) {
GemmQuantDispatch =
Expand Down Expand Up @@ -901,6 +902,7 @@ MlasGemmQuantGetDispatch(
BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch;
}
#endif
#endif // !defined(FORCE_GENERIC_ALGORITHMS)

if (nullptr == GemmQuantDispatch) {
std::stringstream ss;
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/mlas/lib/sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ Return Value:

size_t RowsHandled;

#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
#if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS)
RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode);
#else
if (ZeroMode) {
Expand Down Expand Up @@ -1158,6 +1158,7 @@ Return Value:

if (M == 1 && TransA == CblasNoTrans && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) {

#if !defined(FORCE_GENERIC_ALGORITHMS)
#if defined(MLAS_TARGET_AMD64)

MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine;
Expand All @@ -1181,6 +1182,7 @@ Return Value:
}

#endif
#endif // !defined(FORCE_GENERIC_ALGORITHMS)

}

Expand All @@ -1193,7 +1195,7 @@ Return Value:

if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) {

#if defined(MLAS_TARGET_AMD64)
#if defined(MLAS_TARGET_AMD64) && !defined(FORCE_GENERIC_ALGORITHMS)

MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine;

Expand Down

0 comments on commit f4ffe62

Please sign in to comment.