Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,16 @@ int main(int argc, char ** argv) {

params.embedding = true;

// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();

// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
params.n_parallel = n_seq_max;
}

// utilize the full context
Expand All @@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
params.n_ubatch = params.n_batch;
}

// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();

llama_backend_init();
llama_numa_init(params.numa);

Expand Down
4 changes: 2 additions & 2 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ llama_context::llama_context(

cross.v_embd.clear();

const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

// avoid reserving graphs with zero outputs - assume one output per sequence
Expand Down Expand Up @@ -542,7 +542,7 @@ bool llama_context::memory_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context");
}

const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't recall why this cparams.kv_unified check was added in #14363. It seems unnecessary now and removing it gives a better estimate of the worst-case graph.

const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
Expand Down
Loading