Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 1, 2024
1 parent 1d5dc56 commit 0fa81e6
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 130 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "ZeroBand is a production ready codebase for decentralized trainin
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.4.1",
"torch==2.5.1",
"numpy",
"setuptools",
"transformers>=4.44.2",
Expand Down
6 changes: 3 additions & 3 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def load_state_dict(self, state_dict):
self.state = SequencePackingDataSetState(**state_dict["state"])


def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor]:
def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor | list[torch.LongTensor]]:
assert samples[0].keys() == {"input_ids", "labels", "seqlens"}

inputs_ids = []
Expand All @@ -135,12 +135,12 @@ def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.Lo
inputs_ids.append(sample["input_ids"])
labels.append(sample["labels"])

seqlens.extend(sample["seqlens"])
seqlens.extend(torch.Tensor(sample["seqlens"]).long())

return {
"input_ids": torch.stack(inputs_ids, dim=0),
"labels": torch.stack(labels, dim=0),
"seqlens": torch.Tensor(seqlens).long(),
"seqlens": seqlens,
}


Expand Down
6 changes: 4 additions & 2 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@


def get_model(
name_model: str, type_model: str, vocab_size: int, seq_length: int, attn_fn: str
name_model: str,
type_model: str,
vocab_size: int,
seq_length: int,
) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

Expand All @@ -94,6 +97,5 @@ def get_model(

config.vocab_size = vocab_size
config.max_seq_len = seq_length
config.attn_fn = attn_fn

return Transformer(config), config
86 changes: 32 additions & 54 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@


from dataclasses import dataclass
from importlib.util import find_spec
from typing import Literal, Optional, Tuple
from einops import rearrange
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from zeroband.models.norms import build_norm

flash_attn_available = find_spec("flash_attn") is not None
if flash_attn_available:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

flex_attention_compiled = torch.compile(flex_attention, dynamic=False)


@dataclass
Expand All @@ -45,8 +43,6 @@ class ModelArgs:
depth_init: bool = True
norm_type: str = "fused_rmsnorm"

attn_fn: Literal["sdpa", "flash"] = "sdpa"


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Expand Down Expand Up @@ -138,6 +134,15 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
)


def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor:
"""Converts list of sequence lengths to document indices tensor.
Example:
seqlens = [tensor([2,2,1]), tensor([2,2,1])] # List of 2 tensors
docs = [[0,0,1,1,2], [0,0,1,1,2]] # Each doc_id repeated per its length
"""
return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens])


class Attention(nn.Module):
"""
Multi-head attention module.
Expand All @@ -164,8 +169,6 @@ def __init__(self, model_args: ModelArgs):
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = model_args.dim // model_args.n_heads

self.attn_fn = model_args.attn_fn

self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
Expand All @@ -180,7 +183,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
seqlens: torch.Tensor | None = None,
seqlens: list[torch.Tensor] | None = None,
):
"""
Forward pass of the attention module.
Expand Down Expand Up @@ -224,55 +227,30 @@ def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor:
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output

def _flash_attention(self, xq, xk, xv) -> torch.Tensor:
q = rearrange(xq, "b n t h -> b t n h")
k = rearrange(xk, "b n t h -> b t n h")
v = rearrange(xv, "b n t h -> b t n h")
# q/k/b is [b, nh, t, hs] but fa2 expected [b , t, nh, hs]
return flash_attn_func(q, k, v, causal=True)

def _fa_attention_with_seqlens(self, xq, xk, xv, seqlens) -> torch.Tensor:
b = xq.shape[0]
cu_seqlens = (
torch.concat([torch.tensor([0]).to(xq.device), seqlens.cumsum(0)], dim=0).to(torch.int32).to(xq.device)
)
max_seqlen = seqlens.max()

q = rearrange(xq, "b n t h -> (b t) n h")
k = rearrange(xk, "b n t h -> (b t) n h")
v = rearrange(xv, "b n t h -> (b t) n h")
# q/k/v is [b, nh, t, hs] but fa expected [b * t, nh, hs]

y = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
def _flex_attention_with_seqlens(self, xq, xk, xv, seqlens: list[torch.Tensor]) -> torch.Tensor:
docs = seqlens_to_docs_tensor(seqlens).to(xq.device)
batch_size, max_seq_len = docs.shape

def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[b, q_idx] == docs[b, kv_idx]
return causal_mask & document_mask

block_mask = create_block_mask(
document_causal_mask, batch_size, None, max_seq_len, max_seq_len, device="cuda", _compile=True
)

y = rearrange(y, "(b t) n h -> b t n h", b=b)
return y
output = flex_attention_compiled(xq, xk, xv, block_mask=block_mask)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output

