Skip to content

Commit

Permalink
llama.cpp: fix issue with position in multi-prompt
Browse files Browse the repository at this point in the history
This example needs some cleanup as it is very messy at the moment.
  • Loading branch information
danbev committed Oct 29, 2024
1 parent 399f850 commit 74ba0b1
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions fundamentals/llama.cpp/src/simple-prompt-multi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ int main(int argc, char** argv) {

printf("add_bos: %d\n", add_bos);
printf("prompt.len: %ld\n", prompt.length());
int n_tokens = prompt.length() + add_bos;
std::vector<llama_token> input_tokens(n_tokens);
n_tokens = llama_tokenize(model,
int input1_len = prompt.length();
std::vector<llama_token> input_tokens(input1_len);
int n_tokens = llama_tokenize(model,
prompt.data(),
prompt.length(),
input_tokens.data(),
Expand All @@ -87,17 +87,17 @@ int main(int argc, char** argv) {
fprintf(stderr, "\n");
fprintf(stdout, "n_tokens: %d\n", n_tokens);

int n_tokens2 = prompt2.length() + add_bos;
std::vector<llama_token> input_tokens2(n_tokens2);
n_tokens = llama_tokenize(model,
int input2_len = prompt2.length() + add_bos;
std::vector<llama_token> input_tokens2(input2_len);
int n_tokens2 = llama_tokenize(model,
prompt2.data(),
prompt2.length(),
input_tokens2.data(),
input_tokens2.size(),
true,
false);
if (n_tokens2 < 0) {
input_tokens2.resize(-n_tokens);
input_tokens2.resize(-n_tokens2);
int new_len = llama_tokenize(model, prompt2.data(), prompt2.length(), input_tokens2.data(), input_tokens2.size(), add_bos, false);
} else {
input_tokens2.resize(n_tokens2);
Expand All @@ -115,8 +115,9 @@ int main(int argc, char** argv) {
batch.n_tokens++;
}

int pos = batch.n_tokens;
for (int i = 0; i < n_tokens2; i++) {
int idx = n_tokens + i;
int idx = pos + i;
batch.token[idx] = input_tokens2[i];
batch.pos[idx] = idx,
batch.n_seq_id[idx] = 1;
Expand Down Expand Up @@ -182,7 +183,7 @@ int main(int argc, char** argv) {
single_token_batch.token[0] = new_token_id; // the new token id.
single_token_batch.pos[0] = n_cur, // the position in the sequence.
single_token_batch.n_seq_id[0] = 1; // the number of sequences for this token.
single_token_batch.seq_id[0][0] = 0; // the actual sequence id.
single_token_batch.seq_id[0][0] = 1; // the actual sequence id.
single_token_batch.logits[0] = true;
n_batch_tokens = single_token_batch.n_tokens;

Expand Down

0 comments on commit 74ba0b1

Please sign in to comment.