From c3ea58aca406911eb4d409cdbfc76683393442b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 17 Nov 2024 09:09:55 +0100 Subject: [PATCH] CUDA: remove DMMV, consolidate F16 mult mat vec (#10318) --- Makefile | 57 -- docs/build.md | 11 - ggml/CMakeLists.txt | 5 - ggml/src/ggml-cuda/CMakeLists.txt | 13 - ggml/src/ggml-cuda/dmmv.cu | 699 ----------------------- ggml/src/ggml-cuda/ggml-cuda.cu | 204 +------ ggml/src/ggml-cuda/mmv.cu | 223 ++++++++ ggml/src/ggml-cuda/{dmmv.cuh => mmv.cuh} | 16 +- ggml/src/ggml-hip/CMakeLists.txt | 7 - ggml/src/ggml-musa/CMakeLists.txt | 11 - 10 files changed, 246 insertions(+), 1000 deletions(-) delete mode 100644 ggml/src/ggml-cuda/dmmv.cu create mode 100644 ggml/src/ggml-cuda/mmv.cu rename ggml/src/ggml-cuda/{dmmv.cuh => mmv.cuh} (55%) diff --git a/Makefile b/Makefile index fecf1f693e195..bd0bad160a1f2 100644 --- a/Makefile +++ b/Makefile @@ -635,10 +635,6 @@ else ifndef CUDA_POWER_ARCH MK_NVCCFLAGS += -arch=native endif # CUDA_DOCKER_ARCH -ifdef GGML_CUDA_FORCE_DMMV - MK_NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV -endif # GGML_CUDA_FORCE_DMMV - ifdef GGML_CUDA_FORCE_MMQ MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ endif # GGML_CUDA_FORCE_MMQ @@ -647,20 +643,6 @@ ifdef GGML_CUDA_FORCE_CUBLAS MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS endif # GGML_CUDA_FORCE_CUBLAS -ifdef GGML_CUDA_DMMV_X - MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) -else - MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 -endif # GGML_CUDA_DMMV_X - -ifdef GGML_CUDA_MMV_Y - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) -else ifdef GGML_CUDA_DMMV_Y - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_DMMV_Y) # for backwards compatibility -else - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=1 -endif # GGML_CUDA_MMV_Y - ifdef GGML_CUDA_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_F16 @@ -669,12 +651,6 @@ ifdef GGML_CUDA_DMMV_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_DMMV_F16 -ifdef GGML_CUDA_KQUANTS_ITER - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) -else - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 -endif - ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) else @@ -783,10 +759,6 @@ ifdef GGML_HIPBLAS AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) endif - GGML_CUDA_DMMV_X ?= 32 - GGML_CUDA_MMV_Y ?= 1 - GGML_CUDA_KQUANTS_ITER ?= 2 - MK_CPPFLAGS += -DGGML_USE_HIP -DGGML_USE_CUDA ifdef GGML_HIP_UMA @@ -800,13 +772,6 @@ endif # GGML_HIP_UMA HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS)) - HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) - HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) - HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) - -ifdef GGML_CUDA_FORCE_DMMV - HIPFLAGS += -DGGML_CUDA_FORCE_DMMV -endif # GGML_CUDA_FORCE_DMMV ifdef GGML_CUDA_FORCE_MMQ HIPFLAGS += -DGGML_CUDA_FORCE_MMQ @@ -869,10 +834,6 @@ ifdef GGML_MUSA MUSAFLAGS += $(addprefix --cuda-gpu-arch=, $(MTGPU_TARGETS)) -ifdef GGML_CUDA_FORCE_DMMV - MUSAFLAGS += -DGGML_CUDA_FORCE_DMMV -endif # GGML_CUDA_FORCE_DMMV - ifdef GGML_CUDA_FORCE_MMQ MUSAFLAGS += -DGGML_CUDA_FORCE_MMQ endif # GGML_CUDA_FORCE_MMQ @@ -881,18 +842,6 @@ ifdef GGML_CUDA_FORCE_CUBLAS MUSAFLAGS += -DGGML_CUDA_FORCE_CUBLAS endif # GGML_CUDA_FORCE_CUBLAS -ifdef GGML_CUDA_DMMV_X - MUSAFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) -else - MUSAFLAGS += -DGGML_CUDA_DMMV_X=32 -endif # GGML_CUDA_DMMV_X - -ifdef GGML_CUDA_MMV_Y - MUSAFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) -else - MUSAFLAGS += -DGGML_CUDA_MMV_Y=1 -endif # GGML_CUDA_MMV_Y - ifdef GGML_CUDA_F16 MUSAFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_F16 @@ -901,12 +850,6 @@ ifdef GGML_CUDA_DMMV_F16 MUSAFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_DMMV_F16 -ifdef GGML_CUDA_KQUANTS_ITER - MUSAFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) -else - MUSAFLAGS += -DK_QUANTS_PER_ITERATION=2 -endif - ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE MUSAFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) else diff --git a/docs/build.md b/docs/build.md index 811bbb409c250..359952b30ff93 100644 --- a/docs/build.md +++ b/docs/build.md @@ -186,13 +186,9 @@ The following compilation options are also available to tweak performance: | Option | Legal values | Default | Description | |-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | -| GGML_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | -| GGML_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | | GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | | GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | -| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | | GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. | @@ -268,13 +264,6 @@ You can download it from your Linux distro's package manager or from here: [ROCm The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used. If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. -The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above): - -| Option | Legal values | Default | Description | -|------------------------|------------------------|---------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the HIP dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | -| GGML_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the HIP mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | -| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per HIP thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | ### Vulkan diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index fd94998264547..a82818d607f5e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -128,14 +128,9 @@ option(GGML_LLAMAFILE "ggml: use LLAMAFILE" option(GGML_CUDA "ggml: use CUDA" OFF) option(GGML_MUSA "ggml: use MUSA" OFF) -option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) -set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") -set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) -set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING - "ggml: iters./thread per block for Q2_K/Q6_K") set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 860552f3a96df..3dde0f3666187 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -54,21 +54,12 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE ggml-base) target_include_directories(ggml-cuda PRIVATE . ..) - # TODO: change the definitions to this target only - - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS) endif() - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -81,10 +72,6 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu deleted file mode 100644 index 00e21b5d77e3c..0000000000000 --- a/ggml/src/ggml-cuda/dmmv.cu +++ /dev/null @@ -1,699 +0,0 @@ -#include "dmmv.cuh" -#include "dequantize.cuh" -#include "convert.cuh" - -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - -static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q2_K * x = (const block_q2_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 16/K_QUANTS_PER_ITERATION; - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int s_offset = 8*im; - const int y_offset = 128*im + l0; - - uint32_t aux[4]; - const uint8_t * d = (const uint8_t *)aux; - const uint8_t * m = (const uint8_t *)(aux + 2); - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = a[1] & 0x0f0f0f0f; - aux[2] = (a[0] >> 4) & 0x0f0f0f0f; - aux[3] = (a[1] >> 4) & 0x0f0f0f0f; - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) - + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) - + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) - + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) - + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) - + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) - + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) - +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); - sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] - + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; - - } - tmp += dall * sum1 - dmin * sum2; - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q3_K * x = (const block_q3_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 - - const uint8_t m = 1 << (4*im); - - const int l0 = n*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int y_offset = 128*im + l0; - - uint16_t utmp[4]; - const int8_t * s = (const int8_t *)utmp; - - const uint16_t s_shift = 4*im; - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - const uint8_t * h = x[i].hmask + l0; - - const uint16_t * a = (const uint16_t *)x[i].scales; - utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); - utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); - utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); - utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); - - const float d = x[i].d; - - float sum = 0; - for (int l = 0; l < n; ++l) { - sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) - + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) - + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) - + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); - sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) - + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) - + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) - + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); - } - tmp += d * sum; - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q4_K * x = (const block_q4_K *)vx + ib0; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 - - const int il = tid/step; // 0...3 - const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - -#if K_QUANTS_PER_ITERATION == 2 - uint32_t q32[4]; - const uint8_t * q4 = (const uint8_t *)q32; -#else - uint16_t q16[4]; - const uint8_t * q4 = (const uint8_t *)q16; -#endif - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - -#if K_QUANTS_PER_ITERATION == 2 - const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); - const uint32_t * q2 = q1 + 16; - - q32[0] = q1[0] & 0x0f0f0f0f; - q32[1] = q1[0] & 0xf0f0f0f0; - q32[2] = q2[0] & 0x0f0f0f0f; - q32[3] = q2[0] & 0xf0f0f0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 4; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; - s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#else - const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); - const uint16_t * q2 = q1 + 32; - - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[0] & 0xf0f0; - q16[2] = q2[0] & 0x0f0f; - q16[3] = q2[0] & 0xf0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 2; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; - s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#endif - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { - - const int row = blockIdx.x; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; - - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 2; - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - const uint8_t hm1 = 1 << (2*im); - const uint8_t hm2 = hm1 << 4; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - - uint16_t q16[8]; - const uint8_t * q4 = (const uint8_t *)q16; - - for (int i = ix; i < num_blocks_per_row; i += 2) { - - const uint8_t * ql1 = x[i].qs + q_offset; - const uint8_t * qh = x[i].qh + l0; - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - - float4 sum = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - const uint16_t * q1 = (const uint16_t *)ql1; - const uint16_t * q2 = q1 + 32; - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[8] & 0x0f0f; - q16[2] = (q1[0] >> 4) & 0x0f0f; - q16[3] = (q1[8] >> 4) & 0x0f0f; - q16[4] = q2[0] & 0x0f0f; - q16[5] = q2[8] & 0x0f0f; - q16[6] = (q2[0] >> 4) & 0x0f0f; - q16[7] = (q2[8] >> 4) & 0x0f0f; - for (int l = 0; l < n; ++l) { - sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) - + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); - sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) - + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); - sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) - + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); - sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) - + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); - smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] - + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; - } - tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q6_K * x = (const block_q6_K *)vx + ib0; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else - const int l0 = 4 * in; // 0, 4, 8, ..., 28 - const int is = in / 4; -#endif - const int ql_offset = 64*im + l0; - const int qh_offset = 32*im + l0; - const int s_offset = 8*im + is; - const int y_offset = 128*im + l0; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * ql = x[i].ql + ql_offset; - const uint8_t * qh = x[i].qh + qh_offset; - const int8_t * s = x[i].scales + s_offset; - - const float d = x[i].d; - -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else - float sum = 0; - for (int l = 0; l < 4; ++l) { - sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) - + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) - + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); - } - tmp += sum; -#endif - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ - const half * x = (const half *) vx; - // load 2 halfs into register in a single instruction - const half2 x_reg = *((half2 *) &(x[ib + iqs])); - // automatic half -> float type cast if dfloat == float - v.x = __low2float(x_reg); - v.y = __high2float(x_reg); -} - -static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 : - type == GGML_TYPE_Q4_1 ? dequantize_q4_1 : - type == GGML_TYPE_Q5_0 ? dequantize_q5_0 : - type == GGML_TYPE_Q5_1 ? dequantize_q5_1 : - type == GGML_TYPE_Q8_0 ? dequantize_q8_0 : - type == GGML_TYPE_F16 ? convert_f16 : - nullptr; -} - -template -static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { - constexpr int qk = ggml_cuda_type_traits::qk; // quantized weights per x block - constexpr int qr = ggml_cuda_type_traits::qr; // number of quantized weights per data value in x block - constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type); - - const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int tid = threadIdx.x; - - const int iter_stride = 2*GGML_CUDA_DMMV_X; - const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter - const int y_offset = qr == 1 ? 1 : qk/2; - -// partial sum for each thread -#ifdef GGML_CUDA_F16 - half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics -#else - float tmp = 0.0f; -#endif // GGML_CUDA_F16 - - for (int i = 0; i < ncols; i += iter_stride) { - const int col = i + vals_per_iter*tid; - const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index - const int iqs = (col%qk)/qr; // x quant index - const int iybs = col - col%qk; // y block start index - -// processing >2 values per i iter is faster for fast GPUs -#pragma unroll - for (int j = 0; j < vals_per_iter; j += 2) { - // process 2 vals per j iter - - // dequantize - // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val - dfloat2 v; - dequantize_kernel(vx, ib, iqs + j/qr, v); - - // matrix multiplication - // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 -#ifdef GGML_CUDA_F16 - if ( y_offset == 1 ) { - // load 2 dfloats into register in a single instruction - const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr])); - tmp += __hmul2(v, y_reg); - } - else { - tmp += __hmul2(v, { - y[iybs + iqs + j/qr + 0], - y[iybs + iqs + j/qr + y_offset] - }); - } -#else - if ( y_offset == 1 ) { - // load 2 dfloats into register in a single instruction - const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr])); - tmp += v.x * y_reg.x; - tmp += v.y * y_reg.y; - } - else { - tmp += v.x * y[iybs + iqs + j/qr + 0]; - tmp += v.y * y[iybs + iqs + j/qr + y_offset]; - } -#endif // GGML_CUDA_F16 - } - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { -#ifdef GGML_CUDA_F16 - dst[row] = tmp.x + tmp.y; -#else - dst[row] = tmp; -#endif // GGML_CUDA_F16 - } -} - -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 1, 1); - dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); -} - -static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); -} - -static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -void ggml_cuda_op_dequantize_mul_mat_vec( - ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream) { - GGML_UNUSED(ctx); - const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics -#ifdef GGML_CUDA_F16 - ggml_cuda_pool_alloc src1_dfloat_a(ctx.pool()); - half * src1_dfloat = nullptr; // dfloat == half - - bool src1_convert_f16 = - src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || - src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; - - if (src1_convert_f16) { - src1_dfloat = src1_dfloat_a.alloc(ne00); - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream); - } -#else - const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion -#endif // GGML_CUDA_F16 - - switch (src0->type) { - case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_F16: - convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddq_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); -} - -bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { - return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 || - src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 || - src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || - src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || - src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || - src0_type == GGML_TYPE_F16; -} diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 07f043328a864..ef56e944d01ab 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -16,11 +16,11 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" -#include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" +#include "ggml-cuda/mmv.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" @@ -1020,114 +1020,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)( #define MUL_MAT_SRC1_COL_STRIDE 128 -static __global__ void mul_mat_p021_f16_f32( - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / (nchannels_y / nchannels_x); - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - // y is not transposed but permuted - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // dst is not transposed and not permuted - const int idst = channel*nrows_dst + row_dst; - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / channel_x_divisor; - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - const int idst = channel*nrows_dst + row_dst; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - const int row_y = col_x; - - const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const int iy = channel*nrows_y + row_y; - - const float xi = __half2float(x[ix]); - - tmp += xi * y[iy]; - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static void ggml_mul_mat_p021_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, - const int nchannels_x, const int nchannels_y, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); -} - -static void ggml_mul_mat_vec_nc_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); -} - static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { @@ -1654,58 +1546,6 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation - GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); -} - -static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - const int64_t row_stride_x = nb01 / sizeof(half); - const int64_t channel_stride_x = nb02 / sizeof(half); - - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); -} - static __global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, char * dst, const void ** ptrs_src, void ** ptrs_dst, @@ -1879,21 +1719,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type) + bool use_mul_mat_vec = src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) + bool use_mul_mat_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - // if mmvq is available it's a better choice than dmmv: -#ifndef GGML_CUDA_FORCE_DMMV - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; -#endif // GGML_CUDA_FORCE_DMMV - - bool any_gpus_with_slow_fp16 = false; + bool any_gpus_with_slow_fp16 = false; + bool any_gpus_without_fp16_mma = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; @@ -1904,14 +1740,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - const int cc = ggml_cuda_info().devices[id].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } } else { - const int cc = ggml_cuda_info().devices[ctx.device].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } // debug helpers @@ -1922,18 +1760,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // FP32 precision KQ single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // FP32 precision KQV single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); + if (!split && src0->type == GGML_TYPE_F16 && src1->ne[1] == 1 && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); - } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); + } else if (use_mul_mat_vec) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu new file mode 100644 index 0000000000000..cfe91f4283fa6 --- /dev/null +++ b/ggml/src/ggml-cuda/mmv.cu @@ -0,0 +1,223 @@ +#include "common.cuh" +#include "mmv.cuh" + +template +static __global__ void mul_mat_vec( + const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { + const int64_t row = blockIdx.x; + const int64_t channel = blockIdx.z; + const int tid = threadIdx.x; + + x += (channel/channel_ratio)*stride_channel_x + row*stride_row; + y += channel *stride_channel_y; + dst += channel *stride_channel_dst; + + const half2 * x2 = (const half2 *) x; + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + float * buf_iw = (float *) data_mmv; + + if (block_size > WARP_SIZE) { + if (tid < WARP_SIZE) { + buf_iw[tid] = 0.0f; + } + __syncthreads(); + } + + float sumf; + + if (std::is_same::value) { + sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + const float2 tmpy = y2[col2]; + sumf += tmpx.x * tmpy.x; + sumf += tmpx.y * tmpy.y; + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2 = make_half2(0.0f, 0.0f); + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmp = y2[col2]; + sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); + } + + sumf = __low2float(sumh2) + __high2float(sumh2); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + + sumf = warp_reduce_sum(sumf); + + if (block_size > WARP_SIZE) { + buf_iw[tid/WARP_SIZE] = sumf; + __syncthreads(); + if (tid > WARP_SIZE) { + return; + } + sumf = buf_iw[tid]; + sumf = warp_reduce_sum(sumf); + } + + if (tid != 0) { + return; + } + + dst[row] = sumf; +} + +template +static void launch_mul_mat_vec_cuda( + const half * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(nchannels_y % nchannels_x == 0); + const int64_t channel_ratio = nchannels_y / nchannels_x; + + int64_t block_size_best = WARP_SIZE; + int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const int smem = WARP_SIZE*sizeof(float); + const dim3 block_nums(nrows, 1, nchannels_y); + const dim3 block_dims(block_size_best, 1, 1); + switch (block_size_best) { + case 32: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 64: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 96: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 128: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 160: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 192: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 224: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 256: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void mul_mat_vec_cuda( + const half * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + enum ggml_prec prec, cudaStream_t stream) { + switch (prec) { + case GGML_PREC_DEFAULT: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + case GGML_PREC_F32: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + } +} + +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + GGML_ASSERT(src1->ne[1] == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const half * src0_d = (const half *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + GGML_ASSERT(dst->ne[2] == ne12); + + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT( dst->ne[3] == 1); + + const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type); + const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type); + const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); + const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); + + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); +} + +void ggml_cuda_op_mul_mat_vec( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + GGML_ASSERT(src1_ncols == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + + // ggml_cuda_op provides single, contiguous matrices + const int64_t stride_row = ne00; + const int64_t nchannels_x = 1; + const int64_t nchannels_y = 1; + const int64_t channel_stride_x = 0; + const int64_t channel_stride_y = 0; + const int64_t channel_stride_dst = 0; + + mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + + GGML_UNUSED(ctx); + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); +} diff --git a/ggml/src/ggml-cuda/dmmv.cuh b/ggml/src/ggml-cuda/mmv.cuh similarity index 55% rename from ggml/src/ggml-cuda/dmmv.cuh rename to ggml/src/ggml-cuda/mmv.cuh index e727eb97f6aad..78a1cd4a6906a 100644 --- a/ggml/src/ggml-cuda/dmmv.cuh +++ b/ggml/src/ggml-cuda/mmv.cuh @@ -1,20 +1,12 @@ #include "common.cuh" -// dmmv = dequantize_mul_mat_vec +// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available +#define MMV_MAX_ROWS 512 -// TODO: remove this? -#ifndef GGML_CUDA_DMMV_X -#define GGML_CUDA_DMMV_X 32 -#endif +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); -#ifndef GGML_CUDA_MMV_Y -#define GGML_CUDA_MMV_Y 1 -#endif - -void ggml_cuda_op_dequantize_mul_mat_vec( +void ggml_cuda_op_mul_mat_vec( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); - -bool ggml_cuda_dmmv_type_supported(ggml_type src0_type); diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 5ed186ded16be..fccf8eb8440b8 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -75,18 +75,11 @@ target_include_directories(ggml-hip PRIVATE . ..) target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) add_compile_definitions(GGML_USE_HIP) -add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) -add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) -add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) if (GGML_HIP_UMA) add_compile_definitions(GGML_HIP_UMA) endif() -if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) -endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 8edc75cc5a9c7..f3c0136920540 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -58,19 +58,12 @@ if (MUSAToolkit_FOUND) target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) add_compile_definitions(GGML_USE_MUSA) - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS) endif() - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -83,10 +76,6 @@ if (MUSAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif()