diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 1ce44a2b93..1df785a157 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -559,19 +559,20 @@ def _segment_ids_pos_to_seqlens_offsets( attn_mask = jnp.logical_and(segment_mask, causal_mask) # TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets - swa_mask = ( - make_swa_mask( - segment_pos_q, - segment_pos_kv, - window_size, - dtype=jnp.bool, - segment_ids_q=segment_ids_q, - segment_ids_kv=segment_ids_kv, - ) - if attn_mask_type.is_bottom_right() - else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) - ) - attn_mask = jnp.logical_and(attn_mask, swa_mask) + # swa_mask = ( + # make_swa_mask( + # segment_pos_q, + # segment_pos_kv, + # window_size, + # dtype=jnp.bool, + # segment_ids_q=segment_ids_q, + # segment_ids_kv=segment_ids_kv, + # ) + # if attn_mask_type.is_bottom_right() + # else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + # ) + # swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + # attn_mask = jnp.logical_and(attn_mask, swa_mask) attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(