Skip to content

Commit

Permalink
SQNBitGemm - move workspace size calculation functions to hardware-sp…
Browse files Browse the repository at this point in the history
…ecific implementations (#20757)

The workspace usage may be hardware-specific. Moving away from a common workspace size calculation allows more flexibility in the hardware-specific implementations.
  • Loading branch information
edgchen1 authored May 22, 2024
1 parent d4fe4b5 commit a39f886
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 87 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down
74 changes: 39 additions & 35 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Module Name:
--*/

#include "sqnbitgemm.h"
#include "sqnbitgemm_q8_block.h"

#include <cassert>

Expand Down Expand Up @@ -92,53 +93,58 @@ namespace
{

size_t
SQNBitGemmWorkspaceAlignment(SQNBitGemmVariant Variant)
SQNBitGemmPerGemmWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompInt8: {
return Q8BlkAlignment();
}
default: {
return 1;
}
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
if (Dispatch == nullptr) {
return 0;
}

if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) {
return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType);
}

return 0;
}

size_t
SQNBitGemmPerGemmWorkspaceSize(
SQNBitGemmVariant Variant,
size_t M,
size_t N,
size_t K,
size_t BlkLen
SQNBitGemmPerGemmWorkspaceAlignment(
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(N);
const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch;
if (Dispatch == nullptr) {
return 1;
}

switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompInt8: {
// workspace buffer is used for block quantization of A to int8
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
return PerGemmWorkspaceSize;
}
default: {
return 0;
}
if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) {
return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType);
}

return 1;
}

size_t
SQNBitGemmPerGemmWorkspaceStride(
SQNBitGemmVariant Variant,
size_t M,
size_t N,
size_t K,
size_t BlkLen
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto Size = SQNBitGemmPerGemmWorkspaceSize(Variant, M, N, K, BlkLen);
const auto Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType);
const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);
return MlasDivRoundup(Size, Alignment) * Alignment;
}

Expand All @@ -155,14 +161,12 @@ MlasSQNBitGemmBatchWorkspaceSize(
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType);

const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType);
if (PerGemmWorkspaceStride == 0) {
return 0;
}

const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);

const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride;

Expand Down Expand Up @@ -574,14 +578,14 @@ MlasSQNBitGemmBatch(
// Ensure `Workspace` has correct alignment.
//
if (Workspace != nullptr) {
const size_t Alignment = SQNBitGemmWorkspaceAlignment(Variant);
const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType);
const uintptr_t WorkspaceAddress = reinterpret_cast<uintptr_t>(Workspace);
Workspace = reinterpret_cast<void*>(
(WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))
);
}

const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen);
const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType);

if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace;
InitializeWorkspaceOperation != nullptr) {
Expand Down
89 changes: 37 additions & 52 deletions onnxruntime/core/mlas/lib/sqnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ Module Name:

#pragma once

#include <cassert>

#include "mlas_qnbit.h"
#include "mlasi.h"

Expand All @@ -44,56 +42,6 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount)
}
}

//
// Quantized int8 block helpers.
//

MLAS_FORCEINLINE
const float&
Q8BlkScale(const std::byte* BlkPtr)
{
return *reinterpret_cast<const float*>(BlkPtr);
}

MLAS_FORCEINLINE
float&
Q8BlkScale(std::byte* BlkPtr)
{
return *reinterpret_cast<float*>(BlkPtr);
}

MLAS_FORCEINLINE
const int8_t*
Q8BlkData(const std::byte* BlkPtr)
{
return reinterpret_cast<const int8_t*>(BlkPtr + sizeof(float));
}

MLAS_FORCEINLINE
int8_t*
Q8BlkData(std::byte* BlkPtr)
{
return reinterpret_cast<int8_t*>(BlkPtr + sizeof(float));
}

MLAS_FORCEINLINE
constexpr size_t
Q8BlkSize(size_t BlkLen)
{
const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t);
// Currently, the strictest alignment requirement of a block is for a float.
// Ensure contiguous blocks are suitably aligned.
assert(BlkSize % alignof(float) == 0);
return BlkSize;
}

MLAS_FORCEINLINE
constexpr size_t
Q8BlkAlignment()
{
return alignof(float);
}

//
// Kernel dispatch structure.
//
Expand Down Expand Up @@ -126,6 +74,43 @@ struct MLAS_SQNBIT_GEMM_DISPATCH {

SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;

//
// Workspace size calculation function prototypes.
//

/**
* @brief Gets the required size in bytes of the per-GEMM intermediate workspace.
* Returns a size of zero if no intermediate workspace is needed.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)(
size_t M,
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr;

/**
* @brief Gets the required byte alignment of the per-GEMM intermediate workspace.
*
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)(
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr;

//
// CompFp32 kernel function prototypes.
//
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;

d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;

d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize;
d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData;

d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize;
d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

Expand Down
47 changes: 47 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "sqnbitgemm.h"
#include "sqnbitgemm_q8_block.h"

//
// Quantized B data packing function implementation.
Expand Down Expand Up @@ -99,6 +100,52 @@ SQ4BitGemmPackQuantBData(
);
}

//
// Workspace size calculation function implementation.
//

static size_t
SQ4BitGemmPerGemmWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(N);

switch(ComputeType) {
case CompInt8: {
// workspace buffer is used for block quantization of A to int8
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen);
return PerGemmWorkspaceSize;
}
default: {
return 0;
}
}
}

static size_t
SQ4BitGemmPerGemmWorkspaceAlignment(
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
)
{
MLAS_UNREFERENCED_PARAMETER(BlkLen);

switch (ComputeType) {
case CompInt8: {
return Q8BlkAlignment();
}
default: {
return 1;
}
}
}

void
Q4BitBlkDequantBForSgemm_CompFp32_avx2(
const size_t BlkLen,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "sqnbitgemm.h"
#include "sqnbitgemm_kernel_avx_common.h"
#include "sqnbitgemm_q8_block.h"

void
SQ4BitGemmM1Kernel_CompInt8_avx2(
Expand Down
Loading

0 comments on commit a39f886

Please sign in to comment.