diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index cb099c2409a44..32c72342b4803 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -25,33 +25,53 @@ Module Name: #include #include #include +#include #include -/** T-MAC GEMM kernel Config */ +/** + * Global cache for T-MAC kernel parameters, indexed by configuration. + * This map and its associated mutex ensure thread-safe parameter management + * across concurrent MLAS calls. + */ static std::unordered_map tmac_kernel_configs; +static std::mutex tmac_kernel_configs_mutex; -const MlasTMACKernelParams& +static std::string +GetTmacKey(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) +{ + // Generate a unique cache key based on the GEMM and quantization configuration. + return std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0"); +} + +MlasTMACKernelParams MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) { - std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0"); - if (tmac_kernel_configs.count(key)) { - return tmac_kernel_configs[key]; + std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point); + std::lock_guard lock(tmac_kernel_configs_mutex); + auto it = tmac_kernel_configs.find(key); + if (it != tmac_kernel_configs.end()) { + return it->second; } - MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized"); + MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized for key: " + key); } void MLASCALL MlasClearLutGemmKernelConfig() { + std::lock_guard lock(tmac_kernel_configs_mutex); tmac_kernel_configs.clear(); } void MLASCALL MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) { - std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0"); - if (tmac_kernel_configs.count(key)) { - return; + std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point); + { + std::lock_guard lock(tmac_kernel_configs_mutex); + if (tmac_kernel_configs.find(key) != tmac_kernel_configs.end()) { + return; + } } MlasTMACKernelParams params; @@ -121,7 +141,10 @@ MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, params.has_zero_point = has_zero_point; params.one_scale = false; // TODO(vraspar): support one scale case for bitnet - tmac_kernel_configs[key] = params; + { + std::lock_guard lock(tmac_kernel_configs_mutex); + tmac_kernel_configs[key] = params; + } return; } @@ -222,53 +245,52 @@ LutGemmPackQuantBData( const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed? - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - size_t im = static_cast(tid); - for (size_t ib = 0; ib < bits; ib++) { - for (size_t ik = 0; ik < K / g; ik++) { - // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) - size_t new_im = im / simd_n_out; - size_t new_isno = im % simd_n_out; - size_t new_ib = ib; - size_t new_ik = ik; - size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; - - // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) - new_im = new_idx / c1_nb0; - size_t new_ing = (new_idx % c1_nb0) / c1_nb1; - size_t new_isni = (new_idx % c1_nb1) / c1_nb2; - new_ik = (new_idx % c1_nb2); - new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; - - // # 0 1 2 3 4 5 - // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) - new_im = new_idx / c2_nb0; - size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; - new_isni = (new_idx % c2_nb1) / c2_nb2; - new_ing = (new_idx % c2_nb2) / c2_nb3; - new_ik = (new_idx % c2_nb3) / c2_nb4; - size_t new_ikf = (new_idx % c2_nb4); - new_idx = new_im * c2_fac0 + - new_ik * c2_fac1 + - new_ibm * c2_fac2 + - new_ikf * c2_fac3 + - new_isni * ngroups_per_elem + - new_ing; - new_idx = new_idx / ngroups_per_elem; - size_t buf_idx = im * bits * K / g + ib * K / g + ik; - uint8_t buf_val = buf[buf_idx]; - - // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) - PackedQuantBDataBegin[new_idx] = static_cast( - static_cast(PackedQuantBDataBegin[new_idx]) + - (buf_val << (new_ing * g)) - ); - } + // NOTE: The second packing loop is intentionally serialized to avoid data races. + // T-MAC packs multiple output features (N) into a single byte if ngroups_per_elem > 1. + // Parallelizing this across N would lead to concurrent bit-plane updates on the same memory location. + for (size_t im = 0; im < Iterations; im++) { + for (size_t ib = 0; ib < bits; ib++) { + for (size_t ik = 0; ik < K / g; ik++) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + size_t new_im = im / simd_n_out; + size_t new_isno = im % simd_n_out; + size_t new_ib = ib; + size_t new_ik = ik; + size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + size_t new_ing = (new_idx % c1_nb0) / c1_nb1; + size_t new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + size_t new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + + new_ik * c2_fac1 + + new_ibm * c2_fac2 + + new_ikf * c2_fac3 + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + size_t buf_idx = im * bits * K / g + ib * K / g + ik; + uint8_t buf_val = buf[buf_idx]; + + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + PackedQuantBDataBegin[new_idx] = static_cast( + static_cast(PackedQuantBDataBegin[new_idx]) + + (buf_val << (new_ing * g)) + ); } } - ); + } } // Internal helper: calculates packed scales and zero points size in floats @@ -472,16 +494,15 @@ size_t CalculateLutBufferSize(size_t n, size_t k, size_t m, const MlasTMACKernelParams& tmac_params) { MLAS_UNREFERENCED_PARAMETER(n); - constexpr size_t kAllockAligment = 64; const size_t lut_scales_size = k / tmac_params.act_group_size; - size_t wsize = k * m * 4 * sizeof(int8_t); // 4 bytes per k element for 2-bit LUT - wsize += lut_scales_size * m * 2 * sizeof(float); // scales + biases - - wsize = ((wsize - 1) / kAllockAligment + 1) * kAllockAligment; + // The AVX2 kernel (g=4) expects 16 entries (16 bytes) per group of 4 activations. + // This effectively requires 4 bytes per activation in the K dimension. + size_t lut_size_bytes = m * k * 4; + size_t scales_size_bytes = m * lut_scales_size * sizeof(float); + size_t biases_size_bytes = m * lut_scales_size * sizeof(float); - // TODO(vrapar): add temp buffer for FP16 - return wsize; + return lut_size_bytes + scales_size_bytes + biases_size_bytes + 256; // + alignment/safety padding } void MLASCALL @@ -532,17 +553,23 @@ MlasLutGemm( // n_tiles_num = m * bits / bm; // TODO(vraspar): support other bitwidths + // For T-MAC, kernel properties (bm, n_tiles_num) are primarily driven by the number of output features (N). + // Initialization during packing (LutGemmPackQuantBDataSize) uses N as the major dimension, + // so we must match that here to ensure consistent weight tiling. + MlasInitLutGemmKernelConfig(N, K, 2, BlkLen, HasZeroPoint); const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, 2, BlkLen, HasZeroPoint); const size_t lut_scales_size = K / tmac_params.act_group_size; + const size_t lut_size_bytes = static_cast(M) * static_cast(K) * 4; size_t lut_buffer_size = CalculateLutBufferSize(N, K, M, tmac_params); // make buffer of lut_buffer_size bytes // TODO(vraspar): other way to do it auto lut_buffer = std::make_unique(lut_buffer_size); + memset(lut_buffer.get(), 0, lut_buffer_size); int8_t* qlut = reinterpret_cast(lut_buffer.get()); - float* lut_scales = reinterpret_cast(qlut + K * M * 4); // after lut - float* lut_biases = reinterpret_cast(lut_scales + lut_scales_size * M); // after scales + float* lut_scales = reinterpret_cast(qlut + lut_size_bytes); // after lut + float* lut_biases = reinterpret_cast(lut_scales + lut_scales_size * M); // after scales const auto* a_float = reinterpret_cast(A); // Activation data @@ -558,11 +585,12 @@ MlasLutGemm( for (size_t ine11 = 0; ine11 < static_cast(M); ine11++) { const size_t row_offset = ine11 * K; - const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT + // Call the LUT generation kernel for this activation row. + // We use a 4-byte stride (per activation) for the LUT entries to satisfy + // the memory layout requirements of the computation kernel. + const size_t lut_offset = ine11 * K * 4; const size_t scale_bias_offset = ine11 * lut_scales_size; - // Call the dispatch function for this row - // ggml_tmac_mul_mat_task_init Dispatch->GenerateLUT( const_cast(a_float + row_offset), // Input activation for this row qlut + lut_offset, // Output LUT for this row @@ -571,7 +599,8 @@ MlasLutGemm( M, K, N, - tmac_params.act_group_size + tmac_params.act_group_size, + tmac_params.act_group_size * 4 ); } @@ -657,15 +686,17 @@ MlasLutGemm( // Process all batch items in this chunk for (size_t ine11 = ir1_start; ine11 < ir1_end; ine11++) { - // Calculate LUT offsets for this batch item + // Calculate LUT offsets with 4-byte stride (per activation) for consistent access. const size_t qlut_offset = K * ine11 * 4; const size_t lut_scales_offset = lut_scales_size * ine11; // Calculate output offset const size_t dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0; - // Call the dispatch function to compute this tile - // Note M and N are swapped in TMAC terminology + // Call the dispatch function to compute this tile. + // We pass one batch item at a time (M=1) and ChunkSize0 output features. + // TotalN is passed specifically to allow the kernel to find the correct + // parameters (bm, tiles) used during weight packing. Dispatch->ComputeGemm( packed_weights + w_offset, // Weight tile QuantBScale + scales_offset, // Weight scales for this tile @@ -674,8 +705,9 @@ MlasLutGemm( lut_biases + lut_scales_offset, // LUT biases act_output + dst_offset, // Output location static_cast(K), // K dimension - static_cast(N), // N dimension - static_cast(1), // M dimension (processing one batch item at a time) + static_cast(1), // M dimension (batch size = 1) + static_cast(ir0_end - ir0_start), // N dimension (output features in chunk) + static_cast(N), // TotalN (total output features in weights) BlkLen, // Weight quantization group size HasZeroPoint // Whether zero points are used ); diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h index ef4d01a2c5809..0a733199ea2e8 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.h +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -42,7 +42,11 @@ struct MlasTMACKernelParams { bool one_scale; }; -const MlasTMACKernelParams& +/** + * Retrieves the T-MAC kernel configuration for a given GEMM problem. + * Returns the parameters by value to ensure thread-safety across concurrent calls. + */ +MlasTMACKernelParams MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point); typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( @@ -53,19 +57,21 @@ typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( size_t M, size_t K, size_t N, - size_t act_group_size + size_t act_group_size, + size_t lut_stride // Stride (in bytes) between consecutive LUT entries along the batch dimension. ); typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)( - const uint8_t* weights, - const float* scales, + const uint8_t* A, + const float* Scales, const int8_t* LUT, const float* LUT_Scales, const float* LUT_Biases, float* C, int K, - int M, // batch size (number of rows in activation) - int N, + int M, // Batch size (current activation rows). + int N, // Number of output features to compute in this tile/chunk. + int TotalN, // Total number of output features in the weights (used for parameter mapping). size_t BlkLen, bool HasZeroPoint ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index a89993d4515b8..7e4df13423be2 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -361,7 +361,8 @@ GenerateLUT_avx2( size_t M, size_t K, size_t N, - size_t act_group_size + size_t act_group_size, + size_t lut_stride ) { (void)M; // silence unused parameter warning @@ -379,7 +380,9 @@ GenerateLUT_avx2( } for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { - lut_ctor_g4_int8_impl(static_cast(act_group_size), (&(qlut[(k_outer_1 * act_group_size * 4)])), (&(b[(k_outer_1 * act_group_size)])), (&(lut_scales[k_outer_1])), (&(lut_biases[k_outer_1]))); + // Use the explicit lut_stride provided by the dispatch/caller to ensure + // consistent memory layout between construction and compute paths. + lut_ctor_g4_int8_impl(static_cast(act_group_size), (&(qlut[(k_outer_1 * lut_stride)])), (&(b[(k_outer_1 * act_group_size)])), (&(lut_scales[k_outer_1])), (&(lut_biases[k_outer_1]))); } } @@ -400,6 +403,20 @@ tbl_g4_int8_float_gather_bit2_impl(int32_t m, float* C_global, float* CBits, flo } } + // Handle tail cases where m is not a multiple of 32. + // This ensures C_global is fully initialized for all m elements. + int32_t m_tail = m % 32; + if (m_tail > 0) { + int32_t m_c_outer = m_c_outer_max; + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + for (int32_t m_c_inner = 0; m_c_inner < m_tail; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 8; + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (float)5.000000e-01f) + (CBits[cse_var_2 + bit_offset_1]); + } + } + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { PRAGMA_UNROLL for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { @@ -407,6 +424,17 @@ tbl_g4_int8_float_gather_bit2_impl(int32_t m, float* C_global, float* CBits, flo C[offset] = C_global[offset]; } } + + // Transfer the remaining tail results from C_global to the final output matrix C. + // This is necessary when m is not a multiple of 32, ensuring all output features + // are correctly written to the destination buffer. + if (m_tail > 0) { + int offset_base = m_c_outer_max * 32; + for (int32_t m_inner = 0; m_inner < m_tail; ++m_inner) { + int offset = offset_base + m_inner; + C[offset] = C_global[offset]; + } + } } // When FastAggregation is enabled, FastAggregationK = ActK @@ -451,8 +479,8 @@ tbl_g4_int8_float_update_impl(int32_t m, float* c, const int8_t* lut, const uint __m256 vec_v_high_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_high())); __m256 vec_v_high_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_high())); - float lut_s = lut_scales[kk / ActK]; - float lut_b = lut_biases[kk / ActK]; + float lut_s = lut_scales[kk / (ActK * 4)]; + float lut_b = lut_biases[kk / (ActK * 4)]; partial_sum += lut_b; @@ -542,17 +570,20 @@ TMACComputeGemm_avx2( int K, int M, int N, + int TotalN, size_t BlkLen, // Weight quantization group size (q_group_size) bool HasZeroPoint ) { - // Validate batch size - if (N != 1) { - MLAS_THROW_EX(std::runtime_error, "N > 1 is not supported yet"); + // Validate batch size (M) + // For now, TMAC AVX2 kernel processes one batch row at a time. + if (M != 1) { + MLAS_THROW_EX(std::runtime_error, "M > 1 is not supported yet in TMAC AVX2 kernel"); } - // get kernel config - const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(M, K, 2, BlkLen, HasZeroPoint); + // get kernel config using the total output features (TotalN) + // This matches the parameters used during weight packing. + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(TotalN, K, 2, BlkLen, HasZeroPoint); // ==================== CONFIGURATION ==================== // Fixed parameters for this kernel implementation @@ -572,7 +603,11 @@ TMACComputeGemm_avx2( const int32_t actk = static_cast(tmac_params.actk); // CRITICAL: = 16 for BlkLen=64, NOT BlkLen! const int32_t bm = static_cast(tmac_params.bm); - int32_t m = bm / bits; + // m is the number of output features this kernel tile produces. + // We clamp m by N (the number of features in the current chunk) to ensure + // we don't read or write past the tile boundary during the gather phase. + int32_t m_full = bm / bits; + int32_t m = std::min(m_full, N); // Validate configuration assert(bm % bits == 0); @@ -590,8 +625,9 @@ TMACComputeGemm_avx2( float* CBits = new float[bm]; float* C_global = new float[m]; - // Reset accumulator buffer to zero - tbl_int32_reset(bm * sizeof(float) / sizeof(int32_t), reinterpret_cast(CBits)); + // Explicitly zero-initialize accumulation buffers to ensure determinism. + memset(CBits, 0, bm * sizeof(float)); + memset(C_global, 0, m * sizeof(float)); // ==================== CALCULATE LOOP PARAMETERS ==================== const int32_t k_outer_max = K / (kfactor * g);