From bff009c781fa15a0ca2496bca5731e7eb85a4ccc Mon Sep 17 00:00:00 2001 From: mali-git Date: Wed, 12 Jun 2024 14:19:27 +0200 Subject: [PATCH 01/14] feat: implemnet SwiGLU --- src/modalities/models/gpt2/gpt2_model.py | 32 +++++++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 9fa4063e..998fec35 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -110,7 +110,7 @@ class QueryKeyValueTransformType(Enum): class ActivationType(str, Enum): GELU = "gelu" - FUSED_SWIGLU = "fused_swiglu" + SWIGLU = "swiglu" class AttentionConfig(BaseModel): @@ -297,6 +297,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.dropout(x) return x +class SwiGLU(nn.Module): + def __init__(self, n_embd: int, bias: bool): + super().__init__() + # Best practice: 4 * n_embd + hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) + + self.c_fc = nn.Linear( + in_features=n_embd, + out_features=hidden_dim, + bias=bias, + ) + self.silu = nn.SiLU() + self.c_proj = nn.Linear( + in_features=n_embd, + out_features=hidden_dim, + bias=bias, + ) + self.out_proj = nn.Linear( + in_features=hidden_dim, + out_features=n_embd, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_proj(self.silu(self.c_fc(x)) * self.c_proj(x)) class GPT2Block(nn.Module): def __init__( @@ -327,9 +352,8 @@ def __init__( ) if activation_type == ActivationType.GELU: self.mlp = TransformerMLP(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias, dropout=dropout) - elif activation_type == ActivationType.FUSED_SWIGLU: - hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) - self.mlp = xops.SwiGLU(n_embd, hidden_dim, n_embd, bias=False) + elif activation_type == ActivationType.SWIGLU: + self.mlp = SwiGLU(n_embd=n_embd, bias=bias) else: raise NotImplementedError("unimplemented activation") From 4cb62b1a582bee55744ac67cddd466f49a1a9845 Mon Sep 17 00:00:00 2001 From: mali-git Date: Wed, 12 Jun 2024 14:24:07 +0200 Subject: [PATCH 02/14] tests: test SwiGLU --- tests/nn/test_mlp.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py index 0c94e69b..c95651e6 100644 --- a/tests/nn/test_mlp.py +++ b/tests/nn/test_mlp.py @@ -1,5 +1,5 @@ import torch - +from modalities.models.gpt2.gpt2_model import SwiGLU from modalities.nn.mlp import MLP @@ -8,3 +8,12 @@ def test_mlp_forward(): dummy_input = torch.randn(1, 10, 64) out = model(dummy_input) assert out.shape == (1, 10, 64) + +def test_SwiGLU_forward(): + n_embd = 512 + bias = True + model = SwiGLU(n_embd, bias) + input_tensor = torch.randn(1, n_embd) + output_tensor = model(input_tensor) + assert output_tensor.shape == (1, n_embd) + From dd972a3088b9a30ca15abd702b4f847b13e154cb Mon Sep 17 00:00:00 2001 From: mali-git Date: Wed, 12 Jun 2024 14:26:06 +0200 Subject: [PATCH 03/14] refactor: use custom SwiGLU --- src/modalities/models/coca/multi_modal_decoder.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index d0d92ccb..8db302c3 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -2,14 +2,12 @@ from typing import Dict, List import torch -import xformers.ops as xops from torch import nn -from modalities.models.gpt2.gpt2_model import ActivationType +from modalities.models.gpt2.gpt2_model import ActivationType, SwiGLU from modalities.models.model import NNModel from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention from modalities.nn.mlp import MLP -from transformers import PreTrainedTokenizer class TransformerBlock(nn.Module): @@ -33,8 +31,8 @@ def __init__( if activation == ActivationType.GELU: mlp = partial(MLP, in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout) - elif activation == ActivationType.FUSED_SWIGLU: - mlp = partial(xops.SwiGLU, in_features=n_embd, hidden_features=ffn_hidden, bias=bias) + elif activation == ActivationType.SWIGLU: + mlp = partial(SwiGLU(n_embd, bias)) else: raise NotImplementedError(f"activation type {activation} not implemented") From 58aea7ae56b8fbde209c58dd0adf3f340de208b8 Mon Sep 17 00:00:00 2001 From: mali-git Date: Wed, 12 Jun 2024 14:26:30 +0200 Subject: [PATCH 04/14] refactor: update dependencies --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b20ffdac..8de89bb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,17 +13,16 @@ dependencies = [ "SentencePiece", "accelerate", "rich", - "xformers", "omegaconf", "pydantic", "click", "click_pathlib", "jq", - "xformers", "class_resolver", "wandb", "packaging", "einops>=0.7.0", + "mamba-ssm", "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` ] From 90fb3bd06a407333423cffeab486711e26ef8ddf Mon Sep 17 00:00:00 2001 From: mali-git Date: Wed, 12 Jun 2024 14:28:46 +0200 Subject: [PATCH 05/14] refactor: update configs --- config_files/training/config_example_coca.yaml | 2 +- config_files/training/config_example_openGPTx_dataset.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index fcc886f9..decaa197 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -206,7 +206,7 @@ model: n_embd: 768 dropout: 0.0 bias: true - activation: fused_swiglu + activation: swiglu epsilon: 1e-5 n_pool_head: 8 n_vision_queries: 256 diff --git a/config_files/training/config_example_openGPTx_dataset.yaml b/config_files/training/config_example_openGPTx_dataset.yaml index 85114fe4..a0a57114 100644 --- a/config_files/training/config_example_openGPTx_dataset.yaml +++ b/config_files/training/config_example_openGPTx_dataset.yaml @@ -151,7 +151,7 @@ model: n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - activation: fused_swiglu + activation: swiglu weight_init: mean: 0.0 std: 0.02 From 13c6554620526fa1ec6770be66ecc338266306a8 Mon Sep 17 00:00:00 2001 From: mali-git Date: Thu, 13 Jun 2024 21:48:00 +0200 Subject: [PATCH 06/14] refactor: move class out and fix imports --- .../models/coca/multi_modal_decoder.py | 4 +-- src/modalities/models/gpt2/gpt2_model.py | 32 ++----------------- tests/nn/test_mlp.py | 3 +- 3 files changed, 6 insertions(+), 33 deletions(-) diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index 8db302c3..79533c7f 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -4,8 +4,8 @@ import torch from torch import nn -from modalities.models.gpt2.gpt2_model import ActivationType, SwiGLU -from modalities.models.model import NNModel +from modalities.models.gpt2.gpt2_model import ActivationType +from modalities.models.model import NNModel, SwiGLU from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention from modalities.nn.mlp import MLP diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 998fec35..b564ef23 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -1,5 +1,5 @@ import math -import sys + from copy import deepcopy from enum import Enum from functools import partial @@ -7,15 +7,13 @@ import torch import torch.nn as nn -import xformers.ops as xops from flash_attn import flash_attn_func from pydantic import BaseModel, Field, model_validator, validator from torch.nn import functional as F -from transformers import PreTrainedTokenizer from modalities.config.pydanctic_if_types import PydanticPytorchModuleType from modalities.config.utils import convert_base_model_config_to_dict -from modalities.models.model import NNModel +from modalities.models.model import NNModel, SwiGLU from modalities.util import parse_enum_by_name @@ -297,32 +295,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.dropout(x) return x -class SwiGLU(nn.Module): - def __init__(self, n_embd: int, bias: bool): - super().__init__() - # Best practice: 4 * n_embd - hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) - - self.c_fc = nn.Linear( - in_features=n_embd, - out_features=hidden_dim, - bias=bias, - ) - self.silu = nn.SiLU() - self.c_proj = nn.Linear( - in_features=n_embd, - out_features=hidden_dim, - bias=bias, - ) - self.out_proj = nn.Linear( - in_features=hidden_dim, - out_features=n_embd, - bias=bias, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.out_proj(self.silu(self.c_fc(x)) * self.c_proj(x)) - class GPT2Block(nn.Module): def __init__( self, diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py index c95651e6..34ddf349 100644 --- a/tests/nn/test_mlp.py +++ b/tests/nn/test_mlp.py @@ -1,5 +1,6 @@ import torch -from modalities.models.gpt2.gpt2_model import SwiGLU + +from modalities.models.model import SwiGLU from modalities.nn.mlp import MLP From 0b18b12770d84f457779d38ae30a8aec562a8113 Mon Sep 17 00:00:00 2001 From: mali-git Date: Thu, 13 Jun 2024 21:48:28 +0200 Subject: [PATCH 07/14] chore: move class and add comment --- src/modalities/models/model.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index a20cfd47..0b0db490 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -21,7 +21,35 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def get_parameters(self) -> Dict[str, torch.Tensor]: return {name: param for name, param in self.named_parameters()} - +class SwiGLU(nn.Module): + def __init__(self, n_embd: int, bias: bool): + super().__init__() + # Best practice: 4 * n_embd + # Because we add an additional linear layer, we need to adjust the hidden_dim to 2/3 of the original value + # which is equivalent to the number of parameters in TransformerMLP, i.e. + # 2 * (n_embd * hidden_dim) == 3 * (n_embd * 2/3 * hidden_dim) + # Besides, we ensure that hidden_dim is the smallest multiple of 256 that is greater than or equal the provided hidden_dim + hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) + + self.c_fc = nn.Linear( + in_features=n_embd, + out_features=hidden_dim, + bias=bias, + ) + self.silu = nn.SiLU() + self.c_proj = nn.Linear( + in_features=n_embd, + out_features=hidden_dim, + bias=bias, + ) + self.out_proj = nn.Linear( + in_features=hidden_dim, + out_features=n_embd, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_proj(self.silu(self.c_fc(x)) * self.c_proj(x)) def model_predict_batch(model: nn.Module, batch: DatasetBatch) -> InferenceResultBatch: forward_result = model.forward(batch.samples) From 23edda1c489747a84c4e3759aa4ff3d4641ee508 Mon Sep 17 00:00:00 2001 From: Mehdi Ali <33023925+mali-git@users.noreply.github.com> Date: Fri, 14 Jun 2024 12:08:10 +0200 Subject: [PATCH 08/14] Update src/modalities/models/coca/multi_modal_decoder.py Co-authored-by: Julian Spravil <979130+spravil@users.noreply.github.com> --- src/modalities/models/coca/multi_modal_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index 79533c7f..71198d3b 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -32,7 +32,7 @@ def __init__( if activation == ActivationType.GELU: mlp = partial(MLP, in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout) elif activation == ActivationType.SWIGLU: - mlp = partial(SwiGLU(n_embd, bias)) + mlp = partial(SwiGLU, n_embd=n_embd, bias=bias) else: raise NotImplementedError(f"activation type {activation} not implemented") From b39265cca90d7e1b5d372ed165af40e57bd8fe70 Mon Sep 17 00:00:00 2001 From: mali-git Date: Sun, 16 Jun 2024 23:24:28 +0200 Subject: [PATCH 09/14] chore: add reference --- src/modalities/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index 0b0db490..f83ae553 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -24,7 +24,7 @@ def get_parameters(self) -> Dict[str, torch.Tensor]: class SwiGLU(nn.Module): def __init__(self, n_embd: int, bias: bool): super().__init__() - # Best practice: 4 * n_embd + # Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) # Because we add an additional linear layer, we need to adjust the hidden_dim to 2/3 of the original value # which is equivalent to the number of parameters in TransformerMLP, i.e. # 2 * (n_embd * hidden_dim) == 3 * (n_embd * 2/3 * hidden_dim) From bd2e32b0db243fffd535dc6f543fd385ca45ec5e Mon Sep 17 00:00:00 2001 From: mali-git Date: Mon, 17 Jun 2024 00:29:54 +0200 Subject: [PATCH 10/14] refactor: move hidden_dim computation to own fct. --- src/modalities/models/model.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index f83ae553..ac2521ef 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -24,12 +24,8 @@ def get_parameters(self) -> Dict[str, torch.Tensor]: class SwiGLU(nn.Module): def __init__(self, n_embd: int, bias: bool): super().__init__() - # Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) - # Because we add an additional linear layer, we need to adjust the hidden_dim to 2/3 of the original value - # which is equivalent to the number of parameters in TransformerMLP, i.e. - # 2 * (n_embd * hidden_dim) == 3 * (n_embd * 2/3 * hidden_dim) - # Besides, we ensure that hidden_dim is the smallest multiple of 256 that is greater than or equal the provided hidden_dim - hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) + + hidden_dim = self._get_hidden_dim(n_embd) self.c_fc = nn.Linear( in_features=n_embd, @@ -48,6 +44,14 @@ def __init__(self, n_embd: int, bias: bool): bias=bias, ) + def _get_hidden_dim(self, n_embd: int) -> int: + # Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) + # Because we add an additional linear layer, we need to adjust the hidden_dim to 2/3 of the original value + # which is equivalent to the number of parameters in TransformerMLP, i.e. + # 2 * (n_embd * hidden_dim) == 3 * (n_embd * 2/3 * hidden_dim) + # Besides, we ensure that hidden_dim is the smallest multiple of 256 that is greater than or equal the provided hidden_dim + return 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_proj(self.silu(self.c_fc(x)) * self.c_proj(x)) From 6a65cc974acdd6cfd55a4c671c9903cfe2c32b6b Mon Sep 17 00:00:00 2001 From: mali-git Date: Mon, 17 Jun 2024 00:30:13 +0200 Subject: [PATCH 11/14] test: finalize SwiGLU test --- tests/nn/test_mlp.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py index 34ddf349..a5f1fa1d 100644 --- a/tests/nn/test_mlp.py +++ b/tests/nn/test_mlp.py @@ -1,5 +1,7 @@ import torch +from torch import nn + from modalities.models.model import SwiGLU from modalities.nn.mlp import MLP @@ -13,8 +15,30 @@ def test_mlp_forward(): def test_SwiGLU_forward(): n_embd = 512 bias = True - model = SwiGLU(n_embd, bias) - input_tensor = torch.randn(1, n_embd) - output_tensor = model(input_tensor) - assert output_tensor.shape == (1, n_embd) + mlp = SwiGLU(n_embd, bias) + + hidden_dim = 1536 + assert mlp._get_hidden_dim(n_embd) == hidden_dim + + n_embd = 511 + assert mlp._get_hidden_dim(n_embd) == hidden_dim + + n_embd = 512 + + # batch size x sequence length x embedding dim + input_tensor = torch.randn(1, 1, n_embd) + output_tensor = mlp(input_tensor) + assert output_tensor.shape == (1, 1, n_embd) + + + c_fc = nn.Linear(in_features=n_embd, out_features=hidden_dim, bias=bias) + c_proj = nn.Linear(in_features=n_embd,out_features=hidden_dim,bias=bias) + out_proj = nn.Linear(in_features=hidden_dim, out_features=n_embd,bias=bias) + silu = nn.SiLU() + mlp.c_fc = c_fc + mlp.c_proj = c_proj + mlp.out_proj = out_proj + + output_tensor = mlp(input_tensor) + assert torch.all(output_tensor == out_proj(silu(c_fc(input_tensor)) * c_proj(input_tensor))) From 4e579b6026b4b0d577892bf68d946af1f8aa545e Mon Sep 17 00:00:00 2001 From: fromm-m Date: Wed, 19 Jun 2024 10:07:23 +0000 Subject: [PATCH 12/14] refactor: made _get_hidden_dim for SwiGLU an static method --- src/modalities/models/model.py | 18 +++++++++++------- tests/nn/test_mlp.py | 16 +++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index 976e683a..3a600f71 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -20,15 +20,16 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def get_parameters(self) -> Dict[str, torch.Tensor]: return {name: param for name, param in self.named_parameters()} + class SwiGLU(nn.Module): def __init__(self, n_embd: int, bias: bool): super().__init__() - hidden_dim = self._get_hidden_dim(n_embd) + hidden_dim = SwiGLU._get_hidden_dim(n_embd) self.c_fc = nn.Linear( in_features=n_embd, - out_features=hidden_dim, + out_features=hidden_dim, bias=bias, ) self.silu = nn.SiLU() @@ -43,17 +44,20 @@ def __init__(self, n_embd: int, bias: bool): bias=bias, ) - def _get_hidden_dim(self, n_embd: int) -> int: + @staticmethod + def _get_hidden_dim(n_embd: int) -> int: # Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) - # Because we add an additional linear layer, we need to adjust the hidden_dim to 2/3 of the original value - # which is equivalent to the number of parameters in TransformerMLP, i.e. + # To ensure that the number of parameters in the SwiGLU module with its additional + # linear layer are equivalent to the TransformerMLP, we need to adapt the SwiGLU hidden dimension as follows: # 2 * (n_embd * hidden_dim) == 3 * (n_embd * 2/3 * hidden_dim) - # Besides, we ensure that hidden_dim is the smallest multiple of 256 that is greater than or equal the provided hidden_dim + # Besides, we ensure that hidden_dim is the smallest multiple of + # 256 that is greater than or equal the provided hidden_dim return 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256) - + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_proj(self.silu(self.c_fc(x)) * self.c_proj(x)) + def model_predict_batch(model: nn.Module, batch: DatasetBatch) -> InferenceResultBatch: forward_result = model.forward(batch.samples) result_batch = InferenceResultBatch(targets=batch.targets, predictions=forward_result) diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py index a5f1fa1d..7bbed65e 100644 --- a/tests/nn/test_mlp.py +++ b/tests/nn/test_mlp.py @@ -1,5 +1,4 @@ import torch - from torch import nn from modalities.models.model import SwiGLU @@ -12,17 +11,18 @@ def test_mlp_forward(): out = model(dummy_input) assert out.shape == (1, 10, 64) + def test_SwiGLU_forward(): n_embd = 512 bias = True mlp = SwiGLU(n_embd, bias) hidden_dim = 1536 - assert mlp._get_hidden_dim(n_embd) == hidden_dim - + assert SwiGLU._get_hidden_dim(n_embd) == hidden_dim + n_embd = 511 - assert mlp._get_hidden_dim(n_embd) == hidden_dim - + assert SwiGLU._get_hidden_dim(n_embd) == hidden_dim + n_embd = 512 # batch size x sequence length x embedding dim @@ -30,10 +30,9 @@ def test_SwiGLU_forward(): output_tensor = mlp(input_tensor) assert output_tensor.shape == (1, 1, n_embd) - c_fc = nn.Linear(in_features=n_embd, out_features=hidden_dim, bias=bias) - c_proj = nn.Linear(in_features=n_embd,out_features=hidden_dim,bias=bias) - out_proj = nn.Linear(in_features=hidden_dim, out_features=n_embd,bias=bias) + c_proj = nn.Linear(in_features=n_embd, out_features=hidden_dim, bias=bias) + out_proj = nn.Linear(in_features=hidden_dim, out_features=n_embd, bias=bias) silu = nn.SiLU() mlp.c_fc = c_fc mlp.c_proj = c_proj @@ -41,4 +40,3 @@ def test_SwiGLU_forward(): output_tensor = mlp(input_tensor) assert torch.all(output_tensor == out_proj(silu(c_fc(input_tensor)) * c_proj(input_tensor))) - From 4114a0cd0e6941cdca31b79bb4db32652bc108ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbbering?= <2804731+le1nux@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:21:27 +0200 Subject: [PATCH 13/14] feat: Create CHANGELOG_DEV.md --- CHANGELOG_DEV.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 CHANGELOG_DEV.md diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md new file mode 100644 index 00000000..7711ee1e --- /dev/null +++ b/CHANGELOG_DEV.md @@ -0,0 +1,25 @@ +# Changelog + +| PR | Type | Ref. Issue(s) | Breaking Changes |PR Description| +|------------------|------------|---------------|------------------|------------------------------------------------------------------------------------------------| +| [#141](#pr-141-towards-stable-modalities-version) | Bug Fix | [#129](https://github.com/Modalities/modalities/issues/129) | **Yes** | Towards stable modalities version | +| | | | | | + + + +## PR #141 Towards stable modalities version + +This PR further stabilise the codebase and makes training more robust also w.r.t. loss spikes, which we fixed via scaled weight initialisation and an increased batch size in our experiments. +The PR also fixes all failing tests and adds a simple entrypoint for running cpu, single-gpu and multi-gpu tests. The PR contains multiple sub PRs. + +**General changes:** +* Bug fix: the model evaluation mode is now properly deactivated after evaluation (see PR [#131](https://github.com/Modalities/modalities/pull/131)) +* Bug fix: Fixed the implementation of Pre-LN for GPT2 model (see PR [#136](https://github.com/Modalities/modalities/pull/136)) +* Enhancement: Further mixed precision strategies; also added one matching MegatronLM's. +* Enhancement: Single, unified entrypoint for running cpu, single-gpu and multi-gpu tests. All tests fixed. (PR [#155](https://github.com/Modalities/modalities/pull/155)) + +**Breaking changes:** +* Enhancement: Logging is now always based on #training steps and #consumed tokens (PR [#137](https://github.com/Modalities/modalities/pull/137)) + This change is a breaking change and the experiment configs need to adapated as shown [here](https://github.com/Modalities/modalities/pull/137/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4). +* Enhancement: The model parameters are now grouped within the respective model. The optimizer can leverage these groups to e.g., only apply weight decay to non-layer-norm weights. See [here](https://github.com/Modalities/modalities/pull/139/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4) for the necessary config changes. (PR [#139](https://github.com/Modalities/modalities/pull/139)) +* Enhancement: We support now different attention implementations (manual, pytorch flash, DAO flash) See [here](https://github.com/Modalities/modalities/pull/138/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4) for the respective config changes. (PR [#138](https://github.com/Modalities/modalities/pull/138)) From 333fcc0477838edfbc9ee77f10785157f8bf3202 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbbering?= <2804731+le1nux@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:23:46 +0200 Subject: [PATCH 14/14] Update CHANGELOG_DEV.md --- CHANGELOG_DEV.md | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index 7711ee1e..c1d4a4c7 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -2,24 +2,17 @@ | PR | Type | Ref. Issue(s) | Breaking Changes |PR Description| |------------------|------------|---------------|------------------|------------------------------------------------------------------------------------------------| -| [#141](#pr-141-towards-stable-modalities-version) | Bug Fix | [#129](https://github.com/Modalities/modalities/issues/129) | **Yes** | Towards stable modalities version | +| [#154](pr-154-manual-swiglu-implementation) | Bug Fix | [#14](https://github.com/Modalities/modalities/issues/14) | **Yes** | Towards stable modalities version | | | | | | | -## PR #141 Towards stable modalities version +## PR #154 Manual SwiGLU implementation -This PR further stabilise the codebase and makes training more robust also w.r.t. loss spikes, which we fixed via scaled weight initialisation and an increased batch size in our experiments. -The PR also fixes all failing tests and adds a simple entrypoint for running cpu, single-gpu and multi-gpu tests. The PR contains multiple sub PRs. +This [PR](https://github.com/Modalities/modalities/pull/154) adds a manual SwiGLU implementation. The original one from xops was imcompatible with activation checkpointing (see issue [#14](https://github.com/Modalities/modalities/issues/14)) **General changes:** -* Bug fix: the model evaluation mode is now properly deactivated after evaluation (see PR [#131](https://github.com/Modalities/modalities/pull/131)) -* Bug fix: Fixed the implementation of Pre-LN for GPT2 model (see PR [#136](https://github.com/Modalities/modalities/pull/136)) -* Enhancement: Further mixed precision strategies; also added one matching MegatronLM's. -* Enhancement: Single, unified entrypoint for running cpu, single-gpu and multi-gpu tests. All tests fixed. (PR [#155](https://github.com/Modalities/modalities/pull/155)) +* replaces xops swiglu imlementation with custom reimplementation **Breaking changes:** -* Enhancement: Logging is now always based on #training steps and #consumed tokens (PR [#137](https://github.com/Modalities/modalities/pull/137)) - This change is a breaking change and the experiment configs need to adapated as shown [here](https://github.com/Modalities/modalities/pull/137/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4). -* Enhancement: The model parameters are now grouped within the respective model. The optimizer can leverage these groups to e.g., only apply weight decay to non-layer-norm weights. See [here](https://github.com/Modalities/modalities/pull/139/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4) for the necessary config changes. (PR [#139](https://github.com/Modalities/modalities/pull/139)) -* Enhancement: We support now different attention implementations (manual, pytorch flash, DAO flash) See [here](https://github.com/Modalities/modalities/pull/138/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4) for the respective config changes. (PR [#138](https://github.com/Modalities/modalities/pull/138)) +* renaming of `fused_swiglu` to `swiglu` in `ActivationType` (see [here](https://github.com/Modalities/modalities/pull/154/commits/90fb3bd06a407333423cffeab486711e26ef8ddf) for the respective config changes)