Skip to content

Commit

Permalink
deprecated fa2 and use flex attention (#166)
Browse files Browse the repository at this point in the history
* wip

* add debug script

* fix block size

* only use one block

* refactor: do block mask before torch compile

* add example

* remove fa2 dependencies from docs and ci
  • Loading branch information
samsja authored Dec 10, 2024
1 parent 8f334cb commit a116ef1
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 173 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,10 @@ jobs:
enable-cache: true
cache-dependency-glob: "uv.lock"



- name: Set up Python
run: uv python install 3.10.13

- name: Install the project
run: uv sync --all-extras --dev

- name: Install flash attention
run: uv pip install flash-attn --no-build-isolation

- name: Run tests
run: uv run pytest tests
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ sudo apt install iperf -y
uv venv
source .venv/bin/activate
uv sync --extra all
uv pip install flash-attn==2.6.3 --no-build-isolation
git submodule update --init --recursive
```

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "ZeroBand is a production ready codebase for decentralized trainin
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.4.1",
"torch==2.5.1",
"numpy",
"setuptools",
"transformers>=4.44.2",
Expand Down
5 changes: 1 addition & 4 deletions scripts/install/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ main() {

log_info "Installing dependencies..."
uv sync --extra all

log_info "Installing flash-attn..."
uv pip install flash-attn==2.6.3 --no-build-isolation


log_info "Updating git submodules..."
git submodule update --init --recursive

Expand Down
6 changes: 3 additions & 3 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def load_state_dict(self, state_dict):
self.state = SequencePackingDataSetState(**state_dict["state"])


def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor]:
def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor | list[torch.LongTensor]]:
assert samples[0].keys() == {"input_ids", "labels", "seqlens"}

inputs_ids = []
Expand All @@ -136,12 +136,12 @@ def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.Lo
inputs_ids.append(sample["input_ids"])
labels.append(sample["labels"])

seqlens.extend(sample["seqlens"])
seqlens.append(torch.Tensor(sample["seqlens"]).long())

return {
"input_ids": torch.stack(inputs_ids, dim=0),
"labels": torch.stack(labels, dim=0),
"seqlens": torch.Tensor(seqlens).long(),
"seqlens": seqlens,
}


Expand Down
6 changes: 4 additions & 2 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@


def get_model(
name_model: str, type_model: str, vocab_size: int, seq_length: int, attn_fn: str
name_model: str,
type_model: str,
vocab_size: int,
seq_length: int,
) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

Expand All @@ -94,6 +97,5 @@ def get_model(

config.vocab_size = vocab_size
config.max_seq_len = seq_length
config.attn_fn = attn_fn

return Transformer(config), config
142 changes: 79 additions & 63 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,31 @@


from dataclasses import dataclass
from importlib.util import find_spec
from typing import Literal, Optional, Tuple
from einops import rearrange
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from zeroband.models.norms import build_norm

flash_attn_available = find_spec("flash_attn") is not None
if flash_attn_available:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE

_flex_attention_compiled = torch.compile(flex_attention, dynamic=False)


# copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27
# We cannot do nested compile, but flex attention only has perf benefits
# when compiled. To insulate it from the compiler, we wrap it with
# compiler.disable so that it can be used regardless of whether the model
# is compiled or not, and flex attention always remains compiled.
@torch.compiler.disable(recursive=False)
def flex_attention_compiled(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_mask: BlockMask,
) -> torch.Tensor:
return _flex_attention_compiled(q, k, v, block_mask=block_mask)


@dataclass
Expand All @@ -45,8 +58,6 @@ class ModelArgs:
depth_init: bool = True
norm_type: str = "fused_rmsnorm"

attn_fn: Literal["sdpa", "flash"] = "sdpa"


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Expand Down Expand Up @@ -138,6 +149,48 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
)


def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor:
"""Converts list of sequence lengths to document indices tensor.
Example:
seqlens = [tensor([2,2,1]), tensor([2,2,1])] # List of 2 tensors
docs = [[0,0,1,1,2], [0,0,1,1,2]] # Each doc_id repeated per its length
"""
return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens])


def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask:
"""Creates a block mask from a list of sequence lengths.
Example:
seqlens = [tensor([2,2,1]))] # List of 2 tensors
docs = [[0,0,1,1,2]] # Each doc_id repeated per its length
mask = [[1 1 0 0 0] # First token of doc 0 can see itself and second token of doc 0
[1 1 0 0 0] # Second token of doc 0 can see both tokens of doc 0
[0 0 1 1 0] # First token of doc 1 can see itself and second token of doc 1
[0 0 1 1 0] # Second token of doc 1 can see both tokens of doc 1
[0 0 0 0 1]] # Token of doc 2 can only see itself
"""
docs = seqlens_to_docs_tensor(seqlens).to("cuda")
batch_size, max_seq_len = docs.shape

def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[b, q_idx] == docs[b, kv_idx]
return causal_mask & document_mask

return create_block_mask(
document_causal_mask,
batch_size,
None,
max_seq_len,
max_seq_len,
device="cuda",
_compile=True,
BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE,
)


class Attention(nn.Module):
"""
Multi-head attention module.
Expand All @@ -164,8 +217,6 @@ def __init__(self, model_args: ModelArgs):
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = model_args.dim // model_args.n_heads

self.attn_fn = model_args.attn_fn

self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
Expand All @@ -180,7 +231,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
seqlens: torch.Tensor | None = None,
block_mask: BlockMask | None = None,
):
"""
Forward pass of the attention module.
Expand Down Expand Up @@ -214,7 +265,7 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

output = self.self_attention(xq, xk, xv, seqlens)
output = self.self_attention(xq, xk, xv, block_mask)

output = output.view(bs, seqlen, -1)
return self.wo(output)
Expand All @@ -224,55 +275,21 @@ def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor:
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output

def _flash_attention(self, xq, xk, xv) -> torch.Tensor:
q = rearrange(xq, "b n t h -> b t n h")
k = rearrange(xk, "b n t h -> b t n h")
v = rearrange(xv, "b n t h -> b t n h")
# q/k/b is [b, nh, t, hs] but fa2 expected [b , t, nh, hs]
return flash_attn_func(q, k, v, causal=True)

def _fa_attention_with_seqlens(self, xq, xk, xv, seqlens) -> torch.Tensor:
b = xq.shape[0]
cu_seqlens = (
torch.concat([torch.tensor([0]).to(xq.device), seqlens.cumsum(0)], dim=0).to(torch.int32).to(xq.device)
)
max_seqlen = seqlens.max()

q = rearrange(xq, "b n t h -> (b t) n h")
k = rearrange(xk, "b n t h -> (b t) n h")
v = rearrange(xv, "b n t h -> (b t) n h")
# q/k/v is [b, nh, t, hs] but fa expected [b * t, nh, hs]

y = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
)

y = rearrange(y, "(b t) n h -> b t n h", b=b)
return y
def _flex_attention_with_seqlens(self, xq, xk, xv, block_mask: BlockMask) -> torch.Tensor:
output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output
# output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
# output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
# return output

def self_attention(
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, seqlens: torch.Tensor | None = None
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, block_mask: BlockMask | None = None
) -> torch.Tensor:
if self.attn_fn == "sdpa":
if seqlens is not None:
raise NotImplementedError("SDPA with seqlens is not implemented.")
return self._sdpa_attention(xq, xk, xv)
elif self.attn_fn == "flash":
if not flash_attn_available:
raise RuntimeError("Flash attention is not available. Please install flash_attn.")
if seqlens is not None:
return self._fa_attention_with_seqlens(xq, xk, xv, seqlens)
else:
return self._flash_attention(xq, xk, xv)
if block_mask is not None:
return self._flex_attention_with_seqlens(xq, xk, xv, block_mask)
else:
raise ValueError(f"Unknown attention function: {self.attn_fn}")
return self._sdpa_attention(xq, xk, xv)


class FeedForward(nn.Module):
Expand Down Expand Up @@ -365,7 +382,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
seqlens: torch.Tensor | None = None,
block_mask: BlockMask | None = None,
):
"""
Perform a forward pass through the TransformerBlock.
Expand All @@ -378,7 +395,7 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.
"""
h = x + self.attention(self.attention_norm(x), freqs_cis, seqlens=seqlens)
h = x + self.attention(self.attention_norm(x), freqs_cis, block_mask=block_mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out

Expand Down Expand Up @@ -475,14 +492,13 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def forward(self, tokens: torch.Tensor, seqlens: torch.Tensor | None = None):
def forward(self, tokens: torch.Tensor, block_mask: BlockMask | None = None):
"""
Perform a forward pass through the Transformer model.
Args:
tokens (torch.Tensor): Input token indices.
seqlens (torch.Tensor | None): Sequence lengths tensor for packing.
block_mask (BlockMask | None): Block mask for attention.
Returns:
torch.Tensor: Output logits after applying the Transformer model.
Expand All @@ -491,7 +507,7 @@ def forward(self, tokens: torch.Tensor, seqlens: torch.Tensor | None = None):
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis, seqlens=seqlens)
h = layer(h, self.freqs_cis, block_mask=block_mask)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
Expand Down
19 changes: 9 additions & 10 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Literal
import time
import warnings
import psutil
from pydantic import model_validator
from multiprocessing.process import _children
Expand All @@ -19,6 +20,7 @@
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss
from zeroband.models.llama.model import create_block_mask_from_seqlens

from zeroband.utils import (
FakeTokenizer,
Expand Down Expand Up @@ -75,12 +77,12 @@ class TrainConfig(BaseConfig):
memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True
attn_fn: Literal["flash", "sdpa"] = "flash"
attn_fn: Literal["flash", "sdpa"] | None = None

@model_validator(mode="after")
def validate_attn_fn(self):
if self.attn_fn == "sdpa" and self.sequence_packing:
raise ValueError("SDPA does not support sequence packing")
if self.attn_fn is not None:
warnings.warn("attn_fn argument is deprecated")

return self

Expand Down Expand Up @@ -162,7 +164,6 @@ def train(config: Config):
config.type_model,
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
seq_length=config.data.seq_length,
attn_fn=config.train.attn_fn,
)

model = model.to(world_info.local_rank)
Expand Down Expand Up @@ -360,14 +361,12 @@ def train(config: Config):
input_ids = batch["input_ids"].to("cuda")
labels = batch["labels"].to("cuda")
if config.train.sequence_packing:
seqlens = batch["seqlens"].to("cuda")
# seqlens has a dynamic shape but fixed dimension, this allow to still torch compile
# https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html
torch._dynamo.mark_dynamic(seqlens, 0)
seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]]
block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None
else:
seqlens = None
block_mask = None

logits = model(tokens=input_ids, seqlens=seqlens).contiguous()
logits = model(tokens=input_ids, block_mask=block_mask).contiguous()
flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
flatten_labels = rearrange(labels, "b seq -> (b seq)")

Expand Down
Loading

0 comments on commit a116ef1

Please sign in to comment.