Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Various Configurable Initializations #161

Merged
merged 52 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f75f9dd
test: scaled initialization for gpt2
flxst Jun 25, 2024
16e31ab
test: scaled initialization for coca
flxst Jun 25, 2024
c2b44c9
test(initialization): scaled initialization for gpt2/coca: improvements
flxst Jun 26, 2024
c332fe3
feat: introduce configurable initialization (plain, scaled)
flxst Jun 26, 2024
5a2d299
refactor: move weight initialization for gpt2 & coca to common parent…
flxst Jun 26, 2024
5ae48e5
feat: introduce auto option for weight initialization standard deviat…
flxst Jun 26, 2024
6376b0a
feat: introduce scaled_embed initialization
flxst Jun 26, 2024
ec520f4
docs: add weight initialization to features in README
flxst Jun 27, 2024
53616e0
feat: added Weight Initialization Factory
le1nux Jun 30, 2024
03b7a7f
feat: implemented model-wise and named-parameter-wise initalization c…
le1nux Jun 30, 2024
d2ecf0d
feat: drafted the init configs
le1nux Jun 30, 2024
0efd32d
chore: added missing init file
le1nux Jun 30, 2024
dcafa38
refactor: removed previous initialization code in the abstract model …
le1nux Jun 30, 2024
83ce570
feat: added initialization config classes
le1nux Jun 30, 2024
ce48a57
refactor: moved WeightInitializationIF to separate file to prevent ci…
le1nux Jun 30, 2024
a08d4d9
feat: added WeightInitializerWrapperConfig to config
le1nux Jun 30, 2024
c39eabd
feat: added WeightInitializerWrapper
le1nux Jun 30, 2024
2af2625
feat: wired up all the weight initializers as components
le1nux Jun 30, 2024
a950746
feat: added functionaliy to initalize model weights
le1nux Jun 30, 2024
416831a
feat: added the weight init to config lorem ipsum
le1nux Jun 30, 2024
86308d4
refactor: removed raising an excpetion when module is not covered for…
le1nux Jun 30, 2024
d2f2e15
refactor: fixed pydantic model_validator in PlainWeightInitialization…
le1nux Jun 30, 2024
2dc37e2
refactor: removed all weight init code from the models
le1nux Jun 30, 2024
f7d6d0d
chore: fixed typo
le1nux Jun 30, 2024
8fcaf08
refactor: finalized config lorem ipsum for weight initialization
le1nux Jun 30, 2024
fe2fb49
refactor: moved configs back to the init factory and removed not need…
le1nux Jul 1, 2024
7164482
feat: added HighLevelWeightInitializationFactory
le1nux Jul 1, 2024
327dbd8
feat: wired up HighLevelWeightInitializationFactory
le1nux Jul 1, 2024
89d5a29
refactor: config_lorem_ipsum.yaml now suppports HighLevelWeightInitia…
le1nux Jul 1, 2024
010f362
refactor: mean of weight init constraint is now relaxed to float-only
le1nux Jul 1, 2024
c0a9aa5
chore: added reference citation to the weigh init for gpt2
le1nux Jul 1, 2024
7d26675
fix: replaced std of 0.4 by math.sqrt(0.4) for scaled_embed
le1nux Jul 1, 2024
bb4cd36
chore: added citation
le1nux Jul 1, 2024
48dced9
refactor: added edge case handling when module is not of type linear …
le1nux Jul 2, 2024
93a0cb4
refactor: removed ModulewiseNormalInitialization and extended NamedPa…
le1nux Jul 2, 2024
c2817f7
refactor: removed old weigh_init from coca config
le1nux Jul 2, 2024
98fe633
refactor: all weight init code is now based on regex matching on name…
le1nux Jul 2, 2024
b8e42c1
feat: added plain weight init for CoCa
le1nux Jul 2, 2024
054d6a2
refactor: renamed plain_std -> std
le1nux Jul 2, 2024
602336c
refactor: renamed plain_std -> std (missed one instance)
le1nux Jul 2, 2024
824da6a
refactor: some renamings and passing now the calculated std from plai…
le1nux Jul 2, 2024
4573364
refactor: removed legacy NamedParameterwiseNormalInitializationConfig
le1nux Jul 2, 2024
0e29447
refactor: simplified the initialization structure and improved naming
le1nux Jul 2, 2024
eea99bc
refactor: for plain init we do not allow hidde_dim to be specified wh…
le1nux Jul 3, 2024
d7c0825
refactor: removed currently not needed init routines from component r…
le1nux Jul 3, 2024
2c50e06
test: initialization unit test adjustments (first steps)
flxst Jul 3, 2024
96c0657
test: initialization unit test adjustments (config fix)
flxst Jul 3, 2024
d3c1683
test: fix coca config
flxst Jul 3, 2024
1a61bf6
feat: added plain filters for coca
le1nux Jul 3, 2024
003eb99
test: readd test for initialization
flxst Jul 3, 2024
9f5651b
Merge pull request #168 from Modalities/feat/initialization_component
le1nux Jul 3, 2024
f549392
refactor: Added references and minor improvements
le1nux Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ In the following, we list the already implemented, planned and in-progress featu
|--------------------------------|------------------|-------------------------------------------------------------------------------------------------------------------|
| SwiGLU | supported | A nonlinear activation function combining Gated Linear Units (GLU) with Swish for enhancing model capacity and learning efficiency. |
| Weight Decay | supported | Regularization technique that adds a penalty on the size of weights, encouraging smaller weights to reduce overfitting and improve generalization. |
| Weight Initialization | supported | Choose between different, configurable weight initialization techniques to stabilize training. |
| RMSNorm (pre-normalization) | supported | Normalizes the pre-activation weights in a layer to stabilize training, often used as an alternative to LayerNorm for improved training dynamics. |
| Rotary Positional Embeddings (RoPE) | supported | Encodes sequence position information into attention mechanisms, preserving relative positional information and improving model's understanding of sequence order. |
| Grouped-query Attention (GQA) | supported | Enhances attention mechanisms by grouping queries to reduce computation and memory footprint while maintaining or improving performance. |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ raw_model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ raw_model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
27 changes: 20 additions & 7 deletions config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ train_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
num_samples: 64
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
Expand All @@ -55,7 +55,7 @@ val_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
num_samples: 32
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
Expand Down Expand Up @@ -170,7 +170,23 @@ wrapped_model:
sharding_strategy: FULL_SHARD
block_names: [TransformerBlock, VisionTransformerBlock]

