Skip to content

Commit 13dca2a

Browse files
Vectorize load instructions in dmmv f16 CUDA kernel (ggml-org#9816)
* Vectorize load instructions in dmmv f16 CUDA kernel Replaces scalar with vector load instructions, which substantially improves performance on NVIDIA HBM GPUs, e.g. gives a 1.27X overall speedup for Meta-Llama-3-8B-Instruct-F16 BS1 inference evaluation on H100 SXM 80GB HBM3. On GDDR GPUs, there is a slight (1.01X) speedup. * addressed comment * Update ggml/src/ggml-cuda/dmmv.cu Co-authored-by: Johannes Gäßler <[email protected]> --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent d4c19c0 commit 13dca2a

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

ggml/src/ggml-cuda/dmmv.cu

+25-9
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
416416

417417
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
418418
const half * x = (const half *) vx;
419-
419+
// load 2 halfs into register in a single instruction
420+
const half2 x_reg = *((half2 *) &(x[ib + iqs]));
420421
// automatic half -> float type cast if dfloat == float
421-
v.x = x[ib + iqs + 0];
422-
v.y = x[ib + iqs + 1];
422+
v.x = __low2float(x_reg);
423+
v.y = __high2float(x_reg);
423424
}
424425

425426
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
@@ -476,13 +477,28 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
476477
// matrix multiplication
477478
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
478479
#ifdef GGML_CUDA_F16
479-
tmp += __hmul2(v, {
480-
y[iybs + iqs + j/qr + 0],
481-
y[iybs + iqs + j/qr + y_offset]
482-
});
480+
if ( y_offset == 1 ) {
481+
// load 2 dfloats into register in a single instruction
482+
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
483+
tmp += __hmul2(v, y_reg);
484+
}
485+
else {
486+
tmp += __hmul2(v, {
487+
y[iybs + iqs + j/qr + 0],
488+
y[iybs + iqs + j/qr + y_offset]
489+
});
490+
}
483491
#else
484-
tmp += v.x * y[iybs + iqs + j/qr + 0];
485-
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
492+
if ( y_offset == 1 ) {
493+
// load 2 dfloats into register in a single instruction
494+
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
495+
tmp += v.x * y_reg.x;
496+
tmp += v.y * y_reg.y;
497+
}
498+
else {
499+
tmp += v.x * y[iybs + iqs + j/qr + 0];
500+
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
501+
}
486502
#endif // GGML_CUDA_F16
487503
}
488504
}

0 commit comments

Comments
 (0)