Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama: Ensure KV cache is fully defragmented. #10873

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 46 additions & 53 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,13 @@ struct llama_kv_cache {
}
};

// block of KV slots to move when defragging
struct llama_kv_defrag_move {
uint32_t src;
uint32_t dst;
uint32_t len;
};

struct llama_control_vector {
std::vector<struct ggml_tensor *> tensors; // per layer
std::vector<ggml_context_ptr> ctxs;
Expand Down Expand Up @@ -10652,67 +10659,53 @@ struct llm_build_context {
return gf;
}

struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);

for (uint32_t i = 0; i < ids.size(); ++i) {
const uint32_t id = ids[i];

if (i == id || id == ids.size()) {
continue;
}

uint32_t nm = 1;

while (i + nm < ids.size() && ids[i + nm] == id + nm) {
nm++;
}

for (const auto & move : moves) {
for (int il = 0; il < n_layer; ++il) {
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm,
n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));

ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm,
n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));

ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;

if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));

view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, i));
ggml_row_size(kv_self.v_l[il]->type, move.src));

view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, id));
ggml_row_size(kv_self.v_l[il]->type, move.dst));
}

ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
}

i += nm - 1;
}

//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
Expand Down Expand Up @@ -16944,7 +16937,7 @@ struct llm_build_context {
}
};

static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
llama_ubatch dummy = {};
dummy.equal_seqs = true;

Expand All @@ -16954,7 +16947,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const

llm.init();

struct ggml_cgraph * result = llm.build_defrag(ids);
struct ggml_cgraph * result = llm.build_defrag(moves);

llm.free();

Expand Down Expand Up @@ -17957,7 +17950,12 @@ static int llama_decode_internal(
kv_self.head = 0;
}

const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
llama_kv_cache_defrag(kv_self);
llama_kv_cache_update(&lctx);
slot = llama_kv_cache_find_slot(kv_self, ubatch);
}
if (!slot) {
return 1;
}
Expand Down Expand Up @@ -18359,8 +18357,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {

//const int64_t t_start = ggml_time_us();

// number of cells moved
uint32_t n_moves = 0;
// groups of cells moved
std::vector<struct llama_kv_defrag_move> moves;

// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
Expand Down Expand Up @@ -18424,19 +18422,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// are we moving a continuous block of memory?
bool cont = false;

// should we stop searching for the next move?
bool stop = false;

// go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1];

if (cell1.is_empty() || ids[i1] != n_kv) {
if (n_moves == max_moves) {
stop = true;
break;
}

cont = false;
continue;
}
Expand All @@ -18452,8 +18442,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
kv_self.head = n_used;

if (!cont) {
n_moves++;
moves.push_back({i1, i0 + nf, 1});
cont = true;
} else {
moves.back().len++;
}

nf++;
Expand All @@ -18463,22 +18455,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
}
}

if (stop || n_moves == max_moves) {
break;
}

//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);

i0 += nh - 1;
}

if (n_moves == 0) {
if (moves.size() == 0) {
return;
}

//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);

//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());

#if 0
// CPU defrag
Expand Down Expand Up @@ -18553,11 +18539,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
#else
// ggml_graph defrag

ggml_backend_sched_reset(lctx.sched.get());
for (std::size_t i = 0; i < moves.size(); i += max_moves) {
std::vector<struct llama_kv_defrag_move> chunk;
auto end = std::min(i + max_moves, moves.size());
chunk.assign(moves.begin() + i, moves.begin() + end);

ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
ggml_backend_sched_reset(lctx.sched.get());

//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);

llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
}
#endif

//const int64_t t_end = ggml_time_us();
Expand Down
Loading