From c2a0a870c16ceffec6ee18609753e69361e8220c Mon Sep 17 00:00:00 2001
From: Sami Jaghouar <sami.jaghouar@hotmail.fr>
Date: Tue, 10 Dec 2024 00:24:47 +0000
Subject: [PATCH] fix block size

---
 debug_flex.py                      | 29 -----------------------------
 src/zeroband/models/llama/model.py | 11 +++++++++--
 tests/test_model.py                | 18 +++++++-----------
 3 files changed, 16 insertions(+), 42 deletions(-)
 delete mode 100644 debug_flex.py

diff --git a/debug_flex.py b/debug_flex.py
deleted file mode 100644
index 4b069e9..0000000
--- a/debug_flex.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import torch
-from torch.nn.attention.flex_attention import create_block_mask
-# flex_attention = torch.compile(flex_attention, dynamic=False)
-# create_block_mask = torch.compile(create_block_mask, dynamic=False)
-
-
-def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor:
-    """Converts list of sequence lengths to document indices tensor.
-    Example:
-        seqlens = [tensor([2,2,1]), tensor([2,2,1])]  # List of 2 tensors
-        docs = [[0,0,1,1,2], [0,0,1,1,2]] # Each doc_id repeated per its length
-    """
-    return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens])
-
-
-SEQ_LEN = 16
-BS = 8
-
-
-seqlens = [torch.Tensor([16 // 4] * 4).int().to("cuda") for _ in range(BS)]
-docs = seqlens_to_docs_tensor(seqlens)
-
-
-def document_masking(b, h, q_idx, kv_idx):
-    return docs[b, q_idx] == docs[b, kv_idx]
-
-
-# block_mask = create_block_mask(document_masking, BS, None, SEQ_LEN, SEQ_LEN, device="cuda", _compile=True)
-block_mask = create_block_mask(document_masking, BS, None, SEQ_LEN, SEQ_LEN, device="cuda", _compile=False)
diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py
index 4cc9713..7fe5b31 100644
--- a/src/zeroband/models/llama/model.py
+++ b/src/zeroband/models/llama/model.py
@@ -19,7 +19,7 @@
 from torch import nn
 from zeroband.models.norms import build_norm
 
-from torch.nn.attention.flex_attention import create_block_mask, flex_attention
+from torch.nn.attention.flex_attention import create_block_mask, flex_attention, _DEFAULT_SPARSE_BLOCK_SIZE
 
 flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
 
@@ -237,7 +237,14 @@ def document_causal_mask(b, h, q_idx, kv_idx):
             return causal_mask & document_mask
 
         block_mask = create_block_mask(
-            document_causal_mask, batch_size, None, max_seq_len, max_seq_len, device="cuda", _compile=True
+            document_causal_mask,
+            batch_size,
+            None,
+            max_seq_len,
+            max_seq_len,
+            device="cuda",
+            _compile=True,
+            BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE,
         )
 
         output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask)
diff --git a/tests/test_model.py b/tests/test_model.py
index b44b585..9f46a23 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -119,7 +119,6 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs):
     take two sequences and compare the outout of attention on individual sequences vs the output of attention on the packed sequence
     """
 
-    llama_config.attn_fn = "flash"
     model = Attention(llama_config).to("cuda")
     emb = torch.nn.Embedding(10, llama_config.dim).to("cuda")
 
@@ -129,8 +128,7 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs):
     seq_2 = [3, 7, 5, 6]
 
     input_packed_raw = torch.Tensor([seq_1 + seq_2]).long().to("cuda")
-    seqlens = [len(seq_1), len(seq_2)]
-    seqlens = torch.Tensor(seqlens).int().to("cuda")
+    seqlens = [torch.Tensor([len(seq_1), len(seq_2)]).int().to("cuda")]
 
     input_packed = emb(input_packed_raw)
 
@@ -166,7 +164,6 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs):
     take two sequences and compare the outout of attention on individual sequences vs the output of attention on the packed sequence
     """
 
-    llama_config.attn_fn = "flash"
     model = Attention(llama_config).to("cuda")
 
     freqs_cis = get_freqs_cis(llama_config)
@@ -176,11 +173,12 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs):
     for _ in range(10):
         seq_len_cutoff = random.randint(1, MAX_SEQ_LEN)
 
-        input_1 = torch.rand(1, seq_len_cutoff, llama_config.dim).to("cuda")
-        input_2 = torch.rand(1, MAX_SEQ_LEN - seq_len_cutoff, llama_config.dim).to("cuda")
+        seq1 = seq_len_cutoff
+        seq2 = MAX_SEQ_LEN - seq_len_cutoff
+        input_1 = torch.rand(1, seq1, llama_config.dim).to("cuda")
+        input_2 = torch.rand(1, seq2, llama_config.dim).to("cuda")
 
-        seqlens = [seq_len_cutoff, MAX_SEQ_LEN - seq_len_cutoff]
-        seqlens = torch.Tensor(seqlens).int().to("cuda")
+        seqlens = [torch.Tensor([seq1, seq2]).int().to("cuda")]
 
         packed_input = torch.cat([input_1, input_2], dim=1)
 
@@ -208,7 +206,6 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs):
 
 
 def test_end_to_end_packing(llama_config: ModelArgs):
-    llama_config.attn_fn = "flash"
     model = Transformer(llama_config).to("cuda")
 
     BS = 8
@@ -216,8 +213,7 @@ def test_end_to_end_packing(llama_config: ModelArgs):
 
     input_ = torch.randint(1, llama_config.vocab_size, (BS, SEQ_LEN)).to("cuda")
 
-    seqlens = [SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2]
-    seqlens = torch.Tensor(seqlens).int().to("cuda")
+    seqlens = [torch.Tensor([SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2]).int().to("cuda") for _ in range(BS)]
 
     with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
         output = model(input_, seqlens=seqlens)