Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecated fa2 and use flex attention #166

Merged
merged 7 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "ZeroBand is a production ready codebase for decentralized trainin
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.4.1",
"torch==2.5.1",
"numpy",
"setuptools",
"transformers>=4.44.2",
Expand Down
6 changes: 3 additions & 3 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def load_state_dict(self, state_dict):
self.state = SequencePackingDataSetState(**state_dict["state"])


def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor]:
def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor | list[torch.LongTensor]]:
assert samples[0].keys() == {"input_ids", "labels", "seqlens"}

inputs_ids = []
Expand All @@ -136,12 +136,12 @@ def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.Lo
inputs_ids.append(sample["input_ids"])
labels.append(sample["labels"])

seqlens.extend(sample["seqlens"])
seqlens.append(torch.Tensor(sample["seqlens"]).long())

return {
"input_ids": torch.stack(inputs_ids, dim=0),
"labels": torch.stack(labels, dim=0),
"seqlens": torch.Tensor(seqlens).long(),
"seqlens": seqlens,
}


Expand Down
6 changes: 4 additions & 2 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@


def get_model(
name_model: str, type_model: str, vocab_size: int, seq_length: int, attn_fn: str
name_model: str,
type_model: str,
vocab_size: int,
seq_length: int,
) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

Expand All @@ -94,6 +97,5 @@ def get_model(

config.vocab_size = vocab_size
config.max_seq_len = seq_length
config.attn_fn = attn_fn

return Transformer(config), config
142 changes: 79 additions & 63 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,31 @@


from dataclasses import dataclass
from importlib.util import find_spec
from typing import Literal, Optional, Tuple
from einops import rearrange
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from zeroband.models.norms import build_norm

flash_attn_available = find_spec("flash_attn") is not None
if flash_attn_available:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE

_flex_attention_compiled = torch.compile(flex_attention, dynamic=False)


# copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27
# We cannot do nested compile, but flex attention only has perf benefits
# when compiled. To insulate it from the compiler, we wrap it with
# compiler.disable so that it can be used regardless of whether the model
# is compiled or not, and flex attention always remains compiled.
@torch.compiler.disable(recursive=False)
def flex_attention_compiled(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_mask: BlockMask,
) -> torch.Tensor:
return _flex_attention_compiled(q, k, v, block_mask=block_mask)


@dataclass
Expand All @@ -45,8 +58,6 @@ class ModelArgs:
depth_init: bool = True
norm_type: str = "fused_rmsnorm"

attn_fn: Literal["sdpa", "flash"] = "sdpa"


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Expand Down Expand Up @@ -138,6 +149,48 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
)


def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor:
"""Converts list of sequence lengths to document indices tensor.
Example:
seqlens = [tensor([2,2,1]), tensor([2,2,1])] # List of 2 tensors
docs = [[0,0,1,1,2], [0,0,1,1,2]] # Each doc_id repeated per its length
"""
return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens])


def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask:
"""Creates a block mask from a list of sequence lengths.

Example:
seqlens = [tensor([2,2,1]))] # List of 2 tensors
docs = [[0,0,1,1,2]] # Each doc_id repeated per its length

mask = [[1 1 0 0 0] # First token of doc 0 can see itself and second token of doc 0
[1 1 0 0 0] # Second token of doc 0 can see both tokens of doc 0
[0 0 1 1 0] # First token of doc 1 can see itself and second token of doc 1
[0 0 1 1 0] # Second token of doc 1 can see both tokens of doc 1
[0 0 0 0 1]] # Token of doc 2 can only see itself
"""
docs = seqlens_to_docs_tensor(seqlens).to("cuda")
batch_size, max_seq_len = docs.shape

def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[b, q_idx] == docs[b, kv_idx]
return causal_mask & document_mask

return create_block_mask(
document_causal_mask,
batch_size,
None,
max_seq_len,
max_seq_len,
device="cuda",
_compile=True,
BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE,
)


class Attention(nn.Module):
"""
Multi-head attention module.
Expand All @@ -164,8 +217,6 @@ def __init__(self, model_args: ModelArgs):
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = model_args.dim // model_args.n_heads

self.attn_fn = model_args.attn_fn

self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
Expand All @@ -180,7 +231,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
seqlens: torch.Tensor | None = None,
block_mask: BlockMask | None = None,
):
"""
Forward pass of the attention module.
Expand Down Expand Up @@ -214,7 +265,7 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

output = self.self_attention(xq, xk, xv, seqlens)
output = self.self_attention(xq, xk, xv, block_mask)

output = output.view(bs, seqlen, -1)
return self.wo(output)
Expand All @@ -224,55 +275,21 @@ def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor:
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output

def _flash_attention(self, xq, xk, xv) -> torch.Tensor:
q = rearrange(xq, "b n t h -> b t n h")
k = rearrange(xk, "b n t h -> b t n h")
v = rearrange(xv, "b n t h -> b t n h")
# q/k/b is [b, nh, t, hs] but fa2 expected [b , t, nh, hs]
return flash_attn_func(q, k, v, causal=True)

