Skip to content

Commit

Permalink
Merge pull request #161 from Modalities/feat/initialization
Browse files Browse the repository at this point in the history
Feat: Various Configurable Initializations
  • Loading branch information
le1nux authored Jul 9, 2024
2 parents 3ad5d56 + f549392 commit 0b8dfc0
Show file tree
Hide file tree
Showing 35 changed files with 1,172 additions and 78 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ In the following, we list the already implemented, planned and in-progress featu
|--------------------------------|------------------|-------------------------------------------------------------------------------------------------------------------|
| SwiGLU | supported | A nonlinear activation function combining Gated Linear Units (GLU) with Swish for enhancing model capacity and learning efficiency. |
| Weight Decay | supported | Regularization technique that adds a penalty on the size of weights, encouraging smaller weights to reduce overfitting and improve generalization. |
| Weight Initialization | supported | Choose between different, configurable weight initialization techniques to stabilize training. |
| RMSNorm (pre-normalization) | supported | Normalizes the pre-activation weights in a layer to stabilize training, often used as an alternative to LayerNorm for improved training dynamics. |
| Rotary Positional Embeddings (RoPE) | supported | Encodes sequence position information into attention mechanisms, preserving relative positional information and improving model's understanding of sequence order. |
| Grouped-query Attention (GQA) | supported | Enhances attention mechanisms by grouping queries to reduce computation and memory footprint while maintaining or improving performance. |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ raw_model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ raw_model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
27 changes: 20 additions & 7 deletions config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ train_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
num_samples: 64
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
Expand All @@ -55,7 +55,7 @@ val_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
num_samples: 32
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
Expand Down Expand Up @@ -170,7 +170,23 @@ wrapped_model:
sharding_strategy: FULL_SHARD
block_names: [TransformerBlock, VisionTransformerBlock]

model:
model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: composed
config:
model_type: coca
weight_init_type: plain
mean: 0.0
std: 0.02

model_raw:
component_key: model
variant_key: coca
config:
Expand Down Expand Up @@ -215,9 +231,6 @@ model:
n_vision_queries: 256
bias_attn_pool: False
epsilon_attn_pool: 1e-5
weight_init:
mean: 0.0
std: 0.02

scheduler:
component_key: scheduler
Expand All @@ -229,7 +242,7 @@ scheduler:
max_lr: 6e-4
div_factor: 10
final_div_factor: 1
total_steps: 4
total_steps: 64
pct_start: 0.01
anneal_strategy: cos

Expand Down
1 change: 1 addition & 0 deletions config_files/training/config_example_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
1 change: 1 addition & 0 deletions config_files/training/config_example_openGPTx_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled

scheduler:
type_hint: StepLR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
32 changes: 23 additions & 9 deletions config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,24 @@ wrapped_model:
sharding_strategy: FULL_SHARD
block_names: [GPT2Block]

model:
model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: composed
config:
model_type: gpt2
weight_init_type: scaled_embed
mean: 0.0
std: 0.02
num_layers: ${model_raw.config.n_layer}

model_raw:
component_key: model
variant_key: gpt2
config:
Expand All @@ -198,33 +215,30 @@ model:
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model.config.n_embd}
n_head: ${model.config.n_head_q} #it has to be head_q here
n_embd: ${model_raw.config.n_embd}
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
seq_length_dim: -2
attention_implementation: manual
activation_type: gelu
weight_init:
mean: 0.0
std: 0.02
attention_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5
ffn_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5
lm_head_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5

Expand Down
1 change: 1 addition & 0 deletions examples/getting_started/example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ raw_model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
1 change: 1 addition & 0 deletions examples/library_usage/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled

optimizer:
component_key: optimizer
Expand Down
6 changes: 6 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PydanticCollateFnIFType,
PydanticDatasetIFType,
PydanticLLMDataLoaderIFType,
PydanticModelInitializationIFType,
PydanticOptimizerIFType,
PydanticPytorchDeviceType,
PydanticPytorchModuleType,
Expand Down Expand Up @@ -228,6 +229,11 @@ def parse_sharding_strategy_by_name(cls, name):
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)


class WeightInitializedModelConfig(BaseModel):
model: PydanticPytorchModuleType
model_initializer: PydanticModelInitializationIFType


class PreTrainedHFTokenizerConfig(BaseModel):
pretrained_model_name_or_path: str
max_length: Annotated[int, Field(strict=True, ge=0)]
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/pydanctic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.loss_functions import Loss
from modalities.models.gpt2.collator import CollateFnIF
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF

