From bf0d01fe498c2957375ecedccbc3bb9b5c1f5b54 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 6 Aug 2024 07:22:42 +0200 Subject: [PATCH] docs: update self-extend RoPE notes --- notes/llama-main.md | 97 +++++++++++++++++++++++---------------------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/notes/llama-main.md b/notes/llama-main.md index b01fae4f..d07b3892 100644 --- a/notes/llama-main.md +++ b/notes/llama-main.md @@ -419,8 +419,8 @@ of groups we divide `ga_w` by `ga_n` which is 2. So we will have 2 groups. const int ib = (ga_n * ga_i) / ga_w; ``` Now, `ga_i` is the index specifying which group we are currently processing. -This will be incremented with the group size which is `ga_w/ga_n` which is 2 in -this case. So the first iteration `ga_i` will be 0, 2, 4, 6, 8. +This will be incremented with the group size which is `ga_w/ga_n`, 2 in this +case. So the first iteration `ga_i` will be 0, 2, 4, 6, 8. ``` ib = (ga_n * ga_i) / ga_w; // index base ib = (2 * 0) / 4; @@ -466,7 +466,6 @@ dd = (4 / 2) - 0*2 - 4; dd = 2 - - 4; dd = -2 ``` -_wip_ Notice that `ga_w / ga_n` gives us the tokens per group. So `ga_w` is the total number of tokens used for group attention. @@ -574,7 +573,7 @@ static void llama_kv_cache_seq_div( } ``` The cache size is 8192 in this case we are going to iterate over all of them. -Notice that `cache.cells[i].has_seq_id(seq_id)` checks is a cell has the +Notice that `cache.cells[i].has_seq_id(seq_id)` checks if a cell has the passed-in sequence id. `cache.cells[i].pos >= p0 && cache.cells[i].pos < p1` checks that the cell's @@ -743,7 +742,7 @@ Back in the main while look we then have: ``` Now, this is interesting `n_past` is currently 5 and `bd` is 2 so `n_past` will be updated to 3. So instead of having the position of the next token be 5 it -has become 3. +has become 3 (which is the next sequence, following the last pos which was 2). And `ga_i` will be updated to become 2 (the group size) @@ -812,7 +811,7 @@ $91 = {pos = 1, delta = 0, src = 0, seq_id = std::set with 1 element = {[0] = 0} 5 = {pos = 7, delta = 2, src = 0, seq_id = std::set with 1 element = {[0] = 0}} ``` Is this done because the positions were added using `n_past` which was 3 and -then incremented, to the new cells to the positions 3, 4, 5. This add is +then incremented, the new cells have the positions 3, 4, 5. This add is adjusting them to be in the order prior to the adjustment of `n_past`. Notice that they are now incremental from the first position, But not in the grouping but that will be handled by the next division operation: @@ -835,7 +834,7 @@ from the addition call. } } ``` -When i = 4 we will enter the block and set `has_shift`. +When i = 4 we will enter the block and set `has_shift`. ```console (gdb) p cache.cells[i] $100 = {pos = 4, delta = 2, src = 0, seq_id = std::set with 1 element = {[0] = 0}} @@ -1019,9 +1018,18 @@ op_params = {0 }, flags = 0, grad = 0x0, src = { 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, view_src = 0x0, view_offs = 0, data = 0x0, name = '\000' , extra = 0x0} ``` -Next we iterate over all the layer (42). -Looking at the above code I cant see that the `kv_cache` is used. -After that `llama_set_k_shift is called which sets the _delta_ from the cache +Next we iterate over all the layers (42), and for each layer we will create +a ggml operation using `ggml_rope_ext_inplace`. This will add an operation for +rope which will update the target tensors which are the `kv_self.k_l[il]` +tensors. And notice that the position tensor, the third argument to +`ggml_rope_ext_inplace` is the `lctx.inp_K_shift` tensor which is the tensor +that was updated above with the delta values from the cache cells. So this is +how the the kv-cache interacts with rope I think, the positions will have been +updated to reflect the changes to the positions, and then a normal rope +operation will follow. For details about rope see [ggml-rope.md](ggml-rope.md). + + +After that `llama_set_k_shift` is called which sets the _delta_ from the cache cells on the tensor: ```c @@ -1037,6 +1045,32 @@ static void llama_set_k_shift(llama_context & lctx) { } } ``` +So this is updating the `inp_K_shift` tensor with the delta values from the cache +cells. Later in `llama_decode_internal` we have: +```c +static int llama_decode_internal( + llama_context & lctx, + llama_batch batch_all) { // TODO: rename back to batch + ... + + // non-causal masks do not use the KV cache + if (hparams.causal_attn) { + llama_kv_cache_update(&lctx); + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*n_tokens) { + kv_self.head = 0; + } + + if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + return 1; + } + ... + } +``` +Notice the call to `llama_kv_cache_update`. + ```console $202 = {pos = 0, delta = 0, src = 0, seq_id = std::set with 1 element = {[0] = 0}} $203 = {pos = 0, delta = -1, src = 0, seq_id = std::set with 1 element = {[0] = 0}} @@ -1088,45 +1122,12 @@ struct ggml_tensor * ggml_rope_ext_inplace( float beta_slow) { ``` -So this is setting the values of the tensor `inp_K_shift`. +So this is rotating the tensor `a` (`kv_self_.k_l[il]`) using the positions in +`b` (`lctx.inp_K_shift`). And recall that the positions in `lctx.inp_K_shift` +were updated using the cache cell deltas. -Later when the computation is done this call will end up in: -```c -static void ggml_compute_forward_rope_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const bool forward) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; -``` -```console -(gdb) p *src1 -$220 = {type = GGML_TYPE_I32, backend = GGML_BACKEND_TYPE_CPU, buffer = 0x555555a90660, ne = {8192, 1, 1, 1}, - nb = {4, 32768, 32768, 32768}, op = GGML_OP_NONE, op_params = {0 }, flags = 1, grad = 0x0, - src = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, view_src = 0x0, view_offs = 0, - data = 0x7ffadefa0020, name = "leaf_1", '\000' , extra = 0x0} -``` -```c - const int32_t * pos = (const int32_t *) src1->data; -``` -We can inspect these values to see that they are infact the delta values: -```console -(gdb) p *pos -$237 = 0 -(gdb) p *(pos+1) -$238 = -1 -(gdb) p *(pos+2) -$239 = -1 -(gdb) p *(pos+3) -$240 = -2 -(gdb) p *(pos+4) -$241 = -2 -``` -I need to understand YaRN to really understand what is happing in the rope -function. But since we are passing in the delta values from the cache cells -I think these will be used to adjust the positions in the computation in -someway. TODO: Read the YaRN paper and try to understand this properly. +And after this a normal RoPE operation will be done later in the formard pass +using the positions (which have been adjusted). Testings this: So the basic idea is that we have a model which was trained on a certain