Skip to content

Commit

Permalink
vulkan: optimize coopmat2 dequant functions (#10855)
Browse files Browse the repository at this point in the history
Change the code to do 16b loads when possible and extract the appropriate
component late, so the code is effectively decoding a pair of elements and
then selecting one. This can allow more commoning to happen in the compiler
when neighboring elements are loaded.
  • Loading branch information
jeffbolznv authored Dec 21, 2024
1 parent e34c5af commit a91a413
Showing 1 changed file with 45 additions and 25 deletions.
70 changes: 45 additions & 25 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1];
uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
qs >>= shift;
qs &= 0xF;
qs &= 0x0F0F;
qs = unpack8(qs)[idx & 1];
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
return ret;
}
Expand Down Expand Up @@ -152,15 +153,17 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
block_q4_K block;
};

layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
block_q4_K_packed16 block;
};

float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint n = iqs / 64; // 0,1,2,3
const uint b = (iqs % 64) / 32; // 0,1
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
const uint qsi = n * 32 + (iqs % 32); // 0..127

const f16vec2 loadd = bl.block.d;

Expand All @@ -184,9 +187,11 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);

uint32_t dmask = 0xF << (b * 4);
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1];

float16_t ret = d * float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) - m;
float16_t ret = d * float16_t(qs) - m;

return ret;
}
Expand All @@ -195,18 +200,19 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
block_q5_K block;
};

layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
block_q5_K_packed16 block;
};

float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint n = iqs / 64; // 0,1,2,3
const uint b = (iqs % 64) / 32; // 0,1
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
const uint qsi = n * 32 + (iqs % 32); // 0..127
const uint qhi = (iqs % 32); // 0..31

const uint8_t hm = uint8_t(1 << (iqs / 32));
const uint32_t hm = 0x0101 << is;

const f16vec2 loadd = bl.block.d;

Expand All @@ -230,9 +236,15 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);

uint32_t dmask = 0xF << (b * 4);
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = qh & hm;
qh = unpack8(qh)[idx & 1];

float16_t ret = d * (float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi ] & hm) != 0 ? 16 : 0)) - m;
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1];

float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;

return ret;
}
Expand All @@ -241,22 +253,30 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_
block_q6_K block;
};

layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
block_q6_K_packed16 block;
};

float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint n = iqs / 128; // 0,1
const uint b = (iqs % 128) / 64; // 0,1
const uint is_b = (iqs % 32) / 16; // 0,1
const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
const uint is = 8 * n + qhshift + is_b; // 0..15
const uint qsi = n * 64 + (iqs % 64); // 0..127
const uint qhi = n * 32 + (iqs % 32); // 0..63
const uint b = (idx & 0x40) >> 6; // 0,1
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
const uint is = (idx & 0xF0) >> 4; // 0..15

const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);

float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi ] >> qhshift) & 3) << 4)) - 32);
uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
ql = (ql >> (b * 4)) & 0x0F0F;

uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qh = ((qh >> qhshift) & 0x0303) << 4;

int q = unpack8(ql | qh)[idx & 1];

float16_t ret = dscale * float16_t(q - 32);

return ret;
}
Expand Down

0 comments on commit a91a413

Please sign in to comment.