@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261    }
262262}
263263
264- static  void  print_mask (float  * data, int64_t  n_tokens, int64_t  n_kv, int64_t  n_swa, llama_swa_type swa_type) {
264+ static  void  print_mask (const   float  * data, int64_t  n_tokens, int64_t  n_kv, int64_t  n_swa, llama_swa_type swa_type) {
265265    LLAMA_LOG_DEBUG (" %s: === Attention mask ===\n " 
266-     const  char  * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? " LLAMA_SWA_TYPE_NONE" 
267-                           (swa_type == LLAMA_SWA_TYPE_STANDARD) ? " LLAMA_SWA_TYPE_STANDARD" 
268-                           (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? " LLAMA_SWA_TYPE_CHUNKED" 
269-                           (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? " LLAMA_SWA_TYPE_SYMMETRIC" " unknown" 
266+     const  char  * swa_type_str = " unknown" 
267+ 
268+     switch  (swa_type) {
269+         case  LLAMA_SWA_TYPE_NONE:      swa_type_str = " LLAMA_SWA_TYPE_NONE" break ;
270+         case  LLAMA_SWA_TYPE_STANDARD:  swa_type_str = " LLAMA_SWA_TYPE_STANDARD" break ;
271+         case  LLAMA_SWA_TYPE_CHUNKED:   swa_type_str = " LLAMA_SWA_TYPE_CHUNKED" break ;
272+         case  LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = " LLAMA_SWA_TYPE_SYMMETRIC" break ;
273+     };
274+ 
270275    LLAMA_LOG_DEBUG (" %s: n_swa : %d, n_kv: %d, swq_type: %s\n " int )n_swa, (int )n_kv, swa_type_str);
271276    LLAMA_LOG_DEBUG (" %s: '0' = can attend, '∞' = masked\n " 
272277    LLAMA_LOG_DEBUG (" %s: Rows = query tokens, Columns = key/value tokens\n\n " 
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295300    const  int64_t  n_kv     = ubatch->n_tokens ;
296301    const  int64_t  n_tokens = ubatch->n_tokens ;
297302
298-     GGML_ASSERT (kq_mask);
299-     GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
300- 
301-     float  * data = (float  *) kq_mask->data ;
302- 
303-     //  [TAG_NO_CACHE_ISWA]
304-     GGML_ASSERT (hparams.swa_type  == LLAMA_SWA_TYPE_NONE && " TODO: implement" 
303+     const  auto  fill_mask = [&](float  * data, int  n_swa, llama_swa_type swa_type) {
304+         for  (int  h = 0 ; h < 1 ; ++h) {
305+             for  (int  i1 = 0 ; i1 < n_tokens; ++i1) {
306+                 const  llama_seq_id s1 = ubatch->seq_id [i1][0 ];
307+                 const  llama_pos    p1 = ubatch->pos [i1];
305308
306-     for  (int  h = 0 ; h < 1 ; ++h) {
307-         for  (int  i1 = 0 ; i1 < n_tokens; ++i1) {
308-             const  llama_seq_id s1 = ubatch->seq_id [i1][0 ];
309+                 const  uint64_t  idst = h*(n_kv*n_tokens) + i1*n_kv;
309310
310-             for  (int  i0 = 0 ; i0 < n_tokens; ++i0) {
311-                 float  f = -INFINITY;
312- 
313-                 for  (int  s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
311+                 for  (int  i0 = 0 ; i0 < n_tokens; ++i0) {
314312                    const  llama_seq_id s0 = ubatch->seq_id [i0][0 ];
313+                     const  llama_pos p0    = ubatch->pos [i0];
315314
315+                     //  mask different sequences
316316                    if  (s0 != s1) {
317-                         continue ;  //  skip different sequences 
317+                         continue ;
318318                    }
319319
320-                     if  (cparams.causal_attn  && ubatch->pos [i0] > ubatch->pos [i1]) {
321-                         continue ; //  skip future tokens for causal attention
320+                     //  mask future tokens
321+                     if  (cparams.causal_attn  && p0 > p1) {
322+                         continue ;
322323                    }
323324
324-                     //  TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325-                     // if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326-                     //     continue; // skip masked tokens for SWA
327-                     // }
328- 
329-                     //  TODO: reimplement this like in llama_kv_cache_unified
330-                     if  (hparams.use_alibi ) {
331-                         f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
332-                     } else  {
333-                         f = 0 .0f ;
325+                     //  apply SWA if any
326+                     if  (llama_hparams::is_masked_swa (n_swa, swa_type, p0, p1)) {
327+                         continue ;
334328                    }
329+ 
330+                     data[idst + i0] = hparams.use_alibi  ? -std::abs (p0 - p1) : 0 .0f ;
335331                }
336-                 data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
337332            }
338333        }
334+     };
335+ 
336+     {
337+         GGML_ASSERT (self_kq_mask);
338+         GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
339+ 
340+         float  * data = (float  *) self_kq_mask->data ;
341+ 
342+         std::fill (data, data + ggml_nelements (self_kq_mask), -INFINITY);
343+ 
344+         fill_mask (data, 0 , LLAMA_SWA_TYPE_NONE);
345+ 
346+         if  (debug) {
347+             print_mask (data, n_tokens, n_kv, 0 , LLAMA_SWA_TYPE_NONE);
348+         }
339349    }
340-     if  (debug) {
341-         print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
350+ 
351+     if  (hparams.swa_type  != LLAMA_SWA_TYPE_NONE) {
352+         GGML_ASSERT (self_kq_mask_swa);
353+         GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
354+ 
355+         float  * data = (float  *) self_kq_mask_swa->data ;
356+ 
357+         std::fill (data, data + ggml_nelements (self_kq_mask_swa), -INFINITY);
358+ 
359+         fill_mask (data, hparams.n_swa , hparams.swa_type );
360+ 
361+         if  (debug) {
362+             print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
363+         }
342364    }
343365}
344366
@@ -1299,12 +1321,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12991321    k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
13001322    v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
13011323
1302-     const  auto  n_kv = k->ne [1 ];
1303- 
13041324    ggml_tensor * cur;
13051325
13061326    //  TODO: replace hardcoded padding with ggml-provided padding
1307-     if  (cparams.flash_attn  && (n_kv %  256  ==  0 ) &&  kq_b == nullptr ) {
1327+     if  (cparams.flash_attn  && kq_b == nullptr ) {
13081328        GGML_ASSERT (kq_b == nullptr  && " Flash attention does not support KQ bias yet" 
13091329
13101330        if  (v_trans) {
@@ -1419,10 +1439,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14191439    auto  inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14201440
14211441    //  note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422-     inp->kq_mask  = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1423-     ggml_set_input (inp->kq_mask );
1442+     inp->self_kq_mask  = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1443+     ggml_set_input (inp->self_kq_mask );
1444+ 
1445+     inp->self_kq_mask_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
14241446
1425-     inp->kq_mask_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
1447+     if  (hparams.swa_type  != LLAMA_SWA_TYPE_NONE) {
1448+         inp->self_kq_mask_swa  = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1449+         ggml_set_input (inp->self_kq_mask_swa );
1450+ 
1451+         inp->self_kq_mask_swa_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
1452+     } else  {
1453+         inp->self_kq_mask_swa      = nullptr ;
1454+         inp->self_kq_mask_swa_cnv  = nullptr ;
1455+     }
14261456
14271457    return  (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
14281458}
@@ -1447,7 +1477,9 @@ ggml_tensor * llm_graph_context::build_attn(
14471477    ggml_build_forward_expand (gf, k_cur);
14481478    ggml_build_forward_expand (gf, v_cur);
14491479
1450-     const  auto  & kq_mask = inp->get_kq_mask ();
1480+     const  bool  is_swa = hparams.is_swa (il);
1481+ 
1482+     const  auto  & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
14511483
14521484    //  [TAG_NO_CACHE_PAD]
14531485    //  TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
0 commit comments