diff --git a/notes/llama.md b/notes/llama.md index f320061..7fdccf0 100644 --- a/notes/llama.md +++ b/notes/llama.md @@ -5,6 +5,10 @@ biases of the model. These models come in different sizes and are trained on different datasets. The larger the model the more data it has been trained on and the more accurate it is. +## Table of Contents +- [Llama 2](#llama-2) +- [llama_batch](#llama_batch) + ### Llama 2 Is really a family of pre-trained models in various scales (the number of weights). From 7B to 70B. @@ -17,9 +21,6 @@ It is based on the transformer architecture which some improvements like: (apperently inspired by GPT-Neo), see [rope.md](./rope.md) * AdamW optimizer -TODO: I'm not familiar with any of the above so this so look into these -separately. - I've seen the achitecture of a transformer where there is an encoder and a decoder. But my understanding of Llama is that there is only an encoder. ``` @@ -204,35 +205,6 @@ are two drivers: OpenAI and Llama. Llama uses a binding to llama.cpp, and is is created using bindget. The crate llm-chain-llama-sys contains the binding and llm-chain-llama contains Rust API. -#### llama_batch -This struct holdes `input` data for `llama_decode` and is defined as: -```c++ - typedef struct llama_batch { - int32_t n_tokens; - - llama_token * token; - float * embd; - llama_pos * pos; - llama_seq_id * seq_id; - int8_t * logits; - } llama_batch; -``` -The `n_tokens` is a counter of the number of tokens that this batch contains. - -A `llama_batch` is simlilar to the contept of context we talked about -[llm.md](../../notes/llm.md#context_size). Below we are adding the input query -tokens to this batch/context. So it will initially just contain the tokens for -our query. But after running the inference, we will append the next token to the -batch and run the inference again and then run the inference again to predict -the next token, now with more context (the previous token). - -The `embd` is the embedding of the tokens (I think). So this was not obvious to -me at first, but recall that the tokens just integer representations of -works/subwords, like a mapping. But they don't contains any semantic -information. Recall that this is data which is setup as input for llama_decode -and I think this is used when embeddings are already available perhaps. -TODO: verify this. -The `pos` is the position of the tokens in the sequence. ### Key-Value Cache This section tries to explain the key-value cache used in the llama 2 @@ -5936,7 +5908,7 @@ struct llama_sbatch { std::vector ubatch_output; ... ``` -When we call `from_batch` we are setting the `simple_slip` argument to true +When we call `from_batch` we are setting the `simple_split` argument to true as the model we are using is not recurrent. And `n_outputs` is only 2 in our case and `n_tokens_all` is 13 so `logits_all` will be false. @@ -5954,6 +5926,28 @@ So lets now looks closer at `from_batch`: ``` So is just setting the members. `this` is used because these fields have the same name as the parameters that this function takes. +Notice that this is the sbatch as a pointer to the original batch which is +stored inte batch field. And this happens for each decode call. + +If this is the first time we call `llama_decode_internal` then the `sbatch` +will have been default initilized as part of the `llama_context` struct. +```c++ +int ret = llama_decode_internal(llama_context & lctx, llama_batch batch); + ... + + lctx.sbatch.from_batch(batch, n_embd, + /* simple_split */ !kv_self.recurrent, + /* logits_all */ n_outputs == n_tokens_all); +``` +So the sbatch will get updated with a reference to the original batch: +``` + lctx.sbatch + +-----------------+ +----------------+ + | batch | ------------>| llama_batch | + | | +----------------+ + +__wip__ +``` Next the vector of ids (token indices not sequence ids) is resized to the number of tokens (so 13 in our case): @@ -6969,3 +6963,62 @@ struct llama_sbatch_seq { ``` _wip_ + +### batch/ubatch/sbatch +So a batch, or rather a `llama_batch` is what we pass into the `llama_decode` +function and is really the only thing that an external caller know anything +about (well that is not entirely true as they can set a ubatch value but the +struct itself is private). But the internal decode operation is/can be split +into smaller units called ubatches (update batches?). The internal decode +will create an sbatch, sequence-aware that manages the sequences of the ubatches. + +Lets take a look at when an sbatch is created. `llama_sbatch` is a struct and it +is a member of `llama_context`. When we create a new context with from a model +one instance of this will be created. For example, if we look at the main +example we have: +```c++ +common_init_result llama_init = common_init_from_params(params); +``` +This will call: +```c++ +llama_context * lctx = llama_new_context_with_model(model, cparams) +``` +Which in turn will call: +```c++ +llama_context * ctx = new llama_context(*model); +``` +And this is where the sbatch is created. At this point it does not contain +any data (default initilized). +```console +(gdb) p ctx->sbatch +$2 = {n_tokens = 93802895217820, n_embd = 8587913089966890862, logits_all = 128, ids = std::vector of length 0, capacity 0, + out_ids = std::vector of length 0, capacity 0, seq = std::vector of length 0, capacity 0, batch = 0x0, + ubatch_token = std::vector of length 0, capacity 0, ubatch_embd = std::vector of length 0, capacity 0, + ubatch_pos = std::vector of length 0, capacity 0, ubatch_n_seq_id = std::vector of length 0, capacity 0, + ubatch_seq_id = std::vector of length 0, capacity 0, ubatch_output = std::vector of length 0, capacity 0} +``` + +A `llama_batch` is what a caller will create and pass to `llama_decode`. This +batch will be passed to the the sbatch's `from_batch` method: +```c++ + lctx.sbatch.from_batch(batch, n_embd, + /* simple_split */ !kv_self.recurrent, + /* logits_all */ n_outputs == n_tokens_all); +``` + + +```console + + llama_batch ctx->llama_sbatch + +--------------+ +--------------+ + | | llama_decode | | + | | ------------> | | + | | | | + | | | | + | | | | + +--------------+ +--------------+ + + +``` + +