Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 105 additions & 73 deletions onnxruntime/core/mlas/lib/qlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,53 @@ Module Name:
#include <memory>
#include <string>
#include <thread>
#include <mutex>
#include <unordered_map>

/** T-MAC GEMM kernel Config */
/**
* Global cache for T-MAC kernel parameters, indexed by configuration.
* This map and its associated mutex ensure thread-safe parameter management
* across concurrent MLAS calls.
*/
static std::unordered_map<std::string, struct MlasTMACKernelParams> tmac_kernel_configs;
static std::mutex tmac_kernel_configs_mutex;

const MlasTMACKernelParams&
static std::string
GetTmacKey(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
// Generate a unique cache key based on the GEMM and quantization configuration.
return std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" +
std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
}

MlasTMACKernelParams
MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
if (tmac_kernel_configs.count(key)) {
return tmac_kernel_configs[key];
std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point);
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
auto it = tmac_kernel_configs.find(key);
if (it != tmac_kernel_configs.end()) {
return it->second;
}
MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized");
MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized for key: " + key);
}

void MLASCALL
MlasClearLutGemmKernelConfig()
{
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
tmac_kernel_configs.clear();
}

void MLASCALL
MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
if (tmac_kernel_configs.count(key)) {
return;
std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point);
{
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
if (tmac_kernel_configs.find(key) != tmac_kernel_configs.end()) {
return;
}
}

MlasTMACKernelParams params;
Expand Down Expand Up @@ -121,7 +141,10 @@ MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size,
params.has_zero_point = has_zero_point;
params.one_scale = false; // TODO(vraspar): support one scale case for bitnet

tmac_kernel_configs[key] = params;
{
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
tmac_kernel_configs[key] = params;
}
return;
}

Expand Down Expand Up @@ -222,53 +245,52 @@ LutGemmPackQuantBData(
const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem);
memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed?

MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
size_t im = static_cast<size_t>(tid);
for (size_t ib = 0; ib < bits; ib++) {
for (size_t ik = 0; ik < K / g; ik++) {
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
size_t new_im = im / simd_n_out;
size_t new_isno = im % simd_n_out;
size_t new_ib = ib;
size_t new_ik = ik;
size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;

// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
new_im = new_idx / c1_nb0;
size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
new_ik = (new_idx % c1_nb2);
new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;

// # 0 1 2 3 4 5
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
new_im = new_idx / c2_nb0;
size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
new_isni = (new_idx % c2_nb1) / c2_nb2;
new_ing = (new_idx % c2_nb2) / c2_nb3;
new_ik = (new_idx % c2_nb3) / c2_nb4;
size_t new_ikf = (new_idx % c2_nb4);
new_idx = new_im * c2_fac0 +
new_ik * c2_fac1 +
new_ibm * c2_fac2 +
new_ikf * c2_fac3 +
new_isni * ngroups_per_elem +
new_ing;
new_idx = new_idx / ngroups_per_elem;
size_t buf_idx = im * bits * K / g + ib * K / g + ik;
uint8_t buf_val = buf[buf_idx];

// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
PackedQuantBDataBegin[new_idx] = static_cast<std::byte>(
static_cast<unsigned>(PackedQuantBDataBegin[new_idx]) +
(buf_val << (new_ing * g))
);
}
// NOTE: The second packing loop is intentionally serialized to avoid data races.
// T-MAC packs multiple output features (N) into a single byte if ngroups_per_elem > 1.
// Parallelizing this across N would lead to concurrent bit-plane updates on the same memory location.
for (size_t im = 0; im < Iterations; im++) {
for (size_t ib = 0; ib < bits; ib++) {
for (size_t ik = 0; ik < K / g; ik++) {
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
size_t new_im = im / simd_n_out;
size_t new_isno = im % simd_n_out;
size_t new_ib = ib;
size_t new_ik = ik;
size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;

// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
new_im = new_idx / c1_nb0;
size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
new_ik = (new_idx % c1_nb2);
new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;

// # 0 1 2 3 4 5
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
new_im = new_idx / c2_nb0;
size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
new_isni = (new_idx % c2_nb1) / c2_nb2;
new_ing = (new_idx % c2_nb2) / c2_nb3;
new_ik = (new_idx % c2_nb3) / c2_nb4;
size_t new_ikf = (new_idx % c2_nb4);
new_idx = new_im * c2_fac0 +
new_ik * c2_fac1 +
new_ibm * c2_fac2 +
new_ikf * c2_fac3 +
new_isni * ngroups_per_elem +
new_ing;
new_idx = new_idx / ngroups_per_elem;
size_t buf_idx = im * bits * K / g + ib * K / g + ik;
uint8_t buf_val = buf[buf_idx];

// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
PackedQuantBDataBegin[new_idx] = static_cast<std::byte>(
static_cast<unsigned>(PackedQuantBDataBegin[new_idx]) +
(buf_val << (new_ing * g))
);
}
}
);
}
}