model:
model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: composed
config:
model_type: coca
weight_init_type: plain
mean: 0.0
std: 0.02

model_raw:
component_key: model
variant_key: coca
config:
Expand Down Expand Up @@ -215,9 +231,6 @@ model:
n_vision_queries: 256
bias_attn_pool: False
epsilon_attn_pool: 1e-5
weight_init:
mean: 0.0
std: 0.02

scheduler:
component_key: scheduler
Expand All @@ -229,7 +242,7 @@ scheduler:
max_lr: 6e-4
div_factor: 10
final_div_factor: 1
total_steps: 4
total_steps: 64
pct_start: 0.01
anneal_strategy: cos

Expand Down
1 change: 1 addition & 0 deletions config_files/training/config_example_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
1 change: 1 addition & 0 deletions config_files/training/config_example_openGPTx_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled

scheduler:
type_hint: StepLR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
32 changes: 23 additions & 9 deletions config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,24 @@ wrapped_model:
sharding_strategy: FULL_SHARD
block_names: [GPT2Block]

model:
model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: composed
config:
model_type: gpt2
weight_init_type: scaled_embed
mean: 0.0
std: 0.02
num_layers: ${model_raw.config.n_layer}

model_raw:
component_key: model
variant_key: gpt2
config:
Expand All @@ -198,33 +215,30 @@ model:
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model.config.n_embd}
n_head: ${model.config.n_head_q} #it has to be head_q here
n_embd: ${model_raw.config.n_embd}
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
seq_length_dim: -2
attention_implementation: manual
activation_type: gelu
weight_init:
mean: 0.0
std: 0.02
attention_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5
ffn_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5
lm_head_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5

Expand Down
1 change: 1 addition & 0 deletions examples/getting_started/example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ raw_model:
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
1 change: 1 addition & 0 deletions examples/library_usage/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ model:
weight_init:
mean: 0.0
std: 0.02
type: scaled

optimizer:
component_key: optimizer
Expand Down
6 changes: 6 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PydanticCollateFnIFType,
PydanticDatasetIFType,
PydanticLLMDataLoaderIFType,
PydanticModelInitializationIFType,
PydanticOptimizerIFType,
PydanticPytorchDeviceType,
PydanticPytorchModuleType,
Expand Down Expand Up @@ -228,6 +229,11 @@ def parse_sharding_strategy_by_name(cls, name):
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)


class WeightInitializedModelConfig(BaseModel):
model: PydanticPytorchModuleType
model_initializer: PydanticModelInitializationIFType


class PreTrainedHFTokenizerConfig(BaseModel):
pretrained_model_name_or_path: str
max_length: Annotated[int, Field(strict=True, ge=0)]
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/pydanctic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.loss_functions import Loss
from modalities.models.gpt2.collator import CollateFnIF
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF

