-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Implement new experimental lookup-based matrix multiplication method(TMAC) #26695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 38 commits
5484560
8c1cfe1
f6f22e3
3e1a951
0130061
b4aad01
ff531cb
6849ea2
e85431e
9642740
892222a
07b7f3f
5fb2edd
493ebd1
b4b143f
534b8e6
70d6588
ad2572b
b312815
bfeac34
a5de108
7ff8218
6d8e8ec
5d19daf
c600056
5cf99e6
f9a9b47
5687e5e
6f08418
6191aad
f2de776
9ef6d75
59c0055
457cfa3
bdb2982
289e53e
fabae08
5d8a6ee
b1fcda1
3eb22b0
f61c3d8
bebcb64
6a2e822
a19b2f6
26678b2
e5f80cb
b518ce9
f94e51e
7b708ad
58e93ec
469cde7
48fd982
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,10 @@ | |
| #include "core/mlas/inc/mlas_q4.h" | ||
| #include "core/providers/cpu/math/matmul_helper.h" | ||
| #include "core/providers/common.h" | ||
| #include "core/session/onnxruntime_session_options_config_keys.h" | ||
| #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" | ||
| #include "core/platform/threadpool.h" | ||
| #include "core/util/thread_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
|
|
@@ -39,12 +42,14 @@ | |
| Level2, /*!< input fp16, accumulator fp16 */ | ||
| Level3, /*!< input bf16, accumulator fp32 */ | ||
| Level4, /*!< input int8, accumulator int32 */ | ||
| Level5, /*!< input uint8, use TMAC LUT approach TODO: fix this comment*/ | ||
| } ACCURACY_LEVEL; | ||
|
|
||
| // T: A data type. | ||
| template <typename T> | ||
| MLAS_QNBIT_GEMM_COMPUTE_TYPE | ||
| GetComputeType(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { | ||
|
|
||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // For Fp32, only accuracy level 1 or 4 makes sense. | ||
| // non-ARM CPU converts Fp16 to Fp32. | ||
| // By converting Fp32 to Fp16, precision becomes worse. And due to the casting, | ||
|
|
@@ -54,6 +59,7 @@ | |
| return SQNBIT_CompInt8; | ||
| } | ||
|
|
||
|
|
||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return SQNBIT_CompFp32; | ||
| } | ||
|
|
||
|
|
@@ -100,6 +106,7 @@ | |
| nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))}, | ||
| has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()}, | ||
| has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()}, | ||
| prefer_lut_gemm_{info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLUTGemm) == "1"}, | ||
| compute_type_{GetComputeType<T1>(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} { | ||
| const auto& node = info.node(); | ||
| auto input_defs = node.InputDefs(); | ||
|
|
@@ -116,6 +123,8 @@ | |
| "Only 2b, 4b and 8b quantization is supported for MatMulNBits op, additional bits support is planned."); | ||
| const Tensor* tensor_zero_point = nullptr; | ||
| has_zp_input_ = info.TryGetConstantInput(InputIndex::zero_points, &tensor_zero_point); | ||
| prefer_lut_gemm_ = true; | ||
| prefer_lut_gemm_ = prefer_lut_gemm_ && MlasIsLUTGemmAvailable(N_, K_, nbits_, block_size_); | ||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| Status Compute(OpKernelContext* context) const override; | ||
|
|
@@ -135,11 +144,14 @@ | |
| const bool has_g_idx_; | ||
| const bool has_bias_; | ||
| bool scales_are_packed_{false}; | ||
| bool prefer_lut_gemm_{false}; | ||
| const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; | ||
| bool has_unquantized_zero_point_{false}; | ||
| const bool column_wise_quant_{true}; | ||
| IAllocatorUniquePtr<void> packed_b_{}; | ||
| size_t packed_b_size_{0}; | ||
| IAllocatorUniquePtr<float> packed_scales_zp_{}; | ||
| size_t packed_scales_zp_size_{0}; | ||
| IAllocatorUniquePtr<float> scales_fp32_{}; | ||
| IAllocatorUniquePtr<float> bias_fp32_{}; | ||
|
|
||
|
|
@@ -167,6 +179,15 @@ | |
| AllocatorPtr& allocator, | ||
| concurrency::ThreadPool* thread_pool, | ||
| const MatMulComputeHelper& helper) const; | ||
|
|
||
| Status ComputeBPackedLUT(const Tensor* a, | ||
| const Tensor* scales, | ||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const Tensor* zero_points, | ||
| const Tensor* bias, | ||
| Tensor* y, | ||
| AllocatorPtr& allocator, | ||
| concurrency::ThreadPool* thread_pool, | ||
| const MatMulComputeHelper& helper) const; | ||
| }; | ||
|
|
||
| template <typename T1> | ||
|
|
@@ -175,26 +196,62 @@ | |
| /*out*/ PrePackedWeights* prepacked_weights) { | ||
| ORT_UNUSED_PARAMETER(prepacked_weights); | ||
| is_packed = false; | ||
| if (has_g_idx_ || has_unquantized_zero_point_) { | ||
| // if (has_g_idx_ || has_unquantized_zero_point_) | ||
| // TODO: this part modified so i can test ek atmulnbits | ||
|
Check warning on line 200 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
|
||
| if (has_g_idx_) { | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { | ||
| if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && !prefer_lut_gemm_) { | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| // Create a temporary threadpool for parallel packing | ||
| // This is used during model load time to speed up weight prepacking | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the overhead like for creating a new threadpool in each call to I wonder if we should make an existing threadpool available to this code. perhaps we can pass in the threadpool from SessionState. something to consider, and maybe for a future PR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, passing thread pool to |
||
| std::unique_ptr<concurrency::ThreadPool> temp_threadpool; | ||
|
Check warning on line 211 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
|
||
| concurrency::ThreadPool* threadpool_ptr = nullptr; | ||
|
|
||
| // Only create threadpool for operations that can benefit from it | ||
| if (prefer_lut_gemm_ || compute_type_ == SQNBIT_CompInt8) { | ||
| OrtThreadPoolParams tpo; | ||
| tpo.thread_pool_size = 4; // Use default (typically number of cores) | ||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| tpo.allow_spinning = false; // Don't spin during model load | ||
| tpo.auto_set_affinity = false; | ||
|
|
||
| temp_threadpool = concurrency::CreateThreadPool( | ||
vraspar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| &Env::Default(), | ||
| tpo, | ||
| concurrency::ThreadPoolType::INTRA_OP); | ||
|
|
||
| threadpool_ptr = temp_threadpool.get(); | ||
| } | ||
|
|
||
| if (input_idx == InputIndex::B) { | ||
|
|
||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const Tensor* scales = nullptr; | ||
| OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); | ||
|
|
||
| packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); | ||
| if (packed_b_size_ == 0) { | ||
| return Status::OK(); | ||
| if (prefer_lut_gemm_) { | ||
| MlasInitLUTGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); | ||
| packed_b_size_ = MlasLUTGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_); | ||
| if (packed_b_size_ == 0) { | ||
| return Status::OK(); | ||
| } | ||
| auto qptr = tensor.DataRaw(); | ||
| packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true); | ||
| MlasLUTGemmPackQuantBData(N_, K_, nbits_, block_size_, static_cast<const std::byte*>(qptr), static_cast<std::byte*>(packed_b_.get()), threadpool_ptr); | ||
| } else { | ||
| packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); | ||
| if (packed_b_size_ == 0) { | ||
| return Status::OK(); | ||
| } | ||
| auto qptr = tensor.DataRaw(); | ||
| auto scale_ptr = scales ? scales->DataRaw() : nullptr; | ||
| packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true); | ||
| MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, | ||
| has_zp_input_, nullptr, threadpool_ptr); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC - The usage of threadpool in the existing non-LUT path seems like a new addition - is that intentaional (and come with apprioriate tests) ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initially, I thought tests in Once we add tests, I think it might be beneficial to use thread pool for pre packing for other paths
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. closing comment for now to merge as discussed offline |
||
|
|
||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| auto qptr = tensor.DataRaw(); | ||
| auto scale_ptr = scales ? scales->DataRaw() : nullptr; | ||
| packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true); | ||
| MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, | ||
| has_zp_input_, nullptr, nullptr); | ||
| is_packed = true; | ||
| } else if (compute_type_ == SQNBIT_CompInt8) { | ||
| // Packing scales and zero points | ||
|
|
@@ -230,8 +287,26 @@ | |
| is_packed = true; | ||
| } | ||
| #endif // MLAS_TARGET_ARM64 | ||
| } else if (prefer_lut_gemm_) { | ||
| if (input_idx == InputIndex::scales && packed_b_ != nullptr) { | ||
| auto scales_ptr = tensor.Data<float>(); | ||
| packed_scales_zp_size_ = MlasLUTPackScalesAndZeroPointsSize(N_, K_, block_size_, has_zp_input_); | ||
| packed_scales_zp_ = IAllocator::MakeUniquePtr<float>(alloc, packed_scales_zp_size_, true); | ||
|
|
||
| // TODO(vraspar): improve this logic block | ||
| if (has_zp_input_) { | ||
| const Tensor* zero_points = nullptr; | ||
| OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); | ||
| auto zero_points_ptr = zero_points->Data<uint8_t>(); | ||
| MlasLUTPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, zero_points_ptr); | ||
| } else { | ||
| MlasLUTPackScalesAndZeroPoints(N_, K_, nbits_, block_size_, has_zp_input_, packed_scales_zp_.get(), scales_ptr, nullptr); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Threadpool will be automatically destroyed when temp_threadpool goes out of scope | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
|
|
@@ -296,7 +371,7 @@ | |
| is_packed = false; | ||
| } | ||
| #endif // MLAS_TARGET_AMD64_IX86 | ||
| } | ||
| } | ||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
@@ -307,14 +382,38 @@ | |
| /*out*/ bool& used_shared_buffers) { | ||
| used_shared_buffers = false; | ||
|
|
||
| if (input_idx == 1) { | ||
| if (input_idx == 1) { //TODO(vraspar): DO we need shared Prepacked buffer for TMAC, combine packing of weights + scales/ZP into one buffer ??? | ||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| used_shared_buffers = true; | ||
| packed_b_ = std::move(prepacked_buffers[0]); | ||
| } | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| template<typename T1> | ||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a, | ||
| const Tensor* scales, | ||
| const Tensor* zero_points, | ||
| const Tensor* bias, | ||
| Tensor* y, | ||
| AllocatorPtr& allocator, | ||
| concurrency::ThreadPool* thread_pool, | ||
| const MatMulComputeHelper& helper) const { | ||
| const auto* a_data = a->Data<T1>(); | ||
| const auto* scales_data = scales == nullptr ? nullptr : scales->Data<T1>(); | ||
| const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); | ||
| const auto* bias_data = bias == nullptr ? nullptr : bias->Data<T1>(); | ||
| auto* y_data = y->MutableData<T1>(); | ||
| const size_t batch_count = helper.OutputOffsets().size(); | ||
| const size_t M = static_cast<size_t>(helper.M()); | ||
| const size_t N = static_cast<size_t>(helper.N()); | ||
| const size_t K = static_cast<size_t>(helper.K()); | ||
| // TODO(vraspar): Should we batch it here? | ||
| //MlasInitLUTGemmKernelConfig(N, K, nbits_, block_size_, has_zp_input_); | ||
vraspar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| MlasLUTGemm(a_data, block_size_, packed_b_.get(), packed_scales_zp_.get(), y_data, K, M, N, thread_pool); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| template <typename T1> | ||
| Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a, | ||
| const Tensor* scales, | ||
|
|
@@ -334,6 +433,7 @@ | |
| const size_t M = static_cast<size_t>(helper.M()); | ||
| const size_t N = static_cast<size_t>(helper.N()); | ||
| const size_t K = static_cast<size_t>(helper.K()); | ||
|
|
||
| const size_t lda = helper.Lda(false); | ||
|
|
||
| IAllocatorUniquePtr<std::byte> workspace{}; | ||
|
|
@@ -774,6 +874,10 @@ | |
| // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while | ||
| // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() | ||
| // with B directly too. | ||
| if (prefer_lut_gemm_) { | ||
| return ComputeBPackedLUT(a, scales, zero_points, bias, y, allocator, thread_pool, helper); | ||
| } | ||
|
|
||
| if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { | ||
| return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.