Skip to content

Commit

Permalink
docs: update llama-batch-embd.md
Browse files Browse the repository at this point in the history
  • Loading branch information
danbev committed Oct 24, 2024
1 parent 126124b commit 11da6c4
Showing 1 changed file with 47 additions and 14 deletions.
61 changes: 47 additions & 14 deletions notes/llama-batch-embd.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ This document will take a look at the `embd` field of `llama_batch` and see how
it can be used and for what purposes.

So this is a field similar to the `llama_batch.token` field which is a pointer
to the input tokens. So we might have a prompt which we tokenize which the
is bacically splitting the prompt into tokens and looking them up in the models
to the input tokens. So we might have a prompt which we tokenize which is bacically
splitting the prompt into tokens and looking them up in the models
vocabulary. This is what is passed to the `llama_batch.token` field. We can see
in following function that if the batch has tokens then a 1d tensor will be
created for the `inp_tokens`:
Expand Down Expand Up @@ -76,15 +76,15 @@ name = "token_embd.weight", '\000' <repeats 46 times>, extra = 0x0}
```
So this is how the lookup is performed.

Now, with that in mind when we have `batch.embd` set to point to embeddings in
stead of having tokenss in the batch the `inp_embd` tensor will be created
Now, with that in mind when we have `batch.embd` set to point to embeddings
instead of having tokens in the batch, the `inp_embd` tensor will be created
directly as a 2d tensor of float32 values:
```c++
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
inpL = lctx.inp_embd;
ggml_set_input(lctx.inp_embd);
```
And later in the `llama_set_inputs` the `inp_embd` tensor will be populated
And later in the `llama_set_inputs` function the `inp_embd` tensor will be populated
with the embeddings from the batch:
```c++
static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
Expand All @@ -96,19 +96,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
}
```
So my understanding of `batch.embd` is that the embeddings we can pass to the
model should be the context _unaware_ embeddings.
My initial thought about the use case for using this field was that we could take the
result of a program like `llama_embedding` which can output context aware embeddings and
then use these embeddings, perhaps at some later point, to perform inference based on these
embeddings, similar to how we would used a token prompt.

I initially thought that these embeddings could be the result of a program like
`llama_embedding` which can output context aware embeddings. I actually tried
this but was not able to get it to work and while I decided to write this
document to sort out my thoughts.

In that case I was trying to use context aware embeddings, so they had already
In this case what I am trying to use context aware embeddings, so they had already
been processed by the llama.cpp. But as we saw above these would then be used
as inputs to the model and go through all the layers of the model (with the
self-attention and feed-forward layers) and then we would have predict the
next token based on those embeddings.
next token based on those embeddings. I was somewhat concerned about this as these
embeddings will go through the model layers and self-attention and feed-forward
layers will operate on these embeddings.

This is what I did when testing:

Expand Down Expand Up @@ -184,5 +183,39 @@ $54 = {id = 7228, logit = 14.2697344, p = 0}
(gdb) p ctx.model.vocab.id_to_token[7228]
$55 = {text = "PA", score = -6969, attr = LLAMA_TOKEN_ATTR_NORMAL}
```

So after to debugging I realized that I was using a chat/instruct model and in this
case llama 2 which requires a template matching what the model was trained on.
```c++
std::string model_path = "models/llama-2-7b-chat.Q4_K_M.gguf";
std::string prompt = "<s>[INST] <<SYS>>\n\n<</SYS>>\n\nWhat is LoRA? [/INST]";
```
This took care of the above "PA" token problem but the output is still not great but
the predictions are closer to the context of the original prompt "What is LoRA?":
```console
Top 5 logits:
Token 4309 ( Lo): 13.126865
Token 3410 (Lo): 11.235941
Token 322 ( and): 10.326930
Token 13 (
): 9.867510
Token 297 ( in): 9.695880
token_seq: 4309 : token_str [ Lo]
Inference: token: 4309, pos: 6

Top 5 logits:
Token 29949 (O): 13.566551
Token 29877 (o): 12.267788
Token 29892 (,): 11.640937
Token 3410 (Lo): 11.488244
Token 13 (
): 10.990350
token_seq: 29949 : token [O]
Inference: token: 29949, pos: 7
...
```



_wip_

0 comments on commit 11da6c4

Please sign in to comment.