Skip to content

Commit

Permalink
docs: add more notes about the llama-kv-cache.md (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
danbev committed Oct 29, 2024
1 parent 0bd71eb commit 9bf67f4
Showing 1 changed file with 198 additions and 0 deletions.
198 changes: 198 additions & 0 deletions notes/llama-kv-cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -1881,5 +1881,203 @@ z_31

```


TODO: Add this information to the above section that walks through this part
of the code (but it does so for the initial prompt):
For a single token decode the matrix multiplication between Q an K looks like
this:
```c++
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
```
Now, if we inspect the shape of the K and Q tensors:
```console
(gdb) p q->ne
$2 = {128, 1, 32, 1}
(gdb) p k->ne
$1 = {128, 32, 32, 1}
(gdb) p kq_mask->ne
$6 = {32, 32, 1, 1}
```
Recall that the embedding dimension is 4096 and the number of heads is 32. So
each head will have 128 dimensions/features.

We can visualize this something like this:
```
Query (q) Key(k)
z0 z_0
0 [0 ... 127] 0 [0...31]
...
...
...
127 [0...31]
... ...
... ...
... ...
z_31 z_31
0 [0 ... 127] 0 [0...31]
...
...
...
127 [0...31]
```
And notice that the Key matrix has 32 rows (tokens) which in this case would
contain previous roped k value tokens. Now, my understanding it that the values
in this matrix might not all belong to the same sequence as the current token
belongs to. But this is where the mask comes in and will mask out those values
when the softmax is calculated. This works as when the matrix multiplication is
done the results of the dot products of against the roped k-values that are not
part of the current sequenece are independent. What I mean is if we have this
following simplified example:
```
q = [1, 2, 3]
k = [1, 2, 3] (seq0) [4 1 2]
[4, 5, 6] (seq1) [4 5 6] x [1 2 3] = [12 32 50]
[7, 8, 9] (seq1) [7 8 9]
mask = [0 1 1]
[12 32 50] = [32 50]
```
The computation for seq0 is still being performed but it will not be part of the
softmax calculation as the mask will be applied which removes the values that
belong to seq0.

My initial reaction to this was that it seemed wasteful to compute the dot
product for the cached key values that don't belong to the current token's
sequence. But I think that modern CPU and GPUs are optimized for these types
of dense matrix operations. Filtering or breaking up the computation would
require irregular memory access patterns which would be less efficient.
Is sounds like this is a common pattern in transformer implementations, that is
compute everything + mask.
This part if of the code is just creating the operation and we don't have any
actual values yet. This is done in the `llama_set_inputs` function which we
we discussed above before I derailed the discussion to this part. This will
happen from time to time to connect the tensors with the actual data.
```c++
static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
...
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
if (cparams.causal_attn && !lctx.is_encoding) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;


float * data = nullptr;

if (lctx.inp_KQ_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
data = (float *) lctx.inp_KQ_mask->data;
}
```
```console
(gdb) p lctx->inp_KQ_mask-ne
No symbol "ne" in current context.
(gdb) p lctx->inp_KQ_mask->ne
$8 = {32, 32, 1, 1}
(gdb) p seq_id
$10 = 1
(gdb) p n_seq_tokens
$11 = 1
(gdb) p ubatch.pos[0]
$16 = 13
(gdb) p n_kv
$17 = 32
```
```c++
for (int h = 0; h < 1; ++h) {
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];

for (int j = 0; j < n_seq_tokens; ++j) {
const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];

(1) ----> for (int i = 0; i < n_kv; ++i) {
float f;
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(kv_self.cells[i].pos - pos);
} else {
f = 0.0f;
}
}

if (data) {
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}

// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}
}

(2) -----> if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
```
So iterate over all the 32 tokens in the cache and mask out the ones that don't
belong to the current tokens sequence:
```console
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 31
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+-------+----+
|-inf|-inf|-inf|-inf|-inf|-inf|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|-inf|-inf| ... |-inf|
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+
```
Notice how this inner loop (1) is just iterating over a single token in this
case. The second loop (2) will then start from from the first token and continue
up to the the padded size of the tokens (n_kv which is 32 in this case):
```console
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 31
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+-------+----+
0 |-inf|-inf|-inf|-inf|-inf|-inf|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|-inf|-inf| ... |-inf|
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+
|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf| ... |-inf|
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+
```
And this will continue for all the 32 tokens in the cache.
```console
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 31
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+-------+----+
0 |-inf|-inf|-inf|-inf|-inf|-inf|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|0.0f|-inf|-inf| ... |-inf|
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+
|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf| ... |-inf|
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+
...
31 |-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf|-inf| ... |-inf|
+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------------+
```
So now we know that we know what the mask looks like, we can revisit the usage
of it and for that we have to turn back to `llm_build_kqv`:
```c++
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
```
```console
(gdb) p q->ne
$2 = {128, 1, 32, 1}

(gdb) p k->ne
$1 = {128, 32, 32, 1}

(gdb) p kq_mask->ne
$6 = {32, 32, 1, 1}
```

<a name="wip"></a>
_wip_

0 comments on commit 9bf67f4

Please sign in to comment.