Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2936,6 +2936,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );

CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
#endif

Expand Down Expand Up @@ -3055,6 +3057,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );

CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
#endif

Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {

const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
const vec2 d = vec2(data_a[a_offset + ib].d);
const vec2 dm = vec2(data_a[a_offset + ib].dm);

return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];

const uint scalesi = (idx & 0xF0) >> 4; // 0..15
Expand All @@ -119,7 +119,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
qs = unpack8(qs)[idx & 1];

const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
return ret;
}

Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ void main() {
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];

FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));

vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
vec2 dm = vec2(data_a[ib0 + i].dm);
const FLOAT_TYPE dall = FLOAT_TYPE(dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(dm.y);

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127

const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6

const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi];
const vec2 d = vec2(data_a[ib].d);
const vec2 dm = vec2(data_a[ib].dm);

const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);

buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)
Expand Down
92 changes: 30 additions & 62 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
Expand Down Expand Up @@ -78,28 +81,23 @@ layout (constant_id = 10) const uint WARP = 32;

#ifdef COOPMAT
#define SHMEM_STRIDE (BK / 4 + 4)
#else
#define SHMEM_STRIDE (BK / 4 + 1)
#endif

shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
#define MMQ_SHMEM

#ifndef COOPMAT
#if QUANT_AUXF == 1
shared FLOAT_TYPE buf_a_dm[BM];
#else
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
#endif
#endif
#include "mul_mmq_shmem_types.glsl"

shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
#ifndef COOPMAT
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif
// Shared memory cache
shared block_a_cache buf_a[BM];
shared block_b_cache buf_b[BN];
// Register cache
block_a_cache cache_a[WMITER * TM];
block_b_cache cache_b[TN];

#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
#define LOAD_VEC_B 16

// TODO: Recheck if this can work with mul_mat_id
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
#endif // MUL_MAT_ID
Expand Down Expand Up @@ -222,9 +220,6 @@ void main() {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
int32_t cache_a_qs[WMITER * TM * BK / 4];

int32_t cache_b_qs[TN * BK / 4];

ACC_TYPE sums[WMITER * TM * WNITER * TN];

Expand All @@ -233,34 +228,13 @@ void main() {
}
#endif

#if QUANT_AUXF == 1
FLOAT_TYPE cache_a_dm[WMITER * TM];
#else
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
#endif

FLOAT_TYPE_VEC2 cache_b_ds[TN];

for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;

if (iqs == 0) {
#if QUANT_AUXF == 1
buf_a_dm[buf_ib] = get_d(ib);
#else
buf_a_dm[buf_ib] = get_dm(ib);
#endif
}
#if QUANT_R == 1
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
#else
const i32vec2 vals = repack(ib, iqs);
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
#endif
block_a_to_shmem(buf_ib, ib, iqs);
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
#ifdef MUL_MAT_ID
Expand All @@ -279,13 +253,13 @@ void main() {
const uint buf_ib = loadc_b + l;

if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
}

barrier();
Expand Down Expand Up @@ -328,20 +302,19 @@ void main() {
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
}
const uint reg_ib = wsir * TM + cr;
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;

block_a_to_registers(reg_ib, buf_ib);
}
}

[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
cache_b_ds[cc] = buf_b_ds[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
cache_b[cc].ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b[cc].qs[iqs] = buf_b[ib].qs[iqs];
}
}

Expand All @@ -350,13 +323,8 @@ void main() {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
int32_t q_sum = 0;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]);
}

sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
sums[sums_idx] += mmq_dot_product(cache_a_idx, cc);
}
}
}
Expand Down
Loading
Loading