Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 327 additions & 0 deletions generate_startend_row_indices.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,122 @@
import paddle
import numpy as np

def _scale_int_list(values, seqlen_k: int, base_seqlen_k: int = 8192, min_value: int = 1):
scale = seqlen_k / base_seqlen_k
return [max(min_value, int(v * scale)) for v in values]

def _default_doc_seqlens_int(seqlen_k: int):
# Keep in sync with the defaults used by generate_{causal_,}document_mask/share_question/causal_blockwise.
doc_seqlens = [2538, 1742, 3213]
if seqlen_k != 8192:
doc_seqlens = _scale_int_list(doc_seqlens, seqlen_k=seqlen_k)
return doc_seqlens

def _default_doc_seqlens_prefix_pairs(seqlen_k: int):
# Keep in sync with the defaults used by generate_prefix_lm_document_mask.
doc_seqlens = [(1024, 2538), (1742, 1742), (512, 3213)]
if seqlen_k != 8192:
scale = seqlen_k / 8192
doc_seqlens = [tuple(max(1, int(v * scale)) for v in pair) for pair in doc_seqlens]
return doc_seqlens

def _perturb_doc_seqlens_int(doc_seqlens, head_idx: int, seqlen_k: int):
"""
Deterministically perturb integer segment lengths to avoid identical per-head masks,
even when `doc_seqlens` is None.
"""
if not doc_seqlens:
return doc_seqlens
if len(doc_seqlens) == 1:
return [min(seqlen_k, max(1, int(doc_seqlens[0]) + ((head_idx % 3) - 1)))]

segs = [max(1, int(v)) for v in doc_seqlens]
delta = int((head_idx * 7) % 17) - 8 # [-8, 8], period 17
if delta == 0:
delta = 1
i = head_idx % len(segs)
j = (i + 1) % len(segs)

segs[i] = max(1, segs[i] + delta)
segs[j] = max(1, segs[j] - delta)

total = sum(segs)
if total > seqlen_k:
excess = total - seqlen_k
k = max(range(len(segs)), key=lambda t: segs[t])
segs[k] = max(1, segs[k] - excess)
return segs

def _perturb_doc_seqlens_prefix_pairs(doc_seqlens, head_idx: int, seqlen_k: int):
"""
Deterministically perturb prefix-LM (prefix_length, seq_length) segments per head.
Keeps 0 <= prefix_length <= seq_length and sum(seq_length) <= seqlen_k.
"""
if not doc_seqlens:
return doc_seqlens
pairs = [(max(0, int(p)), max(1, int(s))) for p, s in doc_seqlens]
if len(pairs) == 1:
p, s = pairs[0]
p = min(s, max(0, p + ((head_idx % 5) - 2)))
s = min(seqlen_k, max(1, s + ((head_idx % 3) - 1)))
p = min(p, s)
return [(p, s)]

delta = int((head_idx * 7) % 17) - 8
if delta == 0:
delta = 1
i = head_idx % len(pairs)
j = (i + 1) % len(pairs)

prefixes = [p for p, _ in pairs]
lengths = [s for _, s in pairs]

lengths[i] = max(1, lengths[i] + delta)
lengths[j] = max(1, lengths[j] - delta)

