From c2a0a870c16ceffec6ee18609753e69361e8220c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar <sami.jaghouar@hotmail.fr> Date: Tue, 10 Dec 2024 00:24:47 +0000 Subject: [PATCH] fix block size --- debug_flex.py | 29 ----------------------------- src/zeroband/models/llama/model.py | 11 +++++++++-- tests/test_model.py | 18 +++++++----------- 3 files changed, 16 insertions(+), 42 deletions(-) delete mode 100644 debug_flex.py diff --git a/debug_flex.py b/debug_flex.py deleted file mode 100644 index 4b069e9..0000000 --- a/debug_flex.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from torch.nn.attention.flex_attention import create_block_mask -# flex_attention = torch.compile(flex_attention, dynamic=False) -# create_block_mask = torch.compile(create_block_mask, dynamic=False) - - -def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor: - """Converts list of sequence lengths to document indices tensor. - Example: - seqlens = [tensor([2,2,1]), tensor([2,2,1])] # List of 2 tensors - docs = [[0,0,1,1,2], [0,0,1,1,2]] # Each doc_id repeated per its length - """ - return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens]) - - -SEQ_LEN = 16 -BS = 8 - - -seqlens = [torch.Tensor([16 // 4] * 4).int().to("cuda") for _ in range(BS)] -docs = seqlens_to_docs_tensor(seqlens) - - -def document_masking(b, h, q_idx, kv_idx): - return docs[b, q_idx] == docs[b, kv_idx] - - -# block_mask = create_block_mask(document_masking, BS, None, SEQ_LEN, SEQ_LEN, device="cuda", _compile=True) -block_mask = create_block_mask(document_masking, BS, None, SEQ_LEN, SEQ_LEN, device="cuda", _compile=False) diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index 4cc9713..7fe5b31 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -19,7 +19,7 @@ from torch import nn from zeroband.models.norms import build_norm -from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from torch.nn.attention.flex_attention import create_block_mask, flex_attention, _DEFAULT_SPARSE_BLOCK_SIZE flex_attention_compiled = torch.compile(flex_attention, dynamic=False) @@ -237,7 +237,14 @@ def document_causal_mask(b, h, q_idx, kv_idx): return causal_mask & document_mask block_mask = create_block_mask( - document_causal_mask, batch_size, None, max_seq_len, max_seq_len, device="cuda", _compile=True + document_causal_mask, + batch_size, + None, + max_seq_len, + max_seq_len, + device="cuda", + _compile=True, + BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE, ) output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask) diff --git a/tests/test_model.py b/tests/test_model.py index b44b585..9f46a23 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -119,7 +119,6 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs): take two sequences and compare the outout of attention on individual sequences vs the output of attention on the packed sequence """ - llama_config.attn_fn = "flash" model = Attention(llama_config).to("cuda") emb = torch.nn.Embedding(10, llama_config.dim).to("cuda") @@ -129,8 +128,7 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs): seq_2 = [3, 7, 5, 6] input_packed_raw = torch.Tensor([seq_1 + seq_2]).long().to("cuda") - seqlens = [len(seq_1), len(seq_2)] - seqlens = torch.Tensor(seqlens).int().to("cuda") + seqlens = [torch.Tensor([len(seq_1), len(seq_2)]).int().to("cuda")] input_packed = emb(input_packed_raw) @@ -166,7 +164,6 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs): take two sequences and compare the outout of attention on individual sequences vs the output of attention on the packed sequence """ - llama_config.attn_fn = "flash" model = Attention(llama_config).to("cuda") freqs_cis = get_freqs_cis(llama_config) @@ -176,11 +173,12 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs): for _ in range(10): seq_len_cutoff = random.randint(1, MAX_SEQ_LEN) - input_1 = torch.rand(1, seq_len_cutoff, llama_config.dim).to("cuda") - input_2 = torch.rand(1, MAX_SEQ_LEN - seq_len_cutoff, llama_config.dim).to("cuda") + seq1 = seq_len_cutoff + seq2 = MAX_SEQ_LEN - seq_len_cutoff + input_1 = torch.rand(1, seq1, llama_config.dim).to("cuda") + input_2 = torch.rand(1, seq2, llama_config.dim).to("cuda") - seqlens = [seq_len_cutoff, MAX_SEQ_LEN - seq_len_cutoff] - seqlens = torch.Tensor(seqlens).int().to("cuda") + seqlens = [torch.Tensor([seq1, seq2]).int().to("cuda")] packed_input = torch.cat([input_1, input_2], dim=1) @@ -208,7 +206,6 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs): def test_end_to_end_packing(llama_config: ModelArgs): - llama_config.attn_fn = "flash" model = Transformer(llama_config).to("cuda") BS = 8 @@ -216,8 +213,7 @@ def test_end_to_end_packing(llama_config: ModelArgs): input_ = torch.randint(1, llama_config.vocab_size, (BS, SEQ_LEN)).to("cuda") - seqlens = [SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2] - seqlens = torch.Tensor(seqlens).int().to("cuda") + seqlens = [torch.Tensor([SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2]).int().to("cuda") for _ in range(BS)] with torch.autocast(device_type="cuda", dtype=torch.bfloat16): output = model(input_, seqlens=seqlens)