Skip to content

Commit

Permalink
add debug script
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 3, 2024
1 parent 0fa81e6 commit a556771
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions debug_flex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from torch.nn.attention.flex_attention import create_block_mask
# flex_attention = torch.compile(flex_attention, dynamic=False)
# create_block_mask = torch.compile(create_block_mask, dynamic=False)


def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor:
"""Converts list of sequence lengths to document indices tensor.
Example:
seqlens = [tensor([2,2,1]), tensor([2,2,1])] # List of 2 tensors
docs = [[0,0,1,1,2], [0,0,1,1,2]] # Each doc_id repeated per its length
"""
return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens])


SEQ_LEN = 16
BS = 8


seqlens = [torch.Tensor([16 // 4] * 4).int().to("cuda") for _ in range(BS)]
docs = seqlens_to_docs_tensor(seqlens)


def document_masking(b, h, q_idx, kv_idx):
return docs[b, q_idx] == docs[b, kv_idx]


# block_mask = create_block_mask(document_masking, BS, None, SEQ_LEN, SEQ_LEN, device="cuda", _compile=True)
block_mask = create_block_mask(document_masking, BS, None, SEQ_LEN, SEQ_LEN, device="cuda", _compile=False)

0 comments on commit a556771

Please sign in to comment.