From 9ed153c21d32d47a71d8fabb7d01cb523451b58d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 4 Nov 2024 06:03:09 +0100 Subject: [PATCH] docs: add more notes on the kv cache for recurrent models Work in progress. --- notes/llama-kv-cache.md | 376 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 369 insertions(+), 7 deletions(-) diff --git a/notes/llama-kv-cache.md b/notes/llama-kv-cache.md index 1a42b98..56279c5 100644 --- a/notes/llama-kv-cache.md +++ b/notes/llama-kv-cache.md @@ -2196,6 +2196,7 @@ The `llama_batch` that is passed to `llama_decode` looks like this: $5 = (const llama_batch &) @0x7fffffffd538: {n_tokens = 9, token = 0x555555a48f10, embd = 0x0, pos = 0x555555a8a9b0, n_seq_id = 0x555555a32c40, seq_id = 0x555555a33450, logits = 0x555555b071a0 ""} ``` +Notice that we have 9 tokens in this batch and 2 sequences. Let now look at `llama_kv_cache_init` which is pretty much the same as we discussed before but this time recurrent will be true: @@ -2204,18 +2205,164 @@ skipped earlier. (gdb) p cache.recurrent $11 = true ``` + +Next lets turn our attention to `llama_decode_internal`: +```c++ +static int llama_decode_internal( + llama_context & lctx, + llama_batch inp_batch) { + ... + // temporary allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(lctx, inp_batch); + const llama_batch & batch = batch_allocr.batch; + const uint32_t n_tokens_all = batch.n_tokens; +```` +```console +(lldb) expr n_tokens_all +(const uint32_t) $0 = 9 +``` +There will be a check to make sure that all the tokens in the batch are in the +models vocabulary. + +Then we will set/update/initialize the sbatch member of `llama_context`, and notice +that `kv_self_recurrent` is used to set the `simple_split` parameter: +```c++ + lctx.sbatch.from_batch(batch, n_embd, + /* simple_split */ !kv_self.recurrent, + /* logits_all */ n_outputs == n_tokens_all); +``` +So in this case it is not a `simple_spit` which has been the case before when we went +through this. I actually think I've gone through this as well so I'll link to this +later. But lets take a look at the sbatch to get an overview: +```console +(lldb) p lctx.sbatch +(llama_sbatch) { + n_tokens = 9 + n_embd = 2048 + logits_all = false + ids = size=9 { + [0] = 0 + [1] = 1 + [2] = 2 + [3] = 3 + [4] = 4 + [5] = 5 + [6] = 6 + [7] = 7 + [8] = 8 + } + out_ids = size=0 {} + seq = size=2 { + [0] = { + n_seq_id = 1 + seq_id = 0x00006000015c5170 + offset = 0 + length = 5 + } + [1] = { + n_seq_id = 1 + seq_id = 0x00006000015c65f0 + offset = 5 + length = 4 + } + } + batch = 0x000000016fdfe630 + ubatch_token = size=0 {} + ubatch_embd = size=0 {} + ubatch_pos = size=0 {} + ubatch_n_seq_id = size=0 {} + ubatch_seq_id = size=0 {} + ubatch_output = size=0 {} +} +``` +And then there will be a while loop over the `n_tokens` in the sbatch: +```c++ + while (lctx.sbatch.n_tokens > 0) { + llama_ubatch ubatch; + if (kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = lctx.sbatch.split_seq(n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = lctx.sbatch.split_equal(n_ubatch); + } + } else { + ubatch = lctx.sbatch.split_simple(n_ubatch); + } + const uint32_t n_tokens = ubatch.n_tokens; +``` +In our case this will call `lctx.sbatch.split_equal`. Now it is good to keep in mind +that recurrent models differ from causul models in that they process tokens in sequence +where as a causal model can process all tokens in one go (using a causal mask): +``` + SSM (Mamba) Transformer + +--+ +--+ +--+ +--------------------+ + |t1| -> |t2| -> |t3| ... | t1 t2 t3 ... | + +--+ +--+ +--+ +--------------------+ +``` +After this the returned ubatch will look like this: +```console +(llama_ubatch) { + equal_seqs = true + n_tokens = 8 + n_seq_tokens = 4 + n_seqs = 2 + token = 0x00006000019c3180 + embd = 0x0000000000000000 + pos = 0x00006000019c3150 + n_seq_id = 0x00006000019c31e0 + seq_id = 0x00006000034dbf20 + output = 0x00006000015d0000 +} +``` +So this has split the batch into two equals sized ubatches (not that we have +9 tokens in total so there is one left over which I'll return to later). +We can see that this ubatch has two sequences, and that each sequence is of equal +size and the size of each is 4 tokens. +```console +(lldb) p ubatch.token[0] +(llama_token) 15961 +(lldb) p ubatch.token[1] +(llama_token) 14528 +(lldb) p ubatch.token[2] +(llama_token) 6749 +(lldb) p ubatch.token[3] +(llama_token) 10777 +(lldb) p ubatch.token[4] +(llama_token) 1276 +(lldb) p ubatch.token[5] +(llama_token) 310 +(lldb) p ubatch.token[6] +(llama_token) 9497 +(lldb) p ubatch.token[7] +(llama_token) 5214 + +(lldb) expr lctx.model.vocab.id_to_token[15961].text +(const llama_vocab::token) $2 = "Dan" +(lldb) expr lctx.model.vocab.id_to_token[14528].text +(const llama_vocab::token) $3 = "Ġloves" +(lldb) expr lctx.model.vocab.id_to_token[6749].text +(const llama_vocab::token) $4 = "Ġice" +(lldb) expr lctx.model.vocab.id_to_token[10777].text +(const llama_vocab::token) $5 = "Ġcream" +``` So now lets look at `llama_kv_cache_find_slot` and the recurrent block that we skipped earlier: ```c++ static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_ubatch & batch) { + const uint32_t n_tokens = batch.n_tokens; // 8 + const uint32_t n_seqs = batch.n_seqs; // 2 + const uint32_t n_seq_tokens = batch.n_seq_tokens; // 4 ... if (cache.recurrent) { // can only process batches with an equal number of new tokens in each sequence GGML_ASSERT(batch.equal_seqs); - int32_t min = cache.size - 1; + int32_t min = cache.size - 1; // cache.size = 2, so min = -1 int32_t max = 0; // everything should fit if all seq_ids are smaller than the max @@ -2247,18 +2394,233 @@ static bool llama_kv_cache_find_slot( } } } + } +``` +This is what the cache cells look like before any updated are made: +```console +(lldb) p cache.cells +(std::vector) size=2 { + [0] = { + pos = -1 + delta = 0 + src = -1 + tail = -1 + seq_id = size=0 {} + } + [1] = { + pos = -1 + delta = 0 + src = -1 + tail = -1 + seq_id = size=0 {} + } +} +``` +And this is what it looks like afterwards: +```console +(lldb) p cache.cells +(std::vector) size=2 { + [0] = { + pos = 8 + delta = 0 + src = -1 + tail = 1 + seq_id = size=1 { + [0] = 1 + } + } + [1] = { + pos = 3 + delta = 0 + src = -1 + tail = 0 + seq_id = size=1 { + [0] = 0 + } + } +} +``` +So for the cell 0, the position 8 is the last token for that sequence that will be processed, we +can verify this by taking the sequence id that this cells has, 1 and look it up in the batch: +```console +(lldb) expr lctx.sbatch.batch->seq_id[8][0] +(llama_seq_id) $18 = 1 +``` +For the second entry notice that pos is 3 as this is the last token that will be processed +for that sequence and that sequence still has one more token to be processed. + +Now, `tail` for cell 0 set set to 1 and for cell 1 it is set to 0. So this is a circular +kind of reference here. +I still don't understand what `tail` is so lets try to add a call to +`llama_kv_cache_seq_rm` to perhaps help undertand this: +```c++ + llama_kv_cache_seq_rm(ctx, 0, 0, 2); +``` +```c++ +static bool llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + ... + if (cache.recurrent) { + ... + int32_t & tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + const llama_kv_cell & cell = cache.cells[tail_id]; +``` +So the `seq_id` that we passed in, which is 0, is looked up in the cache cells, and that +cells tails will be set to `tail_id`: +```console +(lldb) p cache.cells[0].tail +(int32_t) 1 +``` +And this is creater than 0 so this will then look up the cell that tail points to: +```console +(lldb) p cell +(const llama_kv_cell &) 0x000060000198d2e8: { + pos = 4 + delta = 0 + src = 1 + tail = 0 + seq_id = size=1 { + [0] = 0 + } +} +``` +Next we check the range: +```c++ + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } +``` +So this is checking that p0 is positive and less than or equal to the position of the cell, +and similarly for p1. +```console +(lldb) p batch.seq_id[0][0] +(llama_seq_id) 0 +(lldb) p batch.seq_id[1][0] +(llama_seq_id) 0 +(lldb) p batch.seq_id[2][0] +(llama_seq_id) 0 +(lldb) p batch.seq_id[3][0] +(llama_seq_id) 0 +(lldb) p batch.seq_id[4][0] +(llama_seq_id) 0 +(lldb) p batch.seq_id[5][0] +(llama_seq_id) 1 +(lldb) p batch.seq_id[6][0] +(llama_seq_id) 1 +(lldb) p batch.seq_id[7][0] +(llama_seq_id) 1 +(lldb) p batch.seq_id[8][0] +(llama_seq_id) 1 +(lldb) p batch.seq_id[9][0] +(llama_seq_id) 0 +``` +So we can specify that we would like to remove the positions 0-5 for sequence 0 using: +```c++ + llama_kv_cache_seq_rm(ctx, 0, 0, 5); +``` + +```console +(lldb) p cache.cells +(std::vector) size=2 { + [0] = { + pos = 8 + delta = 0 + src = 0 + tail = 1 + seq_id = size=1 { + [0] = 1 } + } + [1] = { + pos = 4 + delta = 0 + src = 1 + tail = 0 + seq_id = size=1 { + [0] = 0 + } + } +} +``` +Now we can see that the first entry in the cache is not the first sequence in our batch +but that does not matter as the sequence id we pass in will look up the cell with index +of the sequence id but use it's tail. +```c++ + const int32_t tail_id = cache.cells[seq_id].tail; +``` +And this will be the index of the cell that contains the sequence: +```console +(const llama_kv_cell &) 0x000060000214dce8: { + pos = 4 + delta = 0 + src = 1 + tail = 0 + seq_id = size=1 { + [0] = 0 + } +} +``` +And what will happen is that the `tail_id`, which is a reference to the tail of the +cell that we looked up and this will set it to -1. + +If we look again as `llama_kv_cache_find_slot` we have the following code: +```c++ + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + llama_kv_cell & seq_meta = cache.cells[seq_id]; ``` +This will iterate over the number of sequences in the ubatch which is 2 in this case. So +the sequence id for 0 will be looked up in the ubatch and this will be 1 as that is +how the ubatch was created. With that we lookup the cell for that sequence. -The ubatch (`n_seq`) looks like this: +I think that using the tail is a way to enable things like speculative decoding and +allows for not having to copy the state. For example, lets say we have decoded a +sequence and our cache looks like it this: ```console -$16 = (const llama_ubatch &) @0x7fffffffd490: {equal_seqs = true, n_tokens = 8, n_seq_tokens = 4, n_seqs = 2, - token = 0x555555a137d0, embd = 0x0, pos = 0x555555a137a0, n_seq_id = 0x555555a13880, seq_id = 0x555555a74060, - output = 0x555555b073b0 ""} +(lldb) p kv_self.cells +(std::vector) size=2 { + [0] = { + pos = 8 + delta = 0 + src = -1 + tail = 1 + seq_id = size=1 { + [0] = 1 + } + } + [1] = { + pos = 3 + delta = 0 + src = -1 + tail = 0 + seq_id = size=1 { + [0] = 0 + } + } +} +``` +If we wanted to speculatively decode the next token for sequence 0 we could use the tail +```console + [2] = { + pos = 8 + delta = 0 + src = -1 + tail = 1 + seq_id = size=1 { + [0] = 2 + } + } ``` -TODO: Take another looks as the ubatch splitting/creation to understand how this -works for recurrent models as it seems to be different. +I'm still now 100% sure about this and I need to revist later. _wip_