Skip to content

Commit

Permalink
Merge pull request #67 from Modalities/rms_norm
Browse files Browse the repository at this point in the history
RMS norm implementation
  • Loading branch information
le1nux committed Mar 13, 2024
2 parents 14f4f2e + bedb564 commit 4f509cc
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 60 deletions.
37 changes: 32 additions & 5 deletions config_files/config_example_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ train_dataset:
component_key: dataset
variant_key: packed_mem_map_dataset_megatron
config:
raw_data_path: /raid/s3/opengptx/max_lue/LLMgym/data/redpyjama_v2_default_DE_num_docs_16777216.pbin
raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_1050391.pbin
block_size: ${settings.training.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}

Expand Down Expand Up @@ -136,21 +136,48 @@ model:
sample_key: ${settings.referencing_keys.sample_key}
prediction_key: ${settings.referencing_keys.prediction_key}
block_size: ${settings.training.sequence_length}
poe_type: NOPE
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head: 12
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention:
bias: true # True: bias in Linears, like GPT-2. False: a bit better and faster
attention_config:
attention_type: pytorch_flash_attention
scaling_factor: 3
activation: gelu
epsilon: 1e-5
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model.config.n_embd}
n_head: ${model.config.n_head}
seq_length_dim: -2
activation_type: gelu
weight_init:
mean: 0.0
std: 0.02
attention_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
bias: true
epsilon: 1e-5
ffn_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
bias: true
epsilon: 1e-5
lm_head_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
bias: true
epsilon: 1e-5

wrapped_model:
component_key: model
Expand Down
14 changes: 7 additions & 7 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __get_pydantic_core_schema__(
PydanticCheckpointingExecutionIFType = Annotated[
CheckpointingExecutionIF, PydanticThirdPartyTypeIF(CheckpointingExecutionIF)
]
PydanticModelIFType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)]
PydanticPytorchModuleType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)]
PydanticTokenizerIFType = Annotated[PreTrainedTokenizerFast, PydanticThirdPartyTypeIF(PreTrainedTokenizerFast)]
PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)]
PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)]
Expand Down Expand Up @@ -134,24 +134,24 @@ class CheckpointingConfig(BaseModel):

class AdamWOptimizerConfig(BaseModel):
lr: float
wrapped_model: PydanticModelIFType
wrapped_model: PydanticPytorchModuleType


class CheckpointedOptimizerConfig(BaseModel):
checkpointing: PydanticCheckpointingIFType
checkpoint_path: Path
wrapped_model: PydanticModelIFType
wrapped_model: PydanticPytorchModuleType
optimizer: PydanticOptimizerIFType


class CheckpointedModelConfig(BaseModel):
checkpointing: PydanticCheckpointingIFType
checkpoint_path: Path
model: PydanticModelIFType
model: PydanticPytorchModuleType


class FSDPWrappedModelConfig(BaseModel):
model: PydanticModelIFType
model: PydanticPytorchModuleType
sync_module_states: bool
mixed_precision_settings: MixedPrecisionSettings
sharding_strategy: ShardingStrategy
Expand Down Expand Up @@ -303,7 +303,7 @@ class Paths(BaseModel):


class ComponentsModel(BaseModel):
wrapped_model: PydanticModelIFType
wrapped_model: PydanticPytorchModuleType
optimizer: PydanticOptimizerIFType
loss_fn: PydanticLossIFType
train_dataloader: PydanticLLMDataLoaderIFType
Expand All @@ -315,7 +315,7 @@ class ComponentsModel(BaseModel):


class ComponentsInferenceModel(BaseModel):
wrapped_model: PydanticModelIFType
wrapped_model: PydanticPytorchModuleType
cuda_env: CudaEnv


Expand Down
Empty file.
60 changes: 60 additions & 0 deletions src/modalities/models/components/layer_norms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Annotated

import torch
import torch.nn as nn
from pydantic import BaseModel, Field

from modalities.config.lookup_enum import LookupEnum


