diff --git a/fundamentals/llama.cpp/src/embeddings.cpp b/fundamentals/llama.cpp/src/embeddings.cpp index db6ab7a..334c1d4 100644 --- a/fundamentals/llama.cpp/src/embeddings.cpp +++ b/fundamentals/llama.cpp/src/embeddings.cpp @@ -104,74 +104,90 @@ int main(int argc, char** argv) { } std::vector input_tokens = tokenize_prompt(model, prompt); - llama_batch prompt_batch = create_batch(ctx_params.n_batch, input_tokens, model); + + // Decode the prompt to generate the embeddings. We are not going to use + // the logits at this stage. if (llama_decode(embd_ctx, prompt_batch) != 0) { fprintf(stderr, "llama_decode() failed\n"); return 1; } - //float* embd = nullptr; + // Now we will extract the embeddings. int n_embd = llama_n_embd(model); - //std::vector token_embeddings; - std::vector> token_embeddings; - + std::vector token_embeddings; for (size_t i = 0; i < input_tokens.size(); i++) { float* embd = llama_get_embeddings_ith(embd_ctx, i); - //token_embeddings.push_back(embd); - token_embeddings.push_back(std::vector(embd, embd + n_embd)); + token_embeddings.insert(token_embeddings.end(), embd, embd + n_embd); - printf("%stoken %ld embeddings: %d.\%s\n", BLUE, i, n_embd, RESET); + printf("Original embedding %zu: ", i); + for (int j = 0; j < 5; j++) { + printf("%f ", embd[j]); + } + printf("\n"); + } + // Print out the first 5 embeddings from all token embeddings generated. + for (size_t i = 0; i < input_tokens.size(); i++) { + printf("%sembedding %ld \%s", BLUE, i, RESET); + float* token_embd = token_embeddings.data() + (i * n_embd); for (int j = 0; j < 5; j++) { - printf("%s%f %s", BLUE, embd[j], RESET); + printf("%f ", token_embd[j]); } printf("\n"); } + // Now we are done with the context used to generate the embeddings. This + // is to simulate a case where the embeddings were generated as a previous + // stage for usage later. llama_kv_cache_clear(embd_ctx); llama_free(embd_ctx); llama_batch_free(prompt_batch); + // Now we are going to create a new context for inference. llama_context_params inf_ctx_params = llama_context_default_params(); inf_ctx_params.n_ctx = 1024; inf_ctx_params.n_threads = 1; inf_ctx_params.n_threads_batch = 1; llama_context* inf_ctx = llama_new_context_with_model(model, inf_ctx_params); + int pos = 0; + + // Next we create a batch for the token embeddings generated above. + // The following is creating a single batch with 6 token embeddings in it. + llama_batch embd_batch = llama_batch_init(1024, n_embd, 1); + embd_batch.n_tokens = input_tokens.size(); + embd_batch.embd = token_embeddings.data(); + printf("%sToken embeddings size: %d, n_tokens: %d%s\n", GREEN, n_embd, embd_batch.n_tokens, RESET); + for (size_t i = 0; i < input_tokens.size(); i++) { + embd_batch.pos[i] = i; + embd_batch.n_seq_id[i] = 1; + embd_batch.seq_id[i][0] = 10; + embd_batch.logits[i] = true; + } + printf("%slast position : %d%s\n", GREEN, embd_batch.pos[input_tokens.size() - 1], RESET); - for (size_t i = 0; i < token_embeddings.size(); i++) { - llama_batch embd_batch = llama_batch_init(1, n_embd, 1); - embd_batch.n_tokens = 1; - embd_batch.embd = token_embeddings[i].data(); - embd_batch.pos[0] = i; - embd_batch.n_seq_id[0] = 1; - embd_batch.seq_id[0][0] = 0; - embd_batch.logits[0] = true; - - if (llama_decode(inf_ctx, embd_batch) != 0) { - fprintf(stderr, "llama_decode() failed for token %zu\n", i); - return 1; - } - //llama_batch_free(embd_batch); + // Decode the token embeddings to generate the logits. + if (llama_decode(inf_ctx, embd_batch) != 0) { + fprintf(stderr, "llama_decode() failed for token\n"); + return 1; + llama_batch_free(embd_batch); } + pos = embd_batch.pos[input_tokens.size() - 1]; + // Next create a sampler chain for sampling the next token. auto sparams = llama_sampler_chain_default_params(); llama_sampler* sampler = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(40)); - llama_sampler_chain_add(sampler, llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(3)); llama_sampler_chain_add(sampler, llama_sampler_init_dist(1234)); std::vector output; // Sample a token (sp=sampled token) - llama_token sp_token = llama_sampler_sample(sampler, inf_ctx, 0); + llama_token sp_token = llama_sampler_sample(sampler, inf_ctx, -1); std::string sp_str = token_as_string(model, sp_token); - output.push_back(sp_str); printf("%stoken_seq: %d : token_str [%s]%s\n", ORANGE, sp_token, sp_str.c_str(), RESET); + output.push_back(sp_str); llama_sampler_reset(sampler); - int decode_calls = 10; - int pos = token_embeddings.size() - 1; - - printf("%sInference:%s\n", ORANGE, RESET); + int decode_calls = 5; while (decode_calls--) { llama_batch update_batch = llama_batch_init(1, 0, 1); update_batch.n_tokens = 1; @@ -179,8 +195,9 @@ int main(int argc, char** argv) { update_batch.pos[0] = pos++; update_batch.n_seq_id[0] = 1; - update_batch.seq_id[0][0] = 0; + update_batch.seq_id[0][0] = 10; update_batch.logits[0] = true; + printf("%sInference: token: %d, pos: %d %s\n", ORANGE, update_batch.token[0], update_batch.pos[0], RESET); if (llama_decode(inf_ctx, update_batch) != 0) { fprintf(stderr, "llama_decode() failed\n"); @@ -193,8 +210,7 @@ int main(int argc, char** argv) { printf("%stoken_seq: %.4d : token [%s]%s\n", ORANGE, sp_token, sp_str.c_str(), RESET); llama_sampler_reset(sampler); - - //llama_batch_free(update_batch); + llama_batch_free(update_batch); } printf("Generated output:\n"); @@ -203,10 +219,10 @@ int main(int argc, char** argv) { } printf("\n"); - //llama_free(inf_ctx); - //llama_free_model(model); - //llama_backend_free(); - //llama_sampler_free(sampler); + llama_free(inf_ctx); + llama_free_model(model); + llama_backend_free(); + llama_sampler_free(sampler); return 0; }