Skip to content

Commit

Permalink
only use one block
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 10, 2024
1 parent 1643a6e commit 3a7f9cc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 38 deletions.
62 changes: 33 additions & 29 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn
from zeroband.models.norms import build_norm

from torch.nn.attention.flex_attention import create_block_mask, flex_attention, _DEFAULT_SPARSE_BLOCK_SIZE
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE

flex_attention_compiled = torch.compile(flex_attention, dynamic=False)

Expand Down Expand Up @@ -143,6 +143,27 @@ def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor:
return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens])


def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask:
docs = seqlens_to_docs_tensor(seqlens).to("cuda")
batch_size, max_seq_len = docs.shape

def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[b, q_idx] == docs[b, kv_idx]
return causal_mask & document_mask

return create_block_mask(
document_causal_mask,
batch_size,
None,
max_seq_len,
max_seq_len,
device="cuda",
_compile=True,
BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE,
)


class Attention(nn.Module):
"""
Multi-head attention module.
Expand Down Expand Up @@ -183,7 +204,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
seqlens: list[torch.Tensor] | None = None,
block_mask: BlockMask | None = None,
):
"""
Forward pass of the attention module.
Expand Down Expand Up @@ -217,7 +238,7 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

output = self.self_attention(xq, xk, xv, seqlens)
output = self.self_attention(xq, xk, xv, block_mask)

output = output.view(bs, seqlen, -1)
return self.wo(output)
Expand All @@ -227,35 +248,16 @@ def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor:
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output

def _flex_attention_with_seqlens(self, xq, xk, xv, seqlens: list[torch.Tensor]) -> torch.Tensor:
docs = seqlens_to_docs_tensor(seqlens).to(xq.device)
batch_size, max_seq_len = docs.shape

def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[b, q_idx] == docs[b, kv_idx]
return causal_mask & document_mask

block_mask = create_block_mask(
document_causal_mask,
batch_size,
None,
max_seq_len,
max_seq_len,
device="cuda",
_compile=True,
BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE,
)

def _flex_attention_with_seqlens(self, xq, xk, xv, block_mask: BlockMask) -> torch.Tensor:
output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output

def self_attention(
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, seqlens: list[torch.Tensor] | None = None
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, block_mask: BlockMask | None = None
) -> torch.Tensor:
if seqlens is not None:
return self._flex_attention_with_seqlens(xq, xk, xv, seqlens)
if block_mask is not None:
return self._flex_attention_with_seqlens(xq, xk, xv, block_mask)
else:
return self._sdpa_attention(xq, xk, xv)

