Skip to content

Commit

Permalink
Merge pull request #154 from Modalities/add_swiglu
Browse files Browse the repository at this point in the history
Manual SwiGLU implementation
  • Loading branch information
le1nux committed Jun 19, 2024
2 parents ed3fb62 + 333fcc0 commit 4aa2e88
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 15 deletions.
18 changes: 18 additions & 0 deletions CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Changelog

| PR | Type | Ref. Issue(s) | Breaking Changes |PR Description|
|------------------|------------|---------------|------------------|------------------------------------------------------------------------------------------------|
| [#154](pr-154-manual-swiglu-implementation) | Bug Fix | [#14](https://github.com/Modalities/modalities/issues/14) | **Yes** | Towards stable modalities version |
| | | | | |



## PR #154 Manual SwiGLU implementation

This [PR](https://github.com/Modalities/modalities/pull/154) adds a manual SwiGLU implementation. The original one from xops was imcompatible with activation checkpointing (see issue [#14](https://github.com/Modalities/modalities/issues/14))

**General changes:**
* replaces xops swiglu imlementation with custom reimplementation

**Breaking changes:**
* renaming of `fused_swiglu` to `swiglu` in `ActivationType` (see [here](https://github.com/Modalities/modalities/pull/154/commits/90fb3bd06a407333423cffeab486711e26ef8ddf) for the respective config changes)
2 changes: 1 addition & 1 deletion config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ model:
n_embd: 768
dropout: 0.0
bias: true
activation: fused_swiglu
activation: swiglu
epsilon: 1e-5
n_pool_head: 8
n_vision_queries: 256
Expand Down
2 changes: 1 addition & 1 deletion config_files/training/config_example_openGPTx_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ model:
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
activation: fused_swiglu
activation: swiglu
weight_init:
mean: 0.0
std: 0.02
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ dependencies = [
"SentencePiece",
"accelerate",
"rich",
"xformers",
"omegaconf",
"pydantic",
"click",
"click_pathlib",
"jq",
"xformers",
"class_resolver",
"wandb",
"packaging",
"einops>=0.7.0",
"mamba-ssm",
"flash-attn", # install this directly via `pip install flash-attn --no-build-isolation`
"mamba-ssm",
"causal-conv1d>=1.2.0",
Expand Down
7 changes: 3 additions & 4 deletions src/modalities/models/coca/multi_modal_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import Dict

import torch
import xformers.ops as xops
from torch import nn

from modalities.models.gpt2.gpt2_model import ActivationType
from modalities.models.model import NNModel
from modalities.models.model import NNModel, SwiGLU
from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention
from modalities.nn.mlp import MLP

Expand All @@ -32,8 +31,8 @@ def __init__(

if activation == ActivationType.GELU:
mlp = partial(MLP, in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout)
elif activation == ActivationType.FUSED_SWIGLU:
mlp = partial(xops.SwiGLU, in_features=n_embd, hidden_features=ffn_hidden, bias=bias)
elif activation == ActivationType.SWIGLU:
mlp = partial(SwiGLU, n_embd=n_embd, bias=bias)
else:
raise NotImplementedError(f"activation type {activation} not implemented")

Expand Down
14 changes: 7 additions & 7 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import math

from copy import deepcopy
from enum import Enum
from functools import partial
from typing import Annotated, Dict, List, Tuple

import torch
import torch.nn as nn
import xformers.ops as xops
from flash_attn import flash_attn_func
from pydantic import BaseModel, Field, model_validator, validator

from torch.nn import functional as F

from modalities.config.pydanctic_if_types import PydanticPytorchModuleType
from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.model import NNModel
from modalities.models.model import NNModel, SwiGLU
from modalities.util import parse_enum_by_name

# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
Expand Down Expand Up @@ -106,7 +108,7 @@ class QueryKeyValueTransformType(Enum):

class ActivationType(str, Enum):
GELU = "gelu"
FUSED_SWIGLU = "fused_swiglu"
SWIGLU = "swiglu"


class AttentionConfig(BaseModel):
Expand Down Expand Up @@ -293,7 +295,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.dropout(x)
return x


class GPT2Block(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -323,9 +324,8 @@ def __init__(
)
if activation_type == ActivationType.GELU:
self.mlp = TransformerMLP(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias, dropout=dropout)
elif activation_type == ActivationType.FUSED_SWIGLU:
hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256)
self.mlp = xops.SwiGLU(n_embd, hidden_dim, n_embd, bias=False)
elif activation_type == ActivationType.SWIGLU:
self.mlp = SwiGLU(n_embd=n_embd, bias=bias)
else:
raise NotImplementedError("unimplemented activation")

Expand Down
37 changes: 37 additions & 0 deletions src/modalities/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,43 @@ def get_parameters(self) -> Dict[str, torch.Tensor]:
return {name: param for name, param in self.named_parameters()}


class SwiGLU(nn.Module):
def __init__(self, n_embd: int, bias: bool):
super().__init__()

hidden_dim = SwiGLU._get_hidden_dim(n_embd)

self.c_fc = nn.Linear(
in_features=n_embd,
out_features=hidden_dim,
bias=bias,
)
self.silu = nn.SiLU()
self.c_proj = nn.Linear(
in_features=n_embd,
out_features=hidden_dim,
bias=bias,
)
self.out_proj = nn.Linear(
in_features=hidden_dim,
out_features=n_embd,
bias=bias,
)

@staticmethod
def _get_hidden_dim(n_embd: int) -> int:
# Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762)
# To ensure that the number of parameters in the SwiGLU module with its additional
# linear layer are equivalent to the TransformerMLP, we need to adapt the SwiGLU hidden dimension as follows:
# 2 * (n_embd * hidden_dim) == 3 * (n_embd * 2/3 * hidden_dim)
# Besides, we ensure that hidden_dim is the smallest multiple of
# 256 that is greater than or equal the provided hidden_dim
return 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_proj(self.silu(self.c_fc(x)) * self.c_proj(x))


def model_predict_batch(model: nn.Module, batch: DatasetBatch) -> InferenceResultBatch:
forward_result = model.forward(batch.samples)
result_batch = InferenceResultBatch(targets=batch.targets, predictions=forward_result)
Expand Down
32 changes: 32 additions & 0 deletions tests/nn/test_mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from torch import nn

from modalities.models.model import SwiGLU
from modalities.nn.mlp import MLP


Expand All @@ -8,3 +10,33 @@ def test_mlp_forward():
dummy_input = torch.randn(1, 10, 64)
out = model(dummy_input)
assert out.shape == (1, 10, 64)


def test_SwiGLU_forward():
n_embd = 512
bias = True
mlp = SwiGLU(n_embd, bias)

hidden_dim = 1536
assert SwiGLU._get_hidden_dim(n_embd) == hidden_dim

n_embd = 511
assert SwiGLU._get_hidden_dim(n_embd) == hidden_dim

n_embd = 512

# batch size x sequence length x embedding dim
input_tensor = torch.randn(1, 1, n_embd)
output_tensor = mlp(input_tensor)
assert output_tensor.shape == (1, 1, n_embd)

c_fc = nn.Linear(in_features=n_embd, out_features=hidden_dim, bias=bias)
c_proj = nn.Linear(in_features=n_embd, out_features=hidden_dim, bias=bias)
out_proj = nn.Linear(in_features=hidden_dim, out_features=n_embd, bias=bias)
silu = nn.SiLU()
mlp.c_fc = c_fc
mlp.c_proj = c_proj
mlp.out_proj = out_proj

output_tensor = mlp(input_tensor)
assert torch.all(output_tensor == out_proj(silu(c_fc(input_tensor)) * c_proj(input_tensor)))

0 comments on commit 4aa2e88

Please sign in to comment.