Expand Down Expand Up @@ -61,3 +62,4 @@ def __get_pydantic_core_schema__(
PydanticPytorchDeviceType = Annotated[torch.device, PydanticThirdPartyTypeIF(torch.device)]
PydanticTextInferenceComponentType = Annotated[TextInferenceComponent, PydanticThirdPartyTypeIF(TextInferenceComponent)]
PydanticGradientClipperIFType = Annotated[GradientClipperIF, PydanticThirdPartyTypeIF(GradientClipperIF)]
PydanticModelInitializationIFType = Annotated[ModelInitializationIF, PydanticThirdPartyTypeIF(ModelInitializationIF)]
27 changes: 1 addition & 26 deletions src/modalities/models/coca/coca_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import math
from functools import partial
from typing import Annotated, Dict, Tuple

import torch
Expand All @@ -10,8 +8,7 @@
from modalities.models.coca.attention_pooling import AttentionPooling
from modalities.models.coca.multi_modal_decoder import MultiModalTextDecoder
from modalities.models.coca.text_decoder import TextDecoder
from modalities.models.gpt2.gpt2_model import ActivationType, WeightInitializationConfig
from modalities.models.model import NNModel
from modalities.models.model import ActivationType, NNModel
from modalities.models.vision_transformer.vision_transformer_model import VisionTransformer, VisionTransformerConfig
from modalities.nn.attention import AttentionConfig

Expand Down Expand Up @@ -45,7 +42,6 @@ class CoCaConfig(BaseModel):
n_vision_queries: Annotated[int, Field(ge=1)]
bias_attn_pool: bool
epsilon_attn_pool: Annotated[float, Field(ge=0.0)]
weight_init: WeightInitializationConfig


class CoCa(NNModel):
Expand All @@ -71,7 +67,6 @@ def __init__(
epsilon_attn_pool: float,
vision_encoder_config: VisionTransformerConfig,
text_decoder_config: TextDecoderConfig,
weight_init: WeightInitializationConfig,
) -> None:
super().__init__()
self.prediction_key = prediction_key
Expand Down Expand Up @@ -126,26 +121,6 @@ def __init__(
attention_config=text_decoder_config.attention_config,
)

# init all weights
self.apply(partial(self._init_weights, weight_init=weight_init))
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(
p,
mean=weight_init.mean,
std=weight_init.std
/ math.sqrt(2 * (text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text)),
)

def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
vision_embd, vision_cls_token = self._forward_encode_vision(inputs)
text_embd, text_cls_token = self._forward_encode_text(inputs)
Expand Down
30 changes: 1 addition & 29 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
from copy import deepcopy
from enum import Enum
from functools import partial
from typing import Annotated, Dict, List, Tuple

import torch
Expand All @@ -16,7 +15,7 @@

from modalities.config.pydanctic_if_types import PydanticPytorchModuleType
from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.model import NNModel
from modalities.models.model import ActivationType, NNModel
from modalities.util import parse_enum_by_name

# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
Expand Down Expand Up @@ -108,11 +107,6 @@ class QueryKeyValueTransformType(Enum):
RotaryTransform = RotaryTransform


class ActivationType(str, Enum):
GELU = "gelu"
FUSED_SWIGLU = "fused_swiglu"


class AttentionImplementation(str, Enum):
MANUAL = "manual"
PYTORCH_FLASH = "pytorch_flash"
Expand All @@ -139,11 +133,6 @@ def parse_sharding_strategy_by_name(cls, name):
qkv_transforms: List[QueryKeyValueTransformConfig]


class WeightInitializationConfig(BaseModel):
mean: Annotated[float, Field(strict=True, ge=0.0)]
std: Annotated[float, Field(strict=True, ge=0.0)]


class GPT2LLMConfig(BaseModel):
sample_key: str
prediction_key: str
Expand All @@ -165,7 +154,6 @@ class GPT2LLMConfig(BaseModel):
attention_norm: PydanticPytorchModuleType
ffn_norm: PydanticPytorchModuleType
lm_head_norm: PydanticPytorchModuleType
weight_init: WeightInitializationConfig

@model_validator(mode="after")
def check_divisibility(self) -> "GPT2LLMConfig":
Expand Down Expand Up @@ -430,7 +418,6 @@ def __init__(
bias: bool,
activation_type: ActivationType,
attention_implementation: AttentionImplementation,
weight_init: WeightInitializationConfig,
attention_config: AttentionConfig,
attention_norm: nn.Module,
ffn_norm: nn.Module,
Expand Down Expand Up @@ -500,21 +487,6 @@ def __init__(
# not 100% sure what this is, so far seems to be harmless. TODO investigate
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

# init all weights
self.apply(partial(self._init_weights, weight_init=weight_init))
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=weight_init.mean, std=weight_init.std / math.sqrt(2 * n_layer))

def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)

def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = inputs[self.sample_key]
device = input_ids.device
Expand Down
Loading
Loading