Skip to content

Commit

Permalink
llama.cpp: update embeddings.cpp (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
danbev committed Oct 24, 2024
1 parent daf6857 commit 8a7f982
Showing 1 changed file with 54 additions and 38 deletions.
92 changes: 54 additions & 38 deletions fundamentals/llama.cpp/src/embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,83 +104,100 @@ int main(int argc, char** argv) {
}

std::vector<llama_token> input_tokens = tokenize_prompt(model, prompt);

llama_batch prompt_batch = create_batch(ctx_params.n_batch, input_tokens, model);

// Decode the prompt to generate the embeddings. We are not going to use
// the logits at this stage.
if (llama_decode(embd_ctx, prompt_batch) != 0) {
fprintf(stderr, "llama_decode() failed\n");
return 1;
}

//float* embd = nullptr;
// Now we will extract the embeddings.
int n_embd = llama_n_embd(model);
//std::vector<float*> token_embeddings;
std::vector<std::vector<float>> token_embeddings;

std::vector<float> token_embeddings;
for (size_t i = 0; i < input_tokens.size(); i++) {
float* embd = llama_get_embeddings_ith(embd_ctx, i);
//token_embeddings.push_back(embd);
token_embeddings.push_back(std::vector<float>(embd, embd + n_embd));
token_embeddings.insert(token_embeddings.end(), embd, embd + n_embd);

printf("%stoken %ld embeddings: %d.\%s\n", BLUE, i, n_embd, RESET);
printf("Original embedding %zu: ", i);
for (int j = 0; j < 5; j++) {
printf("%f ", embd[j]);
}
printf("\n");
}
// Print out the first 5 embeddings from all token embeddings generated.
for (size_t i = 0; i < input_tokens.size(); i++) {
printf("%sembedding %ld \%s", BLUE, i, RESET);
float* token_embd = token_embeddings.data() + (i * n_embd);
for (int j = 0; j < 5; j++) {
printf("%s%f %s", BLUE, embd[j], RESET);
printf("%f ", token_embd[j]);
}
printf("\n");
}

// Now we are done with the context used to generate the embeddings. This
// is to simulate a case where the embeddings were generated as a previous
// stage for usage later.
llama_kv_cache_clear(embd_ctx);
llama_free(embd_ctx);
llama_batch_free(prompt_batch);

// Now we are going to create a new context for inference.
llama_context_params inf_ctx_params = llama_context_default_params();
inf_ctx_params.n_ctx = 1024;
inf_ctx_params.n_threads = 1;
inf_ctx_params.n_threads_batch = 1;
llama_context* inf_ctx = llama_new_context_with_model(model, inf_ctx_params);
int pos = 0;

// Next we create a batch for the token embeddings generated above.
// The following is creating a single batch with 6 token embeddings in it.
llama_batch embd_batch = llama_batch_init(1024, n_embd, 1);
embd_batch.n_tokens = input_tokens.size();
embd_batch.embd = token_embeddings.data();
printf("%sToken embeddings size: %d, n_tokens: %d%s\n", GREEN, n_embd, embd_batch.n_tokens, RESET);
for (size_t i = 0; i < input_tokens.size(); i++) {
embd_batch.pos[i] = i;
embd_batch.n_seq_id[i] = 1;
embd_batch.seq_id[i][0] = 10;
embd_batch.logits[i] = true;
}
printf("%slast position : %d%s\n", GREEN, embd_batch.pos[input_tokens.size() - 1], RESET);

for (size_t i = 0; i < token_embeddings.size(); i++) {
llama_batch embd_batch = llama_batch_init(1, n_embd, 1);
embd_batch.n_tokens = 1;
embd_batch.embd = token_embeddings[i].data();
embd_batch.pos[0] = i;
embd_batch.n_seq_id[0] = 1;
embd_batch.seq_id[0][0] = 0;
embd_batch.logits[0] = true;

if (llama_decode(inf_ctx, embd_batch) != 0) {
fprintf(stderr, "llama_decode() failed for token %zu\n", i);
return 1;
}
//llama_batch_free(embd_batch);
// Decode the token embeddings to generate the logits.
if (llama_decode(inf_ctx, embd_batch) != 0) {
fprintf(stderr, "llama_decode() failed for token\n");
return 1;
llama_batch_free(embd_batch);
}
pos = embd_batch.pos[input_tokens.size() - 1];

// Next create a sampler chain for sampling the next token.
auto sparams = llama_sampler_chain_default_params();
llama_sampler* sampler = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(40));
llama_sampler_chain_add(sampler, llama_sampler_init_temp(0.8f));
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(3));
llama_sampler_chain_add(sampler, llama_sampler_init_dist(1234));

std::vector<std::string> output;
// Sample a token (sp=sampled token)
llama_token sp_token = llama_sampler_sample(sampler, inf_ctx, 0);
llama_token sp_token = llama_sampler_sample(sampler, inf_ctx, -1);
std::string sp_str = token_as_string(model, sp_token);
output.push_back(sp_str);
printf("%stoken_seq: %d : token_str [%s]%s\n", ORANGE, sp_token, sp_str.c_str(), RESET);
output.push_back(sp_str);
llama_sampler_reset(sampler);

int decode_calls = 10;
int pos = token_embeddings.size() - 1;

printf("%sInference:%s\n", ORANGE, RESET);
int decode_calls = 5;
while (decode_calls--) {
llama_batch update_batch = llama_batch_init(1, 0, 1);
update_batch.n_tokens = 1;
update_batch.token[0] = sp_token;
update_batch.pos[0] = pos++;

update_batch.n_seq_id[0] = 1;
update_batch.seq_id[0][0] = 0;
update_batch.seq_id[0][0] = 10;
update_batch.logits[0] = true;
printf("%sInference: token: %d, pos: %d %s\n", ORANGE, update_batch.token[0], update_batch.pos[0], RESET);

if (llama_decode(inf_ctx, update_batch) != 0) {
fprintf(stderr, "llama_decode() failed\n");
Expand All @@ -193,8 +210,7 @@ int main(int argc, char** argv) {
printf("%stoken_seq: %.4d : token [%s]%s\n", ORANGE, sp_token, sp_str.c_str(), RESET);

llama_sampler_reset(sampler);

//llama_batch_free(update_batch);
llama_batch_free(update_batch);
}

printf("Generated output:\n");
Expand All @@ -203,10 +219,10 @@ int main(int argc, char** argv) {
}
printf("\n");

//llama_free(inf_ctx);
//llama_free_model(model);
//llama_backend_free();
//llama_sampler_free(sampler);
llama_free(inf_ctx);
llama_free_model(model);
llama_backend_free();
llama_sampler_free(sampler);

return 0;
}

0 comments on commit 8a7f982

Please sign in to comment.