Skip to content

Commit

Permalink
llama.cpp: stash changes (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
danbev committed Oct 20, 2024
1 parent f89a549 commit 61615c1
Showing 1 changed file with 41 additions and 24 deletions.
65 changes: 41 additions & 24 deletions fundamentals/llama.cpp/src/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,20 @@ struct token_position {
size_t index;
token_position() : seq_id(0), index(0) {}
token_position(size_t s, size_t i) : seq_id(s), index(i) {}

std::string to_string() const {
return "{ seq_id: " + std::to_string(seq_id) + ", index: " + std::to_string(index) + " }";
}
};

std::vector<std::pair<llama_token, std::vector<token_position>>> find_common_tokens(
std::unordered_map<llama_token, std::vector<token_position>> find_common_tokens(
const std::vector<std::vector<llama_token>>& input_tokens,
llama_model* model) {
if (input_tokens.empty()) {
return {};
}

std::unordered_map<llama_token, std::unordered_map<size_t, token_position>> token_positions;

for (size_t seq_id = 0; seq_id < input_tokens.size(); ++seq_id) {
const auto& current_vec = input_tokens[seq_id];
for (size_t token_idx = 0; token_idx < current_vec.size(); ++token_idx) {
Expand All @@ -60,32 +63,25 @@ std::vector<std::pair<llama_token, std::vector<token_position>>> find_common_tok
}
}

std::vector<std::pair<llama_token, std::vector<token_position>>> common_tokens;

std::unordered_map<llama_token, std::vector<token_position>> common_tokens;
for (const auto& entry : token_positions) {
if (llama_add_bos_token(model) && entry.first == 1) {
continue;
}

if (entry.second.size() > 1) {
std::vector<token_position> positions;
positions.reserve(entry.second.size());
for (const auto& seq_pos : entry.second) {
positions.push_back(seq_pos.second);
}
common_tokens.emplace_back(entry.first, std::move(positions));
common_tokens[entry.first] = std::move(positions);
}
}

std::sort(common_tokens.begin(), common_tokens.end(),
[](const std::pair<llama_token, std::vector<token_position>>& a,
const std::pair<llama_token, std::vector<token_position>>& b) {
return a.first < b.first;
});

return common_tokens;
}

void print_common_tokens(const std::vector<std::pair<llama_token, std::vector<token_position>>>& common_tokens) {
void print_common_tokens(std::unordered_map<llama_token, std::vector<token_position>> common_tokens) {
for (const auto& token_info : common_tokens) {
printf("Token id [%d] in common at positions:\n", token_info.first);
for (const auto& pos : token_info.second) {
Expand All @@ -108,19 +104,35 @@ llama_batch create_batch(int size, std::vector<std::vector<llama_token>> input_t
// Create a single batch for both prompts.
llama_batch batch = llama_batch_init(size, 0, n_prompts);

for (size_t p = 0; p < input_tokens.size(); p++) {
printf("Processing prompt %ld, batch_n_tokens: %d \n", p, batch.n_tokens);
std::vector<llama_token> prompt_tokens = input_tokens[p];
for (size_t s = 0; s < input_tokens.size(); s++) {
printf("Processing prompt %ld, batch_n_tokens: %d \n", s, batch.n_tokens);
std::vector<llama_token> prompt_tokens = input_tokens[s];
for (size_t i = 0; i < prompt_tokens.size(); i++) {
int token_id = prompt_tokens[i];
int idx = batch.n_tokens;
batch.token[idx] = prompt_tokens[i];
batch.pos[idx] = i,
batch.n_seq_id[idx] = 1;
batch.seq_id[idx][0] = p; // the sequence id
batch.token[idx] = token_id;
batch.pos[idx] = i;

auto it = common_tokens.find(token_id);
if (it != common_tokens.end()) {
std::vector<token_position> tps = it->second;
batch.n_seq_id[idx] = tps.size();
batch.seq_id[idx][0] = s;
for (size_t j = 1; j < tps.size(); j++) {
auto tp = tps[j];
if (tp.seq_id != s) {
printf("seq: %zu, token %d: %s \n", s, token_id, tps[j].to_string().c_str());
batch.seq_id[idx][j] = tps[j].seq_id;
}
}
} else {
batch.n_seq_id[idx] = 1;
batch.seq_id[idx][0] = s; // the sequence id
}
batch.logits[idx] = i == prompt_tokens.size() - 1;

batch.n_tokens++;
printf("idx: %4d, token: %6d, seq_id: %ld, logits: %d\n", idx, prompt_tokens[i], p, batch.logits[idx]);
printf("idx: %4d, token: %6d, seq_id: %ld, logits: %d\n", idx, token_id, s, batch.logits[idx]);
}
}
return batch;
Expand All @@ -138,14 +150,17 @@ void print_batch(llama_batch batch) {
int main(int argc, char** argv) {
fprintf(stdout, "llama.cpp batch exploration\n");
llama_model_params model_params = llama_model_default_params();
std::string model_path = "models/llama-2-7b.Q4_K_M.gguf";
//std::string model_path = "models/llama-2-7b.Q4_K_M.gguf";
std::string model_path = "models/mamba-1.4b-f16.gguf";

model_params.main_gpu = 0;
model_params.n_gpu_layers = 0;

// This prompt is 69 tokens
std::string prompt1 = R"(You are an AI assistant specializing in task completion. Your goal is to provide clear, concise, and accurate responses to user queries. Always maintain a helpful and professional tone. If a request is unclear, ask for clarification. Prioritize user safety and ethical considerations in your answers.)";
std::string prompt2 = "What is the day following Thursday?";
//std::string prompt1 = R"(You are an AI assistant specializing in task completion. Your goal is to provide clear, concise, and accurate responses to user queries. Always maintain a helpful and professional tone. If a request is unclear, ask for clarification. Prioritize user safety and ethical considerations in your answers.)";
std::string prompt1 = "Yesterday was Friday";
//std::string prompt2 = "How many r's are there in strawberry?";
std::string prompt2 = "Tomorrow was Monday";

llama_backend_init();
llama_numa_init(GGML_NUMA_STRATEGY_DISABLED);
Expand Down Expand Up @@ -181,6 +196,7 @@ int main(int argc, char** argv) {
llama_batch batch = create_batch(512, {input_tokens1, input_tokens2}, model);
print_batch(batch);

/*
if (llama_decode(ctx, batch) != 0) {
fprintf(stderr, "llama_decode() failed\n");
return 1;
Expand All @@ -197,6 +213,7 @@ int main(int argc, char** argv) {
for (int i = embd_size - 10; i < embd_size; i++) {
fprintf(stderr, "logits2[%d]: %f\n", i, logits2[i]);
}
*/

llama_batch_free(batch);
llama_free(ctx);
Expand Down

0 comments on commit 61615c1

Please sign in to comment.