Skip to content

Commit

Permalink
llama: Ensure KV cache is fully defragmented.
Browse files Browse the repository at this point in the history
Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.

In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.
  • Loading branch information
jessegross committed Dec 17, 2024
1 parent 081b29b commit a2d4b6f
Showing 1 changed file with 46 additions and 53 deletions.
99 changes: 46 additions & 53 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,13 @@ struct llama_kv_cache {
}
};

// block of KV slots to move when defragging
struct llama_kv_defrag_move {
uint32_t src;
uint32_t dst;
uint32_t len;
};

struct llama_control_vector {
std::vector<struct ggml_tensor *> tensors; // per layer
std::vector<ggml_context_ptr> ctxs;
Expand Down Expand Up @@ -10652,67 +10659,53 @@ struct llm_build_context {
return gf;
}

struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);

for (uint32_t i = 0; i < ids.size(); ++i) {
const uint32_t id = ids[i];

if (i == id || id == ids.size()) {
continue;
}

uint32_t nm = 1;

while (i + nm < ids.size() && ids[i + nm] == id + nm) {
nm++;
}

for (const auto & move : moves) {
for (int il = 0; il < n_layer; ++il) {
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm,
n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));

ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm,
n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));

ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;

if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));

view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, i));
ggml_row_size(kv_self.v_l[il]->type, move.src));

view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, id));
ggml_row_size(kv_self.v_l[il]->type, move.dst));
}

ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
}

i += nm - 1;
}

//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
Expand Down Expand Up @@ -16944,7 +16937,7 @@ struct llm_build_context {
}
};

static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
llama_ubatch dummy = {};
dummy.equal_seqs = true;

Expand All @@ -16954,7 +16947,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const

llm.init();

struct ggml_cgraph * result = llm.build_defrag(ids);
struct ggml_cgraph * result = llm.build_defrag(moves);

llm.free();

Expand Down Expand Up @@ -17957,7 +17950,12 @@ static int llama_decode_internal(
kv_self.head = 0;
}

const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
llama_kv_cache_defrag(kv_self);
llama_kv_cache_update(&lctx);
slot = llama_kv_cache_find_slot(kv_self, ubatch);
}
if (!slot) {
return 1;
}
Expand Down Expand Up @@ -18359,8 +18357,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {

//const int64_t t_start = ggml_time_us();

// number of cells moved
uint32_t n_moves = 0;
// groups of cells moved
std::vector<struct llama_kv_defrag_move> moves;

// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
Expand Down Expand Up @@ -18424,19 +18422,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// are we moving a continuous block of memory?
bool cont = false;

// should we stop searching for the next move?
bool stop = false;

// go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1];

if (cell1.is_empty() || ids[i1] != n_kv) {
if (n_moves == max_moves) {
stop = true;
break;
}

cont = false;
continue;
}
Expand All @@ -18452,8 +18442,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
kv_self.head = n_used;

if (!cont) {
n_moves++;
moves.push_back({i1, i0 + nf, 1});
cont = true;
} else {
moves.back().len++;
}

nf++;
Expand All @@ -18463,22 +18455,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
}
}

if (stop || n_moves == max_moves) {
break;
}

//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);

i0 += nh - 1;
}

if (n_moves == 0) {
if (moves.size() == 0) {
return;
}

//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);

//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());

#if 0
// CPU defrag
Expand Down Expand Up @@ -18553,11 +18539,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
#else
// ggml_graph defrag

ggml_backend_sched_reset(lctx.sched.get());
for (std::size_t i = 0; i < moves.size(); i += max_moves) {
std::vector<struct llama_kv_defrag_move> chunk;
auto end = std::min(i + max_moves, moves.size());
chunk.assign(moves.begin() + i, moves.begin() + end);

ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
ggml_backend_sched_reset(lctx.sched.get());

//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);

llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
}
#endif

//const int64_t t_end = ggml_time_us();
Expand Down

0 comments on commit a2d4b6f

Please sign in to comment.