diff --git a/notes/llama-kv-cache.md b/notes/llama-kv-cache.md index 9bf606b..6f6b120 100644 --- a/notes/llama-kv-cache.md +++ b/notes/llama-kv-cache.md @@ -2111,6 +2111,8 @@ So this will include only the logits that belong to the current token's sequence To get a feel for how this works there is a standalone example in [llama-att-softmax.c](../ggml/src/llama-att-softmax.c) +### has_shift + ### Connection/link between cells and cache tensors This might seem obvious but this was not clear to me initially which is why I'm adding this section. The `cells` member of `llama_kv_cache` is a vector of `llama_kv_cell`: @@ -2267,5 +2269,63 @@ next 7 belong to sequence 1: } ``` +### Size of the cache +We can calculate the size of the cache using the following forumula: + +So we have two caches one for the keys and one for the values: +``` +K cache size = ctx.cparams.n_ctx * + ctx.model.hparams.n_layer * + ctx.model.hparams.n_head_kv(0) * + ctx.model.hparams.n_embd_head_k * + ctx.kv_self.type_k + +V cache size = ctx.cparams.n_ctx * + ctx.model.hparams.n_layer * + ctx.model.hparams.n_head_kv(0) * + ctx.model.hparams.n_embd_head_v * + ctx.kv_self.type_v +``` +Lets take a concrete example: +```console +(gdb) p ctx.cparams.n_ctx +$11 = 512 +(gdb) p ctx.model.hparams.n_layer +$12 = 32 +(gdb) p ctx.model.hparams.n_head_kv(0) +$13 = 32 +(gdb) p ctx.model.hparams.n_embd_head_k +$14 = 128 +(gdb) p ctx.model.hparams.n_embd_head_v +$17 = 128 +(gdb) p ctx.kv_self.type_k +$15 = GGML_TYPE_F16 +(gdb) p ctx.kv_self.type_v +$16 = GGML_TYPE_F16 +``` +This can also be written as just one value and then doubling that if the +values are the same for both the keys and values: +```console +kv-cache size = 2 * // both keys and values + ctx.cparams.n_ctx * + ctx.model.hparams.n_layer * + ctx.model.hparams.n_head_kv(0) * + ctx.model.hparams.n_embd_head_k * + ctx.kv_self.type_k + +k_cache_size = 512 * 32 * 32 * 128 * 2 = 536870912 = 512MB +``` +And for other models: +```console +kv-cache size = 2 * 30016 * 32 * 8 * 128 * 2 bytes + = 2 * 30016 * 32 * 8 * 128 * 2 + = 3934257152 + = 3934257152 / (1024*1024) + = 3,750 MB +``` + + _wip_ + +