Expand Down Expand Up @@ -350,7 +352,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
seqlens: torch.Tensor | None = None,
block_mask: BlockMask | None = None,
):
"""
Perform a forward pass through the TransformerBlock.
Expand All @@ -363,7 +365,7 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.
"""
h = x + self.attention(self.attention_norm(x), freqs_cis, seqlens=seqlens)
h = x + self.attention(self.attention_norm(x), freqs_cis, block_mask=block_mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out

Expand Down Expand Up @@ -475,8 +477,10 @@ def forward(self, tokens: torch.Tensor, seqlens: torch.Tensor | None = None):
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None

for layer in self.layers.values():
h = layer(h, self.freqs_cis, seqlens=seqlens)
h = layer(h, self.freqs_cis, block_mask=block_mask)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
Expand Down
21 changes: 12 additions & 9 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import torch
from zeroband.models.llama import Transformer, llama2_configs
from zeroband.models.llama.model import Attention, ModelArgs
from zeroband.models.llama.model import Attention, ModelArgs, create_block_mask_from_seqlens


VOCAB_SIZE = 1024
Expand All @@ -26,11 +26,9 @@ def llama_config() -> ModelArgs:
return config


@pytest.mark.parametrize("attn_fn", ["flash", "sdpa"])
def test_llama(llama_config: ModelArgs, attn_fn):
def test_llama(llama_config: ModelArgs):
seq_len = 512
bs = 8
llama_config.attn_fn = attn_fn
model = Transformer(llama_config).to("cuda")
input_ = torch.randint(0, llama_config.vocab_size, (bs, seq_len)).to("cuda")
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
Expand All @@ -51,14 +49,15 @@ def test_attn(llama_config: ModelArgs):
freqs_cis = get_freqs_cis(llama_config)
input_ = torch.rand(bs, seq_len, llama_config.dim).to("cuda")
seqlens = [torch.Tensor([seq_len]).int().to("cuda") for _ in range(bs)]
block_mask = create_block_mask_from_seqlens(seqlens)

attn = Attention(llama_config).to("cuda")

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output_sdpa = attn(input_, freqs_cis)

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output_flex = attn(input_, freqs_cis, seqlens=seqlens)
output_flex = attn(input_, freqs_cis, block_mask=block_mask)

rtol = ERROR_RTOL[torch.bfloat16]
atol = ERROR_ATOL[torch.bfloat16]
Expand All @@ -73,11 +72,12 @@ def test_packing_simple(llama_config: ModelArgs):
freqs_cis = get_freqs_cis(llama_config)
input_ = torch.rand(bs, seq_len, llama_config.dim).to("cuda")
seqlens = [torch.Tensor([seq_len // 4] * 4).int().to("cuda") for _ in range(bs)]
block_mask = create_block_mask_from_seqlens(seqlens)

attn = Attention(llama_config).to("cuda")

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = attn(input_, freqs_cis, seqlens=seqlens)
output = attn(input_, freqs_cis, block_mask=block_mask)

assert output.shape == (bs, seq_len, llama_config.dim)

Expand All @@ -95,13 +95,14 @@ def test_sequence_packing_two_time_same_sequence(llama_config: ModelArgs):
seq = [2, 1, 4, 8]
input_stuff_raw = torch.Tensor([seq + seq]).long().to("cuda")
seqlens = [torch.Tensor([len(seq), len(seq)]).int().to("cuda")]
block_mask = create_block_mask_from_seqlens(seqlens)

input_stuff = emb(input_stuff_raw)

freqs_cis = get_freqs_cis(llama_config)

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(input_stuff, freqs_cis, seqlens=seqlens)
output = model(input_stuff, freqs_cis, block_mask=block_mask)

output_left = output[:, :4, :]
output_right = output[:, 4:, :]
Expand Down Expand Up @@ -129,11 +130,12 @@ def test_sequence_packing_vs_normal(llama_config: ModelArgs):

input_packed_raw = torch.Tensor([seq_1 + seq_2]).long().to("cuda")
seqlens = [torch.Tensor([len(seq_1), len(seq_2)]).int().to("cuda")]
block_mask = create_block_mask_from_seqlens(seqlens)

input_packed = emb(input_packed_raw)

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(input_packed, freqs_cis, seqlens=seqlens)
output = model(input_packed, freqs_cis, block_mask=block_mask)

output_packed_1 = output[:, :4, :]
output_packed_2 = output[:, 4:, :]
Expand Down Expand Up @@ -179,12 +181,13 @@ def test_sequence_packing_vs_normal_random(llama_config: ModelArgs):
input_2 = torch.rand(1, seq2, llama_config.dim).to("cuda")

seqlens = [torch.Tensor([seq1, seq2]).int().to("cuda")]
block_mask = create_block_mask_from_seqlens(seqlens)

packed_input = torch.cat([input_1, input_2], dim=1)

# packed output
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(packed_input, freqs_cis, seqlens=seqlens)
output = model(packed_input, freqs_cis, block_mask=block_mask)

output_packed_1 = output[:, :seq_len_cutoff, :]
output_packed_2 = output[:, seq_len_cutoff:, :]
Expand Down

0 comments on commit 3a7f9cc

Please sign in to comment.