Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Feat/initialization component #168

Merged
merged 42 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
53616e0
feat: added Weight Initialization Factory
le1nux Jun 30, 2024
03b7a7f
feat: implemented model-wise and named-parameter-wise initalization c…
le1nux Jun 30, 2024
d2ecf0d
feat: drafted the init configs
le1nux Jun 30, 2024
0efd32d
chore: added missing init file
le1nux Jun 30, 2024
dcafa38
refactor: removed previous initialization code in the abstract model …
le1nux Jun 30, 2024
83ce570
feat: added initialization config classes
le1nux Jun 30, 2024
ce48a57
refactor: moved WeightInitializationIF to separate file to prevent ci…
le1nux Jun 30, 2024
a08d4d9
feat: added WeightInitializerWrapperConfig to config
le1nux Jun 30, 2024
c39eabd
feat: added WeightInitializerWrapper
le1nux Jun 30, 2024
2af2625
feat: wired up all the weight initializers as components
le1nux Jun 30, 2024
a950746
feat: added functionaliy to initalize model weights
le1nux Jun 30, 2024
416831a
feat: added the weight init to config lorem ipsum
le1nux Jun 30, 2024
86308d4
refactor: removed raising an excpetion when module is not covered for…
le1nux Jun 30, 2024
d2f2e15
refactor: fixed pydantic model_validator in PlainWeightInitialization…
le1nux Jun 30, 2024
2dc37e2
refactor: removed all weight init code from the models
le1nux Jun 30, 2024
f7d6d0d
chore: fixed typo
le1nux Jun 30, 2024
8fcaf08
refactor: finalized config lorem ipsum for weight initialization
le1nux Jun 30, 2024
fe2fb49
refactor: moved configs back to the init factory and removed not need…
le1nux Jul 1, 2024
7164482
feat: added HighLevelWeightInitializationFactory
le1nux Jul 1, 2024
327dbd8
feat: wired up HighLevelWeightInitializationFactory
le1nux Jul 1, 2024
89d5a29
refactor: config_lorem_ipsum.yaml now suppports HighLevelWeightInitia…
le1nux Jul 1, 2024
010f362
refactor: mean of weight init constraint is now relaxed to float-only
le1nux Jul 1, 2024
c0a9aa5
chore: added reference citation to the weigh init for gpt2
le1nux Jul 1, 2024
7d26675
fix: replaced std of 0.4 by math.sqrt(0.4) for scaled_embed
le1nux Jul 1, 2024
bb4cd36
chore: added citation
le1nux Jul 1, 2024
48dced9
refactor: added edge case handling when module is not of type linear …
le1nux Jul 2, 2024
93a0cb4
refactor: removed ModulewiseNormalInitialization and extended NamedPa…
le1nux Jul 2, 2024
c2817f7
refactor: removed old weigh_init from coca config
le1nux Jul 2, 2024
98fe633
refactor: all weight init code is now based on regex matching on name…
le1nux Jul 2, 2024
b8e42c1
feat: added plain weight init for CoCa
le1nux Jul 2, 2024
054d6a2
refactor: renamed plain_std -> std
le1nux Jul 2, 2024
602336c
refactor: renamed plain_std -> std (missed one instance)
le1nux Jul 2, 2024
824da6a
refactor: some renamings and passing now the calculated std from plai…
le1nux Jul 2, 2024
4573364
refactor: removed legacy NamedParameterwiseNormalInitializationConfig
le1nux Jul 2, 2024
0e29447
refactor: simplified the initialization structure and improved naming
le1nux Jul 2, 2024
eea99bc
refactor: for plain init we do not allow hidde_dim to be specified wh…
le1nux Jul 3, 2024
d7c0825
refactor: removed currently not needed init routines from component r…
le1nux Jul 3, 2024
2c50e06
test: initialization unit test adjustments (first steps)
flxst Jul 3, 2024
96c0657
test: initialization unit test adjustments (config fix)
flxst Jul 3, 2024
d3c1683
test: fix coca config
flxst Jul 3, 2024
1a61bf6
feat: added plain filters for coca
le1nux Jul 3, 2024
003eb99
test: readd test for initialization
flxst Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,25 @@ wrapped_model:
sharding_strategy: FULL_SHARD
block_names: [GPT2Block]