def self_attention(
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, seqlens: torch.Tensor | None = None
self, xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, seqlens: list[torch.Tensor] | None = None
) -> torch.Tensor:
if self.attn_fn == "sdpa":
if seqlens is not None:
raise NotImplementedError("SDPA with seqlens is not implemented.")
return self._sdpa_attention(xq, xk, xv)
elif self.attn_fn == "flash":
if not flash_attn_available:
raise RuntimeError("Flash attention is not available. Please install flash_attn.")
if seqlens is not None:
return self._fa_attention_with_seqlens(xq, xk, xv, seqlens)
else:
return self._flash_attention(xq, xk, xv)
if seqlens is not None:
return self._flex_attention_with_seqlens(xq, xk, xv, seqlens)
else:
raise ValueError(f"Unknown attention function: {self.attn_fn}")
return self._sdpa_attention(xq, xk, xv)


class FeedForward(nn.Module):
Expand Down
14 changes: 8 additions & 6 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Literal
import time
import warnings
import psutil
from pydantic import model_validator
from multiprocessing.process import _children
Expand Down Expand Up @@ -75,12 +76,12 @@ class TrainConfig(BaseConfig):
memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True
attn_fn: Literal["flash", "sdpa"] = "flash"
attn_fn: Literal["flash", "sdpa"] | None = None

@model_validator(mode="after")
def validate_attn_fn(self):
if self.attn_fn == "sdpa" and self.sequence_packing:
raise ValueError("SDPA does not support sequence packing")
if self.attn_fn is not None:
warnings.warn("attn_fn argument is deprecated")

return self

Expand Down Expand Up @@ -162,7 +163,6 @@ def train(config: Config):
config.type_model,
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
seq_length=config.data.seq_length,
attn_fn=config.train.attn_fn,
)

model = model.to(world_info.local_rank)
Expand Down Expand Up @@ -360,10 +360,12 @@ def train(config: Config):
input_ids = batch["input_ids"].to("cuda")
labels = batch["labels"].to("cuda")
if config.train.sequence_packing:
seqlens = batch["seqlens"].to("cuda")
seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]]

# seqlens has a dynamic shape but fixed dimension, this allow to still torch compile
# https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html
torch._dynamo.mark_dynamic(seqlens, 0)
# torch._dynamo.mark_dynamic(seqlens, 0)
logger.debug(f"seqlens: {seqlens}")
else:
seqlens = None

Expand Down
29 changes: 21 additions & 8 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,36 @@ def test_attn(llama_config: ModelArgs):

freqs_cis = get_freqs_cis(llama_config)
input_ = torch.rand(bs, seq_len, llama_config.dim).to("cuda")
seqlens = [torch.Tensor([seq_len]).int().to("cuda") for _ in range(bs)]

attn = Attention(llama_config).to("cuda")
attn.attn_fn = "sdpa"

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output_sdpa = attn(input_, freqs_cis)

attn.attn_fn = "flash"
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output_fa = attn(input_, freqs_cis)
output_flex = attn(input_, freqs_cis, seqlens=seqlens)

rtol = ERROR_RTOL[torch.bfloat16]
atol = ERROR_ATOL[torch.bfloat16]
assert output_sdpa.shape == output_fa.shape
torch.testing.assert_close(output_sdpa, output_fa, rtol=rtol, atol=atol)
assert output_sdpa.shape == output_flex.shape
torch.testing.assert_close(output_sdpa, output_flex, rtol=rtol, atol=atol)


def test_packing_simple(llama_config: ModelArgs):
seq_len = 512
bs = 8

freqs_cis = get_freqs_cis(llama_config)
input_ = torch.rand(bs, seq_len, llama_config.dim).to("cuda")
seqlens = [torch.Tensor([seq_len // 4] * 4).int().to("cuda") for _ in range(bs)]

attn = Attention(llama_config).to("cuda")

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = attn(input_, freqs_cis, seqlens=seqlens)

assert output.shape == (bs, seq_len, llama_config.dim)


def test_sequence_packing_two_time_same_sequence(llama_config: ModelArgs):
Expand All @@ -73,15 +88,13 @@ def test_sequence_packing_two_time_same_sequence(llama_config: ModelArgs):
We then pass the packed sequence to the attention layer and check that the output for each sequence is the same.
"""

llama_config.attn_fn = "flash"
model = Attention(llama_config).to("cuda")

emb = torch.nn.Embedding(10, llama_config.dim).to("cuda")

seq = [2, 1, 4, 8]
input_stuff_raw = torch.Tensor([seq + seq]).long().to("cuda")
seqlens = [len(seq), len(seq)]
seqlens = torch.Tensor(seqlens).int().to("cuda")
seqlens = [torch.Tensor([len(seq), len(seq)]).int().to("cuda")]

input_stuff = emb(input_stuff_raw)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_z_loss():
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--optim.z_loss"])


@pytest.mark.parametrize("packing", [True, False])
@pytest.mark.parametrize("packing", [True]) # , False])
def test_packing(packing: bool):
num_gpus = [2, 1]
packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing"
Expand Down
Loading

0 comments on commit 0fa81e6

Please sign in to comment.