def _fa_attention_with_seqlens(self, xq, xk, xv, seqlens) -> torch.Tensor:
b = xq.shape[0]
cu_seqlens = (
torch.concat([torch.tensor([0]).to(xq.device), seqlens.cumsum(0)], dim=0).to(torch.int32).to(xq.device)
)
max_seqlen = seqlens.max()

q = rearrange(xq, "b n t h -> (b t) n h")
k = rearrange(xk, "b n t h -> (b t) n h")
v = rearrange(xv, "b n t h -> (b t) n h")
# q/k/v is [b, nh, t, hs] but fa expected [b * t, nh, hs]

y = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
)

y = rearrange(y, "(b t) n h -> b t n h", b=b)
return y
def _flex_attention_with_seqlens(self, xq, xk, xv, block_mask: BlockMask) -> torch.Tensor:
output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output
# output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
# output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
# return output

def self_attention(
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, seqlens: torch.Tensor | None = None
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, block_mask: BlockMask | None = None
) -> torch.Tensor:
if self.attn_fn == "sdpa":
if seqlens is not None:
raise NotImplementedError("SDPA with seqlens is not implemented.")
return self._sdpa_attention(xq, xk, xv)
elif self.attn_fn == "flash":
if not flash_attn_available:
raise RuntimeError("Flash attention is not available. Please install flash_attn.")
if seqlens is not None:
return self._fa_attention_with_seqlens(xq, xk, xv, seqlens)
else:
return self._flash_attention(xq, xk, xv)
if block_mask is not None:
return self._flex_attention_with_seqlens(xq, xk, xv, block_mask)
else:
raise ValueError(f"Unknown attention function: {self.attn_fn}")
return self._sdpa_attention(xq, xk, xv)


class FeedForward(nn.Module):
Expand Down Expand Up @@ -365,7 +382,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
seqlens: torch.Tensor | None = None,
block_mask: BlockMask | None = None,
):
"""
Perform a forward pass through the TransformerBlock.
Expand All @@ -378,7 +395,7 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
h = x + self.attention(self.attention_norm(x), freqs_cis, seqlens=seqlens)
h = x + self.attention(self.attention_norm(x), freqs_cis, block_mask=block_mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out

Expand Down Expand Up @@ -475,14 +492,13 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def forward(self, tokens: torch.Tensor, seqlens: torch.Tensor | None = None):
def forward(self, tokens: torch.Tensor, block_mask: BlockMask | None = None):
"""
Perform a forward pass through the Transformer model.

Args:
tokens (torch.Tensor): Input token indices.
seqlens (torch.Tensor | None): Sequence lengths tensor for packing.

block_mask (BlockMask | None): Block mask for attention.
Returns:
torch.Tensor: Output logits after applying the Transformer model.

Expand All @@ -491,7 +507,7 @@ def forward(self, tokens: torch.Tensor, seqlens: torch.Tensor | None = None):
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis, seqlens=seqlens)
h = layer(h, self.freqs_cis, block_mask=block_mask)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
Expand Down
19 changes: 9 additions & 10 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Literal
import time
import warnings
import psutil
from pydantic import model_validator
from multiprocessing.process import _children
Expand All @@ -19,6 +20,7 @@
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss
from zeroband.models.llama.model import create_block_mask_from_seqlens

from zeroband.utils import (
FakeTokenizer,
Expand Down Expand Up @@ -75,12 +77,12 @@ class TrainConfig(BaseConfig):
memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True
attn_fn: Literal["flash", "sdpa"] = "flash"
attn_fn: Literal["flash", "sdpa"] | None = None

@model_validator(mode="after")
def validate_attn_fn(self):
if self.attn_fn == "sdpa" and self.sequence_packing:
raise ValueError("SDPA does not support sequence packing")
if self.attn_fn is not None:
warnings.warn("attn_fn argument is deprecated")

return self

Expand Down Expand Up @@ -162,7 +164,6 @@ def train(config: Config):
config.type_model,
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
seq_length=config.data.seq_length,
attn_fn=config.train.attn_fn,
)

model = model.to(world_info.local_rank)
Expand Down Expand Up @@ -360,14 +361,12 @@ def train(config: Config):
input_ids = batch["input_ids"].to("cuda")
labels = batch["labels"].to("cuda")
if config.train.sequence_packing:
seqlens = batch["seqlens"].to("cuda")
# seqlens has a dynamic shape but fixed dimension, this allow to still torch compile
# https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html
torch._dynamo.mark_dynamic(seqlens, 0)
seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]]
block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None
else:
seqlens = None
block_mask = None

logits = model(tokens=input_ids, seqlens=seqlens).contiguous()
logits = model(tokens=input_ids, block_mask=block_mask).contiguous()
flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
flatten_labels = rearrange(labels, "b seq -> (b seq)")

Expand Down
Loading
Loading