Skip to content

Commit

Permalink
docs: save in-progress work on llama.md
Browse files Browse the repository at this point in the history
  • Loading branch information
danbev committed Nov 13, 2024
1 parent 37a4c5e commit f6b47e1
Showing 1 changed file with 86 additions and 33 deletions.
119 changes: 86 additions & 33 deletions notes/llama.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ biases of the model. These models come in different sizes and are trained on
different datasets. The larger the model the more data it has been trained on
and the more accurate it is.

## Table of Contents
- [Llama 2](#llama-2)
- [llama_batch](#llama_batch)

### Llama 2
Is really a family of pre-trained models in various scales (the number of
weights). From 7B to 70B.
Expand All @@ -17,9 +21,6 @@ It is based on the transformer architecture which some improvements like:
(apperently inspired by GPT-Neo), see [rope.md](./rope.md)
* AdamW optimizer

TODO: I'm not familiar with any of the above so this so look into these
separately.

I've seen the achitecture of a transformer where there is an encoder and a
decoder. But my understanding of Llama is that there is only an encoder.
```
Expand Down Expand Up @@ -204,35 +205,6 @@ are two drivers: OpenAI and Llama. Llama uses a binding to llama.cpp, and is
is created using bindget. The crate llm-chain-llama-sys contains the binding
and llm-chain-llama contains Rust API.

#### llama_batch
This struct holdes `input` data for `llama_decode` and is defined as:
```c++
typedef struct llama_batch {
int32_t n_tokens;

llama_token * token;
float * embd;
llama_pos * pos;
llama_seq_id * seq_id;
int8_t * logits;
} llama_batch;
```
The `n_tokens` is a counter of the number of tokens that this batch contains.

A `llama_batch` is simlilar to the contept of context we talked about
[llm.md](../../notes/llm.md#context_size). Below we are adding the input query
tokens to this batch/context. So it will initially just contain the tokens for
our query. But after running the inference, we will append the next token to the
batch and run the inference again and then run the inference again to predict
the next token, now with more context (the previous token).

The `embd` is the embedding of the tokens (I think). So this was not obvious to
me at first, but recall that the tokens just integer representations of
works/subwords, like a mapping. But they don't contains any semantic
information. Recall that this is data which is setup as input for llama_decode
and I think this is used when embeddings are already available perhaps.
TODO: verify this.
The `pos` is the position of the tokens in the sequence.

### Key-Value Cache
This section tries to explain the key-value cache used in the llama 2
Expand Down Expand Up @@ -5936,7 +5908,7 @@ struct llama_sbatch {
std::vector<int8_t> ubatch_output;
...
```
When we call `from_batch` we are setting the `simple_slip` argument to true
When we call `from_batch` we are setting the `simple_split` argument to true
as the model we are using is not recurrent. And `n_outputs` is only 2 in our
case and `n_tokens_all` is 13 so `logits_all` will be false.

Expand All @@ -5954,6 +5926,28 @@ So lets now looks closer at `from_batch`:
```
So is just setting the members. `this` is used because these fields have the
same name as the parameters that this function takes.
Notice that this is the sbatch as a pointer to the original batch which is
stored inte batch field. And this happens for each decode call.

If this is the first time we call `llama_decode_internal` then the `sbatch`
will have been default initilized as part of the `llama_context` struct.
```c++
int ret = llama_decode_internal(llama_context & lctx, llama_batch batch);
...

lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
```
So the sbatch will get updated with a reference to the original batch:
```
lctx.sbatch
+-----------------+ +----------------+
| batch | ------------>| llama_batch |
| | +----------------+

__wip__
```

Next the vector of ids (token indices not sequence ids) is resized to the
number of tokens (so 13 in our case):
Expand Down Expand Up @@ -6969,3 +6963,62 @@ struct llama_sbatch_seq {
```

_wip_

### batch/ubatch/sbatch
So a batch, or rather a `llama_batch` is what we pass into the `llama_decode`
function and is really the only thing that an external caller know anything
about (well that is not entirely true as they can set a ubatch value but the
struct itself is private). But the internal decode operation is/can be split
into smaller units called ubatches (update batches?). The internal decode
will create an sbatch, sequence-aware that manages the sequences of the ubatches.

Lets take a look at when an sbatch is created. `llama_sbatch` is a struct and it
is a member of `llama_context`. When we create a new context with from a model
one instance of this will be created. For example, if we look at the main
example we have:
```c++
common_init_result llama_init = common_init_from_params(params);
```
This will call:
```c++
llama_context * lctx = llama_new_context_with_model(model, cparams)
```
Which in turn will call:
```c++
llama_context * ctx = new llama_context(*model);
```
And this is where the sbatch is created. At this point it does not contain
any data (default initilized).
```console
(gdb) p ctx->sbatch
$2 = {n_tokens = 93802895217820, n_embd = 8587913089966890862, logits_all = 128, ids = std::vector of length 0, capacity 0,
out_ids = std::vector of length 0, capacity 0, seq = std::vector of length 0, capacity 0, batch = 0x0,
ubatch_token = std::vector of length 0, capacity 0, ubatch_embd = std::vector of length 0, capacity 0,
ubatch_pos = std::vector of length 0, capacity 0, ubatch_n_seq_id = std::vector of length 0, capacity 0,
ubatch_seq_id = std::vector of length 0, capacity 0, ubatch_output = std::vector of length 0, capacity 0}
```

A `llama_batch` is what a caller will create and pass to `llama_decode`. This
batch will be passed to the the sbatch's `from_batch` method:
```c++
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
```


```console

llama_batch ctx->llama_sbatch
+--------------+ +--------------+
| | llama_decode | |
| | ------------> | |
| | | |
| | | |
| | | |
+--------------+ +--------------+


```


0 comments on commit f6b47e1

Please sign in to comment.