From a750868428868abd437e228ae5cab763ef3dc387 Mon Sep 17 00:00:00 2001 From: AIWintermuteAI <32562299+AIWintermuteAI@users.noreply.github.com> Date: Tue, 16 Apr 2024 19:15:52 +0800 Subject: [PATCH 001/100] readme : add up-to-date repository for Python bindings (#2063) README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 55b811475cb..33570ef02bc 100644 --- a/README.md +++ b/README.md @@ -808,6 +808,7 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model - [NickDarvey/whisper](https://github.com/NickDarvey/whisper) - [x] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9) - [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython) + - [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp) - [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11) - [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper) - [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity) From b0c3cbf2e851cf232e432b590dcc514a689ec028 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 17 Apr 2024 12:23:47 +0300 Subject: [PATCH 002/100] main : pass nullptr when regex is empty (#2070) --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6f490e3e77d..15d8c8a83b6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1068,7 +1068,7 @@ int main(int argc, char ** argv) { wparams.tdrz_enable = params.tinydiarize; // [TDRZ] - wparams.suppress_regex = params.suppress_regex.c_str(); + wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str(); wparams.initial_prompt = params.prompt.c_str(); From 7f85e1d7fd3c995adc68c808a9bce5486a5ca90a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 24 Apr 2024 14:45:27 +0300 Subject: [PATCH 003/100] whisper : more prominent log message for sub-1s audio (#2065) --- whisper.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper.cpp b/whisper.cpp index b9e1ef2ced1..1a6d889af28 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -5272,7 +5272,7 @@ int whisper_full_with_state( // basically don't process anything that is less than 1.0s // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { - WHISPER_LOG_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10); + WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10); return 0; } From 858452d58dba3acdc3431c9bced2bb8cfd9bf418 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 24 Apr 2024 14:56:30 +0300 Subject: [PATCH 004/100] models : disable old script (#2079) --- models/download-coreml-model.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/models/download-coreml-model.sh b/models/download-coreml-model.sh index 83f2b238e94..405b355ea02 100755 --- a/models/download-coreml-model.sh +++ b/models/download-coreml-model.sh @@ -1,5 +1,8 @@ #!/bin/sh +printf "whisper.cpp: this script hasn't been maintained and is not functional atm\n" +exit 1 + # This script downloads Whisper model files that have already been converted to Core ML format. # This way you don't have to convert them yourself. From 22b6598cc9f1454567efa0d816fdc57637243999 Mon Sep 17 00:00:00 2001 From: goldwaving <77494627+goldwaving@users.noreply.github.com> Date: Sun, 28 Apr 2024 15:06:12 -0230 Subject: [PATCH 005/100] Remove unnecessary memory reallocation in fft (#2080) fft_out needs to be twice the frame_size, not the frame_step. It is resized in fft() anyway, but this change prevents an unnecessary reallocation. n_fft must match the mel filter size, so it is best not to calculate it from the framesize. We only need to get the magnitudes for half the spectrum since the other half is a mirror and not used in the mel filter loop later. --- whisper.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 1a6d889af28..f31309ed3b4 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2900,11 +2900,13 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector int n_samples, int frame_size, int frame_step, int n_threads, const whisper_filters & filters, whisper_mel & mel) { std::vector fft_in(frame_size, 0.0); - std::vector fft_out(2 * frame_step); - // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist - int n_fft = 1 + (frame_size / 2); + std::vector fft_out(2 * frame_size); + int n_fft = filters.n_fft; int i = ith; + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + assert( n_fft == 1 + (frame_size / 2) ); + // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { const int offset = i * frame_step; @@ -2923,7 +2925,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. - for (int j = 0; j < frame_size; j++) { + for (int j = 0; j < n_fft; j++) { fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); } From 8fac6455ffeb0a0950a84e790ddb74f7290d33c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Pawe=C5=82czyk?= Date: Sun, 28 Apr 2024 23:54:21 +0200 Subject: [PATCH 006/100] make : change GNU make default CXX from g++ to c++ (#2100) --- Makefile | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/Makefile b/Makefile index 3dd4a630857..b7e5a0e96f0 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,17 @@ ifndef NVCC_VERSION endif endif +# In GNU make default CXX is g++ instead of c++. Let's fix that so that users +# of non-gcc compilers don't have to provide g++ alias or wrapper. +DEFCC := cc +DEFCXX := c++ +ifeq ($(origin CC),default) +CC := $(DEFCC) +endif +ifeq ($(origin CXX),default) +CXX := $(DEFCXX) +endif + CCV := $(shell $(CC) --version | head -n 1) CXXV := $(shell $(CXX) --version | head -n 1) From 58210d6a7634ea1e42e0a2dab611f4a0518731dc Mon Sep 17 00:00:00 2001 From: Pedro Probst Date: Thu, 2 May 2024 18:52:55 -0300 Subject: [PATCH 007/100] examples : fix node compilation (#2115) * node : fix compilation and update examples * node : fix readme * Update addon.node test --- .github/workflows/examples.yml | 2 +- examples/addon.node/CMakeLists.txt | 2 +- examples/addon.node/README.md | 4 ++-- examples/addon.node/__test__/whisper.spec.js | 3 ++- examples/addon.node/index.js | 3 ++- examples/addon.node/package.json | 2 +- 6 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index ddaf5e9de5d..808dd18c0b7 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -37,7 +37,7 @@ jobs: run: npm install - name: Compile addon.node - run: npx cmake-js compile -T whisper-addon -B Release + run: npx cmake-js compile -T addon.node -B Release - name: Download test model run: | diff --git a/examples/addon.node/CMakeLists.txt b/examples/addon.node/CMakeLists.txt index aef7839eb77..29cb1a27d07 100644 --- a/examples/addon.node/CMakeLists.txt +++ b/examples/addon.node/CMakeLists.txt @@ -1,4 +1,4 @@ -set(TARGET whisper-addon) +set(TARGET addon.node) # Base settings #================================================================== diff --git a/examples/addon.node/README.md b/examples/addon.node/README.md index bdb1d256bec..16df7d95870 100644 --- a/examples/addon.node/README.md +++ b/examples/addon.node/README.md @@ -14,14 +14,14 @@ npm install Make sure it is in the project root directory and compiled with make-js. ```shell -npx cmake-js compile -T whisper-addon -B Release +npx cmake-js compile -T addon.node -B Release ``` For Electron addon and cmake-js options, you can see [cmake-js](https://github.com/cmake-js/cmake-js) and make very few configuration changes. > Such as appointing special cmake path: > ```shell -> npx cmake-js compile -c 'xxx/cmake' -T whisper-addon -B Release +> npx cmake-js compile -c 'xxx/cmake' -T addon.node -B Release > ``` ## Run diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index d102fe7624e..c0367a8c587 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -1,7 +1,7 @@ const path = require("path"); const { whisper } = require(path.join( __dirname, - "../../../build/Release/whisper-addon" + "../../../build/Release/addon.node" )); const { promisify } = require("util"); @@ -12,6 +12,7 @@ const whisperParamsMock = { model: path.join(__dirname, "../../../models/ggml-base.en.bin"), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), use_gpu: true, + no_timestamps: false, }; describe("Run whisper.node", () => { diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 3c6429375ab..9156a52de07 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -1,7 +1,7 @@ const path = require("path"); const { whisper } = require(path.join( __dirname, - "../../build/Release/whisper-addon" + "../../build/Release/addon.node" )); const { promisify } = require("util"); @@ -12,6 +12,7 @@ const whisperParams = { model: path.join(__dirname, "../../models/ggml-base.en.bin"), fname_inp: "../../samples/jfk.wav", use_gpu: true, + no_timestamps: false, }; const arguments = process.argv.slice(2); diff --git a/examples/addon.node/package.json b/examples/addon.node/package.json index bf51f0bba9f..50046bf1f56 100644 --- a/examples/addon.node/package.json +++ b/examples/addon.node/package.json @@ -1,5 +1,5 @@ { - "name": "whisper-addon", + "name": "addon.node", "version": "0.0.0", "description": "", "main": "index.js", From f7607560785700b6903b2e05085f9b80931bc5ab Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Wed, 8 May 2024 11:03:21 +0300 Subject: [PATCH 008/100] minor: add CMakeSettings.json to gitignore (#2094) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 43dcdabfb85..295cb74e625 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ .vscode/ .DS_Store .vimspector.json +/CMakeSettings.json build/ build-coreml/ From b6680fab503b3d469d3ffbff71545150896671d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Pawe=C5=82czyk?= Date: Wed, 8 May 2024 17:32:43 +0200 Subject: [PATCH 009/100] build : improve disabling AVX-512 (#2129) * cmake : make WHISPER_NO_AVX512=ON disable all subsets of AVX-512 Previously it happened only for MSVC, but it makes sense to have the same behavior for other compilers too. * make : reorder x86 ISA extensions in chronological order And update compiler flags at the end to ease modifying conditions. * make : support WHISPER_NO_AVX512=1 for disabling all AVX-512 subsets. That way you do not have to override each AVX-512 subset setting individually if it has been turned on during autodetection. --- CMakeLists.txt | 12 ++++----- Makefile | 71 ++++++++++++++++++++++++++++---------------------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 71cade0795a..b34b3768336 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -514,12 +514,12 @@ else() endif() if(NOT WHISPER_NO_AVX512) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw") - endif() - if(NOT WHISPER_NO_AVX512_VBMI) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512vbmi") - endif() - if(NOT WHISPER_NO_AVX512_VNNI) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512vnni") + if(NOT WHISPER_NO_AVX512_VBMI) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512vbmi") + endif() + if(NOT WHISPER_NO_AVX512_VNNI) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512vnni") + endif() endif() if(NOT WHISPER_NO_FMA) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma") diff --git a/Makefile b/Makefile index b7e5a0e96f0..901fe216035 100644 --- a/Makefile +++ b/Makefile @@ -142,59 +142,68 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64)) CPUINFO_CMD := sysinfo -cpu endif + # x86 ISA extensions (chronological order) ifdef CPUINFO_CMD + SSE3_M := $(shell $(CPUINFO_CMD) | grep -iwE 'PNI|SSE3') + SSSE3_M := $(shell $(CPUINFO_CMD) | grep -iw 'SSSE3') AVX_M := $(shell $(CPUINFO_CMD) | grep -iwE 'AVX|AVX1.0') + F16C_M := $(shell $(CPUINFO_CMD) | grep -iw 'F16C') + FMA_M := $(shell $(CPUINFO_CMD) | grep -iw 'FMA') + AVX2_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX2') + AVX512F_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX512F') + AVX512VBMI_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX512VBMI') + AVX512VNNI_M := $(shell $(CPUINFO_CMD) | grep -iwE 'AVX512_VNNI|AVX512VNNI') + + # AVX-512 has many subsets, so let's make it easy to disable them all + ifneq ($(filter-out 0,$(WHISPER_NO_AVX512)),) + AVX512F_M := + AVX512VBMI_M := + AVX512VNNI_M := + endif + + ifneq (,$(SSE3_M)) + CFLAGS += -msse3 + CXXFLAGS += -msse3 + endif + + ifneq (,$(SSSE3_M)) + CFLAGS += -mssse3 + CXXFLAGS += -mssse3 + endif + ifneq (,$(AVX_M)) CFLAGS += -mavx CXXFLAGS += -mavx endif - AVX2_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX2') + ifneq (,$(F16C_M)) + CFLAGS += -mf16c + CXXFLAGS += -mf16c + endif + + ifneq (,$(FMA_M)) + CFLAGS += -mfma + CXXFLAGS += -mfma + endif + ifneq (,$(AVX2_M)) CFLAGS += -mavx2 CXXFLAGS += -mavx2 endif - AVX512F_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX512F') ifneq (,$(AVX512F_M)) CFLAGS += -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw CXXFLAGS += -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw endif - AVX512VNNI_M := $(shell $(CPUINFO_CMD) | grep -iwE 'AVX512_VNNI|AVX512VNNI') - ifneq (,$(AVX512VNNI_M)) - CFLAGS += -mavx512vnni - CXXFLAGS += -mavx512vnni - endif - - AVX512VBMI_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX512VBMI') ifneq (,$(AVX512VBMI_M)) CFLAGS += -mavx512vbmi CXXFLAGS += -mavx512vbmi endif - FMA_M := $(shell $(CPUINFO_CMD) | grep -iw 'FMA') - ifneq (,$(FMA_M)) - CFLAGS += -mfma - CXXFLAGS += -mfma - endif - - F16C_M := $(shell $(CPUINFO_CMD) | grep -iw 'F16C') - ifneq (,$(F16C_M)) - CFLAGS += -mf16c - CXXFLAGS += -mf16c - endif - - SSE3_M := $(shell $(CPUINFO_CMD) | grep -iwE 'PNI|SSE3') - ifneq (,$(SSE3_M)) - CFLAGS += -msse3 - CXXFLAGS += -msse3 - endif - - SSSE3_M := $(shell $(CPUINFO_CMD) | grep -iw 'SSSE3') - ifneq (,$(SSSE3_M)) - CFLAGS += -mssse3 - CXXFLAGS += -mssse3 + ifneq (,$(AVX512VNNI_M)) + CFLAGS += -mavx512vnni + CXXFLAGS += -mavx512vnni endif endif endif From 73d13ad19a8c9c4da4f405088a85169b1a171e66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Pawe=C5=82czyk?= Date: Wed, 8 May 2024 17:33:43 +0200 Subject: [PATCH 010/100] ggml : expose SSE3 and SSSE3 for MSVC when AVX is available (#2128) --- ggml-impl.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml-impl.h b/ggml-impl.h index e68b728775c..93a4f1a2b72 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -37,9 +37,16 @@ extern "C" { #ifndef __F16C__ #define __F16C__ #endif +#endif + +// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available +#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)) #ifndef __SSE3__ #define __SSE3__ #endif +#ifndef __SSSE3__ +#define __SSSE3__ +#endif #endif // 16-bit float From 8dcefdf4a9118276b85637a916562b6abc6a8e80 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Thu, 25 Apr 2024 17:24:07 +0300 Subject: [PATCH 011/100] build: fix and ignore msvc warnings (ggml/805) --- ggml-backend.c | 4 ++-- ggml-quants.c | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/ggml-backend.c b/ggml-backend.c index 402d86ef3ac..a55967a6f16 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1178,9 +1178,9 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st static char * fmt_size(size_t size) { static char buffer[128]; if (size >= 1024*1024) { - sprintf(buffer, "%zuM", size/1024/1024); + snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024); } else { - sprintf(buffer, "%zuK", size/1024); + snprintf(buffer, sizeof(buffer), "%zuK", size/1024); } return buffer; } diff --git a/ggml-quants.c b/ggml-quants.c index 32e84434a8c..029511a6074 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -14,6 +14,12 @@ #include // for qsort #include // for GGML_ASSERT +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid warnings for hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) +#endif + #ifdef __ARM_NEON // if YCM cannot find , make a symbolic link to it, for example: From 37e6757453d4157bf0588e1f65e31931d3849628 Mon Sep 17 00:00:00 2001 From: Justina Cho Date: Wed, 1 May 2024 14:44:26 -0700 Subject: [PATCH 012/100] feat: implemented sigmoid function (ggml/806) * added sigmoid function * implemented metal kernel for sigmoid * implemented cuda kernel for sigmoid * added sigmoid unary op and incremented count --- ggml-cuda.cu | 4 +++ ggml-cuda/unary.cu | 26 ++++++++++++++++ ggml-cuda/unary.cuh | 3 ++ ggml-metal.m | 15 ++++++++++ ggml-metal.metal | 7 +++++ ggml.c | 73 ++++++++++++++++++++++++++++++++++++++++++++- ggml.h | 9 ++++++ 7 files changed, 136 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bff8ad9d96e..4a2bbdabfa7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2115,6 +2115,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_RELU: ggml_cuda_op_relu(ctx, dst); break; + case GGML_UNARY_OP_SIGMOID: + ggml_cuda_op_sigmoid(ctx, dst); + break; case GGML_UNARY_OP_HARDSIGMOID: ggml_cuda_op_hardsigmoid(ctx, dst); break; @@ -2355,6 +2358,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: diff --git a/ggml-cuda/unary.cu b/ggml-cuda/unary.cu index 1a7f0946972..ac03d5c6fce 100644 --- a/ggml-cuda/unary.cu +++ b/ggml-cuda/unary.cu @@ -48,6 +48,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) { dst[i] = fmaxf(x[i], 0); } +static __global__ void sigmoid_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = 1.0f / (1.0f + expf(-x[i])); +} + static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -108,6 +117,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ relu_f32<<>>(x, dst, k); } +static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE; + sigmoid_f32<<>>(x, dst, k); +} + static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE; hardsigmoid_f32<<>>(x, dst, k); @@ -188,6 +202,18 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml-cuda/unary.cuh b/ggml-cuda/unary.cuh index 2002ed98920..a1d07c04fcd 100644 --- a/ggml-cuda/unary.cuh +++ b/ggml-cuda/unary.cuh @@ -4,6 +4,7 @@ #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 +#define CUDA_SIGMOID_BLOCK_SIZE 256 #define CUDA_HARDSIGMOID_BLOCK_SIZE 256 #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 @@ -18,6 +19,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-metal.m b/ggml-metal.m index 419d8b9e568..86426e933e0 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -39,6 +39,7 @@ GGML_METAL_KERNEL_TYPE_SCALE_4, GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, + GGML_METAL_KERNEL_TYPE_SIGMOID, GGML_METAL_KERNEL_TYPE_GELU, GGML_METAL_KERNEL_TYPE_GELU_QUICK, GGML_METAL_KERNEL_TYPE_SILU, @@ -470,6 +471,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); @@ -695,6 +697,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: @@ -1178,6 +1181,18 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_GELU: diff --git a/ggml-metal.metal b/ggml-metal.metal index 9a29f57a38c..7f840ab089b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -220,6 +220,13 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + kernel void kernel_tanh( device const float * src0, device float * dst, diff --git a/ggml.c b/ggml.c index 793b67f4c70..3256dda8a08 100644 --- a/ggml.c +++ b/ggml.c @@ -1763,6 +1763,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } +inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } // TODO: optimize performance inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } @@ -2136,6 +2137,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "TANH", "ELU", "RELU", + "SIGMOID", "GELU", "GELU_QUICK", "SILU", @@ -2143,7 +2145,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", }; -static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12"); +static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -4295,6 +4297,20 @@ struct ggml_tensor * ggml_relu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); } +// ggml_sigmoid + +struct ggml_tensor * ggml_sigmoid( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID); +} + +struct ggml_tensor * ggml_sigmoid_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID); +} + // ggml_leaky_relu struct ggml_tensor * ggml_leaky_relu( @@ -9838,6 +9854,52 @@ static void ggml_compute_forward_relu( } } +// ggml_compute_forward_sigmoid + +static void ggml_compute_forward_sigmoid_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sigmoid_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sigmoid( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sigmoid_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_gelu static void ggml_compute_forward_gelu_f32( @@ -15485,6 +15547,10 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_relu(params, dst); } break; + case GGML_UNARY_OP_SIGMOID: + { + ggml_compute_forward_sigmoid(params, dst); + } break; case GGML_UNARY_OP_GELU: { ggml_compute_forward_gelu(params, dst); @@ -17471,6 +17537,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_UNARY_OP_SIGMOID: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_UNARY_OP_GELU: { GGML_ASSERT(false); // TODO: not implemented @@ -18000,6 +18070,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads { diff --git a/ggml.h b/ggml.h index abe3767f224..fbc34f0c9d0 100644 --- a/ggml.h +++ b/ggml.h @@ -511,6 +511,7 @@ extern "C" { GGML_UNARY_OP_TANH, GGML_UNARY_OP_ELU, GGML_UNARY_OP_RELU, + GGML_UNARY_OP_SIGMOID, GGML_UNARY_OP_GELU, GGML_UNARY_OP_GELU_QUICK, GGML_UNARY_OP_SILU, @@ -1055,6 +1056,10 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sigmoid( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_leaky_relu( struct ggml_context * ctx, struct ggml_tensor * a, float negative_slope, bool inplace); @@ -1063,6 +1068,10 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sigmoid_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_gelu( struct ggml_context * ctx, struct ggml_tensor * a); From 60f3713026a76ea6e196bb187df9dcdfb63fc94e Mon Sep 17 00:00:00 2001 From: jiez <373447296@qq.com> Date: Fri, 12 Apr 2024 18:45:06 +0800 Subject: [PATCH 013/100] llama : add gguf_remove_key + remove split meta during quantize (llama/6591) * Remove split metadata when quantize model shards * Find metadata key by enum * Correct loop range for gguf_remove_key and code format * Free kv memory --------- Co-authored-by: z5269887 --- ggml.c | 65 ++++++++++++++++++++++++++++++++++++---------------------- ggml.h | 3 +++ 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/ggml.c b/ggml.c index 3256dda8a08..2c4b8ec4ff4 100644 --- a/ggml.c +++ b/ggml.c @@ -20621,6 +20621,32 @@ static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) { return ok; } +static void gguf_free_kv(struct gguf_kv * kv) { + if (kv->key.data) { + GGML_FREE(kv->key.data); + } + + if (kv->type == GGUF_TYPE_STRING) { + if (kv->value.str.data) { + GGML_FREE(kv->value.str.data); + } + } + + if (kv->type == GGUF_TYPE_ARRAY) { + if (kv->value.arr.data) { + if (kv->value.arr.type == GGUF_TYPE_STRING) { + for (uint64_t j = 0; j < kv->value.arr.n; ++j) { + struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j]; + if (str->data) { + GGML_FREE(str->data); + } + } + } + GGML_FREE(kv->value.arr.data); + } + } +} + struct gguf_context * gguf_init_empty(void) { struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); @@ -20970,31 +20996,7 @@ void gguf_free(struct gguf_context * ctx) { if (ctx->kv) { // free string memory - not great.. for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { - struct gguf_kv * kv = &ctx->kv[i]; - - if (kv->key.data) { - GGML_FREE(kv->key.data); - } - - if (kv->type == GGUF_TYPE_STRING) { - if (kv->value.str.data) { - GGML_FREE(kv->value.str.data); - } - } - - if (kv->type == GGUF_TYPE_ARRAY) { - if (kv->value.arr.data) { - if (kv->value.arr.type == GGUF_TYPE_STRING) { - for (uint64_t j = 0; j < kv->value.arr.n; ++j) { - struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j]; - if (str->data) { - GGML_FREE(str->data); - } - } - } - GGML_FREE(kv->value.arr.data); - } - } + gguf_free_kv(&ctx->kv[i]); } GGML_FREE(ctx->kv); @@ -21219,6 +21221,19 @@ static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) { return n_kv; } +void gguf_remove_key(struct gguf_context * ctx, const char * key) { + const int idx = gguf_find_key(ctx, key); + if (idx >= 0) { + const int n_kv = gguf_get_n_kv(ctx); + gguf_free_kv(&ctx->kv[idx]); + for (int i = idx; i < n_kv-1; ++i) { + ctx->kv[i] = ctx->kv[i+1]; + } + ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv)); + ctx->header.n_kv--; + } +} + void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { const int idx = gguf_get_or_add_key(ctx, key); diff --git a/ggml.h b/ggml.h index fbc34f0c9d0..1a776ca83e4 100644 --- a/ggml.h +++ b/ggml.h @@ -2298,6 +2298,9 @@ extern "C" { GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int i); + // removes key if it exists + GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key); + // overrides existing values or adds a new one GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); From 00a0947c65e47a7aa9733c22d824aabf0bb18ec0 Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 12 Apr 2024 18:13:20 +0200 Subject: [PATCH 014/100] metal : unify mul_mv_id kernels (llama/6556) --- ggml-metal.m | 5 + ggml-metal.metal | 1189 ++++++---------------------------------------- ggml.c | 1 - 3 files changed, 140 insertions(+), 1055 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 86426e933e0..7f0f1f1f1ce 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1941,7 +1941,12 @@ static enum ggml_status ggml_metal_graph_compute( { nth0 = 4; nth1 = 16; + #if QK_K == 64 + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + #else pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + #endif + } break; default: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 7f840ab089b..79cce21ff4f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -864,15 +864,16 @@ void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; @@ -949,7 +950,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -975,7 +976,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1001,7 +1002,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1027,7 +1028,7 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } @@ -1046,6 +1047,7 @@ void kernel_mul_mv_q8_0_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1126,7 +1128,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } #define N_F32_F32 4 @@ -2716,6 +2718,7 @@ void kernel_mul_mv_q2_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2878,7 +2881,7 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } #if QK_K == 256 @@ -2895,6 +2898,7 @@ void kernel_mul_mv_q3_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3053,6 +3057,7 @@ void kernel_mul_mv_q3_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3142,7 +3147,7 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } #if QK_K == 256 @@ -3159,6 +3164,7 @@ void kernel_mul_mv_q4_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3272,6 +3278,7 @@ void kernel_mul_mv_q4_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3380,7 +3387,7 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( @@ -3396,6 +3403,7 @@ void kernel_mul_mv_q5_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3586,7 +3594,7 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( @@ -3602,6 +3610,7 @@ void kernel_mul_mv_q6_K_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3720,7 +3729,7 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit @@ -4403,6 +4412,7 @@ void kernel_mul_mv_iq1_s_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4492,6 +4502,7 @@ void kernel_mul_mv_iq1_m_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4600,11 +4611,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, - threadgroup float * shared_values [[threadgroup(0)]], + threadgroup int8_t * shared_values_i8 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; const int nb = ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -4694,11 +4706,11 @@ void kernel_mul_mv_iq4_xs_f32_impl( constant int64_t & ne1, constant uint & r2, constant uint & r3, - threadgroup float * shared_values [[threadgroup(0)]], + threadgroup int8_t * shared_values_i8 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -4801,7 +4813,7 @@ kernel void kernel_mul_mv_iq1_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -4829,7 +4841,7 @@ kernel void kernel_mul_mv_iq1_m_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -4853,7 +4865,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( constant int64_t & ne1, constant uint & r2, constant uint & r3, - threadgroup float * shared_values [[threadgroup(0)]], + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4882,7 +4894,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( constant int64_t & ne1, constant uint & r2, constant uint & r3, - threadgroup float * shared_values [[threadgroup(0)]], + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -6029,135 +6041,52 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel // matrix-vector multiplication // -[[host_name("kernel_mul_mv_id_f32_f32")]] -kernel void kernel_mul_mv_id_f32_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_f32_f32_impl( - src0, - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); -} - -[[host_name("kernel_mul_mv_id_f16_f32")]] -kernel void kernel_mul_mv_id_f16_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; +typedef void (kernel_mul_mv_impl_t)( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]); - kernel_mul_mv_f16_f32_impl( - src0, - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); -} +typedef void (kernel_mul_mv2_impl_t)( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); -[[host_name("kernel_mul_mv_id_q8_0_f32")]] -kernel void kernel_mul_mv_id_q8_0_f32( - device const char * src0s, +template +void mmv_fn( + device const char * src0, device const char * src1, device float * dst, - device const char * ids, - constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -6176,43 +6105,19 @@ kernel void kernel_mul_mv_id_q8_0_f32( constant uint64_t & nb1, constant uint & r2, constant uint & r3, - constant int & idx, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_q8_0_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); } -[[host_name("kernel_mul_mv_id_q4_0_f32")]] -kernel void kernel_mul_mv_id_q4_0_f32( - device const char * src0s, +template +void mmv_fn( + device const char * src0, device const char * src1, device float * dst, - device const char * ids, - constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -6231,43 +6136,18 @@ kernel void kernel_mul_mv_id_q4_0_f32( constant uint64_t & nb1, constant uint & r2, constant uint & r3, - constant int & idx, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - mul_vec_q_n_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } -[[host_name("kernel_mul_mv_id_q4_1_f32")]] -kernel void kernel_mul_mv_id_q4_1_f32( - device const char * src0s, +typedef void (mul_mv_impl_fn_t)( + device const char * src0, device const char * src1, device float * dst, - device const char * ids, - constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -6286,38 +6166,14 @@ kernel void kernel_mul_mv_id_q4_1_f32( constant uint64_t & nb1, constant uint & r2, constant uint & r3, - constant int & idx, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - mul_vec_q_n_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} + uint sgitg[[simdgroup_index_in_threadgroup]]); -[[host_name("kernel_mul_mv_id_q5_0_f32")]] -kernel void kernel_mul_mv_id_q5_0_f32( +template +kernel void kernel_mul_mv_id( device const char * src0s, device const char * src1, device float * dst, @@ -6342,6 +6198,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( constant uint & r2, constant uint & r3, constant int & idx, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], @@ -6353,26 +6210,36 @@ kernel void kernel_mul_mv_id_q5_0_f32( const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; device const char * src0 = src0s + id*nb02; - mul_vec_q_n_f32_impl( + impl_fn( src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, + src1 + bid*nb11, + dst + bid*ne0, ne00, ne01, ne02, + nb00, + nb01, + nb02, ne10, + ne11, ne12, + ne13, + nb10, + nb11, + nb12, ne0, ne1, + nb1, r2, r3, + shared_values, tgpig, + tiitg, tiisg, sgitg); } -[[host_name("kernel_mul_mv_id_q5_1_f32")]] -kernel void kernel_mul_mv_id_q5_1_f32( +typedef void (kernel_mul_mv_id_t)( device const char * src0s, device const char * src1, device float * dst, @@ -6397,819 +6264,33 @@ kernel void kernel_mul_mv_id_q5_1_f32( constant uint & r2, constant uint & r3, constant int & idx, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); + uint sgitg[[simdgroup_index_in_threadgroup]]); + +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +#if QK_K != 64 +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +#endif - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - mul_vec_q_n_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q2_K_f32")]] -kernel void kernel_mul_mv_id_q2_K_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_q2_K_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q3_K_f32")]] -kernel void kernel_mul_mv_id_q3_K_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_q3_K_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q4_K_f32")]] -kernel void kernel_mul_mv_id_q4_K_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_q4_K_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_K_f32")]] -kernel void kernel_mul_mv_id_q5_K_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_q5_K_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q6_K_f32")]] -kernel void kernel_mul_mv_id_q6_K_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_q6_K_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] -kernel void kernel_mul_mv_id_iq2_xxs_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_iq2_xxs_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - shared_values, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq2_xs_f32")]] -kernel void kernel_mul_mv_id_iq2_xs_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_iq2_xs_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - shared_values, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] -kernel void kernel_mul_mv_id_iq3_xxs_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_iq3_xxs_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - shared_values, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq3_s_f32")]] -kernel void kernel_mul_mv_id_iq3_s_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_iq3_s_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - shared_values, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq2_s_f32")]] -kernel void kernel_mul_mv_id_iq2_s_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_iq2_s_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - shared_values, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq1_s_f32")]] -kernel void kernel_mul_mv_id_iq1_s_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_iq1_s_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq1_m_f32")]] -kernel void kernel_mul_mv_id_iq1_m_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_iq1_m_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq4_nl_f32")]] -kernel void kernel_mul_mv_id_iq4_nl_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup float * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - - kernel_mul_mv_iq4_nl_f32_impl( - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - shared_values, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_iq4_xs_f32")]] -kernel void kernel_mul_mv_id_iq4_xs_f32( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup float * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; - -#if QK_K == 64 - kernel_mul_mv_iq4_nl_f32_impl( -#else - kernel_mul_mv_iq4_xs_f32_impl( -#endif - src0, - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - shared_values, - tgpig, - tiisg, - sgitg); -} diff --git a/ggml.c b/ggml.c index 2c4b8ec4ff4..ba06665a536 100644 --- a/ggml.c +++ b/ggml.c @@ -11074,7 +11074,6 @@ static void ggml_compute_forward_mul_mat_id( } // initialize matrix_row_counts - GGML_ASSERT(wdata == wdata_src1_end); memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); // group rows by src0 matrix From 66aaf03a7a06e595f43c95ef6e8cbfe846bfde09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 14 Apr 2024 00:21:55 +0200 Subject: [PATCH 015/100] CUDA: fix matrix multiplication logic for tests (llama/6667) --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4a2bbdabfa7..a3bbb920e68 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1946,7 +1946,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { From c1320c1f0c9de3dc7d04b1ccf9fb2a38da8c840f Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Sun, 14 Apr 2024 10:42:29 +0800 Subject: [PATCH 016/100] fix memcpy() crash, add missed cmd in guide, fix softmax (llama/6622) * disable mmap to fix memcpy crash, add missed cmd in guide, fix softmax * refactor to disable mmap for SYCL backend * fix compile error in other os * refactor the solution, use host buf to fix it, instead of disable mmap * keep to support mmap() * use host buff to reduce malloc times * revert to malloc/free solution, for threaad safe --- ggml-sycl.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 55a1eedb553..86091cfbfbd 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)( #define SYCL_SCALE_BLOCK_SIZE 256 #define SYCL_CLAMP_BLOCK_SIZE 256 #define SYCL_ROPE_BLOCK_SIZE 256 -#define SYCL_SOFT_MAX_BLOCK_SIZE 1024 #define SYCL_ALIBI_BLOCK_SIZE 32 #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32 #define SYCL_QUANTIZE_BLOCK_SIZE 256 @@ -13080,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float * const int nrows_y, const float scale, const float max_bias, dpct::queue_ptr stream) { int nth = WARP_SIZE; - while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2; + int max_block_size = g_work_group_size; + while (nth < ncols_x && nth < max_block_size) nth *= 2; + if (nth>max_block_size) nth = max_block_size; + const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_nums(1, 1, nrows_x); const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE); - static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); const uint32_t n_head_kv = nrows_x/nrows_y; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); @@ -13094,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float * const size_t local_mem_size = stream->get_device().get_info(); if (n_local_scratch*sizeof(float) < local_mem_size) { + if (ncols_x > max_block_size) { + soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + return; + } switch (ncols_x) { case 32: soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, @@ -16814,11 +16821,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, const dpct::queue_ptr stream = g_syclStreams[ctx->device][0]; SYCL_CHECK( CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); - + char* host_buf = (char*)malloc(size); + memcpy(host_buf, data, size); SYCL_CHECK( CHECK_TRY_ERROR((*stream) - .memcpy((char *)tensor->data + offset, data, size) + .memcpy((char *)tensor->data + offset, host_buf, size) .wait())); + free(host_buf); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ From 9d6d50d93309cfa9b23c8ead059bde3c1665d941 Mon Sep 17 00:00:00 2001 From: Dave Date: Sun, 14 Apr 2024 07:14:19 -0400 Subject: [PATCH 017/100] Added support for GGML_OP_CLAMP in Metal (llama/6662) * Added support for GGML_OP_CLAMP in Metal * Corrected size --------- Co-authored-by: dave-fl --- ggml-metal.m | 22 ++++++++++++++++++++++ ggml-metal.metal | 9 +++++++++ 2 files changed, 31 insertions(+) diff --git a/ggml-metal.m b/ggml-metal.m index 7f0f1f1f1ce..b43dfc3931d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -37,6 +37,7 @@ GGML_METAL_KERNEL_TYPE_DIV_ROW, GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE_4, + GGML_METAL_KERNEL_TYPE_CLAMP, GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, GGML_METAL_KERNEL_TYPE_SIGMOID, @@ -469,6 +470,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); @@ -716,6 +718,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_SCALE: + case GGML_OP_CLAMP: case GGML_OP_SQR: case GGML_OP_SUM_ROWS: return true; @@ -1157,6 +1160,25 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_CLAMP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { case GGML_UNARY_OP_TANH: diff --git a/ggml-metal.metal b/ggml-metal.metal index 79cce21ff4f..1d05087f4c1 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -213,6 +213,15 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + kernel void kernel_relu( device const float * src0, device float * dst, From 98c0b77e0cecaa34186e758c025f7b37934d786c Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Mon, 15 Apr 2024 17:12:26 +0800 Subject: [PATCH 018/100] fix mul_mat_id() for new input, make the ut pass (llama/6682) --- ggml-sycl.cpp | 96 +++++++++++++++++++++++++++------------------------ 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 86091cfbfbd..f5bb7da8698 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -15996,73 +15996,76 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) { static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { -#if 0 - ggml_sycl_mul_mat_id_sycl(dst); - // TODO: mmq/mmv support -#endif + GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT && + "mul_mat_id does not support split buffers"); + const ggml_tensor *ids = dst->src[2]; + const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; - const int64_t nb11 = src1->nb[1]; - const int64_t nb1 = dst->nb[1]; + const size_t nb11 = src1->nb[1]; + const size_t nb1 = dst->nb[1]; - const struct ggml_tensor * ids = src0; - const int32_t id = ((int32_t *) dst->op_params)[0]; - const int32_t n_as = ((int32_t *) dst->op_params)[1]; + const int32_t id = ((int32_t *)dst->op_params)[0]; + const int32_t n_as = src0->ne[2]; std::vector ids_host(ggml_nbytes(ids)); + const char *ids_dev = (const char *)ids->data; - const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; - - if (ids->backend == GGML_BACKEND_TYPE_GPU) { - const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)).wait())); - // SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); - } else { - memcpy(ids_host.data(), ids->data, ggml_nbytes(ids)); - } + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); + SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); - const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra; - const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra; + const ggml_tensor_extra_gpu *src0_extra = + (const ggml_tensor_extra_gpu *)src0->extra; + const ggml_tensor_extra_gpu *src1_extra = + (const ggml_tensor_extra_gpu *)src1->extra; + const ggml_tensor_extra_gpu *dst_extra = + (const ggml_tensor_extra_gpu *)dst->extra; + ggml_tensor_extra_gpu src0_row_extra; ggml_tensor_extra_gpu src1_row_extra; ggml_tensor_extra_gpu dst_row_extra; + ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; src1_row.backend = GGML_BACKEND_TYPE_GPU; dst_row.backend = GGML_BACKEND_TYPE_GPU; + src0_row.extra = &src0_row_extra; src1_row.extra = &src1_row_extra; dst_row.extra = &dst_row_extra; - char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ? - (char *) src1->data : (char *) src1_extra->data_device[g_main_device]; - char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ? - (char *) dst->data : (char *) dst_extra->data_device[g_main_device]; + char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU + ? (char *)src0->data + : (char *)src0_extra->data_device[g_main_device]; + char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU + ? (char *)src1->data + : (char *)src1_extra->data_device[g_main_device]; + char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU + ? (char *)dst->data + : (char *)dst_extra->data_device[g_main_device]; - if (src1->ne[1] == 1) { - GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU); - GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU); + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[3] = src0->nb[2]; + if (src1->ne[1] == 1) { for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - //int32_t row_id; - //SYCL_CHECK(syclMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), syclMemcpyDeviceToHost, g_syclStreams[g_main_device][0])); - //SYCL_CHECK(syclStreamSynchronize(g_syclStreams[g_main_device][0])); - - const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + const int32_t row_id = + *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] + + id * ids->nb[0]); GGML_ASSERT(row_id >= 0 && row_id < n_as); - const struct ggml_tensor * src0_row = dst->src[row_id + 2]; + src0_row_extra.data_device[g_main_device] = + src0_original + row_id * src0->nb[2]; + src1_row_extra.data_device[g_main_device] = + src1_original + i01 * src1->nb[1]; + dst_row_extra.data_device[g_main_device] = + dst_original + i01 * dst->nb[1]; - src1_row_extra.data_device[g_main_device] = src1_original + i01*src1->nb[1]; - src1_row.data = (char *) src1->data + i01*src1->nb[1]; // TODO why is this set? - - dst_row_extra.data_device[g_main_device] = dst_original + i01*dst->nb[1]; - dst_row.data = (char *) dst->data + i01*dst->nb[1]; // TODO why is this set? - - ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row); + ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); } } else { sycl_pool_alloc src1_contiguous(sizeof(float)*ggml_nelements(src1)); @@ -16072,8 +16075,6 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, dst_row_extra.data_device[g_main_device] = dst_contiguous.get(); for (int32_t row_id = 0; row_id < n_as; ++row_id) { - const struct ggml_tensor * src0_row = dst->src[row_id + 2]; - int64_t num_src1_rows = 0; for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); @@ -16086,7 +16087,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, SYCL_CHECK(CHECK_TRY_ERROR( stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11, - src1_original + i01 * nb11, nb11).wait())); + src1_original + i01 * nb11, nb11))); num_src1_rows++; } @@ -16094,6 +16095,9 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, continue; } + src0_row_extra.data_device[g_main_device] = + src0_original + row_id * src0->nb[2]; + src1_row.ne[1] = num_src1_rows; dst_row.ne[1] = num_src1_rows; @@ -16105,7 +16109,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; - ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row); + ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); num_src1_rows = 0; for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { @@ -16119,7 +16123,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy( dst_original + i01 * nb1, - dst_contiguous.get() + num_src1_rows * nb1, nb1).wait())); + dst_contiguous.get() + num_src1_rows * nb1, nb1))); num_src1_rows++; } } From fdb2c8735066a788aadf8ab1f32d21d0812cd7c7 Mon Sep 17 00:00:00 2001 From: Shijie <821898965@qq.com> Date: Tue, 16 Apr 2024 23:40:48 +0800 Subject: [PATCH 019/100] llama : add qwen2moe (llama/6074) * support qwen2moe * fix-review * metal : support unary ops for nelements % 4 != 0 * metal : require contiguousness for float4 unary kernels * metal : require contiguousness for float4 unary kernels (cont) * fix-review * names : for brevity "SHARED_EXP" -> "SHEXP" * llama : reuse build_moe_ffn() * llama : add model type name --------- Co-authored-by: Georgi Gerganov --- ggml-metal.m | 57 +++++++++++++++++++++++++++++++++++------------- ggml-metal.metal | 26 ++++++++++++++++++++++ 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index b43dfc3931d..0ec47febbd2 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -42,8 +42,11 @@ GGML_METAL_KERNEL_TYPE_RELU, GGML_METAL_KERNEL_TYPE_SIGMOID, GGML_METAL_KERNEL_TYPE_GELU, + GGML_METAL_KERNEL_TYPE_GELU_4, GGML_METAL_KERNEL_TYPE_GELU_QUICK, + GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, + GGML_METAL_KERNEL_TYPE_SILU_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX, GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, @@ -475,8 +478,11 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); @@ -1181,6 +1187,9 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + case GGML_UNARY_OP_TANH: { id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; @@ -1219,42 +1228,60 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_GELU_QUICK: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_SILU: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 1d05087f4c1..d7ae37206f4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -249,6 +249,15 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -262,6 +271,15 @@ kernel void kernel_gelu( } kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -271,6 +289,14 @@ kernel void kernel_gelu_quick( } kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + ++kernel void kernel_silu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { From 7a4f7d825e7121118300e98a13f38d96047a1460 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Tue, 16 Apr 2024 14:55:30 -0400 Subject: [PATCH 020/100] ggml : add llamafile sgemm (llama/6414) This change upstreams llamafile's cpu matrix multiplication kernels which improve image and prompt evaluation speed. For starters, Q4_0 and Q8_0 weights should go ~40% faster on CPU. The biggest benefits are with data types like f16 / f32, which process prompts 2x faster thus making them faster than quantized data types for prompt evals. This change also introduces bona fide AVX512 support since tinyBLAS is able to exploit the larger register file. For example, on my CPU llama.cpp llava-cli processes an image prompt at 305 tokens/second, using the Q4_K and Q4_0 types, which has always been faster than if we used f16 LLaVA weights, which at HEAD go 188 tokens/second. With this change, f16 LLaVA performance leap frogs to 464 tokens/second. On Intel Core i9-14900K this change improves F16 prompt perf by 5x. For example, using llama.cpp at HEAD with Mistral 7b f16 to process a 215 token prompt will go 13 tok/sec. This change has fixes making it go 52 tok/sec. It's mostly thanks to my vectorized outer product kernels but also because I added support for correctly counting the number of cores on Alderlake, so the default thread count discounts Intel's new efficiency cores. Only Linux right now can count cores. This work was sponsored by Mozilla who's given permission to change the license of this code from Apache 2.0 to MIT. To read more about what's improved, and how it works, see: https://justine.lol/matmul/ --- ggml-impl.h | 2 +- ggml-quants.c | 2 +- ggml.c | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/ggml-impl.h b/ggml-impl.h index 93a4f1a2b72..43eb631e4c1 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -95,7 +95,7 @@ typedef uint16_t ggml_fp16_internal_t; #if defined(_MSC_VER) || defined(__MINGW32__) #include #else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) #if !defined(__riscv) #include #endif diff --git a/ggml-quants.c b/ggml-quants.c index 029511a6074..4be9575e0c1 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -138,7 +138,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { } static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || defined(__AVX512VNNI__) +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); diff --git a/ggml.c b/ggml.c index ba06665a536..c5280e718cf 100644 --- a/ggml.c +++ b/ggml.c @@ -4,6 +4,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml.h" +#include "sgemm.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -32,6 +33,14 @@ #include #endif +#ifndef GGML_USE_LLAMAFILE +#ifdef __ARM_FEATURE_MATMUL_INT8 +#define GGML_USE_LLAMAFILE 0 +#else +#define GGML_USE_LLAMAFILE 1 +#endif +#endif + #if defined(_MSC_VER) // disable "possible loss of data" to avoid hundreds of casts // we should just be careful :) @@ -10872,6 +10881,28 @@ static void ggml_compute_forward_mul_mat( } #endif +#if GGML_USE_LLAMAFILE + if (nb10 == ggml_type_size(src1->type)) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)src1->data + i12*nb12 + i13*nb13, + nb11/ggml_type_size(src1->type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/ggml_type_size(dst->type), + ith, nth, + params->type, + src0->type, + src1->type, + dst->type)) + goto UseGgmlGemm1; + return; + } +UseGgmlGemm1:; +#endif + if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; @@ -10903,6 +10934,29 @@ static void ggml_compute_forward_mul_mat( const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); +#if GGML_USE_LLAMAFILE + if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 + + nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13), + row_size/ggml_type_size(vec_dot_type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/ggml_type_size(dst->type), + ith, nth, + params->type, + src0->type, + vec_dot_type, + dst->type)) + goto UseGgmlGemm2; + return; + } +UseGgmlGemm2:; +#endif + const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = ne1*ne12*ne13; // src1 rows From c97796aa0f7534e45e58aa30f5ab569c56b91db3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 16 Apr 2024 23:50:22 +0300 Subject: [PATCH 021/100] ggml : fix llamafile sgemm wdata offsets (llama/6710) ggml-ci --- ggml.c | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ggml.c b/ggml.c index c5280e718cf..707a1fe4140 100644 --- a/ggml.c +++ b/ggml.c @@ -33,12 +33,8 @@ #include #endif -#ifndef GGML_USE_LLAMAFILE #ifdef __ARM_FEATURE_MATMUL_INT8 -#define GGML_USE_LLAMAFILE 0 -#else -#define GGML_USE_LLAMAFILE 1 -#endif +#undef GGML_USE_LLAMAFILE #endif #if defined(_MSC_VER) @@ -10941,8 +10937,9 @@ UseGgmlGemm1:; if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 + - nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13), + (const char *)wdata + ggml_row_size(vec_dot_type, + nb12/ggml_type_size(src1->type)*i12 + + nb13/ggml_type_size(src1->type)*i13), row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), From c96b0a938ed330eed197ae930ea21f8b19cd6191 Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 18 Apr 2024 15:18:48 +0200 Subject: [PATCH 022/100] ggml : group all experts in a single ggml_mul_mat_id (llama/6505) * ggml : group all experts in a single ggml_mul_mat_id cuda : improve mmid row copy * cuda : fix bin bcast with non-cont src0 * test-backend-ops : only run all mul mat tests for base types * llama : disable moe offloading with SYCL --------- Co-authored-by: Georgi Gerganov --- ggml-cuda.cu | 179 ++++++--- ggml-cuda/binbcast.cu | 92 +++-- ggml-cuda/convert.cu | 2 + ggml-metal.m | 129 +++---- ggml-metal.metal | 878 +++++++++++++++++++----------------------- ggml-sycl.cpp | 2 +- ggml.c | 123 +++--- ggml.h | 6 +- 8 files changed, 730 insertions(+), 681 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a3bbb920e68..07534370c34 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas( if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 - ggml_cuda_pool_alloc src0_as_f16(ctx.pool()); + ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); GGML_ASSERT(to_fp16_cuda != nullptr); @@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas( } const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get(); - ggml_cuda_pool_alloc src1_as_f16(ctx.pool()); + ggml_cuda_pool_alloc src1_as_f16(ctx.pool(id)); if (src1->type != GGML_TYPE_F16) { const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); GGML_ASSERT(to_fp16_cuda != nullptr); @@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool(), row_diff*src1_ncols); + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; @@ -1960,20 +1960,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } } +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, + int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, + const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + int64_t ne11, int64_t ne10, + size_t nb11, size_t nb12) { + int32_t iid1 = blockIdx.x; + int32_t id = blockIdx.y; + + const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + + if (row_id_i != i02) { + return; + } + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + __shared__ int src1_row; + if (threadIdx.x == 0) { + src1_row = atomicAdd(cur_src1_row, 1); + row_mapping[src1_row] = {id, iid1}; + } + __syncthreads(); + + const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); + float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); + + for (int i = threadIdx.x; i < ne10; i += blockDim.x) { + src1_row_contiguous[i] = src1_row_original[i]; + } +} + +static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous, + const mmid_row_mapping * __restrict__ row_mapping, + int64_t ne0, + size_t nb1, size_t nb2) { + int32_t i = blockIdx.x; + + const int32_t i1 = row_mapping[i].i1; + const int32_t i2 = row_mapping[i].i2; + + const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); + float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); + + for (int j = threadIdx.x; j < ne0; j += blockDim.x) { + dst_row_original[j] = dst_row_contiguous[j]; + } +} + static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); cudaStream_t stream = ctx.stream(); - const size_t nb11 = src1->nb[1]; - const size_t nb1 = dst->nb[1]; - - const int32_t id = ((int32_t *) dst->op_params)[0]; - const int32_t n_as = src0->ne[2]; + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; @@ -1982,7 +2035,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + ggml_tensor dst_row = *dst; char * src0_original = (char *) src0->data; char * src1_original = (char *) src1->data; @@ -1990,19 +2043,39 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * src0_row.ne[2] = 1; src0_row.ne[3] = 1; - src0_row.nb[3] = src0->nb[2]; + src0_row.nb[3] = nb02; - if (src1->ne[1] == 1) { - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; - GGML_ASSERT(row_id >= 0 && row_id < n_as); + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; - src0_row.data = src0_original + row_id*src0->nb[2]; - src1_row.data = src1_original + i01*src1->nb[1]; - dst_row.data = dst_original + i01*dst->nb[1]; + if (ne12 == 1) { + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = id; + const int64_t i2 = i12; + + src0_row.data = src0_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + dst_row.data = dst_original + i1*nb1 + i2*nb2; + + ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } } } else { ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); @@ -2011,54 +2084,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); - for (int32_t row_id = 0; row_id < n_as; ++row_id) { + for (int64_t i02 = 0; i02 < n_as; i02++) { int64_t num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); - if (row_id_i != row_id) { - continue; - } + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(row_id >= 0 && row_id < n_as); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11, - nb11, cudaMemcpyDeviceToDevice, stream)); - num_src1_rows++; + if (row_id_i != i02) { + continue; + } + + num_src1_rows++; + } } if (num_src1_rows == 0) { continue; } - src0_row.data = src0_original + row_id*src0->nb[2]; + ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); + CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); - src1_row.ne[1] = num_src1_rows; - dst_row.ne[1] = num_src1_rows; + { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(ids->ne[1], n_ids); + k_copy_src1_to_contiguous<<>>( + src1_original, src1_contiguous.get(), + dev_cur_src1_row.get(), dev_row_mapping.get(), + ids_dev, i02, ids->nb[1], ids->nb[0], + ne11, ne10, + nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + } + + src0_row.data = src0_original + i02*nb02; + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); + + src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); - num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); - - if (row_id_i != row_id) { - continue; - } - - GGML_ASSERT(row_id >= 0 && row_id < n_as); - - CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1, - nb1, cudaMemcpyDeviceToDevice, stream)); - num_src1_rows++; + { + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_contiguous.get(), + dev_row_mapping.get(), + ne0, + nb1, nb2); + CUDA_CHECK(cudaGetLastError()); } } } @@ -2491,7 +2579,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) { const int min_batch_size = 32; - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); GGML_UNUSED(backend); } diff --git a/ggml-cuda/binbcast.cu b/ggml-cuda/binbcast.cu index 959eaed95c1..19b08b74fb0 100644 --- a/ggml-cuda/binbcast.cu +++ b/ggml-cuda/binbcast.cu @@ -22,6 +22,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, /*int s10,*/ int s11, int s12, int s13) { const int i0s = blockDim.x*blockIdx.x + threadIdx.x; const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); @@ -36,9 +37,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst const int i12 = i2 % ne12; const int i13 = i3 % ne13; - const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i_src0; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; const src0_t * src0_row = src0 + i_src0; const src1_t * src1_row = src1 + i_src1; @@ -55,6 +56,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, /*int s10,*/ int s11, int s12, int s13) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -72,9 +74,9 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s const int i12 = i2 % ne12; const int i13 = i3 % ne13; - const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i_src0; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; const src0_t * src0_row = src0 + i_src0; const src1_t * src1_row = src1 + i_src1; @@ -101,10 +103,14 @@ struct bin_bcast_cuda { int nr[4] = { nr0, nr1, nr2, nr3 }; // collapse dimensions until first broadcast dimension - int64_t cne0[] = {ne0, ne1, ne2, ne3}; + int64_t cne[] = {ne0, ne1, ne2, ne3}; + int64_t cne0[] = {ne00, ne01, ne02, ne03}; int64_t cne1[] = {ne10, ne11, ne12, ne13}; - size_t cnb0[] = {nb0, nb1, nb2, nb3}; + + size_t cnb[] = {nb0, nb1, nb2, nb3}; + size_t cnb0[] = {nb00, nb01, nb02, nb03}; size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { cne[0] *= cne[1]; cne[1] = cne[2]; @@ -118,32 +124,47 @@ struct bin_bcast_cuda { cnb[3] *= cne[3]; }; - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne0); - collapse(cne1); + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } } } + { - int64_t ne0 = cne0[0]; - int64_t ne1 = cne0[1]; - int64_t ne2 = cne0[2]; - int64_t ne3 = cne0[3]; + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); + //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); + //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); + //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); int64_t ne10 = cne1[0]; int64_t ne11 = cne1[1]; int64_t ne12 = cne1[2]; int64_t ne13 = cne1[3]; - size_t nb0 = cnb0[0]; - size_t nb1 = cnb0[1]; - size_t nb2 = cnb0[2]; - size_t nb3 = cnb0[3]; + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; size_t nb10 = cnb1[0]; size_t nb11 = cnb1[1]; @@ -160,7 +181,28 @@ struct bin_bcast_cuda { size_t s12 = nb12 / sizeof(src1_t); size_t s13 = nb13 / sizeof(src1_t); + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s00 == 1); GGML_ASSERT(s10 == 1); const int block_size = 128; @@ -179,13 +221,14 @@ struct bin_bcast_cuda { ); if (block_nums.z > 65535) { - // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; k_bin_bcast_unravel<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, /* s0, */ s1, s2, s3, + /* s00, */ s01, s02, s03, /* s10, */ s11, s12, s13); } else { k_bin_bcast<<>>( @@ -193,6 +236,7 @@ struct bin_bcast_cuda { ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, /* s0, */ s1, s2, s3, + /* s00, */ s01, s02, s03, /* s10, */ s11, s12, s13); } } diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index ed4fa274897..b15e3578267 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h vals[ix] = x0[ix]; } + __syncthreads(); + #pragma unroll for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) { if (need_check && i0 + iy + 2*threadIdx.x >= k) { diff --git a/ggml-metal.m b/ggml-metal.m index 0ec47febbd2..fdba0de85bc 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1747,15 +1747,10 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_MUL_MAT_ID: { - //GGML_ASSERT(ne00 == ne10); - //GGML_ASSERT(ne03 == ne13); const int n_as = src0->ne[2]; - // max size of the src1ids array in the kernel shared buffer - GGML_ASSERT(ne11 <= 4096); - // src2 = ids - const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20); + const int64_t ne20 = src2->ne[0]; const int64_t ne21 = src2->ne[1]; const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); @@ -1776,15 +1771,13 @@ static enum ggml_status ggml_metal_graph_compute( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = n_as; - - const int idx = ((int32_t *) dst->op_params)[0]; + // ne20 = n_used_experts + // ne21 = n_rows + const int dst_rows = ne20*ne21; + const int dst_rows_min = n_as; - // batch size - GGML_ASSERT(ne21 == ne11); // ? - GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting - const uint r2 = 1; - const uint r3 = 1; + // max size of the rowids array in the kernel shared buffer + GGML_ASSERT(dst_rows <= 2048); // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel @@ -1794,7 +1787,7 @@ static enum ggml_status ggml_metal_graph_compute( // !!! if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && - ne11 > ne11_mm_min) { + dst_rows > dst_rows_min) { // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) @@ -1836,26 +1829,26 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:19]; - - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + + [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; @@ -2008,72 +2001,72 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ne00 >= nth0*nth1); } - const int64_t _ne1 = 1; // kernels needs a reference in constant memory - [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:21]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:22]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:23]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + + const int64_t _ne1 = 1; + const int tgz = dst_rows; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { const int mem_size = 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif } else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index d7ae37206f4..7f37c17d668 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -899,16 +899,16 @@ void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; @@ -1073,19 +1073,19 @@ void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -1172,24 +1172,24 @@ void kernel_mul_mv_f32_f32_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F32_F32; @@ -1442,24 +1442,24 @@ void kernel_mul_mv_f16_f32_impl( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F16_F32; @@ -2744,19 +2744,19 @@ void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -2924,19 +2924,19 @@ void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -3190,19 +3190,19 @@ void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -3429,19 +3429,19 @@ void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; @@ -3636,19 +3636,19 @@ void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -3773,19 +3773,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -3902,19 +3902,19 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4041,19 +4041,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4173,19 +4173,19 @@ void kernel_mul_mv_iq3_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4305,19 +4305,19 @@ void kernel_mul_mv_iq2_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4438,19 +4438,19 @@ void kernel_mul_mv_iq1_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4528,19 +4528,19 @@ void kernel_mul_mv_iq1_m_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -4637,19 +4637,19 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values_i8 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { threadgroup float * shared_values = (threadgroup float *)shared_values_i8; const int nb = ne00/QK4_NL; @@ -4732,19 +4732,20 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values_i8 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; const int nb = ne00/QK_K; const int r0 = tgpig.x; @@ -5686,25 +5687,25 @@ void kernel_mul_mm_impl(device const uchar * src0, } } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids template void kernel_mul_mm_id_impl( device const uchar * src0, device const uchar * src1, - threadgroup short * src1ids, + threadgroup ushort2 * rowids, device float * dst, constant int64_t & ne00, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant int64_t & ne0, int64_t ne1, - constant uint & r2, - constant uint & r3, + int64_t ne0ne1, threadgroup uchar * shared_memory, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -5715,7 +5716,6 @@ void kernel_mul_mm_id_impl( const uint r0 = tgpig.y; const uint r1 = tgpig.x; - const uint im = tgpig.z; if (r1 * BLOCK_SIZE_N >= ne1) return; @@ -5733,19 +5733,16 @@ void kernel_mul_mm_id_impl( for (int i = 0; i < 8; i++){ c_res[i] = make_filled_simdgroup_matrix(0.f); } - short il = (tiitg % THREAD_PER_ROW); - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] + + nb12 * id[1] + + nb11 * (id[0] % ne11) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { @@ -5774,11 +5771,11 @@ void kernel_mul_mm_id_impl( for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; @@ -5800,11 +5797,13 @@ void kernel_mul_mm_id_impl( threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; + device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; + int joff = jid[0] * ne0 + jid[1] * ne0ne1; + for (int i = 0; i < n_rows; i++) { + *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); } } } @@ -5859,11 +5858,14 @@ kernel void kernel_mul_mm_id( device const uchar * src1, device float * dst, device const uchar * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, constant uint64_t & nb10, @@ -5872,47 +5874,52 @@ kernel void kernel_mul_mm_id( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - // expert id - const int32_t id = tgpig.z/(ne12*ne13); - device const uchar * src0 = src0s + id*nb02; + const int32_t i02 = tgpig.z; + tgpig.z = 0; - tgpig.z = tgpig.z%(ne12*ne13); + device const uchar * src0 = src0s + i02*nb02; - // row indices of src1 for expert id - threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192); + // row indices + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + // TODO: parallelize this loop int64_t _ne1 = 0; - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { - src1ids[_ne1++] = i1; + for (ushort ii1 = 0; ii1 < nei1; ii1++) { + for (ushort ii0 = 0; ii0 < nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) { + //if (tiitg == 0) { + rowids[_ne1] = ushort2(ii0, ii1); + //} + _ne1++; + } } } + threadgroup_barrier(mem_flags::mem_threadgroup); + kernel_mul_mm_id_impl( src0, src1, - src1ids, + rowids, dst, ne00, ne02, nb01, nb02, + ne11, ne12, nb10, nb11, nb12, ne0, _ne1, - r2, - r3, + ne0*ne1, shared_memory, tgpig, tiitg, @@ -5973,24 +5980,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r // matrix-matrix multiplication // -typedef void (mat_mm_t)( - device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar *, - uint3, uint, uint); +typedef decltype(kernel_mul_mm) mat_mm_t; template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; @@ -6022,29 +6012,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m // indirect matrix-matrix multiplication // -typedef void (mat_mm_id_t)( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup uchar *, - uint3, uint, uint); +typedef decltype(kernel_mul_mm_id) mat_mm_id_t; template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; @@ -6080,71 +6048,71 @@ typedef void (kernel_mul_mv_impl_t)( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]); + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg); typedef void (kernel_mul_mv2_impl_t)( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg); template void mmv_fn( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); } @@ -6153,59 +6121,33 @@ void mmv_fn( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } -typedef void (mul_mv_impl_fn_t)( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); +typedef decltype(mmv_fn) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( @@ -6213,6 +6155,8 @@ kernel void kernel_mul_mv_id( device const char * src1, device float * dst, device const char * ids, + constant int64_t & nei0, + constant int64_t & nei1, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6230,43 +6174,50 @@ kernel void kernel_mul_mv_id( constant int64_t & ne0, constant int64_t & ne1, constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t bid = tgpig.z/(ne12*ne13); + const int iid1 = tgpig.z/nei0; + const int idx = tgpig.z%nei0; + + tgpig.z = 0; - tgpig.z = tgpig.z%(ne12*ne13); + const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - device const char * src0 = src0s + id*nb02; + const int64_t i11 = idx % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*nb02; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; impl_fn( - src0, - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - ne13, - nb10, - nb11, - nb12, - ne0, - ne1, - nb1, - r2, - r3, + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + /* ne00 */ ne00, + /* ne01 */ ne01, + /* ne02 */ 1,//ne02, + /* nb00 */ nb00, + /* nb01 */ nb01, + /* nb02 */ nb02, + /* ne10 */ ne10, + /* ne11 */ 1,//ne11, + /* ne12 */ 1,//ne12, + /* ne13 */ 1,//ne13, + /* nb10 */ nb10, + /* nb11 */ nb11, + /* nb12 */ nb12, + /* ne0 */ ne0, + /* ne1 */ 1,//ne1, + /* nb1 */ nb1, + /* r2 */ 1, + /* r3 */ 1, shared_values, tgpig, tiitg, @@ -6274,36 +6225,7 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef void (kernel_mul_mv_id_t)( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]); +typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index f5bb7da8698..a9b310243f0 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -17752,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) { const int min_batch_size = 32; - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; + return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID; GGML_UNUSED(backend); } diff --git a/ggml.c b/ggml.c index 707a1fe4140..a745104c655 100644 --- a/ggml.c +++ b/ggml.c @@ -4594,21 +4594,32 @@ void ggml_mul_mat_set_prec( // ggml_mul_mat_id -// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed -// this will allow computing all the used experts in a single matrix multiplication +/* + c = ggml_mul_mat_id(ctx, as, b, ids); + + as -> [cols, rows, n_expert] + ids -> [n_experts_used, n_tokens] (i32) + b -> [cols, n_expert_used, n_tokens] + c -> [cols, n_expert_used, n_tokens] + + in b, n_experts_used can be broadcasted to match the n_expert_used of ids + + c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids +*/ struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, struct ggml_tensor * as, - struct ggml_tensor * ids, - int id, - struct ggml_tensor * b) { - + struct ggml_tensor * b, + struct ggml_tensor * ids) { + GGML_ASSERT(!ggml_is_transposed(as)); GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d - GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row - GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); - GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast bool is_node = false; @@ -4616,11 +4627,9 @@ struct ggml_tensor * ggml_mul_mat_id( is_node = true; } - const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_set_op_params_i32(result, 0, id); - result->op = GGML_OP_MUL_MAT_ID; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = as; @@ -11071,11 +11080,6 @@ static void ggml_compute_forward_mul_mat_id( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type)); @@ -11086,22 +11090,21 @@ static void ggml_compute_forward_mul_mat_id( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // broadcast is not supported with mmid - assert(ne12 == 1); - assert(ne13 == 1); - // row groups - const int id = ggml_get_op_params_i32(dst, 0); - const int n_as = src0->ne[2]; + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert char * wdata_src1_end = (src1->type == vec_dot_type) ? (char *) params->wdata : (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11] + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; - #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { @@ -11127,13 +11130,18 @@ static void ggml_compute_forward_mul_mat_id( // initialize matrix_row_counts memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] + // group rows by src0 matrix - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); - GGML_ASSERT(row_id >= 0 && row_id < n_as); - MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01; - matrix_row_counts[row_id] += 1; + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; + matrix_row_counts[i02] += 1; + } } return; @@ -11151,15 +11159,13 @@ static void ggml_compute_forward_mul_mat_id( continue; } - size_t src0_offset = cur_a*src0->nb[2]; + const char * src0_cur = (const char *) src0->data + cur_a*nb02; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1*ne12*ne13; // src1 rows - - //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows // distribute the thread work across the inner or outer loop based on which one is larger @@ -11178,13 +11184,11 @@ static void ggml_compute_forward_mul_mat_id( const int64_t ir110 = dr1*ith1; const int64_t ir111 = MIN(ir110 + dr1, nr1); - //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); - // threads with no work simply yield (not sure if it helps) - if (ir010 >= ir011 || ir110 >= ir111) { - sched_yield(); - continue; - } + //if (ir010 >= ir011 || ir110 >= ir111) { + // sched_yield(); + // continue; + //} // block-tiling attempt const int64_t blck_0 = 16; @@ -11196,20 +11200,16 @@ static void ggml_compute_forward_mul_mat_id( for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix - const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1; - const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1); - const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); + const int64_t _i12 = ir1; // logical row index for this expert - // broadcast src0 into src1 - //const int64_t i03 = i13/r3; - //const int64_t i02 = i12/r2; + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); + const int id = row_mapping.i1; // selected expert index - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 - const char * src0_row = (const char *) src0->data + src0_offset; + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using @@ -11217,25 +11217,26 @@ static void ggml_compute_forward_mul_mat_id( // TODO: this is a bit of a hack, we should probably have a better way to handle this const char * src1_col = (const char *) wdata + (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size - : (i11*nb11 + i12*nb12 + i13*nb13)); + ? (i11 + i12*ne11)*row_size + : (i11*nb11 + i12*nb12)); - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1); + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); } + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); } } } } - #undef MMID_MATRIX_ROW +#undef MMID_MATRIX_ROW } // ggml_compute_forward_out_prod @@ -18583,7 +18584,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const int n_as = src0->ne[2]; cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows + cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows } break; case GGML_OP_OUT_PROD: { @@ -21009,12 +21010,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p ok = ok && cur != NULL; - ggml_set_name(cur, ctx->infos[i].name.data); - if (!ok) { break; } + ggml_set_name(cur, ctx->infos[i].name.data); + // point the data member to the appropriate location in the binary blob using the tensor infos if (!params.no_alloc) { //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file diff --git a/ggml.h b/ggml.h index 1a776ca83e4..6d2c8c566ec 100644 --- a/ggml.h +++ b/ggml.h @@ -1170,13 +1170,11 @@ extern "C" { enum ggml_prec prec); // indirect matrix multiplication - // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) GGML_API struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, struct ggml_tensor * as, - struct ggml_tensor * ids, - int id, - struct ggml_tensor * b); + struct ggml_tensor * b, + struct ggml_tensor * ids); // A: m columns, n rows, // B: p columns, n rows, From 295968601986c64b8f7ce1aa5f012e309271b3a4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Apr 2024 16:47:57 +0300 Subject: [PATCH 023/100] ggml : fix ggml_backend_cpu_supports_op() for CPY (llama/0) --- ggml-backend.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml-backend.c b/ggml-backend.c index a55967a6f16..d96100b3702 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -822,7 +822,11 @@ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { switch (op->op) { case GGML_OP_CPY: - return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float + return + op->type != GGML_TYPE_IQ2_XXS && + op->type != GGML_TYPE_IQ2_XS && + op->type != GGML_TYPE_IQ1_S && + op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; default: From a6d264f331d3a4cd98188992a48d0f923532009e Mon Sep 17 00:00:00 2001 From: Dave Airlie Date: Tue, 23 Apr 2024 00:05:06 +1000 Subject: [PATCH 024/100] ggml : fix calloc argument ordering. (llama/6820) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Latest gcc complains here: /home/airlied/devel/llama.cpp/ggml-alloc.c: In function ‘ggml_gallocr_new_n’: /home/airlied/devel/llama.cpp/ggml-alloc.c:374:59: warning: ‘calloc’ sizes specified with ‘sizeof’ in the earlier argument and not in the later argument [-Wcalloc-transposed-args] 374 | ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(sizeof(struct ggml_gallocr), 1); | ^~~~~~ /home/airlied/devel/llama.cpp/ggml-alloc.c:374:59: note: earlier argument should specify number of elements, later size of each element and a bunch more. calloc is specified to take nmemb first then size, so realign the code. In a couple of places there was a * x, 1 so I fixed those to use calloc properly. --- ggml-alloc.c | 16 ++++++++-------- ggml-backend.c | 18 +++++++++--------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ggml-alloc.c b/ggml-alloc.c index 7ceafec309d..1fbd376edf4 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -371,16 +371,16 @@ struct ggml_gallocr { }; ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs) { - ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(sizeof(struct ggml_gallocr), 1); + ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(1, sizeof(struct ggml_gallocr)); GGML_ASSERT(galloc != NULL); - galloc->bufts = calloc(sizeof(ggml_backend_buffer_type_t) * n_bufs, 1); + galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t)); GGML_ASSERT(galloc->bufts != NULL); - galloc->buffers = calloc(sizeof(ggml_backend_buffer_t) * n_bufs, 1); + galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t) * n_bufs); GGML_ASSERT(galloc->buffers != NULL); - galloc->buf_tallocs = calloc(sizeof(struct ggml_dyn_tallocr *) * n_bufs, 1); + galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); GGML_ASSERT(galloc->buf_tallocs != NULL); for (int i = 0; i < n_bufs; i++) { @@ -646,8 +646,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c free(galloc->hash_set.keys); free(galloc->hash_values); galloc->hash_set.size = hash_size; - galloc->hash_set.keys = calloc(sizeof(struct ggml_tensor *), hash_size); - galloc->hash_values = calloc(sizeof(struct hash_node), hash_size); + galloc->hash_set.keys = calloc(hash_size, sizeof(struct ggml_tensor *)); + galloc->hash_values = calloc(hash_size, sizeof(struct hash_node)); GGML_ASSERT(galloc->hash_set.keys != NULL); GGML_ASSERT(galloc->hash_values != NULL); } else { @@ -667,7 +667,7 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c // set the node_allocs from the hash table if (galloc->n_nodes < graph->n_nodes) { free(galloc->node_allocs); - galloc->node_allocs = calloc(sizeof(struct node_alloc), graph->n_nodes); + galloc->node_allocs = calloc(graph->n_nodes, sizeof(struct node_alloc)); GGML_ASSERT(galloc->node_allocs != NULL); } galloc->n_nodes = graph->n_nodes; @@ -697,7 +697,7 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } if (galloc->n_leafs < graph->n_leafs) { free(galloc->leaf_allocs); - galloc->leaf_allocs = calloc(sizeof(galloc->leaf_allocs[0]), graph->n_leafs); + galloc->leaf_allocs = calloc(graph->n_leafs, sizeof(galloc->leaf_allocs[0])); GGML_ASSERT(galloc->leaf_allocs != NULL); } galloc->n_leafs = graph->n_leafs; diff --git a/ggml-backend.c b/ggml-backend.c index d96100b3702..2be7ad591be 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1725,23 +1725,23 @@ ggml_backend_sched_t ggml_backend_sched_new( GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU - struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1); + struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched)); // initialize hash table sched->hash_set = ggml_hash_set_new(graph_size); - sched->tensor_backend_id = calloc(sizeof(sched->tensor_backend_id[0]), sched->hash_set.size); - sched->tensor_copies = calloc(sizeof(sched->tensor_copies[0]), sched->hash_set.size); + sched->tensor_backend_id = calloc(sched->hash_set.size, sizeof(sched->tensor_backend_id[0])); + sched->tensor_copies = calloc(sched->hash_set.size, sizeof(sched->tensor_copies[0])); const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2; - sched->node_backend_ids = calloc(sizeof(sched->node_backend_ids[0]), nodes_size); - sched->leaf_backend_ids = calloc(sizeof(sched->leaf_backend_ids[0]), nodes_size); + sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); + sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); sched->n_backends = n_backends; sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; const int initial_splits_capacity = 16; - sched->splits = calloc(sizeof(sched->splits[0]), initial_splits_capacity); + sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0])); sched->splits_capacity = initial_splits_capacity; for (int b = 0; b < n_backends; b++) { @@ -1972,10 +1972,10 @@ static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_te struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { struct ggml_hash_set hash_set = { /* .size = */ graph->visited_hash_table.size, - /* .keys = */ calloc(sizeof(hash_set.keys[0]), graph->visited_hash_table.size) // NOLINT + /* .keys = */ calloc(graph->visited_hash_table.size, sizeof(hash_set.keys[0])) // NOLINT }; - struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]), hash_set.size); // NOLINT - bool * node_init = calloc(sizeof(node_init[0]), hash_set.size); + struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT + bool * node_init = calloc(hash_set.size, sizeof(node_init[0])); struct ggml_init_params params = { /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false), From 6c3971b29babd4a177e9bb7e834917019a66c2cd Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Mon, 22 Apr 2024 15:00:36 -0400 Subject: [PATCH 025/100] llamafile : improve sgemm.cpp (llama/6796) * llamafile : improve sgemm.cpp - Re-enable by default - Fix issue described in #6716 - Make code more abstract, elegant, and maintainable - Faster handling of weirdly shaped `m` an `n` edge cases * Address review comments * Help clang produce fma instructions * Address review comments --- ggml.c | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ggml.c b/ggml.c index a745104c655..b9e2150f16e 100644 --- a/ggml.c +++ b/ggml.c @@ -10887,7 +10887,7 @@ static void ggml_compute_forward_mul_mat( #endif #if GGML_USE_LLAMAFILE - if (nb10 == ggml_type_size(src1->type)) { + if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), @@ -10940,15 +10940,13 @@ UseGgmlGemm1:; const size_t row_size = ggml_row_size(vec_dot_type, ne10); #if GGML_USE_LLAMAFILE - if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) { + if (src1->type != vec_dot_type) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - (const char *)wdata + ggml_row_size(vec_dot_type, - nb12/ggml_type_size(src1->type)*i12 + - nb13/ggml_type_size(src1->type)*i13), + (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), From 63fd148d8f553286d94150b202d18023a61bf2b2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 24 Apr 2024 12:00:07 +0300 Subject: [PATCH 026/100] ggml : move 32-bit arm compat in ggml-impl.h (llama/6865) ggml-ci --- ggml-impl.h | 260 ++++++++++++++++++++++++++++++++++++++++++++- ggml-quants.c | 287 -------------------------------------------------- 2 files changed, 256 insertions(+), 291 deletions(-) diff --git a/ggml-impl.h b/ggml-impl.h index 43eb631e4c1..ae27d978999 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -52,7 +52,7 @@ extern "C" { // 16-bit float // on Arm, we use __fp16 // on x86, we use uint16_t -#if defined(__ARM_NEON) && !defined(_MSC_VER) +#if defined(__ARM_NEON) // if YCM cannot find , make a symbolic link to it, for example: // @@ -60,8 +60,262 @@ extern "C" { // #include +#ifdef _MSC_VER + +typedef uint16_t ggml_fp16_internal_t; + +#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } + +#else + typedef __fp16 ggml_fp16_internal_t; +#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } + +#endif // _MSC_VER + +#if !defined(__aarch64__) + +// 32-bit ARM compatibility + +// vaddvq_s16 +// vpaddq_s16 +// vpaddq_s32 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +// NOTE: not tested +inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +// NOTE: not tested +inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { + uint8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 +#define ggml_vqtbl1q_s8 vqtbl1q_s8 +#define ggml_vqtbl1q_u8 vqtbl1q_u8 + +#endif // !defined(__aarch64__) + +#if !defined(__ARM_FEATURE_DOTPROD) + +inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); + const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + + return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); +} + +#else + +#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) + +#endif // !defined(__ARM_FEATURE_DOTPROD) + +#endif // defined(__ARM_NEON) + +#if defined(__ARM_NEON) && !defined(__MSC_VER) + #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) @@ -82,8 +336,6 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #else -typedef uint16_t ggml_fp16_internal_t; - #ifdef __wasm_simd128__ #include #else @@ -228,7 +480,7 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #endif // __F16C__ -#endif // __ARM_NEON +#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) // precomputed f32 table for f16 (256 KB) // defined in ggml.c, initialized in ggml_init() diff --git a/ggml-quants.c b/ggml-quants.c index 4be9575e0c1..e4f96e399d4 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -20,41 +20,6 @@ #pragma warning(disable: 4244 4267) #endif -#ifdef __ARM_NEON - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#else - -#ifdef __wasm_simd128__ -#include -#else -#if defined(__POWER9_VECTOR__) || defined(__powerpc64__) -#include -#undef bool -#define bool _Bool -#else -#if defined(_MSC_VER) || defined(__MINGW32__) -#include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) -#if !defined(__riscv) -#include -#endif -#endif -#endif -#endif -#endif -#endif - -#ifdef __riscv_v_intrinsic -#include -#endif - #undef MIN #undef MAX @@ -282,258 +247,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #endif // __AVX__ || __AVX2__ || __AVX512F__ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -#if defined(__ARM_NEON) - -#ifdef _MSC_VER - -#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } - -#else - -#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } - -#endif - -#if !defined(__aarch64__) - -// 64-bit compatibility - -// vaddvq_s16 -// vpaddq_s16 -// vpaddq_s32 -// vaddvq_s32 -// vaddvq_f32 -// vmaxvq_f32 -// vcvtnq_s32_f32 -// vzip1_u8 -// vzip2_u8 - -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { - int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); - int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); - return vcombine_s32(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -inline static float vmaxvq_f32(float32x4_t v) { - return - MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), - MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); -} - -inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { - int32x4_t res; - - res[0] = roundf(vgetq_lane_f32(v, 0)); - res[1] = roundf(vgetq_lane_f32(v, 1)); - res[2] = roundf(vgetq_lane_f32(v, 2)); - res[3] = roundf(vgetq_lane_f32(v, 3)); - - return res; -} - -inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { - uint8x8_t res; - - res[0] = a[0]; res[1] = b[0]; - res[2] = a[1]; res[3] = b[1]; - res[4] = a[2]; res[5] = b[2]; - res[6] = a[3]; res[7] = b[3]; - - return res; -} - -inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { - uint8x8_t res; - - res[0] = a[4]; res[1] = b[4]; - res[2] = a[5]; res[3] = b[5]; - res[4] = a[6]; res[5] = b[6]; - res[6] = a[7]; res[7] = b[7]; - - return res; -} - -// vld1q_s16_x2 -// vld1q_u8_x2 -// vld1q_u8_x4 -// vld1q_s8_x2 -// vld1q_s8_x4 -// TODO: double-check these work correctly - -typedef struct ggml_int16x8x2_t { - int16x8_t val[2]; -} ggml_int16x8x2_t; - -inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { - ggml_int16x8x2_t res; - - res.val[0] = vld1q_s16(ptr + 0); - res.val[1] = vld1q_s16(ptr + 8); - - return res; -} - -typedef struct ggml_uint8x16x2_t { - uint8x16_t val[2]; -} ggml_uint8x16x2_t; - -inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { - ggml_uint8x16x2_t res; - - res.val[0] = vld1q_u8(ptr + 0); - res.val[1] = vld1q_u8(ptr + 16); - - return res; -} - -typedef struct ggml_uint8x16x4_t { - uint8x16_t val[4]; -} ggml_uint8x16x4_t; - -inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { - ggml_uint8x16x4_t res; - - res.val[0] = vld1q_u8(ptr + 0); - res.val[1] = vld1q_u8(ptr + 16); - res.val[2] = vld1q_u8(ptr + 32); - res.val[3] = vld1q_u8(ptr + 48); - - return res; -} - -typedef struct ggml_int8x16x2_t { - int8x16_t val[2]; -} ggml_int8x16x2_t; - -inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { - ggml_int8x16x2_t res; - - res.val[0] = vld1q_s8(ptr + 0); - res.val[1] = vld1q_s8(ptr + 16); - - return res; -} - -typedef struct ggml_int8x16x4_t { - int8x16_t val[4]; -} ggml_int8x16x4_t; - -inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { - ggml_int8x16x4_t res; - - res.val[0] = vld1q_s8(ptr + 0); - res.val[1] = vld1q_s8(ptr + 16); - res.val[2] = vld1q_s8(ptr + 32); - res.val[3] = vld1q_s8(ptr + 48); - - return res; -} - -// NOTE: not tested -inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { - int8x16_t res; - - res[ 0] = a[b[ 0]]; - res[ 1] = a[b[ 1]]; - res[ 2] = a[b[ 2]]; - res[ 3] = a[b[ 3]]; - res[ 4] = a[b[ 4]]; - res[ 5] = a[b[ 5]]; - res[ 6] = a[b[ 6]]; - res[ 7] = a[b[ 7]]; - res[ 8] = a[b[ 8]]; - res[ 9] = a[b[ 9]]; - res[10] = a[b[10]]; - res[11] = a[b[11]]; - res[12] = a[b[12]]; - res[13] = a[b[13]]; - res[14] = a[b[14]]; - res[15] = a[b[15]]; - - return res; -} - -// NOTE: not tested -inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { - uint8x16_t res; - - res[ 0] = a[b[ 0]]; - res[ 1] = a[b[ 1]]; - res[ 2] = a[b[ 2]]; - res[ 3] = a[b[ 3]]; - res[ 4] = a[b[ 4]]; - res[ 5] = a[b[ 5]]; - res[ 6] = a[b[ 6]]; - res[ 7] = a[b[ 7]]; - res[ 8] = a[b[ 8]]; - res[ 9] = a[b[ 9]]; - res[10] = a[b[10]]; - res[11] = a[b[11]]; - res[12] = a[b[12]]; - res[13] = a[b[13]]; - res[14] = a[b[14]]; - res[15] = a[b[15]]; - - return res; -} - -#else - -#define ggml_int16x8x2_t int16x8x2_t -#define ggml_uint8x16x2_t uint8x16x2_t -#define ggml_uint8x16x4_t uint8x16x4_t -#define ggml_int8x16x2_t int8x16x2_t -#define ggml_int8x16x4_t int8x16x4_t - -#define ggml_vld1q_s16_x2 vld1q_s16_x2 -#define ggml_vld1q_u8_x2 vld1q_u8_x2 -#define ggml_vld1q_u8_x4 vld1q_u8_x4 -#define ggml_vld1q_s8_x2 vld1q_s8_x2 -#define ggml_vld1q_s8_x4 vld1q_s8_x4 -#define ggml_vqtbl1q_s8 vqtbl1q_s8 -#define ggml_vqtbl1q_u8 vqtbl1q_u8 - -#endif - -#if !defined(__ARM_FEATURE_DOTPROD) - -inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { - const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); - const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); - - return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); -} - -#else - -#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) - -#endif - -#endif - #if defined(__ARM_NEON) || defined(__wasm_simd128__) #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) From a15fb5cd798a5e620d6ae184ef178c7cc8c89054 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 15:12:28 +0300 Subject: [PATCH 027/100] ggml : fix MIN / MAX macros (llama/6904) ggml-ci --- ggml-impl.h | 6 ++++++ ggml-quants.c | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml-impl.h b/ggml-impl.h index ae27d978999..3cb0a599783 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -11,6 +11,12 @@ #include // memcpy #include // fabsf +#undef MIN +#undef MAX + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + #ifdef __cplusplus extern "C" { #endif diff --git a/ggml-quants.c b/ggml-quants.c index e4f96e399d4..7d54730eecb 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -20,12 +20,6 @@ #pragma warning(disable: 4244 4267) #endif -#undef MIN -#undef MAX - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - #define UNUSED GGML_UNUSED // some compilers don't provide _mm256_set_m128i, e.g. gcc 7 From 05b17112cfd9bd9077fae52499fb9a972a7c3df9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Apr 2024 15:48:25 +0300 Subject: [PATCH 028/100] ggml : fix redefinition of vaddvq_f32 for 32-bit ARM (llama/6906) --- ggml.c | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/ggml.c b/ggml.c index b9e2150f16e..307947d2803 100644 --- a/ggml.c +++ b/ggml.c @@ -858,18 +858,6 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { // simd mappings // -#if defined(__ARM_NEON) -#if !defined(__aarch64__) - -// 64-bit compatibility - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -#endif -#endif - // we define a common set of C macros which map to specific intrinsics based on the current architecture // we then implement the fundamental computation operations below using only these macros // adding support for new architectures requires to define the corresponding SIMD macros From 6f7140f56824a81654fec2f832e29e315f5a6489 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 26 Apr 2024 10:41:53 +0300 Subject: [PATCH 029/100] Merge pull request from GHSA-p5mv-gjc5-mwqv * always use calloc clamp n_kv on failure to read a kv * ggml : alternative ctx->header.n_kv update --------- Co-authored-by: slaren --- ggml.c | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/ggml.c b/ggml.c index 307947d2803..2e929b8d2f7 100644 --- a/ggml.c +++ b/ggml.c @@ -20685,7 +20685,7 @@ static void gguf_free_kv(struct gguf_kv * kv) { } struct gguf_context * gguf_init_empty(void) { - struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context)); memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic)); ctx->header.version = GGUF_VERSION; @@ -20730,7 +20730,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p bool ok = true; - struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context)); // read the header { @@ -20767,9 +20767,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the kv pairs { - ctx->kv = GGML_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv)); + const uint64_t n_kv = ctx->header.n_kv; - for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { + // header.n_kv will hold the actual value of pairs that were successfully read in the loop below + ctx->header.n_kv = 0; + ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv)); + + for (uint64_t i = 0; i < n_kv; ++i) { struct gguf_kv * kv = &ctx->kv[i]; //fprintf(stderr, "%s: reading kv %d\n", __func__, i); @@ -20818,7 +20822,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p return NULL; } - kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * gguf_type_size(kv->value.arr.type)); + kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type)); ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset); } break; @@ -20832,7 +20836,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p return NULL; } - kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * sizeof(struct gguf_str)); + kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str)); for (uint64_t j = 0; j < kv->value.arr.n; ++j) { ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); @@ -20848,6 +20852,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p if (!ok) { break; } + + ctx->header.n_kv++; } if (!ok) { @@ -20860,7 +20866,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the tensor infos { - ctx->infos = GGML_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); + ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info)); for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { struct gguf_tensor_info * info = &ctx->infos[i]; @@ -20881,6 +20887,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); + // TODO: return an error instead of crashing with GGML_ASSERT gguf_tensor_info_sanitize(info); if (!ok) { @@ -21362,7 +21369,7 @@ void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_ty ctx->kv[idx].type = GGUF_TYPE_ARRAY; ctx->kv[idx].value.arr.type = type; ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = GGML_MALLOC(n*gguf_type_size(type)); + ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type)); memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type)); } @@ -21372,7 +21379,7 @@ void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** ctx->kv[idx].type = GGUF_TYPE_ARRAY; ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = GGML_MALLOC(n*sizeof(struct gguf_str)); + ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str)); for (int i = 0; i < n; i++) { struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i]; str->n = strlen(data[i]); @@ -21399,7 +21406,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { case GGUF_TYPE_ARRAY: { if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) { - const char ** data = GGML_MALLOC(src->kv[i].value.arr.n*sizeof(char *)); + const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *)); for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; } @@ -21487,7 +21494,7 @@ struct gguf_buf { static struct gguf_buf gguf_buf_init(size_t size) { struct gguf_buf buf = { - /*buf.data =*/ size == 0 ? NULL : GGML_MALLOC(size), + /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size), /*buf.size =*/ size, /*buf.offset =*/ 0, }; From ecfac1e240b7122266f824872afc86d4b9ba07a5 Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 26 Apr 2024 17:07:42 +0200 Subject: [PATCH 030/100] gguf : fix mismatch between alloc and free functions (llama/6929) --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 2e929b8d2f7..7f637543625 100644 --- a/ggml.c +++ b/ggml.c @@ -21058,7 +21058,7 @@ void gguf_free(struct gguf_context * ctx) { GGML_FREE(ctx->infos); } - GGML_ALIGNED_FREE(ctx); + GGML_FREE(ctx); } const char * gguf_type_name(enum gguf_type type) { From 9d4c8b8aa5d5cc46f63043eaadf674dce1cc7d6a Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 26 Apr 2024 18:39:58 +0200 Subject: [PATCH 031/100] add basic tensor data validation function (llama/6884) * add basic tensor data validation function * add --check-tensors command line argument tensor validation is disabled by default and can be enabled by adding `--check-tensors` to the command line arguments. quantize always validates tensors. --- ggml-quants.c | 284 ++++++++++++++++++++++++++++++++++++++++++++++++++ ggml.h | 2 + 2 files changed, 286 insertions(+) diff --git a/ggml-quants.c b/ggml-quants.c index 7d54730eecb..15370f1b515 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -12389,3 +12389,287 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) block_iq2_s * restrict y = vy; quantize_row_iq2_s_reference(x, y, k); } + +static bool validate_float(float f, size_t i) { + if (isinf(f)) { + fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan(f)) { + fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + +static bool isinf_fp16(ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0; +} + +static bool isnan_fp16(ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0; +} + +static bool validate_fp16(ggml_fp16_t f, size_t i) { + if (isinf_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + +#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_fp16(q[i].d, i)) { \ + return false; \ + } \ + } + +#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \ + return false; \ + } \ + } + +bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) { + if (type < 0 || type >= GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid type %d\n", __func__, type); + return false; + } + + if (nbytes % ggml_type_size(type) != 0) { + fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type); + return false; + } + + const size_t nb = nbytes/ggml_type_size(type); + + switch (type) { + case GGML_TYPE_F16: + { + const ggml_fp16_t * f = (const ggml_fp16_t *) data; + size_t i = 0; +#if defined(__AVX2__) + for (; i + 15 < nb; i += 16) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00)); + __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 16; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#elif defined(__ARM_NEON) + for (; i + 7 < nb; i += 8) { + uint16x8_t v = vld1q_u16(f + i); + uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00)); + uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_fp16(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_F32: + { + const float * f = (const float *) data; + size_t i = 0; +#if defined(__AVX2__) + for (; i + 7 < nb; i += 8) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000)); + __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#elif defined(__ARM_NEON) + for (; i + 3 < nb; i += 4) { + uint32x4_t v = vld1q_u32((const uint32_t *)f + i); + uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000)); + uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0); + if (mask) { + for (size_t j = 0; j < 4; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_float(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_F64: + { + const double * f = (const double *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_Q4_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); + } break; + case GGML_TYPE_Q4_1: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m); + } break; + case GGML_TYPE_Q5_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb); + } break; + case GGML_TYPE_Q5_1: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m); + } break; + case GGML_TYPE_Q8_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); + } break; + case GGML_TYPE_Q2_K: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin); + } break; + case GGML_TYPE_Q3_K: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb); + } break; + case GGML_TYPE_Q4_K: + { + #ifdef GGML_QKK_64 + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]); + #else + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin); + #endif + } break; + case GGML_TYPE_Q5_K: + { + #ifdef GGML_QKK_64 + VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb); + #else + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin); + #endif + } break; + case GGML_TYPE_Q6_K: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb); + } break; + case GGML_TYPE_Q8_K: + { + const block_q8_K * q = (const block_q8_K *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(q[i].d, i)) { + return false; + } + } + } break; + case GGML_TYPE_IQ1_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb); + } break; + case GGML_TYPE_IQ1_M: + { + const block_iq1_m * q = (const block_iq1_m *) data; + for (size_t i = 0; i < nb; ++i) { + #if QK_K == 64 + if (!validate_fp16(q[i].d, i)) { + return false; + } + #else + iq1m_scale_t scale; + const uint16_t * sc = (const uint16_t *)q[i].scales; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + if (!validate_fp16(scale.f16, i)) { + return false; + } + #endif + } + } break; + case GGML_TYPE_IQ2_XXS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb); + } break; + case GGML_TYPE_IQ2_XS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb); + } break; + case GGML_TYPE_IQ2_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb); + } break; + case GGML_TYPE_IQ3_XXS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb); + } break; + + case GGML_TYPE_IQ3_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb); + } break; + case GGML_TYPE_IQ4_XS: + #if QK_K != 64 + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb); + } break; + #endif + // with QK_K == 64, iq4_xs is iq4_nl + case GGML_TYPE_IQ4_NL: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_I64: + // nothing to validate + break; + default: + { + fprintf(stderr, "%s: invalid type %d\n", __func__, type); + return false; + } + } + + return true; +} diff --git a/ggml.h b/ggml.h index 6d2c8c566ec..06cafbd78ba 100644 --- a/ggml.h +++ b/ggml.h @@ -763,6 +763,8 @@ extern "C" { // use this to compute the memory overhead of a tensor GGML_API size_t ggml_tensor_overhead(void); + GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes); + // main GGML_API struct ggml_context * ggml_init(struct ggml_init_params params); From f0d3fb4a7ed4604a743c8794e08c7a716109d3a9 Mon Sep 17 00:00:00 2001 From: agray3 Date: Fri, 26 Apr 2024 19:08:30 +0100 Subject: [PATCH 032/100] Reset schedule earlier to allow overlap with ggml graph computation on device (llama/6933) * Reset schedule earlier to allow overlap with graph computation on device --- ggml-backend.c | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml-backend.c b/ggml-backend.c index 2be7ad591be..dd090a583f6 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1784,12 +1784,14 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { void ggml_backend_sched_reset(ggml_backend_sched_t sched) { // reset state for the next run - size_t hash_size = sched->hash_set.size; - memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT - memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); - memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); + if (!sched->is_reset) { + size_t hash_size = sched->hash_set.size; + memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT + memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); + memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); - sched->is_reset = true; + sched->is_reset = true; + } sched->is_alloc = false; } From 9ad202bee9c210c8c6e54fcde7975b9dbbaaae9f Mon Sep 17 00:00:00 2001 From: Neo Zhang <14088817+arthw@users.noreply.github.com> Date: Sun, 28 Apr 2024 22:40:31 +0800 Subject: [PATCH 033/100] add device version in device list (llama/6959) Co-authored-by: arthw <> --- ggml-sycl.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index a9b310243f0..2b76b3ebd64 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -13416,11 +13416,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type) version += std::to_string(prop.get_minor_version()); device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), ""); + std::string name = std::string(prop.get_name()); + name = std::regex_replace(name, std::regex("\\(R\\)"), ""); + name = std::regex_replace(name, std::regex("\\(TM\\)"), ""); - fprintf(stderr, "|%2d|%18s|%45s|%10s|%11d|%8d|%7d|%15lu|\n", id, device_type.c_str(), - prop.get_name(), version.c_str(), prop.get_max_compute_units(), + auto global_mem_size = prop.get_global_mem_size()/1000000; + + fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(), + name.c_str(), version.c_str(), prop.get_max_compute_units(), prop.get_max_work_group_size(), prop.get_max_sub_group_size(), - prop.get_global_mem_size()); + global_mem_size, device.get_info().c_str()); } void ggml_backend_sycl_print_sycl_devices() { @@ -13428,9 +13433,10 @@ void ggml_backend_sycl_print_sycl_devices() { int device_count = dpct::dev_mgr::instance().device_count(); std::map DeviceNums; fprintf(stderr, "found %d SYCL devices:\n", device_count); - fprintf(stderr, "| | | |Compute |Max compute|Max work|Max sub| |\n"); - fprintf(stderr, "|ID| Device Type| Name|capability|units |group |group |Global mem size|\n"); - fprintf(stderr, "|--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------|\n"); + fprintf(stderr, "| | | | |Max | |Max |Global | |\n"); + fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n"); + fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n"); + fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n"); for (int id = 0; id < device_count; ++id) { sycl::device device = dpct::dev_mgr::instance().get_device(id); sycl::backend backend = device.get_backend(); From 388c3462a6c38a09e403c4b2b061cd5f154edf52 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 28 Apr 2024 17:36:18 +0200 Subject: [PATCH 034/100] gguf : enforce that tensor names are unique (llama/6905) * not allow adding duplicated tensor name * no duplicated tensor while reading gguf * typo * throw exception inside llama_model_loader Co-authored-by: slaren --------- Co-authored-by: slaren --- ggml.c | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ggml.c b/ggml.c index 7f637543625..3bddcdbf28a 100644 --- a/ggml.c +++ b/ggml.c @@ -20890,6 +20890,14 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // TODO: return an error instead of crashing with GGML_ASSERT gguf_tensor_info_sanitize(info); + // make sure there is no duplicated tensor names + for (uint64_t j = 0; j < i; ++j) { + if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) { + fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data); + ok = false; + } + } + if (!ok) { fprintf(stderr, "%s: failed to read tensor info\n", __func__); fclose(file); @@ -21426,6 +21434,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { void gguf_add_tensor( struct gguf_context * ctx, const struct ggml_tensor * tensor) { + if (gguf_find_tensor(ctx, tensor->name) != -1) { + GGML_ASSERT(false && "duplicated tensor name"); + } + const int idx = ctx->header.n_tensors; ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info)); From b574646d754b4abd54c4df544e648f5232fd00be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Sun, 28 Apr 2024 18:38:44 -0400 Subject: [PATCH 035/100] Fix more int overflow during quant (PPL/CUDA). (llama/6563) * Fix more int overflow during quant. * Fix some more int overflow in softmax. * Revert back to int64_t. --- ggml-cuda/convert.cu | 168 +++++++++++++++++++++---------------------- ggml-cuda/softmax.cu | 8 +-- 2 files changed, 88 insertions(+), 88 deletions(-) diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index b15e3578267..75e50c98561 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -5,16 +5,16 @@ template static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { - const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x); + const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x); if (i >= k) { return; } const int64_t ib = i/qk; // block index - const int iqs = (i%qk)/qr; // quant index - const int iybs = i - i%qk; // y block start index - const int y_offset = qr == 1 ? 1 : qk/2; + const int64_t iqs = (i%qk)/qr; // quant index + const int64_t iybs = i - i%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; // dequantize dfloat2 v; @@ -29,7 +29,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h #if __CUDA_ARCH__ >= CC_PASCAL constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE; - const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x; + const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x; const int * x0 = ((int *) vx) + blockIdx.x * nint; half2 * y2 = (half2 *) (y + i0); @@ -73,9 +73,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t const int64_t i = blockIdx.x; // assume 32 threads - const int tid = threadIdx.x; - const int il = tid/8; - const int ir = tid%8; + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; const int64_t ib = 8*i + ir; if (ib >= nb32) { return; @@ -101,9 +101,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t const int64_t i = blockIdx.x; // assume 32 threads - const int tid = threadIdx.x; - const int il = tid/8; - const int ir = tid%8; + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; const int64_t ib = 8*i + ir; if (ib >= nb32) { return; @@ -127,14 +127,14 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t template static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_q2_K * x = (const block_q2_K *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int n = tid/32; - const int l = tid - 32*n; - const int is = 8*n + l/16; + const int64_t n = tid/32; + const int64_t l = tid - 32*n; + const int64_t is = 8*n + l/16; const uint8_t q = x[i].qs[32*n + l]; dst_t * y = yy + i*QK_K + 128*n; @@ -146,8 +146,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); #else - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); dst_t * y = yy + i*QK_K + 16*is + il; float dall = __low2half(x[i].dm); @@ -161,19 +161,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t template static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; #if QK_K == 256 - const int r = threadIdx.x/4; - const int tid = r/2; - const int is0 = r%2; - const int l0 = 16*is0 + 4*(threadIdx.x%4); - const int n = tid / 4; - const int j = tid - 4*n; + const int64_t r = threadIdx.x/4; + const int64_t tid = r/2; + const int64_t is0 = r%2; + const int64_t l0 = 16*is0 + 4*(threadIdx.x%4); + const int64_t n = tid / 4; + const int64_t j = tid - 4*n; uint8_t m = 1 << (4*n + j); - int is = 8*n + 2*j + is0; + int64_t is = 8*n + 2*j + is0; int shift = 2*j; int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : @@ -189,11 +189,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); #else - const int tid = threadIdx.x; - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 - const int im = il/8; // 0...1 - const int in = il%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 + const int64_t im = il/8; // 0...1 + const int64_t in = il%8; // 0...7 dst_t * y = yy + i*QK_K + 16*is + il; @@ -227,15 +227,15 @@ template static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q4_K * x = (const block_q4_K *) vx; - const int i = blockIdx.x; + const int64_t i = blockIdx.x; #if QK_K == 256 // assume 32 threads - const int tid = threadIdx.x; - const int il = tid/8; - const int ir = tid%8; - const int is = 2*il; - const int n = 4; + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t is = 2*il; + const int64_t n = 4; dst_t * y = yy + i*QK_K + 64*il + n*ir; @@ -254,7 +254,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t y[l +32] = d2 * (q[l] >> 4) - m2; } #else - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; const uint8_t * q = x[i].qs; dst_t * y = yy + i*QK_K; const float d = (float)x[i].dm[0]; @@ -268,14 +268,14 @@ template static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q5_K * x = (const block_q5_K *) vx; - const int i = blockIdx.x; + const int64_t i = blockIdx.x; #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int il = tid/16; // il is in 0...3 - const int ir = tid%16; // ir is in 0...15 - const int is = 2*il; // is is in 0...6 + const int64_t tid = threadIdx.x; + const int64_t il = tid/16; // il is in 0...3 + const int64_t ir = tid%16; // ir is in 0...15 + const int64_t is = 2*il; // is is in 0...6 dst_t * y = yy + i*QK_K + 64*il + 2*ir; @@ -298,11 +298,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; #else - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; const uint8_t q = x[i].qs[tid]; - const int im = tid/8; // 0...3 - const int in = tid%8; // 0...7 - const int is = tid/16; // 0 or 1 + const int64_t im = tid/8; // 0...3 + const int64_t in = tid%8; // 0...7 + const int64_t is = tid/16; // 0 or 1 const uint8_t h = x[i].qh[in] >> im; const float d = x[i].d; dst_t * y = yy + i*QK_K + tid; @@ -359,13 +359,13 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t template static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq2_xxs * x = (const block_iq2_xxs *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * aux8 = (const uint8_t *)q2; @@ -383,13 +383,13 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds template static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq2_xs * x = (const block_iq2_xs *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); @@ -405,13 +405,13 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst template static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq2_s * x = (const block_iq2_s *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; @@ -426,13 +426,13 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq3_xxs * x = (const block_iq3_xxs *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * q3 = x[i].qs + 8*ib; const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; @@ -454,13 +454,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq3_s * x = (const block_iq3_s *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); @@ -480,13 +480,13 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq1_s * x = (const block_iq1_s *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1); @@ -506,18 +506,18 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq1_m * x = (const block_iq1_m *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * sc = (const uint16_t *)x[i].scales; iq1m_scale_t scale; scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); + const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1); const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA; uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; @@ -537,12 +537,12 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); - const int tid = threadIdx.x; - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 4*il; const uint8_t * q4 = x[ib].qs + 4*il; const float d = (float)x[ib].d; @@ -556,12 +556,12 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst #if QK_K != 64 template static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq4_xs * x = (const block_iq4_xs *)vx; - const int tid = threadIdx.x; - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 4*il; const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 9bda18e581c..fa8f987cf7c 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -28,7 +28,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols; + float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols; float max_val = -INFINITY; @@ -40,8 +40,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f break; } - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; + const int64_t ix = (int64_t)rowx*ncols + col; + const int64_t iy = (int64_t)rowy*ncols + col; const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); @@ -109,7 +109,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f return; } - const int idst = rowx*ncols + col; + const int64_t idst = (int64_t)rowx*ncols + col; dst[idst] = vals[col] * inv_sum; } } From 5167ebdfcaac7c26524d425afe980b5d13be1c46 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Apr 2024 17:55:02 +0300 Subject: [PATCH 036/100] ggml : fix __MSC_VER -> _MSC_VER (llama/6977) ggml-ci --- ggml-impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-impl.h b/ggml-impl.h index 3cb0a599783..c4be87c29e2 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -320,7 +320,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif // defined(__ARM_NEON) -#if defined(__ARM_NEON) && !defined(__MSC_VER) +#if defined(__ARM_NEON) && !defined(_MSC_VER) #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) From 156a33a9904ef1617f08e72c46c2381e21bdfd92 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Apr 2024 12:16:08 +0300 Subject: [PATCH 037/100] ggml : add Flash Attention (llama/5021) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (llama/6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (llama/6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler Co-authored-by: Pierrick HYMBERT --- ggml-cuda.cu | 6 + ggml-cuda/common.cuh | 40 +- ggml-cuda/fattn.cu | 944 +++++++++++++++++++++++++++++++++++++++++++ ggml-cuda/fattn.cuh | 3 + ggml-cuda/softmax.cu | 46 ++- ggml-kompute.cpp | 7 + ggml-metal.m | 549 +++++++++++++++++-------- ggml-metal.metal | 672 +++++++++++++++++++++++++++++- ggml-sycl.cpp | 6 +- ggml-vulkan.cpp | 5 + ggml.c | 375 ++++++++++++++++- ggml.h | 20 + 12 files changed, 2438 insertions(+), 235 deletions(-) create mode 100644 ggml-cuda/fattn.cu create mode 100644 ggml-cuda/fattn.cuh diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 07534370c34..fa56f9521e4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -14,6 +14,7 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/dmmv.cuh" +#include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" @@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { @@ -2293,6 +2295,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -2568,6 +2573,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 481065b2a34..156eba6d1ef 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -142,6 +142,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) @@ -271,7 +272,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -284,7 +284,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -294,19 +293,26 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} +#if CUDART_VERSION < 12000 +static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { + const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); + const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); + return mask_low | mask_high; +} +#endif // CUDART_VERSION < 12000 #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -391,6 +397,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL + +#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA + // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -404,6 +415,7 @@ struct ggml_cuda_device_info { struct cuda_device_info { int cc; // compute capability + int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu new file mode 100644 index 00000000000..df1e80068b3 --- /dev/null +++ b/ggml-cuda/fattn.cu @@ -0,0 +1,944 @@ +#include "common.cuh" +#include "fattn.cuh" + +#include + +#if FP16_MMA_AVAILABLE +#include +#endif + +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. + +template // D == head size +__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + ne11*ic; + + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); + + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nwarps*WARP_SIZE); + + __shared__ half KQ[nwarps*WARP_SIZE]; + KQ[tid] = -INFINITY; + half2 * KQ2 = (half2 *) KQ; + + half kqmax = -HALF_MAX_HALF; + half kqsum = 0.0f; + + __shared__ half kqmax_shared[WARP_SIZE]; + __shared__ half kqsum_shared[WARP_SIZE]; + if (threadIdx.y == 0) { + kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; + kqsum_shared[threadIdx.x] = 0.0f; + } + __syncthreads(); + + // Convert Q to half2 and store in registers: + half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y); + } + + half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + half kqmax_new = kqmax; +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + + half2 sum2 = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) { + break; + } + + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + } + + sum2 = warp_reduce_sum(sum2); + half sum = __low2half(sum2) + __high2half(sum2); + sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); + kqmax_new = __hmax(kqmax_new, sum); + if (threadIdx.x == 0) { + KQ[i_KQ] = sum; + } + } + + kqmax_new = warp_reduce_max(kqmax_new); + if (threadIdx.x == 0) { + kqmax_shared[threadIdx.y] = kqmax_new; + } + __syncthreads(); + kqmax_new = kqmax_shared[threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + + const half KQ_max_scale = hexp(kqmax - kqmax_new); + kqmax = kqmax_new; + + const half val = hexp(KQ[tid] - kqmax); + kqsum = kqsum*KQ_max_scale + val; + KQ[tid] = val; + + VKQ *= __half2half2(KQ_max_scale); + + __syncthreads(); + + if (tid < D) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } + + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; + } + } + + __syncthreads(); + } + + if (tid >= D) { + kqsum = 0.0f; + } + + kqsum = warp_reduce_sum(kqsum); + if (threadIdx.x == 0) { + kqsum_shared[threadIdx.y] = kqsum; + } + __syncthreads(); + kqsum = kqsum_shared[threadIdx.x]; + kqsum = warp_reduce_sum(kqsum); + + if (tid >= D) { + return; + } + + half dst_val = (__low2half(VKQ) + __high2half(VKQ)); + if (parallel_blocks == 1) { + dst_val /= kqsum; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; + + if (parallel_blocks == 1 || tid != 0) { + return; + } + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template +__launch_bounds__(nwarps*WARP_SIZE, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + frag_b Q_b[D/16][ncols/frag_n]; + + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded*kqar; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; + half2 * KQ2 = (half2 *) KQ; + + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } + + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + + __syncthreads(); + + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c_KQ KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + } + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + } + + float KQ_max_new = KQ_max_f[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; + + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } + } + + __syncthreads(); + + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); + } + } + + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + } + +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + } + } + } + + __syncthreads(); + + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + float dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + } + + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; + } + + float2 dst_meta_val; + if (std::is_same::value) { + dst_meta_val.x = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); + } + dst_meta_val.y = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + } +#else + NO_DEVICE_CODE; +#endif // FP16_MMA_AVAILABLE +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { +#if FP16_AVAILABLE + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (parallel_blocks == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16_impl( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream +) { + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + + if (4*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); +} + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + ggml_cuda_set_device(ctx.device); + + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int32_t precision = KQV->op_params[1]; + + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + } else { + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + // case 256: + // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + // break; + default: + GGML_ASSERT(false); + break; + } + } + return; + } + + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; +} diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh new file mode 100644 index 00000000000..ad3ca7a8d8e --- /dev/null +++ b/ggml-cuda/fattn.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index fa8f987cf7c..6ed225999bd 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,7 +1,17 @@ #include "softmax.cuh" -template -static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +template +static __device__ __forceinline__ float t2f32(T val) { + return (float) val; +} + +template <> +__device__ float __forceinline__ t2f32(half val) { + return __half2float(val); +} + +template +static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f } } -static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float * void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; - const float * src1_d = src1 ? (const float *)src1->data : nullptr; + const void * src1_d = src1 ? (const void *)src1->data : nullptr; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - float * src2_dd = nullptr; + void * src2_d = nullptr; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - src2_dd = (float *)src2->data; + src2_d = (void *)src2->data; } - soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + + if (use_f16) { + const half * src1_dd = (const half *)src1_d; + const half * src2_dd = (const half *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } else { + const float * src1_dd = (const float *)src1_d; + const float * src2_dd = (const float *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } } diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 407062e6fd4..9a469821d80 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml for (int i = node_start; i < node_end; ++i) { struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2); struct ggml_tensor * dst = gf->nodes[i]; GGML_ASSERT(dst->data != nullptr); @@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml { float scale; memcpy(&scale, dst->op_params, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); + GGML_ASSERT(src2 == nullptr); + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; case GGML_OP_DIAG_MASK_INF: diff --git a/ggml-metal.m b/ggml-metal.m index fdba0de85bc..71b8a099b7e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -47,8 +47,10 @@ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -178,6 +180,14 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -444,7 +454,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ } /* - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ */ @@ -460,173 +470,183 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ return NULL; \ } \ } else { \ - GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \ + GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } [metal_library release]; @@ -746,6 +766,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -1341,20 +1362,33 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); + int nth = 32; // SIMD width id pipeline = nil; + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } } float scale; @@ -2518,6 +2552,161 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = nil; + + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -2721,10 +2910,13 @@ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buff UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2734,10 +2926,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2771,8 +2968,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + //ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2859,7 +3055,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2882,7 +3078,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2891,8 +3088,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 7f37c17d668..4d710b04fa2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -359,11 +359,12 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +template kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device const float * src2, - device float * dst, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -382,10 +383,10 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device const float * ppos = src2 != src0 ? src2 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); float slope = 0.0f; @@ -463,11 +464,12 @@ kernel void kernel_soft_max( } } +template kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device const float * src2, - device float * dst, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -486,10 +488,10 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; float slope = 0.0f; @@ -506,7 +508,7 @@ kernel void kernel_soft_max_4( float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -532,7 +534,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -569,6 +571,14 @@ kernel void kernel_soft_max_4( } } +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, @@ -2091,6 +2101,632 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TF + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; + + // v indices + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; + + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[D8]; + + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // prepare diagonal scale matrix + simdgroup_float8x8 mscale(scale); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + // mqk = mqk*scale + mask + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + } + } + + // used to detect blocks full of -INF + float smax = -INFINITY; + + // online softmax + { + float ms[Q]; + + for (short j = 0; j < Q; ++j) { + const short p = tiisg; + + const float m = M[j]; + const float s = ss[j*TF + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TF + p] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*TF + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + { + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + C, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (short sg = 1; sg < nsg; ++sg) { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; + + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; + + M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; + + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; + + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = ss[j*TF + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + float4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += (float4) (mq[i] * mk); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask + if (tiisg == 0) { + float4 mm = (float4) mp4[ic/4 + cc]; + mqk = mqk*scale + mm; + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 2b76b3ebd64..57fe4ea3d4a 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -14744,7 +14744,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + const ggml_tensor * src2 = dst->src[2]; + +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -14760,7 +14765,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, float * src2_dd = nullptr; sycl_pool_alloc src2_f; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 1736ab7361c..f712cdd5a90 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SOFT_MAX: +#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } diff --git a/ggml.c b/ggml.c index 3bddcdbf28a..00f3e170a16 100644 --- a/ggml.c +++ b/ggml.c @@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL @@ -1046,7 +1046,7 @@ do { \ // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F // so F16C guard isn't required -#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x))) +#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) @@ -1144,7 +1144,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -2001,6 +2061,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "SSM_CONV", @@ -2027,7 +2088,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2091,6 +2152,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "ssm_conv(x)", @@ -2117,7 +2179,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4575,6 +4637,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5413,17 +5477,23 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[0] == a->ne[0]); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F32); + GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); GGML_ASSERT(pos->ne[0] == a->ne[0]); } + if (pos && mask) { + GGML_ASSERT(pos->type == mask->type); + } + if (max_bias > 0.0f) { GGML_ASSERT(pos); } @@ -6232,6 +6302,59 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -12317,7 +12440,7 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -12340,19 +12463,31 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - float * pos = src2 ? (float *) src2->data : src0->data; + ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; + float * pos_f32 = src2 ? (float *) src2->data : src0->data; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); - if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += mp_f32[i]; + } + } } // ALiBi bias @@ -12360,8 +12495,14 @@ static void ggml_compute_forward_soft_max_f32( const uint32_t h = (i1/ne01)%ne02; // head const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*pos[i]; + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*pos_f32[i]; + } } } @@ -14631,6 +14772,198 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + if (params->type == GGML_TASK_TYPE_INIT) { + return; + } + + if (params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + // loop over n_batch and n_head + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float S = 0.0f; + float M = -INFINITY; + + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); + + memset(V16, 0, D*sizeof(ggml_fp16_t)); + + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s; + + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + + ggml_vec_dot_f16(D, + &s, 0, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, + Q16, 0, 1); + + s = s*scale + mv; + + const float Mold = M; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); + } + + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); + + S = S*ms + vs; + } + + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; + } + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: + case GGML_PREC_F32: + { + // uses F32 accumulators + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -16442,6 +16775,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor); @@ -17454,6 +17791,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -18231,6 +18569,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -18634,6 +18973,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index 06cafbd78ba..d90ba8ed664 100644 --- a/ggml.h +++ b/ggml.h @@ -475,6 +475,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, @@ -1731,6 +1732,25 @@ extern "C" { struct ggml_tensor * v, bool masked); +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, From 6c39ea46b6245762493e5348a2dc0162a511b88a Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Tue, 30 Apr 2024 02:34:50 -0700 Subject: [PATCH 038/100] metal : log more info on error (llama/6987) --- ggml-metal.m | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/ggml-metal.m b/ggml-metal.m index 71b8a099b7e..43752f7295d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2794,6 +2794,45 @@ static enum ggml_status ggml_metal_graph_compute( MTLCommandBufferStatus status = [command_buffer status]; if (status != MTLCommandBufferStatusCompleted) { GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + MTLCommandBufferError error_code = [command_buffer error].code; + switch (error_code) { + case MTLCommandBufferErrorNone: + GGML_METAL_LOG_INFO("no error code reported\n"); + break; + case MTLCommandBufferErrorTimeout: + GGML_METAL_LOG_INFO("timeout\n"); + break; + case MTLCommandBufferErrorPageFault: + GGML_METAL_LOG_INFO("unserviceable page fault\n"); + break; + case MTLCommandBufferErrorOutOfMemory: + GGML_METAL_LOG_INFO("out of memory\n"); + break; + case MTLCommandBufferErrorInvalidResource: + GGML_METAL_LOG_INFO("invalid reference to resource\n"); + break; + case MTLCommandBufferErrorMemoryless: + GGML_METAL_LOG_INFO("GPU ran out of one or more of its internal resources that support memoryless render pass attachments\n"); + break; + case MTLCommandBufferErrorDeviceRemoved: + GGML_METAL_LOG_INFO("device removed\n"); + break; + case MTLCommandBufferErrorStackOverflow: + GGML_METAL_LOG_INFO("kernel function of tile shader used too many stack frames\n"); + break; + case MTLCommandBufferErrorAccessRevoked: + GGML_METAL_LOG_INFO("access to device revoked by system\n"); + break; + case MTLCommandBufferErrorInternal: + GGML_METAL_LOG_INFO("internal error\n"); + break; + default: + GGML_METAL_LOG_INFO("unknown error %lu\n", error_code); + break; + } + } + return GGML_STATUS_FAILED; } } From 1bce67999d6d9f5dee074519273301ef63c6abc4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Apr 2024 15:52:21 +0300 Subject: [PATCH 039/100] metal : remove deprecated error code (llama/7008) --- ggml-metal.m | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 43752f7295d..160d5c399eb 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2815,9 +2815,9 @@ static enum ggml_status ggml_metal_graph_compute( case MTLCommandBufferErrorMemoryless: GGML_METAL_LOG_INFO("GPU ran out of one or more of its internal resources that support memoryless render pass attachments\n"); break; - case MTLCommandBufferErrorDeviceRemoved: - GGML_METAL_LOG_INFO("device removed\n"); - break; + //case MTLCommandBufferErrorDeviceRemoved: + // GGML_METAL_LOG_INFO("device removed\n"); + // break; case MTLCommandBufferErrorStackOverflow: GGML_METAL_LOG_INFO("kernel function of tile shader used too many stack frames\n"); break; From c754494fddf81716f034230b180cc4233a9232d3 Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Tue, 30 Apr 2024 08:14:02 -0700 Subject: [PATCH 040/100] switch to using localizedDescription (llama/7010) --- ggml-metal.m | 38 ++------------------------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 160d5c399eb..ee579a4229b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2795,42 +2795,8 @@ static enum ggml_status ggml_metal_graph_compute( if (status != MTLCommandBufferStatusCompleted) { GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); if (status == MTLCommandBufferStatusError) { - MTLCommandBufferError error_code = [command_buffer error].code; - switch (error_code) { - case MTLCommandBufferErrorNone: - GGML_METAL_LOG_INFO("no error code reported\n"); - break; - case MTLCommandBufferErrorTimeout: - GGML_METAL_LOG_INFO("timeout\n"); - break; - case MTLCommandBufferErrorPageFault: - GGML_METAL_LOG_INFO("unserviceable page fault\n"); - break; - case MTLCommandBufferErrorOutOfMemory: - GGML_METAL_LOG_INFO("out of memory\n"); - break; - case MTLCommandBufferErrorInvalidResource: - GGML_METAL_LOG_INFO("invalid reference to resource\n"); - break; - case MTLCommandBufferErrorMemoryless: - GGML_METAL_LOG_INFO("GPU ran out of one or more of its internal resources that support memoryless render pass attachments\n"); - break; - //case MTLCommandBufferErrorDeviceRemoved: - // GGML_METAL_LOG_INFO("device removed\n"); - // break; - case MTLCommandBufferErrorStackOverflow: - GGML_METAL_LOG_INFO("kernel function of tile shader used too many stack frames\n"); - break; - case MTLCommandBufferErrorAccessRevoked: - GGML_METAL_LOG_INFO("access to device revoked by system\n"); - break; - case MTLCommandBufferErrorInternal: - GGML_METAL_LOG_INFO("internal error\n"); - break; - default: - GGML_METAL_LOG_INFO("unknown error %lu\n", error_code); - break; - } + NSString * error_code = [command_buffer error].localizedDescription; + GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]); } return GGML_STATUS_FAILED; From 11c1df0436d0bccad6148212da4e26db3ae23149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 1 May 2024 14:46:37 +0200 Subject: [PATCH 041/100] CUDA: CUDART < 11.7 workaround for __hmax, __hmax2 (llama/7019) --- ggml-cuda/common.cuh | 45 +++++++++++++++++++++++++++++++++++++++----- ggml-cuda/fattn.cu | 6 +++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 156eba6d1ef..b2627b7b4b7 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -137,7 +137,8 @@ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) #define WARP_SIZE 32 -#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) +#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) +#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products @@ -293,20 +294,54 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + +#if CUDART_VERSION >= CUDART_HMAX + return __hmax(a, b); +#else + return __half2float(a) > __half2float(b) ? a : b; +#endif // CUDART_VERSION >= CUDART_HMAX + +#else + GGML_UNUSED(a); + GGML_UNUSED(b); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX +} +static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + +#if CUDART_VERSION >= CUDART_HMAX + return __hmax2(a, b); +#else + half2 ret; + reinterpret_cast(ret.x) = __low2float(a) > __low2float(b) ? __low2half(a) : __low2half(b); + reinterpret_cast(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b); + return ret; +#endif // CUDART_VERSION >= CUDART_HMAX + +#else + GGML_UNUSED(a); + GGML_UNUSED(b); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX +} + static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { - x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); } return x; #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#if CUDART_VERSION < 12000 +#if CUDART_VERSION < CUDART_HMASK static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index df1e80068b3..c8a11d17334 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -116,7 +116,7 @@ static __global__ void flash_attn_vec_ext_f16( sum2 = warp_reduce_sum(sum2); half sum = __low2half(sum2) + __high2half(sum2); sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); - kqmax_new = __hmax(kqmax_new, sum); + kqmax_new = ggml_cuda_hmax(kqmax_new, sum); if (threadIdx.x == 0) { KQ[i_KQ] = sum; } @@ -416,9 +416,9 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } - KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); From 9b84195225480516e77101830f6abd5b35340f32 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 4 May 2024 18:56:22 +0200 Subject: [PATCH 042/100] gguf-split: add --no-tensor-first-split (llama/7072) --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 00f3e170a16..84b76fcb969 100644 --- a/ggml.c +++ b/ggml.c @@ -21210,7 +21210,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p } // read the tensor infos - { + if (ctx->header.n_tensors > 0) { ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info)); for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { From b5521fea1988e8110d9fdc1c17f0f98abff0346a Mon Sep 17 00:00:00 2001 From: William Tambellini Date: Mon, 6 May 2024 11:12:14 -0700 Subject: [PATCH 043/100] Add an option to build without CUDA VMM (llama/7067) Add an option to build ggml cuda without CUDA VMM resolves https://github.com/ggerganov/llama.cpp/issues/6889 https://forums.developer.nvidia.com/t/potential-nvshmem-allocated-memory-performance-issue/275416/4 --- ggml-cuda.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index fa56f9521e4..8739baa2a79 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -113,7 +113,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -259,7 +259,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -356,7 +356,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { #endif // !defined(GGML_USE_HIPBLAS) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } From 1ae1a9cd56fbe5cf5390e35da8e0fdc898589f2b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 8 May 2024 09:14:50 +0300 Subject: [PATCH 044/100] metal : fix unused warning --- ggml-metal.metal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 4d710b04fa2..b67d1882f00 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2182,7 +2182,7 @@ kernel void kernel_flash_attn_ext_f16( const short D4 = D/4; const short D8 = D/8; - const short Q8 = Q/8; + //const short Q8 = Q/8; const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) From a2ad810118bf31811d66b2379e308d979d584dc3 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Wed, 8 May 2024 02:30:09 -0400 Subject: [PATCH 045/100] ggml : introduce bfloat16 support (llama/6412) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduce bfloat16 support Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as their canonical floating point format. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───┐ 0b0000000000000000 brain16 This encoding has the same number of exponent bits as float32. That makes conversion relatively straightforward, even in the absence of hardware support. For example, converting brain16 to binary32 means simply shifting 16 bits to the left. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───────────────────┐ 0b00000000000000000000000000000000 IEEE binary32 The issue is that converting bf16 to fp16 can result in information loss. Only 13% of bf16 numbers can be precisely represented in fp16 which in practice ends up being 99.71% of Mistral 7b v0.2's weights however there is currently no way other than fp32 to get the others ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌─┴─┐┌─┴──────┐ 0b0000000000000000 IEEE binary16 This change fixes that, by adding a bf16 data type to GGML. Support for CPU inference has been implemented along with optimizations for the AVX2, AVX512, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2 improves somewhere around -0.0024 to -0.0046 compared to using fp16 * Remove GGML code that's not needed * Minimize the GGML API surface area for BF16 * Remove bf16 luts * Make the GGML header look nicer * Fix documentation * Apply ggerganov's fixes for test-backend-ops * Add BF16 code for new ggml_validate_row_data() function --- ggml-impl.h | 77 ++++ ggml-metal.m | 2 +- ggml-quants.c | 18 + ggml.c | 1177 +++++++++++++++++++++++++++++++++++++++++++++---- ggml.h | 22 +- 5 files changed, 1200 insertions(+), 96 deletions(-) diff --git a/ggml-impl.h b/ggml-impl.h index c4be87c29e2..59684fa81f0 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -17,6 +17,83 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) +/** + * Converts brain16 to float32. + * + * The bfloat16 floating point format has the following structure: + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───┐ + * 0b0000000000000000 brain16 + * + * Since bf16 has the same number of exponent bits as a 32bit float, + * encoding and decoding numbers becomes relatively straightforward. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌──┴───┐┌─┴───────────────────┐ + * 0b00000000000000000000000000000000 IEEE binary32 + * + * For comparison, the standard fp16 format has fewer exponent bits. + * + * ┌sign + * │ + * │ ┌exponent + * │ │ + * │ │ ┌mantissa + * │ │ │ + * │┌─┴─┐┌─┴──────┐ + * 0b0000000000000000 IEEE binary16 + * + * @see IEEE 754-2008 + */ +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +/** + * Converts float32 to brain16. + * + * This function is binary identical to AMD Zen4 VCVTNEPS2BF16. + * Subnormals shall be flushed to zero, and NANs will be quiet. + * This code should vectorize nicely if using modern compilers. + */ +static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { + ggml_bf16_t h; + union { + float f; + uint32_t i; + } u; + u.f = s; + if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */ + h.bits = (u.i >> 16) | 64; /* force to quiet */ + return h; + } + if (!(u.i & 0x7f800000)) { /* subnormal */ + h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */ + return h; + } + h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16; + return h; +} + +#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) +#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) + #ifdef __cplusplus extern "C" { #endif diff --git a/ggml-metal.m b/ggml-metal.m index ee579a4229b..4d4b9717812 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -806,7 +806,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_DIAG_MASK_INF: case GGML_OP_GET_ROWS: { - return op->ne[3] == 1; + return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1; } default: return false; diff --git a/ggml-quants.c b/ggml-quants.c index 15370f1b515..00334c5feb3 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -12456,6 +12456,24 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte const size_t nb = nbytes/ggml_type_size(type); switch (type) { + case GGML_TYPE_BF16: + { + int nans = 0; + int infs = 0; + const unsigned short * f = (const unsigned short *) data; + for (size_t i = 0; i < nb; ++i) { + nans += (f[i] & 0x7fff) > 0x7f80; + infs += (f[i] & 0x7fff) == 0x7f80; + } + if (nans) { + fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb); + return false; + } + if (infs) { + fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb); + return false; + } + } break; case GGML_TYPE_F16: { const ggml_fp16_t * f = (const ggml_fp16_t *) data; diff --git a/ggml.c b/ggml.c index 84b76fcb969..118d3f541f4 100644 --- a/ggml.c +++ b/ggml.c @@ -322,7 +322,7 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) (ggml-impl.h) float ggml_table_f32_f16[1 << 16]; -const char * ggml_status_to_string(enum ggml_status status) { +GGML_CALL const char * ggml_status_to_string(enum ggml_status status) { switch (status) { case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; case GGML_STATUS_FAILED: return "GGML status: error (operation failed)"; @@ -333,16 +333,26 @@ const char * ggml_status_to_string(enum ggml_status status) { return "GGML status: unknown"; } -// note: do not use these inside ggml.c -// these are meant to be used via the ggml.h API float ggml_fp16_to_fp32(ggml_fp16_t x) { +#define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml return GGML_FP16_TO_FP32(x); } ggml_fp16_t ggml_fp32_to_fp16(float x) { +#define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml return GGML_FP32_TO_FP16(x); } +float ggml_bf16_to_fp32(ggml_bf16_t x) { +#define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml + return GGML_BF16_TO_FP32(x); // it just left shifts +} + +ggml_bf16_t ggml_fp32_to_bf16(float x) { +#define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml + return GGML_FP32_TO_BF16(x); +} + void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) { for (int64_t i = 0; i < n; i++) { y[i] = GGML_FP16_TO_FP32(x[i]); @@ -368,6 +378,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { } } +void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) { + int64_t i = 0; +#if defined(__AVX512F__) + for (; i + 16 <= n; i += 16) { + _mm512_storeu_ps(y + i, + _mm512_castsi512_ps( + _mm512_slli_epi32( + _mm512_cvtepu16_epi32( + _mm256_loadu_si256( + (const __m256i *)(x + i))), + 16))); + } +#elif defined(__AVX2__) + for (; i + 8 <= n; i += 8) { + _mm256_storeu_ps(y + i, + _mm256_castsi256_ps( + _mm256_slli_epi32( + _mm256_cvtepu16_epi32( + _mm_loadu_si128( + (const __m128i *)(x + i))), + 16))); + } +#endif + for (; i < n; i++) { + y[i] = GGML_BF16_TO_FP32(x[i]); + } +} + +void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) { + int i = 0; +#if defined(__AVX512BF16__) + for (; i + 32 <= n; i += 32) { + _mm512_storeu_ps( + (__m512 *)(y + i), + (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), + _mm512_loadu_ps(x + i))); + } +#endif + for (; i < n; i++) { + y[i] = GGML_FP32_TO_BF16(x[i]); + } +} + bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) { return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0; } @@ -503,6 +556,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc); static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc); +static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { @@ -845,6 +899,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_K), .is_quantized = true, .from_float = quantize_row_q8_K, + }, + [GGML_TYPE_BF16] = { + .type_name = "bf16", + .blck_size = 1, + .type_size = sizeof(ggml_bf16_t), + .is_quantized = false, + .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, + .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, + .vec_dot_type = GGML_TYPE_BF16, + .nrows = 1, } }; @@ -1480,6 +1546,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } @@ -1498,7 +1566,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * UNUSED(by); UNUSED(bs); -#ifdef GGML_SIMD +#if defined(GGML_SIMD) float sumf = 0.0f; const int np = (n & ~(GGML_F32_STEP - 1)); @@ -1534,6 +1602,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * *s = sumf; } +static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + int i = 0; + ggml_float sumf = 0; + +#if defined(__AVX512BF16__) + __m512 c1 = _mm512_setzero_ps(); + __m512 c2 = _mm512_setzero_ps(); + for (; i + 64 <= n; i += 64) { + c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)), + (__m512bh)_mm512_loadu_ps((const float *)(y + i))); + c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)), + (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32))); + } + sumf += (ggml_float)_mm512_reduce_add_ps(c1); + sumf += (ggml_float)_mm512_reduce_add_ps(c2); + +#elif defined(__AVX512F__) +#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16)) + __m512 c1 = _mm512_setzero_ps(); + __m512 c2 = _mm512_setzero_ps(); + for (; i + 32 <= n; i += 32) { + c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1); + c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2); + } + sumf += (ggml_float)_mm512_reduce_add_ps(c1); + sumf += (ggml_float)_mm512_reduce_add_ps(c2); + +#undef LOAD +#elif defined(__AVX2__) +#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)) + __m256 c1 = _mm256_setzero_ps(); + __m256 c2 = _mm256_setzero_ps(); + __m256 c3 = _mm256_setzero_ps(); + __m256 c4 = _mm256_setzero_ps(); + for (; i + 32 <= n; i += 32) { + c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1); + c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2); + c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3); + c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4); + } + __m128 g; + c1 = _mm256_add_ps(_mm256_add_ps(c1, c3), + _mm256_add_ps(c2, c4)); + g = _mm_add_ps(_mm256_extractf128_ps(c1, 1), + _mm256_castps256_ps128(c1)); + g = _mm_add_ps(g, _mm_movehl_ps(g, g)); + g = _mm_add_ss(g, _mm_movehdup_ps(g)); + sumf += (ggml_float)_mm_cvtss_f32(g); + +#undef LOAD +#endif + + for (; i < n; ++i) { + sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * + GGML_BF16_TO_FP32(y[i])); + } + *s = sumf; +} + static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -1968,6 +2100,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_ *s = sum; } +inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_BF16_TO_FP32(x[i]); + } + *s = sum; +} + inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE float max = -INFINITY; @@ -2379,7 +2519,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) { // figure out which node we're on uint current_cpu; int getcpu_ret = 0; -#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) +#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__) getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); #else // old glibc doesn't have a wrapper for this call. Fall back on direct syscall @@ -2590,6 +2730,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { switch (ftype) { case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; + case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; @@ -2731,15 +2872,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - ggml_fp16_t ii; for (int i = 0; i < (1 << 16); ++i) { - uint16_t ui = i; - memcpy(&ii, &ui, sizeof(ii)); - const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); + union { + uint16_t u16; + ggml_fp16_t fp16; + } u = {i}; + float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); - ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); + ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); @@ -3203,6 +3345,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); } } break; + case GGML_TYPE_BF16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value)); + } + } break; case GGML_TYPE_F32: { assert(tensor->nb[0] == sizeof(float)); @@ -3255,6 +3404,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); } } break; + case GGML_TYPE_BF16: + { + assert(tensor->nb[0] == sizeof(ggml_bf16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value)); + } + } break; case GGML_TYPE_F32: { assert(tensor->nb[0] == sizeof(float)); @@ -3322,6 +3478,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); } + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); + } case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); @@ -3364,6 +3525,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); + } break; case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); @@ -3387,6 +3553,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i return ((int32_t *) data)[0]; case GGML_TYPE_F16: return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_BF16: + return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]); case GGML_TYPE_F32: return ((float *) data)[0]; default: @@ -3415,6 +3583,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, { ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); } break; + case GGML_TYPE_BF16: + { + ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value); + } break; case GGML_TYPE_F32: { ((float *)(data))[0] = value; @@ -3453,6 +3625,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); } + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); + } case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); @@ -3495,6 +3672,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); + } break; case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); @@ -3518,6 +3700,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, return ((int32_t *) data)[0]; case GGML_TYPE_F16: return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_BF16: + return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]); case GGML_TYPE_F32: return ((float *) data)[0]; default: @@ -3546,6 +3730,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, { ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); } break; + case GGML_TYPE_BF16: + { + ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value); + } break; case GGML_TYPE_F32: { ((float *)(data))[0] = value; @@ -3740,7 +3928,11 @@ static struct ggml_tensor * ggml_add_cast_impl( // TODO: support less-strict constraint // GGML_ASSERT(ggml_can_repeat(b, a)); GGML_ASSERT(ggml_can_repeat_rows(b, a)); - GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16 + + // currently only supported for quantized input and f16 + GGML_ASSERT(ggml_is_quantized(a->type) || + a->type == GGML_TYPE_F16 || + a->type == GGML_TYPE_BF16); bool is_node = false; @@ -7231,8 +7423,8 @@ static void ggml_compute_forward_dup_same_cont( ((char *) src0->data + ie0*nb00), (ie1 - ie0) * ggml_type_size(src0->type)); } - } + static void ggml_compute_forward_dup_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7506,7 +7698,7 @@ static void ggml_compute_forward_dup_f16( } } -static void ggml_compute_forward_dup_f32( +static void ggml_compute_forward_dup_bf16( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7554,10 +7746,11 @@ static void ggml_compute_forward_dup_f32( return; } + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + if (ggml_is_contiguous(dst)) { - // TODO: simplify - if (nb00 == sizeof(float)) { - if (dst->type == GGML_TYPE_F32) { + if (nb00 == sizeof(ggml_bf16_t)) { + if (dst->type == GGML_TYPE_BF16) { size_t id = 0; const size_t rs = ne00 * nb00; char * dst_ptr = (char *) dst->data; @@ -7573,8 +7766,43 @@ static void ggml_compute_forward_dup_f32( id += rs * (ne01 - ir1); } } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00])); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } } else if (type_traits[dst->type].from_float) { ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; size_t id = 0; size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); @@ -7584,8 +7812,13 @@ static void ggml_compute_forward_dup_f32( for (int i02 = 0; i02 < ne02; i02++) { id += rs * ir0; for (int i01 = ir0; i01 < ir1; i01++) { - const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - quantize_row_q(src0_ptr, dst_ptr + id, ne00); + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); id += rs; } id += rs * (ne01 - ir1); @@ -7606,7 +7839,25 @@ static void ggml_compute_forward_dup_f32( id += ne00 * ir0; for (int i01 = ir0; i01 < ir1; i01++) { for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_BF16) { + size_t id = 0; + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); dst_ptr[id] = *src0_ptr; id++; @@ -7624,9 +7875,9 @@ static void ggml_compute_forward_dup_f32( id += ne00 * ir0; for (int i01 = ir0; i01 < ir1; i01++) { for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr)); id++; } } @@ -7637,18 +7888,16 @@ static void ggml_compute_forward_dup_f32( GGML_ASSERT(false); // TODO: implement } } - return; } // dst counters - int64_t i10 = 0; int64_t i11 = 0; int64_t i12 = 0; int64_t i13 = 0; - if (dst->type == GGML_TYPE_F32) { + if (dst->type == GGML_TYPE_BF16) { for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { i10 += ne00 * ir0; @@ -7669,7 +7918,59 @@ static void ggml_compute_forward_dup_f32( const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - memcpy(dst_ptr, src0_ptr, sizeof(float)); + memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr)); if (++i10 == ne0) { i10 = 0; @@ -7700,7 +8001,7 @@ static void ggml_compute_forward_dup_f32( } } } - } else if (dst->type == GGML_TYPE_F16) { + } else if (dst->type == GGML_TYPE_F32) { for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { i10 += ne00 * ir0; @@ -7721,7 +8022,7 @@ static void ggml_compute_forward_dup_f32( const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr); + *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr); if (++i10 == ne0) { i10 = 0; @@ -7757,31 +8058,27 @@ static void ggml_compute_forward_dup_f32( } } -// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. -static void ggml_compute_forward_dup_bytes( +static void ggml_compute_forward_dup_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - GGML_ASSERT(src0->type == dst->type); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { - ggml_compute_forward_dup_same_cont(params, dst); - return; - } - - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS - const size_t type_size = ggml_type_size(src0->type); const int ith = params->ith; // thread index const int nth = params->nth; // number of threads + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + ggml_compute_forward_dup_same_cont(params, dst); + return; + } // parallelize by rows const int nr = ne01; @@ -7793,9 +8090,9 @@ static void ggml_compute_forward_dup_bytes( if (src0->type == dst->type && ne00 == ne0 && - nb00 == type_size && nb0 == type_size) { + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { // copy by rows - const size_t rs = ne00 * type_size; + const size_t rs = ne00*nb00; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ir0; i01 < ir1; i01++) { @@ -7810,41 +8107,366 @@ static void ggml_compute_forward_dup_bytes( } if (ggml_is_contiguous(dst)) { - size_t id = 0; - char * dst_ptr = (char *) dst->data; - const size_t rs = ne00 * type_size; - - if (nb00 == type_size) { - // src0 is contigous on first dimension, copy by rows - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int64_t i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, type_size); + // TODO: simplify + if (nb00 == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; - id += type_size; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; } + id += rs * (ne01 - ir1); } - id += rs * (ne01 - ir1); } - } - } + } else if (type_traits[dst->type].from_float) { + ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + quantize_row_q(src0_ptr, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_BF16) { + size_t id = 0; + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(float)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_BF16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } +} + +// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. +static void ggml_compute_forward_dup_bytes( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(src0->type == dst->type); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { + ggml_compute_forward_dup_same_cont(params, dst); + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS; + + const size_t type_size = ggml_type_size(src0->type); + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == type_size && nb0 == type_size) { + // copy by rows + const size_t rs = ne00 * type_size; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + size_t id = 0; + char * dst_ptr = (char *) dst->data; + const size_t rs = ne00 * type_size; + + if (nb00 == type_size) { + // src0 is contigous on first dimension, copy by rows + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, type_size); + + id += type_size; + } + } + id += rs * (ne01 - ir1); + } + } + } return; } @@ -7925,6 +8547,10 @@ static void ggml_compute_forward_dup( { ggml_compute_forward_dup_f16(params, dst); } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_dup_bf16(params, dst); + } break; case GGML_TYPE_F32: { ggml_compute_forward_dup_f32(params, dst); @@ -8018,17 +8644,96 @@ static void ggml_compute_forward_add_f32( float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_add_f16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + if (dst->type == GGML_TYPE_F32) { + GGML_ASSERT( nb0 == sizeof(float)); + } + else { + GGML_ASSERT(dst->type == GGML_TYPE_F16); + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + } + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + if (dst->type == GGML_TYPE_F16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } } } } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } } -static void ggml_compute_forward_add_f16_f32( +static void ggml_compute_forward_add_bf16_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8048,18 +8753,18 @@ static void ggml_compute_forward_add_f16_f32( GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_BF16); GGML_ASSERT(src1->type == GGML_TYPE_F32); if (dst->type == GGML_TYPE_F32) { GGML_ASSERT( nb0 == sizeof(float)); } else { - GGML_ASSERT(dst->type == GGML_TYPE_F16); - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); } - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -8069,19 +8774,19 @@ static void ggml_compute_forward_add_f16_f32( const int ir1 = MIN(ir0 + dr, nr); if (nb10 == sizeof(float)) { - if (dst->type == GGML_TYPE_F16) { + if (dst->type == GGML_TYPE_BF16) { for (int ir = ir0; ir < ir1; ++ir) { // src0, src1 and dst are same shape => same indices const int i3 = ir/(ne2*ne1); const int i2 = (ir - i3*ne2*ne1)/ne1; const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); } } } else { @@ -8092,11 +8797,11 @@ static void ggml_compute_forward_add_f16_f32( const int i1 = (ir - i3*ne2*ne1 - i2*ne1); float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; } } } @@ -8163,6 +8868,62 @@ static void ggml_compute_forward_add_f16_f16( } } +static void ggml_compute_forward_add_bf16_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(ggml_bf16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + static void ggml_compute_forward_add_q_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8272,6 +9033,18 @@ static void ggml_compute_forward_add( GGML_ASSERT(false); } } break; + case GGML_TYPE_BF16: + { + if (src1->type == GGML_TYPE_BF16) { + ggml_compute_forward_add_bf16_bf16(params, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_bf16_f32(params, dst); + } + else { + GGML_ASSERT(false); + } + } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8530,6 +9303,110 @@ static void ggml_compute_forward_add1_q_f32( } } +static void ggml_compute_forward_add1_bf16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_bf16_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // scalar to add + const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v); + } + } +} + static void ggml_compute_forward_add1( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8554,6 +9431,18 @@ static void ggml_compute_forward_add1( GGML_ASSERT(false); } } break; + case GGML_TYPE_BF16: + { + if (src1->type == GGML_TYPE_BF16) { + ggml_compute_forward_add1_bf16_bf16(params, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add1_bf16_f32(params, dst); + } + else { + GGML_ASSERT(false); + } + } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8682,6 +9571,7 @@ static void ggml_compute_forward_acc( ggml_compute_forward_acc_f32(params, dst); } break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -9203,6 +10093,40 @@ static void ggml_compute_forward_sum_f16( ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); } +static void ggml_compute_forward_sum_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(ggml_bf16_t)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_bf16_ggf(ne00, + &row_sum, + (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum); +} + static void ggml_compute_forward_sum( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -9218,6 +10142,10 @@ static void ggml_compute_forward_sum( { ggml_compute_forward_sum_f16(params, dst); } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_sum_bf16(params, dst); + } break; default: { GGML_ASSERT(false); @@ -9492,6 +10420,7 @@ static void ggml_compute_forward_repeat( switch (src0->type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_I16: { ggml_compute_forward_repeat_f16(params, dst); @@ -11855,6 +12784,7 @@ static void ggml_compute_forward_set( ggml_compute_forward_set_f32(params, dst); } break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -12029,6 +12959,49 @@ static void ggml_compute_forward_get_rows_f16( } } +static void ggml_compute_forward_get_rows_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(ggml_bf16_t)); + assert(ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + ggml_bf16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } +} + static void ggml_compute_forward_get_rows_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -12106,6 +13079,10 @@ static void ggml_compute_forward_get_rows( { ggml_compute_forward_get_rows_f16(params, dst); } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_get_rows_bf16(params, dst); + } break; case GGML_TYPE_F32: case GGML_TYPE_I32: { @@ -12801,6 +13778,7 @@ static void ggml_compute_forward_alibi( { ggml_compute_forward_alibi_f32(params, dst); } break; + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -12890,6 +13868,7 @@ static void ggml_compute_forward_clamp( ggml_compute_forward_clamp_f32(params, dst); } break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15987,6 +16966,7 @@ static void ggml_compute_forward_get_rel_pos( switch (src0->type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: { ggml_compute_forward_get_rel_pos_f16(params, dst); } break; @@ -18856,7 +19836,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa case GGML_OP_CPY: case GGML_OP_DUP: { - if (ggml_is_quantized(node->type)) { + if (ggml_is_quantized(node->type) || + // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 + (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || + (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } } break; @@ -18935,7 +19918,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const int64_t ne10 = node->src[1]->ne[0]; // L const int64_t ne11 = node->src[1]->ne[1]; // Cin - if (node->src[0]->type == GGML_TYPE_F16 && + if ((node->src[0]->type == GGML_TYPE_F16 || + node->src[0]->type == GGML_TYPE_BF16) && node->src[1]->type == GGML_TYPE_F32) { cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02; cur += sizeof(ggml_fp16_t)*ne10*ne11; @@ -18971,6 +19955,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_BF16) { + cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; case GGML_OP_FLASH_ATTN_EXT: @@ -18987,6 +19974,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_BF16) { + cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 } } break; case GGML_OP_FLASH_ATTN_BACK: @@ -19000,6 +19990,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_BF16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; @@ -19776,7 +20769,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { fprintf(fp, "%d", ggml_get_i32_1d(node, j)); } - else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) { + else if (node->type == GGML_TYPE_F32 || + node->type == GGML_TYPE_F16 || + node->type == GGML_TYPE_BF16) { fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j)); } else { @@ -20834,6 +21829,12 @@ size_t ggml_quantize_chunk( ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n); result = n * elemsize; } break; + case GGML_TYPE_BF16: + { + size_t elemsize = sizeof(ggml_bf16_t); + ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n); + result = n * elemsize; + } break; case GGML_TYPE_F32: { size_t elemsize = sizeof(float); diff --git a/ggml.h b/ggml.h index d90ba8ed664..bc9efcf408d 100644 --- a/ggml.h +++ b/ggml.h @@ -326,14 +326,20 @@ extern "C" { // get ggml_status name string GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status); + // ieee 754-2008 half-precision float16 + // todo: make this not an integral type typedef uint16_t ggml_fp16_t; - - // convert FP16 <-> FP32 - GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); - GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); - - GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n); - GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n); + GGML_API float ggml_fp16_to_fp32(ggml_fp16_t); + GGML_API ggml_fp16_t ggml_fp32_to_fp16(float); + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t); + GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t); + + // google brain half-precision bfloat16 + typedef struct { uint16_t bits; } ggml_bf16_t; + GGML_API ggml_bf16_t ggml_fp32_to_bf16(float); + GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16 + GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t); + GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t); struct ggml_object; struct ggml_context; @@ -370,6 +376,7 @@ extern "C" { GGML_TYPE_I64 = 27, GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, + GGML_TYPE_BF16 = 30, GGML_TYPE_COUNT, }; @@ -410,6 +417,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors + GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors }; // available tensor operations: From 69efc39d5c95f2917bf38b85109af15727043a30 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 8 May 2024 22:08:10 +0300 Subject: [PATCH 046/100] metal : use `vm_allocate` instead of `posix_memalign` on macOS (llama/7078) * fix: use `malloc` instead of `posix_memalign` in `ggml-metal.m` to make it not crash Electron proccesses * fix: typo * fix: use `vm_allocate` instead of `posix_memalign` * fix: don't call `newBufferWithBytesNoCopy` with `NULL` when `ggml_metal_host_malloc` returns `NULL` * fix: use `vm_allocate` only on macOS --- ggml-metal.m | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4d4b9717812..038a5061f9b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -266,11 +266,20 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ static void * ggml_metal_host_malloc(size_t n) { void * data = NULL; + +#if TARGET_OS_OSX + kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); + if (err != KERN_SUCCESS) { + GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); + return NULL; + } +#else const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); if (result != 0) { GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); return NULL; } +#endif return data; } @@ -2855,7 +2864,11 @@ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_ ggml_backend_metal_free_device(); if (ctx->owned) { +#if TARGET_OS_OSX + vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size); +#else free(ctx->all_data); +#endif } free(ctx); @@ -2959,14 +2972,16 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff ctx->owned = true; ctx->n_buffers = 1; - ctx->buffers[0].data = ctx->all_data; - ctx->buffers[0].size = size; - ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data - length:size_aligned - options:MTLResourceStorageModeShared - deallocator:nil]; + if (ctx->all_data != NULL) { + ctx->buffers[0].data = ctx->all_data; + ctx->buffers[0].size = size; + ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data + length:size_aligned + options:MTLResourceStorageModeShared + deallocator:nil]; + } - if (ctx->buffers[0].metal == nil) { + if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) { GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); free(ctx); ggml_backend_metal_free_device(); From 24f0aa460b1c04becdfdd673b913cce6a4d057b1 Mon Sep 17 00:00:00 2001 From: agray3 Date: Wed, 8 May 2024 21:55:49 +0100 Subject: [PATCH 047/100] Introduction of CUDA Graphs to LLama.cpp (llama/6766) * DRAFT: Introduction of CUDA Graphs to LLama.cpp * FIx issues raised in comments * Tidied to now only use CUDA runtime (not mixed with driver calls) * disable for multi-gpu and batch size > 1 * Disable CUDA graphs for old GPU arch and with env var * added missing CUDA_CHECKs * Addressed comments * further addressed comments * limit to GGML_ALLOW_CUDA_GRAPHS defined in llama.cpp cmake * Added more comprehensive graph node checking * With mechanism to fall back if graph capture fails * Revert "With mechanism to fall back if graph capture fails" This reverts commit eb9f15fb6fcb81384f732c4601a5b25c016a5143. * Fall back if graph capture fails and address other comments * - renamed GGML_ALLOW_CUDA_GRAPHS to GGML_CUDA_USE_GRAPHS - rename env variable to disable CUDA graphs to GGML_CUDA_DISABLE_GRAPHS - updated Makefile build to enable CUDA graphs - removed graph capture failure checking in ggml_cuda_error using a global variable to track this is not thread safe, but I am also not safistied with checking an error by string if this is necessary to workaround some issues with graph capture with eg. cuBLAS, we can pass the ggml_backend_cuda_context to the error checking macro and store the result in the context - fixed several resource leaks - fixed issue with zero node graphs - changed fixed size arrays to vectors - removed the count of number of evaluations before start capturing, and instead changed the capture mode to relaxed - removed the check for multiple devices so that it is still possible to use a single device, instead checks for split buffers to disable cuda graphs with -sm row - changed the op for checking batch size to GGML_OP_ADD, should be more reliable than GGML_OP_SOFT_MAX - code style fixes - things to look into - VRAM usage of the cudaGraphExec_t, if it is significant we may need to make it optional - possibility of using cudaStreamBeginCaptureToGraph to keep track of which ggml graph nodes correspond to which cuda graph nodes * fix build without cuda graphs * remove outdated comment * replace minimum cc value with a constant --------- Co-authored-by: slaren --- ggml-cuda.cu | 300 +++++++++++++++++++++++++++++++++++++++++-- ggml-cuda/clamp.cu | 1 - ggml-cuda/common.cuh | 40 ++++++ ggml-cuda/convert.cu | 4 +- ggml-cuda/cpy.cu | 29 +++++ ggml-cuda/cpy.cuh | 2 + ggml-cuda/mmq.cu | 30 ++--- ggml-cuda/mmvq.cu | 6 +- ggml-cuda/scale.cu | 1 - 9 files changed, 370 insertions(+), 43 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8739baa2a79..ceb66170edd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1647,7 +1647,7 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation @@ -1670,7 +1670,7 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } -static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); @@ -2413,32 +2413,304 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + graph_node_properties->node_address = node->data; + graph_node_properties->node_op = node->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + graph_node_properties->ne[i] = node->ne[i]; + graph_node_properties->nb[i] = node->nb[i]; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + } +} + +static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + if (node->data != graph_node_properties->node_address && + node->op != GGML_OP_CPY && + node->op != GGML_OP_VIEW) { + return false; + } + + if (node->op != graph_node_properties->node_op) { + return false; + } + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (node->ne[i] != graph_node_properties->ne[i]) { + return false; + } + if (node->nb[i] != graph_node_properties->nb[i]) { + return false; + } + } + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i] && + node->src[i]->data != graph_node_properties->src_address[i] && + node->op != GGML_OP_CPY && + node->op != GGML_OP_VIEW + ) { + return false; + } + } + return true; +} + GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_cuda_set_device(cuda_ctx->device); - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; +#ifdef USE_CUDA_GRAPH + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; + // Objects required for CUDA Graph + if (cuda_ctx->cuda_graph == nullptr) { + cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); + } + + bool use_cuda_graph = true; + bool cuda_graph_update_required = false; + // pointer to CUDA cpy kernel, which is required to identify + // kernel parameters which need updated in the graph for each token + void * ggml_cuda_cpy_fn_ptr = nullptr; + + if (cuda_ctx->cuda_graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { + cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__); +#endif + } + } + + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, + // or previous graph capture failure. + // Also disable for multi-gpu for now. TO DO investigate + if (disable_cuda_graphs_due_to_env + || cuda_ctx->cuda_graph->disable_due_to_gpu_arch + || cuda_ctx->cuda_graph->disable_due_to_too_many_updates + || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { + use_cuda_graph = false; + } + + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { + cuda_graph_update_required = true; + } + + // Check if the graph size has changed + if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cuda_graph_update_required = true; + cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } + + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if (!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + if (!has_matching_properties) { + cuda_graph_update_required = true; + } + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + + // Loop over nodes in GGML graph to obtain info needed for CUDA graph + cuda_ctx->cuda_graph->updated_kernel_arg.clear(); + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) { + use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__); +#endif + } + + if (node->op == GGML_OP_MUL_MAT_ID) { + use_cuda_graph = false; // This node type is not supported by CUDA graph capture +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__); +#endif + } + + if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) { + // disable CUDA graphs for batch size > 1 for now. + // Changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } + + if (node->op == GGML_OP_CPY) { + // store the copy op parameter which changes with each token. + cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); + if (ggml_cuda_cpy_fn_ptr == nullptr) { + // store a pointer to the copy op CUDA kernel to identify it later + ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + } + } + + if (!use_cuda_graph) { + break; + } + } + + // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. + if (cuda_graph_update_required) { + cuda_ctx->cuda_graph->number_consecutive_updates++; + } else { + cuda_ctx->cuda_graph->number_consecutive_updates = 0; } + if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { + cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); +#endif + } + } + + if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + } + +#else + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; +#endif // USE_CUDA_GRAPH + + bool graph_evaluated_or_captured = false; + + while (!graph_evaluated_or_captured) { + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. + // With the use of CUDA graphs, the execution will be performed by the graph launch. + if (!use_cuda_graph || cuda_graph_update_required) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + #ifndef NDEBUG - assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); - for (int j = 0; j < GGML_MAX_SRC; j++) { - if (node->src[j] != nullptr) { - assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); + } + } +#endif + + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); + if (!ok) { + fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); } } + +#ifdef USE_CUDA_GRAPH + if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture + if (cuda_ctx->cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); + cuda_ctx->cuda_graph->graph = nullptr; + } + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + +#if 0 + if (disable_cuda_graphs_due_to_failed_capture) { + use_cuda_graph = false; + cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__); #endif + } else { + graph_evaluated_or_captured = true; // CUDA graph has been captured + } +#endif + graph_evaluated_or_captured = true; // CUDA graph has been captured + } else { + graph_evaluated_or_captured = true; // ggml graph has been directly evaluated + } + } - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); - if (!ok) { - fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } - GGML_ASSERT(ok); + + // Perform update to graph (if required for this token), and change copy parameter (required for every token) + + if (cuda_graph_update_required) { + // Extract nodes from graph + if (cuda_ctx->cuda_graph->num_nodes == 0) { + // First call with null argument gets number of nodes in graph + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); + } + // Subsequent call with non-null argument gets nodes + cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); + cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); + if (cuda_ctx->cuda_graph->num_nodes > 0) { + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); + + // Loop over nodes, and extract kernel parameters from each node + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + cudaGraphNodeType node_type; + CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); + if (node_type == cudaGraphNodeTypeKernel) { + cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime + if (stat == cudaErrorInvalidDeviceFunction) { + // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. + // We don't need to update blas nodes, so clear error and move on. + cudaGetLastError(); + } else { + GGML_ASSERT(stat == cudaSuccess); + } + } + } + } + } + + // One of the arguments to the copy kernel is updated for each token, hence we need to + // replace that argument with the updated value in the CUDA graph + if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured + int k = 0; + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) { + char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); + cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; + CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); + } + } + } + + // Update graph executable + cudaGraphExecUpdateResultInfo result_info; + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + if (stat == cudaErrorGraphExecUpdateFailure) { +#ifndef NDEBUG + fprintf(stderr, "%s: CUDA graph update failed\n", __func__); +#endif + // The pre-existing graph exec cannot be updated due to violated constraints + // so instead clear error and re-instantiate + cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); + cuda_ctx->cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } else { + GGML_ASSERT(stat == cudaSuccess); + } + // Launch graph + CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); +#else + graph_evaluated_or_captured = true; +#endif // USE_CUDA_GRAPH } return GGML_STATUS_SUCCESS; diff --git a/ggml-cuda/clamp.cu b/ggml-cuda/clamp.cu index 379ded042d8..8009a3e3d86 100644 --- a/ggml-cuda/clamp.cu +++ b/ggml-cuda/clamp.cu @@ -31,5 +31,4 @@ void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream); - CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b2627b7b4b7..a4197f11ba7 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #if defined(GGML_USE_HIPBLAS) #include @@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu { cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs }; + +#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#define USE_CUDA_GRAPH +#endif + +struct ggml_graph_node_properties { + void * node_address; + ggml_op node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; +}; + +struct ggml_cuda_graph { +#ifdef USE_CUDA_GRAPH + ~ggml_cuda_graph() { + if (instance != nullptr) { + CUDA_CHECK(cudaGraphExecDestroy(instance)); + } + if (graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph)); + } + } + cudaGraph_t graph = nullptr; + cudaGraphExec_t instance = nullptr; + size_t num_nodes = 0; + std::vector nodes; + std::vector params; + bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + bool disable_due_to_failed_graph_capture = false; + int number_consecutive_updates = 0; + std::vector ggml_graph_properties; + std::vector updated_kernel_arg; +#endif +}; + struct ggml_backend_cuda_context { int device; std::string name; @@ -534,6 +572,8 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; + std::unique_ptr cuda_graph; + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index 75e50c98561..830e2d75661 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -727,7 +727,6 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_ } to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { - int id; switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; @@ -738,8 +737,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_Q5_1: return dequantize_block_cuda; case GGML_TYPE_Q8_0: - CUDA_CHECK(cudaGetDevice(&id)); - if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) { + if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) { return dequantize_block_q8_0_f16_cuda; } return dequantize_block_cuda; diff --git a/ggml-cuda/cpy.cu b/ggml-cuda/cpy.cu index 16d9c8fffb4..12d741f017d 100644 --- a/ggml-cuda/cpy.cu +++ b/ggml-cuda/cpy.cu @@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; ggml_cuda_cpy(ctx, src0, dst); } + +void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_f32_f16; + } else { + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ASSERT(false); + } +} + diff --git a/ggml-cuda/cpy.cuh b/ggml-cuda/cpy.cuh index f0b2c453bfe..7961674266e 100644 --- a/ggml-cuda/cpy.cuh +++ b/ggml-cuda/cpy.cuh @@ -5,3 +5,5 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 60d6616a860..7948f1b1237 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -1735,8 +1735,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1780,8 +1779,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1825,8 +1823,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1870,8 +1867,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1915,8 +1911,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1960,8 +1955,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2007,8 +2001,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( #if QK_K == 256 - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2053,8 +2046,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2098,8 +2090,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2143,8 +2134,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu index 3965590017b..65cc1bcaad6 100644 --- a/ggml-cuda/mmvq.cu +++ b/ggml-cuda/mmvq.cu @@ -89,8 +89,7 @@ static void mul_mat_vec_q_cuda( GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; @@ -328,8 +327,7 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t ne0 = dst->ne[0]; - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into diff --git a/ggml-cuda/scale.cu b/ggml-cuda/scale.cu index 6e3617d1cdb..1405e066e86 100644 --- a/ggml-cuda/scale.cu +++ b/ggml-cuda/scale.cu @@ -28,5 +28,4 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&scale, dst->op_params, sizeof(float)); scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); - CUDA_CHECK(cudaGetLastError()); } From 26c550f77287158d14f3e8b15486cec86ee8a42d Mon Sep 17 00:00:00 2001 From: Albert Jin Date: Thu, 9 May 2024 17:34:37 +0800 Subject: [PATCH 048/100] opencl : alignment size converted from bits to bytes (llama/7090) * opencl alignment size should be converted from bits to bytes Reference: https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#CL_DEVICE_MEM_BASE_ADDR_ALIGN > Alignment requirement (in bits) for sub-buffer offsets. * Update ggml-opencl.cpp for readability using division instead of shift Co-authored-by: Jared Van Bortel --------- Co-authored-by: Jared Van Bortel --- ggml-opencl.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index b3f8b7eaf0a..880a14958ce 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -2119,6 +2119,7 @@ static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_ if (alignment == (cl_uint)-1) { ggml_cl_init(); clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &alignment, NULL); + alignment /= 8; // bits to bytes } return alignment; From 4be936b88ba64faa027fca89af5e2cfcaa64e926 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 9 May 2024 14:32:02 +0200 Subject: [PATCH 049/100] CUDA: generalize FP16 fattn vec kernel (llama/7061) * CUDA: generalize FP16 fattn vec kernel * disable unsupported head sizes for AMD in test * try AMD fix * fix batch size 2-8 * partially revert changes --- ggml-cuda/common.cuh | 232 +++++++++++++++++--------------- ggml-cuda/fattn.cu | 305 +++++++++++++++++++++++++++++++------------ 2 files changed, 347 insertions(+), 190 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index a4197f11ba7..44e67e040e1 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -234,6 +234,97 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif //GGML_CUDA_F16 +#if defined(GGML_USE_HIPBLAS) +#define __CUDA_ARCH__ 1300 + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ + defined(__gfx1150__) || defined(__gfx1151__) +#define RDNA3 +#endif + +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) +#define RDNA2 +#endif + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4(const int a, const int b) { + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); +#if __has_builtin(__builtin_elementwise_sub_sat) + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); + if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); + c[i] = tmp; + } + return reinterpret_cast(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) +} + +static __device__ __forceinline__ int __vsub4(const int a, const int b) { + return __vsubss4(a, b); +} + +static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0xff : 0x00; + } + return c; +} + +static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { +#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) + c = __builtin_amdgcn_sdot4(a, b, c, false); +#elif defined(RDNA3) + c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); +#elif defined(__gfx1010__) || defined(__gfx900__) + int tmp1; + int tmp2; + asm("\n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + " + : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) + : "v"(a), "v"(b) + ); +#else + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); + c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; +#endif + return c; +} +#endif // defined(GGML_USE_HIPBLAS) + +#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL + +#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA + +static bool fp16_mma_available(const int cc) { + return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { @@ -275,16 +366,28 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { } static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if FP16_AVAILABLE + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); - } - return a; + for (int mask = 16; mask > 0; mask >>= 1) { + const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32); + reinterpret_cast(a.x) += __low2half(a_other); + reinterpret_cast(a.y) += __high2half(a_other); + } + return a; #else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE } static __device__ __forceinline__ float warp_reduce_max(float x) { @@ -296,20 +399,21 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { } static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if FP16_AVAILABLE -#if CUDART_VERSION >= CUDART_HMAX - return __hmax(a, b); +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX + return __float2half(fmaxf(__half2float(a), __half2float(b))); #else - return __half2float(a) > __half2float(b) ? a : b; -#endif // CUDART_VERSION >= CUDART_HMAX + return __hmax(a, b); +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX #else - GGML_UNUSED(a); - GGML_UNUSED(b); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX + NO_DEVICE_CODE; + GGML_UNUSED(b); + return a; +#endif // FP16_AVAILABLE } + static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -317,8 +421,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal return __hmax2(a, b); #else half2 ret; - reinterpret_cast(ret.x) = __low2float(a) > __low2float(b) ? __low2half(a) : __low2half(b); - reinterpret_cast(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b); + reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b))); + reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); return ret; #endif // CUDART_VERSION >= CUDART_HMAX @@ -326,7 +430,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal GGML_UNUSED(a); GGML_UNUSED(b); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) } static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { @@ -350,94 +454,6 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half } #endif // CUDART_VERSION < 12000 -#if defined(GGML_USE_HIPBLAS) -#define __CUDA_ARCH__ 1300 - -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ - defined(__gfx1150__) || defined(__gfx1151__) -#define RDNA3 -#endif - -#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ - defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) -#define RDNA2 -#endif - -#ifndef __has_builtin - #define __has_builtin(x) 0 -#endif - -typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); -typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); -static __device__ __forceinline__ int __vsubss4(const int a, const int b) { - const int8x4_t va = reinterpret_cast(a); - const int8x4_t vb = reinterpret_cast(b); -#if __has_builtin(__builtin_elementwise_sub_sat) - const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); - return reinterpret_cast(c); -#else - int8x4_t c; - int16_t tmp; -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp = va[i] - vb[i]; - if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); - if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); - c[i] = tmp; - } - return reinterpret_cast(c); -#endif // __has_builtin(__builtin_elementwise_sub_sat) -} - -static __device__ __forceinline__ int __vsub4(const int a, const int b) { - return __vsubss4(a, b); -} - -static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) { - const uint8x4_t& va = reinterpret_cast(a); - const uint8x4_t& vb = reinterpret_cast(b); - unsigned int c; - uint8x4_t& vc = reinterpret_cast(c); -#pragma unroll - for (int i = 0; i < 4; ++i) { - vc[i] = va[i] == vb[i] ? 0xff : 0x00; - } - return c; -} - -static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { -#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) - c = __builtin_amdgcn_sdot4(a, b, c, false); -#elif defined(RDNA3) - c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); -#elif defined(__gfx1010__) || defined(__gfx900__) - int tmp1; - int tmp2; - asm("\n \ - v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ - v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ - v_add3_u32 %0, %1, %2, %0 \n \ - v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ - v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ - v_add3_u32 %0, %1, %2, %0 \n \ - " - : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) - : "v"(a), "v"(b) - ); -#else - const int8x4_t va = reinterpret_cast(a); - const int8x4_t vb = reinterpret_cast(b); - c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; -#endif - return c; -} -#endif // defined(GGML_USE_HIPBLAS) - -#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ - defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL - -#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA - // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index c8a11d17334..7c486f4829b 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -11,8 +11,10 @@ #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. -template // D == head size -__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -44,55 +46,77 @@ static __global__ void flash_attn_vec_ext_f16( #if FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic; + const half * maskh = (const half *) mask + ne11*ic0; const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); - constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = D / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < nwarps*WARP_SIZE); + __builtin_assume(tid < D); - __shared__ half KQ[nwarps*WARP_SIZE]; - KQ[tid] = -INFINITY; + __shared__ half KQ[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ[j*D + tid] = -HALF_MAX_HALF; + } half2 * KQ2 = (half2 *) KQ; - half kqmax = -HALF_MAX_HALF; - half kqsum = 0.0f; + half kqmax[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax[j] = -HALF_MAX_HALF; + } + half kqsum[ncols] = {0.0f}; - __shared__ half kqmax_shared[WARP_SIZE]; - __shared__ half kqsum_shared[WARP_SIZE]; - if (threadIdx.y == 0) { - kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; - kqsum_shared[threadIdx.x] = 0.0f; + __shared__ half kqmax_shared[ncols][WARP_SIZE]; + __shared__ half kqsum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; + kqsum_shared[j][threadIdx.x] = 0.0f; + } } __syncthreads(); // Convert Q to half2 and store in registers: - half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; + half2 Q_h2[ncols][D/(2*WARP_SIZE)]; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { - break; - } + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y); + const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } } - half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. + half2 VKQ[ncols] = {{0.0f, 0.0f}}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; + const int k_start = parallel_blocks == 1 ? 0 : ip*D; for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: - half kqmax_new = kqmax; + + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, + // see https://github.com/ggerganov/llama.cpp/pull/7061 . + // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). + half kqmax_new = kqmax[0]; + half kqmax_new_arr[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new_arr[j] = kqmax[j]; + } + #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; @@ -101,89 +125,112 @@ static __global__ void flash_attn_vec_ext_f16( break; } - half2 sum2 = make_half2(0.0f, 0.0f); + half2 sum2[ncols] = {{0.0f, 0.0f}}; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; - if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) { - break; - } const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; - sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE]; + } } - sum2 = warp_reduce_sum(sum2); - half sum = __low2half(sum2) + __high2half(sum2); - sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); - kqmax_new = ggml_cuda_hmax(kqmax_new, sum); - if (threadIdx.x == 0) { - KQ[i_KQ] = sum; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + sum2[j] = warp_reduce_sum(sum2[j]); + half sum = __low2half(sum2[j]) + __high2half(sum2[j]); + sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + + if (ncols == 1) { + kqmax_new = ggml_cuda_hmax(kqmax_new, sum); + } else { + kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); + } + + if (threadIdx.x == 0) { + KQ[j*D + i_KQ] = sum; + } } } - kqmax_new = warp_reduce_max(kqmax_new); - if (threadIdx.x == 0) { - kqmax_shared[threadIdx.y] = kqmax_new; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; + + kqmax_new_j = warp_reduce_max(kqmax_new_j); + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = kqmax_new_j; + } } + __syncthreads(); - kqmax_new = kqmax_shared[threadIdx.x]; - kqmax_new = warp_reduce_max(kqmax_new); - const half KQ_max_scale = hexp(kqmax - kqmax_new); - kqmax = kqmax_new; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; - const half val = hexp(KQ[tid] - kqmax); - kqsum = kqsum*KQ_max_scale + val; - KQ[tid] = val; + const half val = hexp(KQ[j*D + tid] - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale + val; + KQ[j*D + tid] = val; - VKQ *= __half2half2(KQ_max_scale); + VKQ[j] *= __half2half2(KQ_max_scale); + } __syncthreads(); - if (tid < D) { #pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } - half2 V_k; - reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; - reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; - VKQ += V_k*KQ2[k0/2]; + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; } } __syncthreads(); } - if (tid >= D) { - kqsum = 0.0f; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqsum[j] = warp_reduce_sum(kqsum[j]); + if (threadIdx.x == 0) { + kqsum_shared[j][threadIdx.y] = kqsum[j]; + } } - kqsum = warp_reduce_sum(kqsum); - if (threadIdx.x == 0) { - kqsum_shared[threadIdx.y] = kqsum; - } __syncthreads(); - kqsum = kqsum_shared[threadIdx.x]; - kqsum = warp_reduce_sum(kqsum); - if (tid >= D) { - return; - } +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); - half dst_val = (__low2half(VKQ) + __high2half(VKQ)); - if (parallel_blocks == 1) { - dst_val /= kqsum; + half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); + if (parallel_blocks == 1) { + dst_val /= kqsum[j_VKQ]; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; - if (parallel_blocks == 1 || tid != 0) { - return; + if (parallel_blocks != 1 && tid != 0) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]); + } } - dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE @@ -191,7 +238,9 @@ static __global__ void flash_attn_vec_ext_f16( // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -573,7 +622,9 @@ static __global__ void flash_attn_ext_f16( } template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, @@ -642,7 +693,7 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -template void launch_fattn_vec_f16( +template void launch_fattn_vec_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { @@ -656,13 +707,13 @@ template void launch_fattn_vec_f16( constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - flash_attn_vec_ext_f16 + flash_attn_vec_ext_f16 <<>> ( (const char *) Q->data, (const char *) K->data, @@ -783,10 +834,99 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; const int32_t precision = KQV->op_params[1]; + if (!fp16_mma_available(cc)) { + GGML_ASSERT(precision == GGML_PREC_DEFAULT); + GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 8) { + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 1; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + if (precision != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { constexpr int cols_per_block = 16; @@ -845,16 +985,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; switch (Q->ne[0]) { case 64: - launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 128: - launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; case 256: - launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); break; default: GGML_ASSERT(false); From c114b75aeecdb6220f14fcd1d20169995d49acce Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Thu, 9 May 2024 20:39:54 +0200 Subject: [PATCH 050/100] Vulkan Bugfixes and Improvements (llama/7084) * Modify mat mat mul shader for mul_mat_id, modify mat vec mul shaders for single call batch operation * Further work towards MoE, disabled for now * Disable MoE code (not ready yet), fix a number of bugs in shaders and Vulkan code * Add softmax with f16 mask and pos buffer support * Disable mul_mat_id shaders for now * Fix flake8 * Fix validation errors caused by empty buffers on larger batch sizes --- ggml-vulkan.cpp | 1217 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 950 insertions(+), 267 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index f712cdd5a90..95f71897405 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -120,8 +120,16 @@ struct vk_device { vk_matmul_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES]; + vk_matmul_pipeline pipeline_matmul_id_f32; + vk_matmul_pipeline pipeline_matmul_id_f16; + vk_matmul_pipeline pipeline_matmul_id_f16_f32; + + vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[VK_NUM_TYPES]; + vk_pipeline pipeline_dequant[VK_NUM_TYPES]; - vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[VK_NUM_TYPES]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[VK_NUM_TYPES]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[VK_NUM_TYPES]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; @@ -139,7 +147,7 @@ struct vk_device { vk_pipeline pipeline_silu_f32; vk_pipeline pipeline_relu_f32; vk_pipeline pipeline_diag_mask_inf_f32; - vk_pipeline pipeline_soft_max_f32; + vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_rope_f32, pipeline_rope_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; vk_pipeline pipeline_argsort_f32; @@ -215,6 +223,21 @@ struct vk_submission { typedef std::vector vk_sequence; +struct vk_mat_mat_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t k_split; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t expert_stride_b; uint32_t expert_stride_d; + uint32_t idx; uint32_t nbi1; uint32_t n_as; +}; + +struct vk_mat_vec_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; +}; + struct vk_op_push_constants { uint32_t KX; uint32_t KY; @@ -1003,201 +1026,422 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) { ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared(); ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared(); + /*ctx->device->pipeline_matmul_id_f32 = std::make_shared(); + ctx->device->pipeline_matmul_id_f16_f32 = std::make_shared(); + ctx->device->pipeline_matmul_id_f16 = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared(); + ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared();*/ + if (device->fp16) { - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_0_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_0_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_0_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_len, matmul_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_len, matmul_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_0_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_0_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_0_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + /*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_len, matmul_id_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_aligned_len, matmul_id_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_aligned_len, matmul_id_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);*/ } else { - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align); - } - - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + /*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_fp32_len, matmul_id_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_aligned_fp32_len, matmul_id_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_aligned_fp32_len, matmul_id_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);*/ + } + + // mul mat vec + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32_f32", mul_mat_vec_q2_K_f32_f32_len, mul_mat_vec_q2_K_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32_f32", mul_mat_vec_q3_K_f32_f32_len, mul_mat_vec_q3_K_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32_f32", mul_mat_vec_q4_K_f32_f32_len, mul_mat_vec_q4_K_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32_f32", mul_mat_vec_q5_K_f32_f32_len, mul_mat_vec_q5_K_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32_f32", mul_mat_vec_q6_K_f32_f32_len, mul_mat_vec_q6_K_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f16_f32", mul_mat_vec_q2_K_f16_f32_len, mul_mat_vec_q2_K_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f16_f32", mul_mat_vec_q3_K_f16_f32_len, mul_mat_vec_q3_K_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f16_f32", mul_mat_vec_q4_K_f16_f32_len, mul_mat_vec_q4_K_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f16_f32", mul_mat_vec_q5_K_f16_f32_len, mul_mat_vec_q5_K_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f16_f32", mul_mat_vec_q6_K_f16_f32_len, mul_mat_vec_q6_K_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + /*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_K_f32", mul_mat_vec_id_q2_K_f32_len, mul_mat_vec_id_q2_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_K_f32", mul_mat_vec_id_q3_K_f32_len, mul_mat_vec_id_q3_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_K_f32", mul_mat_vec_id_q4_K_f32_len, mul_mat_vec_id_q4_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_K_f32", mul_mat_vec_id_q5_K_f32_len, mul_mat_vec_id_q5_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_K_f32", mul_mat_vec_id_q6_K_f32_len, mul_mat_vec_id_q6_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);*/ // dequant shaders ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -1258,6 +1502,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) { ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -1686,11 +1931,48 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte return ctx->device->pipeline_dequant_mul_mat_mat[src0_type]; } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) { +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) { +#ifdef GGML_VULKAN_DEBUG + std::cerr << "ggml_vk_get_mul_mat_mat_id_pipeline()" << std::endl; +#endif + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f32; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16; + } + + GGML_ASSERT(src1_type == GGML_TYPE_F32); + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type]; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl; #endif - switch (type) { + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); + + switch (a_type) { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -1707,7 +1989,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * return nullptr; } - return ctx->device->pipeline_dequant_mul_mat_vec_f32[type]; + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type]; } static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { @@ -1913,6 +2195,9 @@ static void ggml_vk_ctx_begin(ggml_backend_vk_context * ctx, vk_context * subctx } static size_t ggml_vk_align_size(size_t width, size_t align) { +#ifdef GGML_VULKAN_DEBUG + std::cerr << "ggml_vk_align_size(" << width << ", " << align << ")" << std::endl; +#endif return CEIL_DIV(width, align) * align; } @@ -2368,11 +2653,13 @@ static uint32_t ggml_vk_guess_split_k(int m, int n, int k) { #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")" << std::endl; #endif - if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) { - return 4; - } + // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) { + // return 4; + // } return 1; + + GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k); } static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { @@ -2424,25 +2711,58 @@ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ct #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl; #endif - return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, false)->align; + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align; +} + +static void ggml_vk_matmul( + ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t expert_stride_b, uint32_t expert_stride_d, uint32_t idx, uint32_t nbi1, uint32_t n_as) { +#ifdef GGML_VULKAN_DEBUG + std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), c: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << split_k_buffer.buffer->buffer << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl; +#endif + ggml_vk_sync_buffers(subctx); + if (split_k == 1) { + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, expert_stride_b, expert_stride_d, idx, nbi1, n_as }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); + return; + } + + GGML_ASSERT(batch_stride_d == m * n); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, expert_stride_b, expert_stride_d, idx, nbi1, n_as }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); } -static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) { +static void ggml_vk_matmul_id( + ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, + vk_subbuffer&& ids, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& a, vk_subbuffer&& split_k_buffer, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t expert_stride_b, uint32_t expert_stride_d, uint32_t idx, uint32_t nbi1, uint32_t n_as) { #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), c: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << split_k_buffer.buffer->buffer << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl; #endif ggml_vk_sync_buffers(subctx); if (split_k == 1) { - const std::array pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc.size() * sizeof(uint32_t), pc.data(), { m, n, batch }); + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, expert_stride_b, expert_stride_d, idx, nbi1, n_as }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ids, b, d, a }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); return; } GGML_ASSERT(batch_stride_d == m * n); - const std::array pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d }; + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, expert_stride_b, expert_stride_d, idx, nbi1, n_as }; // Make sure enough workgroups get assigned for split k to work - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ids, b, split_k_buffer, a }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ggml_vk_sync_buffers(subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); @@ -2557,7 +2877,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su const int d_ne = ne11 * ne01; const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); - const bool aligned = ne10 == kpad; + const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); @@ -2655,7 +2975,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su } // compute - ggml_vk_matmul(ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21); // NOLINT + ggml_vk_matmul( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21, + 0, 0, 0, 0, 1 + ); // NOLINT if (dst->backend == GGML_BACKEND_TYPE_CPU) { // copy dst to host @@ -2685,8 +3011,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context GGML_ASSERT(ne11 == 1); - const uint64_t nb2 = dst->nb[2]; - const uint64_t nb3 = dst->nb[3]; + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; const uint64_t r2 = ne12 / ne02; const uint64_t r3 = ne13 / ne03; @@ -2718,6 +3046,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context const bool qx_needs_dequant = x_non_contig; const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + const uint64_t x_ne = ne01 * ne00; const uint64_t y_ne = ne11 * ne10; const uint64_t d_ne = ne11 * ne01; @@ -2770,7 +3101,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type); + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type); GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); @@ -2793,43 +3124,25 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); } - for (uint64_t i13 = 0; i13 < ne13; i13++) { - const uint64_t i03 = i13 / r3; - for (uint64_t i12 = 0; i12 < ne12; i12++) { - const uint64_t i02 = i12 / r2; - - const uint64_t it_idx0 = (i03 * ne02 + i02); - const uint64_t it_idx1 = (i13 * ne12 + i12); - const uint64_t x_offset = x_buf_offset + x_sz * it_idx0; - const uint64_t qy_offset = qy_buf_offset + qy_sz * it_idx1; - const uint64_t y_offset = y_buf_offset + y_sz * it_idx1; - const uint64_t d_offset = d_buf_offset + d_sz * it_idx1; - - const uint64_t y_buffer_offset = (y_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; - const uint64_t y_shader_offset = y_offset - y_buffer_offset; - - const uint64_t d_buffer_offset = (d_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; - const uint64_t d_shader_offset = d_offset - d_buffer_offset; - - if (!y_non_contig && qy_needs_dequant) { - const std::vector pc = { (uint32_t)ne11, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(y_ne / 32) }; - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)y_ne, 1, 1}); - } + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; - // compute - const std::array pc = { (uint32_t)ne00, (uint32_t)(y_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type))}; - ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1}); + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } - if (dst->backend == GGML_BACKEND_TYPE_CPU) { - // copy dst to host - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); - ggml_vk_sync_buffers(subctx); - ggml_vk_buffer_read_async(ctx, subctx, d_D, d_offset, d, sizeof(float) * d_ne); - } - } + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } + + // compute + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21), + }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1}); } static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -3011,7 +3324,7 @@ static bool ggml_vk_can_mul_mat(const ggml_tensor * src0, const ggml_tensor * sr ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU); } -static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_VULKAN_DEBUG std::cerr << "ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")" << std::endl; #endif @@ -3026,9 +3339,347 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, } } -// static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { -// -// } +/*static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +#ifdef GGML_VULKAN_DEBUG + std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", backend=" << ids->backend << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; +#endif + GGML_ASSERT(src0->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint32_t nb11 = src1->nb[1]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + const uint32_t nbi1 = src0->nb[1]; + const uint32_t idx = ((uint32_t *) dst->op_params)[0]; + const uint64_t n_as = ne02; + + GGML_ASSERT(n_as <= 8); + + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; + ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; + ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra; + + vk_buffer d_Qx; + size_t qx_buf_offset = 0; + vk_buffer d_Qy; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type); + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + + if (mmp == nullptr) { + GGML_ASSERT(false); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); + const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; + + const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); + + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_buffer d_D = extra->buffer_gpu.lock(); + const uint64_t d_buf_offset = extra->offset; + GGML_ASSERT(d_D != nullptr); + GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = extra_src0->buffer_gpu.lock(); + qx_buf_offset = extra_src0->offset; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = extra_src1->buffer_gpu.lock(); + qy_buf_offset = extra_src1->offset; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + // Allocate descriptor sets + ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + if (split_k > 1) { + ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1); + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + } + if (y_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21, + nb11 / ggml_type_size(src1->type), ne20, idx, nbi1, n_as + ); // NOLINT +} + +static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_VULKAN_DEBUG + std::cerr << "ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; +#endif + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t nb22 = dst->nb[2]; + const uint64_t nb23 = dst->nb[3]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; + ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; + + vk_buffer d_Qx; + size_t qx_buf_offset = 0; + vk_buffer d_Qy; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne11 * ne01; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_buffer d_D = extra->buffer_gpu.lock(); + const uint64_t d_buf_offset = extra->offset; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = extra_src0->buffer_gpu.lock(); + qx_buf_offset = extra_src0->offset; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = extra_src1->buffer_gpu.lock(); + qy_buf_offset = extra_src1->offset; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + // Allocate descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13); + } + ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13); + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21), + // 0, 0, 0, 0, 1 + }; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1}); +}*/ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { // guaranteed to be an integer due to the check in ggml_can_repeat @@ -3178,14 +3829,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SOFT_MAX: -#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16); if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_soft_max_f32_f16; + } return nullptr; case GGML_OP_ROPE: { @@ -3233,6 +3885,21 @@ static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) { } } +static bool ggml_vk_op_supports_incontiguous(ggml_op op) { + switch (op) { + case GGML_OP_CPY: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_CLAMP: + return true; + default: + return false; + } +} + template static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) { #ifdef GGML_VULKAN_DEBUG @@ -3284,6 +3951,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c return; } + const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; @@ -3345,7 +4014,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c GGML_ASSERT(d_Z != nullptr); } - if (op == GGML_OP_CPY || op == GGML_OP_GET_ROWS) { + if (op_supports_incontiguous) { x_sz = ggml_nbytes(src0); y_sz = use_src1 ? ggml_nbytes(src1) : 0; d_sz = ggml_nbytes(dst); @@ -3364,7 +4033,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c std::array elements; // Single call if dimension 2 is contiguous - if (op == GGML_OP_CPY || op == GGML_OP_GET_ROWS || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) { + if (op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) { ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1); switch (dst->op) { @@ -3385,7 +4054,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c break; } - if (op != GGML_OP_CPY && op != GGML_OP_GET_ROWS) { + if (!op_supports_incontiguous) { if (x_sz != VK_WHOLE_SIZE) { x_sz *= ne02 * ne03; } @@ -3403,14 +4072,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c if (use_src1) { subbuf_y = { d_Y, y_buf_offset, y_sz }; } else { - subbuf_y = { ctx->prealloc_y, 0, ctx->prealloc_y->size }; + subbuf_y = { d_X, 0, d_X->size }; } vk_subbuffer subbuf_z; if (use_src2) { subbuf_z = { d_Z, z_buf_offset, z_sz }; } else { - subbuf_z = { ctx->prealloc_y, 0, ctx->prealloc_y->size }; + subbuf_z = { d_X, 0, d_X->size }; } ggml_vk_sync_buffers(subctx); @@ -3582,7 +4251,9 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons } static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }); + float * op_params = (float *)dst->op_params; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); } static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -3617,7 +4288,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { ncols, - nrows_y, + src1 != nullptr ? nrows_y : (uint32_t)0, src2 != nullptr ? (uint32_t)1 : (uint32_t)0, scale, max_bias, m0, m1, @@ -3834,7 +4505,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); for (size_t i = 0; i < num_it; i++) { ggml_vk_ctx_begin(ctx, subctx); - ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n); + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), + m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n, 0, 0, 0, 0, 1 + ); ggml_vk_ctx_end(subctx); } @@ -4339,7 +5013,10 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); for (size_t i = 0; i < num_it; i++) { ggml_vk_ctx_begin(ctx, subctx); - ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n); + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), + m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n, 0, 0, 0, 0, 1 + ); ggml_vk_ctx_end(subctx); } @@ -4590,6 +5267,8 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K); + ggml_vk_test_matmul(ctx, 512, 512, 100, 32, 100, 1, 2); + ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 1, 0); ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 1, 1); ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 1, 2); @@ -4739,7 +5418,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){ - if (ctx->disable || node->backend != GGML_BACKEND_TYPE_GPU) { + if (ctx->disable || node->backend != GGML_BACKEND_TYPE_GPU || ggml_is_empty(node)) { return; } @@ -5539,7 +6218,7 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen int last_node = cgraph->n_nodes - 1; // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly - while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_TYPE_GPU) { + while (last_node > 0 && (cgraph->nodes[last_node]->backend != GGML_BACKEND_TYPE_GPU || ggml_is_empty(cgraph->nodes[last_node]))) { last_node -= 1; } @@ -5577,6 +6256,8 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen } GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tensor * op) { + // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context; + switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -6199,7 +6880,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); } } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { - tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(float *)tensor->op_params); + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_ROPE) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; @@ -6247,6 +6928,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ tensor_clone = ggml_transpose(ggml_ctx, src0_clone); } else if (tensor->op == GGML_OP_GET_ROWS) { tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_ARGSORT) { + tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ASSERT(false); @@ -6280,7 +6963,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ free(src1_buffer); } if (src2 != nullptr) { - free(src1_buffer); + free(src2_buffer); } ggml_free(ggml_ctx); From fe454b8d9e74c3cac70a469922037987e95d85f4 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Fri, 10 May 2024 01:32:15 +0100 Subject: [PATCH 051/100] Minor arithmetic improvement to mmvq wrapper kernel (llama/7172) --- ggml-sycl.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 57fe4ea3d4a..79aec4d9f02 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -8330,24 +8330,26 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ const int blocks_per_row = ncols / qk; const int blocks_per_warp = vdr * WARP_SIZE / qi; -// partial sum for each thread + const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block + + // partial sum for each thread float tmp = 0.0f; const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; - for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i; // x block index + const int ibx = row * blocks_per_row + i; // x block index - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx - const int iqs = - vdr * - (item_ct1.get_local_id(2) % - (qi / vdr)); // x block quant index when casting the quants to int + const int iqs = + vdr * + (item_ct1.get_local_id(2) - + i * qi_vdr); // x block quant index when casting the quants to int - tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); } // sum up partial sums and write back result From 284fac39fbe20716907327dd7af786cc38eb5049 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 May 2024 18:20:10 +0300 Subject: [PATCH 052/100] metal : fix flash attention kernel requirements (llama/7169) * metal : fix flash attention kernel requirements ggml-ci * metal : fix ggml_metal_supports_op ggml-ci --- ggml-metal.m | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 038a5061f9b..26e01e415db 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -635,14 +635,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -775,8 +775,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: - case GGML_OP_FLASH_ATTN_EXT: return true; + case GGML_OP_FLASH_ATTN_EXT: + return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction && From e54329da7b3463a99c07a7b7e0dcece347057018 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 11 May 2024 10:32:41 +0300 Subject: [PATCH 053/100] ggml : full ALiBi support (llama/7192) * ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models --- ggml-cuda.cu | 5 - ggml-cuda/fattn.cu | 72 ++++++++-- ggml-cuda/softmax.cu | 55 +++----- ggml-kompute.cpp | 12 +- ggml-metal.m | 148 ++++++++------------- ggml-metal.metal | 120 +++++++---------- ggml-sycl.cpp | 138 +++---------------- ggml-vulkan.cpp | 6 +- ggml.c | 309 ++++++------------------------------------- ggml.h | 18 +-- 10 files changed, 261 insertions(+), 622 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ceb66170edd..5b6c9091924 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4,7 +4,6 @@ #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" -#include "ggml-cuda/alibi.cuh" #include "ggml-cuda/arange.cuh" #include "ggml-cuda/argsort.cuh" #include "ggml-cuda/binbcast.cuh" @@ -2280,9 +2279,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; - case GGML_OP_ALIBI: - ggml_cuda_op_alibi(ctx, dst); - break; case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; @@ -2833,7 +2829,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: - case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 7c486f4829b..ac5d6672b30 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16( const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); + half slopeh = __float2half(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + } + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; @@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { sum2[j] = warp_reduce_sum(sum2[j]); half sum = __low2half(sum2[j]) + __high2half(sum2[j]); - sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16( const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); + half slopeh = __float2half(1.0f); + half2 slope2 = make_half2(1.0f, 1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + slope2 = make_half2(slopeh, slopeh); + } + frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: @@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -710,8 +744,17 @@ template void launch_fattn_vec_ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_vec_ext_f16 <<>> ( @@ -720,7 +763,7 @@ template void launch_fattn_vec_ (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -761,8 +804,17 @@ template ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_ext_f16 <<>> ( @@ -771,7 +823,7 @@ template data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - const int32_t precision = KQV->op_params[1]; + const int32_t precision = KQV->op_params[2]; if (!fp16_mma_available(cc)) { GGML_ASSERT(precision == GGML_PREC_DEFAULT); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 6ed225999bd..ca85285a3f4 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32(half val) { } template -static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -23,16 +23,16 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { const int h = rowx/nrows_y; // head index const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - slope = powf(base, exp); + slope = powf(base, exph); } extern __shared__ float data_soft_max_f32[]; @@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p } template -static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * src2 = dst->src[2]; const float * src0_d = (const float *)src0->data; const void * src1_d = src1 ? (const void *)src1->data : nullptr; @@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - // positions tensor - void * src2_d = nullptr; - - const bool use_src2 = src2 != nullptr; - - if (use_src2) { - src2_d = (void *)src2->data; - } - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (use_f16) { const half * src1_dd = (const half *)src1_d; - const half * src2_dd = (const half *)src2_d; - soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } else { const float * src1_dd = (const float *)src1_d; - const float * src2_dd = (const float *)src2_d; - soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } } diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 9a469821d80..3f033d58be4 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml case GGML_OP_SOFT_MAX: { float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float max_bias; -#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") + memcpy(&scale, (float *)dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); - GGML_ASSERT(src2 == nullptr); + +#pragma message("TODO: add ALiBi support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") + GGML_ASSERT(max_bias == 0.0f); ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; diff --git a/ggml-metal.m b/ggml-metal.m index 26e01e415db..66c398d54fd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -170,7 +170,6 @@ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_ROPE_F32, GGML_METAL_KERNEL_TYPE_ROPE_F16, - GGML_METAL_KERNEL_TYPE_ALIBI_F32, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, @@ -625,7 +624,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); @@ -762,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_GROUP_NORM: return ctx->support_simdgroup_reduction; case GGML_OP_NORM: - case GGML_OP_ALIBI: case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; @@ -1373,13 +1370,12 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_SOFT_MAX: { GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); int nth = 32; // SIMD width id pipeline = nil; - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { @@ -1410,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -1423,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; } - if (id_src2) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:7]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:9]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:10]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -2241,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT((src0t == GGML_TYPE_F32)); - - const int nth = MIN(1024, ne00); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; - [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; case GGML_OP_ROPE: { GGML_ASSERT(ne10 == ne02); @@ -2581,7 +2529,7 @@ static enum ggml_status ggml_metal_graph_compute( "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - const int64_t ne31 = src3 ? src3->ne[1] : 0; + //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); @@ -2593,7 +2541,16 @@ static enum ggml_status ggml_metal_graph_compute( const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float max_bias; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); id pipeline = nil; @@ -2630,34 +2587,37 @@ static enum ggml_status ggml_metal_graph_compute( } [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; - [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&scale length:sizeof( float) atIndex:26]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:27]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:28]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:29]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml-metal.metal b/ggml-metal.metal index b67d1882f00..f8b07400c9a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -363,7 +363,6 @@ template kernel void kernel_soft_max( device const char * src0, device const char * src1, - device const char * src2, device char * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -385,10 +384,9 @@ kernel void kernel_soft_max( device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { @@ -404,7 +402,7 @@ kernel void kernel_soft_max( float lmax = -INFINITY; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block @@ -429,7 +427,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -468,7 +466,6 @@ template kernel void kernel_soft_max_4( device const char * src0, device const char * src1, - device const char * src2, device char * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -490,10 +487,9 @@ kernel void kernel_soft_max_4( device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - float slope = 0.0f; + float slope = 1.0f; if (max_bias > 0.0f) { const int64_t h = i02; @@ -508,7 +504,7 @@ kernel void kernel_soft_max_4( float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -534,7 +530,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -1602,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4( } } -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - const int64_t k = i3*ne3 + i2; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = pow(m0, k + 1); - } else { - m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; - device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - const float src_v = *(device float *)(src_row + i00*nb00); - device float * dst_v = (device float *)(dst_row + i00*nb0); - *dst_v = i00 * m_k + src_v; - } -} - static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -2123,13 +2065,16 @@ typedef void (flash_attn_ext_f16_t)( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2161,13 +2106,16 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2264,6 +2212,19 @@ kernel void kernel_flash_attn_ext_f16( // prepare diagonal scale matrix simdgroup_float8x8 mscale(scale); + // prepare diagonal slope matrix + simdgroup_float8x8 mslope(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const short h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + mslope = simdgroup_float8x8(pow(base, exph)); + } + // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { @@ -2286,9 +2247,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } - // mqk = mqk*scale + mask + // mqk = mqk*scale + mask*slope simdgroup_half8x8 mm; simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply(mm, mslope, mm); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, TF, 0, false); @@ -2479,13 +2441,16 @@ kernel void kernel_flash_attn_ext_vec_f16( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2504,6 +2469,18 @@ kernel void kernel_flash_attn_ext_vec_f16( const short T = D + 2*nsg*SH; // shared memory size per query in (half) + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const short h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix @@ -2610,10 +2587,10 @@ kernel void kernel_flash_attn_ext_vec_f16( mqk += simd_shuffle_down(mqk, 2); mqk += simd_shuffle_down(mqk, 1); - // mqk = mqk*scale + mask + // mqk = mqk*scale + mask*slope if (tiisg == 0) { float4 mm = (float4) mp4[ic/4 + cc]; - mqk = mqk*scale + mm; + mqk = mqk*scale + mm*slope; ss4[cc] = mqk; } @@ -2847,7 +2824,8 @@ kernel void kernel_cpy_f32_f16( for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + // TODO: is there a better way to handle -INFINITY? + dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; } } diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 79aec4d9f02..e93d2af631c 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)( #define SYCL_SCALE_BLOCK_SIZE 256 #define SYCL_CLAMP_BLOCK_SIZE 256 #define SYCL_ROPE_BLOCK_SIZE 256 -#define SYCL_ALIBI_BLOCK_SIZE 32 #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32 #define SYCL_QUANTIZE_BLOCK_SIZE 256 #define SYCL_DEQUANTIZE_BLOCK_SIZE 256 @@ -9316,32 +9315,6 @@ static void rope_glm_f32( dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; } -static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, - const int n_heads_log2_floor, const float m0, const float m1, - const sycl::nd_item<3> &item_ct1) { - const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (col >= ncols) { - return; - } - - const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i = row*ncols + col; - - const int k = row/k_rows; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = dpct::pow(m0, k + 1); - } else { - m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - dst[i] = col * m_k + x[i]; -} - static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> &item_ct1) { const int row = item_ct1.get_group(1); @@ -9443,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con template -static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, +static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -9457,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos, const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { @@ -9482,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos, const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows, }); } -static void alibi_f32_sycl(const float *x, float *dst, const int ncols, - const int nrows, const int k_rows, - const int n_heads_log2_floor, const float m0, - const float m1, dpct::queue_ptr stream) { - const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE); - const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE); - const sycl::range<3> block_nums(1, nrows, num_blocks_x); - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - alibi_f32(x, dst, ncols, k_rows, - n_heads_log2_floor, m0, m1, item_ct1); - }); -} - static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, dpct::queue_ptr stream) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -13058,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst, } template -static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, +static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, dpct::queue_ptr stream) { @@ -13068,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - soft_max_f32(x, mask, pos, dst, ncols_par, + soft_max_f32(x, mask, dst, ncols_par, nrows_y, scale, max_bias, m0, m1, n_head_log2, item_ct1, local_buf_acc.get_pointer()); @@ -13076,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl }); } -static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos, +static void soft_max_f32_sycl(const float * x, const float * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, dpct::queue_ptr stream) { @@ -13098,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float * const size_t local_mem_size = stream->get_device().get_info(); if (n_local_scratch*sizeof(float) < local_mem_size) { if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); return; } switch (ncols_x) { case 32: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 64: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 128: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 256: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 512: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 1024: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 2048: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 4096: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; default: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; } } else { - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, WARP_SIZE, stream); } @@ -14562,36 +14521,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne); - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - //GGML_ASSERT(ne01 + n_past == ne00); - GGML_ASSERT(n_head == ne02); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream); - - (void) src1; - (void) src1_dd; -} - static void ggml_sycl_op_pool2d(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, @@ -14746,12 +14675,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const ggml_tensor * src2 = dst->src[2]; - -#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support") +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -14763,25 +14689,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - // positions tensor - float * src2_dd = nullptr; - sycl_pool_alloc src2_f; - - const bool use_src2 = src2 != nullptr; - - if (use_src2) { - const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU; - - if (src2_on_device) { - ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra; - src2_dd = (float *) src2_extra->data_device[g_main_device]; - } else { - src2_dd = src2_f.alloc(ggml_nelements(src2)); - SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream)); - } - } - - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, + soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream); } @@ -16232,10 +16140,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope); } -static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi); -} - static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d); } @@ -16612,9 +16516,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_ROPE: func = ggml_sycl_rope; break; - case GGML_OP_ALIBI: - func = ggml_sycl_alibi; - break; case GGML_OP_IM2COL: func = ggml_sycl_im2col; break; @@ -17744,7 +17645,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: - case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 95f71897405..b9449be0357 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16); - if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { @@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); +#pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, diff --git a/ggml.c b/ggml.c index 118d3f541f4..75621d3557b 100644 --- a/ggml.c +++ b/ggml.c @@ -2186,7 +2186,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SOFT_MAX_BACK", "ROPE", "ROPE_BACK", - "ALIBI", "CLAMP", "CONV_TRANSPOSE_1D", "IM2COL", @@ -2228,7 +2227,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2277,7 +2276,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "soft_max_back(x)", "rope(x)", "rope_back(x)", - "alibi(x)", "clamp(x)", "conv_transpose_1d(x)", "im2col(x)", @@ -2319,7 +2317,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5662,7 +5660,6 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias, bool inplace) { @@ -5676,18 +5673,8 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->ne[1] >= a->ne[1]); } - if (pos) { - GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); - GGML_ASSERT(pos->ne[0] == a->ne[0]); - } - - if (pos && mask) { - GGML_ASSERT(pos->type == mask->type); - } - if (max_bias > 0.0f) { - GGML_ASSERT(pos); + GGML_ASSERT(mask); } bool is_node = false; @@ -5705,7 +5692,6 @@ static struct ggml_tensor * ggml_soft_max_impl( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = mask; - result->src[2] = pos; return result; } @@ -5713,23 +5699,22 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * ggml_soft_max( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false); } struct ggml_tensor * ggml_soft_max_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true); } struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias) { - return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false); + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } // ggml_soft_max_back @@ -5944,37 +5929,6 @@ struct ggml_tensor * ggml_rope_back( return result; } -// ggml_alibi - -struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head, - float bias_max) { - GGML_ASSERT(n_past >= 0); - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - int32_t op_params[3] = { n_past, n_head }; - memcpy(op_params + 2, &bias_max, sizeof(float)); - ggml_set_op_params(result, op_params, sizeof(op_params)); - - result->op = GGML_OP_ALIBI; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - // ggml_clamp struct ggml_tensor * ggml_clamp( @@ -6502,9 +6456,11 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * mask, - float scale) { + float scale, + float max_bias) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) + if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); @@ -6514,6 +6470,10 @@ struct ggml_tensor * ggml_flash_attn_ext( //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + bool is_node = false; if (q->grad || k->grad || v->grad) { @@ -6524,7 +6484,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale }; + float params[] = { scale, max_bias }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -6544,7 +6504,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos + ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_ff @@ -13395,7 +13355,6 @@ static void ggml_compute_forward_soft_max_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src2 = dst->src[2]; assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); @@ -13421,8 +13380,8 @@ static void ggml_compute_forward_soft_max_f32( // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 - const uint32_t n_head_kv = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv)); + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -13439,13 +13398,13 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; - // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; - float * pos_f32 = src2 ? (float *) src2->data : src0->data; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { + // ALiBi + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); @@ -13458,27 +13417,11 @@ static void ggml_compute_forward_soft_max_f32( if (mp_f32) { if (use_f16) { for (int i = 0; i < nc; ++i) { - wp[i] += GGML_FP16_TO_FP32(mp_f16[i]); + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); } } else { for (int i = 0; i < nc; ++i) { - wp[i] += mp_f32[i]; - } - } - } - - // ALiBi bias - if (max_bias > 0.0f) { - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*pos_f32[i]; + wp[i] += slope*mp_f32[i]; } } } @@ -13640,178 +13583,6 @@ static void ggml_compute_forward_soft_max_back( } } -// ggml_compute_forward_alibi - -static void ggml_compute_forward_alibi_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int64_t ne1 = src0->ne[1]; // seq_len_without_past - const int64_t ne2 = src0->ne[2]; // n_head -> this is k - //const int64_t ne3 = src0->ne[3]; // 1 -> bsz - - const int64_t n = ggml_nrows(src0); - const int64_t ne2_ne3 = n/ne1; // ne2*ne3 - - const size_t nb0 = src0->nb[0]; - const size_t nb1 = src0->nb[1]; - const size_t nb2 = src0->nb[2]; - //const int nb3 = src0->nb[3]; - - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(n_head == ne2); - - // add alibi to src0 (KQ_scaled) - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - for (int64_t k = 0; k < ne2_ne3; k++) { - // TODO: k*nb2 or k*nb3 - float m_k; - - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - for (int64_t i = 0; i < ne0; i++) { - for (int64_t j = 0; j < ne1; j++) { - float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); - float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); - pdst[0] = i * m_k + src[0]; - } - } - } -} - -static void ggml_compute_forward_alibi_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int ne1 = src0->ne[1]; // seq_len_without_past - const int ne2 = src0->ne[2]; // n_head -> this is k - //const int ne3 = src0->ne[3]; // 1 -> bsz - - const int n = ggml_nrows(src0); - const int ne2_ne3 = n/ne1; // ne2*ne3 - - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; - //const int nb3 = src0->nb[3]; - - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; - GGML_ASSERT(n_head == ne2); - - // add alibi to src0 (KQ_scaled) - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - for (int k = 0; k < ne2_ne3; k++) { - // TODO: k*nb2 or k*nb3 - float m_k; - - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - for (int i = 0; i < ne0; i++) { - for (int j = 0; j < ne1; j++) { - ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); - float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); - - // we return F32 - pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); - } - } - } -} - -static void ggml_compute_forward_alibi( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_alibi_f16(params, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_alibi_f32(params, dst); - } break; - case GGML_TYPE_BF16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q8_K: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_I64: - case GGML_TYPE_F64: - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_clamp static void ggml_compute_forward_clamp_f32( @@ -15825,8 +15596,17 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float scale = 1.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -15835,6 +15615,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + const uint32_t h = iq2; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + float S = 0.0f; float M = -INFINITY; @@ -15858,7 +15641,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } @@ -15929,7 +15712,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[1]) { + switch (dst->op_params[2]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { @@ -17696,10 +17479,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rope_back(params, tensor); } break; - case GGML_OP_ALIBI: - { - ggml_compute_forward_alibi(params, tensor); - } break; case GGML_OP_CLAMP: { ggml_compute_forward_clamp(params, tensor); @@ -18718,10 +18497,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT(false); // TODO: not implemented - } break; case GGML_OP_CLAMP: { GGML_ASSERT(false); // TODO: not implemented @@ -19499,10 +19274,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ { n_tasks = n_threads; } break; - case GGML_OP_ALIBI: - { - n_tasks = 1; //TODO - } break; case GGML_OP_CLAMP: { n_tasks = 1; //TODO diff --git a/ggml.h b/ggml.h index bc9efcf408d..c004a86fcf8 100644 --- a/ggml.h +++ b/ggml.h @@ -468,7 +468,6 @@ extern "C" { GGML_OP_SOFT_MAX_BACK, GGML_OP_ROPE, GGML_OP_ROPE_BACK, - GGML_OP_ALIBI, GGML_OP_CLAMP, GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, @@ -1437,15 +1436,13 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope)) + // fused soft_max(a*scale + mask*(ALiBi slope)) // mask is optional - // pos is required when max_bias > 0.0f // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias); @@ -1547,16 +1544,6 @@ extern "C" { float xpos_base, bool xpos_down); - // alibi position embedding - // in-place, returns view(a) - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head, - float bias_max), - "use ggml_soft_max_ext instead (will be removed in Mar 2024)"); - // clamp // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_clamp( @@ -1753,7 +1740,8 @@ extern "C" { struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * mask, - float scale); + float scale, + float max_bias); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, From accada542ae866ff615dd1f7ad4a87bd84851e57 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 11 May 2024 16:25:50 +0300 Subject: [PATCH 054/100] ggml : resolve merge (ggml/0) ggml-ci --- examples/common-ggml.cpp | 2 ++ ggml-metal.metal | 6 +++--- ggml.c | 5 ++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index 4ea8e44af02..d8dbc88a01e 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -71,6 +71,7 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_IQ4_NL: case GGML_FTYPE_MOSTLY_IQ4_XS: case GGML_FTYPE_MOSTLY_IQ1_M: + case GGML_FTYPE_MOSTLY_BF16: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -207,6 +208,7 @@ bool ggml_common_quantize_0( case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ1_M: + case GGML_TYPE_BF16: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); diff --git a/ggml-metal.metal b/ggml-metal.metal index f8b07400c9a..7af4e8f9342 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -296,7 +296,7 @@ kernel void kernel_silu( dst[tpig] = x / (1.0f + exp(-x)); } -+kernel void kernel_silu_4( +kernel void kernel_silu_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -2217,7 +2217,7 @@ kernel void kernel_flash_attn_ext_f16( // ALiBi if (max_bias > 0.0f) { - const short h = iq2; + const uint32_t h = iq2; const float base = h < n_head_log2 ? m0 : m1; const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; @@ -2473,7 +2473,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // ALiBi if (max_bias > 0.0f) { - const short h = iq2; + const uint32_t h = iq2; const float base = h < n_head_log2 ? m0 : m1; const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; diff --git a/ggml.c b/ggml.c index 75621d3557b..263073b1c3c 100644 --- a/ggml.c +++ b/ggml.c @@ -4,7 +4,6 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml.h" -#include "sgemm.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -37,6 +36,10 @@ #undef GGML_USE_LLAMAFILE #endif +#ifdef GGML_USE_LLAMAFILE +#include "sgemm.h" +#endif + #if defined(_MSC_VER) // disable "possible loss of data" to avoid hundreds of casts // we should just be careful :) From 91c646c61df048b2a0263b8ccc69c09dbc41a55e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 11 May 2024 16:50:54 +0300 Subject: [PATCH 055/100] ggml : restore sigmoid decl order (ggml/0) --- ggml.c | 30 +++++++++++++++--------------- ggml.h | 8 ++++---- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/ggml.c b/ggml.c index 263073b1c3c..b96a82a4151 100644 --- a/ggml.c +++ b/ggml.c @@ -1951,8 +1951,8 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } +inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } // TODO: optimize performance inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } @@ -4545,20 +4545,6 @@ struct ggml_tensor * ggml_relu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); } -// ggml_sigmoid - -struct ggml_tensor * ggml_sigmoid( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID); -} - -struct ggml_tensor * ggml_sigmoid_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID); -} - // ggml_leaky_relu struct ggml_tensor * ggml_leaky_relu( @@ -4580,6 +4566,20 @@ struct ggml_tensor * ggml_leaky_relu( return result; } +// ggml_sigmoid + +struct ggml_tensor * ggml_sigmoid( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID); +} + +struct ggml_tensor * ggml_sigmoid_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID); +} + // ggml_gelu struct ggml_tensor * ggml_gelu( diff --git a/ggml.h b/ggml.h index c004a86fcf8..3fe95ed5763 100644 --- a/ggml.h +++ b/ggml.h @@ -1066,10 +1066,6 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - GGML_API struct ggml_tensor * ggml_sigmoid( - struct ggml_context * ctx, - struct ggml_tensor * a); - GGML_API struct ggml_tensor * ggml_leaky_relu( struct ggml_context * ctx, struct ggml_tensor * a, float negative_slope, bool inplace); @@ -1078,6 +1074,10 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sigmoid( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sigmoid_inplace( struct ggml_context * ctx, struct ggml_tensor * a); From 5a863fbe18990e5f2a63e8c96ca0c0bf42a46d2d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 11 May 2024 16:57:53 +0300 Subject: [PATCH 056/100] metal : fix indent (ggml/0) --- ggml-metal.m | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 66c398d54fd..28dec762a8a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1195,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - float min; - float max; - memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; - const int64_t n = ggml_nelements(dst); + const int64_t n = ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { // we are not taking into account the strides, so for now require contiguous tensors From 40aeeeecc4b8700b2a7e50cbcfa5c5412f2626ab Mon Sep 17 00:00:00 2001 From: Hong Bo PENG Date: Sun, 12 May 2024 17:17:18 +0800 Subject: [PATCH 057/100] ggml : optimize for ppc64le using VSX intrinsics (ggml/784) * optimize for ppc64le using VSX intrinsics * 1. code clean up by removing comments about overflow concern. 2. fix typo in suffix of scaling. * Continue to fix typo in suffix of scaling for QK_K <> 256 --------- Co-authored-by: Georgi Gerganov --- ggml-quants.c | 2169 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 2167 insertions(+), 2 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index 00334c5feb3..f711bd01341 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -241,7 +241,7 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #endif // __AVX__ || __AVX2__ || __AVX512F__ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -#if defined(__ARM_NEON) || defined(__wasm_simd128__) +#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__) #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) @@ -643,6 +643,38 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) // store result __riscv_vse8_v_i8m1(y[i].qs , vs, vl); } +#elif defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + } #else GGML_UNUSED(nb); // scalar @@ -898,6 +930,46 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); y[i].s = GGML_FP32_TO_FP16(sum*d); } +#elif defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = GGML_FP32_TO_FP16(d); + + vector int accv = vec_splats(0); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + + accv = vec_add(accv, vi[j]); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + + accv = vec_add(accv, vec_sld(accv, accv, 4)); + accv = vec_add(accv, vec_sld(accv, accv, 8)); + y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); + } #else GGML_UNUSED(nb); // scalar @@ -3740,6 +3812,46 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (int i = 0; i < nb; i++) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char q8y0 = vec_xl( 0, y[i].qs); + vector signed char q8y1 = vec_xl(16, y[i].qs); + + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_sub(q4x0, v8); + q4x1 = vec_sub(q4x1, v8); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + qv0 = vec_add(qv0, qv1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else // scalar float sumf = 0.0; @@ -3958,6 +4070,46 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r } *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (int i = 0; i < nb; i++) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].m)); + vector float vys = {GGML_FP16_TO_FP32(y[i].s), 0.0f, 0.0f, 0.0f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char q8y0 = vec_xl( 0, y[i].qs); + vector signed char q8y1 = vec_xl(16, y[i].qs); + + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + qv0 = vec_add(qv0, qv1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else // scalar float sumf = 0.0; @@ -4243,6 +4395,49 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r } *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[i].qh[0]]), (uint64_t)(table_b2b_1[x[i].qh[1]])}; + vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[i].qh[2]]), (uint64_t)(table_b2b_1[x[i].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); + + vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); + vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[i].qs); + vector signed char q8y1 = vec_xl( 16, y[i].qs); + + vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); + + qv0 = vec_add(qv0, qv1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else // scalar float sumf = 0.0; @@ -4547,6 +4742,53 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r } *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].m)); + vector float vys = {GGML_FP16_TO_FP32(y[i].s), 0.f, 0.f, 0.f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[i].qh[0]]), (uint64_t)(table_b2b_0[x[i].qh[1]])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[i].qh[2]]), (uint64_t)(table_b2b_0[x[i].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); + + vector signed char q5x0 = vec_or(vec_and(qxs, lowMask), qh0); + vector signed char q5x1 = vec_or(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[i].qs); + vector signed char q8y1 = vec_xl( 16, y[i].qs); + + vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); + + qv0 = vec_add(qv0, qv1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else // scalar float sumf = 0.0; @@ -4722,6 +4964,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } *s = sumf; +#elif defined(__POWER9_VECTOR__) + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (int i = 0; i < nb; i++) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char q8x0 = vec_xl( 0, x[i].qs); + vector signed char q8x1 = vec_xl(16, x[i].qs); + vector signed char q8y0 = vec_xl( 0, y[i].qs); + vector signed char q8y1 = vec_xl(16, y[i].qs); + + vector signed short qv0 = vec_mule(q8x0, q8y0); + vector signed short qv1 = vec_mulo(q8x0, q8y0); + vector signed short qv2 = vec_mule(q8x1, q8y1); + vector signed short qv3 = vec_mulo(q8x1, q8y1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackh(qv1)); + vector signed int vsumi1 = vec_add(vec_unpackl(qv0), vec_unpackl(qv1)); + vector signed int vsumi2 = vec_add(vec_unpackh(qv2), vec_unpackh(qv3)); + vector signed int vsumi3 = vec_add(vec_unpackl(qv2), vec_unpackl(qv3)); + + vsumi0 = vec_add(vsumi0, vsumi2); + vsumi1 = vec_add(vsumi1, vsumi3); + + vsumi0 = vec_add(vsumi0, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else // scalar float sumf = 0.0; @@ -5077,6 +5358,147 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowScaleMask = vec_splats((signed char)0xF); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales); + vector signed char vscales = vec_and(q2xmins, lowScaleMask); + + q2xmins = vec_sr(q2xmins, v4); + vector signed short q2xmins0 = vec_unpackh(q2xmins); + vector signed short q2xmins1 = vec_unpackl(q2xmins); + + vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); + vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); + vector signed int prod2 = vec_mule(q2xmins1, q8ysums1); + vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q2); + vector signed char qxs1 = (vector signed char)vec_xl(16, q2); + q2 += 32; + + vector signed char q2x00 = vec_and(qxs0, lowMask); + vector signed char q2x01 = vec_and(vec_sr(qxs0, v2), lowMask); + vector signed char q2x02 = vec_and(vec_sr(qxs0, v4), lowMask); + vector signed char q2x03 = vec_and(vec_sr(qxs0, v6), lowMask); + vector signed char q2x10 = vec_and(qxs1, lowMask); + vector signed char q2x11 = vec_and(vec_sr(qxs1, v2), lowMask); + vector signed char q2x12 = vec_and(vec_sr(qxs1, v4), lowMask); + vector signed char q2x13 = vec_and(vec_sr(qxs1, v6), lowMask); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed short qv0 = vec_add(vec_mule(q2x00, q8y00), vec_mulo(q2x00, q8y00)); + vector signed short qv1 = vec_add(vec_mule(q2x01, q8y01), vec_mulo(q2x01, q8y01)); + vector signed short qv2 = vec_add(vec_mule(q2x02, q8y02), vec_mulo(q2x02, q8y02)); + vector signed short qv3 = vec_add(vec_mule(q2x03, q8y03), vec_mulo(q2x03, q8y03)); + vector signed short qv4 = vec_add(vec_mule(q2x10, q8y10), vec_mulo(q2x10, q8y10)); + vector signed short qv5 = vec_add(vec_mule(q2x11, q8y11), vec_mulo(q2x11, q8y11)); + vector signed short qv6 = vec_add(vec_mule(q2x12, q8y12), vec_mulo(q2x12, q8y12)); + vector signed short qv7 = vec_add(vec_mule(q2x13, q8y13), vec_mulo(q2x13, q8y13)); + + vector signed short vscales_h = vec_unpackh(vscales); + vector signed short vs0 = vec_splat(vscales_h, 0); + vector signed short vs1 = vec_splat(vscales_h, 1); + vector signed short vs2 = vec_splat(vscales_h, 2); + vector signed short vs3 = vec_splat(vscales_h, 3); + vector signed short vs4 = vec_splat(vscales_h, 4); + vector signed short vs5 = vec_splat(vscales_h, 5); + vector signed short vs6 = vec_splat(vscales_h, 6); + vector signed short vs7 = vec_splat(vscales_h, 7); + vscales = vec_sld(vscales, vscales, 8); + + qv0 = vec_mul(qv0, vs0); + qv1 = vec_mul(qv1, vs2); + qv2 = vec_mul(qv2, vs4); + qv3 = vec_mul(qv3, vs6); + + qv0 = vec_madd(qv4, vs1, qv0); + qv1 = vec_madd(qv5, vs3, qv1); + qv2 = vec_madd(qv6, vs5, qv2); + qv3 = vec_madd(qv7, vs7, qv3); + + vsumi0 = vec_add(vec_unpackh(qv0), vsumi0); + vsumi1 = vec_add(vec_unpackh(qv1), vsumi1); + vsumi2 = vec_add(vec_unpackh(qv2), vsumi2); + vsumi3 = vec_add(vec_unpackh(qv3), vsumi3); + + vsumi4 = vec_add(vec_unpackl(qv0), vsumi4); + vsumi5 = vec_add(vec_unpackl(qv1), vsumi5); + vsumi6 = vec_add(vec_unpackl(qv2), vsumi6); + vsumi7 = vec_add(vec_unpackl(qv3), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + #else float sumf = 0; @@ -5347,6 +5769,87 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowScaleMask = vec_splats((signed char)0xF); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + +#pragma GCC unroll 2 + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl_len(y[i].bsums, 8); + + vector signed char q2xmins = (vector signed char)vec_xl_len(x[i].scales, 4); + vector signed char vscales = vec_and(q2xmins, lowScaleMask); + + q2xmins = vec_sr(q2xmins, v4); + vector signed short q2xmins0 = vec_unpackh((vector signed char)q2xmins); + + vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); + vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char q2x00 = vec_and(qxs0, lowMask); + vector signed char q2x01 = vec_and(vec_sr(qxs0, v2), lowMask); + vector signed char q2x02 = vec_and(vec_sr(qxs0, v4), lowMask); + vector signed char q2x03 = vec_and(vec_sr(qxs0, v6), lowMask); + + vector signed char q8y00 = vec_xl( 0, y[i].qs); + vector signed char q8y01 = vec_xl( 16, y[i].qs); + vector signed char q8y02 = vec_xl( 32, y[i].qs); + vector signed char q8y03 = vec_xl( 48, y[i].qs); + + vector signed short qv0 = vec_add(vec_mule(q2x00, q8y00), vec_mulo(q2x00, q8y00)); + vector signed short qv1 = vec_add(vec_mule(q2x01, q8y01), vec_mulo(q2x01, q8y01)); + vector signed short qv2 = vec_add(vec_mule(q2x02, q8y02), vec_mulo(q2x02, q8y02)); + vector signed short qv3 = vec_add(vec_mule(q2x03, q8y03), vec_mulo(q2x03, q8y03)); + + vector signed short vscales_h = vec_unpackh(vscales); + vector signed short vs0 = vec_splat(vscales_h, 0); + vector signed short vs1 = vec_splat(vscales_h, 1); + vector signed short vs2 = vec_splat(vscales_h, 2); + vector signed short vs3 = vec_splat(vscales_h, 3); + + vector signed int vsumi0 = vec_add(vec_mule(qv0, vs0), vec_mulo(qv0, vs0)); + vector signed int vsumi1 = vec_add(vec_mule(qv1, vs1), vec_mulo(qv1, vs1)); + vector signed int vsumi2 = vec_add(vec_mule(qv2, vs2), vec_mulo(qv2, vs2)); + vector signed int vsumi3 = vec_add(vec_mule(qv3, vs3), vec_mulo(qv3, vs3)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + #else float sumf = 0; @@ -5841,6 +6344,160 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char v1 = vec_splats((signed char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + uint32_t aux[3]; + uint32_t utmp[4]; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + vector signed char vscales = (vector signed char)vec_xl( 0, utmp); + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); + + vscales = vec_sub(vscales, off); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q3); + vector signed char qxs1 = (vector signed char)vec_xl(16, q3); + q3 += 32; + + //the low 2 bits + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); + vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask); + vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask); + vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask); + vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask); + + //the 3rd bit + vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); + vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2); + vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); + vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2); + vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2); + vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2); + vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2); + vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2); + qxhs0 = vec_sr(qxhs0, v4); + qxhs1 = vec_sr(qxhs1, v4); + + vector signed char q3x00 = vec_sub(qxs00, qxh00); + vector signed char q3x01 = vec_sub(qxs01, qxh01); + vector signed char q3x02 = vec_sub(qxs02, qxh02); + vector signed char q3x03 = vec_sub(qxs03, qxh03); + vector signed char q3x10 = vec_sub(qxs10, qxh10); + vector signed char q3x11 = vec_sub(qxs11, qxh11); + vector signed char q3x12 = vec_sub(qxs12, qxh12); + vector signed char q3x13 = vec_sub(qxs13, qxh13); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed short vscales_h = vec_unpackh(vscales); + vector signed short vs0 = vec_splat(vscales_h, 0); + vector signed short vs1 = vec_splat(vscales_h, 1); + vector signed short vs2 = vec_splat(vscales_h, 2); + vector signed short vs3 = vec_splat(vscales_h, 3); + vector signed short vs4 = vec_splat(vscales_h, 4); + vector signed short vs5 = vec_splat(vscales_h, 5); + vector signed short vs6 = vec_splat(vscales_h, 6); + vector signed short vs7 = vec_splat(vscales_h, 7); + vscales = vec_sld(vscales, vscales, 8); + + vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); + vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02)); + vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03)); + vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); + vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); + vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); + + vector signed int vsum0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); + vector signed int vsum1 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); + vector signed int vsum2 = vec_add(vec_mule(qv02, vs4), vec_mulo(qv02, vs4)); + vector signed int vsum3 = vec_add(vec_mule(qv03, vs6), vec_mulo(qv03, vs6)); + vector signed int vsum4 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); + vector signed int vsum5 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); + vector signed int vsum6 = vec_add(vec_mule(qv12, vs5), vec_mulo(qv12, vs5)); + vector signed int vsum7 = vec_add(vec_mule(qv13, vs7), vec_mulo(qv13, vs7)); + + vsumi0 = vec_add(vsum0, vsumi0); + vsumi1 = vec_add(vsum1, vsumi1); + vsumi2 = vec_add(vsum2, vsumi2); + vsumi3 = vec_add(vsum3, vsumi3); + vsumi4 = vec_add(vsum4, vsumi4); + vsumi5 = vec_add(vsum5, vsumi5); + vsumi6 = vec_add(vsum6, vsumi6); + vsumi7 = vec_add(vsum7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else // scalar version // This function is written like this so the compiler can manage to vectorize most of it @@ -6207,6 +6864,95 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char v1 = vec_splats((signed char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x8); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + +#pragma GCC unroll 2 + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + vector signed char vscales = (vector signed char)vec_xl_len(scales, 8); + vector signed char qxhs0 = (vector signed char)vec_xl_len(x[i].hmask, 8); + qxhs0 = vec_or(qxhs0, vec_sr(vec_sld(qxhs0, qxhs0, 8), (vector unsigned char)v1)); + + vscales = vec_sub(vscales, off); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); + vector signed char qxs10 = vec_and(vec_sr(qxs0, v4), lowMask); + vector signed char qxs11 = vec_and(vec_sr(qxs0, v6), lowMask); + + //the 3rd bit + vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); + vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); + vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v4)), v2); + vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v6)), v2); + qxhs0 = vec_sr(qxhs0, v4); + + vector signed char q3x00 = vec_sub(qxs00, qxh00); + vector signed char q3x01 = vec_sub(qxs01, qxh01); + vector signed char q3x10 = vec_sub(qxs10, qxh02); + vector signed char q3x11 = vec_sub(qxs11, qxh03); + + vector signed char q8y00 = vec_xl( 0, y[i].qs); + vector signed char q8y01 = vec_xl( 16, y[i].qs); + vector signed char q8y10 = vec_xl( 32, y[i].qs); + vector signed char q8y11 = vec_xl( 48, y[i].qs); + + vector signed short vscales_h = vec_unpackh(vscales); + vector signed short vs0 = vec_splat(vscales_h, 0); + vector signed short vs1 = vec_splat(vscales_h, 1); + vector signed short vs2 = vec_splat(vscales_h, 2); + vector signed short vs3 = vec_splat(vscales_h, 3); + + vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); + vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); + vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); + vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); + + vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); + vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); + vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); + vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else int8_t aux8[QK_K]; @@ -6559,6 +7305,142 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + memcpy(utmp, x[i].scales, 12); + + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vector signed char utmps = (vector signed char)vec_xl( 0, utmp); + vector signed short vscales = vec_unpackh(utmps); + vector signed short q4xmins = vec_unpackl(utmps); + vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); + vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins); + + vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q4xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; j+=2) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + vector signed char qxs2 = (vector signed char)vec_xl(32, q4); + vector signed char qxs3 = (vector signed char)vec_xl(48, q4); + q4 += 64; + + vector signed char q4x00 = vec_and(qxs0, lowMask); + vector signed char q4x01 = vec_sr(qxs0, v4); + vector signed char q4x10 = vec_and(qxs1, lowMask); + vector signed char q4x11 = vec_sr(qxs1, v4); + vector signed char q4x20 = vec_and(qxs2, lowMask); + vector signed char q4x21 = vec_sr(qxs2, v4); + vector signed char q4x30 = vec_and(qxs3, lowMask); + vector signed char q4x31 = vec_sr(qxs3, v4); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y20 = vec_xl( 64, q8); + vector signed char q8y30 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01)); + vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11)); + vector signed short qv20 = vec_add(vec_mule(q4x20, q8y20), vec_mulo(q4x20, q8y20)); + vector signed short qv21 = vec_add(vec_mule(q4x21, q8y21), vec_mulo(q4x21, q8y21)); + vector signed short qv30 = vec_add(vec_mule(q4x30, q8y30), vec_mulo(q4x30, q8y30)); + vector signed short qv31 = vec_add(vec_mule(q4x31, q8y31), vec_mulo(q4x31, q8y31)); + + vector signed short vs0 = vec_splat(vscales, 0); + vector signed short vs1 = vec_splat(vscales, 1); + vector signed short vs2 = vec_splat(vscales, 2); + vector signed short vs3 = vec_splat(vscales, 3); + vscales = vec_sld(vscales, vscales, 8); + + qv00 = vec_add(qv00, qv10); + qv10 = vec_add(qv01, qv11); + qv20 = vec_add(qv20, qv30); + qv30 = vec_add(qv21, qv31); + + vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); + vsumi2 = vec_add(vec_mule(qv10, vs1), vsumi2); + vsumi3 = vec_add(vec_mulo(qv10, vs1), vsumi3); + vsumi4 = vec_add(vec_mule(qv20, vs2), vsumi4); + vsumi5 = vec_add(vec_mulo(qv20, vs2), vsumi5); + vsumi6 = vec_add(vec_mule(qv30, vs3), vsumi6); + vsumi7 = vec_add(vec_mulo(qv30, vs3), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + #else @@ -6825,6 +7707,87 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + +#pragma GCC unroll 2 + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d[1])); + vector float vyd = vec_splats(y[i].d); + vector float vd= vec_mul(vxd, vyd); + + uint16_t s16[2]; + const uint8_t * scales = (const uint8_t *)s16; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + vector signed char utmps = (vector signed char)vec_xl_len(scales, 4); + vector signed short vscales = (vector signed short)vec_unpackh(utmps); + vector signed short q4xmins0 = vec_mergeh(vscales, vscales); + q4xmins0 = vec_sld(q4xmins0, q4xmins0, 8); + + vector signed short q8ysums0 = vec_xl_len((const int16_t *)(y[i].bsums), 8); + + vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); + vector signed int prod1 = vec_mulo(q4xmins0, q8ysums0); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vd, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vd, vsumf1); + + vd = vec_mul(vyd, vec_splats(GGML_FP16_TO_FP32(x[i].d[0]))); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].qs); + vector signed char q4x00 = vec_and(qxs0, lowMask); + vector signed char q4x01 = vec_sr(qxs0, v4); + vector signed char q4x10 = vec_and(qxs1, lowMask); + vector signed char q4x11 = vec_sr(qxs1, v4); + + vector signed char q8y00 = vec_xl( 0, y[i].qs); + vector signed char q8y10 = vec_xl(16, y[i].qs); + vector signed char q8y01 = vec_xl(32, y[i].qs); + vector signed char q8y11 = vec_xl(48, y[i].qs); + + vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01)); + vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11)); + + vector signed short vs0 = vec_splat(vscales, 0); + vector signed short vs1 = vec_splat(vscales, 1); + + vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); + vector signed int vsumi1 = vec_add(vec_mule(qv10, vs0), vec_mulo(qv10, vs0)); + vector signed int vsumi2 = vec_add(vec_mule(qv01, vs1), vec_mulo(qv01, vs1)); + vector signed int vsumi3 = vec_add(vec_mule(qv11, vs1), vec_mulo(qv11, vs1)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + #else uint8_t aux8[QK_K]; @@ -7226,6 +8189,130 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf+sums; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v1 = vec_splats((unsigned char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + memcpy(utmp, x[i].scales, 12); + + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed char utmps = (vector signed char)vec_xl( 0, utmp); + vector signed short vscales = vec_unpackh(utmps); + + vector signed short q5xmins = vec_unpackl(utmps); + vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins); + vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins); + + vector signed int prod0 = vec_mule(q5xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q5xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q5, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q5); + vector signed char qxs1 = (vector signed char)vec_xl(16, q5); + q5 += 32; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + + vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4); + vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3); + vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4); + vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3); + qxhs0 = vec_sr(qxhs0, v2); + qxhs1 = vec_sr(qxhs1, v2); + + vector signed char q5x00 = vec_or(q5h00, qxs00); + vector signed char q5x01 = vec_or(q5h01, qxs01); + vector signed char q5x10 = vec_or(q5h10, qxs10); + vector signed char q5x11 = vec_or(q5h11, qxs11); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl(16, q8); + vector signed char q8y01 = vec_xl(32, q8); + vector signed char q8y11 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01)); + vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11)); + + vector signed short vs0 = vec_splat(vscales, 0); + vector signed short vs1 = vec_splat(vscales, 1); + vscales = vec_sld(vscales, vscales, 12); + + qv00 = vec_add(qv00, qv10); + qv01 = vec_add(qv01, qv11); + + vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); + vsumi2 = vec_add(vec_mule(qv01, vs1), vsumi2); + vsumi3 = vec_add(vec_mulo(qv01, vs1), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + #else const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -7523,6 +8610,83 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v1 = vec_splats((unsigned char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + +#pragma GCC unroll 2 + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(x[i].qs, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd= vec_mul(vxd, vyd); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].qs); + vector signed char qxs00 = (vector signed char)vec_and(qxs0, lowMask); + vector signed char qxs01 = (vector signed char)vec_sr(qxs0, v4); + vector signed char qxs10 = (vector signed char)vec_and(qxs1, lowMask); + vector signed char qxs11 = (vector signed char)vec_sr(qxs1, v4); + + vector signed char qxhs = (vector signed char)vec_xl_len(x[i].qh, 8); + vector signed char qxhs0 = vec_or(qxhs, vec_sr(vec_sld(qxhs, qxhs, 8), v1)); + vector signed char qxhs1 = vec_sr(qxhs0, v2); + vector signed char qxh00 = vec_sl(vec_andc((vector signed char)v1, qxhs0), v4); + vector signed char qxh10 = vec_sl(vec_andc((vector signed char)v1, qxhs1), v4); + vector signed char qxh01 = vec_sl(vec_andc((vector signed char)v1, vec_sr(qxhs0, v4)), v4); + vector signed char qxh11 = vec_sl(vec_andc((vector signed char)v1, vec_sr(qxhs1, v4)), v4); + + vector signed char q5x00 = vec_sub(qxs00, qxh00); + vector signed char q5x10 = vec_sub(qxs10, qxh10); + vector signed char q5x01 = vec_sub(qxs01, qxh01); + vector signed char q5x11 = vec_sub(qxs11, qxh11); + + vector signed char q8y00 = vec_xl( 0, y[i].qs); + vector signed char q8y10 = vec_xl(16, y[i].qs); + vector signed char q8y01 = vec_xl(32, y[i].qs); + vector signed char q8y11 = vec_xl(48, y[i].qs); + + vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01)); + vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11)); + + vector signed short vs = (vector signed short)vec_unpackh(vec_xl_len(x[i].scales, 4)); + vector signed short vs0 = vec_splat(vs, 0); + vector signed short vs1 = vec_splat(vs, 1); + vector signed short vs2 = vec_splat(vs, 2); + vector signed short vs3 = vec_splat(vs, 3); + + vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); + vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); + vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); + vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + #else int8_t aux8[QK_K]; @@ -7953,6 +9117,151 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict qs = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q6, 0, 0); + __builtin_prefetch(qh, 0, 0); + __builtin_prefetch(q8, 0, 0); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q6); + vector signed char qxs1 = (vector signed char)vec_xl(16, q6); + vector signed char qxs2 = (vector signed char)vec_xl(32, q6); + vector signed char qxs3 = (vector signed char)vec_xl(48, q6); + q6 += 64; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + vector signed char qxs20 = vec_and(qxs2, lowMask); + vector signed char qxs21 = vec_sr(qxs2, v4); + vector signed char qxs30 = vec_and(qxs3, lowMask); + vector signed char qxs31 = vec_sr(qxs3, v4); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, qh); + qh += 32; + + vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); + vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); + vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4); + vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4); + vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); + vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); + vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4); + vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4); + + vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); + vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); + vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); + vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); + vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off); + vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off); + vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off); + vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y20 = vec_xl( 32, q8); + vector signed char q8y30 = vec_xl( 48, q8); + vector signed char q8y01 = vec_xl( 64, q8); + vector signed char q8y11 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); + vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); + vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20)); + vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30)); + vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); + vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); + vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21)); + vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31)); + + vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8)); + qs += 8; + + vector signed short vs0 = vec_splat(vscales, 0); + vector signed short vs1 = vec_splat(vscales, 1); + vector signed short vs2 = vec_splat(vscales, 2); + vector signed short vs3 = vec_splat(vscales, 3); + vector signed short vs4 = vec_splat(vscales, 4); + vector signed short vs5 = vec_splat(vscales, 5); + vector signed short vs6 = vec_splat(vscales, 6); + vector signed short vs7 = vec_splat(vscales, 7); + + vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); + vsumi2 = vec_add(vec_mule(qv01, vs4), vsumi2); + vsumi3 = vec_add(vec_mulo(qv01, vs4), vsumi3); + vsumi4 = vec_add(vec_mule(qv10, vs1), vsumi4); + vsumi5 = vec_add(vec_mulo(qv10, vs1), vsumi5); + vsumi6 = vec_add(vec_mule(qv11, vs5), vsumi6); + vsumi7 = vec_add(vec_mulo(qv11, vs5), vsumi7); + + vsumi0 = vec_add(vec_mule(qv20, vs2), vsumi0); + vsumi1 = vec_add(vec_mulo(qv20, vs2), vsumi1); + vsumi2 = vec_add(vec_mule(qv21, vs6), vsumi2); + vsumi3 = vec_add(vec_mulo(qv21, vs6), vsumi3); + vsumi4 = vec_add(vec_mule(qv30, vs3), vsumi4); + vsumi5 = vec_add(vec_mulo(qv30, vs3), vsumi5); + vsumi6 = vec_add(vec_mule(qv31, vs7), vsumi6); + vsumi7 = vec_add(vec_mulo(qv31, vs7), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + #else int8_t aux8[QK_K]; @@ -8259,6 +9568,85 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + +#pragma GCC unroll 2 + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(x[i].ql, 0, 1); + __builtin_prefetch(x[i].qh, 0, 1); + __builtin_prefetch(y[i].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd= vec_mul(vxd, vyd); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].ql); + vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].ql); + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); + + vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); + vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); + vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); + vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); + + vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); + vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); + vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); + vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); + + vector signed char q8y00 = vec_xl( 0, y[i].qs); + vector signed char q8y10 = vec_xl(16, y[i].qs); + vector signed char q8y01 = vec_xl(32, y[i].qs); + vector signed char q8y11 = vec_xl(48, y[i].qs); + + vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); + vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); + vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); + vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); + + vector signed short vs = (vector signed short)vec_unpackh(vec_xl_len(x[i].scales, 4)); + vector signed short vs0 = vec_splat(vs, 0); + vector signed short vs1 = vec_splat(vs, 1); + vector signed short vs2 = vec_splat(vs, 2); + vector signed short vs3 = vec_splat(vs, 3); + + vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); + vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); + vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); + vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + #else int8_t aux8[QK_K]; @@ -8300,7 +9688,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r #endif -#if defined (__AVX2__) || defined (__ARM_NEON) +#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -8433,6 +9821,103 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void *s = 0.125f * hsum_float_8(accumf); +#elif defined(__POWER9_VECTOR__) + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + memcpy(aux32, q2, 4*sizeof(uint32_t)); + q2 += 8; + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))}; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = aux32[1] >> 28; + const uint16_t ls1 = aux32[3] >> 28; + + vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); + + vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); + vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); + vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); + vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); + vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); + vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); + vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); + vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); #else uint32_t aux32[2]; @@ -8708,6 +10193,104 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * *s = 0.125f * hsum_float_8(accumf); #endif +#elif defined(__POWER9_VECTOR__) + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))}; + q2 += 8; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_add(vec_mule(qv0, vscales0), vsumi0); + vsumi1 = vec_add(vec_mule(qv1, vscales1), vsumi1); + vsumi2 = vec_add(vec_mule(qv2, vscales2), vsumi2); + vsumi3 = vec_add(vec_mule(qv3, vscales3), vsumi3); + vsumi4 = vec_add(vec_mulo(qv0, vscales0), vsumi4); + vsumi5 = vec_add(vec_mulo(qv1, vscales1), vsumi5); + vsumi6 = vec_add(vec_mulo(qv2, vscales2), vsumi6); + vsumi7 = vec_add(vec_mulo(qv3, vscales3), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); #else float sumf = 0.f; @@ -8908,6 +10491,124 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * *s = 0.125f * hsum_float_8(accumf); +#elif defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + const uint8_t * restrict q2 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))}; + q2 += 8; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0); + vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0); + vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1); + vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2); + vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_add(vec_mule(qv0, vscales0), vsumi0); + vsumi1 = vec_add(vec_mule(qv1, vscales1), vsumi1); + vsumi2 = vec_add(vec_mule(qv2, vscales2), vsumi2); + vsumi3 = vec_add(vec_mule(qv3, vscales3), vsumi3); + vsumi4 = vec_add(vec_mulo(qv0, vscales0), vsumi4); + vsumi5 = vec_add(vec_mulo(qv1, vscales1), vsumi5); + vsumi6 = vec_add(vec_mulo(qv2, vscales2), vsumi6); + vsumi7 = vec_add(vec_mulo(qv3, vscales3), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); #else float sumf = 0; @@ -9052,6 +10753,101 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void *s = 0.25f * hsum_float_8(accumf); +#elif defined(__POWER9_VECTOR__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + const uint8_t * restrict q3 = x[i].qs; + const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4); + const int8_t * restrict q8 = y[i].qs; + +#pragma GCC unroll 1 + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; + vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; + vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; + vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; + q3 += 16; + + vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])}; + vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])}; + vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])}; + + vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0); + vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1); + vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2); + vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(signs[0] >> 28); + const uint16_t ls1 = (uint16_t)(signs[1] >> 28); + signs += 2; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); + vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); + vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); + vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); + vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); + vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); + vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); + vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.25f * vec_extract(vsumf0, 0); #else uint32_t aux32; @@ -9279,6 +11075,124 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * *s = hsum_float_8(accumf); +#elif defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].signs); + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)], + iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]}; + vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)], + iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]}; + vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)], + iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]}; + vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)], + iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]}; + q3 += 16; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0); + vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0); + vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1); + vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2); + vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + sc ++; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); + vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); + vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); + vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); + vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); + vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); + vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); + vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else float sumf = 0.f; @@ -9433,6 +11347,113 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; +#elif defined(__POWER9_VECTOR__) + const vector unsigned char v0 = vec_splats((unsigned char)0x0); + const vector unsigned short vsign = vec_splats((unsigned short)0x8000); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi8 = vec_splats((int32_t)0); + + const uint8_t * restrict q1 = x[i].qs; + const uint16_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + const int16_t * restrict qs = y[i].bsums; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q1, 0, 1); + __builtin_prefetch(qh, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))}; + q1 += 8; + + vector signed char q1x0 = (vector signed char)aux64x2_0; + vector signed char q1x1 = (vector signed char)aux64x2_1; + vector signed char q1x2 = (vector signed char)aux64x2_2; + vector signed char q1x3 = (vector signed char)aux64x2_3; + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3)); + + const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7); + const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7); + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + vector signed short vscales = vec_sld(vscales23, vscales01, 8); + + vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); + vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); + vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); + vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); + vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); + vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); + vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); + vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + + vector signed short q8ysums = vec_xl_len(qs, 8); + qs += 4; + q8ysums = vec_mergeh(q8ysums, (vector signed short)v0); + + vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); + qh += 2; + vector bool short vsel = vec_cmpge(qxh, (vector signed short)v0); + + vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); + + vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + + vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else float sumf = 0; @@ -9789,6 +11810,51 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * *s = hsum_float_8(_mm256_add_ps(accum1, accum2)); +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + +#pragma GCC unroll 4 + for (int ib = 0; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_perm(values, values, (vector unsigned char)q4x0); + q4x1 = vec_perm(values, values, (vector unsigned char)q4x1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + vector signed int vsumi1 = vec_add(vec_unpackh(qv1), vec_unpackl(qv1)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + } + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else float sumf = 0; for (int ib = 0; ib < nb; ++ib) { @@ -9900,6 +11966,105 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * *s = hsum_float_8(accum); +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + + for (int ibl = 0; ibl < nb; ++ibl) { + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ibl].d)); + vector float vyd = vec_splats(y[ibl].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi4 = vec_splats((int32_t)0); + vector signed int vsumi5 = vec_splats((int32_t)0); + vector signed int vsumi6 = vec_splats((int32_t)0); + vector signed int vsumi7 = vec_splats((int32_t)0); + + uint16_t h = x[ibl].scales_h; + + const uint8_t * restrict q4 = x[ibl].qs; + const uint8_t * restrict sc = x[ibl].scales_l; + const int8_t * restrict q8 = y[ibl].qs; + + for (int ib = 0; ib < QK_K/64; ib ++ ) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + q4 += 32; + + vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask); + vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4); + vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask); + vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4); + + q4x00 = vec_perm(values, values, (vector unsigned char)q4x00); + q4x01 = vec_perm(values, values, (vector unsigned char)q4x01); + q4x10 = vec_perm(values, values, (vector unsigned char)q4x10); + q4x11 = vec_perm(values, values, (vector unsigned char)q4x11); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3)); + + const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32); + const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32); + h >>= 4; + sc ++; + + vector signed short vscales01 = vec_splats((int16_t)ls0); + vector signed short vscales23 = vec_splats((int16_t)ls1); + + vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); + vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); + vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); + vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); + vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); + vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); + vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); + vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); #else float sumf = 0; for (int ibl = 0; ibl < nb; ++ibl) { From fe179ae0cc2e10d7878d0f1862d95eb403d18c4c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 May 2024 19:23:22 +0300 Subject: [PATCH 058/100] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index b4f391d50a9..0096c0b533a 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -4151b0420d7a8f4c3c1b420afa0f62ca441b9cd8 +9149580f5e15fa7510fa3413516fbf517cf2e921 From 3fa7d29876918970246ca4ae1bf49f5f03c8cb6c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 May 2024 20:12:46 +0300 Subject: [PATCH 059/100] talk-llama : sync llama.cpp --- examples/talk-llama/llama.cpp | 3638 ++++++++++++++++++++------ examples/talk-llama/llama.h | 163 +- examples/talk-llama/talk-llama.cpp | 4 +- examples/talk-llama/unicode-data.cpp | 1844 ++++++++----- examples/talk-llama/unicode-data.h | 7 +- examples/talk-llama/unicode.cpp | 639 ++++- examples/talk-llama/unicode.h | 9 +- 7 files changed, 4716 insertions(+), 1588 deletions(-) diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 21772618487..e91ad7285da 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -75,6 +75,7 @@ #include #include #include +#include #include #include #include @@ -105,8 +106,7 @@ #endif #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 8 - +#define LLAMA_MAX_EXPERTS 60 // // logging @@ -205,11 +205,14 @@ enum llm_arch { LLM_ARCH_REFACT, LLM_ARCH_BERT, LLM_ARCH_NOMIC_BERT, + LLM_ARCH_JINA_BERT_V2, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, LLM_ARCH_QWEN2, + LLM_ARCH_QWEN2MOE, LLM_ARCH_PHI2, + LLM_ARCH_PHI3, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, @@ -220,39 +223,46 @@ enum llm_arch { LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, + LLM_ARCH_DBRX, + LLM_ARCH_OLMO, LLM_ARCH_UNKNOWN, }; static const std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GROK, "grok" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, - { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_PERSIMMON, "persimmon" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_BERT, "bert" }, - { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, - { LLM_ARCH_BLOOM, "bloom" }, - { LLM_ARCH_STABLELM, "stablelm" }, - { LLM_ARCH_QWEN, "qwen" }, - { LLM_ARCH_QWEN2, "qwen2" }, - { LLM_ARCH_PHI2, "phi2" }, - { LLM_ARCH_PLAMO, "plamo" }, - { LLM_ARCH_CODESHELL, "codeshell" }, - { LLM_ARCH_ORION, "orion" }, - { LLM_ARCH_INTERNLM2, "internlm2" }, - { LLM_ARCH_MINICPM, "minicpm" }, - { LLM_ARCH_GEMMA, "gemma" }, - { LLM_ARCH_STARCODER2, "starcoder2" }, - { LLM_ARCH_MAMBA, "mamba" }, - { LLM_ARCH_XVERSE, "xverse" }, - { LLM_ARCH_COMMAND_R, "command-r" }, - { LLM_ARCH_UNKNOWN, "(unknown)" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_PERSIMMON, "persimmon" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, }; enum llm_kv { @@ -308,6 +318,7 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_PRE, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, @@ -318,11 +329,17 @@ enum llm_kv { LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_CLS_ID, + LLM_KV_TOKENIZER_MASK_ID, LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_PREFIX, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, + LLM_KV_TOKENIZER_EOT_ID, }; static const std::map LLM_KV_NAMES = { @@ -378,6 +395,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, @@ -388,11 +406,17 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, }; struct LLM_KV { @@ -423,6 +447,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_OUT_NORM, LLM_TENSOR_ATTN_ROT_EMBD, LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, @@ -434,6 +459,9 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, @@ -665,6 +693,25 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_JINA_BERT_V2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_BLOOM, { @@ -696,6 +743,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, { @@ -731,6 +780,28 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_QWEN2MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_PHI2, { @@ -747,6 +818,23 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PHI3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_PLAMO, { @@ -926,6 +1014,38 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_DBRX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_OLMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, { @@ -1522,12 +1642,12 @@ struct llama_mlock { }; using llama_mlocks = std::vector>; -static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { +static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); GGML_ASSERT(check == -n_tokens); } else { @@ -1632,17 +1752,17 @@ static size_t llama_get_device_memory(int device) { #if defined(GGML_USE_CUDA) size_t total; size_t free; - ggml_backend_cuda_get_device_memory(device, &total, &free); + ggml_backend_cuda_get_device_memory(device, &free, &total); return free; #elif defined(GGML_USE_SYCL) size_t total; size_t free; - ggml_backend_sycl_get_device_memory(device, &total, &free); + ggml_backend_sycl_get_device_memory(device, &free, &total); return free; #elif defined(GGML_USE_VULKAN) size_t total; size_t free; - ggml_backend_vk_get_device_memory(device, &total, &free); + ggml_backend_vk_get_device_memory(device, &free, &total); return free; #else return 1; @@ -1684,6 +1804,7 @@ enum e_model { MODEL_4B, MODEL_7B, MODEL_8B, + MODEL_12B, MODEL_13B, MODEL_14B, MODEL_15B, @@ -1699,6 +1820,10 @@ enum e_model { MODEL_MEDIUM, MODEL_LARGE, MODEL_XL, + MODEL_A2_7B, + MODEL_8x7B, + MODEL_8x22B, + MODEL_16x12B, }; static const size_t kiB = 1024; @@ -1741,7 +1866,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; + bool use_alibi = false; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -1831,6 +1956,7 @@ struct llama_cparams { bool embeddings; bool causal_attn; bool offload_kqv; + bool flash_attn; enum llama_pooling_type pooling_type; @@ -1882,6 +2008,12 @@ struct llama_layer { struct ggml_tensor * ffn_down_exps; struct ggml_tensor * ffn_up_exps ; + // ff shared expert (shexp) + struct ggml_tensor * ffn_gate_inp_shexp; + struct ggml_tensor * ffn_gate_shexp; + struct ggml_tensor * ffn_down_shexp; + struct ggml_tensor * ffn_up_shexp; + // ff bias struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_up_b; // b3 @@ -1928,8 +2060,8 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; + bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2006,7 +2138,8 @@ struct llama_vocab { ttype type; }; - enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; std::unordered_map token_to_id; std::vector id_to_token; @@ -2016,20 +2149,22 @@ struct llama_vocab { std::map, int> bpe_ranks; // default LLaMA special tokens - id special_bos_id = 1; - id special_eos_id = 2; - id special_unk_id = 0; - id special_sep_id = -1; - id special_pad_id = -1; + id special_bos_id = 1; + id special_eos_id = 2; + id special_unk_id = 0; + id special_sep_id = -1; + id special_pad_id = -1; + id special_cls_id = -1; + id special_mask_id = -1; int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add. int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add. id linefeed_id = 13; - id special_prefix_id = 32007; - id special_middle_id = 32009; - id special_suffix_id = 32008; - id special_eot_id = 32010; + id special_prefix_id = -1; + id special_suffix_id = -1; + id special_middle_id = -1; + id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token bool add_space_prefix = true; @@ -2177,7 +2312,7 @@ struct llama_context { std::vector output_ids; // map batch token positions to ids of the logits and embd buffers size_t output_size = 0; // capacity (of tokens positions) for the output buffers - int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch + int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch bool logits_all = false; @@ -2203,7 +2338,6 @@ struct llama_context { struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_out_ids; // I32 [n_outputs] struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_KQ_pos; // F32 [n_kv] struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] @@ -2225,11 +2359,14 @@ struct llama_context { static bool llama_kv_cache_init( struct llama_kv_cache & cache, - const llama_model & model, + const llama_context * ctx, ggml_type type_k, ggml_type type_v, uint32_t kv_size, bool offload) { + const llama_model & model = ctx->model; + const llama_cparams & cparams = ctx->cparams; + const struct llama_hparams & hparams = model.hparams; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); @@ -2240,8 +2377,9 @@ static bool llama_kv_cache_init( // TODO: find a nicer way to add other recurrent model architectures cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.v_trans = !cparams.flash_attn; - // TODO: support mixed reccurent Transformer architectues + // TODO: support mixed recurrent Transformer architectures // NOTE: (!a || b) is a logical implication (a -> b) GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); @@ -2452,6 +2590,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { } cache.head = 0; cache.used = 0; + + for (auto & buf : cache.bufs) { + ggml_backend_buffer_clear(buf, 0); + } } static bool llama_kv_cache_seq_rm( @@ -2772,6 +2914,7 @@ namespace GGUFMeta { case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; } return "unknown"; } @@ -2783,13 +2926,16 @@ namespace GGUFMeta { __func__, override_type_to_str(ovrd->tag), ovrd->key); switch (ovrd->tag) { case LLAMA_KV_OVERRIDE_TYPE_BOOL: { - LLAMA_LOG_INFO("%s\n", ovrd->bool_value ? "true" : "false"); + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); } break; case LLAMA_KV_OVERRIDE_TYPE_INT: { - LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->int_value); + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); } break; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { - LLAMA_LOG_INFO("%.6f\n", ovrd->float_value); + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); } break; default: // Shouldn't be possible to end up here, but just in case... @@ -2808,7 +2954,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { - target = ovrd->bool_value; + target = ovrd->val_bool; return true; } return false; @@ -2818,7 +2964,7 @@ namespace GGUFMeta { static typename std::enable_if::value && std::is_integral::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { - target = ovrd->int_value; + target = ovrd->val_i64; return true; } return false; @@ -2828,7 +2974,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { - target = ovrd->float_value; + target = ovrd->val_f64; return true; } return false; @@ -2837,12 +2983,11 @@ namespace GGUFMeta { template static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { - (void)target; - (void)ovrd; - if (!ovrd) { return false; } - // Currently, we should never end up here so it would be a bug if we do. - throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n", - ovrd ? ovrd->key : "NULL")); + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; } static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { @@ -2875,6 +3020,7 @@ struct llama_model_loader { size_t n_bytes = 0; bool use_mmap = false; + bool check_tensors; llama_files files; llama_ftype ftype; @@ -2889,9 +3035,13 @@ struct llama_model_loader { ggml_tensor * tensor; - llama_tensor_weight(uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { + llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { const int tensor_idx = gguf_find_tensor(gguf_ctx, name); offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); + + if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) { + throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); + } } }; std::vector weights; @@ -2904,7 +3054,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) { + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -2930,15 +3080,15 @@ struct llama_model_loader { get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + files.emplace_back(new llama_file(fname.c_str(), "rb")); + contexts.emplace_back(ctx); + // Save tensors data offset of the main file. // For subsidiary files, `meta` tensor data offset must not be used, // so we build a unified tensors index for weights. for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - weights.emplace_back(0, cur->name, meta, cur); + weights.emplace_back(files.back().get(), 0, cur->name, meta, cur); } - files.emplace_back(new llama_file(fname.c_str(), "rb")); - contexts.emplace_back(ctx); - uint16_t n_split = 0; get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); @@ -2972,12 +3122,13 @@ struct llama_model_loader { throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path)); } + files.emplace_back(new llama_file(split_path, "rb")); + contexts.emplace_back(ctx); + // Save tensors data offset info of the shard. for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - weights.emplace_back(idx, cur->name, ctx_gguf, cur); + weights.emplace_back(files.back().get(), idx, cur->name, ctx_gguf, cur); } - files.emplace_back(new llama_file(split_path, "rb")); - contexts.emplace_back(ctx); gguf_free(ctx_gguf); } @@ -3000,9 +3151,17 @@ struct llama_model_loader { fver = (enum llama_fver) gguf_get_version(meta); + std::set tensor_names; for (auto & w : weights) { n_elements += ggml_nelements(w.tensor); n_bytes += ggml_nbytes(w.tensor); + // make sure there is no duplicated tensor names + const std::string name(w.tensor->name); + auto found = tensor_names.find(name); + if (found != tensor_names.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name)); + } + tensor_names.insert(name); } LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", @@ -3036,6 +3195,7 @@ struct llama_model_loader { switch (type_max) { case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; @@ -3108,6 +3268,7 @@ struct llama_model_loader { } this->use_mmap = use_mmap; + this->check_tensors = check_tensors; } ~llama_model_loader() { @@ -3187,6 +3348,10 @@ struct llama_model_loader { return nullptr; } + const llama_tensor_weight * get_weight(int i) const { + return get_weight(get_tensor_name(i)); + } + const llama_tensor_weight & require_weight(const char * name) const { const llama_tensor_weight * weight = get_weight(name); if (!weight) { @@ -3362,6 +3527,10 @@ struct llama_model_loader { file->seek(w.offs, SEEK_SET); file->read_raw(cur->data, ggml_nbytes(cur)); } + + if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } } size_t size_done = 0; @@ -3378,6 +3547,8 @@ struct llama_model_loader { GGML_ASSERT(size_data != 0 && "call init_mappings() first"); std::vector> read_buf; + std::vector>> validation_result; + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(ggml_get_name(cur)); if (weight == nullptr) { @@ -3399,37 +3570,66 @@ struct llama_model_loader { if (bufs_mmap.count(weight->idx)) { buf_mmap = bufs_mmap.at(weight->idx); } + uint8_t * data = (uint8_t *) mapping->addr + weight->offs; + + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size)); + })); + } + GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated if (buf_mmap && cur->data == nullptr) { - ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + weight->offs); + ggml_backend_tensor_alloc(buf_mmap, cur, data); if (lmlocks) { const auto & lmlock = lmlocks->at(weight->idx); - lmlock->grow_to(weight->offs + ggml_nbytes(cur)); + lmlock->grow_to(weight->offs + n_size); } auto & mmap_used = mmaps_used[weight->idx]; mmap_used.first = std::min(mmap_used.first, weight->offs); mmap_used.second = std::max(mmap_used.second, weight->offs + n_size); } else { - ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + weight->offs, 0, n_size); + ggml_backend_tensor_set(cur, data, 0, n_size); } } else { GGML_ASSERT(weight->idx < files.size()); const auto & file = files.at(weight->idx); if (ggml_backend_buffer_is_host(cur->buffer)) { file->seek(weight->offs, SEEK_SET); - file->read_raw(cur->data, ggml_nbytes(cur)); + file->read_raw(cur->data, n_size); + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); + })); + } } else { - read_buf.resize(ggml_nbytes(cur)); + read_buf.resize(n_size); file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), ggml_nbytes(cur)); + file->read_raw(read_buf.data(), n_size); ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } } } size_done += n_size; } + // check validation results + bool validation_failed = false; + for (auto & future : validation_result) { + auto result = future.get(); + if (!result.second) { + LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first)); + validation_failed = true; + } + } + if (validation_failed) { + throw std::runtime_error("found tensors with invalid data"); + } + // check if this is the last call and do final cleanup if (size_done >= size_data) { // unmap offloaded tensors and metadata @@ -3487,6 +3687,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { switch (ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: @@ -3535,6 +3736,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_3B: return "3B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; + case MODEL_12B: return "12B"; case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; case MODEL_15B: return "15B"; @@ -3550,6 +3752,10 @@ static const char * llama_model_type_name(e_model type) { case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; case MODEL_XL: return "1.5B"; + case MODEL_A2_7B: return "A2.7B"; + case MODEL_8x7B: return "8x7B"; + case MODEL_8x22B: return "8x22B"; + case MODEL_16x12B: return "16x12B"; default: return "?B"; } } @@ -3593,6 +3799,12 @@ static void llm_load_hparams( // get hparams kv ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + + // everything past this point is not vocab-related + if (hparams.vocab_only) { + return; + } + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); @@ -3664,15 +3876,23 @@ static void llm_load_hparams( { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 22: model.type = e_model::MODEL_1B; break; - case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - case 48: model.type = e_model::MODEL_34B; break; - case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + if (hparams.n_expert == 8) { + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8x7B; break; + case 56: model.type = e_model::MODEL_8x22B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } else { + switch (hparams.n_layer) { + case 22: model.type = e_model::MODEL_1B; break; + case 26: model.type = e_model::MODEL_3B; break; + case 32: model.type = hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 48: model.type = e_model::MODEL_34B; break; + case 60: model.type = e_model::MODEL_30B; break; + case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } } } break; case LLM_ARCH_MINICPM: @@ -3768,6 +3988,19 @@ static void llm_load_hparams( model.type = e_model::MODEL_335M; break; // bge-large } } break; + case LLM_ARCH_JINA_BERT_V2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + hparams.f_max_alibi_bias = 8.0f; + + switch (hparams.n_layer) { + case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small + case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base + } + } break; case LLM_ARCH_NOMIC_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3814,6 +4047,7 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; + case 40: model.type = e_model::MODEL_12B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -3838,10 +4072,28 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN2MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_A2_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_PHI3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; @@ -3963,20 +4215,44 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_DBRX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_16x12B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_OLMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer) { + case 22: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { - hparams.need_kq_pos = true; + hparams.use_alibi = true; } hparams.rope_type = llama_rope_type(&model); } // TODO: This should probably be in llama.h -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false); +static std::vector llama_tokenize_internal( + const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special = false +); static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); static void llm_load_vocab( @@ -3990,39 +4266,92 @@ static void llm_load_vocab( // determine vocab type { - std::string tokenizer_name; + std::string tokenizer_model; + std::string tokenizer_pre; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); - if (tokenizer_name == "no_vocab") { + if (tokenizer_model == "no_vocab") { vocab.type = LLAMA_VOCAB_TYPE_NONE; // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.linefeed_id = -1; + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; + vocab.linefeed_id = -1; return; - } else if (tokenizer_name == "llama") { + } else if (tokenizer_model == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; // default special tokens - vocab.special_bos_id = 1; - vocab.special_eos_id = 2; - vocab.special_unk_id = 0; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; + vocab.special_bos_id = 1; + vocab.special_eos_id = 2; + vocab.special_unk_id = 0; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; + + // For Fill-In-the-Middle (FIM)/infill models which where converted + // prior to support of FIM special tokens in GGUF, the following + // will allow those models to continue to work. The general names + // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and + // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once + // new versions of these models have been published. + std::string gen_name; + ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); + + std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), + [](unsigned char c){ return std::tolower(c); }); + + if (gen_name.find("code") != std::string::npos) { + if (model.arch == LLM_ARCH_LLAMA) { + vocab.special_prefix_id = 32007; + vocab.special_suffix_id = 32008; + vocab.special_middle_id = 32009; + vocab.special_eot_id = 32010; + } else if (model.arch == LLM_ARCH_GEMMA) { + vocab.special_prefix_id = 67; + vocab.special_suffix_id = 69; + vocab.special_middle_id = 68; + // TODO: this is not EOT, it is "file separator" token, needs fix + // https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572 + //vocab.special_eot_id = 70; + vocab.special_eot_id = 107; + } + } const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); if (add_space_prefix_keyidx != -1) { vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); } // The default value of add_space_prefix is true. - } else if (tokenizer_name == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else if (tokenizer_model == "bert") { + vocab.type = LLAMA_VOCAB_TYPE_WPM; + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = 100; + vocab.special_sep_id = 102; + vocab.special_pad_id = 0; + vocab.special_cls_id = 101; + vocab.special_mask_id = 103; + vocab.add_space_prefix = false; + } else { + if (tokenizer_model == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else { + LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str()); + LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + vocab.type = LLAMA_VOCAB_TYPE_SPM; + return; + } // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); if (merges_keyidx == -1) { @@ -4049,26 +4378,74 @@ static void llm_load_vocab( } // default special tokens - vocab.special_bos_id = 11; - vocab.special_eos_id = 11; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - } else if (tokenizer_name == "bert") { - vocab.type = LLAMA_VOCAB_TYPE_WPM; - - // default special tokens - vocab.special_bos_id = 101; - vocab.special_eos_id = 102; - vocab.special_unk_id = 100; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.add_space_prefix = false; + vocab.special_bos_id = 11; + vocab.special_eos_id = 11; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; + } + + // for now, only BPE models have pre-tokenizers + if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { + if (tokenizer_pre.empty()) { + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "default") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "llama3" || + tokenizer_pre == "llama-v3" || + tokenizer_pre == "llama-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + } else if ( + tokenizer_pre == "deepseek-llm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + } else if ( + tokenizer_pre == "deepseek-coder") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + } else if ( + tokenizer_pre == "falcon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; + } else if ( + tokenizer_pre == "mpt") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT; + } else if ( + tokenizer_pre == "starcoder") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; + } else if ( + tokenizer_pre == "gpt-2" || + tokenizer_pre == "jina-es" || + tokenizer_pre == "jina-de") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "refact") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT; + } else if ( + tokenizer_pre == "command-r") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + } else if ( + tokenizer_pre == "qwen2") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else if ( + tokenizer_pre == "olmo") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO; + } else if ( + tokenizer_pre == "dbrx") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX; + } else { + throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + } } else { - LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); - LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); - - vocab.type = LLAMA_VOCAB_TYPE_SPM; + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } } @@ -4125,12 +4502,19 @@ static void llm_load_vocab( // special tokens { const std::vector> special_token_types = { - { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, - { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, - { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, - { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, - { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, + { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, + { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, + { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, + { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, + { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id }, + { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id }, + { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id }, + { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, }; + for (const auto & it : special_token_types) { const std::string & key = kv(std::get<0>(it)); int32_t & id = std::get<1>(it); @@ -4145,7 +4529,6 @@ static void llm_load_vocab( } else { id = new_id; } - } // Handle add_bos_token and add_eos_token @@ -4159,6 +4542,28 @@ static void llm_load_vocab( vocab.special_add_eos = int(temp); } } + + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. + // + // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID + // for now, we apply this workaround to find the EOT token based on its text + if (vocab.special_eot_id == -1) { + for (const auto & t : vocab.token_to_id) { + if ( + // TODO: gemma "" is exported as a normal token, so the following check does not work + // need to fix convert script + //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL && + (t.first == "<|eot_id|>" || + t.first == "<|im_end|>" || + t.first == "<|end|>" || + t.first == "" + ) + ) { + vocab.special_eot_id = t.second; + break; + } + } + } } // build special tokens cache @@ -4321,12 +4726,19 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); // special tokens - if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } - if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } - if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } - if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } - if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } + if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } + if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } + if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } + if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } + if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); } + if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); } + + if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } + if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); } + if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); } + if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); } + if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } } // Returns false if cancelled by progress_callback @@ -4344,6 +4756,13 @@ static bool llm_load_tensors( auto & hparams = model.hparams; +#ifdef GGML_USE_SYCL + // disable MoE with SYCL until mul_mat_id is updated + if (hparams.n_expert > 0) { + n_gpu_layers = 0; + } +#endif + model.split_mode = split_mode; model.main_gpu = main_gpu; model.n_gpu_layers = n_gpu_layers; @@ -4441,7 +4860,7 @@ static bool llm_load_tensors( size_t ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output // for moe merged tensors - ctx_size += ggml_tensor_overhead()*hparams.n_expert*n_layer; + ctx_size += ggml_tensor_overhead()*n_layer*3; std::map ctx_map; for (auto & it : buft_layer_count) { @@ -4637,6 +5056,39 @@ static bool llm_load_tensors( layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); } } break; + case LLM_ARCH_DBRX: + { + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + } + } break; case LLM_ARCH_BAICHUAN: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -4831,6 +5283,50 @@ static bool llm_load_tensors( layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); } } break; + case LLM_ARCH_JINA_BERT_V2: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings + model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings + model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm + model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; // JinaBertLayer + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, false); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, false); + + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, false); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, false); + + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens + layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens + + layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm + layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + + layer.layer_out_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + } + } break; case LLM_ARCH_BLOOM: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -4951,8 +5447,13 @@ static bool llm_load_tensors( layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, false); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, false); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, false); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); @@ -4995,7 +5496,13 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } } for (int i = 0; i < n_layer; ++i) { @@ -5023,16 +5530,14 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; - case LLM_ARCH_PHI2: + case LLM_ARCH_QWEN2MOE: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); - model.output_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}); + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -5041,18 +5546,68 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, false); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false); + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - if (layer.wqkv == nullptr) { - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + auto n_ff_exp = n_ff / hparams.n_expert_used; + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + + // Shared expert branch + layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}); + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff}); + } + } break; + case LLM_ARCH_PHI2: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, false); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false); + + if (layer.wqkv == nullptr) { + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); @@ -5068,6 +5623,33 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; + case LLM_ARCH_PHI3: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context* ctx_layer = ctx_for_layer(i); + ggml_context* ctx_split = ctx_for_layer_split(i); + + auto& layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); + } + } break; case LLM_ARCH_PLAMO: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -5406,11 +5988,47 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + if (n_layer >= 64){ + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}); + } + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; + case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); @@ -5566,7 +6184,7 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.kv_overrides); + llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); model.hparams.vocab_only = params.vocab_only; @@ -5604,6 +6222,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam || !( model.ftype == LLAMA_FTYPE_ALL_F32 || model.ftype == LLAMA_FTYPE_MOSTLY_F16 || + model.ftype == LLAMA_FTYPE_MOSTLY_BF16 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ) @@ -5695,37 +6314,47 @@ static struct ggml_tensor * llm_build_inp_embd( static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, const llm_build_cb & cb, int64_t il) { + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(kv.size == n_ctx); - // compute the transposed [n_tokens, n_embd] V matrix - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv.v_l[il]), - (kv_head)*ggml_element_size(kv.v_l[il])); + // note: storing RoPE-ed version of K in the KV cache + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + + struct ggml_tensor * v_cache_view = nullptr; + + if (cparams.flash_attn) { + v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv_head)*ggml_element_size(kv.v_l[il])); + + v_cur = ggml_transpose(ctx, v_cur); + } cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); } static struct ggml_tensor * llm_build_norm( @@ -5774,7 +6403,7 @@ static struct ggml_tensor * llm_build_ffn( llm_ffn_gate_type type_gate, const llm_build_cb & cb, int il) { - struct ggml_tensor * tmp = ggml_mul_mat(ctx, up, cur); + struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur) : cur; cb(tmp, "ffn_up", il); if (up_b) { @@ -5851,24 +6480,117 @@ static struct ggml_tensor * llm_build_ffn( return cur; } -// if max_alibi_bias > 0 then apply ALiBi +static struct ggml_tensor * llm_build_moe_ffn( + struct ggml_context * ctx, + struct ggml_tensor * cur, + struct ggml_tensor * gate_inp, + struct ggml_tensor * up_exps, + struct ggml_tensor * gate_exps, + struct ggml_tensor * down_exps, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + const llm_build_cb & cb, + int il) { + int64_t n_embd = cur->ne[0]; + int64_t n_tokens = cur->ne[1]; + + ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + cb(probs, "ffn_moe_probs", il); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); + + ggml_tensor * weights = ggml_get_rows(ctx, + ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights", il); + + if (norm_w) { + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens] + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights_norm", il); + + weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); + } + + cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); + ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + switch (type_op) { + case LLM_FFN_SILU: + { + gate = ggml_silu(ctx, gate); + cb(gate, "ffn_moe_silu", il); + } break; + case LLM_FFN_GELU: + { + gate = ggml_gelu(ctx, gate); + cb(gate, "ffn_moe_gelu", il); + } break; + default: + GGML_ASSERT(false); + } + + ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + cb(par, "ffn_moe_gate_par", il); + + ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); + + experts = ggml_mul(ctx, experts, weights); + + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx, moe_out, cur_expert); + } + } + + if (n_expert_used == 1) { + // avoid returning a non-contiguous tensor + moe_out = ggml_cont(ctx, moe_out); + } + + return moe_out; +} + static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, struct ggml_tensor * wo_b, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t n_kv, float kq_scale, const llm_build_cb & cb, int il) { + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -5886,71 +6608,75 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + struct ggml_tensor * cur; - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + if (cparams.flash_attn) { + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); - //try from phi2 - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); - } + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } -#if defined(GGML_USE_KOMPUTE) -#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") -#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else -#endif - { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); + kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx, kq, 30); + } + + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - } - GGML_ASSERT(kv.size == n_ctx); + GGML_ASSERT(kv.size == n_ctx); - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } ggml_build_forward_expand(graph, cur); @@ -5970,6 +6696,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -5978,8 +6705,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * v_cur, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -5993,12 +6718,12 @@ static struct ggml_tensor * llm_build_kv( ggml_build_forward_expand(graph, k_cur); ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6040,6 +6765,8 @@ struct llm_build_context { const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_orig_ctx; + const bool flash_attn; + const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -6086,6 +6813,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), + flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -6102,18 +6830,17 @@ struct llm_build_context { ctx0 = ggml_init(params); - lctx.inp_tokens = nullptr; - lctx.inp_embd = nullptr; - lctx.inp_pos = nullptr; + lctx.inp_tokens = nullptr; + lctx.inp_embd = nullptr; + lctx.inp_pos = nullptr; lctx.inp_out_ids = nullptr; lctx.inp_KQ_mask = nullptr; - lctx.inp_KQ_pos = nullptr; lctx.inp_K_shift = nullptr; - lctx.inp_mean = nullptr; - lctx.inp_cls = nullptr; - lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; + lctx.inp_mean = nullptr; + lctx.inp_cls = nullptr; + lctx.inp_s_copy = nullptr; + lctx.inp_s_mask = nullptr; + lctx.inp_s_seq = nullptr; } void free() { @@ -6200,15 +6927,31 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); - ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, i)); + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; - ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, id)); + if (flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, id)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); @@ -6238,20 +6981,13 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; - } - - struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); - cb(lctx.inp_KQ_pos, "KQ_pos", -1); - ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } struct ggml_tensor * build_inp_mean() { @@ -6357,9 +7093,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6394,62 +7130,15 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] - cb(logits, "ffn_moe_logits", il); - - ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] - cb(probs, "ffn_moe_probs", il); - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); - cb(weights, "ffn_moe_weights", il); - - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] - - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); - cb(weights_sum, "ffn_moe_weights_sum", il); - - weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] - cb(weights, "ffn_moe_weights_norm", il); - - // compute expert outputs - ggml_tensor * moe_out = nullptr; - - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert; - - ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur); - cb(cur_up, "ffn_moe_up", il); - - ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur); - cb(cur_gate, "ffn_moe_gate", il); - - cur_gate = ggml_silu(ctx0, cur_gate); - cb(cur_gate, "ffn_moe_silu", il); - - cur_expert = ggml_mul(ctx0, cur_up, cur_gate); - cb(cur_expert, "ffn_moe_gate_par", il); - - cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd] - cb(cur_expert, "ffn_moe_down", il); - - cur_expert = ggml_mul(ctx0, cur_expert, - ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); - cb(cur_expert, "ffn_moe_weighted", il); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - cb(moe_out, "ffn_moe_out", il); - } - } - - cur = moe_out; + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + cb, il); + cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -6499,9 +7188,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6544,9 +7230,9 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6616,9 +7302,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6651,9 +7334,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6771,9 +7454,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -6896,9 +7579,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -6928,63 +7611,15 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] - cb(logits, "ffn_moe_logits", il); - - ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] - cb(probs, "ffn_moe_probs", il); - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); - cb(weights, "ffn_moe_weights", il); - - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] - - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); - cb(weights_sum, "ffn_moe_weights_sum", il); - - weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] - cb(weights, "ffn_moe_weights_norm", il); - - // compute expert outputs - ggml_tensor * moe_out = nullptr; - - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert; - - ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur); - cb(cur_up, "ffn_moe_up", il); - - ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur); - cb(cur_gate, "ffn_moe_gate", il); - - //GeLU - cur_gate = ggml_gelu(ctx0, cur_gate); - cb(cur_gate, "ffn_moe_gelu", il); - - cur_expert = ggml_mul(ctx0, cur_up, cur_gate); - cb(cur_expert, "ffn_moe_gate_par", il); - - cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd] - cb(cur_expert, "ffn_moe_down", il); - - cur_expert = ggml_mul(ctx0, cur_expert, - ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); - cb(cur_expert, "ffn_moe_weighted", il); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - cb(moe_out, "ffn_moe_out", il); - } - } - - cur = moe_out; + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_GELU, true, + cb, il); + cb(cur, "ffn_moe_out", il); // Grok // if layer_out_norm is present then apply it before adding the input @@ -6996,7 +7631,6 @@ struct llm_build_context { cb(cur, "layer_out_norm", il); } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -7032,12 +7666,16 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_starcoder() { + struct ggml_cgraph * build_dbrx() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -7050,59 +7688,183 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); - cb(pos, "pos_embd", -1); - - inpL = ggml_add(ctx0, inpL, pos); - cb(inpL, "inpL", -1); - for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm cur = llm_build_norm(ctx0, inpL, hparams, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, cb, il); + model.layers[il].attn_norm, NULL, + LLM_NORM, cb, il); cb(cur, "attn_norm", il); // self-attention { + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); + cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(cur, "wqkv_clamped", il); - struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // add the input - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // FF - { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, cb, il); + // feed-forward network + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].attn_out_norm, NULL, + LLM_NORM, cb, il); + cb(cur, "attn_out_norm", il); + + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + cb, il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_starcoder() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, @@ -7285,9 +8047,9 @@ struct llm_build_context { ); cb(Vcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7353,9 +8115,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7381,9 +8140,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7445,8 +8204,11 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; + struct ggml_tensor * inp_pos = nullptr; - struct ggml_tensor * inp_pos = build_inp_pos(); + if (model.arch != LLM_ARCH_JINA_BERT_V2) { + inp_pos = build_inp_pos(); + } struct ggml_tensor * inp_mean = build_inp_mean(); struct ggml_tensor * inp_cls = build_inp_cls(); @@ -7477,13 +8239,26 @@ struct llm_build_context { struct ggml_tensor * Vcur; // self-attention - if (model.arch == LLM_ARCH_BERT) { + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); cb(Qcur, "Qcur", il); + if (model.layers[il].attn_q_norm) { + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, cb, il); + } + Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); cb(Kcur, "Kcur", il); + if (model.layers[il].attn_k_norm) { + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, cb, il); + } Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); cb(Vcur, "Vcur", il); @@ -7523,7 +8298,7 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); @@ -7574,6 +8349,13 @@ struct llm_build_context { model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, cb, il); } else { cur = llm_build_ffn(ctx0, cur, model.layers[il].ffn_up, NULL, @@ -7640,9 +8422,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, @@ -7674,9 +8453,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7741,9 +8520,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - if (model.pos_embd) { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -7805,14 +8581,15 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -7884,7 +8661,7 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; + // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -7893,6 +8670,8 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "attn_norm", il); + struct ggml_tensor * inpSA = cur; + // self-attention { // compute Q and K and RoPE them @@ -7917,43 +8696,69 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + cb(Kcur, "Kcur", il); + + if (model.layers[il].attn_q_norm) { + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM, cb, il); + cb(Qcur, "Qcur", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM, cb, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - + if (model.layers[il].ffn_norm) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + } else { + // parallel residual + cur = inpSA; + } cur = llm_build_ffn(ctx0, cur, model.layers[il].ffn_up, NULL, model.layers[il].ffn_gate, NULL, @@ -8044,9 +8849,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8143,12 +8948,6 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - // these nodes are added to the graph together so that they are not reordered - // by doing so, the number of splits in the graph is reduced - ggml_build_forward_expand(gf, Qcur); - ggml_build_forward_expand(gf, Kcur); - ggml_build_forward_expand(gf, Vcur); - Qcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, @@ -8163,9 +8962,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8215,16 +9014,17 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_phi2() { + struct ggml_cgraph * build_qwen2moe() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); struct ggml_tensor * cur; - struct ggml_tensor * attn_norm_output; - struct ggml_tensor * ffn_output; struct ggml_tensor * inpL; inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); @@ -8236,79 +9036,222 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); for (int il = 0; il < n_layer; ++il) { - attn_norm_output = llm_build_norm(ctx0, inpL, hparams, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, cb, il); - cb(attn_norm_output, "attn_norm", il); - - // self-attention - { - struct ggml_tensor * Qcur = nullptr; - struct ggml_tensor * Kcur = nullptr; - struct ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); + struct ggml_tensor * inpSA = inpL; - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - } else { - Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); - } + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + // self_attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, - freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - // with phi2, we scale the Q to avoid precision issues - // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 - Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); - cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, - freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // FF + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + cb, il); + cb(cur, "ffn_moe_out", il); + + // FFN shared expert { - ffn_output = llm_build_ffn(ctx0, attn_norm_output, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, + ggml_tensor * cur_gate_inp = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur); + cb(cur_gate_inp, "ffn_shexp_gate_inp", il); + + // sigmoid + ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); + cb(cur_gate, "ffn_shexp_gate", il); + + ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up_shexp, NULL, + model.layers[il].ffn_gate_shexp, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, - LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur_ffn, "ffn_shexp", il); + + ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate); + cb(ffn_shexp_out, "ffn_shexp_out", il); + + moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out); + cb(moe_out, "ffn_out", il); + + cur = moe_out; + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_phi2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * attn_norm_output; + struct ggml_tensor * ffn_output; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + attn_norm_output = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(attn_norm_output, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + // with phi2, we scale the Q to avoid precision issues + // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 + Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids); + } + + // FF + { + ffn_output = llm_build_ffn(ctx0, attn_norm_output, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); cb(ffn_output, "ffn_out", il); } @@ -8332,12 +9275,140 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); + ggml_build_forward_expand(gf, cur); + return gf; + } + + struct ggml_cgraph * build_phi3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + auto residual = inpL; + + // self-attention + { + struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(attn_norm_output, "attn_norm", il); + + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + } + else { + Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor* inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, residual); + residual = cur; + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // FF + // special-case: the up and gate tensors are merged into a single tensor + // TOOD: support into llm_build_ffn + { + struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0)); + auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2)); + + y = ggml_mul(ctx0, y, ggml_silu(ctx0, g)); + cb(y, "ffn_gate", il); + + auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y); + cb(down, "ffn_down", il); + + cur = down; + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, residual, cur); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); return gf; } + struct ggml_cgraph * build_plamo() { struct ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -8390,9 +9461,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -8493,9 +9564,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8600,9 +9671,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8716,9 +9787,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8833,9 +9904,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8963,9 +10034,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9084,9 +10155,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9203,9 +10274,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9454,6 +10525,31 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } + if (model.layers[il].attn_q_norm) { + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM, cb, il); + cb(Qcur, "Qcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM, cb, il); + cb(Kcur, "Kcur", il); + } + Qcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, @@ -9468,9 +10564,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9524,28 +10620,161 @@ struct llm_build_context { return gf; } -}; -static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { - llama_batch dummy; - dummy.n_tokens = 0; + // ref: https://allenai.org/olmo + // based on the original build_llama() function, changes: + // * non-parametric layer norm + // * clamp qkv + // * removed bias + // * removed MoE + struct ggml_cgraph * build_olmo() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; - struct llm_build_context llm(lctx, dummy, cb, false); + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); - llm.init(); + struct ggml_tensor * cur; + struct ggml_tensor * inpL; - struct ggml_cgraph * result = llm.build_defrag(ids); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - llm.free(); + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); - return result; -} + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); -static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + NULL, NULL, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + NULL, NULL, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + NULL, NULL, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } +}; + +static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { + llama_batch dummy; + dummy.n_tokens = 0; + + llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; + + struct llm_build_context llm(lctx, dummy, cb, false); + + llm.init(); + + struct ggml_cgraph * result = llm.build_defrag(ids); + + llm.free(); + + return result; +} + +static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { + llama_batch dummy; + dummy.n_tokens = 0; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -9649,6 +10878,7 @@ static struct ggml_cgraph * llama_build_graph( result = llm.build_refact(); } break; case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: { result = llm.build_bert(); @@ -9673,10 +10903,18 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen2(); } break; + case LLM_ARCH_QWEN2MOE: + { + result = llm.build_qwen2moe(); + } break; case LLM_ARCH_PHI2: { result = llm.build_phi2(); } break; + case LLM_ARCH_PHI3: + { + result = llm.build_phi3(); + } break; case LLM_ARCH_PLAMO: { result = llm.build_plamo(); @@ -9721,6 +10959,14 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_command_r(); } break; + case LLM_ARCH_DBRX: + { + result = llm.build_dbrx(); + } break; + case LLM_ARCH_OLMO: + { + result = llm.build_olmo(); + } break; default: GGML_ASSERT(false); } @@ -9840,11 +11086,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { f = -INFINITY; } else { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(lctx.kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } } data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } } } else { // when using kv cache, the mask needs to match the kv cache size @@ -9863,7 +11119,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float f = -INFINITY; for (int s = 0; s < batch.n_seq_id[i]; ++s) { if (batch.seq_id[i][s] == seq_id) { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(batch.pos[i] - batch.pos[j]); + } else { + f = 0.0f; + } break; } } @@ -9879,19 +11139,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (hparams.need_kq_pos) { - const int64_t n_kv = kv_self.n; - - GGML_ASSERT(lctx.inp_KQ_pos); - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer)); - - float * data = (float *) lctx.inp_KQ_pos->data; - - for (int i = 0; i < n_kv; ++i) { - data[i] = float(lctx.kv_self.cells[i].pos); - } - } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; @@ -10261,7 +11508,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -10411,6 +11658,9 @@ static int llama_decode_internal( n_outputs_prev += lctx.n_outputs; } + // set to total number of outputs in the batch, for use in llama_get_logits_ith + lctx.n_outputs = n_outputs; + // wait for the computation to finish (automatically done when obtaining the model output) //llama_synchronize(&lctx); @@ -10426,6 +11676,10 @@ static int llama_decode_internal( } } + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(lctx.sched); + return 0; } @@ -10451,7 +11705,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // each move requires 6*n_layer tensors (see build_defrag) // - source view, destination view, copy operation // - x2 for keys and values - const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -10767,7 +12023,7 @@ static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(llama_is_byte_token(vocab, id)); - const auto& token_data = vocab.id_to_token.at(id); + const auto & token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_SPM: { auto buf = token_data.text.substr(3, 2); @@ -10775,7 +12031,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { } case LLAMA_VOCAB_TYPE_BPE: { GGML_ASSERT(false); - return unicode_utf8_to_byte(token_data.text); + return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT? } case LLAMA_VOCAB_TYPE_WPM: { GGML_ASSERT(false); @@ -10997,7 +12253,101 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - auto word_collection = bpe_gpt2_preprocess(text); + bool ignore_merges = false; + + std::vector word_collection; + switch (vocab.type) { + case LLAMA_VOCAB_TYPE_BPE: + switch (vocab.type_pre) { + case LLAMA_VOCAB_PRE_TYPE_LLAMA3: + ignore_merges = true; + word_collection = unicode_regex_split(text, { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + + // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_DBRX: + word_collection = unicode_regex_split(text, { + // same as llama3 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: + word_collection = unicode_regex_split(text, { + "[\r\n]", + "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", + "\\s?[!-/:-~!-/:-~‘-‟ -。]+", + "\\s+$", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: + word_collection = unicode_regex_split(text, { + "[\r\n]", + "\\s?\\p{L}+", + "\\s?\\p{P}+", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_FALCON: + word_collection = unicode_regex_split(text, { + "[\\p{P}\\$\\+<=>\\^~\\|]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "[0-9][0-9][0-9]", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_MPT: + // TODO: MPT pre-tokenization regexes are unknown + // the following are close, but not exact. run the following: + // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf + GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed"); + word_collection = unicode_regex_split(text, { + "\\s?\\p{L}+", + "\\s?\\p{P}+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_STARCODER: + case LLAMA_VOCAB_PRE_TYPE_REFACT: + case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: + word_collection = unicode_regex_split(text, { + "\\p{N}", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_GPT2: + case LLAMA_VOCAB_PRE_TYPE_OLMO: + word_collection = unicode_regex_split(text, { + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_QWEN2: + word_collection = unicode_regex_split(text, { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }); + break; + default: + // default regex for BPE tokenization pre-processing + word_collection = unicode_regex_split(text, { + "[\\p{P}\\$\\+<=>\\^~\\|]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "\\p{N}+", + "[0-9][0-9][0-9]", + }); + break; + } + break; + default: + GGML_ASSERT(false); + break; + } symbols_final.clear(); @@ -11008,6 +12358,11 @@ struct llm_tokenizer_bpe { int index = 0; size_t offset = 0; + if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } + while (offset < word.size()) { llm_symbol sym; size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset])); @@ -11054,7 +12409,7 @@ struct llm_tokenizer_bpe { add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol } - // add the fnished tokens to the final list keeping correct order for next and prev + // add the finished tokens to the final list keeping correct order for next and prev for (auto & sym : symbols) { if (sym.n > 0) { sym.prev = final_prev_index; @@ -11124,145 +12479,6 @@ struct llm_tokenizer_bpe { work_queue.push(bigram); } - std::vector bpe_gpt2_preprocess(const std::string & text) { - std::vector bpe_words; - std::vector bpe_encoded_words; - - std::string token = ""; - // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ - bool collecting_numeric = false; - bool collecting_letter = false; - bool collecting_special = false; - bool collecting_whitespace_lookahead = false; - bool collecting = false; - - std::vector text_utf; - text_utf.reserve(text.size()); - bpe_words.reserve(text.size()); - bpe_encoded_words.reserve(text.size()); - - const auto cpts = unicode_cpts_from_utf8(text); - for (size_t i = 0; i < cpts.size(); ++i) - text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i])); - - for (int i = 0; i < (int)text_utf.size(); i++) { - const std::string & utf_char = text_utf[i]; - bool split_condition = false; - int bytes_remain = text_utf.size() - i; - // forward backward lookups - const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; - const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; - - // handling contractions - if (!split_condition && bytes_remain >= 2) { - // 's|'t|'m|'d - if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { - split_condition = true; - } - if (split_condition) { - if (token.size()) { - bpe_words.emplace_back(token); // push previous content as token - } - token = utf_char + utf_char_next; - bpe_words.emplace_back(token); - token = ""; - i++; - continue; - } - } - if (!split_condition && bytes_remain >= 3) { - // 're|'ve|'ll - if (utf_char == "\'" && ( - (utf_char_next == "r" && utf_char_next_next == "e") || - (utf_char_next == "v" && utf_char_next_next == "e") || - (utf_char_next == "l" && utf_char_next_next == "l")) - ) { - split_condition = true; - } - if (split_condition) { - // current token + next token can be defined - if (token.size()) { - bpe_words.emplace_back(token); // push previous content as token - } - token = utf_char + utf_char_next + utf_char_next_next; - bpe_words.emplace_back(token); // the contraction - token = ""; - i += 2; - continue; - } - } - - if (!split_condition && !collecting) { - if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { - collecting_letter = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - collecting_numeric = true; - collecting = true; - } - else if ( - ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || - (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) - ) { - collecting_special = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { - collecting_whitespace_lookahead = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { - split_condition = true; - } - } - else if (!split_condition && collecting) { - if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) { - split_condition = true; - } - else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) { - split_condition = true; - } - else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { - split_condition = true; - } - else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - split_condition = true; - } - } - - if (utf_char_next == "") { - split_condition = true; // final - token += utf_char; - } - - if (split_condition) { - if (token.size()) { - bpe_words.emplace_back(token); - } - token = utf_char; - collecting = false; - collecting_letter = false; - collecting_numeric = false; - collecting_special = false; - collecting_whitespace_lookahead = false; - } - else { - token += utf_char; - } - } - - for (std::string & word : bpe_words) { - std::string encoded_token = ""; - for (char & c : word) { - encoded_token += unicode_byte_to_utf8(c); - } - bpe_encoded_words.emplace_back(encoded_token); - } - - return bpe_encoded_words; - } - const llama_vocab & vocab; std::vector symbols; @@ -11323,9 +12539,6 @@ struct llm_tokenizer_wpm { output.push_back(vocab.special_unk_id); } } - - // append eos token - output.push_back(vocab.special_eos_id); } std::vector preprocess(const std::string & text) { @@ -11340,7 +12553,7 @@ struct llm_tokenizer_wpm { continue; } code = unicode_tolower(code); - if (type == CODEPOINT_TYPE_WHITESPACE) { + if (type == CODEPOINT_TYPE_SEPARATOR) { code = ' '; } std::string s = unicode_cpt_to_utf8(code); @@ -11530,30 +12743,28 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< } } -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) { +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { std::vector output; - - // OG tokenizer behavior: - // - // tokenizer.encode('', add_bos=True) returns [1] - // tokenizer.encode('', add_bos=False) returns [] - - if (bos && vocab.special_bos_id != -1) { - output.push_back(vocab.special_bos_id); - } - - if (raw_text.empty()) { - return output; - } - std::forward_list fragment_buffer; - fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); - if (special) tokenizer_st_partition(vocab, fragment_buffer); + if (!raw_text.empty()) { + fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); + if (parse_special) tokenizer_st_partition(vocab, fragment_buffer); + } switch (vocab.type) { case LLAMA_VOCAB_TYPE_SPM: { + // OG tokenizer behavior: + // + // tokenizer.encode('', add_special_tokens=True) returns [1] + // tokenizer.encode('', add_special_tokens=False) returns [] + + if (add_special && vocab.special_add_bos != 0) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + } + for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { // without adding this leading whitespace, we do not get the same results as the original tokenizer @@ -11579,9 +12790,19 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(fragment.token); } } + + if (add_special && vocab.special_add_eos == 1) { + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); + } } break; case LLAMA_VOCAB_TYPE_BPE: { + if (add_special && vocab.special_add_bos != 0) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + } + for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -11595,9 +12816,19 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(fragment.token); } } + + if (add_special && vocab.special_add_eos == 1) { + GGML_ASSERT(vocab.special_add_eos != -1); + output.push_back(vocab.special_eos_id); + } } break; case LLAMA_VOCAB_TYPE_WPM: { + if (add_special) { + GGML_ASSERT(vocab.special_cls_id != -1); + output.push_back(vocab.special_cls_id); + } + for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -11611,6 +12842,11 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(fragment.token); } } + + if (add_special) { + GGML_ASSERT(vocab.special_sep_id != -1); + output.push_back(vocab.special_sep_id); + } } break; case LLAMA_VOCAB_TYPE_NONE: GGML_ASSERT(false); @@ -11777,7 +13013,9 @@ static void llama_grammar_advance_stack( std::vector> & new_stacks) { if (stack.empty()) { - new_stacks.emplace_back(stack); + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + new_stacks.emplace_back(stack); + } return; } @@ -11814,7 +13052,10 @@ static void llama_grammar_advance_stack( } case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: - new_stacks.emplace_back(stack); + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + // only add the stack if it's not a duplicate of one we already have + new_stacks.emplace_back(stack); + } break; default: // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range @@ -11828,12 +13069,13 @@ static void llama_grammar_advance_stack( // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -std::vector> llama_grammar_accept( +void llama_grammar_accept( const std::vector> & rules, const std::vector> & stacks, - const uint32_t chr) { + const uint32_t chr, + std::vector> & new_stacks) { - std::vector> new_stacks; + new_stacks.clear(); for (const auto & stack : stacks) { if (stack.empty()) { @@ -11852,8 +13094,6 @@ std::vector> llama_grammar_accept( llama_grammar_advance_stack(rules, new_stack, new_stacks); } } - - return new_stacks; } static std::vector llama_grammar_reject_candidates( @@ -11867,6 +13107,7 @@ static std::vector llama_grammar_reject_candidates_for_ const std::vector & candidates) { std::vector rejects; + rejects.reserve(candidates.size()); if (stack.empty()) { for (const auto & tok : candidates) { @@ -11880,6 +13121,8 @@ static std::vector llama_grammar_reject_candidates_for_ const llama_grammar_element * stack_pos = stack.back(); std::vector next_candidates; + next_candidates.reserve(candidates.size()); + for (const auto & tok : candidates) { if (*tok.code_points == 0) { // reached end of full codepoints in token, reject iff it ended in a partial sequence @@ -12465,16 +13708,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); - bool allow_eos = false; + bool allow_eog = false; for (const auto & stack : grammar->stacks) { if (stack.empty()) { - allow_eos = true; + allow_eog = true; break; } } - const llama_token eos = llama_token_eos(&ctx->model); - std::vector, llama_partial_utf8>> candidates_decoded; candidates_decoded.reserve(candidates->size); std::vector candidates_grammar; @@ -12482,9 +13723,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_piece(ctx, id); - if (id == eos) { - if (!allow_eos) { + const std::string piece = llama_token_to_piece(ctx, id, false); + + if (llama_token_is_eog(&ctx->model, id)) { + if (!allow_eog) { candidates->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { @@ -12647,7 +13889,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da return result; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); @@ -12660,7 +13902,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra } std::discrete_distribution<> dist(probs.begin(), probs.end()); - auto & rng = ctx->rng; int idx = dist(rng); llama_token result = candidates->data[idx].id; @@ -12670,10 +13911,14 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + return llama_sample_token_with_rng(ctx, candidates, ctx->rng); +} + void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - if (token == llama_token_eos(&ctx->model)) { + if (llama_token_is_eog(&ctx->model, token)) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { return; @@ -12682,13 +13927,15 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar GGML_ASSERT(false); } - const std::string piece = llama_token_to_piece(ctx, token); + const std::string piece = llama_token_to_piece(ctx, token, false); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8); const auto & code_points = decoded.first; + std::vector> tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); + grammar->stacks = tmp_new_stacks; } grammar->partial_utf8 = decoded.second; GGML_ASSERT(!grammar->stacks.empty()); @@ -12822,6 +14069,11 @@ struct llama_beam_search_data { } llama_logit_info logit_info(ctx); std::vector next_tokens = logit_info.top_k(n_beams); + + // Clear the kv slot so that other beams may try different tokens at this position. The llama_decode() + // call in loop() will conclusively fill in the kv slot once the beams converge at this position. + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + size_t i=0; if (next_beams.size() < n_beams) { for (; next_beams.size() < n_beams ; ++i) { @@ -12991,13 +14243,16 @@ static void llama_tensor_dequantize_internal( if (qtype.to_float == NULL) { throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); } - } else if (tensor->type != GGML_TYPE_F16) { + } else if (tensor->type != GGML_TYPE_F16 && + tensor->type != GGML_TYPE_BF16) { throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); } if (nthread < 2) { if (tensor->type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); + } else if (tensor->type == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements); } else if (ggml_is_quantized(tensor->type)) { qtype.to_float(tensor->data, f32_output, nelements); } else { @@ -13006,7 +14261,14 @@ static void llama_tensor_dequantize_internal( return; } - size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type); + size_t block_size; + if (tensor->type == GGML_TYPE_F16 || + tensor->type == GGML_TYPE_BF16) { + block_size = 1; + } else { + block_size = (size_t)ggml_blck_size(tensor->type); + } + size_t block_size_bytes = ggml_type_size(tensor->type); GGML_ASSERT(nelements % block_size == 0); @@ -13025,6 +14287,8 @@ static void llama_tensor_dequantize_internal( auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { if (typ == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); + } else if (typ == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels); } else { qtype.to_float(inbuf, outbuf, nels); } @@ -13320,21 +14584,27 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n return new_type; } -static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int chunk_size, int nrows, int n_per_row, const float * imatrix, std::vector & workers, const int nthread) { - std::mutex mutex; - int counter = 0; - size_t new_size = 0; +static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { if (nthread < 2) { // single-thread - return ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); + size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); + if (!ggml_validate_row_data(new_type, new_data, new_size)) { + throw std::runtime_error("quantized data validation failed"); + } + return new_size; } - auto compute = [&mutex, &counter, &new_size, new_type, f32_data, new_data, chunk_size, + + std::mutex mutex; + int64_t counter = 0; + size_t new_size = 0; + bool valid = true; + auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size, nrows, n_per_row, imatrix]() { - const int nrows_per_chunk = chunk_size / n_per_row; + const int64_t nrows_per_chunk = chunk_size / n_per_row; size_t local_size = 0; while (true) { std::unique_lock lock(mutex); - int first_row = counter; counter += nrows_per_chunk; + int64_t first_row = counter; counter += nrows_per_chunk; if (first_row >= nrows) { if (local_size > 0) { new_size += local_size; @@ -13342,8 +14612,18 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa break; } lock.unlock(); - const int this_nrow = std::min(nrows - first_row, nrows_per_chunk); - local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); + const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk); + size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); + local_size += this_size; + + // validate the quantized data + const size_t row_size = ggml_row_size(new_type, n_per_row); + void * this_data = (char *) new_data + first_row * row_size; + if (!ggml_validate_row_data(new_type, this_data, this_size)) { + std::unique_lock lock(mutex); + valid = false; + break; + } } }; for (int it = 0; it < nthread - 1; ++it) { @@ -13352,6 +14632,9 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa compute(); for (auto & w : workers) { w.join(); } workers.clear(); + if (!valid) { + throw std::runtime_error("quantized data validation failed"); + } return new_size; } @@ -13366,6 +14649,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; // K-quants @@ -13414,7 +14698,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, kv_overrides); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); ml.init_mappings(false); // no prefetching llama_model model; @@ -13442,17 +14726,23 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s gguf_set_kv (ctx_out, ml.meta); gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); gguf_set_val_u32(ctx_out, "general.file_type", ftype); + // Remove split metadata + gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str()); + gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str()); + gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str()); if (params->kv_overrides) { const std::vector & overrides = *(const std::vector *)params->kv_overrides; for (auto & o : overrides) { if (o.key[0] == 0) break; if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - gguf_set_val_f32(ctx_out, o.key, o.float_value); + gguf_set_val_f32(ctx_out, o.key, o.val_f64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { - gguf_set_val_i32(ctx_out, o.key, o.int_value); + gguf_set_val_i32(ctx_out, o.key, o.val_i64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - gguf_set_val_bool(ctx_out, o.key, o.bool_value); + gguf_set_val_bool(ctx_out, o.key, o.val_bool); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out, o.key, o.val_str); } else { LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); } @@ -13465,7 +14755,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const std::string name = ggml_get_name(meta); // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { + if (name.find("attn_v.weight") != std::string::npos || + name.find("attn_qkv.weight") != std::string::npos) { ++qs.n_attention_wv; } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { qs.has_output = true; @@ -13475,7 +14766,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks - GGML_ASSERT(qs.n_attention_wv == (int)model.hparams.n_layer && "n_attention_wv != n_layer is unexpected"); + // + // - qs.n_attention_wv == 0 for Mamba models + // - qs.n_attention_wv == model.hparams.n_layer for Transformer models + // + GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); size_t total_size_org = 0; size_t total_size_new = 0; @@ -13489,26 +14784,74 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector> work; std::vector> f32_conv_buf; + uint16_t n_split = 1; + // Assume split index is continuous + if (params->keep_split) { + for (int i = 0; i < ml.n_tensors; ++i) { + n_split = std::max(uint16_t(ml.get_weight(i)->idx+1), n_split); + } + } + std::vector ctx_outs(n_split, NULL); + ctx_outs[0] = ctx_out; + // populate the original tensors so we get an initial meta data for (int i = 0; i < ml.n_tensors; ++i) { - const struct ggml_tensor * meta = ml.get_tensor_meta(i); - gguf_add_tensor(ctx_out, meta); + auto weight = ml.get_weight(i); + uint16_t i_split = params->keep_split ? weight->idx : 0; + struct ggml_tensor * tensor = weight->tensor; + if (ctx_outs[i_split] == NULL) { + ctx_outs[i_split] = gguf_init_empty(); + } + gguf_add_tensor(ctx_outs[i_split], tensor); } - std::ofstream fout(fname_out, std::ios::binary); - fout.exceptions(std::ofstream::failbit); // fail fast on write errors - - const size_t meta_size = gguf_get_meta_size(ctx_out); - - LLAMA_LOG_INFO("%s: meta size = %zu bytes\n", __func__, meta_size); + // Set split info if needed + if (n_split > 1) { + for (size_t i = 0; i < ctx_outs.size(); ++i) { + gguf_set_val_u16(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i); + gguf_set_val_u16(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split); + gguf_set_val_i32(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors); + } + } - // placeholder for the meta data - ::zeros(fout, meta_size); + int cur_split = -1; + std::ofstream fout; + auto close_ofstream = [&]() { + // Write metadata and close file handler + if (fout.is_open()) { + fout.seekp(0); + std::vector data(gguf_get_meta_size(ctx_outs[cur_split])); + gguf_get_meta_data(ctx_outs[cur_split], data.data()); + fout.write((const char *) data.data(), data.size()); + fout.close(); + } + }; + auto new_ofstream = [&](int index) { + cur_split = index; + GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context"); + std::string fname = fname_out; + if (params->keep_split) { + char split_path[PATH_MAX] = {0}; + llama_split_path(split_path, sizeof(split_path), fname_out.c_str(), cur_split, n_split); + fname = std::string(split_path); + } - const auto tn = LLM_TN(model.arch); + fout = std::ofstream(fname, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors + const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split]); + // placeholder for the meta data + ::zeros(fout, meta_size); + }; + const auto tn = LLM_TN(model.arch); + new_ofstream(0); for (int i = 0; i < ml.n_tensors; ++i) { - struct ggml_tensor * tensor = ml.get_tensor_meta(i); + auto weight = ml.get_weight(i); + struct ggml_tensor * tensor = weight->tensor; + if (weight->idx != cur_split && params->keep_split) { + close_ofstream(); + new_ofstream(weight->idx); + } const std::string name = ggml_get_name(tensor); @@ -13531,6 +14874,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); + + // do not quantize norm tensors + quantize &= name.find("_norm.weight") == std::string::npos; + quantize &= params->quantize_output_tensor || name != "output.weight"; quantize &= !params->only_copy; @@ -13559,10 +14906,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (!params->pure && ggml_is_quantized(default_type)) { new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); } - else if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { + if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; } - else if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { + if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { new_type = params->output_tensor_type; } @@ -13577,7 +14924,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_size = ggml_nbytes(tensor); LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); } else { - const size_t nelements = ggml_nelements(tensor); + const int64_t nelements = ggml_nelements(tensor); const float * imatrix = nullptr; if (imatrix_data) { @@ -13629,20 +14976,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); fflush(stdout); - if (work.size() < nelements * 4) { + if (work.size() < (size_t)nelements * 4) { work.resize(nelements * 4); // upper bound on size } new_data = work.data(); - const int n_per_row = tensor->ne[0]; - const int nrows = tensor->ne[1]; + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; - static const int min_chunk_size = 32 * 512; - const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row); + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row); - const int nelements_matrix = tensor->ne[0] * tensor->ne[1]; - const int nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; - const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; // quantize each expert separately since they have different importance matrices new_size = 0; @@ -13659,26 +15006,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s total_size_new += new_size; // update the gguf meta data as we go - gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + gguf_set_tensor_type(ctx_outs[cur_split], name.c_str(), new_type); + gguf_set_tensor_data(ctx_outs[cur_split], name.c_str(), new_data, new_size); // write tensor data + padding fout.write((const char *) new_data, new_size); zeros(fout, GGML_PAD(new_size, align) - new_size); } - - // go back to beginning of file and write the updated meta data - { - fout.seekp(0); - std::vector data(gguf_get_meta_size(ctx_out)); - gguf_get_meta_data(ctx_out, data.data()); - fout.write((const char *) data.data(), data.size()); + close_ofstream(); + for (auto & c:ctx_outs) { + gguf_free(c); } - fout.close(); - - gguf_free(ctx_out); - LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); @@ -13722,7 +15061,7 @@ static int llama_apply_lora_from_file_internal( std::unique_ptr ml; if (path_base_model) { LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ nullptr)); + ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr)); ml->init_mappings(/*prefetch*/ false); // no prefetching } @@ -13981,6 +15320,7 @@ struct llama_model_params llama_model_default_params() { /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, + /*.check_tensors =*/ false, }; #ifdef GGML_USE_METAL @@ -14017,6 +15357,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -14034,6 +15375,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.quantize_output_tensor =*/ true, /*.only_copy =*/ false, /*.pure =*/ false, + /*.keep_split =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, }; @@ -14182,6 +15524,7 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -14189,12 +15532,20 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = GGML_KQ_MASK_PAD; + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : @@ -14226,6 +15577,11 @@ struct llama_context * llama_new_context_with_model( } } + if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); + cparams.flash_attn = false; + } + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } @@ -14233,6 +15589,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -14361,7 +15718,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -14514,6 +15871,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_JINA_BERT_V2: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -14527,18 +15885,22 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_MINICPM: case LLM_ARCH_XVERSE: case LLM_ARCH_COMMAND_R: + case LLM_ARCH_OLMO: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 case LLM_ARCH_FALCON: case LLM_ARCH_GROK: + case LLM_ARCH_DBRX: case LLM_ARCH_PERSIMMON: case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2MOE: case LLM_ARCH_PHI2: + case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: case LLM_ARCH_STARCODER2: return LLAMA_ROPE_TYPE_NEOX; @@ -14552,6 +15914,10 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { + return ctx->cparams.pooling_type; +} + int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } @@ -14907,9 +16273,33 @@ void llama_kv_cache_update(struct llama_context * ctx) { llama_kv_cache_update_internal(*ctx); } +// deprecated +size_t llama_get_state_size(const struct llama_context * ctx) { + return llama_state_get_size(ctx); +} + +// deprecated +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + return llama_state_get_data(ctx, dst); +} + +// deprecated +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { + return llama_state_set_data(ctx, src); +} + +// deprecated +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); +} + +// deprecated +bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + return llama_state_save_file(ctx, path_session, tokens, n_token_count); +} // Returns the *maximum* size of the state -size_t llama_get_state_size(const struct llama_context * ctx) { +size_t llama_state_get_size(const struct llama_context * ctx) { const auto & cparams = ctx->cparams; const auto & hparams = ctx->model.hparams; @@ -14928,6 +16318,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); + const size_t s_v_trans = sizeof(uint32_t); const size_t s_kv = ctx->kv_self.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; @@ -14945,10 +16336,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) { + s_kv_head + s_kv_size + s_kv_used + + s_v_trans + s_kv + s_kv_cells ); + // on session change it is very likely that the state size has changed - so we need to update this function + static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); + return s_total; } @@ -14997,15 +16392,17 @@ struct llama_data_file_context : llama_data_context { * file context: * llama_file file("/path", "wb"); * llama_data_file_context data_ctx(&file); - * llama_copy_state_data(ctx, &data_ctx); + * llama_state_get_data(ctx, &data_ctx); * * buffer context: * std::vector buf(max_size, 0); * llama_data_buffer_context data_ctx(&buf.data()); - * llama_copy_state_data(ctx, &data_ctx); + * llama_state_get_data(ctx, &data_ctx); * */ -static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { +static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { + llama_synchronize(ctx); + // copy rng { std::ostringstream rng_ss; @@ -15092,11 +16489,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const uint32_t kv_size = kv_self.size; const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head; const uint32_t kv_used = kv_self.used; + const uint32_t v_trans = kv_self.v_trans ? 1 : 0; data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); data_ctx->write(&kv_head, sizeof(kv_head)); data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_used, sizeof(kv_used)); + data_ctx->write(&v_trans, sizeof(v_trans)); if (kv_buf_size) { const size_t pre_kv_buf_size = data_ctx->get_size_written(); @@ -15109,7 +16508,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -15149,15 +16548,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat } } -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { +size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) { llama_data_buffer_context data_ctx(dst); - llama_copy_state_data_internal(ctx, &data_ctx); + llama_state_get_data_internal(ctx, &data_ctx); return data_ctx.get_size_written(); } // Sets the state reading from the specified source address -size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { +size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { + llama_synchronize(ctx); + const uint8_t * inp = src; // set rng @@ -15194,6 +16595,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch); ctx->output_ids[id] = i; } + + ctx->n_outputs = n_outputs; } } @@ -15238,11 +16641,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { uint32_t kv_head; uint32_t kv_size; uint32_t kv_used; + uint32_t v_trans; memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans); + + GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition if (kv_self.size != kv_size) { // the KV cache needs to be big enough to load all the KV cells from the saved state @@ -15252,6 +16659,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } + llama_kv_cache_clear(ctx); + if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; @@ -15263,7 +16672,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -15285,8 +16694,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; ctx->kv_self.used = kv_used; @@ -15309,14 +16716,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { } const size_t nread = inp - src; - const size_t max_size = llama_get_state_size(ctx); + const size_t max_size = llama_state_get_size(ctx); GGML_ASSERT(nread <= max_size); return nread; } -static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks @@ -15354,7 +16761,7 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c // restore the context state { const size_t n_state_size_cur = file.size - file.tell(); - const size_t n_state_size_max = llama_get_state_size(ctx); + const size_t n_state_size_max = llama_state_get_size(ctx); if (n_state_size_cur > n_state_size_max) { LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); @@ -15364,22 +16771,22 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c std::vector state_data(n_state_size_max); file.read_raw(state_data.data(), n_state_size_cur); - llama_set_state_data(ctx, state_data.data()); + llama_state_set_data(ctx, state_data.data()); } return true; } -bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { try { - return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); return false; } } -bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { +static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { llama_file file(path_session, "wb"); file.write_u32(LLAMA_SESSION_MAGIC); @@ -15393,11 +16800,479 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi // save the context state using stream saving llama_data_file_context data_ctx(&file); - llama_copy_state_data_internal(ctx, &data_ctx); + llama_state_get_data_internal(ctx, &data_ctx); return true; } +bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error saving session file: %s\n", err.what()); + return false; + } +} + +size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) { + // save the size of size_t as a uint32_t for safety check + const size_t size_t_size_size = sizeof(uint32_t); + + // other values + const size_t s_cell_count_size = sizeof(uint32_t); + const size_t s_layer_count_size = sizeof(uint32_t); + const size_t n_embd_v_gqa_size = sizeof(uint32_t); + + size_t s_cell_count = 0; + size_t s_cell_data_size = 0; + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto & cell = kv_self.cells[i]; + if (cell.seq_id.count(seq_id) > 0) { + ++s_cell_count; + s_cell_data_size += sizeof(llama_pos); + } + } + + for (int il = 0; il < (int)n_layer; ++il) { + // types of keys and values + s_cell_data_size += sizeof(int32_t) * 2; + // k_size_row and v_size_el values of layer + s_cell_data_size += sizeof(size_t) * 2; + + // keys + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + s_cell_data_size += k_size_row * s_cell_count; + + // values (transposed) + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa; + } + + const size_t s_total = ( + size_t_size_size + + s_cell_count_size + + s_layer_count_size + + n_embd_v_gqa_size + + s_cell_data_size + ); + + return s_total; +} + +static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) { + llama_synchronize(ctx); + + const auto & kv_self = ctx->kv_self; + GGML_ASSERT(!kv_self.recurrent); // not implemented + + // Save the size of size_t as a uint32_t for safety check + const uint32_t size_t_size = sizeof(size_t); + data_ctx.write(&size_t_size, sizeof(size_t_size)); + + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id + { + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto & cell = kv_self.cells[i]; + if (cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } + else { + if (cell_range_begin != kv_self.size) { + cell_ranges.push_back({ cell_range_begin, i }); + cell_range_begin = kv_self.size; + } + } + } + if (cell_range_begin != kv_self.size) { + cell_ranges.push_back({ cell_range_begin, kv_self.size }); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + } + + // Write the cell count + data_ctx.write(&cell_count, sizeof(cell_count)); + + const auto & hparams = ctx->model.hparams; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + + // Write the layer count + data_ctx.write(&n_layer, sizeof(n_layer)); + + // Write n_embd_v_gqa + data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // Iterate the ranges and write all the pos (this is the token position in the prompt) + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = kv_self.cells[i]; + data_ctx.write(&cell.pos, sizeof(cell.pos)); + } + } + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + std::vector tmp_buf; + for (int il = 0; il < (int)n_layer; ++il) { + // Write key type + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + data_ctx.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + data_ctx.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + tmp_buf.resize(range_size * k_size_row); + ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); + } + } + + // TODO: simplify, reduce copy-paste + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + data_ctx.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + tmp_buf.resize(range_size * v_size_row); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); + } + } + } else { + // For the values, they are transposed, so we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + data_ctx.write(&v_size_el, sizeof(v_size_el)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + tmp_buf.resize(range_size * v_size_el); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); + } + } + } + } + + return data_ctx.get_size_written(); +} + +size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) { + llama_data_buffer_context data_ctx(dst); + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); +} + +size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { + llama_synchronize(ctx); + + auto & kv_self = ctx->kv_self; + GGML_ASSERT(!kv_self.recurrent); // not implemented + + // Wipe the slot + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + + const uint8_t * inp = src; + + // Read size of size_t + uint32_t size_t_size; + memcpy(&size_t_size, inp, sizeof(size_t_size)); + inp += sizeof(size_t_size); + if (size_t_size != sizeof(size_t)) { + LLAMA_LOG_ERROR("%s: size_t size mismatch\n", __func__); + return 0; + } + + // Read the cell count + uint32_t cell_count; + memcpy(&cell_count, inp, sizeof(cell_count)); + inp += sizeof(cell_count); + + // Read the layer count + uint32_t n_layer_ref; + memcpy(&n_layer_ref, inp, sizeof(n_layer_ref)); + inp += sizeof(n_layer_ref); + + // Read n_embd_v_gqa + uint32_t n_embd_v_gqa_ref; + memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref)); + inp += sizeof(n_embd_v_gqa_ref); + + // Sanity check model compatibility + const auto & hparams = ctx->model.hparams; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + if (n_layer != n_layer_ref) { + LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref); + return 0; + } + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref); + return 0; + } + + // Allocate the new cells for the slot + if (cell_count) { + llama_batch batch = llama_batch_init(cell_count, 0, 1); + batch.n_tokens = cell_count; + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + memcpy(&pos, inp, sizeof(pos)); + inp += sizeof(pos); + + batch.pos[i] = pos; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = dest_seq_id; + } + if (!llama_kv_cache_find_slot(kv_self, batch)) { + llama_batch_free(batch); + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return 0; + } + + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); + GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); + + // Cleanup + llama_batch_free(batch); + } + + const uint32_t kv_size = kv_self.size; + const uint32_t kv_head = kv_self.head; + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of key + int32_t k_type_i_ref; + memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref)); + inp += sizeof(k_type_i_ref); + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + if (k_type_i != k_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return 0; + } + + // Read row size of key + size_t k_size_row_ref; + memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); + inp += sizeof(k_size_row_ref); + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il); + return 0; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row); + inp += cell_count * k_size_row; + } + } + + // TODO: simplify, reduce copy-paste + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + + // Read row size of value + size_t v_size_row_ref; + memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); + inp += sizeof(v_size_row_ref); + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); + return 0; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); + inp += cell_count * v_size_row; + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + + // Read element size of value + size_t v_size_el_ref; + memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); + inp += sizeof(v_size_el_ref); + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); + return 0; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } + } + } + } + + const size_t nread = inp - src; + + return nread; +} + +static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t)n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_file_context data_ctx(&file); + llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); + return res; +} + +static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // version checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return 0; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t state_size = file.size - file.tell(); + std::vector state_data(state_size); + file.read_raw(state_data.data(), state_size); + const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id); + if (!nread) { + LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); + return 0; + } + GGML_ASSERT(nread <= state_size); + GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); + } + + return file.tell(); +} + +size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what()); + return 0; + } +} + +size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); + return 0; + } +} + void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { ctx->cparams.n_threads = n_threads; ctx->cparams.n_threads_batch = n_threads_batch; @@ -15511,23 +17386,31 @@ float * llama_get_logits(struct llama_context * ctx) { } float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + int32_t j = -1; llama_synchronize(ctx); try { if (ctx->logits == nullptr) { throw std::runtime_error("no logits"); } - if ((size_t) i >= ctx->output_ids.size()) { + + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; } - const int32_t j = ctx->output_ids[i]; if (j < 0) { throw std::runtime_error(format("batch.logits[%d] != true", i)); } - if ((size_t) j >= ctx->output_size) { + if (j >= ctx->n_outputs) { // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size)); + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); } return ctx->logits + j*ctx->model.hparams.n_vocab; @@ -15547,23 +17430,32 @@ float * llama_get_embeddings(struct llama_context * ctx) { } float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { + int32_t j = -1; + llama_synchronize(ctx); try { if (ctx->embd == nullptr) { throw std::runtime_error("no embeddings"); } - if ((size_t) i >= ctx->output_ids.size()) { + + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; } - const int32_t j = ctx->output_ids[i]; if (j < 0) { throw std::runtime_error(format("batch.logits[%d] != true", i)); } - if ((size_t) j >= ctx->output_size) { + if (j >= ctx->n_outputs) { // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size)); + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); } return ctx->embd + j*ctx->model.hparams.n_embd; @@ -15602,6 +17494,13 @@ llama_token_type llama_token_get_type(const struct llama_model * model, llama_to return model->vocab.id_to_token[token].type; } +bool llama_token_is_eog(const struct llama_model * model, llama_token token) { + return token != -1 && ( + token == llama_token_eos(model) || + token == llama_token_eot(model) + ); +} + llama_token llama_token_bos(const struct llama_model * model) { return model->vocab.special_bos_id; } @@ -15610,6 +17509,14 @@ llama_token llama_token_eos(const struct llama_model * model) { return model->vocab.special_eos_id; } +llama_token llama_token_cls(const struct llama_model * model) { + return model->vocab.special_cls_id; +} + +llama_token llama_token_sep(const struct llama_model * model) { + return model->vocab.special_sep_id; +} + llama_token llama_token_nl(const struct llama_model * model) { return model->vocab.linefeed_id; } @@ -15644,9 +17551,9 @@ int32_t llama_tokenize( int32_t text_len, llama_token * tokens, int32_t n_tokens_max, - bool add_bos, - bool special) { - auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special); + bool add_special, + bool parse_special) { + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); @@ -15662,16 +17569,17 @@ int32_t llama_tokenize( static std::string llama_decode_text(const std::string & text) { std::string decoded_text; - auto unicode_sequences = unicode_cpts_from_utf8(text); - for (auto & unicode_sequence : unicode_sequences) { - decoded_text += unicode_utf8_to_byte(unicode_cpt_to_utf8(unicode_sequence)); + + const auto cpts = unicode_cpts_from_utf8(text); + for (const auto cpt : cpts) { + decoded_text += unicode_utf8_to_byte(unicode_cpt_to_utf8(cpt)); } return decoded_text; } // does not write null-terminator to buf -int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) { +int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) { if (0 <= token && token < llama_n_vocab(model)) { switch (llama_vocab_get_type(model->vocab)) { case LLAMA_VOCAB_TYPE_WPM: @@ -15686,7 +17594,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token } memcpy(buf, result.c_str(), result.length()); return result.length(); - } else if (llama_is_user_defined_token(model->vocab, token)) { + } else if ( + (llama_is_user_defined_token(model->vocab, token)) || + (llama_is_control_token (model->vocab, token) && special)) { std::string result = model->vocab.id_to_token[token].text; if (length < (int) result.length()) { return -(int) result.length(); @@ -15699,8 +17609,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token } memcpy(buf, "\xe2\x96\x85", 3); return 3; - } else if (llama_is_control_token(model->vocab, token)) { - ; } else if (llama_is_byte_token(model->vocab, token)) { if (length < 1) { return -1; @@ -15721,15 +17629,15 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token } memcpy(buf, result.c_str(), result.length()); return result.length(); - } else if (llama_is_user_defined_token(model->vocab, token)) { + } else if ( + (llama_is_user_defined_token(model->vocab, token)) || + (llama_is_control_token (model->vocab, token) && special)) { std::string result = model->vocab.id_to_token[token].text; if (length < (int) result.length()) { return -(int) result.length(); } memcpy(buf, result.c_str(), result.length()); return result.length(); - } else if (llama_is_control_token(model->vocab, token)) { - ; } break; } @@ -15912,6 +17820,39 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "### Response:\n"; } + } else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) { + // CohereForAI/c4ai-command-r-plus + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "user") { + ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "assistant") { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } + } + if (add_ass) { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; + } + } else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) { + // Llama 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>"; + } + if (add_ass) { + ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; + } + } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos )) { + // Phi 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << trim(message->content) << "<|end|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else { // template not supported return -1; @@ -15995,7 +17936,7 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) { /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, /*.n_sample =*/ std::max(1, ctx->n_sample), - /*.n_p_eval =*/ std::max(1, ctx->n_p_eval), + /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), /*.n_eval =*/ std::max(1, ctx->n_eval), }; @@ -16044,6 +17985,11 @@ const char * llama_print_system_info(void) { s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | "; +#ifdef GGML_USE_LLAMAFILE + s += "LLAMAFILE = 1 | "; +#else + s += "LLAMAFILE = 0 | "; +#endif return s.c_str(); } diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 036b3268533..0b2e708d06d 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -37,9 +37,13 @@ #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' +#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 5 +#define LLAMA_SESSION_VERSION 6 + +#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ +#define LLAMA_STATE_SEQ_VERSION 1 #ifdef __cplusplus extern "C" { @@ -65,6 +69,23 @@ extern "C" { LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece }; + // pre-tokenization types + enum llama_vocab_pre_type { + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10, + LLAMA_VOCAB_PRE_TYPE_OLMO = 11, + LLAMA_VOCAB_PRE_TYPE_DBRX = 12, + }; + // note: these values should be synchronized with ggml_rope // TODO: maybe move this enum to ggml.h (ggml_rope_type) enum llama_rope_type { @@ -118,6 +139,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors + LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -155,7 +177,7 @@ extern "C" { bool sorted; } llama_token_data_array; - typedef bool (*llama_progress_callback)(float progress, void *ctx); + typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_decode // A llama_batch object can contain input about one or many sequences @@ -191,15 +213,19 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, LLAMA_KV_OVERRIDE_TYPE_BOOL, + LLAMA_KV_OVERRIDE_TYPE_STR, }; struct llama_model_kv_override { - char key[128]; enum llama_model_kv_override_type tag; + + char key[128]; + union { - int64_t int_value; - double float_value; - bool bool_value; + int64_t val_i64; + double val_f64; + bool val_bool; + char val_str[128]; }; }; @@ -228,9 +254,10 @@ extern "C" { const struct llama_model_kv_override * kv_overrides; // Keep the booleans together to avoid misalignment during copy-by-value. - bool vocab_only; // only load the vocabulary, no weights - bool use_mmap; // use mmap if possible - bool use_mlock; // force system to keep model in RAM + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + bool check_tensors; // validate model tensor data }; struct llama_context_params { @@ -266,6 +293,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -284,6 +312,7 @@ extern "C" { bool quantize_output_tensor; // quantize output.weight bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides } llama_model_quantize_params; @@ -386,8 +415,10 @@ extern "C" { LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + + LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); @@ -518,11 +549,12 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache + // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) @@ -594,35 +626,93 @@ extern "C" { // Returns the maximum size in bytes of the state (rng, logits, embedding // and kv_cache) - will often be smaller after compacting tokens - LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); + LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx); + LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx), + "use llama_state_get_size instead"); // Copies the state to the specified destination address. // Destination needs to have allocated enough memory. // Returns the number of bytes copied - LLAMA_API size_t llama_copy_state_data( + LLAMA_API size_t llama_state_get_data( struct llama_context * ctx, uint8_t * dst); + LLAMA_API DEPRECATED(size_t llama_copy_state_data( + struct llama_context * ctx, + uint8_t * dst), + "use llama_state_get_data instead"); // Set the state reading from the specified address // Returns the number of bytes read - LLAMA_API size_t llama_set_state_data( + LLAMA_API size_t llama_state_set_data( struct llama_context * ctx, const uint8_t * src); + LLAMA_API DEPRECATED(size_t llama_set_state_data( + struct llama_context * ctx, + const uint8_t * src), + "use llama_state_set_data instead"); // Save/load session file - LLAMA_API bool llama_load_session_file( + LLAMA_API bool llama_state_load_file( struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); + LLAMA_API DEPRECATED(bool llama_load_session_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out), + "use llama_state_load_file instead"); - LLAMA_API bool llama_save_session_file( + LLAMA_API bool llama_state_save_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count); + LLAMA_API DEPRECATED(bool llama_save_session_file( struct llama_context * ctx, const char * path_session, + const llama_token * tokens, + size_t n_token_count), + "use llama_state_save_file instead"); + + // Get the exact size needed to copy the KV cache of a single sequence + LLAMA_API size_t llama_state_seq_get_size( + struct llama_context * ctx, + llama_seq_id seq_id); + + // Copy the KV cache of a single sequence into the specified buffer + LLAMA_API size_t llama_state_seq_get_data( + struct llama_context * ctx, + uint8_t * dst, + llama_seq_id seq_id); + + // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence + // Returns: + // - Positive: Ok + // - Zero: Failed to load + LLAMA_API size_t llama_state_seq_set_data( + struct llama_context * ctx, + const uint8_t * src, + llama_seq_id dest_seq_id); + + LLAMA_API size_t llama_state_seq_save_file( + struct llama_context * ctx, + const char * filepath, + llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count); + LLAMA_API size_t llama_state_seq_load_file( + struct llama_context * ctx, + const char * filepath, + llama_seq_id dest_seq_id, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + // // Decoding // @@ -684,8 +774,9 @@ extern "C" { // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); - // Logits for the ith token. Equivalent to: + // Logits for the ith token. For positive indices, Equivalent to: // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab + // Negative indicies can be used to access logits in reverse order, -1 is the last logit. // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); @@ -697,8 +788,9 @@ extern "C" { // Otherwise, returns NULL. LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - // Get the embeddings for the ith token. Equivalent to: + // Get the embeddings for the ith token. For positive indices, Equivalent to: // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding. // shape: [n_embd] (1-dimensional) // returns NULL for invalid ids. LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); @@ -718,9 +810,14 @@ extern "C" { LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); + // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) + LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence + LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification + LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line // Returns -1 if unknown, 1 for true or 0 for false. @@ -729,7 +826,7 @@ extern "C" { // Returns -1 if unknown, 1 for true or 0 for false. LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model); - // codellama infill tokens + // Codellama infill tokens LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix @@ -743,26 +840,28 @@ extern "C" { /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. /// @return Returns the number of tokens on success, no more than n_tokens_max /// @return Returns a negative number on failure - the number of tokens that would have been returned - /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. - /// Does not insert a leading space. + /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated + /// as plaintext. Does not insert a leading space. LLAMA_API int32_t llama_tokenize( const struct llama_model * model, const char * text, int32_t text_len, llama_token * tokens, int32_t n_tokens_max, - bool add_bos, - bool special); + bool add_special, + bool parse_special); // Token Id -> Piece. // Uses the vocabulary in the provided context. // Does not write null terminator to the buffer. // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. + // @param special If true, special tokens are rendered in the output. LLAMA_API int32_t llama_token_to_piece( const struct llama_model * model, llama_token token, char * buf, - int32_t length); + int32_t length, + bool special); /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" @@ -915,7 +1014,7 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities. + /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. LLAMA_API llama_token llama_sample_token( struct llama_context * ctx, llama_token_data_array * candidates); @@ -1002,8 +1101,9 @@ extern "C" { // Internal API to be implemented by llama.cpp and used by tests/benchmarks only #ifdef LLAMA_API_INTERNAL -#include +#include #include +#include struct ggml_tensor; @@ -1030,15 +1130,20 @@ const std::vector> & llama_internal struct llama_context * ctx ); -std::vector> llama_grammar_accept( +void llama_grammar_accept( const std::vector> & rules, const std::vector> & stacks, - const uint32_t chr); + const uint32_t chr, + std::vector> & new_stacks); std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start); +// Randomly selects a token from the candidates based on their probabilities using given std::mt19937. +// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); + #endif // LLAMA_API_INTERNAL #endif // LLAMA_H diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 4e1c1755f1c..bb8c26d5efd 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -35,10 +35,10 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), false); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), false); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); diff --git a/examples/talk-llama/unicode-data.cpp b/examples/talk-llama/unicode-data.cpp index 22f8b0f0b29..c54175fc3b4 100644 --- a/examples/talk-llama/unicode-data.cpp +++ b/examples/talk-llama/unicode-data.cpp @@ -1,31 +1,50 @@ -#include "unicode-data.h" +#include "unicode-data.h" #include #include #include #include -const std::vector> unicode_ranges_digit = { -{0x00000030, 0x00000039}, {0x000000B2, 0x000000B3}, {0x000000B9, 0x000000B9}, {0x00000660, 0x00000669}, -{0x000006F0, 0x000006F9}, {0x000007C0, 0x000007C9}, {0x00000966, 0x0000096F}, {0x000009E6, 0x000009EF}, -{0x00000A66, 0x00000A6F}, {0x00000AE6, 0x00000AEF}, {0x00000B66, 0x00000B6F}, {0x00000BE6, 0x00000BEF}, -{0x00000C66, 0x00000C6F}, {0x00000CE6, 0x00000CEF}, {0x00000D66, 0x00000D6F}, {0x00000DE6, 0x00000DEF}, -{0x00000E50, 0x00000E59}, {0x00000ED0, 0x00000ED9}, {0x00000F20, 0x00000F29}, {0x00001040, 0x00001049}, -{0x00001090, 0x00001099}, {0x00001369, 0x00001371}, {0x000017E0, 0x000017E9}, {0x00001810, 0x00001819}, -{0x00001946, 0x0000194F}, {0x000019D0, 0x000019DA}, {0x00001A80, 0x00001A89}, {0x00001A90, 0x00001A99}, -{0x00001B50, 0x00001B59}, {0x00001BB0, 0x00001BB9}, {0x00001C40, 0x00001C49}, {0x00001C50, 0x00001C59}, -{0x00002070, 0x00002070}, {0x00002074, 0x00002079}, {0x00002080, 0x00002089}, {0x00002460, 0x00002468}, -{0x00002474, 0x0000247C}, {0x00002488, 0x00002490}, {0x000024EA, 0x000024EA}, {0x000024F5, 0x000024FD}, -{0x000024FF, 0x000024FF}, {0x00002776, 0x0000277E}, {0x00002780, 0x00002788}, {0x0000278A, 0x00002792}, -{0x0000A620, 0x0000A629}, {0x0000A8D0, 0x0000A8D9}, {0x0000A900, 0x0000A909}, {0x0000A9D0, 0x0000A9D9}, -{0x0000A9F0, 0x0000A9F9}, {0x0000AA50, 0x0000AA59}, {0x0000ABF0, 0x0000ABF9}, {0x0000FF10, 0x0000FF19}, -{0x000104A0, 0x000104A9}, {0x00010A40, 0x00010A43}, {0x00010D30, 0x00010D39}, {0x00010E60, 0x00010E68}, -{0x00011052, 0x0001105A}, {0x00011066, 0x0001106F}, {0x000110F0, 0x000110F9}, {0x00011136, 0x0001113F}, -{0x000111D0, 0x000111D9}, {0x000112F0, 0x000112F9}, {0x00011450, 0x00011459}, {0x000114D0, 0x000114D9}, -{0x00011650, 0x00011659}, {0x000116C0, 0x000116C9}, {0x00011730, 0x00011739}, {0x000118E0, 0x000118E9}, -{0x00011950, 0x00011959}, {0x00011C50, 0x00011C59}, {0x00011D50, 0x00011D59}, {0x00011DA0, 0x00011DA9}, -{0x00016A60, 0x00016A69}, {0x00016B50, 0x00016B59}, {0x0001D7CE, 0x0001D7FF}, {0x0001E140, 0x0001E149}, -{0x0001E2F0, 0x0001E2F9}, {0x0001E950, 0x0001E959}, {0x0001F100, 0x0001F10A}, {0x0001FBF0, 0x0001FBF9}, +// generated with scripts/gen-unicode-data.py +// +// TODO: generate unicode_map_nfd + +const std::vector> unicode_ranges_number = { +{0x00000030, 0x00000039}, {0x000000B2, 0x000000B3}, {0x000000B9, 0x000000B9}, {0x000000BC, 0x000000BE}, +{0x00000660, 0x00000669}, {0x000006F0, 0x000006F9}, {0x000007C0, 0x000007C9}, {0x00000966, 0x0000096F}, +{0x000009E6, 0x000009EF}, {0x000009F4, 0x000009F9}, {0x00000A66, 0x00000A6F}, {0x00000AE6, 0x00000AEF}, +{0x00000B66, 0x00000B6F}, {0x00000B72, 0x00000B77}, {0x00000BE6, 0x00000BF2}, {0x00000C66, 0x00000C6F}, +{0x00000C78, 0x00000C7E}, {0x00000CE6, 0x00000CEF}, {0x00000D58, 0x00000D5E}, {0x00000D66, 0x00000D78}, +{0x00000DE6, 0x00000DEF}, {0x00000E50, 0x00000E59}, {0x00000ED0, 0x00000ED9}, {0x00000F20, 0x00000F33}, +{0x00001040, 0x00001049}, {0x00001090, 0x00001099}, {0x00001369, 0x0000137C}, {0x000016EE, 0x000016F0}, +{0x000017E0, 0x000017E9}, {0x000017F0, 0x000017F9}, {0x00001810, 0x00001819}, {0x00001946, 0x0000194F}, +{0x000019D0, 0x000019DA}, {0x00001A80, 0x00001A89}, {0x00001A90, 0x00001A99}, {0x00001B50, 0x00001B59}, +{0x00001BB0, 0x00001BB9}, {0x00001C40, 0x00001C49}, {0x00001C50, 0x00001C59}, {0x00002070, 0x00002070}, +{0x00002074, 0x00002079}, {0x00002080, 0x00002089}, {0x00002150, 0x00002182}, {0x00002185, 0x00002189}, +{0x00002460, 0x0000249B}, {0x000024EA, 0x000024FF}, {0x00002776, 0x00002793}, {0x00002CFD, 0x00002CFD}, +{0x00003007, 0x00003007}, {0x00003021, 0x00003029}, {0x00003038, 0x0000303A}, {0x00003192, 0x00003195}, +{0x00003220, 0x00003229}, {0x00003248, 0x0000324F}, {0x00003251, 0x0000325F}, {0x00003280, 0x00003289}, +{0x000032B1, 0x000032BF}, {0x0000A620, 0x0000A629}, {0x0000A6E6, 0x0000A6EF}, {0x0000A830, 0x0000A835}, +{0x0000A8D0, 0x0000A8D9}, {0x0000A900, 0x0000A909}, {0x0000A9D0, 0x0000A9D9}, {0x0000A9F0, 0x0000A9F9}, +{0x0000AA50, 0x0000AA59}, {0x0000ABF0, 0x0000ABF9}, {0x0000FF10, 0x0000FF19}, {0x00010107, 0x00010133}, +{0x00010140, 0x00010178}, {0x0001018A, 0x0001018B}, {0x000102E1, 0x000102FB}, {0x00010320, 0x00010323}, +{0x00010341, 0x00010341}, {0x0001034A, 0x0001034A}, {0x000103D1, 0x000103D5}, {0x000104A0, 0x000104A9}, +{0x00010858, 0x0001085F}, {0x00010879, 0x0001087F}, {0x000108A7, 0x000108AF}, {0x000108FB, 0x000108FF}, +{0x00010916, 0x0001091B}, {0x000109BC, 0x000109BD}, {0x000109C0, 0x000109CF}, {0x000109D2, 0x000109FF}, +{0x00010A40, 0x00010A48}, {0x00010A7D, 0x00010A7E}, {0x00010A9D, 0x00010A9F}, {0x00010AEB, 0x00010AEF}, +{0x00010B58, 0x00010B5F}, {0x00010B78, 0x00010B7F}, {0x00010BA9, 0x00010BAF}, {0x00010CFA, 0x00010CFF}, +{0x00010D30, 0x00010D39}, {0x00010E60, 0x00010E7E}, {0x00010F1D, 0x00010F26}, {0x00010F51, 0x00010F54}, +{0x00010FC5, 0x00010FCB}, {0x00011052, 0x0001106F}, {0x000110F0, 0x000110F9}, {0x00011136, 0x0001113F}, +{0x000111D0, 0x000111D9}, {0x000111E1, 0x000111F4}, {0x000112F0, 0x000112F9}, {0x00011450, 0x00011459}, +{0x000114D0, 0x000114D9}, {0x00011650, 0x00011659}, {0x000116C0, 0x000116C9}, {0x00011730, 0x0001173B}, +{0x000118E0, 0x000118F2}, {0x00011950, 0x00011959}, {0x00011C50, 0x00011C6C}, {0x00011D50, 0x00011D59}, +{0x00011DA0, 0x00011DA9}, {0x00011F50, 0x00011F59}, {0x00011FC0, 0x00011FD4}, {0x00012400, 0x0001246E}, +{0x00016A60, 0x00016A69}, {0x00016AC0, 0x00016AC9}, {0x00016B50, 0x00016B59}, {0x00016B5B, 0x00016B61}, +{0x00016E80, 0x00016E96}, {0x0001D2C0, 0x0001D2D3}, {0x0001D2E0, 0x0001D2F3}, {0x0001D360, 0x0001D378}, +{0x0001D7CE, 0x0001D7FF}, {0x0001E140, 0x0001E149}, {0x0001E2F0, 0x0001E2F9}, {0x0001E4F0, 0x0001E4F9}, +{0x0001E8C7, 0x0001E8CF}, {0x0001E950, 0x0001E959}, {0x0001EC71, 0x0001ECAB}, {0x0001ECAD, 0x0001ECAF}, +{0x0001ECB1, 0x0001ECB4}, {0x0001ED01, 0x0001ED2D}, {0x0001ED2F, 0x0001ED3D}, {0x0001F100, 0x0001F10C}, +{0x0001FBF0, 0x0001FBF9}, }; const std::vector> unicode_ranges_letter = { @@ -41,73 +60,73 @@ const std::vector> unicode_ranges_letter = { {0x00000710, 0x00000710}, {0x00000712, 0x0000072F}, {0x0000074D, 0x000007A5}, {0x000007B1, 0x000007B1}, {0x000007CA, 0x000007EA}, {0x000007F4, 0x000007F5}, {0x000007FA, 0x000007FA}, {0x00000800, 0x00000815}, {0x0000081A, 0x0000081A}, {0x00000824, 0x00000824}, {0x00000828, 0x00000828}, {0x00000840, 0x00000858}, -{0x00000860, 0x0000086A}, {0x000008A0, 0x000008B4}, {0x000008B6, 0x000008C7}, {0x00000904, 0x00000939}, -{0x0000093D, 0x0000093D}, {0x00000950, 0x00000950}, {0x00000958, 0x00000961}, {0x00000971, 0x00000980}, -{0x00000985, 0x0000098C}, {0x0000098F, 0x00000990}, {0x00000993, 0x000009A8}, {0x000009AA, 0x000009B0}, -{0x000009B2, 0x000009B2}, {0x000009B6, 0x000009B9}, {0x000009BD, 0x000009BD}, {0x000009CE, 0x000009CE}, -{0x000009DC, 0x000009DD}, {0x000009DF, 0x000009E1}, {0x000009F0, 0x000009F1}, {0x000009FC, 0x000009FC}, -{0x00000A05, 0x00000A0A}, {0x00000A0F, 0x00000A10}, {0x00000A13, 0x00000A28}, {0x00000A2A, 0x00000A30}, -{0x00000A32, 0x00000A33}, {0x00000A35, 0x00000A36}, {0x00000A38, 0x00000A39}, {0x00000A59, 0x00000A5C}, -{0x00000A5E, 0x00000A5E}, {0x00000A72, 0x00000A74}, {0x00000A85, 0x00000A8D}, {0x00000A8F, 0x00000A91}, -{0x00000A93, 0x00000AA8}, {0x00000AAA, 0x00000AB0}, {0x00000AB2, 0x00000AB3}, {0x00000AB5, 0x00000AB9}, -{0x00000ABD, 0x00000ABD}, {0x00000AD0, 0x00000AD0}, {0x00000AE0, 0x00000AE1}, {0x00000AF9, 0x00000AF9}, -{0x00000B05, 0x00000B0C}, {0x00000B0F, 0x00000B10}, {0x00000B13, 0x00000B28}, {0x00000B2A, 0x00000B30}, -{0x00000B32, 0x00000B33}, {0x00000B35, 0x00000B39}, {0x00000B3D, 0x00000B3D}, {0x00000B5C, 0x00000B5D}, -{0x00000B5F, 0x00000B61}, {0x00000B71, 0x00000B71}, {0x00000B83, 0x00000B83}, {0x00000B85, 0x00000B8A}, -{0x00000B8E, 0x00000B90}, {0x00000B92, 0x00000B95}, {0x00000B99, 0x00000B9A}, {0x00000B9C, 0x00000B9C}, -{0x00000B9E, 0x00000B9F}, {0x00000BA3, 0x00000BA4}, {0x00000BA8, 0x00000BAA}, {0x00000BAE, 0x00000BB9}, -{0x00000BD0, 0x00000BD0}, {0x00000C05, 0x00000C0C}, {0x00000C0E, 0x00000C10}, {0x00000C12, 0x00000C28}, -{0x00000C2A, 0x00000C39}, {0x00000C3D, 0x00000C3D}, {0x00000C58, 0x00000C5A}, {0x00000C60, 0x00000C61}, -{0x00000C80, 0x00000C80}, {0x00000C85, 0x00000C8C}, {0x00000C8E, 0x00000C90}, {0x00000C92, 0x00000CA8}, -{0x00000CAA, 0x00000CB3}, {0x00000CB5, 0x00000CB9}, {0x00000CBD, 0x00000CBD}, {0x00000CDE, 0x00000CDE}, -{0x00000CE0, 0x00000CE1}, {0x00000CF1, 0x00000CF2}, {0x00000D04, 0x00000D0C}, {0x00000D0E, 0x00000D10}, -{0x00000D12, 0x00000D3A}, {0x00000D3D, 0x00000D3D}, {0x00000D4E, 0x00000D4E}, {0x00000D54, 0x00000D56}, -{0x00000D5F, 0x00000D61}, {0x00000D7A, 0x00000D7F}, {0x00000D85, 0x00000D96}, {0x00000D9A, 0x00000DB1}, -{0x00000DB3, 0x00000DBB}, {0x00000DBD, 0x00000DBD}, {0x00000DC0, 0x00000DC6}, {0x00000E01, 0x00000E30}, -{0x00000E32, 0x00000E33}, {0x00000E40, 0x00000E46}, {0x00000E81, 0x00000E82}, {0x00000E84, 0x00000E84}, -{0x00000E86, 0x00000E8A}, {0x00000E8C, 0x00000EA3}, {0x00000EA5, 0x00000EA5}, {0x00000EA7, 0x00000EB0}, -{0x00000EB2, 0x00000EB3}, {0x00000EBD, 0x00000EBD}, {0x00000EC0, 0x00000EC4}, {0x00000EC6, 0x00000EC6}, -{0x00000EDC, 0x00000EDF}, {0x00000F00, 0x00000F00}, {0x00000F40, 0x00000F47}, {0x00000F49, 0x00000F6C}, -{0x00000F88, 0x00000F8C}, {0x00001000, 0x0000102A}, {0x0000103F, 0x0000103F}, {0x00001050, 0x00001055}, -{0x0000105A, 0x0000105D}, {0x00001061, 0x00001061}, {0x00001065, 0x00001066}, {0x0000106E, 0x00001070}, -{0x00001075, 0x00001081}, {0x0000108E, 0x0000108E}, {0x000010A0, 0x000010C5}, {0x000010C7, 0x000010C7}, -{0x000010CD, 0x000010CD}, {0x000010D0, 0x000010FA}, {0x000010FC, 0x00001248}, {0x0000124A, 0x0000124D}, -{0x00001250, 0x00001256}, {0x00001258, 0x00001258}, {0x0000125A, 0x0000125D}, {0x00001260, 0x00001288}, -{0x0000128A, 0x0000128D}, {0x00001290, 0x000012B0}, {0x000012B2, 0x000012B5}, {0x000012B8, 0x000012BE}, -{0x000012C0, 0x000012C0}, {0x000012C2, 0x000012C5}, {0x000012C8, 0x000012D6}, {0x000012D8, 0x00001310}, -{0x00001312, 0x00001315}, {0x00001318, 0x0000135A}, {0x00001380, 0x0000138F}, {0x000013A0, 0x000013F5}, -{0x000013F8, 0x000013FD}, {0x00001401, 0x0000166C}, {0x0000166F, 0x0000167F}, {0x00001681, 0x0000169A}, -{0x000016A0, 0x000016EA}, {0x000016F1, 0x000016F8}, {0x00001700, 0x0000170C}, {0x0000170E, 0x00001711}, -{0x00001720, 0x00001731}, {0x00001740, 0x00001751}, {0x00001760, 0x0000176C}, {0x0000176E, 0x00001770}, -{0x00001780, 0x000017B3}, {0x000017D7, 0x000017D7}, {0x000017DC, 0x000017DC}, {0x00001820, 0x00001878}, -{0x00001880, 0x00001884}, {0x00001887, 0x000018A8}, {0x000018AA, 0x000018AA}, {0x000018B0, 0x000018F5}, -{0x00001900, 0x0000191E}, {0x00001950, 0x0000196D}, {0x00001970, 0x00001974}, {0x00001980, 0x000019AB}, -{0x000019B0, 0x000019C9}, {0x00001A00, 0x00001A16}, {0x00001A20, 0x00001A54}, {0x00001AA7, 0x00001AA7}, -{0x00001B05, 0x00001B33}, {0x00001B45, 0x00001B4B}, {0x00001B83, 0x00001BA0}, {0x00001BAE, 0x00001BAF}, -{0x00001BBA, 0x00001BE5}, {0x00001C00, 0x00001C23}, {0x00001C4D, 0x00001C4F}, {0x00001C5A, 0x00001C7D}, -{0x00001C80, 0x00001C88}, {0x00001C90, 0x00001CBA}, {0x00001CBD, 0x00001CBF}, {0x00001CE9, 0x00001CEC}, -{0x00001CEE, 0x00001CF3}, {0x00001CF5, 0x00001CF6}, {0x00001CFA, 0x00001CFA}, {0x00001D00, 0x00001DBF}, -{0x00001E00, 0x00001F15}, {0x00001F18, 0x00001F1D}, {0x00001F20, 0x00001F45}, {0x00001F48, 0x00001F4D}, -{0x00001F50, 0x00001F57}, {0x00001F59, 0x00001F59}, {0x00001F5B, 0x00001F5B}, {0x00001F5D, 0x00001F5D}, -{0x00001F5F, 0x00001F7D}, {0x00001F80, 0x00001FB4}, {0x00001FB6, 0x00001FBC}, {0x00001FBE, 0x00001FBE}, -{0x00001FC2, 0x00001FC4}, {0x00001FC6, 0x00001FCC}, {0x00001FD0, 0x00001FD3}, {0x00001FD6, 0x00001FDB}, -{0x00001FE0, 0x00001FEC}, {0x00001FF2, 0x00001FF4}, {0x00001FF6, 0x00001FFC}, {0x00002071, 0x00002071}, -{0x0000207F, 0x0000207F}, {0x00002090, 0x0000209C}, {0x00002102, 0x00002102}, {0x00002107, 0x00002107}, -{0x0000210A, 0x00002113}, {0x00002115, 0x00002115}, {0x00002119, 0x0000211D}, {0x00002124, 0x00002124}, -{0x00002126, 0x00002126}, {0x00002128, 0x00002128}, {0x0000212A, 0x0000212D}, {0x0000212F, 0x00002139}, -{0x0000213C, 0x0000213F}, {0x00002145, 0x00002149}, {0x0000214E, 0x0000214E}, {0x00002183, 0x00002184}, -{0x00002C00, 0x00002C2E}, {0x00002C30, 0x00002C5E}, {0x00002C60, 0x00002CE4}, {0x00002CEB, 0x00002CEE}, -{0x00002CF2, 0x00002CF3}, {0x00002D00, 0x00002D25}, {0x00002D27, 0x00002D27}, {0x00002D2D, 0x00002D2D}, -{0x00002D30, 0x00002D67}, {0x00002D6F, 0x00002D6F}, {0x00002D80, 0x00002D96}, {0x00002DA0, 0x00002DA6}, -{0x00002DA8, 0x00002DAE}, {0x00002DB0, 0x00002DB6}, {0x00002DB8, 0x00002DBE}, {0x00002DC0, 0x00002DC6}, -{0x00002DC8, 0x00002DCE}, {0x00002DD0, 0x00002DD6}, {0x00002DD8, 0x00002DDE}, {0x00002E2F, 0x00002E2F}, -{0x00003005, 0x00003006}, {0x00003031, 0x00003035}, {0x0000303B, 0x0000303C}, {0x00003041, 0x00003096}, -{0x0000309D, 0x0000309F}, {0x000030A1, 0x000030FA}, {0x000030FC, 0x000030FF}, {0x00003105, 0x0000312F}, -{0x00003131, 0x0000318E}, {0x000031A0, 0x000031BF}, {0x000031F0, 0x000031FF}, {0x00003400, 0x00004DBF}, -{0x00004E00, 0x00009FFC}, {0x0000A000, 0x0000A48C}, {0x0000A4D0, 0x0000A4FD}, {0x0000A500, 0x0000A60C}, -{0x0000A610, 0x0000A61F}, {0x0000A62A, 0x0000A62B}, {0x0000A640, 0x0000A66E}, {0x0000A67F, 0x0000A69D}, -{0x0000A6A0, 0x0000A6E5}, {0x0000A717, 0x0000A71F}, {0x0000A722, 0x0000A788}, {0x0000A78B, 0x0000A7BF}, -{0x0000A7C2, 0x0000A7CA}, {0x0000A7F5, 0x0000A801}, {0x0000A803, 0x0000A805}, {0x0000A807, 0x0000A80A}, +{0x00000860, 0x0000086A}, {0x00000870, 0x00000887}, {0x00000889, 0x0000088E}, {0x000008A0, 0x000008C9}, +{0x00000904, 0x00000939}, {0x0000093D, 0x0000093D}, {0x00000950, 0x00000950}, {0x00000958, 0x00000961}, +{0x00000971, 0x00000980}, {0x00000985, 0x0000098C}, {0x0000098F, 0x00000990}, {0x00000993, 0x000009A8}, +{0x000009AA, 0x000009B0}, {0x000009B2, 0x000009B2}, {0x000009B6, 0x000009B9}, {0x000009BD, 0x000009BD}, +{0x000009CE, 0x000009CE}, {0x000009DC, 0x000009DD}, {0x000009DF, 0x000009E1}, {0x000009F0, 0x000009F1}, +{0x000009FC, 0x000009FC}, {0x00000A05, 0x00000A0A}, {0x00000A0F, 0x00000A10}, {0x00000A13, 0x00000A28}, +{0x00000A2A, 0x00000A30}, {0x00000A32, 0x00000A33}, {0x00000A35, 0x00000A36}, {0x00000A38, 0x00000A39}, +{0x00000A59, 0x00000A5C}, {0x00000A5E, 0x00000A5E}, {0x00000A72, 0x00000A74}, {0x00000A85, 0x00000A8D}, +{0x00000A8F, 0x00000A91}, {0x00000A93, 0x00000AA8}, {0x00000AAA, 0x00000AB0}, {0x00000AB2, 0x00000AB3}, +{0x00000AB5, 0x00000AB9}, {0x00000ABD, 0x00000ABD}, {0x00000AD0, 0x00000AD0}, {0x00000AE0, 0x00000AE1}, +{0x00000AF9, 0x00000AF9}, {0x00000B05, 0x00000B0C}, {0x00000B0F, 0x00000B10}, {0x00000B13, 0x00000B28}, +{0x00000B2A, 0x00000B30}, {0x00000B32, 0x00000B33}, {0x00000B35, 0x00000B39}, {0x00000B3D, 0x00000B3D}, +{0x00000B5C, 0x00000B5D}, {0x00000B5F, 0x00000B61}, {0x00000B71, 0x00000B71}, {0x00000B83, 0x00000B83}, +{0x00000B85, 0x00000B8A}, {0x00000B8E, 0x00000B90}, {0x00000B92, 0x00000B95}, {0x00000B99, 0x00000B9A}, +{0x00000B9C, 0x00000B9C}, {0x00000B9E, 0x00000B9F}, {0x00000BA3, 0x00000BA4}, {0x00000BA8, 0x00000BAA}, +{0x00000BAE, 0x00000BB9}, {0x00000BD0, 0x00000BD0}, {0x00000C05, 0x00000C0C}, {0x00000C0E, 0x00000C10}, +{0x00000C12, 0x00000C28}, {0x00000C2A, 0x00000C39}, {0x00000C3D, 0x00000C3D}, {0x00000C58, 0x00000C5A}, +{0x00000C5D, 0x00000C5D}, {0x00000C60, 0x00000C61}, {0x00000C80, 0x00000C80}, {0x00000C85, 0x00000C8C}, +{0x00000C8E, 0x00000C90}, {0x00000C92, 0x00000CA8}, {0x00000CAA, 0x00000CB3}, {0x00000CB5, 0x00000CB9}, +{0x00000CBD, 0x00000CBD}, {0x00000CDD, 0x00000CDE}, {0x00000CE0, 0x00000CE1}, {0x00000CF1, 0x00000CF2}, +{0x00000D04, 0x00000D0C}, {0x00000D0E, 0x00000D10}, {0x00000D12, 0x00000D3A}, {0x00000D3D, 0x00000D3D}, +{0x00000D4E, 0x00000D4E}, {0x00000D54, 0x00000D56}, {0x00000D5F, 0x00000D61}, {0x00000D7A, 0x00000D7F}, +{0x00000D85, 0x00000D96}, {0x00000D9A, 0x00000DB1}, {0x00000DB3, 0x00000DBB}, {0x00000DBD, 0x00000DBD}, +{0x00000DC0, 0x00000DC6}, {0x00000E01, 0x00000E30}, {0x00000E32, 0x00000E33}, {0x00000E40, 0x00000E46}, +{0x00000E81, 0x00000E82}, {0x00000E84, 0x00000E84}, {0x00000E86, 0x00000E8A}, {0x00000E8C, 0x00000EA3}, +{0x00000EA5, 0x00000EA5}, {0x00000EA7, 0x00000EB0}, {0x00000EB2, 0x00000EB3}, {0x00000EBD, 0x00000EBD}, +{0x00000EC0, 0x00000EC4}, {0x00000EC6, 0x00000EC6}, {0x00000EDC, 0x00000EDF}, {0x00000F00, 0x00000F00}, +{0x00000F40, 0x00000F47}, {0x00000F49, 0x00000F6C}, {0x00000F88, 0x00000F8C}, {0x00001000, 0x0000102A}, +{0x0000103F, 0x0000103F}, {0x00001050, 0x00001055}, {0x0000105A, 0x0000105D}, {0x00001061, 0x00001061}, +{0x00001065, 0x00001066}, {0x0000106E, 0x00001070}, {0x00001075, 0x00001081}, {0x0000108E, 0x0000108E}, +{0x000010A0, 0x000010C5}, {0x000010C7, 0x000010C7}, {0x000010CD, 0x000010CD}, {0x000010D0, 0x000010FA}, +{0x000010FC, 0x00001248}, {0x0000124A, 0x0000124D}, {0x00001250, 0x00001256}, {0x00001258, 0x00001258}, +{0x0000125A, 0x0000125D}, {0x00001260, 0x00001288}, {0x0000128A, 0x0000128D}, {0x00001290, 0x000012B0}, +{0x000012B2, 0x000012B5}, {0x000012B8, 0x000012BE}, {0x000012C0, 0x000012C0}, {0x000012C2, 0x000012C5}, +{0x000012C8, 0x000012D6}, {0x000012D8, 0x00001310}, {0x00001312, 0x00001315}, {0x00001318, 0x0000135A}, +{0x00001380, 0x0000138F}, {0x000013A0, 0x000013F5}, {0x000013F8, 0x000013FD}, {0x00001401, 0x0000166C}, +{0x0000166F, 0x0000167F}, {0x00001681, 0x0000169A}, {0x000016A0, 0x000016EA}, {0x000016F1, 0x000016F8}, +{0x00001700, 0x00001711}, {0x0000171F, 0x00001731}, {0x00001740, 0x00001751}, {0x00001760, 0x0000176C}, +{0x0000176E, 0x00001770}, {0x00001780, 0x000017B3}, {0x000017D7, 0x000017D7}, {0x000017DC, 0x000017DC}, +{0x00001820, 0x00001878}, {0x00001880, 0x00001884}, {0x00001887, 0x000018A8}, {0x000018AA, 0x000018AA}, +{0x000018B0, 0x000018F5}, {0x00001900, 0x0000191E}, {0x00001950, 0x0000196D}, {0x00001970, 0x00001974}, +{0x00001980, 0x000019AB}, {0x000019B0, 0x000019C9}, {0x00001A00, 0x00001A16}, {0x00001A20, 0x00001A54}, +{0x00001AA7, 0x00001AA7}, {0x00001B05, 0x00001B33}, {0x00001B45, 0x00001B4C}, {0x00001B83, 0x00001BA0}, +{0x00001BAE, 0x00001BAF}, {0x00001BBA, 0x00001BE5}, {0x00001C00, 0x00001C23}, {0x00001C4D, 0x00001C4F}, +{0x00001C5A, 0x00001C7D}, {0x00001C80, 0x00001C88}, {0x00001C90, 0x00001CBA}, {0x00001CBD, 0x00001CBF}, +{0x00001CE9, 0x00001CEC}, {0x00001CEE, 0x00001CF3}, {0x00001CF5, 0x00001CF6}, {0x00001CFA, 0x00001CFA}, +{0x00001D00, 0x00001DBF}, {0x00001E00, 0x00001F15}, {0x00001F18, 0x00001F1D}, {0x00001F20, 0x00001F45}, +{0x00001F48, 0x00001F4D}, {0x00001F50, 0x00001F57}, {0x00001F59, 0x00001F59}, {0x00001F5B, 0x00001F5B}, +{0x00001F5D, 0x00001F5D}, {0x00001F5F, 0x00001F7D}, {0x00001F80, 0x00001FB4}, {0x00001FB6, 0x00001FBC}, +{0x00001FBE, 0x00001FBE}, {0x00001FC2, 0x00001FC4}, {0x00001FC6, 0x00001FCC}, {0x00001FD0, 0x00001FD3}, +{0x00001FD6, 0x00001FDB}, {0x00001FE0, 0x00001FEC}, {0x00001FF2, 0x00001FF4}, {0x00001FF6, 0x00001FFC}, +{0x00002071, 0x00002071}, {0x0000207F, 0x0000207F}, {0x00002090, 0x0000209C}, {0x00002102, 0x00002102}, +{0x00002107, 0x00002107}, {0x0000210A, 0x00002113}, {0x00002115, 0x00002115}, {0x00002119, 0x0000211D}, +{0x00002124, 0x00002124}, {0x00002126, 0x00002126}, {0x00002128, 0x00002128}, {0x0000212A, 0x0000212D}, +{0x0000212F, 0x00002139}, {0x0000213C, 0x0000213F}, {0x00002145, 0x00002149}, {0x0000214E, 0x0000214E}, +{0x00002183, 0x00002184}, {0x00002C00, 0x00002CE4}, {0x00002CEB, 0x00002CEE}, {0x00002CF2, 0x00002CF3}, +{0x00002D00, 0x00002D25}, {0x00002D27, 0x00002D27}, {0x00002D2D, 0x00002D2D}, {0x00002D30, 0x00002D67}, +{0x00002D6F, 0x00002D6F}, {0x00002D80, 0x00002D96}, {0x00002DA0, 0x00002DA6}, {0x00002DA8, 0x00002DAE}, +{0x00002DB0, 0x00002DB6}, {0x00002DB8, 0x00002DBE}, {0x00002DC0, 0x00002DC6}, {0x00002DC8, 0x00002DCE}, +{0x00002DD0, 0x00002DD6}, {0x00002DD8, 0x00002DDE}, {0x00002E2F, 0x00002E2F}, {0x00003005, 0x00003006}, +{0x00003031, 0x00003035}, {0x0000303B, 0x0000303C}, {0x00003041, 0x00003096}, {0x0000309D, 0x0000309F}, +{0x000030A1, 0x000030FA}, {0x000030FC, 0x000030FF}, {0x00003105, 0x0000312F}, {0x00003131, 0x0000318E}, +{0x000031A0, 0x000031BF}, {0x000031F0, 0x000031FF}, {0x00003400, 0x00004DBF}, {0x00004E00, 0x0000A48C}, +{0x0000A4D0, 0x0000A4FD}, {0x0000A500, 0x0000A60C}, {0x0000A610, 0x0000A61F}, {0x0000A62A, 0x0000A62B}, +{0x0000A640, 0x0000A66E}, {0x0000A67F, 0x0000A69D}, {0x0000A6A0, 0x0000A6E5}, {0x0000A717, 0x0000A71F}, +{0x0000A722, 0x0000A788}, {0x0000A78B, 0x0000A7CA}, {0x0000A7D0, 0x0000A7D1}, {0x0000A7D3, 0x0000A7D3}, +{0x0000A7D5, 0x0000A7D9}, {0x0000A7F2, 0x0000A801}, {0x0000A803, 0x0000A805}, {0x0000A807, 0x0000A80A}, {0x0000A80C, 0x0000A822}, {0x0000A840, 0x0000A873}, {0x0000A882, 0x0000A8B3}, {0x0000A8F2, 0x0000A8F7}, {0x0000A8FB, 0x0000A8FB}, {0x0000A8FD, 0x0000A8FE}, {0x0000A90A, 0x0000A925}, {0x0000A930, 0x0000A946}, {0x0000A960, 0x0000A97C}, {0x0000A984, 0x0000A9B2}, {0x0000A9CF, 0x0000A9CF}, {0x0000A9E0, 0x0000A9E4}, @@ -129,51 +148,60 @@ const std::vector> unicode_ranges_letter = { {0x000102A0, 0x000102D0}, {0x00010300, 0x0001031F}, {0x0001032D, 0x00010340}, {0x00010342, 0x00010349}, {0x00010350, 0x00010375}, {0x00010380, 0x0001039D}, {0x000103A0, 0x000103C3}, {0x000103C8, 0x000103CF}, {0x00010400, 0x0001049D}, {0x000104B0, 0x000104D3}, {0x000104D8, 0x000104FB}, {0x00010500, 0x00010527}, -{0x00010530, 0x00010563}, {0x00010600, 0x00010736}, {0x00010740, 0x00010755}, {0x00010760, 0x00010767}, -{0x00010800, 0x00010805}, {0x00010808, 0x00010808}, {0x0001080A, 0x00010835}, {0x00010837, 0x00010838}, -{0x0001083C, 0x0001083C}, {0x0001083F, 0x00010855}, {0x00010860, 0x00010876}, {0x00010880, 0x0001089E}, -{0x000108E0, 0x000108F2}, {0x000108F4, 0x000108F5}, {0x00010900, 0x00010915}, {0x00010920, 0x00010939}, -{0x00010980, 0x000109B7}, {0x000109BE, 0x000109BF}, {0x00010A00, 0x00010A00}, {0x00010A10, 0x00010A13}, -{0x00010A15, 0x00010A17}, {0x00010A19, 0x00010A35}, {0x00010A60, 0x00010A7C}, {0x00010A80, 0x00010A9C}, -{0x00010AC0, 0x00010AC7}, {0x00010AC9, 0x00010AE4}, {0x00010B00, 0x00010B35}, {0x00010B40, 0x00010B55}, -{0x00010B60, 0x00010B72}, {0x00010B80, 0x00010B91}, {0x00010C00, 0x00010C48}, {0x00010C80, 0x00010CB2}, -{0x00010CC0, 0x00010CF2}, {0x00010D00, 0x00010D23}, {0x00010E80, 0x00010EA9}, {0x00010EB0, 0x00010EB1}, -{0x00010F00, 0x00010F1C}, {0x00010F27, 0x00010F27}, {0x00010F30, 0x00010F45}, {0x00010FB0, 0x00010FC4}, -{0x00010FE0, 0x00010FF6}, {0x00011003, 0x00011037}, {0x00011083, 0x000110AF}, {0x000110D0, 0x000110E8}, -{0x00011103, 0x00011126}, {0x00011144, 0x00011144}, {0x00011147, 0x00011147}, {0x00011150, 0x00011172}, -{0x00011176, 0x00011176}, {0x00011183, 0x000111B2}, {0x000111C1, 0x000111C4}, {0x000111DA, 0x000111DA}, -{0x000111DC, 0x000111DC}, {0x00011200, 0x00011211}, {0x00011213, 0x0001122B}, {0x00011280, 0x00011286}, -{0x00011288, 0x00011288}, {0x0001128A, 0x0001128D}, {0x0001128F, 0x0001129D}, {0x0001129F, 0x000112A8}, -{0x000112B0, 0x000112DE}, {0x00011305, 0x0001130C}, {0x0001130F, 0x00011310}, {0x00011313, 0x00011328}, -{0x0001132A, 0x00011330}, {0x00011332, 0x00011333}, {0x00011335, 0x00011339}, {0x0001133D, 0x0001133D}, -{0x00011350, 0x00011350}, {0x0001135D, 0x00011361}, {0x00011400, 0x00011434}, {0x00011447, 0x0001144A}, -{0x0001145F, 0x00011461}, {0x00011480, 0x000114AF}, {0x000114C4, 0x000114C5}, {0x000114C7, 0x000114C7}, -{0x00011580, 0x000115AE}, {0x000115D8, 0x000115DB}, {0x00011600, 0x0001162F}, {0x00011644, 0x00011644}, -{0x00011680, 0x000116AA}, {0x000116B8, 0x000116B8}, {0x00011700, 0x0001171A}, {0x00011800, 0x0001182B}, +{0x00010530, 0x00010563}, {0x00010570, 0x0001057A}, {0x0001057C, 0x0001058A}, {0x0001058C, 0x00010592}, +{0x00010594, 0x00010595}, {0x00010597, 0x000105A1}, {0x000105A3, 0x000105B1}, {0x000105B3, 0x000105B9}, +{0x000105BB, 0x000105BC}, {0x00010600, 0x00010736}, {0x00010740, 0x00010755}, {0x00010760, 0x00010767}, +{0x00010780, 0x00010785}, {0x00010787, 0x000107B0}, {0x000107B2, 0x000107BA}, {0x00010800, 0x00010805}, +{0x00010808, 0x00010808}, {0x0001080A, 0x00010835}, {0x00010837, 0x00010838}, {0x0001083C, 0x0001083C}, +{0x0001083F, 0x00010855}, {0x00010860, 0x00010876}, {0x00010880, 0x0001089E}, {0x000108E0, 0x000108F2}, +{0x000108F4, 0x000108F5}, {0x00010900, 0x00010915}, {0x00010920, 0x00010939}, {0x00010980, 0x000109B7}, +{0x000109BE, 0x000109BF}, {0x00010A00, 0x00010A00}, {0x00010A10, 0x00010A13}, {0x00010A15, 0x00010A17}, +{0x00010A19, 0x00010A35}, {0x00010A60, 0x00010A7C}, {0x00010A80, 0x00010A9C}, {0x00010AC0, 0x00010AC7}, +{0x00010AC9, 0x00010AE4}, {0x00010B00, 0x00010B35}, {0x00010B40, 0x00010B55}, {0x00010B60, 0x00010B72}, +{0x00010B80, 0x00010B91}, {0x00010C00, 0x00010C48}, {0x00010C80, 0x00010CB2}, {0x00010CC0, 0x00010CF2}, +{0x00010D00, 0x00010D23}, {0x00010E80, 0x00010EA9}, {0x00010EB0, 0x00010EB1}, {0x00010F00, 0x00010F1C}, +{0x00010F27, 0x00010F27}, {0x00010F30, 0x00010F45}, {0x00010F70, 0x00010F81}, {0x00010FB0, 0x00010FC4}, +{0x00010FE0, 0x00010FF6}, {0x00011003, 0x00011037}, {0x00011071, 0x00011072}, {0x00011075, 0x00011075}, +{0x00011083, 0x000110AF}, {0x000110D0, 0x000110E8}, {0x00011103, 0x00011126}, {0x00011144, 0x00011144}, +{0x00011147, 0x00011147}, {0x00011150, 0x00011172}, {0x00011176, 0x00011176}, {0x00011183, 0x000111B2}, +{0x000111C1, 0x000111C4}, {0x000111DA, 0x000111DA}, {0x000111DC, 0x000111DC}, {0x00011200, 0x00011211}, +{0x00011213, 0x0001122B}, {0x0001123F, 0x00011240}, {0x00011280, 0x00011286}, {0x00011288, 0x00011288}, +{0x0001128A, 0x0001128D}, {0x0001128F, 0x0001129D}, {0x0001129F, 0x000112A8}, {0x000112B0, 0x000112DE}, +{0x00011305, 0x0001130C}, {0x0001130F, 0x00011310}, {0x00011313, 0x00011328}, {0x0001132A, 0x00011330}, +{0x00011332, 0x00011333}, {0x00011335, 0x00011339}, {0x0001133D, 0x0001133D}, {0x00011350, 0x00011350}, +{0x0001135D, 0x00011361}, {0x00011400, 0x00011434}, {0x00011447, 0x0001144A}, {0x0001145F, 0x00011461}, +{0x00011480, 0x000114AF}, {0x000114C4, 0x000114C5}, {0x000114C7, 0x000114C7}, {0x00011580, 0x000115AE}, +{0x000115D8, 0x000115DB}, {0x00011600, 0x0001162F}, {0x00011644, 0x00011644}, {0x00011680, 0x000116AA}, +{0x000116B8, 0x000116B8}, {0x00011700, 0x0001171A}, {0x00011740, 0x00011746}, {0x00011800, 0x0001182B}, {0x000118A0, 0x000118DF}, {0x000118FF, 0x00011906}, {0x00011909, 0x00011909}, {0x0001190C, 0x00011913}, {0x00011915, 0x00011916}, {0x00011918, 0x0001192F}, {0x0001193F, 0x0001193F}, {0x00011941, 0x00011941}, {0x000119A0, 0x000119A7}, {0x000119AA, 0x000119D0}, {0x000119E1, 0x000119E1}, {0x000119E3, 0x000119E3}, {0x00011A00, 0x00011A00}, {0x00011A0B, 0x00011A32}, {0x00011A3A, 0x00011A3A}, {0x00011A50, 0x00011A50}, -{0x00011A5C, 0x00011A89}, {0x00011A9D, 0x00011A9D}, {0x00011AC0, 0x00011AF8}, {0x00011C00, 0x00011C08}, +{0x00011A5C, 0x00011A89}, {0x00011A9D, 0x00011A9D}, {0x00011AB0, 0x00011AF8}, {0x00011C00, 0x00011C08}, {0x00011C0A, 0x00011C2E}, {0x00011C40, 0x00011C40}, {0x00011C72, 0x00011C8F}, {0x00011D00, 0x00011D06}, {0x00011D08, 0x00011D09}, {0x00011D0B, 0x00011D30}, {0x00011D46, 0x00011D46}, {0x00011D60, 0x00011D65}, {0x00011D67, 0x00011D68}, {0x00011D6A, 0x00011D89}, {0x00011D98, 0x00011D98}, {0x00011EE0, 0x00011EF2}, -{0x00011FB0, 0x00011FB0}, {0x00012000, 0x00012399}, {0x00012480, 0x00012543}, {0x00013000, 0x0001342E}, -{0x00014400, 0x00014646}, {0x00016800, 0x00016A38}, {0x00016A40, 0x00016A5E}, {0x00016AD0, 0x00016AED}, -{0x00016B00, 0x00016B2F}, {0x00016B40, 0x00016B43}, {0x00016B63, 0x00016B77}, {0x00016B7D, 0x00016B8F}, -{0x00016E40, 0x00016E7F}, {0x00016F00, 0x00016F4A}, {0x00016F50, 0x00016F50}, {0x00016F93, 0x00016F9F}, -{0x00016FE0, 0x00016FE1}, {0x00016FE3, 0x00016FE3}, {0x00017000, 0x000187F7}, {0x00018800, 0x00018CD5}, -{0x00018D00, 0x00018D08}, {0x0001B000, 0x0001B11E}, {0x0001B150, 0x0001B152}, {0x0001B164, 0x0001B167}, -{0x0001B170, 0x0001B2FB}, {0x0001BC00, 0x0001BC6A}, {0x0001BC70, 0x0001BC7C}, {0x0001BC80, 0x0001BC88}, -{0x0001BC90, 0x0001BC99}, {0x0001D400, 0x0001D454}, {0x0001D456, 0x0001D49C}, {0x0001D49E, 0x0001D49F}, -{0x0001D4A2, 0x0001D4A2}, {0x0001D4A5, 0x0001D4A6}, {0x0001D4A9, 0x0001D4AC}, {0x0001D4AE, 0x0001D4B9}, -{0x0001D4BB, 0x0001D4BB}, {0x0001D4BD, 0x0001D4C3}, {0x0001D4C5, 0x0001D505}, {0x0001D507, 0x0001D50A}, -{0x0001D50D, 0x0001D514}, {0x0001D516, 0x0001D51C}, {0x0001D51E, 0x0001D539}, {0x0001D53B, 0x0001D53E}, -{0x0001D540, 0x0001D544}, {0x0001D546, 0x0001D546}, {0x0001D54A, 0x0001D550}, {0x0001D552, 0x0001D6A5}, -{0x0001D6A8, 0x0001D6C0}, {0x0001D6C2, 0x0001D6DA}, {0x0001D6DC, 0x0001D6FA}, {0x0001D6FC, 0x0001D714}, -{0x0001D716, 0x0001D734}, {0x0001D736, 0x0001D74E}, {0x0001D750, 0x0001D76E}, {0x0001D770, 0x0001D788}, -{0x0001D78A, 0x0001D7A8}, {0x0001D7AA, 0x0001D7C2}, {0x0001D7C4, 0x0001D7CB}, {0x0001E100, 0x0001E12C}, -{0x0001E137, 0x0001E13D}, {0x0001E14E, 0x0001E14E}, {0x0001E2C0, 0x0001E2EB}, {0x0001E800, 0x0001E8C4}, +{0x00011F02, 0x00011F02}, {0x00011F04, 0x00011F10}, {0x00011F12, 0x00011F33}, {0x00011FB0, 0x00011FB0}, +{0x00012000, 0x00012399}, {0x00012480, 0x00012543}, {0x00012F90, 0x00012FF0}, {0x00013000, 0x0001342F}, +{0x00013441, 0x00013446}, {0x00014400, 0x00014646}, {0x00016800, 0x00016A38}, {0x00016A40, 0x00016A5E}, +{0x00016A70, 0x00016ABE}, {0x00016AD0, 0x00016AED}, {0x00016B00, 0x00016B2F}, {0x00016B40, 0x00016B43}, +{0x00016B63, 0x00016B77}, {0x00016B7D, 0x00016B8F}, {0x00016E40, 0x00016E7F}, {0x00016F00, 0x00016F4A}, +{0x00016F50, 0x00016F50}, {0x00016F93, 0x00016F9F}, {0x00016FE0, 0x00016FE1}, {0x00016FE3, 0x00016FE3}, +{0x00017000, 0x000187F7}, {0x00018800, 0x00018CD5}, {0x00018D00, 0x00018D08}, {0x0001AFF0, 0x0001AFF3}, +{0x0001AFF5, 0x0001AFFB}, {0x0001AFFD, 0x0001AFFE}, {0x0001B000, 0x0001B122}, {0x0001B132, 0x0001B132}, +{0x0001B150, 0x0001B152}, {0x0001B155, 0x0001B155}, {0x0001B164, 0x0001B167}, {0x0001B170, 0x0001B2FB}, +{0x0001BC00, 0x0001BC6A}, {0x0001BC70, 0x0001BC7C}, {0x0001BC80, 0x0001BC88}, {0x0001BC90, 0x0001BC99}, +{0x0001D400, 0x0001D454}, {0x0001D456, 0x0001D49C}, {0x0001D49E, 0x0001D49F}, {0x0001D4A2, 0x0001D4A2}, +{0x0001D4A5, 0x0001D4A6}, {0x0001D4A9, 0x0001D4AC}, {0x0001D4AE, 0x0001D4B9}, {0x0001D4BB, 0x0001D4BB}, +{0x0001D4BD, 0x0001D4C3}, {0x0001D4C5, 0x0001D505}, {0x0001D507, 0x0001D50A}, {0x0001D50D, 0x0001D514}, +{0x0001D516, 0x0001D51C}, {0x0001D51E, 0x0001D539}, {0x0001D53B, 0x0001D53E}, {0x0001D540, 0x0001D544}, +{0x0001D546, 0x0001D546}, {0x0001D54A, 0x0001D550}, {0x0001D552, 0x0001D6A5}, {0x0001D6A8, 0x0001D6C0}, +{0x0001D6C2, 0x0001D6DA}, {0x0001D6DC, 0x0001D6FA}, {0x0001D6FC, 0x0001D714}, {0x0001D716, 0x0001D734}, +{0x0001D736, 0x0001D74E}, {0x0001D750, 0x0001D76E}, {0x0001D770, 0x0001D788}, {0x0001D78A, 0x0001D7A8}, +{0x0001D7AA, 0x0001D7C2}, {0x0001D7C4, 0x0001D7CB}, {0x0001DF00, 0x0001DF1E}, {0x0001DF25, 0x0001DF2A}, +{0x0001E030, 0x0001E06D}, {0x0001E100, 0x0001E12C}, {0x0001E137, 0x0001E13D}, {0x0001E14E, 0x0001E14E}, +{0x0001E290, 0x0001E2AD}, {0x0001E2C0, 0x0001E2EB}, {0x0001E4D0, 0x0001E4EB}, {0x0001E7E0, 0x0001E7E6}, +{0x0001E7E8, 0x0001E7EB}, {0x0001E7ED, 0x0001E7EE}, {0x0001E7F0, 0x0001E7FE}, {0x0001E800, 0x0001E8C4}, {0x0001E900, 0x0001E943}, {0x0001E94B, 0x0001E94B}, {0x0001EE00, 0x0001EE03}, {0x0001EE05, 0x0001EE1F}, {0x0001EE21, 0x0001EE22}, {0x0001EE24, 0x0001EE24}, {0x0001EE27, 0x0001EE27}, {0x0001EE29, 0x0001EE32}, {0x0001EE34, 0x0001EE37}, {0x0001EE39, 0x0001EE39}, {0x0001EE3B, 0x0001EE3B}, {0x0001EE42, 0x0001EE42}, @@ -182,13 +210,18 @@ const std::vector> unicode_ranges_letter = { {0x0001EE5B, 0x0001EE5B}, {0x0001EE5D, 0x0001EE5D}, {0x0001EE5F, 0x0001EE5F}, {0x0001EE61, 0x0001EE62}, {0x0001EE64, 0x0001EE64}, {0x0001EE67, 0x0001EE6A}, {0x0001EE6C, 0x0001EE72}, {0x0001EE74, 0x0001EE77}, {0x0001EE79, 0x0001EE7C}, {0x0001EE7E, 0x0001EE7E}, {0x0001EE80, 0x0001EE89}, {0x0001EE8B, 0x0001EE9B}, -{0x0001EEA1, 0x0001EEA3}, {0x0001EEA5, 0x0001EEA9}, {0x0001EEAB, 0x0001EEBB}, {0x00020000, 0x0002A6DD}, -{0x0002A700, 0x0002B734}, {0x0002B740, 0x0002B81D}, {0x0002B820, 0x0002CEA1}, {0x0002CEB0, 0x0002EBE0}, -{0x0002F800, 0x0002FA1D}, {0x00030000, 0x0003134A}, +{0x0001EEA1, 0x0001EEA3}, {0x0001EEA5, 0x0001EEA9}, {0x0001EEAB, 0x0001EEBB}, {0x00020000, 0x0002A6DF}, +{0x0002A700, 0x0002B739}, {0x0002B740, 0x0002B81D}, {0x0002B820, 0x0002CEA1}, {0x0002CEB0, 0x0002EBE0}, +{0x0002EBF0, 0x0002EE5D}, {0x0002F800, 0x0002FA1D}, {0x00030000, 0x0003134A}, {0x00031350, 0x000323AF}, +}; + +const std::vector> unicode_ranges_separator = { +{0x00000020, 0x00000020}, {0x000000A0, 0x000000A0}, {0x00001680, 0x00001680}, {0x00002000, 0x0000200A}, +{0x00002028, 0x00002029}, {0x0000202F, 0x0000202F}, {0x0000205F, 0x0000205F}, {0x00003000, 0x00003000}, }; const std::vector> unicode_ranges_whitespace = { -{0x00000009, 0x0000000D}, {0x0000001C, 0x00000020}, {0x00000085, 0x00000085}, {0x000000A0, 0x000000A0}, +{0x00000009, 0x0000000D}, {0x00000020, 0x00000020}, {0x00000085, 0x00000085}, {0x000000A0, 0x000000A0}, {0x00001680, 0x00001680}, {0x00002000, 0x0000200A}, {0x00002028, 0x00002029}, {0x0000202F, 0x0000202F}, {0x0000205F, 0x0000205F}, {0x00003000, 0x00003000}, }; @@ -200,72 +233,77 @@ const std::vector> unicode_ranges_accent_mark = { {0x000006E7, 0x000006E8}, {0x000006EA, 0x000006ED}, {0x00000711, 0x00000711}, {0x00000730, 0x0000074A}, {0x000007A6, 0x000007B0}, {0x000007EB, 0x000007F3}, {0x000007FD, 0x000007FD}, {0x00000816, 0x00000819}, {0x0000081B, 0x00000823}, {0x00000825, 0x00000827}, {0x00000829, 0x0000082D}, {0x00000859, 0x0000085B}, -{0x000008D3, 0x000008E1}, {0x000008E3, 0x00000903}, {0x0000093A, 0x0000093C}, {0x0000093E, 0x0000094F}, -{0x00000951, 0x00000957}, {0x00000962, 0x00000963}, {0x00000981, 0x00000983}, {0x000009BC, 0x000009BC}, -{0x000009BE, 0x000009C4}, {0x000009C7, 0x000009C8}, {0x000009CB, 0x000009CD}, {0x000009D7, 0x000009D7}, -{0x000009E2, 0x000009E3}, {0x000009FE, 0x000009FE}, {0x00000A01, 0x00000A03}, {0x00000A3C, 0x00000A3C}, -{0x00000A3E, 0x00000A42}, {0x00000A47, 0x00000A48}, {0x00000A4B, 0x00000A4D}, {0x00000A51, 0x00000A51}, -{0x00000A70, 0x00000A71}, {0x00000A75, 0x00000A75}, {0x00000A81, 0x00000A83}, {0x00000ABC, 0x00000ABC}, -{0x00000ABE, 0x00000AC5}, {0x00000AC7, 0x00000AC9}, {0x00000ACB, 0x00000ACD}, {0x00000AE2, 0x00000AE3}, -{0x00000AFA, 0x00000AFF}, {0x00000B01, 0x00000B03}, {0x00000B3C, 0x00000B3C}, {0x00000B3E, 0x00000B44}, -{0x00000B47, 0x00000B48}, {0x00000B4B, 0x00000B4D}, {0x00000B55, 0x00000B57}, {0x00000B62, 0x00000B63}, -{0x00000B82, 0x00000B82}, {0x00000BBE, 0x00000BC2}, {0x00000BC6, 0x00000BC8}, {0x00000BCA, 0x00000BCD}, -{0x00000BD7, 0x00000BD7}, {0x00000C00, 0x00000C04}, {0x00000C3E, 0x00000C44}, {0x00000C46, 0x00000C48}, -{0x00000C4A, 0x00000C4D}, {0x00000C55, 0x00000C56}, {0x00000C62, 0x00000C63}, {0x00000C81, 0x00000C83}, -{0x00000CBC, 0x00000CBC}, {0x00000CBE, 0x00000CC4}, {0x00000CC6, 0x00000CC8}, {0x00000CCA, 0x00000CCD}, -{0x00000CD5, 0x00000CD6}, {0x00000CE2, 0x00000CE3}, {0x00000D00, 0x00000D03}, {0x00000D3B, 0x00000D3C}, -{0x00000D3E, 0x00000D44}, {0x00000D46, 0x00000D48}, {0x00000D4A, 0x00000D4D}, {0x00000D57, 0x00000D57}, -{0x00000D62, 0x00000D63}, {0x00000D81, 0x00000D83}, {0x00000DCA, 0x00000DCA}, {0x00000DCF, 0x00000DD4}, -{0x00000DD6, 0x00000DD6}, {0x00000DD8, 0x00000DDF}, {0x00000DF2, 0x00000DF3}, {0x00000E31, 0x00000E31}, -{0x00000E34, 0x00000E3A}, {0x00000E47, 0x00000E4E}, {0x00000EB1, 0x00000EB1}, {0x00000EB4, 0x00000EBC}, -{0x00000EC8, 0x00000ECD}, {0x00000F18, 0x00000F19}, {0x00000F35, 0x00000F35}, {0x00000F37, 0x00000F37}, -{0x00000F39, 0x00000F39}, {0x00000F3E, 0x00000F3F}, {0x00000F71, 0x00000F84}, {0x00000F86, 0x00000F87}, -{0x00000F8D, 0x00000F97}, {0x00000F99, 0x00000FBC}, {0x00000FC6, 0x00000FC6}, {0x0000102B, 0x0000103E}, -{0x00001056, 0x00001059}, {0x0000105E, 0x00001060}, {0x00001062, 0x00001064}, {0x00001067, 0x0000106D}, -{0x00001071, 0x00001074}, {0x00001082, 0x0000108D}, {0x0000108F, 0x0000108F}, {0x0000109A, 0x0000109D}, -{0x0000135D, 0x0000135F}, {0x00001712, 0x00001714}, {0x00001732, 0x00001734}, {0x00001752, 0x00001753}, -{0x00001772, 0x00001773}, {0x000017B4, 0x000017D3}, {0x000017DD, 0x000017DD}, {0x0000180B, 0x0000180D}, +{0x00000898, 0x0000089F}, {0x000008CA, 0x000008E1}, {0x000008E3, 0x00000903}, {0x0000093A, 0x0000093C}, +{0x0000093E, 0x0000094F}, {0x00000951, 0x00000957}, {0x00000962, 0x00000963}, {0x00000981, 0x00000983}, +{0x000009BC, 0x000009BC}, {0x000009BE, 0x000009C4}, {0x000009C7, 0x000009C8}, {0x000009CB, 0x000009CD}, +{0x000009D7, 0x000009D7}, {0x000009E2, 0x000009E3}, {0x000009FE, 0x000009FE}, {0x00000A01, 0x00000A03}, +{0x00000A3C, 0x00000A3C}, {0x00000A3E, 0x00000A42}, {0x00000A47, 0x00000A48}, {0x00000A4B, 0x00000A4D}, +{0x00000A51, 0x00000A51}, {0x00000A70, 0x00000A71}, {0x00000A75, 0x00000A75}, {0x00000A81, 0x00000A83}, +{0x00000ABC, 0x00000ABC}, {0x00000ABE, 0x00000AC5}, {0x00000AC7, 0x00000AC9}, {0x00000ACB, 0x00000ACD}, +{0x00000AE2, 0x00000AE3}, {0x00000AFA, 0x00000AFF}, {0x00000B01, 0x00000B03}, {0x00000B3C, 0x00000B3C}, +{0x00000B3E, 0x00000B44}, {0x00000B47, 0x00000B48}, {0x00000B4B, 0x00000B4D}, {0x00000B55, 0x00000B57}, +{0x00000B62, 0x00000B63}, {0x00000B82, 0x00000B82}, {0x00000BBE, 0x00000BC2}, {0x00000BC6, 0x00000BC8}, +{0x00000BCA, 0x00000BCD}, {0x00000BD7, 0x00000BD7}, {0x00000C00, 0x00000C04}, {0x00000C3C, 0x00000C3C}, +{0x00000C3E, 0x00000C44}, {0x00000C46, 0x00000C48}, {0x00000C4A, 0x00000C4D}, {0x00000C55, 0x00000C56}, +{0x00000C62, 0x00000C63}, {0x00000C81, 0x00000C83}, {0x00000CBC, 0x00000CBC}, {0x00000CBE, 0x00000CC4}, +{0x00000CC6, 0x00000CC8}, {0x00000CCA, 0x00000CCD}, {0x00000CD5, 0x00000CD6}, {0x00000CE2, 0x00000CE3}, +{0x00000CF3, 0x00000CF3}, {0x00000D00, 0x00000D03}, {0x00000D3B, 0x00000D3C}, {0x00000D3E, 0x00000D44}, +{0x00000D46, 0x00000D48}, {0x00000D4A, 0x00000D4D}, {0x00000D57, 0x00000D57}, {0x00000D62, 0x00000D63}, +{0x00000D81, 0x00000D83}, {0x00000DCA, 0x00000DCA}, {0x00000DCF, 0x00000DD4}, {0x00000DD6, 0x00000DD6}, +{0x00000DD8, 0x00000DDF}, {0x00000DF2, 0x00000DF3}, {0x00000E31, 0x00000E31}, {0x00000E34, 0x00000E3A}, +{0x00000E47, 0x00000E4E}, {0x00000EB1, 0x00000EB1}, {0x00000EB4, 0x00000EBC}, {0x00000EC8, 0x00000ECE}, +{0x00000F18, 0x00000F19}, {0x00000F35, 0x00000F35}, {0x00000F37, 0x00000F37}, {0x00000F39, 0x00000F39}, +{0x00000F3E, 0x00000F3F}, {0x00000F71, 0x00000F84}, {0x00000F86, 0x00000F87}, {0x00000F8D, 0x00000F97}, +{0x00000F99, 0x00000FBC}, {0x00000FC6, 0x00000FC6}, {0x0000102B, 0x0000103E}, {0x00001056, 0x00001059}, +{0x0000105E, 0x00001060}, {0x00001062, 0x00001064}, {0x00001067, 0x0000106D}, {0x00001071, 0x00001074}, +{0x00001082, 0x0000108D}, {0x0000108F, 0x0000108F}, {0x0000109A, 0x0000109D}, {0x0000135D, 0x0000135F}, +{0x00001712, 0x00001715}, {0x00001732, 0x00001734}, {0x00001752, 0x00001753}, {0x00001772, 0x00001773}, +{0x000017B4, 0x000017D3}, {0x000017DD, 0x000017DD}, {0x0000180B, 0x0000180D}, {0x0000180F, 0x0000180F}, {0x00001885, 0x00001886}, {0x000018A9, 0x000018A9}, {0x00001920, 0x0000192B}, {0x00001930, 0x0000193B}, {0x00001A17, 0x00001A1B}, {0x00001A55, 0x00001A5E}, {0x00001A60, 0x00001A7C}, {0x00001A7F, 0x00001A7F}, -{0x00001AB0, 0x00001AC0}, {0x00001B00, 0x00001B04}, {0x00001B34, 0x00001B44}, {0x00001B6B, 0x00001B73}, +{0x00001AB0, 0x00001ACE}, {0x00001B00, 0x00001B04}, {0x00001B34, 0x00001B44}, {0x00001B6B, 0x00001B73}, {0x00001B80, 0x00001B82}, {0x00001BA1, 0x00001BAD}, {0x00001BE6, 0x00001BF3}, {0x00001C24, 0x00001C37}, {0x00001CD0, 0x00001CD2}, {0x00001CD4, 0x00001CE8}, {0x00001CED, 0x00001CED}, {0x00001CF4, 0x00001CF4}, -{0x00001CF7, 0x00001CF9}, {0x00001DC0, 0x00001DF9}, {0x00001DFB, 0x00001DFF}, {0x000020D0, 0x000020F0}, -{0x00002CEF, 0x00002CF1}, {0x00002D7F, 0x00002D7F}, {0x00002DE0, 0x00002DFF}, {0x0000302A, 0x0000302F}, -{0x00003099, 0x0000309A}, {0x0000A66F, 0x0000A672}, {0x0000A674, 0x0000A67D}, {0x0000A69E, 0x0000A69F}, -{0x0000A6F0, 0x0000A6F1}, {0x0000A802, 0x0000A802}, {0x0000A806, 0x0000A806}, {0x0000A80B, 0x0000A80B}, -{0x0000A823, 0x0000A827}, {0x0000A82C, 0x0000A82C}, {0x0000A880, 0x0000A881}, {0x0000A8B4, 0x0000A8C5}, -{0x0000A8E0, 0x0000A8F1}, {0x0000A8FF, 0x0000A8FF}, {0x0000A926, 0x0000A92D}, {0x0000A947, 0x0000A953}, -{0x0000A980, 0x0000A983}, {0x0000A9B3, 0x0000A9C0}, {0x0000A9E5, 0x0000A9E5}, {0x0000AA29, 0x0000AA36}, -{0x0000AA43, 0x0000AA43}, {0x0000AA4C, 0x0000AA4D}, {0x0000AA7B, 0x0000AA7D}, {0x0000AAB0, 0x0000AAB0}, -{0x0000AAB2, 0x0000AAB4}, {0x0000AAB7, 0x0000AAB8}, {0x0000AABE, 0x0000AABF}, {0x0000AAC1, 0x0000AAC1}, -{0x0000AAEB, 0x0000AAEF}, {0x0000AAF5, 0x0000AAF6}, {0x0000ABE3, 0x0000ABEA}, {0x0000ABEC, 0x0000ABED}, -{0x0000FB1E, 0x0000FB1E}, {0x0000FE00, 0x0000FE0F}, {0x0000FE20, 0x0000FE2F}, {0x000101FD, 0x000101FD}, -{0x000102E0, 0x000102E0}, {0x00010376, 0x0001037A}, {0x00010A01, 0x00010A03}, {0x00010A05, 0x00010A06}, -{0x00010A0C, 0x00010A0F}, {0x00010A38, 0x00010A3A}, {0x00010A3F, 0x00010A3F}, {0x00010AE5, 0x00010AE6}, -{0x00010D24, 0x00010D27}, {0x00010EAB, 0x00010EAC}, {0x00010F46, 0x00010F50}, {0x00011000, 0x00011002}, -{0x00011038, 0x00011046}, {0x0001107F, 0x00011082}, {0x000110B0, 0x000110BA}, {0x00011100, 0x00011102}, +{0x00001CF7, 0x00001CF9}, {0x00001DC0, 0x00001DFF}, {0x000020D0, 0x000020F0}, {0x00002CEF, 0x00002CF1}, +{0x00002D7F, 0x00002D7F}, {0x00002DE0, 0x00002DFF}, {0x0000302A, 0x0000302F}, {0x00003099, 0x0000309A}, +{0x0000A66F, 0x0000A672}, {0x0000A674, 0x0000A67D}, {0x0000A69E, 0x0000A69F}, {0x0000A6F0, 0x0000A6F1}, +{0x0000A802, 0x0000A802}, {0x0000A806, 0x0000A806}, {0x0000A80B, 0x0000A80B}, {0x0000A823, 0x0000A827}, +{0x0000A82C, 0x0000A82C}, {0x0000A880, 0x0000A881}, {0x0000A8B4, 0x0000A8C5}, {0x0000A8E0, 0x0000A8F1}, +{0x0000A8FF, 0x0000A8FF}, {0x0000A926, 0x0000A92D}, {0x0000A947, 0x0000A953}, {0x0000A980, 0x0000A983}, +{0x0000A9B3, 0x0000A9C0}, {0x0000A9E5, 0x0000A9E5}, {0x0000AA29, 0x0000AA36}, {0x0000AA43, 0x0000AA43}, +{0x0000AA4C, 0x0000AA4D}, {0x0000AA7B, 0x0000AA7D}, {0x0000AAB0, 0x0000AAB0}, {0x0000AAB2, 0x0000AAB4}, +{0x0000AAB7, 0x0000AAB8}, {0x0000AABE, 0x0000AABF}, {0x0000AAC1, 0x0000AAC1}, {0x0000AAEB, 0x0000AAEF}, +{0x0000AAF5, 0x0000AAF6}, {0x0000ABE3, 0x0000ABEA}, {0x0000ABEC, 0x0000ABED}, {0x0000FB1E, 0x0000FB1E}, +{0x0000FE00, 0x0000FE0F}, {0x0000FE20, 0x0000FE2F}, {0x000101FD, 0x000101FD}, {0x000102E0, 0x000102E0}, +{0x00010376, 0x0001037A}, {0x00010A01, 0x00010A03}, {0x00010A05, 0x00010A06}, {0x00010A0C, 0x00010A0F}, +{0x00010A38, 0x00010A3A}, {0x00010A3F, 0x00010A3F}, {0x00010AE5, 0x00010AE6}, {0x00010D24, 0x00010D27}, +{0x00010EAB, 0x00010EAC}, {0x00010EFD, 0x00010EFF}, {0x00010F46, 0x00010F50}, {0x00010F82, 0x00010F85}, +{0x00011000, 0x00011002}, {0x00011038, 0x00011046}, {0x00011070, 0x00011070}, {0x00011073, 0x00011074}, +{0x0001107F, 0x00011082}, {0x000110B0, 0x000110BA}, {0x000110C2, 0x000110C2}, {0x00011100, 0x00011102}, {0x00011127, 0x00011134}, {0x00011145, 0x00011146}, {0x00011173, 0x00011173}, {0x00011180, 0x00011182}, {0x000111B3, 0x000111C0}, {0x000111C9, 0x000111CC}, {0x000111CE, 0x000111CF}, {0x0001122C, 0x00011237}, -{0x0001123E, 0x0001123E}, {0x000112DF, 0x000112EA}, {0x00011300, 0x00011303}, {0x0001133B, 0x0001133C}, -{0x0001133E, 0x00011344}, {0x00011347, 0x00011348}, {0x0001134B, 0x0001134D}, {0x00011357, 0x00011357}, -{0x00011362, 0x00011363}, {0x00011366, 0x0001136C}, {0x00011370, 0x00011374}, {0x00011435, 0x00011446}, -{0x0001145E, 0x0001145E}, {0x000114B0, 0x000114C3}, {0x000115AF, 0x000115B5}, {0x000115B8, 0x000115C0}, -{0x000115DC, 0x000115DD}, {0x00011630, 0x00011640}, {0x000116AB, 0x000116B7}, {0x0001171D, 0x0001172B}, -{0x0001182C, 0x0001183A}, {0x00011930, 0x00011935}, {0x00011937, 0x00011938}, {0x0001193B, 0x0001193E}, -{0x00011940, 0x00011940}, {0x00011942, 0x00011943}, {0x000119D1, 0x000119D7}, {0x000119DA, 0x000119E0}, -{0x000119E4, 0x000119E4}, {0x00011A01, 0x00011A0A}, {0x00011A33, 0x00011A39}, {0x00011A3B, 0x00011A3E}, -{0x00011A47, 0x00011A47}, {0x00011A51, 0x00011A5B}, {0x00011A8A, 0x00011A99}, {0x00011C2F, 0x00011C36}, -{0x00011C38, 0x00011C3F}, {0x00011C92, 0x00011CA7}, {0x00011CA9, 0x00011CB6}, {0x00011D31, 0x00011D36}, -{0x00011D3A, 0x00011D3A}, {0x00011D3C, 0x00011D3D}, {0x00011D3F, 0x00011D45}, {0x00011D47, 0x00011D47}, -{0x00011D8A, 0x00011D8E}, {0x00011D90, 0x00011D91}, {0x00011D93, 0x00011D97}, {0x00011EF3, 0x00011EF6}, -{0x00016AF0, 0x00016AF4}, {0x00016B30, 0x00016B36}, {0x00016F4F, 0x00016F4F}, {0x00016F51, 0x00016F87}, -{0x00016F8F, 0x00016F92}, {0x00016FE4, 0x00016FE4}, {0x00016FF0, 0x00016FF1}, {0x0001BC9D, 0x0001BC9E}, -{0x0001D165, 0x0001D169}, {0x0001D16D, 0x0001D172}, {0x0001D17B, 0x0001D182}, {0x0001D185, 0x0001D18B}, -{0x0001D1AA, 0x0001D1AD}, {0x0001D242, 0x0001D244}, {0x0001DA00, 0x0001DA36}, {0x0001DA3B, 0x0001DA6C}, -{0x0001DA75, 0x0001DA75}, {0x0001DA84, 0x0001DA84}, {0x0001DA9B, 0x0001DA9F}, {0x0001DAA1, 0x0001DAAF}, -{0x0001E000, 0x0001E006}, {0x0001E008, 0x0001E018}, {0x0001E01B, 0x0001E021}, {0x0001E023, 0x0001E024}, -{0x0001E026, 0x0001E02A}, {0x0001E130, 0x0001E136}, {0x0001E2EC, 0x0001E2EF}, {0x0001E8D0, 0x0001E8D6}, +{0x0001123E, 0x0001123E}, {0x00011241, 0x00011241}, {0x000112DF, 0x000112EA}, {0x00011300, 0x00011303}, +{0x0001133B, 0x0001133C}, {0x0001133E, 0x00011344}, {0x00011347, 0x00011348}, {0x0001134B, 0x0001134D}, +{0x00011357, 0x00011357}, {0x00011362, 0x00011363}, {0x00011366, 0x0001136C}, {0x00011370, 0x00011374}, +{0x00011435, 0x00011446}, {0x0001145E, 0x0001145E}, {0x000114B0, 0x000114C3}, {0x000115AF, 0x000115B5}, +{0x000115B8, 0x000115C0}, {0x000115DC, 0x000115DD}, {0x00011630, 0x00011640}, {0x000116AB, 0x000116B7}, +{0x0001171D, 0x0001172B}, {0x0001182C, 0x0001183A}, {0x00011930, 0x00011935}, {0x00011937, 0x00011938}, +{0x0001193B, 0x0001193E}, {0x00011940, 0x00011940}, {0x00011942, 0x00011943}, {0x000119D1, 0x000119D7}, +{0x000119DA, 0x000119E0}, {0x000119E4, 0x000119E4}, {0x00011A01, 0x00011A0A}, {0x00011A33, 0x00011A39}, +{0x00011A3B, 0x00011A3E}, {0x00011A47, 0x00011A47}, {0x00011A51, 0x00011A5B}, {0x00011A8A, 0x00011A99}, +{0x00011C2F, 0x00011C36}, {0x00011C38, 0x00011C3F}, {0x00011C92, 0x00011CA7}, {0x00011CA9, 0x00011CB6}, +{0x00011D31, 0x00011D36}, {0x00011D3A, 0x00011D3A}, {0x00011D3C, 0x00011D3D}, {0x00011D3F, 0x00011D45}, +{0x00011D47, 0x00011D47}, {0x00011D8A, 0x00011D8E}, {0x00011D90, 0x00011D91}, {0x00011D93, 0x00011D97}, +{0x00011EF3, 0x00011EF6}, {0x00011F00, 0x00011F01}, {0x00011F03, 0x00011F03}, {0x00011F34, 0x00011F3A}, +{0x00011F3E, 0x00011F42}, {0x00013440, 0x00013440}, {0x00013447, 0x00013455}, {0x00016AF0, 0x00016AF4}, +{0x00016B30, 0x00016B36}, {0x00016F4F, 0x00016F4F}, {0x00016F51, 0x00016F87}, {0x00016F8F, 0x00016F92}, +{0x00016FE4, 0x00016FE4}, {0x00016FF0, 0x00016FF1}, {0x0001BC9D, 0x0001BC9E}, {0x0001CF00, 0x0001CF2D}, +{0x0001CF30, 0x0001CF46}, {0x0001D165, 0x0001D169}, {0x0001D16D, 0x0001D172}, {0x0001D17B, 0x0001D182}, +{0x0001D185, 0x0001D18B}, {0x0001D1AA, 0x0001D1AD}, {0x0001D242, 0x0001D244}, {0x0001DA00, 0x0001DA36}, +{0x0001DA3B, 0x0001DA6C}, {0x0001DA75, 0x0001DA75}, {0x0001DA84, 0x0001DA84}, {0x0001DA9B, 0x0001DA9F}, +{0x0001DAA1, 0x0001DAAF}, {0x0001E000, 0x0001E006}, {0x0001E008, 0x0001E018}, {0x0001E01B, 0x0001E021}, +{0x0001E023, 0x0001E024}, {0x0001E026, 0x0001E02A}, {0x0001E08F, 0x0001E08F}, {0x0001E130, 0x0001E136}, +{0x0001E2AE, 0x0001E2AE}, {0x0001E2EC, 0x0001E2EF}, {0x0001E4EC, 0x0001E4EF}, {0x0001E8D0, 0x0001E8D6}, {0x0001E944, 0x0001E94A}, {0x000E0100, 0x000E01EF}, }; @@ -276,7 +314,7 @@ const std::vector> unicode_ranges_punctuation = { {0x000000B6, 0x000000B7}, {0x000000BB, 0x000000BB}, {0x000000BF, 0x000000BF}, {0x0000037E, 0x0000037E}, {0x00000387, 0x00000387}, {0x0000055A, 0x0000055F}, {0x00000589, 0x0000058A}, {0x000005BE, 0x000005BE}, {0x000005C0, 0x000005C0}, {0x000005C3, 0x000005C3}, {0x000005C6, 0x000005C6}, {0x000005F3, 0x000005F4}, -{0x00000609, 0x0000060A}, {0x0000060C, 0x0000060D}, {0x0000061B, 0x0000061B}, {0x0000061E, 0x0000061F}, +{0x00000609, 0x0000060A}, {0x0000060C, 0x0000060D}, {0x0000061B, 0x0000061B}, {0x0000061D, 0x0000061F}, {0x0000066A, 0x0000066D}, {0x000006D4, 0x000006D4}, {0x00000700, 0x0000070D}, {0x000007F7, 0x000007F9}, {0x00000830, 0x0000083E}, {0x0000085E, 0x0000085E}, {0x00000964, 0x00000965}, {0x00000970, 0x00000970}, {0x000009FD, 0x000009FD}, {0x00000A76, 0x00000A76}, {0x00000AF0, 0x00000AF0}, {0x00000C77, 0x00000C77}, @@ -286,37 +324,38 @@ const std::vector> unicode_ranges_punctuation = { {0x00001360, 0x00001368}, {0x00001400, 0x00001400}, {0x0000166E, 0x0000166E}, {0x0000169B, 0x0000169C}, {0x000016EB, 0x000016ED}, {0x00001735, 0x00001736}, {0x000017D4, 0x000017D6}, {0x000017D8, 0x000017DA}, {0x00001800, 0x0000180A}, {0x00001944, 0x00001945}, {0x00001A1E, 0x00001A1F}, {0x00001AA0, 0x00001AA6}, -{0x00001AA8, 0x00001AAD}, {0x00001B5A, 0x00001B60}, {0x00001BFC, 0x00001BFF}, {0x00001C3B, 0x00001C3F}, -{0x00001C7E, 0x00001C7F}, {0x00001CC0, 0x00001CC7}, {0x00001CD3, 0x00001CD3}, {0x00002010, 0x00002027}, -{0x00002030, 0x00002043}, {0x00002045, 0x00002051}, {0x00002053, 0x0000205E}, {0x0000207D, 0x0000207E}, -{0x0000208D, 0x0000208E}, {0x00002308, 0x0000230B}, {0x00002329, 0x0000232A}, {0x00002768, 0x00002775}, -{0x000027C5, 0x000027C6}, {0x000027E6, 0x000027EF}, {0x00002983, 0x00002998}, {0x000029D8, 0x000029DB}, -{0x000029FC, 0x000029FD}, {0x00002CF9, 0x00002CFC}, {0x00002CFE, 0x00002CFF}, {0x00002D70, 0x00002D70}, -{0x00002E00, 0x00002E2E}, {0x00002E30, 0x00002E4F}, {0x00002E52, 0x00002E52}, {0x00003001, 0x00003003}, -{0x00003008, 0x00003011}, {0x00003014, 0x0000301F}, {0x00003030, 0x00003030}, {0x0000303D, 0x0000303D}, -{0x000030A0, 0x000030A0}, {0x000030FB, 0x000030FB}, {0x0000A4FE, 0x0000A4FF}, {0x0000A60D, 0x0000A60F}, -{0x0000A673, 0x0000A673}, {0x0000A67E, 0x0000A67E}, {0x0000A6F2, 0x0000A6F7}, {0x0000A874, 0x0000A877}, -{0x0000A8CE, 0x0000A8CF}, {0x0000A8F8, 0x0000A8FA}, {0x0000A8FC, 0x0000A8FC}, {0x0000A92E, 0x0000A92F}, -{0x0000A95F, 0x0000A95F}, {0x0000A9C1, 0x0000A9CD}, {0x0000A9DE, 0x0000A9DF}, {0x0000AA5C, 0x0000AA5F}, -{0x0000AADE, 0x0000AADF}, {0x0000AAF0, 0x0000AAF1}, {0x0000ABEB, 0x0000ABEB}, {0x0000FD3E, 0x0000FD3F}, -{0x0000FE10, 0x0000FE19}, {0x0000FE30, 0x0000FE52}, {0x0000FE54, 0x0000FE61}, {0x0000FE63, 0x0000FE63}, -{0x0000FE68, 0x0000FE68}, {0x0000FE6A, 0x0000FE6B}, {0x0000FF01, 0x0000FF03}, {0x0000FF05, 0x0000FF0A}, -{0x0000FF0C, 0x0000FF0F}, {0x0000FF1A, 0x0000FF1B}, {0x0000FF1F, 0x0000FF20}, {0x0000FF3B, 0x0000FF3D}, -{0x0000FF3F, 0x0000FF3F}, {0x0000FF5B, 0x0000FF5B}, {0x0000FF5D, 0x0000FF5D}, {0x0000FF5F, 0x0000FF65}, -{0x00010100, 0x00010102}, {0x0001039F, 0x0001039F}, {0x000103D0, 0x000103D0}, {0x0001056F, 0x0001056F}, -{0x00010857, 0x00010857}, {0x0001091F, 0x0001091F}, {0x0001093F, 0x0001093F}, {0x00010A50, 0x00010A58}, -{0x00010A7F, 0x00010A7F}, {0x00010AF0, 0x00010AF6}, {0x00010B39, 0x00010B3F}, {0x00010B99, 0x00010B9C}, -{0x00010EAD, 0x00010EAD}, {0x00010F55, 0x00010F59}, {0x00011047, 0x0001104D}, {0x000110BB, 0x000110BC}, -{0x000110BE, 0x000110C1}, {0x00011140, 0x00011143}, {0x00011174, 0x00011175}, {0x000111C5, 0x000111C8}, -{0x000111CD, 0x000111CD}, {0x000111DB, 0x000111DB}, {0x000111DD, 0x000111DF}, {0x00011238, 0x0001123D}, -{0x000112A9, 0x000112A9}, {0x0001144B, 0x0001144F}, {0x0001145A, 0x0001145B}, {0x0001145D, 0x0001145D}, -{0x000114C6, 0x000114C6}, {0x000115C1, 0x000115D7}, {0x00011641, 0x00011643}, {0x00011660, 0x0001166C}, -{0x0001173C, 0x0001173E}, {0x0001183B, 0x0001183B}, {0x00011944, 0x00011946}, {0x000119E2, 0x000119E2}, -{0x00011A3F, 0x00011A46}, {0x00011A9A, 0x00011A9C}, {0x00011A9E, 0x00011AA2}, {0x00011C41, 0x00011C45}, -{0x00011C70, 0x00011C71}, {0x00011EF7, 0x00011EF8}, {0x00011FFF, 0x00011FFF}, {0x00012470, 0x00012474}, -{0x00016A6E, 0x00016A6F}, {0x00016AF5, 0x00016AF5}, {0x00016B37, 0x00016B3B}, {0x00016B44, 0x00016B44}, -{0x00016E97, 0x00016E9A}, {0x00016FE2, 0x00016FE2}, {0x0001BC9F, 0x0001BC9F}, {0x0001DA87, 0x0001DA8B}, -{0x0001E95E, 0x0001E95F}, +{0x00001AA8, 0x00001AAD}, {0x00001B5A, 0x00001B60}, {0x00001B7D, 0x00001B7E}, {0x00001BFC, 0x00001BFF}, +{0x00001C3B, 0x00001C3F}, {0x00001C7E, 0x00001C7F}, {0x00001CC0, 0x00001CC7}, {0x00001CD3, 0x00001CD3}, +{0x00002010, 0x00002027}, {0x00002030, 0x00002043}, {0x00002045, 0x00002051}, {0x00002053, 0x0000205E}, +{0x0000207D, 0x0000207E}, {0x0000208D, 0x0000208E}, {0x00002308, 0x0000230B}, {0x00002329, 0x0000232A}, +{0x00002768, 0x00002775}, {0x000027C5, 0x000027C6}, {0x000027E6, 0x000027EF}, {0x00002983, 0x00002998}, +{0x000029D8, 0x000029DB}, {0x000029FC, 0x000029FD}, {0x00002CF9, 0x00002CFC}, {0x00002CFE, 0x00002CFF}, +{0x00002D70, 0x00002D70}, {0x00002E00, 0x00002E2E}, {0x00002E30, 0x00002E4F}, {0x00002E52, 0x00002E5D}, +{0x00003001, 0x00003003}, {0x00003008, 0x00003011}, {0x00003014, 0x0000301F}, {0x00003030, 0x00003030}, +{0x0000303D, 0x0000303D}, {0x000030A0, 0x000030A0}, {0x000030FB, 0x000030FB}, {0x0000A4FE, 0x0000A4FF}, +{0x0000A60D, 0x0000A60F}, {0x0000A673, 0x0000A673}, {0x0000A67E, 0x0000A67E}, {0x0000A6F2, 0x0000A6F7}, +{0x0000A874, 0x0000A877}, {0x0000A8CE, 0x0000A8CF}, {0x0000A8F8, 0x0000A8FA}, {0x0000A8FC, 0x0000A8FC}, +{0x0000A92E, 0x0000A92F}, {0x0000A95F, 0x0000A95F}, {0x0000A9C1, 0x0000A9CD}, {0x0000A9DE, 0x0000A9DF}, +{0x0000AA5C, 0x0000AA5F}, {0x0000AADE, 0x0000AADF}, {0x0000AAF0, 0x0000AAF1}, {0x0000ABEB, 0x0000ABEB}, +{0x0000FD3E, 0x0000FD3F}, {0x0000FE10, 0x0000FE19}, {0x0000FE30, 0x0000FE52}, {0x0000FE54, 0x0000FE61}, +{0x0000FE63, 0x0000FE63}, {0x0000FE68, 0x0000FE68}, {0x0000FE6A, 0x0000FE6B}, {0x0000FF01, 0x0000FF03}, +{0x0000FF05, 0x0000FF0A}, {0x0000FF0C, 0x0000FF0F}, {0x0000FF1A, 0x0000FF1B}, {0x0000FF1F, 0x0000FF20}, +{0x0000FF3B, 0x0000FF3D}, {0x0000FF3F, 0x0000FF3F}, {0x0000FF5B, 0x0000FF5B}, {0x0000FF5D, 0x0000FF5D}, +{0x0000FF5F, 0x0000FF65}, {0x00010100, 0x00010102}, {0x0001039F, 0x0001039F}, {0x000103D0, 0x000103D0}, +{0x0001056F, 0x0001056F}, {0x00010857, 0x00010857}, {0x0001091F, 0x0001091F}, {0x0001093F, 0x0001093F}, +{0x00010A50, 0x00010A58}, {0x00010A7F, 0x00010A7F}, {0x00010AF0, 0x00010AF6}, {0x00010B39, 0x00010B3F}, +{0x00010B99, 0x00010B9C}, {0x00010EAD, 0x00010EAD}, {0x00010F55, 0x00010F59}, {0x00010F86, 0x00010F89}, +{0x00011047, 0x0001104D}, {0x000110BB, 0x000110BC}, {0x000110BE, 0x000110C1}, {0x00011140, 0x00011143}, +{0x00011174, 0x00011175}, {0x000111C5, 0x000111C8}, {0x000111CD, 0x000111CD}, {0x000111DB, 0x000111DB}, +{0x000111DD, 0x000111DF}, {0x00011238, 0x0001123D}, {0x000112A9, 0x000112A9}, {0x0001144B, 0x0001144F}, +{0x0001145A, 0x0001145B}, {0x0001145D, 0x0001145D}, {0x000114C6, 0x000114C6}, {0x000115C1, 0x000115D7}, +{0x00011641, 0x00011643}, {0x00011660, 0x0001166C}, {0x000116B9, 0x000116B9}, {0x0001173C, 0x0001173E}, +{0x0001183B, 0x0001183B}, {0x00011944, 0x00011946}, {0x000119E2, 0x000119E2}, {0x00011A3F, 0x00011A46}, +{0x00011A9A, 0x00011A9C}, {0x00011A9E, 0x00011AA2}, {0x00011B00, 0x00011B09}, {0x00011C41, 0x00011C45}, +{0x00011C70, 0x00011C71}, {0x00011EF7, 0x00011EF8}, {0x00011F43, 0x00011F4F}, {0x00011FFF, 0x00011FFF}, +{0x00012470, 0x00012474}, {0x00012FF1, 0x00012FF2}, {0x00016A6E, 0x00016A6F}, {0x00016AF5, 0x00016AF5}, +{0x00016B37, 0x00016B3B}, {0x00016B44, 0x00016B44}, {0x00016E97, 0x00016E9A}, {0x00016FE2, 0x00016FE2}, +{0x0001BC9F, 0x0001BC9F}, {0x0001DA87, 0x0001DA8B}, {0x0001E95E, 0x0001E95F}, }; const std::vector> unicode_ranges_symbol = { @@ -328,170 +367,172 @@ const std::vector> unicode_ranges_symbol = { {0x00000375, 0x00000375}, {0x00000384, 0x00000385}, {0x000003F6, 0x000003F6}, {0x00000482, 0x00000482}, {0x0000058D, 0x0000058F}, {0x00000606, 0x00000608}, {0x0000060B, 0x0000060B}, {0x0000060E, 0x0000060F}, {0x000006DE, 0x000006DE}, {0x000006E9, 0x000006E9}, {0x000006FD, 0x000006FE}, {0x000007F6, 0x000007F6}, -{0x000007FE, 0x000007FF}, {0x000009F2, 0x000009F3}, {0x000009FA, 0x000009FB}, {0x00000AF1, 0x00000AF1}, -{0x00000B70, 0x00000B70}, {0x00000BF3, 0x00000BFA}, {0x00000C7F, 0x00000C7F}, {0x00000D4F, 0x00000D4F}, -{0x00000D79, 0x00000D79}, {0x00000E3F, 0x00000E3F}, {0x00000F01, 0x00000F03}, {0x00000F13, 0x00000F13}, -{0x00000F15, 0x00000F17}, {0x00000F1A, 0x00000F1F}, {0x00000F34, 0x00000F34}, {0x00000F36, 0x00000F36}, -{0x00000F38, 0x00000F38}, {0x00000FBE, 0x00000FC5}, {0x00000FC7, 0x00000FCC}, {0x00000FCE, 0x00000FCF}, -{0x00000FD5, 0x00000FD8}, {0x0000109E, 0x0000109F}, {0x00001390, 0x00001399}, {0x0000166D, 0x0000166D}, -{0x000017DB, 0x000017DB}, {0x00001940, 0x00001940}, {0x000019DE, 0x000019FF}, {0x00001B61, 0x00001B6A}, -{0x00001B74, 0x00001B7C}, {0x00001FBD, 0x00001FBD}, {0x00001FBF, 0x00001FC1}, {0x00001FCD, 0x00001FCF}, -{0x00001FDD, 0x00001FDF}, {0x00001FED, 0x00001FEF}, {0x00001FFD, 0x00001FFE}, {0x00002044, 0x00002044}, -{0x00002052, 0x00002052}, {0x0000207A, 0x0000207C}, {0x0000208A, 0x0000208C}, {0x000020A0, 0x000020BF}, -{0x00002100, 0x00002101}, {0x00002103, 0x00002106}, {0x00002108, 0x00002109}, {0x00002114, 0x00002114}, -{0x00002116, 0x00002118}, {0x0000211E, 0x00002123}, {0x00002125, 0x00002125}, {0x00002127, 0x00002127}, -{0x00002129, 0x00002129}, {0x0000212E, 0x0000212E}, {0x0000213A, 0x0000213B}, {0x00002140, 0x00002144}, -{0x0000214A, 0x0000214D}, {0x0000214F, 0x0000214F}, {0x0000218A, 0x0000218B}, {0x00002190, 0x00002307}, -{0x0000230C, 0x00002328}, {0x0000232B, 0x00002426}, {0x00002440, 0x0000244A}, {0x0000249C, 0x000024E9}, -{0x00002500, 0x00002767}, {0x00002794, 0x000027C4}, {0x000027C7, 0x000027E5}, {0x000027F0, 0x00002982}, -{0x00002999, 0x000029D7}, {0x000029DC, 0x000029FB}, {0x000029FE, 0x00002B73}, {0x00002B76, 0x00002B95}, -{0x00002B97, 0x00002BFF}, {0x00002CE5, 0x00002CEA}, {0x00002E50, 0x00002E51}, {0x00002E80, 0x00002E99}, -{0x00002E9B, 0x00002EF3}, {0x00002F00, 0x00002FD5}, {0x00002FF0, 0x00002FFB}, {0x00003004, 0x00003004}, -{0x00003012, 0x00003013}, {0x00003020, 0x00003020}, {0x00003036, 0x00003037}, {0x0000303E, 0x0000303F}, -{0x0000309B, 0x0000309C}, {0x00003190, 0x00003191}, {0x00003196, 0x0000319F}, {0x000031C0, 0x000031E3}, -{0x00003200, 0x0000321E}, {0x0000322A, 0x00003247}, {0x00003250, 0x00003250}, {0x00003260, 0x0000327F}, -{0x0000328A, 0x000032B0}, {0x000032C0, 0x000033FF}, {0x00004DC0, 0x00004DFF}, {0x0000A490, 0x0000A4C6}, -{0x0000A700, 0x0000A716}, {0x0000A720, 0x0000A721}, {0x0000A789, 0x0000A78A}, {0x0000A828, 0x0000A82B}, -{0x0000A836, 0x0000A839}, {0x0000AA77, 0x0000AA79}, {0x0000AB5B, 0x0000AB5B}, {0x0000AB6A, 0x0000AB6B}, -{0x0000FB29, 0x0000FB29}, {0x0000FBB2, 0x0000FBC1}, {0x0000FDFC, 0x0000FDFD}, {0x0000FE62, 0x0000FE62}, +{0x000007FE, 0x000007FF}, {0x00000888, 0x00000888}, {0x000009F2, 0x000009F3}, {0x000009FA, 0x000009FB}, +{0x00000AF1, 0x00000AF1}, {0x00000B70, 0x00000B70}, {0x00000BF3, 0x00000BFA}, {0x00000C7F, 0x00000C7F}, +{0x00000D4F, 0x00000D4F}, {0x00000D79, 0x00000D79}, {0x00000E3F, 0x00000E3F}, {0x00000F01, 0x00000F03}, +{0x00000F13, 0x00000F13}, {0x00000F15, 0x00000F17}, {0x00000F1A, 0x00000F1F}, {0x00000F34, 0x00000F34}, +{0x00000F36, 0x00000F36}, {0x00000F38, 0x00000F38}, {0x00000FBE, 0x00000FC5}, {0x00000FC7, 0x00000FCC}, +{0x00000FCE, 0x00000FCF}, {0x00000FD5, 0x00000FD8}, {0x0000109E, 0x0000109F}, {0x00001390, 0x00001399}, +{0x0000166D, 0x0000166D}, {0x000017DB, 0x000017DB}, {0x00001940, 0x00001940}, {0x000019DE, 0x000019FF}, +{0x00001B61, 0x00001B6A}, {0x00001B74, 0x00001B7C}, {0x00001FBD, 0x00001FBD}, {0x00001FBF, 0x00001FC1}, +{0x00001FCD, 0x00001FCF}, {0x00001FDD, 0x00001FDF}, {0x00001FED, 0x00001FEF}, {0x00001FFD, 0x00001FFE}, +{0x00002044, 0x00002044}, {0x00002052, 0x00002052}, {0x0000207A, 0x0000207C}, {0x0000208A, 0x0000208C}, +{0x000020A0, 0x000020C0}, {0x00002100, 0x00002101}, {0x00002103, 0x00002106}, {0x00002108, 0x00002109}, +{0x00002114, 0x00002114}, {0x00002116, 0x00002118}, {0x0000211E, 0x00002123}, {0x00002125, 0x00002125}, +{0x00002127, 0x00002127}, {0x00002129, 0x00002129}, {0x0000212E, 0x0000212E}, {0x0000213A, 0x0000213B}, +{0x00002140, 0x00002144}, {0x0000214A, 0x0000214D}, {0x0000214F, 0x0000214F}, {0x0000218A, 0x0000218B}, +{0x00002190, 0x00002307}, {0x0000230C, 0x00002328}, {0x0000232B, 0x00002426}, {0x00002440, 0x0000244A}, +{0x0000249C, 0x000024E9}, {0x00002500, 0x00002767}, {0x00002794, 0x000027C4}, {0x000027C7, 0x000027E5}, +{0x000027F0, 0x00002982}, {0x00002999, 0x000029D7}, {0x000029DC, 0x000029FB}, {0x000029FE, 0x00002B73}, +{0x00002B76, 0x00002B95}, {0x00002B97, 0x00002BFF}, {0x00002CE5, 0x00002CEA}, {0x00002E50, 0x00002E51}, +{0x00002E80, 0x00002E99}, {0x00002E9B, 0x00002EF3}, {0x00002F00, 0x00002FD5}, {0x00002FF0, 0x00002FFF}, +{0x00003004, 0x00003004}, {0x00003012, 0x00003013}, {0x00003020, 0x00003020}, {0x00003036, 0x00003037}, +{0x0000303E, 0x0000303F}, {0x0000309B, 0x0000309C}, {0x00003190, 0x00003191}, {0x00003196, 0x0000319F}, +{0x000031C0, 0x000031E3}, {0x000031EF, 0x000031EF}, {0x00003200, 0x0000321E}, {0x0000322A, 0x00003247}, +{0x00003250, 0x00003250}, {0x00003260, 0x0000327F}, {0x0000328A, 0x000032B0}, {0x000032C0, 0x000033FF}, +{0x00004DC0, 0x00004DFF}, {0x0000A490, 0x0000A4C6}, {0x0000A700, 0x0000A716}, {0x0000A720, 0x0000A721}, +{0x0000A789, 0x0000A78A}, {0x0000A828, 0x0000A82B}, {0x0000A836, 0x0000A839}, {0x0000AA77, 0x0000AA79}, +{0x0000AB5B, 0x0000AB5B}, {0x0000AB6A, 0x0000AB6B}, {0x0000FB29, 0x0000FB29}, {0x0000FBB2, 0x0000FBC2}, +{0x0000FD40, 0x0000FD4F}, {0x0000FDCF, 0x0000FDCF}, {0x0000FDFC, 0x0000FDFF}, {0x0000FE62, 0x0000FE62}, {0x0000FE64, 0x0000FE66}, {0x0000FE69, 0x0000FE69}, {0x0000FF04, 0x0000FF04}, {0x0000FF0B, 0x0000FF0B}, {0x0000FF1C, 0x0000FF1E}, {0x0000FF3E, 0x0000FF3E}, {0x0000FF40, 0x0000FF40}, {0x0000FF5C, 0x0000FF5C}, {0x0000FF5E, 0x0000FF5E}, {0x0000FFE0, 0x0000FFE6}, {0x0000FFE8, 0x0000FFEE}, {0x0000FFFC, 0x0000FFFD}, {0x00010137, 0x0001013F}, {0x00010179, 0x00010189}, {0x0001018C, 0x0001018E}, {0x00010190, 0x0001019C}, {0x000101A0, 0x000101A0}, {0x000101D0, 0x000101FC}, {0x00010877, 0x00010878}, {0x00010AC8, 0x00010AC8}, {0x0001173F, 0x0001173F}, {0x00011FD5, 0x00011FF1}, {0x00016B3C, 0x00016B3F}, {0x00016B45, 0x00016B45}, -{0x0001BC9C, 0x0001BC9C}, {0x0001D000, 0x0001D0F5}, {0x0001D100, 0x0001D126}, {0x0001D129, 0x0001D164}, -{0x0001D16A, 0x0001D16C}, {0x0001D183, 0x0001D184}, {0x0001D18C, 0x0001D1A9}, {0x0001D1AE, 0x0001D1E8}, -{0x0001D200, 0x0001D241}, {0x0001D245, 0x0001D245}, {0x0001D300, 0x0001D356}, {0x0001D6C1, 0x0001D6C1}, -{0x0001D6DB, 0x0001D6DB}, {0x0001D6FB, 0x0001D6FB}, {0x0001D715, 0x0001D715}, {0x0001D735, 0x0001D735}, -{0x0001D74F, 0x0001D74F}, {0x0001D76F, 0x0001D76F}, {0x0001D789, 0x0001D789}, {0x0001D7A9, 0x0001D7A9}, -{0x0001D7C3, 0x0001D7C3}, {0x0001D800, 0x0001D9FF}, {0x0001DA37, 0x0001DA3A}, {0x0001DA6D, 0x0001DA74}, -{0x0001DA76, 0x0001DA83}, {0x0001DA85, 0x0001DA86}, {0x0001E14F, 0x0001E14F}, {0x0001E2FF, 0x0001E2FF}, -{0x0001ECAC, 0x0001ECAC}, {0x0001ECB0, 0x0001ECB0}, {0x0001ED2E, 0x0001ED2E}, {0x0001EEF0, 0x0001EEF1}, -{0x0001F000, 0x0001F02B}, {0x0001F030, 0x0001F093}, {0x0001F0A0, 0x0001F0AE}, {0x0001F0B1, 0x0001F0BF}, -{0x0001F0C1, 0x0001F0CF}, {0x0001F0D1, 0x0001F0F5}, {0x0001F10D, 0x0001F1AD}, {0x0001F1E6, 0x0001F202}, -{0x0001F210, 0x0001F23B}, {0x0001F240, 0x0001F248}, {0x0001F250, 0x0001F251}, {0x0001F260, 0x0001F265}, -{0x0001F300, 0x0001F6D7}, {0x0001F6E0, 0x0001F6EC}, {0x0001F6F0, 0x0001F6FC}, {0x0001F700, 0x0001F773}, -{0x0001F780, 0x0001F7D8}, {0x0001F7E0, 0x0001F7EB}, {0x0001F800, 0x0001F80B}, {0x0001F810, 0x0001F847}, -{0x0001F850, 0x0001F859}, {0x0001F860, 0x0001F887}, {0x0001F890, 0x0001F8AD}, {0x0001F8B0, 0x0001F8B1}, -{0x0001F900, 0x0001F978}, {0x0001F97A, 0x0001F9CB}, {0x0001F9CD, 0x0001FA53}, {0x0001FA60, 0x0001FA6D}, -{0x0001FA70, 0x0001FA74}, {0x0001FA78, 0x0001FA7A}, {0x0001FA80, 0x0001FA86}, {0x0001FA90, 0x0001FAA8}, -{0x0001FAB0, 0x0001FAB6}, {0x0001FAC0, 0x0001FAC2}, {0x0001FAD0, 0x0001FAD6}, {0x0001FB00, 0x0001FB92}, +{0x0001BC9C, 0x0001BC9C}, {0x0001CF50, 0x0001CFC3}, {0x0001D000, 0x0001D0F5}, {0x0001D100, 0x0001D126}, +{0x0001D129, 0x0001D164}, {0x0001D16A, 0x0001D16C}, {0x0001D183, 0x0001D184}, {0x0001D18C, 0x0001D1A9}, +{0x0001D1AE, 0x0001D1EA}, {0x0001D200, 0x0001D241}, {0x0001D245, 0x0001D245}, {0x0001D300, 0x0001D356}, +{0x0001D6C1, 0x0001D6C1}, {0x0001D6DB, 0x0001D6DB}, {0x0001D6FB, 0x0001D6FB}, {0x0001D715, 0x0001D715}, +{0x0001D735, 0x0001D735}, {0x0001D74F, 0x0001D74F}, {0x0001D76F, 0x0001D76F}, {0x0001D789, 0x0001D789}, +{0x0001D7A9, 0x0001D7A9}, {0x0001D7C3, 0x0001D7C3}, {0x0001D800, 0x0001D9FF}, {0x0001DA37, 0x0001DA3A}, +{0x0001DA6D, 0x0001DA74}, {0x0001DA76, 0x0001DA83}, {0x0001DA85, 0x0001DA86}, {0x0001E14F, 0x0001E14F}, +{0x0001E2FF, 0x0001E2FF}, {0x0001ECAC, 0x0001ECAC}, {0x0001ECB0, 0x0001ECB0}, {0x0001ED2E, 0x0001ED2E}, +{0x0001EEF0, 0x0001EEF1}, {0x0001F000, 0x0001F02B}, {0x0001F030, 0x0001F093}, {0x0001F0A0, 0x0001F0AE}, +{0x0001F0B1, 0x0001F0BF}, {0x0001F0C1, 0x0001F0CF}, {0x0001F0D1, 0x0001F0F5}, {0x0001F10D, 0x0001F1AD}, +{0x0001F1E6, 0x0001F202}, {0x0001F210, 0x0001F23B}, {0x0001F240, 0x0001F248}, {0x0001F250, 0x0001F251}, +{0x0001F260, 0x0001F265}, {0x0001F300, 0x0001F6D7}, {0x0001F6DC, 0x0001F6EC}, {0x0001F6F0, 0x0001F6FC}, +{0x0001F700, 0x0001F776}, {0x0001F77B, 0x0001F7D9}, {0x0001F7E0, 0x0001F7EB}, {0x0001F7F0, 0x0001F7F0}, +{0x0001F800, 0x0001F80B}, {0x0001F810, 0x0001F847}, {0x0001F850, 0x0001F859}, {0x0001F860, 0x0001F887}, +{0x0001F890, 0x0001F8AD}, {0x0001F8B0, 0x0001F8B1}, {0x0001F900, 0x0001FA53}, {0x0001FA60, 0x0001FA6D}, +{0x0001FA70, 0x0001FA7C}, {0x0001FA80, 0x0001FA88}, {0x0001FA90, 0x0001FABD}, {0x0001FABF, 0x0001FAC5}, +{0x0001FACE, 0x0001FADB}, {0x0001FAE0, 0x0001FAE8}, {0x0001FAF0, 0x0001FAF8}, {0x0001FB00, 0x0001FB92}, {0x0001FB94, 0x0001FBCA}, }; const std::vector> unicode_ranges_control = { -{0x00000000, 0x00000008}, {0x0000000E, 0x0000001B}, {0x0000007F, 0x00000084}, {0x00000086, 0x0000009F}, -{0x000000AD, 0x000000AD}, {0x00000378, 0x00000379}, {0x00000380, 0x00000383}, {0x0000038B, 0x0000038B}, -{0x0000038D, 0x0000038D}, {0x000003A2, 0x000003A2}, {0x00000530, 0x00000530}, {0x00000557, 0x00000558}, -{0x0000058B, 0x0000058C}, {0x00000590, 0x00000590}, {0x000005C8, 0x000005CF}, {0x000005EB, 0x000005EE}, -{0x000005F5, 0x00000605}, {0x0000061C, 0x0000061D}, {0x000006DD, 0x000006DD}, {0x0000070E, 0x0000070F}, -{0x0000074B, 0x0000074C}, {0x000007B2, 0x000007BF}, {0x000007FB, 0x000007FC}, {0x0000082E, 0x0000082F}, -{0x0000083F, 0x0000083F}, {0x0000085C, 0x0000085D}, {0x0000085F, 0x0000085F}, {0x0000086B, 0x0000089F}, -{0x000008B5, 0x000008B5}, {0x000008C8, 0x000008D2}, {0x000008E2, 0x000008E2}, {0x00000984, 0x00000984}, -{0x0000098D, 0x0000098E}, {0x00000991, 0x00000992}, {0x000009A9, 0x000009A9}, {0x000009B1, 0x000009B1}, -{0x000009B3, 0x000009B5}, {0x000009BA, 0x000009BB}, {0x000009C5, 0x000009C6}, {0x000009C9, 0x000009CA}, -{0x000009CF, 0x000009D6}, {0x000009D8, 0x000009DB}, {0x000009DE, 0x000009DE}, {0x000009E4, 0x000009E5}, -{0x000009FF, 0x00000A00}, {0x00000A04, 0x00000A04}, {0x00000A0B, 0x00000A0E}, {0x00000A11, 0x00000A12}, -{0x00000A29, 0x00000A29}, {0x00000A31, 0x00000A31}, {0x00000A34, 0x00000A34}, {0x00000A37, 0x00000A37}, -{0x00000A3A, 0x00000A3B}, {0x00000A3D, 0x00000A3D}, {0x00000A43, 0x00000A46}, {0x00000A49, 0x00000A4A}, -{0x00000A4E, 0x00000A50}, {0x00000A52, 0x00000A58}, {0x00000A5D, 0x00000A5D}, {0x00000A5F, 0x00000A65}, -{0x00000A77, 0x00000A80}, {0x00000A84, 0x00000A84}, {0x00000A8E, 0x00000A8E}, {0x00000A92, 0x00000A92}, -{0x00000AA9, 0x00000AA9}, {0x00000AB1, 0x00000AB1}, {0x00000AB4, 0x00000AB4}, {0x00000ABA, 0x00000ABB}, -{0x00000AC6, 0x00000AC6}, {0x00000ACA, 0x00000ACA}, {0x00000ACE, 0x00000ACF}, {0x00000AD1, 0x00000ADF}, -{0x00000AE4, 0x00000AE5}, {0x00000AF2, 0x00000AF8}, {0x00000B00, 0x00000B00}, {0x00000B04, 0x00000B04}, -{0x00000B0D, 0x00000B0E}, {0x00000B11, 0x00000B12}, {0x00000B29, 0x00000B29}, {0x00000B31, 0x00000B31}, -{0x00000B34, 0x00000B34}, {0x00000B3A, 0x00000B3B}, {0x00000B45, 0x00000B46}, {0x00000B49, 0x00000B4A}, -{0x00000B4E, 0x00000B54}, {0x00000B58, 0x00000B5B}, {0x00000B5E, 0x00000B5E}, {0x00000B64, 0x00000B65}, -{0x00000B78, 0x00000B81}, {0x00000B84, 0x00000B84}, {0x00000B8B, 0x00000B8D}, {0x00000B91, 0x00000B91}, -{0x00000B96, 0x00000B98}, {0x00000B9B, 0x00000B9B}, {0x00000B9D, 0x00000B9D}, {0x00000BA0, 0x00000BA2}, -{0x00000BA5, 0x00000BA7}, {0x00000BAB, 0x00000BAD}, {0x00000BBA, 0x00000BBD}, {0x00000BC3, 0x00000BC5}, -{0x00000BC9, 0x00000BC9}, {0x00000BCE, 0x00000BCF}, {0x00000BD1, 0x00000BD6}, {0x00000BD8, 0x00000BE5}, -{0x00000BFB, 0x00000BFF}, {0x00000C0D, 0x00000C0D}, {0x00000C11, 0x00000C11}, {0x00000C29, 0x00000C29}, -{0x00000C3A, 0x00000C3C}, {0x00000C45, 0x00000C45}, {0x00000C49, 0x00000C49}, {0x00000C4E, 0x00000C54}, -{0x00000C57, 0x00000C57}, {0x00000C5B, 0x00000C5F}, {0x00000C64, 0x00000C65}, {0x00000C70, 0x00000C76}, -{0x00000C8D, 0x00000C8D}, {0x00000C91, 0x00000C91}, {0x00000CA9, 0x00000CA9}, {0x00000CB4, 0x00000CB4}, -{0x00000CBA, 0x00000CBB}, {0x00000CC5, 0x00000CC5}, {0x00000CC9, 0x00000CC9}, {0x00000CCE, 0x00000CD4}, -{0x00000CD7, 0x00000CDD}, {0x00000CDF, 0x00000CDF}, {0x00000CE4, 0x00000CE5}, {0x00000CF0, 0x00000CF0}, -{0x00000CF3, 0x00000CFF}, {0x00000D0D, 0x00000D0D}, {0x00000D11, 0x00000D11}, {0x00000D45, 0x00000D45}, -{0x00000D49, 0x00000D49}, {0x00000D50, 0x00000D53}, {0x00000D64, 0x00000D65}, {0x00000D80, 0x00000D80}, -{0x00000D84, 0x00000D84}, {0x00000D97, 0x00000D99}, {0x00000DB2, 0x00000DB2}, {0x00000DBC, 0x00000DBC}, -{0x00000DBE, 0x00000DBF}, {0x00000DC7, 0x00000DC9}, {0x00000DCB, 0x00000DCE}, {0x00000DD5, 0x00000DD5}, -{0x00000DD7, 0x00000DD7}, {0x00000DE0, 0x00000DE5}, {0x00000DF0, 0x00000DF1}, {0x00000DF5, 0x00000E00}, -{0x00000E3B, 0x00000E3E}, {0x00000E5C, 0x00000E80}, {0x00000E83, 0x00000E83}, {0x00000E85, 0x00000E85}, -{0x00000E8B, 0x00000E8B}, {0x00000EA4, 0x00000EA4}, {0x00000EA6, 0x00000EA6}, {0x00000EBE, 0x00000EBF}, -{0x00000EC5, 0x00000EC5}, {0x00000EC7, 0x00000EC7}, {0x00000ECE, 0x00000ECF}, {0x00000EDA, 0x00000EDB}, -{0x00000EE0, 0x00000EFF}, {0x00000F48, 0x00000F48}, {0x00000F6D, 0x00000F70}, {0x00000F98, 0x00000F98}, -{0x00000FBD, 0x00000FBD}, {0x00000FCD, 0x00000FCD}, {0x00000FDB, 0x00000FFF}, {0x000010C6, 0x000010C6}, -{0x000010C8, 0x000010CC}, {0x000010CE, 0x000010CF}, {0x00001249, 0x00001249}, {0x0000124E, 0x0000124F}, -{0x00001257, 0x00001257}, {0x00001259, 0x00001259}, {0x0000125E, 0x0000125F}, {0x00001289, 0x00001289}, -{0x0000128E, 0x0000128F}, {0x000012B1, 0x000012B1}, {0x000012B6, 0x000012B7}, {0x000012BF, 0x000012BF}, -{0x000012C1, 0x000012C1}, {0x000012C6, 0x000012C7}, {0x000012D7, 0x000012D7}, {0x00001311, 0x00001311}, -{0x00001316, 0x00001317}, {0x0000135B, 0x0000135C}, {0x0000137D, 0x0000137F}, {0x0000139A, 0x0000139F}, -{0x000013F6, 0x000013F7}, {0x000013FE, 0x000013FF}, {0x0000169D, 0x0000169F}, {0x000016F9, 0x000016FF}, -{0x0000170D, 0x0000170D}, {0x00001715, 0x0000171F}, {0x00001737, 0x0000173F}, {0x00001754, 0x0000175F}, -{0x0000176D, 0x0000176D}, {0x00001771, 0x00001771}, {0x00001774, 0x0000177F}, {0x000017DE, 0x000017DF}, -{0x000017EA, 0x000017EF}, {0x000017FA, 0x000017FF}, {0x0000180E, 0x0000180F}, {0x0000181A, 0x0000181F}, -{0x00001879, 0x0000187F}, {0x000018AB, 0x000018AF}, {0x000018F6, 0x000018FF}, {0x0000191F, 0x0000191F}, -{0x0000192C, 0x0000192F}, {0x0000193C, 0x0000193F}, {0x00001941, 0x00001943}, {0x0000196E, 0x0000196F}, -{0x00001975, 0x0000197F}, {0x000019AC, 0x000019AF}, {0x000019CA, 0x000019CF}, {0x000019DB, 0x000019DD}, -{0x00001A1C, 0x00001A1D}, {0x00001A5F, 0x00001A5F}, {0x00001A7D, 0x00001A7E}, {0x00001A8A, 0x00001A8F}, -{0x00001A9A, 0x00001A9F}, {0x00001AAE, 0x00001AAF}, {0x00001AC1, 0x00001AFF}, {0x00001B4C, 0x00001B4F}, -{0x00001B7D, 0x00001B7F}, {0x00001BF4, 0x00001BFB}, {0x00001C38, 0x00001C3A}, {0x00001C4A, 0x00001C4C}, -{0x00001C89, 0x00001C8F}, {0x00001CBB, 0x00001CBC}, {0x00001CC8, 0x00001CCF}, {0x00001CFB, 0x00001CFF}, -{0x00001DFA, 0x00001DFA}, {0x00001F16, 0x00001F17}, {0x00001F1E, 0x00001F1F}, {0x00001F46, 0x00001F47}, +{0x00000000, 0x0000001F}, {0x0000007F, 0x0000009F}, {0x000000AD, 0x000000AD}, {0x00000378, 0x00000379}, +{0x00000380, 0x00000383}, {0x0000038B, 0x0000038B}, {0x0000038D, 0x0000038D}, {0x000003A2, 0x000003A2}, +{0x00000530, 0x00000530}, {0x00000557, 0x00000558}, {0x0000058B, 0x0000058C}, {0x00000590, 0x00000590}, +{0x000005C8, 0x000005CF}, {0x000005EB, 0x000005EE}, {0x000005F5, 0x00000605}, {0x0000061C, 0x0000061C}, +{0x000006DD, 0x000006DD}, {0x0000070E, 0x0000070F}, {0x0000074B, 0x0000074C}, {0x000007B2, 0x000007BF}, +{0x000007FB, 0x000007FC}, {0x0000082E, 0x0000082F}, {0x0000083F, 0x0000083F}, {0x0000085C, 0x0000085D}, +{0x0000085F, 0x0000085F}, {0x0000086B, 0x0000086F}, {0x0000088F, 0x00000897}, {0x000008E2, 0x000008E2}, +{0x00000984, 0x00000984}, {0x0000098D, 0x0000098E}, {0x00000991, 0x00000992}, {0x000009A9, 0x000009A9}, +{0x000009B1, 0x000009B1}, {0x000009B3, 0x000009B5}, {0x000009BA, 0x000009BB}, {0x000009C5, 0x000009C6}, +{0x000009C9, 0x000009CA}, {0x000009CF, 0x000009D6}, {0x000009D8, 0x000009DB}, {0x000009DE, 0x000009DE}, +{0x000009E4, 0x000009E5}, {0x000009FF, 0x00000A00}, {0x00000A04, 0x00000A04}, {0x00000A0B, 0x00000A0E}, +{0x00000A11, 0x00000A12}, {0x00000A29, 0x00000A29}, {0x00000A31, 0x00000A31}, {0x00000A34, 0x00000A34}, +{0x00000A37, 0x00000A37}, {0x00000A3A, 0x00000A3B}, {0x00000A3D, 0x00000A3D}, {0x00000A43, 0x00000A46}, +{0x00000A49, 0x00000A4A}, {0x00000A4E, 0x00000A50}, {0x00000A52, 0x00000A58}, {0x00000A5D, 0x00000A5D}, +{0x00000A5F, 0x00000A65}, {0x00000A77, 0x00000A80}, {0x00000A84, 0x00000A84}, {0x00000A8E, 0x00000A8E}, +{0x00000A92, 0x00000A92}, {0x00000AA9, 0x00000AA9}, {0x00000AB1, 0x00000AB1}, {0x00000AB4, 0x00000AB4}, +{0x00000ABA, 0x00000ABB}, {0x00000AC6, 0x00000AC6}, {0x00000ACA, 0x00000ACA}, {0x00000ACE, 0x00000ACF}, +{0x00000AD1, 0x00000ADF}, {0x00000AE4, 0x00000AE5}, {0x00000AF2, 0x00000AF8}, {0x00000B00, 0x00000B00}, +{0x00000B04, 0x00000B04}, {0x00000B0D, 0x00000B0E}, {0x00000B11, 0x00000B12}, {0x00000B29, 0x00000B29}, +{0x00000B31, 0x00000B31}, {0x00000B34, 0x00000B34}, {0x00000B3A, 0x00000B3B}, {0x00000B45, 0x00000B46}, +{0x00000B49, 0x00000B4A}, {0x00000B4E, 0x00000B54}, {0x00000B58, 0x00000B5B}, {0x00000B5E, 0x00000B5E}, +{0x00000B64, 0x00000B65}, {0x00000B78, 0x00000B81}, {0x00000B84, 0x00000B84}, {0x00000B8B, 0x00000B8D}, +{0x00000B91, 0x00000B91}, {0x00000B96, 0x00000B98}, {0x00000B9B, 0x00000B9B}, {0x00000B9D, 0x00000B9D}, +{0x00000BA0, 0x00000BA2}, {0x00000BA5, 0x00000BA7}, {0x00000BAB, 0x00000BAD}, {0x00000BBA, 0x00000BBD}, +{0x00000BC3, 0x00000BC5}, {0x00000BC9, 0x00000BC9}, {0x00000BCE, 0x00000BCF}, {0x00000BD1, 0x00000BD6}, +{0x00000BD8, 0x00000BE5}, {0x00000BFB, 0x00000BFF}, {0x00000C0D, 0x00000C0D}, {0x00000C11, 0x00000C11}, +{0x00000C29, 0x00000C29}, {0x00000C3A, 0x00000C3B}, {0x00000C45, 0x00000C45}, {0x00000C49, 0x00000C49}, +{0x00000C4E, 0x00000C54}, {0x00000C57, 0x00000C57}, {0x00000C5B, 0x00000C5C}, {0x00000C5E, 0x00000C5F}, +{0x00000C64, 0x00000C65}, {0x00000C70, 0x00000C76}, {0x00000C8D, 0x00000C8D}, {0x00000C91, 0x00000C91}, +{0x00000CA9, 0x00000CA9}, {0x00000CB4, 0x00000CB4}, {0x00000CBA, 0x00000CBB}, {0x00000CC5, 0x00000CC5}, +{0x00000CC9, 0x00000CC9}, {0x00000CCE, 0x00000CD4}, {0x00000CD7, 0x00000CDC}, {0x00000CDF, 0x00000CDF}, +{0x00000CE4, 0x00000CE5}, {0x00000CF0, 0x00000CF0}, {0x00000CF4, 0x00000CFF}, {0x00000D0D, 0x00000D0D}, +{0x00000D11, 0x00000D11}, {0x00000D45, 0x00000D45}, {0x00000D49, 0x00000D49}, {0x00000D50, 0x00000D53}, +{0x00000D64, 0x00000D65}, {0x00000D80, 0x00000D80}, {0x00000D84, 0x00000D84}, {0x00000D97, 0x00000D99}, +{0x00000DB2, 0x00000DB2}, {0x00000DBC, 0x00000DBC}, {0x00000DBE, 0x00000DBF}, {0x00000DC7, 0x00000DC9}, +{0x00000DCB, 0x00000DCE}, {0x00000DD5, 0x00000DD5}, {0x00000DD7, 0x00000DD7}, {0x00000DE0, 0x00000DE5}, +{0x00000DF0, 0x00000DF1}, {0x00000DF5, 0x00000E00}, {0x00000E3B, 0x00000E3E}, {0x00000E5C, 0x00000E80}, +{0x00000E83, 0x00000E83}, {0x00000E85, 0x00000E85}, {0x00000E8B, 0x00000E8B}, {0x00000EA4, 0x00000EA4}, +{0x00000EA6, 0x00000EA6}, {0x00000EBE, 0x00000EBF}, {0x00000EC5, 0x00000EC5}, {0x00000EC7, 0x00000EC7}, +{0x00000ECF, 0x00000ECF}, {0x00000EDA, 0x00000EDB}, {0x00000EE0, 0x00000EFF}, {0x00000F48, 0x00000F48}, +{0x00000F6D, 0x00000F70}, {0x00000F98, 0x00000F98}, {0x00000FBD, 0x00000FBD}, {0x00000FCD, 0x00000FCD}, +{0x00000FDB, 0x00000FFF}, {0x000010C6, 0x000010C6}, {0x000010C8, 0x000010CC}, {0x000010CE, 0x000010CF}, +{0x00001249, 0x00001249}, {0x0000124E, 0x0000124F}, {0x00001257, 0x00001257}, {0x00001259, 0x00001259}, +{0x0000125E, 0x0000125F}, {0x00001289, 0x00001289}, {0x0000128E, 0x0000128F}, {0x000012B1, 0x000012B1}, +{0x000012B6, 0x000012B7}, {0x000012BF, 0x000012BF}, {0x000012C1, 0x000012C1}, {0x000012C6, 0x000012C7}, +{0x000012D7, 0x000012D7}, {0x00001311, 0x00001311}, {0x00001316, 0x00001317}, {0x0000135B, 0x0000135C}, +{0x0000137D, 0x0000137F}, {0x0000139A, 0x0000139F}, {0x000013F6, 0x000013F7}, {0x000013FE, 0x000013FF}, +{0x0000169D, 0x0000169F}, {0x000016F9, 0x000016FF}, {0x00001716, 0x0000171E}, {0x00001737, 0x0000173F}, +{0x00001754, 0x0000175F}, {0x0000176D, 0x0000176D}, {0x00001771, 0x00001771}, {0x00001774, 0x0000177F}, +{0x000017DE, 0x000017DF}, {0x000017EA, 0x000017EF}, {0x000017FA, 0x000017FF}, {0x0000180E, 0x0000180E}, +{0x0000181A, 0x0000181F}, {0x00001879, 0x0000187F}, {0x000018AB, 0x000018AF}, {0x000018F6, 0x000018FF}, +{0x0000191F, 0x0000191F}, {0x0000192C, 0x0000192F}, {0x0000193C, 0x0000193F}, {0x00001941, 0x00001943}, +{0x0000196E, 0x0000196F}, {0x00001975, 0x0000197F}, {0x000019AC, 0x000019AF}, {0x000019CA, 0x000019CF}, +{0x000019DB, 0x000019DD}, {0x00001A1C, 0x00001A1D}, {0x00001A5F, 0x00001A5F}, {0x00001A7D, 0x00001A7E}, +{0x00001A8A, 0x00001A8F}, {0x00001A9A, 0x00001A9F}, {0x00001AAE, 0x00001AAF}, {0x00001ACF, 0x00001AFF}, +{0x00001B4D, 0x00001B4F}, {0x00001B7F, 0x00001B7F}, {0x00001BF4, 0x00001BFB}, {0x00001C38, 0x00001C3A}, +{0x00001C4A, 0x00001C4C}, {0x00001C89, 0x00001C8F}, {0x00001CBB, 0x00001CBC}, {0x00001CC8, 0x00001CCF}, +{0x00001CFB, 0x00001CFF}, {0x00001F16, 0x00001F17}, {0x00001F1E, 0x00001F1F}, {0x00001F46, 0x00001F47}, {0x00001F4E, 0x00001F4F}, {0x00001F58, 0x00001F58}, {0x00001F5A, 0x00001F5A}, {0x00001F5C, 0x00001F5C}, {0x00001F5E, 0x00001F5E}, {0x00001F7E, 0x00001F7F}, {0x00001FB5, 0x00001FB5}, {0x00001FC5, 0x00001FC5}, {0x00001FD4, 0x00001FD5}, {0x00001FDC, 0x00001FDC}, {0x00001FF0, 0x00001FF1}, {0x00001FF5, 0x00001FF5}, {0x00001FFF, 0x00001FFF}, {0x0000200B, 0x0000200F}, {0x0000202A, 0x0000202E}, {0x00002060, 0x0000206F}, -{0x00002072, 0x00002073}, {0x0000208F, 0x0000208F}, {0x0000209D, 0x0000209F}, {0x000020C0, 0x000020CF}, +{0x00002072, 0x00002073}, {0x0000208F, 0x0000208F}, {0x0000209D, 0x0000209F}, {0x000020C1, 0x000020CF}, {0x000020F1, 0x000020FF}, {0x0000218C, 0x0000218F}, {0x00002427, 0x0000243F}, {0x0000244B, 0x0000245F}, -{0x00002B74, 0x00002B75}, {0x00002B96, 0x00002B96}, {0x00002C2F, 0x00002C2F}, {0x00002C5F, 0x00002C5F}, -{0x00002CF4, 0x00002CF8}, {0x00002D26, 0x00002D26}, {0x00002D28, 0x00002D2C}, {0x00002D2E, 0x00002D2F}, -{0x00002D68, 0x00002D6E}, {0x00002D71, 0x00002D7E}, {0x00002D97, 0x00002D9F}, {0x00002DA7, 0x00002DA7}, -{0x00002DAF, 0x00002DAF}, {0x00002DB7, 0x00002DB7}, {0x00002DBF, 0x00002DBF}, {0x00002DC7, 0x00002DC7}, -{0x00002DCF, 0x00002DCF}, {0x00002DD7, 0x00002DD7}, {0x00002DDF, 0x00002DDF}, {0x00002E53, 0x00002E7F}, -{0x00002E9A, 0x00002E9A}, {0x00002EF4, 0x00002EFF}, {0x00002FD6, 0x00002FEF}, {0x00002FFC, 0x00002FFF}, -{0x00003040, 0x00003040}, {0x00003097, 0x00003098}, {0x00003100, 0x00003104}, {0x00003130, 0x00003130}, -{0x0000318F, 0x0000318F}, {0x000031E4, 0x000031EF}, {0x0000321F, 0x0000321F}, {0x00009FFD, 0x00009FFF}, +{0x00002B74, 0x00002B75}, {0x00002B96, 0x00002B96}, {0x00002CF4, 0x00002CF8}, {0x00002D26, 0x00002D26}, +{0x00002D28, 0x00002D2C}, {0x00002D2E, 0x00002D2F}, {0x00002D68, 0x00002D6E}, {0x00002D71, 0x00002D7E}, +{0x00002D97, 0x00002D9F}, {0x00002DA7, 0x00002DA7}, {0x00002DAF, 0x00002DAF}, {0x00002DB7, 0x00002DB7}, +{0x00002DBF, 0x00002DBF}, {0x00002DC7, 0x00002DC7}, {0x00002DCF, 0x00002DCF}, {0x00002DD7, 0x00002DD7}, +{0x00002DDF, 0x00002DDF}, {0x00002E5E, 0x00002E7F}, {0x00002E9A, 0x00002E9A}, {0x00002EF4, 0x00002EFF}, +{0x00002FD6, 0x00002FEF}, {0x00003040, 0x00003040}, {0x00003097, 0x00003098}, {0x00003100, 0x00003104}, +{0x00003130, 0x00003130}, {0x0000318F, 0x0000318F}, {0x000031E4, 0x000031EE}, {0x0000321F, 0x0000321F}, {0x0000A48D, 0x0000A48F}, {0x0000A4C7, 0x0000A4CF}, {0x0000A62C, 0x0000A63F}, {0x0000A6F8, 0x0000A6FF}, -{0x0000A7C0, 0x0000A7C1}, {0x0000A7CB, 0x0000A7F4}, {0x0000A82D, 0x0000A82F}, {0x0000A83A, 0x0000A83F}, -{0x0000A878, 0x0000A87F}, {0x0000A8C6, 0x0000A8CD}, {0x0000A8DA, 0x0000A8DF}, {0x0000A954, 0x0000A95E}, -{0x0000A97D, 0x0000A97F}, {0x0000A9CE, 0x0000A9CE}, {0x0000A9DA, 0x0000A9DD}, {0x0000A9FF, 0x0000A9FF}, -{0x0000AA37, 0x0000AA3F}, {0x0000AA4E, 0x0000AA4F}, {0x0000AA5A, 0x0000AA5B}, {0x0000AAC3, 0x0000AADA}, -{0x0000AAF7, 0x0000AB00}, {0x0000AB07, 0x0000AB08}, {0x0000AB0F, 0x0000AB10}, {0x0000AB17, 0x0000AB1F}, -{0x0000AB27, 0x0000AB27}, {0x0000AB2F, 0x0000AB2F}, {0x0000AB6C, 0x0000AB6F}, {0x0000ABEE, 0x0000ABEF}, -{0x0000ABFA, 0x0000ABFF}, {0x0000D7A4, 0x0000D7AF}, {0x0000D7C7, 0x0000D7CA}, {0x0000D7FC, 0x0000F8FF}, -{0x0000FA6E, 0x0000FA6F}, {0x0000FADA, 0x0000FAFF}, {0x0000FB07, 0x0000FB12}, {0x0000FB18, 0x0000FB1C}, -{0x0000FB37, 0x0000FB37}, {0x0000FB3D, 0x0000FB3D}, {0x0000FB3F, 0x0000FB3F}, {0x0000FB42, 0x0000FB42}, -{0x0000FB45, 0x0000FB45}, {0x0000FBC2, 0x0000FBD2}, {0x0000FD40, 0x0000FD4F}, {0x0000FD90, 0x0000FD91}, -{0x0000FDC8, 0x0000FDEF}, {0x0000FDFE, 0x0000FDFF}, {0x0000FE1A, 0x0000FE1F}, {0x0000FE53, 0x0000FE53}, -{0x0000FE67, 0x0000FE67}, {0x0000FE6C, 0x0000FE6F}, {0x0000FE75, 0x0000FE75}, {0x0000FEFD, 0x0000FF00}, -{0x0000FFBF, 0x0000FFC1}, {0x0000FFC8, 0x0000FFC9}, {0x0000FFD0, 0x0000FFD1}, {0x0000FFD8, 0x0000FFD9}, -{0x0000FFDD, 0x0000FFDF}, {0x0000FFE7, 0x0000FFE7}, {0x0000FFEF, 0x0000FFFB}, {0x0000FFFE, 0x0000FFFF}, -{0x0001000C, 0x0001000C}, {0x00010027, 0x00010027}, {0x0001003B, 0x0001003B}, {0x0001003E, 0x0001003E}, -{0x0001004E, 0x0001004F}, {0x0001005E, 0x0001007F}, {0x000100FB, 0x000100FF}, {0x00010103, 0x00010106}, -{0x00010134, 0x00010136}, {0x0001018F, 0x0001018F}, {0x0001019D, 0x0001019F}, {0x000101A1, 0x000101CF}, -{0x000101FE, 0x0001027F}, {0x0001029D, 0x0001029F}, {0x000102D1, 0x000102DF}, {0x000102FC, 0x000102FF}, -{0x00010324, 0x0001032C}, {0x0001034B, 0x0001034F}, {0x0001037B, 0x0001037F}, {0x0001039E, 0x0001039E}, -{0x000103C4, 0x000103C7}, {0x000103D6, 0x000103FF}, {0x0001049E, 0x0001049F}, {0x000104AA, 0x000104AF}, -{0x000104D4, 0x000104D7}, {0x000104FC, 0x000104FF}, {0x00010528, 0x0001052F}, {0x00010564, 0x0001056E}, -{0x00010570, 0x000105FF}, {0x00010737, 0x0001073F}, {0x00010756, 0x0001075F}, {0x00010768, 0x000107FF}, -{0x00010806, 0x00010807}, {0x00010809, 0x00010809}, {0x00010836, 0x00010836}, {0x00010839, 0x0001083B}, -{0x0001083D, 0x0001083E}, {0x00010856, 0x00010856}, {0x0001089F, 0x000108A6}, {0x000108B0, 0x000108DF}, -{0x000108F3, 0x000108F3}, {0x000108F6, 0x000108FA}, {0x0001091C, 0x0001091E}, {0x0001093A, 0x0001093E}, -{0x00010940, 0x0001097F}, {0x000109B8, 0x000109BB}, {0x000109D0, 0x000109D1}, {0x00010A04, 0x00010A04}, -{0x00010A07, 0x00010A0B}, {0x00010A14, 0x00010A14}, {0x00010A18, 0x00010A18}, {0x00010A36, 0x00010A37}, -{0x00010A3B, 0x00010A3E}, {0x00010A49, 0x00010A4F}, {0x00010A59, 0x00010A5F}, {0x00010AA0, 0x00010ABF}, -{0x00010AE7, 0x00010AEA}, {0x00010AF7, 0x00010AFF}, {0x00010B36, 0x00010B38}, {0x00010B56, 0x00010B57}, -{0x00010B73, 0x00010B77}, {0x00010B92, 0x00010B98}, {0x00010B9D, 0x00010BA8}, {0x00010BB0, 0x00010BFF}, -{0x00010C49, 0x00010C7F}, {0x00010CB3, 0x00010CBF}, {0x00010CF3, 0x00010CF9}, {0x00010D28, 0x00010D2F}, -{0x00010D3A, 0x00010E5F}, {0x00010E7F, 0x00010E7F}, {0x00010EAA, 0x00010EAA}, {0x00010EAE, 0x00010EAF}, -{0x00010EB2, 0x00010EFF}, {0x00010F28, 0x00010F2F}, {0x00010F5A, 0x00010FAF}, {0x00010FCC, 0x00010FDF}, -{0x00010FF7, 0x00010FFF}, {0x0001104E, 0x00011051}, {0x00011070, 0x0001107E}, {0x000110BD, 0x000110BD}, -{0x000110C2, 0x000110CF}, {0x000110E9, 0x000110EF}, {0x000110FA, 0x000110FF}, {0x00011135, 0x00011135}, +{0x0000A7CB, 0x0000A7CF}, {0x0000A7D2, 0x0000A7D2}, {0x0000A7D4, 0x0000A7D4}, {0x0000A7DA, 0x0000A7F1}, +{0x0000A82D, 0x0000A82F}, {0x0000A83A, 0x0000A83F}, {0x0000A878, 0x0000A87F}, {0x0000A8C6, 0x0000A8CD}, +{0x0000A8DA, 0x0000A8DF}, {0x0000A954, 0x0000A95E}, {0x0000A97D, 0x0000A97F}, {0x0000A9CE, 0x0000A9CE}, +{0x0000A9DA, 0x0000A9DD}, {0x0000A9FF, 0x0000A9FF}, {0x0000AA37, 0x0000AA3F}, {0x0000AA4E, 0x0000AA4F}, +{0x0000AA5A, 0x0000AA5B}, {0x0000AAC3, 0x0000AADA}, {0x0000AAF7, 0x0000AB00}, {0x0000AB07, 0x0000AB08}, +{0x0000AB0F, 0x0000AB10}, {0x0000AB17, 0x0000AB1F}, {0x0000AB27, 0x0000AB27}, {0x0000AB2F, 0x0000AB2F}, +{0x0000AB6C, 0x0000AB6F}, {0x0000ABEE, 0x0000ABEF}, {0x0000ABFA, 0x0000ABFF}, {0x0000D7A4, 0x0000D7AF}, +{0x0000D7C7, 0x0000D7CA}, {0x0000D7FC, 0x0000F8FF}, {0x0000FA6E, 0x0000FA6F}, {0x0000FADA, 0x0000FAFF}, +{0x0000FB07, 0x0000FB12}, {0x0000FB18, 0x0000FB1C}, {0x0000FB37, 0x0000FB37}, {0x0000FB3D, 0x0000FB3D}, +{0x0000FB3F, 0x0000FB3F}, {0x0000FB42, 0x0000FB42}, {0x0000FB45, 0x0000FB45}, {0x0000FBC3, 0x0000FBD2}, +{0x0000FD90, 0x0000FD91}, {0x0000FDC8, 0x0000FDCE}, {0x0000FDD0, 0x0000FDEF}, {0x0000FE1A, 0x0000FE1F}, +{0x0000FE53, 0x0000FE53}, {0x0000FE67, 0x0000FE67}, {0x0000FE6C, 0x0000FE6F}, {0x0000FE75, 0x0000FE75}, +{0x0000FEFD, 0x0000FF00}, {0x0000FFBF, 0x0000FFC1}, {0x0000FFC8, 0x0000FFC9}, {0x0000FFD0, 0x0000FFD1}, +{0x0000FFD8, 0x0000FFD9}, {0x0000FFDD, 0x0000FFDF}, {0x0000FFE7, 0x0000FFE7}, {0x0000FFEF, 0x0000FFFB}, +{0x0000FFFE, 0x0000FFFF}, {0x0001000C, 0x0001000C}, {0x00010027, 0x00010027}, {0x0001003B, 0x0001003B}, +{0x0001003E, 0x0001003E}, {0x0001004E, 0x0001004F}, {0x0001005E, 0x0001007F}, {0x000100FB, 0x000100FF}, +{0x00010103, 0x00010106}, {0x00010134, 0x00010136}, {0x0001018F, 0x0001018F}, {0x0001019D, 0x0001019F}, +{0x000101A1, 0x000101CF}, {0x000101FE, 0x0001027F}, {0x0001029D, 0x0001029F}, {0x000102D1, 0x000102DF}, +{0x000102FC, 0x000102FF}, {0x00010324, 0x0001032C}, {0x0001034B, 0x0001034F}, {0x0001037B, 0x0001037F}, +{0x0001039E, 0x0001039E}, {0x000103C4, 0x000103C7}, {0x000103D6, 0x000103FF}, {0x0001049E, 0x0001049F}, +{0x000104AA, 0x000104AF}, {0x000104D4, 0x000104D7}, {0x000104FC, 0x000104FF}, {0x00010528, 0x0001052F}, +{0x00010564, 0x0001056E}, {0x0001057B, 0x0001057B}, {0x0001058B, 0x0001058B}, {0x00010593, 0x00010593}, +{0x00010596, 0x00010596}, {0x000105A2, 0x000105A2}, {0x000105B2, 0x000105B2}, {0x000105BA, 0x000105BA}, +{0x000105BD, 0x000105FF}, {0x00010737, 0x0001073F}, {0x00010756, 0x0001075F}, {0x00010768, 0x0001077F}, +{0x00010786, 0x00010786}, {0x000107B1, 0x000107B1}, {0x000107BB, 0x000107FF}, {0x00010806, 0x00010807}, +{0x00010809, 0x00010809}, {0x00010836, 0x00010836}, {0x00010839, 0x0001083B}, {0x0001083D, 0x0001083E}, +{0x00010856, 0x00010856}, {0x0001089F, 0x000108A6}, {0x000108B0, 0x000108DF}, {0x000108F3, 0x000108F3}, +{0x000108F6, 0x000108FA}, {0x0001091C, 0x0001091E}, {0x0001093A, 0x0001093E}, {0x00010940, 0x0001097F}, +{0x000109B8, 0x000109BB}, {0x000109D0, 0x000109D1}, {0x00010A04, 0x00010A04}, {0x00010A07, 0x00010A0B}, +{0x00010A14, 0x00010A14}, {0x00010A18, 0x00010A18}, {0x00010A36, 0x00010A37}, {0x00010A3B, 0x00010A3E}, +{0x00010A49, 0x00010A4F}, {0x00010A59, 0x00010A5F}, {0x00010AA0, 0x00010ABF}, {0x00010AE7, 0x00010AEA}, +{0x00010AF7, 0x00010AFF}, {0x00010B36, 0x00010B38}, {0x00010B56, 0x00010B57}, {0x00010B73, 0x00010B77}, +{0x00010B92, 0x00010B98}, {0x00010B9D, 0x00010BA8}, {0x00010BB0, 0x00010BFF}, {0x00010C49, 0x00010C7F}, +{0x00010CB3, 0x00010CBF}, {0x00010CF3, 0x00010CF9}, {0x00010D28, 0x00010D2F}, {0x00010D3A, 0x00010E5F}, +{0x00010E7F, 0x00010E7F}, {0x00010EAA, 0x00010EAA}, {0x00010EAE, 0x00010EAF}, {0x00010EB2, 0x00010EFC}, +{0x00010F28, 0x00010F2F}, {0x00010F5A, 0x00010F6F}, {0x00010F8A, 0x00010FAF}, {0x00010FCC, 0x00010FDF}, +{0x00010FF7, 0x00010FFF}, {0x0001104E, 0x00011051}, {0x00011076, 0x0001107E}, {0x000110BD, 0x000110BD}, +{0x000110C3, 0x000110CF}, {0x000110E9, 0x000110EF}, {0x000110FA, 0x000110FF}, {0x00011135, 0x00011135}, {0x00011148, 0x0001114F}, {0x00011177, 0x0001117F}, {0x000111E0, 0x000111E0}, {0x000111F5, 0x000111FF}, -{0x00011212, 0x00011212}, {0x0001123F, 0x0001127F}, {0x00011287, 0x00011287}, {0x00011289, 0x00011289}, +{0x00011212, 0x00011212}, {0x00011242, 0x0001127F}, {0x00011287, 0x00011287}, {0x00011289, 0x00011289}, {0x0001128E, 0x0001128E}, {0x0001129E, 0x0001129E}, {0x000112AA, 0x000112AF}, {0x000112EB, 0x000112EF}, {0x000112FA, 0x000112FF}, {0x00011304, 0x00011304}, {0x0001130D, 0x0001130E}, {0x00011311, 0x00011312}, {0x00011329, 0x00011329}, {0x00011331, 0x00011331}, {0x00011334, 0x00011334}, {0x0001133A, 0x0001133A}, @@ -499,59 +540,792 @@ const std::vector> unicode_ranges_control = { {0x00011358, 0x0001135C}, {0x00011364, 0x00011365}, {0x0001136D, 0x0001136F}, {0x00011375, 0x000113FF}, {0x0001145C, 0x0001145C}, {0x00011462, 0x0001147F}, {0x000114C8, 0x000114CF}, {0x000114DA, 0x0001157F}, {0x000115B6, 0x000115B7}, {0x000115DE, 0x000115FF}, {0x00011645, 0x0001164F}, {0x0001165A, 0x0001165F}, -{0x0001166D, 0x0001167F}, {0x000116B9, 0x000116BF}, {0x000116CA, 0x000116FF}, {0x0001171B, 0x0001171C}, -{0x0001172C, 0x0001172F}, {0x00011740, 0x000117FF}, {0x0001183C, 0x0001189F}, {0x000118F3, 0x000118FE}, +{0x0001166D, 0x0001167F}, {0x000116BA, 0x000116BF}, {0x000116CA, 0x000116FF}, {0x0001171B, 0x0001171C}, +{0x0001172C, 0x0001172F}, {0x00011747, 0x000117FF}, {0x0001183C, 0x0001189F}, {0x000118F3, 0x000118FE}, {0x00011907, 0x00011908}, {0x0001190A, 0x0001190B}, {0x00011914, 0x00011914}, {0x00011917, 0x00011917}, {0x00011936, 0x00011936}, {0x00011939, 0x0001193A}, {0x00011947, 0x0001194F}, {0x0001195A, 0x0001199F}, {0x000119A8, 0x000119A9}, {0x000119D8, 0x000119D9}, {0x000119E5, 0x000119FF}, {0x00011A48, 0x00011A4F}, -{0x00011AA3, 0x00011ABF}, {0x00011AF9, 0x00011BFF}, {0x00011C09, 0x00011C09}, {0x00011C37, 0x00011C37}, -{0x00011C46, 0x00011C4F}, {0x00011C6D, 0x00011C6F}, {0x00011C90, 0x00011C91}, {0x00011CA8, 0x00011CA8}, -{0x00011CB7, 0x00011CFF}, {0x00011D07, 0x00011D07}, {0x00011D0A, 0x00011D0A}, {0x00011D37, 0x00011D39}, -{0x00011D3B, 0x00011D3B}, {0x00011D3E, 0x00011D3E}, {0x00011D48, 0x00011D4F}, {0x00011D5A, 0x00011D5F}, -{0x00011D66, 0x00011D66}, {0x00011D69, 0x00011D69}, {0x00011D8F, 0x00011D8F}, {0x00011D92, 0x00011D92}, -{0x00011D99, 0x00011D9F}, {0x00011DAA, 0x00011EDF}, {0x00011EF9, 0x00011FAF}, {0x00011FB1, 0x00011FBF}, +{0x00011AA3, 0x00011AAF}, {0x00011AF9, 0x00011AFF}, {0x00011B0A, 0x00011BFF}, {0x00011C09, 0x00011C09}, +{0x00011C37, 0x00011C37}, {0x00011C46, 0x00011C4F}, {0x00011C6D, 0x00011C6F}, {0x00011C90, 0x00011C91}, +{0x00011CA8, 0x00011CA8}, {0x00011CB7, 0x00011CFF}, {0x00011D07, 0x00011D07}, {0x00011D0A, 0x00011D0A}, +{0x00011D37, 0x00011D39}, {0x00011D3B, 0x00011D3B}, {0x00011D3E, 0x00011D3E}, {0x00011D48, 0x00011D4F}, +{0x00011D5A, 0x00011D5F}, {0x00011D66, 0x00011D66}, {0x00011D69, 0x00011D69}, {0x00011D8F, 0x00011D8F}, +{0x00011D92, 0x00011D92}, {0x00011D99, 0x00011D9F}, {0x00011DAA, 0x00011EDF}, {0x00011EF9, 0x00011EFF}, +{0x00011F11, 0x00011F11}, {0x00011F3B, 0x00011F3D}, {0x00011F5A, 0x00011FAF}, {0x00011FB1, 0x00011FBF}, {0x00011FF2, 0x00011FFE}, {0x0001239A, 0x000123FF}, {0x0001246F, 0x0001246F}, {0x00012475, 0x0001247F}, -{0x00012544, 0x00012FFF}, {0x0001342F, 0x000143FF}, {0x00014647, 0x000167FF}, {0x00016A39, 0x00016A3F}, -{0x00016A5F, 0x00016A5F}, {0x00016A6A, 0x00016A6D}, {0x00016A70, 0x00016ACF}, {0x00016AEE, 0x00016AEF}, -{0x00016AF6, 0x00016AFF}, {0x00016B46, 0x00016B4F}, {0x00016B5A, 0x00016B5A}, {0x00016B62, 0x00016B62}, -{0x00016B78, 0x00016B7C}, {0x00016B90, 0x00016E3F}, {0x00016E9B, 0x00016EFF}, {0x00016F4B, 0x00016F4E}, -{0x00016F88, 0x00016F8E}, {0x00016FA0, 0x00016FDF}, {0x00016FE5, 0x00016FEF}, {0x00016FF2, 0x00016FFF}, -{0x000187F8, 0x000187FF}, {0x00018CD6, 0x00018CFF}, {0x00018D09, 0x0001AFFF}, {0x0001B11F, 0x0001B14F}, -{0x0001B153, 0x0001B163}, {0x0001B168, 0x0001B16F}, {0x0001B2FC, 0x0001BBFF}, {0x0001BC6B, 0x0001BC6F}, -{0x0001BC7D, 0x0001BC7F}, {0x0001BC89, 0x0001BC8F}, {0x0001BC9A, 0x0001BC9B}, {0x0001BCA0, 0x0001CFFF}, -{0x0001D0F6, 0x0001D0FF}, {0x0001D127, 0x0001D128}, {0x0001D173, 0x0001D17A}, {0x0001D1E9, 0x0001D1FF}, -{0x0001D246, 0x0001D2DF}, {0x0001D2F4, 0x0001D2FF}, {0x0001D357, 0x0001D35F}, {0x0001D379, 0x0001D3FF}, +{0x00012544, 0x00012F8F}, {0x00012FF3, 0x00012FFF}, {0x00013430, 0x0001343F}, {0x00013456, 0x000143FF}, +{0x00014647, 0x000167FF}, {0x00016A39, 0x00016A3F}, {0x00016A5F, 0x00016A5F}, {0x00016A6A, 0x00016A6D}, +{0x00016ABF, 0x00016ABF}, {0x00016ACA, 0x00016ACF}, {0x00016AEE, 0x00016AEF}, {0x00016AF6, 0x00016AFF}, +{0x00016B46, 0x00016B4F}, {0x00016B5A, 0x00016B5A}, {0x00016B62, 0x00016B62}, {0x00016B78, 0x00016B7C}, +{0x00016B90, 0x00016E3F}, {0x00016E9B, 0x00016EFF}, {0x00016F4B, 0x00016F4E}, {0x00016F88, 0x00016F8E}, +{0x00016FA0, 0x00016FDF}, {0x00016FE5, 0x00016FEF}, {0x00016FF2, 0x00016FFF}, {0x000187F8, 0x000187FF}, +{0x00018CD6, 0x00018CFF}, {0x00018D09, 0x0001AFEF}, {0x0001AFF4, 0x0001AFF4}, {0x0001AFFC, 0x0001AFFC}, +{0x0001AFFF, 0x0001AFFF}, {0x0001B123, 0x0001B131}, {0x0001B133, 0x0001B14F}, {0x0001B153, 0x0001B154}, +{0x0001B156, 0x0001B163}, {0x0001B168, 0x0001B16F}, {0x0001B2FC, 0x0001BBFF}, {0x0001BC6B, 0x0001BC6F}, +{0x0001BC7D, 0x0001BC7F}, {0x0001BC89, 0x0001BC8F}, {0x0001BC9A, 0x0001BC9B}, {0x0001BCA0, 0x0001CEFF}, +{0x0001CF2E, 0x0001CF2F}, {0x0001CF47, 0x0001CF4F}, {0x0001CFC4, 0x0001CFFF}, {0x0001D0F6, 0x0001D0FF}, +{0x0001D127, 0x0001D128}, {0x0001D173, 0x0001D17A}, {0x0001D1EB, 0x0001D1FF}, {0x0001D246, 0x0001D2BF}, +{0x0001D2D4, 0x0001D2DF}, {0x0001D2F4, 0x0001D2FF}, {0x0001D357, 0x0001D35F}, {0x0001D379, 0x0001D3FF}, {0x0001D455, 0x0001D455}, {0x0001D49D, 0x0001D49D}, {0x0001D4A0, 0x0001D4A1}, {0x0001D4A3, 0x0001D4A4}, {0x0001D4A7, 0x0001D4A8}, {0x0001D4AD, 0x0001D4AD}, {0x0001D4BA, 0x0001D4BA}, {0x0001D4BC, 0x0001D4BC}, {0x0001D4C4, 0x0001D4C4}, {0x0001D506, 0x0001D506}, {0x0001D50B, 0x0001D50C}, {0x0001D515, 0x0001D515}, {0x0001D51D, 0x0001D51D}, {0x0001D53A, 0x0001D53A}, {0x0001D53F, 0x0001D53F}, {0x0001D545, 0x0001D545}, {0x0001D547, 0x0001D549}, {0x0001D551, 0x0001D551}, {0x0001D6A6, 0x0001D6A7}, {0x0001D7CC, 0x0001D7CD}, -{0x0001DA8C, 0x0001DA9A}, {0x0001DAA0, 0x0001DAA0}, {0x0001DAB0, 0x0001DFFF}, {0x0001E007, 0x0001E007}, -{0x0001E019, 0x0001E01A}, {0x0001E022, 0x0001E022}, {0x0001E025, 0x0001E025}, {0x0001E02B, 0x0001E0FF}, -{0x0001E12D, 0x0001E12F}, {0x0001E13E, 0x0001E13F}, {0x0001E14A, 0x0001E14D}, {0x0001E150, 0x0001E2BF}, -{0x0001E2FA, 0x0001E2FE}, {0x0001E300, 0x0001E7FF}, {0x0001E8C5, 0x0001E8C6}, {0x0001E8D7, 0x0001E8FF}, -{0x0001E94C, 0x0001E94F}, {0x0001E95A, 0x0001E95D}, {0x0001E960, 0x0001EC70}, {0x0001ECB5, 0x0001ED00}, -{0x0001ED3E, 0x0001EDFF}, {0x0001EE04, 0x0001EE04}, {0x0001EE20, 0x0001EE20}, {0x0001EE23, 0x0001EE23}, -{0x0001EE25, 0x0001EE26}, {0x0001EE28, 0x0001EE28}, {0x0001EE33, 0x0001EE33}, {0x0001EE38, 0x0001EE38}, -{0x0001EE3A, 0x0001EE3A}, {0x0001EE3C, 0x0001EE41}, {0x0001EE43, 0x0001EE46}, {0x0001EE48, 0x0001EE48}, -{0x0001EE4A, 0x0001EE4A}, {0x0001EE4C, 0x0001EE4C}, {0x0001EE50, 0x0001EE50}, {0x0001EE53, 0x0001EE53}, -{0x0001EE55, 0x0001EE56}, {0x0001EE58, 0x0001EE58}, {0x0001EE5A, 0x0001EE5A}, {0x0001EE5C, 0x0001EE5C}, -{0x0001EE5E, 0x0001EE5E}, {0x0001EE60, 0x0001EE60}, {0x0001EE63, 0x0001EE63}, {0x0001EE65, 0x0001EE66}, -{0x0001EE6B, 0x0001EE6B}, {0x0001EE73, 0x0001EE73}, {0x0001EE78, 0x0001EE78}, {0x0001EE7D, 0x0001EE7D}, -{0x0001EE7F, 0x0001EE7F}, {0x0001EE8A, 0x0001EE8A}, {0x0001EE9C, 0x0001EEA0}, {0x0001EEA4, 0x0001EEA4}, -{0x0001EEAA, 0x0001EEAA}, {0x0001EEBC, 0x0001EEEF}, {0x0001EEF2, 0x0001EFFF}, {0x0001F02C, 0x0001F02F}, -{0x0001F094, 0x0001F09F}, {0x0001F0AF, 0x0001F0B0}, {0x0001F0C0, 0x0001F0C0}, {0x0001F0D0, 0x0001F0D0}, -{0x0001F0F6, 0x0001F0FF}, {0x0001F1AE, 0x0001F1E5}, {0x0001F203, 0x0001F20F}, {0x0001F23C, 0x0001F23F}, -{0x0001F249, 0x0001F24F}, {0x0001F252, 0x0001F25F}, {0x0001F266, 0x0001F2FF}, {0x0001F6D8, 0x0001F6DF}, -{0x0001F6ED, 0x0001F6EF}, {0x0001F6FD, 0x0001F6FF}, {0x0001F774, 0x0001F77F}, {0x0001F7D9, 0x0001F7DF}, -{0x0001F7EC, 0x0001F7FF}, {0x0001F80C, 0x0001F80F}, {0x0001F848, 0x0001F84F}, {0x0001F85A, 0x0001F85F}, -{0x0001F888, 0x0001F88F}, {0x0001F8AE, 0x0001F8AF}, {0x0001F8B2, 0x0001F8FF}, {0x0001F979, 0x0001F979}, -{0x0001F9CC, 0x0001F9CC}, {0x0001FA54, 0x0001FA5F}, {0x0001FA6E, 0x0001FA6F}, {0x0001FA75, 0x0001FA77}, -{0x0001FA7B, 0x0001FA7F}, {0x0001FA87, 0x0001FA8F}, {0x0001FAA9, 0x0001FAAF}, {0x0001FAB7, 0x0001FABF}, -{0x0001FAC3, 0x0001FACF}, {0x0001FAD7, 0x0001FAFF}, {0x0001FB93, 0x0001FB93}, {0x0001FBCB, 0x0001FBEF}, -{0x0001FBFA, 0x0001FFFF}, {0x0002A6DE, 0x0002A6FF}, {0x0002B735, 0x0002B73F}, {0x0002B81E, 0x0002B81F}, -{0x0002CEA2, 0x0002CEAF}, {0x0002EBE1, 0x0002F7FF}, {0x0002FA1E, 0x0002FFFF}, {0x0003134B, 0x000E00FF}, -{0x000E01F0, 0x0010FFFF}, +{0x0001DA8C, 0x0001DA9A}, {0x0001DAA0, 0x0001DAA0}, {0x0001DAB0, 0x0001DEFF}, {0x0001DF1F, 0x0001DF24}, +{0x0001DF2B, 0x0001DFFF}, {0x0001E007, 0x0001E007}, {0x0001E019, 0x0001E01A}, {0x0001E022, 0x0001E022}, +{0x0001E025, 0x0001E025}, {0x0001E02B, 0x0001E02F}, {0x0001E06E, 0x0001E08E}, {0x0001E090, 0x0001E0FF}, +{0x0001E12D, 0x0001E12F}, {0x0001E13E, 0x0001E13F}, {0x0001E14A, 0x0001E14D}, {0x0001E150, 0x0001E28F}, +{0x0001E2AF, 0x0001E2BF}, {0x0001E2FA, 0x0001E2FE}, {0x0001E300, 0x0001E4CF}, {0x0001E4FA, 0x0001E7DF}, +{0x0001E7E7, 0x0001E7E7}, {0x0001E7EC, 0x0001E7EC}, {0x0001E7EF, 0x0001E7EF}, {0x0001E7FF, 0x0001E7FF}, +{0x0001E8C5, 0x0001E8C6}, {0x0001E8D7, 0x0001E8FF}, {0x0001E94C, 0x0001E94F}, {0x0001E95A, 0x0001E95D}, +{0x0001E960, 0x0001EC70}, {0x0001ECB5, 0x0001ED00}, {0x0001ED3E, 0x0001EDFF}, {0x0001EE04, 0x0001EE04}, +{0x0001EE20, 0x0001EE20}, {0x0001EE23, 0x0001EE23}, {0x0001EE25, 0x0001EE26}, {0x0001EE28, 0x0001EE28}, +{0x0001EE33, 0x0001EE33}, {0x0001EE38, 0x0001EE38}, {0x0001EE3A, 0x0001EE3A}, {0x0001EE3C, 0x0001EE41}, +{0x0001EE43, 0x0001EE46}, {0x0001EE48, 0x0001EE48}, {0x0001EE4A, 0x0001EE4A}, {0x0001EE4C, 0x0001EE4C}, +{0x0001EE50, 0x0001EE50}, {0x0001EE53, 0x0001EE53}, {0x0001EE55, 0x0001EE56}, {0x0001EE58, 0x0001EE58}, +{0x0001EE5A, 0x0001EE5A}, {0x0001EE5C, 0x0001EE5C}, {0x0001EE5E, 0x0001EE5E}, {0x0001EE60, 0x0001EE60}, +{0x0001EE63, 0x0001EE63}, {0x0001EE65, 0x0001EE66}, {0x0001EE6B, 0x0001EE6B}, {0x0001EE73, 0x0001EE73}, +{0x0001EE78, 0x0001EE78}, {0x0001EE7D, 0x0001EE7D}, {0x0001EE7F, 0x0001EE7F}, {0x0001EE8A, 0x0001EE8A}, +{0x0001EE9C, 0x0001EEA0}, {0x0001EEA4, 0x0001EEA4}, {0x0001EEAA, 0x0001EEAA}, {0x0001EEBC, 0x0001EEEF}, +{0x0001EEF2, 0x0001EFFF}, {0x0001F02C, 0x0001F02F}, {0x0001F094, 0x0001F09F}, {0x0001F0AF, 0x0001F0B0}, +{0x0001F0C0, 0x0001F0C0}, {0x0001F0D0, 0x0001F0D0}, {0x0001F0F6, 0x0001F0FF}, {0x0001F1AE, 0x0001F1E5}, +{0x0001F203, 0x0001F20F}, {0x0001F23C, 0x0001F23F}, {0x0001F249, 0x0001F24F}, {0x0001F252, 0x0001F25F}, +{0x0001F266, 0x0001F2FF}, {0x0001F6D8, 0x0001F6DB}, {0x0001F6ED, 0x0001F6EF}, {0x0001F6FD, 0x0001F6FF}, +{0x0001F777, 0x0001F77A}, {0x0001F7DA, 0x0001F7DF}, {0x0001F7EC, 0x0001F7EF}, {0x0001F7F1, 0x0001F7FF}, +{0x0001F80C, 0x0001F80F}, {0x0001F848, 0x0001F84F}, {0x0001F85A, 0x0001F85F}, {0x0001F888, 0x0001F88F}, +{0x0001F8AE, 0x0001F8AF}, {0x0001F8B2, 0x0001F8FF}, {0x0001FA54, 0x0001FA5F}, {0x0001FA6E, 0x0001FA6F}, +{0x0001FA7D, 0x0001FA7F}, {0x0001FA89, 0x0001FA8F}, {0x0001FABE, 0x0001FABE}, {0x0001FAC6, 0x0001FACD}, +{0x0001FADC, 0x0001FADF}, {0x0001FAE9, 0x0001FAEF}, {0x0001FAF9, 0x0001FAFF}, {0x0001FB93, 0x0001FB93}, +{0x0001FBCB, 0x0001FBEF}, {0x0001FBFA, 0x0001FFFF}, {0x0002A6E0, 0x0002A6FF}, {0x0002B73A, 0x0002B73F}, +{0x0002B81E, 0x0002B81F}, {0x0002CEA2, 0x0002CEAF}, {0x0002EBE1, 0x0002EBEF}, {0x0002EE5E, 0x0002F7FF}, +{0x0002FA1E, 0x0002FFFF}, {0x0003134B, 0x0003134F}, {0x000323B0, 0x000E00FF}, {0x000E01F0, 0x0010FFFF}, +}; + +const std::map unicode_map_lowercase = { +{0x00000041, 0x00000061}, {0x00000042, 0x00000062}, {0x00000043, 0x00000063}, {0x00000044, 0x00000064}, +{0x00000045, 0x00000065}, {0x00000046, 0x00000066}, {0x00000047, 0x00000067}, {0x00000048, 0x00000068}, +{0x00000049, 0x00000069}, {0x0000004A, 0x0000006A}, {0x0000004B, 0x0000006B}, {0x0000004C, 0x0000006C}, +{0x0000004D, 0x0000006D}, {0x0000004E, 0x0000006E}, {0x0000004F, 0x0000006F}, {0x00000050, 0x00000070}, +{0x00000051, 0x00000071}, {0x00000052, 0x00000072}, {0x00000053, 0x00000073}, {0x00000054, 0x00000074}, +{0x00000055, 0x00000075}, {0x00000056, 0x00000076}, {0x00000057, 0x00000077}, {0x00000058, 0x00000078}, +{0x00000059, 0x00000079}, {0x0000005A, 0x0000007A}, {0x000000C0, 0x000000E0}, {0x000000C1, 0x000000E1}, +{0x000000C2, 0x000000E2}, {0x000000C3, 0x000000E3}, {0x000000C4, 0x000000E4}, {0x000000C5, 0x000000E5}, +{0x000000C6, 0x000000E6}, {0x000000C7, 0x000000E7}, {0x000000C8, 0x000000E8}, {0x000000C9, 0x000000E9}, +{0x000000CA, 0x000000EA}, {0x000000CB, 0x000000EB}, {0x000000CC, 0x000000EC}, {0x000000CD, 0x000000ED}, +{0x000000CE, 0x000000EE}, {0x000000CF, 0x000000EF}, {0x000000D0, 0x000000F0}, {0x000000D1, 0x000000F1}, +{0x000000D2, 0x000000F2}, {0x000000D3, 0x000000F3}, {0x000000D4, 0x000000F4}, {0x000000D5, 0x000000F5}, +{0x000000D6, 0x000000F6}, {0x000000D8, 0x000000F8}, {0x000000D9, 0x000000F9}, {0x000000DA, 0x000000FA}, +{0x000000DB, 0x000000FB}, {0x000000DC, 0x000000FC}, {0x000000DD, 0x000000FD}, {0x000000DE, 0x000000FE}, +{0x00000100, 0x00000101}, {0x00000102, 0x00000103}, {0x00000104, 0x00000105}, {0x00000106, 0x00000107}, +{0x00000108, 0x00000109}, {0x0000010A, 0x0000010B}, {0x0000010C, 0x0000010D}, {0x0000010E, 0x0000010F}, +{0x00000110, 0x00000111}, {0x00000112, 0x00000113}, {0x00000114, 0x00000115}, {0x00000116, 0x00000117}, +{0x00000118, 0x00000119}, {0x0000011A, 0x0000011B}, {0x0000011C, 0x0000011D}, {0x0000011E, 0x0000011F}, +{0x00000120, 0x00000121}, {0x00000122, 0x00000123}, {0x00000124, 0x00000125}, {0x00000126, 0x00000127}, +{0x00000128, 0x00000129}, {0x0000012A, 0x0000012B}, {0x0000012C, 0x0000012D}, {0x0000012E, 0x0000012F}, +{0x00000130, 0x00000069}, {0x00000132, 0x00000133}, {0x00000134, 0x00000135}, {0x00000136, 0x00000137}, +{0x00000139, 0x0000013A}, {0x0000013B, 0x0000013C}, {0x0000013D, 0x0000013E}, {0x0000013F, 0x00000140}, +{0x00000141, 0x00000142}, {0x00000143, 0x00000144}, {0x00000145, 0x00000146}, {0x00000147, 0x00000148}, +{0x0000014A, 0x0000014B}, {0x0000014C, 0x0000014D}, {0x0000014E, 0x0000014F}, {0x00000150, 0x00000151}, +{0x00000152, 0x00000153}, {0x00000154, 0x00000155}, {0x00000156, 0x00000157}, {0x00000158, 0x00000159}, +{0x0000015A, 0x0000015B}, {0x0000015C, 0x0000015D}, {0x0000015E, 0x0000015F}, {0x00000160, 0x00000161}, +{0x00000162, 0x00000163}, {0x00000164, 0x00000165}, {0x00000166, 0x00000167}, {0x00000168, 0x00000169}, +{0x0000016A, 0x0000016B}, {0x0000016C, 0x0000016D}, {0x0000016E, 0x0000016F}, {0x00000170, 0x00000171}, +{0x00000172, 0x00000173}, {0x00000174, 0x00000175}, {0x00000176, 0x00000177}, {0x00000178, 0x000000FF}, +{0x00000179, 0x0000017A}, {0x0000017B, 0x0000017C}, {0x0000017D, 0x0000017E}, {0x00000181, 0x00000253}, +{0x00000182, 0x00000183}, {0x00000184, 0x00000185}, {0x00000186, 0x00000254}, {0x00000187, 0x00000188}, +{0x00000189, 0x00000256}, {0x0000018A, 0x00000257}, {0x0000018B, 0x0000018C}, {0x0000018E, 0x000001DD}, +{0x0000018F, 0x00000259}, {0x00000190, 0x0000025B}, {0x00000191, 0x00000192}, {0x00000193, 0x00000260}, +{0x00000194, 0x00000263}, {0x00000196, 0x00000269}, {0x00000197, 0x00000268}, {0x00000198, 0x00000199}, +{0x0000019C, 0x0000026F}, {0x0000019D, 0x00000272}, {0x0000019F, 0x00000275}, {0x000001A0, 0x000001A1}, +{0x000001A2, 0x000001A3}, {0x000001A4, 0x000001A5}, {0x000001A6, 0x00000280}, {0x000001A7, 0x000001A8}, +{0x000001A9, 0x00000283}, {0x000001AC, 0x000001AD}, {0x000001AE, 0x00000288}, {0x000001AF, 0x000001B0}, +{0x000001B1, 0x0000028A}, {0x000001B2, 0x0000028B}, {0x000001B3, 0x000001B4}, {0x000001B5, 0x000001B6}, +{0x000001B7, 0x00000292}, {0x000001B8, 0x000001B9}, {0x000001BC, 0x000001BD}, {0x000001C4, 0x000001C6}, +{0x000001C5, 0x000001C6}, {0x000001C7, 0x000001C9}, {0x000001C8, 0x000001C9}, {0x000001CA, 0x000001CC}, +{0x000001CB, 0x000001CC}, {0x000001CD, 0x000001CE}, {0x000001CF, 0x000001D0}, {0x000001D1, 0x000001D2}, +{0x000001D3, 0x000001D4}, {0x000001D5, 0x000001D6}, {0x000001D7, 0x000001D8}, {0x000001D9, 0x000001DA}, +{0x000001DB, 0x000001DC}, {0x000001DE, 0x000001DF}, {0x000001E0, 0x000001E1}, {0x000001E2, 0x000001E3}, +{0x000001E4, 0x000001E5}, {0x000001E6, 0x000001E7}, {0x000001E8, 0x000001E9}, {0x000001EA, 0x000001EB}, +{0x000001EC, 0x000001ED}, {0x000001EE, 0x000001EF}, {0x000001F1, 0x000001F3}, {0x000001F2, 0x000001F3}, +{0x000001F4, 0x000001F5}, {0x000001F6, 0x00000195}, {0x000001F7, 0x000001BF}, {0x000001F8, 0x000001F9}, +{0x000001FA, 0x000001FB}, {0x000001FC, 0x000001FD}, {0x000001FE, 0x000001FF}, {0x00000200, 0x00000201}, +{0x00000202, 0x00000203}, {0x00000204, 0x00000205}, {0x00000206, 0x00000207}, {0x00000208, 0x00000209}, +{0x0000020A, 0x0000020B}, {0x0000020C, 0x0000020D}, {0x0000020E, 0x0000020F}, {0x00000210, 0x00000211}, +{0x00000212, 0x00000213}, {0x00000214, 0x00000215}, {0x00000216, 0x00000217}, {0x00000218, 0x00000219}, +{0x0000021A, 0x0000021B}, {0x0000021C, 0x0000021D}, {0x0000021E, 0x0000021F}, {0x00000220, 0x0000019E}, +{0x00000222, 0x00000223}, {0x00000224, 0x00000225}, {0x00000226, 0x00000227}, {0x00000228, 0x00000229}, +{0x0000022A, 0x0000022B}, {0x0000022C, 0x0000022D}, {0x0000022E, 0x0000022F}, {0x00000230, 0x00000231}, +{0x00000232, 0x00000233}, {0x0000023A, 0x00002C65}, {0x0000023B, 0x0000023C}, {0x0000023D, 0x0000019A}, +{0x0000023E, 0x00002C66}, {0x00000241, 0x00000242}, {0x00000243, 0x00000180}, {0x00000244, 0x00000289}, +{0x00000245, 0x0000028C}, {0x00000246, 0x00000247}, {0x00000248, 0x00000249}, {0x0000024A, 0x0000024B}, +{0x0000024C, 0x0000024D}, {0x0000024E, 0x0000024F}, {0x00000370, 0x00000371}, {0x00000372, 0x00000373}, +{0x00000376, 0x00000377}, {0x0000037F, 0x000003F3}, {0x00000386, 0x000003AC}, {0x00000388, 0x000003AD}, +{0x00000389, 0x000003AE}, {0x0000038A, 0x000003AF}, {0x0000038C, 0x000003CC}, {0x0000038E, 0x000003CD}, +{0x0000038F, 0x000003CE}, {0x00000391, 0x000003B1}, {0x00000392, 0x000003B2}, {0x00000393, 0x000003B3}, +{0x00000394, 0x000003B4}, {0x00000395, 0x000003B5}, {0x00000396, 0x000003B6}, {0x00000397, 0x000003B7}, +{0x00000398, 0x000003B8}, {0x00000399, 0x000003B9}, {0x0000039A, 0x000003BA}, {0x0000039B, 0x000003BB}, +{0x0000039C, 0x000003BC}, {0x0000039D, 0x000003BD}, {0x0000039E, 0x000003BE}, {0x0000039F, 0x000003BF}, +{0x000003A0, 0x000003C0}, {0x000003A1, 0x000003C1}, {0x000003A3, 0x000003C3}, {0x000003A4, 0x000003C4}, +{0x000003A5, 0x000003C5}, {0x000003A6, 0x000003C6}, {0x000003A7, 0x000003C7}, {0x000003A8, 0x000003C8}, +{0x000003A9, 0x000003C9}, {0x000003AA, 0x000003CA}, {0x000003AB, 0x000003CB}, {0x000003CF, 0x000003D7}, +{0x000003D8, 0x000003D9}, {0x000003DA, 0x000003DB}, {0x000003DC, 0x000003DD}, {0x000003DE, 0x000003DF}, +{0x000003E0, 0x000003E1}, {0x000003E2, 0x000003E3}, {0x000003E4, 0x000003E5}, {0x000003E6, 0x000003E7}, +{0x000003E8, 0x000003E9}, {0x000003EA, 0x000003EB}, {0x000003EC, 0x000003ED}, {0x000003EE, 0x000003EF}, +{0x000003F4, 0x000003B8}, {0x000003F7, 0x000003F8}, {0x000003F9, 0x000003F2}, {0x000003FA, 0x000003FB}, +{0x000003FD, 0x0000037B}, {0x000003FE, 0x0000037C}, {0x000003FF, 0x0000037D}, {0x00000400, 0x00000450}, +{0x00000401, 0x00000451}, {0x00000402, 0x00000452}, {0x00000403, 0x00000453}, {0x00000404, 0x00000454}, +{0x00000405, 0x00000455}, {0x00000406, 0x00000456}, {0x00000407, 0x00000457}, {0x00000408, 0x00000458}, +{0x00000409, 0x00000459}, {0x0000040A, 0x0000045A}, {0x0000040B, 0x0000045B}, {0x0000040C, 0x0000045C}, +{0x0000040D, 0x0000045D}, {0x0000040E, 0x0000045E}, {0x0000040F, 0x0000045F}, {0x00000410, 0x00000430}, +{0x00000411, 0x00000431}, {0x00000412, 0x00000432}, {0x00000413, 0x00000433}, {0x00000414, 0x00000434}, +{0x00000415, 0x00000435}, {0x00000416, 0x00000436}, {0x00000417, 0x00000437}, {0x00000418, 0x00000438}, +{0x00000419, 0x00000439}, {0x0000041A, 0x0000043A}, {0x0000041B, 0x0000043B}, {0x0000041C, 0x0000043C}, +{0x0000041D, 0x0000043D}, {0x0000041E, 0x0000043E}, {0x0000041F, 0x0000043F}, {0x00000420, 0x00000440}, +{0x00000421, 0x00000441}, {0x00000422, 0x00000442}, {0x00000423, 0x00000443}, {0x00000424, 0x00000444}, +{0x00000425, 0x00000445}, {0x00000426, 0x00000446}, {0x00000427, 0x00000447}, {0x00000428, 0x00000448}, +{0x00000429, 0x00000449}, {0x0000042A, 0x0000044A}, {0x0000042B, 0x0000044B}, {0x0000042C, 0x0000044C}, +{0x0000042D, 0x0000044D}, {0x0000042E, 0x0000044E}, {0x0000042F, 0x0000044F}, {0x00000460, 0x00000461}, +{0x00000462, 0x00000463}, {0x00000464, 0x00000465}, {0x00000466, 0x00000467}, {0x00000468, 0x00000469}, +{0x0000046A, 0x0000046B}, {0x0000046C, 0x0000046D}, {0x0000046E, 0x0000046F}, {0x00000470, 0x00000471}, +{0x00000472, 0x00000473}, {0x00000474, 0x00000475}, {0x00000476, 0x00000477}, {0x00000478, 0x00000479}, +{0x0000047A, 0x0000047B}, {0x0000047C, 0x0000047D}, {0x0000047E, 0x0000047F}, {0x00000480, 0x00000481}, +{0x0000048A, 0x0000048B}, {0x0000048C, 0x0000048D}, {0x0000048E, 0x0000048F}, {0x00000490, 0x00000491}, +{0x00000492, 0x00000493}, {0x00000494, 0x00000495}, {0x00000496, 0x00000497}, {0x00000498, 0x00000499}, +{0x0000049A, 0x0000049B}, {0x0000049C, 0x0000049D}, {0x0000049E, 0x0000049F}, {0x000004A0, 0x000004A1}, +{0x000004A2, 0x000004A3}, {0x000004A4, 0x000004A5}, {0x000004A6, 0x000004A7}, {0x000004A8, 0x000004A9}, +{0x000004AA, 0x000004AB}, {0x000004AC, 0x000004AD}, {0x000004AE, 0x000004AF}, {0x000004B0, 0x000004B1}, +{0x000004B2, 0x000004B3}, {0x000004B4, 0x000004B5}, {0x000004B6, 0x000004B7}, {0x000004B8, 0x000004B9}, +{0x000004BA, 0x000004BB}, {0x000004BC, 0x000004BD}, {0x000004BE, 0x000004BF}, {0x000004C0, 0x000004CF}, +{0x000004C1, 0x000004C2}, {0x000004C3, 0x000004C4}, {0x000004C5, 0x000004C6}, {0x000004C7, 0x000004C8}, +{0x000004C9, 0x000004CA}, {0x000004CB, 0x000004CC}, {0x000004CD, 0x000004CE}, {0x000004D0, 0x000004D1}, +{0x000004D2, 0x000004D3}, {0x000004D4, 0x000004D5}, {0x000004D6, 0x000004D7}, {0x000004D8, 0x000004D9}, +{0x000004DA, 0x000004DB}, {0x000004DC, 0x000004DD}, {0x000004DE, 0x000004DF}, {0x000004E0, 0x000004E1}, +{0x000004E2, 0x000004E3}, {0x000004E4, 0x000004E5}, {0x000004E6, 0x000004E7}, {0x000004E8, 0x000004E9}, +{0x000004EA, 0x000004EB}, {0x000004EC, 0x000004ED}, {0x000004EE, 0x000004EF}, {0x000004F0, 0x000004F1}, +{0x000004F2, 0x000004F3}, {0x000004F4, 0x000004F5}, {0x000004F6, 0x000004F7}, {0x000004F8, 0x000004F9}, +{0x000004FA, 0x000004FB}, {0x000004FC, 0x000004FD}, {0x000004FE, 0x000004FF}, {0x00000500, 0x00000501}, +{0x00000502, 0x00000503}, {0x00000504, 0x00000505}, {0x00000506, 0x00000507}, {0x00000508, 0x00000509}, +{0x0000050A, 0x0000050B}, {0x0000050C, 0x0000050D}, {0x0000050E, 0x0000050F}, {0x00000510, 0x00000511}, +{0x00000512, 0x00000513}, {0x00000514, 0x00000515}, {0x00000516, 0x00000517}, {0x00000518, 0x00000519}, +{0x0000051A, 0x0000051B}, {0x0000051C, 0x0000051D}, {0x0000051E, 0x0000051F}, {0x00000520, 0x00000521}, +{0x00000522, 0x00000523}, {0x00000524, 0x00000525}, {0x00000526, 0x00000527}, {0x00000528, 0x00000529}, +{0x0000052A, 0x0000052B}, {0x0000052C, 0x0000052D}, {0x0000052E, 0x0000052F}, {0x00000531, 0x00000561}, +{0x00000532, 0x00000562}, {0x00000533, 0x00000563}, {0x00000534, 0x00000564}, {0x00000535, 0x00000565}, +{0x00000536, 0x00000566}, {0x00000537, 0x00000567}, {0x00000538, 0x00000568}, {0x00000539, 0x00000569}, +{0x0000053A, 0x0000056A}, {0x0000053B, 0x0000056B}, {0x0000053C, 0x0000056C}, {0x0000053D, 0x0000056D}, +{0x0000053E, 0x0000056E}, {0x0000053F, 0x0000056F}, {0x00000540, 0x00000570}, {0x00000541, 0x00000571}, +{0x00000542, 0x00000572}, {0x00000543, 0x00000573}, {0x00000544, 0x00000574}, {0x00000545, 0x00000575}, +{0x00000546, 0x00000576}, {0x00000547, 0x00000577}, {0x00000548, 0x00000578}, {0x00000549, 0x00000579}, +{0x0000054A, 0x0000057A}, {0x0000054B, 0x0000057B}, {0x0000054C, 0x0000057C}, {0x0000054D, 0x0000057D}, +{0x0000054E, 0x0000057E}, {0x0000054F, 0x0000057F}, {0x00000550, 0x00000580}, {0x00000551, 0x00000581}, +{0x00000552, 0x00000582}, {0x00000553, 0x00000583}, {0x00000554, 0x00000584}, {0x00000555, 0x00000585}, +{0x00000556, 0x00000586}, {0x000010A0, 0x00002D00}, {0x000010A1, 0x00002D01}, {0x000010A2, 0x00002D02}, +{0x000010A3, 0x00002D03}, {0x000010A4, 0x00002D04}, {0x000010A5, 0x00002D05}, {0x000010A6, 0x00002D06}, +{0x000010A7, 0x00002D07}, {0x000010A8, 0x00002D08}, {0x000010A9, 0x00002D09}, {0x000010AA, 0x00002D0A}, +{0x000010AB, 0x00002D0B}, {0x000010AC, 0x00002D0C}, {0x000010AD, 0x00002D0D}, {0x000010AE, 0x00002D0E}, +{0x000010AF, 0x00002D0F}, {0x000010B0, 0x00002D10}, {0x000010B1, 0x00002D11}, {0x000010B2, 0x00002D12}, +{0x000010B3, 0x00002D13}, {0x000010B4, 0x00002D14}, {0x000010B5, 0x00002D15}, {0x000010B6, 0x00002D16}, +{0x000010B7, 0x00002D17}, {0x000010B8, 0x00002D18}, {0x000010B9, 0x00002D19}, {0x000010BA, 0x00002D1A}, +{0x000010BB, 0x00002D1B}, {0x000010BC, 0x00002D1C}, {0x000010BD, 0x00002D1D}, {0x000010BE, 0x00002D1E}, +{0x000010BF, 0x00002D1F}, {0x000010C0, 0x00002D20}, {0x000010C1, 0x00002D21}, {0x000010C2, 0x00002D22}, +{0x000010C3, 0x00002D23}, {0x000010C4, 0x00002D24}, {0x000010C5, 0x00002D25}, {0x000010C7, 0x00002D27}, +{0x000010CD, 0x00002D2D}, {0x000013A0, 0x0000AB70}, {0x000013A1, 0x0000AB71}, {0x000013A2, 0x0000AB72}, +{0x000013A3, 0x0000AB73}, {0x000013A4, 0x0000AB74}, {0x000013A5, 0x0000AB75}, {0x000013A6, 0x0000AB76}, +{0x000013A7, 0x0000AB77}, {0x000013A8, 0x0000AB78}, {0x000013A9, 0x0000AB79}, {0x000013AA, 0x0000AB7A}, +{0x000013AB, 0x0000AB7B}, {0x000013AC, 0x0000AB7C}, {0x000013AD, 0x0000AB7D}, {0x000013AE, 0x0000AB7E}, +{0x000013AF, 0x0000AB7F}, {0x000013B0, 0x0000AB80}, {0x000013B1, 0x0000AB81}, {0x000013B2, 0x0000AB82}, +{0x000013B3, 0x0000AB83}, {0x000013B4, 0x0000AB84}, {0x000013B5, 0x0000AB85}, {0x000013B6, 0x0000AB86}, +{0x000013B7, 0x0000AB87}, {0x000013B8, 0x0000AB88}, {0x000013B9, 0x0000AB89}, {0x000013BA, 0x0000AB8A}, +{0x000013BB, 0x0000AB8B}, {0x000013BC, 0x0000AB8C}, {0x000013BD, 0x0000AB8D}, {0x000013BE, 0x0000AB8E}, +{0x000013BF, 0x0000AB8F}, {0x000013C0, 0x0000AB90}, {0x000013C1, 0x0000AB91}, {0x000013C2, 0x0000AB92}, +{0x000013C3, 0x0000AB93}, {0x000013C4, 0x0000AB94}, {0x000013C5, 0x0000AB95}, {0x000013C6, 0x0000AB96}, +{0x000013C7, 0x0000AB97}, {0x000013C8, 0x0000AB98}, {0x000013C9, 0x0000AB99}, {0x000013CA, 0x0000AB9A}, +{0x000013CB, 0x0000AB9B}, {0x000013CC, 0x0000AB9C}, {0x000013CD, 0x0000AB9D}, {0x000013CE, 0x0000AB9E}, +{0x000013CF, 0x0000AB9F}, {0x000013D0, 0x0000ABA0}, {0x000013D1, 0x0000ABA1}, {0x000013D2, 0x0000ABA2}, +{0x000013D3, 0x0000ABA3}, {0x000013D4, 0x0000ABA4}, {0x000013D5, 0x0000ABA5}, {0x000013D6, 0x0000ABA6}, +{0x000013D7, 0x0000ABA7}, {0x000013D8, 0x0000ABA8}, {0x000013D9, 0x0000ABA9}, {0x000013DA, 0x0000ABAA}, +{0x000013DB, 0x0000ABAB}, {0x000013DC, 0x0000ABAC}, {0x000013DD, 0x0000ABAD}, {0x000013DE, 0x0000ABAE}, +{0x000013DF, 0x0000ABAF}, {0x000013E0, 0x0000ABB0}, {0x000013E1, 0x0000ABB1}, {0x000013E2, 0x0000ABB2}, +{0x000013E3, 0x0000ABB3}, {0x000013E4, 0x0000ABB4}, {0x000013E5, 0x0000ABB5}, {0x000013E6, 0x0000ABB6}, +{0x000013E7, 0x0000ABB7}, {0x000013E8, 0x0000ABB8}, {0x000013E9, 0x0000ABB9}, {0x000013EA, 0x0000ABBA}, +{0x000013EB, 0x0000ABBB}, {0x000013EC, 0x0000ABBC}, {0x000013ED, 0x0000ABBD}, {0x000013EE, 0x0000ABBE}, +{0x000013EF, 0x0000ABBF}, {0x000013F0, 0x000013F8}, {0x000013F1, 0x000013F9}, {0x000013F2, 0x000013FA}, +{0x000013F3, 0x000013FB}, {0x000013F4, 0x000013FC}, {0x000013F5, 0x000013FD}, {0x00001C90, 0x000010D0}, +{0x00001C91, 0x000010D1}, {0x00001C92, 0x000010D2}, {0x00001C93, 0x000010D3}, {0x00001C94, 0x000010D4}, +{0x00001C95, 0x000010D5}, {0x00001C96, 0x000010D6}, {0x00001C97, 0x000010D7}, {0x00001C98, 0x000010D8}, +{0x00001C99, 0x000010D9}, {0x00001C9A, 0x000010DA}, {0x00001C9B, 0x000010DB}, {0x00001C9C, 0x000010DC}, +{0x00001C9D, 0x000010DD}, {0x00001C9E, 0x000010DE}, {0x00001C9F, 0x000010DF}, {0x00001CA0, 0x000010E0}, +{0x00001CA1, 0x000010E1}, {0x00001CA2, 0x000010E2}, {0x00001CA3, 0x000010E3}, {0x00001CA4, 0x000010E4}, +{0x00001CA5, 0x000010E5}, {0x00001CA6, 0x000010E6}, {0x00001CA7, 0x000010E7}, {0x00001CA8, 0x000010E8}, +{0x00001CA9, 0x000010E9}, {0x00001CAA, 0x000010EA}, {0x00001CAB, 0x000010EB}, {0x00001CAC, 0x000010EC}, +{0x00001CAD, 0x000010ED}, {0x00001CAE, 0x000010EE}, {0x00001CAF, 0x000010EF}, {0x00001CB0, 0x000010F0}, +{0x00001CB1, 0x000010F1}, {0x00001CB2, 0x000010F2}, {0x00001CB3, 0x000010F3}, {0x00001CB4, 0x000010F4}, +{0x00001CB5, 0x000010F5}, {0x00001CB6, 0x000010F6}, {0x00001CB7, 0x000010F7}, {0x00001CB8, 0x000010F8}, +{0x00001CB9, 0x000010F9}, {0x00001CBA, 0x000010FA}, {0x00001CBD, 0x000010FD}, {0x00001CBE, 0x000010FE}, +{0x00001CBF, 0x000010FF}, {0x00001E00, 0x00001E01}, {0x00001E02, 0x00001E03}, {0x00001E04, 0x00001E05}, +{0x00001E06, 0x00001E07}, {0x00001E08, 0x00001E09}, {0x00001E0A, 0x00001E0B}, {0x00001E0C, 0x00001E0D}, +{0x00001E0E, 0x00001E0F}, {0x00001E10, 0x00001E11}, {0x00001E12, 0x00001E13}, {0x00001E14, 0x00001E15}, +{0x00001E16, 0x00001E17}, {0x00001E18, 0x00001E19}, {0x00001E1A, 0x00001E1B}, {0x00001E1C, 0x00001E1D}, +{0x00001E1E, 0x00001E1F}, {0x00001E20, 0x00001E21}, {0x00001E22, 0x00001E23}, {0x00001E24, 0x00001E25}, +{0x00001E26, 0x00001E27}, {0x00001E28, 0x00001E29}, {0x00001E2A, 0x00001E2B}, {0x00001E2C, 0x00001E2D}, +{0x00001E2E, 0x00001E2F}, {0x00001E30, 0x00001E31}, {0x00001E32, 0x00001E33}, {0x00001E34, 0x00001E35}, +{0x00001E36, 0x00001E37}, {0x00001E38, 0x00001E39}, {0x00001E3A, 0x00001E3B}, {0x00001E3C, 0x00001E3D}, +{0x00001E3E, 0x00001E3F}, {0x00001E40, 0x00001E41}, {0x00001E42, 0x00001E43}, {0x00001E44, 0x00001E45}, +{0x00001E46, 0x00001E47}, {0x00001E48, 0x00001E49}, {0x00001E4A, 0x00001E4B}, {0x00001E4C, 0x00001E4D}, +{0x00001E4E, 0x00001E4F}, {0x00001E50, 0x00001E51}, {0x00001E52, 0x00001E53}, {0x00001E54, 0x00001E55}, +{0x00001E56, 0x00001E57}, {0x00001E58, 0x00001E59}, {0x00001E5A, 0x00001E5B}, {0x00001E5C, 0x00001E5D}, +{0x00001E5E, 0x00001E5F}, {0x00001E60, 0x00001E61}, {0x00001E62, 0x00001E63}, {0x00001E64, 0x00001E65}, +{0x00001E66, 0x00001E67}, {0x00001E68, 0x00001E69}, {0x00001E6A, 0x00001E6B}, {0x00001E6C, 0x00001E6D}, +{0x00001E6E, 0x00001E6F}, {0x00001E70, 0x00001E71}, {0x00001E72, 0x00001E73}, {0x00001E74, 0x00001E75}, +{0x00001E76, 0x00001E77}, {0x00001E78, 0x00001E79}, {0x00001E7A, 0x00001E7B}, {0x00001E7C, 0x00001E7D}, +{0x00001E7E, 0x00001E7F}, {0x00001E80, 0x00001E81}, {0x00001E82, 0x00001E83}, {0x00001E84, 0x00001E85}, +{0x00001E86, 0x00001E87}, {0x00001E88, 0x00001E89}, {0x00001E8A, 0x00001E8B}, {0x00001E8C, 0x00001E8D}, +{0x00001E8E, 0x00001E8F}, {0x00001E90, 0x00001E91}, {0x00001E92, 0x00001E93}, {0x00001E94, 0x00001E95}, +{0x00001E9E, 0x000000DF}, {0x00001EA0, 0x00001EA1}, {0x00001EA2, 0x00001EA3}, {0x00001EA4, 0x00001EA5}, +{0x00001EA6, 0x00001EA7}, {0x00001EA8, 0x00001EA9}, {0x00001EAA, 0x00001EAB}, {0x00001EAC, 0x00001EAD}, +{0x00001EAE, 0x00001EAF}, {0x00001EB0, 0x00001EB1}, {0x00001EB2, 0x00001EB3}, {0x00001EB4, 0x00001EB5}, +{0x00001EB6, 0x00001EB7}, {0x00001EB8, 0x00001EB9}, {0x00001EBA, 0x00001EBB}, {0x00001EBC, 0x00001EBD}, +{0x00001EBE, 0x00001EBF}, {0x00001EC0, 0x00001EC1}, {0x00001EC2, 0x00001EC3}, {0x00001EC4, 0x00001EC5}, +{0x00001EC6, 0x00001EC7}, {0x00001EC8, 0x00001EC9}, {0x00001ECA, 0x00001ECB}, {0x00001ECC, 0x00001ECD}, +{0x00001ECE, 0x00001ECF}, {0x00001ED0, 0x00001ED1}, {0x00001ED2, 0x00001ED3}, {0x00001ED4, 0x00001ED5}, +{0x00001ED6, 0x00001ED7}, {0x00001ED8, 0x00001ED9}, {0x00001EDA, 0x00001EDB}, {0x00001EDC, 0x00001EDD}, +{0x00001EDE, 0x00001EDF}, {0x00001EE0, 0x00001EE1}, {0x00001EE2, 0x00001EE3}, {0x00001EE4, 0x00001EE5}, +{0x00001EE6, 0x00001EE7}, {0x00001EE8, 0x00001EE9}, {0x00001EEA, 0x00001EEB}, {0x00001EEC, 0x00001EED}, +{0x00001EEE, 0x00001EEF}, {0x00001EF0, 0x00001EF1}, {0x00001EF2, 0x00001EF3}, {0x00001EF4, 0x00001EF5}, +{0x00001EF6, 0x00001EF7}, {0x00001EF8, 0x00001EF9}, {0x00001EFA, 0x00001EFB}, {0x00001EFC, 0x00001EFD}, +{0x00001EFE, 0x00001EFF}, {0x00001F08, 0x00001F00}, {0x00001F09, 0x00001F01}, {0x00001F0A, 0x00001F02}, +{0x00001F0B, 0x00001F03}, {0x00001F0C, 0x00001F04}, {0x00001F0D, 0x00001F05}, {0x00001F0E, 0x00001F06}, +{0x00001F0F, 0x00001F07}, {0x00001F18, 0x00001F10}, {0x00001F19, 0x00001F11}, {0x00001F1A, 0x00001F12}, +{0x00001F1B, 0x00001F13}, {0x00001F1C, 0x00001F14}, {0x00001F1D, 0x00001F15}, {0x00001F28, 0x00001F20}, +{0x00001F29, 0x00001F21}, {0x00001F2A, 0x00001F22}, {0x00001F2B, 0x00001F23}, {0x00001F2C, 0x00001F24}, +{0x00001F2D, 0x00001F25}, {0x00001F2E, 0x00001F26}, {0x00001F2F, 0x00001F27}, {0x00001F38, 0x00001F30}, +{0x00001F39, 0x00001F31}, {0x00001F3A, 0x00001F32}, {0x00001F3B, 0x00001F33}, {0x00001F3C, 0x00001F34}, +{0x00001F3D, 0x00001F35}, {0x00001F3E, 0x00001F36}, {0x00001F3F, 0x00001F37}, {0x00001F48, 0x00001F40}, +{0x00001F49, 0x00001F41}, {0x00001F4A, 0x00001F42}, {0x00001F4B, 0x00001F43}, {0x00001F4C, 0x00001F44}, +{0x00001F4D, 0x00001F45}, {0x00001F59, 0x00001F51}, {0x00001F5B, 0x00001F53}, {0x00001F5D, 0x00001F55}, +{0x00001F5F, 0x00001F57}, {0x00001F68, 0x00001F60}, {0x00001F69, 0x00001F61}, {0x00001F6A, 0x00001F62}, +{0x00001F6B, 0x00001F63}, {0x00001F6C, 0x00001F64}, {0x00001F6D, 0x00001F65}, {0x00001F6E, 0x00001F66}, +{0x00001F6F, 0x00001F67}, {0x00001F88, 0x00001F80}, {0x00001F89, 0x00001F81}, {0x00001F8A, 0x00001F82}, +{0x00001F8B, 0x00001F83}, {0x00001F8C, 0x00001F84}, {0x00001F8D, 0x00001F85}, {0x00001F8E, 0x00001F86}, +{0x00001F8F, 0x00001F87}, {0x00001F98, 0x00001F90}, {0x00001F99, 0x00001F91}, {0x00001F9A, 0x00001F92}, +{0x00001F9B, 0x00001F93}, {0x00001F9C, 0x00001F94}, {0x00001F9D, 0x00001F95}, {0x00001F9E, 0x00001F96}, +{0x00001F9F, 0x00001F97}, {0x00001FA8, 0x00001FA0}, {0x00001FA9, 0x00001FA1}, {0x00001FAA, 0x00001FA2}, +{0x00001FAB, 0x00001FA3}, {0x00001FAC, 0x00001FA4}, {0x00001FAD, 0x00001FA5}, {0x00001FAE, 0x00001FA6}, +{0x00001FAF, 0x00001FA7}, {0x00001FB8, 0x00001FB0}, {0x00001FB9, 0x00001FB1}, {0x00001FBA, 0x00001F70}, +{0x00001FBB, 0x00001F71}, {0x00001FBC, 0x00001FB3}, {0x00001FC8, 0x00001F72}, {0x00001FC9, 0x00001F73}, +{0x00001FCA, 0x00001F74}, {0x00001FCB, 0x00001F75}, {0x00001FCC, 0x00001FC3}, {0x00001FD8, 0x00001FD0}, +{0x00001FD9, 0x00001FD1}, {0x00001FDA, 0x00001F76}, {0x00001FDB, 0x00001F77}, {0x00001FE8, 0x00001FE0}, +{0x00001FE9, 0x00001FE1}, {0x00001FEA, 0x00001F7A}, {0x00001FEB, 0x00001F7B}, {0x00001FEC, 0x00001FE5}, +{0x00001FF8, 0x00001F78}, {0x00001FF9, 0x00001F79}, {0x00001FFA, 0x00001F7C}, {0x00001FFB, 0x00001F7D}, +{0x00001FFC, 0x00001FF3}, {0x00002126, 0x000003C9}, {0x0000212A, 0x0000006B}, {0x0000212B, 0x000000E5}, +{0x00002132, 0x0000214E}, {0x00002160, 0x00002170}, {0x00002161, 0x00002171}, {0x00002162, 0x00002172}, +{0x00002163, 0x00002173}, {0x00002164, 0x00002174}, {0x00002165, 0x00002175}, {0x00002166, 0x00002176}, +{0x00002167, 0x00002177}, {0x00002168, 0x00002178}, {0x00002169, 0x00002179}, {0x0000216A, 0x0000217A}, +{0x0000216B, 0x0000217B}, {0x0000216C, 0x0000217C}, {0x0000216D, 0x0000217D}, {0x0000216E, 0x0000217E}, +{0x0000216F, 0x0000217F}, {0x00002183, 0x00002184}, {0x000024B6, 0x000024D0}, {0x000024B7, 0x000024D1}, +{0x000024B8, 0x000024D2}, {0x000024B9, 0x000024D3}, {0x000024BA, 0x000024D4}, {0x000024BB, 0x000024D5}, +{0x000024BC, 0x000024D6}, {0x000024BD, 0x000024D7}, {0x000024BE, 0x000024D8}, {0x000024BF, 0x000024D9}, +{0x000024C0, 0x000024DA}, {0x000024C1, 0x000024DB}, {0x000024C2, 0x000024DC}, {0x000024C3, 0x000024DD}, +{0x000024C4, 0x000024DE}, {0x000024C5, 0x000024DF}, {0x000024C6, 0x000024E0}, {0x000024C7, 0x000024E1}, +{0x000024C8, 0x000024E2}, {0x000024C9, 0x000024E3}, {0x000024CA, 0x000024E4}, {0x000024CB, 0x000024E5}, +{0x000024CC, 0x000024E6}, {0x000024CD, 0x000024E7}, {0x000024CE, 0x000024E8}, {0x000024CF, 0x000024E9}, +{0x00002C00, 0x00002C30}, {0x00002C01, 0x00002C31}, {0x00002C02, 0x00002C32}, {0x00002C03, 0x00002C33}, +{0x00002C04, 0x00002C34}, {0x00002C05, 0x00002C35}, {0x00002C06, 0x00002C36}, {0x00002C07, 0x00002C37}, +{0x00002C08, 0x00002C38}, {0x00002C09, 0x00002C39}, {0x00002C0A, 0x00002C3A}, {0x00002C0B, 0x00002C3B}, +{0x00002C0C, 0x00002C3C}, {0x00002C0D, 0x00002C3D}, {0x00002C0E, 0x00002C3E}, {0x00002C0F, 0x00002C3F}, +{0x00002C10, 0x00002C40}, {0x00002C11, 0x00002C41}, {0x00002C12, 0x00002C42}, {0x00002C13, 0x00002C43}, +{0x00002C14, 0x00002C44}, {0x00002C15, 0x00002C45}, {0x00002C16, 0x00002C46}, {0x00002C17, 0x00002C47}, +{0x00002C18, 0x00002C48}, {0x00002C19, 0x00002C49}, {0x00002C1A, 0x00002C4A}, {0x00002C1B, 0x00002C4B}, +{0x00002C1C, 0x00002C4C}, {0x00002C1D, 0x00002C4D}, {0x00002C1E, 0x00002C4E}, {0x00002C1F, 0x00002C4F}, +{0x00002C20, 0x00002C50}, {0x00002C21, 0x00002C51}, {0x00002C22, 0x00002C52}, {0x00002C23, 0x00002C53}, +{0x00002C24, 0x00002C54}, {0x00002C25, 0x00002C55}, {0x00002C26, 0x00002C56}, {0x00002C27, 0x00002C57}, +{0x00002C28, 0x00002C58}, {0x00002C29, 0x00002C59}, {0x00002C2A, 0x00002C5A}, {0x00002C2B, 0x00002C5B}, +{0x00002C2C, 0x00002C5C}, {0x00002C2D, 0x00002C5D}, {0x00002C2E, 0x00002C5E}, {0x00002C60, 0x00002C61}, +{0x00002C62, 0x0000026B}, {0x00002C63, 0x00001D7D}, {0x00002C64, 0x0000027D}, {0x00002C67, 0x00002C68}, +{0x00002C69, 0x00002C6A}, {0x00002C6B, 0x00002C6C}, {0x00002C6D, 0x00000251}, {0x00002C6E, 0x00000271}, +{0x00002C6F, 0x00000250}, {0x00002C70, 0x00000252}, {0x00002C72, 0x00002C73}, {0x00002C75, 0x00002C76}, +{0x00002C7E, 0x0000023F}, {0x00002C7F, 0x00000240}, {0x00002C80, 0x00002C81}, {0x00002C82, 0x00002C83}, +{0x00002C84, 0x00002C85}, {0x00002C86, 0x00002C87}, {0x00002C88, 0x00002C89}, {0x00002C8A, 0x00002C8B}, +{0x00002C8C, 0x00002C8D}, {0x00002C8E, 0x00002C8F}, {0x00002C90, 0x00002C91}, {0x00002C92, 0x00002C93}, +{0x00002C94, 0x00002C95}, {0x00002C96, 0x00002C97}, {0x00002C98, 0x00002C99}, {0x00002C9A, 0x00002C9B}, +{0x00002C9C, 0x00002C9D}, {0x00002C9E, 0x00002C9F}, {0x00002CA0, 0x00002CA1}, {0x00002CA2, 0x00002CA3}, +{0x00002CA4, 0x00002CA5}, {0x00002CA6, 0x00002CA7}, {0x00002CA8, 0x00002CA9}, {0x00002CAA, 0x00002CAB}, +{0x00002CAC, 0x00002CAD}, {0x00002CAE, 0x00002CAF}, {0x00002CB0, 0x00002CB1}, {0x00002CB2, 0x00002CB3}, +{0x00002CB4, 0x00002CB5}, {0x00002CB6, 0x00002CB7}, {0x00002CB8, 0x00002CB9}, {0x00002CBA, 0x00002CBB}, +{0x00002CBC, 0x00002CBD}, {0x00002CBE, 0x00002CBF}, {0x00002CC0, 0x00002CC1}, {0x00002CC2, 0x00002CC3}, +{0x00002CC4, 0x00002CC5}, {0x00002CC6, 0x00002CC7}, {0x00002CC8, 0x00002CC9}, {0x00002CCA, 0x00002CCB}, +{0x00002CCC, 0x00002CCD}, {0x00002CCE, 0x00002CCF}, {0x00002CD0, 0x00002CD1}, {0x00002CD2, 0x00002CD3}, +{0x00002CD4, 0x00002CD5}, {0x00002CD6, 0x00002CD7}, {0x00002CD8, 0x00002CD9}, {0x00002CDA, 0x00002CDB}, +{0x00002CDC, 0x00002CDD}, {0x00002CDE, 0x00002CDF}, {0x00002CE0, 0x00002CE1}, {0x00002CE2, 0x00002CE3}, +{0x00002CEB, 0x00002CEC}, {0x00002CED, 0x00002CEE}, {0x00002CF2, 0x00002CF3}, {0x0000A640, 0x0000A641}, +{0x0000A642, 0x0000A643}, {0x0000A644, 0x0000A645}, {0x0000A646, 0x0000A647}, {0x0000A648, 0x0000A649}, +{0x0000A64A, 0x0000A64B}, {0x0000A64C, 0x0000A64D}, {0x0000A64E, 0x0000A64F}, {0x0000A650, 0x0000A651}, +{0x0000A652, 0x0000A653}, {0x0000A654, 0x0000A655}, {0x0000A656, 0x0000A657}, {0x0000A658, 0x0000A659}, +{0x0000A65A, 0x0000A65B}, {0x0000A65C, 0x0000A65D}, {0x0000A65E, 0x0000A65F}, {0x0000A660, 0x0000A661}, +{0x0000A662, 0x0000A663}, {0x0000A664, 0x0000A665}, {0x0000A666, 0x0000A667}, {0x0000A668, 0x0000A669}, +{0x0000A66A, 0x0000A66B}, {0x0000A66C, 0x0000A66D}, {0x0000A680, 0x0000A681}, {0x0000A682, 0x0000A683}, +{0x0000A684, 0x0000A685}, {0x0000A686, 0x0000A687}, {0x0000A688, 0x0000A689}, {0x0000A68A, 0x0000A68B}, +{0x0000A68C, 0x0000A68D}, {0x0000A68E, 0x0000A68F}, {0x0000A690, 0x0000A691}, {0x0000A692, 0x0000A693}, +{0x0000A694, 0x0000A695}, {0x0000A696, 0x0000A697}, {0x0000A698, 0x0000A699}, {0x0000A69A, 0x0000A69B}, +{0x0000A722, 0x0000A723}, {0x0000A724, 0x0000A725}, {0x0000A726, 0x0000A727}, {0x0000A728, 0x0000A729}, +{0x0000A72A, 0x0000A72B}, {0x0000A72C, 0x0000A72D}, {0x0000A72E, 0x0000A72F}, {0x0000A732, 0x0000A733}, +{0x0000A734, 0x0000A735}, {0x0000A736, 0x0000A737}, {0x0000A738, 0x0000A739}, {0x0000A73A, 0x0000A73B}, +{0x0000A73C, 0x0000A73D}, {0x0000A73E, 0x0000A73F}, {0x0000A740, 0x0000A741}, {0x0000A742, 0x0000A743}, +{0x0000A744, 0x0000A745}, {0x0000A746, 0x0000A747}, {0x0000A748, 0x0000A749}, {0x0000A74A, 0x0000A74B}, +{0x0000A74C, 0x0000A74D}, {0x0000A74E, 0x0000A74F}, {0x0000A750, 0x0000A751}, {0x0000A752, 0x0000A753}, +{0x0000A754, 0x0000A755}, {0x0000A756, 0x0000A757}, {0x0000A758, 0x0000A759}, {0x0000A75A, 0x0000A75B}, +{0x0000A75C, 0x0000A75D}, {0x0000A75E, 0x0000A75F}, {0x0000A760, 0x0000A761}, {0x0000A762, 0x0000A763}, +{0x0000A764, 0x0000A765}, {0x0000A766, 0x0000A767}, {0x0000A768, 0x0000A769}, {0x0000A76A, 0x0000A76B}, +{0x0000A76C, 0x0000A76D}, {0x0000A76E, 0x0000A76F}, {0x0000A779, 0x0000A77A}, {0x0000A77B, 0x0000A77C}, +{0x0000A77D, 0x00001D79}, {0x0000A77E, 0x0000A77F}, {0x0000A780, 0x0000A781}, {0x0000A782, 0x0000A783}, +{0x0000A784, 0x0000A785}, {0x0000A786, 0x0000A787}, {0x0000A78B, 0x0000A78C}, {0x0000A78D, 0x00000265}, +{0x0000A790, 0x0000A791}, {0x0000A792, 0x0000A793}, {0x0000A796, 0x0000A797}, {0x0000A798, 0x0000A799}, +{0x0000A79A, 0x0000A79B}, {0x0000A79C, 0x0000A79D}, {0x0000A79E, 0x0000A79F}, {0x0000A7A0, 0x0000A7A1}, +{0x0000A7A2, 0x0000A7A3}, {0x0000A7A4, 0x0000A7A5}, {0x0000A7A6, 0x0000A7A7}, {0x0000A7A8, 0x0000A7A9}, +{0x0000A7AA, 0x00000266}, {0x0000A7AB, 0x0000025C}, {0x0000A7AC, 0x00000261}, {0x0000A7AD, 0x0000026C}, +{0x0000A7AE, 0x0000026A}, {0x0000A7B0, 0x0000029E}, {0x0000A7B1, 0x00000287}, {0x0000A7B2, 0x0000029D}, +{0x0000A7B3, 0x0000AB53}, {0x0000A7B4, 0x0000A7B5}, {0x0000A7B6, 0x0000A7B7}, {0x0000A7B8, 0x0000A7B9}, +{0x0000A7BA, 0x0000A7BB}, {0x0000A7BC, 0x0000A7BD}, {0x0000A7BE, 0x0000A7BF}, {0x0000A7C2, 0x0000A7C3}, +{0x0000A7C4, 0x0000A794}, {0x0000A7C5, 0x00000282}, {0x0000A7C6, 0x00001D8E}, {0x0000A7C7, 0x0000A7C8}, +{0x0000A7C9, 0x0000A7CA}, {0x0000A7F5, 0x0000A7F6}, {0x0000FF21, 0x0000FF41}, {0x0000FF22, 0x0000FF42}, +{0x0000FF23, 0x0000FF43}, {0x0000FF24, 0x0000FF44}, {0x0000FF25, 0x0000FF45}, {0x0000FF26, 0x0000FF46}, +{0x0000FF27, 0x0000FF47}, {0x0000FF28, 0x0000FF48}, {0x0000FF29, 0x0000FF49}, {0x0000FF2A, 0x0000FF4A}, +{0x0000FF2B, 0x0000FF4B}, {0x0000FF2C, 0x0000FF4C}, {0x0000FF2D, 0x0000FF4D}, {0x0000FF2E, 0x0000FF4E}, +{0x0000FF2F, 0x0000FF4F}, {0x0000FF30, 0x0000FF50}, {0x0000FF31, 0x0000FF51}, {0x0000FF32, 0x0000FF52}, +{0x0000FF33, 0x0000FF53}, {0x0000FF34, 0x0000FF54}, {0x0000FF35, 0x0000FF55}, {0x0000FF36, 0x0000FF56}, +{0x0000FF37, 0x0000FF57}, {0x0000FF38, 0x0000FF58}, {0x0000FF39, 0x0000FF59}, {0x0000FF3A, 0x0000FF5A}, +{0x00010400, 0x00010428}, {0x00010401, 0x00010429}, {0x00010402, 0x0001042A}, {0x00010403, 0x0001042B}, +{0x00010404, 0x0001042C}, {0x00010405, 0x0001042D}, {0x00010406, 0x0001042E}, {0x00010407, 0x0001042F}, +{0x00010408, 0x00010430}, {0x00010409, 0x00010431}, {0x0001040A, 0x00010432}, {0x0001040B, 0x00010433}, +{0x0001040C, 0x00010434}, {0x0001040D, 0x00010435}, {0x0001040E, 0x00010436}, {0x0001040F, 0x00010437}, +{0x00010410, 0x00010438}, {0x00010411, 0x00010439}, {0x00010412, 0x0001043A}, {0x00010413, 0x0001043B}, +{0x00010414, 0x0001043C}, {0x00010415, 0x0001043D}, {0x00010416, 0x0001043E}, {0x00010417, 0x0001043F}, +{0x00010418, 0x00010440}, {0x00010419, 0x00010441}, {0x0001041A, 0x00010442}, {0x0001041B, 0x00010443}, +{0x0001041C, 0x00010444}, {0x0001041D, 0x00010445}, {0x0001041E, 0x00010446}, {0x0001041F, 0x00010447}, +{0x00010420, 0x00010448}, {0x00010421, 0x00010449}, {0x00010422, 0x0001044A}, {0x00010423, 0x0001044B}, +{0x00010424, 0x0001044C}, {0x00010425, 0x0001044D}, {0x00010426, 0x0001044E}, {0x00010427, 0x0001044F}, +{0x000104B0, 0x000104D8}, {0x000104B1, 0x000104D9}, {0x000104B2, 0x000104DA}, {0x000104B3, 0x000104DB}, +{0x000104B4, 0x000104DC}, {0x000104B5, 0x000104DD}, {0x000104B6, 0x000104DE}, {0x000104B7, 0x000104DF}, +{0x000104B8, 0x000104E0}, {0x000104B9, 0x000104E1}, {0x000104BA, 0x000104E2}, {0x000104BB, 0x000104E3}, +{0x000104BC, 0x000104E4}, {0x000104BD, 0x000104E5}, {0x000104BE, 0x000104E6}, {0x000104BF, 0x000104E7}, +{0x000104C0, 0x000104E8}, {0x000104C1, 0x000104E9}, {0x000104C2, 0x000104EA}, {0x000104C3, 0x000104EB}, +{0x000104C4, 0x000104EC}, {0x000104C5, 0x000104ED}, {0x000104C6, 0x000104EE}, {0x000104C7, 0x000104EF}, +{0x000104C8, 0x000104F0}, {0x000104C9, 0x000104F1}, {0x000104CA, 0x000104F2}, {0x000104CB, 0x000104F3}, +{0x000104CC, 0x000104F4}, {0x000104CD, 0x000104F5}, {0x000104CE, 0x000104F6}, {0x000104CF, 0x000104F7}, +{0x000104D0, 0x000104F8}, {0x000104D1, 0x000104F9}, {0x000104D2, 0x000104FA}, {0x000104D3, 0x000104FB}, +{0x00010C80, 0x00010CC0}, {0x00010C81, 0x00010CC1}, {0x00010C82, 0x00010CC2}, {0x00010C83, 0x00010CC3}, +{0x00010C84, 0x00010CC4}, {0x00010C85, 0x00010CC5}, {0x00010C86, 0x00010CC6}, {0x00010C87, 0x00010CC7}, +{0x00010C88, 0x00010CC8}, {0x00010C89, 0x00010CC9}, {0x00010C8A, 0x00010CCA}, {0x00010C8B, 0x00010CCB}, +{0x00010C8C, 0x00010CCC}, {0x00010C8D, 0x00010CCD}, {0x00010C8E, 0x00010CCE}, {0x00010C8F, 0x00010CCF}, +{0x00010C90, 0x00010CD0}, {0x00010C91, 0x00010CD1}, {0x00010C92, 0x00010CD2}, {0x00010C93, 0x00010CD3}, +{0x00010C94, 0x00010CD4}, {0x00010C95, 0x00010CD5}, {0x00010C96, 0x00010CD6}, {0x00010C97, 0x00010CD7}, +{0x00010C98, 0x00010CD8}, {0x00010C99, 0x00010CD9}, {0x00010C9A, 0x00010CDA}, {0x00010C9B, 0x00010CDB}, +{0x00010C9C, 0x00010CDC}, {0x00010C9D, 0x00010CDD}, {0x00010C9E, 0x00010CDE}, {0x00010C9F, 0x00010CDF}, +{0x00010CA0, 0x00010CE0}, {0x00010CA1, 0x00010CE1}, {0x00010CA2, 0x00010CE2}, {0x00010CA3, 0x00010CE3}, +{0x00010CA4, 0x00010CE4}, {0x00010CA5, 0x00010CE5}, {0x00010CA6, 0x00010CE6}, {0x00010CA7, 0x00010CE7}, +{0x00010CA8, 0x00010CE8}, {0x00010CA9, 0x00010CE9}, {0x00010CAA, 0x00010CEA}, {0x00010CAB, 0x00010CEB}, +{0x00010CAC, 0x00010CEC}, {0x00010CAD, 0x00010CED}, {0x00010CAE, 0x00010CEE}, {0x00010CAF, 0x00010CEF}, +{0x00010CB0, 0x00010CF0}, {0x00010CB1, 0x00010CF1}, {0x00010CB2, 0x00010CF2}, {0x000118A0, 0x000118C0}, +{0x000118A1, 0x000118C1}, {0x000118A2, 0x000118C2}, {0x000118A3, 0x000118C3}, {0x000118A4, 0x000118C4}, +{0x000118A5, 0x000118C5}, {0x000118A6, 0x000118C6}, {0x000118A7, 0x000118C7}, {0x000118A8, 0x000118C8}, +{0x000118A9, 0x000118C9}, {0x000118AA, 0x000118CA}, {0x000118AB, 0x000118CB}, {0x000118AC, 0x000118CC}, +{0x000118AD, 0x000118CD}, {0x000118AE, 0x000118CE}, {0x000118AF, 0x000118CF}, {0x000118B0, 0x000118D0}, +{0x000118B1, 0x000118D1}, {0x000118B2, 0x000118D2}, {0x000118B3, 0x000118D3}, {0x000118B4, 0x000118D4}, +{0x000118B5, 0x000118D5}, {0x000118B6, 0x000118D6}, {0x000118B7, 0x000118D7}, {0x000118B8, 0x000118D8}, +{0x000118B9, 0x000118D9}, {0x000118BA, 0x000118DA}, {0x000118BB, 0x000118DB}, {0x000118BC, 0x000118DC}, +{0x000118BD, 0x000118DD}, {0x000118BE, 0x000118DE}, {0x000118BF, 0x000118DF}, {0x00016E40, 0x00016E60}, +{0x00016E41, 0x00016E61}, {0x00016E42, 0x00016E62}, {0x00016E43, 0x00016E63}, {0x00016E44, 0x00016E64}, +{0x00016E45, 0x00016E65}, {0x00016E46, 0x00016E66}, {0x00016E47, 0x00016E67}, {0x00016E48, 0x00016E68}, +{0x00016E49, 0x00016E69}, {0x00016E4A, 0x00016E6A}, {0x00016E4B, 0x00016E6B}, {0x00016E4C, 0x00016E6C}, +{0x00016E4D, 0x00016E6D}, {0x00016E4E, 0x00016E6E}, {0x00016E4F, 0x00016E6F}, {0x00016E50, 0x00016E70}, +{0x00016E51, 0x00016E71}, {0x00016E52, 0x00016E72}, {0x00016E53, 0x00016E73}, {0x00016E54, 0x00016E74}, +{0x00016E55, 0x00016E75}, {0x00016E56, 0x00016E76}, {0x00016E57, 0x00016E77}, {0x00016E58, 0x00016E78}, +{0x00016E59, 0x00016E79}, {0x00016E5A, 0x00016E7A}, {0x00016E5B, 0x00016E7B}, {0x00016E5C, 0x00016E7C}, +{0x00016E5D, 0x00016E7D}, {0x00016E5E, 0x00016E7E}, {0x00016E5F, 0x00016E7F}, {0x0001E900, 0x0001E922}, +{0x0001E901, 0x0001E923}, {0x0001E902, 0x0001E924}, {0x0001E903, 0x0001E925}, {0x0001E904, 0x0001E926}, +{0x0001E905, 0x0001E927}, {0x0001E906, 0x0001E928}, {0x0001E907, 0x0001E929}, {0x0001E908, 0x0001E92A}, +{0x0001E909, 0x0001E92B}, {0x0001E90A, 0x0001E92C}, {0x0001E90B, 0x0001E92D}, {0x0001E90C, 0x0001E92E}, +{0x0001E90D, 0x0001E92F}, {0x0001E90E, 0x0001E930}, {0x0001E90F, 0x0001E931}, {0x0001E910, 0x0001E932}, +{0x0001E911, 0x0001E933}, {0x0001E912, 0x0001E934}, {0x0001E913, 0x0001E935}, {0x0001E914, 0x0001E936}, +{0x0001E915, 0x0001E937}, {0x0001E916, 0x0001E938}, {0x0001E917, 0x0001E939}, {0x0001E918, 0x0001E93A}, +{0x0001E919, 0x0001E93B}, {0x0001E91A, 0x0001E93C}, {0x0001E91B, 0x0001E93D}, {0x0001E91C, 0x0001E93E}, +{0x0001E91D, 0x0001E93F}, {0x0001E91E, 0x0001E940}, {0x0001E91F, 0x0001E941}, {0x0001E920, 0x0001E942}, +{0x0001E921, 0x0001E943}, +}; + +const std::map unicode_map_uppercase = { +{0x00000061, 0x00000041}, {0x00000062, 0x00000042}, {0x00000063, 0x00000043}, {0x00000064, 0x00000044}, +{0x00000065, 0x00000045}, {0x00000066, 0x00000046}, {0x00000067, 0x00000047}, {0x00000068, 0x00000048}, +{0x00000069, 0x00000049}, {0x0000006A, 0x0000004A}, {0x0000006B, 0x0000004B}, {0x0000006C, 0x0000004C}, +{0x0000006D, 0x0000004D}, {0x0000006E, 0x0000004E}, {0x0000006F, 0x0000004F}, {0x00000070, 0x00000050}, +{0x00000071, 0x00000051}, {0x00000072, 0x00000052}, {0x00000073, 0x00000053}, {0x00000074, 0x00000054}, +{0x00000075, 0x00000055}, {0x00000076, 0x00000056}, {0x00000077, 0x00000057}, {0x00000078, 0x00000058}, +{0x00000079, 0x00000059}, {0x0000007A, 0x0000005A}, {0x000000B5, 0x0000039C}, {0x000000DF, 0x00000053}, +{0x000000E0, 0x000000C0}, {0x000000E1, 0x000000C1}, {0x000000E2, 0x000000C2}, {0x000000E3, 0x000000C3}, +{0x000000E4, 0x000000C4}, {0x000000E5, 0x000000C5}, {0x000000E6, 0x000000C6}, {0x000000E7, 0x000000C7}, +{0x000000E8, 0x000000C8}, {0x000000E9, 0x000000C9}, {0x000000EA, 0x000000CA}, {0x000000EB, 0x000000CB}, +{0x000000EC, 0x000000CC}, {0x000000ED, 0x000000CD}, {0x000000EE, 0x000000CE}, {0x000000EF, 0x000000CF}, +{0x000000F0, 0x000000D0}, {0x000000F1, 0x000000D1}, {0x000000F2, 0x000000D2}, {0x000000F3, 0x000000D3}, +{0x000000F4, 0x000000D4}, {0x000000F5, 0x000000D5}, {0x000000F6, 0x000000D6}, {0x000000F8, 0x000000D8}, +{0x000000F9, 0x000000D9}, {0x000000FA, 0x000000DA}, {0x000000FB, 0x000000DB}, {0x000000FC, 0x000000DC}, +{0x000000FD, 0x000000DD}, {0x000000FE, 0x000000DE}, {0x000000FF, 0x00000178}, {0x00000101, 0x00000100}, +{0x00000103, 0x00000102}, {0x00000105, 0x00000104}, {0x00000107, 0x00000106}, {0x00000109, 0x00000108}, +{0x0000010B, 0x0000010A}, {0x0000010D, 0x0000010C}, {0x0000010F, 0x0000010E}, {0x00000111, 0x00000110}, +{0x00000113, 0x00000112}, {0x00000115, 0x00000114}, {0x00000117, 0x00000116}, {0x00000119, 0x00000118}, +{0x0000011B, 0x0000011A}, {0x0000011D, 0x0000011C}, {0x0000011F, 0x0000011E}, {0x00000121, 0x00000120}, +{0x00000123, 0x00000122}, {0x00000125, 0x00000124}, {0x00000127, 0x00000126}, {0x00000129, 0x00000128}, +{0x0000012B, 0x0000012A}, {0x0000012D, 0x0000012C}, {0x0000012F, 0x0000012E}, {0x00000131, 0x00000049}, +{0x00000133, 0x00000132}, {0x00000135, 0x00000134}, {0x00000137, 0x00000136}, {0x0000013A, 0x00000139}, +{0x0000013C, 0x0000013B}, {0x0000013E, 0x0000013D}, {0x00000140, 0x0000013F}, {0x00000142, 0x00000141}, +{0x00000144, 0x00000143}, {0x00000146, 0x00000145}, {0x00000148, 0x00000147}, {0x00000149, 0x000002BC}, +{0x0000014B, 0x0000014A}, {0x0000014D, 0x0000014C}, {0x0000014F, 0x0000014E}, {0x00000151, 0x00000150}, +{0x00000153, 0x00000152}, {0x00000155, 0x00000154}, {0x00000157, 0x00000156}, {0x00000159, 0x00000158}, +{0x0000015B, 0x0000015A}, {0x0000015D, 0x0000015C}, {0x0000015F, 0x0000015E}, {0x00000161, 0x00000160}, +{0x00000163, 0x00000162}, {0x00000165, 0x00000164}, {0x00000167, 0x00000166}, {0x00000169, 0x00000168}, +{0x0000016B, 0x0000016A}, {0x0000016D, 0x0000016C}, {0x0000016F, 0x0000016E}, {0x00000171, 0x00000170}, +{0x00000173, 0x00000172}, {0x00000175, 0x00000174}, {0x00000177, 0x00000176}, {0x0000017A, 0x00000179}, +{0x0000017C, 0x0000017B}, {0x0000017E, 0x0000017D}, {0x0000017F, 0x00000053}, {0x00000180, 0x00000243}, +{0x00000183, 0x00000182}, {0x00000185, 0x00000184}, {0x00000188, 0x00000187}, {0x0000018C, 0x0000018B}, +{0x00000192, 0x00000191}, {0x00000195, 0x000001F6}, {0x00000199, 0x00000198}, {0x0000019A, 0x0000023D}, +{0x0000019E, 0x00000220}, {0x000001A1, 0x000001A0}, {0x000001A3, 0x000001A2}, {0x000001A5, 0x000001A4}, +{0x000001A8, 0x000001A7}, {0x000001AD, 0x000001AC}, {0x000001B0, 0x000001AF}, {0x000001B4, 0x000001B3}, +{0x000001B6, 0x000001B5}, {0x000001B9, 0x000001B8}, {0x000001BD, 0x000001BC}, {0x000001BF, 0x000001F7}, +{0x000001C5, 0x000001C4}, {0x000001C6, 0x000001C4}, {0x000001C8, 0x000001C7}, {0x000001C9, 0x000001C7}, +{0x000001CB, 0x000001CA}, {0x000001CC, 0x000001CA}, {0x000001CE, 0x000001CD}, {0x000001D0, 0x000001CF}, +{0x000001D2, 0x000001D1}, {0x000001D4, 0x000001D3}, {0x000001D6, 0x000001D5}, {0x000001D8, 0x000001D7}, +{0x000001DA, 0x000001D9}, {0x000001DC, 0x000001DB}, {0x000001DD, 0x0000018E}, {0x000001DF, 0x000001DE}, +{0x000001E1, 0x000001E0}, {0x000001E3, 0x000001E2}, {0x000001E5, 0x000001E4}, {0x000001E7, 0x000001E6}, +{0x000001E9, 0x000001E8}, {0x000001EB, 0x000001EA}, {0x000001ED, 0x000001EC}, {0x000001EF, 0x000001EE}, +{0x000001F0, 0x0000004A}, {0x000001F2, 0x000001F1}, {0x000001F3, 0x000001F1}, {0x000001F5, 0x000001F4}, +{0x000001F9, 0x000001F8}, {0x000001FB, 0x000001FA}, {0x000001FD, 0x000001FC}, {0x000001FF, 0x000001FE}, +{0x00000201, 0x00000200}, {0x00000203, 0x00000202}, {0x00000205, 0x00000204}, {0x00000207, 0x00000206}, +{0x00000209, 0x00000208}, {0x0000020B, 0x0000020A}, {0x0000020D, 0x0000020C}, {0x0000020F, 0x0000020E}, +{0x00000211, 0x00000210}, {0x00000213, 0x00000212}, {0x00000215, 0x00000214}, {0x00000217, 0x00000216}, +{0x00000219, 0x00000218}, {0x0000021B, 0x0000021A}, {0x0000021D, 0x0000021C}, {0x0000021F, 0x0000021E}, +{0x00000223, 0x00000222}, {0x00000225, 0x00000224}, {0x00000227, 0x00000226}, {0x00000229, 0x00000228}, +{0x0000022B, 0x0000022A}, {0x0000022D, 0x0000022C}, {0x0000022F, 0x0000022E}, {0x00000231, 0x00000230}, +{0x00000233, 0x00000232}, {0x0000023C, 0x0000023B}, {0x0000023F, 0x00002C7E}, {0x00000240, 0x00002C7F}, +{0x00000242, 0x00000241}, {0x00000247, 0x00000246}, {0x00000249, 0x00000248}, {0x0000024B, 0x0000024A}, +{0x0000024D, 0x0000024C}, {0x0000024F, 0x0000024E}, {0x00000250, 0x00002C6F}, {0x00000251, 0x00002C6D}, +{0x00000252, 0x00002C70}, {0x00000253, 0x00000181}, {0x00000254, 0x00000186}, {0x00000256, 0x00000189}, +{0x00000257, 0x0000018A}, {0x00000259, 0x0000018F}, {0x0000025B, 0x00000190}, {0x0000025C, 0x0000A7AB}, +{0x00000260, 0x00000193}, {0x00000261, 0x0000A7AC}, {0x00000263, 0x00000194}, {0x00000265, 0x0000A78D}, +{0x00000266, 0x0000A7AA}, {0x00000268, 0x00000197}, {0x00000269, 0x00000196}, {0x0000026A, 0x0000A7AE}, +{0x0000026B, 0x00002C62}, {0x0000026C, 0x0000A7AD}, {0x0000026F, 0x0000019C}, {0x00000271, 0x00002C6E}, +{0x00000272, 0x0000019D}, {0x00000275, 0x0000019F}, {0x0000027D, 0x00002C64}, {0x00000280, 0x000001A6}, +{0x00000282, 0x0000A7C5}, {0x00000283, 0x000001A9}, {0x00000287, 0x0000A7B1}, {0x00000288, 0x000001AE}, +{0x00000289, 0x00000244}, {0x0000028A, 0x000001B1}, {0x0000028B, 0x000001B2}, {0x0000028C, 0x00000245}, +{0x00000292, 0x000001B7}, {0x0000029D, 0x0000A7B2}, {0x0000029E, 0x0000A7B0}, {0x00000345, 0x00000399}, +{0x00000371, 0x00000370}, {0x00000373, 0x00000372}, {0x00000377, 0x00000376}, {0x0000037B, 0x000003FD}, +{0x0000037C, 0x000003FE}, {0x0000037D, 0x000003FF}, {0x00000390, 0x00000399}, {0x000003AC, 0x00000386}, +{0x000003AD, 0x00000388}, {0x000003AE, 0x00000389}, {0x000003AF, 0x0000038A}, {0x000003B0, 0x000003A5}, +{0x000003B1, 0x00000391}, {0x000003B2, 0x00000392}, {0x000003B3, 0x00000393}, {0x000003B4, 0x00000394}, +{0x000003B5, 0x00000395}, {0x000003B6, 0x00000396}, {0x000003B7, 0x00000397}, {0x000003B8, 0x00000398}, +{0x000003B9, 0x00000399}, {0x000003BA, 0x0000039A}, {0x000003BB, 0x0000039B}, {0x000003BC, 0x0000039C}, +{0x000003BD, 0x0000039D}, {0x000003BE, 0x0000039E}, {0x000003BF, 0x0000039F}, {0x000003C0, 0x000003A0}, +{0x000003C1, 0x000003A1}, {0x000003C2, 0x000003A3}, {0x000003C3, 0x000003A3}, {0x000003C4, 0x000003A4}, +{0x000003C5, 0x000003A5}, {0x000003C6, 0x000003A6}, {0x000003C7, 0x000003A7}, {0x000003C8, 0x000003A8}, +{0x000003C9, 0x000003A9}, {0x000003CA, 0x000003AA}, {0x000003CB, 0x000003AB}, {0x000003CC, 0x0000038C}, +{0x000003CD, 0x0000038E}, {0x000003CE, 0x0000038F}, {0x000003D0, 0x00000392}, {0x000003D1, 0x00000398}, +{0x000003D5, 0x000003A6}, {0x000003D6, 0x000003A0}, {0x000003D7, 0x000003CF}, {0x000003D9, 0x000003D8}, +{0x000003DB, 0x000003DA}, {0x000003DD, 0x000003DC}, {0x000003DF, 0x000003DE}, {0x000003E1, 0x000003E0}, +{0x000003E3, 0x000003E2}, {0x000003E5, 0x000003E4}, {0x000003E7, 0x000003E6}, {0x000003E9, 0x000003E8}, +{0x000003EB, 0x000003EA}, {0x000003ED, 0x000003EC}, {0x000003EF, 0x000003EE}, {0x000003F0, 0x0000039A}, +{0x000003F1, 0x000003A1}, {0x000003F2, 0x000003F9}, {0x000003F3, 0x0000037F}, {0x000003F5, 0x00000395}, +{0x000003F8, 0x000003F7}, {0x000003FB, 0x000003FA}, {0x00000430, 0x00000410}, {0x00000431, 0x00000411}, +{0x00000432, 0x00000412}, {0x00000433, 0x00000413}, {0x00000434, 0x00000414}, {0x00000435, 0x00000415}, +{0x00000436, 0x00000416}, {0x00000437, 0x00000417}, {0x00000438, 0x00000418}, {0x00000439, 0x00000419}, +{0x0000043A, 0x0000041A}, {0x0000043B, 0x0000041B}, {0x0000043C, 0x0000041C}, {0x0000043D, 0x0000041D}, +{0x0000043E, 0x0000041E}, {0x0000043F, 0x0000041F}, {0x00000440, 0x00000420}, {0x00000441, 0x00000421}, +{0x00000442, 0x00000422}, {0x00000443, 0x00000423}, {0x00000444, 0x00000424}, {0x00000445, 0x00000425}, +{0x00000446, 0x00000426}, {0x00000447, 0x00000427}, {0x00000448, 0x00000428}, {0x00000449, 0x00000429}, +{0x0000044A, 0x0000042A}, {0x0000044B, 0x0000042B}, {0x0000044C, 0x0000042C}, {0x0000044D, 0x0000042D}, +{0x0000044E, 0x0000042E}, {0x0000044F, 0x0000042F}, {0x00000450, 0x00000400}, {0x00000451, 0x00000401}, +{0x00000452, 0x00000402}, {0x00000453, 0x00000403}, {0x00000454, 0x00000404}, {0x00000455, 0x00000405}, +{0x00000456, 0x00000406}, {0x00000457, 0x00000407}, {0x00000458, 0x00000408}, {0x00000459, 0x00000409}, +{0x0000045A, 0x0000040A}, {0x0000045B, 0x0000040B}, {0x0000045C, 0x0000040C}, {0x0000045D, 0x0000040D}, +{0x0000045E, 0x0000040E}, {0x0000045F, 0x0000040F}, {0x00000461, 0x00000460}, {0x00000463, 0x00000462}, +{0x00000465, 0x00000464}, {0x00000467, 0x00000466}, {0x00000469, 0x00000468}, {0x0000046B, 0x0000046A}, +{0x0000046D, 0x0000046C}, {0x0000046F, 0x0000046E}, {0x00000471, 0x00000470}, {0x00000473, 0x00000472}, +{0x00000475, 0x00000474}, {0x00000477, 0x00000476}, {0x00000479, 0x00000478}, {0x0000047B, 0x0000047A}, +{0x0000047D, 0x0000047C}, {0x0000047F, 0x0000047E}, {0x00000481, 0x00000480}, {0x0000048B, 0x0000048A}, +{0x0000048D, 0x0000048C}, {0x0000048F, 0x0000048E}, {0x00000491, 0x00000490}, {0x00000493, 0x00000492}, +{0x00000495, 0x00000494}, {0x00000497, 0x00000496}, {0x00000499, 0x00000498}, {0x0000049B, 0x0000049A}, +{0x0000049D, 0x0000049C}, {0x0000049F, 0x0000049E}, {0x000004A1, 0x000004A0}, {0x000004A3, 0x000004A2}, +{0x000004A5, 0x000004A4}, {0x000004A7, 0x000004A6}, {0x000004A9, 0x000004A8}, {0x000004AB, 0x000004AA}, +{0x000004AD, 0x000004AC}, {0x000004AF, 0x000004AE}, {0x000004B1, 0x000004B0}, {0x000004B3, 0x000004B2}, +{0x000004B5, 0x000004B4}, {0x000004B7, 0x000004B6}, {0x000004B9, 0x000004B8}, {0x000004BB, 0x000004BA}, +{0x000004BD, 0x000004BC}, {0x000004BF, 0x000004BE}, {0x000004C2, 0x000004C1}, {0x000004C4, 0x000004C3}, +{0x000004C6, 0x000004C5}, {0x000004C8, 0x000004C7}, {0x000004CA, 0x000004C9}, {0x000004CC, 0x000004CB}, +{0x000004CE, 0x000004CD}, {0x000004CF, 0x000004C0}, {0x000004D1, 0x000004D0}, {0x000004D3, 0x000004D2}, +{0x000004D5, 0x000004D4}, {0x000004D7, 0x000004D6}, {0x000004D9, 0x000004D8}, {0x000004DB, 0x000004DA}, +{0x000004DD, 0x000004DC}, {0x000004DF, 0x000004DE}, {0x000004E1, 0x000004E0}, {0x000004E3, 0x000004E2}, +{0x000004E5, 0x000004E4}, {0x000004E7, 0x000004E6}, {0x000004E9, 0x000004E8}, {0x000004EB, 0x000004EA}, +{0x000004ED, 0x000004EC}, {0x000004EF, 0x000004EE}, {0x000004F1, 0x000004F0}, {0x000004F3, 0x000004F2}, +{0x000004F5, 0x000004F4}, {0x000004F7, 0x000004F6}, {0x000004F9, 0x000004F8}, {0x000004FB, 0x000004FA}, +{0x000004FD, 0x000004FC}, {0x000004FF, 0x000004FE}, {0x00000501, 0x00000500}, {0x00000503, 0x00000502}, +{0x00000505, 0x00000504}, {0x00000507, 0x00000506}, {0x00000509, 0x00000508}, {0x0000050B, 0x0000050A}, +{0x0000050D, 0x0000050C}, {0x0000050F, 0x0000050E}, {0x00000511, 0x00000510}, {0x00000513, 0x00000512}, +{0x00000515, 0x00000514}, {0x00000517, 0x00000516}, {0x00000519, 0x00000518}, {0x0000051B, 0x0000051A}, +{0x0000051D, 0x0000051C}, {0x0000051F, 0x0000051E}, {0x00000521, 0x00000520}, {0x00000523, 0x00000522}, +{0x00000525, 0x00000524}, {0x00000527, 0x00000526}, {0x00000529, 0x00000528}, {0x0000052B, 0x0000052A}, +{0x0000052D, 0x0000052C}, {0x0000052F, 0x0000052E}, {0x00000561, 0x00000531}, {0x00000562, 0x00000532}, +{0x00000563, 0x00000533}, {0x00000564, 0x00000534}, {0x00000565, 0x00000535}, {0x00000566, 0x00000536}, +{0x00000567, 0x00000537}, {0x00000568, 0x00000538}, {0x00000569, 0x00000539}, {0x0000056A, 0x0000053A}, +{0x0000056B, 0x0000053B}, {0x0000056C, 0x0000053C}, {0x0000056D, 0x0000053D}, {0x0000056E, 0x0000053E}, +{0x0000056F, 0x0000053F}, {0x00000570, 0x00000540}, {0x00000571, 0x00000541}, {0x00000572, 0x00000542}, +{0x00000573, 0x00000543}, {0x00000574, 0x00000544}, {0x00000575, 0x00000545}, {0x00000576, 0x00000546}, +{0x00000577, 0x00000547}, {0x00000578, 0x00000548}, {0x00000579, 0x00000549}, {0x0000057A, 0x0000054A}, +{0x0000057B, 0x0000054B}, {0x0000057C, 0x0000054C}, {0x0000057D, 0x0000054D}, {0x0000057E, 0x0000054E}, +{0x0000057F, 0x0000054F}, {0x00000580, 0x00000550}, {0x00000581, 0x00000551}, {0x00000582, 0x00000552}, +{0x00000583, 0x00000553}, {0x00000584, 0x00000554}, {0x00000585, 0x00000555}, {0x00000586, 0x00000556}, +{0x00000587, 0x00000535}, {0x000010D0, 0x00001C90}, {0x000010D1, 0x00001C91}, {0x000010D2, 0x00001C92}, +{0x000010D3, 0x00001C93}, {0x000010D4, 0x00001C94}, {0x000010D5, 0x00001C95}, {0x000010D6, 0x00001C96}, +{0x000010D7, 0x00001C97}, {0x000010D8, 0x00001C98}, {0x000010D9, 0x00001C99}, {0x000010DA, 0x00001C9A}, +{0x000010DB, 0x00001C9B}, {0x000010DC, 0x00001C9C}, {0x000010DD, 0x00001C9D}, {0x000010DE, 0x00001C9E}, +{0x000010DF, 0x00001C9F}, {0x000010E0, 0x00001CA0}, {0x000010E1, 0x00001CA1}, {0x000010E2, 0x00001CA2}, +{0x000010E3, 0x00001CA3}, {0x000010E4, 0x00001CA4}, {0x000010E5, 0x00001CA5}, {0x000010E6, 0x00001CA6}, +{0x000010E7, 0x00001CA7}, {0x000010E8, 0x00001CA8}, {0x000010E9, 0x00001CA9}, {0x000010EA, 0x00001CAA}, +{0x000010EB, 0x00001CAB}, {0x000010EC, 0x00001CAC}, {0x000010ED, 0x00001CAD}, {0x000010EE, 0x00001CAE}, +{0x000010EF, 0x00001CAF}, {0x000010F0, 0x00001CB0}, {0x000010F1, 0x00001CB1}, {0x000010F2, 0x00001CB2}, +{0x000010F3, 0x00001CB3}, {0x000010F4, 0x00001CB4}, {0x000010F5, 0x00001CB5}, {0x000010F6, 0x00001CB6}, +{0x000010F7, 0x00001CB7}, {0x000010F8, 0x00001CB8}, {0x000010F9, 0x00001CB9}, {0x000010FA, 0x00001CBA}, +{0x000010FD, 0x00001CBD}, {0x000010FE, 0x00001CBE}, {0x000010FF, 0x00001CBF}, {0x000013F8, 0x000013F0}, +{0x000013F9, 0x000013F1}, {0x000013FA, 0x000013F2}, {0x000013FB, 0x000013F3}, {0x000013FC, 0x000013F4}, +{0x000013FD, 0x000013F5}, {0x00001C80, 0x00000412}, {0x00001C81, 0x00000414}, {0x00001C82, 0x0000041E}, +{0x00001C83, 0x00000421}, {0x00001C84, 0x00000422}, {0x00001C85, 0x00000422}, {0x00001C86, 0x0000042A}, +{0x00001C87, 0x00000462}, {0x00001C88, 0x0000A64A}, {0x00001D79, 0x0000A77D}, {0x00001D7D, 0x00002C63}, +{0x00001D8E, 0x0000A7C6}, {0x00001E01, 0x00001E00}, {0x00001E03, 0x00001E02}, {0x00001E05, 0x00001E04}, +{0x00001E07, 0x00001E06}, {0x00001E09, 0x00001E08}, {0x00001E0B, 0x00001E0A}, {0x00001E0D, 0x00001E0C}, +{0x00001E0F, 0x00001E0E}, {0x00001E11, 0x00001E10}, {0x00001E13, 0x00001E12}, {0x00001E15, 0x00001E14}, +{0x00001E17, 0x00001E16}, {0x00001E19, 0x00001E18}, {0x00001E1B, 0x00001E1A}, {0x00001E1D, 0x00001E1C}, +{0x00001E1F, 0x00001E1E}, {0x00001E21, 0x00001E20}, {0x00001E23, 0x00001E22}, {0x00001E25, 0x00001E24}, +{0x00001E27, 0x00001E26}, {0x00001E29, 0x00001E28}, {0x00001E2B, 0x00001E2A}, {0x00001E2D, 0x00001E2C}, +{0x00001E2F, 0x00001E2E}, {0x00001E31, 0x00001E30}, {0x00001E33, 0x00001E32}, {0x00001E35, 0x00001E34}, +{0x00001E37, 0x00001E36}, {0x00001E39, 0x00001E38}, {0x00001E3B, 0x00001E3A}, {0x00001E3D, 0x00001E3C}, +{0x00001E3F, 0x00001E3E}, {0x00001E41, 0x00001E40}, {0x00001E43, 0x00001E42}, {0x00001E45, 0x00001E44}, +{0x00001E47, 0x00001E46}, {0x00001E49, 0x00001E48}, {0x00001E4B, 0x00001E4A}, {0x00001E4D, 0x00001E4C}, +{0x00001E4F, 0x00001E4E}, {0x00001E51, 0x00001E50}, {0x00001E53, 0x00001E52}, {0x00001E55, 0x00001E54}, +{0x00001E57, 0x00001E56}, {0x00001E59, 0x00001E58}, {0x00001E5B, 0x00001E5A}, {0x00001E5D, 0x00001E5C}, +{0x00001E5F, 0x00001E5E}, {0x00001E61, 0x00001E60}, {0x00001E63, 0x00001E62}, {0x00001E65, 0x00001E64}, +{0x00001E67, 0x00001E66}, {0x00001E69, 0x00001E68}, {0x00001E6B, 0x00001E6A}, {0x00001E6D, 0x00001E6C}, +{0x00001E6F, 0x00001E6E}, {0x00001E71, 0x00001E70}, {0x00001E73, 0x00001E72}, {0x00001E75, 0x00001E74}, +{0x00001E77, 0x00001E76}, {0x00001E79, 0x00001E78}, {0x00001E7B, 0x00001E7A}, {0x00001E7D, 0x00001E7C}, +{0x00001E7F, 0x00001E7E}, {0x00001E81, 0x00001E80}, {0x00001E83, 0x00001E82}, {0x00001E85, 0x00001E84}, +{0x00001E87, 0x00001E86}, {0x00001E89, 0x00001E88}, {0x00001E8B, 0x00001E8A}, {0x00001E8D, 0x00001E8C}, +{0x00001E8F, 0x00001E8E}, {0x00001E91, 0x00001E90}, {0x00001E93, 0x00001E92}, {0x00001E95, 0x00001E94}, +{0x00001E96, 0x00000048}, {0x00001E97, 0x00000054}, {0x00001E98, 0x00000057}, {0x00001E99, 0x00000059}, +{0x00001E9A, 0x00000041}, {0x00001E9B, 0x00001E60}, {0x00001EA1, 0x00001EA0}, {0x00001EA3, 0x00001EA2}, +{0x00001EA5, 0x00001EA4}, {0x00001EA7, 0x00001EA6}, {0x00001EA9, 0x00001EA8}, {0x00001EAB, 0x00001EAA}, +{0x00001EAD, 0x00001EAC}, {0x00001EAF, 0x00001EAE}, {0x00001EB1, 0x00001EB0}, {0x00001EB3, 0x00001EB2}, +{0x00001EB5, 0x00001EB4}, {0x00001EB7, 0x00001EB6}, {0x00001EB9, 0x00001EB8}, {0x00001EBB, 0x00001EBA}, +{0x00001EBD, 0x00001EBC}, {0x00001EBF, 0x00001EBE}, {0x00001EC1, 0x00001EC0}, {0x00001EC3, 0x00001EC2}, +{0x00001EC5, 0x00001EC4}, {0x00001EC7, 0x00001EC6}, {0x00001EC9, 0x00001EC8}, {0x00001ECB, 0x00001ECA}, +{0x00001ECD, 0x00001ECC}, {0x00001ECF, 0x00001ECE}, {0x00001ED1, 0x00001ED0}, {0x00001ED3, 0x00001ED2}, +{0x00001ED5, 0x00001ED4}, {0x00001ED7, 0x00001ED6}, {0x00001ED9, 0x00001ED8}, {0x00001EDB, 0x00001EDA}, +{0x00001EDD, 0x00001EDC}, {0x00001EDF, 0x00001EDE}, {0x00001EE1, 0x00001EE0}, {0x00001EE3, 0x00001EE2}, +{0x00001EE5, 0x00001EE4}, {0x00001EE7, 0x00001EE6}, {0x00001EE9, 0x00001EE8}, {0x00001EEB, 0x00001EEA}, +{0x00001EED, 0x00001EEC}, {0x00001EEF, 0x00001EEE}, {0x00001EF1, 0x00001EF0}, {0x00001EF3, 0x00001EF2}, +{0x00001EF5, 0x00001EF4}, {0x00001EF7, 0x00001EF6}, {0x00001EF9, 0x00001EF8}, {0x00001EFB, 0x00001EFA}, +{0x00001EFD, 0x00001EFC}, {0x00001EFF, 0x00001EFE}, {0x00001F00, 0x00001F08}, {0x00001F01, 0x00001F09}, +{0x00001F02, 0x00001F0A}, {0x00001F03, 0x00001F0B}, {0x00001F04, 0x00001F0C}, {0x00001F05, 0x00001F0D}, +{0x00001F06, 0x00001F0E}, {0x00001F07, 0x00001F0F}, {0x00001F10, 0x00001F18}, {0x00001F11, 0x00001F19}, +{0x00001F12, 0x00001F1A}, {0x00001F13, 0x00001F1B}, {0x00001F14, 0x00001F1C}, {0x00001F15, 0x00001F1D}, +{0x00001F20, 0x00001F28}, {0x00001F21, 0x00001F29}, {0x00001F22, 0x00001F2A}, {0x00001F23, 0x00001F2B}, +{0x00001F24, 0x00001F2C}, {0x00001F25, 0x00001F2D}, {0x00001F26, 0x00001F2E}, {0x00001F27, 0x00001F2F}, +{0x00001F30, 0x00001F38}, {0x00001F31, 0x00001F39}, {0x00001F32, 0x00001F3A}, {0x00001F33, 0x00001F3B}, +{0x00001F34, 0x00001F3C}, {0x00001F35, 0x00001F3D}, {0x00001F36, 0x00001F3E}, {0x00001F37, 0x00001F3F}, +{0x00001F40, 0x00001F48}, {0x00001F41, 0x00001F49}, {0x00001F42, 0x00001F4A}, {0x00001F43, 0x00001F4B}, +{0x00001F44, 0x00001F4C}, {0x00001F45, 0x00001F4D}, {0x00001F50, 0x000003A5}, {0x00001F51, 0x00001F59}, +{0x00001F52, 0x000003A5}, {0x00001F53, 0x00001F5B}, {0x00001F54, 0x000003A5}, {0x00001F55, 0x00001F5D}, +{0x00001F56, 0x000003A5}, {0x00001F57, 0x00001F5F}, {0x00001F60, 0x00001F68}, {0x00001F61, 0x00001F69}, +{0x00001F62, 0x00001F6A}, {0x00001F63, 0x00001F6B}, {0x00001F64, 0x00001F6C}, {0x00001F65, 0x00001F6D}, +{0x00001F66, 0x00001F6E}, {0x00001F67, 0x00001F6F}, {0x00001F70, 0x00001FBA}, {0x00001F71, 0x00001FBB}, +{0x00001F72, 0x00001FC8}, {0x00001F73, 0x00001FC9}, {0x00001F74, 0x00001FCA}, {0x00001F75, 0x00001FCB}, +{0x00001F76, 0x00001FDA}, {0x00001F77, 0x00001FDB}, {0x00001F78, 0x00001FF8}, {0x00001F79, 0x00001FF9}, +{0x00001F7A, 0x00001FEA}, {0x00001F7B, 0x00001FEB}, {0x00001F7C, 0x00001FFA}, {0x00001F7D, 0x00001FFB}, +{0x00001F80, 0x00001F08}, {0x00001F81, 0x00001F09}, {0x00001F82, 0x00001F0A}, {0x00001F83, 0x00001F0B}, +{0x00001F84, 0x00001F0C}, {0x00001F85, 0x00001F0D}, {0x00001F86, 0x00001F0E}, {0x00001F87, 0x00001F0F}, +{0x00001F88, 0x00001F08}, {0x00001F89, 0x00001F09}, {0x00001F8A, 0x00001F0A}, {0x00001F8B, 0x00001F0B}, +{0x00001F8C, 0x00001F0C}, {0x00001F8D, 0x00001F0D}, {0x00001F8E, 0x00001F0E}, {0x00001F8F, 0x00001F0F}, +{0x00001F90, 0x00001F28}, {0x00001F91, 0x00001F29}, {0x00001F92, 0x00001F2A}, {0x00001F93, 0x00001F2B}, +{0x00001F94, 0x00001F2C}, {0x00001F95, 0x00001F2D}, {0x00001F96, 0x00001F2E}, {0x00001F97, 0x00001F2F}, +{0x00001F98, 0x00001F28}, {0x00001F99, 0x00001F29}, {0x00001F9A, 0x00001F2A}, {0x00001F9B, 0x00001F2B}, +{0x00001F9C, 0x00001F2C}, {0x00001F9D, 0x00001F2D}, {0x00001F9E, 0x00001F2E}, {0x00001F9F, 0x00001F2F}, +{0x00001FA0, 0x00001F68}, {0x00001FA1, 0x00001F69}, {0x00001FA2, 0x00001F6A}, {0x00001FA3, 0x00001F6B}, +{0x00001FA4, 0x00001F6C}, {0x00001FA5, 0x00001F6D}, {0x00001FA6, 0x00001F6E}, {0x00001FA7, 0x00001F6F}, +{0x00001FA8, 0x00001F68}, {0x00001FA9, 0x00001F69}, {0x00001FAA, 0x00001F6A}, {0x00001FAB, 0x00001F6B}, +{0x00001FAC, 0x00001F6C}, {0x00001FAD, 0x00001F6D}, {0x00001FAE, 0x00001F6E}, {0x00001FAF, 0x00001F6F}, +{0x00001FB0, 0x00001FB8}, {0x00001FB1, 0x00001FB9}, {0x00001FB2, 0x00001FBA}, {0x00001FB3, 0x00000391}, +{0x00001FB4, 0x00000386}, {0x00001FB6, 0x00000391}, {0x00001FB7, 0x00000391}, {0x00001FBC, 0x00000391}, +{0x00001FBE, 0x00000399}, {0x00001FC2, 0x00001FCA}, {0x00001FC3, 0x00000397}, {0x00001FC4, 0x00000389}, +{0x00001FC6, 0x00000397}, {0x00001FC7, 0x00000397}, {0x00001FCC, 0x00000397}, {0x00001FD0, 0x00001FD8}, +{0x00001FD1, 0x00001FD9}, {0x00001FD2, 0x00000399}, {0x00001FD3, 0x00000399}, {0x00001FD6, 0x00000399}, +{0x00001FD7, 0x00000399}, {0x00001FE0, 0x00001FE8}, {0x00001FE1, 0x00001FE9}, {0x00001FE2, 0x000003A5}, +{0x00001FE3, 0x000003A5}, {0x00001FE4, 0x000003A1}, {0x00001FE5, 0x00001FEC}, {0x00001FE6, 0x000003A5}, +{0x00001FE7, 0x000003A5}, {0x00001FF2, 0x00001FFA}, {0x00001FF3, 0x000003A9}, {0x00001FF4, 0x0000038F}, +{0x00001FF6, 0x000003A9}, {0x00001FF7, 0x000003A9}, {0x00001FFC, 0x000003A9}, {0x0000214E, 0x00002132}, +{0x00002170, 0x00002160}, {0x00002171, 0x00002161}, {0x00002172, 0x00002162}, {0x00002173, 0x00002163}, +{0x00002174, 0x00002164}, {0x00002175, 0x00002165}, {0x00002176, 0x00002166}, {0x00002177, 0x00002167}, +{0x00002178, 0x00002168}, {0x00002179, 0x00002169}, {0x0000217A, 0x0000216A}, {0x0000217B, 0x0000216B}, +{0x0000217C, 0x0000216C}, {0x0000217D, 0x0000216D}, {0x0000217E, 0x0000216E}, {0x0000217F, 0x0000216F}, +{0x00002184, 0x00002183}, {0x000024D0, 0x000024B6}, {0x000024D1, 0x000024B7}, {0x000024D2, 0x000024B8}, +{0x000024D3, 0x000024B9}, {0x000024D4, 0x000024BA}, {0x000024D5, 0x000024BB}, {0x000024D6, 0x000024BC}, +{0x000024D7, 0x000024BD}, {0x000024D8, 0x000024BE}, {0x000024D9, 0x000024BF}, {0x000024DA, 0x000024C0}, +{0x000024DB, 0x000024C1}, {0x000024DC, 0x000024C2}, {0x000024DD, 0x000024C3}, {0x000024DE, 0x000024C4}, +{0x000024DF, 0x000024C5}, {0x000024E0, 0x000024C6}, {0x000024E1, 0x000024C7}, {0x000024E2, 0x000024C8}, +{0x000024E3, 0x000024C9}, {0x000024E4, 0x000024CA}, {0x000024E5, 0x000024CB}, {0x000024E6, 0x000024CC}, +{0x000024E7, 0x000024CD}, {0x000024E8, 0x000024CE}, {0x000024E9, 0x000024CF}, {0x00002C30, 0x00002C00}, +{0x00002C31, 0x00002C01}, {0x00002C32, 0x00002C02}, {0x00002C33, 0x00002C03}, {0x00002C34, 0x00002C04}, +{0x00002C35, 0x00002C05}, {0x00002C36, 0x00002C06}, {0x00002C37, 0x00002C07}, {0x00002C38, 0x00002C08}, +{0x00002C39, 0x00002C09}, {0x00002C3A, 0x00002C0A}, {0x00002C3B, 0x00002C0B}, {0x00002C3C, 0x00002C0C}, +{0x00002C3D, 0x00002C0D}, {0x00002C3E, 0x00002C0E}, {0x00002C3F, 0x00002C0F}, {0x00002C40, 0x00002C10}, +{0x00002C41, 0x00002C11}, {0x00002C42, 0x00002C12}, {0x00002C43, 0x00002C13}, {0x00002C44, 0x00002C14}, +{0x00002C45, 0x00002C15}, {0x00002C46, 0x00002C16}, {0x00002C47, 0x00002C17}, {0x00002C48, 0x00002C18}, +{0x00002C49, 0x00002C19}, {0x00002C4A, 0x00002C1A}, {0x00002C4B, 0x00002C1B}, {0x00002C4C, 0x00002C1C}, +{0x00002C4D, 0x00002C1D}, {0x00002C4E, 0x00002C1E}, {0x00002C4F, 0x00002C1F}, {0x00002C50, 0x00002C20}, +{0x00002C51, 0x00002C21}, {0x00002C52, 0x00002C22}, {0x00002C53, 0x00002C23}, {0x00002C54, 0x00002C24}, +{0x00002C55, 0x00002C25}, {0x00002C56, 0x00002C26}, {0x00002C57, 0x00002C27}, {0x00002C58, 0x00002C28}, +{0x00002C59, 0x00002C29}, {0x00002C5A, 0x00002C2A}, {0x00002C5B, 0x00002C2B}, {0x00002C5C, 0x00002C2C}, +{0x00002C5D, 0x00002C2D}, {0x00002C5E, 0x00002C2E}, {0x00002C61, 0x00002C60}, {0x00002C65, 0x0000023A}, +{0x00002C66, 0x0000023E}, {0x00002C68, 0x00002C67}, {0x00002C6A, 0x00002C69}, {0x00002C6C, 0x00002C6B}, +{0x00002C73, 0x00002C72}, {0x00002C76, 0x00002C75}, {0x00002C81, 0x00002C80}, {0x00002C83, 0x00002C82}, +{0x00002C85, 0x00002C84}, {0x00002C87, 0x00002C86}, {0x00002C89, 0x00002C88}, {0x00002C8B, 0x00002C8A}, +{0x00002C8D, 0x00002C8C}, {0x00002C8F, 0x00002C8E}, {0x00002C91, 0x00002C90}, {0x00002C93, 0x00002C92}, +{0x00002C95, 0x00002C94}, {0x00002C97, 0x00002C96}, {0x00002C99, 0x00002C98}, {0x00002C9B, 0x00002C9A}, +{0x00002C9D, 0x00002C9C}, {0x00002C9F, 0x00002C9E}, {0x00002CA1, 0x00002CA0}, {0x00002CA3, 0x00002CA2}, +{0x00002CA5, 0x00002CA4}, {0x00002CA7, 0x00002CA6}, {0x00002CA9, 0x00002CA8}, {0x00002CAB, 0x00002CAA}, +{0x00002CAD, 0x00002CAC}, {0x00002CAF, 0x00002CAE}, {0x00002CB1, 0x00002CB0}, {0x00002CB3, 0x00002CB2}, +{0x00002CB5, 0x00002CB4}, {0x00002CB7, 0x00002CB6}, {0x00002CB9, 0x00002CB8}, {0x00002CBB, 0x00002CBA}, +{0x00002CBD, 0x00002CBC}, {0x00002CBF, 0x00002CBE}, {0x00002CC1, 0x00002CC0}, {0x00002CC3, 0x00002CC2}, +{0x00002CC5, 0x00002CC4}, {0x00002CC7, 0x00002CC6}, {0x00002CC9, 0x00002CC8}, {0x00002CCB, 0x00002CCA}, +{0x00002CCD, 0x00002CCC}, {0x00002CCF, 0x00002CCE}, {0x00002CD1, 0x00002CD0}, {0x00002CD3, 0x00002CD2}, +{0x00002CD5, 0x00002CD4}, {0x00002CD7, 0x00002CD6}, {0x00002CD9, 0x00002CD8}, {0x00002CDB, 0x00002CDA}, +{0x00002CDD, 0x00002CDC}, {0x00002CDF, 0x00002CDE}, {0x00002CE1, 0x00002CE0}, {0x00002CE3, 0x00002CE2}, +{0x00002CEC, 0x00002CEB}, {0x00002CEE, 0x00002CED}, {0x00002CF3, 0x00002CF2}, {0x00002D00, 0x000010A0}, +{0x00002D01, 0x000010A1}, {0x00002D02, 0x000010A2}, {0x00002D03, 0x000010A3}, {0x00002D04, 0x000010A4}, +{0x00002D05, 0x000010A5}, {0x00002D06, 0x000010A6}, {0x00002D07, 0x000010A7}, {0x00002D08, 0x000010A8}, +{0x00002D09, 0x000010A9}, {0x00002D0A, 0x000010AA}, {0x00002D0B, 0x000010AB}, {0x00002D0C, 0x000010AC}, +{0x00002D0D, 0x000010AD}, {0x00002D0E, 0x000010AE}, {0x00002D0F, 0x000010AF}, {0x00002D10, 0x000010B0}, +{0x00002D11, 0x000010B1}, {0x00002D12, 0x000010B2}, {0x00002D13, 0x000010B3}, {0x00002D14, 0x000010B4}, +{0x00002D15, 0x000010B5}, {0x00002D16, 0x000010B6}, {0x00002D17, 0x000010B7}, {0x00002D18, 0x000010B8}, +{0x00002D19, 0x000010B9}, {0x00002D1A, 0x000010BA}, {0x00002D1B, 0x000010BB}, {0x00002D1C, 0x000010BC}, +{0x00002D1D, 0x000010BD}, {0x00002D1E, 0x000010BE}, {0x00002D1F, 0x000010BF}, {0x00002D20, 0x000010C0}, +{0x00002D21, 0x000010C1}, {0x00002D22, 0x000010C2}, {0x00002D23, 0x000010C3}, {0x00002D24, 0x000010C4}, +{0x00002D25, 0x000010C5}, {0x00002D27, 0x000010C7}, {0x00002D2D, 0x000010CD}, {0x0000A641, 0x0000A640}, +{0x0000A643, 0x0000A642}, {0x0000A645, 0x0000A644}, {0x0000A647, 0x0000A646}, {0x0000A649, 0x0000A648}, +{0x0000A64B, 0x0000A64A}, {0x0000A64D, 0x0000A64C}, {0x0000A64F, 0x0000A64E}, {0x0000A651, 0x0000A650}, +{0x0000A653, 0x0000A652}, {0x0000A655, 0x0000A654}, {0x0000A657, 0x0000A656}, {0x0000A659, 0x0000A658}, +{0x0000A65B, 0x0000A65A}, {0x0000A65D, 0x0000A65C}, {0x0000A65F, 0x0000A65E}, {0x0000A661, 0x0000A660}, +{0x0000A663, 0x0000A662}, {0x0000A665, 0x0000A664}, {0x0000A667, 0x0000A666}, {0x0000A669, 0x0000A668}, +{0x0000A66B, 0x0000A66A}, {0x0000A66D, 0x0000A66C}, {0x0000A681, 0x0000A680}, {0x0000A683, 0x0000A682}, +{0x0000A685, 0x0000A684}, {0x0000A687, 0x0000A686}, {0x0000A689, 0x0000A688}, {0x0000A68B, 0x0000A68A}, +{0x0000A68D, 0x0000A68C}, {0x0000A68F, 0x0000A68E}, {0x0000A691, 0x0000A690}, {0x0000A693, 0x0000A692}, +{0x0000A695, 0x0000A694}, {0x0000A697, 0x0000A696}, {0x0000A699, 0x0000A698}, {0x0000A69B, 0x0000A69A}, +{0x0000A723, 0x0000A722}, {0x0000A725, 0x0000A724}, {0x0000A727, 0x0000A726}, {0x0000A729, 0x0000A728}, +{0x0000A72B, 0x0000A72A}, {0x0000A72D, 0x0000A72C}, {0x0000A72F, 0x0000A72E}, {0x0000A733, 0x0000A732}, +{0x0000A735, 0x0000A734}, {0x0000A737, 0x0000A736}, {0x0000A739, 0x0000A738}, {0x0000A73B, 0x0000A73A}, +{0x0000A73D, 0x0000A73C}, {0x0000A73F, 0x0000A73E}, {0x0000A741, 0x0000A740}, {0x0000A743, 0x0000A742}, +{0x0000A745, 0x0000A744}, {0x0000A747, 0x0000A746}, {0x0000A749, 0x0000A748}, {0x0000A74B, 0x0000A74A}, +{0x0000A74D, 0x0000A74C}, {0x0000A74F, 0x0000A74E}, {0x0000A751, 0x0000A750}, {0x0000A753, 0x0000A752}, +{0x0000A755, 0x0000A754}, {0x0000A757, 0x0000A756}, {0x0000A759, 0x0000A758}, {0x0000A75B, 0x0000A75A}, +{0x0000A75D, 0x0000A75C}, {0x0000A75F, 0x0000A75E}, {0x0000A761, 0x0000A760}, {0x0000A763, 0x0000A762}, +{0x0000A765, 0x0000A764}, {0x0000A767, 0x0000A766}, {0x0000A769, 0x0000A768}, {0x0000A76B, 0x0000A76A}, +{0x0000A76D, 0x0000A76C}, {0x0000A76F, 0x0000A76E}, {0x0000A77A, 0x0000A779}, {0x0000A77C, 0x0000A77B}, +{0x0000A77F, 0x0000A77E}, {0x0000A781, 0x0000A780}, {0x0000A783, 0x0000A782}, {0x0000A785, 0x0000A784}, +{0x0000A787, 0x0000A786}, {0x0000A78C, 0x0000A78B}, {0x0000A791, 0x0000A790}, {0x0000A793, 0x0000A792}, +{0x0000A794, 0x0000A7C4}, {0x0000A797, 0x0000A796}, {0x0000A799, 0x0000A798}, {0x0000A79B, 0x0000A79A}, +{0x0000A79D, 0x0000A79C}, {0x0000A79F, 0x0000A79E}, {0x0000A7A1, 0x0000A7A0}, {0x0000A7A3, 0x0000A7A2}, +{0x0000A7A5, 0x0000A7A4}, {0x0000A7A7, 0x0000A7A6}, {0x0000A7A9, 0x0000A7A8}, {0x0000A7B5, 0x0000A7B4}, +{0x0000A7B7, 0x0000A7B6}, {0x0000A7B9, 0x0000A7B8}, {0x0000A7BB, 0x0000A7BA}, {0x0000A7BD, 0x0000A7BC}, +{0x0000A7BF, 0x0000A7BE}, {0x0000A7C3, 0x0000A7C2}, {0x0000A7C8, 0x0000A7C7}, {0x0000A7CA, 0x0000A7C9}, +{0x0000A7F6, 0x0000A7F5}, {0x0000AB53, 0x0000A7B3}, {0x0000AB70, 0x000013A0}, {0x0000AB71, 0x000013A1}, +{0x0000AB72, 0x000013A2}, {0x0000AB73, 0x000013A3}, {0x0000AB74, 0x000013A4}, {0x0000AB75, 0x000013A5}, +{0x0000AB76, 0x000013A6}, {0x0000AB77, 0x000013A7}, {0x0000AB78, 0x000013A8}, {0x0000AB79, 0x000013A9}, +{0x0000AB7A, 0x000013AA}, {0x0000AB7B, 0x000013AB}, {0x0000AB7C, 0x000013AC}, {0x0000AB7D, 0x000013AD}, +{0x0000AB7E, 0x000013AE}, {0x0000AB7F, 0x000013AF}, {0x0000AB80, 0x000013B0}, {0x0000AB81, 0x000013B1}, +{0x0000AB82, 0x000013B2}, {0x0000AB83, 0x000013B3}, {0x0000AB84, 0x000013B4}, {0x0000AB85, 0x000013B5}, +{0x0000AB86, 0x000013B6}, {0x0000AB87, 0x000013B7}, {0x0000AB88, 0x000013B8}, {0x0000AB89, 0x000013B9}, +{0x0000AB8A, 0x000013BA}, {0x0000AB8B, 0x000013BB}, {0x0000AB8C, 0x000013BC}, {0x0000AB8D, 0x000013BD}, +{0x0000AB8E, 0x000013BE}, {0x0000AB8F, 0x000013BF}, {0x0000AB90, 0x000013C0}, {0x0000AB91, 0x000013C1}, +{0x0000AB92, 0x000013C2}, {0x0000AB93, 0x000013C3}, {0x0000AB94, 0x000013C4}, {0x0000AB95, 0x000013C5}, +{0x0000AB96, 0x000013C6}, {0x0000AB97, 0x000013C7}, {0x0000AB98, 0x000013C8}, {0x0000AB99, 0x000013C9}, +{0x0000AB9A, 0x000013CA}, {0x0000AB9B, 0x000013CB}, {0x0000AB9C, 0x000013CC}, {0x0000AB9D, 0x000013CD}, +{0x0000AB9E, 0x000013CE}, {0x0000AB9F, 0x000013CF}, {0x0000ABA0, 0x000013D0}, {0x0000ABA1, 0x000013D1}, +{0x0000ABA2, 0x000013D2}, {0x0000ABA3, 0x000013D3}, {0x0000ABA4, 0x000013D4}, {0x0000ABA5, 0x000013D5}, +{0x0000ABA6, 0x000013D6}, {0x0000ABA7, 0x000013D7}, {0x0000ABA8, 0x000013D8}, {0x0000ABA9, 0x000013D9}, +{0x0000ABAA, 0x000013DA}, {0x0000ABAB, 0x000013DB}, {0x0000ABAC, 0x000013DC}, {0x0000ABAD, 0x000013DD}, +{0x0000ABAE, 0x000013DE}, {0x0000ABAF, 0x000013DF}, {0x0000ABB0, 0x000013E0}, {0x0000ABB1, 0x000013E1}, +{0x0000ABB2, 0x000013E2}, {0x0000ABB3, 0x000013E3}, {0x0000ABB4, 0x000013E4}, {0x0000ABB5, 0x000013E5}, +{0x0000ABB6, 0x000013E6}, {0x0000ABB7, 0x000013E7}, {0x0000ABB8, 0x000013E8}, {0x0000ABB9, 0x000013E9}, +{0x0000ABBA, 0x000013EA}, {0x0000ABBB, 0x000013EB}, {0x0000ABBC, 0x000013EC}, {0x0000ABBD, 0x000013ED}, +{0x0000ABBE, 0x000013EE}, {0x0000ABBF, 0x000013EF}, {0x0000FB00, 0x00000046}, {0x0000FB01, 0x00000046}, +{0x0000FB02, 0x00000046}, {0x0000FB03, 0x00000046}, {0x0000FB04, 0x00000046}, {0x0000FB05, 0x00000053}, +{0x0000FB06, 0x00000053}, {0x0000FB13, 0x00000544}, {0x0000FB14, 0x00000544}, {0x0000FB15, 0x00000544}, +{0x0000FB16, 0x0000054E}, {0x0000FB17, 0x00000544}, {0x0000FF41, 0x0000FF21}, {0x0000FF42, 0x0000FF22}, +{0x0000FF43, 0x0000FF23}, {0x0000FF44, 0x0000FF24}, {0x0000FF45, 0x0000FF25}, {0x0000FF46, 0x0000FF26}, +{0x0000FF47, 0x0000FF27}, {0x0000FF48, 0x0000FF28}, {0x0000FF49, 0x0000FF29}, {0x0000FF4A, 0x0000FF2A}, +{0x0000FF4B, 0x0000FF2B}, {0x0000FF4C, 0x0000FF2C}, {0x0000FF4D, 0x0000FF2D}, {0x0000FF4E, 0x0000FF2E}, +{0x0000FF4F, 0x0000FF2F}, {0x0000FF50, 0x0000FF30}, {0x0000FF51, 0x0000FF31}, {0x0000FF52, 0x0000FF32}, +{0x0000FF53, 0x0000FF33}, {0x0000FF54, 0x0000FF34}, {0x0000FF55, 0x0000FF35}, {0x0000FF56, 0x0000FF36}, +{0x0000FF57, 0x0000FF37}, {0x0000FF58, 0x0000FF38}, {0x0000FF59, 0x0000FF39}, {0x0000FF5A, 0x0000FF3A}, +{0x00010428, 0x00010400}, {0x00010429, 0x00010401}, {0x0001042A, 0x00010402}, {0x0001042B, 0x00010403}, +{0x0001042C, 0x00010404}, {0x0001042D, 0x00010405}, {0x0001042E, 0x00010406}, {0x0001042F, 0x00010407}, +{0x00010430, 0x00010408}, {0x00010431, 0x00010409}, {0x00010432, 0x0001040A}, {0x00010433, 0x0001040B}, +{0x00010434, 0x0001040C}, {0x00010435, 0x0001040D}, {0x00010436, 0x0001040E}, {0x00010437, 0x0001040F}, +{0x00010438, 0x00010410}, {0x00010439, 0x00010411}, {0x0001043A, 0x00010412}, {0x0001043B, 0x00010413}, +{0x0001043C, 0x00010414}, {0x0001043D, 0x00010415}, {0x0001043E, 0x00010416}, {0x0001043F, 0x00010417}, +{0x00010440, 0x00010418}, {0x00010441, 0x00010419}, {0x00010442, 0x0001041A}, {0x00010443, 0x0001041B}, +{0x00010444, 0x0001041C}, {0x00010445, 0x0001041D}, {0x00010446, 0x0001041E}, {0x00010447, 0x0001041F}, +{0x00010448, 0x00010420}, {0x00010449, 0x00010421}, {0x0001044A, 0x00010422}, {0x0001044B, 0x00010423}, +{0x0001044C, 0x00010424}, {0x0001044D, 0x00010425}, {0x0001044E, 0x00010426}, {0x0001044F, 0x00010427}, +{0x000104D8, 0x000104B0}, {0x000104D9, 0x000104B1}, {0x000104DA, 0x000104B2}, {0x000104DB, 0x000104B3}, +{0x000104DC, 0x000104B4}, {0x000104DD, 0x000104B5}, {0x000104DE, 0x000104B6}, {0x000104DF, 0x000104B7}, +{0x000104E0, 0x000104B8}, {0x000104E1, 0x000104B9}, {0x000104E2, 0x000104BA}, {0x000104E3, 0x000104BB}, +{0x000104E4, 0x000104BC}, {0x000104E5, 0x000104BD}, {0x000104E6, 0x000104BE}, {0x000104E7, 0x000104BF}, +{0x000104E8, 0x000104C0}, {0x000104E9, 0x000104C1}, {0x000104EA, 0x000104C2}, {0x000104EB, 0x000104C3}, +{0x000104EC, 0x000104C4}, {0x000104ED, 0x000104C5}, {0x000104EE, 0x000104C6}, {0x000104EF, 0x000104C7}, +{0x000104F0, 0x000104C8}, {0x000104F1, 0x000104C9}, {0x000104F2, 0x000104CA}, {0x000104F3, 0x000104CB}, +{0x000104F4, 0x000104CC}, {0x000104F5, 0x000104CD}, {0x000104F6, 0x000104CE}, {0x000104F7, 0x000104CF}, +{0x000104F8, 0x000104D0}, {0x000104F9, 0x000104D1}, {0x000104FA, 0x000104D2}, {0x000104FB, 0x000104D3}, +{0x00010CC0, 0x00010C80}, {0x00010CC1, 0x00010C81}, {0x00010CC2, 0x00010C82}, {0x00010CC3, 0x00010C83}, +{0x00010CC4, 0x00010C84}, {0x00010CC5, 0x00010C85}, {0x00010CC6, 0x00010C86}, {0x00010CC7, 0x00010C87}, +{0x00010CC8, 0x00010C88}, {0x00010CC9, 0x00010C89}, {0x00010CCA, 0x00010C8A}, {0x00010CCB, 0x00010C8B}, +{0x00010CCC, 0x00010C8C}, {0x00010CCD, 0x00010C8D}, {0x00010CCE, 0x00010C8E}, {0x00010CCF, 0x00010C8F}, +{0x00010CD0, 0x00010C90}, {0x00010CD1, 0x00010C91}, {0x00010CD2, 0x00010C92}, {0x00010CD3, 0x00010C93}, +{0x00010CD4, 0x00010C94}, {0x00010CD5, 0x00010C95}, {0x00010CD6, 0x00010C96}, {0x00010CD7, 0x00010C97}, +{0x00010CD8, 0x00010C98}, {0x00010CD9, 0x00010C99}, {0x00010CDA, 0x00010C9A}, {0x00010CDB, 0x00010C9B}, +{0x00010CDC, 0x00010C9C}, {0x00010CDD, 0x00010C9D}, {0x00010CDE, 0x00010C9E}, {0x00010CDF, 0x00010C9F}, +{0x00010CE0, 0x00010CA0}, {0x00010CE1, 0x00010CA1}, {0x00010CE2, 0x00010CA2}, {0x00010CE3, 0x00010CA3}, +{0x00010CE4, 0x00010CA4}, {0x00010CE5, 0x00010CA5}, {0x00010CE6, 0x00010CA6}, {0x00010CE7, 0x00010CA7}, +{0x00010CE8, 0x00010CA8}, {0x00010CE9, 0x00010CA9}, {0x00010CEA, 0x00010CAA}, {0x00010CEB, 0x00010CAB}, +{0x00010CEC, 0x00010CAC}, {0x00010CED, 0x00010CAD}, {0x00010CEE, 0x00010CAE}, {0x00010CEF, 0x00010CAF}, +{0x00010CF0, 0x00010CB0}, {0x00010CF1, 0x00010CB1}, {0x00010CF2, 0x00010CB2}, {0x000118C0, 0x000118A0}, +{0x000118C1, 0x000118A1}, {0x000118C2, 0x000118A2}, {0x000118C3, 0x000118A3}, {0x000118C4, 0x000118A4}, +{0x000118C5, 0x000118A5}, {0x000118C6, 0x000118A6}, {0x000118C7, 0x000118A7}, {0x000118C8, 0x000118A8}, +{0x000118C9, 0x000118A9}, {0x000118CA, 0x000118AA}, {0x000118CB, 0x000118AB}, {0x000118CC, 0x000118AC}, +{0x000118CD, 0x000118AD}, {0x000118CE, 0x000118AE}, {0x000118CF, 0x000118AF}, {0x000118D0, 0x000118B0}, +{0x000118D1, 0x000118B1}, {0x000118D2, 0x000118B2}, {0x000118D3, 0x000118B3}, {0x000118D4, 0x000118B4}, +{0x000118D5, 0x000118B5}, {0x000118D6, 0x000118B6}, {0x000118D7, 0x000118B7}, {0x000118D8, 0x000118B8}, +{0x000118D9, 0x000118B9}, {0x000118DA, 0x000118BA}, {0x000118DB, 0x000118BB}, {0x000118DC, 0x000118BC}, +{0x000118DD, 0x000118BD}, {0x000118DE, 0x000118BE}, {0x000118DF, 0x000118BF}, {0x00016E60, 0x00016E40}, +{0x00016E61, 0x00016E41}, {0x00016E62, 0x00016E42}, {0x00016E63, 0x00016E43}, {0x00016E64, 0x00016E44}, +{0x00016E65, 0x00016E45}, {0x00016E66, 0x00016E46}, {0x00016E67, 0x00016E47}, {0x00016E68, 0x00016E48}, +{0x00016E69, 0x00016E49}, {0x00016E6A, 0x00016E4A}, {0x00016E6B, 0x00016E4B}, {0x00016E6C, 0x00016E4C}, +{0x00016E6D, 0x00016E4D}, {0x00016E6E, 0x00016E4E}, {0x00016E6F, 0x00016E4F}, {0x00016E70, 0x00016E50}, +{0x00016E71, 0x00016E51}, {0x00016E72, 0x00016E52}, {0x00016E73, 0x00016E53}, {0x00016E74, 0x00016E54}, +{0x00016E75, 0x00016E55}, {0x00016E76, 0x00016E56}, {0x00016E77, 0x00016E57}, {0x00016E78, 0x00016E58}, +{0x00016E79, 0x00016E59}, {0x00016E7A, 0x00016E5A}, {0x00016E7B, 0x00016E5B}, {0x00016E7C, 0x00016E5C}, +{0x00016E7D, 0x00016E5D}, {0x00016E7E, 0x00016E5E}, {0x00016E7F, 0x00016E5F}, {0x0001E922, 0x0001E900}, +{0x0001E923, 0x0001E901}, {0x0001E924, 0x0001E902}, {0x0001E925, 0x0001E903}, {0x0001E926, 0x0001E904}, +{0x0001E927, 0x0001E905}, {0x0001E928, 0x0001E906}, {0x0001E929, 0x0001E907}, {0x0001E92A, 0x0001E908}, +{0x0001E92B, 0x0001E909}, {0x0001E92C, 0x0001E90A}, {0x0001E92D, 0x0001E90B}, {0x0001E92E, 0x0001E90C}, +{0x0001E92F, 0x0001E90D}, {0x0001E930, 0x0001E90E}, {0x0001E931, 0x0001E90F}, {0x0001E932, 0x0001E910}, +{0x0001E933, 0x0001E911}, {0x0001E934, 0x0001E912}, {0x0001E935, 0x0001E913}, {0x0001E936, 0x0001E914}, +{0x0001E937, 0x0001E915}, {0x0001E938, 0x0001E916}, {0x0001E939, 0x0001E917}, {0x0001E93A, 0x0001E918}, +{0x0001E93B, 0x0001E919}, {0x0001E93C, 0x0001E91A}, {0x0001E93D, 0x0001E91B}, {0x0001E93E, 0x0001E91C}, +{0x0001E93F, 0x0001E91D}, {0x0001E940, 0x0001E91E}, {0x0001E941, 0x0001E91F}, {0x0001E942, 0x0001E920}, +{0x0001E943, 0x0001E921}, }; const std::multimap unicode_map_nfd = { @@ -1407,245 +2181,3 @@ const std::multimap unicode_map_nfd = { {0x0002FA16, 0x00004D56}, {0x0002FA17, 0x00009EF9}, {0x0002FA18, 0x00009EFE}, {0x0002FA19, 0x00009F05}, {0x0002FA1A, 0x00009F0F}, {0x0002FA1B, 0x00009F16}, {0x0002FA1D, 0x0002A600}, }; - -const std::map unicode_map_lowercase = { -{0x00041, 0x00061}, {0x00042, 0x00062}, {0x00043, 0x00063}, {0x00044, 0x00064}, {0x00045, 0x00065}, {0x00046, 0x00066}, -{0x00047, 0x00067}, {0x00048, 0x00068}, {0x00049, 0x00069}, {0x0004A, 0x0006A}, {0x0004B, 0x0006B}, {0x0004C, 0x0006C}, -{0x0004D, 0x0006D}, {0x0004E, 0x0006E}, {0x0004F, 0x0006F}, {0x00050, 0x00070}, {0x00051, 0x00071}, {0x00052, 0x00072}, -{0x00053, 0x00073}, {0x00054, 0x00074}, {0x00055, 0x00075}, {0x00056, 0x00076}, {0x00057, 0x00077}, {0x00058, 0x00078}, -{0x00059, 0x00079}, {0x0005A, 0x0007A}, {0x000C0, 0x000E0}, {0x000C1, 0x000E1}, {0x000C2, 0x000E2}, {0x000C3, 0x000E3}, -{0x000C4, 0x000E4}, {0x000C5, 0x000E5}, {0x000C6, 0x000E6}, {0x000C7, 0x000E7}, {0x000C8, 0x000E8}, {0x000C9, 0x000E9}, -{0x000CA, 0x000EA}, {0x000CB, 0x000EB}, {0x000CC, 0x000EC}, {0x000CD, 0x000ED}, {0x000CE, 0x000EE}, {0x000CF, 0x000EF}, -{0x000D0, 0x000F0}, {0x000D1, 0x000F1}, {0x000D2, 0x000F2}, {0x000D3, 0x000F3}, {0x000D4, 0x000F4}, {0x000D5, 0x000F5}, -{0x000D6, 0x000F6}, {0x000D8, 0x000F8}, {0x000D9, 0x000F9}, {0x000DA, 0x000FA}, {0x000DB, 0x000FB}, {0x000DC, 0x000FC}, -{0x000DD, 0x000FD}, {0x000DE, 0x000FE}, {0x00100, 0x00101}, {0x00102, 0x00103}, {0x00104, 0x00105}, {0x00106, 0x00107}, -{0x00108, 0x00109}, {0x0010A, 0x0010B}, {0x0010C, 0x0010D}, {0x0010E, 0x0010F}, {0x00110, 0x00111}, {0x00112, 0x00113}, -{0x00114, 0x00115}, {0x00116, 0x00117}, {0x00118, 0x00119}, {0x0011A, 0x0011B}, {0x0011C, 0x0011D}, {0x0011E, 0x0011F}, -{0x00120, 0x00121}, {0x00122, 0x00123}, {0x00124, 0x00125}, {0x00126, 0x00127}, {0x00128, 0x00129}, {0x0012A, 0x0012B}, -{0x0012C, 0x0012D}, {0x0012E, 0x0012F}, {0x00130, 0x00069}, {0x00132, 0x00133}, {0x00134, 0x00135}, {0x00136, 0x00137}, -{0x00139, 0x0013A}, {0x0013B, 0x0013C}, {0x0013D, 0x0013E}, {0x0013F, 0x00140}, {0x00141, 0x00142}, {0x00143, 0x00144}, -{0x00145, 0x00146}, {0x00147, 0x00148}, {0x0014A, 0x0014B}, {0x0014C, 0x0014D}, {0x0014E, 0x0014F}, {0x00150, 0x00151}, -{0x00152, 0x00153}, {0x00154, 0x00155}, {0x00156, 0x00157}, {0x00158, 0x00159}, {0x0015A, 0x0015B}, {0x0015C, 0x0015D}, -{0x0015E, 0x0015F}, {0x00160, 0x00161}, {0x00162, 0x00163}, {0x00164, 0x00165}, {0x00166, 0x00167}, {0x00168, 0x00169}, -{0x0016A, 0x0016B}, {0x0016C, 0x0016D}, {0x0016E, 0x0016F}, {0x00170, 0x00171}, {0x00172, 0x00173}, {0x00174, 0x00175}, -{0x00176, 0x00177}, {0x00178, 0x000FF}, {0x00179, 0x0017A}, {0x0017B, 0x0017C}, {0x0017D, 0x0017E}, {0x00181, 0x00253}, -{0x00182, 0x00183}, {0x00184, 0x00185}, {0x00186, 0x00254}, {0x00187, 0x00188}, {0x00189, 0x00256}, {0x0018A, 0x00257}, -{0x0018B, 0x0018C}, {0x0018E, 0x001DD}, {0x0018F, 0x00259}, {0x00190, 0x0025B}, {0x00191, 0x00192}, {0x00193, 0x00260}, -{0x00194, 0x00263}, {0x00196, 0x00269}, {0x00197, 0x00268}, {0x00198, 0x00199}, {0x0019C, 0x0026F}, {0x0019D, 0x00272}, -{0x0019F, 0x00275}, {0x001A0, 0x001A1}, {0x001A2, 0x001A3}, {0x001A4, 0x001A5}, {0x001A6, 0x00280}, {0x001A7, 0x001A8}, -{0x001A9, 0x00283}, {0x001AC, 0x001AD}, {0x001AE, 0x00288}, {0x001AF, 0x001B0}, {0x001B1, 0x0028A}, {0x001B2, 0x0028B}, -{0x001B3, 0x001B4}, {0x001B5, 0x001B6}, {0x001B7, 0x00292}, {0x001B8, 0x001B9}, {0x001BC, 0x001BD}, {0x001C4, 0x001C6}, -{0x001C5, 0x001C6}, {0x001C7, 0x001C9}, {0x001C8, 0x001C9}, {0x001CA, 0x001CC}, {0x001CB, 0x001CC}, {0x001CD, 0x001CE}, -{0x001CF, 0x001D0}, {0x001D1, 0x001D2}, {0x001D3, 0x001D4}, {0x001D5, 0x001D6}, {0x001D7, 0x001D8}, {0x001D9, 0x001DA}, -{0x001DB, 0x001DC}, {0x001DE, 0x001DF}, {0x001E0, 0x001E1}, {0x001E2, 0x001E3}, {0x001E4, 0x001E5}, {0x001E6, 0x001E7}, -{0x001E8, 0x001E9}, {0x001EA, 0x001EB}, {0x001EC, 0x001ED}, {0x001EE, 0x001EF}, {0x001F1, 0x001F3}, {0x001F2, 0x001F3}, -{0x001F4, 0x001F5}, {0x001F6, 0x00195}, {0x001F7, 0x001BF}, {0x001F8, 0x001F9}, {0x001FA, 0x001FB}, {0x001FC, 0x001FD}, -{0x001FE, 0x001FF}, {0x00200, 0x00201}, {0x00202, 0x00203}, {0x00204, 0x00205}, {0x00206, 0x00207}, {0x00208, 0x00209}, -{0x0020A, 0x0020B}, {0x0020C, 0x0020D}, {0x0020E, 0x0020F}, {0x00210, 0x00211}, {0x00212, 0x00213}, {0x00214, 0x00215}, -{0x00216, 0x00217}, {0x00218, 0x00219}, {0x0021A, 0x0021B}, {0x0021C, 0x0021D}, {0x0021E, 0x0021F}, {0x00220, 0x0019E}, -{0x00222, 0x00223}, {0x00224, 0x00225}, {0x00226, 0x00227}, {0x00228, 0x00229}, {0x0022A, 0x0022B}, {0x0022C, 0x0022D}, -{0x0022E, 0x0022F}, {0x00230, 0x00231}, {0x00232, 0x00233}, {0x0023A, 0x02C65}, {0x0023B, 0x0023C}, {0x0023D, 0x0019A}, -{0x0023E, 0x02C66}, {0x00241, 0x00242}, {0x00243, 0x00180}, {0x00244, 0x00289}, {0x00245, 0x0028C}, {0x00246, 0x00247}, -{0x00248, 0x00249}, {0x0024A, 0x0024B}, {0x0024C, 0x0024D}, {0x0024E, 0x0024F}, {0x00370, 0x00371}, {0x00372, 0x00373}, -{0x00376, 0x00377}, {0x0037F, 0x003F3}, {0x00386, 0x003AC}, {0x00388, 0x003AD}, {0x00389, 0x003AE}, {0x0038A, 0x003AF}, -{0x0038C, 0x003CC}, {0x0038E, 0x003CD}, {0x0038F, 0x003CE}, {0x00391, 0x003B1}, {0x00392, 0x003B2}, {0x00393, 0x003B3}, -{0x00394, 0x003B4}, {0x00395, 0x003B5}, {0x00396, 0x003B6}, {0x00397, 0x003B7}, {0x00398, 0x003B8}, {0x00399, 0x003B9}, -{0x0039A, 0x003BA}, {0x0039B, 0x003BB}, {0x0039C, 0x003BC}, {0x0039D, 0x003BD}, {0x0039E, 0x003BE}, {0x0039F, 0x003BF}, -{0x003A0, 0x003C0}, {0x003A1, 0x003C1}, {0x003A3, 0x003C3}, {0x003A4, 0x003C4}, {0x003A5, 0x003C5}, {0x003A6, 0x003C6}, -{0x003A7, 0x003C7}, {0x003A8, 0x003C8}, {0x003A9, 0x003C9}, {0x003AA, 0x003CA}, {0x003AB, 0x003CB}, {0x003CF, 0x003D7}, -{0x003D8, 0x003D9}, {0x003DA, 0x003DB}, {0x003DC, 0x003DD}, {0x003DE, 0x003DF}, {0x003E0, 0x003E1}, {0x003E2, 0x003E3}, -{0x003E4, 0x003E5}, {0x003E6, 0x003E7}, {0x003E8, 0x003E9}, {0x003EA, 0x003EB}, {0x003EC, 0x003ED}, {0x003EE, 0x003EF}, -{0x003F4, 0x003B8}, {0x003F7, 0x003F8}, {0x003F9, 0x003F2}, {0x003FA, 0x003FB}, {0x003FD, 0x0037B}, {0x003FE, 0x0037C}, -{0x003FF, 0x0037D}, {0x00400, 0x00450}, {0x00401, 0x00451}, {0x00402, 0x00452}, {0x00403, 0x00453}, {0x00404, 0x00454}, -{0x00405, 0x00455}, {0x00406, 0x00456}, {0x00407, 0x00457}, {0x00408, 0x00458}, {0x00409, 0x00459}, {0x0040A, 0x0045A}, -{0x0040B, 0x0045B}, {0x0040C, 0x0045C}, {0x0040D, 0x0045D}, {0x0040E, 0x0045E}, {0x0040F, 0x0045F}, {0x00410, 0x00430}, -{0x00411, 0x00431}, {0x00412, 0x00432}, {0x00413, 0x00433}, {0x00414, 0x00434}, {0x00415, 0x00435}, {0x00416, 0x00436}, -{0x00417, 0x00437}, {0x00418, 0x00438}, {0x00419, 0x00439}, {0x0041A, 0x0043A}, {0x0041B, 0x0043B}, {0x0041C, 0x0043C}, -{0x0041D, 0x0043D}, {0x0041E, 0x0043E}, {0x0041F, 0x0043F}, {0x00420, 0x00440}, {0x00421, 0x00441}, {0x00422, 0x00442}, -{0x00423, 0x00443}, {0x00424, 0x00444}, {0x00425, 0x00445}, {0x00426, 0x00446}, {0x00427, 0x00447}, {0x00428, 0x00448}, -{0x00429, 0x00449}, {0x0042A, 0x0044A}, {0x0042B, 0x0044B}, {0x0042C, 0x0044C}, {0x0042D, 0x0044D}, {0x0042E, 0x0044E}, -{0x0042F, 0x0044F}, {0x00460, 0x00461}, {0x00462, 0x00463}, {0x00464, 0x00465}, {0x00466, 0x00467}, {0x00468, 0x00469}, -{0x0046A, 0x0046B}, {0x0046C, 0x0046D}, {0x0046E, 0x0046F}, {0x00470, 0x00471}, {0x00472, 0x00473}, {0x00474, 0x00475}, -{0x00476, 0x00477}, {0x00478, 0x00479}, {0x0047A, 0x0047B}, {0x0047C, 0x0047D}, {0x0047E, 0x0047F}, {0x00480, 0x00481}, -{0x0048A, 0x0048B}, {0x0048C, 0x0048D}, {0x0048E, 0x0048F}, {0x00490, 0x00491}, {0x00492, 0x00493}, {0x00494, 0x00495}, -{0x00496, 0x00497}, {0x00498, 0x00499}, {0x0049A, 0x0049B}, {0x0049C, 0x0049D}, {0x0049E, 0x0049F}, {0x004A0, 0x004A1}, -{0x004A2, 0x004A3}, {0x004A4, 0x004A5}, {0x004A6, 0x004A7}, {0x004A8, 0x004A9}, {0x004AA, 0x004AB}, {0x004AC, 0x004AD}, -{0x004AE, 0x004AF}, {0x004B0, 0x004B1}, {0x004B2, 0x004B3}, {0x004B4, 0x004B5}, {0x004B6, 0x004B7}, {0x004B8, 0x004B9}, -{0x004BA, 0x004BB}, {0x004BC, 0x004BD}, {0x004BE, 0x004BF}, {0x004C0, 0x004CF}, {0x004C1, 0x004C2}, {0x004C3, 0x004C4}, -{0x004C5, 0x004C6}, {0x004C7, 0x004C8}, {0x004C9, 0x004CA}, {0x004CB, 0x004CC}, {0x004CD, 0x004CE}, {0x004D0, 0x004D1}, -{0x004D2, 0x004D3}, {0x004D4, 0x004D5}, {0x004D6, 0x004D7}, {0x004D8, 0x004D9}, {0x004DA, 0x004DB}, {0x004DC, 0x004DD}, -{0x004DE, 0x004DF}, {0x004E0, 0x004E1}, {0x004E2, 0x004E3}, {0x004E4, 0x004E5}, {0x004E6, 0x004E7}, {0x004E8, 0x004E9}, -{0x004EA, 0x004EB}, {0x004EC, 0x004ED}, {0x004EE, 0x004EF}, {0x004F0, 0x004F1}, {0x004F2, 0x004F3}, {0x004F4, 0x004F5}, -{0x004F6, 0x004F7}, {0x004F8, 0x004F9}, {0x004FA, 0x004FB}, {0x004FC, 0x004FD}, {0x004FE, 0x004FF}, {0x00500, 0x00501}, -{0x00502, 0x00503}, {0x00504, 0x00505}, {0x00506, 0x00507}, {0x00508, 0x00509}, {0x0050A, 0x0050B}, {0x0050C, 0x0050D}, -{0x0050E, 0x0050F}, {0x00510, 0x00511}, {0x00512, 0x00513}, {0x00514, 0x00515}, {0x00516, 0x00517}, {0x00518, 0x00519}, -{0x0051A, 0x0051B}, {0x0051C, 0x0051D}, {0x0051E, 0x0051F}, {0x00520, 0x00521}, {0x00522, 0x00523}, {0x00524, 0x00525}, -{0x00526, 0x00527}, {0x00528, 0x00529}, {0x0052A, 0x0052B}, {0x0052C, 0x0052D}, {0x0052E, 0x0052F}, {0x00531, 0x00561}, -{0x00532, 0x00562}, {0x00533, 0x00563}, {0x00534, 0x00564}, {0x00535, 0x00565}, {0x00536, 0x00566}, {0x00537, 0x00567}, -{0x00538, 0x00568}, {0x00539, 0x00569}, {0x0053A, 0x0056A}, {0x0053B, 0x0056B}, {0x0053C, 0x0056C}, {0x0053D, 0x0056D}, -{0x0053E, 0x0056E}, {0x0053F, 0x0056F}, {0x00540, 0x00570}, {0x00541, 0x00571}, {0x00542, 0x00572}, {0x00543, 0x00573}, -{0x00544, 0x00574}, {0x00545, 0x00575}, {0x00546, 0x00576}, {0x00547, 0x00577}, {0x00548, 0x00578}, {0x00549, 0x00579}, -{0x0054A, 0x0057A}, {0x0054B, 0x0057B}, {0x0054C, 0x0057C}, {0x0054D, 0x0057D}, {0x0054E, 0x0057E}, {0x0054F, 0x0057F}, -{0x00550, 0x00580}, {0x00551, 0x00581}, {0x00552, 0x00582}, {0x00553, 0x00583}, {0x00554, 0x00584}, {0x00555, 0x00585}, -{0x00556, 0x00586}, {0x010A0, 0x02D00}, {0x010A1, 0x02D01}, {0x010A2, 0x02D02}, {0x010A3, 0x02D03}, {0x010A4, 0x02D04}, -{0x010A5, 0x02D05}, {0x010A6, 0x02D06}, {0x010A7, 0x02D07}, {0x010A8, 0x02D08}, {0x010A9, 0x02D09}, {0x010AA, 0x02D0A}, -{0x010AB, 0x02D0B}, {0x010AC, 0x02D0C}, {0x010AD, 0x02D0D}, {0x010AE, 0x02D0E}, {0x010AF, 0x02D0F}, {0x010B0, 0x02D10}, -{0x010B1, 0x02D11}, {0x010B2, 0x02D12}, {0x010B3, 0x02D13}, {0x010B4, 0x02D14}, {0x010B5, 0x02D15}, {0x010B6, 0x02D16}, -{0x010B7, 0x02D17}, {0x010B8, 0x02D18}, {0x010B9, 0x02D19}, {0x010BA, 0x02D1A}, {0x010BB, 0x02D1B}, {0x010BC, 0x02D1C}, -{0x010BD, 0x02D1D}, {0x010BE, 0x02D1E}, {0x010BF, 0x02D1F}, {0x010C0, 0x02D20}, {0x010C1, 0x02D21}, {0x010C2, 0x02D22}, -{0x010C3, 0x02D23}, {0x010C4, 0x02D24}, {0x010C5, 0x02D25}, {0x010C7, 0x02D27}, {0x010CD, 0x02D2D}, {0x013A0, 0x0AB70}, -{0x013A1, 0x0AB71}, {0x013A2, 0x0AB72}, {0x013A3, 0x0AB73}, {0x013A4, 0x0AB74}, {0x013A5, 0x0AB75}, {0x013A6, 0x0AB76}, -{0x013A7, 0x0AB77}, {0x013A8, 0x0AB78}, {0x013A9, 0x0AB79}, {0x013AA, 0x0AB7A}, {0x013AB, 0x0AB7B}, {0x013AC, 0x0AB7C}, -{0x013AD, 0x0AB7D}, {0x013AE, 0x0AB7E}, {0x013AF, 0x0AB7F}, {0x013B0, 0x0AB80}, {0x013B1, 0x0AB81}, {0x013B2, 0x0AB82}, -{0x013B3, 0x0AB83}, {0x013B4, 0x0AB84}, {0x013B5, 0x0AB85}, {0x013B6, 0x0AB86}, {0x013B7, 0x0AB87}, {0x013B8, 0x0AB88}, -{0x013B9, 0x0AB89}, {0x013BA, 0x0AB8A}, {0x013BB, 0x0AB8B}, {0x013BC, 0x0AB8C}, {0x013BD, 0x0AB8D}, {0x013BE, 0x0AB8E}, -{0x013BF, 0x0AB8F}, {0x013C0, 0x0AB90}, {0x013C1, 0x0AB91}, {0x013C2, 0x0AB92}, {0x013C3, 0x0AB93}, {0x013C4, 0x0AB94}, -{0x013C5, 0x0AB95}, {0x013C6, 0x0AB96}, {0x013C7, 0x0AB97}, {0x013C8, 0x0AB98}, {0x013C9, 0x0AB99}, {0x013CA, 0x0AB9A}, -{0x013CB, 0x0AB9B}, {0x013CC, 0x0AB9C}, {0x013CD, 0x0AB9D}, {0x013CE, 0x0AB9E}, {0x013CF, 0x0AB9F}, {0x013D0, 0x0ABA0}, -{0x013D1, 0x0ABA1}, {0x013D2, 0x0ABA2}, {0x013D3, 0x0ABA3}, {0x013D4, 0x0ABA4}, {0x013D5, 0x0ABA5}, {0x013D6, 0x0ABA6}, -{0x013D7, 0x0ABA7}, {0x013D8, 0x0ABA8}, {0x013D9, 0x0ABA9}, {0x013DA, 0x0ABAA}, {0x013DB, 0x0ABAB}, {0x013DC, 0x0ABAC}, -{0x013DD, 0x0ABAD}, {0x013DE, 0x0ABAE}, {0x013DF, 0x0ABAF}, {0x013E0, 0x0ABB0}, {0x013E1, 0x0ABB1}, {0x013E2, 0x0ABB2}, -{0x013E3, 0x0ABB3}, {0x013E4, 0x0ABB4}, {0x013E5, 0x0ABB5}, {0x013E6, 0x0ABB6}, {0x013E7, 0x0ABB7}, {0x013E8, 0x0ABB8}, -{0x013E9, 0x0ABB9}, {0x013EA, 0x0ABBA}, {0x013EB, 0x0ABBB}, {0x013EC, 0x0ABBC}, {0x013ED, 0x0ABBD}, {0x013EE, 0x0ABBE}, -{0x013EF, 0x0ABBF}, {0x013F0, 0x013F8}, {0x013F1, 0x013F9}, {0x013F2, 0x013FA}, {0x013F3, 0x013FB}, {0x013F4, 0x013FC}, -{0x013F5, 0x013FD}, {0x01C90, 0x010D0}, {0x01C91, 0x010D1}, {0x01C92, 0x010D2}, {0x01C93, 0x010D3}, {0x01C94, 0x010D4}, -{0x01C95, 0x010D5}, {0x01C96, 0x010D6}, {0x01C97, 0x010D7}, {0x01C98, 0x010D8}, {0x01C99, 0x010D9}, {0x01C9A, 0x010DA}, -{0x01C9B, 0x010DB}, {0x01C9C, 0x010DC}, {0x01C9D, 0x010DD}, {0x01C9E, 0x010DE}, {0x01C9F, 0x010DF}, {0x01CA0, 0x010E0}, -{0x01CA1, 0x010E1}, {0x01CA2, 0x010E2}, {0x01CA3, 0x010E3}, {0x01CA4, 0x010E4}, {0x01CA5, 0x010E5}, {0x01CA6, 0x010E6}, -{0x01CA7, 0x010E7}, {0x01CA8, 0x010E8}, {0x01CA9, 0x010E9}, {0x01CAA, 0x010EA}, {0x01CAB, 0x010EB}, {0x01CAC, 0x010EC}, -{0x01CAD, 0x010ED}, {0x01CAE, 0x010EE}, {0x01CAF, 0x010EF}, {0x01CB0, 0x010F0}, {0x01CB1, 0x010F1}, {0x01CB2, 0x010F2}, -{0x01CB3, 0x010F3}, {0x01CB4, 0x010F4}, {0x01CB5, 0x010F5}, {0x01CB6, 0x010F6}, {0x01CB7, 0x010F7}, {0x01CB8, 0x010F8}, -{0x01CB9, 0x010F9}, {0x01CBA, 0x010FA}, {0x01CBD, 0x010FD}, {0x01CBE, 0x010FE}, {0x01CBF, 0x010FF}, {0x01E00, 0x01E01}, -{0x01E02, 0x01E03}, {0x01E04, 0x01E05}, {0x01E06, 0x01E07}, {0x01E08, 0x01E09}, {0x01E0A, 0x01E0B}, {0x01E0C, 0x01E0D}, -{0x01E0E, 0x01E0F}, {0x01E10, 0x01E11}, {0x01E12, 0x01E13}, {0x01E14, 0x01E15}, {0x01E16, 0x01E17}, {0x01E18, 0x01E19}, -{0x01E1A, 0x01E1B}, {0x01E1C, 0x01E1D}, {0x01E1E, 0x01E1F}, {0x01E20, 0x01E21}, {0x01E22, 0x01E23}, {0x01E24, 0x01E25}, -{0x01E26, 0x01E27}, {0x01E28, 0x01E29}, {0x01E2A, 0x01E2B}, {0x01E2C, 0x01E2D}, {0x01E2E, 0x01E2F}, {0x01E30, 0x01E31}, -{0x01E32, 0x01E33}, {0x01E34, 0x01E35}, {0x01E36, 0x01E37}, {0x01E38, 0x01E39}, {0x01E3A, 0x01E3B}, {0x01E3C, 0x01E3D}, -{0x01E3E, 0x01E3F}, {0x01E40, 0x01E41}, {0x01E42, 0x01E43}, {0x01E44, 0x01E45}, {0x01E46, 0x01E47}, {0x01E48, 0x01E49}, -{0x01E4A, 0x01E4B}, {0x01E4C, 0x01E4D}, {0x01E4E, 0x01E4F}, {0x01E50, 0x01E51}, {0x01E52, 0x01E53}, {0x01E54, 0x01E55}, -{0x01E56, 0x01E57}, {0x01E58, 0x01E59}, {0x01E5A, 0x01E5B}, {0x01E5C, 0x01E5D}, {0x01E5E, 0x01E5F}, {0x01E60, 0x01E61}, -{0x01E62, 0x01E63}, {0x01E64, 0x01E65}, {0x01E66, 0x01E67}, {0x01E68, 0x01E69}, {0x01E6A, 0x01E6B}, {0x01E6C, 0x01E6D}, -{0x01E6E, 0x01E6F}, {0x01E70, 0x01E71}, {0x01E72, 0x01E73}, {0x01E74, 0x01E75}, {0x01E76, 0x01E77}, {0x01E78, 0x01E79}, -{0x01E7A, 0x01E7B}, {0x01E7C, 0x01E7D}, {0x01E7E, 0x01E7F}, {0x01E80, 0x01E81}, {0x01E82, 0x01E83}, {0x01E84, 0x01E85}, -{0x01E86, 0x01E87}, {0x01E88, 0x01E89}, {0x01E8A, 0x01E8B}, {0x01E8C, 0x01E8D}, {0x01E8E, 0x01E8F}, {0x01E90, 0x01E91}, -{0x01E92, 0x01E93}, {0x01E94, 0x01E95}, {0x01E9E, 0x000DF}, {0x01EA0, 0x01EA1}, {0x01EA2, 0x01EA3}, {0x01EA4, 0x01EA5}, -{0x01EA6, 0x01EA7}, {0x01EA8, 0x01EA9}, {0x01EAA, 0x01EAB}, {0x01EAC, 0x01EAD}, {0x01EAE, 0x01EAF}, {0x01EB0, 0x01EB1}, -{0x01EB2, 0x01EB3}, {0x01EB4, 0x01EB5}, {0x01EB6, 0x01EB7}, {0x01EB8, 0x01EB9}, {0x01EBA, 0x01EBB}, {0x01EBC, 0x01EBD}, -{0x01EBE, 0x01EBF}, {0x01EC0, 0x01EC1}, {0x01EC2, 0x01EC3}, {0x01EC4, 0x01EC5}, {0x01EC6, 0x01EC7}, {0x01EC8, 0x01EC9}, -{0x01ECA, 0x01ECB}, {0x01ECC, 0x01ECD}, {0x01ECE, 0x01ECF}, {0x01ED0, 0x01ED1}, {0x01ED2, 0x01ED3}, {0x01ED4, 0x01ED5}, -{0x01ED6, 0x01ED7}, {0x01ED8, 0x01ED9}, {0x01EDA, 0x01EDB}, {0x01EDC, 0x01EDD}, {0x01EDE, 0x01EDF}, {0x01EE0, 0x01EE1}, -{0x01EE2, 0x01EE3}, {0x01EE4, 0x01EE5}, {0x01EE6, 0x01EE7}, {0x01EE8, 0x01EE9}, {0x01EEA, 0x01EEB}, {0x01EEC, 0x01EED}, -{0x01EEE, 0x01EEF}, {0x01EF0, 0x01EF1}, {0x01EF2, 0x01EF3}, {0x01EF4, 0x01EF5}, {0x01EF6, 0x01EF7}, {0x01EF8, 0x01EF9}, -{0x01EFA, 0x01EFB}, {0x01EFC, 0x01EFD}, {0x01EFE, 0x01EFF}, {0x01F08, 0x01F00}, {0x01F09, 0x01F01}, {0x01F0A, 0x01F02}, -{0x01F0B, 0x01F03}, {0x01F0C, 0x01F04}, {0x01F0D, 0x01F05}, {0x01F0E, 0x01F06}, {0x01F0F, 0x01F07}, {0x01F18, 0x01F10}, -{0x01F19, 0x01F11}, {0x01F1A, 0x01F12}, {0x01F1B, 0x01F13}, {0x01F1C, 0x01F14}, {0x01F1D, 0x01F15}, {0x01F28, 0x01F20}, -{0x01F29, 0x01F21}, {0x01F2A, 0x01F22}, {0x01F2B, 0x01F23}, {0x01F2C, 0x01F24}, {0x01F2D, 0x01F25}, {0x01F2E, 0x01F26}, -{0x01F2F, 0x01F27}, {0x01F38, 0x01F30}, {0x01F39, 0x01F31}, {0x01F3A, 0x01F32}, {0x01F3B, 0x01F33}, {0x01F3C, 0x01F34}, -{0x01F3D, 0x01F35}, {0x01F3E, 0x01F36}, {0x01F3F, 0x01F37}, {0x01F48, 0x01F40}, {0x01F49, 0x01F41}, {0x01F4A, 0x01F42}, -{0x01F4B, 0x01F43}, {0x01F4C, 0x01F44}, {0x01F4D, 0x01F45}, {0x01F59, 0x01F51}, {0x01F5B, 0x01F53}, {0x01F5D, 0x01F55}, -{0x01F5F, 0x01F57}, {0x01F68, 0x01F60}, {0x01F69, 0x01F61}, {0x01F6A, 0x01F62}, {0x01F6B, 0x01F63}, {0x01F6C, 0x01F64}, -{0x01F6D, 0x01F65}, {0x01F6E, 0x01F66}, {0x01F6F, 0x01F67}, {0x01F88, 0x01F80}, {0x01F89, 0x01F81}, {0x01F8A, 0x01F82}, -{0x01F8B, 0x01F83}, {0x01F8C, 0x01F84}, {0x01F8D, 0x01F85}, {0x01F8E, 0x01F86}, {0x01F8F, 0x01F87}, {0x01F98, 0x01F90}, -{0x01F99, 0x01F91}, {0x01F9A, 0x01F92}, {0x01F9B, 0x01F93}, {0x01F9C, 0x01F94}, {0x01F9D, 0x01F95}, {0x01F9E, 0x01F96}, -{0x01F9F, 0x01F97}, {0x01FA8, 0x01FA0}, {0x01FA9, 0x01FA1}, {0x01FAA, 0x01FA2}, {0x01FAB, 0x01FA3}, {0x01FAC, 0x01FA4}, -{0x01FAD, 0x01FA5}, {0x01FAE, 0x01FA6}, {0x01FAF, 0x01FA7}, {0x01FB8, 0x01FB0}, {0x01FB9, 0x01FB1}, {0x01FBA, 0x01F70}, -{0x01FBB, 0x01F71}, {0x01FBC, 0x01FB3}, {0x01FC8, 0x01F72}, {0x01FC9, 0x01F73}, {0x01FCA, 0x01F74}, {0x01FCB, 0x01F75}, -{0x01FCC, 0x01FC3}, {0x01FD8, 0x01FD0}, {0x01FD9, 0x01FD1}, {0x01FDA, 0x01F76}, {0x01FDB, 0x01F77}, {0x01FE8, 0x01FE0}, -{0x01FE9, 0x01FE1}, {0x01FEA, 0x01F7A}, {0x01FEB, 0x01F7B}, {0x01FEC, 0x01FE5}, {0x01FF8, 0x01F78}, {0x01FF9, 0x01F79}, -{0x01FFA, 0x01F7C}, {0x01FFB, 0x01F7D}, {0x01FFC, 0x01FF3}, {0x02126, 0x003C9}, {0x0212A, 0x0006B}, {0x0212B, 0x000E5}, -{0x02132, 0x0214E}, {0x02160, 0x02170}, {0x02161, 0x02171}, {0x02162, 0x02172}, {0x02163, 0x02173}, {0x02164, 0x02174}, -{0x02165, 0x02175}, {0x02166, 0x02176}, {0x02167, 0x02177}, {0x02168, 0x02178}, {0x02169, 0x02179}, {0x0216A, 0x0217A}, -{0x0216B, 0x0217B}, {0x0216C, 0x0217C}, {0x0216D, 0x0217D}, {0x0216E, 0x0217E}, {0x0216F, 0x0217F}, {0x02183, 0x02184}, -{0x024B6, 0x024D0}, {0x024B7, 0x024D1}, {0x024B8, 0x024D2}, {0x024B9, 0x024D3}, {0x024BA, 0x024D4}, {0x024BB, 0x024D5}, -{0x024BC, 0x024D6}, {0x024BD, 0x024D7}, {0x024BE, 0x024D8}, {0x024BF, 0x024D9}, {0x024C0, 0x024DA}, {0x024C1, 0x024DB}, -{0x024C2, 0x024DC}, {0x024C3, 0x024DD}, {0x024C4, 0x024DE}, {0x024C5, 0x024DF}, {0x024C6, 0x024E0}, {0x024C7, 0x024E1}, -{0x024C8, 0x024E2}, {0x024C9, 0x024E3}, {0x024CA, 0x024E4}, {0x024CB, 0x024E5}, {0x024CC, 0x024E6}, {0x024CD, 0x024E7}, -{0x024CE, 0x024E8}, {0x024CF, 0x024E9}, {0x02C00, 0x02C30}, {0x02C01, 0x02C31}, {0x02C02, 0x02C32}, {0x02C03, 0x02C33}, -{0x02C04, 0x02C34}, {0x02C05, 0x02C35}, {0x02C06, 0x02C36}, {0x02C07, 0x02C37}, {0x02C08, 0x02C38}, {0x02C09, 0x02C39}, -{0x02C0A, 0x02C3A}, {0x02C0B, 0x02C3B}, {0x02C0C, 0x02C3C}, {0x02C0D, 0x02C3D}, {0x02C0E, 0x02C3E}, {0x02C0F, 0x02C3F}, -{0x02C10, 0x02C40}, {0x02C11, 0x02C41}, {0x02C12, 0x02C42}, {0x02C13, 0x02C43}, {0x02C14, 0x02C44}, {0x02C15, 0x02C45}, -{0x02C16, 0x02C46}, {0x02C17, 0x02C47}, {0x02C18, 0x02C48}, {0x02C19, 0x02C49}, {0x02C1A, 0x02C4A}, {0x02C1B, 0x02C4B}, -{0x02C1C, 0x02C4C}, {0x02C1D, 0x02C4D}, {0x02C1E, 0x02C4E}, {0x02C1F, 0x02C4F}, {0x02C20, 0x02C50}, {0x02C21, 0x02C51}, -{0x02C22, 0x02C52}, {0x02C23, 0x02C53}, {0x02C24, 0x02C54}, {0x02C25, 0x02C55}, {0x02C26, 0x02C56}, {0x02C27, 0x02C57}, -{0x02C28, 0x02C58}, {0x02C29, 0x02C59}, {0x02C2A, 0x02C5A}, {0x02C2B, 0x02C5B}, {0x02C2C, 0x02C5C}, {0x02C2D, 0x02C5D}, -{0x02C2E, 0x02C5E}, {0x02C2F, 0x02C5F}, {0x02C60, 0x02C61}, {0x02C62, 0x0026B}, {0x02C63, 0x01D7D}, {0x02C64, 0x0027D}, -{0x02C67, 0x02C68}, {0x02C69, 0x02C6A}, {0x02C6B, 0x02C6C}, {0x02C6D, 0x00251}, {0x02C6E, 0x00271}, {0x02C6F, 0x00250}, -{0x02C70, 0x00252}, {0x02C72, 0x02C73}, {0x02C75, 0x02C76}, {0x02C7E, 0x0023F}, {0x02C7F, 0x00240}, {0x02C80, 0x02C81}, -{0x02C82, 0x02C83}, {0x02C84, 0x02C85}, {0x02C86, 0x02C87}, {0x02C88, 0x02C89}, {0x02C8A, 0x02C8B}, {0x02C8C, 0x02C8D}, -{0x02C8E, 0x02C8F}, {0x02C90, 0x02C91}, {0x02C92, 0x02C93}, {0x02C94, 0x02C95}, {0x02C96, 0x02C97}, {0x02C98, 0x02C99}, -{0x02C9A, 0x02C9B}, {0x02C9C, 0x02C9D}, {0x02C9E, 0x02C9F}, {0x02CA0, 0x02CA1}, {0x02CA2, 0x02CA3}, {0x02CA4, 0x02CA5}, -{0x02CA6, 0x02CA7}, {0x02CA8, 0x02CA9}, {0x02CAA, 0x02CAB}, {0x02CAC, 0x02CAD}, {0x02CAE, 0x02CAF}, {0x02CB0, 0x02CB1}, -{0x02CB2, 0x02CB3}, {0x02CB4, 0x02CB5}, {0x02CB6, 0x02CB7}, {0x02CB8, 0x02CB9}, {0x02CBA, 0x02CBB}, {0x02CBC, 0x02CBD}, -{0x02CBE, 0x02CBF}, {0x02CC0, 0x02CC1}, {0x02CC2, 0x02CC3}, {0x02CC4, 0x02CC5}, {0x02CC6, 0x02CC7}, {0x02CC8, 0x02CC9}, -{0x02CCA, 0x02CCB}, {0x02CCC, 0x02CCD}, {0x02CCE, 0x02CCF}, {0x02CD0, 0x02CD1}, {0x02CD2, 0x02CD3}, {0x02CD4, 0x02CD5}, -{0x02CD6, 0x02CD7}, {0x02CD8, 0x02CD9}, {0x02CDA, 0x02CDB}, {0x02CDC, 0x02CDD}, {0x02CDE, 0x02CDF}, {0x02CE0, 0x02CE1}, -{0x02CE2, 0x02CE3}, {0x02CEB, 0x02CEC}, {0x02CED, 0x02CEE}, {0x02CF2, 0x02CF3}, {0x0A640, 0x0A641}, {0x0A642, 0x0A643}, -{0x0A644, 0x0A645}, {0x0A646, 0x0A647}, {0x0A648, 0x0A649}, {0x0A64A, 0x0A64B}, {0x0A64C, 0x0A64D}, {0x0A64E, 0x0A64F}, -{0x0A650, 0x0A651}, {0x0A652, 0x0A653}, {0x0A654, 0x0A655}, {0x0A656, 0x0A657}, {0x0A658, 0x0A659}, {0x0A65A, 0x0A65B}, -{0x0A65C, 0x0A65D}, {0x0A65E, 0x0A65F}, {0x0A660, 0x0A661}, {0x0A662, 0x0A663}, {0x0A664, 0x0A665}, {0x0A666, 0x0A667}, -{0x0A668, 0x0A669}, {0x0A66A, 0x0A66B}, {0x0A66C, 0x0A66D}, {0x0A680, 0x0A681}, {0x0A682, 0x0A683}, {0x0A684, 0x0A685}, -{0x0A686, 0x0A687}, {0x0A688, 0x0A689}, {0x0A68A, 0x0A68B}, {0x0A68C, 0x0A68D}, {0x0A68E, 0x0A68F}, {0x0A690, 0x0A691}, -{0x0A692, 0x0A693}, {0x0A694, 0x0A695}, {0x0A696, 0x0A697}, {0x0A698, 0x0A699}, {0x0A69A, 0x0A69B}, {0x0A722, 0x0A723}, -{0x0A724, 0x0A725}, {0x0A726, 0x0A727}, {0x0A728, 0x0A729}, {0x0A72A, 0x0A72B}, {0x0A72C, 0x0A72D}, {0x0A72E, 0x0A72F}, -{0x0A732, 0x0A733}, {0x0A734, 0x0A735}, {0x0A736, 0x0A737}, {0x0A738, 0x0A739}, {0x0A73A, 0x0A73B}, {0x0A73C, 0x0A73D}, -{0x0A73E, 0x0A73F}, {0x0A740, 0x0A741}, {0x0A742, 0x0A743}, {0x0A744, 0x0A745}, {0x0A746, 0x0A747}, {0x0A748, 0x0A749}, -{0x0A74A, 0x0A74B}, {0x0A74C, 0x0A74D}, {0x0A74E, 0x0A74F}, {0x0A750, 0x0A751}, {0x0A752, 0x0A753}, {0x0A754, 0x0A755}, -{0x0A756, 0x0A757}, {0x0A758, 0x0A759}, {0x0A75A, 0x0A75B}, {0x0A75C, 0x0A75D}, {0x0A75E, 0x0A75F}, {0x0A760, 0x0A761}, -{0x0A762, 0x0A763}, {0x0A764, 0x0A765}, {0x0A766, 0x0A767}, {0x0A768, 0x0A769}, {0x0A76A, 0x0A76B}, {0x0A76C, 0x0A76D}, -{0x0A76E, 0x0A76F}, {0x0A779, 0x0A77A}, {0x0A77B, 0x0A77C}, {0x0A77D, 0x01D79}, {0x0A77E, 0x0A77F}, {0x0A780, 0x0A781}, -{0x0A782, 0x0A783}, {0x0A784, 0x0A785}, {0x0A786, 0x0A787}, {0x0A78B, 0x0A78C}, {0x0A78D, 0x00265}, {0x0A790, 0x0A791}, -{0x0A792, 0x0A793}, {0x0A796, 0x0A797}, {0x0A798, 0x0A799}, {0x0A79A, 0x0A79B}, {0x0A79C, 0x0A79D}, {0x0A79E, 0x0A79F}, -{0x0A7A0, 0x0A7A1}, {0x0A7A2, 0x0A7A3}, {0x0A7A4, 0x0A7A5}, {0x0A7A6, 0x0A7A7}, {0x0A7A8, 0x0A7A9}, {0x0A7AA, 0x00266}, -{0x0A7AB, 0x0025C}, {0x0A7AC, 0x00261}, {0x0A7AD, 0x0026C}, {0x0A7AE, 0x0026A}, {0x0A7B0, 0x0029E}, {0x0A7B1, 0x00287}, -{0x0A7B2, 0x0029D}, {0x0A7B3, 0x0AB53}, {0x0A7B4, 0x0A7B5}, {0x0A7B6, 0x0A7B7}, {0x0A7B8, 0x0A7B9}, {0x0A7BA, 0x0A7BB}, -{0x0A7BC, 0x0A7BD}, {0x0A7BE, 0x0A7BF}, {0x0A7C0, 0x0A7C1}, {0x0A7C2, 0x0A7C3}, {0x0A7C4, 0x0A794}, {0x0A7C5, 0x00282}, -{0x0A7C6, 0x01D8E}, {0x0A7C7, 0x0A7C8}, {0x0A7C9, 0x0A7CA}, {0x0A7D0, 0x0A7D1}, {0x0A7D6, 0x0A7D7}, {0x0A7D8, 0x0A7D9}, -{0x0A7F5, 0x0A7F6}, {0x0FF21, 0x0FF41}, {0x0FF22, 0x0FF42}, {0x0FF23, 0x0FF43}, {0x0FF24, 0x0FF44}, {0x0FF25, 0x0FF45}, -{0x0FF26, 0x0FF46}, {0x0FF27, 0x0FF47}, {0x0FF28, 0x0FF48}, {0x0FF29, 0x0FF49}, {0x0FF2A, 0x0FF4A}, {0x0FF2B, 0x0FF4B}, -{0x0FF2C, 0x0FF4C}, {0x0FF2D, 0x0FF4D}, {0x0FF2E, 0x0FF4E}, {0x0FF2F, 0x0FF4F}, {0x0FF30, 0x0FF50}, {0x0FF31, 0x0FF51}, -{0x0FF32, 0x0FF52}, {0x0FF33, 0x0FF53}, {0x0FF34, 0x0FF54}, {0x0FF35, 0x0FF55}, {0x0FF36, 0x0FF56}, {0x0FF37, 0x0FF57}, -{0x0FF38, 0x0FF58}, {0x0FF39, 0x0FF59}, {0x0FF3A, 0x0FF5A}, {0x10400, 0x10428}, {0x10401, 0x10429}, {0x10402, 0x1042A}, -{0x10403, 0x1042B}, {0x10404, 0x1042C}, {0x10405, 0x1042D}, {0x10406, 0x1042E}, {0x10407, 0x1042F}, {0x10408, 0x10430}, -{0x10409, 0x10431}, {0x1040A, 0x10432}, {0x1040B, 0x10433}, {0x1040C, 0x10434}, {0x1040D, 0x10435}, {0x1040E, 0x10436}, -{0x1040F, 0x10437}, {0x10410, 0x10438}, {0x10411, 0x10439}, {0x10412, 0x1043A}, {0x10413, 0x1043B}, {0x10414, 0x1043C}, -{0x10415, 0x1043D}, {0x10416, 0x1043E}, {0x10417, 0x1043F}, {0x10418, 0x10440}, {0x10419, 0x10441}, {0x1041A, 0x10442}, -{0x1041B, 0x10443}, {0x1041C, 0x10444}, {0x1041D, 0x10445}, {0x1041E, 0x10446}, {0x1041F, 0x10447}, {0x10420, 0x10448}, -{0x10421, 0x10449}, {0x10422, 0x1044A}, {0x10423, 0x1044B}, {0x10424, 0x1044C}, {0x10425, 0x1044D}, {0x10426, 0x1044E}, -{0x10427, 0x1044F}, {0x104B0, 0x104D8}, {0x104B1, 0x104D9}, {0x104B2, 0x104DA}, {0x104B3, 0x104DB}, {0x104B4, 0x104DC}, -{0x104B5, 0x104DD}, {0x104B6, 0x104DE}, {0x104B7, 0x104DF}, {0x104B8, 0x104E0}, {0x104B9, 0x104E1}, {0x104BA, 0x104E2}, -{0x104BB, 0x104E3}, {0x104BC, 0x104E4}, {0x104BD, 0x104E5}, {0x104BE, 0x104E6}, {0x104BF, 0x104E7}, {0x104C0, 0x104E8}, -{0x104C1, 0x104E9}, {0x104C2, 0x104EA}, {0x104C3, 0x104EB}, {0x104C4, 0x104EC}, {0x104C5, 0x104ED}, {0x104C6, 0x104EE}, -{0x104C7, 0x104EF}, {0x104C8, 0x104F0}, {0x104C9, 0x104F1}, {0x104CA, 0x104F2}, {0x104CB, 0x104F3}, {0x104CC, 0x104F4}, -{0x104CD, 0x104F5}, {0x104CE, 0x104F6}, {0x104CF, 0x104F7}, {0x104D0, 0x104F8}, {0x104D1, 0x104F9}, {0x104D2, 0x104FA}, -{0x104D3, 0x104FB}, {0x10570, 0x10597}, {0x10571, 0x10598}, {0x10572, 0x10599}, {0x10573, 0x1059A}, {0x10574, 0x1059B}, -{0x10575, 0x1059C}, {0x10576, 0x1059D}, {0x10577, 0x1059E}, {0x10578, 0x1059F}, {0x10579, 0x105A0}, {0x1057A, 0x105A1}, -{0x1057C, 0x105A3}, {0x1057D, 0x105A4}, {0x1057E, 0x105A5}, {0x1057F, 0x105A6}, {0x10580, 0x105A7}, {0x10581, 0x105A8}, -{0x10582, 0x105A9}, {0x10583, 0x105AA}, {0x10584, 0x105AB}, {0x10585, 0x105AC}, {0x10586, 0x105AD}, {0x10587, 0x105AE}, -{0x10588, 0x105AF}, {0x10589, 0x105B0}, {0x1058A, 0x105B1}, {0x1058C, 0x105B3}, {0x1058D, 0x105B4}, {0x1058E, 0x105B5}, -{0x1058F, 0x105B6}, {0x10590, 0x105B7}, {0x10591, 0x105B8}, {0x10592, 0x105B9}, {0x10594, 0x105BB}, {0x10595, 0x105BC}, -{0x10C80, 0x10CC0}, {0x10C81, 0x10CC1}, {0x10C82, 0x10CC2}, {0x10C83, 0x10CC3}, {0x10C84, 0x10CC4}, {0x10C85, 0x10CC5}, -{0x10C86, 0x10CC6}, {0x10C87, 0x10CC7}, {0x10C88, 0x10CC8}, {0x10C89, 0x10CC9}, {0x10C8A, 0x10CCA}, {0x10C8B, 0x10CCB}, -{0x10C8C, 0x10CCC}, {0x10C8D, 0x10CCD}, {0x10C8E, 0x10CCE}, {0x10C8F, 0x10CCF}, {0x10C90, 0x10CD0}, {0x10C91, 0x10CD1}, -{0x10C92, 0x10CD2}, {0x10C93, 0x10CD3}, {0x10C94, 0x10CD4}, {0x10C95, 0x10CD5}, {0x10C96, 0x10CD6}, {0x10C97, 0x10CD7}, -{0x10C98, 0x10CD8}, {0x10C99, 0x10CD9}, {0x10C9A, 0x10CDA}, {0x10C9B, 0x10CDB}, {0x10C9C, 0x10CDC}, {0x10C9D, 0x10CDD}, -{0x10C9E, 0x10CDE}, {0x10C9F, 0x10CDF}, {0x10CA0, 0x10CE0}, {0x10CA1, 0x10CE1}, {0x10CA2, 0x10CE2}, {0x10CA3, 0x10CE3}, -{0x10CA4, 0x10CE4}, {0x10CA5, 0x10CE5}, {0x10CA6, 0x10CE6}, {0x10CA7, 0x10CE7}, {0x10CA8, 0x10CE8}, {0x10CA9, 0x10CE9}, -{0x10CAA, 0x10CEA}, {0x10CAB, 0x10CEB}, {0x10CAC, 0x10CEC}, {0x10CAD, 0x10CED}, {0x10CAE, 0x10CEE}, {0x10CAF, 0x10CEF}, -{0x10CB0, 0x10CF0}, {0x10CB1, 0x10CF1}, {0x10CB2, 0x10CF2}, {0x118A0, 0x118C0}, {0x118A1, 0x118C1}, {0x118A2, 0x118C2}, -{0x118A3, 0x118C3}, {0x118A4, 0x118C4}, {0x118A5, 0x118C5}, {0x118A6, 0x118C6}, {0x118A7, 0x118C7}, {0x118A8, 0x118C8}, -{0x118A9, 0x118C9}, {0x118AA, 0x118CA}, {0x118AB, 0x118CB}, {0x118AC, 0x118CC}, {0x118AD, 0x118CD}, {0x118AE, 0x118CE}, -{0x118AF, 0x118CF}, {0x118B0, 0x118D0}, {0x118B1, 0x118D1}, {0x118B2, 0x118D2}, {0x118B3, 0x118D3}, {0x118B4, 0x118D4}, -{0x118B5, 0x118D5}, {0x118B6, 0x118D6}, {0x118B7, 0x118D7}, {0x118B8, 0x118D8}, {0x118B9, 0x118D9}, {0x118BA, 0x118DA}, -{0x118BB, 0x118DB}, {0x118BC, 0x118DC}, {0x118BD, 0x118DD}, {0x118BE, 0x118DE}, {0x118BF, 0x118DF}, {0x16E40, 0x16E60}, -{0x16E41, 0x16E61}, {0x16E42, 0x16E62}, {0x16E43, 0x16E63}, {0x16E44, 0x16E64}, {0x16E45, 0x16E65}, {0x16E46, 0x16E66}, -{0x16E47, 0x16E67}, {0x16E48, 0x16E68}, {0x16E49, 0x16E69}, {0x16E4A, 0x16E6A}, {0x16E4B, 0x16E6B}, {0x16E4C, 0x16E6C}, -{0x16E4D, 0x16E6D}, {0x16E4E, 0x16E6E}, {0x16E4F, 0x16E6F}, {0x16E50, 0x16E70}, {0x16E51, 0x16E71}, {0x16E52, 0x16E72}, -{0x16E53, 0x16E73}, {0x16E54, 0x16E74}, {0x16E55, 0x16E75}, {0x16E56, 0x16E76}, {0x16E57, 0x16E77}, {0x16E58, 0x16E78}, -{0x16E59, 0x16E79}, {0x16E5A, 0x16E7A}, {0x16E5B, 0x16E7B}, {0x16E5C, 0x16E7C}, {0x16E5D, 0x16E7D}, {0x16E5E, 0x16E7E}, -{0x16E5F, 0x16E7F}, {0x1E900, 0x1E922}, {0x1E901, 0x1E923}, {0x1E902, 0x1E924}, {0x1E903, 0x1E925}, {0x1E904, 0x1E926}, -{0x1E905, 0x1E927}, {0x1E906, 0x1E928}, {0x1E907, 0x1E929}, {0x1E908, 0x1E92A}, {0x1E909, 0x1E92B}, {0x1E90A, 0x1E92C}, -{0x1E90B, 0x1E92D}, {0x1E90C, 0x1E92E}, {0x1E90D, 0x1E92F}, {0x1E90E, 0x1E930}, {0x1E90F, 0x1E931}, {0x1E910, 0x1E932}, -{0x1E911, 0x1E933}, {0x1E912, 0x1E934}, {0x1E913, 0x1E935}, {0x1E914, 0x1E936}, {0x1E915, 0x1E937}, {0x1E916, 0x1E938}, -{0x1E917, 0x1E939}, {0x1E918, 0x1E93A}, {0x1E919, 0x1E93B}, {0x1E91A, 0x1E93C}, {0x1E91B, 0x1E93D}, {0x1E91C, 0x1E93E}, -{0x1E91D, 0x1E93F}, {0x1E91E, 0x1E940}, {0x1E91F, 0x1E941}, {0x1E920, 0x1E942}, {0x1E921, 0x1E943}, -}; diff --git a/examples/talk-llama/unicode-data.h b/examples/talk-llama/unicode-data.h index b99500b8f3a..3cccf206854 100644 --- a/examples/talk-llama/unicode-data.h +++ b/examples/talk-llama/unicode-data.h @@ -5,12 +5,13 @@ #include #include -extern const std::vector> unicode_ranges_digit; +extern const std::vector> unicode_ranges_number; extern const std::vector> unicode_ranges_letter; +extern const std::vector> unicode_ranges_separator; extern const std::vector> unicode_ranges_whitespace; extern const std::vector> unicode_ranges_accent_mark; extern const std::vector> unicode_ranges_punctuation; extern const std::vector> unicode_ranges_symbol; extern const std::vector> unicode_ranges_control; -extern const std::multimap unicode_map_nfd; -extern const std::map unicode_map_lowercase; +extern const std::multimap unicode_map_nfd; +extern const std::map unicode_map_lowercase; diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index df8c5f58134..ca03c49d39c 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -5,11 +5,15 @@ #include #include #include +#include #include #include #include +#include #include #include +#include +#include static std::string unicode_cpts_to_utf8(const std::vector & cps) { std::string result; @@ -53,23 +57,22 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) offset += 4; return result; } - throw std::invalid_argument("invalid string"); + throw std::invalid_argument("failed to convert utf8 to codepoint"); } -static std::vector unicode_cpt_to_utf16(uint32_t cp) { - std::vector result; - if (/* 0x0000 <= cp && */ cp <= 0xffff) { - result.emplace_back(cp); - } - else if (0x10000 <= cp && cp <= 0x10ffff) { - result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); - result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); - } - else { - throw std::invalid_argument("invalid cpt"); - } - return result; -} +//static std::vector unicode_cpt_to_utf16(uint32_t cp) { +// std::vector result; +// if (/* 0x0000 <= cp && */ cp <= 0xffff) { +// result.emplace_back(cp); +// return result; +// } +// if (0x10000 <= cp && cp <= 0x10ffff) { +// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); +// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); +// return result; +// } +// throw std::invalid_argument("failed to convert codepoint to utf16"); +//} //static std::vector unicode_cpts_to_utf16(const std::vector & cps) { // std::vector result; @@ -80,56 +83,56 @@ static std::vector unicode_cpt_to_utf16(uint32_t cp) { // return result; //} -static uint32_t cpt_from_utf16(const std::vector & utf16, size_t & offset) { - assert(offset < utf16.size()); - if (((utf16[0] >> 10) << 10) != 0xd800) { - auto result = utf16[offset + 0]; - offset += 1; - return result; - } - - if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { - throw std::invalid_argument("invalid character"); - } - - auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); - offset += 2; - return result; -} +//static uint32_t unicode_cpt_from_utf16(const std::vector & utf16, size_t & offset) { +// assert(offset < utf16.size()); +// if (((utf16[0] >> 10) << 10) != 0xd800) { +// auto result = utf16[offset + 0]; +// offset += 1; +// return result; +// } +// +// if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { +// throw std::invalid_argument("invalid character"); +// } +// +// auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); +// offset += 2; +// return result; +//} //static std::vector unicode_cpts_from_utf16(const std::vector & utf16) { // std::vector result; // size_t offset = 0; // while (offset < utf16.size()) { -// result.push_back(cpt_from_utf16(utf16, offset)); +// result.push_back(unicode_cpt_from_utf16(utf16, offset)); // } // return result; //} static std::unordered_map unicode_cpt_type_map() { std::unordered_map cpt_types; - for (auto p : unicode_ranges_digit) { - for (auto i = p.first; i <= p.second; ++ i) { - cpt_types[i] = CODEPOINT_TYPE_DIGIT; + for (auto p : unicode_ranges_number) { + for (auto i = p.first; i <= p.second; ++i) { + cpt_types[i] = CODEPOINT_TYPE_NUMBER; } } for (auto p : unicode_ranges_letter) { - for (auto i = p.first; i <= p.second; ++ i) { + for (auto i = p.first; i <= p.second; ++i) { cpt_types[i] = CODEPOINT_TYPE_LETTER; } } - for (auto p : unicode_ranges_whitespace) { - for (auto i = p.first; i <= p.second; ++ i) { - cpt_types[i] = CODEPOINT_TYPE_WHITESPACE; + for (auto p : unicode_ranges_separator) { + for (auto i = p.first; i <= p.second; ++i) { + cpt_types[i] = CODEPOINT_TYPE_SEPARATOR; } } for (auto p : unicode_ranges_accent_mark) { - for (auto i = p.first; i <= p.second; ++ i) { + for (auto i = p.first; i <= p.second; ++i) { cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK; } } for (auto p : unicode_ranges_punctuation) { - for (auto i = p.first; i <= p.second; ++ i) { + for (auto i = p.first; i <= p.second; ++i) { cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION; } } @@ -139,7 +142,7 @@ static std::unordered_map unicode_cpt_type_map() { } } for (auto p : unicode_ranges_control) { - for (auto i = p.first; i <= p.second; ++ i) { + for (auto i = p.first; i <= p.second; ++i) { cpt_types[i] = CODEPOINT_TYPE_CONTROL; } } @@ -194,34 +197,395 @@ static std::unordered_map unicode_utf8_to_byte_map() { return map; } +static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { + std::wstring_convert> conv; + return conv.from_bytes(s); +} + +static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) { + std::vector bpe_encoded_words; + for (const auto & word : bpe_words) { + std::string text_utf; + auto utf_word = unicode_cpts_from_utf8(word); + for (size_t i = 0; i < utf_word.size(); ++i) { + text_utf += unicode_cpt_to_utf8(utf_word[i]); + } + + std::string encoded_token; + for (char & c : text_utf) { + encoded_token += unicode_byte_to_utf8(c); + } + bpe_encoded_words.emplace_back(encoded_token); + } + return bpe_encoded_words; +} + +// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ +static std::vector unicode_regex_split_custom_gpt2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + auto _get_cpt = [&] (const size_t pos) -> char32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; + }; + + auto _get_cpt_type = [&] (const size_t pos) -> int { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + //if (len > 0) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + //} + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const char32_t cpt = _get_cpt(pos); + const int cpt_type = _get_cpt_type(pos); + + // regex: 's|'t|'re|'ve|'m|'ll|'d + if (cpt == '\'' && pos+1 < offset_end) { + char32_t cpt_next = _get_cpt(pos+1); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + char32_t cpt_next_next = _get_cpt(pos+2); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); + int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); + // regex: ?\p{L}+ + if (cpt2_type == CODEPOINT_TYPE_LETTER) { + pos += (cpt == ' '); + while (cpt2_type == CODEPOINT_TYPE_LETTER) { + cpt2_type = _get_cpt_type(++pos); + } + _add_token(pos); + continue; + } + // regex: ?\p{N}+ + if (cpt2_type == CODEPOINT_TYPE_NUMBER) { + pos += (cpt == ' '); + while (cpt2_type == CODEPOINT_TYPE_NUMBER) { + cpt2_type = _get_cpt_type(++pos); + } + _add_token(pos); + continue; + } + // regex: ?[^\s\p{L}\p{N}]+ + if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + pos += (cpt == ' '); + while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + cpt2_type = _get_cpt_type(++pos); + cpt2 = _get_cpt(pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { + num_whitespaces++; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + +// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" +static std::vector unicode_regex_split_custom_llama3(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + auto _get_cpt = [&] (const size_t pos) -> char32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; + }; + + auto _get_cpt_type = [&] (const size_t pos) -> int { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + //if (len > 0) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + //} + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const char32_t cpt = _get_cpt(pos); + const int cpt_type = _get_cpt_type(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + char32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct? + if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) { + if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters + pos++; + while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N}{1,3} + if (cpt_type == CODEPOINT_TYPE_NUMBER) { + size_t ini = pos; + while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) { + if (++pos - ini >= 3 ) { + _add_token(pos); + ini = pos; + } + } + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{N}]+[\r\n]* + char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); + int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); + if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + pos += (cpt == ' '); + while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { + cpt2_type = _get_cpt_type(++pos); + cpt2 = _get_cpt(pos); + } + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { + char32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + +// use std::wregex to split the text +static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) { + std::wregex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr); + std::wcregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::wcmatch match = *it; + if (match.position() > start_idx) { + bpe_offsets.emplace_back(match.position() - start_idx); + } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t) offset) { + bpe_offsets.emplace_back(offset - start_idx); + } + start += offset; + } + + return bpe_offsets; +} + +// use std::regex to split the text +static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { + std::regex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); + std::cregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::cmatch match = *it; + if (match.position() > start_idx) { + bpe_offsets.emplace_back(match.position() - start_idx); + } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t) offset) { + bpe_offsets.emplace_back(offset - start_idx); + } + start += offset; + } + + return bpe_offsets; +} + +static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { + std::vector bpe_offsets; + + if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { + bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets); + } else if ( + regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" || + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + + bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } + + return bpe_offsets; +} + // // interface // std::string unicode_cpt_to_utf8(uint32_t cp) { std::string result; + if (/* 0x00 <= cp && */ cp <= 0x7f) { result.push_back(cp); + return result; } - else if (0x80 <= cp && cp <= 0x7ff) { + if (0x80 <= cp && cp <= 0x7ff) { result.push_back(0xc0 | ((cp >> 6) & 0x1f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else if (0x800 <= cp && cp <= 0xffff) { + if (0x800 <= cp && cp <= 0xffff) { result.push_back(0xe0 | ((cp >> 12) & 0x0f)); result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else if (0x10000 <= cp && cp <= 0x10ffff) { + if (0x10000 <= cp && cp <= 0x10ffff) { result.push_back(0xf0 | ((cp >> 18) & 0x07)); result.push_back(0x80 | ((cp >> 12) & 0x3f)); result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else { - throw std::invalid_argument("invalid codepoint"); - } - return result; + + throw std::invalid_argument("invalid codepoint"); } std::vector unicode_cpts_normalize_nfd(const std::vector & cpts) { @@ -261,6 +625,19 @@ int unicode_cpt_type(const std::string & utf8) { return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset)); } +bool unicode_cpt_is_whitespace(uint32_t cp) { + static const std::unordered_set is_whitespace = [] { + std::unordered_set is_whitespace; + for (auto p : unicode_ranges_whitespace) { + for (auto i = p.first; i <= p.second; ++i) { + is_whitespace.insert(i); + } + } + return is_whitespace; + }(); + return (bool)is_whitespace.count(cp); +} + std::string unicode_byte_to_utf8(uint8_t byte) { static std::unordered_map map = unicode_byte_to_utf8_map(); return map.at(byte); @@ -275,3 +652,167 @@ char32_t unicode_tolower(char32_t cp) { auto it = unicode_map_lowercase.find(cp); return it == unicode_map_lowercase.end() ? cp : it->second; } + +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { + // unicode categories + static const std::map k_ucat_enum = { + { "\\p{N}", CODEPOINT_TYPE_NUMBER }, + { "\\p{L}", CODEPOINT_TYPE_LETTER }, + { "\\p{P}", CODEPOINT_TYPE_PUNCTUATION }, + }; + + static const std::map k_ucat_cpt = { + { CODEPOINT_TYPE_NUMBER, 0xD1 }, + { CODEPOINT_TYPE_LETTER, 0xD2 }, + { CODEPOINT_TYPE_PUNCTUATION, 0xD3 }, + }; + + static const std::map k_ucat_map = { + { CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9 + { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z + { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + }; + + // compute collapsed codepoints only if needed by at least one regex + bool need_collapse = false; + for (auto & regex_expr : regex_exprs) { + // search for unicode categories + for (const auto & ucat : k_ucat_enum) { + if (std::string::npos != regex_expr.find(ucat.first)) { + need_collapse = true; + break; + } + } + } + + const auto cpts = unicode_cpts_from_utf8(text); + + // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte + // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 + std::string text_collapsed; + if (need_collapse) { + // collapse all unicode categories + text_collapsed.resize(cpts.size()); + + for (size_t i = 0; i < cpts.size(); ++i) { + // keep single-byte codepoints as is + if (cpts[i] < 128) { + text_collapsed[i] = cpts[i]; + continue; + } + + const int cpt_type = unicode_cpt_type(cpts[i]); + + if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) { + text_collapsed[i] = k_ucat_cpt.at(cpt_type); + } else { + text_collapsed[i] = (char) 0xD0; // fallback + } + } + } + + std::vector bpe_offsets = { cpts.size() }; + + for (auto & regex_expr : regex_exprs) { + // first, see if we have an efficient custom regex implementation + auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); + + if (!tmp.empty()) { + bpe_offsets = std::move(tmp); + continue; + } + + // fallback to general-purpose std::regex / std::wregex + try { + // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category + // with the corresponding collapsed representation + bool use_collapsed = false; + for (auto & ucat : k_ucat_enum) { + if (std::string::npos != regex_expr.find(ucat.first)) { + use_collapsed = true; + break; + } + } + + if (use_collapsed) { + // sanity-check that the original regex does not contain any non-ASCII characters + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); + for (size_t i = 0; i < cpts_regex.size(); ++i) { + if (cpts_regex[i] >= 128) { + throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); + } + } + + // generate a collapsed representation of the regex + std::string regex_expr_collapsed; + + // track if we are inside [], because nested [] are not allowed + bool inside = false; + for (size_t i = 0; i < regex_expr.size(); ++i) { + if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) { + regex_expr_collapsed += '['; + inside = true; + continue; + } + + if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { + regex_expr_collapsed += ']'; + inside = false; + continue; + } + + if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && + regex_expr[i + 1] == 'p' && + regex_expr[i + 2] == '{' && + regex_expr[i + 4] == '}') { + const std::string pat = regex_expr.substr(i, 5); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { + regex_expr_collapsed += '['; + } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { + regex_expr_collapsed += ']'; + } + i += 4; + continue; + } + } + + regex_expr_collapsed += regex_expr[i]; + } + + //printf("text_collapsed: %s\n", text_collapsed.c_str()); + //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); + bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); + } else { + // no unicode category used, we can use std::wregex directly + const std::wstring wtext = unicode_wstring_from_utf8(text); + const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + + //printf("text: %s\n", text.c_str()); + //printf("regex_expr: %s\n", regex_expr.c_str()); + bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); + } + } catch (std::regex_error & e) { + fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); + fprintf(stderr, "Regex error: %s\n", e.what()); + throw std::runtime_error("Failed to process regex"); + } + } + + std::vector bpe_words; + bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size + + size_t start = 0; + for (size_t & offset : bpe_offsets) { + bpe_words.emplace_back(); + for (size_t i = start; i < start + offset; ++i) { + bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); + } + start += offset; + } + + return unicode_byte_encoding_process(bpe_words); +} diff --git a/examples/talk-llama/unicode.h b/examples/talk-llama/unicode.h index 6a0be393a46..d6a14d470bf 100644 --- a/examples/talk-llama/unicode.h +++ b/examples/talk-llama/unicode.h @@ -5,9 +5,9 @@ #include #define CODEPOINT_TYPE_UNIDENTIFIED 0 -#define CODEPOINT_TYPE_DIGIT 1 +#define CODEPOINT_TYPE_NUMBER 1 #define CODEPOINT_TYPE_LETTER 2 -#define CODEPOINT_TYPE_WHITESPACE 3 +#define CODEPOINT_TYPE_SEPARATOR 3 #define CODEPOINT_TYPE_ACCENT_MARK 4 #define CODEPOINT_TYPE_PUNCTUATION 5 #define CODEPOINT_TYPE_SYMBOL 6 @@ -21,8 +21,11 @@ std::vector unicode_cpts_normalize_nfd(const std::vector & c int unicode_cpt_type(uint32_t cp); int unicode_cpt_type(const std::string & utf8); +bool unicode_cpt_is_whitespace(uint32_t cp); + std::string unicode_byte_to_utf8(uint8_t byte); uint8_t unicode_utf8_to_byte(const std::string & utf8); -// simple tolower that only implements one-to-one mapping, not one-to-many char32_t unicode_tolower(char32_t cp); + +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); From fbeb80b5f0d955ef1ba986fe1f2ee7c660c31f4f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 May 2024 20:55:57 +0300 Subject: [PATCH 060/100] ggml : remove oboslete alibi code (skipme) (#0) --- ggml-cuda/alibi.cu | 63 --------------------------------------------- ggml-cuda/alibi.cuh | 5 ---- 2 files changed, 68 deletions(-) delete mode 100644 ggml-cuda/alibi.cu delete mode 100644 ggml-cuda/alibi.cuh diff --git a/ggml-cuda/alibi.cu b/ggml-cuda/alibi.cu deleted file mode 100644 index 6c7f1fd9562..00000000000 --- a/ggml-cuda/alibi.cu +++ /dev/null @@ -1,63 +0,0 @@ -#include "alibi.cuh" - -static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, - const int n_heads_log2_floor, const float m0, const float m1) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - - if (col >= ncols) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - - const int k = row/k_rows; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - dst[i] = col * m_k + x[i]; -} - -static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, - const int k_rows, const int n_heads_log2_floor, const float m0, - const float m1, cudaStream_t stream) { - const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); - const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); - const dim3 block_nums(num_blocks_x, nrows, 1); - alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); -} - -void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - //GGML_ASSERT(ne01 + n_past == ne00); - GGML_ASSERT(n_head == ne02); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream); -} diff --git a/ggml-cuda/alibi.cuh b/ggml-cuda/alibi.cuh deleted file mode 100644 index 630adfc7f63..00000000000 --- a/ggml-cuda/alibi.cuh +++ /dev/null @@ -1,5 +0,0 @@ -#include "common.cuh" - -#define CUDA_ALIBI_BLOCK_SIZE 32 - -void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 9506267ce52e85e08e6c89a8cc7cee7090bf26b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 May 2024 20:36:31 +0300 Subject: [PATCH 061/100] ggml : try fix ppc64 (#0) --- ggml-quants.c | 2 +- ggml.c | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml-quants.c b/ggml-quants.c index f711bd01341..9e62a3f3261 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -11425,7 +11425,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); qh += 2; - vector bool short vsel = vec_cmpge(qxh, (vector signed short)v0); + vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0); vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); diff --git a/ggml.c b/ggml.c index b96a82a4151..d443a9b42ce 100644 --- a/ggml.c +++ b/ggml.c @@ -1306,6 +1306,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO #define GGML_F16_VEC_SET1 GGML_F32x4_SET1 #define GGML_F16_VEC_FMA GGML_F32x4_FMA +#define GGML_F16_VEC_ADD GGML_F32x4_ADD +#define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE // Use vec_xl, not vec_ld, in case the load address is not aligned. #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ From 2c81e6fd51ee31ac3b9c3e9f4aa44780ff431b5f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 May 2024 10:41:33 +0300 Subject: [PATCH 062/100] whisper : remove old flash attn code (#0) --- whisper.cpp | 42 +++++------------------------------------- 1 file changed, 5 insertions(+), 37 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index f31309ed3b4..f3daf5b6d60 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -147,7 +147,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text } \ } while (0) -//#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 @@ -1951,32 +1950,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // ------ -#ifdef WHISPER_USE_FLASH_ATTN - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)); - - struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); -#else struct ggml_tensor * Q = ggml_permute(ctx0, ggml_cpy(ctx0, @@ -1994,9 +1967,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); struct ggml_tensor * V = ggml_cpy(ctx0, @@ -2009,7 +1980,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); -#endif + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); cur = ggml_cpy(ctx0, @@ -2323,6 +2294,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_input(KQ_mask); + struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2406,12 +2379,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask_f16, 1.0f, 0.0f); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, From e6acaf9d91734aa9803a6eae3a992e6ce914d4e0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 May 2024 11:01:07 +0300 Subject: [PATCH 063/100] metal : tune soft_max number of threads (#0) --- ggml-metal.m | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 28dec762a8a..bfa352c3a9a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1378,7 +1378,7 @@ static enum ggml_status ggml_metal_graph_compute( const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (ne00%4 == 0) { - while (nth < ne00/4 && nth < 256) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } if (use_f16) { @@ -1387,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; } } else { - while (nth < ne00 && nth < 1024) { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } if (use_f16) { From 7705dc52da6fde89008dd172a400ae6885353241 Mon Sep 17 00:00:00 2001 From: mashizora <30516315+mashizora@users.noreply.github.com> Date: Mon, 13 May 2024 16:55:32 +0800 Subject: [PATCH 064/100] main : fix double quote escaping in csv output (#2090) --- examples/main/main.cpp | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 15d8c8a83b6..6a3db73d87a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -471,6 +471,38 @@ char *escape_double_quotes_and_backslashes(const char *str) { return escaped; } +// double quote should be escaped by another double quote. (rfc4180) +char *escape_double_quotes_in_csv(const char *str) { + if (str == NULL) { + return NULL; + } + + size_t escaped_length = strlen(str) + 1; + + for (size_t i = 0; str[i] != '\0'; i++) { + if (str[i] == '"') { + escaped_length++; + } + } + + char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed + if (escaped == NULL) { + return NULL; + } + + size_t pos = 0; + for (size_t i = 0; str[i] != '\0'; i++) { + if (str[i] == '"') { + escaped[pos++] = '"'; + } + escaped[pos++] = str[i]; + } + + // no need to set zero due to calloc() being used prior + + return escaped; +} + bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { std::ofstream fout(fname); if (!fout.is_open()) { @@ -492,7 +524,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_ const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - char * text_escaped = escape_double_quotes_and_backslashes(text); + char * text_escaped = escape_double_quotes_in_csv(text); //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. fout << 10 * t0 << "," << 10 * t1 << ","; From b6bbce4ae98c32fa6b71038d3f97dbeea29acb55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xingchen=20Song=28=E5=AE=8B=E6=98=9F=E8=BE=B0=29?= Date: Mon, 13 May 2024 19:29:39 +0800 Subject: [PATCH 065/100] cmake : fix json INTERFACE library (#2069) --- examples/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 104482f2133..3b493e3db7e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -58,8 +58,8 @@ if (WHISPER_SDL2) endif() # add json lib -add_library(json_cpp INTERFACE json.hpp) -set_target_properties(json_cpp PROPERTIES FOLDER "libs") +add_library(json_cpp INTERFACE) +target_include_directories(json_cpp INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) # examples From e93081f83f20a9142ba9502ef45bf57fb1ad0a62 Mon Sep 17 00:00:00 2001 From: zhangjixiong Date: Mon, 13 May 2024 19:30:03 +0800 Subject: [PATCH 066/100] whisper.android : update example, add field to print timestamp (#2072) --- .../ui/main/MainScreenViewModel.kt | 2 +- .../java/com/whispercpp/whisper/LibWhisper.kt | 27 +++++++++++++++++-- .../lib/src/main/jni/whisper/jni.c | 16 +++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt b/examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt index d614ce3338e..845b023a3fb 100644 --- a/examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt +++ b/examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt @@ -145,7 +145,7 @@ class MainScreenViewModel(private val application: Application) : ViewModel() { val start = System.currentTimeMillis() val text = whisperContext?.transcribeData(data) val elapsed = System.currentTimeMillis() - start - printMessage("Done ($elapsed ms): $text\n") + printMessage("Done ($elapsed ms): \n$text\n") } catch (e: Exception) { Log.w(LOG_TAG, e) printMessage("${e.localizedMessage}\n") diff --git a/examples/whisper.android/lib/src/main/java/com/whispercpp/whisper/LibWhisper.kt b/examples/whisper.android/lib/src/main/java/com/whispercpp/whisper/LibWhisper.kt index 513202fa689..37ae0e9dd4a 100644 --- a/examples/whisper.android/lib/src/main/java/com/whispercpp/whisper/LibWhisper.kt +++ b/examples/whisper.android/lib/src/main/java/com/whispercpp/whisper/LibWhisper.kt @@ -16,7 +16,7 @@ class WhisperContext private constructor(private var ptr: Long) { Executors.newSingleThreadExecutor().asCoroutineDispatcher() ) - suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) { + suspend fun transcribeData(data: FloatArray, printTimestamp: Boolean = true): String = withContext(scope.coroutineContext) { require(ptr != 0L) val numThreads = WhisperCpuConfig.preferredThreadCount Log.d(LOG_TAG, "Selecting $numThreads threads") @@ -24,7 +24,13 @@ class WhisperContext private constructor(private var ptr: Long) { val textCount = WhisperLib.getTextSegmentCount(ptr) return@withContext buildString { for (i in 0 until textCount) { - append(WhisperLib.getTextSegment(ptr, i)) + if (printTimestamp) { + val textTimestamp = "[${toTimestamp(WhisperLib.getTextSegmentT0(ptr, i))} --> ${toTimestamp(WhisperLib.getTextSegmentT1(ptr, i))}]" + val textSegment = WhisperLib.getTextSegment(ptr, i) + append("$textTimestamp: $textSegment\n") + } else { + append(WhisperLib.getTextSegment(ptr, i)) + } } } } @@ -131,12 +137,29 @@ private class WhisperLib { external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray) external fun getTextSegmentCount(contextPtr: Long): Int external fun getTextSegment(contextPtr: Long, index: Int): String + external fun getTextSegmentT0(contextPtr: Long, index: Int): Long + external fun getTextSegmentT1(contextPtr: Long, index: Int): Long external fun getSystemInfo(): String external fun benchMemcpy(nthread: Int): String external fun benchGgmlMulMat(nthread: Int): String } } +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +private fun toTimestamp(t: Long, comma: Boolean = false): String { + var msec = t * 10 + val hr = msec / (1000 * 60 * 60) + msec -= hr * (1000 * 60 * 60) + val min = msec / (1000 * 60) + msec -= min * (1000 * 60) + val sec = msec / 1000 + msec -= sec * 1000 + + val delimiter = if (comma) "," else "." + return String.format("%02d:%02d:%02d%s%03d", hr, min, sec, delimiter, msec) +} + private fun isArmEabiV7a(): Boolean { return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a") } diff --git a/examples/whisper.android/lib/src/main/jni/whisper/jni.c b/examples/whisper.android/lib/src/main/jni/whisper/jni.c index 7f9d724617d..da54c8140dc 100644 --- a/examples/whisper.android/lib/src/main/jni/whisper/jni.c +++ b/examples/whisper.android/lib/src/main/jni/whisper/jni.c @@ -212,6 +212,22 @@ Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegment( return string; } +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegmentT0( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + return whisper_full_get_segment_t0(context, index); +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegmentT1( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + return whisper_full_get_segment_t1(context, index); +} + JNIEXPORT jstring JNICALL Java_com_whispercpp_whisper_WhisperLib_00024Companion_getSystemInfo( JNIEnv *env, jobject thiz From 2b434c449ef091db93b2b644df8b3a2912632d77 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 May 2024 14:43:43 +0300 Subject: [PATCH 067/100] whisper : switch back to F32 mask (#0) --- whisper.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index f3daf5b6d60..bdcf3de40e2 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2294,8 +2294,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_input(KQ_mask); - struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); - // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2379,7 +2377,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask_f16, 1.0f, 0.0f); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, @@ -2873,8 +2871,8 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector int i = ith; // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist - assert( n_fft == 1 + (frame_size / 2) ); - + assert(n_fft == 1 + (frame_size / 2)); + // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { const int offset = i * frame_step; From f141b2b938ebf56384dbbe04e595e1b5a9d1104c Mon Sep 17 00:00:00 2001 From: Daniel Ziegenberg Date: Mon, 13 May 2024 13:59:44 +0200 Subject: [PATCH 068/100] main : add options for temperature control (#2088) Add two options: ``` -tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1 -tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1 ``` The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit. Signed-off-by: Daniel Ziegenberg --- examples/main/main.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6a3db73d87a..bb1931869d3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -44,6 +44,8 @@ struct whisper_params { float entropy_thold = 2.40f; float logprob_thold = -1.00f; float grammar_penalty = 100.0f; + float temperature = 0.0f; + float temperature_inc = 0.2f; bool speed_up = false; bool debug_mode = false; @@ -133,6 +135,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); } + else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); } // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } @@ -198,6 +202,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); @@ -1107,7 +1113,9 @@ int main(int argc, char ** argv) { wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; - wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc; + wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc; + wparams.temperature = params.temperature; + wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; From 0bb05b113dc6996255f685906ade0b8c60caac76 Mon Sep 17 00:00:00 2001 From: Daniel Ziegenberg Date: Mon, 13 May 2024 14:00:19 +0200 Subject: [PATCH 069/100] main : dont print timings with --no-prints (#2108) Signed-off-by: Daniel Ziegenberg --- examples/main/main.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bb1931869d3..d11c1c3f81b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1233,7 +1233,9 @@ int main(int argc, char ** argv) { } } - whisper_print_timings(ctx); + if (!params.no_prints) { + whisper_print_timings(ctx); + } whisper_free(ctx); return 0; From 1da5edcde074e5446ed5ad5c21fa6672b0c70bd6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 May 2024 15:09:35 +0300 Subject: [PATCH 070/100] cmake : fix metal embed sources path (#2110) --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b34b3768336..1017e53c306 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,8 +173,8 @@ if (APPLE) enable_language(ASM) set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_EMBED_LIBRARY) - set(METALLIB_SOURCE "${CMAKE_SOURCE_DIR}/ggml-metal.metal") - set(COMMON_HEADER "${CMAKE_SOURCE_DIR}/ggml-common.h") + set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + set(COMMON_HEADER "${CMAKE_CURRENT_SOURCE_DIR}/ggml-common.h") file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") set(EMBED_METALLIB_ASSEMBLY "${CMAKE_BINARY_DIR}/autogenerated/ggml-embed-metallib.s") From 17fa62d3d35e7b021b822adcd370bb620e97e282 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mark=20Karpel=C3=A8s?= Date: Mon, 13 May 2024 21:13:19 +0900 Subject: [PATCH 071/100] js : remove un-needed request header from fetchRemote (#2119) --- examples/helpers.js | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/helpers.js b/examples/helpers.js index 23f18ab5697..423bada60ba 100644 --- a/examples/helpers.js +++ b/examples/helpers.js @@ -34,9 +34,6 @@ async function fetchRemote(url, cbProgress, cbPrint) { url, { method: 'GET', - headers: { - 'Content-Type': 'application/octet-stream', - }, } ); From 30f73109b8a321da828d13f568c4f679ec32de9b Mon Sep 17 00:00:00 2001 From: valVk Date: Mon, 13 May 2024 15:15:43 +0300 Subject: [PATCH 072/100] node : add additional params (#2000) * Add additional params to addon.node * Add comma_in_time as parameter * Fix tests --- examples/addon.node/__test__/whisper.spec.js | 3 +++ examples/addon.node/addon.cpp | 21 ++++++++++++++++---- examples/addon.node/index.js | 8 ++++++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index c0367a8c587..2f264fd3af5 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -12,6 +12,9 @@ const whisperParamsMock = { model: path.join(__dirname, "../../../models/ggml-base.en.bin"), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), use_gpu: true, + no_prints: true, + comma_in_time: false, + translate: true, no_timestamps: false, }; diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 8988f9edc1d..85576311ca9 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -36,7 +36,9 @@ struct whisper_params { bool print_colors = false; bool print_progress = false; bool no_timestamps = false; + bool no_prints = false; bool use_gpu = true; + bool comma_in_time = true; std::string language = "en"; std::string prompt; @@ -120,7 +122,14 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } +void cb_log_disable(enum ggml_log_level, const char *, void *) {} + int run(whisper_params ¶ms, std::vector> &result) { + + if (params.no_prints) { + whisper_log_set(cb_log_disable, NULL); + } + if (params.fname_inp.empty()) { fprintf(stderr, "error: no input files specified\n"); return 2; @@ -155,14 +164,14 @@ int run(whisper_params ¶ms, std::vector> &result) { } // print system information - { + if (!params.no_prints) { fprintf(stderr, "\n"); fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); } // print some info about the processing - { + if (!params.no_prints) { fprintf(stderr, "\n"); if (!whisper_is_multilingual(ctx)) { if (params.language != "en" || params.translate) { @@ -248,8 +257,8 @@ int run(whisper_params ¶ms, std::vector> &result) { const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - result[i].emplace_back(to_timestamp(t0, true)); - result[i].emplace_back(to_timestamp(t1, true)); + result[i].emplace_back(to_timestamp(t0, params.comma_in_time)); + result[i].emplace_back(to_timestamp(t1, params.comma_in_time)); result[i].emplace_back(text); } @@ -300,13 +309,17 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { std::string model = whisper_params.Get("model").As(); std::string input = whisper_params.Get("fname_inp").As(); bool use_gpu = whisper_params.Get("use_gpu").As(); + bool no_prints = whisper_params.Get("no_prints").As(); bool no_timestamps = whisper_params.Get("no_timestamps").As(); + bool comma_in_time = whisper_params.Get("comma_in_time").As(); params.language = language; params.model = model; params.fname_inp.emplace_back(input); params.use_gpu = use_gpu; + params.no_prints = no_prints; params.no_timestamps = no_timestamps; + params.comma_in_time = comma_in_time; Napi::Function callback = info[1].As(); Worker* worker = new Worker(callback, params); diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 9156a52de07..90bd6fc2ff4 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -10,8 +10,11 @@ const whisperAsync = promisify(whisper); const whisperParams = { language: "en", model: path.join(__dirname, "../../models/ggml-base.en.bin"), - fname_inp: "../../samples/jfk.wav", + fname_inp: path.join(__dirname, "../../samples/jfk.wav"), use_gpu: true, + no_prints: true, + comma_in_time: false, + translate: true, no_timestamps: false, }; @@ -34,5 +37,6 @@ for (const key in params) { console.log("whisperParams =", whisperParams); whisperAsync(whisperParams).then((result) => { - console.log(`Result from whisper: ${result}`); + console.log(); + console.log(result); }); From 2ced6f07422ccdce4340bd28966d530b31d245a4 Mon Sep 17 00:00:00 2001 From: aldorof Date: Mon, 13 May 2024 08:18:43 -0400 Subject: [PATCH 073/100] cmake : fix HIP/ROCm build (#2102) --- CMakeLists.txt | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1017e53c306..cdffbcaa1c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -370,16 +370,18 @@ if (WHISPER_HIPBLAS) if (${hipblas_FOUND} AND ${hip_FOUND}) message(STATUS "HIP and hipBLAS found") + set(GGML_HEADERS_ROCM "ggml-cuda.h") + + file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu") + list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu") + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA) - add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) - set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON) - set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) - target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) if (WHISPER_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} hip::device PUBLIC hip::host roc::rocblas roc::hipblas) else() message(FATAL_ERROR "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") endif() @@ -647,12 +649,15 @@ add_library(${TARGET} ${GGML_SOURCES_METAL} ${GGML_SOURCES_CUDA} ${GGML_SOURCES_OPENCL} - ${GGML_SOURCES_SYCL} - ${GGML_HEADERS_SYCL} + ${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL} + ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} whisper.h whisper.cpp ) +include_directories ( + . +) # Set the version numbers set_target_properties(whisper PROPERTIES VERSION ${PROJECT_VERSION} From 3928dbd2068e99eeb4baf36f7affc4acd1e31602 Mon Sep 17 00:00:00 2001 From: Pedro Probst Date: Mon, 13 May 2024 09:22:23 -0300 Subject: [PATCH 074/100] node : add audio_ctx and audio buffer params (#2123) * node : add audio_ctx param * node : support passing audio buffer directly * node : parse audio_ctx in index.js --------- Co-authored-by: Georgi Gerganov --- examples/addon.node/__test__/whisper.spec.js | 1 + examples/addon.node/addon.cpp | 47 ++++++++++++++++---- examples/addon.node/index.js | 9 +++- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index 2f264fd3af5..9ba86b62985 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -16,6 +16,7 @@ const whisperParamsMock = { comma_in_time: false, translate: true, no_timestamps: false, + audio_ctx: 0, }; describe("Run whisper.node", () => { diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 85576311ca9..8125e5dda4c 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -19,6 +19,7 @@ struct whisper_params { int32_t max_len = 0; int32_t best_of = 5; int32_t beam_size = -1; + int32_t audio_ctx = 0; float word_thold = 0.01f; float entropy_thold = 2.4f; @@ -46,6 +47,8 @@ struct whisper_params { std::vector fname_inp = {}; std::vector fname_out = {}; + + std::vector pcmf32 = {}; // mono-channel F32 PCM }; struct whisper_print_user_data { @@ -125,13 +128,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper void cb_log_disable(enum ggml_log_level, const char *, void *) {} int run(whisper_params ¶ms, std::vector> &result) { - if (params.no_prints) { whisper_log_set(cb_log_disable, NULL); } - if (params.fname_inp.empty()) { - fprintf(stderr, "error: no input files specified\n"); + if (params.fname_inp.empty() && params.pcmf32.empty()) { + fprintf(stderr, "error: no input files or audio buffer specified\n"); return 2; } @@ -151,6 +153,14 @@ int run(whisper_params ¶ms, std::vector> &result) { return 3; } + // if params.pcmf32 is provided, set params.fname_inp to "buffer" + // this is simpler than further modifications in the code + if (!params.pcmf32.empty()) { + fprintf(stderr, "info: using audio buffer as input\n"); + params.fname_inp.clear(); + params.fname_inp.emplace_back("buffer"); + } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; @@ -158,9 +168,14 @@ int run(whisper_params ¶ms, std::vector> &result) { std::vector pcmf32; // mono-channel F32 PCM std::vector> pcmf32s; // stereo-channel F32 PCM - if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { - fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); - continue; + // read the input audio file if params.pcmf32 is not provided + if (params.pcmf32.empty()) { + if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { + fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); + continue; + } + } else { + pcmf32 = params.pcmf32; } // print system information @@ -180,12 +195,13 @@ int run(whisper_params ¶ms, std::vector> &result) { fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, params.n_processors, params.language.c_str(), params.translate ? "translate" : "transcribe", - params.no_timestamps ? 0 : 1); + params.no_timestamps ? 0 : 1, + params.audio_ctx); fprintf(stderr, "\n"); } @@ -212,6 +228,7 @@ int run(whisper_params ¶ms, std::vector> &result) { wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; @@ -311,14 +328,28 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { bool use_gpu = whisper_params.Get("use_gpu").As(); bool no_prints = whisper_params.Get("no_prints").As(); bool no_timestamps = whisper_params.Get("no_timestamps").As(); + int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); bool comma_in_time = whisper_params.Get("comma_in_time").As(); + Napi::Value pcmf32Value = whisper_params.Get("pcmf32"); + std::vector pcmf32_vec; + if (pcmf32Value.IsTypedArray()) { + Napi::Float32Array pcmf32 = pcmf32Value.As(); + size_t length = pcmf32.ElementLength(); + pcmf32_vec.reserve(length); + for (size_t i = 0; i < length; i++) { + pcmf32_vec.push_back(pcmf32[i]); + } + } + params.language = language; params.model = model; params.fname_inp.emplace_back(input); params.use_gpu = use_gpu; params.no_prints = no_prints; params.no_timestamps = no_timestamps; + params.audio_ctx = audio_ctx; + params.pcmf32 = pcmf32_vec; params.comma_in_time = comma_in_time; Napi::Function callback = info[1].As(); diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 90bd6fc2ff4..09b33c54024 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -16,13 +16,20 @@ const whisperParams = { comma_in_time: false, translate: true, no_timestamps: false, + audio_ctx: 0, }; const arguments = process.argv.slice(2); const params = Object.fromEntries( arguments.reduce((pre, item) => { if (item.startsWith("--")) { - return [...pre, item.slice(2).split("=")]; + const [key, value] = item.slice(2).split("="); + if (key === "audio_ctx") { + whisperParams[key] = parseInt(value); + } else { + whisperParams[key] = value; + } + return pre; } return pre; }, []) From 4ef8d9f44eb402c528ab6d990ab50a9f4f666347 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 May 2024 15:33:46 +0300 Subject: [PATCH 075/100] server : return utf-8 (#2138) --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a98e156c26b..e3b96698228 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -831,7 +831,7 @@ int main(int argc, char ** argv) { if (params.response_format == text_format) { std::string results = output_str(ctx, params, pcmf32s); - res.set_content(results.c_str(), "text/html"); + res.set_content(results.c_str(), "text/html; charset=utf-8"); } else if (params.response_format == srt_format) { From d8356a1cc2b95218a808425b50734007fd13aa00 Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Tue, 14 May 2024 09:43:41 +0300 Subject: [PATCH 076/100] whisper : fix model path encoding in windows (#2086) * fix: model path encoding in windows * fix: convert model path to wide string only for MSVC compiler --- whisper.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/whisper.cpp b/whisper.cpp index bdcf3de40e2..ff4223daf42 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -3362,8 +3363,14 @@ struct whisper_context_params whisper_context_default_params() { struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) { WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); - +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else auto fin = std::ifstream(path_model, std::ios::binary); +#endif if (!fin) { WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); return nullptr; From 130f43e4b87d17ba9d1c68234e26d1180f4bb9a1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 May 2024 19:15:35 +0300 Subject: [PATCH 077/100] scripts : sync ggml-rpc --- scripts/sync-ggml-am.sh | 4 ++++ scripts/sync-ggml.sh | 2 ++ 2 files changed, 6 insertions(+) diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 70ff16d0df2..54c243a15db 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -117,6 +117,8 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then # src/ggml-opencl.h -> ggml-opencl.h # src/ggml-quants.c -> ggml-quants.c # src/ggml-quants.h -> ggml-quants.h + # src/ggml-rpc.cpp -> ggml-rpc.cpp + # src/ggml-rpc.h -> ggml-rpc.h # src/ggml-sycl.cpp -> ggml-sycl.cpp # src/ggml-sycl.h -> ggml-sycl.h # src/ggml-vulkan.cpp -> ggml-vulkan.cpp @@ -160,6 +162,8 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then -e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \ -e 's/src\/ggml-quants\.c/ggml-quants.c/g' \ -e 's/src\/ggml-quants\.h/ggml-quants.h/g' \ + -e 's/src\/ggml-rpc\.cpp/ggml-rpc.cpp/g' \ + -e 's/src\/ggml-rpc\.h/ggml-rpc.h/g' \ -e 's/src\/ggml-sycl\.cpp/ggml-sycl.cpp/g' \ -e 's/src\/ggml-sycl\.h/ggml-sycl.h/g' \ -e 's/src\/ggml-vulkan\.cpp/ggml-vulkan.cpp/g' \ diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index 2efffcd213c..1b0f2045cf3 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -20,6 +20,8 @@ cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c cp -rpv ../ggml/src/ggml-quants.h ./ggml-quants.h +cp -rpv ../ggml/src/ggml-rpc.cpp ./ggml-rpc.cpp +cp -rpv ../ggml/src/ggml-rpc.h ./ggml-rpc.h cp -rpv ../ggml/src/ggml-sycl.cpp ./ggml-sycl.cpp cp -rpv ../ggml/src/ggml-sycl.h ./ggml-sycl.h cp -rpv ../ggml/src/ggml-vulkan.cpp ./ggml-vulkan.cpp From e57e95eb0d3bdba42bbf057c888f6ff819a5f59b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 12 May 2024 19:40:45 +0200 Subject: [PATCH 078/100] CUDA: add FP32 FlashAttention vector kernel (llama/7188) * CUDA: add FP32 FlashAttention vector kernel * fixup! CUDA: add FP32 FlashAttention vector kernel * fixup! fixup! CUDA: add FP32 FlashAttention vector kernel * fixup! fixup! fixup! CUDA: add FP32 FlashAttention vector kernel --- ggml-cuda.cu | 11 +- ggml-cuda/common.cuh | 4 + ggml-cuda/fattn-common.cuh | 47 ++++ ggml-cuda/fattn-vec-f16.cu | 430 +++++++++++++++++++++++++++++++++ ggml-cuda/fattn-vec-f16.cuh | 5 + ggml-cuda/fattn-vec-f32.cu | 384 +++++++++++++++++++++++++++++ ggml-cuda/fattn-vec-f32.cuh | 3 + ggml-cuda/fattn.cu | 468 ++---------------------------------- 8 files changed, 898 insertions(+), 454 deletions(-) create mode 100644 ggml-cuda/fattn-common.cuh create mode 100644 ggml-cuda/fattn-vec-f16.cu create mode 100644 ggml-cuda/fattn-vec-f16.cuh create mode 100644 ggml-cuda/fattn-vec-f32.cu create mode 100644 ggml-cuda/fattn-vec-f32.cuh diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5b6c9091924..75a2ad48087 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2713,6 +2713,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -2840,8 +2841,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: - case GGML_OP_FLASH_ATTN_EXT: return true; + case GGML_OP_FLASH_ATTN_EXT: +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128; +#else + if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) { + return true; + } + return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; } diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 44e67e040e1..b6f0bc36a4f 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -321,6 +321,10 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +static bool fast_fp16_available(const int cc) { + return cc >= CC_PASCAL && cc != 610; +} + static bool fp16_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; } diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh new file mode 100644 index 00000000000..33f640691ad --- /dev/null +++ b/ggml-cuda/fattn-common.cuh @@ -0,0 +1,47 @@ +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. + +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +} diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu new file mode 100644 index 00000000000..cbf5f7835f8 --- /dev/null +++ b/ggml-cuda/fattn-vec-f16.cu @@ -0,0 +1,430 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f16.cuh" + +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + ne11*ic0; + + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); + + half slopeh = __float2half(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + } + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = D / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half KQ[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ[j*D + tid] = -HALF_MAX_HALF; + } + half2 * KQ2 = (half2 *) KQ; + + half kqmax[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax[j] = -HALF_MAX_HALF; + } + half kqsum[ncols] = {0.0f}; + + __shared__ half kqmax_shared[ncols][WARP_SIZE]; + __shared__ half kqsum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; + kqsum_shared[j][threadIdx.x] = 0.0f; + } + } + __syncthreads(); + + // Convert Q to half2 and store in registers: + half2 Q_h2[ncols][D/(2*WARP_SIZE)]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } + + half2 VKQ[ncols] = {{0.0f, 0.0f}}; + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, + // see https://github.com/ggerganov/llama.cpp/pull/7061 . + // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). + half kqmax_new = kqmax[0]; + half kqmax_new_arr[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new_arr[j] = kqmax[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + + half2 sum2[ncols] = {{0.0f, 0.0f}}; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE]; + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + sum2[j] = warp_reduce_sum(sum2[j]); + half sum = __low2half(sum2[j]) + __high2half(sum2[j]); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + + if (ncols == 1) { + kqmax_new = ggml_cuda_hmax(kqmax_new, sum); + } else { + kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); + } + + if (threadIdx.x == 0) { + KQ[j*D + i_KQ] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; + + kqmax_new_j = warp_reduce_max(kqmax_new_j); + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = kqmax_new_j; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const half val = hexp(KQ[j*D + tid] - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale + val; + KQ[j*D + tid] = val; + + VKQ[j] *= __half2half2(KQ_max_scale); + } + + __syncthreads(); + +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } + + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqsum[j] = warp_reduce_sum(kqsum[j]); + if (threadIdx.x == 0) { + kqsum_shared[j][threadIdx.y] = kqsum[j]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + + half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); + if (parallel_blocks == 1) { + dst_val /= kqsum[j_VKQ]; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + } + + if (parallel_blocks != 1 && tid != 0) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]); + } + } +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (parallel_blocks == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + const int32_t precision = KQV->op_params[2]; + GGML_ASSERT(precision == GGML_PREC_DEFAULT); + + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } +} + +void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + const int32_t precision = KQV->op_params[2]; + GGML_ASSERT(precision == GGML_PREC_DEFAULT); + GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 8) { + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 1; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } +} diff --git a/ggml-cuda/fattn-vec-f16.cuh b/ggml-cuda/fattn-vec-f16.cuh new file mode 100644 index 00000000000..c7023610ab2 --- /dev/null +++ b/ggml-cuda/fattn-vec-f16.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu new file mode 100644 index 00000000000..40c336ce332 --- /dev/null +++ b/ggml-cuda/fattn-vec-f32.cu @@ -0,0 +1,384 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f32.cuh" + +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_vec_ext_f32( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + ne11*ic0; + + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = powf(base, exph); + } + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = D / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ float KQ[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ[j*D + tid] = -FLT_MAX/2.0f; + } + + float kqmax[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax[j] = -FLT_MAX/2.0f; + } + float kqsum[ncols] = {0.0f}; + + __shared__ float kqmax_shared[ncols][WARP_SIZE]; + __shared__ float kqsum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f; + kqsum_shared[j][threadIdx.x] = 0.0f; + } + } + __syncthreads(); + + // Convert Q to half2 and store in registers: + float2 Q_h2[ncols][D/(2*WARP_SIZE)]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i]; + Q_h2[j][i0/WARP_SIZE].x *= scale; + Q_h2[j][i0/WARP_SIZE].y *= scale; + } + } + + float VKQ[ncols] = {0.0f}; + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + + float kqmax_new_arr[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new_arr[j] = kqmax[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + + float sum[ncols] = {0.0f}; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x; + sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y; + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + sum[j] = warp_reduce_sum(sum[j]); + sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + + kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]); + + if (threadIdx.x == 0) { + KQ[j*D + i_KQ] = sum[j]; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float kqmax_new_j = kqmax_new_arr[j]; + + kqmax_new_j = warp_reduce_max(kqmax_new_j); + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = kqmax_new_j; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const float val = expf(KQ[j*D + tid] - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale + val; + KQ[j*D + tid] = val; + + VKQ[j] *= KQ_max_scale; + } + + __syncthreads(); + +#pragma unroll + for (int k = 0; k < D; ++k) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) { + break; + } + + const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j] += V_ki*KQ[j*D + k]; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqsum[j] = warp_reduce_sum(kqsum[j]); + if (threadIdx.x == 0) { + kqsum_shared[j][threadIdx.y] = kqsum[j]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + + float dst_val = VKQ[j_VKQ]; + if (parallel_blocks == 1) { + dst_val /= kqsum[j_VKQ]; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + } + + if (parallel_blocks != 1 && tid != 0) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]); + } + } +} + +template void launch_fattn_vec_f32( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + flash_attn_vec_ext_f32 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (parallel_blocks == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 8) { + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 1; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } +} diff --git a/ggml-cuda/fattn-vec-f32.cuh b/ggml-cuda/fattn-vec-f32.cuh new file mode 100644 index 00000000000..614d54ae392 --- /dev/null +++ b/ggml-cuda/fattn-vec-f32.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index ac5d6672b30..419f8e752a7 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,4 +1,7 @@ #include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f16.cuh" +#include "fattn-vec-f32.cuh" #include "fattn.cuh" #include @@ -7,251 +10,6 @@ #include #endif -#define FATTN_KQ_STRIDE 256 -#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. -#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. - -template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_vec_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { -#if FP16_AVAILABLE - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; - - const int stride_KV = nb11 / sizeof(half); - const int stride_KV2 = nb11 / sizeof(half2); - - half slopeh = __float2half(1.0f); - - // ALiBi - if (max_bias > 0.0f) { - const int h = blockIdx.y; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slopeh = __float2half(powf(base, exph)); - } - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); - - __shared__ half KQ[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; - } - half2 * KQ2 = (half2 *) KQ; - - half kqmax[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax[j] = -HALF_MAX_HALF; - } - half kqsum[ncols] = {0.0f}; - - __shared__ half kqmax_shared[ncols][WARP_SIZE]; - __shared__ half kqsum_shared[ncols][WARP_SIZE]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.y == 0) { - kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; - kqsum_shared[j][threadIdx.x] = 0.0f; - } - } - __syncthreads(); - - // Convert Q to half2 and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; - Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - - half2 VKQ[ncols] = {{0.0f, 0.0f}}; - - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { - // Calculate KQ tile and keep track of new maximum KQ values: - - // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, - // see https://github.com/ggerganov/llama.cpp/pull/7061 . - // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). - half kqmax_new = kqmax[0]; - half kqmax_new_arr[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax_new_arr[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { - break; - } - - half2 sum2[ncols] = {{0.0f, 0.0f}}; -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE]; - } - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - sum2[j] = warp_reduce_sum(sum2[j]); - half sum = __low2half(sum2[j]) + __high2half(sum2[j]); - sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - - if (ncols == 1) { - kqmax_new = ggml_cuda_hmax(kqmax_new, sum); - } else { - kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); - } - - if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; - } - } - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - - kqmax_new_j = warp_reduce_max(kqmax_new_j); - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = kqmax_new_j; - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const half val = hexp(KQ[j*D + tid] - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; - - VKQ[j] *= __half2half2(KQ_max_scale); - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } - - half2 V_k; - reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; - reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; - } - } - - __syncthreads(); - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum(kqsum[j]); - if (threadIdx.x == 0) { - kqsum_shared[j][threadIdx.y] = kqsum[j]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); - - half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); - if (parallel_blocks == 1) { - dst_val /= kqsum[j_VKQ]; - } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; - } - - if (parallel_blocks != 1 && tid != 0) { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]); - } - } -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE -} - // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -655,54 +413,6 @@ static __global__ void flash_attn_ext_f16( #endif // FP16_MMA_AVAILABLE } -template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const float2 * __restrict__ VKQ_meta, - float * __restrict__ dst) { -#if FP16_AVAILABLE - VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; - dst += D * gridDim.y*blockIdx.x; - - const int tid = threadIdx.x; - __builtin_assume(tid < D); - - __shared__ float2 meta[parallel_blocks]; - if (tid < 2*parallel_blocks) { - ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; - } - - __syncthreads(); - - float kqmax = meta[0].x; -#pragma unroll - for (int l = 1; l < parallel_blocks; ++l) { - kqmax = max(kqmax, meta[l].x); - } - - float VKQ_numerator = 0.0f; - float VKQ_denominator = 0.0f; -#pragma unroll - for (int l = 0; l < parallel_blocks; ++l) { - const float diff = meta[l].x - kqmax; - const float KQ_max_scale = expf(diff); - const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); - *((uint32_t *) &KQ_max_scale) &= ftz_mask; - - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; - VKQ_denominator += KQ_max_scale * meta[l].y; - } - - dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE -} - constexpr int get_max_power_of_2(int x) { return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; } @@ -727,66 +437,6 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -template void launch_fattn_vec_f16( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, - ggml_cuda_pool & pool, cudaStream_t main_stream -) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - - constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); - const int shmem = 0; - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - - const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - flash_attn_vec_ext_f16 - <<>> ( - (const char *) Q->data, - (const char *) K->data, - (const char *) V->data, - mask ? ((const char *) mask->data) : nullptr, - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - CUDA_CHECK(cudaGetLastError()); - - if (parallel_blocks == 1) { - return; - } - - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; - - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - CUDA_CHECK(cudaGetLastError()); -} - template void launch_fattn_f16_impl( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream @@ -891,95 +541,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int32_t precision = KQV->op_params[2]; - if (!fp16_mma_available(cc)) { - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."); - - if (Q->ne[1] == 1) { - constexpr int cols_per_block = 1; - constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } - return; - } - - if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } - return; - } - - if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } - return; - } - - if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } - return; - } + if (!fast_fp16_available(cc)) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + return; + } - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 1; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + if (!fp16_mma_available(cc)) { + ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst); return; } if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] == 1 && (Q->ne[0] == 64 || Q->ne[0] == 128)) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + return; + } + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { constexpr int cols_per_block = 16; constexpr int nwarps = 4; @@ -1037,22 +614,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { - constexpr int cols_per_block = 1; - constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 256: - launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; } From 8e7c22fbdbc27c1c72abd192720e330e7f6361a9 Mon Sep 17 00:00:00 2001 From: Neo Zhang <14088817+arthw@users.noreply.github.com> Date: Mon, 13 May 2024 18:11:26 +0800 Subject: [PATCH 079/100] rm wait() (llama/7233) --- ggml-sycl.cpp | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index e93d2af631c..724070eb910 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -15564,26 +15564,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; -#if 0 - // use syclGemmEx - { - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - int i03 = i13 / r3; - int i02 = i12 / r2; - - SYCL_CHECK( - syclGemmEx(g_sycl_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , SYCL_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, SYCL_R_16F, nb11/sizeof(float), - beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - } - } -#else if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( @@ -15595,7 +15575,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, nb11 / nb10, nb12 / nb10, beta, (char *)dst_t, cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type))); - g_sycl_handles[g_main_device]->wait(); } else { const int ne23 = ne12*ne13; @@ -15626,7 +15605,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, nb02, nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); }); - }).wait(); + }); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans, @@ -15637,9 +15616,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, dpct::library_data_t::real_half, nb11 / nb10, beta, (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type))); - g_sycl_handles[g_main_device]->wait(); } -#endif if (no_mixed_dtypes) { const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); From c451080c8b0e2080f2ca887047ef381b94523e14 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 14 May 2024 14:27:19 +0300 Subject: [PATCH 080/100] ggml : add RPC backend (llama/6829) * ggml : add RPC backend The RPC backend proxies all operations to a remote server which runs a regular backend (CPU, CUDA, Metal, etc). * set TCP_NODELAY * add CI workflows * Address review comments * fix warning * implement llama_max_devices() for RPC * Address review comments * Address review comments * wrap sockfd into a struct * implement get_alignment and get_max_size * add get_device_memory * fix warning * win32 support * add README * readme : trim trailing whitespace * Address review comments * win32 fix * Address review comments * fix compile warnings on macos --- ggml-rpc.cpp | 1023 ++++++++++++++++++++++++++++++++++++++++++++++++++ ggml-rpc.h | 24 ++ 2 files changed, 1047 insertions(+) create mode 100644 ggml-rpc.cpp create mode 100644 ggml-rpc.h diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp new file mode 100644 index 00000000000..efeacb29767 --- /dev/null +++ b/ggml-rpc.cpp @@ -0,0 +1,1023 @@ +#include "ggml-rpc.h" +#include "ggml.h" +#include "ggml-backend-impl.h" + +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include + +#define UNUSED GGML_UNUSED + +#define GGML_DEBUG 1 +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +// cross-platform socket +struct socket_t { + sockfd_t fd; + socket_t(sockfd_t fd) : fd(fd) {} + ~socket_t() { +#ifdef _WIN32 + closesocket(this->fd); +#else + close(this->fd); +#endif + } +}; + +// ggml_tensor is serialized into rpc_tensor +struct rpc_tensor { + uint64_t id; + uint32_t type; + uint64_t buffer; + uint32_t ne[GGML_MAX_DIMS]; + uint32_t nb[GGML_MAX_DIMS]; + uint32_t op; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + int32_t flags; + uint64_t src[GGML_MAX_SRC]; + uint64_t view_src; + uint64_t view_offs; + uint64_t data; + char name[GGML_MAX_NAME]; +}; + +// RPC commands +enum rpc_cmd { + ALLOC_BUFFER = 0, + GET_ALIGNMENT, + GET_MAX_SIZE, + BUFFER_GET_BASE, + FREE_BUFFER, + BUFFER_CLEAR, + SET_TENSOR, + GET_TENSOR, + COPY_TENSOR, + GRAPH_COMPUTE, + GET_DEVICE_MEMORY, +}; + +// RPC data structures + +static ggml_guid_t ggml_backend_rpc_guid() { + static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; + return &guid; +} + +struct ggml_backend_rpc_buffer_type_context { + std::shared_ptr sock; + std::string name; + size_t alignment; + size_t max_size; +}; + +struct ggml_backend_rpc_context { + std::string endpoint; + std::string name; + std::shared_ptr sock; + ggml_backend_buffer_type_t buft; +}; + +struct ggml_backend_rpc_buffer_context { + std::shared_ptr sock; + std::unordered_map base_cache; + uint64_t remote_ptr; + std::string name; +}; + +// RPC helper functions + +static std::shared_ptr make_socket(sockfd_t fd) { +#ifdef _WIN32 + if (fd == INVALID_SOCKET) { + return nullptr; + } +#else + if (fd < 0) { + return nullptr; + } +#endif + return std::make_shared(fd); +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret >= 0; +} + +static std::shared_ptr socket_connect(const char * host, int port) { + struct sockaddr_in addr; + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + auto sock_ptr = make_socket(sockfd); + if (sock_ptr == nullptr) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + fprintf(stderr, "Failed to set TCP_NODELAY\n"); + return nullptr; + } + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + fprintf(stderr, "Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return sock_ptr; +} + +static std::shared_ptr socket_accept(sockfd_t srv_sockfd) { + auto client_socket_fd = accept(srv_sockfd, NULL, NULL); + auto client_socket = make_socket(client_socket_fd); + if (client_socket == nullptr) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + fprintf(stderr, "Failed to set TCP_NODELAY\n"); + return nullptr; + } + return client_socket; +} + +static std::shared_ptr create_server_socket(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + auto sock = make_socket(sockfd); + if (sock == nullptr) { + return nullptr; + } + + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return sock; +} + +static bool send_data(sockfd_t sockfd, const void * data, size_t size) { + size_t bytes_sent = 0; + while (bytes_sent < size) { + ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0); + if (n < 0) { + return false; + } + bytes_sent += n; + } + return true; +} + +static bool recv_data(sockfd_t sockfd, void * data, size_t size) { + size_t bytes_recv = 0; + while (bytes_recv < size) { + ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0); + if (n <= 0) { + return false; + } + bytes_recv += n; + } + return true; +} + +static bool parse_endpoint(const char * endpoint, std::string & host, int & port) { + std::string str(endpoint); + size_t pos = str.find(':'); + if (pos == std::string::npos) { + return false; + } + host = str.substr(0, pos); + port = std::stoi(str.substr(pos + 1)); + return true; +} + +// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | +// RPC response: | response_size (8 bytes) | response_data (response_size bytes) | +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const std::vector & input, std::vector & output) { + uint8_t cmd_byte = cmd; + if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { + return false; + } + uint64_t input_size = input.size(); + if (!send_data(sock->fd, &input_size, sizeof(input_size))) { + return false; + } + if (!send_data(sock->fd, input.data(), input.size())) { + return false; + } + uint64_t output_size; + if (!recv_data(sock->fd, &output_size, sizeof(output_size))) { + return false; + } + if (output_size == 0) { + output.clear(); + return true; + } + output.resize(output_size); + if (!recv_data(sock->fd, output.data(), output_size)) { + return false; + } + return true; +} + +// RPC client-side implementation + +GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + return ctx->name.c_str(); +} + +GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + // input serialization format: | remote_ptr (8 bytes) | + std::vector input(sizeof(uint64_t), 0); + uint64_t remote_ptr = ctx->remote_ptr; + memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); + std::vector output; + bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.empty()); + delete ctx; +} + +GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) { + return ctx->base_cache[buffer]; + } + // input serialization format: | remote_ptr (8 bytes) | + std::vector input(sizeof(uint64_t), 0); + uint64_t remote_ptr = ctx->remote_ptr; + memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); + std::vector output; + bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == sizeof(uint64_t)); + // output serialization format: | base_ptr (8 bytes) | + uint64_t base_ptr; + memcpy(&base_ptr, output.data(), sizeof(base_ptr)); + void * base = reinterpret_cast(base_ptr); + ctx->base_cache[buffer] = base; + return base; +} + +static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { + rpc_tensor result; + result.id = reinterpret_cast(tensor); + result.type = tensor->type; + if (tensor->buffer) { + ggml_backend_buffer_t buffer = tensor->buffer; + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + result.buffer = ctx->remote_ptr; + } else { + result.buffer = 0; + } + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result.ne[i] = tensor->ne[i]; + result.nb[i] = tensor->nb[i]; + } + result.op = tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result.op_params[i] = tensor->op_params[i]; + } + result.flags = tensor->flags; + for (uint32_t i = 0; i < GGML_MAX_SRC; i++) { + result.src[i] = reinterpret_cast(tensor->src[i]); + } + result.view_src = reinterpret_cast(tensor->view_src); + result.view_offs = tensor->view_offs; + result.data = reinterpret_cast(tensor->data); + snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); + return result; +} + +static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) { + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, + tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = tensor->nb[i]; + } + result->buffer = reinterpret_cast(tensor->buffer); + result->op = (ggml_op) tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result->op_params[i] = tensor->op_params[i]; + } + result->flags = tensor->flags; + result->data = reinterpret_cast(tensor->data); + ggml_set_name(result, tensor->name); + return result; +} + +GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + UNUSED(buffer); + if (ggml_is_quantized(tensor->type)) { + // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized + GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor"); + } +} + +GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | + size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size; + std::vector input(input_size, 0); + rpc_tensor rpc_tensor = serialize_tensor(tensor); + memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); + std::vector output; + bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output); + GGML_ASSERT(status); +} + +GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) | + int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t); + std::vector input(input_size, 0); + rpc_tensor rpc_tensor = serialize_tensor(tensor); + memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); + std::vector output; + bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == size); + // output serialization format: | data (size bytes) | + memcpy(data, output.data(), size); +} + +GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + // check if src and dst are on the same server + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + if (src_ctx->sock != dst_ctx->sock) { + return false; + } + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + // input serialization format: | rpc_tensor src | rpc_tensor dst | + int input_size = 2*sizeof(rpc_tensor); + std::vector input(input_size, 0); + rpc_tensor rpc_src = serialize_tensor(src); + rpc_tensor rpc_dst = serialize_tensor(dst); + memcpy(input.data(), &rpc_src, sizeof(rpc_src)); + memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); + std::vector output; + bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output); + GGML_ASSERT(status); + // output serialization format: | result (1 byte) | + GGML_ASSERT(output.size() == 1); + return output[0]; +} + +GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + // serialization format: | bufptr (8 bytes) | value (1 byte) | + int input_size = sizeof(uint64_t) + sizeof(uint8_t); + std::vector input(input_size, 0); + memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr)); + memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value)); + std::vector output; + bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output); + GGML_ASSERT(status); +} + +static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { + /* .get_name = */ ggml_backend_rpc_buffer_get_name, + /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer, + /* .get_base = */ ggml_backend_rpc_buffer_get_base, + /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor, + /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, + /* .clear = */ ggml_backend_rpc_buffer_clear, + /* .reset = */ NULL, +}; + +GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->name.c_str(); +} + +GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + // input serialization format: | size (8 bytes) | + int input_size = sizeof(uint64_t); + std::vector input(input_size, 0); + memcpy(input.data(), &size, sizeof(size)); + std::vector output; + bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); + // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | + uint64_t remote_ptr; + memcpy(&remote_ptr, output.data(), sizeof(remote_ptr)); + size_t remote_size; + memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size)); + + ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, + ggml_backend_rpc_buffer_interface, + new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"}, + remote_size); + + return buffer; +} + +static size_t get_alignment(const std::shared_ptr & sock) { + // input serialization format: | 0 bytes | + std::vector input; + std::vector output; + bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == sizeof(uint64_t)); + // output serialization format: | alignment (8 bytes) | + uint64_t alignment; + memcpy(&alignment, output.data(), sizeof(alignment)); + return alignment; +} + +GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->alignment; +} + +static size_t get_max_size(const std::shared_ptr & sock) { + // input serialization format: | 0 bytes | + std::vector input; + std::vector output; + bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == sizeof(uint64_t)); + // output serialization format: | max_size (8 bytes) | + uint64_t max_size; + memcpy(&max_size, output.data(), sizeof(max_size)); + return max_size; +} + +GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->max_size; +} + +GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + UNUSED(buft); + return ggml_nbytes(tensor); +} + +GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + if (!ggml_backend_is_rpc(backend)) { + return false; + } + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + return buft_ctx->sock == rpc_ctx->sock; +} + +static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { + /* .get_name = */ ggml_backend_rpc_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_rpc_get_max_size, + /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size, + /* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend, + /* .is_host = */ NULL, +}; + + +GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + + return rpc_ctx->name.c_str(); +} + +GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context; + delete buft_ctx; + delete rpc_ctx->buft; + delete rpc_ctx; + delete backend; +} + +GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; + return ctx->buft; +} + +GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { + UNUSED(backend); + // this is no-op because we don't have any async operations +} + +static void add_tensor(ggml_tensor * tensor, std::vector & tensors, std::unordered_set & visited) { + if (tensor == nullptr) { + return; + } + if (visited.find(tensor) != visited.end()) { + return; + } + visited.insert(tensor); + for (int i = 0; i < GGML_MAX_SRC; i++) { + add_tensor(tensor->src[i], tensors, visited); + } + add_tensor(tensor->view_src, tensors, visited); + tensors.push_back(serialize_tensor(tensor)); +} + +static void serialize_graph(const ggml_cgraph * cgraph, std::vector & output) { + uint32_t n_nodes = cgraph->n_nodes; + std::vector tensors; + std::unordered_set visited; + for (uint32_t i = 0; i < n_nodes; i++) { + add_tensor(cgraph->nodes[i], tensors, visited); + } + // serialization format: + // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + uint32_t n_tensors = tensors.size(); + int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + output.resize(output_size, 0); + memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + uint64_t * out_nodes = (uint64_t *)(output.data() + sizeof(n_nodes)); + for (uint32_t i = 0; i < n_nodes; i++) { + out_nodes[i] = reinterpret_cast(cgraph->nodes[i]); + } + uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); + *out_ntensors = n_tensors; + rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); +} + +GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + std::vector input; + serialize_graph(cgraph, input); + std::vector output; + bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == 1); + return (enum ggml_status)output[0]; +} + +GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) { + UNUSED(backend); + UNUSED(op); + GGML_ASSERT(false && "not implemented"); + return false; +} + +static ggml_backend_i ggml_backend_rpc_interface = { + /* .get_name = */ ggml_backend_rpc_name, + /* .free = */ ggml_backend_rpc_free, + /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ ggml_backend_rpc_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_rpc_graph_compute, + /* .supports_op = */ ggml_backend_rpc_supports_op, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static std::unordered_map instances; + +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { + ggml_backend_t backend = ggml_backend_rpc_init(endpoint); + return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr; +} + +GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { + std::string endpoint_str(endpoint); + if (instances.find(endpoint_str) != instances.end()) { + return instances[endpoint_str]; + } +#ifdef _WIN32 + { + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return nullptr; + } + } +#endif + GGML_PRINT_DEBUG("Connecting to %s\n", endpoint); + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + return nullptr; + } + auto sock = socket_connect(host.c_str(), port); + if (sock == nullptr) { + return nullptr; + } + size_t alignment = get_alignment(sock); + size_t max_size = get_max_size(sock); + ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { + /* .sock = */ sock, + /* .name = */ "RPC" + std::to_string(sock->fd), + /* .alignment = */ alignment, + /* .max_size = */ max_size + }; + + ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { + /* .iface = */ ggml_backend_rpc_buffer_type_interface, + /* .context = */ buft_ctx + }; + + ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { + /* .endpoint = */ endpoint, + /* .name = */ "RPC" + std::to_string(sock->fd), + /* .sock = */ sock, + /* .buft = */ buft + }; + + instances[endpoint] = new ggml_backend { + /* .guid = */ ggml_backend_rpc_guid(), + /* .interface = */ ggml_backend_rpc_interface, + /* .context = */ ctx + }; + + return instances[endpoint]; +} + +GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); +} + +static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { + // input serialization format: | 0 bytes | + std::vector input; + std::vector output; + bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); + // output serialization format: | free (8 bytes) | total (8 bytes) | + uint64_t free_mem; + memcpy(&free_mem, output.data(), sizeof(free_mem)); + uint64_t total_mem; + memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem)); + *free = free_mem; + *total = total_mem; +} + +GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { + ggml_backend_t backend = ggml_backend_rpc_init(endpoint); + if (backend == nullptr) { + *free = 0; + *total = 0; + return; + } + ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; + get_device_memory(ctx->sock, free, total); +} + +// RPC server-side implementation + +static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector & input, std::vector & output) { + // input serialization format: | size (8 bytes) | + uint64_t size; + memcpy(&size, input.data(), sizeof(size)); + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size); + uint64_t remote_ptr = reinterpret_cast(buffer); + uint64_t remote_size = buffer->size; + GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size); + // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | + output.resize(2*sizeof(uint64_t), 0); + memcpy(output.data(), &remote_ptr, sizeof(remote_ptr)); + memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size)); +} + +static void rpc_get_alignment(ggml_backend_t backend, std::vector & output) { + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + size_t alignment = ggml_backend_buft_get_alignment(buft); + GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment); + // output serialization format: | alignment (8 bytes) | + output.resize(sizeof(uint64_t), 0); + memcpy(output.data(), &alignment, sizeof(alignment)); +} + +static void rpc_get_max_size(ggml_backend_t backend, std::vector & output) { + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + size_t max_size = ggml_backend_buft_get_max_size(buft); + GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size); + // output serialization format: | max_size (8 bytes) | + output.resize(sizeof(uint64_t), 0); + memcpy(output.data(), &max_size, sizeof(max_size)); +} + +static void rpc_buffer_get_base(const std::vector & input, std::vector & output) { + // input serialization format: | remote_ptr (8 bytes) | + uint64_t remote_ptr; + memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr); + void * base = ggml_backend_buffer_get_base(buffer); + // output serialization format: | base_ptr (8 bytes) | + uint64_t base_ptr = reinterpret_cast(base); + output.resize(sizeof(uint64_t), 0); + memcpy(output.data(), &base_ptr, sizeof(base_ptr)); +} + +static void rpc_free_buffer(const std::vector & input) { + // input serialization format: | remote_ptr (8 bytes) | + uint64_t remote_ptr; + memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr); + ggml_backend_buffer_free(buffer); +} + +static void rpc_buffer_clear(const std::vector & input) { + // input serialization format: | remote_ptr (8 bytes) | value (1 byte) | + uint64_t remote_ptr; + memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); + uint8_t value; + memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value)); + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value); + ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr); + ggml_backend_buffer_clear(buffer, value); +} + +static void rpc_set_tensor(const std::vector & input) { + // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | + const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); + uint64_t offset; + memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); + size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset); + + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); + const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset); + ggml_backend_tensor_set(tensor, data, offset, size); + ggml_free(ctx); +} + +static void rpc_get_tensor(const std::vector & input, std::vector & output) { + // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) | + const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); + uint64_t offset; + memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); + uint64_t size; + memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size)); + + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); + // output serialization format: | data (size bytes) | + output.resize(size, 0); + ggml_backend_tensor_get(tensor, output.data(), offset, size); + ggml_free(ctx); +} + +static void rpc_copy_tensor(const std::vector & input, std::vector & output) { + // serialization format: | rpc_tensor src | rpc_tensor dst | + const rpc_tensor * rpc_src = (const rpc_tensor *)input.data(); + const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src)); + + struct ggml_init_params params { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * src = deserialize_tensor(ctx, rpc_src); + ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst); + GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer); + bool result = ggml_backend_buffer_copy_tensor(src, dst); + // output serialization format: | result (1 byte) | + output.resize(1, 0); + output[0] = result; + ggml_free(ctx); +} + +static struct ggml_tensor * create_node(uint64_t id, + struct ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map) { + if (id == 0) { + return nullptr; + } + if (tensor_map.find(id) != tensor_map.end()) { + return tensor_map[id]; + } + const rpc_tensor * tensor = tensor_ptrs.at(id); + struct ggml_tensor * result = deserialize_tensor(ctx, tensor); + tensor_map[id] = result; + for (int i = 0; i < GGML_MAX_SRC; i++) { + result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + } + result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); + result->view_offs = tensor->view_offs; + return result; +} + +static void rpc_graph_compute(ggml_backend_t backend, const std::vector & input, std::vector & output) { + // serialization format: + // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + uint32_t n_nodes; + memcpy(&n_nodes, input.data(), sizeof(n_nodes)); + const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes)); + uint32_t n_tensors; + memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors)); + const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors)); + GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); + + static size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); + graph->n_nodes = n_nodes; + std::unordered_map tensor_ptrs; + for (uint32_t i = 0; i < n_tensors; i++) { + tensor_ptrs[tensors[i].id] = &tensors[i]; + } + std::unordered_map tensor_map; + for (uint32_t i = 0; i < n_nodes; i++) { + graph->nodes[i] = create_node(nodes[i], ctx, tensor_ptrs, tensor_map); + } + ggml_status status = ggml_backend_graph_compute(backend, graph); + // output serialization format: | status (1 byte) | + output.resize(1, 0); + output[0] = status; + ggml_free(ctx); +} + +static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) { + while (true) { + uint8_t cmd; + if (!recv_data(sockfd, &cmd, 1)) { + break; + } + std::vector input; + std::vector output; + uint64_t input_size; + if (!recv_data(sockfd, &input_size, sizeof(input_size))) { + break; + } + input.resize(input_size); + if (!recv_data(sockfd, input.data(), input_size)) { + break; + } + switch (cmd) { + case ALLOC_BUFFER: { + rpc_alloc_buffer(backend, input, output); + break; + } + case GET_ALIGNMENT: { + rpc_get_alignment(backend, output); + break; + } + case GET_MAX_SIZE: { + rpc_get_max_size(backend, output); + break; + } + case BUFFER_GET_BASE: { + rpc_buffer_get_base(input, output); + break; + } + case FREE_BUFFER: { + rpc_free_buffer(input); + break; + } + case BUFFER_CLEAR: { + rpc_buffer_clear(input); + break; + } + case SET_TENSOR: { + rpc_set_tensor(input); + break; + } + case GET_TENSOR: { + rpc_get_tensor(input, output); + break; + } + case COPY_TENSOR: { + rpc_copy_tensor(input, output); + break; + } + case GRAPH_COMPUTE: { + rpc_graph_compute(backend, input, output); + break; + } + case GET_DEVICE_MEMORY: { + // output serialization format: | free (8 bytes) | total (8 bytes) | + output.resize(2*sizeof(uint64_t), 0); + memcpy(output.data(), &free_mem, sizeof(free_mem)); + memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); + break; + } + default: { + fprintf(stderr, "Unknown command: %d\n", cmd); + return; + } + } + uint64_t output_size = output.size(); + if (!send_data(sockfd, &output_size, sizeof(output_size))) { + break; + } + if (!send_data(sockfd, output.data(), output_size)) { + break; + } + } +} + +void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + return; + } +#ifdef _WIN32 + { + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + fprintf(stderr, "WSAStartup failed: %d\n", res); + return; + } + } +#endif + auto server_socket = create_server_socket(host.c_str(), port); + if (server_socket == nullptr) { + fprintf(stderr, "Failed to create server socket\n"); + return; + } + while (true) { + auto client_socket = socket_accept(server_socket->fd); + if (client_socket == nullptr) { + fprintf(stderr, "Failed to accept client connection\n"); + return; + } + printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); + printf("Client connection closed\n"); + } +#ifdef _WIN32 + WSACleanup(); +#endif +} diff --git a/ggml-rpc.h b/ggml-rpc.h new file mode 100644 index 00000000000..aa144832a6e --- /dev/null +++ b/ggml-rpc.h @@ -0,0 +1,24 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_RPC_MAX_SERVERS 16 + +// backend API +GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend); + +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); + +GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); + +GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem); + +#ifdef __cplusplus +} +#endif From 1056ad762cdc15141e4eba3db5bd6e83eaa4b28f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 May 2024 19:09:30 +0300 Subject: [PATCH 081/100] metal : support FA without mask + add asserts (llama/7278) * ggml : fa without mask + add asserts ggml-ci * metal : support non-contiguous KV ggml-ci --- ggml-metal.m | 69 ++++++++++++++++++++++++++---------------------- ggml-metal.metal | 53 ++++++++++++++----------------------- ggml.c | 10 +++++++ ggml.h | 3 ++- 4 files changed, 70 insertions(+), 65 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index bfa352c3a9a..390a1cd7890 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2512,13 +2512,14 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_FLASH_ATTN_EXT: { - GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne11 % 32 == 0); + GGML_ASSERT(src0->type == GGML_TYPE_F32); - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + GGML_ASSERT(ggml_are_same_shape (src1, src2)); - GGML_ASSERT(ggml_are_same_shape(src1, src2)); - GGML_ASSERT(src3); + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; size_t offs_src3 = 0; @@ -2528,6 +2529,11 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); @@ -2590,34 +2596,35 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + } [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25]; - [encoder setBytes:&scale length:sizeof( float) atIndex:26]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:27]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:28]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:29]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&scale length:sizeof( float) atIndex:23]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml-metal.metal b/ggml-metal.metal index 7af4e8f9342..57fdf564e17 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2049,27 +2049,24 @@ typedef void (flash_attn_ext_f16_t)( device const char * v, device const char * mask, device float * dst, - constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, constant uint64_t & nb31, - constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, - constant int64_t & ne3, constant float & scale, constant float & max_bias, constant float & m0, @@ -2090,27 +2087,24 @@ kernel void kernel_flash_attn_ext_f16( device const char * v, device const char * mask, device float * dst, - constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, constant uint64_t & nb31, - constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, - constant int64_t & ne3, constant float & scale, constant float & max_bias, constant float & m0, @@ -2180,10 +2174,6 @@ kernel void kernel_flash_attn_ext_f16( const short ne22 = ne12; const short ne23 = ne13; - const uint nb21 = nb11; - const uint nb22 = nb12; - const uint nb23 = nb13; - // broadcast const short rk2 = ne02/ne12; const short rk3 = ne03/ne13; @@ -2247,11 +2237,16 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } - // mqk = mqk*scale + mask*slope - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply(mm, mslope, mm); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + if (mask != q) { + // mqk = mqk*scale + mask*slope + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply(mm, mslope, mm); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + } else { + // mqk = mqk*scale + simdgroup_multiply(mqk, mscale, mqk); + } simdgroup_store(mqk, ss + 8*cc, TF, 0, false); } @@ -2425,27 +2420,24 @@ kernel void kernel_flash_attn_ext_vec_f16( device const char * v, device const char * mask, device float * dst, - constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, constant uint64_t & nb31, - constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, - constant int64_t & ne3, constant float & scale, constant float & max_bias, constant float & m0, @@ -2521,10 +2513,6 @@ kernel void kernel_flash_attn_ext_vec_f16( const short ne22 = ne12; const short ne23 = ne13; - const uint nb21 = nb11; - const uint nb22 = nb12; - const uint nb23 = nb13; - // broadcast const short rk2 = ne02/ne12; const short rk3 = ne03/ne13; @@ -2589,8 +2577,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask*slope if (tiisg == 0) { - float4 mm = (float4) mp4[ic/4 + cc]; - mqk = mqk*scale + mm*slope; + mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); ss4[cc] = mqk; } diff --git a/ggml.c b/ggml.c index d443a9b42ce..03b609dddce 100644 --- a/ggml.c +++ b/ggml.c @@ -2824,6 +2824,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor (t0->ne[3] == t1->ne[3] ); } +bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->nb[0] == t1->nb[0] ) && + (t0->nb[1] == t1->nb[1] ) && + (t0->nb[2] == t1->nb[2] ) && + (t0->nb[3] == t1->nb[3] ); +} + // check if t1 can be represented as a repeatition of t0 static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); diff --git a/ggml.h b/ggml.h index 3fe95ed5763..25f4f73a8d9 100644 --- a/ggml.h +++ b/ggml.h @@ -766,7 +766,8 @@ extern "C" { GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars - GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1); + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); + GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); // use this to compute the memory overhead of a tensor GGML_API size_t ggml_tensor_overhead(void); From f56b8305c4f5760b5612a93305ed57aef082bfa5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 May 2024 19:16:32 +0300 Subject: [PATCH 082/100] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 0096c0b533a..35eef4660fd 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -9149580f5e15fa7510fa3413516fbf517cf2e921 +e87c0557b012350005269c49e1c2b5a8631da59a From 9d5771ae43d7fc7cca9d31dd924b13a29144e476 Mon Sep 17 00:00:00 2001 From: petterreinholdtsen Date: Tue, 14 May 2024 20:32:41 +0200 Subject: [PATCH 083/100] talk-llama : reject runs without required arguments (#2153) * Extended talk-llama example to reject runs without required arguments. Print warning and exit if models are not specified on the command line. * Update examples/talk-llama/talk-llama.cpp * Update examples/talk-llama/talk-llama.cpp --------- Co-authored-by: Georgi Gerganov --- examples/talk-llama/talk-llama.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index bb8c26d5efd..838d6f56357 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -288,6 +288,10 @@ int main(int argc, char ** argv) { cparams.use_gpu = params.use_gpu; struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams); + if (!ctx_wsp) { + fprintf(stderr, "No whisper.cpp model specified. Please provide using -mw \n"); + return 1; + } // llama init @@ -301,6 +305,10 @@ int main(int argc, char ** argv) { } struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams); + if (!model_llama) { + fprintf(stderr, "No llama.cpp model specified. Please provide using -ml \n"); + return 1; + } llama_context_params lcparams = llama_context_default_params(); From 7094ea5e750266e16c16c7aecac8fc03294ecaa3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 May 2024 09:38:19 +0300 Subject: [PATCH 084/100] whisper : use flash attention (#2152) * whisper : use flash attention in the encoder * whisper : add kv_pad * whisper : remove extra backend instance (huh?) * whisper : use FA for cross-attention * whisper : use FA for self-attention * whisper : simplify encoder FA * whisper : add flash_attn runtime parameter * scripts : add bench log * scripts : add M1 Pro bench log --- examples/bench/bench.cpp | 17 +- examples/command/command.cpp | 7 +- examples/lsp/lsp.cpp | 8 +- examples/main/main.cpp | 9 +- examples/server/server.cpp | 7 +- examples/stream/stream.cpp | 7 +- examples/talk-llama/talk-llama.cpp | 9 +- examples/talk/talk.cpp | 7 +- examples/wchess/wchess.cmd/wchess.cmd.cpp | 7 +- scripts/bench-all-gg.txt | 298 +++++++++++++++ scripts/bench-all.sh | 25 +- whisper.cpp | 429 ++++++++++++++-------- whisper.h | 1 + 13 files changed, 658 insertions(+), 173 deletions(-) create mode 100644 scripts/bench-all-gg.txt diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index b77621ac884..cac9385c82f 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -12,7 +12,8 @@ struct whisper_params { std::string model = "models/ggml-base.en.bin"; - bool use_gpu = true; + bool use_gpu = true; + bool flash_attn = false; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -25,10 +26,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -49,6 +51,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " %-7s 0 - whisper\n", ""); fprintf(stderr, " %-7s 1 - memcpy\n", ""); fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", ""); @@ -59,7 +62,9 @@ int whisper_bench_full(const whisper_params & params) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); diff --git a/examples/command/command.cpp b/examples/command/command.cpp index ec749d60247..cd6cc023994 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -44,6 +44,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -80,6 +81,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } @@ -118,6 +120,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); @@ -696,7 +699,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); diff --git a/examples/lsp/lsp.cpp b/examples/lsp/lsp.cpp index e5f8360f83d..3df54266a25 100644 --- a/examples/lsp/lsp.cpp +++ b/examples/lsp/lsp.cpp @@ -31,6 +31,7 @@ struct whisper_params { bool print_special = false; bool print_energy = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -74,6 +75,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else { @@ -105,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, "\n"); @@ -436,7 +439,10 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); // init audio diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d11c1c3f81b..45eb17fe7f3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -70,6 +70,7 @@ struct whisper_params { bool no_timestamps = false; bool log_score = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string prompt; @@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } else if ( arg == "--grammar") { params.grammar = argv[++i]; } else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } @@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); @@ -977,7 +980,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { cparams.dtw_token_timestamps = true; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e3b96698228..c78b3026e18 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -75,6 +75,7 @@ struct whisper_params { bool print_progress = false; bool no_timestamps = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string prompt = ""; @@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } // server params else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--host") { sparams.hostname = argv[++i]; } @@ -502,7 +504,10 @@ int main(int argc, char ** argv) { } // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + if (!params.dtw.empty()) { cparams.dtw_token_timestamps = true; cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index b82e379dc61..60c1b0894e4 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -36,6 +36,7 @@ struct whisper_params { bool tinydiarize = false; bool save_audio = false; // save audio to wav file bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); @@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false"); fprintf(stderr, "\n"); } @@ -153,7 +156,9 @@ int main(int argc, char ** argv) { } struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 838d6f56357..4aab62b9a6f 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -66,6 +66,7 @@ struct whisper_params { bool no_timestamps = true; bool verbose_prompt = false; bool use_gpu = true; + bool flash_attn = false; std::string person = "Georgi"; std::string bot_name = "LLaMA"; @@ -105,6 +106,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; } else if (arg == "--session") { params.path_session = argv[++i]; } @@ -123,7 +125,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { } } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -154,6 +155,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str()); fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str()); @@ -285,7 +287,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams); if (!ctx_wsp) { @@ -316,6 +320,7 @@ int main(int argc, char ** argv) { lcparams.n_ctx = 2048; lcparams.seed = 1; lcparams.n_threads = params.n_threads; + lcparams.flash_attn = params.flash_attn; struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams); diff --git a/examples/talk/talk.cpp b/examples/talk/talk.cpp index c1c6f8ba0b2..3e34e5724ff 100644 --- a/examples/talk/talk.cpp +++ b/examples/talk/talk.cpp @@ -32,6 +32,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; + bool flash_attn = false; std::string person = "Santa"; std::string language = "en"; @@ -64,6 +65,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } @@ -99,6 +101,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); @@ -188,7 +191,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams); diff --git a/examples/wchess/wchess.cmd/wchess.cmd.cpp b/examples/wchess/wchess.cmd/wchess.cmd.cpp index f66b1765f5b..09e53f13172 100644 --- a/examples/wchess/wchess.cmd/wchess.cmd.cpp +++ b/examples/wchess/wchess.cmd/wchess.cmd.cpp @@ -32,6 +32,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -61,6 +62,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during decoding\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); @@ -92,6 +94,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } @@ -183,7 +186,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (!ctx) { diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt new file mode 100644 index 00000000000..6fd5605a2bd --- /dev/null +++ b/scripts/bench-all-gg.txt @@ -0,0 +1,298 @@ +## M1 Pro + +make -j && ./scripts/bench-all.sh 8 + +Running memcpy benchmark + +memcpy: 39.10 GB/s (heat-up) +memcpy: 44.75 GB/s ( 1 thread) +memcpy: 44.78 GB/s ( 1 thread) +memcpy: 44.97 GB/s ( 2 thread) +memcpy: 48.04 GB/s ( 3 thread) +memcpy: 50.55 GB/s ( 4 thread) +memcpy: 55.20 GB/s ( 5 thread) +memcpy: 65.60 GB/s ( 6 thread) +memcpy: 70.64 GB/s ( 7 thread) +memcpy: 73.34 GB/s ( 8 thread) +sum: -5120002535.000000 + + +make -j && ./scripts/bench-all.sh 1 0 0 + +Running ggml_mul_mat benchmark with 1 threads + + 64 x 64: Q4_0 237.1 GFLOPS (128 runs) | Q4_1 168.6 GFLOPS (128 runs) + 64 x 64: Q5_0 136.4 GFLOPS (128 runs) | Q5_1 135.6 GFLOPS (128 runs) | Q8_0 243.1 GFLOPS (128 runs) + 64 x 64: F16 140.4 GFLOPS (128 runs) | F32 316.6 GFLOPS (128 runs) + 128 x 128: Q4_0 496.6 GFLOPS (128 runs) | Q4_1 348.6 GFLOPS (128 runs) + 128 x 128: Q5_0 273.2 GFLOPS (128 runs) | Q5_1 274.1 GFLOPS (128 runs) | Q8_0 505.1 GFLOPS (128 runs) + 128 x 128: F16 300.4 GFLOPS (128 runs) | F32 653.9 GFLOPS (128 runs) + 256 x 256: Q4_0 791.7 GFLOPS (128 runs) | Q4_1 615.3 GFLOPS (128 runs) + 256 x 256: Q5_0 651.0 GFLOPS (128 runs) | Q5_1 674.7 GFLOPS (128 runs) | Q8_0 803.1 GFLOPS (128 runs) + 256 x 256: F16 869.6 GFLOPS (128 runs) | F32 957.2 GFLOPS (128 runs) + 512 x 512: Q4_0 973.3 GFLOPS (128 runs) | Q4_1 897.9 GFLOPS (128 runs) + 512 x 512: Q5_0 1078.8 GFLOPS (128 runs) | Q5_1 998.4 GFLOPS (128 runs) | Q8_0 752.4 GFLOPS (128 runs) + 512 x 512: F16 892.5 GFLOPS (128 runs) | F32 1399.6 GFLOPS (128 runs) +1024 x 1024: Q4_0 1402.7 GFLOPS (128 runs) | Q4_1 1218.5 GFLOPS (128 runs) +1024 x 1024: Q5_0 1444.8 GFLOPS (128 runs) | Q5_1 1444.7 GFLOPS (128 runs) | Q8_0 1395.7 GFLOPS (128 runs) +1024 x 1024: F16 1524.1 GFLOPS (128 runs) | F32 1726.6 GFLOPS (128 runs) +2048 x 2048: Q4_0 1479.4 GFLOPS ( 87 runs) | Q4_1 1378.5 GFLOPS ( 81 runs) +2048 x 2048: Q5_0 1454.6 GFLOPS ( 85 runs) | Q5_1 1462.9 GFLOPS ( 86 runs) | Q8_0 1483.2 GFLOPS ( 87 runs) +2048 x 2048: F16 1488.0 GFLOPS ( 87 runs) | F32 1538.2 GFLOPS ( 90 runs) +4096 x 4096: Q4_0 1509.7 GFLOPS ( 11 runs) | Q4_1 1433.0 GFLOPS ( 11 runs) +4096 x 4096: Q5_0 1422.4 GFLOPS ( 11 runs) | Q5_1 1437.0 GFLOPS ( 11 runs) | Q8_0 1523.0 GFLOPS ( 12 runs) +4096 x 4096: F16 1551.3 GFLOPS ( 12 runs) | F32 1451.0 GFLOPS ( 11 runs) + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M1 Pro | METAL | tiny | 1 | 0 | 39.21 | 1.74 | 0.61 | 0.04 | 22c96b4 | +| M1 Pro | METAL | base | 1 | 0 | 70.76 | 2.60 | 0.93 | 0.06 | 22c96b4 | +| M1 Pro | METAL | small | 1 | 0 | 217.28 | 6.42 | 2.14 | 0.17 | 22c96b4 | +| M1 Pro | METAL | medium | 1 | 0 | 596.74 | 14.43 | 4.75 | 0.45 | 22c96b4 | + + +make -j && ./scripts/bench-all.sh 1 1 1 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M1 Pro | METAL | tiny | 1 | 1 | 30.77 | 1.59 | 0.54 | 0.03 | 22c96b4 | +| M1 Pro | METAL | base | 1 | 1 | 60.42 | 2.29 | 0.81 | 0.05 | 22c96b4 | +| M1 Pro | METAL | small | 1 | 1 | 183.82 | 5.12 | 1.81 | 0.14 | 22c96b4 | +| M1 Pro | METAL | medium | 1 | 1 | 517.92 | 11.60 | 4.01 | 0.38 | 22c96b4 | + + +## M2 Ultra + +make -j && ./scripts/bench-all.sh 8 + +Running memcpy benchmark + +memcpy: 46.58 GB/s (heat-up) +memcpy: 54.16 GB/s ( 1 thread) +memcpy: 54.23 GB/s ( 1 thread) +memcpy: 99.63 GB/s ( 2 thread) +memcpy: 140.59 GB/s ( 3 thread) +memcpy: 176.52 GB/s ( 4 thread) +memcpy: 158.90 GB/s ( 5 thread) +memcpy: 163.00 GB/s ( 6 thread) +memcpy: 189.69 GB/s ( 7 thread) +memcpy: 197.15 GB/s ( 8 thread) +sum: -5120002007.000000 + + +make -j && ./scripts/bench-all.sh 1 + +Running ggml_mul_mat benchmark with 1 threads + + 64 x 64: Q4_0 245.8 GFLOPS (128 runs) | Q4_1 168.6 GFLOPS (128 runs) + 64 x 64: Q5_0 115.7 GFLOPS (128 runs) | Q5_1 125.9 GFLOPS (128 runs) | Q8_0 215.8 GFLOPS (128 runs) + 64 x 64: F16 139.5 GFLOPS (128 runs) | F32 337.2 GFLOPS (128 runs) + 128 x 128: Q4_0 494.8 GFLOPS (128 runs) | Q4_1 350.4 GFLOPS (128 runs) + 128 x 128: Q5_0 257.1 GFLOPS (128 runs) | Q5_1 261.4 GFLOPS (128 runs) | Q8_0 509.4 GFLOPS (128 runs) + 128 x 128: F16 302.3 GFLOPS (128 runs) | F32 672.8 GFLOPS (128 runs) + 256 x 256: Q4_0 795.7 GFLOPS (128 runs) | Q4_1 663.7 GFLOPS (128 runs) + 256 x 256: Q5_0 737.8 GFLOPS (128 runs) | Q5_1 757.6 GFLOPS (128 runs) | Q8_0 827.7 GFLOPS (128 runs) + 256 x 256: F16 872.6 GFLOPS (128 runs) | F32 956.3 GFLOPS (128 runs) + 512 x 512: Q4_0 1188.0 GFLOPS (128 runs) | Q4_1 1085.0 GFLOPS (128 runs) + 512 x 512: Q5_0 1421.1 GFLOPS (128 runs) | Q5_1 1454.9 GFLOPS (128 runs) | Q8_0 1191.4 GFLOPS (128 runs) + 512 x 512: F16 1577.4 GFLOPS (128 runs) | F32 1982.0 GFLOPS (128 runs) +1024 x 1024: Q4_0 2342.6 GFLOPS (128 runs) | Q4_1 1955.8 GFLOPS (128 runs) +1024 x 1024: Q5_0 2306.7 GFLOPS (128 runs) | Q5_1 2217.0 GFLOPS (128 runs) | Q8_0 2230.7 GFLOPS (128 runs) +1024 x 1024: F16 2593.8 GFLOPS (128 runs) | F32 3269.0 GFLOPS (128 runs) +2048 x 2048: Q4_0 3735.7 GFLOPS (128 runs) | Q4_1 3205.3 GFLOPS (128 runs) +2048 x 2048: Q5_0 3584.5 GFLOPS (128 runs) | Q5_1 3621.7 GFLOPS (128 runs) | Q8_0 3622.3 GFLOPS (128 runs) +2048 x 2048: F16 3763.6 GFLOPS (128 runs) | F32 4153.3 GFLOPS (128 runs) +4096 x 4096: Q4_0 3891.1 GFLOPS ( 29 runs) | Q4_1 3554.0 GFLOPS ( 26 runs) +4096 x 4096: Q5_0 3753.1 GFLOPS ( 28 runs) | Q5_1 3750.1 GFLOPS ( 28 runs) | Q8_0 3768.5 GFLOPS ( 28 runs) +4096 x 4096: F16 3864.2 GFLOPS ( 29 runs) | F32 3970.5 GFLOPS ( 29 runs) + + +make -j && ./scripts/bench-all.sh 1 1 0 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M2 ULTRA | METAL | tiny | 1 | 0 | 12.32 | 1.35 | 0.49 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 11.65 | 1.30 | 0.51 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 12.08 | 1.30 | 0.51 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | base | 1 | 0 | 17.58 | 1.90 | 0.76 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 18.89 | 1.86 | 0.79 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 20.69 | 1.88 | 0.79 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | small | 1 | 0 | 49.32 | 3.85 | 1.71 | 0.05 | 22c96b4 | +| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 54.91 | 3.81 | 1.82 | 0.06 | 22c96b4 | +| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 54.92 | 3.81 | 1.79 | 0.06 | 22c96b4 | +| M2 ULTRA | METAL | medium | 1 | 0 | 134.34 | 8.04 | 3.82 | 0.13 | 22c96b4 | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 151.68 | 7.59 | 4.07 | 0.14 | 22c96b4 | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 151.58 | 7.67 | 4.07 | 0.14 | 22c96b4 | +| M2 ULTRA | METAL | medium-dis | 1 | 0 | 120.82 | 1.07 | 0.41 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | large-v2 | 1 | 0 | 235.63 | 12.27 | 5.85 | 0.22 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 273.38 | 11.17 | 6.40 | 0.26 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 272.44 | 11.32 | 6.29 | 0.26 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 212.51 | 1.20 | 0.47 | 0.02 | 22c96b4 | + + +make -j && ./scripts/bench-all.sh 1 1 1 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M2 ULTRA | METAL | tiny | 1 | 1 | 9.07 | 1.33 | 0.45 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 9.74 | 1.33 | 0.47 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 8.93 | 1.31 | 0.46 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | base | 1 | 1 | 15.75 | 1.87 | 0.71 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 17.04 | 1.83 | 0.74 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 17.17 | 1.83 | 0.74 | 0.02 | 22c96b4 | +| M2 ULTRA | METAL | small | 1 | 1 | 42.33 | 3.64 | 1.60 | 0.05 | 22c96b4 | +| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 47.61 | 3.63 | 1.70 | 0.05 | 22c96b4 | +| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 47.70 | 3.66 | 1.68 | 0.05 | 22c96b4 | +| M2 ULTRA | METAL | medium | 1 | 1 | 114.42 | 7.53 | 3.55 | 0.11 | 22c96b4 | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 132.63 | 7.02 | 3.77 | 0.13 | 22c96b4 | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 132.28 | 7.10 | 3.76 | 0.13 | 22c96b4 | +| M2 ULTRA | METAL | medium-dis | 1 | 1 | 102.34 | 1.01 | 0.42 | 0.01 | 22c96b4 | +| M2 ULTRA | METAL | large-v2 | 1 | 1 | 203.01 | 11.03 | 5.45 | 0.20 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 240.05 | 10.18 | 5.98 | 0.23 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 239.22 | 10.23 | 5.87 | 0.23 | 22c96b4 | +| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 181.14 | 1.14 | 0.48 | 0.02 | 22c96b4 | + + + +## Ryzen 9 5950X + RTX 2060 + +make -j && ./scripts/bench-all.sh 8 0 0 + +Running memcpy benchmark + +memcpy: 12.36 GB/s (heat-up) +memcpy: 12.33 GB/s ( 1 thread) +memcpy: 12.38 GB/s ( 1 thread) +memcpy: 14.48 GB/s ( 2 thread) +memcpy: 15.00 GB/s ( 3 thread) +memcpy: 14.77 GB/s ( 4 thread) +memcpy: 14.60 GB/s ( 5 thread) +memcpy: 14.57 GB/s ( 6 thread) +memcpy: 14.34 GB/s ( 7 thread) +memcpy: 14.40 GB/s ( 8 thread) +sum: -5119998076.000000 + +Running ggml_mul_mat benchmark with 8 threads + + 64 x 64: Q4_0 3.1 GFLOPS (128 runs) | Q4_1 3.1 GFLOPS (128 runs) + 64 x 64: Q5_0 3.0 GFLOPS (128 runs) | Q5_1 2.9 GFLOPS (128 runs) | Q8_0 3.1 GFLOPS (128 runs) + 64 x 64: F16 3.0 GFLOPS (128 runs) | F32 3.0 GFLOPS (128 runs) + 128 x 128: Q4_0 21.1 GFLOPS (128 runs) | Q4_1 20.3 GFLOPS (128 runs) + 128 x 128: Q5_0 20.6 GFLOPS (128 runs) | Q5_1 20.4 GFLOPS (128 runs) | Q8_0 22.1 GFLOPS (128 runs) + 128 x 128: F16 21.7 GFLOPS (128 runs) | F32 21.7 GFLOPS (128 runs) + 256 x 256: Q4_0 105.7 GFLOPS (128 runs) | Q4_1 94.4 GFLOPS (128 runs) + 256 x 256: Q5_0 94.8 GFLOPS (128 runs) | Q5_1 87.5 GFLOPS (128 runs) | Q8_0 107.2 GFLOPS (128 runs) + 256 x 256: F16 95.1 GFLOPS (128 runs) | F32 94.3 GFLOPS (128 runs) + 512 x 512: Q4_0 214.7 GFLOPS (128 runs) | Q4_1 189.8 GFLOPS (128 runs) + 512 x 512: Q5_0 187.7 GFLOPS (128 runs) | Q5_1 176.2 GFLOPS (128 runs) | Q8_0 252.2 GFLOPS (128 runs) + 512 x 512: F16 220.8 GFLOPS (128 runs) | F32 218.3 GFLOPS (128 runs) +1024 x 1024: Q4_0 333.7 GFLOPS (128 runs) | Q4_1 305.8 GFLOPS (128 runs) +1024 x 1024: Q5_0 283.2 GFLOPS (128 runs) | Q5_1 268.2 GFLOPS (125 runs) | Q8_0 394.1 GFLOPS (128 runs) +1024 x 1024: F16 355.0 GFLOPS (128 runs) | F32 313.0 GFLOPS (128 runs) +2048 x 2048: Q4_0 395.0 GFLOPS ( 23 runs) | Q4_1 380.6 GFLOPS ( 23 runs) +2048 x 2048: Q5_0 336.6 GFLOPS ( 20 runs) | Q5_1 318.4 GFLOPS ( 19 runs) | Q8_0 482.6 GFLOPS ( 29 runs) +2048 x 2048: F16 424.5 GFLOPS ( 25 runs) | F32 337.7 GFLOPS ( 20 runs) +4096 x 4096: Q4_0 412.8 GFLOPS ( 4 runs) | Q4_1 405.1 GFLOPS ( 3 runs) +4096 x 4096: Q5_0 346.0 GFLOPS ( 3 runs) | Q5_1 334.6 GFLOPS ( 3 runs) | Q8_0 502.6 GFLOPS ( 4 runs) +4096 x 4096: F16 412.5 GFLOPS ( 4 runs) | F32 274.0 GFLOPS ( 3 runs) + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| Ryzen 9 5950X | AVX2 | tiny | 8 | 0 | 195.29 | 1.57 | 0.51 | 0.26 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | tiny-q5_0 | 8 | 0 | 213.33 | 1.10 | 0.50 | 0.30 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | tiny-q5_1 | 8 | 0 | 219.38 | 1.18 | 0.53 | 0.32 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | base | 8 | 0 | 424.85 | 3.71 | 1.03 | 0.46 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | base-q5_0 | 8 | 0 | 473.61 | 1.81 | 0.82 | 0.52 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | base-q5_1 | 8 | 0 | 484.14 | 1.92 | 0.85 | 0.56 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | small | 8 | 0 | 1458.32 | 12.66 | 3.09 | 1.26 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | small-q5_0 | 8 | 0 | 1673.22 | 6.42 | 2.18 | 1.45 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | small-q5_1 | 8 | 0 | 1724.78 | 6.72 | 2.32 | 1.52 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | medium | 8 | 0 | 4333.87 | 36.80 | 8.56 | 3.37 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | medium-q5_0 | 8 | 0 | 5194.09 | 19.21 | 5.71 | 3.97 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | medium-q5_1 | 8 | 0 | 5450.39 | 20.01 | 5.99 | 4.17 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | medium-dis | 8 | 0 | 3995.19 | 5.08 | 1.21 | 0.55 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | large-v2 | 8 | 0 | 8056.16 | 69.74 | 16.11 | 6.13 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | large-v2-q5_0 | 8 | 0 | 9799.58 | 35.16 | 10.49 | 7.28 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | large-v2-q5_1 | 8 | 0 | ms | 36.74 | 11.02 | 7.65 | 22c96b4 | +| Ryzen 9 5950X | AVX2 | large-v2-dis | 8 | 0 | 7490.03 | 7.40 | 1.70 | 0.72 | 22c96b4 | + + +WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0 + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| RTX 2060 | AVX2 CUDA | tiny | 8 | 0 | 12.54 | 0.93 | 0.29 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | tiny-q5_0 | 8 | 0 | 12.73 | 0.98 | 0.24 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | tiny-q5_1 | 8 | 0 | 12.72 | 0.99 | 0.24 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base | 8 | 0 | 24.14 | 1.28 | 0.41 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base-q5_0 | 8 | 0 | 24.58 | 1.38 | 0.35 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base-q5_1 | 8 | 0 | 24.58 | 1.37 | 0.35 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small | 8 | 0 | 74.70 | 2.91 | 0.84 | 0.07 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small-q5_0 | 8 | 0 | 76.12 | 2.84 | 0.77 | 0.08 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small-q5_1 | 8 | 0 | 76.14 | 2.84 | 0.76 | 0.08 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium | 8 | 0 | 200.69 | 6.46 | 1.83 | 0.17 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-q5_0 | 8 | 0 | 204.80 | 5.90 | 1.65 | 0.19 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-q5_1 | 8 | 0 | 205.61 | 5.85 | 1.61 | 0.19 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-dis | 8 | 0 | 186.17 | 0.86 | 0.24 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2 | 8 | 0 | 347.22 | 10.36 | 2.82 | 0.29 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-q5_0 | 8 | 0 | 357.06 | 8.81 | 2.58 | 0.34 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-q5_1 | 8 | 0 | 356.97 | 8.62 | 2.49 | 0.33 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-dis | 8 | 0 | 318.05 | 1.03 | 0.34 | 0.04 | 22c96b4 | + + +WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1 + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| RTX 2060 | AVX2 CUDA | tiny | 8 | 1 | 7.21 | 0.76 | 0.29 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | tiny-q5_0 | 8 | 1 | 7.42 | 0.82 | 0.18 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | tiny-q5_1 | 8 | 1 | 7.38 | 0.82 | 0.18 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base | 8 | 1 | 13.49 | 1.04 | 0.36 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base-q5_0 | 8 | 1 | 13.94 | 1.13 | 0.26 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | base-q5_1 | 8 | 1 | 13.94 | 1.14 | 0.26 | 0.03 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small | 8 | 1 | 42.81 | 2.33 | 0.69 | 0.05 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small-q5_0 | 8 | 1 | 44.43 | 2.25 | 0.59 | 0.06 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | small-q5_1 | 8 | 1 | 44.11 | 2.24 | 0.58 | 0.06 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium | 8 | 1 | 115.47 | 5.17 | 1.45 | 0.11 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-q5_0 | 8 | 1 | 120.37 | 4.63 | 1.25 | 0.13 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-q5_1 | 8 | 1 | 120.28 | 4.55 | 1.21 | 0.13 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | medium-dis | 8 | 1 | 101.69 | 0.75 | 0.20 | 0.02 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2 | 8 | 1 | 205.67 | 8.49 | 2.19 | 0.18 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-q5_0 | 8 | 1 | 214.07 | 6.88 | 1.94 | 0.22 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-q5_1 | 8 | 1 | 213.98 | 6.70 | 1.86 | 0.22 | 22c96b4 | +| RTX 2060 | AVX2 CUDA | large-v2-dis | 8 | 1 | 176.71 | 0.91 | 0.31 | 0.03 | 22c96b4 | + + + + +# V100 + +WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0 + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| V100 | AVX2 CUDA | tiny | 1 | 0 | 6.21 | 1.11 | 0.30 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | tiny-q5_1 | 1 | 0 | 5.97 | 1.10 | 0.26 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | base | 1 | 0 | 10.95 | 1.47 | 0.42 | 0.03 | 22c96b4 | +| V100 | AVX2 CUDA | base-q5_1 | 1 | 0 | 11.13 | 1.53 | 0.36 | 0.03 | 22c96b4 | +| V100 | AVX2 CUDA | small | 1 | 0 | 31.57 | 2.96 | 0.84 | 0.05 | 22c96b4 | +| V100 | AVX2 CUDA | small-q5_1 | 1 | 0 | 32.19 | 3.14 | 0.75 | 0.05 | 22c96b4 | +| V100 | AVX2 CUDA | medium | 1 | 0 | 85.88 | 6.49 | 1.80 | 0.10 | 22c96b4 | +| V100 | AVX2 CUDA | medium-q5_0 | 1 | 0 | 87.53 | 5.82 | 1.37 | 0.10 | 22c96b4 | +| V100 | AVX2 CUDA | large-v2 | 1 | 0 | 142.23 | 8.92 | 2.62 | 0.15 | 22c96b4 | + + +WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1 + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| V100 | AVX2 CUDA | tiny | 1 | 1 | 3.96 | 0.82 | 0.24 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | tiny-q5_1 | 1 | 1 | 4.05 | 0.85 | 0.18 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | base | 1 | 1 | 7.21 | 1.16 | 0.36 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | base-q5_1 | 1 | 1 | 7.39 | 1.21 | 0.26 | 0.02 | 22c96b4 | +| V100 | AVX2 CUDA | small | 1 | 1 | 19.81 | 2.41 | 0.71 | 0.04 | 22c96b4 | +| V100 | AVX2 CUDA | small-q5_1 | 1 | 1 | 20.50 | 2.31 | 0.51 | 0.04 | 22c96b4 | +| V100 | AVX2 CUDA | medium | 1 | 1 | 56.02 | 4.89 | 1.44 | 0.07 | 22c96b4 | +| V100 | AVX2 CUDA | medium-q5_0 | 1 | 1 | 57.85 | 4.73 | 1.09 | 0.08 | 22c96b4 | +| V100 | AVX2 CUDA | large-v2 | 1 | 1 | 92.73 | 7.18 | 2.14 | 0.10 | 22c96b4 | + diff --git a/scripts/bench-all.sh b/scripts/bench-all.sh index 6939dafaca0..8a857c67b6c 100755 --- a/scripts/bench-all.sh +++ b/scripts/bench-all.sh @@ -2,7 +2,7 @@ # Helper script to run the bench tool on all models and print the results in share-able format -printf "Usage: ./bench.sh [n_threads] [encoder-only]\n" +printf "Usage: ./bench.sh [n_threads] [encoder-only] [flash-attn]\n" if [ -z "$1" ]; then n_threads=4 @@ -11,12 +11,19 @@ else fi encoder_only=0 -if [ -z "$2" ]; then +if [ -z "$2" ] || [ "$2" -eq 0 ]; then encoder_only=0 else encoder_only=$2 fi +fattn="" +if [ -z "$3" ] || [ "$3" -eq 0 ]; then + fattn="" +else + fattn="-fa" +fi + models=( \ "tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \ "base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \ @@ -44,13 +51,19 @@ if [ "$encoder_only" -eq 0 ]; then printf "\n" fi -printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit" -printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" +if [ "$fattn" == "-fa" ]; then + fattn_i=1 +else + fattn_i=0 +fi + +printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "FA" "Enc." "Dec." "Bch5" "PP" "Commit" +printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" for model in "${models[@]}"; do # actual run # store stderr output in a variable in order to parse it later - output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1) + output=$(./bench -m ./models/ggml-$model.bin -t $n_threads $fattn 2>&1) ret=$? # parse the output: @@ -95,6 +108,6 @@ for model in "${models[@]}"; do commit=$(git rev-parse --short HEAD) if [ $ret -eq 0 ]; then - printf "| | | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit" + printf "| | | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$fattn_i" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit" fi done diff --git a/whisper.cpp b/whisper.cpp index ff4223daf42..84aec8238cd 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -809,14 +809,15 @@ struct whisper_state { // shared between all decoders whisper_kv_cache kv_cross; + // padded buffer for flash-attention + whisper_kv_cache kv_pad; + whisper_mel mel; whisper_batch batch; whisper_decoder decoders[WHISPER_MAX_DECODERS]; - ggml_backend_t backend = nullptr; - // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers // - stores the actual tensor data into the `data` buffers @@ -902,14 +903,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) { } static bool kv_cache_init( - const struct whisper_hparams & hparams, struct whisper_kv_cache & cache, ggml_backend_t backend, ggml_type wtype, + int64_t n_text_state, + int64_t n_text_layer, int n_ctx) { - const int64_t n_text_state = hparams.n_text_state; - const int64_t n_text_layer = hparams.n_text_layer; - const int64_t n_mem = n_text_layer*n_ctx; const int64_t n_elements = n_text_state*n_mem; @@ -941,6 +940,8 @@ static bool kv_cache_init( return false; } + ggml_backend_buffer_clear(cache.buffer, 0); + return true; } @@ -1068,6 +1069,26 @@ static void whisper_kv_cache_seq_cp( } } +static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) { + if (!wctx.params.flash_attn) { + return 1u; + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(wctx.backend)) { + return 32u; + } +#endif + +#ifdef GGML_USE_CUDA + if (ggml_backend_is_cuda(wctx.backend)) { + return 256u; + } +#endif + + return 1u; +} + // [EXPERIMENTAL] Token-level timestamps with DTW static bool aheads_masks_init( const whisper_context_params & cparams, @@ -1872,6 +1893,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; + const int n_state_head = n_state/n_head; + + auto & kv_pad = wstate.kv_pad; + + WHISPER_ASSERT(!!kv_pad.ctx); + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_encode.meta.size(), /*.mem_buffer =*/ wstate.alloc_encode.meta.data(), @@ -1884,7 +1913,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); - const float KQscale = 1.0f/sqrtf(float(n_state)/n_head); + const float KQscale = 1.0f/sqrtf(float(n_state_head)); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1934,14 +1963,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); - //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25)); + //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); - //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25)); + //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, @@ -1955,38 +1984,61 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)), 0, 2, 1, 3); - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + if (wctx.params.flash_attn) { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0))); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_pad.k, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.k)*n_state, + ggml_element_size(kv_pad.k)*n_state_head, + 0); - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) - ); + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_pad.v, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.v)*n_state, + ggml_element_size(kv_pad.v)*n_state_head, + 0); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); + } else { + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } } // projection @@ -2085,6 +2137,10 @@ static struct ggml_cgraph * whisper_build_graph_cross( const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; + const int n_state_head = n_state/n_head; + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_cross.meta.size(), /*.mem_buffer =*/ wstate.alloc_cross.meta.data(), @@ -2097,18 +2153,18 @@ static struct ggml_cgraph * whisper_build_graph_cross( struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); - const float Kscale = pow(float(n_state) / n_head, -0.25); + const float Kscale = pow(float(n_state_head), -0.25); for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; - struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, cur); Kcross = ggml_scale(ctx0, Kcross, Kscale); - struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, cur); @@ -2116,15 +2172,25 @@ static struct ggml_cgraph * whisper_build_graph_cross( Vcross, layer.cross_attn_v_b); - Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + struct ggml_tensor * k; + struct ggml_tensor * v; - struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, - n_state*n_ctx, - (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + if (wctx.params.flash_attn) { + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad)); - struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, - ( n_ctx)*ggml_element_size(wstate.kv_cross.v), - (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad)); + } else { + Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + + v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + ( n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v)); @@ -2195,7 +2261,7 @@ static bool whisper_encode_internal( } if (!whisper_encode_external(wstate)) { - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } else { @@ -2218,7 +2284,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -2234,7 +2300,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -2263,11 +2329,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder( const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; + const int n_state_head = n_state/n_head; + const int n_tokens = batch.n_tokens; const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; + const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256); + + const int32_t n_kv = worst_case ? n_ctx : kv_self.n; + const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); @@ -2289,12 +2359,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_set_name(position, "position"); ggml_set_input(position); - const float KQscale = pow(float(n_state)/n_head, -0.25); + const float KQscale = pow(float(n_state_head), -0.25); - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1); ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_input(KQ_mask); + struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2350,12 +2422,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Vcur, layer.attn_v_b); - Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); + struct ggml_tensor * k; + struct ggml_tensor * v; + + if (wctx.params.flash_attn) { + k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, + (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + + v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state, + (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head)); + } else { + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); + k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, + (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + + v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -2365,35 +2450,48 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_state/n_head, n_kv, n_head, + n_state_head, n_kv, n_head, ggml_element_size(kv_self.k)*n_state, - ggml_element_size(kv_self.k)*n_state/n_head, + ggml_element_size(kv_self.k)*n_state_head, ggml_element_size(kv_self.k)*n_state*n_ctx*il); - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + if (wctx.params.flash_attn) { + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.v)*n_state, + ggml_element_size(kv_self.v)*n_state_head, + ggml_element_size(kv_self.v)*n_state*n_ctx*il); + + cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } else { + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_state/n_head, n_head, - n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, - n_ctx*ggml_element_size(kv_self.v)*n_state*il); + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_state_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_state_head, + n_ctx*ggml_element_size(kv_self.v)*n_state*il); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } } // projection @@ -2432,80 +2530,77 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Qcur, layer.cross_attn_q_b); - Qcur = ggml_scale(ctx0, Qcur, KQscale); - - // Kcross is already scaled - struct ggml_tensor * Kcross = - ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state/n_head, n_audio_ctx, n_head, - ggml_element_size(wstate.kv_cross.k)*n_state, - ggml_element_size(wstate.kv_cross.k)*n_state/n_head, - ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); - - //struct ggml_tensor * Vcross = - // ggml_reshape_3d(ctx0, - // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state), - // n_state/n_head, n_head, n_audio_ctx); - - //struct ggml_tensor * V_trans = - // ggml_cpy(ctx0, - // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head)); - - struct ggml_tensor * V = - ggml_view_3d(ctx0, wstate.kv_cross.v, - n_audio_ctx, n_state/n_head, n_head, - n_audio_ctx*ggml_element_size(wstate.kv_cross.v), - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); - - // ------ - struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); - - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctx0, - // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - // ); + if (wctx.params.flash_attn) { + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state_head, + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il); - // no masking for cross-attention - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + struct ggml_tensor * Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.v)*n_state, + ggml_element_size(wstate.kv_cross.v)*n_state_head, + ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); + cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f); - // [EXPERIMENTAL] Token-level timestamps with DTW - if (wctx.params.dtw_token_timestamps) { - if (wstate.aheads_masks.m[il] != nullptr) { - struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); - aheads_KQs = ggml_transpose(ctx0, aheads_KQs); - aheads_KQs = ggml_cont(ctx0, aheads_KQs); - aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); - aheads_KQs = ggml_transpose(ctx0, aheads_KQs); - aheads_KQs = ggml_cont(ctx0, aheads_KQs); - aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); - if (aheads_cross_QKs == NULL) { - aheads_cross_QKs = aheads_KQs; - } else { - aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } else { + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state_head, + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); + + struct ggml_tensor * Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_audio_ctx, n_state_head, n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v), + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); + + // ------ + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps) { + if (wstate.aheads_masks.m[il] != nullptr) { + struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); + if (aheads_cross_QKs == NULL) { + aheads_cross_QKs = aheads_KQs; + } else { + aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs); + } } } - } - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // cur = KQV_merged.contiguous().view(n_state, n_tokens) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } } // projection @@ -2638,7 +2733,9 @@ static bool whisper_decode_internal( return false; } - kv_self.n = whisper_kv_cache_cell_max(kv_self); + const uint32_t pad = whisper_kv_cache_get_padding(wctx); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad))); + //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); } @@ -2672,9 +2769,10 @@ static bool whisper_decode_internal( struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask"); auto & kv_self = wstate.kv_self; - const int32_t n_kv = kv_self.n; - wstate.inp_mask.resize(n_kv*n_tokens); + const int32_t n_kv = kv_self.n; + + wstate.inp_mask.resize(ggml_nelements(KQ_mask)); float * data = wstate.inp_mask.data(); memset(data, 0, ggml_nbytes(KQ_mask)); @@ -2690,6 +2788,12 @@ static bool whisper_decode_internal( } } } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } } ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float)); @@ -2697,7 +2801,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -3144,18 +3248,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_state * state = new whisper_state; - state->backend = whisper_backend_init(ctx->params); - if (!state->backend) { - WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); - whisper_free_state(state); - return nullptr; - } - // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // in theory, there can be a case where this is not enough, but in practice it should always be enough const int factor = 3; - if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) { + if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3166,7 +3266,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); } - if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3177,6 +3280,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } + if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, + ctx->model.hparams.n_audio_state, + 1, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v); + WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6); + } + // [EXPERIMENTAL] Token-level timestamps with DTW if (ctx->params.dtw_token_timestamps) { if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { @@ -3347,6 +3464,7 @@ int whisper_ctx_init_openvino_encoder( struct whisper_context_params whisper_context_default_params() { struct whisper_context_params result = { /*.use_gpu =*/ true, + /*.flash_attn =*/ false, /*.gpu_device =*/ 0, /*.dtw_token_timestamps =*/ false, @@ -3445,6 +3563,16 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) { ggml_time_init(); + if (params.flash_attn && params.dtw_token_timestamps) { + WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); + params.dtw_token_timestamps = false; + } + + WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); + WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps); + whisper_context * ctx = new whisper_context; ctx->params = params; @@ -3533,6 +3661,7 @@ void whisper_free_state(struct whisper_state * state) { if (state) { kv_cache_free(state->kv_self); kv_cache_free(state->kv_cross); + kv_cache_free(state->kv_pad); #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { @@ -3555,8 +3684,6 @@ void whisper_free_state(struct whisper_state * state) { ggml_gallocr_free(state->alloc_cross.alloc); ggml_gallocr_free(state->alloc_decode.alloc); - ggml_backend_free(state->backend); - // [EXPERIMENTAL] Token-level timestamps with DTW aheads_masks_free(state->aheads_masks); diff --git a/whisper.h b/whisper.h index 6a875d3bbb9..9c7c58d874b 100644 --- a/whisper.h +++ b/whisper.h @@ -113,6 +113,7 @@ extern "C" { struct whisper_context_params { bool use_gpu; + bool flash_attn; int gpu_device; // CUDA device // [EXPERIMENTAL] Token-level timestamps with DTW From 08981d1bacbe494ff1c943af6c577c669a2d9f4d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 May 2024 09:59:48 +0300 Subject: [PATCH 085/100] release : v1.6.0 --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/ios | 2 +- bindings/javascript/package.json | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cdffbcaa1c0..588aa61cd11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required (VERSION 3.5) # Allow for the creation of solution folders. set_property(GLOBAL PROPERTY USE_FOLDERS ON) -project(whisper.cpp VERSION 1.5.5) +project(whisper.cpp VERSION 1.6.0) set(SOVERSION 1) # Add path to modules diff --git a/README.md b/README.md index 33570ef02bc..0c34e8dffce 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.5.5](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.5) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) +Stable: [v1.6.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/ios b/bindings/ios index 0c6cfa58a2c..5cfcfb0801b 160000 --- a/bindings/ios +++ b/bindings/ios @@ -1 +1 @@ -Subproject commit 0c6cfa58a2c7384f567a5680459a0deb79224881 +Subproject commit 5cfcfb0801be756d8347822b472e4b5e343f403f diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index f64d975663e..354d0ce903c 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.5.5", + "version": "1.6.0", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 4798be1f9a8e9bb4aaf05884e852902274235fdc Mon Sep 17 00:00:00 2001 From: Tamotsu Takahashi Date: Sun, 19 May 2024 17:49:26 +0900 Subject: [PATCH 086/100] ci: Update build.yml to suppress warnings about node.js versions (#2166) * Update actions to suppress warnings about old node.js https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/ * Update actions/upload-artifact, specify android cmdline-tools-version * Use java 20 gradle 8.1 complains against 21 https://docs.gradle.org/current/userguide/compatibility.html --- .github/workflows/build.yml | 92 ++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2d75fc31466..e9bf9c28292 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,10 +15,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Build ${{ matrix.arch }} run: | @@ -36,7 +36,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Dependencies run: | @@ -53,10 +53,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Build - uses: cross-platform-actions/action@v0.15.0 + uses: cross-platform-actions/action@v0.24.0 with: operating_system: freebsd version: '13.2' @@ -77,10 +77,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Build ${{ matrix.arch }} run: | @@ -105,10 +105,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Build ${{ matrix.arch }} run: | @@ -133,10 +133,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Build ${{ matrix.arch }} run: | @@ -165,7 +165,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: add oneAPI to apt shell: bash @@ -189,7 +189,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Build id: cmake_build @@ -215,7 +215,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: add oneAPI to apt shell: bash @@ -239,7 +239,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Build id: cmake_build @@ -262,7 +262,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup ${{ matrix.sys }} uses: msys2/setup-msys2@v2 @@ -328,10 +328,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1 + uses: microsoft/setup-msbuild@v2 - name: Fetch SDL2 and set SDL2_DIR if: matrix.sdl2 == 'ON' @@ -356,14 +356,14 @@ jobs: run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - name: Upload dll - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ matrix.jnaPath }}_whisper.dll path: build/bin/${{ matrix.build }}/whisper.dll - name: Upload binaries if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: whisper-bin-${{ matrix.arch }} path: build/bin/${{ matrix.build }} @@ -392,10 +392,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1 + uses: microsoft/setup-msbuild@v2 - name: Fetch OpenBLAS if: matrix.blas == 'ON' @@ -453,7 +453,7 @@ jobs: - name: Upload binaries if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: whisper-blas${{ matrix.clblast == 'ON' && '-clblast' || ''}}-bin-${{ matrix.arch }} path: build/bin/${{ matrix.build }} @@ -476,14 +476,14 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1 + uses: microsoft/setup-msbuild@v2 - name: Install CUDA Toolkit id: cuda-toolkit - uses: Jimver/cuda-toolkit@v0.2.11 + uses: Jimver/cuda-toolkit@v0.2.15 with: cuda: '${{ matrix.cuda-toolkit }}' @@ -519,7 +519,7 @@ jobs: - name: Upload binaries if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }} path: build/bin/${{ matrix.build }} @@ -533,10 +533,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v12 + uses: mymindstorm/setup-emsdk@v14 - name: Verify run: emcc -v @@ -555,7 +555,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Configure run: | @@ -573,24 +573,24 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: path: whisper - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: ggerganov/ggml path: ggml - name: Install Java - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: distribution: zulu - java-version: 17 + java-version: 21 - name: Setup Android SDK - uses: android-actions/setup-android@v2 + uses: android-actions/setup-android@v3 - name: Build run: | @@ -608,20 +608,19 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: set up JDK 11 - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: java-version: '11' distribution: 'temurin' cache: gradle - name: Setup Android SDK - uses: android-actions/setup-android@v2 + uses: android-actions/setup-android@v3 with: - api-level: 30 - build-tools-version: 30.0.3 + cmdline-tools-version: 9.0 - name: Build run: | @@ -633,15 +632,16 @@ jobs: needs: [ 'windows' ] runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Java - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 17 + distribution: zulu + java-version: 20 - name: Download Windows lib - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: win32-x86-64_whisper.dll path: bindings/java/build/generated/resources/main/win32-x86-64 @@ -654,7 +654,7 @@ jobs: ./gradlew build - name: Upload jar - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: whispercpp.jar path: bindings/java/build/libs/whispercpp-*.jar @@ -676,7 +676,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Test quantize run: | From adee3f9c1faec890eb0c5f3f6f2f73597a8b3962 Mon Sep 17 00:00:00 2001 From: Pedro Probst Date: Mon, 20 May 2024 03:08:48 -0300 Subject: [PATCH 087/100] node : add flash_attn param (#2170) --- examples/addon.node/__test__/whisper.spec.js | 1 + examples/addon.node/addon.cpp | 4 ++++ examples/addon.node/index.js | 1 + 3 files changed, 6 insertions(+) diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index 9ba86b62985..1ee888a1e00 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -12,6 +12,7 @@ const whisperParamsMock = { model: path.join(__dirname, "../../../models/ggml-base.en.bin"), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), use_gpu: true, + flash_attn: false, no_prints: true, comma_in_time: false, translate: true, diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 8125e5dda4c..53bf1abb5a3 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -39,6 +39,7 @@ struct whisper_params { bool no_timestamps = false; bool no_prints = false; bool use_gpu = true; + bool flash_attn = false; bool comma_in_time = true; std::string language = "en"; @@ -146,6 +147,7 @@ int run(whisper_params ¶ms, std::vector> &result) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (ctx == nullptr) { @@ -326,6 +328,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { std::string model = whisper_params.Get("model").As(); std::string input = whisper_params.Get("fname_inp").As(); bool use_gpu = whisper_params.Get("use_gpu").As(); + bool flash_attn = whisper_params.Get("flash_attn").As(); bool no_prints = whisper_params.Get("no_prints").As(); bool no_timestamps = whisper_params.Get("no_timestamps").As(); int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); @@ -346,6 +349,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { params.model = model; params.fname_inp.emplace_back(input); params.use_gpu = use_gpu; + params.flash_attn = flash_attn; params.no_prints = no_prints; params.no_timestamps = no_timestamps; params.audio_ctx = audio_ctx; diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 09b33c54024..643ee756452 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -12,6 +12,7 @@ const whisperParams = { model: path.join(__dirname, "../../models/ggml-base.en.bin"), fname_inp: path.join(__dirname, "../../samples/jfk.wav"), use_gpu: true, + flash_attn: false, no_prints: true, comma_in_time: false, translate: true, From 1b51fdf170714dcdd8fb9cfd02dcee684aac6150 Mon Sep 17 00:00:00 2001 From: William Tambellini Date: Tue, 21 May 2024 08:31:41 -0700 Subject: [PATCH 088/100] examples : add support for decoding input with ffmpeg (Linux) (#2133) - search for ffmpeg libs/headers at cmake time - added ffmpeg-transcode.cpp into libcommon if ffmpeg on - hooked ffmpeg trancoding in common read_wav(...) - passed test: ./main -m ggml-base.en.bin -f samples/jfk.mp3 --- CMakeLists.txt | 24 +++ cmake/FindFFmpeg.cmake | 163 ++++++++++++++++ examples/CMakeLists.txt | 5 + examples/common.cpp | 18 +- examples/ffmpeg-transcode.cpp | 350 ++++++++++++++++++++++++++++++++++ examples/main/CMakeLists.txt | 2 +- samples/.gitignore | 3 + samples/jfk.mp3 | Bin 0 -> 76447 bytes tests/CMakeLists.txt | 11 ++ 9 files changed, 574 insertions(+), 2 deletions(-) create mode 100644 cmake/FindFFmpeg.cmake create mode 100644 examples/ffmpeg-transcode.cpp create mode 100644 samples/jfk.mp3 diff --git a/CMakeLists.txt b/CMakeLists.txt index 588aa61cd11..3eb12c10783 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,10 @@ option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDA option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF) +endif() + option(WHISPER_NO_AVX "whisper: disable AVX" OFF) option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF) option(WHISPER_NO_AVX512 "whisper: disable AVX512" ON) @@ -125,6 +129,26 @@ else() set(CMAKE_CXX_STANDARD 11) endif() +if (WHISPER_FFMPEG) + # As of cmake 3.27, there is no official cmake support for FindFFmpeg. + # Consequnelty we added a FindFFmpeg.cmake script the cmake subfolder: + # whisper.cpp does not need the full ffmpeg libs, just AVFORMAT AVCODEC AVUTIL SWRESAMPLE + # libswresample performs highly optimized audio resampling, rematrixing and sample format conversion operations + # libavcodec provides a generic encoding/decoding framework and contains multiple decoders and encoders for audio, video and subtitle streams, and several bitstream filters. + # libavformat provides a generic framework for multiplexing and demultiplexing (muxing and demuxing) audio, video and subtitle streams. + find_package(FFmpeg REQUIRED) + if (NOT ${FFMPEG_FOUND}) + message(FATAL_ERROR "Cannot find ffmpeg libs/headers") + endif() + message(STATUS "Found ffmpeg libs: ${FFMPEG_LIBRARIES}") + message(STATUS "Found ffmpeg headers in: ${FFMPEG_INCLUDE_DIRS}") + message(STATUS "ffmpeg definitions: ${FFMPEG_DEFINITIONS}") + message(STATUS "Found avformat ${AVFORMAT_VERSION}") + include_directories(${FFMPEG_INCLUDE_DIRS}) + add_compile_definitions(WHISPER_FFMPEG) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${FFMPEG_LIBRARIES}) +endif() + # on APPLE if (APPLE) # include Accelerate framework diff --git a/cmake/FindFFmpeg.cmake b/cmake/FindFFmpeg.cmake new file mode 100644 index 00000000000..19dc751605e --- /dev/null +++ b/cmake/FindFFmpeg.cmake @@ -0,0 +1,163 @@ +# From +# https://github.com/snikulov/cmake-modules/blob/master/FindFFmpeg.cmake +# +# vim: ts=2 sw=2 +# - Try to find the required ffmpeg components(default: AVFORMAT, AVUTIL, AVCODEC) +# +# Once done this will define +# FFMPEG_FOUND - System has the all required components. +# FFMPEG_INCLUDE_DIRS - Include directory necessary for using the required components headers. +# FFMPEG_LIBRARIES - Link these to use the required ffmpeg components. +# FFMPEG_DEFINITIONS - Compiler switches required for using the required ffmpeg components. +# +# For each of the components it will additionally set. +# - AVCODEC +# - AVDEVICE +# - AVFORMAT +# - AVFILTER +# - AVUTIL +# - POSTPROC +# - SWSCALE +# the following variables will be defined +# _FOUND - System has +# _INCLUDE_DIRS - Include directory necessary for using the headers +# _LIBRARIES - Link these to use +# _DEFINITIONS - Compiler switches required for using +# _VERSION - The components version +# +# Copyright (c) 2006, Matthias Kretz, +# Copyright (c) 2008, Alexander Neundorf, +# Copyright (c) 2011, Michael Jansen, +# +# Redistribution and use is allowed according to the terms of the BSD license. +# For details see the accompanying COPYING-CMAKE-SCRIPTS file. + +include(FindPackageHandleStandardArgs) + +# The default components were taken from a survey over other FindFFMPEG.cmake files +if (NOT FFmpeg_FIND_COMPONENTS) + set(FFmpeg_FIND_COMPONENTS AVFORMAT AVCODEC AVUTIL SWRESAMPLE) +endif() + +# +### Macro: set_component_found +# +# Marks the given component as found if both *_LIBRARIES AND *_INCLUDE_DIRS is present. +# +macro(set_component_found _component ) + if (${_component}_LIBRARIES AND ${_component}_INCLUDE_DIRS) + message(DEBUG " - ${_component} found.") + set(${_component}_FOUND TRUE) + else () + message(DEBUG " - ${_component} not found.") + endif () +endmacro() + +# +### Macro: find_component +# +# Checks for the given component by invoking pkgconfig and then looking up the libraries and +# include directories. +# +macro(find_component _component _pkgconfig _library _header) + + if (NOT WIN32) + # use pkg-config to get the directories and then use these values + # in the FIND_PATH() and FIND_LIBRARY() calls + find_package(PkgConfig) + if (PKG_CONFIG_FOUND) + pkg_check_modules(PC_${_component} ${_pkgconfig}) + message(STATUS "Pkgconfig found: ${PC_${_component}_INCLUDEDIR}") + message(STATUS "Pkgconfig found: ${PC_${_component}_INCLUDE_DIRS}") + message(STATUS "${PC_${_component}_CFLAGS}") + endif () + endif (NOT WIN32) + + + find_path(${_component}_INCLUDE_DIRS ${_header} + HINTS + ${PC_${_component}_INCLUDEDIR} + ${PC_${_component}_INCLUDE_DIRS} + PATH_SUFFIXES + ffmpeg + ) + + # CMake's default is to search first for shared libraries and then for static libraries. + # Todo later: add option to prefer static libs over dynamic: + find_library(${_component}_LIBRARIES NAMES ${_library} lib${_library}.a + HINTS + ${PC_${_component}_LIBDIR} + ${PC_${_component}_LIBRARY_DIRS} + ) + + set(${_component}_DEFINITIONS ${PC_${_component}_CFLAGS_OTHER} CACHE STRING "The ${_component} CFLAGS.") + set(${_component}_VERSION ${PC_${_component}_VERSION} CACHE STRING "The ${_component} version number.") + + set_component_found(${_component}) + + mark_as_advanced( + ${_component}_INCLUDE_DIRS + ${_component}_LIBRARIES + ${_component}_DEFINITIONS + ${_component}_VERSION) + +endmacro() + + +# Check for cached results. If there are skip the costly part. +if (NOT FFMPEG_LIBRARIES) + + # Check for all possible component. + find_component(AVCODEC libavcodec avcodec libavcodec/avcodec.h) + find_component(AVFORMAT libavformat avformat libavformat/avformat.h) + find_component(AVDEVICE libavdevice avdevice libavdevice/avdevice.h) + #find_component(AVRESAMPLE libavresample avresample libavresample/avresample.h) # old name for swresample + find_component(AVUTIL libavutil avutil libavutil/avutil.h) + find_component(AVFILTER libavfilter avfilter libavfilter/avfilter.h) + find_component(SWSCALE libswscale swscale libswscale/swscale.h) + find_component(POSTPROC libpostproc postproc libpostproc/postprocess.h) + find_component(SWRESAMPLE libswresample swresample libswresample/swresample.h) + + # Check if the required components were found and add their stuff to the FFMPEG_* vars. + foreach (_component ${FFmpeg_FIND_COMPONENTS}) + if (${_component}_FOUND) + # message(STATUS "Required component ${_component} present.") + set(FFMPEG_LIBRARIES ${FFMPEG_LIBRARIES} ${${_component}_LIBRARIES}) + set(FFMPEG_DEFINITIONS ${FFMPEG_DEFINITIONS} ${${_component}_DEFINITIONS}) + list(APPEND FFMPEG_INCLUDE_DIRS ${${_component}_INCLUDE_DIRS}) + else () + # message(STATUS "Required component ${_component} missing.") + endif () + endforeach () + + # Build the include path with duplicates removed. + if (FFMPEG_INCLUDE_DIRS) + list(REMOVE_DUPLICATES FFMPEG_INCLUDE_DIRS) + endif () + + # cache the vars. + set(FFMPEG_INCLUDE_DIRS ${FFMPEG_INCLUDE_DIRS} CACHE STRING "The FFmpeg include directories." FORCE) + set(FFMPEG_LIBRARIES ${FFMPEG_LIBRARIES} CACHE STRING "The FFmpeg libraries." FORCE) + set(FFMPEG_DEFINITIONS ${FFMPEG_DEFINITIONS} CACHE STRING "The FFmpeg cflags." FORCE) + + mark_as_advanced(FFMPEG_INCLUDE_DIRS + FFMPEG_LIBRARIES + FFMPEG_DEFINITIONS) + +endif () + +# Now set the noncached _FOUND vars for the components. +# whisper.cpp does not need SWSCALE +foreach (_component AVCODEC AVDEVICE AVFORMAT AVRESAMPLE AVUTIL POSTPROCESS) + set_component_found(${_component}) +endforeach () + +# Compile the list of required vars +set(_FFmpeg_REQUIRED_VARS FFMPEG_LIBRARIES FFMPEG_INCLUDE_DIRS) +foreach (_component ${FFmpeg_FIND_COMPONENTS}) + list(APPEND _FFmpeg_REQUIRED_VARS ${_component}_LIBRARIES ${_component}_INCLUDE_DIRS) +endforeach () + +# Give a nice error message if some of the required vars are missing. +find_package_handle_standard_args(FFmpeg DEFAULT_MSG ${_FFmpeg_REQUIRED_VARS}) + diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3b493e3db7e..24678e1c6ac 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -22,6 +22,10 @@ endif() set(TARGET common) +if (WHISPER_FFMPEG) + set(COMMON_SOURCES_FFMPEG ffmpeg-transcode.cpp) +endif() + add_library(${TARGET} STATIC common.h common.cpp @@ -29,6 +33,7 @@ add_library(${TARGET} STATIC common-ggml.cpp grammar-parser.h grammar-parser.cpp + ${COMMON_SOURCES_FFMPEG} ) include(DefaultTargetOptions) diff --git a/examples/common.cpp b/examples/common.cpp index 2c0cdf082ed..25a0272cf08 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -24,6 +24,11 @@ #include #endif +#ifdef WHISPER_FFMPEG +// as implemented in ffmpeg_trancode.cpp only embedded in common lib if whisper built with ffmpeg support +extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data); +#endif + // Function to check if the next argument exists std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) { if (i + 1 < argc && argv[i + 1][0] != '-') { @@ -637,7 +642,7 @@ bool is_wav_buffer(const std::string buf) { bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { drwav wav; - std::vector wav_data; // used for pipe input from stdin + std::vector wav_data; // used for pipe input from stdin or ffmpeg decoding output if (fname == "-") { { @@ -670,8 +675,19 @@ bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector } } else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { +#if defined(WHISPER_FFMPEG) + if (ffmpeg_decode_audio(fname, wav_data) != 0) { + fprintf(stderr, "error: failed to ffmpeg decode '%s' \n", fname.c_str()); + return false; + } + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to read wav data as wav \n"); + return false; + } +#else fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); return false; +#endif } if (wav.channels != 1 && wav.channels != 2) { diff --git a/examples/ffmpeg-transcode.cpp b/examples/ffmpeg-transcode.cpp new file mode 100644 index 00000000000..910cdf5700b --- /dev/null +++ b/examples/ffmpeg-transcode.cpp @@ -0,0 +1,350 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +/* + * transcode.c - convert audio file to WAVE + * + * Copyright (C) 2019 Andrew Clayton + * Copyright (C) 2024 William Tambellini + */ + +// Just for conveninent C++ API +#include +#include + +// C +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern "C" { +#include +#include +#include +#include +} + +typedef uint64_t u64; +typedef int64_t s64; +typedef uint32_t u32; +typedef int32_t s32; +typedef uint16_t u16; +typedef int16_t s16; +typedef uint8_t u8; +typedef int8_t s8; + +#define WAVE_SAMPLE_RATE 16000 +#define AVIO_CTX_BUF_SZ 4096 + +static const char* ffmpegLog = getenv("FFMPEG_LOG"); +// Todo: add __FILE__ __LINE__ +#define LOG(...) \ + do { if (ffmpegLog) fprintf(stderr, __VA_ARGS__); } while(0) // C99 + +/* + * WAVE file header based on definition from + * https://gist.github.com/Jon-Schneider/8b7c53d27a7a13346a643dac9c19d34f + * + * We must ensure this structure doesn't have any holes or + * padding so we can just map it straight to the WAVE data. + */ +struct wave_hdr { + /* RIFF Header: "RIFF" */ + char riff_header[4]; + /* size of audio data + sizeof(struct wave_hdr) - 8 */ + int wav_size; + /* "WAVE" */ + char wav_header[4]; + + /* Format Header */ + /* "fmt " (includes trailing space) */ + char fmt_header[4]; + /* Should be 16 for PCM */ + int fmt_chunk_size; + /* Should be 1 for PCM. 3 for IEEE Float */ + s16 audio_format; + s16 num_channels; + int sample_rate; + /* + * Number of bytes per second + * sample_rate * num_channels * bit_depth/8 + */ + int byte_rate; + /* num_channels * bytes per sample */ + s16 sample_alignment; + /* bits per sample */ + s16 bit_depth; + + /* Data Header */ + /* "data" */ + char data_header[4]; + /* + * size of audio + * number of samples * num_channels * bit_depth/8 + */ + int data_bytes; +} __attribute__((__packed__)); + +struct audio_buffer { + u8 *ptr; + int size; /* size left in the buffer */ +}; + +static void set_wave_hdr(wave_hdr& wh, size_t size) { + memcpy(&wh.riff_header, "RIFF", 4); + wh.wav_size = size + sizeof(struct wave_hdr) - 8; + memcpy(&wh.wav_header, "WAVE", 4); + memcpy(&wh.fmt_header, "fmt ", 4); + wh.fmt_chunk_size = 16; + wh.audio_format = 1; + wh.num_channels = 1; + wh.sample_rate = WAVE_SAMPLE_RATE; + wh.sample_alignment = 2; + wh.bit_depth = 16; + wh.byte_rate = wh.sample_rate * wh.sample_alignment; + memcpy(&wh.data_header, "data", 4); + wh.data_bytes = size; +} + +static void write_wave_hdr(int fd, size_t size) { + struct wave_hdr wh; + set_wave_hdr(wh, size); + write(fd, &wh, sizeof(struct wave_hdr)); +} + +static int map_file(int fd, u8 **ptr, size_t *size) +{ + struct stat sb; + + fstat(fd, &sb); + *size = sb.st_size; + + *ptr = (u8*)mmap(NULL, *size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); + if (*ptr == MAP_FAILED) { + perror("mmap"); + return -1; + } + + return 0; +} + +static int read_packet(void *opaque, u8 *buf, int buf_size) +{ + struct audio_buffer *audio_buf = (audio_buffer*)opaque; + + buf_size = FFMIN(buf_size, audio_buf->size); + + /* copy internal buffer data to buf */ + memcpy(buf, audio_buf->ptr, buf_size); + audio_buf->ptr += buf_size; + audio_buf->size -= buf_size; + + return buf_size; +} + +static void convert_frame(struct SwrContext *swr, AVCodecContext *codec, + AVFrame *frame, s16 **data, int *size, bool flush) +{ + int nr_samples; + s64 delay; + u8 *buffer; + + delay = swr_get_delay(swr, codec->sample_rate); + nr_samples = av_rescale_rnd(delay + frame->nb_samples, + WAVE_SAMPLE_RATE, codec->sample_rate, + AV_ROUND_UP); + av_samples_alloc(&buffer, NULL, 1, nr_samples, AV_SAMPLE_FMT_S16, 0); + + /* + * !flush is used to check if we are flushing any remaining + * conversion buffers... + */ + nr_samples = swr_convert(swr, &buffer, nr_samples, + !flush ? (const u8 **)frame->data : NULL, + !flush ? frame->nb_samples : 0); + + *data = (s16*)realloc(*data, (*size + nr_samples) * sizeof(s16)); + memcpy(*data + *size, buffer, nr_samples * sizeof(s16)); + *size += nr_samples; + av_freep(&buffer); +} + +static bool is_audio_stream(const AVStream *stream) +{ + if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) + return true; + + return false; +} + +// Return non zero on error, 0 on success +// audio_buffer: input memory +// data: decoded output audio data (wav file) +// size: size of output data +static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size) +{ + LOG("decode_audio: input size: %d\n", audio_buf->size); + AVFormatContext *fmt_ctx; + AVIOContext *avio_ctx; + AVStream *stream; + AVCodecContext *codec; + AVPacket packet; + AVFrame *frame; + struct SwrContext *swr; + u8 *avio_ctx_buffer; + unsigned int i; + int stream_index = -1; + int err; + const size_t errbuffsize = 1024; + char errbuff[errbuffsize]; + + av_register_all(); // from avformat. Still a must-have call for ffmpeg v3! (can be skipped for later versions) + + fmt_ctx = avformat_alloc_context(); + avio_ctx_buffer = (u8*)av_malloc(AVIO_CTX_BUF_SZ); + LOG("Creating an avio context: AVIO_CTX_BUF_SZ=%d\n", AVIO_CTX_BUF_SZ); + avio_ctx = avio_alloc_context(avio_ctx_buffer, AVIO_CTX_BUF_SZ, 0, audio_buf, &read_packet, NULL, NULL); + fmt_ctx->pb = avio_ctx; + + // open the input stream and read header + err = avformat_open_input(&fmt_ctx, NULL, NULL, NULL); + if (err) { + LOG("Could not read audio buffer: %d: %s\n", err, av_make_error_string(errbuff, errbuffsize, err)); + return err; + } + + err = avformat_find_stream_info(fmt_ctx, NULL); + if (err < 0) { + LOG("Could not retrieve stream info from audio buffer: %d\n", err); + return err; + } + + for (i = 0; i < fmt_ctx->nb_streams; i++) { + if (is_audio_stream(fmt_ctx->streams[i])) { + stream_index = i; + break; + } + } + + if (stream_index == -1) { + LOG("Could not retrieve audio stream from buffer\n"); + return -1; + } + + stream = fmt_ctx->streams[stream_index]; + codec = avcodec_alloc_context3( + avcodec_find_decoder(stream->codecpar->codec_id)); + avcodec_parameters_to_context(codec, stream->codecpar); + err = avcodec_open2(codec, avcodec_find_decoder(codec->codec_id), + NULL); + if (err) { + LOG("Failed to open decoder for stream #%d in audio buffer\n", stream_index); + return err; + } + + /* prepare resampler */ + swr = swr_alloc(); + + av_opt_set_int(swr, "in_channel_count", codec->channels, 0); + av_opt_set_int(swr, "out_channel_count", 1, 0); + av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0); + av_opt_set_int(swr, "out_channel_layout", AV_CH_LAYOUT_MONO, 0); + av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); + av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); + av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); + av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); + + swr_init(swr); + if (!swr_is_initialized(swr)) { + LOG("Resampler has not been properly initialized\n"); + return -1; + } + + av_init_packet(&packet); + frame = av_frame_alloc(); + if (!frame) { + LOG("Error allocating the frame\n"); + return -1; + } + + /* iterate through frames */ + *data = NULL; + *size = 0; + while (av_read_frame(fmt_ctx, &packet) >= 0) { + avcodec_send_packet(codec, &packet); + + err = avcodec_receive_frame(codec, frame); + if (err == AVERROR(EAGAIN)) + continue; + + convert_frame(swr, codec, frame, data, size, false); + } + /* Flush any remaining conversion buffers... */ + convert_frame(swr, codec, frame, data, size, true); + + av_frame_free(&frame); + swr_free(&swr); + //avio_context_free(); // todo? + avcodec_close(codec); + avformat_close_input(&fmt_ctx); + avformat_free_context(fmt_ctx); + + if (avio_ctx) { + av_freep(&avio_ctx->buffer); + av_freep(&avio_ctx); + } + + return 0; +} + +// in mem decoding/conversion/resampling: +// ifname: input file path +// owav_data: in mem wav file. Can be forwarded as it to whisper/drwav +// return 0 on success +int ffmpeg_decode_audio(const std::string &ifname, std::vector& owav_data) { + LOG("ffmpeg_decode_audio: %s\n", ifname.c_str()); + int ifd = open(ifname.c_str(), O_RDONLY); + if (ifd == -1) { + fprintf(stderr, "Couldn't open input file %s\n", ifname.c_str()); + return -1; + } + u8 *ibuf = NULL; + size_t ibuf_size; + int err = map_file(ifd, &ibuf, &ibuf_size); + if (err) { + LOG("Couldn't map input file %s\n", ifname.c_str()); + return err; + } + LOG("Mapped input file: %x size: %d\n", ibuf, ibuf_size); + struct audio_buffer inaudio_buf; + inaudio_buf.ptr = ibuf; + inaudio_buf.size = ibuf_size; + + s16 *odata=NULL; + int osize=0; + + err = decode_audio(&inaudio_buf, &odata, &osize); + LOG("decode_audio returned %d \n", err); + if (err != 0) { + LOG("decode_audio failed\n"); + return err; + } + LOG("decode_audio output size: %d\n", osize); + + wave_hdr wh; + const size_t outdatasize = osize * sizeof(s16); + set_wave_hdr(wh, outdatasize); + owav_data.resize(sizeof(wave_hdr) + outdatasize); + // header: + memcpy(owav_data.data(), &wh, sizeof(wave_hdr)); + // the data: + memcpy(owav_data.data() + sizeof(wave_hdr), odata, osize* sizeof(s16)); + + return 0; +} diff --git a/examples/main/CMakeLists.txt b/examples/main/CMakeLists.txt index 1bb16f58214..1e66e4b5cc8 100644 --- a/examples/main/CMakeLists.txt +++ b/examples/main/CMakeLists.txt @@ -3,4 +3,4 @@ add_executable(${TARGET} main.cpp) include(DefaultTargetOptions) -target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common whisper ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) diff --git a/samples/.gitignore b/samples/.gitignore index 72e8ffc0db8..e084659df25 100644 --- a/samples/.gitignore +++ b/samples/.gitignore @@ -1 +1,4 @@ * +!jfk.wave +!jfk.mp3 + diff --git a/samples/jfk.mp3 b/samples/jfk.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..fbfa1d9897365f05adf4f624c3210cc4fd374a44 GIT binary patch literal 76447 zcmce7_g7P2@a;_{K!8vUND0-@L$6BcT|$+rp%)P;N>P;1yYwm|ND&Z}-b4vix^zXU z7EnY$X&(^CgYSE9t@jtav(`=Oe3fMfxXhz9_H zQqj^gv7$M+FnmJ7q7pb6IRzz^%U9Gjv~=_gOf0V1+B*>4Ja69e@%6tG7<@036cHVl zkerg9nfth~xTL)D#mn0IhNhO*_RijSgCk=TQ!{g)7ni?&|FOQg{daf&zoV10|05Ww z8*9s)TLUHhA4>pMga8E33xJl-?cjqx>iWMo{y)D{c7l5WSTGn51_1B{%Ktl_-o7M! z^#1q?5O%`1*K$rT z{5MG{1bZYsnzDoKons88*W$sWKqyUVT2dly7=@yC$eP&yvn<=QNEa2?D+Wd5hEod- z;Cf;hYVr45;sJRSg;#Opjkkhe>8q3Ylw$|VPHyj*StgsZL%~PjA4!9(R~<`2jaPV( z!DUP!3C#@=*a_v75>lu-ZYD-`~RSY`~iOi(;EM9Kb>gUu7b#R4L z=XGcE=|3YL6~wqw7`BdK_?gmlw~d0=Q)CHikXd0Lk4(3)(*;nEEDmbMko8{A^>6Mz zq4%mCm7!;rKn_Qo{CtC*nSzA7%O^tZaIaY!`ipT`xnil^om~fujkwM3LO0_}QU+x2YDmRulc0`S8>7ub1ji zKfe)tWg^s|roh!Cm9ZlJjI<`$y4Yr@q#MP9`u(H8R#HW{#rUC=iEWIHWT$%8#n&Qj zUJGI~wsucH2MHGs^tQjOD_@&`7o?JNb1nPwP=EFJ>>D`(4cD&xerwSg{p)j`&)CLL z?bytS*lmjB*3QKPx#MBFfTmkP5{svf2j{W6s9Lfgc;h1(L}^mF-x8uE~D*f>0pIT6n zg*sd5RZu4EP9LJjgY47SoeOy@~?+(_8@Peu2c#0+@P*UX}m!E*0TXuZP5rc z4;RhoOn4g~ze}9)-)MO+6_}AKfr^i{ddXs@GSxLPIHXKNM|#pjKljnzVIZ?()O62F z_3-fFClv?rWVPCnFE!lhgpAadzbXDzo3`75;wapSLzYNBIkJQT=_|>ffX1HZNl>=%brM zVl7`Q{k%9Y_vS(g+DwJTYb5Q>gMVLj-d^tftc;0I&Z^nsRvq=<-`!Z4UpkOi)n1N1 z`1^Kz^n1o(S_p36`r#h(L3~mh-B8``Pm|K4lT#hvx;ch_LLRJna>4?P^)%3@$iSV* zs<7Wnm0Kl^W~@CG$EIycJo@492` zqF1l|H^53o(7qf)&wO|(()UM{EO9Ns$~p%3H-_f&7tU(l+z%}p=QjVsM*fZ44~b3g zGvM~mzw@R4_ScCa*|c|l!i?qto{;oHm9%~Ltxp1a(`~^wn?wDD>L>ox?q$q}l-!Jw zdr2qB-9%wNE;G6Fw0HQKg~Q#Zi!CF*-_4EhYkWFkb>HT(6ag61B5bf<+TJcHl@FwU zk>w%VIZ=T$V#;q9SXB`GhlSMRuz!jj318|Mr7%pa^mNpqA#{ocV;M(xQ)PKa*Yu!# z?vQ%5o=D))sh`&i%8p*|HkNiX>VpX$Wp!k>PhbsZc$qA_*}85)$g+vrcVB*PMtg95fQ;n{aG5)m&+1>#(8c&lEcKm-cGs z96o7k&*qtvM+T?^ey%VLu+Qc{D5CnSanxC}krKR~M_Y7Z>}WQ+yNv$vfk?*Zp^dtO z+u5o-Jbqs$gnafj8>asoCaPTfaibzom-YFRoQ3XdbPgq*kl7a|>KrZbnX=9xs=39n z!Y7{_3^jfFU-QRsB>(h&>$9rbAhlnbg+KrR6;~P`EE7x;N))N>M~}FMmt#;N@j?O? zK|r*NMz>MYwDIU8i}8y|(9A*Jf!<=-o^i4F4O&xo3O=bTj%o~iPTTFf^rL&6r5le6 zhTyEx(l{rPaQG$aZgTj?XusTW#{|cpIpL;}>S!zv4g

P~0>~dJsUCB3wX4Mgyp* zI6feuXfCo#J(_fw2m*zd+Qfmn(wA3?fvyPNJS?IKY0Wg!%$cMNN=BO@^(B)=IFeX2 z>EL=vR~fHbBaDiWMuPO+;RqZ*4lOQT8SShX)gd*VC1vDKxA|u&!G+chfu?W!g!T34!GA<2FUFtd$j6GDq3P!0M z&syRa)Oja}OkN2mn-aKinbgIMfL>oV`ufkPm5DU4QY5*S__x{#C9LoKBILN-EuA_^ z>!!o+ME}ZLdXMaCKH0xcIoJ!LKa&S%<5YI)cB#ARYhUL|#u!nZmz{Sp`YvxD(;U6J z_%y-gSo*NtYudbV_V-ueOvl6nuLQSyg%J$#^UoV(tQc)?xH6O(v-yV?rFr8q>Iq-s z^RyV_fiGM=AiSQU04P_7n;kBnw_U&y?1~P?NHeo&X`%+dURy3nTb#1Hd!pxcb4h9-U6~8L8C*SC(Y6hH7qbH0 zh!}4o5${-G;Rv)7$9tW8iOwN#Cv>1QmXoNbO8mp^yiM;q80Cn$>K3l&5})9qbd|Ex z+FR{q0k@}G_yHoCJ-h8I;GBO9br^Z?mn;&V!Y8!XPvlH^94;HDG4icHZg2CyPh_7k}2 z7gDN@I8IXtmkxcB$zik>C{7bG!N*D0^U6y270Pwb%iQ+3cCqt+=I#9D2Bq`E7A5Re z52f>n((`obz1r!i;@Onmc>j*_ujBP$(B|;J!?0hJ(#pQ44^~J^ z;Iw5B6cQ{8G6ItXT#!aQU`&HS0a+EPPL?Jrf?Y`lI3CzndedJxVNeEy#EcUMiiqG` z@o>{9Tuz=KO-OtMZ?1HbFP4W)@*o-!bA^bAdZ&-yT=bxxgMnyS8bG4WS5KXvt{=jz z?F?u5$^cmGh4Bm<&lE)JM&cs=6md9uBGf@7V$U+Zm-yp72#@=UWauV!MLXgIe54E3 zWgLU$5|5=x6kKB)00;@Pm)}D6o3pu@pv#4n$|GoEF2t^KfWJl5#dk|RBDL0Dr6-Pz zS|^!F1Wf88_;wc^ z_oRxEfeLbjSU<0gIv&r$ddV_+7)QY3yq$5a3=S^jZ@(h}5@;FKC8gL`)W#y+MEZen z{K4tSIPHWoqY;b`7g#0bBcL~=%#xH6WN=~tkR*hI>)6NYDHu0hJ@m|^9WhFeRNS62i7ZH3 zK2cHE7h<$_AyUWNF0}YL8VR#Nt-fTzM|{1B*mv4lo)AZU(^ME+ZjbuN4_3ug7KajX zFK!VW0qFTtw3)sZ?)1e#5+O%wm(7Fz`H?tjSeNeTUfmucdSLIDK7h#Z(3l_zWK%!ks*Z4vy4L$i3m zdqJcOs{BzG8!E72`;;3`v&M~Wz#4$R@`S$cFKsh z3!j=kCP38{&=6>l)91}P?{k8n4%83NA>byIpwsh zd3artyU1QbS`>zkQw*7IvH3O~;r_8d-N&mj+04q9e3*dEHv--E+)7m7B{qoxkMb@L zW-{;!K=E;U7B;?!!|Z4#9`wMMH%`cW4lmeGI4iZO>hK0mj2xb&+?)$sOuJgfU=b6t z{fbUIFH$X|K0AQ(zLO=pk%~ynNevy$r16>`Y502hI2GuL^!He!nE<09$B#WNg4xnJ zhiyhL=XvvJw=hpUe=(|885UeHlm0u@?R@Y0Z+2zA#NJ@zw| z${XLh`AgmWzA%@e^*rkJ2g`em3zKh|*(&P=%HFK;s-5xbf3H4>@D#w$YqV)4$lGA} zXU4v0pSOqY#_vjFI&-Ktw0p{dJF@g*LT9CRUU~%mQ{+KzH}9HgZY!{}5)UTjAZyTy zaT8gisvCnqQmPoH66B_L-RYSdJ&lXkySLs>s~RkD1lG}V)+r*FNs>KN=sj&jSyipZ zW#t&-7ax^0X4W%|YR?pS9Cs@vbgsURoI7eI#7oiX1`g@~X3r)o&kKd|{S-10q!68~)cEt#f3b zmgsaxNv}3G&qM#AP!#stwH)ci=)9QVhqPv0!MYdUpL6Dt<^6IC_+sJVSj2vupJrBk zpoAS`{NIn^vfve3jbTTqd<=V-4QF(;R>7ri4`UkB%Mc}g?UZNZot6EmO$C3}c}U}a zm%_=Qy#99W@~IUDW;xdD6`%9@rw%&|$uT`}*0;|0Z(F~BDiwE_)K<52JCz9N+;|HW z(cKRXC^bI-?sPI5V6Ov*Lf=7Q&CQ=Q zzK8Hrb97n%f*ZR>YuKj-)s9BC7j{L$Kl=k1f)be`?Q&`BrO55|MDe2EP$3XsqG0bPiKBix#}f-?$8YXd>(Y z5{w2TnEx6Iz=}akrY)LOi}ntU3C402&*ELay)5$l)W8>eR5 zjLEhQCuX#FMg{!bUOvA6E;1~L@~i(%*yk+0EFf^w@jos>Fk6FTFID>w3{gCJ&vDul z+OhatJ+RH&p*5S64!FP2T5v>t@$uq!yy*}&(lJb0@}ga$9`W5h)aeBoJ-ew`^R&!3 z^^1s?M82_|G?&O|#I7a+bz3x1=biJQ?e@bAg|d5-54h94B;4l0B&&#&RM!6tFjXfpN#bl2G5A5y`GK4DCw#qx13O?HBQmS}&M2 zN$ucs|1{xy+P&dD@9at6@jMB8W$0HqyF!}vniik0Ut?pz7`&NtxydGtj!Ab~=J^Ut zO2qsePp%8ST6=GXZ@K&5y)WPLL=Glu-rTq=9R1V!>BM(M&aUUG((FK(4BJDS|A?0` zwU>!d5<577#$_P;B&Ax!u(2YKx*?*skDwR}!jPkq(xP2%eUZ+2ki_Sd-$^Ic6|#5` za-IKZE8lT@xOGkH+Ab5_jYnWjW;Pg>&2IL^ED^0~C~?t@xFIFO`BU?><6Hf~pOwd@ z4+SUL%`MFDnux_Kz{uOuI*QL7eKC%nDMc9F@YEI^%=m*wb|0GTkM19@%e1le@z2Q4 zSk=CJagVKq##+hs-jQ;c&Q1iYs4|C1{$tjwlfwdMeZRQgJ}p`LVEEL`=u+@KiLs1I zJ>_(0$GpZ0=clf+sDGus8wnPQizZK?b2h3}TY0W`8rwho`DMGt#h}S>xxC@m($d1+ zhA|Nc;%Y@or68)EyUKAuEei-Osf)91$bdS@X0t?OI&ie@!fPx=Y_4 z+K7wuvL~9o0IKJ?_l*s2uv)FB)+2U3Cf}xb9osRoZsLZ~jGx zLX&}VcCb3>BZgKKg%pC0C|X>tk3qJwM#6~|^h5wfpb;#>sNqnO(u)7&IAE%&%To6I z#Bg!~brKASAqknWQAdD%J&45id(Iw{aOx@JS5>}kysq(Zq$+X106E>*10{hl%P|gC z9D@`%EHdUshBTH9W18piT<`=Hhi=^#9_R-#Bzvzg!C-i@Yme8a*<Rh|8_42}n#Khj->~Ujbn|#tLr2xWqTp}j%7kZCulQe?a|iI8y26(pKlwd$>)@_2Ga~xafHJE< zh}FdALHt{%hQl?F7!eT;>lwN4cyvj5KIe71lB-p5BEW;pD*B-5e5PK~7cwBmOa<+y z2`72XMMLl<0J${0`fD!0LQU<&NFWRJZkuR*Y3S`DTBkz9{(@1^aU%vu5(6la?5z0m z_MZWiNDhKrq=%t#A2@vi2S=`Qz~E}-I(Gzhl%QRpZb7fSUrxfM-AF0Una9X)0zF+9 zJ4~xvJRy6sP&%2?i8Q>F=3vnFf7%vdg0vv9JnUnenA$PVnx^qq&lKhSPw|sB7YX%{ z%KGx>r0W>GO&)bF8k7K5X$bTzN#QV{k7K_l6`n)VxbXdaz0VjoygiHO_cn6k&_m&0=qbK~tJ z@bW-h(=D~r+kbX>-i}(?iH}}p=~Xhpt;+sht-UflY{rwz&F)kPkfCRN=K|$&rfb^J zDsLki2%yOg3nz);Xuu@F|JXpfBr?kXNeq}&E-a+;$2V3!GZOH|)&8DkqKs5W!eSxH z*_sBlrbUj6I3$y==TDEx86G4xGARWhYp#16-n2_MSaK!EP}A*eF2>*QEJ#w3kwQm) z!k};FM&wR@r_Ay7y8rw^L~@wUy39(KHC8fnc2Ox?f}2N+6arsb3V$CXFOUswa%|yw|cvHz&N>*?!o3RYZW0uAUsokPE&2Px;cDOvwW(7wJ#DDo<`)7 z`#KTNqW&@$s>1<)mTM1la&XZSqWX6Hb{;}P8dMKzL84tR)|i&ncz9kxmT}^;WQy07 z3=5B!`sQ#_wFG3yQD^4|m(*>j_W;P!Y4-Jt8*fEYi#-R*n9-#0lB-P$qVKP)5Y<=i ztcQ68<%1 zzvtaodx&mkBW{6PPj?EoXjznffAXZOo#^vWG?vahkn&mn#oeI`&)#scwArvCU|}O2 zsrv3y62TS-4_bO6dfur*BEqdBGP?C@#P0|^4ngCfDQsc>$&~1;L)^v==mLM1dAkS@ zFxeNtI0h@i~!)8O12hwur%54Y8>E94aJi!H=p+e%eDrJ{LZE7AlOE(SjLU5 zn~{4x>#&0VdTgfanga`}pWfoTFTQ`G{m5hT{zv;6F?>sVZzonL6dICk{wjLwwfa4? zg~VDvgZPZR4>O&lPsIcfO}sAdpQ-z8ipAtwU^B~p^^HCD2)=$}FSTF00%m;X*VdR) zI{UtyYyMBFNlc;To0b-XqbEb$51m6OloP%6@(kOmYNtS{joGv+64gT+{G*HfR)0vW z79#+Fz-us}`Uz6Raf2L35{N+MwL&rtMe-^c2P}>tSR%re!Z9n6%SwUyqu2@r9)z}1 z&f{~-7n6?#y;AiNl*}AR0)a;8=!|rMNHE-(W0oa7#;otcPZG&pe6@*f)o;JiMqW%c zO3?bS;P-mI4iOY&c=xjL>-09e$>CWvnn~czEx{3Mx7+on016Nhmav$(U-45|%&I+q zJ$2~&JlS3NZ%hqLhekP6l%p@)XI&EEL+9lGT&jFV*6VoptptD+HvD!T8?$OIY-iYE z77uP%tRkF2kNh(kB9+4C$YK3Zs5ip|g>poq-zQ3Oc_-wQI}y-TB=uYWDlL?Y<-8RF zM;rlOQvc8eI@}(QI*pIPAV>)9#hKrI$AaKAPXbcDk_V3GIk9jYCU{V5hZdv>xPE)% zquz32sU2QaBqQrBx4Pz7{PRx?lVgVywxFhQa2`6q3#ikvfNAQE>fbNY{QOex?w5jU zskdotFI1ZMXjT~j)Kg0pqm?vg*MEO)Rsn5`vN3U0Ln5_C7F!xQT ziX2##yJ1KnyF9HWe>WU$46p{Vy0sI6Dq(euS7PVByk*T8i9PB}()_bSS<+o9TQ6h? zdm`KJ88C0D*qw|1p-9>B==~+X^ERsg>#*gd$Q8}89oTawxC=Id^{V`RNW3f%t0J1s z3~qt;&_w29|KvntS~oytrUJPD%8|Lkm;XtIAsQ3q;mkf99p~lKQDZO{zdqX1nt&Cj zFuuu+k9YESatr#!8s4O{SjQ0}v*ug2dn8}?ST{cID0@FqR0RovA|%3p&T!rLo>Zn4 z<&mstqc=C&mC1=0)wXt`{8d|$c9{89F*ZO+dElby7Bn}dYruJ=*h?E0rFJ3xx%78^ z?YIm$A3pCtVg%b+B^Wc|nHrgNrja~)KS}hm#lb%hqfxlDkyXhl#Z1sZW-?9@wo_us z426!^y2%g_z6d@@&Tdov&0%37BP$i@!QDZOHOG7wuipcv5hxW7{iIvNuI1jzdh|Bi z%3Gwa$Yc^S3Lyf;aG;jWX?mA4pmCkWw(pk&br`Eb82{YqvFSBi4&FG;*jXYDiz?Zc z>-q0G|Ly3xXW2X&Xkgp!fI@os`gPgstt%lluV#C>gBnt5=qRH~-Q9ZG_bLBQ{A;ib z5z=dBiN=F+YYOuv7NK7!Qi9g??#^74xI%l5Yh78Eh6Mlm-jC95?>^a$s=Ym4fh9*Q z`|TY_FmLDnkOm`cXsAX$E6DOj4Mdv(ww_IlM9w-Q9`5`&AnBeSF4#Iw!75;74G@!% z^E3MJ@~ceb)5Bl>2HLkxO9W_BT_ckr?X<`0A{d7Its^pk9#_A{cOMPDq5-LNEoM*rIFlg8 zh`2~5hN0Ed3*bcw*-LmNAGWFerAsVd3ZcuM!XKq=!fudN}wEx zgzf9P;-op6zc`e|z_T5eHhljw!7pA(Hr*XmNYdiaEGEBPoC#fQ9F}rQZo2?dcb#<& z{PyF&a)Iurr)Fza(|oAXgdtk3m9{4)S2f7p$pX$s*xbupPv#p30Fk7-IKLkc^Vj^0 z-+}b=72N=7Lg!L;^Ir9ybK%;_IY*7E8SdYqU(NCxo6o%ND7h4t8jjb)s2s0RD5UMJ zvt&inJ-F1Aac{cd=)KgZiU+2aR&n*_EN%}i!M7#hJe;kl zEgW#fGwmmVipi*o0at#43%Y_>tK&$f5kn@<-)D_Lf z`tr>Il6qGRG%Z{cMHJgQc2hjnOTzzW&4AooLzGLhSXt}fK zh(I%?ek_KJzwIf)Ry*_cLB{0EedEjU1Z3AYBeFdw1gTekpU2yc^V~mQ3?84J>BG;- zT4$lStZ1cV=D+EahBed=#5P(P=0V4SBG`{$ug>F9`Ep9Bff|KU71P8nZ8GI>l-mBC zs#q|I@LJrXJYxK6vd(7Y@V^Ec$k51jIayAS5eh~tLPX;TbddmPvO|$UKOGN^0}-?r ziLD+)9D)~lHRsNV7BbbByv=o|#vDRsBza&3a$5(vBV5_<(ZP{b_LcxFyrSwC&YB4q zERUWLs+oU!=u zxmw!)w@6(6w8Aq&gS`XF{P^;4?WrTcK@}4!Bu}CA!Bmcx^&~^NW(>;8LT}R7$r-6T z&kyZksCKgnTQ%iQ5s&O(G=Li`6}0Q3>Qy%X_2*I~Qb?ROsIFEx?K^LDmmdD9WKsl4 zk<}GW?iy7fYyS16DbKoilER_cy?V-V-6swc+?l1>th zdFaYU-!{q$|EBCur)cm^upI1#jzAP6bGrow@wn$#u56Agr-kw_0p|X$iZeaSvY^MYiVlFgk~3oG8Eq`}DBR*L2Kwe1db7denXrinmGORTThj`#o}2 znve;+HZG-ay{~&%_=@uPaKAo}_hU_y>c_N`k@&G+O>NJzGv4_49IAsi3`hVV(~cdV zvj2lA`whvpQF>aC@8Rcb`-?)SV+a(hXNc~~xGd`r*khVCdR@qf@PD?P|_r4^|gVxH$6fbI|&sjpT9!=^Eeml>UO@ zi!Q^8I3lFf>v~?gD_MB``Sapo^pB;rlg>*i3g31qT$FTv#U0(3Vtf<6aUiLd+A0!w zbLai<4#sCZU?DtGg_>a3>@=)j+b+NY_w9h__-4%v1ipFRmmGlu6-%8PTfc`n45MhS8mbQ$Iotf&BVvM3hFq< z$p2nA%D&sy-F)XK4etX1|IG}kCjt-rJ+sRe%Km+_igf+vU;6fPYWI6+Xb=Bdpk@}w zyI`cs5m@QNCo%nRw;w1koP4x{4k~LOzkgo%?$s2#-5!nTVSC_Y!eHdf6teDT4vE`l z+O}=hcU*u@zg9+1=>#mjS=T!n@gLLWN?$I}M@R9MXL8x{I7DYsA#!~`?{U9U=dM$@ zFhO(oJkH@F750ZqO0nL-8JDd;^+rV43k2qBsghJf zYH>=*VpxXw50eoQTK?({t&)`yl)u?uzZsy{1^sxQ95Q|p4x>n0Tuwe5cz?P;zv%p; z;9g=JwR35rYVE3g9}r@C;2p;%#sLD%!FZ?RTn^;P6rbYU|0ZPOjed8Z5MYsDI5aXXgjIys#F&iNW#G}2r>@1^((iLs$faLYM= zp0;K9xsKV~r3%fYIj;PaWnfsA65a>u zGc!h$)ny=&BVZ~V7IxQ9#E3}ElFLSZ(=DQSFP@ew|IW z{4yL`dS4;6Az&)GG-nT+tX_f6r+*e&UE!$Hoc_yckVGk%0?=t_NlMvSUTj<~S^d6ABR+R()m9h^LMS zcLP!7A~_ui>Zvbu0wiMK(WLbvG>wfDj)Z6F!JGT`GK9v{3&t>Fq{liL91)m-cMui2 z(d$@G$y13TPQr8rKkR}YNaqHP4k!22JGS(0b|VJ##Le60jC0^vIt;FC>)xeJ6WtZb zVRj}atAz%mSqG2xhq8Fy?p6J=nZVu8r#QeAEO3vM^;!OBRbzO$p#OOs)~dZ;eW0_- z;A<7*@|qxcs|RcQNcb%}Ks52*^Y43gc3#}u9m8KY9aH``{d~s+zo$kiR6w8AJ2{k7 zWl4z1?WYn-J9Z{t-8KU%AHTf+fGSqY6y?Jg^}7N9D1kfRa6CR?w6DIT%L2y<@vgdg zt4)WK7lr70){veDdn7D>g0*MCfWd%k8P0<^2j6Cd-4%W@U7l>wUfyr>FUt9y0Lzt^ zHMIu|BmrS8_yc}x_|T_BPeT(;a68R;GVgcn_~H|LY0gl@wp71d(iX|AO7sm?;HxQP5|jeH zV!7OL*KM!&L#Fb+lk5kXOeKQYWv?;xra%f;Vvxj&Xr8NPUvG(ZbDH;Ez4#{K2lWqa zdcRlB=pwtsoy@ThJ{HF@@vB{>b?7v4sN>~bOgC|#e`qP5Hf)Cft*bWCbOu*(>q}j0 zYWKf$pog%EB>8uB4$c+lJQZ|rz^GCCgsAfKvZ}pXH zbfW*?r|P^B;bL{~!b|F^%Y2k8JKkhkEPTE*D%JgHY5I7SZeYQrNp9|LeC6)LpyzGb z7r1j%Em&S{e$6(Yt7D1{_)qGgi@V_aJyFq*M*TMB$~UfEiyP|2r)?_U+gh*z(I=1p zM)I5;y+3(l!YQ+RDiE1gq$QnP_@GnsgYCg0hME-5>pORpO&r@F;{#JV98qZE+1{2% z?~+`@WK{Xr_q~(j#& zhL)kC*={h-Gd2o@#wk@9MuQ+Ii?``-W`nU%xU}iO<7a!8pSy4Pc713z+3NT`gl;r- zHxw?ZA869Q$>6J$eb}aXC!>dXerE1qA?F#9Abx$yH5%$fs|SzO7ujU}zGbcuC?2*h zC79MA&}#0<FCSoawE!25{Lt0omGLW8#m}VauwMc9q(SG zA>}2@`+oq?-vp>R7!1VmP7w$XSdIq8^0ggHf)pJ>>P^$Nry>LRZtcE5H95FA(fh}z zgKd<=4ZFtivHgGY^Kr(Gqf%|Al)xW+c`TAU$@eJ;xp}jn1Fyoum?W#H1dz-D^#5X@ zicB6t;ul0Yj*<~Gm_rPIQOEW1Rpoc=mH;~(W#W_k(5QBJE<3EB!24otF!0Z0-fl#F2( z04$wq#|*3ku;EB{Dk+@nbG8SM>!^fJ9Cf7baBTY+sfhF4)w|Uciz45+;uxQrLKF2V z?&X0hN5PK*hiltGVp?km7$o)W!ND6XZZ^&Y^Cu^Sp8Zs&lf~I}P5SS2nWmZUuU=bT zHd2=z%j^6%U8MrITG(0*?Prgh4EXnmrEJf2OZ?$4p|8@c*mzV}*q;A}3|Ae6lAB{P zE^~>p?kM>ki@UUQwzZsT8)|lGdx=s~_`as6GB)*9W!m~V_A!}h`Y8BnG^bz|&24?& z@RdJADi$PD{phba&ZG~YMt1Ai&7eIRNqxQ0kS$b{i?h0zu|opjyv>^_R}8{A8>h*v zcCcc&5Kgo@%nxZfWodC@Jy5s5ba67-6U3R7hkQm?+5LZOMs&&a1AnFIre?41S(wC? zI#4Rzp0K0>wa72}TuCWKSP0SA8_MtZ2JHLys#3H5qt9ZNVY_Ij5a#&Y1g(0-S+!)) z_Sjrje%n@y%*{ug3U?289^Ifk=nnAh=Qb*So^s`||3Ym6YRu z{2$-wN4CCjb=wj<9E$XUy=SV?SiTT#c~*CO^!`?SJQ)mMwCX4lT@9z@n3gUWpKiK5 zNq3BdqsycN`s>oDh4h&cK_B%}?ZuFi`bqTgE@B2_(cPHn?%9YL8aDKQ_r`B{KMbJK zBEyVWpiA1CNIhM|Bc`)hq>}RI=Vp}UFz{4~ z*29HvKT;x=c=a|<-_2<~S-tpbx4orv=ov#k?4m#UWa>P3&)`eX_TN|EbFJ(BLDi)c zJ<{K8ALTjzOFFh+&|TWoOg??OEAt~nT`X^{_@zDDg?oaiwPq^=64hz@rglH6R4;$0 z#3OCecS_z$qvsk;H@Z4lDYVasWmTl-i>N?H|J~LH-(k<}SuMDEsftLwsWj9@=+9=? zLgZ<;bf}HG^3L&LVu8Tc3@4n%6YG6>R1%E9@@x ztL7vYjl^HMZg=_gU8)$%L<0Y%Hp+o@!R108ia+7^8+OO{|NXT_|4~LD^fW14u(3s3OyhFUIDsZI{c;_aLX8v zyK;tlGqxuus=Pgq`R_^rpW@0FO_FrZ)W{98o?#ddfnYl6dR6&5?80=FClS?6Q*d)d zkD3%t`{|nMIQsYsvA8(T2!xR?5|3p9KKy!(Y=a9+koBo(b+6@7T8V^U!jMxD=SfX3 zzXy#}deY_yUXPO#RQNdU??qZ}g?L8LYIl0tfzGn5W)?Mbk@cHLw~Ui4q-rBEg0DBN z$p-VPb?x>!i(AgWGClg>YHw!mkw9j=l~)6{D_Bu*PB@5VxO1>gW3xOz*=}u9`eXGk zhNHbQ_4<@`Rgv4%4G)z^f`+6y7Qat#W%8Eru0x|5${IeKYHs|*4ZnB0a=gX0VBQ39EGC)1;NdR%7RMRZ-;sn@isZW zF?ru}$u;cRqlS|Mc%z-g_xqbuzp~Tn#BFLWm%oELxj%blT6F_=^&*pYS@H}6lim2U z{N0I|SqFW3rQ6lPojZ2bJDm42&Yo5(Gn~kZ=G%#{XpAo{xT5{ymnd}Rtp9ViSE6h< z8LhKAb^>WxyYS-YU2#3nR8MwPri#f~^OVM-QCh!?tR32}#RnzJ4QL>Z1U7yk0jXD7 z#c?NkZM=-BvCrC{1zLw*d2aHb<)|`8!h%`Tmf{>hj)8ax@>wDXzYYzau5yCkaXZ|} zk#?zIkzz3_kp?WrM)XgcU#`=#>5*4HLspjJ#6d}>ns<3?r0b~7f@5|E`gOZek z|Iwl&@`;N{I-e&>PnJga10UK~suq3llS!;1iJts252-Pzy2t{hHkal~(8Q_uhB-9w zrN1M2Nm0pZ#p6pDsN%z)^FbQ$B#p8=OqE_&A|fID;hZ6>IRcLFo%W@C#nU9L42p!9 zKrj-VlZqYj!`UT14Y*14Ku&R%@MYk{8tSisG?#1}x0iqF;T~Wd>yl-f#uEI#?@!yz zH>2xcf9R3RQk(Po`rkL5>+T7Gjc=WVem4bWXZ(#H@9|MpH>OD@3?fW$H-I30&I3x^#o~2dfB-UcyoyYU?6;Ar3 z@>N=2*iCBM7-okXhg9QZZc>H)wpfzyXqW}(o+BW%dmvn-PNI#u{%ru~YkGws*RJ&) zGtNc(C3!zF=_ zoYZalt&;QkCt4p3prII%9FtMd$f4Q`l@~Q`>C%6upaakA&|$N_SqRR3tfA4*f6HSg zljf0BH`!zU?XTBWL3V#24Pd%hw*O?fg1Sw!m&Yv@u zuMK5&MU?p<=Uq#C^RiU;75Hxdru{o_j{?5p(L|YMXiEmGNA!sJFIb^{CC%z0fTd+- zv6tS`_ZPP3VP~E3y!9BTO>HPyQByMDwDqgy5_ z2^_coqZ~W`=nraQWiXebd^W`PS1o*$$(_08+;Bf@YUfEFa%wLY>;HFm7bcVE3dtR1?pV}ByugEl4eW|zn1f?#z1OXy`%c#iIS5_&O z+8noTXj-n7TQfP1UjI4h5+eDe700Q;ME&0*waAiU;(UN;4z$)sutc<^IPG74-p#NX zr}??Oh)aZQT)(@dgoKk|+BNIGre1HY`iE~4g*`h)t!B7H@84Q2kNHQhm!);_+4X~8 zZKBG60G!=}2fFmrq@&C-mqUT(-%H=X1FE?TJBQGH=*Fd%;hzUmqqI3kd$$Hz#j1S; zUm!U*keh0?_u0*Mt;E7_cAevzm)vC|vGq_{@Qbo%#D-&{jo1-$g;uz!OT2CW?eRy4 z96xZi*mqwHq`R*EB)D`~B^lGJKuenLmah5XGaxDxEoMO{v3{qg#_zhyumL)t%(S@; zyY3^cR}CHKjoq*2VwkkcF~9NKyRnSnWn$+|kJQsEB~NvDlVw;HL&|1TE>zcF{IP7e zpK+g>7Od-7Px7GK4E&dNiD|54ZD{S3jd-eq0DY%=)?sbYzZGD|DXgyOVa zGZY&xu9qkLZy$V9d&p-*SO9(NDhm*iRVdY{r&8XqIU>) zlvy~w;|ljUWdRm1gqNk;0fwzTlja~62n-k5ifcdmmP;&3Ge^$d>8MD zp%*K9pKt!0@_f4=Di*?Fc){&!*4(wWVSNeWtS3!S>p4eF%Vms(mQQLAa_V}a_ycaY zyt?4pWs9@yo1*iyxaf@3jl$duPhu#36jmTO&e(wI>}ZsvHWCzD(|mRvq|IST{RQKR zlb^`wO^vwC7kQJBcdXacvJWCqC3nlJE4+F^gI;%xvEkigCJ!kw+9Kz%_7Wu=wiF%U zrVRbIOe#!rj5JufFYolI@g>36eWUuu^4ZBEiuL=B#KGg!FvG2$l_#Oaj7v*$dtB{j zha&v9zS)qbr7TKwH*P=4%usss=)Va`*B74KdvPSD?fYDM|D{g%tCd?`bN><&TnX7Y zwSrV!U8vhR2ry!P+qASwsOmb&mblv@@%Q{W#Ena(^l{DibQ!mZJp?C@RYYuHh`;79 z7Oe&8fRW4#mD2bw>+oxyOw74Nd?}}*g4!QF;?gg@mtIM(3_I02+EZfx7`|L!e*D+v zS7u1}{mLw+4PE{uVq@q{fm1p>86D)~(^byR{FCh6ig9;HbR1iD6tzfun$b9lD*1Co zO|w8*ncIUnxjJ2FamJa}z2jJwvG~CI@V$b%796VNig;&!<4u8MSC8ylgZrOfoeY<1 zN;eF1%st2$9>qsBeM3fR3w05T81eJ5oPNAEtRMV}JU$6eKKZ0F*lx$Rr=Nyse6jPg zc;V}({{p8JwIAFw$EJ-mnC(}*$*+%U+4>KMZ1r`Ag6{EEn&+{Q;5=U+wC#1|Hw&Rk?Za>W@9%&6mwAo;=ZhkzKXrISRFTguk(sl8~TjX;1h%f9a1@ zs+(z%lWl|6O2$8lSx4&5weyL@`X_0N7YwYCv-M+KDiUwEhmucYej{t%Oge~6h-8Uw zDXDJ1W!WindF&+5*$nl;XCQAlO&2`X#0dP+ka+h%-STbw&Kr^OTcwjGb$7pAc@=Ha zX~sVF&)r#Btwec8y8e^qs~qk#vQnO1%T&pf3cU|O6;(U}*Lz$kq=?X1S?wA#Ny>(@ zHH)Sm+rA0;nfW3D8O4l#f5P#OQ!nqg^{w&zw=&3A_ixx-j7L5dg(n7!zX7{d%u6A_ zt8etPo#t)Cxie2=ogCclcjcN6BjUdN?I-mK%=G~1j~Ojp!0tWkWXdw`szIB$Kh z4;hXYen0y546b?d`o1#ew{)O(YX0S?^1GjFuECf8>l$_5n*VCCg88CIVix+(^Rtk^k7^ z$*etE!Zp4oPLvrovP$<7R1YJ}@Z5MFOM@lM4$Hv-b7rS8FUfFil20x<0BzvYTTo8y z3=)EiATskZSZDH+J;__`j*`w&-nVS?u0E)nzTYT&)9srghKQ?jBNnd7<;v6qs@97i zUXjzvJaTvMM{_aQr3j0H#Fu>}tY4=nOQw!LObTzA8?01#XnzZR_Vu7eXgLv&eo?$+ z;E4G)*?-tHeDjh1sF&2`E#^O6V9b?~xllo)L;eowMAt3%R#x*m{U5w>SO46ZeD!Y0 z>@!9E%qfTe+Qx#{ue4woka>{kE!C8*T999G7?5hVzVmPN%VPS~xTYRt9lB|FD7afX zrYfi%GylT2nDS+Tn&?~p*zDZ2ty9svAY_iQXU0ZkbWGa0Wm*_acpkIo5ckMTcsN8# zim3dYc#Q3dp>1ed8|Mqv%(nWe z`rH7sjKJ#>o9)SS9sHNHWmLZp@wfhGCg8Fd+L6tlBK>%A{(4|{%RW=rlRV93yZ*1p z`lz?Rjw0_mU7pX#pPTmJ-}31TS9^Km{^4QW+v{I{z8A2p>6cilU)Y>)mdbSU`Rbpb zEPvtFapz4J{!nlxE(ys*6t}8>gL_&3OPyCjc#oE)k~2QMnV*^h4g7w0VPuFu&)JYz zWbQM0dTngU-ft(IEII(~QzK>mGx{%xVeXBIb^CwT)`)LEZ?xV3cbj_$)YfTB%3yxt zQ*ec9Uo5v4E7S;EU$&CzF1@c zcVnTvK0ny_I(vsPD~*PW8>4i6H+NM}RL6yDLgoxk$-?j5Xz(>p>G#T|Yhc#<98jYn z8p$(t3884>naP~zJ37gfn9+i?9O5zfI45u#^D6p59QimhDOst$*71o{i|J!Ru6-Z% z_Tx!DYMWuWbWDa9YP4SN;anw;`5I>H9Ql)7t&bVE5C3P!&pC6U9fKW7qM|_MfW^Ku zMve10qwVL@F(+LoSiXRZzYc`pV%3S~%P;cq)+f#dIrr?X$QqL>g3~2(ZS-{z`5rgN zn7_Cyw&~*Q1z6~fMbyLvF6EbR$8TGX_`kAHR-R4CpShj`n~n~Dj;=dD_ffK5g84M& zkCJoj$t6)nUAKJI>?b<{LmmHk?XY}l{n+gq^s@h4af+rX&!I!&_L0e#j|ob{59>k- z3nwF8kG)LgS42llrn|iTR?ea)1M>5w?hAi1W%oJah!+s>B7o2x;!p2Bx2;Ux#@UzF zI8plo#m)Y6A|Bo~sz{Kr>}KV@%o8$4bAhF2WRIjq6mj;;OI#u9wCCeKai>rR3;h@M z#l9+7>~vOKwH)ShlKWtz|D+8*l79MJ-jxXWPPrPH6e#>iAY})YKHRtd*o4a#PBQC> zt9-@zI8MVK9-M4=l>NMCk2ntr>`WLL86~|lS=qgPaz^{b@`tythouaHzHTrbo7xiy z)i*XyX1@Gl{NBpm_r~{0Q$y`NzISD$H}4t<9~z&WH$H^d+bP-qkmKdo8!PhW?UWm} z)NlX#yhG00wgh{M;XA3Q_EQZ1DgV@|dP7yfNfgL{HCg46@!ByM)sJg)Y|7p-MoR7? zM82|+v{!+FA3M5}Zj;V2IL8dn{yDzgpccC?s3A2{`RKA{^y`T#QEiD6Zn3pDt96l% zi2579JKk>abVH>r>(oy;uF4wB;4j)#DUlyu;prEf-Y;i{YCix#(mdoAxF>_eTFm8|4sP5^AWnO>Pm-{;$BXZMJ4_`WPMWhBObJ_TY zrA|ml_i;``dK*$YPp%o5jLbO2a8k;AlYD#!9({9OZ`PN|Zo5?Wt!c_-W9zGyW=Rc* z@PL=avD^WD*RsD>_3Vr<-%B2kEJ7B}%w<($e;})(oL{|d8j)ChH`g_#-(bD_^r4R7 zUCoV^KX#MFA$%N5tR?=Nt|Ru}e}}c4IhuBwwDu8&lUPeXzHL|3oR{zoW|HJA2N0z= z6<{+CMQvAhW~k0fT!oaECI7jzEh0FtzIbO?ZKh37=@(=zOdgUa92B3%{95wD`%^y< zAsMq3Xj5SQAj;t@Dzj?iIfuHnbH49o@dQ+f@f)Mko*k3<@?hBgTzSrdlRw(-yum-j z>R9^4^$wQiLOd>W8u!kRIc61K7}9!Y5R?CTwm5#tuPhmze3LiS6!$Ucgrzs^;lunJ zsxtT=@?#<%7jSp0y}ucXD-9Z6@qikS*Aw>m&PpJ7?&!`d8KH7Nq8`ZU)-PmRBMUD_ z&bfjL$Cp~8hD^R)_+M|cA>|Q6?`@Lo5pDLmBwplguf)u$uokDryO$)ihP}{xsl6r& zJ69!stK!_(bnQJZ#GQS_&oXjBkny&;=E}oggHfLseka8g|9Z+B&KslF_C2I)T*{sR zxhwhc=}|`aj2+<)I9rYsZnYk4DHp|&K4n%B=9*z%HEkbKpC$J$4W1JBlWfp8D2YvG zT;qyVq8tXAO8r?glq5P>c0sOj*F01{9~C@&B`7rUjjG@{eV*22;G6iHnB=@Zqv9s% z_FePH^2M^ol}EwROLk~{F))1+oIzwwnIe3qi9kuusR>|99TTC``5@S@Qtn-teIBF3kODfPzADTeD3 z_l#0sL?$VoWh|*zUsvAWTVua>)=Lj%MjYrPLyU!N?UHRiw@IKj?Ae39P z_ZzbK+{IBMOGxRuVpyjFqVYZA{^0Ptz5!LYgU5}n4XZJZrx;F2K75?q9%Z!rX+qcT z$vMrN!oVaOENdb8ohUg4v{aIvmC(viaMyt(lswZPv9`vU`K=(Q~uR22-!SV;?-KjS}PXxMuoE_8A`upp` z3FYqH@mJAY)B5XuxA%|C!gkVrB_;B9+h%I(QBopne@dQjDSx+- zW&d=<&paD27G!!Wt}N7t-3!XTK3!#3jeA;ZB{VuQg78WJ76#ORfYn-=jIP<;m{ES0 zqY8x1u}prKgd6gj4p1Pu?#Kl1H~PGkk=^F{(y+B+dQU>H^AZ73r}w*d&NkIjjW>K^ z!M;*_d(f`NQT46Ws$ZYw&B+2_=C70dkGNZH z`2|rAF1e2vhOIQzNS0!9Z+Uz>>|tGwT|Yq=_0jAfReyZ?UtjsQy96uqc|_O?9exDs zmVet}ZmPq-;+QF3`xdK;uosnC+ABvo7tg*i_s(9J1Z_AS$~nMxtK4mGSwx@4{9ZCN zkmfye49*jG*M8{-PU~m_WKa0m|#P<-=)VJ+om2(wlwBZ zSC(C}qSWf|>t>PBU0#1SIj??UAN32@e-SgsTi92+xxc9Y`;4NlsXu|SU^D+jF!8Fn zS?C#7=Kh>#H=N%BT_{xU#~QUHIbVhD1Sq^dG@LKczD5#5L?8qecA1+T)!$2WlbWQH z1K-tWC&?xZvFi5N=vrN}I?emG)doy1&GNCjf57in#O-^JHROgB{hMGq zuTiy*o7)@+qC2NHA!*gqGc2J?$VieH!Y;VF1^uI?a;}X$J{N1i6#XQBxN6q@PRBiT zQ5FGUT=_QdWO9A*tEngMrGvt_=vdC616#F<{$|8)*>mUYj;|GSPH%in>8)${nO&*4 zRT=m0e|eWuK2A8T;qRBemW+q=Lq}^##P+2D=y36IwXB4=jUhF)&QY$FtIK6>^0DqE zSX`BObP6BPGU#?9-UvNqc-q)QK_{lFh|vhBe&b%#sK$pai+P zn->p#Vt!k%yyq1XVwowPYd`w#gG23I31`Z>#d6)Fp%((f1CgA(=1-krO+W&U;+aCO zp}SvE7$Xo4xP#io7QKU$hJsK3*0{E3j{RN)hbenH?yZdG1mq TeBsR_I!(_|*f2 zH}^)H^vo!61y<(bPfrr^{&d!P!mwz!O6Q;j=}MKI^TlffiVCRB4=ZExlQiac{>yCz z^5>1$24k{W%BmlU)$8JHN~0F1F|SB@=p@u`a2zMHY;_$=<4m}{DD zqx|e2*OOPGF_lVN9ig59)-^KwEdB(*gAjG$m4453k((Lto#A%C zEPa*i@j5*{?`1AibNbkPD=TZft&^%wI=|2(~N97Viw=z#5I=|ptu2} zl1l<_@7Erc3@xri}hh_H$d+)jC58&n0aF3$twwY|l$;`^jH=?yPz5LbGX! zlQQGa$mhxjid{lZcF}5vC|wK?!sru(5g$T87!^GQlC5v%F=SPHb`cy{Ekr;(O4lBo za^|n(JQ7o`Irsw4qtU@~sJH%LP_GCI4;XgS*I=VO_EN+Je^boPSqBpC1r-h`1GsBt zz}$>jZha`v+1cW0gGODb}7X5x0Q-gsr3`&HB^?hm}rVnMyCkcz< zQ=CS<)XioorqrsjqY^GjX0ny7CcEH4A{pSsAqH7Nm9*fagSvH_(L_}M(g6(z*==Ak zryS9)H=QF&DqIE7Sokrw#T|6q6dT2>RSYmFL7IE72w5Y{1Jz04t@i)APxOL+Wn)Mz zbPeQ&V&HI$(aI2Ohb`0bYefus2Sa#hM115GF}PVGYw?{aqP*KY_-B^GyW%$i5u0@~ zmSd%f|4nc21~|G+Zii5{w+P~A8Aj-XZpUA{T+<@{JqsFNOmQVm_cSi}9`k*=vNN$O ze|#-ZaL9CY(f`r?wz_ZYKDWP=@|(trQDq3=?O&M5q6EM>gefxU3>$-V)Et!H$&I6# z!0~7Wx9Kf+OpZI*3SSGsJJ40A?Ve4|?jNFkpF!mLh>87@Qm{&}BG6(0$aSVD(P)zg zf+S;lDWYJ^B7DFNg)1^S^q6)eClLGN0VMO=P8Bqy%oz;U1bGHSh630X%Yp~-86<~B zSU)XZI6)UU@J*=`H&_u5fm8V7lsL!xB*ZX?@n%a!Kf2i)0eXAo)E}u4zug{Ov%Cr@ zn}bx0pLbB&n-`ITok5Zzg2IY*;xt{N7YYjTD4QXUiQJTqA4`NGkYx;d5{OS-EY;Sx>X-hI99u z6V~FlIH+tSw>Y>76#;_M+^Nx|WCC8Eje#?6ixsH{FpbBvQO^L^hz@DUGzu}D-;M8p z9Y~xu&V=$MHr7$a7`iB^#R5H)bjA0%LeJc_&hjo~JF6nW0m4JrgL97a#d-D)Gnf%{ z-0YB&9DO_ZRdxsr&??Sf4JiY;k--oqR42X!GQ>cuz*Aw~h`2rXB0&fn*5UL|V3{c1 z1hmO0BIuSOS`F`A^n3x>G$jGtX$*;Mgl`~#Ef6Dx$$_8x#ZBGj>NWxVGS3KwXZR~| z!5*Bg;!^Z8*SGu#!=!=4*icW|y)0byC>Ffqv-Rm$;!>6JCY*`myvnc+Js1bTA_qYO z?x*6Ja~T{Phef&+|I|E;Bn*?`BA6mc%&NSiR!+Yu%#Gml(-*XhC)tY zrf)&Wk_!0{nG|mhXd)w`nx8q58?=FKgt1cn#Xwfd+#DV}@D`qtP_`6XK%%YWum_9f zI)|CEQ2dW|4k)#=$k<^un_xZQ#K9cFp~ z*MEKiJV2@YUQDFn@c=LZZ;f$-;Av)pATHK84U9Jfo&e?~auU#}S`osu7BP`Zqlm#c zdK*V>>RVjZb@4?~09Zm_k-CLIU&cT*GlfkL<8X1@b^`UX;f)2B2JNfm1!6C?K1P9T zfp85BKg5G;e+cu2Asdu|gfb}-SYQCN2yr}=Zk7SS;A6xj#e41Wo6uD(*QN4RFdu*m zfCpkSz+6z;BSpM~GXK6j#u4DbfKodLjN4PDi>DL>B~g89_`g>-=Yc;xCfTct<@#YoD!Y2t_5sI$TV3KJKlOB?1Tm8uO-C7HkiR4gD6m^i~=JQ#>wt;J}Z zMGndE@dnW!07wTOaZciDJq~#lPW^3CDCk%aD9=x_V@?N!PsFYe0RR>uH+A{OX@&6uhsd`)$2!- zFfYX9L0+D8&u*;T#XmhKFG~}h{;H4fKS|O&JlQbQ@e<05>+L6jZ!Rx3}SUlosNSM z_TlVcf<-NW7_5Wg0OBz)K3L&e20$1o1lFJ%6aoNpBoGCPRRIARXyR@Fu1zt9%lJh> zZ%+JWyf=td903-=U&Ty8xPW^`Fs27WcmZX3{%o#BMT`?*9fb42II$`-W?G;i%51Bn z%5GQ1#`52(hZu>m(4JwaEOZs-5XOLMj_x(oIj|Y~lh*+9srwn00RsVZF{0QoCLpo1 z?msZ*usKfDI4jVIfIKha)F)6m*R18zca0^*Z6X1O~jlmY$ z+)~EdD-HA9(LC4lsn5MZ1S1x*5f@aFi?Ldex_JxQ8|O;_W9$&@j4&qd@dzH65eO#B z0W0OB=zw$qa47xq1lv!fuEvcGxRv>E457b-(x;Z5?ba35pTzF(ZaaAWar^ygW##Xy zx~lMyspk7g^PZFcdd4Dl@FhQ~+X}b7ca#Y2Djj<{!{pp_7M_IOeR*x=c;nA*=i1Yz z=)DX-a0U-(6$2fDfoN#d*;z7<3q%XUlQ@RhX*TX;NHrTuPmXnysmvD*-4w+C#TP;e zI896-v|olxMcw$KGAAa7?&ul{XjJEc^t0pK>CTebK%5CC8WhjfOj5h(lr4!#z+465 zpj_3cMrh3eOq=Go(=jpErd$9$<@1NcuxEaiNNWtjU0D$Y2QZH19J} zVZ0`gUzZy0WpGkMc>m;wRTVMFm*;QEBH8X)6Q}vo0XI>sR2IMpfyRK8PO8fQ%2EaL zmXbBBsA7yw3A9cefFWM53_z=<>sHO;sdt6>D;xQS0mx0(BqoM(zQjt!qzBf-e#c~8 z(^4b86nx#ivJn!^T}RW1EshMvKrxCLua%^gWPn)RX+f(5tO#5fKh2JJsP#l^ciIld_Iy=n-*@N=*X+8Ei^6npnu){{kDPp(uk>0LY6Sze4D8E^$xDltG9%Fy?czu`q;iSSh`{BJml0L6p?Jpg+& zrOV8rz&$nzus&$j!qkJAb_qfhbpjLiXv5mDZ4p#2+Rgy0g$FPtLa+dLUqLT|9^G9a z9=b}w4D!x#`~<400_Ip&Suhs?ig*l0i7B2n;}S;s^wJQ)s4{_BmCh-G&pu$t5Cj7_ z3AY)j8ej(~`#6kdK?6W7puJ>Tm?f+Rh6XVk0a4;OkO&UI2x2Cfup*g&(s6KRWLutY zDhZDqPD>_{VMsk-=Ku0Odzb4QJ@1j_CN{`~l-}erL@r&Ct|DfU2y1=Wl{$$uS{&HB zNcTv^!T=~ZyTS!dj1=Gp41Z4Pk_4C0_Xm z5ohI^AM(hS;h;Rjp=FR6d2Y=j3kxrg1PZAtyVhyMv+L|hMhS6o5BQl=8PZh9cJ%-r z;1yyPa9JFLWI*ASMD9TojvRkn_ITDZb#gMsCmRB0i374P#Jo>hF@B#f7qIcjD{uQm zx$b$;hFOoTX4l`?lRtEo!kdT9ZO=51UY=dKHCDK+b2k4%$uH;oQ_kPMzwf!Ff(mwv zUI715nXS#ca1+Kk3;@)^#wo0fuuRrol(IGx_#qU+RbGY$T+_ww%;l;;fr{#$w-J~! zBa92Hfj)rL1c&*z5!`(<*g@Au@D6YUN8f6wCdPybBmSCQ%}yc;o8A=JN5G&MM8Otf zVr=}Z6eKV&nIx(iFVdUVMITUM3FdbMy+?S zZ9N|Cf`>wzORmJ(^bG@k0OR{r;dA`PE$je8)&K|_Q?(I*OvRs`cZQwBdFi^p*TXm@ ziiv{Lm|%ohpqRa1b@ca<#B{a-rh#QJyZW7_k$l(E>lN;PhQnqy@-Av zV**^u#Nm`?3GG& zrhMp+J;@35wA?ar`84r$movCft@r8_gN0?Ux!mt6GAs^JW7 zlQjtyfiHpX!Tn%E9*iOlr0cXN(D~7W@)Iy-Lh;1bLI_fga9%gU5K5U3Mc=y;mt+jG zR$^0V#IvAtjdb-Z6j5>CMHtdRtT{aVN(C+Y_t$Tqw;S)LXoU4of(c_Kp7EPnW*+_3 z8ZvfJw1BR49<$w3+PY}7v&e*ZZI}O5Zt*%~wq`)x?1z@Buo-w~w-jr9v%5*F?(Dgc zo8lJDx~{h?9)9+|hK?=bxgI2QI-5FeW`0P3yywVdsXsUOed&x{H#Eq+;J(2|`{QZ- z*7TqCLX^pnchwWl7w<@K|4?adm;ZO9{%9z4jQ`c=bqzMCMw124Juu*z9s_IV@75p3 zp(p>H_WX85yRR>6Zf$1y8P@~OvAs#KQj;!g5;JsRSVf}+>Gs$c{1c@y%oy=SIrL(K zs+a|)h1+gQ-UZ?Ulu6zh;*{h_>ZReU&)VV2hO7xOG5w2J9r+(EIZ7D zW($=5Gp(4C@rvqpj+ed>9Q5dw6w<_4cc@3-KHj{jgO4oPe>KQ|ymMn`W5z|!z7qtJ z8wl_sznXpR$tiVH@d2jr<;~5T%G`114Bb*9)YJQ*@}Rt=$Gi_{=YE#H@HVdb?-ZXd zEciGYGB-HEb5Z5Zn6qQ`(0;;oK5a+7iT`GLd7g6$eG|00b<0YMn|ZK?6}>GK9A>i> zeVZcizf{Cr)kwa_!C+Uk%=U*1=*XcRwrPHXHb_ zBP|Mudj=_(gn8S#A$cNVoQLoJfAVX##^JM&lkdy4wrgm}s+KI1y$bWgDUDeZ#R8GaCL@tq8wIOuL(j*XU8 ziHtol(zWGZ`0{wmBPhiF=}f+P)8t+0!rEv4UG{LEzdEaLsvWgot`-y+QS9@$xc3xC z0fA)Sx+{{*-^IULykN+KYg02y{Hg}8m>SKzAF|4eEtF;-qaL5O{ z^yCCrIZFSYd}(ucJd(x@da&TYgLptj$~D=%G%Bb;)l z>(X)Ur0P4cpTikvLW|z!V)4;#=_ffkhC)xCmB4Hb&!#!P{_v*_IyPy#uFmJOvGBu9 zSiBfJ^!wAJgc}{z%i|>S+e^x?uBoDt%M*+V*Cyr?HY>6ZN#BNVCG}2)Xm;M}`S|rB zqS0e2x%@r*2bl~)!q^?zxcciI8Nouyk`zg=8Z*{fk@muxkahu)?q^;OxwXu)4SHW| z508DFU210B6Ww)^-`6=h664(jZXix${)iqMC4Hm63t))@O(;MqEinadWQ+N<-ANt} zRky#g3x72)(v}s(f^nAM_o{-m|fs^REI}wN7LHg!VX0 z4!Xm{bDKhfsGhhEOM^}Q46x~jc{25NwTlMx)@qB-BI8TaN}2S%4fRrvQXi!@{mQ?o zE`eMO=MgW|kdY=jRraeqC|)@UPl?U%I8<8(mo6yk$!lE>$URav&pKWR#v=LhPE1l- zF#l|T6bJv=^e~coi%F-NM;5^iMOhY+WqqKrV)n%SNyXdi;iE=AKa2@ekjS7wCp}(h=8fQP9@vtB^M--$-Lc6bMp3g@stH&M+ci!Couei9frqN4-Y-mE}Dq;~o z8iGHav+TX%{b9K3>G>MDwoxe3N8uBdS&oNY42eBdQwN#=Bae)`Of-MASt|^~P+lO9dU@gBo=iQ#{#~}O z*YjJiX!c0cH@q-Pvr)A$9ck;AF0E$YVXg>+PQzrDHr@PBt#>_t^Zar6 z-tgGk`AJyC6w^3qp$a4Cd}+9k@VeiJ0O*LViYfhP;@awX;q5smu*kQOD|i>+w4I~B zDClxBd%G|XrPcgJ_09aI*7{Wgw~!S7wo{zfW#hjQUX2>7im)eWEC1ns>iROYKq9g6 zoq0su&R}?3qDe8!)X%;+6jPeZW2!9T{iiGK&m$jDmg0e;=YBx**PfIVWOKVrga|A= z_x(dqv1N3V4{1*sIk{VH6`ifFLbQ@an1W zR^63(8Hl}q-0hyCSvV6dOrnIbAV}X(yTO}BDu@Jqd@# z1GJ#hZR+6QzFyP)#362ww5yQDD#ME*W*`y|y5_bj$^cwYL~NVL4J3e7zN(Hy3;cG^ zx?B&v$w}}LX3@T%s7B=WZD%kXVfnB$iEu=s!a?!!?*glJM}ktS#*}D+3=X#3>l(AN z^s^`3r_1^$d!W&#b>`;+GeBjOmkpmm~#iahN9ddzP;z zmfnll!k^H8A2yT|0i?eElj_2g>FLWTzAe&%;f$2FrGZ1A!iU0P^uy}OQc_Ix6Oo(XFKQBj-h`T)?m_7w~YJvi~!P%%%@yziwPb3z!MC>FLx(vEeu`bkTa$P(W&H|z@y@4@e1L+g2 z6xHG3@tx=iR>_R%&Ry6Y((j>$Q%>elxcPwr;FnmY1WdFKO6z2YvCv4gVO9!=2BmRh zV6Fo@15hCfJE5UWfL6xO0tgP~rE-FLNn5-~XXqL+@kun;`c==G5VqgaMuBji6#x z(}w;JO3Ts}f`dSB4jX>OvM9n#L1|)zFp`$x<%UDT^|%$?Dp@|CPx*O#PXt|8Ip=K% z1SV55g1J!~1>~4O6lb?ZqB(bd7z08Pg@>v5`H7k_@yjs*1c7iUMhV}j0mZ~&YmtNN z>%4KvX)T$NUt&+{a7IAW01Yd*lcTyM9Or}1nV#d|i;CHhG4-IHI|u(=Tdr+0car(H z^!f4gk-Vl?O4Inqp&0Ck(cpu@?2fLcP#pkBvH|i#0-u#0 za*t7`0>dcZi$ZaWD{8S}pq!GZ33a~lJDBl70fU4MZV>=u9Kg5IM=#H3?*Q-k(`(Vn zW0+lqg(?h4p}+r4G!o>g)+Bo?G#GUo_y|9!7A6>RfsI&5)l3Z4NFGT}8OZ`M$xbOL zEr!1P~QMyG>{ts*kVZ8{O} zz=DRhPmqnkjM$LHii)4EAk#+ z;Po$;kSQiEw<5CqciL@%mxhj?abDskzfHONbjj|8eCQ>0vG5+D$ zj|$oVp>SK>vQ`ZN1?%LGjwv5Z6G^=!KxR64rU&dtKt(1IoQQnl7pO2ESVp{VQ``sZ zJQQbZNMq(>HySSe08f{YX;0$vN3t;eqLn07wL|=z9#6$Sz(T(#~mJs$^RNiC-4CASE zCvL!bd~=xLT|C0;S_C~_Dq?XYmL3>X1;~lTmT@)? z(7=vUz+nU-J7`_(jUsZ`*FcydZXiR_g^L>I4L@^9jFlCp0#h>K){|xswn%f<;h)rVRkda`t~Te-{3JRYbKrp2n34BLn&WyN~qfd#Q0P)EL`rqzsV4= z1(<0uAHbwac_SzYR>E8f<$|av(elJ(?3MMF;RxNP z1R_Xp;W{RmOvA?ur&EMM@^M8n1u-vmtrMiL&L6a;x?|BQ5Q?SFiG!Wy)XDK(J)EG} zNzeb1{okAO4R>FEib)xi(>CVnneaGTp13>}>0J8cmj_o--&eOU(F<1X90|OxvA5jAY@m?4x~q_;9mn?%&i#Yjv2`q zskY-}H1*~Zr5?zGaFQa z4D*#mY&fi;-bOIch$~s~o12_&vXWQxFg05?sQJ^(lB)tG>SKgMSJkxL!LJjWJe4D( zTV1g~P7sdl-)i*KH_!#&%7Oa#TGWfancs~q{Mn-EUlk$GUc15LIA(1IxRl*1=S3!@ z_|~pd<&fzjARDajo-U6uRwS4&&L9yWOtsT+pVo~BJ~W_0+%vv7QaeRS^tu6yCl2&W zvcLkBA)Mog9A)O#Ibyb3F`I&{>hN@eAn;7#vm)QhxuH+<5vKmO3zasezZy#TkZqmVl9OEUlqK2o)oc z3llp;B>z0KJ@XA`N73{bpCw;7+PHUw&0Mi8c%p7-^faeipVuvana@1xM@UP{)qETq z&z8Ye#EKbutv2n^L}*U#SGQaj*;r5BlW9tgK_=ftrA=9z=oland4qQjDp zZLVKN;g*1l=)Si_uQ;0&EZ`F1;(GpxjASr_oLtW8T~s&x!HiMb zpObcH4^T1;O^?BBg_6%=tV8i4jDQR<($O>@`N9B9>RlCE>qn5e&hh#~x{MH_dXIjx z3iCKPj%KJ8s6H26h)c;?oZx2Te%T~%^YL;SBMM``2oDDMewgKPLCr&6@+dvdlb%I5 z-D&Pen1+B889-whig)G8(9FWmPW!qw^HZGPn^VMOwcA{t2-I)uTJ0Q%S1q5kA%-&g z$~pVWXPMG8)&~s?YGh@EfWvY9{>HrcA0Mfuvf)0>$+<&p|6HH<92;?&)ng6{ifumU zzO=xfXSoqFz;?Yxo8<$aSMz`67v#dJIceOZu4cFXncMJQ9`(CiyK2)WZwh*8E;Dy! z+FCjWW%lS?g=;@t)+1=f*i<`J^jy}hyzn$EeC%Uc6VrF~nDs9M1@k}LTow7N#WT@G zQG&@bxYvBCde(cwIugQ91kD~JdI>pN}%+j#!Dt)rpszVU+P>2^im{x*~5)}(vcBIwkP|NF0hTGT!VxmdYKzM)#nbX zaJ?Wy`k7L7xC6z{&1A$Zi*NYTK%1YGgkuzokjRfBvVhIFD!XN?#*f(--W6F}-j13# zv&!63+eqL!@A|Ut5D`w+)Fbsh393Mj->Pw?dCqaGF)ciAALRW0fpKWqW0Jn}=MARs z10xHTUef}7=U-u34_HuHtY$baE?Uz$X~m=eWDnkIwccD2HhdFAOwKKg?uIW;I#(BW zQ9n3k5DDj~HVvqD$CJ;+0R_QB4y3kT0YDBq6r4w#47}nwxHru zR$HKuuC>U}BD>>`PbME<9WY2-AG}jJ?0qJTh#sQOGS}Ww4*nE#_T)=!>_JTapPN^v zj_b^BI`VU!ap_s^f0UuF_KlOx!E5U~?B}Q$ON7UNU%m?Q=rY7SyehOO;2_rj)jWnlNYzos7I@$pCXVXpXZu`SR~gMp7i81YVe6OQ=GlCncqJ{b z^yN$b3k^i$clCTndyzc86zTypY;*(M=;b99GziOrUU)1wqZumKf#V-SU-p@^@ymRR zvY@J7LR9fwVyL%`8>p0tTDIj1Y=H+=$q%%A)i{ei%*_Bk?eKqC`!a4u>s(EDoruWf zx|m)8otw(AJT&**_19!AcorpSo-|FR%q~7fqL|q>{>LYvvmwbwcbG7hZ!{;=DCt@` zm(*M8?9?0xY2HM^)HqV6j~4CFfT0-05@#@!_w0NsAg80V{QFX)GZ->2>_}o|FaP6N zv6AT7Sf6MwQ?JQi=u)s$bN(Tttr84CyKg-9S#Ebls9YwjBT?*V z8UIhiiyj?Ku!>CbgTl`k@~Rclj-MrJLc6Z*8C1zc2`{+%`CB#)l)Se@*d)Bm;nEJc zN~q*#o`a!)48MUqb{eE$T^N2uUARRXMK(X zZ!vtGPJG0>PPqBH*v{@XgLGWZNGg~2lXwij$(_}C+nCn$l6=OPSjoc{!uCq4k3GZq3J0FVlB*0ZfaZd;Uaf zg5zFrGw(Z1d#m4v6p4Crr^E*~x;efW*quY4;b5#0Yq=p$Ajnj$d@i@yM`W=Lf?P8W z^e!ca{w^3|AZiCxp@Z&^j^!nd8_iujAg%J1mtRxpzITe@q~wc;}a0T&VdCAaO8L0ZbHpN!UomQJcPc^napWaP#bwhh@)^xv) zjqAU%a;FQP_e!(`nDQ--ea;JP=dAg4K4I!&%ZJj^qp7tX4hd{2vYZA;$@{n3P<)OO zF)cmVi7tOP<}GNF+{F^l~I7!f{B5J*cW?^9hlZf&MmIf`Cc$`@~qO zKdLXU`d>|u2t6hr+<_F&A1{jrP_llGji}G}Qst$-$dj7d$+N9L z|Lx~%9_je}Nc02Ly}D6>41i5nXlD!|B}k8MSl%k3h86BbhWrIA#AsZ9V4gY z^S0QfR@{I`Nk&{I^2MIqE05siemWU$=%2SSbei`s60?)ZSMR`F1#vedW!RgtV;1_H z+OD!~h{eLGrOK}C`C+p~=C93-Z!ZlVgDy&yDh)q*opj^mhfJhG^ITUz)1nw9(I8PFU*c0e^=bYVD#;Dyx8Hh01l)i=t=1Yw#l9brR7e&63 z)?Ht2dvrA8T?ay!)SU&y@bx7c#B=MMXRbM!^Bnr*R)=amDSC3T(y&0?<7dyw6-i~e zq|_69;h1f;*3r^x*6xReGI~QbS9mh{z*qgXr$z_qChmcbsZ9i**#{SE3!`}E1uDx( zUZL>9hvOFVCACtmQ4f*B)WY3QLFDgyV!h99Yj1m;56P5unmNUAQ;IbrO(%*;=-V2T zN7%_>j8AToV6~k8*cHj?O5Lf{Oi8W564R3kq#*IL5Inmm(=J>k)=He>nd{e3RjzJ+ z-E#Qmr`jcie0OF(Jl!MWljxtP_|%#c$J=*!W~9og_ROwt=PiEi#9 zbd_;YKVNrYm!)CpC0%;yB_)+^q#Nl*8WfONx{;QYM!G`@=?>`z=@bM7{3*gd%m3L| z-_O39xpQac&b{ZJ6FBlTA-btQ2yd~}@voI*#-c=HWbLbDxqVs9@?x(^MD|>sZf)ME z_C((AtIQVDr>#)Rd4a0i-`|J>9p5m6IZWK9K5)c9=oC$f#Y+swc5^CkOm*B<6~+!t zGT$t8^(;+OonskR+x$16eW6zKa}+>CtFzg;&*BWVn&T-UF(p5Q_#FqVT@B`0c56@i zz3A2(u$519FP5Eo_5z4Om&jdmW#sR$oM%_~A({IsRnc3hWrt4wv)ILvjnLSmoiEsb z)jjrQdpc~*eNaAp#i~5Wv|Q8vCOD0cb`PqV)IbPYWip}h|$8d zdup$qrjqPI^GDAeuj9+K&a3csex5;qt^33IzlkMDW!L&aJ=eiE7$m(jv`>odRVW>W znm#IL#u0f8=HTVBq@JXN?Z@_~%SN#DYA58;LPrXe(E8>^gR?Boym1Tx>SJnU?#la) zdrq-Mb=L_Sl0rWMmdD<7*W0-^wm*cl4NX#BhkIL(XD8hF89S+Qrj~$rE2nKoCHeT< zwgfsCy}7TF{wN9We2F}#>(S|Z{`_9>zqU)LUsW%r@n@cIH^rV~tIDbv48hJUT=a>f zDwHh36}^{#th9jdh&(*N{T)D0YXwI_V$6d)Mm1t=<&-Gw5dlXyNQovQqCY&#%JJF8 z)$RZ%nDBqT%BpIRB+{?`5*>y#p3D35m%-yzuK1@j7#1n!)Rq(;=L3INWqBxZ=2{nYfB16GAxM&cWC;qbXyd_1=doECI z-0P}oJk!?9KIJ}2XN#5SNJB<_h1+#14MYZa&x^gMc05ue?oSTJNYL1P2_>uPfRPv} zvgE-ICAp^xAc@OsYtHXYqWR?>Kjmt3=nB0z4Z(O(CUSf(T9XsPA1cYmg(IvU!pq17 z;5U-SLmxs8xX&ZG@X*dv(Dy(zQh+wB@E@Vmr)++wR399c}oEL`+^J{4V#ll>th{(j;l zxOMgO3(~XnOgHo4%$wL#r;}OqhdhC?a8C?y8PdH8(wQ z-dU4+qfHhC(?zCV8i^ei4^CxLmLOV{2ro=VHy6qlk?h|;UasRoJ(A5-%$a+RqU1Fh zE;s-hd4sT4CMY=-!emXi!(q{jJbDJSs<&FxX;^K#im3xRy3?yvx!h5-LvcBi!#eM# zO7ejSh2Z3*ejW}iTno2&rnF%O$C8(Uh1ee~)HjRNwa4BR^9V@VEZTr5Z>s+Tp`OsJWd zN_>hCgom*FCrvDb^L?;{7r%6LuYLgNu@jx&pDb}b+4khk$xhe4(c*h&P2a2OY0r~& z@^>}sC7Kd6A0stqw3g~QN~;lP6DMq^NC0>F!MAG@h23+HLQ3y0zm&$}qqJJDZ8t8F zYBxjX>O9pl90F}+Lb)9=jMG9<5;E&iE5STZRvw#an%#kX%GpI7aI$KWXTqYMbVpIO zo2PA=woX-($Q+l7N__lVTA}5qfvrh8oz)Q^iss)}#^&a}bofbASKs-k)a-K3W_|EeKeuBK=PpdE3&Y;#KWJajFIvWSq=*9ZCRGUPI%+*tH z#gYr+>YW;To9p--E{SuPP8E)+#7uL2l7>we2uCujDG;|c$)rj}mI|Y})k94+9RyWW z&;cA_^&P{<(|V;KdWwJD76v{08`?M5&grQSIoI!C2Y6MVfp##*ctL%fs=dtf2PV<& zU%ZSw3UyQp|NX2S%tyRH%g7O6t0yzRj_Yu98daCgO@4-fflznE6Yz0X=6^gBR^V!n zJPb+{6f~uA1;(!4SR{!UrVNQ~=5Os*4IKt2D#uBnRY^{x!>H!Be@d-Fs3&5E{m8{Ax6 z3y|WpYbRsx&KRqg01;kS`Ks7OA&E`Z@SC@AofH(-`3us62|hB5&spnoyB4=yd}m_M zM)x!LAv$|QA-%32%FiVGFnK0e*YA3t9|$P{-ee92uv|uUz7S-FRj|9|s|*Q&h&N^y zIY;*-fJ7o2s z{dag!Z2)c`$!Cy8QVioLRJA9ZhcfJ!hJjz<_C!krd`R*i(C7Bb%A7zdsJ1t zP#}p1-|{cy0y@(Fd*A!i>(`6qiP$M5#-`WHtBf99ylig@tqD4I+=3zI6j7|tojwio zpN>a|dW>PMky$Wt>_}<}pw7h-b7?@EItH&i`FdRdA5dntt2Cdiat@$Rznzpp<(it&yrrXvnNxT^MI;htrptjZc3He;U- z;=1M<1tmAXU*aJUxmxP@+wJ_RPaio$ za(5$N_bYDHJv(SNl0H)Y^sl^F-On)hAbIe|Q66b`UdguA9cAw5;f2Y|?CTBR^e5-j zdIUQ!(uJsFDdO1<#NNm)mv^nn{I|(N%zmg;dF*_s9sEPJKc@PllMckqltVl@WAt0u znq|iMelLd8@aFn$i4^pUHEpk@0plU=en`>BZb-(uAd5UZiQaefx@9T{r?wN8w~9@# zZ^Lfdxc5jY)5N4(y2Wt0i-}SIk7M44%iTWq@ZK9UV$Y49Y+rEacfoG#w&KzWTjB7_U#fV=aT&HX&#JdCs@ERJdfxPYZ{Xq$HhtJnq9x}QkA=+BJew~HMcx_oC{(VdI4X4q=nGOc75v9fLp#$ zO_o=-YE0>GwLj<(6X~(8T1}~k@WqeRoq>rI~CD&KZli{ z?2J`}S`d#~@w*muobVX0R=dUAqW#%rJ)dccX+4NjwRYiMt8uza+d6*&JL|dK4%Y>% zxP@x)7kcf|?A!2~{US~tC`<@p3@mvyxlG;AB&4J8P^jZ6Qd zKmxG-(6)zc#qMd+b6dFxLGf4RQ*3NZWL@4QX!x3yQPOU6fw=#h6irT6?pW}Y~#0+ z*XlPu;-@`7MGH-k_^j}Yty(1p#L%cT6S;PyP~`1x3w*-dd3utVE0}uY*Q< z$p@}BdDcQnWFqZ|N7dR|o-N%wfK4TYQxYhOu$CT{FILOAv;l&}%Uvq|$1icfdG>&O z=!@2WhSwRtgVCZ`I7mxp5^4ULnfNxaRj<9s%^h-MUXuGx|0g?>fK}3!V4SRnNAwM; zYthw3pHtCJsLu4^1WiTpwlQIvxsbo!Wl<&51pkuhf-;tPJ~0dMUu2dNMo3`6(?RaG zBw?LU%FLz4pqI4qQ4wQu#4{(Lc7`fj2ygbQ zC@fsoOHoc~T7I;+Wks_}{^XLcP!;#Yb{aK70t{282tYT$L{=l8kVh=u!0*$W)TwY8 z1T**(F3V+>*gER_+vvNL#b)gDgk2f?YfxV5n_|f3Vv;dg+!zQ%0F;t8`HGdFnKN7! znd?QuA1c}Y&x%S;|_P`y%O^$AD-^rhUBu zT|o43=+LjnIP+A|r7h^iA-;Y**Kn4$-i2QNZYgZm78AeV4=pJ_j=ne!0LBFC2pMs7 zFO|>df&QX&S<|0!&?*D)Q?cV{`ker2_8k#Nyz(~#q_9zD9L#6SA2=Or!m>>dHuO=s& zH3_HP{vjaW>7_- zNKv&iDIpL4#L<(FVLC6tMpc)Am=+7$IuW{H2n>SGXT}WfsYhTw)*^k^z{K*4yqK9- za8I(AGTZ*NCnOC^EacbiO}QnpYTWY%6FPpS2kZ|w8aftlEZ+K;DhDrdp@Q%}u3JU9 zqbnV%Wp$x=ML6#anCM`1`gTZ+yHDE!RNo6l3SiQqk%LU5HaKA~v9J+1EFfC`UPP9M zwO4Nf8JkQ40B0y%5SUFk$0s^GxvT0;f+=qqjvY=(Pl5@wEK>PW?g5EZVq^`KRHVoT zTsC>yL^aD=3{N zp{Zb1>NplkJR8C+7#I+XWMWak9opqVf+{lrs>xfQo>tH!97`?qOfy}{v1EkWN{6WQ z&Z~_4AK!z6>i@ugvnJn)Udm%K?@(ua46v7`WAfdbp)I8o`#z z;$hIkGcI(z2$`My?~tHOOX=OEC2~%lUF|f7DVbSrisl_F29wPbRz@kU)(oWp*1w5G zMRyCr^YN>5klf1y+lw4KWXLB;JN*KOw@zP=iwt4$&f*mnXCEDx8_Z{42JQTZ3`>6c z27Xla;Gj3ivv2l#DJeqpk}u@JG|*zoViR2&MeDn0mEo?Nw8uEVO)5Ir?Uxa|7shye z@0^L7q3t|BtFMWrsg{)rm>~jNVpXjOh;b$5sVt1YS0yb@mJJk$AdN$Mfc4%riD8IU z&ci%b<9&l(j&uJ07A*kqD%eo<1-mH?7Q6ZY`@ z7wp!9PbcO|sl#SGAimIlTBT8#BHoLNUjq-MNcL14&jclJI|CD5U9r5YCv9A&KkR%f zwNTZHiRQycyq8H3%HjjHu~UBp=@ zmKaIJ2E#wK>Sc$PWVuE$8PZsJ19d~}7rJ?IVE-D09N8YV5k@111(F=D?1~6QLJ)}d zNv^UYW_@m4TxwK(o(Bv^EA2|v4k+8bpkKF|9%o;vAbya{Lpkf-mp-#!UPG+J*A0;0 zX|r9-D*{z9Wl$mEOcmf59w&d$P~gw`17EYsrb-H0emM*d{@8QTnu;^6>6dc9X7GsR z7~7xVD#$igf@2z`-~Vk#{V@BnHYOz@nD|7Fe)a;*6X-YiLrM~Dp9$jqYA&D758{C5 zT$*pIUgL4|!2$B)H#jKVh)HP$8XQpW9xg;MHSLM8EHl>=N_KJpejE^_e&DmBnApIP z!(`Nple}XOVmuV#ZE(q8ocJ`OuRH=y_=HA8QDcUcl0<~kU=~$TL3?t8E_z@lj_WX{ohyp_!J$;p!E(gOLL{>X!fYYEFv!8M%?x?#oJthT;I} zKoRgNYsR5%lZk92Dz4!gg1VsXU{c|32R+XD8(LpUtO{E7?Bk~np5sWKO4tIvjEeL; zv^2plw;#eyQRzb5%d3s zJ+%b}W4!Px{IDRT^QeCXXrL3Z*yIQ~q!pckFNY=n=9(6(LSl6r*#)##S)W0r!Qr?- zLK!VBrq(roJvtb05{H1`6)C}UXy^n<+%{YqCQkxY=ceJnpw}va4g=p|!(fpb$}#|) zD4Pm%tA-brj@(kpTK-OxD&CrAlWjP`UpAH<8@Q(yqQg8^m!Oa4~=&Ozak?bKj`&e@XZk@ z4`KEY@W90v6r?;iIuIDE#D9bJ?6_(8_BOXj&o(8)pm10TR2rX65!+}Sy)Y7S;KW^L zdcJn@)gdH9g9r0^kmsmp^FPgy%Xgv}VZ9`)5`Bz}1iYtSavsCRCp!@@exam>x19e3 z*QHT0M$rt#KK*?YA{UKTuS~+_S$W($3z8C5Xk8&qvE)Br>b#FFw>>vqU04DO?$>$` zNCm|TqWrA%YOOFdKgM~^zL1|nplJlOM6dxm=G<(YNBcgy|heQ#PfB{tq5@1=; zr!BdFOV{rKGGf$%I#EvqTJ{ao)nq}y{mFT_q*nHHa8C)P80UZGsX>KOQBdgWXA)+; zvdP^eL?z2i<%mL3b^`4rVUJ{8jq! z$%1G95o#$556Qc>64X;SZh7(RLJhA-&rPkB`9m}f(kp@|K$zBePVw97R>*}HX6yT- zZy11ZtmOV9hHv+ebDxtGJVfhBP0&02M^uo7vUDG{;K%;weieoK2^bB%(-U0s_|;|i zrWT3a?O^?q!K@YpV-T{?8}I$mTk*H7urCGKcL>1`4`V|!h`p4VP@(KVnuuJ>|{<8q@{8{0S=D6Znn6C^~wUN7LxRQ^4lvB?U5H`LJ3$AaK~*R z05tdkQH++hC~Ux7trmD+#Nx6s;iN`HresQ)@G z1y?~|mrCADK7B<2bhr22YLleq#7I=kO9R_if;$xeEwfsD9L@$ZG1lI~IlIL)ir1bk zoS2+8nAz4akc|rAsnU&D3Yx`a5UJ)PO^(x~)&cif)4;LBWJ#_V91NAss0vW6t&`2* z&Fs4GLC_KgzdX9-{`b5_cw3A8pD$iAzLh|6b!^-y+?BH7c1>tdW90i`&KKjOvSonh9n3}xU17eL^LB z9m2sHrU(wlI|Zu@Qz;u4jWWPCY3&(nR3IlQ1G+(S!EqXfV$Zq)h0ktZ!By9UhRL{B z-n7-7lkNXT%heLM@!!0?DdzcT`rI)pR5hGnEtNW!oIi{0!87I`N_Nr}PbrD8Nj^K| zY!q@`8<5?{_f5&d85ysEA~25VQ2ANgYxB~!(#x*)JL0zj9SgYURc(0;z=s6ah;{^1 zMSZZoC5y**YKX%|8(Js40`&P`E`@9TFwy{>OEGWqV@MRIANjfaJyv)LZ--YJ5_3rx z>OMnMX2TOE%l)4>t4BQNFQ80lnftW?6>Uz`N$qEM+Q3bVW8Qtx8bY*1V>F}ZjKTG7 zox(Uu(jzcfg)Apg^@yWj44u{Hus716q*9>~ARxeT1u@Uar-JP-eJ9eY%N>y3^IKXE z>Ia^bVnbmpaYus6Ur~QJ$MiPMCJu?6=W4W*IS4M)GB#&OrUr{t|F>SEXn$@Si`;j{ z`TTt5a;j~N3nNr$Rik+ep3sjwt*>>_3xQh7 z3Ho!sH0LH!k|3p54c&TeSvhoQq(uXx$aSF>BI)49DmdV?vFXgnfn;V%Yk@Bz+a3h;o52IrR3KH%UW zkU$~=AYm?(STh4VS~fQnIcX*|Jg5wcB}vdCKeAJI`jkoVO?Qfb{w?Kq0(oXBp<7^F zk)%lY!^JaRV3PnW!0ublJX^BA5wI$PnWOc!hi+DSc{0F9%Y_+{_xIJav#*DHay6%w zLbkGu2^jpL;XrJ9pfSfR)FU>^BLz88+QPa8UKkF8<(MaZl!zzVu&R+|jLoa-f0J1)Uh%XTq<8T*B{H*PK}keCfcP8K?GN zsdY+oU-&IPYjFEYa80?ua*NJcp*gs$2`B@>Ta7sGZ}&(mz4!n+3JiG0ub#N^b0vNq z|6MWIWMN`V@It?bHgUq~+kcx<%APBkcF+Vh{)_?7PT!qR!37-h*8Khq_(_SbF~BZ0!(G7_+jl)IOh1+h1SNERIQ`PJLOJx=+uU!Ku|I8x zzO3P)B_ouim`6utpiX^{H=H9!l^3uZj8Vi+s+j55(p!?I>dI;&e;r`zz@bThqTCC! zdVk#>kaD>4^{2a<`(eY%3L#k37tQTR&l#fpHM%%tRmy zNHs*=>U{g#nhGNjtD`wKa4yofN`k*?{aLLw-YMu@8&LIqqEAT#&w~RlCxACdRah74 zC>MY~V~#~f?xU@W@fEGR?Ks%SXUhBK#l)-=?*h3rP(u(S0$I3+8LAFZ6mZls0$^Yl z&iOu5BkwP^VJ>_JSQakm-QLv&1AtggRmMZZSeTq1Dg>NZ2*Ca*#-d*pr|3yXUvG-l z4qUyzR}c$B2TN7nvdt$-p%DvHolQt!b!LujfR3n_vBBdYS8aj?`jo%VY-(tphhm~F zv=L8btq!Ld2jb^f4b0o_G^i|>9U)h#f_cSN)WBw?4~9T7-DUq7>;w6ZGnq*TGyK#{DeVg2GX@NwOH z`O?sF!aTlf=bhW8a@@bsj^VL2hrB314(Fo_juS@G*OICg zDPN95CVvh^z9kS-DpU8o=8Mq1Ty-R*<=(6;E@=NzS*&WRIT>}V2= zbdAoR8CfcJog$V@XmyN!B#epd^wW+~^|ZenPQj{cC!J5z&7nnJm|a(K`*Oi*>k+;> zF*lmiG%-tiwECBzU`UVYRaVsO>yRMvfW*v0*(UvOZ!7&LkFGM1RB!K~?tWpXdHhzX z8q@_VG}L0EZ=P|wDK4Q5y6`2BUh=vP1ws@;lkc7b$P zFh+#(^(+<=aCBEzYRui&m?a8&oY#Ty9@-DbBBz)M-R)<5Dn&*+jI`|8*T3gz6SPJl!T7edoNxWVoS0L~EB_#o!#(h>7nFFh*G=i>LjnKVWBf4X=p;@`n&4eR84bi`01n$Jt%B2$qpSxG{< z?=xkB7qoh4{x=Owur11^g49Ad`0ZOMBjl+7NJOvylDl-#ShNgAB$epzC`&jeCMc6+ zux*7DD;ny^i4ElXj3#vwsaxL8$brM`knp(uYHwa$PoyaQ4#LAj%Q6|Cd%VVB=9FAg zlE*n{po~%o!zuSO42$I%^CwT*1z-mw_QOZHAVLiWNz|0`h)Kw zRj8xY$EHm-&#rSbB5t$pk@Txrn>;wTDd|o1f6wtf29}ByVSq~g=d6@ZUEE$ixVN#@ zy8|dGBGSn@Tcd(h5u{8VvsY|P| z4Lh{!O znpA@X#~TmP!|{t>io2eckzzTpD}Rh4spl{ruk6$B7et*w_2Jbzumk3Udo;GT}zws0-(;Z&!5Nly zl;L-^0t!i<<;NKH`hdXEG^_z*%R=3!hzj+7JbTe0eQT7q3E#k>{P(qJ9)k~CGw5@kG%y0iWrK1l$zlrH?p zZ;#4fw`=;STCM8r*>Rz5nT5;M5QQVc8g332D{Mr6LzKQ0$FpXi|^LkYYM+HYALltR0U*Eb- zj?ZaO&tR%TIyBX+KyVS3Ldx$0yKW1-W@+RlW6G~vgsU00RJlrB9ZKSF^ne`x&3!RG z9X8KCoD1x!JBwQfgnR(@P5x?vzY*l!YhX$vy zK?G)&hh^H2cgndih{zAy#1HVxOi4ga4hIk6jnLjUCKy&8mL~wn!%5#^d9nqA-Yc?+ zvM?}`$;#u!xe!@}G4|Ftk%jGI3}MxiL=`*7g8+Qm5Ds6cIBAyZEDp9-gYphCHOmMu zm<&&dT^7+W7BjqJ9Z)i;$E#y+d8qZuku@q1V$Dc!|%0P@7wXvJO0cz ztlR`IIf|1UQynqWOv?CD%QpW-7Ws2*6|45{pKo_-ZgP7GJ7uw0p$&k`4Om|!fh}SJ z92JfooBj94dpUk5^uEX%+&AQ=v}Dg+0j$LiUY+p=(fslVI!K{SPn%B;+Y+q`z>&87 z%t$Kx*ZetTeVy?6cP4O1>D=A7JJbd)h)2OX%<`!J<@5vUhU)E_T!u3r*d!0NgfiK@ z2!2w()wrnstX<$vqB(#1&F?p_n?CdWQJ84EOA0U;NqLtwwyflUes(30av z4gpFm5m1}m#z=GwLxHKQ1A&fOimH~E9jk!;AmBKDiiDV13t?Jh{3IYKEq~TzFDxPv z0Zm5tl3Oa=GGRhLA+4lU5)|3OwX?b5OgMlj@B~#Bj*ex2Q7&ba4g~;n7!#PK%$XlW zN`lk-Xy;w4=GRU~x!Zt)CWQ3HAb@E`po9XQGRKvNAyeEGk;gfrlxRh24BS~UTuVT& zsNzcmOsfwj0Kp@rDB{!(s#9Hdq0Duz^Lmn*i~!%UG1J?t-kMG5F7Z-fXKk>BrS>1= z02CV3apfk6MP?nOqQ*bO`JBetp#6O$b|E=`6dSGTi24)nLJ!k5>RiS?h26yCJc7z_l z^2#dNJ?edN-oh~jA&9W>Be!FrP3B3p?l2cAKDK;Vm<(ni6vU2zs8?iQF)D>>i#fsQ ziI5QEXq9qaEKM(Y_uAT&bje^0J zS8N10P*G(|Z5;+OtZ%l@Nx}v(kTB?Q7^7v${_K+YsGL$)@Ae$bwsA|>h2T$SbpXKJ zkn#&@j^g4;X;jq?>hA<9B(f@5>Qp;2BD&}2rwR1}PZ*~b0XC`Rf0Cl0WFQ`9;T}fe zt}_z%J7u9K_A7ElN?fR;mz8DXOvH|s7(b*J7FuX~IYTgj?hO|ux8on>&Wf;v;Av%v zs@ooTt}~PyaAfu>;&4XnTh5S#g1~xDpm~zW{jbenf&NiQf;yBK4xoaHk`~s8EdUVGNay!+syU|v6@Ehg7(l3GU#dVMFq}=rWgZTN^s(aA!3hfX z$ijdLmY=v^ z5Ry~klXHtU{?c_5WS#pRCWkp-Sp@PIRMMlYpPhUsh<)$FO}Cd#-%TPjpEvxns#@8ubV4Wd`zEp2-4khWqV?e)(L<^Jo7lxF}JPF9gL; z`{JEGTk*vc3?f{JXo*~>gWwAW>e!!(@l(V!MldPoW(hzOdr+oRZJzX(pTO` z7}!V+U`l62l*JnJT0Wn$X&3AF=v-w`GLog%i~Goa1&yIovF@h3>ImooP1+p`lpPQO`%&`)bT56zX?Y)1_j+E|b@ z*6R8j&ON+DC5B+Ohhn2?+`pkj9-jTChk88x4juw9;ys3Ej2%?3^bC zOE?W{kCeQWa+j1`_-kgRNp)psM3v*#<|Z}|aRN*c$qcr{95~o6@}}$%ZnlSBiMj*3 zCg`uhUGaqi7yeqoAkS5f-RE`s~Bu&ooHs3>fE z0@7Av7fkuzXf27X!71GY$^gT>8JNM}30MHN{_P%n)Sk<{S@&}j>uf2g4f1|wz_ZKI zbEIj+oNZUurnS|}nfUMLpImgg5sRO>v-3hLZp1#g2jcY6j;;UvZ;^%fp{d7)uk~N; z*MqUGTMG@&@=OPvj3(t%Z5OZ

%fOaPbyTQv%HY2-7b)qPG}C#WoggynqK z>2-u%GC5`y!a(w)lf&=lB`cV03)$3#^aVtsydB}@_JIjRP!MBY7!;bQCtJkv+H%hd z)49q%vR?xXz_e0OG-0#<*ei`z{=^4D9b>~|p)DM63?w#>jFnce-IwW&ra~hp7U)Wd70ec9sr zgu0fePQOrpS3Autuh^>?kb*_=As6=L&Bb>XsK3LD|6HbLjK6Zid7sqRJpJ15znNW} zyJ%u`;kTF?^kcy6$ezRd(EC3fPcZ5;ejSSHbvgm>IdwuxYQ#H$qKV%Ohv!w3CyG1LSp@d&vIrziAeqUrq<(lK{KaWq88Lt^X9YnD5S>^H0 z{lqI=VituMMHKa_{$~6&ry@{pYsz-o$&{Ntbpz%GFf(ZCU^6!Jkq_}z;dQ5G?*H6~ z&Xx4S@Q2Yp{i2T{rz2XDc`Hdg|L;Kd+q%)qZ@+T0QfE8=E@C+a9<2Me_SAl89LaLI zdpyTFp4;LfvDbSsHeY_dM4^75-qau6n+I7eewWC9hx+mSF&shNt!GWG{6M)bj})NZ zKK^los=qk-{QK#@L$O~dl%H{R9qNZ+kX#`|_C%vT)_#E(=rBgCn!OmJiF z2dJt+<&~{-X0M{-Y4XS3%Z)oPCxE0uF&SR3MV;0CEFhQl?on=FKw-cC{NTL6d zCEYT9mR4cxLQDGjw3b}lqjXLJM?ENUJ7KXktK^H!%>Mo#o;ul00IankTHsE)e;iiX zBUip>pN5@$O)xD6zc<9g3w2@B8SDy9H+S9SoBhaf`gW&_5kb_CZ1NC4uzSnQ?tC$dFP}Dy0;FoVN@lp<}^1Y+ZO~6xZ z2TC(PFXHr(H~zUmG$rfN9- zW0B4!`L5^R9m$NcT3hM^na>?^E(tvHThm>1z1G|-Oo?f@T=3^7%;3_d(tU-GYW|uC zAv|4K;}6-0Kq5PRE~eN6_t{=T(w6ur71oH^orjNI_A@6q)?WiGpSCClRZXs4- zC6y=+r7d-1c4%!-=6Lol5f4yoKs}YYBsFUmF2^3HOv^7SQbu~0@^mv@%vq=X`#|1B z{wU#4ylKyh{Zr9%4Yl(NIm+mwm9)sQU_-G__MCrbFVaMFZS6gV{6W#3jn3!Xg0AqB z&`YU34qub^DPf)?OGnXRS}U>`NJlTxt3#C+`10B|*q-X0Wu7Of5Q>C|#yzqp1RGf< z`cf#=rjuX+O4!rD;@dw%LxDxf&AB|ZCZONoq|gQtnh!*4gN5nehWuWJFy0T7ceq}l z6Ln=s_()ZjygiWgh5U{}GXV&#vp~X64#vY)akxiF%h;)%YLyL5B{7i#3T@fAB>Fl9 zuRXJ4Pvl(PRi_3k%qo44zkWB;lbbh`i<3&~K3b`&8DRf1TiK`JJ zB9W{gxgQwNZz#-0ZdJF5df;?4-y%41`t$y4!4unm@fFXTXd3DIGmsc66ZgvjANGC? zmzT#7y)7{Isaq7h%&}!6R0q1fZO)QDe@*;VJlNnbHBUnIjz!4hn3p)nw$MCy;8nQ_ zjhFceIMnmJ^R~pSypnTSp&#U-x1n$BNsTRkuj1UA8Hj{bXGX0wQAaAhd+j8vTf>8^ zl6-RHo-jo+ny~oB77{5M&oIAfs{_5jkt6xdDp=O$p8xf^nUvQo>h4oqxVDg>&IlS% zg-IdN84eUP)TaYsQh==VwE(7I(g|fb1~@z=0fRH%UYG)*|8!5d&-a5vq~?0?{P}}^ zqX<2?tA<~G@^Q>7Y^F_#27Qrw2$L%x?(O38l&CN;QX!pZchJ>wdZoWJ*ZibCMgy*zL%@oxDGOS+!qJp%!iR~6Icz&fjcyK?s5j9 zCq~%c9YyD(C+TDq)cRo7Q4nrtU1@s}kQEH~!}6?0AK7GA*0?;(KBDp@tY4Aa976i8 z_T^cp- z@44&T4;#y}yO!B@{tzAoaL>HA@eXEs%22^MGUrhJdA@4)f=5mvsPBBd6>X1mKa*^}jR0>&(!2C-OcC1n$pS+(wKGTh0S3pt zj{p6|Tht?n`x?vs_EzhPsa_bS*!uf zZ6GJ#LRhHf3x~-_xPZYOQb&9fNrD_Q<|9$OYVfmId5TYgb6KHp{S zeEAsvY{({1qTg|bbQa%jIL&|G9ie*u-nCP8CNeOR21(95eEadap!&zxtdTi-at~DA z&)Ao=`kExpoJausugm^^-Wxp)Z)%?lX#xlcU=giIKuPWxiy4Lu;N|p`X^LiSt_+8| z(*JiYdafq$RcgmP)s*H9{YcVlZH4ErOM{s%rGJsXQ)5r7YiD>3j2|tE{tF<$w5Tc> z&MajW-?*3&NE+n(>`Y`PUp$4Cs_xlW*I544$Ix`FG>Pk6l(thOvm=EnXf8}}*!n4D z`e#1ca|QkTZmxoMXMQ!{+WV~-)k;>K_>i;_GRGOYc~B5b7u3-+Q5K)y`L4M&gk~cG zh58+AFC+hI*uCfSh=iPDwU&&jmW#`pr<4WWiq4NY4(^NQQHClm;nM-wugJS>W`{o2 z?s$5XxE z#`pMkks}Umkq%@T4;DyQf&DWv^ngR-S^Ny5fP zDYmhQ!xUgo9zGO?D2finip-K>3-YgYbhJw*Mkw|V6ZnPW5cEuK<~Fy!nAA%kSYECc ztfX?^_{jF_p0prdky&RWm=#mSBfw7U?f4>hr|tCU7< z!n64U{wi=#rCL3{-0K->VgZ2BAgiSy$D2a>H6?A>)f-tPGfv&tnU8>IWRa1{A_I^g zEN+RANen;06#hSsuEL?|wu^3Kj4=j`9!TeaA)`ao(cO)pw19wgi88v8cBFI((k;?G z>3mI6N6~6K9FSz%&`#jIBb576vvs~nx89?JL!tCTWepEQ*(FAi4TKX! zRvG9V(MRA=&PR^IOj993C14h}@bE-jW+It-BNP`tnUw@8#Y+N#6UFQ?wHGhGpjl_Q zjFi z=&eqVk$_c#7}ywV_Pv$8akBM_q>9K%7&o+}_~($ABPFtW!RHc={}^T^@a|OR5I)I~ zBWS{YXVDX+(hNzXpZ0!U@vo? zjstzL`EXz#O(2+5E~h2D9O!LAE^<0ij_!Cn2@fCP33oraEgHsrV^Hnyt1*5z&g!4P zDrfpIc15+0UK$=HI4P;eEJlY|ua*6Ha^Fv`ttQU0&sr8GTSeM07;3XVWup(i57!ep6+JMg zTBFa}c>cA1N!1i#re-&*?naZ9urwfh@Zb89$k+P+T6v~8gu|{D_heQ-H&Q8O4n+|4 z5U+O&B_UCeLsgU2{#r}eQz=L%bqxtSgHjbu{($3;mKS5#WOC~a2n86akRST%ce`mP zR(M?0_@h2|oJBSj8Sh`u`d_~1Ksm^L^V>+HKY>pAa;<&XB6Pz_2xX8bz6F)#n@bl|0E$BrjN743^cX8FfiK{5~%$uk;PSdgk%}q9cVw1~XN|uJ43>dEb~Q$i}c= zn2#!drijx)sZi>`x-nHO{=aYj8dxY98yIT19p6NJxmC!E@D`9UB4Qp&5`(42MX2-# zrl-+YEHohHF<4Xi67s0AA)wcn@$mlb`ygbQ0e}Rcn6PlU7z^OcafoRtnbxsyuD~D1 zI?W}1EG=|$*`6@^5o_LaTaSRlirYSNTgoSQpK|zTH1^|1bx>fJ|8UpYP}SO<_~D+TC)m2rP(P?qT%=G%Npz(+|EDuj#Q625jj6+2<@tOSb=Gg48^ty zB+fP`-=h1~Sjmv!|DKFN*8v#$6n;1)2A~A6Gm!N`b1GS?X!(6E^Tz}w&6db&YmM>D zhO#}hg2;*lh_mDL7gnUrx-13#JGz@-wD7tUtlocAdotvDs!zf0jp8DDML@o(c6 zNGf`B>(!$`RRBii0gADxcCn&gEbLnlqrOo6r<0ama-(cC zq`Lr-!SIvVH2&iA2bH>4jDaB~PROIr-M-Ka1i}StYTIjzhA43R!2~|WD@w1QRcX#4kF^xu`}9$ zLhqe&19_84&E&9saq3pk;f*m`#dwB;2fIMmdz72MJh@0z_y&GIt7h11RsQ+t@n1pu z$#%oNU0YV@7ZweZ*q2w;F>UtA$zl{Q)QFQ*!?fPsGw)ZGC$)Ei?I&+aS^w-}i7mqC zTu|2mA1jMq$NZtZFL9T6W7^cse39j^iZUJjeMu&J=*SvU17q(<^Dc* zU&MoHZNG>msX(nu?6JNiWN-@<1U?P`61NF%xS#>*Bdhn!7%-NVw$Z)pdglCybRzDE zJ?h2bxt+VYn9Yj|WIn!RBS(Y-O_x^i>NRarctaFt7R7LL=3dF(tpK6Uu3xMbN>8kO zZSOd7xfEeV6ZCZ%^6q9xs4iS(FI(Imjs`c1#Oz-_mUHnomm?5CdG7<&PTPE@KAgsu zzI(Ojc8lo|!$YIW`RkaUmb$SN3fZ7GMfX4NvPdvU==~VG`R|@y&a-#OGX5FK;dIBj z0($=c@QgGsg%m7?pDQWGFecOvA7Ye}FeEqqCQ@AqU9y+8w^Y1K>6hOINZm}=Rc^SI zB?1ssp|u6jJOnwXpBY;baYTyWG?i)ff1Y{VXH%&();`L~dBXeW`PU=WcYg6oI9!K7F0q*CP-d_4hD(@`FOdJBV?ow>aE`J*m1&ULXR$>tegX6wx+l=(^z)L{b zZC~CP>gyV9SE$=Fj&U~8u#tyby~)|A;6GICKnZOwO7br6U07m?relsWHE(XHRle^$ zy|2Z@vczS+T3bN1x|2DNMKD=#)Rl&h-Rvz~kpJFW5A z{irA`W%FT4EN?4z{-d(Z-(E4(=L6od0$1%cH^XbPj{*>TgSvkkn@t#!8*0XCtXXqPpmm*R zgek2o6v&EBh(t||fFR+rlr)g8=ULkRG}p)Oa1cfCBmi; zC^UC(+dgs5Bg^m2E1Nw}G;fsC4977)aRq5y!9%g29<-3)Eq3ByuF7D)@?>Zbd}y!;Zllsy#{y}>Vi&x8+%k@Lo15Rg zM7{buR$VNAL-QP#`#@R0T&xrK^d`rjvA~a3H~?j4dH0@%_i5(s<+-i!u`fSb8tP6; ztp+t>3JlpOEsNcAS=k|oEN-Ef)@R$UHfw~&fK28sDjP!*ECqB@;mz24IpR4X+PuNP zai7Thq}Ic#Goayn+;?M5FyXFBchw4y@0EQ5Myht=p@*N%hhOqmPR){A?=>k)c6Wa~ zaB6?NxpYV7kHqVvwfwU%-DsvC{`z$^p+6QKsaq6yYLLY?uGqs8Xj39J15cX!@@gK4ql+L_6f!sE~f= z-1k&A7OKQdG$upO7}WBlFnGt7xcg1yPe5_MjV~O8H{Kc>SI|R9FsQZ1M z$6;>D+Tqe}IbMIm86k<2II{sVQ=a~&n4yv7R@UM{QZ&87e?wl*qKK9*fvVPVH?D^s zaxxPp9;Z1SKMc+7*chMC_=@}(&dGnw_>WoVBulyGdohK40B_z65Mz2|YKPy9=8MLm z6|v%jh=$Ku?`R}NmxP{qcYi$j_z#|ZGvxB&ADs|3OaIY1iPD;MD4R#;*T# za($l};_c_1dh{;(_XsL2nq1a1Yk!ofJnE~n^)MPH+`)Hk^OQiw!2dZ_?3+~mQn-21 zEikCK7OvJ-v8Cwb#aA6D&Ua}%bfm3wyWcN7^_4sWQ(zC$vweAZy)3W?6J zzFA7d<`8>&DDL&Nx52RrmI`QlOH@m)mRfm+s!UUMZ5jVhuGSK{oujH@fY`q{zdGKH9UaiiF7 z{wR}p2Q`BdvZzNlWcYnPFPVl;V4`?WiD@QhV1NWL4XTg?2RQ4;NZ0& zMjv)GHsNnonKP>9&F(OU1$A_<2=<=X=xDUETUpDyD;)YFDn_ z#(N!wQp&mUAdwqUQ5NO)lV=a*AMgs*I^H+wJ<-XpW~cOUX*_ToFW@{UJqZ``kQ4pa z`hD1bvB5WcA;EDiK*w|=2wybooGW!dfAFuJ#+&um-P;2C1y({sx5Y-|EBD1_w*Gqu zX2;`Gxl9a~Q58pfDll|D`z~>>TVv=5iehL z9kV;OBf+%PO_1v3@xiRrh5ZKthw#Tq(EZl!`Ya!G0S6*6zho(2(lk)}5x3s^9a$8z6)9;d* znBF-j&vWum6P368Ci(8c$f4N+PiZQQoF67hTfrZU!#rkV))0#jwqn^CbNfbNVVNhg z#_r#IqwKeV*iF`NTUY;4ppkVgL0 zilpvRd(E7uSwT@az_CVswj%S!oroEYa1waS+ngjMmdee*Oeilm&P=uH{SV(quJ39R zh&w{W^ToIRGr2H0R!X&i&!Kka1^XzaxO#=o`dFk3Y+E*a5<)KrRgo5pikv9TiOogW;)hjzc8Ouu;d?&E3ior;n5 z=k7;RHUT~Fj`RbpKQG6P{tUavvt=!GYl4Grx@SgM>n{yo%Hz!3mLKr%T%k9&rlR#m zta`rS=u4a?+rMet6Ms`_6m?>eztCMP{3&}5@O66gax2S*?k$i3`nfZ$}jXH)Fu{Vprkn#fZZEG{zAvL9dX%&(I;0C8EX(7JBKxsSz znv0rEVx?w?R$sIKV;ZW>RH2Q^WCoLUw)R5m+v)y%{wi=Qjpwheg^_x?{fd zt@$q2_A3=NSPwfi>yNz4zK2MML&*7%qOtb0|7#li!6y<&*@eW= zN6i@MbWyY3OnmwpMkbj-j?+GD%r%voTsZe=wI>#GG%NP1Qq;X2kA<3F|2=(9()DBC z*l6pa$SF%=O{<7?;Kk7|t49W`35!x}k}~!WG$t!|gN=HsooXG$pdV|<&)?7fXYo(U zZMwTYztUg4_i@4B?ktT-qd!gY7j`uy=WGq1=t+dAn8@HcC#`&Prj&RCKhtgKyQaOCJ0fM9sU zMidWZU7h@{6Ys9_$t8xOcEcHW~Se!EHxk2YkMLyhnNAJLg{okiTherL#+-Q zhc09vzW`lT6?0;>M+&h>90JalD5j{$!+k-OtCsqd>AO_D8Lt$ zHJmVJ;xPLKW5qxg4*)2)Bf%grP+O?O{=B+69E4rRSwt~WMIm(5=)_X1=x>PGTEs`# zHHsy}cuOOqii3ufc%U)L&_S+d04xc?02L@rohg%*y)Qw6`|rsgrcs5f?~87aET7g# zpI)6T6B|=MvMn~I!Y(ZzvpYQhMYZ*3`zq|u@=;)W7`2b>qrBxOZ<;Poe-T6;J^B3U z^BvzmZynzFn_pdJ-`M{}WW2h&AR-unvBh=2#Ybaj6w?e_cHCKMUEN%44P2vv#O_T* zxQj_rHPkjk?hWiKSLY4lKB^8oDV~=lx|CCb21cN#BaxPfjg1lkpdu9cT~B)F39WTB0v3~kqzyCRtOKzK8YQ<>wUR9o zoUpN%`E|c;OykWp=IA*aZDOG45K)C|hCS68K&c=CEK`TYOhos^Ew3d5-r#=&+-B}= z49z4n&jbCdkOoA_r})hE2Uk}++E?2|$b`c84gJer1+>l8d|2bDpWmB-q}UKzRwa;+ zW=!7<#G)YKh&jUx|Hi`P)6cw@;@boOM(d% z6O#n;)4GVR6I-#VB!ee`M4fV5Ajwe;qmgnYDjY4Z&yO~>PW^>axy>O+t~9Jkq_gEm zQGx(J|1$~oa;x71c-aC-`N5nhTQ!guFv5!g&k}wfZhISwJ05M&f(ruPASn1LY$3LH z$hkSiVZdVql%Qy!GQJ=IrvOU!7cnN0;k*XEkmyi=OoloFbwUgkwzKO9eG$lhM@qP> z;&Ckvpvcpf5#mhEX<_Hbi{n|wKmnKzAW_#60E_|SLM4$gTn?(QA7lr3z!f00TsV;{O_0Pe@LPdh|}o$Ml$6FcMM#59TiT6u=1( zgU}Dcjw3M5WXj+fb}HB@38$k-anx^VWDm1l<%s~xr(DI@Dkw7O8LseDW`u+;x)haq zn)qIVQ)^IJi^~BzB(5;r;Esb(&{bS@@gLL)v>8^&ki}qP6vJ2Is7b&vDAow2_-1~j z;0ikxXdT8~!$YP?!Um62OBv))wU{jY#pncN!KhFG1uC)Xp{}{QMwhP2Chj2uX2VCUGNR^ z7<2pWTB_2XTMQL)lSqf|B3%qQ2k%}sx&VsU%3CfRc7L4}1H9a!$}D|qy%o- z5yfK-m&ct5kYkc66I@6L?#z)?OR6JJ38=J;5e*)9IFDIh#YYuqh67(>6{Dc8c>dV5 z*ho>lz4BUopESNACOFPE311x{hxJRs0V8IDb4+6`V`RmURDiVzSphsnL_9(i4+R-U z6wxqq(~_qn?eHW~#SsW9q%etV_!tn6j9?BgVnhEWfg4Im7Tlmz4YqDr+32n~{uOdvr0Twt;g z!6JP@KMs};8iYO5CNhtC?rom;b zKSMBv%~63OC25HdgJbn`ZcZ{{)smjJQNaxi$CLd0Og!0p2miEh9Bp`(&m!G1@nK0u$Kx@tGB5`M{!T52F_oAMCIE0b2(%x} z79*Sb0p=pEf$R{BWHV8;T}w)WAklbEAsC1r48)b*h$kQ{B-UgXb|s8S)l+L{fe;u8 zGq1E=YfhaeBq<^f>j?%UvAoc95B4@5m#X2rVoFR7fA3@Mu?Xc@$I%t3weS^!@Ud$d zf776!?U)c%{`##Q#^<|*wt89W0w)a^m09}^gJBh-SSQVwx+>!0E%6=PGm@kt@)qVL zWkX^(F}4Y~xClK!3=I{T*aKX8xv_50G@F`q9%9Y3~JTonAV}DVXMm zd<<#}BV#U$9NYi95yG?5{;fXbZFqmIRgY`s(cZ$VDm-R5rsrPMdEQN67_cM74H_K@ z^fG`#?G(R9MT8Ib+e%~?mjA(^q6UiSr~%R8wteBzgWDXuxgO7)mkh^Kw@G2h5By3b zq+oT$ZAwPGZJhm9d?k5>G&q$stq(+@s0d&kh-ijT66i@p26XUXDqBtq0|O8bK2^z1 zOpeO}%Zgde?~5xHr>E3Pk{DlcC3qy?kO49djKx$E@%3tH$1PQ46i^Iz|D3Mipu*-V zi9u_rV8A$(R4|ZJJDGn+Va2ZrsAqyg%J?mhSSE{%zc1Dv2 z$|+BGmmh9zrb(WBv|n0Reufr|i3U|aIy}x&X4>qnPpu-6XMWvFU<+Bv+PA(`5Yzim z^w8j^e;*xKK&|s=s@e80Pp_E0*Wl>tFXmj-Gc8 zmDcHazV_ViY~-xj13bN%bo&*T)Q1vH28o~>KdW8a_jGz!j$+f!_ zZzY|TYMWFV5L`GnY`Yfwy#-2tW^G+{m57AmB=70_!qo%DsLm(GI4}xF9*E^+7qB58 z9=D!z63cNyE2+Jk086e80flNYU4_AIEb|;W+bUCR_jDQDy3Qz@wZo3;`C`0^EFzUc zHcP{iwQ{rtY=&{ev_D{RZrLe&3uWY~xooX23I1Ohe~Sdo3*CQ{nl$h=*mw|n-gYnM zNUk;F|M{2#FD#VEH8f^^XJ0^DgDQ##Ca2}S&FV*A{I7TVlJCiaOZl%bZGVMg*OIfA zYDUZjAWTtH_Jg^v%LgOFZ6AXl4XrLByouInM|2&7B`HWtJ_GncTA89nZSkg~cF$7! z+qBys!tP}XzR=yeMfngFx#H1wf+jG7KG+k!S2X##-*YRMJ>h*~X0l(r9BEEs{rv6o z(vI%`=8OJ~tc{g%IA_=FymUcmzdq@z)%@o7Db8`_6{(AnQiGFl!g!|%m?^C#W>x;D zjY!?q!iL)`8*2*_Xbi&wrdD z6;#-Aw41UP6ytW3s*n?=*CNg7nyfpM4B60u-~*5A`RSgt4pk^*KzsaSG8Q8uKZx0J z8>SDwTzh2l#rl1LegpZ9h3fvU$VnV`kD;<@6=E4r`OaiRjN;Yt9i-j zY36A?p2pDqW2N+g+iRkx$JH-e<+CmJr>02(MQf}PA6(Daq!A8*Va|bWgSu~umH%7S zE%(hmROYYyRjbzSc!TsxubNpAaKLSEd4QgAPf9EIdaYAlZ^FoZT~ogoyu}}_Hwpjk zi>;6@^$S-b@w6#VJJ9tM4hHMZ-PSJIDw{ir4h6)iNb&|`I9RgkcS4Zhw+vYZ3uutl zupWi@>swt$HsZeawmNJQhPXzk0DJq7zb67$u@_f#%NGT*CYN4vGf4pHV`p*Nz_6fa z{N?(MP%h1sFG8h$_VA+SI71$9iHA@fX#@6gtz}x_*KHK2tdxtDK%C28TG|vKPW+af zA?@ja5XI&oH}=}*kNm!EaZ@4cm?h>ik$A57*=nG-qt;y!(8y3z*+W(CRVVR8(XXZr zPitr0qMefs`c)ntJ`-tSJyZ^^aR1odtzh?1Gjy;irY__U;bQU=ZW3`qATCj zRKk3iRCYEkmtNIOo^x<5`&?eNUq1LP`jbz=&+RU+c38*&VQyg&bj9%G<%fj6MBY9n zB4&4fHARE2?Y?aTl3dNPh=Qp?Uquqh$87l9cNnLrm~n0M-p%)IIQ~I0HpxKiD7ZtZ zVhOCA`iu8*pY4@}1neDH3QDjwW|;ItlpP&^<#3Lf+K-{s{QTvp`|w{SUl{3#jno=Q z+ZNm#-Ibb~RlE|-Ba81L`^sO26Wpy7nfT!BEmG-YH%b}XtFivqauB*V8m+m-AeAY1 zv#}uL#}`t}gag~tp;uOOLT#6V>3AXRWw9Ni#JyGS4eOh<{))$kP(3c2Ue6mX2ID$3NbQ{K^%b+OYasARsr^kead1P6$VO<%AYb6#_A}*s6*V z&k>G%ga@R<1RB7FM+0%pYM3z=Mj%-nf@a4-Fv_kh#!D&zY%0e7FH=apC8& z13BS19FkTP$_rKY#&L30lF}~hI9l#)iyuCO?9>eelwiwPo^w(_x4Nc#!7gz$kWLPEo#e#1?u}x)hLGk z_BINgJY3bTz5881%RF0ARxpEhkJg&cu`o+UWxzn-CY$Z!SL4cv#8E!84X=b_`jQW~ zZ@Y>QfYMVPMl$?nF)TV^e?v0klIA97q>8?dq`>VyeKh7byWbt^ov;BL6Vw#?aX9Wx z!%fQ(1>@wm%UB)X8_S(6{YOdUTai!LoV8f(GXRJ!0>kbby_Y35YL8aU*j2T$Fc!A2 zj(C2^dglS2VHTNt{qAgF-Zwqn9#5=NppO^*GVuLomA^Y3dBj_fQo>$SwBWteWkgwI z?xfBXjkN_Fl2xMS7fkY3S9b|j&!BUJPx`(vG}KmOQrm0N?crkH!pre;KQb-!tBptr z(nuBONlMJ3l&fUHeL74+J5Dp{f>9V<5+#f>32=4_0>)~Pfx-(+WWX~$f9lcedEvn9 z&Z*-VRmlbS;~cxm+Hb!+%$xVXv+T-+!g$UGlmS2%YYHLnfZcqJV4w?|Ch(;`-Bm2j zey;vt&td1gaS%mF^50$-V!a$Rv?t%)<_8&^yd42}cF}#a4O%uq`#pv`-GLP7f(7rX z3b3_mZbir7O8HqDJH;2$UDKltE#?hEOBNOO4*Z2CMeeo3bDsO5b&a_b`)aSolK?Cmhn#c=oP@71N6uiM0v zCvOVtn^tH`u6;iw?yD1hAKtatYIF%fhMU?lQ4*65@Y(pp3|GG$6X$^tf00B@O7Kc0 zKbPr7st_TjPX%3Ml-P5qnO!vFSRp7brvhY6RhAU*bk}qJHz4#)_kHW_!k3z88Lc00 zJVcS(`UT+fOj+uh=`gmerPSxXUW(kc=G5*+LoBBvdq_U5Dht|(Ci+pIZ>WORCP}w^Xw;Pe)3^8?TQayYM3Vd145`Lu1|Ii>R%(I)D@?-%VIXVUHLK5E z)E`Zn0jMM5aNC&_U3rgFyQ!Zz>O0}d1C@OO-XQ}%?mkiWOm7BFwmL!Otn*QX3dlj1ubR4;ik9wlGMy9v|%U~`$1>Q2FR&)pqL^8 z8Ob?gKaW49sg4<$u~W!kWQmBzCW71=+qS>>y0S)(Y}5_bU^epUFbgS< zIwArr4n=}!m@^Zm#0;PUgjr1p2mNqDpGu8}GiOXTE#$Tm8g(&l~m$uM`w*$qD7!k9+0fLm$>8SGwBj z)`Nqfa9f_U=wzD4gBDYNMXodHx~thvE{-+6BqH#dm`liK(ngKEpRD3BC<)&g5Q$r+ ze)-E>k?}(ITa8ep8|bqlYsSQS!qgj+^IRR)Dw6xLH2F}&_{!Zx&lQsX#lFM4HGuq{95PkSGxZgPY^ zuD{})T@;xDzd4F=&^x$V3tYV5fBU!c`!$cmds5>)deRwr!W&j~OnJK2=r6HbqP zfntFF->K!i*Tr!&c7$9bK2q+)pm`8QqN3{$3{uu-1_X=%j-WU_%^pCb8tRs?y{xa? z-+|xF-tRqrD^h-m0jnXsTtmB(zMXc>bB#qOZ5WxqRsL;K@3xrF|L47AYyF3;G0$g3 zLdRgfi$Jf6Sf)YmSXPZctKEy$SGU!DdNXSa-kF(Qd8l>yhV@?bpbxHI8LoZ{EjhoR z5SMUNp!q4G`re7XExGHRw4_FxY)Fz68hn93 z*Nq@RcKs+T9u-F*Kq&l;P{Qbhsl>1SmUR!6Iy2$a*4P?S?r7JWUP5|KW8*;E5E5{` zB)I*qp3|m#?YM8b)fxl!Q_5VH^15DqZyt2QzMPVoMxGr%S?Z>FYKx{9-z&miocFCU zIb*`eRueQKyQ;8U|NTmZ*()B7^_M@r@5RN-XbV1u$}^CZ#V627_bieTC4Xq4Z?uS5 zMPkh;f9o0Xf1fk~c%D+f0b>7?nnfVv47Ouc;y-V9QPJwy5wVej`hDV=Uj}$XL{d?5 z1CD@*QjHrAJU3m9TO+W>f{GGz8rI;>d+I%{Z9*(MEXy}lxI1eF;FWAy-ya1;t9WR3 z7AuA3Mk=Td=IhS+Py7i5uSLn>D3L4fep6NBKJE2hAC}sqJ2IX&?fp5B4UJhpUb&Z( zr*pcMO-H*&uR%L963B9k*=uZF>}$b~%U)yc(C-VoSMpyUau9z1H0%`AKKXpH8z1}Z zGJ}=*p6(*CXGP>K6i@VwR(n5f`uNNTrtybId4si2>a-^ARCDhrOEKr9_&r;eF4;R_ zP6P5pB%J%btS6l=Bm+|pr)zdmp%)Sl0W*MfUD{Kj`>%_pFQ zuWZ_3_3v|-PG!Qbe*W}Kyyc&G_IUQDaZSqNK9ln%hY{3#?_@OHqA$x}79__!fi8(FNFTn#y9I*!?t7lP1idrOjEo|q6Ylj@iVS%A~;8lbG zxw!Zx+uZ3 zc}%Q1H#0u+z$Ka?6P}LGV@j^js^;ZOZ*J0ScYJYAepbw}cGOV)<5`Vr){Lq~zC*yO z`-I9~L9eEtr0%{EYNg+0JaeJP`5#|R!cuNw_>^eGM&_kdRI46wM--T_v*>>LEeZgk zqGI9`;)w^-A(2-2mOh)Y0mv|rKsZNWcLf)=NiLUB13%=&NZ9T3)L2@dh~~ z&Ap9p1I*rtR{G*)BQgsk9(HkT_=uQvVS7fom6I2NhQlLs@k3?P#k;}Da&@OY*w)0S z%JVJ1N@vvl+=t(W!>vprEbRr!bOz_*RVspiC15EhSOe{Ua{V&>8@ai(dtF};jFPB) zWb(#<7Vj&@i)2OKKlDi8)7Mw~vfcmRqqdKZ#!F7~OVyvVn<*z6UUVC|{QXdJ@UM!N zCAKU8>7mfG67Cb@U68z(IX%0Ot?J}2cv5E+ldXR@rT9Z|wy8_y&DnV;xo)}tn%}KO zYX{k}_jbhoS2At#ubt9R^|Y6{xV_M5Pvj_Yq^1u;WRZ#hVEe1D2p&;1Bib<0PDDTv z4OuI#OaB2YXnYSA2+n_gKfV8;me8nuFmpQ94TW#bw&IeMV~;pd7vLTn5Rh zWPsr@fU=p{z(=96+TB{o=p{e0%=;Db&X2h1l-r+~v7)u?zc?|LRviv6$f7~QplP^R zeA*q{xPzZbHEkV%<{l%@nKQjYo%sjZGYDq$97;JJDx;4GTe3g*Ond8EaQEHpq~Y)u zAz#bhTq!yynF5OAQXYQ$}|3)zL#qCwizw={In+Pflr9kI3z^u*DGEd~oBPoe=O zA>8-O245<2zH?~+<4nqFB8sGrkbq)SOcerA6=&3QYocX{+g{v2I%Gg4kefMIwy8^kE?e5I01~ebz$hBDD*&R&hgS#Tg^+u5}6ZL%%fmWJe~$6eU~#gS_{PstZc~B)eqn+&|NW3S^@S2BkSv0CA}4sWEP(F9du!HG zTPyM`&tjicn+UP}cD)>N1|=KMSY3^LCia~FJil%{U;n}`4zJMuf!$Ys@8&!Xo4j@T zFy%pQ?;pYC(_eS5E{^tp5t@oKe6sd{|AOrtlgC&|RX1qm-W_Oz$iNBF1272OjZI9e z_Jan5eLX_81`KCrkkuX zICB?=_?2bKaMFNutjQkghLlndHQXeDjYP$fDW7p*FyP@0I+`jF-1V6d2?mcG0NT=E ziTxJ*P>(?p#1XP>KHjH@sStRPYAv+{@kV|L`tv}v1GC}Nnz}wh&)UxlSR1#D*u1sq zt{2HFVbm8kd1B~YUbdE+rY`a1-I;BR7y6dltG%l_2-V)_Nt?gHFT*ZhvrP>Dsp$QE z{W(v6j(*S!|D}wGk+2s{`abv^?bdYmjSBW$x5f-ou|;~*ZS%y$KZ?eb z3L#wk+HT|Dq7!At6F!MJg(qghb6gt?zRV0U z`yL#Wlafb987j`eYrY8?RS#PUyU#tFv=6)Cb5PeJC*YDQd{>{3RC|%LZo=JxtIJX{ zzPctb!y)U8&sEf)R&_>OS>=y*IWpIBATt0pt6yPTW?fGEgp{?QA(EG;P=lHY73j zBHTsewv~KJDg8%fiees-${T#nJ(!t0HXRb`S?LVBUCJ-@^(YTtdysr|&H7W$q*o*H zX17&qjcoTQU1vn1Fx_IKD15Lx4Oc^!Jg4jQD$r|B%RM8OMog%Jqt2Im>|T!tnP*n- z^X^$ABDJMKSP|wxXG2tb`UWND9zkllEuJ^n*$SzFH7Ef;dQOppt@pdwW2L7^# z`Z;+jQ_2rfZG{TESm1z-K3W7OgN%$vUQW{b?JgC-61$cbgc zfq>&`o@Y~OOq?7-t@vUTy7B0t>DDE3Lnglw@rzG5M~or%kN51y$~ZfBuhrk<%Kw zK0&%iw=VS6f!Qi7jH;FI@{md8W)LO)IxieVh2r70i~N5sDmi!<#6LOn4{zlBVaEb$ z5+6?6$2KfTLPC#38m>LbTY6UWh;1rMs)O<>f>n($&wH}Hm0x`j#?!dnZDy3_wqrt$ zaO}iqzO6v#Jm==s(3Z5ij||;Wj#}*tB5T72PjZaVs{Su%zZvZqMt@@Vpfi)|TD>iJMbC zg&Cenu8#k}7+KOE&}7>X2e5Kflk4LL@GBzNbQg;HKr|b17E}({@7}gfl8Q(qvZj7- z--hOq@l?`8s=4pN>!i9p%(&FNW@ipf4C;$o#kO{z5esECj~FD*3y+a+;@=T6T}KOs zg;U@G;F}MYRxKypY(aeN04v+3AR9R!e8mV}qm7Q~sVv*J!xEo&Pk$%G~Ng z=T^M?qk;L1e9L=mSCj9=?yx>3&7-?2YIfa|UY&B&xVWHLL>e?V+0nInE;1$N8`9qP z=mQ*G$NaST>CfoxjmhB?^R0kE^y~%(!CAw7{|;=8F_)zgZ$Q5207#e>$BvBP=%kqY z3{85oq*#<+u<7y0&24}y;$}};?ReiVEhUF^40ARNjhAiK6uy=icJ+Fg^dp(}(_roW zOYok+@y0ECO9}(w{6Z-<*Fc?lW4YnKB8Qh~rrP$%jn!c)-N0-!6}d?QG(%eTxcE|}Jlqx$?dd>$bCg#3;1@8Yz%FJfn-9g|j*mQi_=#2&wo}J(1E=p2w+=r`x?VmfT z`P3#dCK5Z|?UUDjO^K^e5$msESL7H=?P88M38ns(Lu@4`g`3NFc-MGY#onNd(S$#>4}j$r6K z+KJK;g;yWxbrpdbe57{VoNzrI7G9Ot5?>sS-i)h1%*i&Z71eE7m=w|-8{QZ6ae6}!eY;!5I&g5v`$F|TA2@LFQFXq zZ&aw4gmhZtX$eE8fv$ps8vTMVDBQj~nZsCEc(C7{Gh)~!til(0eTZ%J{Dp{Zkxpj! zVhNVoib=8Mqg6kzCziY=Z+PUEO%x-PMwEG2^EbS9s!F<>KfX+pSn`$NPAFI_EhM8i z z-B~w8)qQV&h8TwK4(W~=Lb@A0cmbTL2`b> z_y6Y&Jm(dhbz<+e*Sgl{vh^{t$cQn+_rF_-+|p-e#o=Oz-w0NDRq{qawb(%7d1Xcw zMVcF8fI*yY;N9ZGJ)`&^dVlk?jcLayMe@BAA0rZJ<{Tq;73kKUQfRbCgO(njBTmGE$i$48)a;y2mEhCrsI$8VEX*gW}XM&;YGPZucb zjGtTc?8KAT3zZSxmozwpzWGf$F_bfyG93%S*iV`ykw{ZBzPha&QEXh;So7uPO8Usi zSX$sJuvYGa_y7l!6@3)MmLP6*uK;ILb3Z$$!*J#>K5Ll&La6>4PBN$~KROB~u*aI3 z*#_Yi$`riSrosB+Jiu_^0fvI>y3BdUyL{2hw_{STCL5dssG7{w+{LtKCtkbO=YJF3^T^v9 zHtu{g^4hsEm!yjW)k@o>M@i(QlGSbBOcY7X3Y|!gT z=io8MeRXDDjNTf{W1gQ(z`+VfMlH|)i*eR3= zJ6ZWS{7P57=)GE2z;K2(zBD=DZ7 zL6f+PQXkH1xzeHnoxG7OCpWP(G25LJG5|D~lJn+$tDjm0aGHrW;zo))e2LWG9f?6i)`ODq9@pUW zKfJG92^^XSzmj1%Nf)*U$v--&g(>AMTjdx>5Evoox8yw4q9RHcmq9w!_oH-6N}*}t z8$??o?bG)wCbl8_!`8|w@|Rd;gPjG#w0K=2vWwo7F$Z^ylW0c(#NQCD-$6s$uiw} zn+#JRoC^6V4M>*9DJkrawb@);$l&emeO^ojabx2u9F}~5-@L?0BbLFYd|3=N9v{4z zyYX3Ni)s~C-oh*B2&Gc1vPL+07LA-U?E<1?uvEeiNs9zqz5Zd9u*1P?*DSU( z&Tnnd$9&VcBjR4_FGCv(*K&lck+e|@Lp)Covw*2~9~?lU!!E-yTwVB-Vi|ZtGYm zZ< zqGj>)U*3ro+!cS*)l9$17t<7*{a+rg#YBVAM;9=S3Ej3FE^?`C`%Z86FZ_Pfe&Cq; zua%aLc11THI8LyCT0TVut401Xt^sqQd=ud-*pU=PZj(3Sv4v9>*E)#rPHdh zl?nfO3&T9iwD{EN8fl&dgvR9b<7ZBz!rvqPyBOZZ z4lpH)4zAp~T{BJ@Oivj--|L6_vOm(M)9fBi`&Wr|WNFse-MDh{5s9d`iw6iuu0Fp{ zV9ASgF4fExdTIGQBHx54@=e*#T90YVCFaBD>w5SZr86?~Mp4XG6i3Q$cy%^S_7I5q z0$sIc7u;4pR`uT-PaW<~uGb2=dcOJ5N41N_zzzBM!cGoA)<)ZlqqGpBs%vBy>y8m|E)hb)mmOHa_v56d?+6w(6*TBc4v1 zpIK_ESMxg=1=3fL*^kg9j_Ksb@UuyPb!5_$iJ1^b{rcoi_Ll8JuQ#jySL+rTTFu>^ z)YD!p8MVxL{CSG^RK`vc+raoKIX5y#z>_yl9e>`9ZJIsAA?T}GVs1$x)=T+@C@Q?K zj+^*8)B+R71}{ZkoG!mcHBjmla8YEd&&3$_0RQo1jjPmTW|ekukm*+0{`vh82!T<^ zh{}~3eUh41EBer17q7=pKhYjUADW-v&ZK}DjV-)fVfevLfHwCh`^Rxh!&z(Ev1c6> zDWh?-wO8vW9e^bjnTvSXFP4S=U|38jO*4Q%6c#vA;+3eORzgxB1(nqcW~z3r;q6cc zYvF4B{LIF7(|Q+un9R@-Pd!hbS|l!ZhBd4Mhas6xmD@>qP*41uFtjfcT)9EQ+Uu#z zf!b6+ogF1gruLHAdqFEC&6W;_PWnDriP5$lPi`dQa$P6aS<3y|Hy&EB@^3v|O(!Lw zGWsSLLUvJKWVznXr|>a6>zse(iNQsIo_h5T7oOAKn136|(;HL_Ozfmb_R{Gdm6J~~ zMp3fuTGw9vfx72C?Pl)|kMFu9n)=Z_Gse~PAI3r^D9F_j=Kr9w z-^79?vp%a+T2-WVlTyXW;2oTl7nC;V|Ip2j1t3H&mg~nZ^IPFS#vRriPO`HcqdqPG zg@&=X+-8Xfdg&h4!yUbYNzA$X_Ys^^cjISQx=e1h>>2x~H(9`3-A0QbCaF$0mkJe= z_ZL-s`)$1vj>UP>vXrAHKG}?&tOV9(3b+di*rzE$Q)i?MM*{{ zPv(?Op1%Up?k@C_!cqVMqu@q{kx>Lyx{s1Y7&!mtEqoRhbs8cOuj*i>VH5XWiKnqG zQ(Be%Ew`RoT~|p?#De4DNnaY*A_ev|Crig=k`kB&jn8*JtcOT_8%*71H#yKa7b`OR zy?|AA=jCkhE9J_0d%{RCnHt!Qmkd_G3lt$Xf~wRAkeRxWjSB?-aCVRWUf?n3NT$+1 z<*O2-RvW5t#z?8S2;MBj z0?App;W*N`!*nJZ(-}|PGC^%$bQ8RtnW;dg*?WVi3VUc8I{`LZKG%OX&}^{7uy{vZ zMza}E@KC6&Zk!#6pNwql5pcEz5s{~o!N}ng#SjRsN+O{CEaXE=U|ps#Wao}1f1m@I z2~umxbjf81w4#UeyMbzm%9(16rIURFnJ~TDgyBuC9E|2KhGtK=Bg2L_z-FSHjRM4(Y0<>rgF8KbcoxMRj@=o?5x3whK(ZIB(Wgy z(gWX1Ukns?WNdBQW1FPKuJ2qU!)JT?tl6^4Em7TRMuq3|Ty`TDt}m^vxlnaC)(8%) zFbPLCBM`A#``Elmc3d+r&>JVKJ)#Y#@VGzp>*U5RE&Kzh#GF!3#wIj*$z%9R%m`N{ zy+)OQ^ZV)DjXwXUi;$vaG`eXgHvMliH?f8nL5Etn4t%Qjm@n$l--Bjg`JsYRk`twk zTh_zcJ16|jJO(tV=Y?SfxloD*fYZ)KnS*y;cQcZ12uLcGC%FC48pPtN zE@JK8FUl+^RFg}#%2kb`XfH1{5;h;nmQotUh5-K)Ti;JKzt!-|GTDEjh_<`_6^F?y z2My9>&UAJHl`6d0OoY)yb|DT6A9bS13VurDaIexO0$EvQ4cbpK)}4JRqIBl$wNu@p zo5nm0r2Z&OzG^xALr{(+rfLZEFL3C;1ewU;Gkn>51UTC?M`6Euq5e75>lp@xCA4V7Go4KwpCt`!sbSt*7<|CQ$@!KC~7u)Kv=s)ZiOA4GAjr|beOx>z9wF2xY zEz@)2(y`b$Z@=^vA=#mWzfZVE0zqc328h!c4ye9to??ZT%3^LPYNHF(t3mVR!VX)3 z^rweaXP_%=M||pEiNTPhYtv9yjNnMnj9C8Jxo^FyFlEynce&Tj`-7D!QzNFQUoU%> z$>Kt+d6*00GQlS10fz(nE1`Db9F+t6x8bUuZUt5nd`?_EVJrKwozktp4_^!S&##Kw zocHH8a4B5Yc8CNzZxm^ITJ#4tW{9b3L?e)Yg~eS>Ot7dKwp$W%#5HxgZqc#Zk27!3 z9asM7(}9;6ke%B-NQvX^A-@W)@RiQPcQ_3#+Kts1!b`36n+Rr7`mg&$9INh#j%0-* zOXikjcr(WxXcO$QSxOtI>b`9}Mm`)_X7al*IG*if#A$K8#vj*B(oD>!u8bq9NSunu zJ%Uo4>ao*|@p@o&$83t&-gIis_B{2fD3pejv0gIIaJ#>vBIF-@`w*tb;mh!Rvd&ZT z+(2R0)fG%iOVeh0P%mqKlZjBWws z#UF+d|JmS!`Z9bbRKnPlj1SsC1Kmfo>(%cw{qZ5LSxEVHtj0icWoqkRo4OQE3TlZ# zyhtrosl)!PT~BOvK8Jq}_wI`xy^kB7QRgj1_+5JmX&--m{f6GFn~BxJ{d0)jfpp;v zb<$9~o_qkYcDWSa(utURH^rUHHB%&Gk>JzQh?AfB1_GrysyS~L*}*Q}r)yu?c^3!y ziSnEQPN?@O<&3lR?$X^-qxK_F7V0G~yHz$E3e_qmIW5O4>2<`(~ z1+A)2zWPzU6MNo9b42@ym0sRV|I~oFx<#BXIFswZawgL%vf&}F6{zkKX5G*nL0*fD zZ7KCA3#xWRt{U)+wdFRt-&1h?($;J(X_p>L@6s3v!<&3LqhH(?(_kGFFJR4>Pa+xG zH^N{>;ijUCY7;hsn$~k9#uQ}bY^GiCujYOsDYU=o`leDkv2Zfo?XCCGcz_ooSoywg zimaE#FzPs<`I7K8GMqx2TMtmF_#vJ|z6YHzwd! zFWVH`lUnnBu3w15TlfLn7fr~p$IhhOjp^%Q zSb2Q9GBGrRhJV5Og1lN`o2bX3G;_uQ<;trpPRgPZk82$~g$R>swWCr-S-P=0MoXsZ zwK9{O!fo`lZ^QyoL)@ZqouQH+rdbjaCr)X#UCxIzJ6}W`e1C-8zW!ZvcOKkm-iwmi zXt|ROWCdaY@L|c;{YvdDA?r|h{y;PLigJ~DHA&{_dSrIkTuerxDY&`bK%#4@utMgjPNl&O;|Qi0)h%V3hNM{p`$yry;jK-;nJ6d zMVA~RVqDJn5vhT=CGz}F6c_L}`dj6%wtVM4w6IZ$sw{^w*@VCU=5PBe)MlIC^buAx zkp?}2s+`ZiG^!js9t^}0VU*4O7Eq#h< zY8MfruR(s@OnF1ZI&|UQ<8OQatXKv|BBh#~rk@7&wW8y9K3hW@!LReDF>habPHG!7PO>~7(NVp(k8P0^b; z6Ta6J^8E)wyLsSbd~m%Ts*i?gZ7(`I4LZ2eZF9F)ibhr zT&A-6z3Y98l66B-lcO^w zE8m;S4xcQee;q2)eiqQ3Xt207L3@WkJf9}<08p&KN-)p|y`{J0d5Av!q!8$hF(N8@ z{^DK%DslEP;?wz^fUW0@LY`FyfI#p1?e9b~8?zsFTKwP01}{jLhBxo37Ohl}MXzYK zapIEpGsq}oV(}=5>dA8okR2GW5IIV|C#Z@6enAx&fnz{C?UI>j7u)@rOpg>|aT>fM z1|2Sb?wd?A=M?Tr+wxo!H4s|tzn0&5-Y18rGS`}(`>*6Hc;uZ*E)o#!f4Zt$ zyGR#bIVkm`=7rDKEA|PSmrr9@MHEz+i1xzFRn3)s7)3zPcXIkR7N-7#ziho#cQNMz zj8;v^BrVlEB!1X7eI{GbgBb^AcCQ2^cS21r%$sU|x42>FRdT zXroeTtJlhN>xtzeG>R{RWivI{r3Gu=UFDoRxV}{UcWnLXpL#9y-Q823e_h)M6gmNwgY|U7i zG3z$o`{wO`xoc)Mm){T{)_@-(a1D=Q={a6tAVUV#_$`gs3md+?P-XA9T z?u1x`2w{~B%r}PN>G5OdJv_t4oPl9iQYQ@o_Wo^~?7kP~Lb@FFkGGwh7xEUfzWbc6 z0j`z3^p^eTIK_upbLKrf09T|%2*olaW)e&*Cet#x1FrE^`C_g?8^=Ur(CIf05iJK# zZFf)Og!Y*-M>daU`(ynz^!nA^XUETKipD3C=1Qq5Y-~)dKi(L&3Z_&Dc7mk4E}caV zLFPFg^{?0qR#R3qf~HS}yeRv5w~~a>Bpd>CzI-YV`_2sH8-`sxsGdsVW+k^zM>qoR z!84?b`HJ+d8|ExdA^*eH#3oEvV{O%u|6FEe8awl8zTt3bYS|7xtgk`)QANs6Vo{RD zIUn1$hU}rh0-LV>QdOE;)qD1@-Z2|zOw)&AAI;tfP75_J|1d|%C;ivqgtFgo7=KGT z8^4XIdT4OecILr)eJ`OKNEg}Su4l_C1q@w8n#QB{1}d?>u6t9Ldr$H{ejHYbd*lV{ z2qM-h=_HAe#I*=EoV&#Wa-i7a>iOM8CsdXrO1sZ;yOStRC4wH}S`e%cN&tn#4x1v7 zIdn?6yu1*Z`B!v=(vhpEB^InVe4*OvPi8dPzE5mMlubGdz<8;k&L5lxiP-hbdFQ{y z#Bq;V9exTt@!$OC=s`8)knr*SVLDCF_Ln=iZ>N;)c$#}$K6tJYs|Qej)tD$nq$rJ4 zCdVJDBr)!Z<;oJ@7hAR5b<)kbsE=G0GVVOR6|Oxt#}uWbX*D$?cQ^h%G$Os|?xbQ{ z>x2w42`+nA=wd5m*zE zn7H;t1*$Ol&s0J*AGe-K&lHu0j%c#EU);wT0LE~c3{LJa3xc&sk(a_Xr1^a*pH8z7 z>PH$b0cGa19+}`2+aKk80?gr!m{5x-*SHgX!*+>dUH-u?#xMcmgxt!-@+H7=bZ2Tr zzyeCWRXgWg@jM#dY8^itk8g#SABCnaTz$^BfKC7U{E2Dz*Uam8HXd9%zvQ2Jjecpk zYxcGa&^R=~#v|pU>T=r?(aIcgdP<SdBa+zcMLy(JJ1~puW1oR4PvER zG@7|LUb;NWdK~QhBeicr`_Rfl)lUBaFZ~d~F`*vPfET}LSD&5&KlF6PmFCUwruJa1rkG_$InZKUe8+$pO z$9k=3#-y%g*{bH8@^EHckrKOygC>$^M5O|npQZ7o9)Rq`e-Ec($J6?S&RWemcsB>lTmRcVKsI8k|ny=nW3Q}Hq zj=iAJ_Ljq6PPZyPv3zToh^T`_)Nyms&X8=7Eee38gdg{Q_NG=OtzC<2GU`C?q;1bq z^8{F(#(gLdPT!t%JK~`)>xZcJe8kTrF2JP(M+kZs5|h6-jh>El63AprN-p=<{$TWl zKBD4Z)1MFZ(;qA8)Qx1EA}nJjR&j6+D6j}hUxl1JeGNgQ?a;Ry;t6TrjrmD$gy__4 zqpV^cV&9dGz6b-Mbf*`#p*iCcmAtqSj#-IoBuwLD@pu4bHdZ6LbJoaxCk)=_4=bl( z*8xcu9E6ACiRHHd9wotk+@y!ED=a}TO0vZ-#JV+2w%~9H(s2Sv`H!$kgc2dPAX$G) zuB?Qf3t{K${9XYj;Nk#-r-;gMbrr-E=?1{!5wg0^`p( z4!C^#1>=Ate|(&wo5bATjI`J{-Vi*0l<~vEh6QdF=V)nG=Oil<*_}sp5e`qWz2zjMyQOB~)}G zQJayLY&>(%RZ4*7#xuq7uN{0jIrJhN_TN>c0|0AE_@e-OC3_XXAb>Qhgj$ioxFl%* zorOPs&B@0F{l&dC<%a%dH3`1LTE=*Y=VueFfK5q;1>3=BhosMpp16w7c|0l;6LP+1~T zfiG+R1|(sHgt;nZbn5Jv=au!b5?=crn|ju5h252A&68$VkD_&D2!`rl%6jTz!cYjL Y?*G>deNH!g09?VM0wMu`|3{Gj2W%$_b^rhX literal 0 HcmV?d00001 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 5366f848b09..295bc48cf53 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -74,3 +74,14 @@ add_test(NAME ${TEST_TARGET} -m ${PROJECT_SOURCE_DIR}/models/for-tests-ggml-large.bin -f ${PROJECT_SOURCE_DIR}/samples/jfk.wav) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "large") + +if (WHISPER_FFMPEG) + set(TEST_TARGET test-main-tiny-mp3) + # Check with reviewers: any way to check the output transcription via ctest (diff, ...)? + add_test(NAME ${TEST_TARGET} + COMMAND $ + -m ${PROJECT_SOURCE_DIR}/models/for-tests-ggml-tiny.en.bin + -f ${PROJECT_SOURCE_DIR}/samples/jfk.mp3) + set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3") +endif() + From c10db6ea2883a4f77440fa8caeb296a0e351a58c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 May 2024 18:44:37 +0300 Subject: [PATCH 089/100] release : v1.6.1 --- CMakeLists.txt | 2 +- bindings/ios | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3eb12c10783..541be8a5d57 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required (VERSION 3.5) # Allow for the creation of solution folders. set_property(GLOBAL PROPERTY USE_FOLDERS ON) -project(whisper.cpp VERSION 1.6.0) +project(whisper.cpp VERSION 1.6.1) set(SOVERSION 1) # Add path to modules diff --git a/bindings/ios b/bindings/ios index 5cfcfb0801b..9a32de38144 160000 --- a/bindings/ios +++ b/bindings/ios @@ -1 +1 @@ -Subproject commit 5cfcfb0801be756d8347822b472e4b5e343f403f +Subproject commit 9a32de3814477ad2e598d4a550fcab4b23a9c576 diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 354d0ce903c..da6a9efdc6c 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.6.0", + "version": "1.6.1", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 22d46b7ba4620e2db1281e210d0186863cffcec0 Mon Sep 17 00:00:00 2001 From: Todd Date: Wed, 22 May 2024 16:02:52 -0400 Subject: [PATCH 090/100] ruby : update bindings (#2154) * update library files * update whispercpp * not needed for gem --- bindings/ruby/Rakefile | 12 + bindings/ruby/ext/ggml-backend-impl.h | 118 +- bindings/ruby/ext/ggml-backend.c | 2205 ++++- bindings/ruby/ext/ggml-backend.h | 217 +- bindings/ruby/ext/ggml-common.h | 1853 ++++ bindings/ruby/ext/ggml-cuda.h | 43 + bindings/ruby/ext/ggml-impl.h | 51 +- bindings/ruby/ext/ggml-kompute.h | 46 + bindings/ruby/ext/ggml-metal.h | 66 + bindings/ruby/ext/ggml-opencl.h | 36 + bindings/ruby/ext/ggml-quants.c | 12622 +++++++++++++++++------- bindings/ruby/ext/ggml-quants.h | 333 +- bindings/ruby/ext/ggml-sycl.h | 49 + bindings/ruby/ext/ggml-vulkan.h | 29 + bindings/ruby/whispercpp.gemspec | 28 + 15 files changed, 13247 insertions(+), 4461 deletions(-) create mode 100644 bindings/ruby/Rakefile create mode 100644 bindings/ruby/ext/ggml-common.h create mode 100644 bindings/ruby/ext/ggml-cuda.h create mode 100644 bindings/ruby/ext/ggml-kompute.h create mode 100644 bindings/ruby/ext/ggml-metal.h create mode 100644 bindings/ruby/ext/ggml-opencl.h create mode 100644 bindings/ruby/ext/ggml-sycl.h create mode 100644 bindings/ruby/ext/ggml-vulkan.h create mode 100644 bindings/ruby/whispercpp.gemspec diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile new file mode 100644 index 00000000000..354d8ef2547 --- /dev/null +++ b/bindings/ruby/Rakefile @@ -0,0 +1,12 @@ +require 'rake/clean' + require 'rubygems/package' + +desc 'Build gem' +task :package do + spec_source = File.read File.join(File.dirname(__FILE__),'whispercpp.gemspec') + spec = nil + # see: http://gist.github.com/16215 + Thread.new { spec = eval("#{spec_source}") }.join + spec.validate + Gem::Package.build(spec) +end diff --git a/bindings/ruby/ext/ggml-backend-impl.h b/bindings/ruby/ext/ggml-backend-impl.h index 31788cd6baa..f121e1de420 100644 --- a/bindings/ruby/ext/ggml-backend-impl.h +++ b/bindings/ruby/ext/ggml-backend-impl.h @@ -12,31 +12,63 @@ extern "C" { // Backend buffer // + // buffer type + typedef void * ggml_backend_buffer_type_context_t; + + struct ggml_backend_buffer_type_i { + const char * (*GGML_CALL get_name) (ggml_backend_buffer_type_t buft); + ggml_backend_buffer_t (*GGML_CALL alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size); + size_t (*GGML_CALL get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment + size_t (*GGML_CALL get_max_size) (ggml_backend_buffer_type_t buft); // allocation max size + size_t (*GGML_CALL get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding + bool (*GGML_CALL supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend + // check if tensor data is in host memory + // should be equivalent to supports_backend(buft, ggml_backend_cpu_init()) + bool (*GGML_CALL is_host) (ggml_backend_buffer_type_t buft); + }; + + struct ggml_backend_buffer_type { + struct ggml_backend_buffer_type_i iface; + ggml_backend_buffer_type_context_t context; + }; + + // buffer typedef void * ggml_backend_buffer_context_t; struct ggml_backend_buffer_i { - void (*free_buffer) (ggml_backend_buffer_t buffer); - void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer - size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback - void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback - void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback + const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer); + void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer); + void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer); + void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer + void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value); + void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras }; struct ggml_backend_buffer { - struct ggml_backend_buffer_i iface; - - ggml_backend_t backend; + struct ggml_backend_buffer_i iface; + ggml_backend_buffer_type_t buft; ggml_backend_buffer_context_t context; - size_t size; + enum ggml_backend_buffer_usage usage; }; - GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( - struct ggml_backend * backend, + GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, struct ggml_backend_buffer_i iface, ggml_backend_buffer_context_t context, size_t size); + // do not use directly, use ggml_backend_tensor_copy instead + bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst); + + // buffer that contains a collection of buffers + GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers); + GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); + GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + // // Backend // @@ -44,44 +76,66 @@ extern "C" { typedef void * ggml_backend_context_t; struct ggml_backend_i { - const char * (*get_name)(ggml_backend_t backend); + const char * (*GGML_CALL get_name)(ggml_backend_t backend); - void (*free)(ggml_backend_t backend); + void (*GGML_CALL free)(ggml_backend_t backend); // buffer allocation - ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size); + ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend); - // get buffer alignment - size_t (*get_alignment)(ggml_backend_t backend); + // (optional) asynchronous tensor data access + void (*GGML_CALL set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); - // tensor data access - // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize - void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - void (*synchronize) (ggml_backend_t backend); + // (optional) complete all pending operations + void (*GGML_CALL synchronize)(ggml_backend_t backend); - // (optional) copy tensor between different backends, allow for single-copy tranfers - void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); - void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); + // compute graph with a plan (not used currently) + ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph); + void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); // compute graph with a plan - ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); - void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); - void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); - - // compute graph without a plan - bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph); + enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); + // compute graph without a plan (async) + enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph); // check if the backend supports an operation - bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); + bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); + + // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer + // these should be expensive operations with large batch sizes that may benefit from running on this backend + // even if the weight has to be copied from the CPU temporarily + bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op); + + // (optional) event synchronization + ggml_backend_event_t (*GGML_CALL event_new) (ggml_backend_t backend); + void (*GGML_CALL event_free) (ggml_backend_event_t event); + void (*GGML_CALL event_record) (ggml_backend_event_t event); + void (*GGML_CALL event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + void (*GGML_CALL event_synchronize) (ggml_backend_event_t event); }; struct ggml_backend { - struct ggml_backend_i iface; + ggml_guid_t guid; + struct ggml_backend_i iface; ggml_backend_context_t context; }; + struct ggml_backend_event { + ggml_backend_t backend; + void * context; + }; + + // + // Backend registry + // + + typedef ggml_backend_t (*GGML_CALL ggml_backend_init_fn)(const char * params, void * user_data); + + GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data); + #ifdef __cplusplus } #endif diff --git a/bindings/ruby/ext/ggml-backend.c b/bindings/ruby/ext/ggml-backend.c index 128e33ce630..402d86ef3ac 100644 --- a/bindings/ruby/ext/ggml-backend.c +++ b/bindings/ruby/ext/ggml-backend.c @@ -9,31 +9,76 @@ #include #include -#define UNUSED GGML_UNUSED #define MAX(a, b) ((a) > (b) ? (a) : (b)) +// backend buffer type + +const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name(buft); +} + +GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return buft->iface.alloc_buffer(buft, size); +} + +size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) { + return buft->iface.get_alignment(buft); +} + +size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { + // get_max_size is optional, defaults to SIZE_MAX + if (buft->iface.get_max_size) { + return buft->iface.get_max_size(buft); + } + return SIZE_MAX; +} + +GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) { + // get_alloc_size is optional, defaults to ggml_nbytes + if (buft->iface.get_alloc_size) { + size_t size = buft->iface.get_alloc_size(buft, tensor); + assert(size >= ggml_nbytes(tensor)); + return size; + } + return ggml_nbytes(tensor); +} + +bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return buft->iface.supports_backend(buft, backend); +} + +bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { + if (buft->iface.is_host) { + return buft->iface.is_host(buft); + } + return false; +} + // backend buffer -ggml_backend_buffer_t ggml_backend_buffer_init( - struct ggml_backend * backend, +GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, struct ggml_backend_buffer_i iface, ggml_backend_buffer_context_t context, size_t size) { ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer)); - GGML_ASSERT(iface.get_base != NULL); - (*buffer) = (struct ggml_backend_buffer) { /* .interface = */ iface, - /* .backend = */ backend, + /* .buft = */ buft, /* .context = */ context, /* .size = */ size, + /* .usage = */ GGML_BACKEND_BUFFER_USAGE_ANY }; return buffer; } +const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) { + return buffer->iface.get_name(buffer); +} + void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { if (buffer == NULL) { return; @@ -45,10 +90,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { free(buffer); } -size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) { - return ggml_backend_get_alignment(buffer->backend); -} - size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { return buffer->size; } @@ -61,32 +102,67 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { return base; } +GGML_CALL void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + // init_tensor is optional + if (buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } +} + +size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) { + return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer)); +} + +size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer)); +} + size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - // get_alloc_size is optional, defaults to ggml_nbytes - if (buffer->iface.get_alloc_size) { - return buffer->iface.get_alloc_size(buffer, tensor); + return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor); +} + +void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + buffer->iface.clear(buffer, value); +} + +bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer)); +} + +void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + buffer->usage = usage; + + // FIXME: add a generic callback to the buffer interface + if (ggml_backend_buffer_is_multi_buffer(buffer)) { + ggml_backend_multi_buffer_set_usage(buffer, usage); } - return ggml_nbytes(tensor); } -void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - // init_tensor is optional - if (buffer->iface.init_tensor) { - buffer->iface.init_tensor(buffer, tensor); +ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) { + return buffer->buft; +} + +void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) { + if (buffer->iface.reset) { + buffer->iface.reset(buffer); } } -void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - // free_tensor is optional - if (buffer->iface.free_tensor) { - buffer->iface.free_tensor(buffer, tensor); +bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer; + if (dst_buf->iface.cpy_tensor) { + return src->buffer->iface.cpy_tensor(dst_buf, src, dst); } + return false; } // backend -ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) { - return tensor->buffer ? tensor->buffer->backend : NULL; +ggml_guid_t ggml_backend_guid(ggml_backend_t backend) { + if (backend == NULL) { + return NULL; + } + return backend->guid; } const char * ggml_backend_name(ggml_backend_t backend) { @@ -104,59 +180,105 @@ void ggml_backend_free(ggml_backend_t backend) { backend->iface.free(backend); } +ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) { + return backend->iface.get_default_buffer_type(backend); +} + ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) { - return backend->iface.alloc_buffer(backend, size); + return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size); } size_t ggml_backend_get_alignment(ggml_backend_t backend) { - return backend->iface.get_alignment(backend); + return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend)); +} + +size_t ggml_backend_get_max_size(ggml_backend_t backend) { + return ggml_backend_buft_get_max_size(ggml_backend_get_default_buffer_type(backend)); } -void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); +void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + if (backend->iface.set_tensor_async == NULL) { + ggml_backend_tensor_set(tensor, data, offset, size); + } else { + backend->iface.set_tensor_async(backend, tensor, data, offset, size); + } } -void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); +void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + if (backend->iface.get_tensor_async == NULL) { + ggml_backend_tensor_get(tensor, data, offset, size); + } else { + backend->iface.get_tensor_async(backend, tensor, data, offset, size); + } } -void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_t backend = ggml_get_backend(tensor); +GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(backend != NULL && "tensor backend not set"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - backend->iface.set_tensor_async(backend, tensor, data, offset, size); - backend->iface.synchronize(backend); + if (!size) { + return; + } + + buf->iface.set_tensor(buf, tensor, data, offset, size); } -void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_backend_t backend = ggml_get_backend(tensor); +GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(backend != NULL && "tensor backend not set"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - backend->iface.get_tensor_async(backend, tensor, data, offset, size); - backend->iface.synchronize(backend); + if (!size) { + return; + } + + buf->iface.get_tensor(buf, tensor, data, offset, size); } void ggml_backend_synchronize(ggml_backend_t backend) { + if (backend->iface.synchronize == NULL) { + return; + } + backend->iface.synchronize(backend); } ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend->iface.graph_plan_create != NULL); + return backend->iface.graph_plan_create(backend, cgraph); } void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend->iface.graph_plan_free != NULL); + backend->iface.graph_plan_free(backend, plan); } -void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - backend->iface.graph_plan_compute(backend, plan); +enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend->iface.graph_plan_compute != NULL); + + return backend->iface.graph_plan_compute(backend, plan); +} + +enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph); + ggml_backend_synchronize(backend); + return err; } -bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { +enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) { return backend->iface.graph_compute(backend, cgraph); } @@ -164,6 +286,13 @@ bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * return backend->iface.supports_op(backend, op); } +bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) { + if (backend->iface.offload_op != NULL) { + return backend->iface.offload_op(backend, op); + } + return false; +} + // backend copy static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { @@ -182,27 +311,20 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml } void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { - //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]); - //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]); GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); - // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src)); - if (src == dst) { return; } - // TODO: allow backends to support copy to/from same backend - - if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) { - ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst); - } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) { - ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst); - } else { - // shouldn't be hit when copying from/to CPU - #ifndef NDEBUG - fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend)); - #endif + if (ggml_backend_buffer_is_host(src->buffer)) { + ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src)); + } else if (ggml_backend_buffer_is_host(dst->buffer)) { + ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); + } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { +#ifndef NDEBUG + fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); +#endif size_t nbytes = ggml_nbytes(src); void * data = malloc(nbytes); ggml_backend_tensor_get(src, data, 0, nbytes); @@ -211,318 +333,846 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } } -// backend CPU +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); -struct ggml_backend_cpu_context { - int n_threads; - void * work_data; - size_t work_size; -}; + if (src == dst) { + return; + } -static const char * ggml_backend_cpu_name(ggml_backend_t backend) { - return "CPU"; + if (backend_dst->iface.cpy_tensor_async != NULL) { + if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) { + return; + } + } - UNUSED(backend); + // an async copy would normally happen after all the queued operations on both backends are completed + // sync src, set_async dst + if (ggml_backend_buffer_is_host(src->buffer)) { + ggml_backend_synchronize(backend_src); + ggml_backend_tensor_set_async(backend_dst, dst, src->data, 0, ggml_nbytes(src)); + } else { + ggml_backend_synchronize(backend_src); + ggml_backend_tensor_copy(src, dst); + ggml_backend_synchronize(backend_dst); + } } -static void ggml_backend_cpu_free(ggml_backend_t backend) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - free(cpu_ctx->work_data); - free(cpu_ctx); - free(backend); +// events + +ggml_backend_event_t ggml_backend_event_new(ggml_backend_t backend) { + if (backend->iface.event_new == NULL) { + return NULL; + } + return backend->iface.event_new(backend); +} + +void ggml_backend_event_free(ggml_backend_event_t event) { + if (event == NULL) { + return; + } + event->backend->iface.event_free(event); } -static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { - return (void *)buffer->context; +void ggml_backend_event_record(ggml_backend_event_t event) { + GGML_ASSERT(event->backend->iface.event_record != NULL); + + event->backend->iface.event_record(event); } -static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { - free(buffer->context); - UNUSED(buffer); +void ggml_backend_event_synchronize(ggml_backend_event_t event) { + GGML_ASSERT(event->backend->iface.event_synchronize != NULL); + + event->backend->iface.event_synchronize(event); } -static struct ggml_backend_buffer_i cpu_backend_buffer_i = { - /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, - /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .init_tensor = */ NULL, // no initialization required - /* .free_tensor = */ NULL, // no cleanup required -}; +void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_ASSERT(backend->iface.event_wait != NULL); -// for buffers from ptr, free is not called -static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { - /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed - /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .init_tensor = */ NULL, - /* .free_tensor = */ NULL, + backend->iface.event_wait(backend, event); +} + +// backend registry + +#define GGML_REG_MAX_BACKENDS 16 + +struct ggml_backend_reg { + char name[128]; + ggml_backend_init_fn init_fn; + ggml_backend_buffer_type_t default_buffer_type; + void * user_data; }; -static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 +static struct ggml_backend_reg ggml_backend_registry[GGML_REG_MAX_BACKENDS]; +static size_t ggml_backend_registry_count = 0; -static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) { - size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned - void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC? +GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data); + +GGML_CALL static void ggml_backend_registry_init(void) { + static bool initialized = false; - GGML_ASSERT(data != NULL && "failed to allocate buffer"); + if (initialized) { + return; + } + + initialized = true; + + ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL); + + // add forward decls here to avoid including the backend headers +#ifdef GGML_USE_CUDA + extern GGML_CALL void ggml_backend_cuda_reg_devices(void); + ggml_backend_cuda_reg_devices(); +#endif + +#ifdef GGML_USE_SYCL + extern void ggml_backend_sycl_reg_devices(void); + ggml_backend_sycl_reg_devices(); +#endif - return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size); +#ifdef GGML_USE_METAL + extern GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); + extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); + ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL); +#endif + +#ifdef GGML_USE_VULKAN + extern GGML_CALL int ggml_backend_vk_reg_devices(void); + ggml_backend_vk_reg_devices(); +#endif + +#ifdef GGML_USE_KOMPUTE + extern GGML_CALL void ggml_backend_kompute_reg_devices(void); + ggml_backend_kompute_reg_devices(); +#endif } -static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) { - return TENSOR_ALIGNMENT; - UNUSED(backend); +GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) { + GGML_ASSERT(ggml_backend_registry_count < GGML_REG_MAX_BACKENDS); + + size_t id = ggml_backend_registry_count; + + ggml_backend_registry[id] = (struct ggml_backend_reg) { + /* .name = */ {0}, + /* .fn = */ init_fn, + /* .default_buffer_type = */ default_buffer_type, + /* .user_data = */ user_data, + }; + + snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name); + +#ifndef NDEBUG + fprintf(stderr, "%s: registered backend %s\n", __func__, name); +#endif + + ggml_backend_registry_count++; } -static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); +size_t ggml_backend_reg_get_count(void) { + ggml_backend_registry_init(); - memcpy((char *)tensor->data + offset, data, size); + return ggml_backend_registry_count; +} + +size_t ggml_backend_reg_find_by_name(const char * name) { + ggml_backend_registry_init(); + + for (size_t i = 0; i < ggml_backend_registry_count; i++) { + // TODO: case insensitive in a portable way + if (strcmp(ggml_backend_registry[i].name, name) == 0) { + return i; + } + } - UNUSED(backend); + // not found + return SIZE_MAX; } -static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); +// init from backend:params string +ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) { + ggml_backend_registry_init(); - memcpy(data, (const char *)tensor->data + offset, size); + const char * params = strchr(backend_str, ':'); + char backend_name[128]; + if (params == NULL) { + snprintf(backend_name, sizeof(backend_name), "%s", backend_str); + params = ""; + } else { + snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str); + params++; + } + + size_t backend_i = ggml_backend_reg_find_by_name(backend_name); - UNUSED(backend); + if (backend_i == SIZE_MAX) { + fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name); + return NULL; + } + + return ggml_backend_reg_init_backend(backend_i, params); } -static void ggml_backend_cpu_synchronize(ggml_backend_t backend) { - UNUSED(backend); +const char * ggml_backend_reg_get_name(size_t i) { + ggml_backend_registry_init(); + + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].name; } -static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { - ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); +ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) { + ggml_backend_registry_init(); - UNUSED(backend); + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data); } -static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { - ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src)); +ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) { + ggml_backend_registry_init(); - UNUSED(backend); + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].default_buffer_type; } -struct ggml_backend_plan_cpu { - struct ggml_cplan cplan; - struct ggml_cgraph cgraph; -}; +ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) { + ggml_backend_registry_init(); -static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size); +} - struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); +// backend CPU - cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); - cpu_plan->cgraph = *cgraph; +static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment - if (cpu_plan->cplan.work_size > 0) { - cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); - } +GGML_CALL static const char * ggml_backend_cpu_buffer_name(ggml_backend_buffer_t buffer) { + return "CPU"; - return cpu_plan; + GGML_UNUSED(buffer); } -static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; +GGML_CALL static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + uintptr_t data = (uintptr_t)buffer->context; - free(cpu_plan->cplan.work_data); - free(cpu_plan); + // align the buffer + if (data % TENSOR_ALIGNMENT != 0) { + data = GGML_PAD(data, TENSOR_ALIGNMENT); + } - UNUSED(backend); + return (void *)data; } -static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; +GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + free(buffer->context); +} - ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); +GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + memcpy((char *)tensor->data + offset, data, size); - UNUSED(backend); + GGML_UNUSED(buffer); } -static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; +GGML_CALL static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + memcpy(data, (const char *)tensor->data + offset, size); - struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + GGML_UNUSED(buffer); +} - if (cpu_ctx->work_size < cplan.work_size) { - // TODO: may be faster to free and use malloc to avoid the copy - cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); - cpu_ctx->work_size = cplan.work_size; +GGML_CALL static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; } + return false; - cplan.work_data = cpu_ctx->work_data; - - ggml_graph_compute(cgraph, &cplan); + GGML_UNUSED(buffer); } -static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return true; - UNUSED(backend); - UNUSED(op); +GGML_CALL static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); } -static struct ggml_backend_i cpu_backend_i = { - /* .get_name = */ ggml_backend_cpu_name, - /* .free = */ ggml_backend_cpu_free, - /* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_get_alignment, - /* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async, - /* .synchronize = */ ggml_backend_cpu_synchronize, - /* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from, - /* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to, - /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, - /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, - /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, - /* .graph_compute = */ ggml_backend_cpu_graph_compute, - /* .supports_op = */ ggml_backend_cpu_supports_op, +static struct ggml_backend_buffer_i cpu_backend_buffer_i = { + /* .get_name = */ ggml_backend_cpu_buffer_name, + /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, }; -ggml_backend_t ggml_backend_cpu_init(void) { - struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); - - ctx->n_threads = GGML_DEFAULT_N_THREADS; - ctx->work_data = NULL; - ctx->work_size = 0; +// for buffers from ptr, free is not called +static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { + /* .get_name = */ ggml_backend_cpu_buffer_name, + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; - ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); +GGML_CALL static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU"; - *cpu_backend = (struct ggml_backend) { - /* .interface = */ cpu_backend_i, - /* .context = */ ctx - }; - return cpu_backend; + GGML_UNUSED(buft); } -bool ggml_backend_is_cpu(ggml_backend_t backend) { - return backend->iface.get_name == ggml_backend_cpu_name; +GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned + void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h) + if (data == NULL) { + fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size); } -void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { - GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); +GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; - struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; - ctx->n_threads = n_threads; + GGML_UNUSED(buft); } -ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) { - return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size); +GGML_CALL static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_cpu(backend); + + GGML_UNUSED(buft); } -// scheduler +GGML_CALL static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; -#define GGML_MAX_BACKENDS 4 -#define GGML_MAX_SPLITS 256 -#define GGML_MAX_SPLIT_INPUTS 16 + GGML_UNUSED(buft); +} -struct ggml_backend_sched_split { - ggml_tallocr_t tallocr; - int i_start; - int i_end; - struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS]; - int n_inputs; - struct ggml_cgraph * graph; -}; +GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ NULL, + }; -struct ggml_backend_sched { - int n_backends; - ggml_backend_t backends[GGML_MAX_BACKENDS]; - ggml_tallocr_t tallocs[GGML_MAX_BACKENDS]; + return &ggml_backend_cpu_buffer_type; +} - ggml_gallocr_t galloc; +#ifdef GGML_USE_CPU_HBM - struct ggml_hash_set hash_set; - ggml_tallocr_t * node_talloc; // [hash_set.size] - struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS] +// buffer type HBM - struct ggml_cgraph * graph; - struct ggml_backend_sched_split splits[GGML_MAX_SPLITS]; - int n_splits; +#include - struct ggml_context * ctx; +GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_HBM"; - // align context_buffer to GGML_MEM_ALIGN - #ifdef _MSC_VER - __declspec(align(GGML_MEM_ALIGN)) - #else - __attribute__((aligned(GGML_MEM_ALIGN))) - #endif - char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)]; -}; + GGML_UNUSED(buft); +} -#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node) -#define node_allocr(node) sched->node_talloc[hash_id(node)] +GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) { + return "CPU_HBM"; -static bool ggml_is_view_op(enum ggml_op op) { - return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE; + GGML_UNUSED(buf); } -// returns the priority of the backend, lower is better -static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) { - for (int i = 0; i < sched->n_backends; i++) { - if (sched->backends[i] == backend) { - return i; - } - } - return INT_MAX; +GGML_CALL static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) { + hbw_free(buffer->context); } -static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) { - for (int i = 0; i < sched->n_backends; i++) { - if (sched->tallocs[i] == allocr) { - return i; - } +GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + //void * ptr = hbw_malloc(size); + void * ptr; + int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size); + if (result != 0) { + fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size); + return NULL; } - return INT_MAX; -} + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name; + buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer; + + return buffer; +} + +ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type_hbm; +} +#endif + +struct ggml_backend_cpu_context { + int n_threads; + void * work_data; + size_t work_size; + + ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) { + return "CPU"; + + GGML_UNUSED(backend); +} + +GGML_CALL static void ggml_backend_cpu_free(ggml_backend_t backend) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + free(cpu_ctx->work_data); + free(cpu_ctx); + free(backend); +} + +GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(backend); +} + +struct ggml_backend_plan_cpu { + struct ggml_cplan cplan; + struct ggml_cgraph cgraph; +}; + +GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); + + cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + cpu_plan->cgraph = *cgraph; // FIXME: deep copy + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); + if (cpu_plan->cplan.work_data == NULL) { + free(cpu_plan); + return NULL; + } + } + + cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; + cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return cpu_plan; +} + +GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + free(cpu_plan->cplan.work_data); + free(cpu_plan); + + GGML_UNUSED(backend); +} + +GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + GGML_UNUSED(backend); +} + +GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + + if (cpu_ctx->work_size < cplan.work_size) { + free(cpu_ctx->work_data); + cpu_ctx->work_data = malloc(cplan.work_size); + if (cpu_ctx->work_data == NULL) { + cpu_ctx->work_size = 0; + return GGML_STATUS_ALLOC_FAILED; + } + cpu_ctx->work_size = cplan.work_size; + } + cplan.work_data = cpu_ctx->work_data; + + cplan.abort_callback = cpu_ctx->abort_callback; + cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return ggml_graph_compute(cgraph, &cplan); +} + +GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_CPY: + return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float + case GGML_OP_MUL_MAT: + return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + default: + return true; + } + + GGML_UNUSED(backend); +} + +static struct ggml_backend_i cpu_backend_i = { + /* .get_name = */ ggml_backend_cpu_name, + /* .free = */ ggml_backend_cpu_free, + /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, + /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cpu_graph_compute, + /* .supports_op = */ ggml_backend_cpu_supports_op, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static ggml_guid_t ggml_backend_cpu_guid(void) { + static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 }; + return &guid; +} + +ggml_backend_t ggml_backend_cpu_init(void) { + struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); + if (ctx == NULL) { + return NULL; + } + + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->work_data = NULL; + ctx->work_size = 0; + ctx->abort_callback = NULL; + ctx->abort_callback_data = NULL; + + ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); + if (cpu_backend == NULL) { + free(ctx); + return NULL; + } + + *cpu_backend = (struct ggml_backend) { + /* .guid = */ ggml_backend_cpu_guid(), + /* .interface = */ cpu_backend_i, + /* .context = */ ctx + }; + return cpu_backend; +} + +GGML_CALL bool ggml_backend_is_cpu(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid()); +} + +void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; +} + +GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { + GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); + return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size); +} + +GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) { + return ggml_backend_cpu_init(); + + GGML_UNUSED(params); + GGML_UNUSED(user_data); +} + +// multi-buffer buffer + +struct ggml_backend_multi_buffer_context { + ggml_backend_buffer_t * buffers; + size_t n_buffers; +}; + +typedef struct ggml_backend_multi_buffer_context * ggml_backend_multi_buffer_context_t; + +GGML_CALL static const char * ggml_backend_multi_buffer_get_name(ggml_backend_buffer_t buffer) { + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + + return ctx->buffers[0]->iface.get_name(ctx->buffers[0]); +} + +GGML_CALL static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_free(ctx->buffers[i]); + } + + free(ctx->buffers); + free(ctx); +} + +GGML_CALL static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_clear(ctx->buffers[i], value); + } +} + +static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(void) { + static struct ggml_backend_buffer_i multi_backend_buffer_i = { + /* .get_name = */ ggml_backend_multi_buffer_get_name, + /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer, + /* .get_base = */ NULL, + /* .init_tensor = */ NULL, + /* .set_tensor = */ NULL, + /* .get_tensor = */ NULL, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_multi_buffer_clear, + /* .reset = */ NULL, + }; + + return multi_backend_buffer_i; +} + +GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) { + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) malloc(sizeof(struct ggml_backend_multi_buffer_context)); + ctx->n_buffers = n_buffers; + ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t)); + + GGML_ASSERT(ctx->buffers != NULL); + + size_t total_size = 0; + for (size_t i = 0; i < n_buffers; i++) { + ctx->buffers[i] = buffers[i]; + total_size += ggml_backend_buffer_get_size(buffers[i]); + } + + return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_context_interface(), ctx, total_size); +} + +GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) { + return buffer->iface.get_name == ggml_backend_multi_buffer_get_name; +} + +GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer)); + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_set_usage(ctx->buffers[i], usage); + } +} + +// creates a copy of the tensor with the same memory layout +static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) { + struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor); + for (int i = 0; i < GGML_MAX_DIMS; i++) { + dup->nb[i] = tensor->nb[i]; + } + return dup; +} + +static bool ggml_is_view_op(enum ggml_op op) { + return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE; +} + +// scheduler + +#ifndef GGML_SCHED_MAX_BACKENDS +#define GGML_SCHED_MAX_BACKENDS 16 +#endif + +#ifndef GGML_SCHED_MAX_SPLITS +#define GGML_SCHED_MAX_SPLITS 2048 +#endif + +#ifndef GGML_SCHED_MAX_SPLIT_INPUTS +#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC +#endif + +#ifndef GGML_SCHED_MAX_COPIES +#define GGML_SCHED_MAX_COPIES 4 +#endif + +struct ggml_backend_sched_split { + int backend_id; + int i_start; + int i_end; + struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; + int n_inputs; + // graph view of this split + struct ggml_cgraph graph; +}; + +struct ggml_backend_sched { + bool is_reset; // true if the scheduler has been reset since the last graph split + bool is_alloc; + + int n_backends; + + ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS]; + ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS]; + ggml_gallocr_t galloc; + + // hash keys of the nodes in the graph + struct ggml_hash_set hash_set; + // hash values + int * tensor_backend_id; + struct ggml_tensor * (* tensor_copies)[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; + + int * node_backend_ids; // [graph_size] + int * leaf_backend_ids; // [graph_size] + + // copy of the graph with modified inputs + struct ggml_cgraph * graph; + + // graph splits + struct ggml_backend_sched_split * splits; + int n_splits; + int splits_capacity; + + // pipeline parallelism support + int n_copies; + int cur_copy; + ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; + struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; + int n_graph_inputs; + + struct ggml_context * ctx; + + ggml_backend_sched_eval_callback callback_eval; + void * callback_eval_user_data; + + // align context_buffer to GGML_MEM_ALIGN +#ifdef _MSC_VER + __declspec(align(GGML_MEM_ALIGN)) +#else + __attribute__((aligned(GGML_MEM_ALIGN))) +#endif + char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)]; +}; + +#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor) +#define tensor_backend_id(tensor) sched->tensor_backend_id[hash_id(tensor)] + +// returns the priority of the backend, lower id is higher priority +static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) { + for (int i = 0; i < sched->n_backends; i++) { + if (sched->backends[i] == backend) { + return i; + } + } + return -1; +} + +static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor) { + ggml_backend_buffer_t buffer = tensor->buffer; + if (buffer == NULL) { + return -1; + } + + // find highest prio backend that supports the buffer type + for (int i = 0; i < sched->n_backends; i++) { + if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) { + return i; + } + } + + fprintf(stderr, "%s: error: no backend supports buffer type %s used in tensor %s\n", + __func__, ggml_backend_buffer_name(buffer), tensor->name); + GGML_ASSERT(false); + + return -1; +} + +#if 0 +static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only +#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) +#define GET_CAUSE(node) causes[hash_id(node)] +#else +#define SET_CAUSE(node, ...) +#define GET_CAUSE(node) "" +#endif // returns the backend that should be used for the node based on the current locations -char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove -static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) { - // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there - // ie. kv cache updates - // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend. - // dst - ggml_backend_t cur_backend = ggml_get_backend(node); - if (cur_backend != NULL) { - sprintf(causes[hash_id(node)], "1.dst"); - return cur_backend; +static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) { + // TODO: use supports_op to check if the backend supports the op + + // assign pre-allocated nodes to their backend + int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor); + if (cur_backend_id != -1) { + SET_CAUSE(tensor, "1.dst"); + return cur_backend_id; } // view_src - if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) { - sprintf(causes[hash_id(node)], "1.vsrc"); - return ggml_get_backend(node->view_src); + if (tensor->view_src != NULL) { + cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src); + if (cur_backend_id != -1) { + SET_CAUSE(tensor, "1.vsrc"); + return cur_backend_id; + } } - // src - int cur_prio = INT_MAX; - size_t cur_size = 0; + // graph input + if (tensor->flags & GGML_TENSOR_FLAG_INPUT) { + cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU) + SET_CAUSE(tensor, "1.inp"); + return cur_backend_id; + } + // assign nodes that use weights to the backend of the weights + // operations with weights are preferably run on the same backend as the weights for (int i = 0; i < GGML_MAX_SRC; i++) { - const struct ggml_tensor * src = node->src[i]; + const struct ggml_tensor * src = tensor->src[i]; if (src == NULL) { - break; + continue; } - ggml_backend_t src_backend = ggml_get_backend(src); - if (src_backend != NULL) { - int src_prio = sched_backend_prio(sched, src_backend); - size_t src_size = ggml_nbytes(src); - if (src_prio < cur_prio && src_size >= cur_size) { - cur_prio = src_prio; - cur_size = src_size; - cur_backend = src_backend; - sprintf(causes[hash_id(node)], "1.src%d", i); + if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src); + // check if a backend with higher prio wants to offload the op + if (src_backend_id == sched->n_backends - 1) { + for (int b = 0; b < src_backend_id; b++) { + if (ggml_backend_offload_op(sched->backends[b], tensor)) { + SET_CAUSE(tensor, "1.off"); + return b; + } + } } + SET_CAUSE(tensor, "1.wgt%d", i); + return src_backend_id; } } - return cur_backend; + + return -1; } static char * fmt_size(size_t size) { @@ -535,14 +1185,16 @@ static char * fmt_size(size_t size) { return buffer; } -static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { +static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { int cur_split = 0; for (int i = 0; i < graph->n_nodes; i++) { if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { - ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend; - fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs); + ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id]; + fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), + sched->splits[cur_split].n_inputs); for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { - fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); + fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, + fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); } fprintf(stderr, "\n"); cur_split++; @@ -551,341 +1203,558 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra if (ggml_is_view_op(node->op)) { continue; } - ggml_tallocr_t node_allocr = node_allocr(node); - ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL; - fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]); + ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); + fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name, + fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { - break; + continue; } - ggml_tallocr_t src_allocr = node_allocr(src); - ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL; - fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]); + ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src); + fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name, + fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); } fprintf(stderr, "\n"); } } -// creates a copy of the tensor with the same memory layout -static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) { - struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor); - for (int i = 0; i < GGML_MAX_DIMS; i++) { - dup->nb[i] = tensor->nb[i]; - } - return dup; -} +//#define DEBUG_PASS1 +//#define DEBUG_PASS2 +//#define DEBUG_PASS3 +//#define DEBUG_PASS4 // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend -// TODO: merge passes -static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { - // reset state - size_t hash_size = sched->hash_set.size; - memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); - memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size); - memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size); +static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + // reset splits sched->n_splits = 0; + sched->n_graph_inputs = 0; + sched->is_reset = false; struct ggml_init_params params = { - /*.mem_size = */ sizeof(sched->context_buffer), - /*.mem_buffer = */ sched->context_buffer, - /*.no_alloc = */ true + /* .mem_size = */ sizeof(sched->context_buffer), + /* .mem_buffer = */ sched->context_buffer, + /* .no_alloc = */ true }; - if (sched->ctx != NULL) { - ggml_free(sched->ctx); - } + ggml_free(sched->ctx); sched->ctx = ggml_init(params); + if (sched->ctx == NULL) { + fprintf(stderr, "%s: failed to initialize context\n", __func__); + GGML_ASSERT(false); + } - // pass 1: assign backends to ops with allocated inputs + // pass 1: assign backends to ops with pre-allocated inputs for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; - if (node_allocr(leaf) != NULL) { + int * leaf_backend_id = &tensor_backend_id(leaf); + if (*leaf_backend_id != -1) { // do not overwrite user assignments continue; } - ggml_backend_t leaf_backend = ggml_get_backend(leaf); - if (leaf_backend == NULL && leaf->view_src != NULL) { - leaf_backend = ggml_get_backend(leaf->view_src); - } - if (leaf_backend != NULL) { - node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend); - } + *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf); } for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; - if (node_allocr(node) != NULL) { + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { // do not overwrite user assignments continue; } - ggml_backend_t node_backend = sched_backend_from_cur(sched, node); - if (node_backend != NULL) { - node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend); + *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node); + // src + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + int * src_backend_id = &tensor_backend_id(src); + if (*src_backend_id == -1) { + *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src); + } } } - //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#ifdef DEBUG_PASS1 + fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); +#endif - // pass 2: assign backends to ops from current assignments - // TODO: - // - reuse sched_backend_from_cur - for (int i = 0; i < graph->n_nodes; i++) { - struct ggml_tensor * node = graph->nodes[i]; - ggml_tallocr_t node_allocr = node_allocr(node); - if (node_allocr == NULL) { - int cur_prio = INT_MAX; - size_t cur_size = 0; - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * src = node->src[j]; - if (src == NULL) { - break; + // pass 2: expand current backend assignments + // assign the same backend to adjacent nodes + // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend) + // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops + + + // pass 2.2 expand gpu down + { + int cur_backend_id = -1; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + if (*node_backend_id == sched->n_backends - 1) { + // skip cpu (lowest prio backend) + cur_backend_id = -1; + } else { + cur_backend_id = *node_backend_id; } - ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr != NULL) { - int src_prio = sched_allocr_prio(sched, src_allocr); - size_t src_size = ggml_nbytes(src); - if (src_prio < cur_prio && src_size >= cur_size) { - cur_prio = src_prio; - cur_size = src_size; - node_allocr = src_allocr; - sprintf(causes[hash_id(node)], "2.src%d", j); - } + } else { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.2"); + } + } + } + // pass 2.1 expand gpu up + { + int cur_backend_id = -1; + for (int i = graph->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + if (*node_backend_id == sched->n_backends - 1) { + // skip cpu (lowest prio backend) + cur_backend_id = -1; + } else { + cur_backend_id = *node_backend_id; } + } else { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.1"); } - if (node_allocr != NULL) { - node_allocr(node) = node_allocr; + } + } + // pass 2.4 expand rest down + { + int cur_backend_id = -1; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + cur_backend_id = *node_backend_id; + } else { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.4"); + } + } + } + // pass 2.3 expand rest up + { + int cur_backend_id = -1; + for (int i = graph->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + cur_backend_id = *node_backend_id; + } else { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.3"); } } } - //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); - // pass 3: assign backends to remaining src from dst (should only be leafs) +#ifdef DEBUG_PASS2 + fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); +#endif + + // pass 3: assign backends to remaining src from dst and view_src for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; - ggml_tallocr_t node_allocr = node_allocr(node); + int * cur_backend_id = &tensor_backend_id(node); + if (node->view_src != NULL && *cur_backend_id == -1) { + *cur_backend_id = tensor_backend_id(node->view_src); + SET_CAUSE(node, "3.vsrc"); + } for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { - break; + continue; } - ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr == NULL) { - node_allocr(src) = node_allocr; + int * src_backend_id = &tensor_backend_id(src); + if (*src_backend_id == -1) { + if (src->view_src != NULL) { + // views are always on the same backend as the source + *src_backend_id = tensor_backend_id(src->view_src); + SET_CAUSE(src, "3.vsrc"); + } else { + *src_backend_id = *cur_backend_id; + SET_CAUSE(src, "3.cur"); + } } } } - //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#ifdef DEBUG_PASS3 + fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); +#endif // pass 4: split graph, find tensors that need to be copied - // TODO: - // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost - // find first backend - int cur_split = 0; - for (int i = 0; i < graph->n_nodes; i++) { - struct ggml_tensor * node = graph->nodes[i]; - if (node->view_src == NULL) { - sched->splits[0].tallocr = node_allocr(node); - break; - } - } - sched->splits[0].i_start = 0; - sched->splits[0].n_inputs = 0; - memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK - ggml_tallocr_t cur_allocr = sched->splits[0].tallocr; - size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr); - for (int i = 0; i < graph->n_nodes; i++) { - struct ggml_tensor * node = graph->nodes[i]; - - if (ggml_is_view_op(node->op)) { - continue; + { + int i_split = 0; + struct ggml_backend_sched_split * split = &sched->splits[0]; + // find the backend of the first split, skipping view ops + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (!ggml_is_view_op(node->op)) { + split->backend_id = tensor_backend_id(node); + break; + } } + split->i_start = 0; + split->n_inputs = 0; + memset(split->inputs, 0, sizeof(split->inputs)); //HACK + int cur_backend_id = split->backend_id; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + + if (ggml_is_view_op(node->op)) { + continue; + } - ggml_tallocr_t node_allocr = node_allocr(node); + const int node_backend_id = tensor_backend_id(node); - if (node_allocr != cur_allocr) { - sched->splits[cur_split].i_end = i; - cur_split++; - GGML_ASSERT(cur_split < GGML_MAX_SPLITS); - sched->splits[cur_split].tallocr = node_allocr; - sched->splits[cur_split].i_start = i; - sched->splits[cur_split].n_inputs = 0; - memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK - cur_allocr = node_allocr; - cur_backend_id = sched_allocr_prio(sched, cur_allocr); - } + GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now - // find inputs that are not on the same backend - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * src = node->src[j]; - if (src == NULL) { - break; + // check if we should start a new split based on the sources of the current node + bool need_new_split = false; + if (node_backend_id == cur_backend_id && split->n_inputs > 0) { + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + // check if a weight is on a different backend + // by starting a new split, the memory of the previously offloaded weights can be reused + if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + int src_backend_id = tensor_backend_id(src); + if (src_backend_id != -1 && src_backend_id != cur_backend_id) { + need_new_split = true; + break; + } + } + // check if the split has too many inputs + if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) { + const size_t id = hash_id(src); + int src_backend_id = sched->tensor_backend_id[id]; + if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) { + //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name); + need_new_split = true; + break; + } + } + } } - ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr != node_allocr) { - int n_inputs = sched->splits[cur_split].n_inputs++; - GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS); - sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src; - - // create copies - size_t id = hash_id(src); - if (sched->node_copies[id][cur_backend_id] == NULL) { - struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); - sched->node_copies[id][cur_backend_id] = tensor_copy; - node_allocr(tensor_copy) = cur_allocr; - ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend; - ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name); + + if (node_backend_id != cur_backend_id || need_new_split) { + split->i_end = i; + i_split++; + if (i_split >= sched->splits_capacity) { + sched->splits_capacity *= 2; + sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); + GGML_ASSERT(sched->splits != NULL); } - node->src[j] = sched->node_copies[id][cur_backend_id]; + GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS); + split = &sched->splits[i_split]; + split->backend_id = node_backend_id; + split->i_start = i; + split->n_inputs = 0; + cur_backend_id = node_backend_id; } - } - } - sched->splits[cur_split].i_end = graph->n_nodes; - sched->n_splits = cur_split + 1; - //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout); + // find inputs that are not on the same backend + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } -#if 1 - // sanity check: all sources should have the same backend as the node - for (int i = 0; i < graph->n_nodes; i++) { - struct ggml_tensor * node = graph->nodes[i]; - ggml_tallocr_t node_allocr = node_allocr(node); - if (node_allocr == NULL) { - fprintf(stderr, "!!!!!!! %s has no backend\n", node->name); - } - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * src = node->src[j]; - if (src == NULL) { - break; - } - ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now - fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n", - node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL", - j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL"); + const int src_backend_id = tensor_backend_id(src); + assert(src_backend_id != -1); // all inputs should be assigned by now + + if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) { + size_t id = hash_id(src); + if (sched->tensor_copies[id][src_backend_id][0] == NULL) { + ggml_backend_t backend = sched->backends[src_backend_id]; + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * tensor_copy; + if (c == sched->cur_copy) { + tensor_copy = src; // use the original tensor as the current copy + } else { + tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); + ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); + } + if (sched->n_copies > 1) { + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor + } + sched->tensor_copies[id][src_backend_id][c] = tensor_copy; + SET_CAUSE(tensor_copy, "4.cpy"); + } + int n_graph_inputs = sched->n_graph_inputs++; + GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS); + sched->graph_inputs[n_graph_inputs] = src; + } + } + + if (src_backend_id != node_backend_id) { + // create a copy of the input in the split's backend + const size_t id = hash_id(src); + if (sched->tensor_copies[id][cur_backend_id][0] == NULL) { + ggml_backend_t backend = sched->backends[cur_backend_id]; + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); + ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); + if (sched->n_copies > 1) { + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor + } + sched->tensor_copies[id][cur_backend_id][c] = tensor_copy; + SET_CAUSE(tensor_copy, "4.cpy"); + } + int n_inputs = split->n_inputs++; + GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS); + split->inputs[n_inputs] = src; + } + node->src[j] = sched->tensor_copies[id][cur_backend_id][sched->cur_copy]; + } } } + split->i_end = graph->n_nodes; + sched->n_splits = i_split + 1; } +#ifdef DEBUG_PASS4 + fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); #endif // create copies of the graph for each split - // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way - struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false); + // TODO: avoid this copy + struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2, false); for (int i = 0; i < sched->n_splits; i++) { struct ggml_backend_sched_split * split = &sched->splits[i]; - split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end); + split->graph = ggml_graph_view(graph, split->i_start, split->i_end); // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split for (int j = 0; j < split->n_inputs; j++) { + assert(graph_copy->size > (graph_copy->n_nodes + 1)); + struct ggml_tensor * input = split->inputs[j]; - struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)]; - input_cpy->src[0] = input; + const size_t input_id = hash_id(input); + struct ggml_tensor * input_cpy = sched->tensor_copies[input_id][split->backend_id][sched->cur_copy]; + + // add a dependency to the input source so that it is not freed before the copy is done + struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input); + input_dep->src[0] = input; + sched->node_backend_ids[graph_copy->n_nodes] = sched->tensor_backend_id[input_id]; + graph_copy->nodes[graph_copy->n_nodes++] = input_dep; + + // add a dependency to the input copy so that it is allocated at the start of the split + sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id; graph_copy->nodes[graph_copy->n_nodes++] = input_cpy; } for (int j = split->i_start; j < split->i_end; j++) { + assert(graph_copy->size > graph_copy->n_nodes); + sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]); graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j]; } } + + if (sched->n_copies > 1) { + // add input copies as leafs so that they are allocated first + for (int i = 0; i < sched->n_graph_inputs; i++) { + struct ggml_tensor * input = sched->graph_inputs[i]; + size_t id = hash_id(input); + int backend_id = tensor_backend_id(input); + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c]; + sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; + graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; + } + } + + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &sched->splits[i]; + int backend_id = split->backend_id; + for (int j = 0; j < split->n_inputs; j++) { + struct ggml_tensor * input = split->inputs[j]; + size_t id = hash_id(input); + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c]; + sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; + graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; + } + } + } + } + + // add leafs from the original graph + for (int i = 0; i < graph->n_leafs; i++) { + struct ggml_tensor * leaf = graph->leafs[i]; + sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf); + graph_copy->leafs[graph_copy->n_leafs++] = leaf; + } + sched->graph = graph_copy; } -static void sched_alloc_splits(ggml_backend_sched_t sched) { - ggml_gallocr_alloc_graph_n( - sched->galloc, - sched->graph, - sched->hash_set, - sched->node_talloc); -} +static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { + // allocate graph + if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) { + // the re-allocation may cause the split inputs to be moved to a different address + ggml_backend_sched_synchronize(sched); +#ifndef NDEBUG + fprintf(stderr, "%s: failed to allocate graph, reserving\n", __func__); +#endif + ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); + if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) { + fprintf(stderr, "%s: failed to allocate graph\n", __func__); + return false; + } + } -static void sched_compute_splits(ggml_backend_sched_t sched) { - uint64_t copy_us[GGML_MAX_BACKENDS] = {0}; - uint64_t compute_us[GGML_MAX_BACKENDS] = {0}; + return true; +} +static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { struct ggml_backend_sched_split * splits = sched->splits; for (int i = 0; i < sched->n_splits; i++) { struct ggml_backend_sched_split * split = &splits[i]; - ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend; - int split_backend_id = sched_backend_prio(sched, split_backend); + int split_backend_id = split->backend_id; + ggml_backend_t split_backend = sched->backends[split_backend_id]; // copy the input tensors to the split backend - uint64_t copy_start_us = ggml_time_us(); for (int j = 0; j < split->n_inputs; j++) { - struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)]; - if (split->inputs[j]->buffer == NULL) { - if (split->inputs[j]->view_src == NULL) { - fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name); - exit(1); + ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); + struct ggml_tensor * input = split->inputs[j]; + struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy]; + + if (input->flags & GGML_TENSOR_FLAG_INPUT) { + // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); } - struct ggml_tensor * view = split->inputs[j]; - view->backend = view->view_src->backend; - view->buffer = view->view_src->buffer; - view->data = (char *)view->view_src->data + view->view_offs; - ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view); - } - if (input_cpy->buffer == NULL) { - fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name); - exit(1); + ggml_backend_tensor_copy(input, input_cpy); + } else { + // wait for the split backend to finish using the input before overwriting it + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); + } + ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); } - GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend); - GGML_ASSERT(input_cpy->buffer->backend == split_backend); - ggml_backend_tensor_copy(split->inputs[j], input_cpy); } - // ggml_backend_synchronize(split_backend); - int64_t copy_end_us = ggml_time_us(); - copy_us[split_backend_id] += copy_end_us - copy_start_us; -#if 0 - char split_filename[GGML_MAX_NAME]; - snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend)); - ggml_graph_dump_dot(split->graph, NULL, split_filename); -#endif + if (!sched->callback_eval) { + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); + if (ec != GGML_STATUS_SUCCESS) { + return ec; + } + } else { + // similar to ggml_backend_compare_graph_backend + for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { + struct ggml_tensor * t = split->graph.nodes[j0]; - uint64_t compute_start_us = ggml_time_us(); - ggml_backend_graph_compute(split_backend, split->graph); - // ggml_backend_synchronize(split_backend); - uint64_t compute_end_us = ggml_time_us(); - compute_us[split_backend_id] += compute_end_us - compute_start_us; - } + // check if the user needs data from this node + bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); -#if 0 - // per-backend timings - fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits); - for (int i = 0; i < sched->n_backends; i++) { - if (copy_us[i] > 0 || compute_us[i] > 0) { - fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]); + int j1 = j0; + + // determine the range [j0, j1] of nodes that can be computed together + while (!need && j1 < split->graph.n_nodes - 1) { + t = split->graph.nodes[++j1]; + need = sched->callback_eval(t, true, sched->callback_eval_user_data); + } + + struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); + + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv); + if (ec != GGML_STATUS_SUCCESS) { + return ec; + } + + // TODO: pass backend to the callback, then the user can decide if they want to synchronize + ggml_backend_synchronize(split_backend); + + if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { + break; + } + + j0 = j1; + } } - } -#endif -} -static void sched_reset(ggml_backend_sched_t sched) { - for (int i = 0; i < sched->n_backends; i++) { - ggml_tallocr_reset(sched->tallocs[i]); + // record the event of this copy + if (split->n_inputs > 0) { + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]); + } + } } + + sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies; + + return GGML_STATUS_SUCCESS; } -ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) { - GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS); +ggml_backend_sched_t ggml_backend_sched_new( + ggml_backend_t * backends, + ggml_backend_buffer_type_t * bufts, + int n_backends, + size_t graph_size, + bool parallel) { + GGML_ASSERT(n_backends > 0); + GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); + GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU + + struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1); - struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched)); - memset(sched, 0, sizeof(struct ggml_backend_sched)); + // initialize hash table + sched->hash_set = ggml_hash_set_new(graph_size); + sched->tensor_backend_id = calloc(sizeof(sched->tensor_backend_id[0]), sched->hash_set.size); + sched->tensor_copies = calloc(sizeof(sched->tensor_copies[0]), sched->hash_set.size); - fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024); + const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2; + sched->node_backend_ids = calloc(sizeof(sched->node_backend_ids[0]), nodes_size); + sched->leaf_backend_ids = calloc(sizeof(sched->leaf_backend_ids[0]), nodes_size); sched->n_backends = n_backends; - for (int i = 0; i < n_backends; i++) { - sched->backends[i] = backends[i]; - } - sched->galloc = ggml_gallocr_new(); + sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; - // init measure allocs for each backend - for (int i = 0; i < n_backends; i++) { - sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]); + const int initial_splits_capacity = 16; + sched->splits = calloc(sizeof(sched->splits[0]), initial_splits_capacity); + sched->splits_capacity = initial_splits_capacity; + + for (int b = 0; b < n_backends; b++) { + sched->backends[b] = backends[b]; + sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]); + GGML_ASSERT(ggml_backend_buft_supports_backend(sched->bufts[b], backends[b])); + if (sched->n_copies > 1) { + for (int c = 0; c < sched->n_copies; c++) { + sched->events[b][c] = ggml_backend_event_new(backends[b]); + } + } } + sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + + ggml_backend_sched_reset(sched); + return sched; } @@ -893,58 +1762,334 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { if (sched == NULL) { return; } - for (int i = 0; i < sched->n_backends; i++) { - ggml_tallocr_free(sched->tallocs[i]); + for (int b = 0; b < sched->n_backends; b++) { + for (int c = 0; c < sched->n_copies; c++) { + ggml_backend_event_free(sched->events[b][c]); + } } ggml_gallocr_free(sched->galloc); + ggml_free(sched->ctx); + free(sched->splits); free(sched->hash_set.keys); - free(sched->node_talloc); - free(sched->node_copies); + free(sched->tensor_backend_id); + free(sched->tensor_copies); + free(sched->node_backend_ids); + free(sched->leaf_backend_ids); free(sched); } -void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { - // initialize hash tables - size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS; - sched->hash_set.size = hash_size; - sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size); - sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size); - sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size); +void ggml_backend_sched_reset(ggml_backend_sched_t sched) { + // reset state for the next run + size_t hash_size = sched->hash_set.size; + memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT + memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); + memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); + + sched->is_reset = true; + sched->is_alloc = false; +} + +bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { + GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes); - sched_split_graph(sched, measure_graph); - sched_alloc_splits(sched); + ggml_backend_sched_split_graph(sched, measure_graph); - // allocate buffers and reset allocators - for (int i = 0; i < sched->n_backends; i++) { - size_t size = ggml_tallocr_max_size(sched->tallocs[i]); - ggml_tallocr_free(sched->tallocs[i]); - sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size); + // TODO: extract this to a separate function + if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { + return false; + } + + ggml_backend_sched_reset(sched); + ggml_backend_sched_synchronize(sched); + + return true; +} + +bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes); + + ggml_backend_sched_split_graph(sched, graph); + + if (!ggml_backend_sched_alloc_splits(sched)) { + return false; + } + + sched->is_alloc = true; + + return true; +} + +enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph); + ggml_backend_sched_synchronize(sched); + return err; +} + +enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + if (!sched->is_reset && !sched->is_alloc) { + ggml_backend_sched_reset(sched); + } + + if (!sched->is_alloc) { + if (!ggml_backend_sched_alloc_graph(sched, graph)) { + return GGML_STATUS_ALLOC_FAILED; + } } - sched_reset(sched); + return ggml_backend_sched_compute_splits(sched); +} + +void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { + for (int i = 0; i < sched->n_backends; i++) { + ggml_backend_synchronize(sched->backends[i]); + } } -void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { - GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS); +void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { + sched->callback_eval = callback; + sched->callback_eval_user_data = user_data; +} - sched_split_graph(sched, graph); - sched_alloc_splits(sched); - sched_compute_splits(sched); - sched_reset(sched); +int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { + return sched->n_splits; } -ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) { - int backend_index = sched_backend_prio(sched, backend); - return sched->tallocs[backend_index]; +int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) { + return sched->n_copies; } -ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) { - int backend_index = sched_backend_prio(sched, backend); - return ggml_tallocr_get_buffer(sched->tallocs[backend_index]); +size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + + return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } -void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { - int backend_index = sched_backend_prio(sched, backend); +void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); - node_allocr(node) = sched->tallocs[backend_index]; + tensor_backend_id(node) = backend_index; +} + +ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) { + int backend_index = tensor_backend_id(node); + if (backend_index == -1) { + return NULL; + } + return sched->backends[backend_index]; +} + +// utils + +void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->buffer == NULL); + GGML_ASSERT(tensor->view_src != NULL); + GGML_ASSERT(tensor->view_src->buffer != NULL); + GGML_ASSERT(tensor->view_src->data != NULL); + + tensor->buffer = buffer; + tensor->data = (char *)tensor->view_src->data + tensor->view_offs; + tensor->backend = tensor->view_src->backend; + ggml_backend_buffer_init_tensor(buffer, tensor); +} + +void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { + GGML_ASSERT(tensor->buffer == NULL); + GGML_ASSERT(tensor->data == NULL); + GGML_ASSERT(tensor->view_src == NULL); + GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); + GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + + tensor->buffer = buffer; + tensor->data = addr; + ggml_backend_buffer_init_tensor(buffer, tensor); +} + +static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, + struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) { + + GGML_ASSERT(src != NULL); + GGML_ASSERT(src->data && "graph must be allocated"); + + size_t id = ggml_hash_insert(hash_set, src); + if (id == GGML_HASHTABLE_ALREADY_EXISTS) { + return node_copies[ggml_hash_find(hash_set, src)]; + } + + struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src); + if (src->view_src != NULL) { + dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src); + dst->view_offs = src->view_offs; + } + dst->op = src->op; + memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); + ggml_set_name(dst, src->name); + + // copy src + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * s = src->src[i]; + if (s == NULL) { + continue; + } + dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s); + } + + node_copies[id] = dst; + return dst; +} + +static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) { + size_t id = ggml_hash_find(hash_set, src); + if (node_init[id]) { + return; + } + node_init[id] = true; + + struct ggml_tensor * dst = node_copies[id]; + if (dst->view_src != NULL) { + graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src); + ggml_backend_view_init(dst->view_src->buffer, dst); + } + else { + ggml_backend_tensor_copy(src, dst); + } + + // init src + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * s = src->src[i]; + if (s == NULL) { + continue; + } + graph_copy_init_tensor(hash_set, node_copies, node_init, s); + } +} + +struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { + struct ggml_hash_set hash_set = { + /* .size = */ graph->visited_hash_table.size, + /* .keys = */ calloc(sizeof(hash_set.keys[0]), graph->visited_hash_table.size) // NOLINT + }; + struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]), hash_set.size); // NOLINT + bool * node_init = calloc(sizeof(node_init[0]), hash_set.size); + + struct ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false), + /* .mem_buffer = */ NULL, + /* .no_alloc = */ true + }; + + struct ggml_context * ctx_allocated = ggml_init(params); + struct ggml_context * ctx_unallocated = ggml_init(params); + + if (ctx_allocated == NULL || ctx_unallocated == NULL) { + fprintf(stderr, "failed to allocate context for graph copy\n"); + free(hash_set.keys); + free(node_copies); + free(node_init); + ggml_free(ctx_allocated); + ggml_free(ctx_unallocated); + return (struct ggml_backend_graph_copy) { + /* .buffer = */ NULL, + /* .ctx_allocated = */ NULL, + /* .ctx_unallocated = */ NULL, + /* .graph = */ NULL, + }; + } + + // dup nodes + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node); + } + + // allocate nodes + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend); + if (buffer == NULL) { + fprintf(stderr, "failed to allocate buffer for graph copy\n"); + free(hash_set.keys); + free(node_copies); + free(node_init); + ggml_free(ctx_allocated); + ggml_free(ctx_unallocated); + return (struct ggml_backend_graph_copy) { + /* .buffer = */ NULL, + /* .ctx_allocated = */ NULL, + /* .ctx_unallocated = */ NULL, + /* .graph = */ NULL, + }; + } + + //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024); + + // copy data and init views + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + graph_copy_init_tensor(hash_set, node_copies, node_init, node); + } + + // build graph copy + struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false); + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)]; + graph_copy->nodes[i] = node_copy; + } + graph_copy->n_nodes = graph->n_nodes; + + free(hash_set.keys); + free(node_copies); + free(node_init); + + return (struct ggml_backend_graph_copy) { + /* .buffer = */ buffer, + /* .ctx_allocated = */ ctx_allocated, + /* .ctx_unallocated = */ ctx_unallocated, + /* .graph = */ graph_copy, + }; +} + +void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { + ggml_backend_buffer_free(copy.buffer); + ggml_free(copy.ctx_allocated); + ggml_free(copy.ctx_unallocated); +} + +bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) { + struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); + if (copy.buffer == NULL) { + return false; + } + + struct ggml_cgraph * g1 = graph; + struct ggml_cgraph * g2 = copy.graph; + + assert(g1->n_nodes == g2->n_nodes); + + for (int i = 0; i < g1->n_nodes; i++) { + //printf("eval %d/%d\n", i, g1->n_nodes); + struct ggml_tensor * t1 = g1->nodes[i]; + struct ggml_tensor * t2 = g2->nodes[i]; + + assert(t1->op == t2->op && ggml_are_same_layout(t1, t2)); + + struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); + struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); + + ggml_backend_graph_compute(backend1, &g1v); + ggml_backend_graph_compute(backend2, &g2v); + + if (ggml_is_view_op(t1->op)) { + continue; + } + + // compare results, calculate rms etc + if (!callback(i, t1, t2, user_data)) { + break; + } + } + + ggml_backend_graph_copy_free(copy); + + return true; } diff --git a/bindings/ruby/ext/ggml-backend.h b/bindings/ruby/ext/ggml-backend.h index 793a0a9d65a..744b6a77457 100644 --- a/bindings/ruby/ext/ggml-backend.h +++ b/bindings/ruby/ext/ggml-backend.h @@ -7,69 +7,123 @@ extern "C" { #endif + typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; + typedef struct ggml_backend_buffer * ggml_backend_buffer_t; + typedef struct ggml_backend_event * ggml_backend_event_t; + typedef struct ggml_backend * ggml_backend_t; + typedef void * ggml_backend_graph_plan_t; + // // Backend buffer // - struct ggml_backend_buffer; - typedef struct ggml_backend_buffer * ggml_backend_buffer_t; - - // backend buffer functions - GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); - GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + // buffer type + GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); + GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); + GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); + GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); + GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); + GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend); + GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); + + // buffer + enum ggml_backend_buffer_usage { + GGML_BACKEND_BUFFER_USAGE_ANY = 0, + GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, + }; + + GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); + GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); + GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); + GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); // // Backend // - struct ggml_backend; - typedef struct ggml_backend * ggml_backend_t; - typedef void * ggml_backend_graph_plan_t; - - GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor); - + GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend); GGML_API const char * ggml_backend_name(ggml_backend_t backend); GGML_API void ggml_backend_free(ggml_backend_t backend); - GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); + GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend); + GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); + GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); + GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); - GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); + GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - - GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); - GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); - GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); - GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan); - GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); - GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op); + GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op); + GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op); // tensor copy between different backends GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + // asynchronous copy + // the copy is performed after all the currently queued operations in backend_src + // backend_dst will wait for the copy to complete before performing other operations + // automatic fallback to sync copy if async is not supported + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); + + // events + GGML_API ggml_backend_event_t ggml_backend_event_new (ggml_backend_t backend); + GGML_API void ggml_backend_event_free (ggml_backend_event_t event); + GGML_API void ggml_backend_event_record (ggml_backend_event_t event); + GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event); + GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event); // wait async on event + // // CPU backend // GGML_API ggml_backend_t ggml_backend_cpu_init(void); - GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend); - GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); + GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend); + GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); + GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); // Create a backend buffer from an existing pointer - GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size); + GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + + GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); + +#ifdef GGML_USE_CPU_HBM + GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void); +#endif + + // + // Backend registry + // + + // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way + GGML_API size_t ggml_backend_reg_get_count(void); + GGML_API size_t ggml_backend_reg_find_by_name(const char * name); + GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params] + GGML_API const char * ggml_backend_reg_get_name(size_t i); + GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific + GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i); + GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size); // // Backend scheduler @@ -83,53 +137,96 @@ extern "C" { /* Example usage: - sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends); - // sched is initialized with measure allocators and cannot be used until allocated with a measure graph + // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned + // preferrably to run on the same backend as the buffer + ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - // initialize buffers from a measure graph - measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false); - // in build_graph: - build_graph(...) { - // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer) - alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu); - ggml_allocr_alloc(alloc_cpu, tensor); + // initialize buffers from a max size graph (optional) + reserve_graph = build_graph(sched, max_batch_size); - // manually assigning nodes to a backend (optional, shouldn't be needed in most cases) - struct ggml_tensor * node = ggml_mul_mat(ctx, ...); - ggml_backend_sched_set_node_backend(sched, node, backend_gpu); - } + // manually assign nodes to a backend (optional, should not be needed in most cases) + struct ggml_tensor * node = ggml_mul_mat(ctx, ...); + ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu); - // allocate backend buffers from measure graph - ggml_backend_sched_init_measure(sched, measure_graph); - - // the scheduler is now ready to compute graphs + ggml_backend_sched_reserve(sched, reserve_graph); // compute graph = build_graph(sched); ggml_backend_sched_graph_compute(sched, graph); + + // if there are graph inputs: + ggml_backend_sched_reset(sched); + ggml_backend_sched_alloc_graph(sched, graph); + ggml_backend_tensor_set(input_tensor, ...); + ggml_backend_sched_graph_compute(sched, graph); + } */ struct ggml_backend_sched; typedef struct ggml_backend_sched * ggml_backend_sched_t; - // Initialize a backend scheduler - GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends); + // when ask == true, the scheduler wants to know if the user wants to observe this node + // this allows the scheduler to batch nodes together in order to evaluate them in a single call + // + // when ask == false, the scheduler is passing the node tensor to the user for observation + // if the user returns false, the scheduler will cancel the graph compute + // + typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); - GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); + // Initialize a backend scheduler + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); + GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph - GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); + GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); + + // Get the number of splits of the last graph + GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched); + GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched); + + GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + + GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); + GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); + + // Allocate and compute graph on the backend scheduler + GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched); + + // Reset all assignments and allocators - must be called before changing the node backends + GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); + + // Set a callback to be called for each resulting node during graph compute + GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); + + // + // Utils + // + + struct ggml_backend_graph_copy { + ggml_backend_buffer_t buffer; + struct ggml_context * ctx_allocated; + struct ggml_context * ctx_unallocated; + struct ggml_cgraph * graph; + }; + + // Copy a graph to a different backend + GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph); + GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy); + + typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); - GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend); - GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend); + // Compare the output of two backends + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data); - GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); + // Tensor initialization + GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); + GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - // Allocate a graph on the backend scheduler - GGML_API void ggml_backend_sched_graph_compute( - ggml_backend_sched_t sched, - struct ggml_cgraph * graph); #ifdef __cplusplus } diff --git a/bindings/ruby/ext/ggml-common.h b/bindings/ruby/ext/ggml-common.h new file mode 100644 index 00000000000..43c7978a098 --- /dev/null +++ b/bindings/ruby/ext/ggml-common.h @@ -0,0 +1,1853 @@ +#ifndef GGML_COMMON_DECL + +#if defined(GGML_COMMON_DECL_C) +#include + +typedef uint16_t ggml_half; +typedef uint32_t ggml_half2; + +#define GGML_COMMON_AGGR + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_METAL) +#include + +typedef half ggml_half; +typedef half2 ggml_half2; + +#define GGML_COMMON_AGGR + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_CUDA) +#include +#include + +typedef half ggml_half; +typedef half2 ggml_half2; + +#define GGML_COMMON_AGGR data + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_HIP) +#include +#include + +typedef half ggml_half; +typedef half2 ggml_half2; + +#define GGML_COMMON_AGGR data + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_SYCL) +#include +#include + +typedef sycl::half ggml_half; +typedef sycl::half2 ggml_half2; + +#define GGML_COMMON_AGGR data + +#define GGML_COMMON_DECL +#endif + +#if defined(GGML_COMMON_DECL) + +#ifndef __cplusplus +#ifndef static_assert +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) +#define static_assert(cond, msg) _Static_assert(cond, msg) +#else +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif +#endif +#endif // __cplusplus + +// QK = number of values after dequantization +// QK_K = super-block size + +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif // GGML_QKK_64 + +#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL) +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QI4_0 (QK4_0 / (4 * QR4_0)) +#define QR4_0 2 + +#define QI4_1 (QK4_1 / (4 * QR4_1)) +#define QR4_1 2 + +#define QI5_0 (QK5_0 / (4 * QR5_0)) +#define QR5_0 2 + +#define QI5_1 (QK5_1 / (4 * QR5_1)) +#define QR5_1 2 + +#define QI8_0 (QK8_0 / (4 * QR8_0)) +#define QR8_0 1 + +#define QI8_1 (QK8_1 / (4 * QR8_1)) +#define QR8_1 1 + +#define QI2_K (QK_K / (4*QR2_K)) +#define QR2_K 4 + +#define QI3_K (QK_K / (4*QR3_K)) +#define QR3_K 4 + +#define QI4_K (QK_K / (4*QR4_K)) +#define QR4_K 2 + +#define QI5_K (QK_K / (4*QR5_K)) +#define QR5_K 2 + +#define QI6_K (QK_K / (4*QR6_K)) +#define QR6_K 2 + +#define QI2_XXS (QK_K / (4*QR2_XXS)) +#define QR2_XXS 8 + +#define QI2_XS (QK_K / (4*QR2_XS)) +#define QR2_XS 8 + +#define QI2_S (QK_K / (4*QR2_S)) +#define QR2_S 8 + +#define QI3_XXS (QK_K / (4*QR3_XXS)) +#define QR3_XXS 8 + +#define QI3_XS (QK_K / (4*QR3_XS)) +#define QR3_XS 8 + +#define QI1_S (QK_K / (4*QR1_S)) +#define QR1_S 8 + +#define QI4_NL (QK4_NL / (4*QR4_NL)) +#define QR4_NL 2 + +#if QK_K == 64 +#define QI4_XS QI4_NL +#define QR4_XS QR4_NL +#else +#define QI4_XS (QK_K / (4*QR4_XS)) +#define QR4_XS 8 +#endif + +#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP + +#define QK4_0 32 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + union { + struct { + ggml_half d; // delta + ggml_half m; // min + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +typedef struct { + ggml_half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + union { + struct { + ggml_half d; // delta + ggml_half m; // min + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +typedef struct { + ggml_half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +typedef struct { + union { + struct { + ggml_half d; // delta + ggml_half s; // d * sum(qs[i]) + } GGML_COMMON_AGGR; + ggml_half2 ds; + }; + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); + +// +// Super-block quantization structures +// + +// 2-bit quantization +// weight is represented as x = a * q + b +// 16 blocks of 16 elements each +// Effectively 2.625 bits per weight +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + union { + struct { + ggml_half d; // super-block scale for quantized scales + ggml_half dmin; // super-block scale for quantized mins + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +// 3-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 3.4375 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[2]; + ggml_half d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); +#else +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits + ggml_half d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +#endif + +// 4-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_half d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + union { + struct { + ggml_half d; // super-block scale for quantized scales + ggml_half dmin; // super-block scale for quantized mins + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +#endif + +// 5-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 5.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_half d; // super-block scale + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_half) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + union { + struct { + ggml_half d; // super-block scale for quantized scales + ggml_half dmin; // super-block scale for quantized mins + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + ggml_half d; // super-block scale +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); + +// This is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + +// (Almost) "true" 2-bit quantization. +// Due to the need to use blocks as per ggml design, it ends up using +// 2.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + ggml_half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); + +// 2.3125 bpw quants +typedef struct { + ggml_half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); + +// 2.5625 bpw quants +typedef struct { + ggml_half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; +} block_iq2_s; +static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding"); + +// (Almost) "true" 3-bit quantization. +// Due to the need to use blocks as per ggml design, it ends up using +// 3.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + ggml_half d; + uint8_t qs[3*QK_K/8]; +} block_iq3_xxs; +static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); + +// 3.4375 bpw +#if QK_K == 64 +#define IQ3S_N_SCALE 2 +#else +#define IQ3S_N_SCALE QK_K/64 +#endif +typedef struct { + ggml_half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t signs[QK_K/8]; + uint8_t scales[IQ3S_N_SCALE]; +} block_iq3_s; +static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); + +typedef struct { + ggml_half d; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; +} block_iq1_s; +static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); + +// 1.75 bpw +typedef struct { + uint8_t qs[QK_K/8]; // grid index, low 8 bits + uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) +#if QK_K == 64 + ggml_half d; +#endif + uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64) +} block_iq1_m; +#if QK_K == 64 +static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32 + sizeof(ggml_half), "wrong iq1_m block size/padding"); +#else +static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); +#endif + +// Used by IQ1_M quants +typedef union { + ggml_half f16; + uint16_t u16; +} iq1m_scale_t; + +// Non-linear quants +#define QK4_NL 32 +typedef struct { + ggml_half d; + uint8_t qs[QK4_NL/2]; +} block_iq4_nl; +static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding"); + +#if QK_K == 64 +#define block_iq4_xs block_iq4_nl +#else +typedef struct { + ggml_half d; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; +} block_iq4_xs; +static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +#endif + +#endif // GGML_COMMON_DECL +#endif // GGML_COMMON_DECL + +//////////////////////////////////////////////////////////////////////////////// + +#ifndef GGML_COMMON_IMPL + +#if defined(GGML_COMMON_IMPL_C) +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { +#define GGML_TABLE_END() }; + +#define GGML_COMMON_IMPL +#elif defined(GGML_COMMON_IMPL_METAL) +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = { +#define GGML_TABLE_END() }; + +#define GGML_COMMON_IMPL +#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = { +#define GGML_TABLE_END() }; + +#define GGML_COMMON_IMPL +#elif defined(GGML_COMMON_IMPL_SYCL) + +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { +#define GGML_TABLE_END() }; + +#define GGML_COMMON_IMPL +#endif + +#if defined(GGML_COMMON_IMPL) + +GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8) + 1, 2, 4, 8, 16, 32, 64, 128 +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128) + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +GGML_TABLE_END() + +//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) + 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, + 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, + 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff, + 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff, + 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff, + 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff, + 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff, + 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff, + 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff, + 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff, + 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff, + 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff, + 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff, + 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff, + 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff, + 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff, + 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff, + 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff, + 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff, + 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff, + 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff, + 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff, + 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff, + 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff, + 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff, + 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff, + 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff, + 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff, + 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff, + 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff, + 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, + 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff, +GGML_TABLE_END() +//#endif + + +GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, + 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, + 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, + 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, + 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, + 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, + 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, + 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, + 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, + 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, + 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, + 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, + 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, + 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, + 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, + 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, + 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, + 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, + 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, + 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, + 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, + 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, + 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, + 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, + 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, + 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, + 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, + 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, + 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, + 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, + 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, + 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, + 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, + 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, + 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, + 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, + 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, + 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, + 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, + 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, + 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, + 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, + 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, + 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, + 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, + 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, + 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, + 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, + 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, + 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, + 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, + 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, + 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, + 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, + 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, + 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, + 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, + 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, + 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, + 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, + 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, + 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, + 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, + 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, + 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, + 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, + 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, + 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, + 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, + 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, + 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, + 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, + 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, + 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, + 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, + 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, + 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, + 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, + 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, + 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, + 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, + 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, + 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, + 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, + 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, + 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, + 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, + 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, + 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, + 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, + 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, + 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, + 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, + 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, + 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, + 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, + 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, + 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, + 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, + 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, + 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, + 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, + 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, + 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, + 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, + 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, + 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, + 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, + 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, + 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, + 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, + 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, + 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, + 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, + 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, + 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, + 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, + 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, + 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, + 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, + 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, + 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, + 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, + 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, + 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, + 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, + 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, + 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, + 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, + 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, + 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, + 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, + 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, + 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, + 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, + 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, + 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, + 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, + 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, + 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, + 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, + 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, + 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, + 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, + 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, + 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, + 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, + 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, + 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, + 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, + 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, + 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, + 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, + 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, + 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, + 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, + 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, + 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, + 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, + 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, + 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, + 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, + 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, + 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, + 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, + 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, + 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, + 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, + 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, + 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, + 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, + 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, + 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, + 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, + 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, + 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, + 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, + 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, + 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, + 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, + 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, + 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, + 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, + 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, + 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, + 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, + 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, + 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, + 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, + 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, + 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, + 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, + 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, + 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, + 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, + 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, + 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, + 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, + 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, + 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, + 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, + 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, + 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, + 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, + 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, + 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, + 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, + 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, + 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, + 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, + 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, + 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, + 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, + 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, + 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, + 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, + 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, + 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, + 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, + 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, + 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, + 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, + 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, + 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, + 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, + 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, + 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, + 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, + 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, + 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256) + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +GGML_TABLE_END() + +#define NGRID_IQ1S 2048 +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f +#if defined(GGML_COMMON_IMPL_C) +GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S) + 0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff, + 0xffffffffffff0101, 0xffffffffff00ff00, 0xffffffffff000000, 0xffffffffff01ffff, + 0xffffffffff01ff01, 0xffffffffff0101ff, 0xffffffffff010101, 0xffffffff00ff0000, + 0xffffffff0000ff00, 0xffffffff000000ff, 0xffffffff00000001, 0xffffffff00010000, + 0xffffffff01ffffff, 0xffffffff01ffff01, 0xffffffff01ff01ff, 0xffffffff01ff0101, + 0xffffffff01000000, 0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff010101ff, + 0xffffffff01010101, 0xffffff00ffff00ff, 0xffffff00ffff0000, 0xffffff00ff00ff00, + 0xffffff00ff0000ff, 0xffffff00ff000001, 0xffffff00ff000100, 0xffffff00ff000101, + 0xffffff00ff010000, 0xffffff0000ffff00, 0xffffff0000ff0001, 0xffffff0000ff0100, + 0xffffff000000ff01, 0xffffff0000000000, 0xffffff0000000101, 0xffffff000001ff00, + 0xffffff00000100ff, 0xffffff0000010001, 0xffffff00000101ff, 0xffffff0001ff0000, + 0xffffff000100ff00, 0xffffff00010000ff, 0xffffff0001000001, 0xffffff0001010000, + 0xffffff01ffffffff, 0xffffff01ffffff01, 0xffffff01ffff01ff, 0xffffff01ffff0101, + 0xffffff01ff000000, 0xffffff01ff01ffff, 0xffffff01ff01ff01, 0xffffff01ff0101ff, + 0xffffff01ff010101, 0xffffff0100ff0000, 0xffffff010000ff00, 0xffffff0100000100, + 0xffffff01000100ff, 0xffffff0100010100, 0xffffff0101ffffff, 0xffffff0101ffff01, + 0xffffff0101ff01ff, 0xffffff0101ff0101, 0xffffff010100ff00, 0xffffff0101000000, + 0xffffff0101000100, 0xffffff010101ffff, 0xffffff010101ff01, 0xffffff01010101ff, + 0xffffff0101010101, 0xffff00ffff00ff00, 0xffff00ffff0000ff, 0xffff00ffff000001, + 0xffff00ffff010000, 0xffff00ff00ffff00, 0xffff00ff00ff0100, 0xffff00ff00000000, + 0xffff00ff00000101, 0xffff00ff000100ff, 0xffff00ff00010000, 0xffff00ff0100ff00, + 0xffff00ff01000100, 0xffff00ff01010000, 0xffff0000ffffff00, 0xffff0000ffff00ff, + 0xffff0000ffff0000, 0xffff0000ffff0001, 0xffff0000ff000000, 0xffff0000ff0001ff, + 0xffff0000ff000101, 0xffff0000ff010100, 0xffff000000ffffff, 0xffff000000ff0000, + 0xffff000000ff0101, 0xffff00000000ffff, 0xffff00000000ff00, 0xffff0000000000ff, + 0xffff000000000000, 0xffff000000000001, 0xffff000000000100, 0xffff00000001ffff, + 0xffff00000001ff01, 0xffff000000010000, 0xffff0000000101ff, 0xffff000000010101, + 0xffff000001ffff00, 0xffff00000100ff00, 0xffff000001000000, 0xffff0000010001ff, + 0xffff000001000101, 0xffff00000101ff00, 0xffff0000010100ff, 0xffff000001010000, + 0xffff000001010001, 0xffff000001010100, 0xffff0001ff0000ff, 0xffff0001ff000100, + 0xffff000100ffff00, 0xffff000100ff00ff, 0xffff00010000ffff, 0xffff00010000ff01, + 0xffff000100000000, 0xffff0001000001ff, 0xffff00010001ffff, 0xffff00010001ff00, + 0xffff000100010001, 0xffff000100010100, 0xffff000101ff0000, 0xffff00010100ff00, + 0xffff0001010000ff, 0xffff000101000100, 0xffff01ffffffffff, 0xffff01ffffffff01, + 0xffff01ffffff01ff, 0xffff01ffffff0101, 0xffff01ffff000000, 0xffff01ffff01ffff, + 0xffff01ffff01ff01, 0xffff01ffff0101ff, 0xffff01ffff010101, 0xffff01ff00ff0000, + 0xffff01ff0000ff00, 0xffff01ff00000001, 0xffff01ff00010000, 0xffff01ff01ffffff, + 0xffff01ff01ffff01, 0xffff01ff01ff01ff, 0xffff01ff01ff0101, 0xffff01ff01000000, + 0xffff01ff0101ffff, 0xffff01ff0101ff01, 0xffff01ff010101ff, 0xffff01ff01010101, + 0xffff0100ffff0000, 0xffff0100ff00ff00, 0xffff0100ff0000ff, 0xffff0100ff000100, + 0xffff0100ff0100ff, 0xffff0100ff010000, 0xffff010000ffff00, 0xffff01000000ffff, + 0xffff01000000ff00, 0xffff010000000000, 0xffff01000001ff00, 0xffff0100000100ff, + 0xffff010000010100, 0xffff01000100ff00, 0xffff0100010000ff, 0xffff010001000001, + 0xffff010001000100, 0xffff010001010000, 0xffff0101ffffffff, 0xffff0101ffffff01, + 0xffff0101ffff01ff, 0xffff0101ffff0101, 0xffff0101ff000000, 0xffff0101ff01ffff, + 0xffff0101ff01ff01, 0xffff0101ff0101ff, 0xffff0101ff010101, 0xffff010100ff0000, + 0xffff01010000ff00, 0xffff010100000100, 0xffff01010001ff00, 0xffff010100010000, + 0xffff010101ffffff, 0xffff010101ffff01, 0xffff010101ff0000, 0xffff010101ff01ff, + 0xffff010101ff0101, 0xffff010101000000, 0xffff01010101ffff, 0xffff01010101ff01, + 0xffff0101010101ff, 0xffff010101010101, 0xff00ffffff00ffff, 0xff00ffffff00ff00, + 0xff00ffffff0000ff, 0xff00ffffff000100, 0xff00ffffff0100ff, 0xff00ffffff010000, + 0xff00ffff00ffff00, 0xff00ffff00ff00ff, 0xff00ffff0000ffff, 0xff00ffff00000000, + 0xff00ffff000001ff, 0xff00ffff0001ff00, 0xff00ffff000100ff, 0xff00ffff00010000, + 0xff00ffff00010100, 0xff00ffff0100ff00, 0xff00ffff010000ff, 0xff00ffff01000001, + 0xff00ffff0101ff00, 0xff00ffff01010000, 0xff00ff00ffffff00, 0xff00ff00ffff00ff, + 0xff00ff00ffff0001, 0xff00ff00ffff0100, 0xff00ff00ff00ffff, 0xff00ff00ff00ff01, + 0xff00ff00ff000000, 0xff00ff00ff0001ff, 0xff00ff00ff01ff00, 0xff00ff00ff0100ff, + 0xff00ff00ff010100, 0xff00ff0000ff0000, 0xff00ff0000ff0101, 0xff00ff000000ffff, + 0xff00ff000000ff00, 0xff00ff000000ff01, 0xff00ff00000000ff, 0xff00ff0000000000, + 0xff00ff0000000001, 0xff00ff0000000100, 0xff00ff000001ffff, 0xff00ff0000010000, + 0xff00ff0001ff00ff, 0xff00ff000100ff01, 0xff00ff0001000000, 0xff00ff000101ff00, + 0xff00ff00010100ff, 0xff00ff01ff00ff00, 0xff00ff01ff0000ff, 0xff00ff01ff000001, + 0xff00ff01ff010000, 0xff00ff0100ffffff, 0xff00ff0100ff0001, 0xff00ff0100ff0100, + 0xff00ff010000ff01, 0xff00ff0100000000, 0xff00ff01000001ff, 0xff00ff0100000101, + 0xff00ff01000100ff, 0xff00ff0100010001, 0xff00ff0101ff0000, 0xff00ff010100ff00, + 0xff00ff01010000ff, 0xff00ff0101000001, 0xff00ff0101010000, 0xff0000ffffffff00, + 0xff0000ffffff0001, 0xff0000ffffff0100, 0xff0000ffff0000ff, 0xff0000ffff000000, + 0xff0000ffff0001ff, 0xff0000ffff000100, 0xff0000ffff01ff00, 0xff0000ffff010001, + 0xff0000ff00ffff00, 0xff0000ff00ff0000, 0xff0000ff00ff0001, 0xff0000ff00ff01ff, + 0xff0000ff00ff0101, 0xff0000ff0000ff00, 0xff0000ff000000ff, 0xff0000ff00000000, + 0xff0000ff00000001, 0xff0000ff00000100, 0xff0000ff0001ff01, 0xff0000ff00010000, + 0xff0000ff000101ff, 0xff0000ff01ff00ff, 0xff0000ff01ff0100, 0xff0000ff0100ffff, + 0xff0000ff010000ff, 0xff0000ff01000000, 0xff0000ff010001ff, 0xff0000ff01000100, + 0xff0000ff01000101, 0xff0000ff0101ff00, 0xff0000ff010100ff, 0xff0000ff01010000, + 0xff0000ff01010100, 0xff000000ffffff01, 0xff000000ffff0000, 0xff000000ffff0101, + 0xff000000ff00ff00, 0xff000000ff0000ff, 0xff000000ff000000, 0xff000000ff000001, + 0xff000000ff000100, 0xff000000ff01ffff, 0xff000000ff01ff01, 0xff000000ff010000, + 0xff000000ff0101ff, 0xff000000ff010101, 0xff00000000ffff00, 0xff00000000ff00ff, + 0xff00000000ff0000, 0xff00000000ff0001, 0xff0000000000ff00, 0xff0000000000ff01, + 0xff000000000000ff, 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, + 0xff00000000000101, 0xff0000000001ff00, 0xff000000000100ff, 0xff00000000010000, + 0xff00000000010001, 0xff00000000010100, 0xff00000001ffffff, 0xff00000001ffff01, + 0xff00000001ff00ff, 0xff00000001ff0000, 0xff00000001ff01ff, 0xff00000001ff0101, + 0xff0000000100ffff, 0xff0000000100ff00, 0xff000000010000ff, 0xff00000001000000, + 0xff00000001000001, 0xff00000001000100, 0xff00000001000101, 0xff0000000101ffff, + 0xff0000000101ff01, 0xff00000001010000, 0xff000001ffffff00, 0xff000001ffff00ff, + 0xff000001ffff0000, 0xff000001ffff0001, 0xff000001ff000000, 0xff000001ff000001, + 0xff000001ff0001ff, 0xff000001ff000101, 0xff000001ff01ff00, 0xff000001ff010001, + 0xff00000100ffffff, 0xff00000100ffff01, 0xff00000100ff00ff, 0xff00000100ff0000, + 0xff00000100ff01ff, 0xff00000100ff0101, 0xff0000010000ff00, 0xff00000100000000, + 0xff00000100000001, 0xff000001000001ff, 0xff00000100000100, 0xff0000010001ff00, + 0xff000001000100ff, 0xff00000100010000, 0xff000001000101ff, 0xff00000100010100, + 0xff00000100010101, 0xff00000101ff0001, 0xff00000101ff0101, 0xff0000010100ff01, + 0xff00000101000000, 0xff000001010100ff, 0xff00000101010100, 0xff0001ffff00ff00, + 0xff0001ffff000001, 0xff0001ffff010000, 0xff0001ff00ffff00, 0xff0001ff00ff00ff, + 0xff0001ff00ff0001, 0xff0001ff00ff0100, 0xff0001ff0000ffff, 0xff0001ff00000000, + 0xff0001ff000001ff, 0xff0001ff00000101, 0xff0001ff0001ffff, 0xff0001ff0001ff00, + 0xff0001ff000100ff, 0xff0001ff00010001, 0xff0001ff00010100, 0xff0001ff01ff0000, + 0xff0001ff0100ff00, 0xff0001ff010000ff, 0xff0001ff01010000, 0xff000100ff00ffff, + 0xff000100ff00ff01, 0xff000100ff000000, 0xff000100ff000101, 0xff000100ff01ff00, + 0xff000100ff010000, 0xff00010000ffff01, 0xff00010000ff00ff, 0xff00010000ff0000, + 0xff00010000ff01ff, 0xff0001000000ff00, 0xff000100000000ff, 0xff00010000000000, + 0xff00010000000001, 0xff00010000000100, 0xff00010000000101, 0xff0001000001ffff, + 0xff00010000010000, 0xff00010000010101, 0xff00010001ff0100, 0xff0001000100ff00, + 0xff0001000100ff01, 0xff00010001000000, 0xff000100010001ff, 0xff0001000101ff00, + 0xff00010001010001, 0xff00010001010100, 0xff000101ffff0100, 0xff000101ff000001, + 0xff000101ff0100ff, 0xff000101ff010001, 0xff00010100ff00ff, 0xff00010100ff0001, + 0xff00010100ff0100, 0xff0001010000ffff, 0xff0001010000ff01, 0xff00010100000000, + 0xff000101000001ff, 0xff0001010001ff00, 0xff00010100010001, 0xff00010100010100, + 0xff00010101ff0000, 0xff0001010100ff00, 0xff00010101000001, 0xff00010101000101, + 0xff01ffffffffffff, 0xff01ffffffffff01, 0xff01ffffffff01ff, 0xff01ffffffff0101, + 0xff01ffffff000000, 0xff01ffffff01ffff, 0xff01ffffff01ff01, 0xff01ffffff010000, + 0xff01ffffff0101ff, 0xff01ffffff010101, 0xff01ffff00ff0000, 0xff01ffff0000ff00, + 0xff01ffff00000100, 0xff01ffff0001ff00, 0xff01ffff00010000, 0xff01ffff01ffffff, + 0xff01ffff01ffff01, 0xff01ffff01ff01ff, 0xff01ffff01ff0101, 0xff01ffff01000000, + 0xff01ffff0101ffff, 0xff01ffff0101ff01, 0xff01ffff01010000, 0xff01ffff010101ff, + 0xff01ffff01010101, 0xff01ff00ffff0000, 0xff01ff00ff00ff00, 0xff01ff00ff0000ff, + 0xff01ff00ff000100, 0xff01ff00ff010000, 0xff01ff0000ffff01, 0xff01ff0000ff00ff, + 0xff01ff0000ff0100, 0xff01ff0000000000, 0xff01ff00000001ff, 0xff01ff0000000101, + 0xff01ff000001ff00, 0xff01ff00000100ff, 0xff01ff0000010000, 0xff01ff0000010001, + 0xff01ff0001ff0000, 0xff01ff000100ffff, 0xff01ff0001000001, 0xff01ff0001000100, + 0xff01ff0001010000, 0xff01ff01ffffff00, 0xff01ff01ffff01ff, 0xff01ff01ffff0101, + 0xff01ff01ff00ff00, 0xff01ff01ff000000, 0xff01ff01ff01ffff, 0xff01ff01ff01ff01, + 0xff01ff01ff0101ff, 0xff01ff01ff010101, 0xff01ff0100ff0000, 0xff01ff010000ff00, + 0xff01ff0100000001, 0xff01ff0100000100, 0xff01ff0100010000, 0xff01ff0101ffff00, + 0xff01ff0101ff01ff, 0xff01ff0101ff0101, 0xff01ff010100ff00, 0xff01ff0101000000, + 0xff01ff010101ffff, 0xff01ff010101ff01, 0xff01ff01010101ff, 0xff01ff0101010101, + 0xff0100ffffff0000, 0xff0100ffff0000ff, 0xff0100ffff000001, 0xff0100ffff000100, + 0xff0100ffff010000, 0xff0100ff00ff00ff, 0xff0100ff00ff0000, 0xff0100ff00ff0001, + 0xff0100ff00ff0100, 0xff0100ff0000ff01, 0xff0100ff00000000, 0xff0100ff000001ff, + 0xff0100ff00000101, 0xff0100ff00010001, 0xff0100ff01ff0000, 0xff0100ff0100ff00, + 0xff0100ff010000ff, 0xff0100ff01000100, 0xff0100ff0101ff00, 0xff0100ff01010000, + 0xff010000ffff0100, 0xff010000ff000000, 0xff010000ff01ff00, 0xff010000ff010100, + 0xff01000000ffffff, 0xff01000000ff0000, 0xff01000000ff01ff, 0xff0100000000ff00, + 0xff010000000000ff, 0xff01000000000000, 0xff01000000000100, 0xff0100000001ff01, + 0xff01000000010000, 0xff010000000101ff, 0xff01000001ff0100, 0xff0100000100ffff, + 0xff010000010000ff, 0xff01000001000000, 0xff010000010001ff, 0xff01000001000101, + 0xff0100000101ff00, 0xff010000010100ff, 0xff01000001010001, 0xff01000001010100, + 0xff010001ffff0000, 0xff010001ff00ffff, 0xff010001ff00ff01, 0xff010001ff000100, + 0xff010001ff010000, 0xff01000100ffff00, 0xff01000100ff0100, 0xff01000100000000, + 0xff0100010001ffff, 0xff0100010001ff00, 0xff01000100010100, 0xff01000101ff00ff, + 0xff01000101ff0001, 0xff0100010100ffff, 0xff01000101000101, 0xff0101ffffffffff, + 0xff0101ffffffff01, 0xff0101ffffff01ff, 0xff0101ffffff0101, 0xff0101ffff000000, + 0xff0101ffff01ffff, 0xff0101ffff01ff01, 0xff0101ffff0101ff, 0xff0101ffff010101, + 0xff0101ff00ff0000, 0xff0101ff0000ff00, 0xff0101ff000000ff, 0xff0101ff00010000, + 0xff0101ff01ffffff, 0xff0101ff01ffff01, 0xff0101ff01ff01ff, 0xff0101ff01ff0101, + 0xff0101ff0101ffff, 0xff0101ff0101ff01, 0xff0101ff010101ff, 0xff0101ff01010101, + 0xff010100ffff0100, 0xff010100ff00ff00, 0xff010100ff0000ff, 0xff010100ff000100, + 0xff010100ff010000, 0xff01010000ff0001, 0xff01010000ff0100, 0xff0101000000ff01, + 0xff01010000000000, 0xff0101000001ff00, 0xff010100000100ff, 0xff01010000010001, + 0xff01010000010100, 0xff01010001ff0000, 0xff0101000100ffff, 0xff01010001000001, + 0xff01010001000100, 0xff010100010100ff, 0xff01010001010000, 0xff010101ffffffff, + 0xff010101ffffff01, 0xff010101ffff01ff, 0xff010101ffff0101, 0xff010101ff01ffff, + 0xff010101ff01ff01, 0xff010101ff0101ff, 0xff010101ff010101, 0xff01010100ff0000, + 0xff0101010000ff00, 0xff01010100000001, 0xff01010100000100, 0xff01010100010000, + 0xff01010101ffffff, 0xff01010101ffff01, 0xff01010101ff01ff, 0xff01010101ff0101, + 0xff01010101000000, 0xff0101010101ffff, 0xff0101010101ff01, 0xff010101010101ff, + 0xff01010101010101, 0x00ffffffffff0000, 0x00ffffffff00ff00, 0x00ffffffff000001, + 0x00ffffffff010000, 0x00ffffff00ff0100, 0x00ffffff0000ff01, 0x00ffffff00000000, + 0x00ffffff000001ff, 0x00ffffff00000101, 0x00ffffff0001ff00, 0x00ffffff000100ff, + 0x00ffffff00010001, 0x00ffffff010000ff, 0x00ffffff01000100, 0x00ffffff0101ff00, + 0x00ffffff01010001, 0x00ffff00ffffffff, 0x00ffff00ffffff00, 0x00ffff00ffff00ff, + 0x00ffff00ffff0001, 0x00ffff00ffff0100, 0x00ffff00ff00ff01, 0x00ffff00ff000000, + 0x00ffff00ff000001, 0x00ffff00ff0001ff, 0x00ffff00ff000101, 0x00ffff00ff01ff00, + 0x00ffff00ff010001, 0x00ffff00ff010100, 0x00ffff0000ff0000, 0x00ffff0000ff01ff, + 0x00ffff0000ff0101, 0x00ffff000000ff00, 0x00ffff00000000ff, 0x00ffff0000000000, + 0x00ffff0000000001, 0x00ffff0000000100, 0x00ffff0000000101, 0x00ffff0000010000, + 0x00ffff00000101ff, 0x00ffff0000010101, 0x00ffff0001ffff00, 0x00ffff0001ff00ff, + 0x00ffff0001ff0001, 0x00ffff000100ffff, 0x00ffff000100ff01, 0x00ffff0001000000, + 0x00ffff000101ffff, 0x00ffff000101ff00, 0x00ffff000101ff01, 0x00ffff01ffff0000, + 0x00ffff01ff00ff00, 0x00ffff01ff0000ff, 0x00ffff01ff000001, 0x00ffff01ff010000, + 0x00ffff0100ffff00, 0x00ffff010000ff01, 0x00ffff0100000000, 0x00ffff0100000101, + 0x00ffff01000100ff, 0x00ffff0100010100, 0x00ffff0101ff0100, 0x00ffff01010000ff, + 0x00ffff0101010000, 0x00ff00ffffffff00, 0x00ff00ffff000000, 0x00ff00ffff000100, + 0x00ff00ffff010100, 0x00ff00ff00ff0000, 0x00ff00ff00ff01ff, 0x00ff00ff00ff0101, + 0x00ff00ff0000ff00, 0x00ff00ff000000ff, 0x00ff00ff00000000, 0x00ff00ff00000001, + 0x00ff00ff0001ff00, 0x00ff00ff0001ff01, 0x00ff00ff00010000, 0x00ff00ff000101ff, + 0x00ff00ff00010101, 0x00ff00ff01ffff00, 0x00ff00ff01ff0001, 0x00ff00ff01ff0100, + 0x00ff00ff0100ffff, 0x00ff00ff0100ff01, 0x00ff00ff01000000, 0x00ff00ff0101ffff, + 0x00ff00ff0101ff00, 0x00ff00ff01010100, 0x00ff0000ffffff00, 0x00ff0000ffffff01, + 0x00ff0000ffff0000, 0x00ff0000ffff0101, 0x00ff0000ff00ff00, 0x00ff0000ff0000ff, + 0x00ff0000ff000000, 0x00ff0000ff000001, 0x00ff0000ff000100, 0x00ff0000ff01ffff, + 0x00ff0000ff010000, 0x00ff0000ff010101, 0x00ff000000ffff00, 0x00ff000000ff00ff, + 0x00ff000000ff0000, 0x00ff000000ff0001, 0x00ff000000ff0100, 0x00ff00000000ffff, + 0x00ff00000000ff00, 0x00ff0000000000ff, 0x00ff000000000000, 0x00ff000000000001, + 0x00ff0000000001ff, 0x00ff000000000100, 0x00ff00000001ff00, 0x00ff0000000100ff, + 0x00ff000000010000, 0x00ff000000010001, 0x00ff000000010100, 0x00ff000001ffff01, + 0x00ff000001ff00ff, 0x00ff000001ff0000, 0x00ff000001ff01ff, 0x00ff00000100ff00, + 0x00ff0000010000ff, 0x00ff000001000000, 0x00ff000001000001, 0x00ff000001000100, + 0x00ff000001000101, 0x00ff000001010000, 0x00ff0000010101ff, 0x00ff000001010101, + 0x00ff0001ffffff00, 0x00ff0001ffff0000, 0x00ff0001ffff0100, 0x00ff0001ff0000ff, + 0x00ff0001ff000000, 0x00ff0001ff0001ff, 0x00ff0001ff000101, 0x00ff0001ff01ff00, + 0x00ff0001ff0100ff, 0x00ff0001ff010100, 0x00ff000100ffffff, 0x00ff000100ffff01, + 0x00ff000100ff0000, 0x00ff000100ff01ff, 0x00ff00010000ffff, 0x00ff00010000ff00, + 0x00ff00010000ff01, 0x00ff000100000000, 0x00ff000100000001, 0x00ff000100000100, + 0x00ff00010001ff01, 0x00ff000100010000, 0x00ff0001000101ff, 0x00ff000101ffff00, + 0x00ff000101ff0000, 0x00ff000101ff0101, 0x00ff0001010000ff, 0x00ff000101000000, + 0x00ff00010101ff00, 0x00ff0001010100ff, 0x00ff000101010001, 0x00ff01ffffff0000, + 0x00ff01ffff00ff00, 0x00ff01ffff000000, 0x00ff01ffff000101, 0x00ff01ffff010000, + 0x00ff01ff00ffff01, 0x00ff01ff00ff0100, 0x00ff01ff0000ffff, 0x00ff01ff00000000, + 0x00ff01ff000001ff, 0x00ff01ff0001ff00, 0x00ff01ff000100ff, 0x00ff01ff00010001, + 0x00ff01ff00010100, 0x00ff01ff01ff0000, 0x00ff01ff0100ff00, 0x00ff01ff010000ff, + 0x00ff01ff01000001, 0x00ff01ff01000100, 0x00ff01ff01010000, 0x00ff0100ffffff00, + 0x00ff0100ffff0000, 0x00ff0100ffff0001, 0x00ff0100ffff0101, 0x00ff0100ff00ffff, + 0x00ff0100ff0000ff, 0x00ff0100ff000000, 0x00ff0100ff0001ff, 0x00ff0100ff01ff00, + 0x00ff0100ff0100ff, 0x00ff0100ff010001, 0x00ff010000ffffff, 0x00ff010000ff0000, + 0x00ff010000ff0101, 0x00ff01000000ff00, 0x00ff01000000ff01, 0x00ff0100000000ff, + 0x00ff010000000000, 0x00ff010000000001, 0x00ff010000000100, 0x00ff01000001ffff, + 0x00ff01000001ff01, 0x00ff010000010000, 0x00ff010000010001, 0x00ff010000010101, + 0x00ff010001ff0001, 0x00ff010001ff0100, 0x00ff01000100ff01, 0x00ff010001000000, + 0x00ff010001000001, 0x00ff0100010001ff, 0x00ff01000101ff00, 0x00ff0100010100ff, + 0x00ff010001010001, 0x00ff010001010100, 0x00ff0101ff000001, 0x00ff010100ff00ff, + 0x00ff010100ff0001, 0x00ff010100ff0100, 0x00ff010100000000, 0x00ff0101000001ff, + 0x00ff010100000101, 0x00ff0101000100ff, 0x00ff010100010100, 0x00ff0101010000ff, + 0x00ff010101010000, 0x0000ffffffffff00, 0x0000ffffffff00ff, 0x0000ffffffff0000, + 0x0000ffffffff0001, 0x0000ffffffff0100, 0x0000ffffff00ff01, 0x0000ffffff000000, + 0x0000ffffff000101, 0x0000ffffff01ff00, 0x0000ffffff0100ff, 0x0000ffffff010100, + 0x0000ffff00ffffff, 0x0000ffff00ff0000, 0x0000ffff00ff01ff, 0x0000ffff0000ff00, + 0x0000ffff000000ff, 0x0000ffff00000000, 0x0000ffff00000001, 0x0000ffff00000100, + 0x0000ffff00010000, 0x0000ffff000101ff, 0x0000ffff01ff0001, 0x0000ffff01ff0100, + 0x0000ffff01000000, 0x0000ffff010001ff, 0x0000ffff0101ffff, 0x0000ffff0101ff00, + 0x0000ffff01010001, 0x0000ffff01010100, 0x0000ff00ffff0000, 0x0000ff00ffff01ff, + 0x0000ff00ffff0100, 0x0000ff00ffff0101, 0x0000ff00ff00ff00, 0x0000ff00ff0000ff, + 0x0000ff00ff000000, 0x0000ff00ff000001, 0x0000ff00ff0001ff, 0x0000ff00ff000100, + 0x0000ff00ff01ffff, 0x0000ff00ff010000, 0x0000ff00ff010001, 0x0000ff00ff0101ff, + 0x0000ff00ff010101, 0x0000ff0000ffff00, 0x0000ff0000ff00ff, 0x0000ff0000ff0000, + 0x0000ff0000ff0001, 0x0000ff0000ff0100, 0x0000ff000000ffff, 0x0000ff000000ff00, + 0x0000ff000000ff01, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, + 0x0000ff00000001ff, 0x0000ff0000000100, 0x0000ff0000000101, 0x0000ff000001ff00, + 0x0000ff00000100ff, 0x0000ff0000010000, 0x0000ff0000010001, 0x0000ff0000010100, + 0x0000ff0001ffff01, 0x0000ff0001ff0000, 0x0000ff000100ff00, 0x0000ff00010000ff, + 0x0000ff0001000000, 0x0000ff0001000001, 0x0000ff0001000100, 0x0000ff000101ffff, + 0x0000ff0001010000, 0x0000ff0001010101, 0x0000ff01ffffff00, 0x0000ff01ffff0001, + 0x0000ff01ff00ff01, 0x0000ff01ff000000, 0x0000ff01ff000101, 0x0000ff01ff01ff00, + 0x0000ff01ff0100ff, 0x0000ff0100ffff01, 0x0000ff0100ff0000, 0x0000ff0100ff0101, + 0x0000ff010000ff00, 0x0000ff01000000ff, 0x0000ff0100000000, 0x0000ff0100000001, + 0x0000ff0100000100, 0x0000ff010001ff01, 0x0000ff0100010000, 0x0000ff0101ff0000, + 0x0000ff010100ffff, 0x0000ff010100ff01, 0x0000ff0101000000, 0x0000ff0101000100, + 0x0000ff0101000101, 0x0000ff01010100ff, 0x000000ffffff00ff, 0x000000ffffff0000, + 0x000000ffff00ff00, 0x000000ffff0000ff, 0x000000ffff000000, 0x000000ffff000001, + 0x000000ffff0001ff, 0x000000ffff000100, 0x000000ffff01ff00, 0x000000ffff010000, + 0x000000ffff0101ff, 0x000000ffff010101, 0x000000ff00ffff00, 0x000000ff00ff00ff, + 0x000000ff00ff0000, 0x000000ff00ff0001, 0x000000ff00ff0100, 0x000000ff00ff0101, + 0x000000ff0000ffff, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, + 0x000000ff00000001, 0x000000ff000001ff, 0x000000ff00000100, 0x000000ff00000101, + 0x000000ff0001ff00, 0x000000ff0001ff01, 0x000000ff000100ff, 0x000000ff00010000, + 0x000000ff00010001, 0x000000ff00010100, 0x000000ff01ffffff, 0x000000ff01ff01ff, + 0x000000ff01ff0101, 0x000000ff0100ff00, 0x000000ff010000ff, 0x000000ff01000000, + 0x000000ff01000001, 0x000000ff01000100, 0x000000ff0101ff00, 0x000000ff010100ff, + 0x000000ff01010000, 0x000000ff01010101, 0x00000000ffffff00, 0x00000000ffffff01, + 0x00000000ffff00ff, 0x00000000ffff0000, 0x00000000ffff0001, 0x00000000ffff0100, + 0x00000000ff00ffff, 0x00000000ff00ff00, 0x00000000ff00ff01, 0x00000000ff0000ff, + 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff000101, + 0x00000000ff01ff00, 0x00000000ff0100ff, 0x00000000ff010000, 0x00000000ff010001, + 0x00000000ff010100, 0x0000000000ffffff, 0x0000000000ffff00, 0x0000000000ffff01, + 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, 0x0000000000ff01ff, + 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, + 0x00000000000000ff, 0x0000000000000000, 0x0000000000000001, 0x00000000000001ff, + 0x0000000000000100, 0x0000000000000101, 0x000000000001ffff, 0x000000000001ff00, + 0x00000000000100ff, 0x0000000000010000, 0x0000000000010001, 0x00000000000101ff, + 0x0000000000010100, 0x0000000000010101, 0x0000000001ffff00, 0x0000000001ff00ff, + 0x0000000001ff0000, 0x0000000001ff0100, 0x0000000001ff0101, 0x000000000100ffff, + 0x000000000100ff00, 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, + 0x00000000010001ff, 0x0000000001000100, 0x000000000101ff00, 0x00000000010100ff, + 0x0000000001010000, 0x0000000001010001, 0x0000000001010100, 0x00000001ffffffff, + 0x00000001ffffff00, 0x00000001ffffff01, 0x00000001ffff00ff, 0x00000001ffff0001, + 0x00000001ffff01ff, 0x00000001ffff0100, 0x00000001ff00ff00, 0x00000001ff0000ff, + 0x00000001ff000000, 0x00000001ff0001ff, 0x00000001ff000100, 0x00000001ff01ffff, + 0x00000001ff01ff00, 0x00000001ff01ff01, 0x00000001ff0100ff, 0x00000001ff010000, + 0x00000001ff010001, 0x00000001ff0101ff, 0x00000001ff010100, 0x0000000100ffff00, + 0x0000000100ff0000, 0x0000000100ff0001, 0x0000000100ff01ff, 0x0000000100ff0100, + 0x0000000100ff0101, 0x000000010000ffff, 0x000000010000ff00, 0x000000010000ff01, + 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, 0x00000001000001ff, + 0x0000000100000100, 0x0000000100000101, 0x000000010001ff00, 0x00000001000100ff, + 0x0000000100010000, 0x0000000100010100, 0x0000000101ffff01, 0x0000000101ff0000, + 0x0000000101ff0001, 0x0000000101ff01ff, 0x0000000101ff0100, 0x0000000101ff0101, + 0x000000010100ff00, 0x0000000101000000, 0x0000000101000101, 0x000000010101ff01, + 0x0000000101010000, 0x0000000101010001, 0x00000001010101ff, 0x0000000101010100, + 0x000001ffffff00ff, 0x000001ffffff0000, 0x000001ffffff0001, 0x000001ffffff0100, + 0x000001ffff00ffff, 0x000001ffff000000, 0x000001ffff0001ff, 0x000001ffff01ff00, + 0x000001ffff010101, 0x000001ff00ff0000, 0x000001ff00ff01ff, 0x000001ff00ff0101, + 0x000001ff0000ff00, 0x000001ff000000ff, 0x000001ff00000000, 0x000001ff00000001, + 0x000001ff000001ff, 0x000001ff00000100, 0x000001ff0001ffff, 0x000001ff0001ff01, + 0x000001ff000100ff, 0x000001ff00010000, 0x000001ff01ffff01, 0x000001ff01ff0100, + 0x000001ff0100ffff, 0x000001ff0100ff01, 0x000001ff01000000, 0x000001ff010001ff, + 0x000001ff0101ff00, 0x000001ff01010100, 0x00000100ffffff00, 0x00000100ffffff01, + 0x00000100ffff0000, 0x00000100ffff0101, 0x00000100ff00ff00, 0x00000100ff0000ff, + 0x00000100ff000000, 0x00000100ff000001, 0x00000100ff000100, 0x00000100ff010000, + 0x0000010000ffff00, 0x0000010000ff00ff, 0x0000010000ff0000, 0x0000010000ff0001, + 0x0000010000ff0100, 0x000001000000ffff, 0x000001000000ff00, 0x000001000000ff01, + 0x00000100000000ff, 0x0000010000000000, 0x0000010000000001, 0x00000100000001ff, + 0x0000010000000100, 0x0000010000000101, 0x000001000001ff00, 0x00000100000100ff, + 0x0000010000010000, 0x0000010000010001, 0x0000010000010100, 0x0000010001ffff00, + 0x0000010001ff0000, 0x0000010001ff0100, 0x000001000100ff00, 0x00000100010000ff, + 0x0000010001000000, 0x0000010001000001, 0x00000100010001ff, 0x0000010001000100, + 0x0000010001010000, 0x00000101ffff00ff, 0x00000101ffff01ff, 0x00000101ff000000, + 0x00000101ff000101, 0x00000101ff01ffff, 0x00000101ff010000, 0x00000101ff010001, + 0x00000101ff010100, 0x0000010100ff0000, 0x0000010100ff01ff, 0x0000010100ff0100, + 0x000001010000ff00, 0x0000010100000000, 0x0000010100000001, 0x00000101000001ff, + 0x0000010100000100, 0x000001010001ff01, 0x0000010100010000, 0x00000101000101ff, + 0x0000010100010101, 0x0000010101ffff00, 0x0000010101ff0101, 0x000001010100ff01, + 0x0000010101000000, 0x0000010101000001, 0x00000101010001ff, 0x0000010101000101, + 0x000001010101ff00, 0x0001ffffffff0000, 0x0001ffffff0000ff, 0x0001ffffff000001, + 0x0001ffffff000100, 0x0001ffffff010000, 0x0001ffff00ff00ff, 0x0001ffff0000ffff, + 0x0001ffff00000000, 0x0001ffff00000001, 0x0001ffff000001ff, 0x0001ffff00000101, + 0x0001ffff0001ff00, 0x0001ffff000100ff, 0x0001ffff00010001, 0x0001ffff00010100, + 0x0001ffff01ffff00, 0x0001ffff01000001, 0x0001ffff01010000, 0x0001ff00ffffff00, + 0x0001ff00ffff00ff, 0x0001ff00ffff0001, 0x0001ff00ffff0100, 0x0001ff00ff00ff01, + 0x0001ff00ff000000, 0x0001ff00ff01ff00, 0x0001ff00ff01ff01, 0x0001ff00ff010001, + 0x0001ff00ff010100, 0x0001ff0000ff0000, 0x0001ff0000ff0100, 0x0001ff000000ff00, + 0x0001ff0000000000, 0x0001ff0000000001, 0x0001ff0000000100, 0x0001ff0000010000, + 0x0001ff0000010001, 0x0001ff0000010101, 0x0001ff0001ff00ff, 0x0001ff0001ff0101, + 0x0001ff000100ff01, 0x0001ff0001000000, 0x0001ff000101ff00, 0x0001ff0001010001, + 0x0001ff0001010100, 0x0001ff01ff00ff00, 0x0001ff01ff000001, 0x0001ff01ff000100, + 0x0001ff0100ffffff, 0x0001ff0100ffff00, 0x0001ff0100ff0001, 0x0001ff0100000000, + 0x0001ff0100000001, 0x0001ff01000001ff, 0x0001ff010001ffff, 0x0001ff0101ff0000, + 0x0001ff010100ff00, 0x0001ff0101000001, 0x0001ff0101010000, 0x000100ffff00ff00, + 0x000100ffff00ff01, 0x000100ffff000000, 0x000100ffff000001, 0x000100ffff000101, + 0x000100ffff01ff00, 0x000100ffff010001, 0x000100ffff010100, 0x000100ff00ffffff, + 0x000100ff00ffff01, 0x000100ff00ff0000, 0x000100ff00ff01ff, 0x000100ff00ff0101, + 0x000100ff0000ff00, 0x000100ff000000ff, 0x000100ff00000000, 0x000100ff00000001, + 0x000100ff00000100, 0x000100ff00000101, 0x000100ff0001ffff, 0x000100ff0001ff01, + 0x000100ff00010000, 0x000100ff01ff00ff, 0x000100ff01ff0000, 0x000100ff01ff0100, + 0x000100ff0100ffff, 0x000100ff0100ff01, 0x000100ff010000ff, 0x000100ff01000000, + 0x000100ff01000001, 0x000100ff010001ff, 0x000100ff01000101, 0x000100ff0101ff00, + 0x000100ff010100ff, 0x000100ff01010100, 0x00010000ffff0000, 0x00010000ffff01ff, + 0x00010000ffff0101, 0x00010000ff00ff00, 0x00010000ff000000, 0x00010000ff000001, + 0x00010000ff000100, 0x0001000000ff00ff, 0x0001000000ff0000, 0x0001000000ff0001, + 0x0001000000ff0100, 0x000100000000ffff, 0x000100000000ff00, 0x00010000000000ff, + 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, 0x000100000001ff00, + 0x00010000000100ff, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100, + 0x0001000001ff0001, 0x0001000001ff0100, 0x0001000001ff0101, 0x000100000100ff00, + 0x0001000001000000, 0x0001000001000001, 0x0001000001000100, 0x0001000001000101, + 0x000100000101ff01, 0x0001000001010000, 0x0001000001010001, 0x00010000010101ff, + 0x00010001ffffff01, 0x00010001ffff0100, 0x00010001ff000000, 0x00010001ff01ffff, + 0x00010001ff010001, 0x00010001ff0101ff, 0x00010001ff010100, 0x0001000100ffffff, + 0x0001000100ff0000, 0x0001000100ff01ff, 0x0001000100ff0101, 0x000100010000ff00, + 0x00010001000000ff, 0x0001000100000000, 0x0001000100000001, 0x00010001000001ff, + 0x0001000100000101, 0x000100010001ffff, 0x0001000100010000, 0x00010001000101ff, + 0x0001000101ffffff, 0x0001000101ffff01, 0x0001000101ff0000, 0x0001000101ff0101, + 0x00010001010000ff, 0x0001000101000001, 0x00010001010001ff, 0x0001000101000100, + 0x000100010101ffff, 0x00010001010100ff, 0x0001000101010001, 0x0001000101010101, + 0x000101ffff000001, 0x000101ffff000100, 0x000101ffff010000, 0x000101ff00ffff00, + 0x000101ff0000ff01, 0x000101ff00000000, 0x000101ff00000101, 0x000101ff0001ff00, + 0x000101ff00010100, 0x000101ff01ff0000, 0x000101ff0100ff00, 0x000101ff010001ff, + 0x000101ff01010001, 0x00010100ffffff00, 0x00010100ffff00ff, 0x00010100ff00ffff, + 0x00010100ff000000, 0x00010100ff01ff00, 0x00010100ff0100ff, 0x00010100ff010001, + 0x00010100ff010100, 0x0001010000ffffff, 0x0001010000ffff00, 0x0001010000ff0000, + 0x0001010000ff0001, 0x0001010000ff01ff, 0x000101000000ff00, 0x00010100000000ff, + 0x0001010000000000, 0x0001010000000001, 0x0001010000000100, 0x000101000001ffff, + 0x0001010000010000, 0x0001010000010101, 0x0001010001ffff01, 0x0001010001ff00ff, + 0x0001010001ff0101, 0x0001010001000000, 0x000101000101ff00, 0x00010100010100ff, + 0x0001010001010000, 0x0001010001010100, 0x00010101ff00ff00, 0x00010101ff000001, + 0x00010101ff0001ff, 0x0001010100ffff00, 0x0001010100ff00ff, 0x0001010100ff0100, + 0x000101010000ffff, 0x0001010100000000, 0x00010101000001ff, 0x0001010100000101, + 0x00010101000100ff, 0x0001010100010000, 0x0001010100010100, 0x0001010101ff0001, + 0x00010101010000ff, 0x00010101010001ff, 0x0001010101000101, 0x0001010101010001, + 0x01ffffffffffffff, 0x01ffffffffffff01, 0x01ffffffffff01ff, 0x01ffffffffff0101, + 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, 0x01ffffffff010101, + 0x01ffffff00ff0000, 0x01ffffff0000ffff, 0x01ffffff0000ff00, 0x01ffffff000000ff, + 0x01ffffff00000001, 0x01ffffff00000100, 0x01ffffff00010000, 0x01ffffff01ffffff, + 0x01ffffff01ffff01, 0x01ffffff01ff01ff, 0x01ffffff01ff0101, 0x01ffffff01000000, + 0x01ffffff0101ffff, 0x01ffffff0101ff01, 0x01ffffff010101ff, 0x01ffffff01010101, + 0x01ffff00ffff0000, 0x01ffff00ff00ff00, 0x01ffff00ff0000ff, 0x01ffff00ff000001, + 0x01ffff00ff000100, 0x01ffff00ff010000, 0x01ffff0000ffff00, 0x01ffff0000ff00ff, + 0x01ffff0000ff0100, 0x01ffff000000ffff, 0x01ffff000000ff01, 0x01ffff0000000000, + 0x01ffff0000000001, 0x01ffff00000001ff, 0x01ffff0000000100, 0x01ffff00000100ff, + 0x01ffff0000010001, 0x01ffff0000010100, 0x01ffff0001ff0000, 0x01ffff0001ff0100, + 0x01ffff00010000ff, 0x01ffff0001000001, 0x01ffff0001000100, 0x01ffff0001010000, + 0x01ffff01ffffffff, 0x01ffff01ffffff01, 0x01ffff01ffff01ff, 0x01ffff01ffff0101, + 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff01ff01, 0x01ffff01ff0101ff, + 0x01ffff01ff010101, 0x01ffff010000ff00, 0x01ffff01000000ff, 0x01ffff0100000100, + 0x01ffff0100010000, 0x01ffff0101ffffff, 0x01ffff0101ffff01, 0x01ffff0101ff01ff, + 0x01ffff0101ff0101, 0x01ffff0101000000, 0x01ffff010101ffff, 0x01ffff010101ff01, + 0x01ffff01010101ff, 0x01ffff0101010101, 0x01ff00ffff0000ff, 0x01ff00ffff000100, + 0x01ff00ff00ffff00, 0x01ff00ff00ff00ff, 0x01ff00ff0000ff00, 0x01ff00ff00000000, + 0x01ff00ff00000101, 0x01ff00ff0001ff00, 0x01ff00ff000100ff, 0x01ff00ff00010100, + 0x01ff00ff010000ff, 0x01ff00ff01000100, 0x01ff0000ffffff00, 0x01ff0000ffff0100, + 0x01ff0000ff00ff01, 0x01ff0000ff000000, 0x01ff0000ff000101, 0x01ff0000ff010001, + 0x01ff0000ff010100, 0x01ff000000ffffff, 0x01ff000000ffff00, 0x01ff000000ff0000, + 0x01ff000000ff01ff, 0x01ff00000000ff00, 0x01ff0000000000ff, 0x01ff000000000000, + 0x01ff000000000001, 0x01ff000000000100, 0x01ff000000000101, 0x01ff000000010000, + 0x01ff000000010001, 0x01ff0000000101ff, 0x01ff000000010101, 0x01ff000001ffff00, + 0x01ff000001ff00ff, 0x01ff000001ff0001, 0x01ff000001ff0100, 0x01ff00000100ffff, + 0x01ff00000100ff01, 0x01ff000001000000, 0x01ff0000010001ff, 0x01ff000001010001, + 0x01ff0001ff00ff00, 0x01ff0001ff000001, 0x01ff0001ff000100, 0x01ff0001ff010000, + 0x01ff000100ffff00, 0x01ff000100ff00ff, 0x01ff000100ff0100, 0x01ff000100ff0101, + 0x01ff00010000ffff, 0x01ff000100000000, 0x01ff000100000100, 0x01ff000100000101, + 0x01ff00010001ff00, 0x01ff000100010001, 0x01ff000100010101, 0x01ff000101ff0000, + 0x01ff00010100ff00, 0x01ff000101000101, 0x01ff0001010100ff, 0x01ff01ffffffffff, + 0x01ff01ffffffff01, 0x01ff01ffffff01ff, 0x01ff01ffffff0101, 0x01ff01ffff000000, + 0x01ff01ffff01ffff, 0x01ff01ffff01ff01, 0x01ff01ffff0101ff, 0x01ff01ffff010101, + 0x01ff01ff00ffff00, 0x01ff01ff00ff0000, 0x01ff01ff0000ff00, 0x01ff01ff000000ff, + 0x01ff01ff00000100, 0x01ff01ff00010000, 0x01ff01ff00010100, 0x01ff01ff01ffffff, + 0x01ff01ff01ffff01, 0x01ff01ff01ff01ff, 0x01ff01ff01ff0101, 0x01ff01ff01000000, + 0x01ff01ff0101ffff, 0x01ff01ff0101ff01, 0x01ff01ff010101ff, 0x01ff01ff01010101, + 0x01ff0100ffff0000, 0x01ff0100ffff0001, 0x01ff0100ff00ff00, 0x01ff0100ff0000ff, + 0x01ff0100ff000001, 0x01ff0100ff010000, 0x01ff010000ffff00, 0x01ff010000ff00ff, + 0x01ff010000ff0001, 0x01ff010000ff0100, 0x01ff01000000ffff, 0x01ff01000000ff01, + 0x01ff010000000000, 0x01ff010000000101, 0x01ff01000001ff00, 0x01ff0100000100ff, + 0x01ff010001ff0000, 0x01ff010001000001, 0x01ff010001000100, 0x01ff010001010000, + 0x01ff0101ffffffff, 0x01ff0101ffffff01, 0x01ff0101ffff01ff, 0x01ff0101ffff0101, + 0x01ff0101ff000000, 0x01ff0101ff01ffff, 0x01ff0101ff01ff01, 0x01ff0101ff0101ff, + 0x01ff0101ff010101, 0x01ff010100ff0000, 0x01ff01010000ff00, 0x01ff0101000000ff, + 0x01ff010100000001, 0x01ff010101ffffff, 0x01ff010101ffff01, 0x01ff010101ff01ff, + 0x01ff010101ff0101, 0x01ff010101000000, 0x01ff01010101ffff, 0x01ff01010101ff01, + 0x01ff0101010101ff, 0x01ff010101010101, 0x0100ffffffff0000, 0x0100ffffff00ff00, + 0x0100ffffff000001, 0x0100ffffff0001ff, 0x0100ffffff000100, 0x0100ffffff010000, + 0x0100ffff00ffff00, 0x0100ffff00ff0001, 0x0100ffff00ff0100, 0x0100ffff00000000, + 0x0100ffff000001ff, 0x0100ffff00000101, 0x0100ffff00010100, 0x0100ffff00010101, + 0x0100ffff01ff0000, 0x0100ffff0100ff00, 0x0100ffff010000ff, 0x0100ffff01000001, + 0x0100ffff01000100, 0x0100ffff01010000, 0x0100ff00ffffff00, 0x0100ff00ffff00ff, + 0x0100ff00ffff0001, 0x0100ff00ffff0100, 0x0100ff00ff00ffff, 0x0100ff00ff000000, + 0x0100ff00ff0001ff, 0x0100ff00ff000101, 0x0100ff00ff01ff00, 0x0100ff00ff0100ff, + 0x0100ff00ff010001, 0x0100ff00ff010100, 0x0100ff0000ffffff, 0x0100ff0000ff0000, + 0x0100ff000000ffff, 0x0100ff000000ff00, 0x0100ff00000000ff, 0x0100ff0000000000, + 0x0100ff0000000001, 0x0100ff0000000100, 0x0100ff000001ff01, 0x0100ff0000010000, + 0x0100ff0001ff00ff, 0x0100ff0001ff0001, 0x0100ff000100ff01, 0x0100ff0001000000, + 0x0100ff00010001ff, 0x0100ff000101ff00, 0x0100ff00010100ff, 0x0100ff0001010001, + 0x0100ff0001010100, 0x0100ff01ffff0000, 0x0100ff01ff00ff00, 0x0100ff01ff0000ff, + 0x0100ff01ff000100, 0x0100ff01ff010000, 0x0100ff0100ff00ff, 0x0100ff0100ff0001, + 0x0100ff0100ff0100, 0x0100ff010000ffff, 0x0100ff010000ff01, 0x0100ff0100000000, + 0x0100ff01000001ff, 0x0100ff0100010001, 0x0100ff0100010100, 0x0100ff0101ff0000, + 0x0100ff01010000ff, 0x0100ff0101000001, 0x0100ff0101010100, 0x010000ffffffff00, + 0x010000ffffff00ff, 0x010000ffffff0001, 0x010000ffff00ffff, 0x010000ffff000000, + 0x010000ffff0001ff, 0x010000ffff010001, 0x010000ff00ffffff, 0x010000ff00ff0101, + 0x010000ff0000ff00, 0x010000ff000000ff, 0x010000ff00000000, 0x010000ff00000001, + 0x010000ff000001ff, 0x010000ff00000100, 0x010000ff0001ffff, 0x010000ff0001ff00, + 0x010000ff0001ff01, 0x010000ff00010000, 0x010000ff01ff00ff, 0x010000ff01ff0001, + 0x010000ff0100ff01, 0x010000ff010000ff, 0x010000ff01000000, 0x010000ff010001ff, + 0x010000ff0101ff00, 0x010000ff01010100, 0x01000000ffffffff, 0x01000000ffff0000, + 0x01000000ffff01ff, 0x01000000ffff0101, 0x01000000ff00ffff, 0x01000000ff00ff00, + 0x01000000ff0000ff, 0x01000000ff000000, 0x01000000ff000001, 0x01000000ff000100, + 0x01000000ff01ff00, 0x01000000ff010000, 0x01000000ff010100, 0x01000000ff010101, + 0x0100000000ffff00, 0x0100000000ff00ff, 0x0100000000ff0000, 0x0100000000ff0001, + 0x0100000000ff0100, 0x010000000000ffff, 0x010000000000ff00, 0x010000000000ff01, + 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, 0x01000000000001ff, + 0x0100000000000100, 0x0100000000000101, 0x010000000001ff00, 0x01000000000100ff, + 0x0100000000010000, 0x0100000000010001, 0x0100000000010100, 0x0100000001ffff00, + 0x0100000001ff0000, 0x0100000001ff01ff, 0x010000000100ff00, 0x010000000100ff01, + 0x01000000010000ff, 0x0100000001000000, 0x0100000001000001, 0x0100000001000100, + 0x0100000001000101, 0x010000000101ffff, 0x010000000101ff01, 0x0100000001010000, + 0x01000000010101ff, 0x0100000001010101, 0x01000001ffffff00, 0x01000001ffff00ff, + 0x01000001ff00ffff, 0x01000001ff000000, 0x01000001ff000100, 0x01000001ff01ffff, + 0x01000001ff010001, 0x01000001ff010100, 0x0100000100ff0000, 0x0100000100ff01ff, + 0x0100000100ff0100, 0x010000010000ff00, 0x010000010000ff01, 0x0100000100000000, + 0x0100000100000001, 0x0100000100000100, 0x0100000100010000, 0x01000001000101ff, + 0x0100000101ffff01, 0x0100000101ff00ff, 0x0100000101ff0100, 0x0100000101ff0101, + 0x010000010100ff01, 0x01000001010000ff, 0x0100000101000000, 0x01000001010100ff, + 0x0100000101010001, 0x0100000101010100, 0x010001ffffff0000, 0x010001ffff000001, + 0x010001ffff000100, 0x010001ffff010000, 0x010001ff00ffff00, 0x010001ff00ff0001, + 0x010001ff0000ffff, 0x010001ff0000ff01, 0x010001ff00000000, 0x010001ff00000001, + 0x010001ff00000101, 0x010001ff000100ff, 0x010001ff00010000, 0x010001ff01ff0000, + 0x010001ff0100ff00, 0x010001ff01000001, 0x010001ff01000100, 0x010001ff01010000, + 0x01000100ffff00ff, 0x01000100ffff0001, 0x01000100ffff0100, 0x01000100ff00ffff, + 0x01000100ff00ff01, 0x01000100ff000000, 0x01000100ff0001ff, 0x01000100ff000101, + 0x01000100ff01ffff, 0x01000100ff01ff00, 0x01000100ff0100ff, 0x01000100ff010001, + 0x0100010000ffffff, 0x0100010000ffff01, 0x0100010000ff0000, 0x0100010000ff01ff, + 0x0100010000ff0101, 0x010001000000ff00, 0x01000100000000ff, 0x0100010000000000, + 0x0100010000000001, 0x0100010000000100, 0x010001000001ff01, 0x0100010000010000, + 0x0100010000010001, 0x0100010000010101, 0x0100010001ffff00, 0x0100010001ff00ff, + 0x010001000100ffff, 0x010001000100ff01, 0x0100010001000000, 0x0100010001000101, + 0x010001000101ff00, 0x0100010001010001, 0x01000101ffff0000, 0x01000101ff000000, + 0x01000101ff010000, 0x0100010100ff00ff, 0x0100010100ff0001, 0x0100010100ff0100, + 0x010001010000ffff, 0x0100010100000000, 0x01000101000001ff, 0x010001010001ff00, + 0x0100010101ff0000, 0x010001010100ff00, 0x01000101010000ff, 0x0100010101000000, + 0x0100010101000001, 0x0101ffffffffffff, 0x0101ffffffffff01, 0x0101ffffffff01ff, + 0x0101ffffffff0101, 0x0101ffffff000000, 0x0101ffffff01ffff, 0x0101ffffff01ff01, + 0x0101ffffff0101ff, 0x0101ffffff010101, 0x0101ffff00ff0000, 0x0101ffff0000ff00, + 0x0101ffff000000ff, 0x0101ffff00000001, 0x0101ffff00000100, 0x0101ffff01ffffff, + 0x0101ffff01ffff01, 0x0101ffff01ff01ff, 0x0101ffff01ff0101, 0x0101ffff01000000, + 0x0101ffff0101ffff, 0x0101ffff0101ff01, 0x0101ffff010101ff, 0x0101ffff01010101, + 0x0101ff00ffff0000, 0x0101ff00ffff0100, 0x0101ff00ff00ff00, 0x0101ff00ff0000ff, + 0x0101ff00ff000001, 0x0101ff00ff000100, 0x0101ff00ff000101, 0x0101ff0000ff0001, + 0x0101ff0000ff0100, 0x0101ff000000ff00, 0x0101ff0000000000, 0x0101ff00000001ff, + 0x0101ff0000000101, 0x0101ff000001ff00, 0x0101ff00000100ff, 0x0101ff0001ff0000, + 0x0101ff000100ffff, 0x0101ff000100ff01, 0x0101ff0001000001, 0x0101ff0001000100, + 0x0101ff01ffffff01, 0x0101ff01ffff01ff, 0x0101ff01ffff0101, 0x0101ff01ff00ffff, + 0x0101ff01ff000100, 0x0101ff01ff01ff01, 0x0101ff01ff0101ff, 0x0101ff01ff010101, + 0x0101ff0100ff0000, 0x0101ff010000ff00, 0x0101ff0100000001, 0x0101ff0100000100, + 0x0101ff0100010000, 0x0101ff0101ffffff, 0x0101ff0101ffff01, 0x0101ff0101ff01ff, + 0x0101ff0101ff0101, 0x0101ff0101000000, 0x0101ff010101ffff, 0x0101ff010101ff01, + 0x0101ff01010101ff, 0x0101ff0101010101, 0x010100ffff000100, 0x010100ffff010000, + 0x010100ff00ffff00, 0x010100ff00ff00ff, 0x010100ff0000ffff, 0x010100ff000000ff, + 0x010100ff00000000, 0x010100ff000001ff, 0x010100ff00000101, 0x010100ff0001ff00, + 0x010100ff00010000, 0x010100ff00010001, 0x010100ff000101ff, 0x010100ff00010100, + 0x010100ff01ff0000, 0x01010000ffff0001, 0x01010000ffff0100, 0x01010000ff00ffff, + 0x01010000ff00ff01, 0x01010000ff000000, 0x01010000ff0001ff, 0x01010000ff010001, + 0x01010000ff010100, 0x0101000000ffff01, 0x0101000000ff0000, 0x010100000000ff00, + 0x01010000000000ff, 0x0101000000000000, 0x0101000000000001, 0x0101000000000100, + 0x0101000000010000, 0x0101000000010101, 0x0101000001ffff00, 0x0101000001ff00ff, + 0x0101000001ff0000, 0x0101000001ff0001, 0x0101000001ff0100, 0x010100000100ff01, + 0x0101000001000000, 0x01010000010001ff, 0x01010001ffff0000, 0x01010001ff00ff00, + 0x01010001ff000001, 0x01010001ff000101, 0x01010001ff01ff00, 0x01010001ff010000, + 0x0101000100ff00ff, 0x0101000100ff0001, 0x0101000100ff0101, 0x010100010000ff01, + 0x0101000100000000, 0x0101000100000001, 0x01010001000001ff, 0x010100010001ffff, + 0x010100010001ff01, 0x0101000101ff0001, 0x010100010100ffff, 0x0101000101000000, + 0x0101000101000001, 0x0101000101000100, 0x010100010101ff00, 0x01010001010100ff, + 0x0101000101010001, 0x010101ffffffffff, 0x010101ffffffff01, 0x010101ffffff01ff, + 0x010101ffffff0101, 0x010101ffff01ffff, 0x010101ffff01ff01, 0x010101ffff0101ff, + 0x010101ffff010101, 0x010101ff0000ff00, 0x010101ff000000ff, 0x010101ff00000001, + 0x010101ff00000100, 0x010101ff01ffffff, 0x010101ff01ffff01, 0x010101ff01ff01ff, + 0x010101ff01ff0101, 0x010101ff01000000, 0x010101ff0101ffff, 0x010101ff0101ff01, + 0x010101ff010101ff, 0x010101ff01010101, 0x01010100ffff0000, 0x01010100ff0000ff, + 0x01010100ff000100, 0x01010100ff01ff00, 0x01010100ff010000, 0x0101010000ffff00, + 0x010101000000ffff, 0x0101010000000000, 0x0101010000000101, 0x010101000001ff00, + 0x0101010000010001, 0x0101010000010100, 0x010101000100ffff, 0x0101010001000001, + 0x01010101ffffffff, 0x01010101ffffff01, 0x01010101ffff01ff, 0x01010101ffff0101, + 0x01010101ff01ffff, 0x01010101ff01ff01, 0x01010101ff0101ff, 0x01010101ff010101, + 0x010101010000ff00, 0x01010101000000ff, 0x0101010100000001, 0x0101010101ffffff, + 0x0101010101ffff01, 0x0101010101ff01ff, 0x0101010101ff0101, 0x0101010101000000, + 0x010101010101ffff, 0x010101010101ff01, 0x01010101010101ff, 0x0101010101010101, +GGML_TABLE_END() +#else +GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S) + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, +GGML_TABLE_END() +#endif + +#endif // GGML_COMMON_IMPL +#endif // GGML_COMMON_IMPL diff --git a/bindings/ruby/ext/ggml-cuda.h b/bindings/ruby/ext/ggml-cuda.h new file mode 100644 index 00000000000..5eb4af40f4d --- /dev/null +++ b/bindings/ruby/ext/ggml-cuda.h @@ -0,0 +1,43 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_HIPBLAS +#define GGML_CUDA_NAME "ROCm" +#define GGML_CUBLAS_NAME "hipBLAS" +#else +#define GGML_CUDA_NAME "CUDA" +#define GGML_CUBLAS_NAME "cuBLAS" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_CUDA_MAX_DEVICES 16 + +// backend API +GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device); + +GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend); + +// device buffer +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); + +// split tensor buffer that splits matrices by rows across multiple devices +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split); + +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); + +GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void); +GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size); +GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total); + +GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); +GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ggml-impl.h b/bindings/ruby/ext/ggml-impl.h index d88f261449f..93a4f1a2b72 100644 --- a/bindings/ruby/ext/ggml-impl.h +++ b/bindings/ruby/ext/ggml-impl.h @@ -5,6 +5,7 @@ // GGML internal header #include +#include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ #include #include #include // memcpy @@ -18,6 +19,7 @@ extern "C" { // fall back to the _Static_assert C11 keyword. // if C99 - static_assert is noop // ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef __cplusplus #ifndef static_assert #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) #define static_assert(cond, msg) _Static_assert(cond, msg) @@ -25,6 +27,7 @@ extern "C" { #define static_assert(cond, msg) struct global_scope_noop_trick #endif #endif +#endif // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) @@ -34,16 +37,17 @@ extern "C" { #ifndef __F16C__ #define __F16C__ #endif +#endif + +// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available +#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)) #ifndef __SSE3__ #define __SSE3__ #endif +#ifndef __SSSE3__ +#define __SSSE3__ +#endif #endif - -#undef MIN -#undef MAX - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) // 16-bit float // on Arm, we use __fp16 @@ -56,14 +60,30 @@ extern "C" { // #include -#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) -#define GGML_COMPUTE_FP32_TO_FP16(x) (x) +typedef __fp16 ggml_fp16_internal_t; -#define GGML_FP16_TO_FP32(x) ((float) (x)) -#define GGML_FP32_TO_FP16(x) (x) +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + ggml_fp16_internal_t tmp; + memcpy(&tmp, &h, sizeof(ggml_fp16_t)); + return (float)tmp; +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + ggml_fp16_t res; + ggml_fp16_internal_t tmp = f; + memcpy(&res, &tmp, sizeof(ggml_fp16_t)); + return res; +} #else +typedef uint16_t ggml_fp16_internal_t; + #ifdef __wasm_simd128__ #include #else @@ -217,8 +237,7 @@ extern float ggml_table_f32_f16[1 << 16]; // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. -#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) - +#if !defined(GGML_FP16_TO_FP32) inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { uint16_t s; memcpy(&s, &f, sizeof(uint16_t)); @@ -226,19 +245,23 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { } #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) +#endif +#if !defined(GGML_FP32_TO_FP16) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) #endif #define GGML_HASHTABLE_FULL ((size_t)-1) #define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2) +struct ggml_hash_set ggml_hash_set_new(size_t size); + bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key); // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key); -// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full +// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key); // return index, asserts if table is full diff --git a/bindings/ruby/ext/ggml-kompute.h b/bindings/ruby/ext/ggml-kompute.h new file mode 100644 index 00000000000..171465456a5 --- /dev/null +++ b/bindings/ruby/ext/ggml-kompute.h @@ -0,0 +1,46 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_vk_device { + int index; + int type; // same as VkPhysicalDeviceType + size_t heapSize; + const char * name; + const char * vendor; + int subgroupSize; + uint64_t bufferAlignment; + uint64_t maxAlloc; +}; + +struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count); +bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name); +bool ggml_vk_has_vulkan(void); +bool ggml_vk_has_device(void); +struct ggml_vk_device ggml_vk_current_device(void); + +// +// backend API +// + +// forward declaration +typedef struct ggml_backend * ggml_backend_t; + +GGML_API ggml_backend_t ggml_backend_kompute_init(int device); + +GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend); + +GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ggml-metal.h b/bindings/ruby/ext/ggml-metal.h new file mode 100644 index 00000000000..a5c542189c2 --- /dev/null +++ b/bindings/ruby/ext/ggml-metal.h @@ -0,0 +1,66 @@ +// An interface allowing to compute ggml_cgraph with Metal +// +// This is a fully functional interface that extends ggml with GPU support for Apple devices. +// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.) +// +// How it works? +// +// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this +// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you +// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.) +// +// You only need to make sure that all memory buffers that you used during the graph creation +// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is +// used during the graph evaluation to determine the arguments of the compute kernels. +// +// Synchronization between device and host memory (for example for input and output tensors) +// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions. +// + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include +#include + +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 64 + +struct ggml_tensor; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +// +// backend API +// user-code should use only these functions +// + +GGML_API void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data); + +GGML_API ggml_backend_t ggml_backend_metal_init(void); + +GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); + +GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size); + +GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); + +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); + +// helper to check if the device supports a specific family +// ideally, the user code should be doing these checks +// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf +GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); + +// capture all command buffers committed the next time `ggml_backend_graph_compute` is called +GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); + +#ifdef __cplusplus +} +#endif + diff --git a/bindings/ruby/ext/ggml-opencl.h b/bindings/ruby/ext/ggml-opencl.h new file mode 100644 index 00000000000..257a6be6af5 --- /dev/null +++ b/bindings/ruby/ext/ggml-opencl.h @@ -0,0 +1,36 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +GGML_API void ggml_cl_init(void); + +GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +GGML_API void ggml_cl_add(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); +GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); + +// GGML_API void * ggml_cl_host_malloc(size_t size); +// GGML_API void ggml_cl_host_free(void * ptr); + +GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor); + +GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor); + +// backend API + +// GGML_API ggml_backend_t ggml_backend_opencl_init(void); + +// GGML_API bool ggml_backend_is_opencl(ggml_backend_t backend); + +GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void); +// GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ggml-quants.c b/bindings/ruby/ext/ggml-quants.c index 740be6dc5c7..32e84434a8c 100644 --- a/bindings/ruby/ext/ggml-quants.c +++ b/bindings/ruby/ext/ggml-quants.c @@ -1,10 +1,18 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + #include "ggml-quants.h" #include "ggml-impl.h" +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + #include #include #include #include +#include // for qsort +#include // for GGML_ASSERT #ifdef __ARM_NEON @@ -14,32 +22,12 @@ // #include -#if !defined(__aarch64__) -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} -#endif - #else #ifdef __wasm_simd128__ #include #else -#ifdef __POWER9_VECTOR__ +#if defined(__POWER9_VECTOR__) || defined(__powerpc64__) #include #undef bool #define bool _Bool @@ -47,13 +35,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #if defined(_MSC_VER) || defined(__MINGW32__) #include #else -#if !defined(__riscv) && !defined(__s390__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) #include #endif #endif #endif #endif #endif +#endif #ifdef __riscv_v_intrinsic #include @@ -61,9 +51,13 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #undef MIN #undef MAX + #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define UNUSED GGML_UNUSED + +// some compilers don't provide _mm256_set_m128i, e.g. gcc 7 #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) @@ -138,7 +132,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { } static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if __AVXVNNI__ +#if defined(__AVXVNNI__) || defined(__AVX512VNNI__) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); @@ -284,8 +278,50 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #if defined(__ARM_NEON) +#ifdef _MSC_VER + +#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } + +#else + +#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } + +#endif + #if !defined(__aarch64__) +// 64-bit compatibility + +// vaddvq_s16 +// vpaddq_s16 +// vpaddq_s32 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + inline static int32_t vaddvq_s32(int32x4_t v) { return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); } @@ -311,7 +347,185 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { return res; } +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +// NOTE: not tested +inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +// NOTE: not tested +inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { + uint8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 +#define ggml_vqtbl1q_s8 vqtbl1q_s8 +#define ggml_vqtbl1q_u8 vqtbl1q_u8 + +#endif + +#if !defined(__ARM_FEATURE_DOTPROD) + +inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); + const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + + return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); +} + +#else + +#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) + #endif + #endif #if defined(__ARM_NEON) || defined(__wasm_simd128__) @@ -330,7 +544,7 @@ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 #endif // reference implementation for deterministic creation of model files -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { +void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -367,11 +581,12 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict } } -void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_0_reference(x, y, k); } -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { + +void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; assert(k % qk == 0); @@ -408,11 +623,11 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict } } -void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_1_reference(x, y, k); } -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { +void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -456,11 +671,11 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict } } -void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { +void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_0_reference(x, y, k); } -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { +void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; assert(k % qk == 0); @@ -504,12 +719,12 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict } } -void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_1_reference(x, y, k); } // reference implementation for deterministic creation of model files -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { +void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -534,7 +749,7 @@ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict } } -void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -723,7 +938,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { } // reference implementation for deterministic creation of model files -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { +void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -739,7 +954,7 @@ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); int sum = 0; @@ -754,11 +969,11 @@ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict sum += y[i].qs[QK8_1/2 + j]; } - y[i].s = sum*d; + y[i].s = GGML_FP32_TO_FP16(sum*d); } } -void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -782,7 +997,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); int32x4_t accv = vdupq_n_s32(0); @@ -798,7 +1013,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { accv = vaddq_s32(accv, vi); } - y[i].s = d * vaddvq_s32(accv); + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } #elif defined(__wasm_simd128__) for (int i = 0; i < nb; i++) { @@ -821,7 +1036,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); v128_t accv = wasm_i32x4_splat(0); @@ -837,10 +1052,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { accv = wasm_i32x4_add(accv, vi); } - y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + - wasm_i32x4_extract_lane(accv, 1) + - wasm_i32x4_extract_lane(accv, 2) + - wasm_i32x4_extract_lane(accv, 3)); + y[i].s = GGML_FP32_TO_FP16( + d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3))); } #elif defined(__AVX2__) || defined(__AVX__) for (int i = 0; i < nb; i++) { @@ -865,7 +1081,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { // Quantize these floats const float d = maxScalar / 127.f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -889,7 +1105,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { #if defined(__AVX2__) // Compute the sum of the quants and set y[i].s - y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -919,7 +1135,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { // Compute the sum of the quants and set y[i].s const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1))); // Convert int32 to int16 ni0 = _mm_packs_epi32( ni0, ni1 ); @@ -950,7 +1166,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); @@ -967,7 +1183,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { // set y[i].s int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); - y[i].s = sum*d; + y[i].s = GGML_FP32_TO_FP16(sum*d); } #else GGML_UNUSED(nb); @@ -976,7 +1192,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { #endif } -void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -996,7 +1212,7 @@ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int } } -void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_1; assert(k % qk == 0); @@ -1017,7 +1233,7 @@ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int } } -void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -1043,7 +1259,7 @@ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int } } -void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { +void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK5_1; assert(k % qk == 0); @@ -1070,7 +1286,7 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int } } -void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK8_0; assert(k % qk == 0); @@ -1100,7 +1316,8 @@ static inline int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } -static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) { +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type, + const float * restrict qw) { float max = 0; float amax = 0; for (int i = 0; i < n; ++i) { @@ -1126,14 +1343,18 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * rmse_type = -rmse_type; return_early = true; } - int weight_type = rmse_type%2; float sumlx = 0; float suml2 = 0; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 0; i < n; ++i) { +#else for (int i = 0; i < n; ++i) { +#endif int l = nearest_int(iscale * x[i]); l = MAX(-nmax, MIN(nmax-1, l)); L[i] = l + nmax; - float w = weight_type == 1 ? x[i] * x[i] : 1; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); sumlx += w*x[i]*l; suml2 += w*l*l; } @@ -1149,7 +1370,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * for (int i = 0; i < n; ++i) { int l = nearest_int(iscale * x[i]); l = MAX(-nmax, MIN(nmax-1, l)); - float w = weight_type == 1 ? x[i] * x[i] : 1; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); sumlx += w*x[i]*l; suml2 += w*l*l; } @@ -1273,7 +1494,12 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f float max = x[0]; float sum_w = weights[0]; float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else for (int i = 1; i < n; ++i) { +#endif if (x[i] < min) min = x[i]; if (x[i] > max) max = x[i]; float w = weights[i]; @@ -1355,7 +1581,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * //========================- 2-bit (de)-quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) { +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1432,7 +1658,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict } } -void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) { +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1478,64 +1704,322 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int } } -void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q2_K_reference(x, vy, k); } -size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - (void)hist; // TODO: collect histograms - - for (int j = 0; j < n; j += k) { - block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K; - quantize_row_q2_K_reference(src + j, y, k); +static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, + uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights ? weights[0] : x[0]*x[0]; + float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else + for (int i = 1; i < n; ++i) { +#endif + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights ? weights[i] : x[i]*x[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) { + min = 0; + } + if (max <= min) { + memset(L, 0, n); + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff*diff; + float w = weights ? weights[i] : x[i]*x[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights ? weights[i] : x[i]*x[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff*diff; + float w = weights ? weights[i] : x[i]*x[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } } - return (n/QK_K*sizeof(block_q2_K)); + *the_min = -min; + return scale; } -//========================= 3-bit (de)-quantization +static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) { + float max = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + if (!max) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = nmax / max; + for (int i = 0; i < n; ++i) { + L[i] = nearest_int(iscale * x[i]); + } + float scale = 1/iscale; + float best_mse = 0; + for (int i = 0; i < n; ++i) { + float diff = x[i] - scale*L[i]; + float w = quant_weights[i]; + best_mse += w*diff*diff; + } + for (int is = -4; is <= 4; ++is) { + if (is == 0) continue; + float iscale_is = (0.1f*is + nmax)/max; + float scale_is = 1/iscale_is; + float mse = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale_is*x[i]); + l = MIN(nmax, l); + float diff = x[i] - scale_is*l; + float w = quant_weights[i]; + mse += w*diff*diff; + } + if (mse < best_mse) { + best_mse = mse; + iscale = iscale_is; + } + } + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MIN(nmax, l); + L[i] = l; + float w = quant_weights[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = quant_weights[i]; + float slx = sumlx - w*x[i]*L[i]; + float sl2 = suml2 - w*L[i]*L[i]; + if (slx > 0 && sl2 > 0) { + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MIN(nmax, new_l); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + return sumlx / suml2; +} -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { +static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) { + GGML_ASSERT(quant_weights); assert(k % QK_K == 0); const int nb = k / QK_K; + const bool requantize = true; - int8_t L[QK_K]; - float scales[QK_K / 16]; + uint8_t L[QK_K]; + uint8_t Laux[16]; + float mins[QK_K/16]; + float scales[QK_K/16]; + float sw[QK_K/16]; + float weight[16]; + uint8_t Ls[QK_K/16], Lm[QK_K/16]; for (int i = 0; i < nb; i++) { - - float max_scale = 0; - float amax = 0; + memset(sw, 0, QK_K/16*sizeof(float)); + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; + float sigma2 = sumx2/QK_K; for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); - float scale = fabsf(scales[j]); - if (scale > amax) { - amax = scale; max_scale = scales[j]; - } + const float * restrict qw = quant_weights + QK_K * i + 16*j; + for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); + for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l]; + scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); } -#if QK_K == 256 - memset(y[i].scales, 0, 12); + float dm, mm; +#if QK_K == 64 + float max_scale = 0, max_min = 0; + for (int j = 0; j < QK_K/16; ++j) { + max_scale = MAX(max_scale, scales[j]); + max_min = MAX(max_min, mins[j]); + } + dm = max_scale/15; + mm = max_min/15; if (max_scale) { - float iscale = -32.f/max_scale; + float id = 1/dm; for (int j = 0; j < QK_K/16; ++j) { - int8_t l = nearest_int(iscale*scales[j]); - l = MAX(-32, MIN(31, l)) + 32; - if (j < 8) { - y[i].scales[j] = l & 0xF; - } else { - y[i].scales[j-8] |= ((l & 0xF) << 4); - } - l >>= 4; - y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + int l = nearest_int(id*scales[j]); + Ls[j] = MAX(0, MIN(15, l)); } - y[i].d = GGML_FP32_TO_FP16(1/iscale); } else { - y[i].d = GGML_FP32_TO_FP16(0.f); + memset(Ls, 0, QK_K/16); } + if (max_min) { + float id = 1/mm; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(id*mins[j]); + Lm[j] = MAX(0, MIN(15, l)); + } + } else { + memset(Lm, 0, QK_K/16); + } +#else + dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); + mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); +#endif + y[i].d = GGML_FP32_TO_FP16(dm); + y[i].dmin = GGML_FP32_TO_FP16(mm); + dm = GGML_FP16_TO_FP32(y[i].d); + mm = GGML_FP16_TO_FP32(y[i].dmin); - int8_t sc; for (int j = 0; j < QK_K/16; ++j) { - sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + y[i].scales[j] = Ls[j] | (Lm[j] << 4); + } + + if (requantize) { + for (int j = 0; j < QK_K/16; ++j) { + const float d = dm * (y[i].scales[j] & 0xF); + if (!d) continue; + const float m = mm * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + m)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + } + +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + + } +} + +size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); + if (!quant_weights) { + quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); + float scale = fabsf(scales[j]); + if (scale > amax) { + amax = scale; max_scale = scales[j]; + } + } + +#if QK_K == 256 + memset(y[i].scales, 0, 12); + if (max_scale) { + float iscale = -32.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int8_t l = nearest_int(iscale*scales[j]); + l = MAX(-32, MIN(31, l)) + 32; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = GGML_FP32_TO_FP16(1/iscale); + } else { + y[i].d = GGML_FP32_TO_FP16(0.f); + } + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; float d = GGML_FP16_TO_FP32(y[i].d) * sc; if (!d) { @@ -1608,7 +2092,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict } #if QK_K == 256 -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1658,7 +2142,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } } #else -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); assert(QK_K == 64); const int nb = k / QK_K; @@ -1691,23 +2175,118 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } #endif -void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q3_K_reference(x, vy, k); } -size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - (void)hist; // TODO: collect histograms +static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { +#if QK_K != 256 + (void)quant_weights; + quantize_row_q3_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int nb = n_per_row / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + float weight[16]; + float sw[QK_K / 16]; + int8_t Ls[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; + float sigma2 = 2*sumx2/QK_K; + + for (int j = 0; j < QK_K/16; ++j) { + if (quant_weights) { + const float * qw = quant_weights ? quant_weights + QK_K * i + 16*j : NULL; + for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]); + } else { + for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l]; + } + float sumw = 0; + for (int l = 0; l < 16; ++l) sumw += weight[l]; + sw[j] = sumw; + + scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight); + + } + + memset(y[i].scales, 0, 12); + + float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw); + for (int j = 0; j < QK_K/16; ++j) { + int l = Ls[j]; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = GGML_FP32_TO_FP16(d_block); + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + } +#endif +} - for (int j = 0; j < n; j += k) { - block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K; - quantize_row_q3_K_reference(src + j, y, k); +size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); + if (!quant_weights) { + quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); } - return (n/QK_K*sizeof(block_q3_K)); + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; } // ====================== 4-bit (de)-quantization -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1814,7 +2393,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } -void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1853,28 +2432,111 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int } } -void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q4_K * restrict y = vy; quantize_row_q4_K_reference(x, y, k); } -size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - assert(k % QK_K == 0); - (void)hist; // TODO: collect histograms +static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + quantize_row_q4_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + uint8_t Ls[QK_K/32]; + uint8_t Lm[QK_K/32]; + float weights[32]; + float sw[QK_K/32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float sum_x2 = 0; + for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l]; + float sigma2 = 2*sum_x2/QK_K; + float av_x = sqrtf(sigma2); + + for (int j = 0; j < QK_K/32; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 32*j; + for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]); + } else { + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + } + float sumw = 0; + for (int l = 0; l < 32; ++l) sumw += weights[l]; + sw[j] = sumw; + scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + } + + float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw); + float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw); + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = Ls[j]; + uint8_t lm = Lm[j]; + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = GGML_FP32_TO_FP16(d_block); + y[i].dmin = GGML_FP32_TO_FP16(m_block); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; + } + + x += QK_K; + + } +#endif +} - for (int j = 0; j < n; j += k) { - block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; - quantize_row_q4_K_reference(src + j, y, k); +size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); + if (!quant_weights) { + quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } } - return (n/QK_K*sizeof(block_q4_K)); + return nrow * row_size; } // ====================== 5-bit (de)-quantization -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 uint8_t L[QK_K]; @@ -1965,7 +2627,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict #else float max_scale = 0, amax = 0; for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL); float abs_scale = fabsf(scales[j]); if (abs_scale > amax) { amax = abs_scale; @@ -2014,9 +2676,9 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict } } -void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) { +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -2059,78 +2721,181 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int } } -void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q5_K * restrict y = vy; quantize_row_q5_K_reference(x, y, k); } -size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - assert(k % QK_K == 0); - (void)hist; // TODO: collect histograms - - for (int j = 0; j < n; j += k) { - block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; - quantize_row_q5_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q5_K)); -} - -// ====================== 6-bit (de)-quantization - -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; +static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + quantize_row_q5_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; - int8_t L[QK_K]; - float scales[QK_K/16]; + uint8_t L[QK_K]; + uint8_t Laux[32]; + uint8_t Ls[QK_K/32]; + uint8_t Lm[QK_K/32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + float sw[QK_K/32]; + float weights[32]; for (int i = 0; i < nb; i++) { - float max_scale = 0; - float max_abs_scale = 0; - - for (int ib = 0; ib < QK_K/16; ++ib) { - - const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1); - scales[ib] = scale; + float sum_x2 = 0; + for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l]; + float sigma2 = 2*sum_x2/QK_K; + float av_x = sqrtf(sigma2); - const float abs_scale = fabsf(scale); - if (abs_scale > max_abs_scale) { - max_abs_scale = abs_scale; - max_scale = scale; + for (int j = 0; j < QK_K/32; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 32*j; + for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]); + } else { + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); } + float sumw = 0; + for (int l = 0; l < 32; ++l) sumw += weights[l]; + sw[j] = sumw; + scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); } - if (!max_abs_scale) { - memset(&y[i], 0, sizeof(block_q6_K)); - y[i].d = GGML_FP32_TO_FP16(0.f); - x += QK_K; - continue; - } + float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw); + float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw); - float iscale = -128.f/max_scale; - y[i].d = GGML_FP32_TO_FP16(1/iscale); - for (int ib = 0; ib < QK_K/16; ++ib) { - y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = Ls[j]; + uint8_t lm = Lm[j]; + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } } + y[i].d = GGML_FP32_TO_FP16(d_block); + y[i].dmin = GGML_FP32_TO_FP16(m_block); - for (int j = 0; j < QK_K/16; ++j) { - float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; - if (!d) { - continue; - } - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int(x[16*j + ii]/d); - l = MAX(-32, MIN(31, l)); - L[16*j + ii] = l + 32; + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; } } - uint8_t * restrict ql = y[i].ql; uint8_t * restrict qh = y[i].qh; -#if QK_K == 256 + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } + + x += QK_K; + + } +#endif +} + +size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); + if (!quant_weights) { + quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL); + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + if (!max_abs_scale) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = GGML_FP32_TO_FP16(0.f); + x += QK_K; + continue; + } + + float iscale = -128.f/max_scale; + y[i].d = GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; +#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { const uint8_t q1 = L[j + l + 0] & 0xF; @@ -2160,9 +2925,9 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict } } -void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -2207,459 +2972,815 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int } } -void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q6_K * restrict y = vy; quantize_row_q6_K_reference(x, y, k); } -size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK_K == 0); - (void)hist; // TODO: collect histograms +static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + quantize_row_q6_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; - for (int j = 0; j < n; j += k) { - block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; - quantize_row_q6_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q6_K)); -} + int8_t L[QK_K]; + float scales[QK_K/16]; + //float weights[16]; -//===================================== Q8_K ============================================== + for (int i = 0; i < nb; i++) { -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; + //float sum_x2 = 0; + //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j]; + //float sigma2 = sum_x2/QK_K; - for (int i = 0; i < nb; i++) { + float max_scale = 0; + float max_abs_scale = 0; - float max = 0; - float amax = 0; - for (int j = 0; j < QK_K; ++j) { - float ax = fabsf(x[j]); - if (ax > amax) { - amax = ax; max = x[j]; + for (int ib = 0; ib < QK_K/16; ++ib) { + + float scale; + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 16*ib; + //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]); + //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights); + scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw); + } else { + scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL); + } + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; } + } - if (!amax) { - y[i].d = 0; - memset(y[i].qs, 0, QK_K); + + if (!max_abs_scale) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = GGML_FP32_TO_FP16(0.f); x += QK_K; continue; } - const float iscale = -128.f/max; - for (int j = 0; j < QK_K; ++j) { - int v = nearest_int(iscale*x[j]); - y[i].qs[j] = MIN(127, v); + + float iscale = -128.f/max_scale; + y[i].d = GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); } + for (int j = 0; j < QK_K/16; ++j) { - int sum = 0; + float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } for (int ii = 0; ii < 16; ++ii) { - sum += y[i].qs[j*16 + ii]; + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; } - y[i].bsums[j] = sum; } - y[i].d = 1/iscale; + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } + x += QK_K; + } +#endif } -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK_K; ++j) { - *y++ = x[i].d * x[i].qs[j]; +size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); + if (!quant_weights) { + quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; } } + return nrow * row_size; } -void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { - quantize_row_q8_K_reference(x, y, k); -} +static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK4_0 == 32, "QK4_0 must be 32"); -//===================================== Dot ptoducts ================================= + if (!quant_weights) { + quantize_row_q4_0_reference(x, y, n_per_row); + return; + } -// -// Helper functions -// -#if __AVX__ || __AVX2__ || __AVX512F__ + float weight[QK4_0]; + int8_t L[QK4_0]; -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int64_t nb = n_per_row/QK4_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_0 * ib; + const float * qw = quant_weights + QK4_0 * ib; + for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight); + y[ib].d = GGML_FP32_TO_FP16(d); + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } } -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return _mm_loadu_si128((const __m128i*)k_shuffle + i); + +size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; } -#endif -void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; +static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK4_1 == 32, "QK4_1 must be 32"); - assert(n % qk == 0); + if (!quant_weights) { + quantize_row_q4_1_reference(x, y, n_per_row); + return; + } - const block_q4_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; + float weight[QK4_1]; + uint8_t L[QK4_1], Laux[QK4_1]; -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; - assert(nb % 2 == 0); // TODO: handle odd nb + const int64_t nb = n_per_row/QK4_1; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_1 * ib; + const float * qw = quant_weights + QK4_1 * ib; + for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float min; + float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false); + y[ib].d = GGML_FP32_TO_FP16(d); + y[ib].m = GGML_FP32_TO_FP16(-min); + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } +} - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; +size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); +static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK5_0 == 32, "QK5_0 must be 32"); - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); + if (!quant_weights) { + quantize_row_q5_0_reference(x, y, n_per_row); + return; + } - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + float weight[QK5_0]; + int8_t L[QK5_0]; - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + const int64_t nb = n_per_row/QK5_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK5_0 * ib; + const float * qw = quant_weights + QK5_0 * ib; + for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight); + y[ib].d = GGML_FP32_TO_FP16(d); -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + uint32_t qh = 0; - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + + memcpy(&y[ib].qh, &qh, sizeof(qh)); } +} - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); +size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} - // Main loop - for (int i = 0; i < nb; ++i) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); +static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK5_1 == 32, "QK5_1 must be 32"); - __m256i bx = bytes_from_nibbles_32(x[i].qs); + if (!quant_weights) { + quantize_row_q5_1_reference(x, y, n_per_row); + return; + } - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); + float weight[QK5_1]; + uint8_t L[QK5_1], Laux[QK5_1]; - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; - const __m256 q = mul_sum_i8_pairs_float(bx, by); + const int64_t nb = n_per_row/QK5_1; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK5_1 * ib; + const float * qw = quant_weights + QK5_1 * ib; + for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float min; + float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false); + y[ib].d = GGML_FP32_TO_FP16(d); + y[ib].m = GGML_FP32_TO_FP16(-min); - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( d, q, acc ); + uint32_t qh = 0; + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + memcpy(&y[ib].qh, &qh, sizeof(qh)); } +} - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); +size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); +size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); + quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); +// ====================== "True" 2-bit (de)-quantization - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); +void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - __m128i bx = _mm_and_si128(lowMask, tmp); - __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx, by); + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; - bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx, by); + for (int i = 0; i < nb; i++) { - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); + const float d = GGML_FP16_TO_FP32(x[i].d); - // Apply the scale, and accumulate - acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t)); + const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } } +} - *s = hsum_float_8(acc); -#elif defined(__SSSE3__) - // set constants - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - // Initialize accumulator with zeros - __m128 acc_0 = _mm_setzero_ps(); - __m128 acc_1 = _mm_setzero_ps(); - __m128 acc_2 = _mm_setzero_ps(); - __m128 acc_3 = _mm_setzero_ps(); - - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); +// ====================== 2.3125 bpw (de)-quantization - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); +void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + float db[2]; - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + for (int i = 0; i < nb; i++) { - _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); + const float d = GGML_FP16_TO_FP32(x[i].d); - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511)); + const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9]; + for (int j = 0; j < 8; ++j) { + y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } + } +} - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); +// ====================== 2.5625 bpw (de)-quantization - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); +void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + float db[2]; - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); + for (int i = 0; i < nb; i++) { - // Apply the scale - acc_0 = _mm_mul_ps( d_0_1, p0 ); - acc_1 = _mm_mul_ps( d_0_1, p1 ); - acc_2 = _mm_mul_ps( d_2_3, p2 ); - acc_3 = _mm_mul_ps( d_2_3, p3 ); + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const float dl = db[l/2]; + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + qs += 4; + signs += 4; + } } +} - assert(nb % 2 == 0); // TODO: handle odd nb - - // Main loop - for (int i = 2; i < nb; i+=2) { - _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); +// ====================== 3.0625 bpw (de)-quantization - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); +void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + uint32_t aux32; - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + for (int i = 0; i < nb; i++) { - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * scales_and_signs = qs + QK_K/4; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t)); + const float db = d * (0.5f + (aux32 >> 28)) * 0.5f; + for (int l = 0; l < 4; ++l) { + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qs += 8; + } + } +} - _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); +// ====================== 3.3125 bpw (de)-quantization - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); +void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); + for (int i = 0; i < nb; i++) { - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = x[i].signs; + + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qs += 8; + signs += 4; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qh += 2; + qs += 8; + signs += 4; + } + } +} - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); +// ====================== 1.5625 bpw (de)-quantization - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); +void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - // Apply the scale - __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); - __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); - __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); - __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + for (int i = 0; i < nb; i++) { - // Acummulate - acc_0 = _mm_add_ps(p0_d, acc_0); - acc_1 = _mm_add_ps(p1_d, acc_1); - acc_2 = _mm_add_ps(p2_d, acc_2); - acc_3 = _mm_add_ps(p3_d, acc_3); + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float dl = d * (2*((qh[ib] >> 12) & 7) + 1); + const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + y[j] = dl * (grid[j] + delta); + } + y += 8; + } + qs += 4; + } } +} - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); +void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - for (int i = 0; i < nb; i++) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + float delta[4]; + uint16_t idx[4]; - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); +#if QK_K != 64 + iq1m_scale_t scale; +#endif - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + for (int i = 0; i < nb; i++) { - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + const uint16_t * sc = (const uint16_t *)x[i].scales; +#if QK_K == 64 + const float d = GGML_FP16_TO_FP32(x[i].d); +#else + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = GGML_FP16_TO_FP32(scale.f16); +#endif + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; - // subtract offset - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + for (int ib = 0; ib < QK_K/32; ++ib) { +#if QK_K == 64 + const float dl1 = d * (2*((sc[ib/2] >> (8*(ib%2)+0)) & 0xf) + 1); + const float dl2 = d * (2*((sc[ib/2] >> (8*(ib%2)+4)) & 0xf) + 1); +#else + const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1); + const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1); +#endif + idx[0] = qs[0] | ((qh[0] << 8) & 0x700); + idx[1] = qs[1] | ((qh[0] << 4) & 0x700); + idx[2] = qs[2] | ((qh[1] << 8) & 0x700); + idx[3] = qs[3] | ((qh[1] << 4) & 0x700); + delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; + for (int l = 0; l < 2; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + y[j] = dl1 * (grid[j] + delta[l]); + } + y += 8; + } + for (int l = 2; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + y[j] = dl2 * (grid[j] + delta[l]); + } + y += 8; + } + qs += 4; + qh += 2; + } + } +} - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); +void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) { + assert(k % QK4_NL == 0); + const int64_t nb = k / QK4_NL; - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + for (int i = 0; i < nb; i++) { - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + const uint8_t * qs = x[i].qs; - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + const float d = GGML_FP16_TO_FP32(x[i].d); + for (int j = 0; j < QK4_NL/2; ++j) { + y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf]; + y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4]; + } + y += QK4_NL; + qs += QK4_NL/2; } +} - *s = sumf; +void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); +#if QK_K == 64 + dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k); #else - // scalar - float sumf = 0.0; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { - int sumi = 0; - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F) - 8; - const int v1 = (x[i].qs[j] >> 4) - 8; + const uint8_t * qs = x[i].qs; - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; + y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4]; + } + y += 32; + qs += 16; } + } +#endif +} - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); +//===================================== Q8_K ============================================== + +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + //const float iscale = -128.f/max; + // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward + const float iscale = -127.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; } +} - *s = sumf; -#endif +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK_K; ++j) { + *y++ = x[i].d * x[i].qs[j]; + } + } } -void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; +void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q8_K_reference(x, y, k); +} + +//===================================== Dot ptoducts ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#endif + +void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; const int nb = n / qk; assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q4_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; - // TODO: add WASM SIMD +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_0 * restrict vx0 = vx; + const block_q4_0 * restrict vx1 = vx + bx; + + const block_q8_0 * restrict vy0 = vy; + const block_q8_0 * restrict vy1 = vy + by; + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_0 * restrict b_x0 = &vx0[i]; + const block_q4_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); + const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); + const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); + const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - float summs = 0; - assert(nb % 2 == 0); // TODO: handle odd nb for (int i = 0; i < nb; i += 2) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i + 0]; - const block_q8_1 * restrict y1 = &y[i + 1]; - - summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2670,393 +3791,445 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) || defined(__AVX__) + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); - float summs = 0; - // Main loop for (int i = 0; i < nb; ++i) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - const float d1 = y[i].d; - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - const __m256 d0v = _mm256_set1_ps( d0 ); - const __m256 d1v = _mm256_set1_ps( d1 ); + __m256i qx = bytes_from_nibbles_32(x[i].qs); - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + qx = _mm256_sub_epi8( qx, off ); - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 xy = mul_sum_us8_pairs_float(bx, by); + const __m256 q = mul_sum_i8_pairs_float(qx, qy); - // Accumulate d0*d1*x*y -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d0d1, xy, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); -#endif + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); } - *s = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); - size_t vl = __riscv_vsetvl_e8m1(qk/2); + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - for (int i = 0; i < nb; i++) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + __m128i bx_0 = _mm_and_si128(lowMask, tmp); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0); - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } + // First round without accumulation + { + _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - *s = sumf; -#else - // scalar - float sumf = 0.0; + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); - for (int i = 0; i < nb; i++) { - int sumi = 0; + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F); - const int v1 = (x[i].qs[j] >> 4); + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); - } + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } + _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); - *s = sumf; -#endif -} - -void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); - assert(n % qk == 0); - assert(qk == QK5_0); + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); - const block_q5_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - uint32_t qh0; - uint32_t qh1; + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); - uint64_t tmp0[4]; - uint64_t tmp1[4]; + // Apply the scale + acc_0 = _mm_mul_ps( d_0_1, p0 ); + acc_1 = _mm_mul_ps( d_0_1, p1 ); + acc_2 = _mm_mul_ps( d_2_3, p2 ); + acc_3 = _mm_mul_ps( d_2_3, p3 ); + } assert(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q5_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i]; - const block_q8_0 * restrict y1 = &y[i + 1]; + // Main loop + for (int i = 2; i < nb; i+=2) { + _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); - const uint8x16_t m4b = vdupq_n_u8(0x0F); + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - // extract the 5th bit via lookup table ((!b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); - tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); - // 4-bit -> 8-bit - int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); - uint32_t qh; - uint64_t tmp[4]; + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } - // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q8_0 * restrict y0 = &y[i]; + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; - const v128_t m4b = wasm_i8x16_splat(0x0F); + size_t vl = __riscv_vsetvl_e8m1(qk/2); - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + for (int i = 0; i < nb; i++) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - const v128_t v0 = wasm_v128_load(x0->qs); + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); + // subtract offset + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); - const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); } - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); + *s = sumf; +#else + // scalar + float sumf = 0.0; - // Main loop for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); - bx = _mm256_or_si256(bx, bxhi); + int sumi = 0; - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F) - 8; + const int v1 = (x[i].qs[j] >> 4) - 8; - const __m256 q = mul_sum_i8_pairs_float(bx, by); + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps(d, q, acc); + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); } - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8((char)0xF0); + *s = sumf; +#endif +} - // Main loop - for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); +void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; - __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_andnot_si128(bxhil, mask); - bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + const block_q4_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; - const __m256 q = mul_sum_i8_pairs_float(bx, by); +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_1 * restrict vx0 = vx; + const block_q4_1 * restrict vx1 = vx + bx; + const block_q8_1 * restrict vy0 = vy; + const block_q8_1 * restrict vy1 = vy + by; - /* Multiply q with scale and accumulate */ - acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); - } + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t summs0 = vdupq_n_f32(0.0f); - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; + for (int i = 0; i < nb; i++) { + const block_q4_1 * restrict b_x0 = &vx0[i]; + const block_q4_1 * restrict b_x1 = &vx1[i]; + const block_q8_1 * restrict b_y0 = &vy0[i]; + const block_q8_1 * restrict b_y1 = &vy1[i]; - uint32_t qh; + float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)}; + summs0 += summs_t; - size_t vl = __riscv_vsetvl_e8m1(qk/2); + const uint8x16_t m4b = vdupq_n_u8(0x0F); - // These tempory registers are for masking and shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); - vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); - vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + // 4-bit -> 8-bit + const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + // mmla into int32x4_t + float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, + GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, + GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, + GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); - vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - // load + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + sumv2 = sumv2 + summs0; + + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif + // TODO: add WASM SIMD +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // dot product into int32x4_t + const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d1 = GGML_FP16_TO_FP32(y[i].d); + + summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[i].qs); + const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + // load elements vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); @@ -3068,7 +4241,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); } *s = sumf; @@ -3077,45 +4250,41 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri float sumf = 0.0; for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - int sumi = 0; for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + const int v0 = (x[i].qs[j] & 0x0F); + const int v1 = (x[i].qs[j] >> 4); - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); } *s = sumf; #endif } -void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; +void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; const int nb = n / qk; assert(n % qk == 0); - assert(qk == QK5_1); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q5_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - float summs0 = 0.0f; - float summs1 = 0.0f; - uint32_t qh0; uint32_t qh1; @@ -3125,29 +4294,26 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri assert(nb % 2 == 0); // TODO: handle odd nb for (int i = 0; i < nb; i += 2) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q5_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y1 = &y[i + 1]; + const block_q5_0 * restrict x0 = &x[i]; + const block_q5_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i]; + const block_q8_0 * restrict y1 = &y[i + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); - summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s; - summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s; - - // extract the 5th bit via lookup table ((b) << 4) + // extract the 5th bit via lookup table ((!b) << 4) memcpy(&qh0, x0->qh, sizeof(qh0)); memcpy(&qh1, x1->qh, sizeof(qh1)); - tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; - tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); @@ -3158,16 +4324,16 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri const uint8x16_t v0_1 = vld1q_u8(x1->qs); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - // add high bit - const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); @@ -3175,59 +4341,35 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); + ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif + ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__wasm_simd128__) v128_t sumv = wasm_f32x4_splat(0.0f); - float summs = 0.0f; - uint32_t qh; uint64_t tmp[4]; // TODO: check if unrolling this is better for (int i = 0; i < nb; ++i) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q8_1 * restrict y0 = &y[i]; - - summs += GGML_FP16_TO_FP32(x0->m) * y0->s; + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; - const v128_t m4b = wasm_i8x16_splat(0x0F); + const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit memcpy(&qh, x0->qh, sizeof(qh)); - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3238,9 +4380,9 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri const v128_t v0l = wasm_v128_and (v0, m4b); const v128_t v0h = wasm_u8x16_shr(v0, 4); - // add high bit - const v128_t v0lf = wasm_v128_or(v0l, qhl); - const v128_t v0hf = wasm_v128_or(v0h, qhh); + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); // load y const v128_t v1l = wasm_v128_load(y0->qs); @@ -3258,77 +4400,71 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); // dot product - sumv = wasm_f32x4_add(sumv, - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), wasm_i32x4_dot_i16x8(v0lfh, v1lh)), wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d))); + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); } *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); - float summs = 0.0f; - // Main loop for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i qx = bytes_from_nibbles_32(x[i].qs); __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); - bx = _mm256_or_si256(bx, bxhi); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + qx = _mm256_or_si256(qx, bxhi); - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_us8_pairs_float(bx, by); + const __m256 q = mul_sum_i8_pairs_float(qx, qy); - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); } - *s = hsum_float_8(acc) + summs; + *s = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8(0x10); - - float summs = 0.0f; + __m128i mask = _mm_set1_epi8((char)0xF0); // Main loop for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); const __m256i bxhi = bytes_from_bits_32(x[i].qh); __m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_and_si128(bxhil, mask); - bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); bxl = _mm_or_si128(bxl, bxhil); bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); + bx_0 = MM256_SET_M128I(bxh, bxl); - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_us8_pairs_float(bx, by); + const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); - acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); } - *s = hsum_float_8(acc) + summs; + *s = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) float sumf = 0.0; @@ -3336,30 +4472,30 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri size_t vl = __riscv_vsetvl_e8m1(qk/2); - // temporary registers for shift operations + // These temporary registers are for masking and shift operations vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + + vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); + vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); for (int i = 0; i < nb; i++) { memcpy(&qh, x[i].qh, sizeof(uint32_t)); - // load qh - vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); + // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - - // ((qh >> (j + 12)) ) & 0x10; - vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); + // ((qh & (1u << (j + 16))) >> (j + 12)); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); + vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); // load @@ -3374,8 +4510,11 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); @@ -3387,7 +4526,7 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; } *s = sumf; @@ -3402,540 +4541,499 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri int sumi = 0; for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; } *s = sumf; #endif } -void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; +void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_1; const int nb = n / qk; assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q8_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + float summs0 = 0.0f; + float summs1 = 0.0f; - const int8x16_t x0_0 = vld1q_s8(x0->qs); - const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); - const int8x16_t x1_0 = vld1q_s8(x1->qs); - const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + uint32_t qh0; + uint32_t qh1; - // load y - const int8x16_t y0_0 = vld1q_s8(y0->qs); - const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); - const int8x16_t y1_0 = vld1q_s8(y1->qs); - const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + uint64_t tmp0[4]; + uint64_t tmp1[4]; -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + assert(nb % 2 == 0); // TODO: handle odd nb - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + for (int i = 0; i < nb; i += 2) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q5_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i]; + const block_q8_1 * restrict y1 = &y[i + 1]; -#else - const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); - const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); - const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); - const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); - const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); - const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); - - const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); - const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); - const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } + const uint8x16_t m4b = vdupq_n_u8(0x0F); - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); + summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); + summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); - const __m256 q = mul_sum_i8_pairs_float(bx, by); + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; - // Multiply q with scale and accumulate -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d, q, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); -#endif - } + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk); + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - for (int i = 0; i < nb; i++) { - // load elements - vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); - vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); - int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = sumf; -#else - // scalar - float sumf = 0.0; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); - for (int i = 0; i < nb; i++) { - int sumi = 0; + float summs = 0.0f; - for (int j = 0; j < qk; j++) { - sumi += x[i].qs[j]*y[i].qs[j]; - } + uint32_t qh; + uint64_t tmp[4]; - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); - } + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; - *s = sumf; -#endif -} + summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); -#if QK_K == 256 -void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const v128_t m4b = wasm_i8x16_splat(0x0F); - const block_q2_K * restrict x = vx; - const block_q8_K * restrict y = vy; + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); - const int nb = n / QK_K; + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; -#ifdef __ARM_NEON + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); - const uint8x16_t m3 = vdupq_n_u8(0x3); - const uint8x16_t m4 = vdupq_n_u8(0xF); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif + const v128_t v0 = wasm_v128_load(x0->qs); - int8x16x2_t q2bytes; - uint8_t aux[16]; + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); - float sum = 0; + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); - for (int i = 0; i < nb; ++i) { + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint8_t * restrict sc = x[i].scales; + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - const uint8x16_t mins_and_scales = vld1q_u8(sc); - const uint8x16_t scales = vandq_u8(mins_and_scales, m4); - vst1q_u8(aux, scales); + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + } - const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; - const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), - vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); - const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), - vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); - sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); - int isum = 0; - int is = 0; + float summs = 0.0f; -// We use this macro instead of a function call because for some reason -// the code runs 2-3% slower, even if the function is declared inline -#if defined(__ARM_FEATURE_DOTPROD) -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; -#else -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - {\ - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ - isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ - } -#endif + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); -#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = vld1q_s8_x2(q8); q8 += 32;\ - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ - MULTIPLY_ACCUM_WITH_SCALE((index)); + summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + __m256i qx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + qx = _mm256_or_si256(qx, bxhi); - for (int j = 0; j < QK_K/128; ++j) { + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d)); + const __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); - const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + const __m256 q = mul_sum_us8_pairs_float(qx, qy); - int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); - MULTIPLY_ACCUM_WITH_SCALE(0); + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + *s = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + float summs = 0.0f; - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - is += 8; - } - sum += d * isum; + summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); - } + __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx_0 = MM256_SET_M128I(bxh, bxl); - *s = sum; + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d)); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); -#elif defined __AVX2__ + const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); - const __m256i m3 = _mm256_set1_epi8(3); - const __m128i m4 = _mm_set1_epi8(0xF); + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } - __m256 acc = _mm256_setzero_ps(); + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; - for (int i = 0; i < nb; ++i) { + uint32_t qh; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + size_t vl = __riscv_vsetvl_e8m1(qk/2); - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + // temporary registers for shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m256i mins = _mm256_cvtepi8_epi16(mins8); - const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + // load qh + vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); - const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + // ((qh >> (j + 0)) << 4) & 0x10; + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - __m256i sumi = _mm256_setzero_si256(); + // ((qh >> (j + 12)) ) & 0x10; + vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); - for (int j = 0; j < QK_K/128; ++j) { + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - const __m256i q2_0 = _mm256_and_si256(q2bits, m3); - const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); - const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); - const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); - __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - p0 = _mm256_add_epi32(p0, p1); - p2 = _mm256_add_epi32(p2, p3); + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); - } + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - } + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - *s = hsum_float_8(acc); + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); -#elif defined __AVX__ + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); + } - const __m128i m3 = _mm_set1_epi8(0x3); - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); + *s = sumf; +#else + // scalar + float sumf = 0.0; - __m256 acc = _mm256_setzero_ps(); + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - for (int i = 0; i < nb; ++i) { + int sumi = 0; - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; - // load mins and scales from block_q2_K.scales[QK_K/16] - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); - const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } - // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 - const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); - const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); + } - // sumf += -dmin * summs in 32bits*8 - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); + *s = sumf; +#endif +} - const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); - const __m128i scales[2] = { scales_0, scales_1 }; +void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - for (int j = 0; j < QK_K/128; ++j) { + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; - // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q8_0 * restrict vx0 = vx; + const block_q8_0 * restrict vx1 = vx + bx; + const block_q8_0 * restrict vy0 = vy; + const block_q8_0 * restrict vy1 = vy + by; - // load 2bits*16*8 from block_q2_K.qs[QK_K/4] - __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_1 = _mm_and_si128(q2bits, m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + float32x4_t sumv0 = vdupq_n_f32(0.0f); - // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 - __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); - __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); - __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); - __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); - __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); - __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); - __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); - __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + for (int i = 0; i < nb; i++) { + const block_q8_0 * restrict b_x0 = &vx0[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; - // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 - __m128i shuffle = _mm_set1_epi16(0x0100); - p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); - shuffle = _mm_add_epi16(shuffle, m2); - p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); - shuffle = _mm_add_epi16(shuffle, m2); - p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); - shuffle = _mm_add_epi16(shuffle, m2); - p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); - shuffle = _mm_add_epi16(shuffle, m2); - p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); - shuffle = _mm_add_epi16(shuffle, m2); - p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); - shuffle = _mm_add_epi16(shuffle, m2); - p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); - shuffle = _mm_add_epi16(shuffle, m2); - p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + const block_q8_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; - p0 = _mm_add_epi32(p0, p1); - p2 = _mm_add_epi32(p2, p3); - p4 = _mm_add_epi32(p4, p5); - p6 = _mm_add_epi32(p6, p7); + const int8x16_t x0_l = vld1q_s8(b_x0->qs); + const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); + const int8x16_t x1_l = vld1q_s8(b_x1->qs); + const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); - // isum in 32bits*4*2 - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); - } + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - // sumf += dall * isum - dmin * summs in 32bits - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); - } + float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; - *s = hsum_float_8(acc); + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); -#elif defined __riscv_v_intrinsic + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - float sumf = 0; - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - for (int i = 0; i < nb; ++i) { + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); - size_t vl = 16; + assert(nb % 2 == 0); // TODO: handle odd nb - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - - vl = 32; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - - uint8_t is=0; - int isum=0; - - for (int j = 0; j < QK_K/128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + const __m256 q = mul_sum_i8_pairs_float(qx, qy); - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + // Multiply q with scale and accumulate +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d, q, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); +#endif + } - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + size_t vl = __riscv_vsetvl_e8m1(qk); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + for (int i = 0; i < nb; i++) { + // load elements + vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl); + vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - q2+=32; q8+=128; is=8; + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); - } + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); - sumf += dall * isum; + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); } *s = sumf; - #else + // scalar + float sumf = 0.0; - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + for (int i = 0; i < nb; i++) { + int sumi = 0; - int summs = 0; - for (int j = 0; j < 16; ++j) { - summs += y[i].bsums[j] * (sc[j] >> 4); + for (int j = 0; j < qk; j++) { + sumi += x[i].qs[j]*y[i].qs[j]; } - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - int isum = 0; - int is = 0; - int d; - for (int k = 0; k < QK_K/128; ++k) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - d = sc[is++] & 0xF; - int isuml = 0; - for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - d = sc[is++] & 0xF; - isuml = 0; - for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - shift += 2; - q8 += 32; - } - q2 += 32; - } - sumf += dall * isum - dmin * summs; + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); } + *s = sumf; #endif } -#else - -void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +#if QK_K == 256 +void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); const block_q2_K * restrict x = vx; const block_q8_K * restrict y = vy; @@ -3943,66 +5041,69 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; #ifdef __ARM_NEON - const uint8x16_t m3 = vdupq_n_u8(0x3); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif + const uint8x16_t m4 = vdupq_n_u8(0xF); - int8x16x4_t q2bytes; + const int32x4_t vzero = vdupq_n_s32(0); - uint32_t aux32[2]; - const uint8_t * scales = (const uint8_t *)aux32; + ggml_int8x16x2_t q2bytes; + uint8_t aux[16]; float sum = 0; for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * (float)x[i].d; - const float dmin = -y[i].d * (float)x[i].dmin; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + const uint8_t * restrict sc = x[i].scales; - aux32[0] = sc[0] & 0x0f0f0f0f; - aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); - sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); - int isum1 = 0, isum2 = 0; + int isum = 0; + int is = 0; - const uint8x16_t q2bits = vld1q_u8(q2); +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); - q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); - q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); + for (int j = 0; j < QK_K/128; ++j) { + const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; -#if defined(__ARM_FEATURE_DOTPROD) - isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; - isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; - isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; - isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; -#else - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum1 += vaddvq_s16(p1) * scales[0]; - isum2 += vaddvq_s16(p2) * scales[1]; - - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum1 += vaddvq_s16(p3) * scales[2]; - isum2 += vaddvq_s16(p4) * scales[3]; -#endif - sum += d * (isum1 + isum2); + ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + sum += d * isum; } *s = sum; @@ -4010,17 +5111,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX2__ const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); __m256 acc = _mm256_setzero_ps(); - uint32_t ud, um; - const uint8_t * restrict db = (const uint8_t *)&ud; - const uint8_t * restrict mb = (const uint8_t *)&um; - - float summs = 0; - - // TODO: optimize this - for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); @@ -4029,145 +5123,242 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - ud = (sc[0] >> 0) & 0x0f0f0f0f; - um = (sc[0] >> 4) & 0x0f0f0f0f; + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); - int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; - summs += dmin * smin; + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); - const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); - const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); - const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + __m256i sumi = _mm256_setzero_si256(); - const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + for (int j = 0; j < QK_K/128; ++j) { - const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); - const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); - const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); - const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); - } + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - *s = hsum_float_8(acc) + summs; + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); -#elif defined __AVX__ + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); - const __m128i m3 = _mm_set1_epi8(3); + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); - __m256 acc = _mm256_setzero_ps(); + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); - uint32_t ud, um; - const uint8_t * restrict db = (const uint8_t *)&ud; - const uint8_t * restrict mb = (const uint8_t *)&um; + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } - float summs = 0; + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - // TODO: optimize this + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(0x3); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - ud = (sc[0] >> 0) & 0x0f0f0f0f; - um = (sc[0] >> 4) & 0x0f0f0f0f; + // load mins and scales from block_q2_K.scales[QK_K/16] + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); + const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); - int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; - summs += dmin * smin; + // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 + const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); + const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); - const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + // sumf += -dmin * summs in 32bits*8 + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); + const __m128i scales[2] = { scales_0, scales_1 }; - const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); - const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); - const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); - const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); - const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); + for (int j = 0; j < QK_K/128; ++j) { - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); + // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // load 2bits*16*8 from block_q2_K.qs[QK_K/4] + __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_1 = _mm_and_si128(q2bits, m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 + __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); + __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); + __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); + __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); + __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); + __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); + __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); + __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + + // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 + __m128i shuffle = _mm_set1_epi16(0x0100); + p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); + shuffle = _mm_add_epi16(shuffle, m2); + p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); + shuffle = _mm_add_epi16(shuffle, m2); + p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); + shuffle = _mm_add_epi16(shuffle, m2); + p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); + shuffle = _mm_add_epi16(shuffle, m2); + p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); + shuffle = _mm_add_epi16(shuffle, m2); + p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); + shuffle = _mm_add_epi16(shuffle, m2); + p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); + shuffle = _mm_add_epi16(shuffle, m2); + p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + + p0 = _mm_add_epi32(p0, p1); + p2 = _mm_add_epi32(p2, p3); + p4 = _mm_add_epi32(p4, p5); + p6 = _mm_add_epi32(p6, p7); + + // isum in 32bits*4*2 + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); + } + + // sumf += dall * isum - dmin * summs in 32bits + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); } - *s = hsum_float_8(acc) + summs; + *s = hsum_float_8(acc); #elif defined __riscv_v_intrinsic - uint32_t aux32[2]; - const uint8_t * scales = (const uint8_t *)aux32; - float sumf = 0; + uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * (float)x[i].d; - const float dmin = -y[i].d * (float)x[i].dmin; + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - aux32[0] = sc[0] & 0x0f0f0f0f; - aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + size_t vl = 16; - sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - int isum1 = 0; - int isum2 = 0; + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - size_t vl = 16; + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - // load Q2 - vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); + vl = 32; - vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); - vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); - vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); - vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - // load Q8, and take product with Q2 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + uint8_t is=0; + int isum=0; - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); - vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); - vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); + for (int j = 0; j < QK_K/128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; - isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; - isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; - isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); - sumf += d * (isum1 + isum2); + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2+=32; q8+=128; is=8; + + } + + sumf += dall * isum; } @@ -4177,8 +5368,6 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; - int isum[4]; - for (int i = 0; i < nb; ++i) { const uint8_t * q2 = x[i].qs; @@ -4186,156 +5375,95 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * sc = x[i].scales; int summs = 0; - for (int j = 0; j < QK_K/16; ++j) { + for (int j = 0; j < 16; ++j) { summs += y[i].bsums[j] * (sc[j] >> 4); } const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - isum[0] = isum[1] = isum[2] = isum[3] = 0; - for (int l = 0; l < 16; ++l) { - isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); - isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); - isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); - isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); - } - for (int l = 0; l < 4; ++l) { - isum[l] *= (sc[l] & 0xF); + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; } - sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; + sumf += dall * isum - dmin * summs; } *s = sumf; #endif } -#endif -#if QK_K == 256 -void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); +#else - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; +void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q3_K * restrict x = vx; + const block_q2_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; #ifdef __ARM_NEON + const uint8x16_t m3 = vdupq_n_u8(0x3); - uint32_t aux[3]; - uint32_t utmp[4]; - - const uint8x16_t m3b = vdupq_n_u8(0x3); -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t vzero = vdupq_n_s32(0); -#endif + const int32x4_t vzero = vdupq_n_s32(0); - const uint8x16_t m0 = vdupq_n_u8(1); - const uint8x16_t m1 = vshlq_n_u8(m0, 1); - const uint8x16_t m2 = vshlq_n_u8(m0, 2); - const uint8x16_t m3 = vshlq_n_u8(m0, 3); - const int8_t m32 = 32; + ggml_int8x16x4_t q2bytes; - int8x16x4_t q3bytes; + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; float sum = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; + const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - uint8x16x2_t qhbits = vld1q_u8_x2(qh); - - uint8x16x4_t q3h; - - int32_t isum = 0; - - // Set up scales - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= m32; - - for (int j = 0; j < QK_K/128; ++j) { - - const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; - const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; - const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; - - q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); - q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); - q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); - q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; -#else - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif - scale += 4; + sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); - q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); - q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); - q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + int isum1 = 0, isum2 = 0; - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + const uint8x16_t q2bits = vld1q_u8(q2); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; -#else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif - scale += 4; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - if (j == 0) { - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); - } + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); - } - sum += d * isum; + isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; + isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; + isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; + isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; + sum += d * (isum1 + isum2); } *s = sum; @@ -4343,261 +5471,256 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX2__ const __m256i m3 = _mm256_set1_epi8(3); - const __m256i mone = _mm256_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); __m256 acc = _mm256_setzero_ps(); - uint32_t aux[3]; + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; - // high bit - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; - // integer accumulator - __m256i sumi = _mm256_setzero_si256(); + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); + const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); - int bit = 0; - int is = 0; + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - // prepare low and high bits - const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); - const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; + const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); + const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); + const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); + const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); - const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); + } - const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); - const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; + *s = hsum_float_8(acc) + summs; - const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); - const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; +#elif defined __AVX__ - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m128i m3 = _mm_set1_epi8(3); - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + __m256 acc = _mm256_setzero_ps(); - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + float summs = 0; - // multiply with scales - p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + // TODO: optimize this - // accumulate - p16_0 = _mm256_add_epi32(p16_0, p16_1); - p16_2 = _mm256_add_epi32(p16_2, p16_3); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + for (int i = 0; i < nb; ++i) { - } + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; - } + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; - *s = hsum_float_8(acc); + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; -#elif defined __AVX__ + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - const __m128i m3 = _mm_set1_epi8(3); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - __m256 acc = _mm256_setzero_ps(); + const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); - const uint32_t *aux; + const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); + const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); + const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); + const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - // Set up scales - aux = (const uint32_t *)x[i].scales; - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); - const __m128i scales[2] = { scales_0, scales_1 }; + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; - // high bit *128*2 from block_q3_K.hmask[QK_K/8] - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - // integer accumulator - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + int isum1 = 0; + int isum2 = 0; - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] - const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + size_t vl = 16; - // prepare low and high bits - const int bit = j << 2; + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); - const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); - const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); - const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + // load Q2 + vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); - const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); + vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); - const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); - const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); - const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + // load Q8, and take product with Q2 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); - const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); - const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); + vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); + vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); - // load Q8 quants from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + sumf += d * (isum1 + isum2); - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + } - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); + *s = sumf; - // multiply with scales - __m128i shuffle = _mm_set1_epi16(0x0100); - p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); - shuffle = _mm_add_epi16(shuffle, m2); - p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); - shuffle = _mm_add_epi16(shuffle, m2); - p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); - shuffle = _mm_add_epi16(shuffle, m2); - p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); - shuffle = _mm_add_epi16(shuffle, m2); - p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); - shuffle = _mm_add_epi16(shuffle, m2); - p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); - shuffle = _mm_add_epi16(shuffle, m2); - p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); - shuffle = _mm_add_epi16(shuffle, m2); - p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); +#else - // accumulate - p16_0 = _mm_add_epi32(p16_0, p16_1); - p16_2 = _mm_add_epi32(p16_2, p16_3); - p16_4 = _mm_add_epi32(p16_4, p16_5); - p16_6 = _mm_add_epi32(p16_6, p16_7); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + float sumf = 0; + + int isum[QK_K/16]; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + int summs = 0; + for (int j = 0; j < QK_K/16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); } - // multiply with block scale and accumulate - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + memset(isum, 0, (QK_K/16)*sizeof(int)); + for (int l = 0; l < 16; ++l) { + isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); + isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); + isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); + isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); + } + for (int l = 0; l < QK_K/16; ++l) { + isum[l] *= (sc[l] & 0xF); + } + sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; } + *s = sumf; +#endif +} +#endif - *s = hsum_float_8(acc); +#if QK_K == 256 +void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); -#elif defined __riscv_v_intrinsic + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON uint32_t aux[3]; uint32_t utmp[4]; - float sumf = 0; + const uint8x16_t m3b = vdupq_n_u8(0x3); + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + ggml_int8x16x4_t q3bytes; + + float sum = 0; + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const uint8_t * restrict q3 = x[i].qs; const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; + const int8_t * restrict q8 = y[i].qs; + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales memcpy(aux, x[i].scales, 12); utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); @@ -4605,90 +5728,409 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; - - - size_t vl = 32; - uint8_t m = 1; + for (int j = 0; j < 16; ++j) scale[j] -= m32; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + for (int j = 0; j < QK_K/128; ++j) { - int sum_t = 0; + const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; + const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; + const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; - for (int j = 0; j < QK_K; j += 128) { + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - vl = 32; + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + scale += 4; - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl); - m <<= 1; + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl); - m <<= 1; + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl); - m <<= 1; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl); - m <<= 1; + scale += 4; - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } - vl = 16; + } + sum += d * isum; - // retreive lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + } - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + *s = sum; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); +#elif defined __AVX2__ - q3 += 32; q8 += 128; scale += 8; + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); - } + __m256 acc = _mm256_setzero_ps(); - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + uint32_t aux[3]; - sumf += d*sum_t; + for (int i = 0; i < nb; ++i) { - } + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - *s = sumf; + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; -#else - // scalar version - // This function is written like this so the compiler can manage to vectorize most of it + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + const uint32_t *aux; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + aux = (const uint32_t *)x[i].scales; + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); + const __m128i scales[2] = { scales_0, scales_1 }; + + // high bit *128*2 from block_q3_K.hmask[QK_K/8] + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + + // integer accumulator + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] + const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + + // prepare low and high bits + const int bit = j << 2; + + const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); + const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); + const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); + const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); + const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + + const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); + const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); + const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + + const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); + const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); + const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + + // load Q8 quants from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + // multiply with scales + __m128i shuffle = _mm_set1_epi16(0x0100); + p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); + shuffle = _mm_add_epi16(shuffle, m2); + p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); + shuffle = _mm_add_epi16(shuffle, m2); + p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); + shuffle = _mm_add_epi16(shuffle, m2); + p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); + shuffle = _mm_add_epi16(shuffle, m2); + p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); + shuffle = _mm_add_epi16(shuffle, m2); + p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); + shuffle = _mm_add_epi16(shuffle, m2); + p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); + shuffle = _mm_add_epi16(shuffle, m2); + p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); + + // accumulate + p16_0 = _mm_add_epi32(p16_0, p16_1); + p16_2 = _mm_add_epi32(p16_2, p16_3); + p16_4 = _mm_add_epi32(p16_4, p16_5); + p16_6 = _mm_add_epi32(p16_6, p16_7); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + + } + + // multiply with block scale and accumulate + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint32_t aux[3]; + uint32_t utmp[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + + *s = sumf; + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the // manually vectorized version above. Every other version I tried would run at least 4 times slower. // The ideal situation would be if we could just write the code once, and the compiler would @@ -4701,202 +6143,1541 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri int32_t aux32[8]; memset(sums, 0, 8*sizeof(float)); - uint32_t auxs[4]; - const int8_t * scales = (const int8_t*)auxs; - + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +#else + +void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const uint8x16_t mh = vdupq_n_u8(4); + + ggml_int8x16x4_t q3bytes; + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + ggml_uint8x16x4_t q3h; + + const uint8x8_t hbits = vld1_u8(x[i].hmask); + const uint8x16_t q3bits = vld1q_u8(x[i].qs); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs); + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); + q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q3h.val[1] = vandq_u8(mh, htmp); + q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); + + q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); + q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); + q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); + q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; + + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m1 = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); + const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); + + memcpy(&aux64, x[i].hmask, 8); + + const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); + __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); + q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); + q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); + const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + // multiply with scales + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + p16_0 = _mm256_add_epi32(p16_0, p16_1); + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m1 = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); + const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); + const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); + const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); + + memcpy(&aux64, x[i].hmask, 8); + + __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); + __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); + __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); + q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); + q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); + q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); + q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m128i q3l_0 = _mm_and_si128(q3bits, m3); + const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_1, p16_1); + p16_2 = _mm_madd_epi16(scale_2, p16_2); + p16_3 = _mm_madd_epi16(scale_3, p16_3); + + p16_0 = _mm_add_epi32(p16_0, p16_2); + p16_1 = _mm_add_epi32(p16_1, p16_3); + __m256i p16 = MM256_SET_M128I(p16_1, p16_0); + + // multiply with block scale and accumulate + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load qh + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); + vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + + size_t vl = 16; + + // extend and combine both qh_x1 and qh_x2 + vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); + vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); + vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); + vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); + + // load Q8 and take product with Q3 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; + + sumf += d * isum; + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + int32_t scales[4]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 8; ++l) { + a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); + a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); + a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); + a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); + a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); + a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); + a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); + a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + } + + scales[0] = (x[i].scales[0] & 0xF) - 8; + scales[1] = (x[i].scales[0] >> 4) - 8; + scales[2] = (x[i].scales[1] & 0xF) - 8; + scales[3] = (x[i].scales[1] >> 4) - 8; + + memset(aux32, 0, 8*sizeof(int32_t)); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x2_t q4bytes; + ggml_int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + const __m256i sumj = _mm256_add_epi32(p16l, p16h); + + sumi = _mm256_add_epi32(sumi, sumj); + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_0 = _mm_and_si128(q4bits, m4); + const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_1 = _mm_and_si128(q4bits, m4); + const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + + const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_0 = _mm_add_epi32(sumi_0, p16l); + const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16l = _mm_maddubs_epi16(q4l_1, q8l_1); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_1 = _mm_add_epi32(sumi_1, p16l); + + const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_0 = _mm_add_epi32(sumi_0, p16h); + const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16h = _mm_maddubs_epi16(q4h_1, q8h_1); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_1 = _mm_add_epi32(sumi_1, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); + + } + + *s = sumf; + +#else + + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#else +void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const int32x4_t mzero = vdupq_n_s32(0); + + float sumf = 0; + + ggml_int8x16x2_t q4bytes; + ggml_int8x16x4_t q8bytes; + + float sum_mins = 0.f; + + uint16_t aux16[2]; + const uint8_t * restrict scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); + + q8bytes = ggml_vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); + const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf - sum_mins; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; + const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + + const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); + + const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; + const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); + const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); + const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); + const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); + const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); + const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); + const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); + + const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); + const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __riscv_v_intrinsic + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + + size_t vl = 32; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); + + sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); + + sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); + + } + + *s = sumf; + +#else + + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; + for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + + for (int j = 0; j < QK_K/32; ++j) { + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + const float dl = d * scales[j]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + + sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; + } + + sumf += d * sumi - dmin * sumi_mins; + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + +#if QK_K == 256 + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +#else + // TODO + const float d = 0, dmin = 0; +#endif + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); + __m128i hmask = mone; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int bit = 0; + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + + __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); + __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); + __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); + __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_0, p16_1); + + q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); + q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); + q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + q5_0 = _mm_add_epi8(q5l_0, q5h_0); + q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); + p16_2 = _mm_madd_epi16(scale_1, p16_2); + p16_3 = _mm_madd_epi16(scale_1, p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + float sums = 0.0; + + size_t vl; + + for (int i = 0; i < nb; ++i) { + + vl = 8; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + vl = 32; + int32_t aux32 = 0; + int is = 0; + + uint8_t m = 1; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q5 and Q8 + vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); + vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + + // compute mask for addition + vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl); + m <<= 1; + + vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl); + m <<= 1; + + vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); + vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + + vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); + vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + q5 += 32; q8 += 64; + + } + + vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); + sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + + } + + *s = sumf+sums; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; uint8_t m = 1; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); a += 32; m <<= 1; - q3 += 32; + q4 += 32; } - a = aux8; + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; - memcpy(auxs, x[i].scales, 12); - uint32_t tmp = auxs[2]; - auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - for (int j = 0; j < QK_K/16; ++j) { + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; - #endif - } #else -void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q3_K * restrict x = vx; + const block_q5_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; #ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mh = vdupq_n_u8(16); + const int32x4_t mzero = vdupq_n_s32(0); -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t vzero = vdupq_n_s32(0); -#endif - - const uint8x16_t m3b = vdupq_n_u8(0x3); - const uint8x16_t mh = vdupq_n_u8(4); - - int8x16x4_t q3bytes; - - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; + ggml_int8x16x4_t q5bytes; + ggml_uint8x16x4_t q5h; - float sum = 0; + float sumf = 0; for (int i = 0; i < nb; ++i) { - uint8x16x4_t q3h; - - const uint8x8_t hbits = vld1_u8(x[i].hmask); - const uint8x16_t q3bits = vld1q_u8(x[i].qs); - const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - for (int j = 0; j < 4; ++j) scales[j] -= 8; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int8_t * sc = x[i].scales; - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - const float d = y[i].d * (float)x[i].d; + const uint8x8_t qhbits = vld1_u8(qh); - const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); - q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); - q3h.val[1] = vandq_u8(mh, htmp); - q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); - q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); - q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); - q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); - q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); + const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); + q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vbicq_u8(mh, htmp); + q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; -#else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; -#endif + q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); + q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); + q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); + q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); - sum += d * isum; + int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); } - *s = sum; + *s = sumf; #elif defined __AVX2__ - const __m256i m3 = _mm256_set1_epi8(3); - const __m256i m1 = _mm256_set1_epi8(1); + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i mone = _mm256_set1_epi8(1); __m256 acc = _mm256_setzero_ps(); - uint64_t aux64; - - uint16_t aux16[2]; - const int8_t * aux8 = (const int8_t *)aux16; - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); - const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - memcpy(&aux64, x[i].hmask, 8); + const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); + const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); - const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); - __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); - __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); - q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); - q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); - // load low 2 bits - const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); - // prepare low and high bits - const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); - const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - // load Q8 quants const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - - // multiply with scales - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); + const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); + const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); + const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); - p16_0 = _mm256_add_epi32(p16_0, p16_1); + const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); } @@ -4904,86 +7685,56 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX__ - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m1 = _mm_set1_epi8(1); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mone = _mm_set1_epi8(1); __m256 acc = _mm256_setzero_ps(); - uint64_t aux64; - - uint16_t aux16[2]; - const int8_t * aux8 = (const int8_t *)aux16; - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); - const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); - const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); - const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - memcpy(&aux64, x[i].hmask, 8); + const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); + const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); + const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); + const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); - __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); - __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); - __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); - __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); - q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); - q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); - q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); - q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); - // load low 2 bits - const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); + const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); + const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); + const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); - // prepare low and high bits - const __m128i q3l_0 = _mm_and_si128(q3bits, m3); - const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); + const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); + const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); + const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); + const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); - // load Q8 quants const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); - - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - - // multiply with scales - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_1, p16_1); - p16_2 = _mm_madd_epi16(scale_2, p16_2); - p16_3 = _mm_madd_epi16(scale_3, p16_3); + const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); + const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); - p16_0 = _mm_add_epi32(p16_0, p16_2); - p16_1 = _mm_add_epi32(p16_1, p16_3); - __m256i p16 = MM256_SET_M128I(p16_1, p16_0); + const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); + const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); - // multiply with block scale and accumulate - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); } @@ -4991,72 +7742,69 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __riscv_v_intrinsic - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; - float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - for (int j = 0; j < 4; ++j) scales[j] -= 8; - - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int8_t * sc = x[i].scales; - const float d = y[i].d * (float)x[i].d; + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); // load qh - vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); size_t vl = 16; - // extend and combine both qh_x1 and qh_x2 + // combine both qh_1 and qh_2 vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); - vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); - vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); - vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); - vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); + vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); + vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); + vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); - // load Q3 - vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); + vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); + vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); + vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); - vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); - vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); - vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); - vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); + // load q5 + vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); + vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); - vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); - vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); - vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); - vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); + vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); + vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); + vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); + vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); - // load Q8 and take product with Q3 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); + vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); + vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); + vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); + + // load Q8 and multiply it with Q5 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); - isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; + int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); + int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); + int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); + int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); - sumf += d * isum; + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); } @@ -5064,371 +7812,429 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #else - int8_t aux8[QK_K]; - int16_t aux16[8]; + int8_t aux8[QK_K]; + int16_t aux16[16]; float sums [8]; - int32_t aux32[8]; - int32_t scales[4]; memset(sums, 0, 8*sizeof(float)); float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; const int8_t * restrict q8 = y[i].qs; int8_t * restrict a = aux8; - for (int l = 0; l < 8; ++l) { - a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); - a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); - a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); - a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); - a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); - a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); - a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); - a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + for (int l = 0; l < 32; ++l) { + a[l+ 0] = q4[l] & 0xF; + a[l+32] = q4[l] >> 4; + } + for (int is = 0; is < 8; ++is) { + uint8_t m = 1 << is; + for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); } - scales[0] = (x[i].scales[0] & 0xF) - 8; - scales[1] = (x[i].scales[0] >> 4) - 8; - scales[2] = (x[i].scales[1] & 0xF) - 8; - scales[3] = (x[i].scales[1] >> 4) - 8; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int8_t * restrict sc = x[i].scales; - memset(aux32, 0, 8*sizeof(int32_t)); for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; + const float dl = d * sc[j]; + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); + q8 += 16; a += 16; } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; - #endif - } #endif + #if QK_K == 256 -void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q4_K * restrict x = vx; + const block_q6_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - #ifdef __ARM_NEON + float sum = 0; - const uint8x16_t m4b = vdupq_n_u8(0xf); -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t mzero = vdupq_n_s32(0); -#endif + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + //const int8x16_t m32s = vdupq_n_s8(32); - int8x16x2_t q4bytes; - int8x16x2_t q8bytes; + const uint8x16_t mone = vdupq_n_u8(3); - float sumf = 0; + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); + const float d_all = GGML_FP16_TO_FP32(x[i].d); - uint32x2_t mins8 = { 0 }; - mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); - mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; + const int8_t * restrict scale = x[i].scales; - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; - const uint8_t * scales = (const uint8_t *)utmp; + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + int32_t isum = 0; - int32_t sumi1 = 0; - int32_t sumi2 = 0; + for (int j = 0; j < QK_K/128; ++j) { - for (int j = 0; j < QK_K/64; ++j) { + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; + ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); -#ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + scale += 4; - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - sumi2 += vaddvq_s32(p2) * scales[2*j+1]; -#else - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); -#endif + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; } - - sumf += d * (sumi1 + sumi2); + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); } - - *s = sumf; + *s = sum; #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); - __m256i sumi = _mm256_setzero_si256(); + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - for (int j = 0; j < QK_K/64; ++j) { + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - p16l = _mm256_madd_epi16(scale_l, p16l); + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); - const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - p16h = _mm256_madd_epi16(scale_h, p16h); - const __m256i sumj = _mm256_add_epi32(p16l, p16h); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); - sumi = _mm256_add_epi32(sumi, sumj); } - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); } - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + *s = hsum_float_8(acc); #elif defined __AVX__ const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); __m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128(); - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { + __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + for (int j = 0; j < QK_K/128; ++j) { - const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); + const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_0 = _mm_and_si128(q4bits, m4); - const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_1 = _mm_and_si128(q4bits, m4); - const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); + const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); + const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); + const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); + const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); - const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_0 = _mm_add_epi32(sumi_0, p16l); - const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16l = _mm_maddubs_epi16(q4l_1, q8l_1); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_1 = _mm_add_epi32(sumi_1, p16l); + const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_0 = _mm_add_epi32(sumi_0, p16h); - const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16h = _mm_maddubs_epi16(q4h_1, q8h_1); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_1 = _mm_add_epi32(sumi_1, p16h); + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); + const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); + const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); + const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); + const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); + + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); + p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); + p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); + p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); } - __m256 vd = _mm256_set1_ps(d); __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); } - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + *s = hsum_float_8(acc); #elif defined __riscv_v_intrinsic - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - float sumf = 0; - for (int i = 0; i < nb; ++i) { - size_t vl = 8; + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + const int8_t * restrict scale = x[i].scales; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + size_t vl; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - vl = 32; + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - int32_t sum_1 = 0; - int32_t sum_2 = 0; + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + vl = 16; - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - q4 += 32; q8 += 64; + q6 += 64; qh += 32; q8 += 128; is=8; } - sumf += d*(sum_1 + sum_2); + sumf += d * sum_t; } @@ -5436,10 +8242,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri #else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - int8_t aux8[QK_K]; int16_t aux16[8]; float sums [8]; @@ -5448,35 +8250,26 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - a += 32; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - a += 32; q4 += 32; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; a = aux8; int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; @@ -5486,1797 +8279,4400 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; #endif } + #else -void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + +void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q4_K * restrict x = vx; + const block_q6_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; #ifdef __ARM_NEON + float sum = 0; - const uint8x16_t m4b = vdupq_n_u8(0xf); - -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t mzero = vdupq_n_s32(0); -#endif - - float sumf = 0; - - int8x16x2_t q4bytes; - int8x16x4_t q8bytes; + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int8x16_t m32s = vdupq_n_s8(32); + const int32x4_t vzero = vdupq_n_s32(0); - float sum_mins = 0.f; + const uint8x16_t mone = vdupq_n_u8(3); - uint16_t aux16[2]; - const uint8_t * restrict scales = (const uint8_t *)aux16; + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t * restrict a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); - sum_mins += y[i].d * (float)x[i].d[1] * summi; - - const float d = y[i].d * (float)x[i].d[0]; + const float d_all = GGML_FP16_TO_FP32(x[i].d); - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; -#ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x4(q8); - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int8_t * restrict scale = x[i].scales; - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; + int32_t isum = 0; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + uint8x16_t qhbits = vld1q_u8(qh); + ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6); + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); - const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits, 2); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 4); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); -#else - q8bytes = vld1q_s8_x4(q8); - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; + q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); + q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); - int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; -#endif - sumf += d * (sumi1 + sumi2); + sum += isum * d_all * y[i].d; } - - *s = sumf - sum_mins; + *s = sum; #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); __m256 acc = _mm256_setzero_ps(); - float summs = 0; - - uint16_t aux16[2]; - const uint8_t * scales = (const uint8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; - const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; - const __m256 vd = _mm256_set1_ps(d); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); - - const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - - const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); - - const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); - - } - - *s = hsum_float_8(acc) - summs; - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - uint16_t aux16[2]; - const uint8_t * scales = (const uint8_t *)aux16; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; - const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; - const __m256 vd = _mm256_set1_ps(d); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); - const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); - const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); - const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); - const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); - const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - - const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); - const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); - - const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); - const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); - - } + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); - *s = hsum_float_8(acc) - summs; + __m256i sumi = _mm256_setzero_si256(); -#elif defined __riscv_v_intrinsic + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); - uint16_t s16[2]; - const uint8_t * restrict scales = (const uint8_t *)s16; + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - float sumf = 0; + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); - for (int i = 0; i < nb; ++i) { + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - size_t vl = 32; + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } - sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); + *s = hsum_float_8(acc); - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); +#elif defined __AVX__ - sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); - } + __m256 acc = _mm256_setzero_ps(); - *s = sumf; + for (int i = 0; i < nb; ++i) { -#else + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - uint8_t aux8[QK_K]; - int16_t aux16[16]; - float sums [8]; - memset(sums, 0, 8*sizeof(float)); + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - uint16_t s16[2]; - const uint8_t * restrict scales = (const uint8_t *)s16; + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - uint8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; - for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); - sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); - for (int j = 0; j < QK_K/32; ++j) { - for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; - q8 += 16; a += 16; - for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; - q8 += 16; a += 16; - const float dl = d * scales[j]; - for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); - } - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#endif + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); -#if QK_K == 256 -void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; + __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); - const int nb = n / QK_K; + __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); - uint32_t utmp[4]; + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); -#ifdef __ARM_NEON + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); + } - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mone = vdupq_n_u8(1); - const uint8x16_t mtwo = vdupq_n_u8(2); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t mzero = vdupq_n_s32(0); -#endif + *s = hsum_float_8(acc); - int8x16x4_t q5bytes; +#elif defined __riscv_v_intrinsic float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + const float d_all = GGML_FP16_TO_FP32(x[i].d); - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + const int8_t * restrict scale = x[i].scales; - const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - int32_t sumi_mins = vaddvq_s32(prod); + int32_t isum = 0; - const uint8_t * scales = (const uint8_t *)utmp; + size_t vl = 16; - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - uint8x16x2_t qhbits = vld1q_u8_x2(qh); + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); - uint8x16x4_t q5h; + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); - int32_t sumi = 0; + vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - for (int j = 0; j < QK_K/64; ++j) { + vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); + vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); + vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); + vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); + vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); + vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); + vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); - q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); - q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + // load Q8 and take product + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; + + sumf += isum * d_all * y[i].d; + + } -#if defined(__ARM_FEATURE_DOTPROD) + *s = sumf; - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; #else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; - - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; -#endif + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int l = 0; l < 16; ++l) { + a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; } - - sumf += d * sumi - dmin * sumi_mins; - + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } - + for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; +#endif +} -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m256i mone = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - -#if QK_K == 256 - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; -#else - // TODO - const float d = 0, dmin = 0; #endif - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); +#if defined (__AVX2__) || defined (__ARM_NEON) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); + const block_iq2_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); - __m256i hmask = mone; + const int nb = n / QK_K; - __m256i sumi = _mm256_setzero_si256(); +#if defined(__ARM_NEON) - int bit = 0; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - for (int j = 0; j < QK_K/64; ++j) { + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; - const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.25f * sumf; - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); - hmask = _mm256_slli_epi16(hmask, 1); +#elif defined(__AVX2__) - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); - hmask = _mm256_slli_epi16(hmask, 1); + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; - __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); +#else - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); +void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - } + const block_iq2_xs * restrict x = vx; + const block_q8_K * restrict y = vy; - *s = hsum_float_8(acc) + summs; + const int nb = n / QK_K; -#elif defined __AVX__ +#if defined(__ARM_NEON) - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m2 = _mm_set1_epi8(2); + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - __m256 acc = _mm256_setzero_ps(); + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; - float summs = 0.f; + int32x4x4_t scales32; + float sumf = 0; for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8x8_t scales8 = vld1_u8(x[i].scales); + const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); + const uint8x8_t scales_h = vshr_n_u8(scales8, 4); + uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); + const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); + const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); + scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); + scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); + scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); + scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); + int32x4_t sumi = vdupq_n_s32(0); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); + const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); + sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); + q2 += 8; + } + sumf += d*vaddvq_s32(sumi); + } + *s = 0.125f * sumf; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); +#elif defined(__AVX2__) - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + const __m256i mone = _mm256_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); + const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); + const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); +#if QK_K == 64 + static const uint8_t k_bit_helper[16] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m128i bit_helper = _mm_loadu_si128((const __m128i*)k_bit_helper); + const __m128i m511 = _mm_set1_epi16(511); + typedef union { + __m128i vec_index; + uint16_t index[8]; + } index_t; + + index_t idx; + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const __m128i q2_data = _mm_loadu_si128((const __m128i*)x[i].qs); + idx.vec_index = _mm_and_si128(q2_data, m511); - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); + const __m128i partial_sign_bits = _mm_srli_epi16(q2_data, 9); + const __m128i partial_sign_bits_upper = _mm_srli_epi16(q2_data, 13); + const __m128i partial_sign_bits_for_counting = _mm_xor_si128(partial_sign_bits, partial_sign_bits_upper); - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); - __m128i hmask = mone; + const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); + const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits); + const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits); - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32)); - int bit = 0; + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]], + iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]], + iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]); - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { + __m256i signs; + signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); + signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); - __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); - __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); - __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); + const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1)); + const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1)); - __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_0, p16_1); + const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2)); - q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); - q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); - q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - q5_0 = _mm_add_epi8(q5l_0, q5h_0); - q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sum), accumf); - q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); - p16_2 = _mm_madd_epi16(scale_1, p16_2); - p16_3 = _mm_madd_epi16(scale_1, p16_3); + } - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + *s = 0.125f * hsum_float_8(accumf); +#else - } + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); + const __m256i m511 = _mm256_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + uint64_t aux64; - } + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; - *s = hsum_float_8(acc) + summs; + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; -#elif defined __riscv_v_intrinsic + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - float sumf = 0; - float sums = 0.0; + const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; + aux_gindex = _mm256_and_si256(q2_data, m511); - size_t vl; + const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); + const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); + const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); - for (int i = 0; i < nb; ++i) { + const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); - vl = 8; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); + const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); + const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); + const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + __m256i signs; + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); - vl = 32; - int32_t aux32 = 0; - int is = 0; + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3); + const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4); - uint8_t m = 1; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); - for (int j = 0; j < QK_K/64; ++j) { - // load Q5 and Q8 - vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); - vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); + } - // compute mask for addition - vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl); - m <<= 1; + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl); - m <<= 1; + } - vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); - vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + *s = 0.125f * hsum_float_8(accumf); +#endif - vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); - vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); +#else - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); - q5 += 32; q8 += 64; +void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - } + const block_iq2_s * restrict x = vx; + const block_q8_K * restrict y = vy; - vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); - sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + const int nb = n / QK_K; - } +#if defined(__ARM_NEON) - *s = sumf+sums; + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; -#else + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; + const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t m1 = vdupq_n_u8(1); + const int32x4_t vzero = vdupq_n_s32(0); - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); + uint8x16x2_t vs; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#else + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; -void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); + q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); + q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); + q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); + qs += 8; - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); - const int nb = n / QK_K; + q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); + q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); -#ifdef __ARM_NEON + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mh = vdupq_n_u8(16); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t mzero = vdupq_n_s32(0); -#endif + signs += 4; - int8x16x4_t q5bytes; - uint8x16x4_t q5h; + q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); + q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); - float sumf = 0; + const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); - for (int i = 0; i < nb; ++i) { + sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); + sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); + sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); + sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); + } + sumf += d*(sumi1 + sumi2); + } - const float d = y[i].d * (float)x[i].d; - const int8_t * sc = x[i].scales; + *s = 0.125f * sumf; - const uint8_t * restrict q5 = x[i].qs; +#elif defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); const int8_t * restrict q8 = y[i].qs; - const uint8x8_t qhbits = vld1_u8(qh); + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; - const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); - q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); - q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); - q5h.val[2] = vbicq_u8(mh, htmp); - q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); + __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); - q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); - q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); - q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); - q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); + aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); -#if defined(__ARM_FEATURE_DOTPROD) + signs += 4; - int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); - int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); - int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); - int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 - sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); #else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); + *s = 0.125f * sumf; - sumf += d*sumi; #endif - } +} - *s = sumf; +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); -#elif defined __AVX2__ + const block_iq3_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i mone = _mm256_set1_epi8(1); + const int nb = n / QK_K; - __m256 acc = _mm256_setzero_ps(); +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + float sumf = 0; for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); + q3 += 16; + q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); + q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); + q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); + q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.5f * sumf; - const uint8_t * restrict q5 = x[i].qs; +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); +#else - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + uint32_t aux32; - const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); - const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} - int64_t aux64; - memcpy(&aux64, x[i].qh, 8); - const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); - const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); +void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); + const block_iq3_s * restrict x = vx; + const block_q8_K * restrict y = vy; - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const int nb = n / QK_K; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); +#if defined(__ARM_NEON) - const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); - const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); - const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); - const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); + typedef union { + uint16x8_t vec_index; + uint16_t index[8]; + } vec_index_t; - const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; - acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - } + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; - *s = hsum_float_8(acc); + const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); -#elif defined __AVX__ + const int16x8_t hshift = vld1q_s16(k_shift); + const uint16x8_t m256 = vdupq_n_u16(256); + const uint8x16_t m1 = vdupq_n_u8(1); - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mone = _mm_set1_epi8(1); + uint8x16x2_t vs; + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + vec_index_t idx; - __m256 acc = _mm256_setzero_ps(); +#if QK_K == 256 + uint32_t scales32[2]; + const uint8_t * scales8 = (const uint8_t *)scales32; +#endif + float sumf = 0; for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; +#if QK_K == 256 + memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; +#endif - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + signs += 4; + + q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); +#if QK_K == 256 + sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; + sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; +#else + sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4)); +#endif + } + sumf += d*(sumi1 + sumi2); + } + *s = sumf; - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); +#elif defined(__AVX2__) - const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); - const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); - const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); - const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; - int64_t aux64; - memcpy(&aux64, x[i].qh, 8); - const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); - const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; - const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); - const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); - const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); - const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); - const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); - const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); - const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = _mm256_set1_epi32(256); - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; - const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); - const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); - const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); - const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); - const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); - const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); - const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); - const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; + idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); + idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); + idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); + idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = _mm256_set_epi32( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = _mm256_set_epi32( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint8_t * restrict signs = x[i].signs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +#endif +} - const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); - const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); +#ifdef __AVX2__ +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = _mm256_sign_epi8(x, x); + const __m256i sy = _mm256_sign_epi8(y, x); + return _mm256_maddubs_epi16(ax, sy); +} +#endif - } +void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - *s = hsum_float_8(acc); + const block_iq1_s * restrict x = vx; + const block_q8_K * restrict y = vy; -#elif defined __riscv_v_intrinsic + const int nb = n / QK_K; - float sumf = 0; +#if defined __ARM_NEON + + ggml_int8x16x4_t q1b; + ggml_int8x16x4_t q8b; + float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * (float)x[i].d; - const int8_t * sc = x[i].scales; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + for (int ib = 0; ib < QK_K/32; ib += 2) { - // load qh - vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); - vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); + qs += 8; - size_t vl = 16; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - // combine both qh_1 and qh_2 - vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); - vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); - vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); - vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); - vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + sumi1 += vaddvq_s32(p1) * ls1; + sumi2 += vaddvq_s32(p2) * ls2; + sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); - vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); - vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); - vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); - vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); + } - // load q5 - vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); - vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); + } - vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); - vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); - vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); - vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); + *s = sumf; - vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); - vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); - vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); - vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); +#elif defined __AVX2__ - // load Q8 and multiply it with Q5 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; - int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); - int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); - int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); - int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); + __m256i sumi = _mm256_setzero_si256(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + qs += 8; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2)); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } - sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum); + accum1 += d * sumi1; } - *s = sumf; + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; #else - int8_t aux8[QK_K]; - int16_t aux16[16]; - float sums [8]; - memset(sums, 0, 8*sizeof(float)); - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - int8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) { - a[l+ 0] = q4[l] & 0xF; - a[l+32] = q4[l] >> 4; - } - for (int is = 0; is < 8; ++is) { - uint8_t m = 1 << is; - for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const int8_t * restrict sc = x[i].scales; + for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK_K/16; ++j) { - const float dl = d * sc[j]; - for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); - q8 += 16; a += 16; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); } - for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + #endif } -#endif - -#if QK_K == 256 -void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; + const block_iq1_m * restrict x = vx; + const block_q8_K * restrict y = vy; const int nb = n / QK_K; -#ifdef __ARM_NEON +#if QK_K != 64 + iq1m_scale_t scale; +#endif - float sum = 0; +#if defined __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xF); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); +#if QK_K == 64 + const int32x4_t mask = vdupq_n_s32(0xf); +#else + const int32x4_t mask = vdupq_n_s32(0x7); #endif - //const int8x16_t m32s = vdupq_n_s8(32); + const int32x4_t mone = vdupq_n_s32(1); + const int32x4_t mzero = vdupq_n_s32(0); - const uint8x16_t mone = vdupq_n_u8(3); + ggml_int8x16x4_t deltas; + deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); + deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); + deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); + deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); + + ggml_int8x16x4_t q1b; + ggml_int8x16x4_t q8b; - int8x16x4_t q6bytes; - uint8x16x4_t q6h; + uint32_t aux32; + const uint8_t * aux8 = (const uint8_t *)&aux32; + float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d_all = GGML_FP16_TO_FP32(x[i].d); + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); +#endif - const int8_t * restrict scale = x[i].scales; + int32x4_t sumi1 = mzero; + int32x4_t sumi2 = mzero; - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int8x16_t scales = vld1q_s8(scale); - const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + for (int ib = 0; ib < QK_K/32; ib += 2) { - const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), - vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), - vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), - vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); - int32_t isum_mins = vaddvq_s32(prod); + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); - int32_t isum = 0; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - for (int j = 0; j < QK_K/128; ++j) { + const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); + const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); + const int32x4_t p12 = vpaddq_s32(p1, p2); - uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; - uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; - int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that + aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 2); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); + const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); + const int32x4_t p34 = vpaddq_s32(p3, p4); - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); +#if QK_K == 64 + int32x4_t scales_4 = ggml_vld1q_u32(sc[0] >> 0, sc[0] >> 4, sc[0] >> 8, sc[0] >> 12); +#else + int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); +#endif + scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); -#if defined(__ARM_FEATURE_DOTPROD) + sumi1 = vmlaq_s32(sumi1, scales_4, p12); + sumi2 = vmlaq_s32(sumi2, scales_4, p34); - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; + qs += 8; qh += 4; + + } + +#if QK_K == 64 + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); +#else + sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); +#endif + } + + *s = sumf; +#elif defined __AVX2__ + +#if QK_K == 64 + const __m256i mask = _mm256_set1_epi16(0xf); #else + const __m256i mask = _mm256_set1_epi16(0x7); +#endif + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); #endif - q8bytes = vld1q_s8_x4(q8); q8 += 64; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)] + ); + const __m256i q1b_2 = _mm256_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)] + ); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + + const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m256i dot3 = mul_add_epi8(delta1, q8b_1); + const __m256i dot4 = mul_add_epi8(delta2, q8b_2); +#if QK_K == 64 + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0)); + __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8)); +#else + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); + __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); +#endif + scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); + scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); + const __m256i p1 = _mm256_madd_epi16(dot1, scale1); + const __m256i p2 = _mm256_madd_epi16(dot2, scale2); + const __m256i p3 = _mm256_madd_epi16(dot3, scale1); + const __m256i p4 = _mm256_madd_epi16(dot4, scale2); - shifted = vshrq_n_u8(qhbits.val[0], 4); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[0], 6); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4)); - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + qs += 8; qh += 4; + } -#if defined(__ARM_FEATURE_DOTPROD) +#if QK_K == 64 + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); +#else + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); +#endif + accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); + accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); - //for (int l = 0; l < 4; ++l) { - // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); - // isum += vaddvq_s32(p) * *scale++; - //} #else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; + + int sum1[2], sum2[2], delta[4]; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); #endif + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + delta[0] = qh[0] & 0x08 ? -1 : 1; + delta[1] = qh[0] & 0x80 ? -1 : 1; + delta[2] = qh[1] & 0x08 ? -1 : 1; + delta[3] = qh[1] & 0x80 ? -1 : 1; + sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); + int lsum1 = 0, lsum2 = 0; + for (int j = 0; j < 8; ++j) { + lsum1 += q8[j] * grid[j]; + lsum2 += q8[j]; + } + q8 += 8; + sum1[l/2] += lsum1; + sum2[l/2] += lsum2*delta[l]; + } +#if QK_K == 64 + const int ls1 = 2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1; + const int ls2 = 2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1; +#else + const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; + const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; +#endif + sumi1 += sum1[0] * ls1 + sum1[1] * ls2; + sumi2 += sum2[0] * ls1 + sum2[1] * ls2; + qs += 4; + qh += 2; } - //sum += isum * d_all * y[i].d; - sum += d_all * y[i].d * (isum - 32 * isum_mins); +#if QK_K == 64 + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); +#else + sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); +#endif } - *s = sum; -#elif defined __AVX2__ + *s = sumf; - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); +#endif +} - __m256 acc = _mm256_setzero_ps(); +void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); - for (int i = 0; i < nb; ++i) { + const block_iq4_nl * restrict x = vx; + const block_q8_0 * restrict y = vy; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int nb = n / QK4_NL; - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + uint8x16x2_t q4bits; + int8x16x4_t q4b; + int8x16x4_t q8b; + int32x4_t prod_1, prod_2; - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + float sumf = 0; - __m256i sumi = _mm256_setzero_si256(); + for (int ib = 0; ib < nb; ib += 2) { - int is = 0; + q4bits.val[0] = vld1q_u8(x[ib+0].qs); + q4bits.val[1] = vld1q_u8(x[ib+1].qs); + q8b.val[0] = vld1q_s8(y[ib+0].qs); + q8b.val[1] = vld1q_s8(y[ib+0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib+1].qs); + q8b.val[3] = vld1q_s8(y[ib+1].qs + 16); - for (int j = 0; j < QK_K/128; ++j) { + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + sumf += + GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) + + GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2); + } - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + *s = sumf; + +#elif defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int ib = 0; ib < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs); + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); + const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)), + _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)), + _mm256_cvtepi32_ps(p_2), accum2); + + y += 2; + x += 2; + } + + *s = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#else + float sumf = 0; + for (int ib = 0; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +#endif +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); +#if QK_K == 64 + ggml_vec_dot_iq4_nl_q8_0(n, s, bs, vx, bx, vy, by, nrc); +#else + + const block_iq4_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + ggml_uint8x16x2_t q4bits; + ggml_int8x16x4_t q4b; + ggml_int8x16x4_t q8b; + int32x4_t prod_1, prod_2; - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + float sumf = 0; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + for (int ibl = 0; ibl < nb; ++ibl) { - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + const int8_t * q8 = y[ibl].qs; + const uint8_t * q4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/64; ++ib) { - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + sumi1 += vaddvq_s32(prod_1) * ls1; + sumi2 += vaddvq_s32(prod_2) * ls2; } - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); } - *s = hsum_float_8(acc); + *s = sumf; -#elif defined __AVX__ +#elif defined __AVX2__ - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); + const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); + sumi1 = _mm256_add_epi32(p_1, sumi1); + sumi2 = _mm256_add_epi32(p_2, sumi2); + } + accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); - __m256 acc = _mm256_setzero_ps(); +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +#endif +} - for (int i = 0; i < nb; ++i) { +// ================================ IQ2 quantization ============================================= + +typedef struct { + uint64_t * grid; + int * map; + uint16_t * neighbours; +} iq2_entry_t; + +static iq2_entry_t iq2_data[4] = { + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, +}; + +static inline int iq2_data_index(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + return type == GGML_TYPE_IQ2_XXS ? 0 : + type == GGML_TYPE_IQ2_XS ? 1 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3; +} - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); +static inline int iq2_grid_size(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + return type == GGML_TYPE_IQ2_XXS ? 256 : + type == GGML_TYPE_IQ2_XS ? 512 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024; +} - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +static int iq2_compare_func(const void * left, const void * right) { + const int * l = (const int *)left; + const int * r = (const int *)right; + return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; +} - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); +void iq2xs_init_impl(enum ggml_type type) { + const int gindex = iq2_data_index(type); + const int grid_size = iq2_grid_size(type); + if (iq2_data[gindex].grid) { + return; + } + static const uint16_t kgrid_2bit_256[256] = { + 0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97, + 100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642, + 1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288, + 1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113, + 2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240, + 4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400, + 5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260, + 8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872, + 10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516, + 16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561, + 17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488, + 20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545, + 22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874, + 25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856, + 33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142, + 37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268, + }; + static const uint16_t kgrid_2bit_512[512] = { + 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, + 73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257, + 260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340, + 352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597, + 640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096, + 1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348, + 1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065, + 2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441, + 2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160, + 4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372, + 4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125, + 5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652, + 5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197, + 8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549, + 8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894, + 10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388, + 16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480, + 16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773, + 16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473, + 17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436, + 18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497, + 20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162, + 21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528, + 22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745, + 24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234, + 32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025, + 33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810, + 33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984, + 35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462, + 37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960, + 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048, + 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690, + }; + static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = { + 0, 2, 5, 8, 10, 17, 21, 32, 34, 40, 42, 69, 81, 84, 86, 101, + 128, 130, 136, 138, 149, 160, 162, 168, 170, 260, 261, 273, 276, 278, 281, 282, + 293, 321, 326, 329, 338, 341, 346, 353, 356, 358, 360, 389, 401, 404, 406, 421, + 512, 514, 520, 522, 533, 544, 546, 552, 554, 581, 593, 601, 612, 617, 640, 642, + 648, 650, 657, 661, 665, 672, 674, 680, 682, 1041, 1044, 1046, 1061, 1089, 1097, 1109, + 1114, 1124, 1125, 1169, 1177, 1189, 1281, 1284, 1285, 1286, 1301, 1304, 1306, 1321, 1344, 1349, + 1354, 1360, 1361, 1364, 1365, 1366, 1369, 1376, 1378, 1381, 1384, 1386, 1409, 1425, 1429, 1432, + 1434, 1441, 1444, 1445, 1446, 1449, 1556, 1561, 1601, 1604, 1616, 1618, 1621, 1624, 1632, 1633, + 1638, 1641, 1669, 1681, 1684, 1689, 2048, 2050, 2056, 2058, 2069, 2080, 2082, 2088, 2090, 2117, + 2129, 2134, 2149, 2176, 2178, 2184, 2186, 2197, 2208, 2210, 2216, 2218, 2309, 2321, 2324, 2329, + 2340, 2341, 2369, 2384, 2385, 2389, 2401, 2404, 2409, 2449, 2452, 2454, 2457, 2469, 2560, 2562, + 2568, 2570, 2581, 2592, 2594, 2600, 2602, 2629, 2641, 2649, 2657, 2661, 2688, 2690, 2693, 2696, + 2698, 2709, 2720, 2722, 2728, 2730, 4112, 4113, 4116, 4121, 4132, 4133, 4161, 4164, 4176, 4181, + 4184, 4193, 4196, 4197, 4201, 4241, 4244, 4246, 4257, 4261, 4353, 4356, 4358, 4361, 4368, 4370, + 4373, 4376, 4385, 4388, 4393, 4421, 4426, 4432, 4433, 4434, 4436, 4437, 4438, 4441, 4448, 4453, + 4484, 4498, 4501, 4513, 4516, 4625, 4628, 4630, 4645, 4672, 4678, 4681, 4690, 4693, 4696, 4698, + 4708, 4710, 4741, 4753, 4756, 4758, 4773, 5121, 5126, 5129, 5140, 5141, 5144, 5145, 5153, 5158, + 5185, 5189, 5190, 5192, 5194, 5201, 5204, 5205, 5206, 5209, 5218, 5221, 5224, 5252, 5257, 5264, + 5268, 5269, 5272, 5273, 5274, 5281, 5284, 5285, 5289, 5378, 5381, 5386, 5393, 5396, 5397, 5398, + 5401, 5408, 5410, 5413, 5416, 5418, 5441, 5444, 5445, 5446, 5457, 5458, 5460, 5461, 5462, 5465, + 5466, 5473, 5476, 5477, 5478, 5481, 5504, 5506, 5508, 5509, 5512, 5514, 5520, 5521, 5524, 5525, + 5526, 5529, 5530, 5536, 5538, 5541, 5633, 5636, 5637, 5638, 5653, 5654, 5656, 5658, 5665, 5670, + 5696, 5698, 5700, 5701, 5704, 5706, 5713, 5717, 5718, 5720, 5721, 5729, 5732, 5733, 5736, 5737, + 5738, 5766, 5770, 5778, 5781, 5796, 5801, 6161, 6166, 6181, 6209, 6212, 6214, 6217, 6224, 6229, + 6232, 6234, 6240, 6241, 6244, 6246, 6249, 6277, 6289, 6292, 6309, 6416, 6418, 6421, 6426, 6433, + 6437, 6466, 6468, 6469, 6472, 6481, 6484, 6485, 6486, 6489, 6490, 6496, 6501, 6506, 6537, 6545, + 6546, 6549, 6552, 6561, 6566, 6569, 6665, 6678, 6692, 6694, 6724, 6726, 6729, 6736, 6738, 6741, + 6744, 6753, 6758, 6761, 6789, 6801, 6806, 6810, 8192, 8194, 8200, 8202, 8213, 8224, 8226, 8229, + 8232, 8234, 8261, 8273, 8281, 8289, 8293, 8320, 8322, 8328, 8330, 8341, 8352, 8354, 8357, 8360, + 8362, 8453, 8465, 8468, 8473, 8485, 8514, 8516, 8521, 8533, 8536, 8538, 8545, 8548, 8549, 8550, + 8581, 8592, 8598, 8601, 8613, 8705, 8712, 8714, 8721, 8725, 8736, 8738, 8744, 8746, 8773, 8785, + 8790, 8793, 8805, 8833, 8840, 8842, 8849, 8853, 8864, 8866, 8872, 8874, 9221, 9236, 9238, 9241, + 9253, 9284, 9285, 9286, 9289, 9298, 9301, 9304, 9306, 9318, 9349, 9361, 9364, 9369, 9377, 9381, + 9481, 9493, 9505, 9513, 9536, 9541, 9544, 9553, 9556, 9557, 9561, 9570, 9573, 9576, 9609, 9616, + 9620, 9621, 9624, 9626, 9633, 9636, 9638, 9641, 9733, 9744, 9746, 9753, 9765, 9793, 9801, 9813, + 9824, 9825, 9833, 9860, 9862, 9872, 9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282, + 10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521, + 10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752, + 10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890, + 10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484, + 16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673, + 16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772, + 16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986, + 16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494, + 17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666, + 17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744, + 17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809, + 17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953, + 17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049, + 18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517, + 18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704, + 18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784, + 18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012, + 19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501, + 20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617, + 20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761, + 20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822, + 20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896, + 20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078, + 21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526, + 21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589, + 21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653, + 21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780, + 21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832, + 21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864, + 21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924, + 21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048, + 22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098, + 22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154, + 22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561, + 22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665, + 22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821, + 22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884, + 22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061, + 23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144, + 23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656, + 24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850, + 24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970, + 24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221, + 25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674, + 25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749, + 25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926, + 25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001, + 26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176, + 26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250, + 26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721, + 26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949, + 26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044, + 27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270, + 27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852, + 32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046, + 33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161, + 33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369, + 33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877, + 33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117, + 34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192, + 34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394, + 34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858, + 34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986, + 35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172, + 35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412, + 35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901, + 36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124, + 37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205, + 37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396, + 37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889, + 37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985, + 37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161, + 38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226, + 38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290, + 38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432, + 38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538, + 38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998, + 39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194, + 39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269, + 39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497, + 39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994, + 41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130, + 41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349, + 41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561, + 41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068, + 42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278, + 42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386, + 42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592, + 42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048, + 43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284, + 43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530, + 43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690, + }; + static const uint16_t kgrid_2bit_1024[1024] = { + 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, + 73, 80, 82, 85, 88, 97, 100, 102, 105, 128, 130, 133, 136, 145, 148, 160, + 165, 170, 257, 260, 262, 265, 272, 274, 277, 280, 289, 292, 320, 322, 325, 328, + 337, 340, 342, 345, 352, 357, 360, 385, 388, 400, 402, 405, 417, 420, 512, 514, + 517, 520, 529, 532, 544, 554, 577, 580, 582, 585, 592, 597, 640, 645, 650, 660, + 674, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1062, 1065, 1088, 1090, 1093, + 1096, 1098, 1105, 1108, 1110, 1113, 1120, 1122, 1125, 1153, 1156, 1158, 1161, 1168, 1173, 1176, + 1185, 1188, 1280, 1282, 1285, 1288, 1290, 1297, 1300, 1302, 1305, 1312, 1317, 1320, 1345, 1348, + 1350, 1353, 1360, 1362, 1365, 1368, 1377, 1380, 1408, 1410, 1413, 1416, 1425, 1428, 1440, 1537, + 1540, 1542, 1545, 1552, 1557, 1600, 1605, 1608, 1617, 1620, 1632, 1665, 1668, 1680, 2048, 2050, + 2053, 2056, 2065, 2068, 2070, 2073, 2080, 2085, 2090, 2113, 2116, 2118, 2121, 2128, 2130, 2133, + 2136, 2145, 2148, 2176, 2181, 2196, 2218, 2305, 2308, 2320, 2322, 2325, 2328, 2337, 2368, 2373, + 2376, 2385, 2388, 2400, 2433, 2448, 2560, 2577, 2580, 2594, 2600, 2602, 2640, 2713, 4097, 4100, + 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4134, 4160, 4162, 4165, 4168, 4177, 4180, 4182, + 4185, 4192, 4194, 4197, 4200, 4225, 4228, 4230, 4240, 4245, 4248, 4257, 4260, 4352, 4354, 4357, + 4360, 4362, 4369, 4372, 4374, 4377, 4384, 4386, 4389, 4392, 4417, 4420, 4422, 4425, 4432, 4434, + 4437, 4440, 4449, 4452, 4480, 4482, 4485, 4488, 4497, 4500, 4609, 4612, 4617, 4624, 4629, 4641, + 4644, 4672, 4677, 4689, 4692, 4737, 4740, 4752, 5120, 5122, 5125, 5128, 5137, 5140, 5142, 5145, + 5152, 5157, 5160, 5185, 5188, 5190, 5193, 5200, 5202, 5205, 5208, 5217, 5220, 5248, 5250, 5253, + 5256, 5265, 5268, 5280, 5377, 5380, 5382, 5385, 5392, 5394, 5397, 5400, 5409, 5412, 5440, 5442, + 5445, 5448, 5457, 5460, 5472, 5505, 5508, 5520, 5632, 5637, 5640, 5649, 5652, 5664, 5697, 5700, + 5712, 5760, 5802, 6145, 6148, 6150, 6153, 6160, 6165, 6168, 6177, 6208, 6210, 6213, 6216, 6225, + 6228, 6240, 6273, 6276, 6400, 6402, 6405, 6408, 6417, 6420, 6432, 6465, 6468, 6480, 6505, 6562, + 6660, 6672, 6720, 6742, 8192, 8194, 8197, 8200, 8209, 8212, 8214, 8217, 8224, 8229, 8234, 8257, + 8260, 8272, 8274, 8277, 8292, 8320, 8330, 8340, 8362, 8449, 8452, 8464, 8466, 8469, 8481, 8512, + 8514, 8517, 8529, 8532, 8544, 8577, 8580, 8592, 8704, 8714, 8738, 8744, 8746, 8772, 8784, 8840, + 8842, 8872, 9217, 9220, 9222, 9225, 9232, 9237, 9240, 9249, 9252, 9280, 9282, 9285, 9288, 9297, + 9300, 9312, 9345, 9348, 9360, 9472, 9477, 9480, 9489, 9492, 9504, 9537, 9540, 9552, 9574, 9600, + 9729, 9732, 9744, 9792, 9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500, + 10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410, + 16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513, + 16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674, + 16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785, + 16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025, + 17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476, + 17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665, + 17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760, + 17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085, + 18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528, + 18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948, + 18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548, + 20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740, + 20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865, + 20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510, + 21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636, + 21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054, + 22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800, + 22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645, + 24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912, + 24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680, + 25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880, + 26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850, + 32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060, + 33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345, + 33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873, + 33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176, + 34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076, + 35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928, + 36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200, + 37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968, + 38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976, + 39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130, + 41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121, + 42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690, + }; - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + const int kmap_size = 43692; + //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2; + const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2; + const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 : + type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024; + uint64_t * kgrid_q2xs; + int * kmap_q2xs; + uint16_t * kneighbors_q2xs; + + //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size); + uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t)); + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + for (int i = 0; i < 8; ++i) { + int l = (kgrid[k] >> 2*i) & 0x3; + pos[i] = 2*l + 1; + } + } + kgrid_q2xs = the_grid; + iq2_data[gindex].grid = the_grid; + kmap_q2xs = (int *)malloc(kmap_size*sizeof(int)); + iq2_data[gindex].map = kmap_q2xs; + for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1; + uint64_t aux64; + uint8_t * aux8 = (uint8_t *)&aux64; + for (int i = 0; i < grid_size; ++i) { + aux64 = kgrid_q2xs[i]; + uint16_t index = 0; + for (int k=0; k<8; ++k) { + uint16_t q = (aux8[k] - 1)/2; + index |= (q << 2*k); + } + kmap_q2xs[index] = i; + } + int8_t pos[8]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + //printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq2_data[gindex].neighbours = kneighbors_q2xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + kmap_q2xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} - __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - for (int j = 0; j < QK_K/128; ++j) { +void iq2xs_free_impl(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + const int gindex = iq2_data_index(type); + if (iq2_data[gindex].grid) { + free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL; + free(iq2_data[gindex].map); iq2_data[gindex].map = NULL; + free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL; + } +} - const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; +static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_d2 = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = pg[i]; + float diff = scale*q - xval[i]; + d2 += weight[i]*diff*diff; + } + if (d2 < best_d2) { + best_d2 = d2; grid_index = neighbours[j]; + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); - const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); - const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); - const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); - const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); +static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { - const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS); - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); - const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); - const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); - const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); - const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); + const int kMaxQ = 3; - __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + const int64_t nbl = n/QK_K; - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); + block_iq2_xxs * y = vy; - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); + float scales[QK_K/32]; + float weight[32]; + float xval[32]; + int8_t L[32]; + int8_t Laux[32]; + float waux[32]; + uint8_t block_signs[4]; + uint32_t q2[2*(QK_K/32)]; - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); - p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); - p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); - p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); - p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); + for (int ibl = 0; ibl < nbl; ++ibl) { - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(q2, 0, QK_K/4); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + const float * qw = quant_weights + QK_K*ibl + 32*ib; + for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 4; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } + float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight); + float eff_max = scale*kMaxQ; + float best = 0; + for (int is = -6; is <= 6; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/eff_max; + float this_scale = 1/id; + for (int k = 0; k < 4; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + memcpy(L, Laux, 32); + } + } + if (scale > 0) { + float id = 1/scale; + for (int k = 0; k < 4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index); + for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + q2[2*ib+0] |= (grid_index << 8*k); + q2[2*ib+1] |= (block_signs[k] << 7*k); + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/4); + continue; } - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + float id = 1/d; + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + q2[2*ib+1] |= ((uint32_t)l << 28); + } + memcpy(y[ibl].qs, q2, QK_K/4); } +} - *s = hsum_float_8(acc); +static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { -#elif defined __riscv_v_intrinsic + const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS); - float sumf = 0; - for (int i = 0; i < nb; ++i) { + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + const int kMaxQ = 3; - const int8_t * restrict scale = x[i].scales; + const int64_t nbl = n/QK_K; - size_t vl; + block_iq2_xs * y = vy; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + float scales[QK_K/16]; + float weight[16]; + float xval[16]; + int8_t L[16]; + int8_t Laux[16]; + float waux[16]; + bool is_on_grid[2]; + bool is_on_grid_aux[2]; + uint8_t block_signs[2]; + uint16_t q2[2*(QK_K/16)]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(q2, 0, QK_K/4); + memset(y[ibl].scales, 0, QK_K/32); - int sum_t = 0; - int is = 0; + float max_scale = 0; - for (int j = 0; j < QK_K/128; ++j) { + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; - vl = 32; + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 2; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + memset(L, 0, 16); + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + is_on_grid[0] = is_on_grid[1] = true; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 2; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 16; ++i) L[i] = Laux[i]; + for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 2; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + L[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + scale = -scale; + for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + q2[2*ib+k] = grid_index | (block_signs[k] << 9); + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/4); + continue; + } - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + float id = 1/d; + for (int ib = 0; ib < QK_K/16; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + if (ib%2 == 0) y[ibl].scales[ib/2] = l; + else y[ibl].scales[ib/2] |= (l << 4); + } + memcpy(y[ibl].qs, q2, QK_K/4); - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + } +} - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); +size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_xxs); + } + return nrow * nblock * sizeof(block_iq2_xxs); +} - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); +size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_xs); + } + return nrow * nblock * sizeof(block_iq2_xs); +} + +// +// ============================================= 3-bit using D4 lattice +// + +typedef struct { + uint32_t * grid; + int * map; + uint16_t * neighbours; +} iq3_entry_t; + +static iq3_entry_t iq3_data[2] = { + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, +}; + +static inline int iq3_data_index(int grid_size) { + (void)grid_size; + GGML_ASSERT(grid_size == 256 || grid_size == 512); + return grid_size == 256 ? 0 : 1; +} + +static int iq3_compare_func(const void * left, const void * right) { + const int * l = (const int *)left; + const int * r = (const int *)right; + return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; +} - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); +void iq3xs_init_impl(int grid_size) { + const int gindex = iq3_data_index(grid_size); + if (iq3_data[gindex].grid) { + return; + } + static const uint16_t kgrid_256[256] = { + 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74, + 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159, + 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321, + 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531, + 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664, + 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978, + 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105, + 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228, + 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553, + 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722, + 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063, + 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389, + 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746, + 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153, + 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, + 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, + }; + static const uint16_t kgrid_512[512] = { + 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34, + 37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77, + 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142, + 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210, + 217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288, + 291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393, + 395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514, + 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576, + 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653, + 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727, + 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833, + 840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977, + 989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047, + 1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103, + 1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199, + 1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296, + 1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415, + 1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561, + 1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648, + 1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761, + 1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877, + 1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068, + 2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177, + 2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269, + 2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520, + 2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634, + 2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805, + 2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083, + 3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276, + 3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591, + 3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729, + 3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032, + }; - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + const int kmap_size = 4096; + const int nwant = grid_size == 256 ? 2 : 3; + const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + uint32_t * kgrid_q3xs; + int * kmap_q3xs; + uint16_t * kneighbors_q3xs; + + //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size); + uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t)); + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + for (int i = 0; i < 4; ++i) { + int l = (kgrid[k] >> 3*i) & 0x7; + pos[i] = 2*l + 1; + } + } + kgrid_q3xs = the_grid; + iq3_data[gindex].grid = the_grid; + kmap_q3xs = (int *)malloc(kmap_size*sizeof(int)); + iq3_data[gindex].map = kmap_q3xs; + for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1; + uint32_t aux32; + uint8_t * aux8 = (uint8_t *)&aux32; + for (int i = 0; i < grid_size; ++i) { + aux32 = kgrid_q3xs[i]; + uint16_t index = 0; + for (int k=0; k<4; ++k) { + uint16_t q = (aux8[k] - 1)/2; + index |= (q << 3*k); + } + kmap_q3xs[index] = i; + } + int8_t pos[4]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + //printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq3_data[gindex].neighbours = kneighbors_q3xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + kmap_q3xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} - vl = 16; +void iq3xs_free_impl(int grid_size) { + GGML_ASSERT(grid_size == 256 || grid_size == 512); + const int gindex = iq3_data_index(grid_size); + if (iq3_data[gindex].grid) { + free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL; + free(iq3_data[gindex].map); iq3_data[gindex].map = NULL; + free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL; + } +} - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); +static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_d2 = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 4; ++i) { + float q = pg[i]; + float diff = scale*q - xval[i]; + d2 += weight[i]*diff*diff; + } + if (d2 < best_d2) { + best_d2 = d2; grid_index = neighbours[j]; + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); +static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n, + const float * restrict quant_weights) { - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + const int gindex = iq3_data_index(grid_size); - q6 += 64; qh += 32; q8 += 128; is=8; + const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; + const int * kmap_q3xs = iq3_data[gindex].map; + const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours; - } + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - sumf += d * sum_t; + const int kMaxQ = 8; + + const int64_t nbl = n/QK_K; + ggml_fp16_t * dh; + uint8_t * qs; + int block_size; + if (grid_size == 256) { + block_iq3_xxs * y = vy; + dh = &y->d; + qs = y->qs; + block_size = sizeof(block_iq3_xxs); + } else { + block_iq3_s * y = vy; + dh = &y->d; + qs = y->qs; + block_size = sizeof(block_iq3_s); } + int quant_size = block_size - sizeof(ggml_fp16_t); - *s = sumf; + float scales[QK_K/32]; + float weight[32]; + float xval[32]; + int8_t L[32]; + int8_t Laux[32]; + float waux[32]; + bool is_on_grid[8]; + bool is_on_grid_aux[8]; + uint8_t block_signs[8]; + uint8_t q3[3*(QK_K/8)+QK_K/32]; + uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4); + uint8_t * qh = q3 + 3*(QK_K/8); + + for (int ibl = 0; ibl < nbl; ++ibl) { + + dh[0] = GGML_FP32_TO_FP16(0.f); + memset(q3, 0, 3*QK_K/8+QK_K/32); -#else + float max_scale = 0; - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 32*ib; + for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i]; } - a += 128; - q4 += 64; - qh += 32; + for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 4; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + for (int is = -15; is <= 15; ++is) { + float id = (2*kMaxQ-1+is*0.2f)/max; + float this_scale = 1/id; + for (int k = 0; k < 8; ++k) { + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 32; ++i) L[i] = Laux[i]; + for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 8; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 3*i); + } + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index); + for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 8; ++k) { + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + if (grid_size == 256) { + q3[8*ib+k] = grid_index; + } else { + q3[8*ib+k] = grid_index & 255; + qh[ib] |= ((grid_index >> 8) << k); + } + + } + scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21); + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); } - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; + + if (!max_scale) { + memset(qs, 0, quant_size); + dh += block_size/sizeof(ggml_fp16_t); + qs += block_size; + continue; } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + + float d = max_scale/31; + dh[0] = GGML_FP32_TO_FP16(d * 1.0125f); // small improvement via this fudge factor + float id = 1/d; + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + scales_and_signs[ib] |= ((uint32_t)l << 28); + } + memcpy(qs, q3, quant_size); + + dh += block_size/sizeof(ggml_fp16_t); + qs += block_size; + } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif } -#else +size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_xxs); + } + return nrow * nblock * sizeof(block_iq3_xxs); +} -void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); +void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_iq3_xxs * restrict y = vy; + quantize_row_iq3_xxs_reference(x, y, k); +} - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; +void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_row_iq3_xxs_impl(256, x, y, k, NULL); +} - const int nb = n / QK_K; +static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n, + const float * restrict quant_weights, + float * scales, + float * weight, + float * xval, + int8_t * L, + int8_t * Laux, + float * waux, + bool * is_on_grid, + bool * is_on_grid_aux, + uint8_t * block_signs) { -#ifdef __ARM_NEON + const int gindex = iq3_data_index(512); - float sum = 0; + const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; + const int * kmap_q3xs = iq3_data[gindex].map; + const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours; - const uint8x16_t m4b = vdupq_n_u8(0xF); - const int8x16_t m32s = vdupq_n_s8(32); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - const uint8x16_t mone = vdupq_n_u8(3); + const int kMaxQ = 8; - int8x16x4_t q6bytes; - uint8x16x4_t q6h; + const int64_t nbl = n/QK_K; - for (int i = 0; i < nb; ++i) { + block_iq3_s * y = vy; - const float d_all = (float)x[i].d; + const int bs4 = block_size/4; + const int bs8 = block_size/8; - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + for (int ibl = 0; ibl < nbl; ++ibl) { - const int8_t * restrict scale = x[i].scales; + memset(&y[ibl], 0, sizeof(block_iq3_s)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); - int32_t isum = 0; + uint8_t * qs = y[ibl].qs; + uint8_t * qh = y[ibl].qh; + uint8_t * signs = y[ibl].signs; - uint8x16_t qhbits = vld1q_u8(qh); - uint8x16x2_t q6bits = vld1q_u8_x2(q6); - int8x16x4_t q8bytes = vld1q_s8_x4(q8); + float max_scale = 0; - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits, 2); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits, 4); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits, 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; - q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); - q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < bs8; ++k) { + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; s |= (1 << i); + } + } + block_signs[k] = s; + } + float max = xval[0]; + for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + for (int k = 0; k < bs4; ++k) is_on_grid[k] = false; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.2f)/max; + float this_scale = 1/id; + for (int k = 0; k < bs4; ++k) { + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < block_size; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < block_size; ++i) L[i] = Laux[i]; + for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < bs4; ++k) { + //if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 3*i); + } + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index); + for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < block_size; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k]; + } + for (int k = 0; k < bs4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + qs[k] = grid_index & 255; + qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8)); + } + qs += bs4; + for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k]; + signs += bs8; + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } -#if defined(__ARM_FEATURE_DOTPROD) + if (!max_scale) { + continue; + } - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; -#else + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f); + float id = 1/d; + for (int ib = 0; ib < QK_K/block_size; ib += 2) { + int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); + l1 = MAX(0, MIN(15, l1)); + int l2 = nearest_int(0.5f*(id*scales[ib+1]-1)); + l2 = MAX(0, MIN(15, l2)); + y[ibl].scales[ib/2] = l1 | (l2 << 4); + } - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif + } +} - sum += isum * d_all * y[i].d; +#define IQ3S_BLOCK_SIZE 32 +size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + float scales[QK_K/IQ3S_BLOCK_SIZE]; + float weight[IQ3S_BLOCK_SIZE]; + float xval[IQ3S_BLOCK_SIZE]; + int8_t L[IQ3S_BLOCK_SIZE]; + int8_t Laux[IQ3S_BLOCK_SIZE]; + float waux[IQ3S_BLOCK_SIZE]; + bool is_on_grid[IQ3S_BLOCK_SIZE/4]; + bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4]; + uint8_t block_signs[IQ3S_BLOCK_SIZE/8]; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights, + scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_s); + } + return nrow * nblock * sizeof(block_iq3_s); +} - } - *s = sum; +void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_iq3_s * restrict y = vy; + quantize_row_iq3_s_reference(x, y, k); +} -#elif defined __AVX2__ +void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq3_s(x, y, 1, k, NULL); +} - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); - __m256 acc = _mm256_setzero_ps(); +// =================================== 1.5 bpw =================================================== + +static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = 0; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + *scale = sumqx/sumq2; best_score = *scale * sumqx; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = (grid_i[j] - 3)/2; + sumqx += w*q*xval[j]; + sumq2 += w*q*q; + } + if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + *scale = sumqx/sumq2; best_score = *scale*sumqx; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + *scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result. + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} - for (int i = 0; i < nb; ++i) { +static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, const float * restrict xg, int8_t * restrict L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = xg[(pg[i] - 1)/2]; + float w = weight[i]; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float d2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = xg[(grid_i[j] - 1)/2]; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = xg[(pg[i] - 1)/2]; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); +static int iq1_sort_helper(const void * left, const void * right) { + const float * l = left; + const float * r = right; + return *l < *r ? -1 : *l > *r ? 1 : 0; +} - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +#define IQ1S_BLOCK_SIZE 32 +#define IQ1M_BLOCK_SIZE 16 +static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, + float * scales, + float * weight, + float * sumx, + float * sumw, + float * pairs, + int8_t * L, + uint16_t * index, + int8_t * shifts) { - const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); - const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); - const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); - const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + const int gindex = iq2_data_index(GGML_TYPE_IQ1_S); - __m256i sumi = _mm256_setzero_si256(); + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); - const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + block_iq1_s * y = vy; - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); + const int64_t nbl = n/QK_K; - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); + const int block_size = IQ1S_BLOCK_SIZE; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA}; + const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA}; - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + int * idx = (int *)(pairs + 1); - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + for (int ibl = 0; ibl < nbl; ++ibl) { - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(y[ibl].qs, 0, QK_K/8); + memset(y[ibl].qh, 0, QK_K/16); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + float max = fabsf(xb[0]); + for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + if (!max) { + scales[ib] = 0; + memset(L, 1, block_size); + continue; + } + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + for (int j = 0; j < block_size; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); + { + sumx[0] = sumw[0] = 0; + for (int j = 0; j < block_size; ++j) { + int i = idx[2*j]; + sumx[j+1] = sumx[j] + weight[i]*xb[i]; + sumw[j+1] = sumw[j] + weight[i]; + } + } + float best_score = 0, scale = max; + int besti1 = -1, besti2 = -1, best_shift = 0; + for (int i1 = 0; i1 <= block_size; ++i1) { + for (int i2 = i1; i2 <= block_size; ++i2) { + float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2]; + float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2]; + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; best_shift = 1; + } + sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2]; + sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2]; + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; best_shift = -1; + } + } + } + GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; + scale = -scale; best_shift = -best_shift; + } + bool all_on_grid = true; + const float * xx = best_shift == 1 ? x_p : x_m; + for (int k = 0; k < block_size/8; ++k) { + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + index[k] = grid_index; + } + if (!all_on_grid) { + float sumqx = 0, sumq2 = 0; + for (int k = 0; k < block_size/8; ++k) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]; + sumqx += w*q*xb[8*k+j]; + sumq2 += w*q*q; + } + } + if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; + } + uint16_t h = 0; + for (int k = 0; k < block_size/8; ++k) { + y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255; + h |= (index[k] >> 8) << 3*k; + } + y[ibl].qh[ib] = h; + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + shifts[ib] = best_shift; + max_scale = MAX(max_scale, scale); + } - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + if (!max_scale) { + continue; + } - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + float d = max_scale/15; + y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed. + float id = 1/d; + for (int ib = 0; ib < QK_K/block_size; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(7, l)); + if (shifts[ib] == -1) l |= 8; + y[ibl].qh[ib] |= (l << 12); + } } +} - *s = hsum_float_8(acc); - -#elif defined __AVX__ +size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + float scales[QK_K/IQ1S_BLOCK_SIZE]; + float weight[IQ1S_BLOCK_SIZE]; + int8_t L[IQ1S_BLOCK_SIZE]; + float sumx[IQ1S_BLOCK_SIZE+1]; + float sumw[IQ1S_BLOCK_SIZE+1]; + float pairs[2*IQ1S_BLOCK_SIZE]; + uint16_t index[IQ1S_BLOCK_SIZE/8]; + int8_t shifts[QK_K/IQ1S_BLOCK_SIZE]; + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts); + src += n_per_row; + qrow += nblock*sizeof(block_iq1_s); + } + return nrow * nblock * sizeof(block_iq1_s); +} - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); +static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, + float * scales, + float * weight, + float * pairs, + int8_t * L, + uint16_t * index, + int8_t * shifts) { - __m256 acc = _mm256_setzero_ps(); + const int gindex = iq2_data_index(GGML_TYPE_IQ1_M); - for (int i = 0; i < nb; ++i) { + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + block_iq1_m * y = vy; - const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); - const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); - const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); - const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + const int64_t nbl = n/QK_K; - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + const int block_size = IQ1M_BLOCK_SIZE; - const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); - const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA}; + const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; + const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88}; - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + int * idx = (int *)(pairs + 1); - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); + float sumqx[4], sumq2[4]; - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); + iq1m_scale_t s; + const float * xx; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + for (int ibl = 0; ibl < nbl; ++ibl) { - __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); +#if QK_K == 64 + y[ibl].d = GGML_FP32_TO_FP16(0.f); +#endif + memset(y[ibl].qs, 0, QK_K/8); + memset(y[ibl].qh, 0, QK_K/16); + memset(y[ibl].scales, 0, QK_K/32); - __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + float max_scale = 0; - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + float max = fabsf(xb[0]); + for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + if (!max) { + scales[ib] = 0; + memset(L, 1, block_size); + continue; + } + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + for (int j = 0; j < block_size; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); + float best_score = 0, scale = max; + int besti1 = -1, besti2 = -1, best_k = -1; + // 0: +, + + // 1: +, - + // 2: -, + + // 3: -, - + for (int i1 = 0; i1 <= block_size; ++i1) { + for (int i2 = i1; i2 <= block_size; ++i2) { + memset(sumqx, 0, 4*sizeof(float)); + memset(sumq2, 0, 4*sizeof(float)); + for (int j = 0; j < i1; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[0]*xb[i]; + sumqx[1] += weight[i]*x_p[0]*xb[i]; + sumqx[2] += weight[i]*x_m[0]*xb[i]; + sumqx[3] += weight[i]*x_m[0]*xb[i]; + sumq2[0] += weight[i]*x_p[0]*x_p[0]; + sumq2[1] += weight[i]*x_p[0]*x_p[0]; + sumq2[2] += weight[i]*x_m[0]*x_m[0]; + sumq2[3] += weight[i]*x_m[0]*x_m[0]; + } else { + sumqx[0] += weight[i]*x_p[0]*xb[i]; + sumqx[2] += weight[i]*x_p[0]*xb[i]; + sumqx[1] += weight[i]*x_m[0]*xb[i]; + sumqx[3] += weight[i]*x_m[0]*xb[i]; + sumq2[0] += weight[i]*x_p[0]*x_p[0]; + sumq2[2] += weight[i]*x_p[0]*x_p[0]; + sumq2[1] += weight[i]*x_m[0]*x_m[0]; + sumq2[3] += weight[i]*x_m[0]*x_m[0]; + } + } + for (int j = i1; j < i2; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[1]*xb[i]; + sumqx[1] += weight[i]*x_p[1]*xb[i]; + sumqx[2] += weight[i]*x_m[1]*xb[i]; + sumqx[3] += weight[i]*x_m[1]*xb[i]; + sumq2[0] += weight[i]*x_p[1]*x_p[1]; + sumq2[1] += weight[i]*x_p[1]*x_p[1]; + sumq2[2] += weight[i]*x_m[1]*x_m[1]; + sumq2[3] += weight[i]*x_m[1]*x_m[1]; + } else { + sumqx[0] += weight[i]*x_p[1]*xb[i]; + sumqx[2] += weight[i]*x_p[1]*xb[i]; + sumqx[1] += weight[i]*x_m[1]*xb[i]; + sumqx[3] += weight[i]*x_m[1]*xb[i]; + sumq2[0] += weight[i]*x_p[1]*x_p[1]; + sumq2[2] += weight[i]*x_p[1]*x_p[1]; + sumq2[1] += weight[i]*x_m[1]*x_m[1]; + sumq2[3] += weight[i]*x_m[1]*x_m[1]; + } + } + for (int j = i2; j < block_size; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[2]*xb[i]; + sumqx[1] += weight[i]*x_p[2]*xb[i]; + sumqx[2] += weight[i]*x_m[2]*xb[i]; + sumqx[3] += weight[i]*x_m[2]*xb[i]; + sumq2[0] += weight[i]*x_p[2]*x_p[2]; + sumq2[1] += weight[i]*x_p[2]*x_p[2]; + sumq2[2] += weight[i]*x_m[2]*x_m[2]; + sumq2[3] += weight[i]*x_m[2]*x_m[2]; + } else { + sumqx[0] += weight[i]*x_p[2]*xb[i]; + sumqx[2] += weight[i]*x_p[2]*xb[i]; + sumqx[1] += weight[i]*x_m[2]*xb[i]; + sumqx[3] += weight[i]*x_m[2]*xb[i]; + sumq2[0] += weight[i]*x_p[2]*x_p[2]; + sumq2[2] += weight[i]*x_p[2]*x_p[2]; + sumq2[1] += weight[i]*x_m[2]*x_m[2]; + sumq2[3] += weight[i]*x_m[2]*x_m[2]; + } + } + for (int k = 0; k < 4; ++k) { + if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { + scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; + besti1 = i1; besti2 = i2; best_k = k; + } + } + } + } + GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; + scale = -scale; + best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0; + } + bool all_on_grid = true; + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = best_k < 2 ? x_p : x_m; + else xx = best_k%2 == 0 ? x_p : x_m; + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + index[k] = grid_index; + } + if (!all_on_grid) { + float sumqx_f = 0, sumq2_f = 0; + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = best_k < 2 ? x_p : x_m; + else xx = best_k%2 == 0 ? x_p : x_m; + const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]; + sumqx_f += w*q*xb[8*k+j]; + sumq2_f += w*q*q; + } + } + if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f; + } + y[ibl].qs[2*ib + 0] = index[0] & 255; + y[ibl].qs[2*ib + 1] = index[1] & 255; + y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4); + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + shifts[ib] = best_k; + max_scale = MAX(max_scale, scale); + } - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + if (!max_scale) { + continue; + } - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); + uint16_t * sc = (uint16_t *)y[ibl].scales; +#if QK_K == 64 + float d = max_scale/31; +#else + float d = max_scale/15; +#endif + float id = 1/d; + float sumqx_f = 0, sumq2_f = 0; + for (int ib = 0; ib < QK_K/block_size; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib+0]-1)); +#if QK_K == 64 + l = MAX(0, MIN(15, l)); + sc[ib/4] |= (l << 4*(ib%4)); +#else + l = MAX(0, MIN(7, l)); + sc[ib/4] |= (l << 3*(ib%4)); +#endif + y[ibl].qh[ib] |= masks[shifts[ib]]; + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m; + else xx = shifts[ib]%2 == 0 ? x_p : x_m; + const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700)); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]*(2*l+1); + sumqx_f += w*q*xb[8*k+j]; + sumq2_f += w*q*q; + } + } + } + if (sumq2_f > 0) d = sumqx_f/sumq2_f; + s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed. +#if QK_K == 64 + y[ibl].d = s.f16; +#else + sc[0] |= ((s.u16 & 0x000f) << 12); + sc[1] |= ((s.u16 & 0x00f0) << 8); + sc[2] |= ((s.u16 & 0x0f00) << 4); + sc[3] |= ((s.u16 & 0xf000) << 0); +#endif } +} - *s = hsum_float_8(acc); +size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + float scales[QK_K/IQ1M_BLOCK_SIZE]; + float weight[IQ1M_BLOCK_SIZE]; + int8_t L[IQ1M_BLOCK_SIZE]; + float pairs[2*IQ1M_BLOCK_SIZE]; + uint16_t index[IQ1M_BLOCK_SIZE/8]; + int8_t shifts[QK_K/IQ1M_BLOCK_SIZE]; + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts); + src += n_per_row; + qrow += nblock*sizeof(block_iq1_m); + } + return nrow * nblock * sizeof(block_iq1_m); +} -#elif defined __riscv_v_intrinsic +// ============================ 4-bit non-linear quants - float sumf = 0; +static inline int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} - for (int i = 0; i < nb; ++i) { +static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * restrict x, + ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l, + float * scales, float * weight, uint8_t * L, + const int8_t * values, + const float * quant_weights, + const int ntry) { + + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j]; + sigma2 *= 2.f/super_block_size; + + memset(q4, 0, super_block_size/2); + dh[0] = GGML_FP32_TO_FP16(0.f); + + float max_scale = 0, amax_scale = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = x + ib*block_size; + uint8_t * Lb = L + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + Lb[j] = l; + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + d = sumqx/sumq2; + float best = d*sumqx; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx = sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d * sumqx; + } + } + scales[ib] = d; + float abs_d = fabsf(d); + if (abs_d > amax_scale) { + amax_scale = abs_d; max_scale = d; + } + } + + if (super_block_size/block_size > 1) { + int nb = super_block_size/block_size; + memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t)); + float d = -max_scale/32; + dh[0] = GGML_FP32_TO_FP16(d); + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + int l = nearest_int(id*scales[ib]); + l = MAX(-32, MIN(31, l)); + float dl = d * l; + float idl = dl ? 1/dl : 0.f; + uint8_t * Lb = L + ib*block_size; + const float * xb = x + ib*block_size; + for (int j = 0; j < block_size; ++j) { + Lb[j] = best_index_int8(16, values, idl*xb[j]); + } + l += 32; + uint8_t l_l = l & 0xf; + uint8_t l_h = l >> 4; + if (ib%2 == 0) scales_l[ib/2] = l_l; + else scales_l[ib/2] |= (l_l << 4); + scales_h[ib/8] |= (l_h << 2*(ib%8)); + } + } else { + dh[0] = GGML_FP32_TO_FP16(scales[0]); + if (ntry > 0) { + float id = scales[0] ? 1/scales[0] : 0; + for (int j = 0; j < super_block_size; ++j) { + L[j] = best_index_int8(16, values, id*x[j]); + } + } + } - const float d_all = (float)x[i].d; + for (int i = 0; i < super_block_size/32; ++i) { + for (int j = 0; j < 16; ++j) { + q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4); + } + } +} - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK4_NL == 0); + int64_t nblock = n_per_row/QK4_NL; + char * qrow = (char *)dst; + uint8_t L[QK4_NL]; + float weight[QK4_NL]; + uint16_t unused_h; + uint8_t * unused_l = NULL; + float scale; + for (int64_t row = 0; row < nrow; ++row) { + block_iq4_nl * iq4 = (block_iq4_nl *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; + quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, + &scale, weight, L, kvalues_iq4nl, qw, 7); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_nl); + } + return nrow * nblock * sizeof(block_iq4_nl); +} - const int8_t * restrict scale = x[i].scales; +void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) { + GGML_ASSERT(k%QK4_NL == 0); + int64_t nblock = k/QK4_NL; + uint8_t L[QK4_NL]; + float weight[QK4_NL]; + uint16_t unused_h; + uint8_t * unused_l = NULL; + float scale; + block_iq4_nl * iq4 = (block_iq4_nl *)vy; + for (int ibl = 0; ibl < nblock; ++ibl) { + quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, + &scale, weight, L, kvalues_iq4nl, NULL, -1); + } +} - int32_t isum = 0; +void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { + assert(k % QK4_NL == 0); + quantize_row_iq4_nl(x, y, k); +} - size_t vl = 16; +size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { +#if QK_K == 64 + return quantize_iq4_nl(src, dst, nrow, n_per_row, quant_weights); +#else + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + uint8_t L[QK_K]; + float weight[32]; + float scales[QK_K/32]; + for (int64_t row = 0; row < nrow; ++row) { + block_iq4_xs * iq4 = (block_iq4_xs *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL; + quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l, + scales, weight, L, kvalues_iq4nl, qw, 7); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_xs); + } + return nrow * nblock * sizeof(block_iq4_xs); +#endif +} - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); +void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_iq4_xs * restrict y = vy; + quantize_row_iq4_xs_reference(x, y, k); +} - // load Q6 - vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); - vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); +void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_xs(x, y, 1, k, NULL); +} - // load qh - vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); +// =============================== 2.5625 bpw - vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); +static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { - vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); - vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); - vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); - vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); + const int gindex = iq2_data_index(GGML_TYPE_IQ2_S); - vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); - vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); - vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); - vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - // load Q8 and take product - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + const int kMaxQ = 3; - isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; + const int64_t nbl = n/QK_K; - sumf += isum * d_all * y[i].d; + block_iq2_s * y = vy; - } + float scales[QK_K/16]; + float weight[16]; + float xval[16]; + int8_t L[16]; + int8_t Laux[16]; + float waux[16]; + bool is_on_grid[2]; + bool is_on_grid_aux[2]; + uint8_t block_signs[2]; - *s = sumf; + for (int ibl = 0; ibl < nbl; ++ibl) { -#else + memset(&y[ibl], 0, sizeof(block_iq2_s)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); + float max_scale = 0; - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int l = 0; l < 16; ++l) { - a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; + } + for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 2; ++k) { + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; s |= (1 << i); + } + } + block_signs[k] = s; + } + float max = xval[0]; + for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + is_on_grid[0] = is_on_grid[1] = true; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 2; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 16; ++i) L[i] = Laux[i]; + for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 2; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + L[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + scale = -scale; + for (int k = 0; k < 2; ++k) block_signs[k] = ~block_signs[k]; + } + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + const int i8 = 2*ib + k; + y[ibl].qs[i8] = grid_index & 255; + y[ibl].qh[i8/4] |= ((grid_index >> 8) << 2*(i8%4)); + y[ibl].qs[QK_K/8 + i8] = block_signs[k]; + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); } - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; + + if (!max_scale) { + continue; + } + + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d * 0.9875f); + float id = 1/d; + for (int ib = 0; ib < QK_K/16; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + if (ib%2 == 0) y[ibl].scales[ib/2] = l; + else y[ibl].scales[ib/2] |= (l << 4); } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif } -#endif +size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_s); + } + return nrow * nblock * sizeof(block_iq2_s); +} + +void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq2_s(x, y, 1, k, NULL); +} + +void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_iq2_s * restrict y = vy; + quantize_row_iq2_s_reference(x, y, k); +} diff --git a/bindings/ruby/ext/ggml-quants.h b/bindings/ruby/ext/ggml-quants.h index 70c12c27465..4d436a8f06b 100644 --- a/bindings/ruby/ext/ggml-quants.h +++ b/bindings/ruby/ext/ggml-quants.h @@ -1,224 +1,133 @@ #pragma once -#include "ggml-impl.h" +#define GGML_COMMON_DECL_C +#include "ggml-common.h" -// GGML internal header - -#include -#include - -#define QK4_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); - -#define QK4_1 32 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; -static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); - -#define QK5_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); - -#define QK5_1 32 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); - -#define QK8_0 32 -typedef struct { - ggml_fp16_t d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); - -#define QK8_1 32 -typedef struct { - float d; // delta - float s; // d * sum(qs[i]) - int8_t qs[QK8_1]; // quants -} block_q8_1; -static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); - -// -// Super-block quantization structures -// - -// Super-block size -#ifdef GGML_QKK_64 -#define QK_K 64 -#define K_SCALE_SIZE 4 -#else -#define QK_K 256 -#define K_SCALE_SIZE 12 -#endif +#include "ggml.h" -// 2-bit quantization -// weight is represented as x = a * q + b -// 16 blocks of 16 elements each -// Effectively 2.5625 bits per weight -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins -} block_q2_K; -static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); - -// 3-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 3.4375 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[2]; - ggml_fp16_t d; // super-block scale -} block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); -#else -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[12]; // scales, quantized with 6 bits - ggml_fp16_t d; // super-block scale -} block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); -#endif - -// 4-bit quantization -// 8 blocks of 32 elements each -// weight is represented as x = a * q + b -// Effectively 4.5 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - ggml_fp16_t d[2]; // super-block scales/mins - uint8_t scales[2]; // 4-bit block scales/mins - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); -#else -typedef struct { - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); -#endif +// GGML internal header -// 5-bit quantization -// 8 blocks of 32 elements each -// weight is represented as x = a * q + b -// Effectively 5.5 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - ggml_fp16_t d; // super-block scale - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); -#else -typedef struct { - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#ifdef __cplusplus +extern "C" { #endif -// 6-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 6.5625 bits per weight -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - ggml_fp16_t d; // super-block scale -} block_q6_K; -static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); - -// This is only used for intermediate quantization and dot products -typedef struct { - float d; // delta - int8_t qs[QK_K]; // quants - int16_t bsums[QK_K/16]; // sum of quants in groups of 16 -} block_q8_K; -static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); - - // Quantization -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k); -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k); -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k); -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k); -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k); - -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k); -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k); -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k); -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); - -void quantize_row_q4_0(const float * restrict x, void * restrict y, int k); -void quantize_row_q4_1(const float * restrict x, void * restrict y, int k); -void quantize_row_q5_0(const float * restrict x, void * restrict y, int k); -void quantize_row_q5_1(const float * restrict x, void * restrict y, int k); -void quantize_row_q8_0(const float * restrict x, void * restrict y, int k); -void quantize_row_q8_1(const float * restrict x, void * restrict y, int k); - -void quantize_row_q2_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q3_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); + +void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); + +void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k); -void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k); -void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k); -void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k); -void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k); -//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k); - -void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k); -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k); -void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k); -void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); -void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); +void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product -void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); - -void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") +size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +void iq2xs_init_impl(enum ggml_type type); +void iq2xs_free_impl(enum ggml_type type); +void iq3xs_init_impl(int grid_size); +void iq3xs_free_impl(int grid_size); + +#ifdef __cplusplus +} +#endif + diff --git a/bindings/ruby/ext/ggml-sycl.h b/bindings/ruby/ext/ggml-sycl.h new file mode 100644 index 00000000000..a9f776fc1dd --- /dev/null +++ b/bindings/ruby/ext/ggml-sycl.h @@ -0,0 +1,49 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_SYCL_MAX_DEVICES 48 +#define GGML_SYCL_NAME "SYCL" + +// backend API +GGML_API ggml_backend_t ggml_backend_sycl_init(int device); + +// devide buffer +GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device); + +// split tensor buffer that splits matrices by rows across multiple devices +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split); + +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); + +GGML_API void ggml_backend_sycl_print_sycl_devices(void); +GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len); +GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, size_t description_size); +GGML_API GGML_CALL int ggml_backend_sycl_get_device_count(); +GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); +GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id); + +// TODO: these are temporary +// ref: https://github.com/ggerganov/llama.cpp/pull/6022#issuecomment-1992615670 +GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index); +GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id); +GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode(); + +// SYCL doesn't support registering host memory, keep here for reference +// GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size); +// GGML_API GGML_CALL void ggml_backend_sycl_unregister_host_buffer(void * buffer); +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ggml-vulkan.h b/bindings/ruby/ext/ggml-vulkan.h new file mode 100644 index 00000000000..af661c2d7d5 --- /dev/null +++ b/bindings/ruby/ext/ggml-vulkan.h @@ -0,0 +1,29 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_VK_NAME "Vulkan" +#define GGML_VK_MAX_DEVICES 16 + +GGML_API void ggml_vk_instance_init(void); + +// backend API +GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num); + +GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend); +GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void); +GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size); +GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total); + +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num); +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec new file mode 100644 index 00000000000..508a6a94052 --- /dev/null +++ b/bindings/ruby/whispercpp.gemspec @@ -0,0 +1,28 @@ +Gem::Specification.new do |s| + s.name = "whispercpp" + s.authors = ["Georgi Gerganov", "Todd A. Fisher"] + s.version = '1.3.0' + s.date = '2024-05-14' + s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} + s.email = 'todd.fisher@gmail.com' + s.extra_rdoc_files = ['LICENSE', 'README.md'] + + s.files = ["LICENSE", "README.md", "Rakefile", "ext/extconf.rb", "ext/ggml.c", "ext/ruby_whisper.cpp", "ext/whisper.cpp", "ext/dr_wav.h", "ext/ggml.h", "ext/ruby_whisper.h", "ext/whisper.h"] + + #### Load-time details + s.require_paths = ['lib','ext'] + s.summary = %q{Ruby whisper.cpp bindings} + s.test_files = ["tests/test_whisper.rb"] + + s.extensions << 'ext/extconf.rb' + + + #### Documentation and testing. + s.homepage = 'https://github.com/ggerganov/whisper.cpp' + s.rdoc_options = ['--main', '../../README.md'] + + + s.platform = Gem::Platform::RUBY + + s.licenses = ['MIT'] +end From a7dc2aab16822b80a6491b0bd4bbf4900404a8a0 Mon Sep 17 00:00:00 2001 From: Daniel Valdivia <18384552+dvaldivia@users.noreply.github.com> Date: Sat, 25 May 2024 00:46:22 -0700 Subject: [PATCH 091/100] server : fix typo (#2181) A simple comment typo, PR can be dismissed --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c78b3026e18..2efa4c7a020 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -947,7 +947,7 @@ int main(int argc, char ** argv) { "application/json"); } - // reset params to thier defaults + // reset params to their defaults params = default_params; }); svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ From 05042a782db3e3df5e14dd992c72a89337648a53 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 27 May 2024 10:20:25 +0300 Subject: [PATCH 092/100] Revert "whisper : remove extra backend instance (huh?)" (#2182) This reverts commit 4caa64b73ed4c0e71097c865b0f6a9c136b007c6. --- whisper.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 84aec8238cd..7b8c683fca7 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -818,6 +818,8 @@ struct whisper_state { whisper_decoder decoders[WHISPER_MAX_DECODERS]; + ggml_backend_t backend = nullptr; + // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers // - stores the actual tensor data into the `data` buffers @@ -2261,7 +2263,7 @@ static bool whisper_encode_internal( } if (!whisper_encode_external(wstate)) { - if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } } else { @@ -2284,7 +2286,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } } @@ -2300,7 +2302,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } } @@ -2801,7 +2803,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } } @@ -3248,6 +3250,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_state * state = new whisper_state; + state->backend = whisper_backend_init(ctx->params); + if (!state->backend) { + WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); + whisper_free_state(state); + return nullptr; + } + // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // in theory, there can be a case where this is not enough, but in practice it should always be enough const int factor = 3; @@ -3684,6 +3693,8 @@ void whisper_free_state(struct whisper_state * state) { ggml_gallocr_free(state->alloc_cross.alloc); ggml_gallocr_free(state->alloc_decode.alloc); + ggml_backend_free(state->backend); + // [EXPERIMENTAL] Token-level timestamps with DTW aheads_masks_free(state->aheads_masks); From c7b6988678779901d02ceba1a8212d2c9908956e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 27 May 2024 10:35:09 +0300 Subject: [PATCH 093/100] release : v1.6.2 --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/ios | 2 +- bindings/javascript/package.json | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 541be8a5d57..82913aa62ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required (VERSION 3.5) # Allow for the creation of solution folders. set_property(GLOBAL PROPERTY USE_FOLDERS ON) -project(whisper.cpp VERSION 1.6.1) +project(whisper.cpp VERSION 1.6.2) set(SOVERSION 1) # Add path to modules diff --git a/README.md b/README.md index 0c34e8dffce..c2c6bc4a2b1 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.6.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) +Stable: [v1.6.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/ios b/bindings/ios index 9a32de38144..a2085436c2e 160000 --- a/bindings/ios +++ b/bindings/ios @@ -1 +1 @@ -Subproject commit 9a32de3814477ad2e598d4a550fcab4b23a9c576 +Subproject commit a2085436c2eb796af90956b62bd64731f5e5b823 diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index da6a9efdc6c..2b3c806f353 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.6.1", + "version": "1.6.2", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From e130b666425879af4b538f2441f741cc70b6f9d7 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Wed, 29 May 2024 19:09:21 +0300 Subject: [PATCH 094/100] whisper: use global cache for sin/cos vals and Hann window (#2194) - also rename Hanning to Hann as it's named after Julius von Hann as per Wikipedia --- whisper.cpp | 97 +++++++++++++++++++++++++++++------------------------ 1 file changed, 54 insertions(+), 43 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 7b8c683fca7..a22da8896bb 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2857,20 +2857,44 @@ static std::string to_timestamp(int64_t t, bool comma = false) { } #define SIN_COS_N_COUNT WHISPER_N_FFT -static float sin_vals[SIN_COS_N_COUNT]; -static float cos_vals[SIN_COS_N_COUNT]; +namespace { +struct whisper_global_cache { + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + float sin_vals[SIN_COS_N_COUNT]; + float cos_vals[SIN_COS_N_COUNT]; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + float hann_window[WHISPER_N_FFT]; + float hann_window2x[WHISPER_N_FFT * 2]; + + whisper_global_cache() { + fill_sin_cos_table(); +#define FILL_HANN_WINDOW(arr) fill_hann_window(sizeof(arr) / sizeof(arr[0]), true, arr) + FILL_HANN_WINDOW(hann_window); + FILL_HANN_WINDOW(hann_window2x); + } + + void fill_sin_cos_table() { + for (int i = 0; i < SIN_COS_N_COUNT; i++) { + double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } -// In FFT, we frequently use sine and cosine operations with the same values. -// We can use precalculated values to speed up the process. -static void fill_sin_cos_table() { - static bool is_filled = false; - if (is_filled) return; - for (int i = 0; i < SIN_COS_N_COUNT; i++) { - double theta = (2*M_PI*i)/SIN_COS_N_COUNT; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); + void fill_hann_window(int length, bool periodic, float* output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } } - is_filled = true; +} global_cache; } // naive Discrete Fourier Transform @@ -2888,8 +2912,8 @@ static void dft(const std::vector & in, std::vector & out) { for (int n = 0; n < N; n++) { int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N - re += in[n]*cos_vals[idx]; // cos(t) - im -= in[n]*sin_vals[idx]; // sin(t) + re += in[n]*global_cache.cos_vals[idx]; // cos(t) + im -= in[n]*global_cache.sin_vals[idx]; // sin(t) } out[k*2 + 0] = re; @@ -2940,8 +2964,8 @@ static void fft(const std::vector & in, std::vector & out) { const int sin_cos_step = SIN_COS_N_COUNT / N; for (int k = 0; k < N/2; k++) { int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = cos_vals[idx]; // cos(t) - float im = -sin_vals[idx]; // sin(t) + float re = global_cache.cos_vals[idx]; // cos(t) + float im = -global_cache.sin_vals[idx]; // sin(t) float re_odd = odd_fft[2*k + 0]; float im_odd = odd_fft[2*k + 1]; @@ -2954,22 +2978,7 @@ static void fft(const std::vector & in, std::vector & out) { } } -static bool hann_window(int length, bool periodic, std::vector & output) { - if (output.size() < static_cast(length)) { - output.resize(length); - } - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset))); - } - - return true; -} - -static void log_mel_spectrogram_worker_thread(int ith, const std::vector & hann, const std::vector & samples, +static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, int n_samples, int frame_size, int frame_step, int n_threads, const whisper_filters & filters, whisper_mel & mel) { std::vector fft_in(frame_size, 0.0); @@ -2984,7 +2993,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { const int offset = i * frame_step; - // apply Hanning window (~10% faster) + // apply Hann window (~10% faster) for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { fft_in[j] = hann[j] * samples[offset + j]; } @@ -3051,12 +3060,16 @@ static bool log_mel_spectrogram( whisper_mel & mel) { const int64_t t_start_us = ggml_time_us(); - // Hanning window (Use cosf to eliminate difference) - // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html - // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 - std::vector hann; - hann_window(frame_size, true, hann); - + // Hann window + const float * hann = nullptr; + if (frame_size == WHISPER_N_FFT) { + hann = global_cache.hann_window; + } else if (frame_size == 2 * WHISPER_N_FFT) { + hann = global_cache.hann_window2x; + } else { + WHISPER_ASSERT(false && "Unsupported frame_size"); + return false; + } // Calculate the length of padding int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; @@ -3086,7 +3099,7 @@ static bool log_mel_spectrogram( std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded, + log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, std::cref(filters), std::ref(mel)); } @@ -3246,8 +3259,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) { #endif struct whisper_state * whisper_init_state(whisper_context * ctx) { - fill_sin_cos_table(); - whisper_state * state = new whisper_state; state->backend = whisper_backend_init(ctx->params); @@ -7235,7 +7246,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw( // operation (after median filter) // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims - w = ggml_norm(gctx, w, 1e-9); + w = ggml_norm(gctx, w, 1e-9f); w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3); // Pass median filter - this is done over AUDIO_TOKENS dimension. From ad130431aa52092a3d091bb304919f48a045aa70 Mon Sep 17 00:00:00 2001 From: Carlos Zoido Date: Thu, 30 May 2024 14:06:15 +0200 Subject: [PATCH 095/100] readme : add install instructions for Conan (#2189) --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index c2c6bc4a2b1..3ef073174e2 100644 --- a/README.md +++ b/README.md @@ -502,6 +502,16 @@ docker run -it --rm \ whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav" ``` +## Installing with Conan + +You can install pre-built binaries for whisper.cpp or build it from source using [Conan](https://conan.io/). Use the following command: + +``` +conan install --requires="whisper-cpp/[*]" --build=missing +``` + +For detailed instructions on how to use Conan, please refer to the [Conan documentation](https://docs.conan.io/2/). + ## Limitations - Inference only From b87494bb8f1e2b5843ec606294e8c370aa25a368 Mon Sep 17 00:00:00 2001 From: Martin Delille Date: Thu, 30 May 2024 14:43:28 +0200 Subject: [PATCH 096/100] readme : add conan badge (#2196) * Add conan badge * Fix markdown formating --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 3ef073174e2..9ec9684971a 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) +[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) Stable: [v1.6.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) @@ -720,7 +721,7 @@ The [main](examples/main) example provides support for output of karaoke-style m currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script. This requires to have `ffmpeg` installed. -Here are a few *"typical"* examples: +Here are a few _"typical"_ examples: ```bash ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts From af5833e29819810f2d83228228a9a3077e5ccd93 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Fri, 31 May 2024 11:37:29 +0300 Subject: [PATCH 097/100] whisper : remove `speed_up` and `phase_vocoder*` functions (#2198) * whisper : fix cast warning * whisper : remove phase_vocoder functions, ref #2195 * whisper : remove speed_up from whisper_full_params, closes #2195 --- bindings/go/examples/go-whisper/flags.go | 9 --- bindings/go/params.go | 7 --- bindings/go/pkg/whisper/context.go | 5 -- bindings/go/pkg/whisper/interface.go | 1 - .../whispercpp/WhisperCppJnaLibrary.java | 10 +-- .../whispercpp/params/WhisperFullParams.java | 10 +-- bindings/ruby/ext/ruby_whisper.cpp | 8 --- bindings/ruby/tests/test_whisper.rb | 7 --- examples/addon.node/addon.cpp | 3 - examples/command/command.cpp | 5 -- examples/common.h | 2 +- examples/lsp/lsp.cpp | 5 -- examples/main/main.cpp | 4 -- examples/server/server.cpp | 4 -- examples/stream/stream.cpp | 4 -- examples/talk-llama/talk-llama.cpp | 4 -- examples/talk/talk.cpp | 4 -- examples/wchess/wchess.cmd/wchess.cmd.cpp | 3 - whisper.cpp | 63 ++++--------------- whisper.h | 17 ----- 20 files changed, 14 insertions(+), 161 deletions(-) diff --git a/bindings/go/examples/go-whisper/flags.go b/bindings/go/examples/go-whisper/flags.go index ea204455c80..766c92f1827 100644 --- a/bindings/go/examples/go-whisper/flags.go +++ b/bindings/go/examples/go-whisper/flags.go @@ -68,10 +68,6 @@ func (flags *Flags) GetOut() string { return strings.ToLower(flags.Lookup("out").Value.String()) } -func (flags *Flags) IsSpeedup() bool { - return flags.Lookup("speedup").Value.String() == "true" -} - func (flags *Flags) IsTokens() bool { return flags.Lookup("tokens").Value.String() == "true" } @@ -111,10 +107,6 @@ func (flags *Flags) SetParams(context whisper.Context) error { fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration) context.SetDuration(duration) } - if flags.IsSpeedup() { - fmt.Fprintf(flags.Output(), "Setting speedup to true\n") - context.SetSpeedup(true) - } if threads := flags.GetThreads(); threads != 0 { fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads) context.SetThreads(threads) @@ -146,7 +138,6 @@ func registerFlags(flag *Flags) { flag.Duration("offset", 0, "Time offset") flag.Duration("duration", 0, "Duration of audio to process") flag.Uint("threads", 0, "Number of threads to use") - flag.Bool("speedup", false, "Enable speedup") flag.Uint("max-len", 0, "Maximum segment length in characters") flag.Uint("max-tokens", 0, "Maximum tokens per segment") flag.Float64("word-thold", 0, "Maximum segment score") diff --git a/bindings/go/params.go b/bindings/go/params.go index 5931bb0b199..4b4da032d62 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -47,10 +47,6 @@ func (p *Params) SetPrintTimestamps(v bool) { p.print_timestamps = toBool(v) } -func (p *Params) SetSpeedup(v bool) { - p.speed_up = toBool(v) -} - // Set language id func (p *Params) SetLanguage(lang int) error { if lang == -1 { @@ -177,9 +173,6 @@ func (p *Params) String() string { if p.token_timestamps { str += " token_timestamps" } - if p.speed_up { - str += " speed_up" - } return str + ">" } diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index 0863ef6bb16..ead92648f3e 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -76,11 +76,6 @@ func (context *context) SetTranslate(v bool) { context.params.SetTranslate(v) } -// Set speedup flag -func (context *context) SetSpeedup(v bool) { - context.params.SetSpeedup(v) -} - func (context *context) SetSplitOnWord(v bool) { context.params.SetSplitOnWord(v) } diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index 4339e16f847..b430e7ce853 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -41,7 +41,6 @@ type Context interface { SetOffset(time.Duration) // Set offset SetDuration(time.Duration) // Set duration SetThreads(uint) // Set number of threads to use - SetSpeedup(bool) // Set speedup flag SetSplitOnWord(bool) // Set split on word flag SetTokenThreshold(float32) // Set timestamp token probability threshold SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java index 56a37380136..1a73cee1181 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -20,7 +20,7 @@ public interface WhisperCppJnaLibrary extends Library { * @return Whisper context on success, null on failure */ Pointer whisper_init_from_file(String path_model); - + /** * Provides default params which can be used with `whisper_init_from_file_with_params()` etc. * Because this function allocates memory for the params, the caller must call either: @@ -304,14 +304,6 @@ public interface WhisperCppJnaLibrary extends Library { /** Language id associated with the provided state */ int whisper_full_lang_id_from_state(Pointer state); - /** - * Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. - * The resulting spectrogram is stored inside the default state of the provided whisper context. - * @return 0 on success - */ - int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads); - - int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads); /** Get the start time of the specified segment. */ long whisper_full_get_segment_t0(Pointer ctx, int i_segment); diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 60d8334b935..90d8c15767c 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -129,14 +129,6 @@ public void splitOnWord(boolean enable) { /** Maximum tokens per segment (0, default = no limit) */ public int max_tokens; - /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */ - public CBool speed_up; - - /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */ - public void speedUp(boolean enable) { - speed_up = enable ? CBool.TRUE : CBool.FALSE; - } - /** Overwrite the audio context size (0 = use default). */ public int audio_ctx; @@ -321,7 +313,7 @@ protected List getFieldOrder() { return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate", "no_context", "single_segment", "no_timestamps", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", - "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", + "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx", "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 86af9391e2c..9d9334539b8 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -311,12 +311,6 @@ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, split_on_word, value) } -static VALUE ruby_whisper_params_get_speed_up(VALUE self) { - BOOL_PARAMS_GETTER(self, speed_up) -} -static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, speed_up, value) -} static VALUE ruby_whisper_params_get_diarize(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -408,8 +402,6 @@ void Init_whisper() { rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); - rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0); - rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1); rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index fa6a3e2d4e8..3700671bce6 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -117,13 +117,6 @@ def test_split_on_word assert !@params.split_on_word end - def test_speed_up - @params.speed_up = true - assert @params.speed_up - @params.speed_up = false - assert !@params.speed_up - end - def test_whisper @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) params = Whisper::Params.new diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 53bf1abb5a3..4ada6ca5084 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -25,7 +25,6 @@ struct whisper_params { float entropy_thold = 2.4f; float logprob_thold = -1.0f; - bool speed_up = false; bool translate = false; bool diarize = false; bool output_txt = false; @@ -232,8 +231,6 @@ int run(whisper_params ¶ms, std::vector> &result) { wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; - wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; diff --git a/examples/command/command.cpp b/examples/command/command.cpp index cd6cc023994..84424d4331b 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -38,7 +38,6 @@ struct whisper_params { grammar_parser::parse_state grammar_parsed; - bool speed_up = false; bool translate = false; bool print_special = false; bool print_energy = false; @@ -76,7 +75,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } @@ -115,7 +113,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); @@ -165,7 +162,6 @@ std::string transcribe( wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.temperature = 0.4f; wparams.temperature_inc = 1.0f; @@ -371,7 +367,6 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.prompt_tokens = k_tokens.data(); wparams.prompt_n_tokens = k_tokens.size(); diff --git a/examples/common.h b/examples/common.h index 2ed91ca9aa8..de895858ab0 100644 --- a/examples/common.h +++ b/examples/common.h @@ -185,7 +185,7 @@ class wav_writer { // It is assumed that PCM data is normalized to a range from -1 to 1 bool write_audio(const float * data, size_t length) { for (size_t i = 0; i < length; ++i) { - const int16_t intSample = data[i] * 32767; + const int16_t intSample = int16_t(data[i] * 32767); file.write(reinterpret_cast(&intSample), sizeof(int16_t)); dataSize += sizeof(int16_t); } diff --git a/examples/lsp/lsp.cpp b/examples/lsp/lsp.cpp index 3df54266a25..8cca87151bf 100644 --- a/examples/lsp/lsp.cpp +++ b/examples/lsp/lsp.cpp @@ -26,7 +26,6 @@ struct whisper_params { float vad_thold = 0.6f; float freq_thold = 100.0f; - bool speed_up = false; bool translate = false; bool print_special = false; bool print_energy = false; @@ -70,7 +69,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } @@ -102,7 +100,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); @@ -184,7 +181,6 @@ json unguided_transcription(struct whisper_context * ctx, audio_async &audio, js wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.suppress_non_speech_tokens = true; // run the transformer and a single decoding pass if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { @@ -223,7 +219,6 @@ json guided_transcription(struct whisper_context * ctx, audio_async &audio, cons wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; // TODO: Do some time testing. Does an overly long prompt slow down processing? // Set up command sets/precompute prompts diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 45eb17fe7f3..bb9b7b79ce5 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -47,7 +47,6 @@ struct whisper_params { float temperature = 0.0f; float temperature_inc = 0.2f; - bool speed_up = false; bool debug_mode = false; bool translate = false; bool detect_language = false; @@ -138,7 +137,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); } else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); } - // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } @@ -206,7 +204,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); - // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); @@ -1106,7 +1103,6 @@ int main(int argc, char ** argv) { wparams.split_on_word = params.split_on_word; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.debug_mode = params.debug_mode; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2efa4c7a020..10aae9c04d3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,7 +61,6 @@ struct whisper_params { float temperature = 0.00f; float temperature_inc = 0.20f; - bool speed_up = false; bool debug_mode = false; bool translate = false; bool detect_language = false; @@ -112,7 +111,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); @@ -159,7 +157,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } @@ -768,7 +765,6 @@ int main(int argc, char ** argv) { wparams.split_on_word = params.split_on_word; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.debug_mode = params.debug_mode; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 60c1b0894e4..50797e96daa 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -27,7 +27,6 @@ struct whisper_params { float vad_thold = 0.6f; float freq_thold = 100.0f; - bool speed_up = false; bool translate = false; bool no_fallback = false; bool print_special = false; @@ -62,7 +61,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } @@ -100,7 +98,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); @@ -314,7 +311,6 @@ int main(int argc, char ** argv) { wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 4aab62b9a6f..b15be0b2789 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -59,7 +59,6 @@ struct whisper_params { float vad_thold = 0.6f; float freq_thold = 100.0f; - bool speed_up = false; bool translate = false; bool print_special = false; bool print_energy = false; @@ -100,7 +99,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } @@ -149,7 +147,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); @@ -205,7 +202,6 @@ std::string transcribe( wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { return ""; diff --git a/examples/talk/talk.cpp b/examples/talk/talk.cpp index 3e34e5724ff..b34fad6c2bb 100644 --- a/examples/talk/talk.cpp +++ b/examples/talk/talk.cpp @@ -26,7 +26,6 @@ struct whisper_params { float vad_thold = 0.6f; float freq_thold = 100.0f; - bool speed_up = false; bool translate = false; bool print_special = false; bool print_energy = false; @@ -60,7 +59,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } @@ -96,7 +94,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); @@ -132,7 +129,6 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { return ""; diff --git a/examples/wchess/wchess.cmd/wchess.cmd.cpp b/examples/wchess/wchess.cmd/wchess.cmd.cpp index 09e53f13172..4d049976315 100644 --- a/examples/wchess/wchess.cmd/wchess.cmd.cpp +++ b/examples/wchess/wchess.cmd/wchess.cmd.cpp @@ -26,7 +26,6 @@ struct whisper_params { float grammar_penalty = 100.0f; - bool speed_up = false; bool translate = false; bool print_special = false; bool print_energy = false; @@ -57,7 +56,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); @@ -89,7 +87,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } diff --git a/whisper.cpp b/whisper.cpp index a22da8896bb..dbb235e9f43 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2868,13 +2868,10 @@ struct whisper_global_cache { // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 float hann_window[WHISPER_N_FFT]; - float hann_window2x[WHISPER_N_FFT * 2]; whisper_global_cache() { fill_sin_cos_table(); -#define FILL_HANN_WINDOW(arr) fill_hann_window(sizeof(arr) / sizeof(arr[0]), true, arr) - FILL_HANN_WINDOW(hann_window); - FILL_HANN_WINDOW(hann_window2x); + fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window); } void fill_sin_cos_table() { @@ -2885,7 +2882,7 @@ struct whisper_global_cache { } } - void fill_hann_window(int length, bool periodic, float* output) { + void fill_hann_window(int length, bool periodic, float * output) { int offset = -1; if (periodic) { offset = 0; @@ -3061,15 +3058,8 @@ static bool log_mel_spectrogram( const int64_t t_start_us = ggml_time_us(); // Hann window - const float * hann = nullptr; - if (frame_size == WHISPER_N_FFT) { - hann = global_cache.hann_window; - } else if (frame_size == 2 * WHISPER_N_FFT) { - hann = global_cache.hann_window2x; - } else { - WHISPER_ASSERT(false && "Unsupported frame_size"); - return false; - } + WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size"); + const float * hann = global_cache.hann_window; // Calculate the length of padding int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; @@ -3752,30 +3742,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); } -// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) -int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { - WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); - return -1; - } - - return 0; -} - -// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) -int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads); -} - -// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2 -// TODO - -// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2 -// TODO - -// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2 -// TODO - int whisper_set_mel_with_state( struct whisper_context * ctx, struct whisper_state * state, @@ -4676,7 +4642,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.split_on_word =*/ false, /*.max_tokens =*/ 0, - /*.speed_up =*/ false, /*.debug_mode =*/ false, /*.audio_ctx =*/ 0, @@ -5350,15 +5315,9 @@ int whisper_full_with_state( if (n_samples > 0) { // compute log mel spectrogram - if (params.speed_up) { - // TODO: Replace PV with more advanced algorithm + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); - return -1; - } else { - if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { - WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); - return -2; - } + return -2; } } @@ -5395,7 +5354,7 @@ int whisper_full_with_state( // if length of spectrogram is less than 1.0s (100 frames), then return // basically don't process anything that is less than 1.0s // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 - if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { + if (seek_end < seek_start + 100) { WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10); return 0; } @@ -6107,8 +6066,8 @@ int whisper_full_with_state( const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); if (!text.empty()) { - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; + const auto tt0 = t0; + const auto tt1 = t1; if (params.print_realtime) { if (params.print_timestamps) { @@ -6154,8 +6113,8 @@ int whisper_full_with_state( if (!text.empty()) { const auto t1 = seek + seek_delta; - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; + const auto tt0 = t0; + const auto tt1 = t1; if (params.print_realtime) { if (params.print_timestamps) { diff --git a/whisper.h b/whisper.h index 9c7c58d874b..2b3d5e574cb 100644 --- a/whisper.h +++ b/whisper.h @@ -266,22 +266,6 @@ extern "C" { int n_samples, int n_threads); - // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. - // The resulting spectrogram is stored inside the default state of the provided whisper context. - // Returns 0 on success - WHISPER_API int whisper_pcm_to_mel_phase_vocoder( - struct whisper_context * ctx, - const float * samples, - int n_samples, - int n_threads); - - WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - const float * samples, - int n_samples, - int n_threads); - // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 @@ -499,7 +483,6 @@ extern "C" { // [EXPERIMENTAL] speed-up techniques // note: these can significantly reduce the quality of the output - bool speed_up; // speed-up the audio by 2x using Phase Vocoder bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) int audio_ctx; // overwrite the audio context size (0 = use default) From ffef323c4cfa8596cb91cf92d6f791f01a44335e Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Tue, 4 Jun 2024 09:32:23 +0300 Subject: [PATCH 098/100] whisper : add CUDA-specific computation mel spectrograms (#2206) * whisper : use polymorphic class to calculate mel spectrogram * whisper : add cuda-specific mel spectrogram calculation * whisper : conditionally compile cufftGetErrorString to avoid warnings * build : add new files to makefile * ruby : add new files to conf script * build : fix typo in makefile * whisper : suppress cub warning for deprecated C++ std in whisper-mel-cuda --- CMakeLists.txt | 10 +- Makefile | 9 +- bindings/ruby/ext/extconf.rb | 1 + whisper-mel-cuda.cu | 342 +++++++++++++++++++++++++++++++++++ whisper-mel-cuda.hpp | 3 + whisper-mel.hpp | 33 ++++ whisper.cpp | 196 ++++++++++---------- whisper.h | 2 + 8 files changed, 497 insertions(+), 99 deletions(-) create mode 100644 whisper-mel-cuda.cu create mode 100644 whisper-mel-cuda.hpp create mode 100644 whisper-mel.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 82913aa62ba..63d707b66c2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -364,12 +364,12 @@ if (WHISPER_CUDA) if (WHISPER_STATIC) if (WIN32) # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft) else () - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static) endif() else() - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft) endif() set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver) @@ -679,6 +679,10 @@ add_library(${TARGET} whisper.cpp ) +if (WHISPER_CUDA) + target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu) +endif() + include_directories ( . ) diff --git a/Makefile b/Makefile index 901fe216035..53f880e88f7 100644 --- a/Makefile +++ b/Makefile @@ -286,8 +286,8 @@ ifdef WHISPER_CUDA CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include - LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib - WHISPER_OBJ += ggml-cuda.o + LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib + WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu)) NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG) @@ -299,6 +299,9 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif +whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp + $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ + ifdef WHISPER_HIPBLAS ROCM_PATH ?= /opt/rocm HIPCC ?= $(ROCM_PATH)/bin/hipcc @@ -404,7 +407,7 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o -whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h +whisper.o: whisper.cpp whisper.h whisper-mel.hpp ggml.h ggml-cuda.h $(CXX) $(CXXFLAGS) -c $< -o $@ ifndef WHISPER_COREML diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 410c08feef5..f22c550ee37 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -1,6 +1,7 @@ require 'mkmf' system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .") diff --git a/whisper-mel-cuda.cu b/whisper-mel-cuda.cu new file mode 100644 index 00000000000..ad36cae5830 --- /dev/null +++ b/whisper-mel-cuda.cu @@ -0,0 +1,342 @@ +#define CUB_IGNORE_DEPRECATED_CPP_DIALECT +#include "whisper-mel-cuda.hpp" +#include "whisper.h" + +#include +#include +#include +#include +#include +#include + +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4324) // added padding +#endif + +#ifndef NDEBUG +# define DO_CHECKS 1 +#else +# define DO_CHECKS 0 +#endif + +namespace { + +#if DO_CHECKS +const char* cufftGetErrorString(cufftResult_t res) { + switch (res) { + case CUFFT_SUCCESS: return "The cuFFT operation was successful"; + case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle"; + case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory"; + case CUFFT_INVALID_TYPE: return "No longer used"; + case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter"; + case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error"; + case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU"; + case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize"; + case CUFFT_INVALID_SIZE: return "User specified an invalid transform size"; + case CUFFT_UNALIGNED_DATA: return "No longer used"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call"; + case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation"; + case CUFFT_PARSE_ERROR: return "Internal plan database error"; + case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution"; + case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given."; + case CUFFT_LICENSE_ERROR: return "Used in previous versions."; + case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given."; + default: return "Unknown error"; + } +} + +# define CUDA_CHECK_GEN(err, success, error_fn) \ + do { \ + auto err_ = (err); \ + if (err_ != (success)) { \ + fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \ + } \ + } while (0) +#else +# define CUDA_CHECK_GEN(err, success, error_fn) err +#endif + +#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) +#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString) +#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString) + +__global__ void k_fill_stft_input( + const float * padded_samples, + const int n_frames, + const float * hann_window, + float * stft_in +) { + auto y = blockIdx.y * blockDim.y + threadIdx.y; + // if (y >= n_frames) return; + auto x = blockIdx.x * blockDim.x + threadIdx.x; + // if (x >= WHISPER_N_FFT) return; + + auto line = padded_samples + y * WHISPER_HOP_LENGTH; + auto outLine = stft_in + y * WHISPER_N_FFT; + + outLine[x] = line[x] * hann_window[x]; +} + +__global__ void k_calc_magnitudes( + const cuComplex* stft_out, + const int n_frames, + float * magnitudes +) { + auto y = blockIdx.y * blockDim.y + threadIdx.y; + // if (y >= n_frames) return; + auto x = blockIdx.x * blockDim.x + threadIdx.x; + // if (x >= WHISPER_N_FFT_HALF) return; + + auto idx = y * WHISPER_N_FFT_HALF + x; + + auto r = stft_out[idx].x; + auto i = stft_out[idx].y; + magnitudes[idx] = r * r + i * i; +} + +__global__ void k_calc_log_mel( + const float * mel_data, + const int n_mel, + const float * max_val, + float * log_mel +) { + auto x = blockIdx.x * blockDim.x + threadIdx.x; + if (x >= n_mel) return; + + float val = mel_data[x]; + + constexpr float e = 1e-10f; + if (val < e) val = e; + + val = log10(val); + + const float max = log10(*max_val) - 8.f; + if (val < max) val = max; + + log_mel[x] = (val + 4) / 4; +} + +void fill_stft_input( + const float * padded_samples, + int n_frames, + const float * hann_window, + float * stft_in, + cudaStream_t stream +) { + dim3 block(WHISPER_N_FFT, 1); + dim3 grid(1, n_frames); + + k_fill_stft_input<<>>(padded_samples, n_frames, hann_window, stft_in); +} + +void calc_magnitudes( + const cuComplex* stft_out, + int n_frames, + float * magnitudes, + cudaStream_t stream +) { + dim3 block(WHISPER_N_FFT_HALF, 1); + dim3 grid(1, n_frames); + k_calc_magnitudes<<>>(stft_out, n_frames, magnitudes); +} + +constexpr auto LOG_MEL_PREFIX_SIZE = 256; + +size_t get_log_mel_temp_storage_size() { + constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT; + constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + constexpr auto maxMels = 160; + + size_t nbytes = 0; + float * temp = nullptr; + cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, maxFrames * maxMels); + return nbytes + LOG_MEL_PREFIX_SIZE; +} + +void calc_log_mel( + const float * mel_data, + int n_mel, + void * tempStorage, + int tempStorageSize, + float * log_mel, + cudaStream_t stream +) { + float * max_val = reinterpret_cast(tempStorage); + void * maxTemp = reinterpret_cast(tempStorage) + LOG_MEL_PREFIX_SIZE; + + size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE); + cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream); + + int block = 256; + int grid = (n_mel + block - 1) / block; + + k_calc_log_mel<<>>(mel_data, n_mel, max_val, log_mel); +} + +class mel_calc_cuda : public whisper_mel_calc { + const int m_n_mel; + + ggml_backend_t m_backend = nullptr; + + cudaStream_t m_stream = nullptr; + cublasHandle_t m_cublas_handle = nullptr; + + float * m_hann_window = nullptr; + + size_t m_cufft_workspace_size = 0; + void * m_cufft_workspace = nullptr; + + float * m_filters = nullptr; + + size_t m_log_mel_temp_storage_size = 0; + void * m_log_mel_temp_storage = nullptr; +public: + mel_calc_cuda(ggml_backend_t backend, const whisper_filters& filters) + : m_n_mel(filters.n_mel) + , m_backend(backend) + { + if (filters.n_fft != WHISPER_N_FFT_HALF) { + throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF"); + } + assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF); + + CUDA_CHECK(cudaStreamCreate(&m_stream)); + CUBLAS_CHECK(cublasCreate(&m_cublas_handle)); + CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); + CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream)); + + // create Hann window + { + auto hw = whisper_mel_calc::hann_window(); + CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream)); + CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream)); + } + + // create working area + { + constexpr auto maxPaddedSamples = 2 * WHISPER_N_SAMPLES + WHISPER_N_FFT; + constexpr auto maxFrames = 1 + (maxPaddedSamples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, maxFrames, &m_cufft_workspace_size)); + CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream)); + } + + // fill filters + { + auto& f = filters.data; + CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream)); + CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream)); + } + + { + m_log_mel_temp_storage_size = get_log_mel_temp_storage_size(); + CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream)); + } + } + + ~mel_calc_cuda() { + CUDA_CHECK(cudaStreamSynchronize(m_stream)); + CUDA_CHECK(cudaStreamDestroy(m_stream)); + CUDA_CHECK(cudaFree(m_hann_window)); + CUDA_CHECK(cudaFree(m_cufft_workspace)); + CUDA_CHECK(cudaFree(m_filters)); + CUDA_CHECK(cudaFree(m_log_mel_temp_storage)); + } + + virtual whisper_mel calculate(whisper_span samples, int /*n_threads*/) const override { + const size_t mirror_pad = WHISPER_N_FFT / 2; + const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT; + + // pad + std::vector padded_samples(padded_size); + std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect + std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy + + // fill the rest of the data + // it should canonically be mirrored at the end as well, + // but we just assume the last MEL_FRAME_SIZE/2 samples are zeros + std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f); + + const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + + float * cu_padded_samples = nullptr; + CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream)); + CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream)); + + float * stft_in = nullptr; // contiguous buffer for stft input + CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream)); + + fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream); + + cufftComplex* stft_out; + CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream)); + + cufftHandle plan; + CUFFT_CHECK(cufftCreate(&plan)); + CUFFT_CHECK(cufftSetAutoAllocation(plan, 0)); + { + size_t waSize; + CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize)); + assert(waSize <= m_cufft_workspace_size); + CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace)); + CUFFT_CHECK(cufftSetStream(plan, m_stream)); + } + CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out)); + + const auto n_mag_frames = n_frames - 1; // drop last frame + float * magnitudes; + CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream)); + calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream); + + float * mel_data = nullptr; + CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream)); + + const float fone = 1.0f, fzero = 0.0f; + CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, + int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF, + &fone, + magnitudes, WHISPER_N_FFT_HALF, + m_filters, WHISPER_N_FFT_HALF, + &fzero, + mel_data, int(n_mag_frames))); + + float * log_mels = nullptr; + CUDA_CHECK(cudaMallocAsync(&log_mels, m_n_mel * n_mag_frames * sizeof(float), m_stream)); + + calc_log_mel( + mel_data, int(m_n_mel * n_mag_frames), + m_log_mel_temp_storage, int(m_log_mel_temp_storage_size), + log_mels, m_stream); + + whisper_mel ret; + ret.n_mel = m_n_mel; + ret.n_len = int(n_mag_frames); + // Calculate semi-padded sample length to ensure compatibility + ret.n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + ret.data.resize(m_n_mel * n_mag_frames); + CUDA_CHECK(cudaMemcpyAsync(ret.data.data(), log_mels, ret.data.size() * sizeof(float), cudaMemcpyDeviceToHost, m_stream)); + + CUDA_CHECK(cudaStreamSynchronize(m_stream)); + + // cleanup + CUFFT_CHECK(cufftDestroy(plan)); + CUDA_CHECK(cudaFreeAsync(log_mels, m_stream)); + CUDA_CHECK(cudaFreeAsync(mel_data, m_stream)); + CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream)); + CUDA_CHECK(cudaFreeAsync(stft_out, m_stream)); + CUDA_CHECK(cudaFreeAsync(stft_in, m_stream)); + CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream)); + + return ret; + } +}; + +} + +whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) { + if (filters.n_fft != WHISPER_N_FFT_HALF) { + return nullptr; + } + return new mel_calc_cuda(backend, filters); +} diff --git a/whisper-mel-cuda.hpp b/whisper-mel-cuda.hpp new file mode 100644 index 00000000000..2acb6505fcb --- /dev/null +++ b/whisper-mel-cuda.hpp @@ -0,0 +1,3 @@ +#include "whisper-mel.hpp" + +whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters); diff --git a/whisper-mel.hpp b/whisper-mel.hpp new file mode 100644 index 00000000000..bc48475feec --- /dev/null +++ b/whisper-mel.hpp @@ -0,0 +1,33 @@ +#pragma once +#include "ggml-backend.h" +#include + +struct whisper_mel { + int n_len; + int n_len_org; + int n_mel; + + std::vector data; +}; + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +template +struct whisper_span { + T * data; + int len; +}; + +struct whisper_mel_calc { + virtual ~whisper_mel_calc(); + virtual whisper_mel calculate(whisper_span samples, int n_threads) const = 0; + static whisper_span hann_window(); +}; + +// returns a new pointer which needs to be freed with delete +whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters); diff --git a/whisper.cpp b/whisper.cpp index dbb235e9f43..2dd2f591bd8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -10,6 +10,7 @@ #ifdef GGML_USE_CUDA #include "ggml-cuda.h" +#include "whisper-mel-cuda.hpp" #endif #ifdef GGML_USE_SYCL @@ -24,6 +25,8 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include "whisper-mel.hpp" + #include #include #include @@ -380,21 +383,6 @@ static const std::map g_aheads { static std::vector get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head); -struct whisper_mel { - int n_len; - int n_len_org; - int n_mel; - - std::vector data; -}; - -struct whisper_filters { - int32_t n_mel; - int32_t n_fft; - - std::vector data; -}; - struct whisper_vocab { using id = int32_t; using token = std::string; @@ -883,6 +871,8 @@ struct whisper_context { whisper_model model; whisper_vocab vocab; + whisper_mel_calc * mel_calc = nullptr; + whisper_state * state = nullptr; ggml_backend_t backend = nullptr; @@ -2894,6 +2884,14 @@ struct whisper_global_cache { } global_cache; } +// Mel spectrogram + +whisper_mel_calc::~whisper_mel_calc() = default; // export vtable + +whisper_span whisper_mel_calc::hann_window() { + return {global_cache.hann_window, WHISPER_N_FFT}; +} + // naive Discrete Fourier Transform // input is real-valued // output is complex-valued @@ -2976,8 +2974,10 @@ static void fft(const std::vector & in, std::vector & out) { } static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, - int n_samples, int frame_size, int frame_step, int n_threads, + int n_samples, int n_threads, const whisper_filters & filters, whisper_mel & mel) { + const auto frame_size = WHISPER_N_FFT; + const auto frame_step = WHISPER_HOP_LENGTH; std::vector fft_in(frame_size, 0.0); std::vector fft_out(2 * frame_size); int n_fft = filters.n_fft; @@ -3041,99 +3041,95 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } } } +namespace { +struct mel_calc_cpu : public whisper_mel_calc { + const whisper_filters& m_filters; + mel_calc_cpu(const whisper_filters & filters) : m_filters(filters) {} -// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 -static bool log_mel_spectrogram( - whisper_state & wstate, - const float * samples, - const int n_samples, - const int /*sample_rate*/, - const int frame_size, - const int frame_step, - const int n_mel, - const int n_threads, - const whisper_filters & filters, - const bool debug, - whisper_mel & mel) { - const int64_t t_start_us = ggml_time_us(); + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 + whisper_mel calculate(whisper_span ssamples, int n_threads) const override { + // Hann window + const float * hann = global_cache.hann_window; - // Hann window - WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size"); - const float * hann = global_cache.hann_window; + // Calculate the length of padding + int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; + int64_t stage_2_pad = WHISPER_N_FFT / 2; - // Calculate the length of padding - int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; - int64_t stage_2_pad = frame_size / 2; + const int n_samples = int(ssamples.len); + const float * samples = ssamples.data; - // Initialize a vector and copy data from C array to it. - std::vector samples_padded; - samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); - std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + // Initialize a vector and copy data from C array to it. + std::vector samples_padded; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); - // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio - std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); - // reflective pad 200 samples at the beginning of audio - std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + // reflective pad 200 samples at the beginning of audio + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); - mel.n_mel = n_mel; - // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 - // Calculate number of frames + remove the last frame - mel.n_len = (samples_padded.size() - frame_size) / frame_step; - // Calculate semi-padded sample length to ensure compatibility - mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; - mel.data.resize(mel.n_mel * mel.n_len); + whisper_mel mel; + mel.n_mel = m_filters.n_mel; + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 + // Calculate number of frames + remove the last frame + mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + // Calculate semi-padded sample length to ensure compatibility + mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + mel.data.resize(mel.n_mel * mel.n_len); - { - std::vector workers(n_threads - 1); - for (int iw = 0; iw < n_threads - 1; ++iw) { - workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded, - n_samples + stage_2_pad, frame_size, frame_step, n_threads, - std::cref(filters), std::ref(mel)); - } - - // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); + { + std::vector workers(n_threads - 1); + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw] = std::thread( + log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded, + n_samples + stage_2_pad, n_threads, + std::cref(m_filters), std::ref(mel)); + } - for (int iw = 0; iw < n_threads - 1; ++iw) { - workers[iw].join(); - } - } + // main thread + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, n_threads, m_filters, mel); - // clamping and normalization - double mmax = -1e20; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] > mmax) { - mmax = mel.data[i]; + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } } - } - - mmax -= 8.0; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] < mmax) { - mel.data[i] = mmax; + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] > mmax) { + mmax = mel.data[i]; + } } - mel.data[i] = (mel.data[i] + 4.0)/4.0; - } + mmax -= 8.0; - wstate.t_mel_us += ggml_time_us() - t_start_us; + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] < mmax) { + mel.data[i] = mmax; + } - // Dump log_mel_spectrogram - if (debug) { - std::ofstream outFile("log_mel_spectrogram.json"); - outFile << "["; - for (uint64_t i = 0; i < mel.data.size() - 1; i++) { - outFile << mel.data[i] << ", "; + mel.data[i] = (mel.data[i] + 4.0)/4.0; } - outFile << mel.data[mel.data.size() - 1] << "]"; - outFile.close(); + + return mel; } +}; +} - return true; +whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters) { +#if GGML_USE_CUDA + if (ggml_backend_is_cuda(backend)) { + auto ret = whisper_mel_calc_create_cuda(backend, filters); + // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run) + const float warmup[256] = {0}; + ret->calculate({warmup, 256}, 1); + return ret; + } else +#endif + return new mel_calc_cpu(filters); } // split text into tokens @@ -3593,6 +3589,8 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ return nullptr; } + ctx->mel_calc = whisper_mel_calc_create(ctx->backend, ctx->model.filters); + loader->close(loader->context); return ctx; @@ -3713,6 +3711,8 @@ void whisper_free(struct whisper_context * ctx) { ggml_backend_free(ctx->backend); + delete ctx->mel_calc; + ctx->mel_calc = nullptr; delete ctx; } } @@ -3730,11 +3730,21 @@ void whisper_free_params(struct whisper_full_params * params) { } int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { - WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); - return -1; - } + const int64_t t_start_us = ggml_time_us(); + state->mel = ctx->mel_calc->calculate({samples, n_samples}, n_threads); + state->t_mel_us += ggml_time_us() - t_start_us; + // Dump log_mel_spectrogram + //{ + // auto& mel = state->mel; + // std::ofstream outFile("log_mel_spectrogram.json"); + // outFile << "["; + // for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + // outFile << mel.data[i] << ", "; + // } + // outFile << mel.data[mel.data.size() - 1] << "]"; + // outFile.close(); + //} return 0; } diff --git a/whisper.h b/whisper.h index 2b3d5e574cb..65e88ed7597 100644 --- a/whisper.h +++ b/whisper.h @@ -31,8 +31,10 @@ #define WHISPER_SAMPLE_RATE 16000 #define WHISPER_N_FFT 400 +#define WHISPER_N_FFT_HALF (WHISPER_N_FFT / 2 + 1) #define WHISPER_HOP_LENGTH 160 #define WHISPER_CHUNK_SIZE 30 +#define WHISPER_N_SAMPLES (WHISPER_SAMPLE_RATE * WHISPER_CHUNK_SIZE) #ifdef __cplusplus extern "C" { From f842d31171f9772c443198e1a20c59357aa7a5af Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Thu, 6 Jun 2024 16:20:46 +0300 Subject: [PATCH 099/100] whisper : calculate mel spectrogram directly into a ggml_tensor (#2208) * whisper : calculate mel spectrogram directly into a ggml_tensor * whisper : remove unused temp buffer from state * whisper : fix not initializing wstate.embd_enc --- whisper-mel-cuda.cu | 21 +++--- whisper-mel.hpp | 20 ++++-- whisper.cpp | 170 +++++++++++++++++++++++++++++++------------- 3 files changed, 144 insertions(+), 67 deletions(-) diff --git a/whisper-mel-cuda.cu b/whisper-mel-cuda.cu index ad36cae5830..3f3e3158d3e 100644 --- a/whisper-mel-cuda.cu +++ b/whisper-mel-cuda.cu @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -301,27 +302,23 @@ public: &fzero, mel_data, int(n_mag_frames))); - float * log_mels = nullptr; - CUDA_CHECK(cudaMallocAsync(&log_mels, m_n_mel * n_mag_frames * sizeof(float), m_stream)); + whisper_mel ret; + // Calculate semi-padded sample length to ensure compatibility + int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; + ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel); + assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float)); + + float* log_mels = reinterpret_cast(ret.tensor->data); calc_log_mel( mel_data, int(m_n_mel * n_mag_frames), - m_log_mel_temp_storage, int(m_log_mel_temp_storage_size), + m_log_mel_temp_storage , int(m_log_mel_temp_storage_size), log_mels, m_stream); - whisper_mel ret; - ret.n_mel = m_n_mel; - ret.n_len = int(n_mag_frames); - // Calculate semi-padded sample length to ensure compatibility - ret.n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; - ret.data.resize(m_n_mel * n_mag_frames); - CUDA_CHECK(cudaMemcpyAsync(ret.data.data(), log_mels, ret.data.size() * sizeof(float), cudaMemcpyDeviceToHost, m_stream)); - CUDA_CHECK(cudaStreamSynchronize(m_stream)); // cleanup CUFFT_CHECK(cufftDestroy(plan)); - CUDA_CHECK(cudaFreeAsync(log_mels, m_stream)); CUDA_CHECK(cudaFreeAsync(mel_data, m_stream)); CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream)); CUDA_CHECK(cudaFreeAsync(stft_out, m_stream)); diff --git a/whisper-mel.hpp b/whisper-mel.hpp index bc48475feec..e52b804d9bc 100644 --- a/whisper-mel.hpp +++ b/whisper-mel.hpp @@ -3,11 +3,23 @@ #include struct whisper_mel { - int n_len; - int n_len_org; - int n_mel; + int n_len_org = 0; - std::vector data; + ggml_tensor * tensor = nullptr; + ggml_context * ctx = nullptr; + ggml_backend_buffer_t buffer = nullptr; + + whisper_mel() = default; + ~whisper_mel(); + + whisper_mel(const whisper_mel &) = delete; + whisper_mel & operator=(const whisper_mel &) = delete; + whisper_mel(whisper_mel &&) noexcept; + whisper_mel & operator=(whisper_mel &&) noexcept; + + void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel); + void reset(); + void take(whisper_mel & other) noexcept; }; struct whisper_filters { diff --git a/whisper.cpp b/whisper.cpp index 2dd2f591bd8..dfbcc9d39a0 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -821,7 +821,6 @@ struct whisper_state { struct ggml_tensor * embd_enc = nullptr; // helpers for GPU offloading - std::vector inp_mel; std::vector inp_mask; // decode output (2-dimensional array: [n_tokens][n_vocab]) @@ -1815,7 +1814,8 @@ static bool whisper_encode_external(const whisper_state & wstate) { static struct ggml_cgraph * whisper_build_graph_conv( whisper_context & wctx, - whisper_state & wstate) { + whisper_state & wstate, + const int mel_offset) { const auto & model = wctx.model; const auto & hparams = model.hparams; @@ -1834,9 +1834,32 @@ static struct ggml_cgraph * whisper_build_graph_conv( ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); - ggml_set_name(mel, "mel"); - ggml_set_input(mel); + ggml_tensor * mel_inp = wstate.mel.tensor; + ggml_tensor * mel; + if (mel_inp) { + const int n_len = int(mel_inp->ne[0]); + const int out_s = 2 * n_ctx; + const int i0 = std::min(mel_offset, n_len); + const int i1 = std::min(mel_offset + out_s, n_len); + const int mel_s = i1 - i0; + + assert(mel_inp->type == GGML_TYPE_F32); + assert(mel_inp->ne[1] == n_mels); + + ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0)); + + if (mel_s < out_s) { + mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0); + } + else { + mel = ggml_cont(ctx0, cur); + } + } + else { + // just create some tensor so that the graph/buffer size estimation is correct + mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels); + } + ggml_set_name(mel, "mel"); // used with external encoding struct ggml_tensor * cur = nullptr; @@ -2218,45 +2241,21 @@ static bool whisper_encode_internal( { auto & alloc = wstate.alloc_conv.alloc; - ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate); + ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset); if (!ggml_gallocr_alloc_graph(alloc, gf)) { // should never happen as we pre-allocate the memory return false; } - struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); - - // set the input - { - const auto & mel_inp = wstate.mel; - const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx; - - assert(mel->type == GGML_TYPE_F32); - assert(mel_inp.n_mel == wctx.model.hparams.n_mels); - - wstate.inp_mel.resize(ggml_nelements(mel)); - - float * dst = wstate.inp_mel.data(); - memset(dst, 0, ggml_nbytes(mel)); - - const int i0 = std::min(mel_offset, mel_inp.n_len); - const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); - - for (int j = 0; j < mel_inp.n_mel; ++j) { - for (int i = i0; i < i1; ++i) { - dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; - } - } - - ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; } - if (!whisper_encode_external(wstate)) { - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { - return false; - } - } else { + if (whisper_encode_external(wstate)) { + ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + assert(mel->ne[1] == wctx.model.hparams.n_mels); + GGML_UNUSED(mel); #if defined(WHISPER_USE_COREML) whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data); #elif defined(WHISPER_USE_OPENVINO) @@ -2886,6 +2885,54 @@ struct whisper_global_cache { // Mel spectrogram +whisper_mel::~whisper_mel() { + reset(); +} + +whisper_mel::whisper_mel(whisper_mel && other) noexcept { + take(other); +} + +whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept { + if (this != &other) { + reset(); + take(other); + } + return *this; +} + +void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) { + this->n_len_org = n_len_org; + assert(!ctx); + ctx = ggml_init({ggml_tensor_overhead(), nullptr, true}); + tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel); + buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend)); + auto alloc = ggml_tallocr_new(buffer); + ggml_tallocr_alloc(&alloc, tensor); +} + +void whisper_mel::reset() { + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + + n_len_org = 0; + tensor = nullptr; + ctx = nullptr; + buffer = nullptr; +} + +void whisper_mel::take(whisper_mel & other) noexcept { + n_len_org = other.n_len_org; + tensor = other.tensor; + ctx = other.ctx; + buffer = other.buffer; + + other.n_len_org = 0; + other.tensor = nullptr; + other.ctx = nullptr; + other.buffer = nullptr; +} + whisper_mel_calc::~whisper_mel_calc() = default; // export vtable whisper_span whisper_mel_calc::hann_window() { @@ -2973,9 +3020,18 @@ static void fft(const std::vector & in, std::vector & out) { } } -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, +namespace { + +struct whisper_mel_data { + int n_len; + int n_len_org; + int n_mel; + float* data; +}; + +void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, int n_samples, int n_threads, - const whisper_filters & filters, whisper_mel & mel) { + const whisper_filters & filters, whisper_mel_data & mel) { const auto frame_size = WHISPER_N_FFT; const auto frame_step = WHISPER_HOP_LENGTH; std::vector fft_in(frame_size, 0.0); @@ -3041,10 +3097,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } } } -namespace { + struct mel_calc_cpu : public whisper_mel_calc { + ggml_backend_t m_backend; const whisper_filters& m_filters; - mel_calc_cpu(const whisper_filters & filters) : m_filters(filters) {} + mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {} // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 whisper_mel calculate(whisper_span ssamples, int n_threads) const override { @@ -3069,15 +3126,24 @@ struct mel_calc_cpu : public whisper_mel_calc { // reflective pad 200 samples at the beginning of audio std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); - whisper_mel mel; + whisper_mel_data mel; mel.n_mel = m_filters.n_mel; // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 // Calculate number of frames + remove the last frame mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; // Calculate semi-padded sample length to ensure compatibility mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; - mel.data.resize(mel.n_mel * mel.n_len); + std::vector host_mel_data; + + whisper_mel ret; + ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel); + if (ggml_backend_buffer_is_host(ret.buffer)) { + mel.data = reinterpret_cast(ret.tensor->data); + } else { + host_mel_data.resize(mel.n_len * mel.n_mel); + mel.data = host_mel_data.data(); + } { std::vector workers(n_threads - 1); @@ -3114,7 +3180,12 @@ struct mel_calc_cpu : public whisper_mel_calc { mel.data[i] = (mel.data[i] + 4.0)/4.0; } - return mel; + if (!host_mel_data.empty()) { + // the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it + ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor)); + } + + return ret; } }; } @@ -3129,7 +3200,7 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper return ret; } else #endif - return new mel_calc_cpu(filters); + return new mel_calc_cpu(backend, filters); } // split text into tokens @@ -3347,7 +3418,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { { bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend, [&]() { - return whisper_build_graph_conv(*ctx, *state); + return whisper_build_graph_conv(*ctx, *state, 0); }); if (!ok) { @@ -3763,12 +3834,9 @@ int whisper_set_mel_with_state( return -1; } - state->mel.n_len = n_len; - state->mel.n_len_org = n_len; - state->mel.n_mel = n_mel; - - state->mel.data.resize(n_len*n_mel); - memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + state->mel.reset(); + state->mel.init(ctx->backend, n_len, n_len, n_mel); + ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor)); return 0; } From 87acd6d629461ff48c3d58a504ea797736d4b070 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 6 Jun 2024 18:51:36 +0300 Subject: [PATCH 100/100] whisper : whisper_state/backend fixes (#2217) * whisper : fixes * ci : WHISPER_CUBLAS -> WHISPER_CUDA --- .github/workflows/build.yml | 2 +- whisper-mel-cuda.cu | 4 +- whisper-mel.hpp | 19 ++---- whisper.cpp | 125 +++++++++++++++--------------------- 4 files changed, 60 insertions(+), 90 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e9bf9c28292..2095e70d175 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -498,7 +498,7 @@ jobs: run: > cmake -S . -B ./build -A ${{ matrix.arch }} -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DWHISPER_CUBLAS=${{ matrix.cublas }} + -DWHISPER_CUDA=${{ matrix.cublas }} -DWHISPER_SDL2=${{ matrix.sdl2 }} - name: Build ${{ matrix.cuda-toolkit }} diff --git a/whisper-mel-cuda.cu b/whisper-mel-cuda.cu index 3f3e3158d3e..9a6f1093f8e 100644 --- a/whisper-mel-cuda.cu +++ b/whisper-mel-cuda.cu @@ -194,7 +194,7 @@ class mel_calc_cuda : public whisper_mel_calc { size_t m_log_mel_temp_storage_size = 0; void * m_log_mel_temp_storage = nullptr; public: - mel_calc_cuda(ggml_backend_t backend, const whisper_filters& filters) + mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters) : m_n_mel(filters.n_mel) , m_backend(backend) { @@ -305,7 +305,7 @@ public: whisper_mel ret; // Calculate semi-padded sample length to ensure compatibility int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; - ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel); + whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel); assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float)); float* log_mels = reinterpret_cast(ret.tensor->data); diff --git a/whisper-mel.hpp b/whisper-mel.hpp index e52b804d9bc..1a54a23c730 100644 --- a/whisper-mel.hpp +++ b/whisper-mel.hpp @@ -5,22 +5,14 @@ struct whisper_mel { int n_len_org = 0; - ggml_tensor * tensor = nullptr; ggml_context * ctx = nullptr; + ggml_tensor * tensor = nullptr; ggml_backend_buffer_t buffer = nullptr; +}; - whisper_mel() = default; - ~whisper_mel(); - - whisper_mel(const whisper_mel &) = delete; - whisper_mel & operator=(const whisper_mel &) = delete; - whisper_mel(whisper_mel &&) noexcept; - whisper_mel & operator=(whisper_mel &&) noexcept; +void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel); - void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel); - void reset(); - void take(whisper_mel & other) noexcept; -}; +void whisper_mel_free(whisper_mel & mel); struct whisper_filters { int32_t n_mel; @@ -40,6 +32,3 @@ struct whisper_mel_calc { virtual whisper_mel calculate(whisper_span samples, int n_threads) const = 0; static whisper_span hann_window(); }; - -// returns a new pointer which needs to be freed with delete -whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters); diff --git a/whisper.cpp b/whisper.cpp index dfbcc9d39a0..e8a1320898b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -801,6 +801,7 @@ struct whisper_state { whisper_kv_cache kv_pad; whisper_mel mel; + whisper_mel_calc * mel_calc = nullptr; whisper_batch batch; @@ -870,8 +871,6 @@ struct whisper_context { whisper_model model; whisper_vocab vocab; - whisper_mel_calc * mel_calc = nullptr; - whisper_state * state = nullptr; ggml_backend_t backend = nullptr; @@ -893,7 +892,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) { BYTESWAP_VALUE(dest); } -static bool kv_cache_init( +static bool whisper_kv_cache_init( struct whisper_kv_cache & cache, ggml_backend_t backend, ggml_type wtype, @@ -936,7 +935,7 @@ static bool kv_cache_init( return true; } -static void kv_cache_free(struct whisper_kv_cache & cache) { +static void whisper_kv_cache_free(struct whisper_kv_cache & cache) { ggml_free(cache.ctx); ggml_backend_buffer_free(cache.buffer); cache.ctx = nullptr; @@ -1250,9 +1249,12 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params } #endif + GGML_UNUSED(params); + if (backend_gpu) { return backend_gpu; } + return ggml_backend_cpu_init(); } @@ -2885,52 +2887,25 @@ struct whisper_global_cache { // Mel spectrogram -whisper_mel::~whisper_mel() { - reset(); -} - -whisper_mel::whisper_mel(whisper_mel && other) noexcept { - take(other); -} - -whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept { - if (this != &other) { - reset(); - take(other); - } - return *this; -} - -void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) { - this->n_len_org = n_len_org; - assert(!ctx); - ctx = ggml_init({ggml_tensor_overhead(), nullptr, true}); - tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel); - buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend)); - auto alloc = ggml_tallocr_new(buffer); - ggml_tallocr_alloc(&alloc, tensor); -} - -void whisper_mel::reset() { - ggml_free(ctx); - ggml_backend_buffer_free(buffer); - - n_len_org = 0; - tensor = nullptr; - ctx = nullptr; - buffer = nullptr; +void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel) { + WHISPER_LOG_INFO("%s: n_len = %d, n_len_org = %d, n_mel = %d\n", __func__, n_len, n_len_org, n_mel); + mel.n_len_org = n_len_org; + assert(!mel.ctx); + mel.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true}); + mel.tensor = ggml_new_tensor_2d(mel.ctx, GGML_TYPE_F32, n_len, n_mel); + mel.buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(mel.tensor) + ggml_backend_get_alignment(backend)); + auto alloc = ggml_tallocr_new(mel.buffer); + ggml_tallocr_alloc(&alloc, mel.tensor); } -void whisper_mel::take(whisper_mel & other) noexcept { - n_len_org = other.n_len_org; - tensor = other.tensor; - ctx = other.ctx; - buffer = other.buffer; +void whisper_mel_free(whisper_mel & mel) { + ggml_free(mel.ctx); + ggml_backend_buffer_free(mel.buffer); - other.n_len_org = 0; - other.tensor = nullptr; - other.ctx = nullptr; - other.buffer = nullptr; + mel.n_len_org = 0; + mel.ctx = nullptr; + mel.tensor = nullptr; + mel.buffer = nullptr; } whisper_mel_calc::~whisper_mel_calc() = default; // export vtable @@ -3026,7 +3001,7 @@ struct whisper_mel_data { int n_len; int n_len_org; int n_mel; - float* data; + float * data; }; void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, @@ -3100,7 +3075,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v struct mel_calc_cpu : public whisper_mel_calc { ggml_backend_t m_backend; - const whisper_filters& m_filters; + const whisper_filters & m_filters; mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {} // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 @@ -3137,7 +3112,7 @@ struct mel_calc_cpu : public whisper_mel_calc { std::vector host_mel_data; whisper_mel ret; - ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel); + whisper_mel_init(ret, m_backend, mel.n_len, mel.n_len_org, mel.n_mel); if (ggml_backend_buffer_is_host(ret.buffer)) { mel.data = reinterpret_cast(ret.tensor->data); } else { @@ -3325,15 +3300,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return nullptr; } + state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters); + // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // in theory, there can be a case where this is not enough, but in practice it should always be enough const int factor = 3; - if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, + if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype, ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_layer, GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { - WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; } @@ -3343,11 +3320,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); } - if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, + if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype, ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_layer, GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { - WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__); whisper_free_state(state); return nullptr; } @@ -3357,11 +3334,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } - if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, + if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype, ctx->model.hparams.n_audio_state, 1, GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { - WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; } @@ -3373,7 +3350,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // [EXPERIMENTAL] Token-level timestamps with DTW if (ctx->params.dtw_token_timestamps) { - if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { + if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) { WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__); whisper_free_state(state); return nullptr; @@ -3416,7 +3393,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // conv allocator { - bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend, + bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend, [&]() { return whisper_build_graph_conv(*ctx, *state, 0); }); @@ -3432,7 +3409,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // encoder allocator if (!whisper_encode_external(*state)) { - bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend, + bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend, [&]() { return whisper_build_graph_encoder(*ctx, *state); }); @@ -3448,7 +3425,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // cross allocator { - bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend, + bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend, [&]() { return whisper_build_graph_cross(*ctx, *state); }); @@ -3464,7 +3441,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // decoder allocator { - bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend, + bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend, [&]() { const auto & hparams = ctx->model.hparams; @@ -3660,8 +3637,6 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ return nullptr; } - ctx->mel_calc = whisper_mel_calc_create(ctx->backend, ctx->model.filters); - loader->close(loader->context); return ctx; @@ -3738,9 +3713,14 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa void whisper_free_state(struct whisper_state * state) { if (state) { - kv_cache_free(state->kv_self); - kv_cache_free(state->kv_cross); - kv_cache_free(state->kv_pad); + whisper_kv_cache_free(state->kv_self); + whisper_kv_cache_free(state->kv_cross); + whisper_kv_cache_free(state->kv_pad); + + whisper_mel_free(state->mel); + + delete state->mel_calc; + state->mel_calc = nullptr; #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { @@ -3782,8 +3762,6 @@ void whisper_free(struct whisper_context * ctx) { ggml_backend_free(ctx->backend); - delete ctx->mel_calc; - ctx->mel_calc = nullptr; delete ctx; } } @@ -3800,9 +3778,11 @@ void whisper_free_params(struct whisper_full_params * params) { } } -int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { +int whisper_pcm_to_mel_with_state(struct whisper_context * /*ctx*/, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { const int64_t t_start_us = ggml_time_us(); - state->mel = ctx->mel_calc->calculate({samples, n_samples}, n_threads); + + state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads); + state->t_mel_us += ggml_time_us() - t_start_us; // Dump log_mel_spectrogram @@ -3834,8 +3814,9 @@ int whisper_set_mel_with_state( return -1; } - state->mel.reset(); - state->mel.init(ctx->backend, n_len, n_len, n_mel); + whisper_mel_free(state->mel); + whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel); + ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor)); return 0;