Skip to content

Commit

Permalink
docs: update llama-self-extend notes
Browse files Browse the repository at this point in the history
  • Loading branch information
danbev committed Oct 31, 2024
1 parent b3a57f9 commit 6e06feb
Showing 1 changed file with 105 additions and 30 deletions.
135 changes: 105 additions & 30 deletions notes/llama-self-extend.md
Original file line number Diff line number Diff line change
Expand Up @@ -875,12 +875,12 @@ By setting `cache.has_shift` to true when `llama_decode_internal` calls
}
```

When the kv-cache `has_shift` is true like in this case where we updated above
in the self-extend code.
```c
When the kv-cache `has_shift` is true like in this case where made updates like
discussed above in the self-extend code.
```c++
static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
llama_batch dummy;
dummy.n_tokens = 0;
llama_ubatch dummy = {};
dummy.equal_seqs = true;

llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };

Expand All @@ -896,9 +896,9 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
}
```
```
```c++
struct ggml_cgraph * build_k_shift() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
GGML_ASSERT(kv_self.size == n_ctx);
Expand All @@ -907,18 +907,39 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
ggml_set_input(lctx.inp_K_shift);
for (int il = 0; il < n_layer; ++il) {
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
struct ggml_tensor * rope_factors = build_rope_factors(il);
struct ggml_tensor * tmp =
struct ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
0);
struct ggml_tensor * tmp;
if (ggml_is_quantized(k->type)) {
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
cb(tmp, "K_f32", il);
for (auto * backend : lctx.backends) {
// Figure out which backend KV cache belongs to
if (ggml_backend_supports_buft(backend, ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) {
ggml_backend_sched_set_tensor_backend(lctx.sched, tmp, backend);
break;
}
}
tmp = ggml_rope_ext_inplace(ctx0, tmp,
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted_f32", il);
tmp = ggml_cpy(ctx0, tmp, k);
} else {
// we rotate only the first n_rot dimensions
ggml_rope_ext_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
0),
tmp = ggml_rope_ext_inplace(ctx0, k,
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);

}
cb(tmp, "K_shifted", il);
ggml_build_forward_expand(gf, tmp);
}
Expand All @@ -938,15 +959,69 @@ op_params = {0 <repeats 16 times>}, flags = 0, grad = 0x0, src = {
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, view_src = 0x0, view_offs = 0, data = 0x0,
name = '\000' <repeats 63 times>, extra = 0x0}
```
Next we iterate over all the layers (42), and for each layer we will create

Now, if we recall from [llama-kv-cache.md](./llama-kv-cache.md) for each layer
in the model there is a key tensor stored in the `kv_self.k_l` array. And the
values in these tensors are the roped key values that have been cached thus far.
```console
(gdb) p n_head_kv
$3 = 32
(gdb) p n_embd_k_gqa
$4 = 4096
(gdb) p n_embd_head_k
$7 = 128
(gdb) p n_ctx
$8 = 8000
(gdb) p kv_self.k_l[il].ne
$6 = {32768000, 1, 1, 1}
```
```c++
struct ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
0);
```
The value 32768000 which we can think of as being 8000 rows (the context size),
each with an a dimension of 4096. So one row for each entry in the cache. 4096
is the embedding dimension size. This is creating a 3d view of this tensor where
we are 32 heads each with an embedding dimension of 128 and we have 8000 of
these. Something like this:
```console
4096 * 8000 = 32768000
/ \
128 * 32 * 8000 = 32768000
0 [0 ... 4095]
...
7999 [0 ... 4095]
z_0
0 [0 ... 127]
...
31 [0 ... 127]
...
z_7999
0 [0 ... 127]
...
31 [0 ... 127]
(gdb) p k->ne
$15 = {128, 32, 8000, 1}
```

So this will iterate over all the layers, and for each layer we will create
a ggml operation using `ggml_rope_ext_inplace`. This will add an operation for
rope which will update the target tensors which are the `kv_self.k_l[il]`
tensors. And notice that the position tensor, the third argument to
rope which will update the target tensors, which are the roped `kv_self.k_l[il]`
values. And notice that the position tensor, the third argument to
`ggml_rope_ext_inplace` is the `lctx.inp_K_shift` tensor which is the tensor
that was updated above with the delta values from the cache cells. So this is
how the the kv-cache interacts with rope I think, the positions will have been
updated to reflect the changes to the positions, and then a normal rope
operation will follow. For details about rope see [ggml-rope.md](ggml-rope.md).
that will be updated with the delta values from the cache cells (see below).
So this is how the the kv-cache interacts with rope I think, the positions will
have been updated to reflect the changes to the positions, and then a normal
rope operation will follow. For details about rope see
[ggml-rope.md](position-embeddings/ggml-rope.md).


After that `llama_set_k_shift` is called which sets the _delta_ from the cache
Expand All @@ -965,8 +1040,8 @@ static void llama_set_k_shift(llama_context & lctx) {
}
}
```
So this is updating the `inp_K_shift` tensor with the delta values from the cache
cells. Later in `llama_decode_internal` we have:
So this is updating the `inp_K_shift` tensor with the delta values from the
cache cells. Later in `llama_decode_internal` we have:
```c
static int llama_decode_internal(
llama_context & lctx,
Expand Down Expand Up @@ -999,7 +1074,7 @@ $205 = {pos = 1, delta = -2, src = 0, seq_id = std::set with 1 element = {[0] =
$206 = {pos = 2, delta = -2, src = 0, seq_id = std::set with 1 element = {[0] = 0}}
```
So I was wrong when I said that it does not look like the function
`build_k_shift` uses the `kv_cache` because it does because it used the tensor
`build_k_shift` uses the `kv_cache` because it does since it uses the tensor
`lctx.inp_K_shift`

```c
Expand Down Expand Up @@ -1157,8 +1232,9 @@ the while loop to start by updating the range 0-64. Since this is the first
time through the first shift does nothing but this is needed later when `n_past`
is updated later in this block. This update will cause the next decode to
incremement the positions using the new self-extend aware cache cells. The
division will take the will take the range of 0-256 and divide those cells
positions by 4, and update their delta with the gt.
division will take the range of 0-256 and divide those cells positions by 4, and
update their delta with the gt.

We can inspect what the cells look like after the division function returns:
```console
(gdb) p ctx.kv_self.cells[255]
Expand All @@ -1171,8 +1247,8 @@ Now, if we take a look at entry 256 we find:
$24 = {pos = 256, delta = 0, src = 0, seq_id = std::set with 1 element = {[0] = 0}}
```
Notice that these positions are not sequential after the division. This is what
the last shift corrects. It will go through the entires cache and update the
sequence id where the ranges is position 256-2048. After this we can again
the last shift corrects. It will go through the entires in the cache and update
the sequence id where the ranges is position 256-2048. After this we can again
inspect the cells:
```console
(gdb) p ctx.kv_self.cells[256]
Expand Down Expand Up @@ -1218,7 +1294,6 @@ Next time though the loop call 257 will be updated to pos 257.
After that the division will operate on the range 256-512.



Configuration parameters:
* prompt token size (to determine if self-extend is required)
* context size (-c) instead of using the context size from the model
Expand Down

0 comments on commit 6e06feb

Please sign in to comment.