Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Voltron v0 #1

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f43824f
wip for v0 for testing
winglian Apr 13, 2024
b7cd6ea
mini fixes
winglian Apr 13, 2024
9b56525
fix install and train
winglian Apr 13, 2024
a357c7f
fix config and dataset
winglian Apr 13, 2024
1bfbdb5
fix lr
winglian Apr 13, 2024
072b74a
fix text field of dataset
winglian Apr 13, 2024
38a872d
improve data handling
winglian Apr 13, 2024
6c0e92b
flesh out the model w/ attn
winglian Apr 13, 2024
91c4fa2
fix args/kwargs ordering
winglian Apr 13, 2024
c298db1
use LlamaBitMGQA
winglian Apr 13, 2024
b5487fd
fix order of init for module
winglian Apr 13, 2024
a7854a2
make tinier and fix dataset map
winglian Apr 13, 2024
bc4ce8a
fix back to use dataset.data.columns
winglian Apr 13, 2024
122316e
make sure to remove extra columns
winglian Apr 13, 2024
a1c7b2a
fix data loop and make tinier
winglian Apr 13, 2024
0d76c4e
use generic collator to pad equally
winglian Apr 13, 2024
91d4a11
accont for position_ids in mod block
winglian Apr 13, 2024
81e18b9
flesh out rotary embeddigs
winglian Apr 13, 2024
bc5ac07
misc fixes
winglian Apr 13, 2024
c5a4c71
fix tokenizer and activation checkpointing
winglian Apr 13, 2024
c493b65
more fixes
winglian Apr 13, 2024
8aa8d81
remove hard dependencies from axolotl
winglian Apr 13, 2024
e6edf93
remove more hard deps
winglian Apr 13, 2024
270150f
re-enable DWA again
winglian Apr 13, 2024
49cc04a
actually check for dwa
winglian Apr 13, 2024
4437672
wandb on main rank only
winglian Apr 13, 2024
c2f804c
fix modulo for log steps
winglian Apr 13, 2024
88f25a9
attempt to use accelerator loop
winglian Apr 13, 2024
a10c31a
update configuration
winglian Apr 13, 2024
8239c6e
wip rms norm
winglian Apr 14, 2024
2b2f332
use apex rms norm optim
winglian Apr 14, 2024
cba6e66
queued dataloader and gradient norm
winglian Apr 14, 2024
ffafd7a
fixes for loss calc, grad accum, dataloader for dispatch_batches
winglian Apr 14, 2024
0255ffe
tweak size names
winglian Apr 15, 2024
d289d98
upcast/downcast
winglian Apr 15, 2024
84d755a
integrate infini-attention
winglian Apr 15, 2024
665530b
handle position_id if passed, throw it on the floor
winglian Apr 15, 2024
2439688
match infini-attention segment len to mixture of depth
winglian Apr 15, 2024
a77e6ae
misc fixes for integrations
winglian Apr 15, 2024
6a49def
make infini-attention work
winglian Apr 15, 2024
555087a
fix dimensions passed to infini-attention
winglian Apr 15, 2024
b117faa
fix perplexity calculation and add quick instructions
winglian Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ Assembling the best SotA AI techniques into a unified model

https://twitter.com/winglian/status/1778675583817326842

## Easy Start

Use the official Nividia/Pytorch docker container @ `nvcr.io/nvidia/pytorch:24.03-py3`

```bash
git clone https://github.com/OpenAccess-AI-Collective/voltronformers.git
cd voltronformers
pip install -e .
accelerate launch train.py
```


# References

## BitNet
Expand Down
36 changes: 36 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[project]
name = "voltronformers"
dynamic = ["version"]
requires-python = ">= 3.10"
dependencies = [
"accelerate",
"addict",
"bitnet",
"schedulefree",
"bitsandbytes",
"datasets",
"einops",
"flash-attn",
"mosaicml-streaming",
"numba",
"numpy",
"safetensors",
"wandb",
"tqdm",
"transformers==4.39.3",
"zstandard",
"denseformer @ git+https://github.com/epfml/DenseFormer.git@main",
]
maintainers = [
{name="Wing Lian", email="[email protected]"},
]
description = "voltronformers: Assembling the best SotA AI techniques into a unified model"

[project.optional-dependencies]
dev = [
"tox",
"pre-commit",
"black",
"mypy",
"pytest",
]
Empty file added src/__init__.py
Empty file.
Empty file added src/voltronformer/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions src/voltronformer/bitlinear/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# from .cg123 import BitLinear
from .official import BitLinear
from .attention import scaled_dot_product_gqa
143 changes: 143 additions & 0 deletions src/voltronformer/bitlinear/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import Optional

import torch
import torch.nn.functional as F
from einops import einsum, rearrange
from torch import Tensor