Expand Down Expand Up @@ -61,3 +62,4 @@ def __get_pydantic_core_schema__(
PydanticPytorchDeviceType = Annotated[torch.device, PydanticThirdPartyTypeIF(torch.device)]
PydanticTextInferenceComponentType = Annotated[TextInferenceComponent, PydanticThirdPartyTypeIF(TextInferenceComponent)]
PydanticGradientClipperIFType = Annotated[GradientClipperIF, PydanticThirdPartyTypeIF(GradientClipperIF)]
PydanticModelInitializationIFType = Annotated[ModelInitializationIF, PydanticThirdPartyTypeIF(ModelInitializationIF)]
27 changes: 1 addition & 26 deletions src/modalities/models/coca/coca_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import math
from functools import partial
from typing import Annotated, Dict, Tuple

import torch
Expand All @@ -10,8 +8,7 @@
from modalities.models.coca.attention_pooling import AttentionPooling
from modalities.models.coca.multi_modal_decoder import MultiModalTextDecoder
from modalities.models.coca.text_decoder import TextDecoder
from modalities.models.gpt2.gpt2_model import ActivationType, WeightInitializationConfig
from modalities.models.model import NNModel
from modalities.models.model import ActivationType, NNModel
from modalities.models.vision_transformer.vision_transformer_model import VisionTransformer, VisionTransformerConfig
from modalities.nn.attention import AttentionConfig

Expand Down Expand Up @@ -45,7 +42,6 @@ class CoCaConfig(BaseModel):
n_vision_queries: Annotated[int, Field(ge=1)]
bias_attn_pool: bool
epsilon_attn_pool: Annotated[float, Field(ge=0.0)]
weight_init: WeightInitializationConfig


class CoCa(NNModel):
Expand All @@ -71,7 +67,6 @@ def __init__(
epsilon_attn_pool: float,
vision_encoder_config: VisionTransformerConfig,
text_decoder_config: TextDecoderConfig,
weight_init: WeightInitializationConfig,
) -> None:
super().__init__()
self.prediction_key = prediction_key
Expand Down Expand Up @@ -126,26 +121,6 @@ def __init__(
attention_config=text_decoder_config.attention_config,
)

# init all weights
self.apply(partial(self._init_weights, weight_init=weight_init))
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(
p,
mean=weight_init.mean,
std=weight_init.std
/ math.sqrt(2 * (text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text)),
)

def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
vision_embd, vision_cls_token = self._forward_encode_vision(inputs)
text_embd, text_cls_token = self._forward_encode_text(inputs)
Expand Down
30 changes: 1 addition & 29 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
from copy import deepcopy
from enum import Enum
from functools import partial
from typing import Annotated, Dict, List, Tuple

import torch
Expand All @@ -16,7 +15,7 @@

from modalities.config.pydanctic_if_types import PydanticPytorchModuleType
from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.model import NNModel
from modalities.models.model import ActivationType, NNModel
from modalities.util import parse_enum_by_name

# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
Expand Down Expand Up @@ -108,11 +107,6 @@ class QueryKeyValueTransformType(Enum):
RotaryTransform = RotaryTransform


class ActivationType(str, Enum):
GELU = "gelu"
FUSED_SWIGLU = "fused_swiglu"


class AttentionImplementation(str, Enum):
MANUAL = "manual"
PYTORCH_FLASH = "pytorch_flash"
Expand All @@ -139,11 +133,6 @@ def parse_sharding_strategy_by_name(cls, name):
qkv_transforms: List[QueryKeyValueTransformConfig]


class WeightInitializationConfig(BaseModel):
mean: Annotated[float, Field(strict=True, ge=0.0)]
std: Annotated[float, Field(strict=True, ge=0.0)]


class GPT2LLMConfig(BaseModel):
sample_key: str
prediction_key: str
Expand All @@ -165,7 +154,6 @@ class GPT2LLMConfig(BaseModel):
attention_norm: PydanticPytorchModuleType
ffn_norm: PydanticPytorchModuleType
lm_head_norm: PydanticPytorchModuleType
weight_init: WeightInitializationConfig

@model_validator(mode="after")
def check_divisibility(self) -> "GPT2LLMConfig":
Expand Down Expand Up @@ -430,7 +418,6 @@ def __init__(
bias: bool,
activation_type: ActivationType,
attention_implementation: AttentionImplementation,
weight_init: WeightInitializationConfig,
attention_config: AttentionConfig,
attention_norm: nn.Module,
ffn_norm: nn.Module,
Expand Down Expand Up @@ -499,21 +486,6 @@ def __init__(
# not 100% sure what this is, so far seems to be harmless. TODO investigate
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

# init all weights
self.apply(partial(self._init_weights, weight_init=weight_init))
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=weight_init.mean, std=weight_init.std / math.sqrt(2 * n_layer))

def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)

def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = inputs[self.sample_key]
device = input_ids.device
Expand Down
Loading

0 comments on commit 0b8dfc0

Please sign in to comment.