From 9bf67f4d1da2039560ffc2f801092b5bb2f78e71 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 29 Oct 2024 08:46:29 +0100 Subject: [PATCH] docs: add more notes about the llama-kv-cache.md (wip) --- notes/llama-kv-cache.md | 198 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 198 insertions(+) diff --git a/notes/llama-kv-cache.md b/notes/llama-kv-cache.md index 1dfc3b5..5b36ed9 100644 --- a/notes/llama-kv-cache.md +++ b/notes/llama-kv-cache.md @@ -1881,5 +1881,203 @@ z_31 ``` + +TODO: Add this information to the above section that walks through this part +of the code (but it does so for the initial prompt): +For a single token decode the matrix multiplication between Q an K looks like +this: +```c++ +struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); +``` +Now, if we inspect the shape of the K and Q tensors: +```console +(gdb) p q->ne +$2 = {128, 1, 32, 1} + +(gdb) p k->ne +$1 = {128, 32, 32, 1} + +(gdb) p kq_mask->ne +$6 = {32, 32, 1, 1} +``` +Recall that the embedding dimension is 4096 and the number of heads is 32. So +each head will have 128 dimensions/features. + +We can visualize this something like this: +``` + Query (q) Key(k) + +z0 z_0 + 0 [0 ... 127] 0 [0...31] + ... + ... + ... + 127 [0...31] +... ... +... ... +... ... + +z_31 z_31 + 0 [0 ... 127] 0 [0...31] + ... + ... + ... + 127 [0...31] +``` +And notice that the Key matrix has 32 rows (tokens) which in this case would +contain previous roped k value tokens. Now, my understanding it that the values +in this matrix might not all belong to the same sequence as the current token +belongs to. But this is where the mask comes in and will mask out those values +when the softmax is calculated. This works as when the matrix multiplication is +done the results of the dot products of against the roped k-values that are not +part of the current sequenece are independent. What I mean is if we have this +following simplified example: +``` + q = [1, 2, 3] + + k = [1, 2, 3] (seq0) [4 1 2] + [4, 5, 6] (seq1) [4 5 6] x [1 2 3] = [12 32 50] + [7, 8, 9] (seq1) [7 8 9] + + mask = [0 1 1] + [12 32 50] = [32 50] +``` +The computation for seq0 is still being performed but it will not be part of the +softmax calculation as the mask will be applied which removes the values that +belong to seq0. + +My initial reaction to this was that it seemed wasteful to compute the dot +product for the cached key values that don't belong to the current token's +sequence. But I think that modern CPU and GPUs are optimized for these types +of dense matrix operations. Filtering or breaking up the computation would +require irregular memory access patterns which would be less efficient. +Is sounds like this is a common pattern in transformer implementations, that is +compute everything + mask. +This part if of the code is just creating the operation and we don't have any +actual values yet. This is done in the `llama_set_inputs` function which we +we discussed above before I derailed the discussion to this part. This will +happen from time to time to connect the tensors with the actual data. +```c++ +static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { + ... + if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { + if (cparams.causal_attn && !lctx.is_encoding) { + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + + float * data = nullptr; + + if (lctx.inp_KQ_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + data = (float *) lctx.inp_KQ_mask->data; + } +``` +```console +(gdb) p lctx->inp_KQ_mask-ne +No symbol "ne" in current context. +(gdb) p lctx->inp_KQ_mask->ne +$8 = {32, 32, 1, 1} +(gdb) p seq_id +$10 = 1 +(gdb) p n_seq_tokens +$11 = 1 +(gdb) p ubatch.pos[0] +$16 = 13 +(gdb) p n_kv +$17 = 32 +``` +```c++ + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + j]; + + (1) ----> for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + } + + (2) -----> if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } +``` +So iterate over all the 32 tokens in the cache and mask out the ones that don't +belong to the current tokens sequence: +```console + 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 31 + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+-------+----+ + |-inf|-inf|-inf|-inf|-inf|-inf|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|-inf|-inf| ... |-inf| + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+ +``` +Notice how this inner loop (1) is just iterating over a single token in this +case. The second loop (2) will then start from from the first token and continue +up to the the padded size of the tokens (n_kv which is 32 in this case): +```console + 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 31 + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+-------+----+ +0 |-inf|-inf|-inf|-inf|-inf|-inf|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|-inf|-inf| ... |-inf| + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+ + |-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf| ... |-inf| + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+ +``` +And this will continue for all the 32 tokens in the cache. +```console + 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 31 + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+-------+----+ +0 |-inf|-inf|-inf|-inf|-inf|-inf|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|-inf|-inf| ... |-inf| + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+ + |-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf| ... |-inf| + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+ + ... +31 |-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf| ... |-inf| + +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+ +``` +So now we know that we know what the mask looks like, we can revisit the usage +of it and for that we have to turn back to `llm_build_kqv`: +```c++ + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); +``` +```console +(gdb) p q->ne +$2 = {128, 1, 32, 1} + +(gdb) p k->ne +$1 = {128, 32, 32, 1} + +(gdb) p kq_mask->ne +$6 = {32, 32, 1, 1} +``` + _wip_