def scaled_dot_product_gqa(
query: Tensor,
key: Tensor,
value: Tensor,
dropout: float = 0.0,
scale: Optional[float] = None,
mask: Optional[Tensor] = None,
is_causal: Optional[bool] = None,
need_weights: bool = False,
average_attn_weights: bool = False,
force_grouped: bool = False,
):
"""Scaled dot product attention with support for grouped queries.

Einstein notation:
- b: batch size
- n / s: sequence length
- h: number of heads
- g: number of groups
- d: dimension of query/key/value

Args:
query: Query tensor of shape (b, n, h, d)
key: Key tensor of shape (b, s, h, d)
value: Value tensor of shape (b, s, h, d)
dropout: Dropout probability (default: 0.0)
scale: Scale factor for query (default: d_query ** 0.5)
mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is
applied to all 'n' rows of the attention matrix. (default: None)
force_grouped: If True, apply grouped-query attention even if the number of
heads is equal for query, key, and value. (default: False)

Returns:
2-tuple of:
- Attention output with shape (b, n, h, d)
- (Optional) Attention weights with shape (b, h, n, s). Only returned if
'need_weights' is True.
"""
if (mask is not None) and (is_causal is not None):
raise ValueError(
"Only one of 'mask' and 'is_causal' should be provided, but got both."
)
elif not query.ndim == key.ndim == value.ndim == 4:
raise ValueError(
f"Expected query, key, and value to be 4-dimensional, but got shapes "
f"{query.shape}, {key.shape}, and {value.shape}."
)

# Move sequence length dimension to axis 2.
# This makes the attention operations below *much* faster.
query = rearrange(query, "b n h d -> b h n d")
key = rearrange(key, "b s h d -> b h s d")
value = rearrange(value, "b s h d -> b h s d")

bq, hq, nq, dq = query.shape
bk, hk, nk, dk = key.shape
bv, hv, nv, dv = value.shape
if not (bq == bk == bv and dq == dk == dv):
raise ValueError(
"Expected query, key, and value to have the same batch size (dim=0) and "
f"embedding dimension (dim=3), but got query: {query.shape}, "
f"key: {key.shape}, and value: {value.shape}."
)
elif (hk != hv) or (nk != nv):
raise ValueError(
"Expected key and value to have the same size in dimensions 1 and 2, but "
f"got key: {key.shape} and value: {value.shape}."
)
elif hq % hk != 0:
raise ValueError(
"Expected query heads to be a multiple of key/value heads, but got "
f"query: {query.shape} and key/value: {key.shape}."
)

if scale is None:
scale = query.size(-1) ** 0.5
query = query / scale

num_head_groups = hq // hk
if num_head_groups > 1 or force_grouped:
# Separate the query heads into 'num_head_groups' chunks, and fold the group
# dimension into the batch dimension. This allows us to compute the attention
# for each head in parallel, then sum over all of the groups at the end.
query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
similarity = einsum(query, key, "b g h n d, b h s d -> b h n s")
else:
# If the number of query/key heads is equal, we can skip grouping the queries,
# and just use the standard sdot product attention.
similarity = einsum(query, key, "b h n d, b h s d -> b h n s")

if is_causal:
# Mask out the upper triangular portion of the attention matrix. This prevents
# the model from attending to tokens in the future.
mask = torch.ones(
(bq, nq, nk),
device=query.device,
dtype=torch.bool,
).tril_()

if mask is not None:
# Expand mask to match the shape of the attention matrix.
# If mask is 2D, assume that it is applied to the key/value sequence dimension.
# Else if mask is 3D, assume that it is applied to the query/key/value sequence
# dimension for all attention heads.
#
# Users could also provide a 4D mask, which is applied to the query/key/value
# sequence dimension for each attention head (though I don't have a particular
# use case in mind for that).
if mask.ndim == 2:
mask = rearrange(mask, "b s -> b () () s")
elif mask.ndim == 3:
mask = rearrange(mask, "b n s -> b () n s")
# Mask similarity values by setting them to negative infinity. This guarantees
# that they will not contribute to the softmax computation below.
similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min)

attention = F.softmax(similarity / scale, dim=-1, dtype=torch.float32).to(dtype=query.dtype)
if dropout > 0.0:
attention = F.dropout(attention, p=dropout)

# Apply attention matrix to the value Tensor.
out = einsum(attention, value, "b h n s, b h s d -> b h n d")
# Move head dimension back to axis 2
out = rearrange(out, "b h n d -> b n h d")

attn_weights: Optional[Tensor] = None
if need_weights:
# Move the sequence dimensions back to positions 1, 2. Move the head dimension
# to position 3. This more closely matches the return shape of the attention
# output: (b, n, h, d).
attn_weights = rearrange(attention, "b h n s -> b n s h")
if average_attn_weights:
attn_weights = attn_weights.mean(dim=1)

return out, attn_weights
172 changes: 172 additions & 0 deletions src/voltronformer/bitlinear/cg123.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
Implementation of the BitLinear layer described in the papers:

1. "BitNet: Scaling 1-bit Transformers for Large Language Models"
2. "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits"

