Skip to content

Commit

Permalink
fix block size
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 10, 2024
1 parent a556771 commit c2a0a87
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 42 deletions.
29 changes: 0 additions & 29 deletions debug_flex.py

This file was deleted.

11 changes: 9 additions & 2 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn
from zeroband.models.norms import build_norm

from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, _DEFAULT_SPARSE_BLOCK_SIZE

flex_attention_compiled = torch.compile(flex_attention, dynamic=False)

Expand Down Expand Up @@ -237,7 +237,14 @@ def document_causal_mask(b, h, q_idx, kv_idx):
return causal_mask & document_mask

block_mask = create_block_mask(
document_causal_mask, batch_size, None, max_seq_len, max_seq_len, device="cuda", _compile=True
document_causal_mask,
batch_size,
None,
max_seq_len,
max_seq_len,
device="cuda",
_compile=True,
BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE,
)

output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask)
Expand Down
18 changes: 7 additions & 11 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs):
take two sequences and compare the outout of attention on individual sequences vs the output of attention on the packed sequence
"""

llama_config.attn_fn = "flash"
model = Attention(llama_config).to("cuda")
emb = torch.nn.Embedding(10, llama_config.dim).to("cuda")

Expand All @@ -129,8 +128,7 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs):
seq_2 = [3, 7, 5, 6]

input_packed_raw = torch.Tensor([seq_1 + seq_2]).long().to("cuda")
seqlens = [len(seq_1), len(seq_2)]
seqlens = torch.Tensor(seqlens).int().to("cuda")
seqlens = [torch.Tensor([len(seq_1), len(seq_2)]).int().to("cuda")]

input_packed = emb(input_packed_raw)

Expand Down Expand Up @@ -166,7 +164,6 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs):
take two sequences and compare the outout of attention on individual sequences vs the output of attention on the packed sequence
"""

llama_config.attn_fn = "flash"
model = Attention(llama_config).to("cuda")

freqs_cis = get_freqs_cis(llama_config)
Expand All @@ -176,11 +173,12 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs):
for _ in range(10):
seq_len_cutoff = random.randint(1, MAX_SEQ_LEN)

input_1 = torch.rand(1, seq_len_cutoff, llama_config.dim).to("cuda")
input_2 = torch.rand(1, MAX_SEQ_LEN - seq_len_cutoff, llama_config.dim).to("cuda")
seq1 = seq_len_cutoff
seq2 = MAX_SEQ_LEN - seq_len_cutoff
input_1 = torch.rand(1, seq1, llama_config.dim).to("cuda")
input_2 = torch.rand(1, seq2, llama_config.dim).to("cuda")

seqlens = [seq_len_cutoff, MAX_SEQ_LEN - seq_len_cutoff]
seqlens = torch.Tensor(seqlens).int().to("cuda")
seqlens = [torch.Tensor([seq1, seq2]).int().to("cuda")]

packed_input = torch.cat([input_1, input_2], dim=1)

Expand Down Expand Up @@ -208,16 +206,14 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs):


def test_end_to_end_packing(llama_config: ModelArgs):
llama_config.attn_fn = "flash"
model = Transformer(llama_config).to("cuda")

BS = 8
SEQ_LEN = 128

input_ = torch.randint(1, llama_config.vocab_size, (BS, SEQ_LEN)).to("cuda")

seqlens = [SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2]
seqlens = torch.Tensor(seqlens).int().to("cuda")
seqlens = [torch.Tensor([SEQ_LEN // 4, SEQ_LEN // 4, SEQ_LEN // 2]).int().to("cuda") for _ in range(BS)]

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(input_, seqlens=seqlens)
Expand Down

0 comments on commit c2a0a87

Please sign in to comment.