// Internal helper: calculates packed scales and zero points size in floats
Expand Down Expand Up @@ -472,16 +494,15 @@ size_t
CalculateLutBufferSize(size_t n, size_t k, size_t m, const MlasTMACKernelParams& tmac_params)
{
MLAS_UNREFERENCED_PARAMETER(n);
constexpr size_t kAllockAligment = 64;
const size_t lut_scales_size = k / tmac_params.act_group_size;

size_t wsize = k * m * 4 * sizeof(int8_t); // 4 bytes per k element for 2-bit LUT
wsize += lut_scales_size * m * 2 * sizeof(float); // scales + biases

wsize = ((wsize - 1) / kAllockAligment + 1) * kAllockAligment;
// The AVX2 kernel (g=4) expects 16 entries (16 bytes) per group of 4 activations.
// This effectively requires 4 bytes per activation in the K dimension.
size_t lut_size_bytes = m * k * 4;
size_t scales_size_bytes = m * lut_scales_size * sizeof(float);
size_t biases_size_bytes = m * lut_scales_size * sizeof(float);

// TODO(vrapar): add temp buffer for FP16
return wsize;
return lut_size_bytes + scales_size_bytes + biases_size_bytes + 256; // + alignment/safety padding
}

void MLASCALL
Expand Down Expand Up @@ -532,17 +553,23 @@ MlasLutGemm(
// n_tiles_num = m * bits / bm;

// TODO(vraspar): support other bitwidths
// For T-MAC, kernel properties (bm, n_tiles_num) are primarily driven by the number of output features (N).
// Initialization during packing (LutGemmPackQuantBDataSize) uses N as the major dimension,
// so we must match that here to ensure consistent weight tiling.
MlasInitLutGemmKernelConfig(N, K, 2, BlkLen, HasZeroPoint);
const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, 2, BlkLen, HasZeroPoint);
const size_t lut_scales_size = K / tmac_params.act_group_size;
const size_t lut_size_bytes = static_cast<size_t>(M) * static_cast<size_t>(K) * 4;
size_t lut_buffer_size = CalculateLutBufferSize(N, K, M, tmac_params);

// make buffer of lut_buffer_size bytes
// TODO(vraspar): other way to do it
auto lut_buffer = std::make_unique<int8_t[]>(lut_buffer_size);
memset(lut_buffer.get(), 0, lut_buffer_size);

int8_t* qlut = reinterpret_cast<int8_t*>(lut_buffer.get());
float* lut_scales = reinterpret_cast<float*>(qlut + K * M * 4); // after lut
float* lut_biases = reinterpret_cast<float*>(lut_scales + lut_scales_size * M); // after scales
float* lut_scales = reinterpret_cast<float*>(qlut + lut_size_bytes); // after lut
float* lut_biases = reinterpret_cast<float*>(lut_scales + lut_scales_size * M); // after scales

