From e297d8b8b9e5a9aa9352047b988cc7d3d2c7117a Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 29 Oct 2024 12:42:20 +0100 Subject: [PATCH] docs: update llama-kv-cache.md --- notes/llama-kv-cache.md | 164 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 160 insertions(+), 4 deletions(-) diff --git a/notes/llama-kv-cache.md b/notes/llama-kv-cache.md index 4a35f1a..d3bcf71 100644 --- a/notes/llama-kv-cache.md +++ b/notes/llama-kv-cache.md @@ -2085,12 +2085,12 @@ So as we can expect and have seen before the result of the Q and K matrix is a s matrix, and recall that this is par layer we are seeing. So what this is doing is it is caclulating the softmax of the logits in `kq` which like we said contains the dot product of the current token with all the cached Key values. -In this case the first 6 tokens in the key cache belong to sequence 0, and the ones from 6-13 are +In this case the first 6 tokens in the key cache belong to sequence 0, and the ones from 7-13 are the ones for sequence 1 which the current token belongs to: ```console kq kq_mask z0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 ... 31] - [0 ... 31] [-inf -inf -inf -inf -inf -inf 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -inf -inf] + [0 ... 31] [-inf -inf -inf -inf -inf -inf -inf 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -inf -inf] [-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf ...-inf] ... 31 [-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf ...-inf] @@ -2098,14 +2098,170 @@ z0 0 1 2 3 4 5 6 7 8 9 10 11 1 ... z31 - [0 ... 31] 0 [-inf -inf -inf -inf -inf -inf 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -inf -inf] + [0 ... 31] 0 [-inf -inf -inf -inf -inf -inf -inf 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -inf -inf] [-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf ...-inf] ... 31 [-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf ...-inf] ``` So this will include only the logits that belong to the current token's sequence. -So get a feel for how this works there is a standalone example in +To get a feel for how this works there is a standalone example in [llama-att-softmax.c](../ggml/src/llama-att-softmax.c) +### Connection/link between cells and cache tensors +This might seem obvious but this was not clear to me initially which is why I'm adding this +section. The `cells` member of `llama_kv_cache` is a vector of `llama_kv_cell`: +```c++ +struct llama_kv_cache { + ... + std::vector cells; + ... +} + +struct llama_kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; + int32_t src = -1; // used by recurrent state models to copy states + int32_t tail = -1; + + std::set seq_id; + ... +}; +``` +As tokens are decoded they are added one by one to the cells vector. In the previous example when +we had two prompts with different sequences 0, and 1, the would populate the first 12 positions +in the cells vector. As we saw in the previous section `llama_set_inputs` iterated through the +cells and checks if the current tokens sequence id is in that cell. + +If we inspect the cache we can see can see that the first 6 cells belong to sequence 0 and the +next 7 belong to sequence 1: +```console +(lldb) p ctx->kv_self->cells +(std::vector) size=1024 { + [0] = { + pos = 0 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 0 + } + } + [1] = { + pos = 1 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 0 + } + } + [2] = { + pos = 2 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 0 + } + } + [3] = { + pos = 3 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 0 + } + } + [4] = { + pos = 4 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 0 + } + } + [5] = { + pos = 5 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 0 + } + } + [6] = { + pos = 6 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 1 + } + } + [7] = { + pos = 7 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 1 + } + } + [8] = { + pos = 8 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 1 + } + } + [9] = { + pos = 9 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 1 + } + } + [10] = { + pos = 10 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 1 + } + } + [11] = { + pos = 11 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 1 + } + } + [12] = { + pos = 12 + delta = 0 + src = -1 + tail = -1 + seq_id = size=1 { + [0] = 1 + } + } + [13] = { + pos = -1 + delta = 0 + src = -1 + tail = -1 + seq_id = size=0 {} + } +``` + _wip_