-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Implement new experimental lookup-based matrix multiplication method(TMAC) #26695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5484560
8c1cfe1
f6f22e3
3e1a951
0130061
b4aad01
ff531cb
6849ea2
e85431e
9642740
892222a
07b7f3f
5fb2edd
493ebd1
b4b143f
534b8e6
70d6588
ad2572b
b312815
bfeac34
a5de108
7ff8218
6d8e8ec
5d19daf
c600056
5cf99e6
f9a9b47
5687e5e
6f08418
6191aad
f2de776
9ef6d75
59c0055
457cfa3
bdb2982
289e53e
fabae08
5d8a6ee
b1fcda1
3eb22b0
f61c3d8
bebcb64
6a2e822
a19b2f6
26678b2
e5f80cb
b518ce9
f94e51e
7b708ad
58e93ec
469cde7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| #include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" | ||
|
|
||
| #include <cstdint> | ||
| #include <memory> | ||
| #include <type_traits> | ||
|
|
||
| #include "core/common/common.h" | ||
|
|
@@ -15,7 +16,10 @@ | |
| #include "core/mlas/inc/mlas_q4.h" | ||
| #include "core/providers/cpu/math/matmul_helper.h" | ||
| #include "core/providers/common.h" | ||
| #include "core/session/onnxruntime_session_options_config_keys.h" | ||
| #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" | ||
| #include "core/platform/threadpool.h" | ||
| #include "core/util/thread_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
|
|
@@ -100,6 +104,11 @@ | |
| nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))}, | ||
| has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()}, | ||
| has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()}, | ||
| prefer_lut_gemm_{info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLutGemm) == "1" && | ||
| MlasIsLutGemmAvailable(narrow<size_t>(info.GetAttr<int64_t>("N")), | ||
| narrow<size_t>(info.GetAttr<int64_t>("K")), | ||
| narrow<size_t>(info.GetAttr<int64_t>("bits")), | ||
| narrow<size_t>(info.GetAttr<int64_t>("block_size")))}, | ||
| compute_type_{GetComputeType<T1>(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} { | ||
| const auto& node = info.node(); | ||
| auto input_defs = node.InputDefs(); | ||
|
|
@@ -135,6 +144,7 @@ | |
| const bool has_g_idx_; | ||
| const bool has_bias_; | ||
| bool scales_are_packed_{false}; | ||
| const bool prefer_lut_gemm_{false}; | ||
| const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; | ||
| bool has_unquantized_zero_point_{false}; | ||
| const bool column_wise_quant_{true}; | ||
|
|
@@ -167,6 +177,11 @@ | |
| AllocatorPtr& allocator, | ||
| concurrency::ThreadPool* thread_pool, | ||
| const MatMulComputeHelper& helper) const; | ||
|
|
||
| Status ComputeBPackedLUT(const Tensor* a, | ||
| Tensor* y, | ||
| concurrency::ThreadPool* thread_pool, | ||
| const MatMulComputeHelper& helper) const; | ||
| }; | ||
|
|
||
| template <typename T1> | ||
|
|
@@ -179,22 +194,76 @@ | |
| return Status::OK(); | ||
| } | ||
|
|
||
| if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { | ||
| if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && !prefer_lut_gemm_) { | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| // Create a temporary threadpool for parallel packing | ||
| // This is used during model load time to speed up weight prepacking | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the overhead like for creating a new threadpool in each call to I wonder if we should make an existing threadpool available to this code. perhaps we can pass in the threadpool from SessionState. something to consider, and maybe for a future PR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, passing thread pool to |
||
| std::unique_ptr<concurrency::ThreadPool> temp_threadpool; | ||
| concurrency::ThreadPool* threadpool_ptr = nullptr; | ||
|
|
||
| // Only create threadpool for LUT GEMM path which can benefit from parallel packing | ||
| // TODO: Consider extending threadpool usage to non-LUT path (CompInt8) with appropriate tests | ||
|
Check warning on line 207 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
|
||
| if (prefer_lut_gemm_) { | ||
| OrtThreadPoolParams tpo; | ||
| tpo.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); | ||
| tpo.allow_spinning = false; // Don't spin during model load | ||
| tpo.auto_set_affinity = false; | ||
|
|
||
| temp_threadpool = concurrency::CreateThreadPool( | ||
vraspar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| &Env::Default(), | ||
| tpo, | ||
| concurrency::ThreadPoolType::INTRA_OP); | ||
|
|
||
| threadpool_ptr = temp_threadpool.get(); | ||
| } | ||
|
|
||
| if (input_idx == InputIndex::B) { | ||
| const Tensor* scales = nullptr; | ||
| OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); | ||
|
|
||
| packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); | ||
| if (packed_b_size_ == 0) { | ||
| return Status::OK(); | ||
| if (prefer_lut_gemm_) { | ||
| MlasInitLutGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); | ||
|
|
||
| packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_); | ||
| if (packed_b_size_ == 0) { | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true); | ||
|
|
||
| const float* scales_ptr = scales ? scales->Data<float>() : nullptr; | ||
| const uint8_t* zp_ptr = nullptr; | ||
| if (scales_ptr != nullptr && has_zp_input_) { | ||
| const Tensor* zero_points = nullptr; | ||
| OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); | ||
| zp_ptr = zero_points ? zero_points->Data<uint8_t>() : nullptr; | ||
| } | ||
|
|
||
| MlasLutGemmPack( | ||
| N_, K_, nbits_, block_size_, has_zp_input_, | ||
| static_cast<const std::byte*>(tensor.DataRaw()), | ||
| scales_ptr, | ||
| zp_ptr, | ||
| static_cast<std::byte*>(packed_b_.get()), | ||
| threadpool_ptr); | ||
|
|
||
| if (prepacked_weights != nullptr) { | ||
| prepacked_weights->buffers_.push_back(std::move(packed_b_)); | ||
| prepacked_weights->buffer_sizes_.push_back(packed_b_size_); | ||
| } | ||
| } else { | ||
| packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); | ||
| if (packed_b_size_ == 0) { | ||
| return Status::OK(); | ||
| } | ||
| auto qptr = tensor.DataRaw(); | ||
| auto scale_ptr = scales ? scales->DataRaw() : nullptr; | ||
| packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true); | ||
| MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, | ||
| has_zp_input_, nullptr, threadpool_ptr); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC - The usage of threadpool in the existing non-LUT path seems like a new addition - is that intentaional (and come with apprioriate tests) ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initially, I thought tests in Once we add tests, I think it might be beneficial to use thread pool for pre packing for other paths |
||
| } | ||
| auto qptr = tensor.DataRaw(); | ||
| auto scale_ptr = scales ? scales->DataRaw() : nullptr; | ||
| packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true); | ||
| MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, | ||
| has_zp_input_, nullptr, nullptr); | ||
| is_packed = true; | ||
| } else if (compute_type_ == SQNBIT_CompInt8) { | ||
| // Packing scales and zero points | ||
|
|
@@ -230,8 +299,30 @@ | |
| is_packed = true; | ||
| } | ||
| #endif // MLAS_TARGET_ARM64 | ||
| } else if (prefer_lut_gemm_) { | ||
| // Pack scales/zero_points for LUT GEMM if B was already packed but scales weren't available then | ||
| if (input_idx == InputIndex::scales && packed_b_ != nullptr) { | ||
| auto scales_ptr = tensor.Data<float>(); | ||
| const uint8_t* zp_ptr = nullptr; | ||
| if (has_zp_input_) { | ||
| const Tensor* zero_points = nullptr; | ||
| OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); | ||
| zp_ptr = zero_points ? zero_points->Data<uint8_t>() : nullptr; | ||
| } | ||
| // Pack only scales (QuantBData is nullptr) | ||
| MlasLutGemmPack( | ||
| N_, K_, nbits_, block_size_, has_zp_input_, | ||
| nullptr, // QuantBData already packed | ||
| scales_ptr, | ||
| zp_ptr, | ||
| static_cast<std::byte*>(packed_b_.get()), | ||
| nullptr); // No threadpool needed for scales only | ||
| is_packed = false; // scales tensor can be released but not "packed" in the ORT sense | ||
| } | ||
| } | ||
|
|
||
| // Threadpool will be automatically destroyed when temp_threadpool goes out of scope | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
|
|
@@ -307,14 +398,34 @@ | |
| /*out*/ bool& used_shared_buffers) { | ||
| used_shared_buffers = false; | ||
|
|
||
| if (input_idx == 1) { | ||
| used_shared_buffers = true; | ||
| if (input_idx == InputIndex::B && !prepacked_buffers.empty()) { | ||
| packed_b_ = std::move(prepacked_buffers[0]); | ||
| used_shared_buffers = true; | ||
|
|
||
| if (prefer_lut_gemm_) { | ||
| MlasInitLutGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); | ||
| packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_); | ||
| } | ||
| } | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| template <typename T1> | ||
| Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a, | ||
| Tensor* y, | ||
| concurrency::ThreadPool* thread_pool, | ||
| const MatMulComputeHelper& helper) const { | ||
| const auto* a_data = a->Data<T1>(); | ||
| auto* y_data = y->MutableData<T1>(); | ||
| const int M = static_cast<int>(helper.M()); | ||
| const int N = static_cast<int>(helper.N()); | ||
| const int K = static_cast<int>(helper.K()); | ||
|
|
||
| MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| template <typename T1> | ||
| Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a, | ||
| const Tensor* scales, | ||
|
|
@@ -740,7 +851,7 @@ | |
| // If B is prepacked, B would have been removed from the context | ||
| const bool is_b_prepacked = packed_b_size_ > 0; | ||
| const Tensor* b = is_b_prepacked ? nullptr : ctx->Input<Tensor>(InputIndex::B); | ||
| const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input<Tensor>(InputIndex::scales); | ||
| const Tensor* scales = (scales_are_packed_ || (prefer_lut_gemm_ && packed_b_)) ? nullptr : ctx->Input<Tensor>(InputIndex::scales); | ||
| const Tensor* zero_points = ctx->Input<Tensor>(InputIndex::zero_points); | ||
| const Tensor* reorder_idx = ctx->Input<Tensor>(InputIndex::g_idx); | ||
| const Tensor* bias = ctx->Input<Tensor>(InputIndex::bias); | ||
|
|
@@ -774,6 +885,10 @@ | |
| // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while | ||
| // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() | ||
| // with B directly too. | ||
| if (prefer_lut_gemm_) { | ||
| return ComputeBPackedLUT(a, y, thread_pool, helper); | ||
| } | ||
|
|
||
| if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { | ||
| return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.