const auto* a_float = reinterpret_cast<const float*>(A); // Activation data

Expand All @@ -558,11 +585,12 @@ MlasLutGemm(

for (size_t ine11 = 0; ine11 < static_cast<size_t>(M); ine11++) {
const size_t row_offset = ine11 * K;
const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT
// Call the LUT generation kernel for this activation row.
// We use a 4-byte stride (per activation) for the LUT entries to satisfy
// the memory layout requirements of the computation kernel.
const size_t lut_offset = ine11 * K * 4;
const size_t scale_bias_offset = ine11 * lut_scales_size;

// Call the dispatch function for this row
// ggml_tmac_mul_mat_task_init
Dispatch->GenerateLUT(
const_cast<float*>(a_float + row_offset), // Input activation for this row
qlut + lut_offset, // Output LUT for this row
Expand All @@ -571,7 +599,8 @@ MlasLutGemm(
M,
K,
N,
tmac_params.act_group_size
tmac_params.act_group_size,
tmac_params.act_group_size * 4
);
}

Expand Down Expand Up @@ -657,15 +686,17 @@ MlasLutGemm(

// Process all batch items in this chunk
for (size_t ine11 = ir1_start; ine11 < ir1_end; ine11++) {
// Calculate LUT offsets for this batch item
// Calculate LUT offsets with 4-byte stride (per activation) for consistent access.
const size_t qlut_offset = K * ine11 * 4;
const size_t lut_scales_offset = lut_scales_size * ine11;

// Calculate output offset
const size_t dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0;

// Call the dispatch function to compute this tile
// Note M and N are swapped in TMAC terminology
// Call the dispatch function to compute this tile.
// We pass one batch item at a time (M=1) and ChunkSize0 output features.
// TotalN is passed specifically to allow the kernel to find the correct
// parameters (bm, tiles) used during weight packing.
Dispatch->ComputeGemm(
packed_weights + w_offset, // Weight tile
QuantBScale + scales_offset, // Weight scales for this tile
Expand All @@ -674,8 +705,9 @@ MlasLutGemm(
lut_biases + lut_scales_offset, // LUT biases
act_output + dst_offset, // Output location
static_cast<int>(K), // K dimension
static_cast<int>(N), // N dimension
static_cast<int>(1), // M dimension (processing one batch item at a time)
static_cast<int>(1), // M dimension (batch size = 1)
static_cast<int>(ir0_end - ir0_start), // N dimension (output features in chunk)
static_cast<int>(N), // TotalN (total output features in weights)
BlkLen, // Weight quantization group size
HasZeroPoint // Whether zero points are used
);
Expand Down
18 changes: 12 additions & 6 deletions onnxruntime/core/mlas/lib/qlutgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ struct MlasTMACKernelParams {
bool one_scale;
};

const MlasTMACKernelParams&
/**
* Retrieves the T-MAC kernel configuration for a given GEMM problem.
* Returns the parameters by value to ensure thread-safety across concurrent calls.
*/
MlasTMACKernelParams
MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point);

typedef void(MLAS_QNBIT_GEMM_LUT_GEN)(
Expand All @@ -53,19 +57,21 @@ typedef void(MLAS_QNBIT_GEMM_LUT_GEN)(
size_t M,
size_t K,
size_t N,
size_t act_group_size
size_t act_group_size,
size_t lut_stride // Stride (in bytes) between consecutive LUT entries along the batch dimension.
);

typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)(
const uint8_t* weights,
const float* scales,
const uint8_t* A,
const float* Scales,
const int8_t* LUT,
const float* LUT_Scales,
const float* LUT_Biases,
float* C,
int K,
int M, // batch size (number of rows in activation)
int N,
int M, // Batch size (current activation rows).
int N, // Number of output features to compute in this tile/chunk.
int TotalN, // Total number of output features in the weights (used for parameter mapping).
size_t BlkLen,
bool HasZeroPoint
);
Expand Down
Loading
Loading