Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
5484560
init code structure for matmul 2 bits
liqunfu Jan 30, 2025
8c1cfe1
add and pass q4dq tests for q2bit - rename file and test name later
liqunfu Jan 31, 2025
f6f22e3
some fixes
liqunfu Jan 31, 2025
3e1a951
add apis to neon and other avxs
liqunfu Feb 3, 2025
0130061
fix neon build
liqunfu Feb 3, 2025
b4aad01
disable 2bit test
liqunfu Feb 3, 2025
ff531cb
2 bit quantize to support model builder
liqunfu Mar 7, 2025
6849ea2
Merge remote-tracking branch 'msft/main' into carzh/bitnet-reverse-la…
carzh Jul 16, 2025
e85431e
fix compile errors
carzh Jul 17, 2025
9642740
resolve build failure update
carzh Jul 18, 2025
892222a
2 bits check
HectorSVC Jul 23, 2025
07b7f3f
fixed bug causing int8 tests to fail
Jul 25, 2025
5fb2edd
Merge remote-tracking branch 'origin/main' into carzh/bitnet-reverse-…
carzh Aug 7, 2025
493ebd1
lintrunner
carzh Aug 7, 2025
b4b143f
prepack wip -- not prepacking b data because dispatch to check for ml…
carzh Aug 13, 2025
534b8e6
fixed dispatch issue, added acc level 4 tests, and now running into a…
carzh Aug 15, 2025
70d6588
deep sigh
Sep 2, 2025
ad2572b
builds somehow
Sep 4, 2025
b312815
update
Sep 10, 2025
bfeac34
udpate
Sep 16, 2025
a5de108
Implement Pre Packing of qweight for tmac
vraspar Oct 1, 2025
7ff8218
Implement Pre packing for Scales and zero points
vraspar Oct 6, 2025
6d8e8ec
Transform zero points before interleaving
vraspar Oct 6, 2025
5d19daf
Initial implementation of tmac kernel config
vraspar Oct 7, 2025
c600056
Move pre packing scales and zp code to qlutgemm and use tmac_params
vraspar Oct 8, 2025
5cf99e6
update
Oct 13, 2025
f9a9b47
bug fixes
Oct 16, 2025
5687e5e
Fix bug in scale unpacking
vraspar Oct 21, 2025
6f08418
Fix issues with TMAC GEMM kernels and remove hard coded variables
vraspar Oct 28, 2025
6191aad
Fix bug in LUT table generation
vraspar Oct 31, 2025
f2de776
Fix casting issue
vraspar Nov 10, 2025
9ef6d75
add session option and clean up
vraspar Nov 13, 2025
59c0055
Refactor QNBit GEMM Implementation for AVX2
vraspar Dec 1, 2025
457cfa3
Refactor dispatch
vraspar Dec 2, 2025
bdb2982
Add test cases
vraspar Dec 2, 2025
289e53e
rewrite test_sqlutgemm.cpp
vraspar Dec 10, 2025
fabae08
Add more robust checking before using LUT kernels
vraspar Dec 11, 2025
5d8a6ee
Merge remote-tracking branch 'origin/main' into vraspar/lut-gemm
vraspar Dec 16, 2025
b1fcda1
revert graph_transform_test.cc
vraspar Dec 16, 2025
3eb22b0
Clean up: revert unchanged files
vraspar Dec 16, 2025
f61c3d8
Apply linting and clean up
vraspar Dec 22, 2025
bebcb64
Add headers, update binding, and general clean up + linting
vraspar Dec 24, 2025
6a2e822
Fix zero point test cases
vraspar Dec 24, 2025
a19b2f6
Refactor ComputeBPackedLUT to remove unused parameters and simplify f…
vraspar Jan 2, 2026
26678b2
Merge remote-tracking branch 'origin/main' into vraspar/lut-gemm
vraspar Jan 2, 2026
e5f80cb
Fix compiler warnings
vraspar Jan 3, 2026
b518ce9
Improve error handling in TMACComputeGemm_avx2 for batch size and sca…
vraspar Jan 3, 2026
f94e51e
Apply feedback and use PrePacking
vraspar Jan 8, 2026
7b708ad
update platform.cpp
vraspar Jan 8, 2026
58e93ec
use MLAS_THROW_EX for qlutgemm.cpp
vraspar Jan 9, 2026
469cde7
Add LUT GEMM 2-bit tests and fix Python quantization reference implem…
Jan 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/qnbitgemm.h
${MLAS_SRC_DIR}/qnbitgemm.cpp
${MLAS_SRC_DIR}/qlutgemm.h
${MLAS_SRC_DIR}/qlutgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
Expand Down Expand Up @@ -209,6 +211,8 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
Expand Down Expand Up @@ -693,6 +697,8 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h
${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

// Use LUT (Lookup Table) based GEMM for quantized models when available.
// Option values:
// - "0": Do not use LUT based GEMM. [DEFAULT]
// - "1": Use LUT based GEMM when available.
static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm";

// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
// Refer to MatMulNBits op schema for more details.
// If not provided, default is 4.
Expand Down
139 changes: 127 additions & 12 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"

#include <cstdint>
#include <memory>
#include <type_traits>

#include "core/common/common.h"
Expand All @@ -15,7 +16,10 @@
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "contrib_ops/cpu/quantization/matmul_nbits_helper.h"
#include "core/platform/threadpool.h"
#include "core/util/thread_utils.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -100,6 +104,11 @@
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()},
has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()},
prefer_lut_gemm_{info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLutGemm) == "1" &&
MlasIsLutGemmAvailable(narrow<size_t>(info.GetAttr<int64_t>("N")),
narrow<size_t>(info.GetAttr<int64_t>("K")),
narrow<size_t>(info.GetAttr<int64_t>("bits")),
narrow<size_t>(info.GetAttr<int64_t>("block_size")))},
compute_type_{GetComputeType<T1>(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
const auto& node = info.node();
auto input_defs = node.InputDefs();
Expand Down Expand Up @@ -135,6 +144,7 @@
const bool has_g_idx_;
const bool has_bias_;
bool scales_are_packed_{false};
const bool prefer_lut_gemm_{false};
const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_;
bool has_unquantized_zero_point_{false};
const bool column_wise_quant_{true};
Expand Down Expand Up @@ -167,6 +177,11 @@
AllocatorPtr& allocator,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const;

Status ComputeBPackedLUT(const Tensor* a,
Tensor* y,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const;
};

template <typename T1>
Expand All @@ -179,22 +194,76 @@
return Status::OK();
}

if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {
if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && !prefer_lut_gemm_) {
return Status::OK();
}

// Create a temporary threadpool for parallel packing
// This is used during model load time to speed up weight prepacking
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the overhead like for creating a new threadpool in each call to PrePack()?

I wonder if we should make an existing threadpool available to this code. perhaps we can pass in the threadpool from SessionState. something to consider, and maybe for a future PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, passing thread pool to PrePack would be clean. I am planning to create second PR improving Prepacking logic in general, I will include this along with this :)

