Skip to content

Commit e38b7c6

Browse files
authored
graph : support cacheless embeddings with FA and iSWA (ggml-org#16528)
* graph : support cacheless embeddings with FA and iSWA * cont : deduplicate mask creation * cont : fix name
1 parent 5016b72 commit e38b7c6

File tree

4 files changed

+87
-51
lines changed

4 files changed

+87
-51
lines changed

src/llama-graph.cpp

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261
}
262262
}
263263

264-
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
264+
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265265
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266-
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267-
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268-
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269-
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
266+
const char * swa_type_str = "unknown";
267+
268+
switch (swa_type) {
269+
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
270+
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
271+
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
272+
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
273+
};
274+
270275
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
271276
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
272277
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295300
const int64_t n_kv = ubatch->n_tokens;
296301
const int64_t n_tokens = ubatch->n_tokens;
297302

298-
GGML_ASSERT(kq_mask);
299-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
300-
301-
float * data = (float *) kq_mask->data;
302-
303-
// [TAG_NO_CACHE_ISWA]
304-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
303+
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
304+
for (int h = 0; h < 1; ++h) {
305+
for (int i1 = 0; i1 < n_tokens; ++i1) {
306+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
307+
const llama_pos p1 = ubatch->pos[i1];
305308

306-
for (int h = 0; h < 1; ++h) {
307-
for (int i1 = 0; i1 < n_tokens; ++i1) {
308-
const llama_seq_id s1 = ubatch->seq_id[i1][0];
309+
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
309310

310-
for (int i0 = 0; i0 < n_tokens; ++i0) {
311-
float f = -INFINITY;
312-
313-
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
311+
for (int i0 = 0; i0 < n_tokens; ++i0) {
314312
const llama_seq_id s0 = ubatch->seq_id[i0][0];
313+
const llama_pos p0 = ubatch->pos[i0];
315314

315+
// mask different sequences
316316
if (s0 != s1) {
317-
continue; // skip different sequences
317+
continue;
318318
}
319319

320-
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321-
continue; // skip future tokens for causal attention
320+
// mask future tokens
321+
if (cparams.causal_attn && p0 > p1) {
322+
continue;
322323
}
323324

324-
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325-
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326-
// continue; // skip masked tokens for SWA
327-
//}
328-
329-
// TODO: reimplement this like in llama_kv_cache_unified
330-
if (hparams.use_alibi) {
331-
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332-
} else {
333-
f = 0.0f;
325+
// apply SWA if any
326+
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
327+
continue;
334328
}
329+
330+
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
335331
}
336-
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
337332
}
338333
}
334+
};
335+
336+
{
337+
GGML_ASSERT(self_kq_mask);
338+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
339+
340+
float * data = (float *) self_kq_mask->data;
341+
342+
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
343+
344+
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
345+
346+
if (debug) {
347+
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
348+
}
339349
}
340-
if (debug) {
341-
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
350+
351+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
352+
GGML_ASSERT(self_kq_mask_swa);
353+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
354+
355+
float * data = (float *) self_kq_mask_swa->data;
356+
357+
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
358+
359+
fill_mask(data, hparams.n_swa, hparams.swa_type);
360+
361+
if (debug) {
362+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
363+
}
342364
}
343365
}
344366

@@ -1299,12 +1321,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12991321
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
13001322
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
13011323

1302-
const auto n_kv = k->ne[1];
1303-
13041324
ggml_tensor * cur;
13051325

13061326
// TODO: replace hardcoded padding with ggml-provided padding
1307-
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1327+
if (cparams.flash_attn && kq_b == nullptr) {
13081328
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
13091329

13101330
if (v_trans) {
@@ -1419,10 +1439,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14191439
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14201440

14211441
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422-
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1423-
ggml_set_input(inp->kq_mask);
1442+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1443+
ggml_set_input(inp->self_kq_mask);
1444+
1445+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
14241446

1425-
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1447+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1448+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1449+
ggml_set_input(inp->self_kq_mask_swa);
1450+
1451+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1452+
} else {
1453+
inp->self_kq_mask_swa = nullptr;
1454+
inp->self_kq_mask_swa_cnv = nullptr;
1455+
}
14261456

14271457
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
14281458
}
@@ -1447,7 +1477,9 @@ ggml_tensor * llm_graph_context::build_attn(
14471477
ggml_build_forward_expand(gf, k_cur);
14481478
ggml_build_forward_expand(gf, v_cur);
14491479

1450-
const auto & kq_mask = inp->get_kq_mask();
1480+
const bool is_swa = hparams.is_swa(il);
1481+
1482+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
14511483

14521484
// [TAG_NO_CACHE_PAD]
14531485
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams

src/llama-graph.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,14 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
257257

258258
void set_input(const llama_ubatch * ubatch) override;
259259

260-
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
260+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
261+
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
261262

262-
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
263-
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
263+
// n_tokens == n_batch
264+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
265+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
266+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
267+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
264268

265269
const llama_hparams hparams;
266270
const llama_cparams cparams;

src/llama-model.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11358,8 +11358,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
1135811358
}
1135911359
};
1136011360

11361-
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
11362-
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11361+
struct llm_build_gemma_embedding : public llm_graph_context {
11362+
llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1136311363
const int64_t n_embd_head = hparams.n_embd_head_k;
1136411364

1136511365
ggml_tensor * cur;
@@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
1137611376
// inp_pos - contains the positions
1137711377
ggml_tensor * inp_pos = build_inp_pos();
1137811378

11379-
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
11380-
auto * inp_attn = build_attn_inp_kv_iswa();
11379+
auto * inp_attn = build_attn_inp_no_cache();
1138111380

1138211381
ggml_tensor * inp_out_ids = build_inp_out_ids();
1138311382

@@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1937819377
case LLM_ARCH_NOMIC_BERT_MOE:
1937919378
case LLM_ARCH_NEO_BERT:
1938019379
case LLM_ARCH_WAVTOKENIZER_DEC:
19381-
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
19380+
case LLM_ARCH_GEMMA_EMBEDDING:
1938219381
case LLM_ARCH_DREAM:
1938319382
case LLM_ARCH_LLADA:
1938419383
case LLM_ARCH_LLADA_MOE:
@@ -19671,7 +19670,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1967119670
} break;
1967219671
case LLM_ARCH_GEMMA_EMBEDDING:
1967319672
{
19674-
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
19673+
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
1967519674
} break;
1967619675
case LLM_ARCH_STARCODER2:
1967719676
{

src/llama.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits(
312312
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
313313
return nullptr;
314314
}
315+
splits.reserve(n_paths);
315316
for (size_t i = 0; i < n_paths; ++i) {
316317
splits.push_back(paths[i]);
317318
}

0 commit comments

Comments
 (0)