class RMSLayerNorm(nn.Module):
def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-5):
"""
Initialize the RMSNorm normalization layer.
Original paper: https://arxiv.org/pdf/1910.07467.pdf
Source code adopted from https://github.com/facebookresearch/llama/blob/a0a4da8b497c566403941ceec47c2512ecf9dd20/llama/model.py#L34C1-L77C36
Args:
ndim (int): The dimension of the input tensor.
epsilon (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
bias (bool, optional): If True, the layer will learn an additive bias. Default is True.
"""
super().__init__()
self.epsilon = epsilon
self.gain = nn.Parameter(torch.ones(ndim))
if bias:
self.bias_tensor = nn.Parameter(torch.zeros(ndim))
else:
self.bias_tensor = None

def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon)

def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
if self.bias_tensor is None:
return output * self.gain
else:
return output * self.gain + self.bias_tensor


class LayerNorms(LookupEnum):
"""
An enumeration of the different layer normalization techniques.
"""

RMSNorm = RMSLayerNorm
LayerNorm = nn.LayerNorm


class LayerNormConfig(BaseModel):
normalized_shape: Annotated[int, Field(strict=True, ge=1)]
eps: Annotated[float, Field(strict=True, gt=0, default=1e-6)]
elementwise_affine: Annotated[bool, Field(strict=True, default=True)]
bias: Annotated[bool, Field(strict=True, default=True)]


class RMSLayerNormConfig(BaseModel):
ndim: Annotated[int, Field(strict=True, ge=1)]
epsilon: Annotated[float, Field(gt=0, default=1e-6)]
bias: Annotated[bool, Field(strict=True, default=True)]
90 changes: 42 additions & 48 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from copy import deepcopy
from enum import Enum
from functools import partial
from typing import Annotated, Dict, List, Tuple
Expand All @@ -9,6 +10,7 @@
from pydantic import BaseModel, Field, model_validator, validator
from torch.nn import functional as F

from modalities.config.config import PydanticPytorchModuleType
from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.model import NNModel
from modalities.util import parse_enum_by_name
Expand Down Expand Up @@ -152,12 +154,13 @@ class GPT2LLMConfig(BaseModel):
n_head: Annotated[int, Field(strict=True, ge=1)]
n_embd: Annotated[int, Field(strict=True, ge=1)]
ffn_hidden: Annotated[int, Field(strict=True, ge=1)]

dropout: Annotated[float, Field(strict=True, ge=0.0)]
bias: bool # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention: AttentionConfig
activation: ActivationType
epsilon: Annotated[float, Field(strict=True, ge=0.0)]
bias: bool # True: bias in Linears like GPT-2. False: a bit better and faster
attention_config: AttentionConfig
activation_type: ActivationType
attention_norm: PydanticPytorchModuleType
ffn_norm: PydanticPytorchModuleType
lm_head_norm: PydanticPytorchModuleType
weight_init: WeightInitializationConfig

@model_validator(mode="after")
Expand All @@ -171,31 +174,12 @@ def validate_sizes(self) -> "GPT2LLMConfig":
return self


class LayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

def __init__(self, ndim: int, bias: bool, epsilon: float):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
self.epsilon = epsilon

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
input=input,
normalized_shape=self.weight.shape,
weight=self.weight,
bias=self.bias,
eps=self.epsilon,
)


class CausalSelfAttention(nn.Module):
def __init__(
self,
n_head: int,
n_embd: int,
attention: AttentionConfig,
attention_config: AttentionConfig,
bias: bool,
dropout: float,
block_size: int,
Expand Down Expand Up @@ -223,12 +207,12 @@ def __init__(
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.flash = attention.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION
self.flash = attention_config.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION

# TODO: inject QKVTransforms from outside
self.qkv_transforms = nn.ModuleList(
transform_config.type_hint.value(**convert_base_model_config_to_dict(transform_config.config))
for transform_config in attention.qkv_transforms
for transform_config in attention_config.qkv_transforms
)

if not self.flash:
Expand Down Expand Up @@ -281,7 +265,7 @@ def __init__(self, n_embd: int, ffn_hidden: int, bias: bool, dropout: float):
super().__init__()
self.c_fc = nn.Linear(
in_features=n_embd,
out_features=ffn_hidden, # 4 * n_embd,
out_features=ffn_hidden, # best practice: 4 * n_embd,
bias=bias,
)
self.gelu = nn.GELU()
Expand All @@ -305,32 +289,39 @@ def __init__(
self,
n_embd: int,
bias: bool,
epsilon: float,
activation: ActivationType,
activation_type: ActivationType,
n_head: int,
attention: AttentionConfig,
attention_config: AttentionConfig,
dropout: float,
block_size: int,
ffn_hidden: int,
attention_norm: nn.Module,
ffn_norm: nn.Module,
):
super().__init__()
self.ln_1 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon)
self.attention_norm = attention_norm
self.ffn_norm = ffn_norm
self.attn = CausalSelfAttention(
n_head=n_head, n_embd=n_embd, attention=attention, bias=bias, dropout=dropout, block_size=block_size
n_head=n_head,
n_embd=n_embd,
attention_config=attention_config,
bias=bias,
dropout=dropout,
block_size=block_size,
)
self.ln_2 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon)

if activation == ActivationType.GELU:
if activation_type == ActivationType.GELU:
self.mlp = TransformerMLP(n_embd=n_embd, ffn_hidden=ffn_hidden, bias=bias, dropout=dropout)
elif activation == ActivationType.FUSED_SWIGLU:
elif activation_type == ActivationType.FUSED_SWIGLU:
hidden_dim = 256 * ((int(2 * 4 * n_embd / 3) + 256 - 1) // 256)
self.mlp = xops.SwiGLU(n_embd, hidden_dim, n_embd, bias=False)
else:
raise Exception("unimplemented activation")
raise NotImplementedError("unimplemented activation")

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
x = self.attention_norm(x)
x = x + self.attn(x)
x = self.ffn_norm(x)
x = x + self.mlp(x)
return x


Expand All @@ -348,10 +339,12 @@ def __init__(
ffn_hidden: int,
dropout: float,
bias: bool,
attention: AttentionConfig,
activation: ActivationType,
epsilon: float,
attention_config: AttentionConfig,
activation_type: ActivationType,
weight_init: WeightInitializationConfig,
attention_norm: nn.Module,
ffn_norm: nn.Module,
lm_head_norm: nn.Module,
):
super().__init__()
self.sample_key = sample_key
Expand All @@ -374,7 +367,7 @@ def __init__(
raise TypeError(f"{poe_type} not supported")

if poe_type is not PositionTypes.NOPE and RotaryTransform in [
config.type_hint.value for config in attention.qkv_transforms
config.type_hint.value for config in attention_config.qkv_transforms
]:
raise ValueError('It is expected to use "RotaryTransform" together with "NOPE".')

Expand All @@ -388,18 +381,19 @@ def __init__(
GPT2Block(
n_embd=n_embd,
bias=bias,
epsilon=epsilon,
activation=activation,
activation_type=activation_type,
n_head=n_head,
attention=attention,
attention_config=attention_config,
dropout=dropout,
block_size=block_size,
ffn_hidden=ffn_hidden,
attention_norm=deepcopy(attention_norm),
ffn_norm=deepcopy(ffn_norm),
)
for _ in range(n_layer)
]
),
ln_f=LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon),
ln_f=lm_head_norm,
)
)
self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size, bias=False)
Expand Down
5 changes: 5 additions & 0 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Type

import torch.nn as nn
from pydantic import BaseModel
from torch.utils.data import BatchSampler, DistributedSampler
from transformers import GPT2TokenizerFast
Expand Down Expand Up @@ -43,6 +44,7 @@
ResultsSubscriberFactory,
)
from modalities.loss_functions import CLMCrossEntropyLoss
from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig
from modalities.models.gpt2.collator import GPT2LLMCollateFn
from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig
from modalities.models.huggingface.huggingface_models import (
Expand Down Expand Up @@ -154,4 +156,7 @@ class ComponentEntity:
ResultsSubscriberFactory.get_wandb_result_subscriber,
WandBEvaluationResultSubscriberConfig,
),
# layer norms
ComponentEntity("layer_norm", "rms_norm", RMSLayerNorm, RMSLayerNormConfig),
ComponentEntity("layer_norm", "layer_norm", nn.LayerNorm, LayerNormConfig),
]
Empty file added tests/models/__init__.py
Empty file.
Empty file.
Loading

0 comments on commit 4f509cc

Please sign in to comment.