Skip to content

Commit

Permalink
Merge pull request #168 from Modalities/feat/initialization_component
Browse files Browse the repository at this point in the history
Draft: Feat/initialization component
  • Loading branch information
le1nux committed Jul 3, 2024
2 parents ec520f4 + 003eb99 commit 9f5651b
Show file tree
Hide file tree
Showing 18 changed files with 797 additions and 157 deletions.
28 changes: 20 additions & 8 deletions config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ train_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
num_samples: 64
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
Expand All @@ -55,7 +55,7 @@ val_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
num_samples: 32
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
Expand Down Expand Up @@ -170,7 +170,23 @@ wrapped_model:
sharding_strategy: FULL_SHARD
block_names: [TransformerBlock, VisionTransformerBlock]

model:
model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: composed
config:
model_type: coca
weight_init_type: plain
mean: 0.0
std: 0.02

model_raw:
component_key: model
variant_key: coca
config:
Expand Down Expand Up @@ -215,10 +231,6 @@ model:
n_vision_queries: 256
bias_attn_pool: False
epsilon_attn_pool: 1e-5
weight_init:
mean: 0.0
std: 0.02
type: scaled

scheduler:
component_key: scheduler
Expand All @@ -230,7 +242,7 @@ scheduler:
max_lr: 6e-4
div_factor: 10
final_div_factor: 1
total_steps: 4
total_steps: 64
pct_start: 0.01
anneal_strategy: cos

Expand Down
33 changes: 23 additions & 10 deletions config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,24 @@ wrapped_model:
sharding_strategy: FULL_SHARD
block_names: [GPT2Block]

model:
model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: composed
config:
model_type: gpt2
weight_init_type: scaled_embed
mean: 0.0
std: 0.02
num_layers: ${model_raw.config.n_layer}

model_raw:
component_key: model
variant_key: gpt2
config:
Expand All @@ -198,34 +215,30 @@ model:
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model.config.n_embd}
n_head: ${model.config.n_head_q} #it has to be head_q here
n_embd: ${model_raw.config.n_embd}
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
seq_length_dim: -2
attention_implementation: manual
activation_type: gelu
weight_init:
mean: 0.0
std: 0.02
type: scaled
attention_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5
ffn_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5
lm_head_norm:
component_key: layer_norm
variant_key: rms_norm
config:
ndim: ${model.config.n_embd}
ndim: ${model_raw.config.n_embd}
bias: true
epsilon: 1e-5

Expand Down
6 changes: 6 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PydanticCollateFnIFType,
PydanticDatasetIFType,
PydanticLLMDataLoaderIFType,
PydanticModelInitializationIFType,
PydanticOptimizerIFType,
PydanticPytorchDeviceType,
PydanticPytorchModuleType,
Expand Down Expand Up @@ -228,6 +229,11 @@ def parse_sharding_strategy_by_name(cls, name):
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)


class WeightInitializedModelConfig(BaseModel):
model: PydanticPytorchModuleType
model_initializer: PydanticModelInitializationIFType


class PreTrainedHFTokenizerConfig(BaseModel):
pretrained_model_name_or_path: str
max_length: Annotated[int, Field(strict=True, ge=0)]
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/pydanctic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.loss_functions import Loss
from modalities.models.gpt2.collator import CollateFnIF
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF

