diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md new file mode 100644 index 00000000..c1d4a4c7 --- /dev/null +++ b/CHANGELOG_DEV.md @@ -0,0 +1,18 @@ +# Changelog + +| PR | Type | Ref. Issue(s) | Breaking Changes |PR Description| +|------------------|------------|---------------|------------------|------------------------------------------------------------------------------------------------| +| [#154](pr-154-manual-swiglu-implementation) | Bug Fix | [#14](https://github.com/Modalities/modalities/issues/14) | **Yes** | Towards stable modalities version | +| | | | | | + + + +## PR #154 Manual SwiGLU implementation + +This [PR](https://github.com/Modalities/modalities/pull/154) adds a manual SwiGLU implementation. The original one from xops was imcompatible with activation checkpointing (see issue [#14](https://github.com/Modalities/modalities/issues/14)) + +**General changes:** +* replaces xops swiglu imlementation with custom reimplementation + +**Breaking changes:** +* renaming of `fused_swiglu` to `swiglu` in `ActivationType` (see [here](https://github.com/Modalities/modalities/pull/154/commits/90fb3bd06a407333423cffeab486711e26ef8ddf) for the respective config changes) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index 62e6fd42..734d1351 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -203,7 +203,7 @@ model: n_embd: 768 dropout: 0.0 bias: true - activation: fused_swiglu + activation: swiglu epsilon: 1e-5 n_pool_head: 8 n_vision_queries: 256 diff --git a/config_files/training/config_example_openGPTx_dataset.yaml b/config_files/training/config_example_openGPTx_dataset.yaml index 85114fe4..a0a57114 100644 --- a/config_files/training/config_example_openGPTx_dataset.yaml +++ b/config_files/training/config_example_openGPTx_dataset.yaml @@ -151,7 +151,7 @@ model: n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - activation: fused_swiglu + activation: swiglu weight_init: mean: 0.0 std: 0.02 diff --git a/pyproject.toml b/pyproject.toml index b61cff62..df9289b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,17 +13,16 @@ dependencies = [ "SentencePiece", "accelerate", "rich", - "xformers", "omegaconf", "pydantic", "click", "click_pathlib", "jq", - "xformers", "class_resolver", "wandb", "packaging", "einops>=0.7.0", + "mamba-ssm", "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` "mamba-ssm", "causal-conv1d>=1.2.0", diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index cced19b4..f1c84e05 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -2,11 +2,10 @@ from typing import Dict import torch -import xformers.ops as xops from torch import nn from modalities.models.gpt2.gpt2_model import ActivationType -from modalities.models.model import NNModel +from modalities.models.model import NNModel, SwiGLU from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention from modalities.nn.mlp import MLP @@ -32,8 +31,8 @@ def __init__( if activation == ActivationType.GELU: mlp = partial(MLP, in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout) - elif activation == ActivationType.FUSED_SWIGLU: - mlp = partial(xops.SwiGLU, in_features=n_embd, hidden_features=ffn_hidden, bias=bias) + elif activation == ActivationType.SWIGLU: + mlp = partial(SwiGLU, n_embd=n_embd, bias=bias) else: raise NotImplementedError(f"activation type {activation} not implemented") diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index a65ab0a6..28dd2308 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -1,4 +1,5 @@ import math + from copy import deepcopy from enum import Enum from functools import partial @@ -6,13 +7,14 @@ import torch import torch.nn as nn -import xformers.ops as xops from flash_attn import flash_attn_func from pydantic import BaseModel, Field, model_validator, validator +from torch.nn import functional as F + from modalities.config.pydanctic_if_types import PydanticPytorchModuleType from modalities.config.utils import convert_base_model_config_to_dict -from modalities.models.model import NNModel +from modalities.models.model import NNModel, SwiGLU from modalities.util import parse_enum_by_name # GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT @@ -106,7 +108,7 @@ class QueryKeyValueTransformType(Enum): class ActivationType(str, Enum): GELU = "gelu" - FUSED_SWIGLU = "fused_swiglu" + SWIGLU = "swiglu" class AttentionConfig(BaseModel): @@ -293,7 +295,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.dropout(x) return x - class GPT2Block(nn.Module): def __init__( self, @@ -323,9 +324,8 @@ def __init__( ) if activation_type == ActivationType.GELU: self.mlp = TransformerMLP(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias, dropout=dropout) - elif activation_type == ActivationType.FUSED_SWIGLU: - hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) - self.mlp = xops.SwiGLU(n_embd, hidden_dim, n_embd, bias=False) + elif activation_type == ActivationType.SWIGLU: + self.mlp = SwiGLU(n_embd=n_embd, bias=bias) else: raise NotImplementedError("unimplemented activation") diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index dfcb02f6..3a600f71 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -21,6 +21,43 @@ def get_parameters(self) -> Dict[str, torch.Tensor]: return {name: param for name, param in self.named_parameters()} +class SwiGLU(nn.Module): + def __init__(self, n_embd: int, bias: bool): + super().__init__() + + hidden_dim = SwiGLU._get_hidden_dim(n_embd) + + self.c_fc = nn.Linear( + in_features=n_embd, + out_features=hidden_dim, + bias=bias, + ) + self.silu = nn.SiLU() + self.c_proj = nn.Linear( + in_features=n_embd, + out_features=hidden_dim, + bias=bias, + ) + self.out_proj = nn.Linear( + in_features=hidden_dim, + out_features=n_embd, + bias=bias, + ) + + @staticmethod + def _get_hidden_dim(n_embd: int) -> int: + # Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) + # To ensure that the number of parameters in the SwiGLU module with its additional + # linear layer are equivalent to the TransformerMLP, we need to adapt the SwiGLU hidden dimension as follows: + # 2 * (n_embd * hidden_dim) == 3 * (n_embd * 2/3 * hidden_dim) + # Besides, we ensure that hidden_dim is the smallest multiple of + # 256 that is greater than or equal the provided hidden_dim + return 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_proj(self.silu(self.c_fc(x)) * self.c_proj(x)) + + def model_predict_batch(model: nn.Module, batch: DatasetBatch) -> InferenceResultBatch: forward_result = model.forward(batch.samples) result_batch = InferenceResultBatch(targets=batch.targets, predictions=forward_result) diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py index 0c94e69b..7bbed65e 100644 --- a/tests/nn/test_mlp.py +++ b/tests/nn/test_mlp.py @@ -1,5 +1,7 @@ import torch +from torch import nn +from modalities.models.model import SwiGLU from modalities.nn.mlp import MLP @@ -8,3 +10,33 @@ def test_mlp_forward(): dummy_input = torch.randn(1, 10, 64) out = model(dummy_input) assert out.shape == (1, 10, 64) + + +def test_SwiGLU_forward(): + n_embd = 512 + bias = True + mlp = SwiGLU(n_embd, bias) + + hidden_dim = 1536 + assert SwiGLU._get_hidden_dim(n_embd) == hidden_dim + + n_embd = 511 + assert SwiGLU._get_hidden_dim(n_embd) == hidden_dim + + n_embd = 512 + + # batch size x sequence length x embedding dim + input_tensor = torch.randn(1, 1, n_embd) + output_tensor = mlp(input_tensor) + assert output_tensor.shape == (1, 1, n_embd) + + c_fc = nn.Linear(in_features=n_embd, out_features=hidden_dim, bias=bias) + c_proj = nn.Linear(in_features=n_embd, out_features=hidden_dim, bias=bias) + out_proj = nn.Linear(in_features=hidden_dim, out_features=n_embd, bias=bias) + silu = nn.SiLU() + mlp.c_fc = c_fc + mlp.c_proj = c_proj + mlp.out_proj = out_proj + + output_tensor = mlp(input_tensor) + assert torch.all(output_tensor == out_proj(silu(c_fc(input_tensor)) * c_proj(input_tensor)))