diff --git a/notes/llama-batch-embd.md b/notes/llama-batch-embd.md index a838efe..0a02b1c 100644 --- a/notes/llama-batch-embd.md +++ b/notes/llama-batch-embd.md @@ -3,8 +3,8 @@ This document will take a look at the `embd` field of `llama_batch` and see how it can be used and for what purposes. So this is a field similar to the `llama_batch.token` field which is a pointer -to the input tokens. So we might have a prompt which we tokenize which the -is bacically splitting the prompt into tokens and looking them up in the models +to the input tokens. So we might have a prompt which we tokenize which is bacically +splitting the prompt into tokens and looking them up in the models vocabulary. This is what is passed to the `llama_batch.token` field. We can see in following function that if the batch has tokens then a 1d tensor will be created for the `inp_tokens`: @@ -76,15 +76,15 @@ name = "token_embd.weight", '\000' , extra = 0x0} ``` So this is how the lookup is performed. -Now, with that in mind when we have `batch.embd` set to point to embeddings in -stead of having tokenss in the batch the `inp_embd` tensor will be created +Now, with that in mind when we have `batch.embd` set to point to embeddings +instead of having tokens in the batch, the `inp_embd` tensor will be created directly as a 2d tensor of float32 values: ```c++ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); inpL = lctx.inp_embd; ggml_set_input(lctx.inp_embd); ``` -And later in the `llama_set_inputs` the `inp_embd` tensor will be populated +And later in the `llama_set_inputs` function the `inp_embd` tensor will be populated with the embeddings from the batch: ```c++ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { @@ -96,19 +96,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); } ``` -So my understanding of `batch.embd` is that the embeddings we can pass to the -model should be the context _unaware_ embeddings. +My initial thought about the use case for using this field was that we could take the +result of a program like `llama_embedding` which can output context aware embeddings and +then use these embeddings, perhaps at some later point, to perform inference based on these +embeddings, similar to how we would used a token prompt. -I initially thought that these embeddings could be the result of a program like -`llama_embedding` which can output context aware embeddings. I actually tried -this but was not able to get it to work and while I decided to write this -document to sort out my thoughts. - -In that case I was trying to use context aware embeddings, so they had already +In this case what I am trying to use context aware embeddings, so they had already been processed by the llama.cpp. But as we saw above these would then be used as inputs to the model and go through all the layers of the model (with the self-attention and feed-forward layers) and then we would have predict the -next token based on those embeddings. +next token based on those embeddings. I was somewhat concerned about this as these +embeddings will go through the model layers and self-attention and feed-forward +layers will operate on these embeddings. This is what I did when testing: @@ -184,5 +183,39 @@ $54 = {id = 7228, logit = 14.2697344, p = 0} (gdb) p ctx.model.vocab.id_to_token[7228] $55 = {text = "PA", score = -6969, attr = LLAMA_TOKEN_ATTR_NORMAL} ``` + +So after to debugging I realized that I was using a chat/instruct model and in this +case llama 2 which requires a template matching what the model was trained on. +```c++ + std::string model_path = "models/llama-2-7b-chat.Q4_K_M.gguf"; + std::string prompt = "[INST] <>\n\n<>\n\nWhat is LoRA? [/INST]"; +``` +This took care of the above "PA" token problem but the output is still not great but +the predictions are closer to the context of the original prompt "What is LoRA?": +```console +Top 5 logits: +Token 4309 ( Lo): 13.126865 +Token 3410 (Lo): 11.235941 +Token 322 ( and): 10.326930 +Token 13 ( +): 9.867510 +Token 297 ( in): 9.695880 +token_seq: 4309 : token_str [ Lo] +Inference: token: 4309, pos: 6 + +Top 5 logits: +Token 29949 (O): 13.566551 +Token 29877 (o): 12.267788 +Token 29892 (,): 11.640937 +Token 3410 (Lo): 11.488244 +Token 13 ( +): 10.990350 +token_seq: 29949 : token [O] +Inference: token: 29949, pos: 7 +... +``` + + + _wip_