std::unique_ptr<concurrency::ThreadPool> temp_threadpool;
concurrency::ThreadPool* threadpool_ptr = nullptr;

// Only create threadpool for LUT GEMM path which can benefit from parallel packing
// TODO: Consider extending threadpool usage to non-LUT path (CompInt8) with appropriate tests

Check warning on line 207 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:207: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
if (prefer_lut_gemm_) {
OrtThreadPoolParams tpo;
tpo.thread_pool_size = Env::Default().GetNumPhysicalCpuCores();
tpo.allow_spinning = false; // Don't spin during model load
tpo.auto_set_affinity = false;

temp_threadpool = concurrency::CreateThreadPool(
&Env::Default(),
tpo,
concurrency::ThreadPoolType::INTRA_OP);

threadpool_ptr = temp_threadpool.get();
}

if (input_idx == InputIndex::B) {
const Tensor* scales = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales);

packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_);
if (packed_b_size_ == 0) {
return Status::OK();
if (prefer_lut_gemm_) {
MlasInitLutGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_);

packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_);
if (packed_b_size_ == 0) {
return Status::OK();
}

packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);

const float* scales_ptr = scales ? scales->Data<float>() : nullptr;
const uint8_t* zp_ptr = nullptr;
if (scales_ptr != nullptr && has_zp_input_) {
const Tensor* zero_points = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points);
zp_ptr = zero_points ? zero_points->Data<uint8_t>() : nullptr;
}

MlasLutGemmPack(
N_, K_, nbits_, block_size_, has_zp_input_,
static_cast<const std::byte*>(tensor.DataRaw()),
scales_ptr,
zp_ptr,
static_cast<std::byte*>(packed_b_.get()),
threadpool_ptr);