References:
- https://arxiv.org/abs/2310.11453
- https://arxiv.org/abs/2402.17764
"""

#!/usr/bin/env python3
# Copyright (C) 2024 Charles O. Goddard

import math
from typing import NamedTuple, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


def _ste(x: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
"""Straight-through estimator."""
return x0 + (x - x0).detach()


@torch.compile()
def _quantize(
x: Optional[torch.Tensor], is_input: bool, num_groups: int, eps: float
) -> Tuple[torch.Tensor, torch.Tensor]:
if x is None:
return None, None

x0 = x
if is_input:
# split last dimension into num_groups
x = x.view(list(x.shape[:-1]) + [num_groups, -1])
scale_factor = x.abs().max(dim=-1, keepdim=True).values
else:
# first dimension is output features, so split that
x = x.view([num_groups, -1] + list(x.shape[1:]))
scale_factor = x.abs().mean(dim=list(range(1, len(x.shape))), keepdim=True)

x_scaled = x / (scale_factor + eps)
if is_input:
x_q = (x_scaled * 127).clamp(-127, 127).to(torch.int8)
else:
x_q = x_scaled.round().clamp(-1, 1).to(torch.int8)

# adjust scale_factor to match shape returned for input
scale_factor = scale_factor.view(1, 1, num_groups, 1)

return _ste(x_q, x_scaled).view_as(x0), scale_factor


class QuantizedWeights(NamedTuple):
"""Quantized weight and optional bias tensor for BitLinear."""

w_q: torch.Tensor
bias_q: Optional[torch.Tensor]
beta: torch.Tensor


@torch.compile()
def _quantize_weights(
weight: torch.Tensor,
bias: Optional[torch.Tensor],
num_groups: int,
eps: float,
) -> QuantizedWeights:
w_q, beta = _quantize(weight, is_input=False, num_groups=num_groups, eps=eps)
bias_q, _ = _quantize(bias, is_input=True, num_groups=num_groups, eps=eps)
# bias assumes the scale factor of weights
return QuantizedWeights(w_q=w_q, bias_q=bias_q, beta=beta)


def _pack_ternary(x: torch.Tensor) -> torch.Tensor:
"""Pack ternary float tensor into int8 tensor. Uses ~1.6 bits per element."""

x_packed = torch.empty(
x.shape[:-1] + (math.ceil(x.shape[-1] / 5)), dtype=torch.int8
)
for i in range(0, x.shape[-1], 5):
chunk = x[..., i : i + 5].to(torch.int8).view(x.shape[:-1] + (1, 5))
# -1 -> 0, 0 -> 1, 1 -> 2
chunk = chunk + 1
# store as base-3 number
chunk = (
chunk
* torch.tensor([1, 3, 9, 27, 81], device=chunk.device, dtype=chunk.dtype)
).sum(dim=-1)
x_packed[..., i // 5] = chunk
return x_packed


class BitLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
*args,
preserve_scale: bool = False,
num_groups: int = 1,
eps: float = 1e-7,
bias: bool = False,
**kwargs,
):
if num_groups < 1:
raise ValueError("num_groups must be >= 1")
if num_groups > 1 and out_features % num_groups != 0:
raise ValueError("out_features must be divisible by num_groups")

super().__init__(in_features, out_features, *args, bias=bias, **kwargs)
self.input_norm = nn.LayerNorm(self.in_features, elementwise_affine=False)
self.preserve_scale = preserve_scale
self.num_groups = num_groups
self.eps = eps

@torch.compile()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.input_norm(x)
x_q, gamma = _quantize(
x, is_input=True, num_groups=self.num_groups, eps=self.eps
)
w_q, bias_q, beta = _quantize_weights(
self.weight, self.bias, num_groups=self.num_groups, eps=self.eps
)

y = F.linear(x_q, w_q, bias_q)
y = y.to(x.dtype) / 127
if self.preserve_scale:
y_grouped = y.view(list(y.shape[:-1]) + [self.num_groups, -1])
y = (y_grouped * gamma * beta).reshape_as(y)

return y


class BitConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
*args,
preserve_scale: bool = False,
eps: float = 1e-7,
bias: bool = False,
**kwargs,
):
super().__init__(
in_channels, out_channels, kernel_size, *args, bias=bias, **kwargs
)
self.input_norm = nn.GroupNorm(1, self.in_channels, affine=False)
self.preserve_scale = preserve_scale
self.eps = eps

@torch.compile()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.input_norm(x)
x_q, gamma = _quantize(x, is_input=True, num_groups=1, eps=self.eps)
w_q, bias_q, beta = _quantize_weights(
self.weight, self.bias, num_groups=1, eps=self.eps
)

y = F.conv2d(x_q, w_q, bias_q, self.stride, self.padding, self.dilation)
y = y.to(x.dtype) / 127
if self.preserve_scale:
y_grouped = y.view(list(y.shape[:-1]) + [1, -1])
y = (y_grouped * gamma * beta).reshape_as(y)

return y
Loading