prefixes[i] = min(lengths[i], max(0, prefixes[i] + (delta // 2)))
prefixes[j] = min(lengths[j], max(0, prefixes[j] - (delta // 2)))

total = sum(lengths)
if total > seqlen_k:
excess = total - seqlen_k
k = max(range(len(lengths)), key=lambda t: lengths[t])
lengths[k] = max(1, lengths[k] - excess)
prefixes[k] = min(prefixes[k], lengths[k])

return list(zip(prefixes, lengths))

def _rotate_list(values, shift: int):
if not values:
return values
shift = shift % len(values)
return values[shift:] + values[:shift]

def _stack_per_head_startend_row_indices(
batch_size,
seqlen_q,
seqlen_k,
h,
per_head_fn,
):
"""
Build [b, h, seqlen_k, bound_num] by calling a per-head generator that returns
[b, 1, seqlen_k, bound_num] for each head.
"""
if h == 1:
return per_head_fn(0)
startend_list = []
causal = None
for head_idx in range(h):
s, c = per_head_fn(head_idx)
if causal is None:
causal = c
else:
assert c == causal
assert s.shape[0] == batch_size and s.shape[1] == 1 and s.shape[2] == seqlen_k
startend_list.append(s)
return paddle.concat(startend_list, axis=1), causal

def startend_row_indices_to_attn_bias(startend_row_indices, seqlen_q, nheads, dtype, causal=True):
if startend_row_indices is None:
return None
Expand Down Expand Up @@ -54,6 +170,30 @@ def generate_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, window_size=
causal=True
return startend_row_indices, causal

def generate_sliding_window_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, window_size=None):
"""
Per-head distinct sliding window: each head uses a different window_size.
Output shape: [b, h, seqlen_k, 1]
"""
if window_size is None:
window_size = 1024
if seqlen_k != 8192:
window_size = int(window_size * (seqlen_k / 8192))
print(f"{seqlen_k=}, auto setting window_size to {window_size}")

# Choose window sizes spanning roughly [window_size/2, window_size], clipped to [1, seqlen_q].
max_window = max(1, int(window_size))
head_ids = paddle.arange(h, dtype=paddle.int32)
denom = max(1, h - 1)
# linear ramp: largest window at head 0, smallest at head h-1
win = max_window - (head_ids * (max_window // 2) // denom)
win = paddle.clip(win, min=1, max=seqlen_q).reshape((1, h, 1, 1))

col = paddle.arange(seqlen_k, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1))
startend_row_indices = paddle.clip(col + win, max=seqlen_q).repeat_interleave(batch_size, 0)
causal = True
return startend_row_indices, causal

def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
# TODO: this seems buggy, to be fixed
if doc_seqlens == None:
Expand All @@ -75,6 +215,22 @@ def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens
causal = True
return startend_row_indices, causal

def generate_causal_document_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
"""
Per-head distinct causal document mask via rotating document segments per head.
Output shape: [b, h, seqlen_k, 1]
"""
base_doc_seqlens = _default_doc_seqlens_int(seqlen_k) if doc_seqlens is None else list(doc_seqlens)

def _per_head(head_idx: int):
per_head = _rotate_list(list(base_doc_seqlens), head_idx)
per_head = _perturb_doc_seqlens_int(per_head, head_idx=head_idx, seqlen_k=seqlen_k)
return generate_causal_document_mask(
batch_size, seqlen_q, seqlen_k, 1, per_head
)

return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head)

def generate_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
# TODO: this seems buggy, to be fixed
if doc_seqlens == None:
Expand Down Expand Up @@ -114,6 +270,22 @@ def generate_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
causal = False
return startend_row_indices, causal

def generate_document_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
"""
Per-head distinct bidirectional document mask via rotating document segments per head.
Output shape: [b, h, seqlen_k, 2]
"""
base_doc_seqlens = _default_doc_seqlens_int(seqlen_k) if doc_seqlens is None else list(doc_seqlens)

def _per_head(head_idx: int):
per_head = _rotate_list(list(base_doc_seqlens), head_idx)
per_head = _perturb_doc_seqlens_int(per_head, head_idx=head_idx, seqlen_k=seqlen_k)
return generate_document_mask(
batch_size, seqlen_q, seqlen_k, 1, per_head
)

return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head)

def generate_share_question_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
if doc_seqlens == None:
doc_seqlens = [2538, 1742, 3213]
Expand Down Expand Up @@ -146,6 +318,22 @@ def generate_share_question_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=
causal = True
return startend_row_indices, causal

def generate_share_question_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
"""
Per-head distinct share-question mask via rotating document segments per head.
Output shape: [b, h, seqlen_k, 1]
"""
base_doc_seqlens = _default_doc_seqlens_int(seqlen_k) if doc_seqlens is None else list(doc_seqlens)

def _per_head(head_idx: int):
per_head = _rotate_list(list(base_doc_seqlens), head_idx)
per_head = _perturb_doc_seqlens_int(per_head, head_idx=head_idx, seqlen_k=seqlen_k)
return generate_share_question_mask(
batch_size, seqlen_q, seqlen_k, 1, per_head
)

return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head)

def generate_global_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, global_token=16, window_size=None):
if window_size == None:
window_size = (512, 512)
Expand Down Expand Up @@ -184,6 +372,51 @@ def generate_global_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, globa
causal = False
return startend_row_indices, causal

def generate_global_sliding_window_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, global_token=16, window_size=None):
"""
Per-head distinct global sliding window mask by varying (global_token, left/right window) per head.
Output shape: [b, h, seqlen_k, 4]
"""
if window_size is None:
window_size = (512, 512)
if seqlen_k != 8192:
window_size = tuple(int(ws * (seqlen_k / 8192)) for ws in window_size)
print(f"{seqlen_k=}, auto setting window_size to {window_size}")
left_base, right_base = window_size

head_ids = paddle.arange(h, dtype=paddle.int32)
denom = max(1, h - 1)

# Vary global token count modestly across heads.
gt_base = int(global_token)
gt = gt_base + (head_ids - (h // 2)) # centered offsets
gt = paddle.clip(gt, min=0, max=seqlen_k).reshape((1, h, 1, 1))

# Vary left/right windows across heads.
left = int(left_base) - (head_ids * max(1, int(left_base) // 4) // denom)
right = int(right_base) - (head_ids * max(1, int(right_base) // 4) // denom)
left = paddle.clip(left, min=0, max=seqlen_q).reshape((1, h, 1, 1))
right = paddle.clip(right, min=0, max=seqlen_q).reshape((1, h, 1, 1))

col = paddle.arange(seqlen_k, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1))

down_left_start = paddle.clip(col + left + 1, max=seqlen_q)
down_left_start = paddle.where(col < gt, paddle.full_like(down_left_start, seqlen_q), down_left_start)
down_left_end = paddle.full_like(down_left_start, seqlen_q)

up_right_start = paddle.full_like(down_left_start, 0)
up_right_start = paddle.where(col >= (gt + right + 1), gt, up_right_start)

up_right_end = col - right
up_right_end = paddle.where(col < (gt + right + 1), paddle.zeros_like(up_right_end), up_right_end)

startend_row_indices = paddle.concat(
[down_left_start, down_left_end, up_right_start, up_right_end], axis=-1
).repeat_interleave(batch_size, 0)
startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q)
causal = False
return startend_row_indices, causal

def generate_causal_blockwise_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
# TODO: this seems buggy, to be fixed
if doc_seqlens == None:
Expand Down Expand Up @@ -216,6 +449,22 @@ def generate_causal_blockwise_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlen
causal = True
return startend_row_indices, causal

def generate_causal_blockwise_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
"""
Per-head distinct causal blockwise mask via rotating document segments per head.
Output shape: [b, h, seqlen_k, 2]
"""
base_doc_seqlens = _default_doc_seqlens_int(seqlen_k) if doc_seqlens is None else list(doc_seqlens)

def _per_head(head_idx: int):
per_head = _rotate_list(list(base_doc_seqlens), head_idx)
per_head = _perturb_doc_seqlens_int(per_head, head_idx=head_idx, seqlen_k=seqlen_k)
return generate_causal_blockwise_mask(
batch_size, seqlen_q, seqlen_k, 1, per_head
)

return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head)

def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
"""
tuple(prefix_length, seq_length)
Expand Down Expand Up @@ -260,6 +509,22 @@ def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seql
causal = False
return startend_row_indices, causal

def generate_prefix_lm_document_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
"""
Per-head distinct prefix-LM document mask via rotating (prefix_length, seq_length) segments per head.
Output shape: [b, h, seqlen_k, 2]
"""
base_doc_seqlens = _default_doc_seqlens_prefix_pairs(seqlen_k) if doc_seqlens is None else list(doc_seqlens)

def _per_head(head_idx: int):
per_head = _rotate_list(list(base_doc_seqlens), head_idx)
per_head = _perturb_doc_seqlens_prefix_pairs(per_head, head_idx=head_idx, seqlen_k=seqlen_k)
return generate_prefix_lm_document_mask(
batch_size, seqlen_q, seqlen_k, 1, per_head
)

return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head)

def generate_prefix_lm_causal_mask(batch_size, seqlen_q, seqlen_k, h, prefix_length=None):
"""
tuple(prefix_length, seq_length)
Expand All @@ -278,6 +543,32 @@ def generate_prefix_lm_causal_mask(batch_size, seqlen_q, seqlen_k, h, prefix_len
causal = False
return startend_row_indices, causal

def generate_prefix_lm_causal_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, prefix_length=None):
"""
Per-head distinct prefix-LM causal mask by varying prefix_length per head.
Output shape: [b, h, seqlen_k, 2]
"""
if prefix_length is None:
prefix_length = 1024
if seqlen_k != 8192:
prefix_length = int(prefix_length * (seqlen_k / 8192))
print(f"{seqlen_k=}, auto setting doc_seqlens to {prefix_length}")
base = int(prefix_length)
head_ids = paddle.arange(h, dtype=paddle.int32)
denom = max(1, h - 1)
# Spread prefix lengths roughly in [base/2, base], clipped.
pref = base - (head_ids * (base // 2) // denom)
pref = paddle.clip(pref, min=0, max=seqlen_k).reshape((1, h, 1, 1))

col = paddle.arange(seqlen_k, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1))
down_left = paddle.full((1, h, seqlen_k, 1), seqlen_k, dtype=paddle.int32)
up_right = paddle.where(col < pref, paddle.zeros_like(col), col)

startend_row_indices = paddle.concat([down_left, up_right], axis=-1).repeat_interleave(batch_size, 0)
startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q)
causal = False
return startend_row_indices, causal

def generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None):
"""
tuple(offset, maskout_len)
Expand Down Expand Up @@ -313,6 +604,42 @@ def generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None
causal = True
return startend_row_indices, causal

def generate_qk_sparse_mask_distinct_heads(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None):
"""
Per-head distinct qk-sparse mask by shifting maskout offsets per head.
Output shape: [b, h, seqlen_k, 2]
"""
if maskout_pair is None:
maskout_pair = [(1024, 538), (2358, 1700)]
if seqlen_k != 8192:
scale = seqlen_k / 8192
maskout_pair = [tuple(int(v * scale) for v in pair) for pair in maskout_pair]
print(f"{seqlen_k=}, auto setting maskout_pair to {maskout_pair}")

base_pairs = list(maskout_pair)
shift_stride = max(1, seqlen_k // max(8, h))

def _per_head(head_idx: int):
shift = head_idx * shift_stride
pairs = []
for offset, length in base_pairs:
o = min(max(0, offset + shift), seqlen_k)
l = min(max(0, length), max(0, seqlen_k - o))
pairs.append((o, l))
# Ensure offsets are strictly increasing for the generator.
pairs.sort(key=lambda x: x[0])
dedup = []
last_offset = -1
for o, l in pairs:
o = max(o, last_offset + 1)
o = min(o, seqlen_k)
l = min(l, max(0, seqlen_k - o))
dedup.append((o, l))
last_offset = o
return generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, 1, dedup)

return _stack_per_head_startend_row_indices(batch_size, seqlen_q, seqlen_k, h, _per_head)

def generate_random_eviction_mask(batch_size, seqlen_q, seqlen_k, h, start_row=None):
# np.random.seed(0)
if start_row == None:
Expand Down
Loading