if (prepacked_weights != nullptr) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
} else {
packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_);
if (packed_b_size_ == 0) {
return Status::OK();
}
auto qptr = tensor.DataRaw();
auto scale_ptr = scales ? scales->DataRaw() : nullptr;
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr,
has_zp_input_, nullptr, threadpool_ptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC - The usage of threadpool in the existing non-LUT path seems like a new addition - is that intentaional (and come with apprioriate tests) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, I thought tests in test_sqnbitgemm.cpp should suffice since they already test it with thread pool. I applied changes to only use thread pool for LUT path now.

Once we add tests, I think it might be beneficial to use thread pool for pre packing for other paths

}
auto qptr = tensor.DataRaw();
auto scale_ptr = scales ? scales->DataRaw() : nullptr;
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr,
has_zp_input_, nullptr, nullptr);
is_packed = true;
} else if (compute_type_ == SQNBIT_CompInt8) {
// Packing scales and zero points
Expand Down Expand Up @@ -230,8 +299,30 @@
is_packed = true;
}
#endif // MLAS_TARGET_ARM64
} else if (prefer_lut_gemm_) {
// Pack scales/zero_points for LUT GEMM if B was already packed but scales weren't available then
if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
auto scales_ptr = tensor.Data<float>();
const uint8_t* zp_ptr = nullptr;
if (has_zp_input_) {
const Tensor* zero_points = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points);
zp_ptr = zero_points ? zero_points->Data<uint8_t>() : nullptr;
}
// Pack only scales (QuantBData is nullptr)
MlasLutGemmPack(
N_, K_, nbits_, block_size_, has_zp_input_,
nullptr, // QuantBData already packed
scales_ptr,
zp_ptr,
static_cast<std::byte*>(packed_b_.get()),
nullptr); // No threadpool needed for scales only
is_packed = false; // scales tensor can be released but not "packed" in the ORT sense
}
}

// Threadpool will be automatically destroyed when temp_threadpool goes out of scope

return Status::OK();
}

Expand Down Expand Up @@ -307,14 +398,34 @@
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;

if (input_idx == 1) {
used_shared_buffers = true;
if (input_idx == InputIndex::B && !prepacked_buffers.empty()) {
packed_b_ = std::move(prepacked_buffers[0]);
used_shared_buffers = true;

if (prefer_lut_gemm_) {
MlasInitLutGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_);
packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_);
}
}

return Status::OK();
}

template <typename T1>
Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a,
Tensor* y,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const {
const auto* a_data = a->Data<T1>();
auto* y_data = y->MutableData<T1>();
const int M = static_cast<int>(helper.M());
const int N = static_cast<int>(helper.N());
const int K = static_cast<int>(helper.K());

MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool);
return Status::OK();
}

template <typename T1>
Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a,
const Tensor* scales,
Expand Down Expand Up @@ -740,7 +851,7 @@
// If B is prepacked, B would have been removed from the context
const bool is_b_prepacked = packed_b_size_ > 0;
const Tensor* b = is_b_prepacked ? nullptr : ctx->Input<Tensor>(InputIndex::B);
const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input<Tensor>(InputIndex::scales);
const Tensor* scales = (scales_are_packed_ || (prefer_lut_gemm_ && packed_b_)) ? nullptr : ctx->Input<Tensor>(InputIndex::scales);
const Tensor* zero_points = ctx->Input<Tensor>(InputIndex::zero_points);
const Tensor* reorder_idx = ctx->Input<Tensor>(InputIndex::g_idx);
const Tensor* bias = ctx->Input<Tensor>(InputIndex::bias);
Expand Down Expand Up @@ -774,6 +885,10 @@
// If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while
// MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch()
// with B directly too.
if (prefer_lut_gemm_) {
return ComputeBPackedLUT(a, y, thread_pool, helper);
}

if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {
return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper);
}
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void Dequantize4BitsKernelReOrder(
}
}

template <typename inputT, typename zeroT>
template <typename inputT, typename zeroT, int qbits>
void DequantizeBlockwise(
inputT* output, // dequantized output
const uint8_t* quant_data, // quantized input
Expand Down Expand Up @@ -102,17 +102,17 @@ void DequantizeBlockwise(
});
}

template void DequantizeBlockwise<float, uint8_t>(
template void DequantizeBlockwise<float, uint8_t, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise<float, float>(
template void DequantizeBlockwise<float, float, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const float* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise<float, MLFloat16>(
template void DequantizeBlockwise<float, MLFloat16, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace onnxruntime {
namespace contrib {

template <typename inputT, typename zeroT>
template <typename inputT, typename zeroT, int qbits = 4>
void DequantizeBlockwise(
inputT* output, // dequantized output
const uint8_t* quant_data, // quantized input
Expand Down
Loading
Loading