model:
model:
component_key: model
variant_key: weight_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
weight_initializer:
component_key: weight_initialization
variant_key: composed
config:
model_type: gpt2
weight_init_type: scaled_embed
mean: 0.0
plain_std: 0.02
le1nux marked this conversation as resolved.
Show resolved Hide resolved
hidden_dim: ${model_raw.config.n_embd}
flxst marked this conversation as resolved.
Show resolved Hide resolved
num_layers: ${model_raw.config.n_layer}
le1nux marked this conversation as resolved.
Show resolved Hide resolved

model_raw:
component_key: model
variant_key: gpt2
config:
Expand All @@ -198,34 +216,30 @@ model:
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model.config.n_embd}
n_head: ${model.config.n_head_q} #it has to be head_q here
n_embd: ${model_raw.config.n_embd}
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
seq_length_dim: -2
attention_implementation: manual
activation_type: gelu
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5
ffn_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5
lm_head_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5

Expand Down
6 changes: 6 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PydanticPytorchModuleType,
PydanticSamplerIFType,
PydanticTokenizerIFType,
PydanticWeightInitializationIFType,
)
from modalities.config.utils import parse_torch_device
from modalities.running_env.env_utils import MixedPrecisionSettings, has_bfloat_support
Expand Down Expand Up @@ -228,6 +229,11 @@ def parse_sharding_strategy_by_name(cls, name):
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)


class WeightInitializedModelConfig(BaseModel):
model: PydanticPytorchModuleType
weight_initializer: PydanticWeightInitializationIFType


class PreTrainedHFTokenizerConfig(BaseModel):
pretrained_model_name_or_path: str
max_length: Annotated[int, Field(strict=True, ge=0)]
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/pydanctic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.loss_functions import Loss
from modalities.models.gpt2.collator import CollateFnIF
from modalities.nn.weight_init.weight_init_if import WeightInitializationIF
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF

Expand Down Expand Up @@ -61,3 +62,4 @@ def __get_pydantic_core_schema__(
PydanticPytorchDeviceType = Annotated[torch.device, PydanticThirdPartyTypeIF(torch.device)]
PydanticTextInferenceComponentType = Annotated[TextInferenceComponent, PydanticThirdPartyTypeIF(TextInferenceComponent)]
PydanticGradientClipperIFType = Annotated[GradientClipperIF, PydanticThirdPartyTypeIF(GradientClipperIF)]
PydanticWeightInitializationIFType = Annotated[WeightInitializationIF, PydanticThirdPartyTypeIF(WeightInitializationIF)]
12 changes: 1 addition & 11 deletions src/modalities/models/coca/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from modalities.models.coca.attention_pooling import AttentionPooling
from modalities.models.coca.multi_modal_decoder import MultiModalTextDecoder
from modalities.models.coca.text_decoder import TextDecoder
from modalities.models.model import ActivationType, NNModel, WeightInitializationConfig
from modalities.models.model import ActivationType, NNModel
from modalities.models.vision_transformer.vision_transformer_model import VisionTransformer, VisionTransformerConfig
from modalities.nn.attention import AttentionConfig

Expand Down Expand Up @@ -42,7 +42,6 @@ class CoCaConfig(BaseModel):
n_vision_queries: Annotated[int, Field(ge=1)]
bias_attn_pool: bool
epsilon_attn_pool: Annotated[float, Field(ge=0.0)]
weight_init: WeightInitializationConfig


class CoCa(NNModel):
Expand All @@ -68,7 +67,6 @@ def __init__(
epsilon_attn_pool: float,
vision_encoder_config: VisionTransformerConfig,
text_decoder_config: TextDecoderConfig,
weight_init: WeightInitializationConfig,
) -> None:
super().__init__()
self.prediction_key = prediction_key
Expand Down Expand Up @@ -123,14 +121,6 @@ def __init__(
attention_config=text_decoder_config.attention_config,
)

# init all weights
assert weight_init.type in ["plain", "scaled"], f"ERROR! weight_init.type = {weight_init.type} not implemented."
self.initialize_weights(
flxst marked this conversation as resolved.
Show resolved Hide resolved
weight_init,
number_of_layers=text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text,
hidden_dim=None, # not well-defined as hidden_dim can be different for the text and multimodal decoder
)

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
vision_embd, vision_cls_token = self._forward_encode_vision(inputs)
text_embd, text_cls_token = self._forward_encode_text(inputs)
Expand Down
7 changes: 1 addition & 6 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from modalities.config.pydanctic_if_types import PydanticPytorchModuleType
from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.model import ActivationType, NNModel, WeightInitializationConfig
from modalities.models.model import ActivationType, NNModel
from modalities.util import parse_enum_by_name

# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
Expand Down Expand Up @@ -154,7 +154,6 @@ class GPT2LLMConfig(BaseModel):
attention_norm: PydanticPytorchModuleType
ffn_norm: PydanticPytorchModuleType
lm_head_norm: PydanticPytorchModuleType
weight_init: WeightInitializationConfig

@model_validator(mode="after")
def check_divisibility(self) -> "GPT2LLMConfig":
Expand Down Expand Up @@ -419,7 +418,6 @@ def __init__(
bias: bool,
activation_type: ActivationType,
attention_implementation: AttentionImplementation,
weight_init: WeightInitializationConfig,
attention_config: AttentionConfig,
attention_norm: nn.Module,
ffn_norm: nn.Module,
Expand Down Expand Up @@ -489,9 +487,6 @@ def __init__(
# not 100% sure what this is, so far seems to be harmless. TODO investigate
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

# init all weights
self.initialize_weights(weight_init, number_of_layers=n_layer, hidden_dim=n_embd)

def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = inputs[self.sample_key]
device = input_ids.device
Expand Down
47 changes: 1 addition & 46 deletions src/modalities/models/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import math
from abc import abstractmethod
from enum import Enum
from functools import partial
from typing import Annotated, Dict, List, Optional
from typing import Dict, List, Optional

import torch
import torch.nn as nn
from pydantic import BaseModel, Field

from modalities.batch import DatasetBatch, InferenceResultBatch

Expand All @@ -18,12 +15,6 @@ class ActivationType(str, Enum):
FUSED_SWIGLU = "fused_swiglu"


class WeightInitializationConfig(BaseModel):
mean: Annotated[float, Field(strict=True, ge=0.0)]
std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto"
type: str


class NNModel(nn.Module):
def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
if seed is not None:
Expand All @@ -42,42 +33,6 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def get_parameters(self) -> Dict[str, torch.Tensor]:
return {name: param for name, param in self.named_parameters()}

def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)

def initialize_weights(
self, weight_init: WeightInitializationConfig, number_of_layers: int, hidden_dim: Optional[int] = None
):
# auto: choose std automatically
if weight_init.std == "auto":
assert hidden_dim is not None, "ERROR! weight_init.std = auto not implemented"
weight_init.std = math.sqrt(2 / (5 * hidden_dim))

# initialize weights
self.apply(partial(self._init_weights, weight_init=weight_init))

if weight_init.type == "plain":
pass # nothing more to do
elif weight_init.type in ["scaled", "scaled_embed"]:
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(
p, mean=weight_init.mean, std=weight_init.std / math.sqrt(2 * number_of_layers)
)
if weight_init.type == "scaled_embed":
# apply special init to embeddings, see https://arxiv.org/abs/2312.16903
for pn, p in self.named_parameters():
if pn.endswith("wte.weight") or pn.endswith("wpe.weight"):
torch.nn.init.normal_(p, mean=weight_init.mean, std=0.4)
else:
raise Exception(f"ERROR! weight_init.type = {weight_init.type} not implemented.")


def model_predict_batch(model: nn.Module, batch: DatasetBatch) -> InferenceResultBatch:
forward_result = model.forward(batch.samples)
Expand Down
6 changes: 6 additions & 0 deletions src/modalities/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.distributed.fsdp import ShardingStrategy

from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF
from modalities.nn.weight_init.weight_init_if import WeightInitializationIF
from modalities.running_env.env_utils import MixedPrecisionSettings
from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory
from modalities.util import compute_number_of_trainable_parameters
Expand Down Expand Up @@ -55,3 +56,8 @@ def get_fsdp_wrapped_model(
)

return fsdp_model

@staticmethod
def get_weight_initalized_model(model: nn.Module, weight_initializer: WeightInitializationIF) -> nn.Module:
weight_initializer.initialize_in_place(model)
return model
1 change: 1 addition & 0 deletions src/modalities/nn/weight_init/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading
Loading