Expand Down Expand Up @@ -61,3 +62,4 @@ def __get_pydantic_core_schema__(
PydanticPytorchDeviceType = Annotated[torch.device, PydanticThirdPartyTypeIF(torch.device)]
PydanticTextInferenceComponentType = Annotated[TextInferenceComponent, PydanticThirdPartyTypeIF(TextInferenceComponent)]
PydanticGradientClipperIFType = Annotated[GradientClipperIF, PydanticThirdPartyTypeIF(GradientClipperIF)]
PydanticModelInitializationIFType = Annotated[ModelInitializationIF, PydanticThirdPartyTypeIF(ModelInitializationIF)]
12 changes: 1 addition & 11 deletions src/modalities/models/coca/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from modalities.models.coca.attention_pooling import AttentionPooling
from modalities.models.coca.multi_modal_decoder import MultiModalTextDecoder
from modalities.models.coca.text_decoder import TextDecoder
from modalities.models.model import ActivationType, NNModel, WeightInitializationConfig
from modalities.models.model import ActivationType, NNModel
from modalities.models.vision_transformer.vision_transformer_model import VisionTransformer, VisionTransformerConfig
from modalities.nn.attention import AttentionConfig

Expand Down Expand Up @@ -42,7 +42,6 @@ class CoCaConfig(BaseModel):
n_vision_queries: Annotated[int, Field(ge=1)]
bias_attn_pool: bool
epsilon_attn_pool: Annotated[float, Field(ge=0.0)]
weight_init: WeightInitializationConfig


class CoCa(NNModel):
Expand All @@ -68,7 +67,6 @@ def __init__(
epsilon_attn_pool: float,
vision_encoder_config: VisionTransformerConfig,
text_decoder_config: TextDecoderConfig,
weight_init: WeightInitializationConfig,
) -> None:
super().__init__()
self.prediction_key = prediction_key
Expand Down Expand Up @@ -123,14 +121,6 @@ def __init__(
attention_config=text_decoder_config.attention_config,
)

# init all weights
assert weight_init.type in ["plain", "scaled"], f"ERROR! weight_init.type = {weight_init.type} not implemented."
self.initialize_weights(
weight_init,
number_of_layers=text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text,
hidden_dim=None, # not well-defined as hidden_dim can be different for the text and multimodal decoder
)

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
vision_embd, vision_cls_token = self._forward_encode_vision(inputs)
text_embd, text_cls_token = self._forward_encode_text(inputs)
Expand Down
7 changes: 1 addition & 6 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from modalities.config.pydanctic_if_types import PydanticPytorchModuleType
from modalities.config.utils import convert_base_model_config_to_dict
from modalities.models.model import ActivationType, NNModel, WeightInitializationConfig
from modalities.models.model import ActivationType, NNModel
from modalities.util import parse_enum_by_name

# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
Expand Down Expand Up @@ -154,7 +154,6 @@ class GPT2LLMConfig(BaseModel):
attention_norm: PydanticPytorchModuleType
ffn_norm: PydanticPytorchModuleType
lm_head_norm: PydanticPytorchModuleType
weight_init: WeightInitializationConfig

@model_validator(mode="after")
def check_divisibility(self) -> "GPT2LLMConfig":
Expand Down Expand Up @@ -419,7 +418,6 @@ def __init__(
bias: bool,
activation_type: ActivationType,
attention_implementation: AttentionImplementation,
weight_init: WeightInitializationConfig,
attention_config: AttentionConfig,
attention_norm: nn.Module,
ffn_norm: nn.Module,
Expand Down Expand Up @@ -489,9 +487,6 @@ def __init__(
# not 100% sure what this is, so far seems to be harmless. TODO investigate
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

# init all weights
self.initialize_weights(weight_init, number_of_layers=n_layer, hidden_dim=n_embd)

def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = inputs[self.sample_key]
device = input_ids.device
Expand Down
47 changes: 1 addition & 46 deletions src/modalities/models/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import math
from abc import abstractmethod
from enum import Enum
from functools import partial
from typing import Annotated, Dict, List, Optional
from typing import Dict, List, Optional

import torch
import torch.nn as nn
from pydantic import BaseModel, Field

from modalities.batch import DatasetBatch, InferenceResultBatch

Expand All @@ -18,12 +15,6 @@ class ActivationType(str, Enum):
FUSED_SWIGLU = "fused_swiglu"


class WeightInitializationConfig(BaseModel):
mean: Annotated[float, Field(strict=True, ge=0.0)]
std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto"
type: str


class NNModel(nn.Module):
def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
if seed is not None:
Expand All @@ -42,42 +33,6 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def get_parameters(self) -> Dict[str, torch.Tensor]:
return {name: param for name, param in self.named_parameters()}

def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std)

def initialize_weights(
self, weight_init: WeightInitializationConfig, number_of_layers: int, hidden_dim: Optional[int] = None
):
# auto: choose std automatically
if weight_init.std == "auto":
assert hidden_dim is not None, "ERROR! weight_init.std = auto not implemented"
weight_init.std = math.sqrt(2 / (5 * hidden_dim))

# initialize weights
self.apply(partial(self._init_weights, weight_init=weight_init))

if weight_init.type == "plain":
pass # nothing more to do
elif weight_init.type in ["scaled", "scaled_embed"]:
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(
p, mean=weight_init.mean, std=weight_init.std / math.sqrt(2 * number_of_layers)
)
if weight_init.type == "scaled_embed":
# apply special init to embeddings, see https://arxiv.org/abs/2312.16903
for pn, p in self.named_parameters():
if pn.endswith("wte.weight") or pn.endswith("wpe.weight"):
torch.nn.init.normal_(p, mean=weight_init.mean, std=0.4)
else:
raise Exception(f"ERROR! weight_init.type = {weight_init.type} not implemented.")


def model_predict_batch(model: nn.Module, batch: DatasetBatch) -> InferenceResultBatch:
forward_result = model.forward(batch.samples)
Expand Down
6 changes: 6 additions & 0 deletions src/modalities/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.distributed.fsdp import ShardingStrategy

from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.running_env.env_utils import MixedPrecisionSettings
from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory
from modalities.util import compute_number_of_trainable_parameters
Expand Down Expand Up @@ -55,3 +56,8 @@ def get_fsdp_wrapped_model(
)

return fsdp_model

@staticmethod
def get_weight_initalized_model(model: nn.Module, model_initializer: ModelInitializationIF) -> nn.Module:
model_initializer.initialize_in_place(model)
return model
1 change: 1 addition & 0 deletions src/modalities/nn/model_initialization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading

0 comments on commit 9f5651b

Please sign in to comment.