diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index c0ab948b41fff..83305539c2853 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -45,6 +45,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp ${MLAS_SRC_DIR}/qnbitgemm.h ${MLAS_SRC_DIR}/qnbitgemm.cpp + ${MLAS_SRC_DIR}/qlutgemm.h + ${MLAS_SRC_DIR}/qlutgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/cast.cpp @@ -209,6 +211,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp @@ -693,6 +697,8 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 64a434e2fe301..c4b6821e65b37 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -374,6 +374,12 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; +// Use LUT (Lookup Table) based GEMM for quantized models when available. +// Option values: +// - "0": Do not use LUT based GEMM. [DEFAULT] +// - "1": Use LUT based GEMM when available. +static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm"; + // When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. // Refer to MatMulNBits op schema for more details. // If not provided, default is 4. diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index fc0b7e40c628b..53e95bd8c5627 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" #include +#include #include #include "core/common/common.h" @@ -15,7 +16,10 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" +#include "core/platform/threadpool.h" +#include "core/util/thread_utils.h" namespace onnxruntime { namespace contrib { @@ -100,6 +104,11 @@ class MatMulNBits final : public OpKernel { nbits_{narrow(info.GetAttr("bits"))}, has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()}, has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()}, + prefer_lut_gemm_{info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLutGemm) == "1" && + MlasIsLutGemmAvailable(narrow(info.GetAttr("N")), + narrow(info.GetAttr("K")), + narrow(info.GetAttr("bits")), + narrow(info.GetAttr("block_size")))}, compute_type_{GetComputeType(nbits_, block_size_, info.GetAttr("accuracy_level"))} { const auto& node = info.node(); auto input_defs = node.InputDefs(); @@ -135,6 +144,7 @@ class MatMulNBits final : public OpKernel { const bool has_g_idx_; const bool has_bias_; bool scales_are_packed_{false}; + const bool prefer_lut_gemm_{false}; const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; bool has_unquantized_zero_point_{false}; const bool column_wise_quant_{true}; @@ -167,6 +177,11 @@ class MatMulNBits final : public OpKernel { AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, const MatMulComputeHelper& helper) const; + + Status ComputeBPackedLUT(const Tensor* a, + Tensor* y, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const; }; template @@ -179,22 +194,76 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } - if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && !prefer_lut_gemm_) { return Status::OK(); } + + // Create a temporary threadpool for parallel packing + // This is used during model load time to speed up weight prepacking + std::unique_ptr temp_threadpool; + concurrency::ThreadPool* threadpool_ptr = nullptr; + + // Only create threadpool for LUT GEMM path which can benefit from parallel packing + // TODO: Consider extending threadpool usage to non-LUT path (CompInt8) with appropriate tests + if (prefer_lut_gemm_) { + OrtThreadPoolParams tpo; + tpo.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); + tpo.allow_spinning = false; // Don't spin during model load + tpo.auto_set_affinity = false; + + temp_threadpool = concurrency::CreateThreadPool( + &Env::Default(), + tpo, + concurrency::ThreadPoolType::INTRA_OP); + + threadpool_ptr = temp_threadpool.get(); + } + if (input_idx == InputIndex::B) { const Tensor* scales = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); - packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); - if (packed_b_size_ == 0) { - return Status::OK(); + if (prefer_lut_gemm_) { + MlasInitLutGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); + + packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + + const float* scales_ptr = scales ? scales->Data() : nullptr; + const uint8_t* zp_ptr = nullptr; + if (scales_ptr != nullptr && has_zp_input_) { + const Tensor* zero_points = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); + zp_ptr = zero_points ? zero_points->Data() : nullptr; + } + + MlasLutGemmPack( + N_, K_, nbits_, block_size_, has_zp_input_, + static_cast(tensor.DataRaw()), + scales_ptr, + zp_ptr, + static_cast(packed_b_.get()), + threadpool_ptr); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + } else { + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + auto qptr = tensor.DataRaw(); + auto scale_ptr = scales ? scales->DataRaw() : nullptr; + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, + has_zp_input_, nullptr, threadpool_ptr); } - auto qptr = tensor.DataRaw(); - auto scale_ptr = scales ? scales->DataRaw() : nullptr; - packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, - has_zp_input_, nullptr, nullptr); is_packed = true; } else if (compute_type_ == SQNBIT_CompInt8) { // Packing scales and zero points @@ -230,8 +299,30 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All is_packed = true; } #endif // MLAS_TARGET_ARM64 + } else if (prefer_lut_gemm_) { + // Pack scales/zero_points for LUT GEMM if B was already packed but scales weren't available then + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto scales_ptr = tensor.Data(); + const uint8_t* zp_ptr = nullptr; + if (has_zp_input_) { + const Tensor* zero_points = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); + zp_ptr = zero_points ? zero_points->Data() : nullptr; + } + // Pack only scales (QuantBData is nullptr) + MlasLutGemmPack( + N_, K_, nbits_, block_size_, has_zp_input_, + nullptr, // QuantBData already packed + scales_ptr, + zp_ptr, + static_cast(packed_b_.get()), + nullptr); // No threadpool needed for scales only + is_packed = false; // scales tensor can be released but not "packed" in the ORT sense + } } + // Threadpool will be automatically destroyed when temp_threadpool goes out of scope + return Status::OK(); } @@ -307,14 +398,34 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; - if (input_idx == 1) { - used_shared_buffers = true; + if (input_idx == InputIndex::B && !prepacked_buffers.empty()) { packed_b_ = std::move(prepacked_buffers[0]); + used_shared_buffers = true; + + if (prefer_lut_gemm_) { + MlasInitLutGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); + packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_); + } } return Status::OK(); } +template +Status MatMulNBits::ComputeBPackedLUT(const Tensor* a, + Tensor* y, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const { + const auto* a_data = a->Data(); + auto* y_data = y->MutableData(); + const int M = static_cast(helper.M()); + const int N = static_cast(helper.N()); + const int K = static_cast(helper.K()); + + MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool); + return Status::OK(); +} + template Status MatMulNBits::ComputeBPacked(const Tensor* a, const Tensor* scales, @@ -740,7 +851,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // If B is prepacked, B would have been removed from the context const bool is_b_prepacked = packed_b_size_ > 0; const Tensor* b = is_b_prepacked ? nullptr : ctx->Input(InputIndex::B); - const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input(InputIndex::scales); + const Tensor* scales = (scales_are_packed_ || (prefer_lut_gemm_ && packed_b_)) ? nullptr : ctx->Input(InputIndex::scales); const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* reorder_idx = ctx->Input(InputIndex::g_idx); const Tensor* bias = ctx->Input(InputIndex::bias); @@ -774,6 +885,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() // with B directly too. + if (prefer_lut_gemm_) { + return ComputeBPackedLUT(a, y, thread_pool, helper); + } + if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index b7b839a4f366b..e9ef220a2187e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -73,7 +73,7 @@ void Dequantize4BitsKernelReOrder( } } -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input @@ -102,17 +102,17 @@ void DequantizeBlockwise( }); } -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const float* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h index 5061ac5c800a6..be77ec03d006b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -6,7 +6,7 @@ namespace onnxruntime { namespace contrib { -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index fc3c0b6016ced..39df8cf4e9a34 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -27,11 +27,11 @@ Module Name: * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ - HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ - BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ - SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ - HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ + SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ + HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ + BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ + SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ + HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ } MLAS_QNBIT_GEMM_COMPUTE_TYPE; /** @@ -41,13 +41,13 @@ typedef enum { */ template struct MLAS_QNBIT_GEMM_DATA_PARAMS { - const T* A = nullptr; ///< address of A (float32/16 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) - const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const T* A = nullptr; ///< address of A (float32/16 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) + const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block /// /// Address of scale * accumulate(quant - zp), one per block, where `scale`, `quant`, `zp` are respectively @@ -58,10 +58,10 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS { /// This input is to be used only when A is quantized to uint8. /// const T* BlkUnsignedQuantAZeroPointCorrection = nullptr; - - const T* Bias = nullptr; ///< optional address of Bias, vector size N - T* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C + + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; @@ -232,3 +232,124 @@ MlasQNBitGemmScalesPacked( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool HasZeroPoint ); + +/** + * @brief Determines whether the Lut (Lookup Table) GEMM optimization path is available. + * + * @param[in] N column size of matrix B + * @param[in] K row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 2 means 2 bit ints) + * @param[in] BlkLen number of quantized values per block + * @return true if Lut GEMM is available for the given parameters + */ +bool MLASCALL +MlasIsLutGemmAvailable( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen +); + +/** + * @brief Initializes kernel configuration for Lut GEMM. + * + * @param[in] M row size of output matrix + * @param[in] N column size of matrix B + * @param[in] nbits quantized value bit width + * @param[in] block_size number of quantized values per block + * @param[in] has_zero_point whether zero points are provided + */ +void MLASCALL +MlasInitLutGemmKernelConfig( + size_t M, + size_t N, + size_t nbits, + size_t block_size, + bool has_zero_point +); + +/** + * @brief Clears the cached LUT GEMM kernel configuration. + * Call this when the model dimensions change or to reset state between operations. + * Primarily used in testing scenarios to ensure clean state between test runs. + */ +void MLASCALL +MlasClearLutGemmKernelConfig(); + +/** + * @brief Gets the total size in bytes of the prepacked buffer for Lut GEMM. + * This buffer contains packed quantized B data followed by packed scales and zero points. + * + * @param[in] N column size of matrix B + * @param[in] K row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 2 means 2 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided + * @return Total size in bytes of the prepacked buffer + */ +size_t MLASCALL +MlasLutGemmPackedSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint +); + +/** + * @brief Packs quantized B data and/or scales/zero points into a buffer for Lut GEMM. + * If QuantBScale is nullptr, only packs B data. If QuantBData is nullptr, only packs scales. + * + * @param[in] N column size of matrix B + * @param[in] K row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 2 means 2 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided + * @param[in] QuantBData quantized B data (nullptr to skip B packing) + * @param[in] QuantBScale quantized B scales (nullptr to skip scale packing) + * @param[in] QuantBZeroPoint quantized B zero points (nullptr if HasZeroPoint is false) + * @param[out] PackedBuf output buffer (must be at least MlasLutGemmPackedSize bytes) + * @param[in] ThreadPool thread pool for parallel packing + */ +void MLASCALL +MlasLutGemmPack( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + const std::byte* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + std::byte* PackedBuf, + MLAS_THREADPOOL* ThreadPool +); + +/** + * @brief Executes TMAC compute using Lut (Lookup Table) based GEMM. + * + * This function handles generating the look up tables and accumulating the matmul results. + * Results will be stored in C. + * + * @param[in] A activation matrix + * @param[in] BlkLen number of quantized values per block + * @param[in] PackedBuf packed buffer containing weights and scales/zp (from MlasLutGemmPack) + * @param[out] C output matrix + * @param[in] K inner dimension + * @param[in] M batch size (number of rows in activation) + * @param[in] N column size of matrix B + * @param[in] HasZeroPoint whether zero points are provided + * @param[in] threadpool thread pool for parallel computation + */ +void MLASCALL +MlasLutGemm( + const void* A, + size_t BlkLen, + const void* PackedBuf, + void* C, + size_t K, + size_t M, + size_t N, + bool HasZeroPoint, + MLAS_THREADPOOL* threadpool +); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ad62cccbfb9c7..02565b4af7900 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1238,6 +1238,10 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx; +struct MLAS_QNBIT_LUT_GEMM_DISPATCH; + +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2; + // // Rotary embedding dispatch structure. // @@ -1443,6 +1447,7 @@ struct MLAS_PLATFORM { const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; + const MLAS_QNBIT_LUT_GEMM_DISPATCH* LutGenKernel{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 528e71bcffed1..fb7144df85083 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -418,6 +418,8 @@ Return Value: this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; this->RopeDispatch = &MlasRopeDispatchAvx2; + // TODO(vraspar): check if this really goes here or if there are other platform reqs that we need to fulfill + this->LutGenKernel = &MlasLutGenKernelAvx2; // // Check if the processor supports Hybrid core architecture. diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index c543770ee22d8..fbbf4005ae4a5 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -545,7 +545,7 @@ struct BlockwiseQuantizer { } } - for (int32_t j = c; j < c_end; ++j) { + for (int32_t j = c; j < c_end; ++j) { // this does not work if j runs more then 1 because zp_bytes is indexed by i. const int32_t meta_c = j / QuantBlk::kColumn; for (int32_t i = r; i < r_end; i += kPackSize) { for (int l = 0; l < kPackSize && i + l < r_end; l++) { @@ -656,19 +656,35 @@ struct BlockwiseQuantizer { * @tparam signed_quant quantized type is signed */ template -struct BlockwiseQDQQuantizer; - -template -struct BlockwiseQDQQuantizer { +struct BlockwiseQDQQuantizer { static MLAS_FORCEINLINE uint8_t GetElem(uint8_t val, int32_t idx) { - return (val >> (idx << 2)) & 0xF; + if constexpr (qbits == 2) { + return (val >> (idx << 1)) & 0x3; + } else if constexpr (qbits == 4) { + return (val >> (idx << 2)) & 0xF; + } } static MLAS_FORCEINLINE uint8_t SetElem(uint8_t val, int32_t idx, uint8_t dst) { - auto shift = idx << 2; - return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + if constexpr (qbits == 2) { + auto shift = idx << 1; + return ((val & 0x3) << shift) | (dst & (~(0x3 << shift))); + } else if constexpr (qbits == 4) { + auto shift = idx << 2; + return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + } + } + + template + static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1, uint8_t v2, uint8_t v3) + { + if constexpr (add2) { + return ((v0 & 0x3) ^ 2) | (((v1 & 0x3) ^ 2) << 2) | (((v2 & 0x3) ^ 2) << 4) | (((v3 & 0x3) ^ 2) << 6); + } else { + return (v0 & 0x3) | ((v1 & 0x3) << 2) | ((v2 & 0x3) << 4) | ((v3 & 0x3) << 6); + } } template @@ -1436,8 +1452,7 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); -template -void +template void MlasBlockwiseQuantMetaShape( int block_size, bool columnwise, @@ -1445,7 +1460,7 @@ MlasBlockwiseQuantMetaShape( int columns, int& meta_rows, int& meta_cols - ); +); template void @@ -1513,8 +1528,7 @@ MlasBlockwiseQuantizedShape( int& q_cols ); -template -void +template void MlasBlockwiseQuantizedShape( int block_size, bool columnwise, @@ -1524,7 +1538,7 @@ MlasBlockwiseQuantizedShape( int& q_cols ); - template +template void MlasBlockwiseQuantizedShape( int block_size, @@ -2016,6 +2030,19 @@ MlasQDQQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template bool +MlasQDQQuantizeBlockwise( + const float* src, + float* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template bool MlasQDQQuantizeBlockwise( const MLAS_FP16* src, @@ -2029,6 +2056,19 @@ MlasQDQQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template bool +MlasQDQQuantizeBlockwise( + const MLAS_FP16* src, + MLAS_FP16* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template void MlasQDQTransposeBlockwiseQuantized( @@ -2055,6 +2095,36 @@ MlasQDQTransposeBlockwiseQuantized( } } +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template void MlasQDQTransposeBlockwiseQuantized( const uint8_t* src_weights, diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp new file mode 100644 index 0000000000000..f029e539f02a1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -0,0 +1,686 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qlutgemm.cpp + +Abstract: + + This module implements kernel functions for generating lookup tables (LUT) + and computing matrix multiplication for the T-MAC GEMM optimization strategy. + + It provides functionality to pack quantized weight data, compute LUT scales + and biases, and perform efficient quantized GEMM operations using lookup + table based computation. + +--*/ +#include "qlutgemm.h" + +#include +#include +#include +#include +#include +#include + +/** T-MAC GEMM kernel Config */ +static std::unordered_map tmac_kernel_configs; + +const MlasTMACKernelParams& +MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) +{ + std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0"); + if (tmac_kernel_configs.count(key)) { + return tmac_kernel_configs[key]; + } + MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized"); +} + +void MLASCALL +MlasClearLutGemmKernelConfig() +{ + tmac_kernel_configs.clear(); +} + +void MLASCALL +MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) +{ + std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0"); + if (tmac_kernel_configs.count(key)) { + return; + } + + MlasTMACKernelParams params; + params.g = 4; + params.ngroups_per_elem = 8 / params.g; + params.simd_n_in = 16; + params.simd_n_out = 8; + params.chunk_n = 8; + + params.bits = nbits; + params.q_group_size = block_size; + + if (block_size % 64 == 0) { + params.act_group_size = 64; + } else if (block_size % 32 == 0) { + params.act_group_size = 32; + } else { + // throw error + MLAS_THROW_EX(std::runtime_error, "Unsupported activation group size"); + } + params.actk = params.act_group_size / params.g; + + // search space + std::vector bms; + if (nbits == 1 || nbits == 2 || nbits == 4) { + bms = {256, 512, 1024, 2048, 320, 640, 1280}; + } else if (nbits == 3) { + bms = {192, 384, 576, 758}; + } + + std::vector kfactors = {8, 16}; + + // TODO(vraspar): add profile based policy + size_t threads = static_cast(std::thread::hardware_concurrency()); + + float smallest_penalty = 1e9f; + params.bm = bms[0]; + for (size_t bm : bms) { + if (M % (bm / nbits) != 0 || bm % nbits != 0) { + continue; + } + size_t num_tiles = M / (bm / nbits); + size_t num_groups = (num_tiles + threads - 1) / threads; + float penalty = 0.1f * static_cast(num_groups) + + (static_cast(num_groups) - 1.0f * static_cast(num_tiles) / static_cast(threads)) / + static_cast(num_groups); + if (penalty < smallest_penalty) { + smallest_penalty = penalty; + params.bm = bm; + } + } + + size_t largest_kfactor = 0; + params.kfactor = kfactors[0]; + for (size_t kfactor : kfactors) { + if ((kfactor < params.actk) || (kfactor * params.g > params.q_group_size)) { + continue; + } + if (kfactor > largest_kfactor) { + largest_kfactor = kfactor; + params.kfactor = kfactor; + } + } + + params.n_tiles_num = M * params.bits / params.bm; + params.has_scale = true; // TODO(vraspar): TMAC supports only scale for now + params.has_zero_point = has_zero_point; + params.one_scale = false; // TODO(vraspar): support one scale case for bitnet + + tmac_kernel_configs[key] = params; + return; +} + +// Internal helper: calculates packed quantized B data size +static size_t +LutGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint +) +{ + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + const size_t PackedQuantBDataSize = (N * BlkBitWidth) * (K / tmac_params.g / tmac_params.ngroups_per_elem); + return PackedQuantBDataSize; +} + +// Internal helper: packs quantized B data +static void +LutGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + // decompose W into w1,... w_bits create temp buffer buf2 of size N * bits * (K/g) + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + const size_t bits = tmac_params.bits; + const size_t g = tmac_params.g; + const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; + const size_t simd_n_in = tmac_params.simd_n_in; + const size_t simd_n_out = tmac_params.simd_n_out; + const size_t bm = tmac_params.bm; + const size_t kfactor = tmac_params.kfactor; + + assert(BlkLen % g == 0); + assert((BlkLen / g) % kfactor == 0); + + const size_t mgroup = ngroups_per_elem * simd_n_in; // 32 + assert(bm % mgroup == 0); + assert(bm % bits == 0); + + std::unique_ptr buf(new uint8_t[N * bits * (K / g)]); + memset(buf.get(), 0, N * bits * (K / g)); + + const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + for (size_t ik = 0; ik < K; ++ik) { + size_t idx = (im * K + ik); + size_t num_elem_per_byte = 8 / bits; + size_t elem_idx = idx % num_elem_per_byte; + + uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); + + for (size_t ib = 0; ib < bits; ++ib) { + size_t new_ik = ik / g; + size_t shft_left = ik % g; + buf[im * bits * K / g + ib * K / g + new_ik] += static_cast(((v >> ib) & 1) << shft_left); + } + } + } + ); + + // Now buf contains the bit planes grouped by g along K + // Next, we need to do a multi-reshape/transpose into the final layout + + const size_t c0_fac2 = K / g; + const size_t c0_fac1 = simd_n_out * c0_fac2; + const size_t c0_fac0 = bits * c0_fac1; + + const size_t c1_nb2 = K / g; + const size_t c1_nb1 = simd_n_in * c1_nb2; + const size_t c1_nb0 = ngroups_per_elem * c1_nb1; + const size_t c1_fac2 = K / g; + const size_t c1_fac1 = ngroups_per_elem * c1_fac2; + const size_t c1_fac0 = simd_n_in * c1_fac1; + + const size_t c2_nb4 = kfactor; + const size_t c2_nb3 = K / g / kfactor * c2_nb4; + const size_t c2_nb2 = ngroups_per_elem * c2_nb3; + const size_t c2_nb1 = simd_n_in * c2_nb2; + const size_t c2_nb0 = bm / mgroup * c2_nb1; + const size_t c2_fac3 = simd_n_in * ngroups_per_elem; + const size_t c2_fac2 = kfactor * c2_fac3; + const size_t c2_fac1 = bm / mgroup * c2_fac2; + const size_t c2_fac0 = K / g / kfactor * c2_fac1; + + const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); + memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed? + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + for (size_t ib = 0; ib < bits; ib++) { + for (size_t ik = 0; ik < K / g; ik++) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + size_t new_im = im / simd_n_out; + size_t new_isno = im % simd_n_out; + size_t new_ib = ib; + size_t new_ik = ik; + size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + size_t new_ing = (new_idx % c1_nb0) / c1_nb1; + size_t new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + size_t new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + + new_ik * c2_fac1 + + new_ibm * c2_fac2 + + new_ikf * c2_fac3 + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + size_t buf_idx = im * bits * K / g + ib * K / g + ik; + uint8_t buf_val = buf[buf_idx]; + + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + PackedQuantBDataBegin[new_idx] = static_cast( + static_cast(PackedQuantBDataBegin[new_idx]) + + (buf_val << (new_ing * g)) + ); + } + } + } + ); +} + +// Internal helper: calculates packed scales and zero points size in floats +static size_t +LutPackScalesAndZeroPointsSize( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint +) +{ + // TODO(vraspar): support one scale case + if (HasZeroPoint) { + return N * K / BlkLen * 2; + } else { + return N * K / BlkLen; + } +} + +// Internal helper: packs scales and zero points +static void +LutPackScalesAndZeroPoints( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + float* PackedQuantBZPBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint +) +{ + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + const size_t bits = tmac_params.bits; + const size_t simd_n_out = tmac_params.simd_n_out; + const size_t bm = tmac_params.bm; + const size_t num_elem_per_byte = 8 / bits; + + // ZP array is column-major packed, with per-column alignment to byte boundary + const size_t row_blks = K / BlkLen; // number of blocks per column + const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; + + for (size_t im = 0; im < N; im += 1) { + for (size_t ik = 0; ik < K; ik += BlkLen) { + size_t idx = (im * K + ik) / BlkLen; // linear block index for scale (scale is NOT packed) + float scale = QuantBScale[idx]; + float zp = 0.0f; + if (HasZeroPoint) { + size_t blk_in_col = ik / BlkLen; // block index within column + size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; + size_t elem_idx = blk_in_col % num_elem_per_byte; + uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1); + + // The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit). + // Thus, need to correct for the actual ZP relative to the midpoint. + + int midpoint = 1 << (bits - 1); // 2 for 2-bit + zp = static_cast(static_cast(v) - midpoint) * scale; + } + + // TODO(vraspar): fix when k < BlkLen and nb1 is 0 + size_t nb1 = K / BlkLen; + size_t nb0 = bm / bits * nb1; + + size_t new_im, new_ibm, new_ik; + if (nb1 == 0) { + new_im = 0; + new_ibm = 0; + new_ik = 0; + + } else { + new_im = idx / nb0; + new_ibm = (idx % nb0) / nb1; + new_ik = (idx % nb1); + } + + if (HasZeroPoint) { + size_t new_isimd = new_ibm % simd_n_out; + size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; + size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; + + PackedQuantBZPBegin[new_idx_scale] = scale; + PackedQuantBZPBegin[new_idx_zero] = zp; + } else { + size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; + PackedQuantBZPBegin[new_idx] = scale; + } + } + } +} + +// Internal helper: calculates the offset to scales in the packed buffer +static size_t +LutGemmPackedScalesOffset( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint +) +{ + constexpr size_t kAlignment = 64; // Cache line alignment + size_t packed_b_size = LutGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + return ((packed_b_size + kAlignment - 1) / kAlignment) * kAlignment; +} + +size_t MLASCALL +MlasLutGemmPackedSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint +) +{ + // Get packed B size (aligned) + size_t aligned_b_size = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + + // Get packed scales/zp size (in floats, convert to bytes) + size_t packed_scales_count = LutPackScalesAndZeroPointsSize(N, K, BlkLen, HasZeroPoint); + size_t packed_scales_bytes = packed_scales_count * sizeof(float); + + return aligned_b_size + packed_scales_bytes; +} + +void MLASCALL +MlasLutGemmPack( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + const std::byte* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + std::byte* PackedBuf, + MLAS_THREADPOOL* ThreadPool +) +{ + // Pack B data if provided + if (QuantBData != nullptr) { + LutGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, HasZeroPoint, QuantBData, PackedBuf, ThreadPool); + } + + // Pack scales/zero points if scales are provided + if (QuantBScale != nullptr) { + size_t scales_offset = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + float* scales_dest = reinterpret_cast(PackedBuf + scales_offset); + LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint); + } +} + +bool MLASCALL +MlasIsLutGemmAvailable( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen +) +{ + const auto* lut_kernel = GetMlasPlatform().LutGenKernel; + if (lut_kernel == nullptr || lut_kernel->GenerateLUT == nullptr || lut_kernel->ComputeGemm == nullptr) { + return false; + } + + // currently only 2-bit is supported + if (BlkBitWidth != 2 || BlkLen == 0 || (BlkLen % 32) != 0) { + return false; + } + + if (K % 32 != 0) { + return false; + } + + size_t n_div = 0; + switch (BlkBitWidth) { + case 1: + n_div = 256; + break; + case 2: + n_div = 128; + break; + case 3: + n_div = 64; + break; + case 4: + n_div = 32; + break; + default: + return false; + } + + if (N % n_div != 0) { + return false; + } + return true; +} + +size_t +CalculateLutBufferSize(size_t n, size_t k, size_t m, const MlasTMACKernelParams& tmac_params) +{ + MLAS_UNREFERENCED_PARAMETER(n); + constexpr size_t kAllockAligment = 64; + const size_t lut_scales_size = k / tmac_params.act_group_size; + + size_t wsize = k * m * 4 * sizeof(int8_t); // 4 bytes per k element for 2-bit LUT + wsize += lut_scales_size * m * 2 * sizeof(float); // scales + biases + + wsize = ((wsize - 1) / kAllockAligment + 1) * kAllockAligment; + + // TODO(vrapar): add temp buffer for FP16 + return wsize; +} + +void MLASCALL +MlasLutGemm( + const void* A, + size_t BlkLen, + const void* PackedBuf, // Packed buffer containing weights followed by scales/zp + void* C, + size_t K, + size_t M, // batch size (number of rows in activation) + size_t N, + bool HasZeroPoint, + MLAS_THREADPOOL* threadpool +) +{ + // adapted from ggml_backend_tmac_mul_mat + const auto* Dispatch = GetMlasPlatform().LutGenKernel; + // This should be ensured by calling MlasIsLutGemmAvailable() before MlasLutGemm() + assert(Dispatch && Dispatch->GenerateLUT && "TMAC not supported in this configuration."); + + // Calculate scales offset from packed buffer + // TODO(vraspar): support other bitwidths + constexpr size_t BlkBitWidth = 2; + size_t scales_offset = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + const auto* QuantBData = PackedBuf; + const auto* QuantBScale = reinterpret_cast( + static_cast(PackedBuf) + scales_offset + ); + + /** TODO(vraspar): The biases_float and scales float values don't make sense + * FP 16 + * QLUT K(ne10) x M(ne11) x 4 bytes + * Scales: lut_scales_size * M * 2 bytes + * Biases: lut_scales_size * M * 2 bytes + * Needs FP 16 conversion Buffer: max(K, N) * M * 2 bytes + * + * FP 32 + * QLUT K x M x 4 bytes + * Scales: lut_scales_size * M * 4 bytes + * Biases: lut_scales_size * M * 4 bytes + * + * Currently, we only support FP32, add FP16 support later which requires conversion buffer + * + * LUT Buffer for FP32 : K * M * 4 * sizeof(uint8_t) bytes + lut_scale_size * m * 2 * sizeof(float) bytes + allignment + * + */ + + // n_tiles_num = m * bits / bm; + + // TODO(vraspar): support other bitwidths + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, 2, BlkLen, HasZeroPoint); + const size_t lut_scales_size = K / tmac_params.act_group_size; + size_t lut_buffer_size = CalculateLutBufferSize(N, K, M, tmac_params); + + // make buffer of lut_buffer_size bytes + // TODO(vraspar): other way to do it + auto lut_buffer = std::make_unique(lut_buffer_size); + + int8_t* qlut = reinterpret_cast(lut_buffer.get()); + float* lut_scales = reinterpret_cast(qlut + K * M * 4); // after lut + float* lut_biases = reinterpret_cast(lut_scales + lut_scales_size * M); // after scales + + const auto* a_float = reinterpret_cast(A); // Activation data + + // const int num_groups = static_cast(K / BlkLen); + + // Parallelize over M (batch dimension) + // Each iteration processes one row of the activation matrix + // TODO(vraspar): Ideally we have to do block parallelism here + + MlasTrySimpleParallel( + threadpool, + static_cast(M), + [&](ptrdiff_t ine11) { + const size_t row_offset = static_cast(ine11) * K; + const size_t lut_offset = static_cast(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT + const size_t scale_bias_offset = static_cast(ine11) * lut_scales_size; + + // Call the dispatch function for this row + // ggml_tmac_mul_mat_task_init + Dispatch->GenerateLUT( + const_cast(a_float + row_offset), // Input activation for this row + qlut + lut_offset, // Output LUT for this row + lut_scales + scale_bias_offset, // Scales for this row + lut_biases + scale_bias_offset, // Biases for this row + M, + K, + N, + tmac_params.act_group_size + ); + } + ); + + // all relevant LUT's have been generated + // equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line + + const size_t n_tiles_num = tmac_params.n_tiles_num; + assert(N % n_tiles_num == 0); + + const size_t bits = tmac_params.bits; + + // Pre-calculate sizes for offset calculations + const size_t w_size = N * K * bits / 8; + const size_t w_chunk_size = w_size / n_tiles_num; + + // TODO: fix the below 4 + // Matrix multiplication: Output[N×M] = QuantBData[N×K] × Weights[K×M] + const size_t OutputRows = N; // Number of output features + const size_t OutputCols = M; // Batch size + + const size_t ChunkSize0 = N / n_tiles_num; + const size_t ChunkSize1 = tmac_params.chunk_n; // process one batch item at a time + + // In llama.cpp terminology (note the swap!): + // ne0 = M (output features, called "n" in llama.cpp) + // ne1 = N (batch size, called "m" in llama.cpp) + + // Calculate number of chunks in each dimension + const size_t nchunk0 = (OutputRows + ChunkSize0 - 1) / ChunkSize0; // Should equal NumTiles + const size_t nchunk1 = (OutputCols + ChunkSize1 - 1) / ChunkSize1; + const size_t total_chunks = nchunk0 * nchunk1; + + // TODO(vraspar): support one_scale case + // Determine weight-scale layout. These should be provided by the caller or inferred from the packed weights. + // For now we default to per-group symmetric quantization (no zero-point, not one-scale). + + const size_t scales_size_total = LutPackScalesAndZeroPointsSize( + static_cast(N), + static_cast(K), + BlkLen, + tmac_params.has_zero_point + ); + + // Per-tile scales size = total scales size divided evenly across tiles. + // If one_scale is true we do not advance the scales pointer per tile, so set per tile size to 0 + size_t scales_size_per_tile = 0; + + if (scales_size_total % n_tiles_num != 0) { + // Sanity: scales should partition evenly across tiles. If they don't, choose floor division + // and document that callers must layout scales accordingly. + // Prefer to error loudly in debug builds. + fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tiles_num=%zu; using floor division.\n", scales_size_total, n_tiles_num); + } + scales_size_per_tile = scales_size_total / n_tiles_num; + + // Note: when one_scale == true, callers should pass a pointer to a single scale value (scales_offset=0 will be used) + + // Cast to appropriate types + const auto* packed_weights = reinterpret_cast(QuantBData); + float* act_output = reinterpret_cast(C); + + // Parallelize over the 2D chunk grid + MlasTrySimpleParallel( + threadpool, + total_chunks, + [&](ptrdiff_t current_chunk) { + // Decompose linear chunk index into 2D coordinates + const size_t ith0 = current_chunk % nchunk0; // Chunk in dimension 0 (output rows) + const size_t ith1 = current_chunk / nchunk0; // Chunk in dimension 1 (batch) + + // Calculate ranges for this chunk + const size_t ir0_start = ChunkSize0 * ith0; + const size_t ir0_end = std::min(ir0_start + ChunkSize0, OutputRows); + + const size_t ir1_start = ChunkSize1 * ith1; + const size_t ir1_end = std::min(ir1_start + ChunkSize1, OutputCols); + + // Process all tiles in dimension 0 for this chunk + for (size_t ichunk0 = ir0_start / ChunkSize0; ichunk0 < ir0_end / ChunkSize0; ichunk0++) { + // Calculate weight offsets + const size_t w_offset = ichunk0 * w_chunk_size; + const size_t scales_offset = ichunk0 * scales_size_per_tile; + + // Process all batch items in this chunk + for (size_t ine11 = ir1_start; ine11 < ir1_end; ine11++) { + // Calculate LUT offsets for this batch item + const size_t qlut_offset = K * ine11 * 4; + const size_t lut_scales_offset = lut_scales_size * ine11; + + // Calculate output offset + const size_t dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0; + + // Call the dispatch function to compute this tile + // Note M and N are swapped in TMAC terminology + Dispatch->ComputeGemm( + packed_weights + w_offset, // Weight tile + QuantBScale + scales_offset, // Weight scales for this tile + qlut + qlut_offset, // LUT for this batch row + lut_scales + lut_scales_offset, // LUT scales + lut_biases + lut_scales_offset, // LUT biases + act_output + dst_offset, // Output location + static_cast(K), // K dimension + static_cast(N), // N dimension + static_cast(1), // M dimension (processing one batch item at a time) + BlkLen, // Weight quantization group size + HasZeroPoint // Whether zero points are used + ); + } + } + } + ); +} diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h new file mode 100644 index 0000000000000..ef4d01a2c5809 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -0,0 +1,84 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qlutgemm.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing LUT-based GEMM. +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +/** + * @brief Parameters for TMAC kernel + */ +struct MlasTMACKernelParams { + size_t g; + size_t ngroups_per_elem; + size_t q_group_size; + size_t act_group_size; + + size_t kfactor; + size_t bits; + size_t actk; + size_t bm; + size_t simd_n_in; + size_t simd_n_out; + size_t chunk_n; + size_t n_tiles_num; + + bool has_scale; + bool has_zero_point; + bool one_scale; +}; + +const MlasTMACKernelParams& +MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point); + +typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( + const float* b, + int8_t* qlut, + float* lut_scales, + float* lut_biases, + size_t M, + size_t K, + size_t N, + size_t act_group_size +); + +typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)( + const uint8_t* weights, + const float* scales, + const int8_t* LUT, + const float* LUT_Scales, + const float* LUT_Biases, + float* C, + int K, + int M, // batch size (number of rows in activation) + int N, + size_t BlkLen, + bool HasZeroPoint +); + +// +// Kernel dispatch structure. +// +// NOTE: This name must match the forward declaration in mlasi.h: +// struct MLAS_QNBIT_LUT_GEMM_DISPATCH; +// Keep it minimal for now; extend with function pointers as kernels are added. +struct MLAS_QNBIT_LUT_GEMM_DISPATCH { + // Intentionally empty placeholder; add members as needed. + MLAS_QNBIT_GEMM_LUT_GEN* GenerateLUT = nullptr; + + MLAS_QNBIT_LUT_GEMM_COMPUTE* ComputeGemm = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp new file mode 100644 index 0000000000000..b54f051ca1504 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -0,0 +1,671 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_lut_kernel_avx2.cpp + +Abstract: + + This module implements x64 AVX2 kernel functions for LUT-based quantized + n-bit integer matrix multiplication. + + It provides optimized AVX2 implementations for lookup table generation, + GEMM computation, and related operations on quantized weight and activation + matrices. + +--*/ + +#include +#include +#include +// AVX2 intrinsics +#include + +#include "qlutgemm.h" +#include "qnbitgemm.h" +#include "sqnbitgemm_q8_block.h" + +static inline float +_mm256_addv_ps(const __m256 v) +{ + __m128 res = _mm256_extractf128_ps(v, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(v)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// Conditional pragma unroll for compiler compatibility +#if defined(__INTEL_COMPILER) || defined(__clang__) +#define PRAGMA_UNROLL _Pragma("unroll") +#else +#define PRAGMA_UNROLL +#endif + +// Helper macros for extracting and widening vectors +#define extract_low_epi8_epi16(v) _mm256_cvtepi8_epi16(_mm256_castsi256_si128(v)) +#define extract_high_epi8_epi16(v) _mm256_cvtepi8_epi16(_mm256_extracti128_si256(v, 1)) +#define extract_low_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_castsi256_si128(v)) +#define extract_high_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_extracti128_si256(v, 1)) + +// Template classes for accumulation +template +struct SignedHalvingAdder { + SignedHalvingAdder adder; + __m256i lhs = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) + { + if (k < N / 2) { + adder.push(v, k); + if (k == N / 2 - 1) { + lhs = adder.get(); + } + } else { + adder.push(v, k - N / 2); + if (k == N - 1) { + lhs = _mm256_avg_epu8(lhs, adder.get()); + } + } + } + + inline __m256i get() + { + return lhs; + } + + inline __m256i get_low() + { + return extract_low_epi8_epi16(lhs); + } + + inline __m256i get_high() + { + return extract_high_epi8_epi16(lhs); + } +}; + +template <> +struct SignedHalvingAdder<2> { + __m256i lhs = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) + { + if (k == 0) { + lhs = v; + } else { + lhs = _mm256_avg_epu8(lhs, v); + } + } + + inline __m256i get() + { + return lhs; + } + + inline __m256i get_low() + { + return extract_low_epi8_epi16(lhs); + } + + inline __m256i get_high() + { + return extract_high_epi8_epi16(lhs); + } +}; + +template +struct SignedWideningAdder { + __m256i lhs_low = _mm256_setzero_si256(); + __m256i lhs_high = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) + { + if (k == 0) { + lhs_low = extract_low_epi8_epi16(v); + lhs_high = extract_high_epi8_epi16(v); + } else { + lhs_low = _mm256_add_epi16(lhs_low, extract_low_epi8_epi16(v)); + lhs_high = _mm256_add_epi16(lhs_high, extract_high_epi8_epi16(v)); + } + } + + inline __m256i get_low() + { + return lhs_low; + } + + inline __m256i get_high() + { + return lhs_high; + } +}; + +template +using SignedAdder = typename std::conditional, SignedWideningAdder>::type; + +// Template for computing log2 at compile time +template +struct mylog2 { + enum { + value = 1 + mylog2::value + }; +}; + +template <> +struct mylog2<0> { + enum { + value = -1 + }; +}; + +// Template for computing bias scale at compile time +template +constexpr int +get_bias_scale() +{ + // The bias scale will be added to the first bit + // 15 = (1/2 + 1 + 2 + 4) / (1/2) + // 7 = (1/2 + 1 + 2) / (1/2) + // 3 = (1/2 + 1) / (1/2) + // 1 = (1/2) / (1/2) + // if constexpr (bits == 4) { + // return 15; + // } else if constexpr (bits == 3) { + // return 7; + // } else if constexpr (bits == 2) { + // return 3; + // } else if constexpr (bits == 1) { + // return 1; + // } else { + // return 0; + // } + return 3; +} + +void +partial_max_g4_int8_k8(float* lut_scales, const float* b) +{ + // TODO(vraspar): add support for arm neon + const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); + __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1); + __m256 vec_b3 = _mm256_i32gather_ps(b + 3, vec_bi, 1); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + __m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0); + __m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1); + __m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2); + __m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3); + __m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3)); + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + float scales = _mm_cvtss_f32(max4) / 127; + *lut_scales = std::max(*lut_scales, scales); +} + +// Current implementation requires (K * 4) == act_group_size and K >= 8 +// s0 = -1, s1 = 1 +// TODO: loop K +inline void +lut_ctor_g4_int8_impl( + int32_t act_k, + int8_t* qlut, + const float* b, + float* lut_scales, + float* lut_biases +) +{ + __m256 vec_lut[16]; + float biases = 0.0; + const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); + float scales = *lut_scales; + float t_scales = scales ? 1.0f / scales : 0.0f; + + for (int k = 0; k < act_k / 32; ++k) { + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 32 + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 32 + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 32 + 2, vec_bi, 1); + __m256 vec_b3 = _mm256_i32gather_ps(b + k * 32 + 3, vec_bi, 1); + + PRAGMA_UNROLL + for (int g = 1; g < 16; g += 2) { + vec_lut[g] = vec_b0; + if (g & 0b0010) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b1); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b1); + } + if (g & 0b0100) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b2); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b2); + } + if (g & 0b1000) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b3); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b3); + } + } + PRAGMA_UNROLL + for (int g = 0; g < 16; g += 2) { + // vec_lut[g] = -vec_lut[15 - g]; + const __m256 neg_mask = _mm256_set1_ps(-0.0f); // all lanes have sign bit set + vec_lut[g] = _mm256_xor_ps(vec_lut[15 - g], neg_mask); + } + + biases += _mm256_addv_ps(vec_lut[0]); + + PRAGMA_UNROLL + for (int g = 0; g < 16; ++g) { + vec_lut[g] = _mm256_mul_ps(vec_lut[g], _mm256_set1_ps(t_scales)); + } + + __m256i vec_qlut[4]; + const __m256i shuf = _mm256_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + PRAGMA_UNROLL + for (int g = 0; g < 4; g += 1) { + __m256i i0 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 0], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i1 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 1], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i2 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 2], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i3 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 3], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + i0 = _mm256_packs_epi32(i0, i1); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32(i2, i3); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16(i0, i2); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + vec_qlut[g] = _mm256_shuffle_epi8(i0, shuf); // 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 + } + + int32_t* qlut_i32 = reinterpret_cast(qlut); + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 0 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 0); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 1 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 1); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 2 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 2); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 3 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 3); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 4 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 4); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 5 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 5); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 6 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 6); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 7 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 7); + } + } + + *lut_scales = scales; + *lut_biases = biases; +} + +// based on lut_ctor_g4_int8_impl +void +GenerateLUT_avx2( + const float* b, + int8_t* qlut, + float* lut_scales, + float* lut_biases, + size_t M, + size_t K, + size_t N, + size_t act_group_size +) +{ + (void)M; // silence unused parameter warning + (void)N; // silence unused parameter warning + // TODO: handle bitnet here + const int32_t kk_outer_max = static_cast(K / act_group_size); + const int32_t ags_div32 = static_cast(act_group_size / 32); + + for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { + // compute partial max - directly reset scale to 0.0 + lut_scales[kk_outer] = 0.0f; // partial max reset + for (int32_t k_outer = 0; k_outer < ags_div32; ++k_outer) { + partial_max_g4_int8_k8(&lut_scales[kk_outer], &b[(kk_outer * act_group_size) + (k_outer * 32)]); + } + } + + for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { + lut_ctor_g4_int8_impl(static_cast(act_group_size), (&(qlut[(k_outer_1 * act_group_size * 4)])), (&(b[(k_outer_1 * act_group_size)])), (&(lut_scales[k_outer_1])), (&(lut_biases[k_outer_1]))); + } +} + +inline void +tbl_g4_int8_float_gather_bit2_impl(int32_t m, float* C_global, float* CBits, float* C) +{ + constexpr int32_t bits = 2; + + int32_t m_c_outer_max = m / 32; + for (int32_t m_c_outer = 0; m_c_outer < m_c_outer_max; ++m_c_outer) { + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + PRAGMA_UNROLL + for (int32_t m_c_inner = 0; m_c_inner < 32; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 8; + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (float)5.000000e-01f) + (CBits[cse_var_2 + bit_offset_1]); + } + } + + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { + PRAGMA_UNROLL + for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { + int offset = m_inner_outer * 32 + m_inner; + C[offset] = C_global[offset]; + } + } +} + +// When FastAggregation is enabled, FastAggregationK = ActK +// zero_points is merged into scales to maintain API +template +inline int32_t +tbl_g4_int8_float_update_impl(int32_t m, float* c, const int8_t* lut, const uint8_t* a, const float* scales, const float* lut_scales, const float* lut_biases) +{ + const __m128i vec_mask = _mm_set1_epi8(0x0f); + __m128i vec_lut[K]; + + PRAGMA_UNROLL + for (int k = 0; k < K; k++) { + vec_lut[k] = _mm_loadu_si128(reinterpret_cast(lut + k * 16)); + } + + SignedAdder adder; + for (int i = 0; i < m / 2; i += 16) { + __m256 vec_c0 = _mm256_setzero_ps(); + __m256 vec_c1 = _mm256_setzero_ps(); + __m256 vec_c2 = _mm256_setzero_ps(); + __m256 vec_c3 = _mm256_setzero_ps(); + + float partial_sum = -0.0f; + PRAGMA_UNROLL + for (int kk = 0; kk < K; kk += ActK) { + PRAGMA_UNROLL + for (int k = 0; k < ActK; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + __m128i vec_as = _mm_loadu_si128(reinterpret_cast(a + i * K + (kk + k) * 16)); + __m128i vec_a_bot = _mm_and_si128(vec_as, vec_mask); + __m128i vec_a_top = _mm_and_si128(_mm_srli_epi16(vec_as, 4), vec_mask); + + __m256i vec_lut_ = _mm256_set_m128i(vec_lut[kk + k], vec_lut[kk + k]); + __m256i vec_a = _mm256_set_m128i(vec_a_top, vec_a_bot); + __m256i vec_v = _mm256_shuffle_epi8(vec_lut_, vec_a); + adder.push(vec_v, k); + } + + __m256 vec_v_low_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_low())); + __m256 vec_v_low_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_low())); + __m256 vec_v_high_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_high())); + __m256 vec_v_high_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_high())); + + float lut_s = lut_scales[kk / ActK]; + float lut_b = lut_biases[kk / ActK]; + + partial_sum += lut_b; + + if (FastAggregation) { + lut_s = lut_s * ActK; + lut_b -= lut_s * (mylog2::value / 4 * get_bias_scale()); + } + +#define lut_fma(vs, ib) \ + ((ib) % Bits) ? (_mm256_mul_ps((vs), _mm256_set1_ps(lut_s))) \ + : (_mm256_fmadd_ps((vs), _mm256_set1_ps(lut_s), _mm256_set1_ps(lut_b))) + if (kk == 0) { + vec_c0 = lut_fma(vec_v_low_low, (i / 4)); + vec_c1 = lut_fma(vec_v_low_high, (i / 4 + 1)); + vec_c2 = lut_fma(vec_v_high_low, (i / 4 + 2)); + vec_c3 = lut_fma(vec_v_high_high, (i / 4 + 3)); + } else { + vec_c0 = _mm256_add_ps(vec_c0, lut_fma(vec_v_low_low, (i / 4))); + vec_c1 = _mm256_add_ps(vec_c1, lut_fma(vec_v_low_high, (i / 4 + 1))); + vec_c2 = _mm256_add_ps(vec_c2, lut_fma(vec_v_high_low, (i / 4 + 2))); + vec_c3 = _mm256_add_ps(vec_c3, lut_fma(vec_v_high_high, (i / 4 + 3))); + } +#undef lut_fma + } + + if (ZeroPoint) { + __m256 vec_s0 = _mm256_loadu_ps(scales + ((i / 4) / Bits) * 16); + __m256 vec_s1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 16); + __m256 vec_s2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 16); + __m256 vec_s3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 16); + vec_c0 = _mm256_fmadd_ps(vec_c0, vec_s0, _mm256_loadu_ps(c + i * 2)); + vec_c1 = _mm256_fmadd_ps(vec_c1, vec_s1, _mm256_loadu_ps(c + i * 2 + 8)); + vec_c2 = _mm256_fmadd_ps(vec_c2, vec_s2, _mm256_loadu_ps(c + i * 2 + 16)); + vec_c3 = _mm256_fmadd_ps(vec_c3, vec_s3, _mm256_loadu_ps(c + i * 2 + 24)); + __m256 vec_z0 = _mm256_loadu_ps(scales + ((i / 4) / Bits) * 16 + 8); + __m256 vec_z1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 16 + 8); + __m256 vec_z2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 16 + 8); + __m256 vec_z3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 16 + 8); + partial_sum *= 2; +#define add_zero(cs, zs, ib) \ + ((ib) % Bits) ? ((cs)) \ + : (_mm256_fmadd_ps((zs), _mm256_set1_ps(partial_sum), (cs))) + _mm256_storeu_ps(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4))); + _mm256_storeu_ps(c + i * 2 + 8, add_zero(vec_c1, vec_z1, (i / 4 + 1))); + _mm256_storeu_ps(c + i * 2 + 16, add_zero(vec_c2, vec_z2, (i / 4 + 2))); + _mm256_storeu_ps(c + i * 2 + 24, add_zero(vec_c3, vec_z3, (i / 4 + 3))); +#undef add_zero + } else if (OneScale) { + float single_scale = scales[0]; + __m256 vec_s = _mm256_set1_ps(single_scale); + _mm256_storeu_ps(c + i * 2, _mm256_fmadd_ps(vec_c0, vec_s, _mm256_loadu_ps(c + i * 2))); + _mm256_storeu_ps(c + i * 2 + 8, _mm256_fmadd_ps(vec_c1, vec_s, _mm256_loadu_ps(c + i * 2 + 8))); + _mm256_storeu_ps(c + i * 2 + 16, _mm256_fmadd_ps(vec_c2, vec_s, _mm256_loadu_ps(c + i * 2 + 16))); + _mm256_storeu_ps(c + i * 2 + 24, _mm256_fmadd_ps(vec_c3, vec_s, _mm256_loadu_ps(c + i * 2 + 24))); + } else { + __m256 vec_s0 = _mm256_loadu_ps(scales + ((i / 4) / Bits) * 8); + __m256 vec_s1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 8); + __m256 vec_s2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 8); + __m256 vec_s3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 8); + _mm256_storeu_ps(c + i * 2, _mm256_fmadd_ps(vec_c0, vec_s0, _mm256_loadu_ps(c + i * 2))); + _mm256_storeu_ps(c + i * 2 + 8, _mm256_fmadd_ps(vec_c1, vec_s1, _mm256_loadu_ps(c + i * 2 + 8))); + _mm256_storeu_ps(c + i * 2 + 16, _mm256_fmadd_ps(vec_c2, vec_s2, _mm256_loadu_ps(c + i * 2 + 16))); + _mm256_storeu_ps(c + i * 2 + 24, _mm256_fmadd_ps(vec_c3, vec_s3, _mm256_loadu_ps(c + i * 2 + 24))); + } + } + + return 0; +} + +int32_t +tbl_int32_reset(int32_t m, int32_t* c) +{ + memset(c, 0, m * sizeof(int32_t)); + return 0; +} + +// based on qgemm_lut_int8_g4 +// Simplified version with hardcoded configuration for 2-bit quantization +void +TMACComputeGemm_avx2( + const uint8_t* A, // Quantized packed weights + const float* Scales, // Weight scales (and optionally zero-points) + const int8_t* LUT, // Pre-computed quantized lookup table + const float* LUT_Scales, // LUT scales from activation quantization + const float* LUT_Biases, // LUT biases from activation quantization + float* C, // Output buffer + int K, + int M, + int N, + size_t BlkLen, // Weight quantization group size (q_group_size) + bool HasZeroPoint +) +{ + // Validate batch size + if (N != 1) { + MLAS_THROW_EX(std::runtime_error, "N > 1 is not supported yet"); + } + + // get kernel config + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(M, K, 2, BlkLen, HasZeroPoint); + + // ==================== CONFIGURATION ==================== + // Fixed parameters for this kernel implementation + bool has_zero_point = tmac_params.has_zero_point; // Whether weights have zero-points (interleaved with scales) + bool one_scale = tmac_params.one_scale; // Whether using single global scale for all weights + + const int32_t bits = static_cast(tmac_params.bits); // 2-bit quantization + const int32_t g = static_cast(tmac_params.g); // Packing group size + const int32_t ngroups_per_elem = static_cast(tmac_params.ngroups_per_elem); // 8 / g = 2 + const int32_t kfactor = static_cast(tmac_params.kfactor); // K-dimension blocking factor + + const bool has_scale = tmac_params.has_scale; // Always use weight scales + + // Parameters derived from inputs + const int32_t q_group_size = static_cast(tmac_params.q_group_size); // Weight quant group size + const int32_t act_group_size = static_cast(tmac_params.act_group_size); // Activation group size (same as weight) + const int32_t actk = static_cast(tmac_params.actk); // CRITICAL: = 16 for BlkLen=64, NOT BlkLen! + + const int32_t bm = static_cast(tmac_params.bm); + int32_t m = bm / bits; + + // Validate configuration + assert(bm % bits == 0); + assert(K % (kfactor * g) == 0); + assert(BlkLen % g == 0); + + // Validate configuration + assert(bm % bits == 0); + assert(K % (kfactor * g) == 0); + assert(BlkLen % g == 0); + + // ==================== ALLOCATE BUFFERS ==================== + // Use float for now (can be changed to _Float16 if needed) + + float* CBits = new float[bm]; + float* C_global = new float[m]; + + // Reset accumulator buffer to zero + tbl_int32_reset(bm * sizeof(float) / sizeof(int32_t), reinterpret_cast(CBits)); + + // ==================== CALCULATE LOOP PARAMETERS ==================== + const int32_t k_outer_max = K / (kfactor * g); + const int32_t scale_gs = q_group_size / (kfactor * g); + + // Calculate bit shift for scale indexing + int32_t scale_idx_shfr = 0; + if (scale_gs == 1) { + scale_idx_shfr = 0; + } else if (scale_gs == 2) { + scale_idx_shfr = 1; + } else if (scale_gs == 4) { + scale_idx_shfr = 2; + } else if (scale_gs == 8) { + scale_idx_shfr = 3; + } else { + MLAS_THROW_EX(std::runtime_error, + ("Unsupported scale_gs=" + std::to_string(scale_gs) + + " (q_group_size=" + std::to_string(q_group_size) + + ", kfactor=" + std::to_string(kfactor) + + ", g=" + std::to_string(g) + "). Expected {1,2,4,8}.").c_str()); + } + + // ==================== MAIN COMPUTATION LOOP ==================== + for (int32_t k_outer = 0; k_outer < k_outer_max; k_outer++) { + // Calculate pointers for this K-outer iteration + const uint8_t* a = A + k_outer * bm * kfactor / ngroups_per_elem; + + // Calculate scales pointer based on configuration + const float* scales = one_scale ? reinterpret_cast(Scales) : // Single global scale + (has_zero_point ? reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m * 2 : // Scale + zero_point pairs + reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m); // Scales only + + // Calculate LUT pointers + const int8_t* lut = reinterpret_cast(LUT) + k_outer * kfactor * (1 << g); // 2^g = 16 for g=4 + const float* lut_scales = reinterpret_cast(LUT_Scales) + + (k_outer * kfactor * g / act_group_size); + const float* lut_biases = reinterpret_cast(LUT_Biases) + + (k_outer * kfactor * g / act_group_size); + + // Select appropriate kernel template based on configuration + // For standard 2-bit, kfactor=16, BlkLen=64: actk = 64/4 = 16 + if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } + // actk == 8 variants (for BlkLen=32) + else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } + // kfactor == 8 variants + else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else { + // No matching kernel template found + MLAS_THROW_EX(std::runtime_error, "No matching kernel found for T-MAC GEMM"); + } + } + + // ==================== GATHER RESULTS ==================== + // Gather bit-plane results into final output + // Only support 2-bit in this implementation + // TODO(vraspar): extend to other bit-widths + tbl_g4_int8_float_gather_bit2_impl(m, C_global, CBits, C); + + // ==================== CLEANUP ==================== + delete[] C_global; + delete[] CBits; +} + +// Kernel dispatch structure definition. + +const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2 = []() { + MLAS_QNBIT_LUT_GEMM_DISPATCH d; + d.GenerateLUT = GenerateLUT_avx2; + d.ComputeGemm = TMACComputeGemm_avx2; + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h new file mode 100644 index 0000000000000..e66eec6fd67ea --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h @@ -0,0 +1,43 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_lut_kernel_avx2.h + +Abstract: + + This module implements x64 AVX2 kernel functions for LUT-based n-bit + quantized integer matrix multiplication. +--*/ + +#pragma once +#include "qnbitgemm.h" + +void +GenerateLUT_avx2( + int32_t group_size, + int8_t lut, + const float* b, + float* scales, + float* biases, + int K +); + +void +TMACComputeGemm_avx2( + const void* A, + const void* a_scales, + const void* LUT, + const void* LUT_Scales, + const void* LUT_Biases, + void* C, + int bm, + int K, + int M, + int N, + size_t BlkLen +); diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 5a72ecb6849c3..fe30cc3f51d85 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -66,12 +66,12 @@ void QuantizeMatMulNBitsBlockwise( tp.get()); } -template -bool QuantizeQDQMatMul4BitsBlockwise( - py::array_t dst, // shape: [K, N / 2] - py::array_t src, // shape: [K, N] - py::array_t scale, // shape: [block_per_K, N] - py::array_t zero_points, // shape: [block_per_K, N / 2] +template +bool QuantizeQDQMatMulNBitsBlockwise( + py::array_t dst, + py::array_t src, + py::array_t scale, + py::array_t zero_points, int32_t quant_block_size, int32_t N, int32_t K, @@ -85,7 +85,7 @@ bool QuantizeQDQMatMul4BitsBlockwise( py::buffer_info scale_buf = scale.request(); py::buffer_info zp_buf = zero_points.request(); - return MlasQDQQuantizeBlockwise( + return MlasQDQQuantizeBlockwise( reinterpret_cast(src_buf.ptr), reinterpret_cast(scale_buf.ptr), is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), @@ -97,6 +97,19 @@ bool QuantizeQDQMatMul4BitsBlockwise( tp.get()); } +template +bool QuantizeQDQMatMul4BitsBlockwise( + py::array_t dst, // shape: [K, N / 2] + py::array_t src, // shape: [K, N] + py::array_t scale, // shape: [block_per_K, N] + py::array_t zero_points, // shape: [block_per_K, N / 2] + int32_t quant_block_size, + int32_t N, + int32_t K, + bool is_symmetric) { + return QuantizeQDQMatMulNBitsBlockwise(dst, src, scale, zero_points, quant_block_size, N, K, is_symmetric); +} + template void QuantizeMatMulBnb4Blockwise( py::array_t dst, @@ -134,6 +147,8 @@ void CreateQuantPybindModule(py::module& m) { m.def("quantize_matmul_8bits", &QuantizeMatMulNBitsBlockwise); m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); + m.def("quantize_qdq_matmul_2bits", &QuantizeQDQMatMulNBitsBlockwise); + m.def("quantize_qdq_matmul_2bits", &QuantizeQDQMatMulNBitsBlockwise); m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise); m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise); } diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index cfdce9479843c..1837c1f4ffa7f 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -15,6 +15,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/mlas/inc/mlas.h" #include "core/session/inference_session.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/unittest_util/framework_test_utils.h" @@ -249,45 +250,200 @@ void TestMatMul2BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { } // namespace -template -struct TypedTestParams { - static constexpr int batch_size = BatchSize; - static constexpr int M = MVal; - static constexpr int N = NVal; - static constexpr int K = KVal; -}; +template +void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size, + bool has_zero_point, float abs_error = 0.15f, float rel_error = 0.05f) { + if (K % 32 != 0 || N % 128 != 0 || block_size % 32 != 0) { + GTEST_SKIP() << "LUT GEMM requires K multiple of 32, N multiple of 128, block_size multiple of 32"; + } -using TestTypes = ::testing::Types< - TypedTestParams<1, 1, 16, 16>, - TypedTestParams<1, 2, 16, 16>, - TypedTestParams<1, 32, 16, 16>, - TypedTestParams<1, 32, 32, 16>, - TypedTestParams<1, 32, 16, 128>, - TypedTestParams<1, 288, 16, 16>, - TypedTestParams<4, 1, 16, 16>, - TypedTestParams<4, 2, 16, 16>, - TypedTestParams<4, 32, 16, 16>, - TypedTestParams<4, 32, 32, 16>, - TypedTestParams<4, 32, 16, 128>, - TypedTestParams<4, 288, 16, 16>>; - -template -class MatMulNBits : public ::testing::Test { - public: - static constexpr int batch_size = T::batch_size; - static constexpr int M = T::M; - static constexpr int N = T::N; - static constexpr int K = T::K; -}; + if (!MlasIsLutGemmAvailable(static_cast(N), static_cast(K), 2, static_cast(block_size))) { + GTEST_SKIP() << "LUT GEMM not available on this platform"; + } + + RandomValueGenerator random{1234}; + std::vector input0_fp32_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); + std::vector input1_fp32_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape(static_cast(block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_rows, q_cols); + + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(static_cast(block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + std::vector input1_vals(q_data_size_in_bytes); + std::vector scales(q_scale_size); + std::vector zp(q_zp_size_in_bytes); + + auto& ortenv = **ort_env.get(); + onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); + + MlasQuantizeBlockwise( + input1_vals.data(), + scales.data(), + has_zero_point ? zp.data() : nullptr, + input1_fp32_vals.data(), + static_cast(block_size), + true, + static_cast(K), + static_cast(N), + static_cast(N), + tp); + + // Dequantize for reference computation + MlasDequantizeBlockwise( + input1_fp32_vals.data(), + input1_vals.data(), + scales.data(), + has_zero_point ? zp.data() : nullptr, + static_cast(block_size), + true, + static_cast(K), + static_cast(N), + tp); -TYPED_TEST_SUITE(MatMulNBits, TestTypes); + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_fp32_vals[m * K + k] * input1_fp32_vals[n * K + k]; + } + expected_vals[m * N + n] = sum; + } + } + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", static_cast(0)); + + if constexpr (std::is_same::value) { + test.AddInput("A", {M, K}, input0_fp32_vals, false); + } + + int64_t k_blocks = (K + block_size - 1) / block_size; + test.AddInput("B", {q_cols, k_blocks, q_rows / k_blocks}, input1_vals, true); + + if constexpr (std::is_same::value) { + test.AddInput("scales", {N, static_cast(q_scale_size) / N}, scales, true); + } + + if (has_zero_point) { + test.AddInput("zero_points", {N, static_cast(q_zp_size_in_bytes) / N}, zp, true); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + + if constexpr (std::is_same::value) { + test.AddOutput("Y", {M, N}, expected_vals); + } + + test.SetOutputAbsErr("Y", abs_error); + test.SetOutputRelErr("Y", rel_error); + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasLutGemm, "1")); + + test.Config(so) + .ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_128x128) { + TestMatMul2BitsLutGemm(1, 128, 128, 32, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x128) { + TestMatMul2BitsLutGemm(1, 128, 128, 32, true); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256) { + TestMatMul2BitsLutGemm(1, 256, 256, 32, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256) { + TestMatMul2BitsLutGemm(1, 256, 256, 32, true); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256_BlkLen64) { + TestMatMul2BitsLutGemm(1, 256, 256, 64, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256_BlkLen64) { + TestMatMul2BitsLutGemm(1, 256, 256, 64, true); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_128x256_BlkLen128) { + TestMatMul2BitsLutGemm(1, 128, 256, 128, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x256_BlkLen128) { + TestMatMul2BitsLutGemm(1, 128, 256, 128, true); +} + +// Batch tests (M > 1) +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_Batch32_128x128) { + TestMatMul2BitsLutGemm(32, 128, 128, 32, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256) { + TestMatMul2BitsLutGemm(32, 256, 256, 32, true); +} -TYPED_TEST(MatMulNBits, Float32_2Bits_Accuracy0) { - TestMatMul2BitsTyped(); +TEST(MatMul2Bits, Float32_2b_Accuracy0) { + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); } -TYPED_TEST(MatMulNBits, Float32_2Bits_Accuracy4) { - TestMatMul2BitsTyped(); +TEST(MatMul2Bits, Float32_2b_Accuracy4) { + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); } } // namespace test diff --git a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp new file mode 100644 index 0000000000000..12ec5ec78f599 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp @@ -0,0 +1,240 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqlutgemm.cpp + +Abstract: + + Tests for MLAS LUT-based n-bit GEMM (TMAC/LUT path) for 2-bit.. + +--*/ + +#include "test_util.h" +#include "mlas_qnbit.h" +#include "mlas_q4.h" + +// Generic template to future-proof for different bit widths; instantiate with 2 for now. +template +class MlasSQLutGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + + MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferQuantBScale; + MatrixGuardBuffer BufferQuantBZeroPoint; + MatrixGuardBuffer BufferDequantizedB; + MatrixGuardBuffer BufferPackedB; // Single buffer for packed weights and scales + + void CallReferenceGemm(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C) { + float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N); + MlasDequantizeBlockwise( + DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), GetMlasThreadPool()); + + // Note: DequantizedBData is in column major layout. + + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const float* a = A + m * K; + const float* b = DequantizedBData + n * K; + float* c = C + (m * N) + n; + float sum = 0.0f; + for (size_t k = 0; k < K; k++) { + sum += (*a) * (*b); + b += 1; + a += 1; + } + *c = sum; + } + } + } + + public: + void Test(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) { + MLAS_THREADPOOL* tp = WithThreadpool ? GetMlasThreadPool() : nullptr; + + // Clear config cache to ensure fresh config for each test case + MlasClearLutGemmKernelConfig(); + + const float* A = BufferA.GetBuffer(K * M); + const float* B = BufferB.GetBuffer(N * K); + float* C = BufferC.GetBuffer(N * M, true); + float* CReference = BufferCReference.GetBuffer(N * M, true); + + // quantize B + uint8_t* QuantBData = nullptr; + float* QuantBScale = nullptr; + uint8_t* QuantBZeroPoint = nullptr; + + { + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes(BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); + QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); + if (!Symmetric) { + QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); + } + + MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, + B, BlkLen, + /* columnwise */ true, + static_cast(K), static_cast(N), + static_cast(N), + GetMlasThreadPool()); + } + + MlasInitLutGemmKernelConfig(N, K, BlkBitWidth, BlkLen, !Symmetric); + + // Use unified packing - single buffer for weights and scales/zp + size_t PackedBufSize = MlasLutGemmPackedSize(N, K, BlkBitWidth, BlkLen, !Symmetric); + std::byte* PackedBuf = BufferPackedB.GetBuffer(PackedBufSize); + + MlasLutGemmPack( + N, + K, + BlkBitWidth, + BlkLen, + !Symmetric, + reinterpret_cast(QuantBData), + QuantBScale, + QuantBZeroPoint, + PackedBuf, + tp); + + MlasLutGemm( + A, + BlkLen, + PackedBuf, + C, + static_cast(K), + static_cast(M), + static_cast(N), + !Symmetric, + tp); + + // Reference computation + CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, CReference); + + size_t f = 0; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_TRUE(CloseEnough(C[f], CReference[f])) + << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " + << "M=" << M << ", N=" << N << ", K=" << K; + } + } + } + + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQLutGemm") + + "BlkBitWidth" + std::to_string(BlkBitWidth) + + "BlkLen" + std::to_string(BlkLen); + return suite_name.c_str(); + } +}; + +// Fixture to register parameterized tests quickly +template +class SQLutGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SQLutGemmShortExecuteTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric) + : M_(M), + N_(N), + K_(K), + WithThreadpool_(WithThreadpool), + Symmetric_(Symmetric) { + } + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test( + M_, N_, K_, WithThreadpool_, Symmetric_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) { + if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) { + return 0; + } + + if (M < BlkLen || N < BlkLen) { + return 0; + } + + std::stringstream ss; + ss << (WithThreadpool ? "Threaded" : "SingleThread") + << "/isSymmetric" << Symmetric + << "/M" << M << "xN" << N << "xK" << K; + + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQLutGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SQLutGemmShortExecuteTest( + M, N, K, WithThreadpool, Symmetric); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t count = 0; + for (bool with_threadpool : {true}) { + for (bool symmetric : {true, false}) { // Test both symmetric and asymmetric + for (size_t b = 256; b < 320; b += 32) { + count += RegisterSingleTest(b, b, b, with_threadpool, symmetric); + } + + count += RegisterSingleTest(64, 128, 128, with_threadpool, symmetric); + count += RegisterSingleTest(128, 256, 256, with_threadpool, symmetric); + } + } + return count; + } + + private: + size_t M_, N_, K_; + bool WithThreadpool_, Symmetric_; +}; + +static size_t SQLutGemmRegisterAllShortExecuteTests() { + size_t count = 0; + count += SQLutGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + count += SQLutGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + count += SQLutGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + count += SQLutGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return SQLutGemmRegisterAllShortExecuteTests(); + } + return 0; + }); diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_2bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_2bits.py new file mode 100644 index 0000000000000..a7a130654407a --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_2bits.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +import numpy.typing as npt + + +def dequantize_blockwise_2bits(quant_values, scale, zero_point, valid_len): + blob_size = quant_values.shape[0] + block_size = blob_size * 4 + + quant_float = np.zeros((block_size), dtype=scale.dtype) + for b in range(blob_size): + v = quant_values[b] + quant_float[4 * b] = ((v & 0x3) - zero_point) * scale if 4 * b < valid_len else 0.0 + quant_float[4 * b + 1] = (((v >> 2) & 0x3) - zero_point) * scale if 4 * b + 1 < valid_len else 0.0 + quant_float[4 * b + 2] = (((v >> 4) & 0x3) - zero_point) * scale if 4 * b + 2 < valid_len else 0.0 + quant_float[4 * b + 3] = (((v >> 6) & 0x3) - zero_point) * scale if 4 * b + 3 < valid_len else 0.0 + return quant_float + + +def quantize_blockwise_2bits_ref(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): + if len(matrix_float.shape) != 2: + raise ValueError("Current int2 block quantization only supports 2D tensors!") + rows, cols = matrix_float.shape + + blob_size = block_size // 4 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + matrix_float_padded = matrix_float + if pad_len > 0: + matrix_float_padded = np.pad(matrix_float, ((0, pad_len), (0, 0)), "constant") + + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + scales = np.zeros((cols, k_blocks), dtype=matrix_float_padded.dtype) + zero_point = np.full((cols, (k_blocks + 3) // 4), 0xAA, dtype="uint8") + + matrix_float_padded = np.transpose(matrix_float_padded) + for n in range(cols): + for k_id in range(0, rows, block_size): + if is_symmetric: + amax_idx = np.argmax(np.abs(matrix_float_padded[n, k_id : k_id + block_size])) + bmax = np.float32(matrix_float_padded[n, k_id + amax_idx]) + scale = bmax / (-2.0) + zp = 2 + else: + vmin = np.min(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmax = np.max(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmin = min(vmin, 0.0) + vmax = max(vmax, 0.0) + scale = (vmax - vmin) / ((1 << 2) - 1) + zero_point_fp = vmin + if scale != 0.0: + zero_point_fp = 0.0 - vmin / scale + zp = min(3, max(0, round(zero_point_fp))) + + reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 + block_idx = k_id // block_size + scales[n, block_idx] = scale + zp_pair = zero_point[n, block_idx // 4] + zp_idx = block_idx % 4 + zp_masks = [0xFC, 0xF3, 0xCF, 0x3F] + zero_point[n, block_idx // 4] = (zp_pair & zp_masks[zp_idx]) | (zp << (zp_idx * 2)) + + blk_int0 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int1 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int2 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 2 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int3 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 3 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + packed[n, block_idx] = np.bitwise_or( + np.bitwise_or(blk_int0, np.left_shift(blk_int1, 2)), + np.bitwise_or(np.left_shift(blk_int2, 4), np.left_shift(blk_int3, 6)), + ) + + return (packed, scales, zero_point) + + +def quantize_blockwise_2bits_target(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): + if len(matrix_float.shape) != 2: + raise ValueError("Current int2 block quantization only supports 2D tensors!") + rows, cols = matrix_float.shape + + k_blocks = (rows + block_size - 1) // block_size + packed = np.zeros((cols, k_blocks, block_size // 4), dtype="uint8") + scales = np.zeros((cols, k_blocks), dtype=matrix_float.dtype) + zero_point = np.full((cols, (k_blocks + 3) // 4), 0xAA, dtype="uint8") + from onnxruntime.capi._pybind_state import quantize_matmul_2bits # noqa: PLC0415 + + quantize_matmul_2bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) + return (packed, scales, zero_point) + + +class TestQuantizeBlockwise2Bits(unittest.TestCase): + def test_quantize_blockwise_2bits(self): + for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for block_size in [16, 32, 64, 128]: + for type in [np.float32, np.float16]: + for is_symmetric in [True, False]: + matrix_float = np.random.rand(rows, cols).astype(type) + quant_value_ref, scales_ref, zero_point_ref = quantize_blockwise_2bits_ref( + matrix_float, block_size, is_symmetric + ) + quant_value, scales, zero_point = quantize_blockwise_2bits_target( + matrix_float, block_size, is_symmetric + ) + assert np.allclose(scales_ref, scales) + assert np.allclose(zero_point_ref, zero_point) + for c in range(quant_value_ref.shape[0]): + for k in range(quant_value_ref.shape[1]): + zp_shift = (k % 4) * 2 + assert np.allclose( + dequantize_blockwise_2bits( + quant_value_ref[c, k], + scales_ref[c, k], + (zero_point_ref[c, k // 4] >> zp_shift) & 0x3, + min(block_size, rows - k * block_size), + ), + dequantize_blockwise_2bits( + quant_value[c, k], + scales[c, k], + (zero_point[c, k // 4] >> zp_shift) & 0x3, + min(block_size, rows - k * block_size), + ), + atol=1.2 * abs(scales[c, k]), + ) + + +if __name__ == "__main__": + unittest.main()