Skip to content

Commit

Permalink
Merge pull request #77 from Modalities/lr_schedulers
Browse files Browse the repository at this point in the history
feat: integrate LR schedulers
  • Loading branch information
mali-git authored Mar 18, 2024
2 parents cd5ef85 + 8a4702d commit b00e2a6
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 41 deletions.
24 changes: 18 additions & 6 deletions config_files/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,29 @@ model:
ndim: ${model.config.n_embd}
bias: true
epsilon: 1e-5
# scheduler:
# type_hint: StepLR
# config:
# step_size: 1
# gamma: 0.1

scheduler:
component_key: scheduler
variant_key: onecycle_lr
config:
optimizer:
instance_key: optimizer
pass_type: BY_REFERENCE
max_lr: 6e-4
div_factor: 10
final_div_factor: 1
total_steps: 4
pct_start: 0.01
anneal_strategy: cos

optimizer:
component_key: optimizer
variant_key: adam_w
config:
lr: 0.0001
betas: [0.9, 0.95]
eps: 1e-8
weight_decay: 1e-1
wrapped_model:
instance_key: wrapped_model
pass_type: BY_REFERENCE
Expand Down Expand Up @@ -270,7 +282,7 @@ evaluation_subscriber:
config:
local_rank: ${settings.cuda_env.local_rank}
project: modalities
mode: OFFLINE
mode: ONLINE
experiment_id: ${settings.experiment_id}
directory: "."

1 change: 1 addition & 0 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def run(self):
checkpointing=components.checkpointing,
model=wrapped_model,
optimizer=components.optimizer,
scheduler=components.scheduler,
)
print("done")

Expand Down
70 changes: 68 additions & 2 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from pathlib import Path
from typing import Annotated, Any, Dict, List, Optional
from typing import Annotated, Any, Dict, List, Optional, Tuple

import torch.nn as nn
from omegaconf import OmegaConf
from pydantic import BaseModel, Field, FilePath, GetCoreSchemaHandler, PositiveInt, field_validator
from pydantic import BaseModel, Field, FilePath, GetCoreSchemaHandler, PositiveInt, field_validator, model_validator
from pydantic_core import core_schema
from torch.distributed.fsdp import ShardingStrategy
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Sampler
from torch.utils.data.dataset import Dataset
from transformers import GPT2TokenizerFast
Expand Down Expand Up @@ -59,6 +60,7 @@ def __get_pydantic_core_schema__(
PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)]
PydanticLLMDataLoaderIFType = Annotated[LLMDataLoader, PydanticThirdPartyTypeIF(LLMDataLoader)]
PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)]
PydanticLRSchedulerIFType = Annotated[LRScheduler, PydanticThirdPartyTypeIF(LRScheduler)]
PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)]
PydanticMessageSubscriberIFType = Annotated[MessageSubscriberIF, PydanticThirdPartyTypeIF(MessageSubscriberIF)]

Expand Down Expand Up @@ -132,9 +134,72 @@ class CheckpointingConfig(BaseModel):
checkpointing_execution: PydanticCheckpointingExecutionIFType


class AdamOptimizerConfig(BaseModel):
lr: float
wrapped_model: PydanticPytorchModuleType
betas: Tuple[float, float]
eps: float
weight_decay: float


class AdamWOptimizerConfig(BaseModel):
lr: float
wrapped_model: PydanticPytorchModuleType
betas: Tuple[float, float]
eps: float
weight_decay: float


class StepLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
step_size: Annotated[int, Field(strict=True, gt=0)]
gamma: Annotated[float, Field(strict=True, ge=0.0)]
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
verbose: bool = False


class OneCycleLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
max_lr: Annotated[float, Field(strict=True, gt=0.0)] | List[Annotated[float, Field(strict=True, gt=0.0)]]
total_steps: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
epochs: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)]
anneal_strategy: str
cycle_momentum: bool = True
base_momentum: Annotated[float, Field(strict=True, gt=0)] | List[
Annotated[float, Field(strict=True, gt=0.0)]
] = 0.85
max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | List[
Annotated[float, Field(strict=True, gt=0.0)]
] = 0.95
div_factor: Annotated[float, Field(strict=True, gt=0.0)]
final_div_factor: Annotated[float, Field(strict=True, gt=0.0)]
three_phase: bool = False
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
verbose: bool = False

@model_validator(mode="after")
def check_totals_steps_and_epchs(self) -> "OneCycleLRSchedulerConfig":
if self.total_steps is None and (self.epochs is None or self.steps_per_epoch is None):
raise ValueError("Please define total_steps or (epochs and steps_per_epoch).")
return self


class ConstantLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
factor: Annotated[float, Field(strict=True, ge=0.0, le=1.0)]
total_iters: Annotated[int, Field(strict=True, gt=0)]
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
verbose: bool = False


class CosineAnnealingLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
t_max: Annotated[int, Field(strict=True, gt=0)]
eta_min: Annotated[float, Field(strict=True, ge=0.0)]
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1
verbose: bool = False


class CheckpointedOptimizerConfig(BaseModel):
Expand Down Expand Up @@ -305,6 +370,7 @@ class Paths(BaseModel):
class ComponentsModel(BaseModel):
wrapped_model: PydanticPytorchModuleType
optimizer: PydanticOptimizerIFType
scheduler: PydanticLRSchedulerIFType
loss_fn: PydanticLossIFType
train_dataloader: PydanticLLMDataLoaderIFType
eval_dataloaders: List[PydanticLLMDataLoaderIFType]
Expand Down
3 changes: 3 additions & 0 deletions src/modalities/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from modalities.checkpointing.checkpointing import Checkpointing
from modalities.dataloader.dataloader import LLMDataLoader
Expand All @@ -22,6 +23,7 @@ def run(
self,
model: NNModel,
optimizer: Optimizer,
scheduler: LRScheduler,
callback_interval_in_batches: int,
train_data_loader: LLMDataLoader,
evaluation_data_loaders: List[LLMDataLoader],
Expand All @@ -42,6 +44,7 @@ def run(
train_loader=train_data_loader,
loss_fun=self.loss_fun,
optimizer=optimizer,
scheduler=scheduler,
callback_interval_in_batches=callback_interval_in_batches,
epoch_done_callback=partial( # TODO rename to something more meaningful
self._run_evaluation_and_checkpointing,
Expand Down
20 changes: 16 additions & 4 deletions src/modalities/optimizers/optimizer_factory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from typing import Tuple

import torch.nn as nn
from torch.optim import AdamW, Optimizer
from torch.optim import Adam, AdamW, Optimizer

from modalities.checkpointing.checkpointing import Checkpointing


class OptimizerFactory:
@staticmethod
def get_adam_w(lr: float, wrapped_model: nn.Module):
def get_adam(
lr: float, betas: Tuple[float, float], eps: float, weight_decay: float, wrapped_model: nn.Module
) -> Optimizer:
model_parameters = wrapped_model.parameters()
optimizer = Adam(params=model_parameters, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
return optimizer

@staticmethod
def get_adam_w(
lr: float, betas: Tuple[float, float], eps: float, weight_decay: float, wrapped_model: nn.Module
) -> Optimizer:
model_parameters = wrapped_model.parameters()
optimizer = AdamW(params=model_parameters, lr=lr)
optimizer = AdamW(params=model_parameters, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
return optimizer

@staticmethod
def get_checkpointed_optimizer(
checkpointing: Checkpointing, checkpoint_path, wrapped_model: nn.Module, optimizer: Optimizer
):
) -> Optimizer:
wrapped_optimizer = checkpointing.load_optimizer_checkpoint(
file_path=checkpoint_path, optimizer=optimizer, wrapped_model=wrapped_model
)
Expand Down
16 changes: 13 additions & 3 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Type

import torch
import torch.nn as nn
from pydantic import BaseModel
from torch.utils.data import BatchSampler, DistributedSampler
Expand All @@ -13,12 +14,15 @@
SaveKMostRecentCheckpointsStrategy,
)
from modalities.config.config import (
AdamOptimizerConfig,
AdamWOptimizerConfig,
BatchSamplerConfig,
CheckpointedModelConfig,
CheckpointedOptimizerConfig,
CheckpointingConfig,
CLMCrossEntropyLossConfig,
ConstantLRSchedulerConfig,
CosineAnnealingLRSchedulerConfig,
DistributedSamplerConfig,
DummyProgressSubscriberConfig,
DummyResultSubscriberConfig,
Expand All @@ -28,13 +32,15 @@
GPT2TokenizerFastConfig,
LLMDataLoaderConfig,
MemMapDatasetConfig,
OneCycleLRSchedulerConfig,
OpenGPTXMMapDatasetConfig,
PackedMemMapDatasetContinuousConfig,
PackedMemMapDatasetMegatronConfig,
RichProgressSubscriberConfig,
RichResultSubscriberConfig,
SaveEveryKStepsCheckpointingStrategyConfig,
SaveKMostRecentCheckpointsStrategyConfig,
StepLRSchedulerConfig,
WandBEvaluationResultSubscriberConfig,
)
from modalities.dataloader.dataloader_factory import DataloaderFactory
Expand Down Expand Up @@ -74,14 +80,18 @@ class ComponentEntity:
# losses
ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig),
# optmizers
ComponentEntity("optimizer", "adam", OptimizerFactory.get_adam, AdamOptimizerConfig),
ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig),
ComponentEntity(
"optimizer", "checkpointed", OptimizerFactory.get_checkpointed_optimizer, CheckpointedOptimizerConfig
),
# schedulers
# ComponentEntity("scheduler", "step_lr", torch.optim.lr_scheduler.StepLR, None), # TODO
# ComponentEntity("scheduler", "constant_lr", torch.optim.lr_scheduler.ConstantLR, None), # TODO
# ComponentEntity("scheduler", "onecycle_lr", torch.optim.lr_scheduler.OneCycleLR, None), # TODO
ComponentEntity("scheduler", "step_lr", torch.optim.lr_scheduler.StepLR, StepLRSchedulerConfig),
ComponentEntity("scheduler", "constant_lr", torch.optim.lr_scheduler.ConstantLR, ConstantLRSchedulerConfig),
ComponentEntity("scheduler", "onecycle_lr", torch.optim.lr_scheduler.OneCycleLR, OneCycleLRSchedulerConfig),
ComponentEntity(
"scheduler", "cosine_annealing_lr", torch.optim.lr_scheduler.CosineAnnealingLR, CosineAnnealingLRSchedulerConfig
),
# tokenizers
ComponentEntity("tokenizer", "gpt2_tokenizer_fast", GPT2TokenizerFast, GPT2TokenizerFastConfig),
# ComponentEntity("tokenizer", "llama_tokenizer_fast", GPT2TokenizerFast, None), # TODO
Expand Down
7 changes: 6 additions & 1 deletion src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from modalities.batch import DatasetBatch, EvaluationResultBatch
from modalities.dataloader.dataloader import LLMDataLoader
Expand Down Expand Up @@ -38,6 +39,7 @@ def _train_batch(
batch: DatasetBatch,
model: NNModel,
optimizer: Optimizer,
scheduler: LRScheduler,
loss_fun: Loss,
batch_id: int,
data_loader: LLMDataLoader,
Expand All @@ -48,14 +50,16 @@ def _train_batch(

if (batch_id + 1) % self.gradient_acc_steps == 0 or (batch_id + 1) == len(data_loader):
optimizer.step()
scheduler.step()
optimizer.zero_grad()
return loss

def train(
self,
model: NNModel,
train_loader: LLMDataLoader,
optimizer,
optimizer: Optimizer,
scheduler: LRScheduler,
loss_fun: Loss,
callback_interval_in_batches: int,
# TODO: remove
Expand All @@ -82,6 +86,7 @@ def train(
batch=batch,
model=model,
optimizer=optimizer,
scheduler=scheduler,
loss_fun=loss_fun,
batch_id=batch_id,
data_loader=train_loader,
Expand Down
10 changes: 2 additions & 8 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from modalities.exceptions import TimeRecorderStateError
from modalities.running_env.fsdp.reducer import Reducer


def parse_enum_by_name(name: str, enum_type: Type[Enum]) -> Enum:
try:
return enum_type[name]
except KeyError:
raise ValidationError(f"Invalid {enum_type} member name: {name}")


def get_callback_interval_in_batches_per_rank(
callback_interval_in_samples: int, local_train_micro_batch_size: int, world_size: int, gradient_acc_steps: int
):
Expand All @@ -36,14 +38,6 @@ def get_callback_interval_in_batches_per_rank(
return num_local_train_micro_batches_ret


def parse_enum_by_name(name: str, enum_type: Type[Enum]) -> Enum:
try:
val = enum_type[name]
return val
except KeyError:
raise ValidationError(f"Invalid {enum_type} member name: {name}")


def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07__14-31-22'
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data.sampler import BatchSampler, SequentialSampler
from transformers import GPT2TokenizerFast

Expand Down Expand Up @@ -102,6 +103,11 @@ def optimizer_mock():
return MagicMock(spec=Optimizer)


@pytest.fixture(scope="function")
def scheduler_mock():
return MagicMock(spec=LRScheduler)


@pytest.fixture(scope="function")
def loss_mock():
return MagicMock(spec=Loss, return_value=torch.rand(1, requires_grad=True))
Expand Down
Loading

0 comments on commit b00e2a6

Please sign in to comment.