Skip to content

Commit

Permalink
Merge pull request #80 from Modalities/79-add-gradient-clipping
Browse files Browse the repository at this point in the history
Added gradient clipping to training.
  • Loading branch information
BlueCrescent committed Mar 18, 2024
2 parents b00e2a6 + d2d61aa commit 842294a
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 0 deletions.
2 changes: 2 additions & 0 deletions config_files/config_example_hf_meditron_7B_instruction.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ settings:
gradient_acc_steps: 1
local_train_micro_batch_size: 1
sequence_length: 4096
gradient_clipping:
mode: none
cuda_env:
local_rank: ${cuda_env:LOCAL_RANK}
global_rank: ${cuda_env:RANK}
Expand Down
2 changes: 2 additions & 0 deletions config_files/config_example_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ settings:
gradient_acc_steps: 1
local_train_micro_batch_size: 16
sequence_length: 4096
gradient_clipping:
mode: none
cuda_env:
local_rank: ${cuda_env:LOCAL_RANK}
global_rank: ${cuda_env:RANK}
Expand Down
3 changes: 3 additions & 0 deletions config_files/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ settings:
gradient_acc_steps: 1
local_train_micro_batch_size: 3
sequence_length: 256
gradient_clipping:
mode: p2_norm
threshold: 1.0
cuda_env:
local_rank: ${cuda_env:LOCAL_RANK}
global_rank: ${cuda_env:RANK}
Expand Down
5 changes: 5 additions & 0 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from modalities.trainer import Trainer
from modalities.util import compute_number_of_trainable_parameters, get_callback_interval_in_batches_per_rank
from modalities.utils.generate_text import main as generate_text_main
from modalities.utils.gradient_clipping import build_gradient_clipper


@click.group()
Expand Down Expand Up @@ -236,6 +237,10 @@ def run(self):
batch_progress_publisher=batch_processed_publisher,
evaluation_result_publisher=evaluation_result_publisher,
gradient_acc_steps=components.settings.training.gradient_acc_steps,
gradient_clipper=build_gradient_clipper(
gradient_clipping_mode=components.settings.training.gradient_clipping.mode,
gradient_clipping_threshold=components.settings.training.gradient_clipping.threshold,
),
)

# Evaluator
Expand Down
23 changes: 23 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ class WandbMode(LookupEnum):
DISABLED = "DISABLED"


class GradientClippingMode(LookupEnum):
NONE = "NONE" # Do not apply gradient clipping.
VALUE = "value" # Clip all gradient values independently.
# For norm based clipping modes, the norm is computed over
# all gradients together, as if they were concatenated
# into a single vector.
P2_NORM = "p2_norm" # Euclidean norm based clipping.
MAX_NORM = "max_norm" # Maximum norm based clipping.


class ReferenceConfig(BaseModel):
instance_key: str
pass_type: PassType
Expand Down Expand Up @@ -349,13 +359,26 @@ class CudaEnv(BaseModel):

class Settings(BaseModel):
class Training(BaseModel):
class GradientClipping(BaseModel):
mode: GradientClippingMode
threshold: Optional[Annotated[float, Field(strict=True, gt=0.0)]] = None

@model_validator(mode="after")
def check_mode_none_iff_threshold_none(self) -> BaseModel:
if self.mode == GradientClippingMode.NONE and self.threshold is not None:
raise ValueError("If gradient clipping is deactivated, no threshold should be set.")
if self.mode != GradientClippingMode.NONE and self.threshold is None:
raise ValueError("A threshold value is required when gradient clipping is used.")
return self

callback_interval_in_samples: Annotated[int, Field(strict=True, ge=1)]
global_num_training_samples: Annotated[int, Field(strict=True, ge=1)]
global_num_seen_samples: Annotated[int, Field(strict=True, ge=0)]
do_apply_activation_checkpointing: bool
gradient_acc_steps: Annotated[int, Field(strict=True, ge=1)]
local_train_micro_batch_size: Annotated[int, Field(strict=True, ge=1)]
sequence_length: Annotated[int, Field(strict=True, ge=1)]
gradient_clipping: GradientClipping

class Paths(BaseModel):
checkpointing_path: Path
Expand Down
3 changes: 3 additions & 0 deletions src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ def __init__(
batch_progress_publisher: MessagePublisher[BatchProgressUpdate],
evaluation_result_publisher: MessagePublisher[EvaluationResultBatch],
gradient_acc_steps: int,
gradient_clipper: Callable[[NNModel], None],
) -> None:
self.local_rank = local_rank
self.batch_progress_publisher = batch_progress_publisher
self.evaluation_result_publisher = evaluation_result_publisher
self.gradient_acc_steps = gradient_acc_steps
self.gradient_clipper = gradient_clipper

def _train_batch(
self,
Expand All @@ -47,6 +49,7 @@ def _train_batch(
result_batch = model_predict_batch(model=model, batch=batch)
loss = loss_fun(result_batch) / self.gradient_acc_steps
loss.backward()
self.gradient_clipper(model)

if (batch_id + 1) % self.gradient_acc_steps == 0 or (batch_id + 1) == len(data_loader):
optimizer.step()
Expand Down
34 changes: 34 additions & 0 deletions src/modalities/utils/gradient_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Callable, Optional

from torch.nn.utils import clip_grad_norm_, clip_grad_value_

from modalities.config.config import GradientClippingMode
from modalities.models.model import NNModel


def build_gradient_clipper(
gradient_clipping_mode: GradientClippingMode, gradient_clipping_threshold: Optional[float]
) -> Callable[[NNModel], None]:
"""Returns a function that applies gradient clipping to a given model (in place).
:param gradient_clipping_mode: Selection between different norm based modes,
value based clipping and no clipping
:type gradient_clipping_mode: GradientClippingMode
:param gradient_clipping_threshold: Value at which will be clipped.
:type gradient_clipping_threshold: float
:return: A function taking a model as input and producing no output.
:rtype: Callable[[NNModel], None]
"""
if (gradient_clipping_threshold is None) != (gradient_clipping_mode == GradientClippingMode.NONE):
raise ValueError(
"Either gradient clipping is deactivated and no threshold given or activated and a threshold set."
)
if gradient_clipping_mode == GradientClippingMode.P2_NORM:
# Always return None to satisfy the Callable[[NNModel], None] interface.
return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, 2), None)[-1]
if gradient_clipping_mode == GradientClippingMode.MAX_NORM:
# Always return None to satisfy the Callable[[NNModel], None] interface.
return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, "inf"), None)[-1]
if gradient_clipping_mode == GradientClippingMode.VALUE:
return lambda model: clip_grad_value_(model.parameters(), gradient_clipping_threshold)
return lambda model: None
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def trainer(progress_publisher_mock):
batch_progress_publisher=progress_publisher_mock,
evaluation_result_publisher=progress_publisher_mock,
gradient_acc_steps=1,
gradient_clipper=lambda model: None,
)


Expand Down
49 changes: 49 additions & 0 deletions tests/utils/test_gradient_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Dict, Optional

import pytest
import torch
import torch.nn as nn

from modalities.config.config import GradientClippingMode
from modalities.models.model import NNModel
from modalities.utils.gradient_clipping import build_gradient_clipper


@pytest.mark.parametrize(
"gradient_clipping_mode", [mode for mode in GradientClippingMode if mode != GradientClippingMode.NONE]
)
def test_clipping_gradients_makes_them_smaller(gradient_clipping_mode: GradientClippingMode):
grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(gradient_clipping_mode)
assert grad_sum1 > grad_sum2


def test_gradient_clipping_mode_none_does_not_change_gradients():
grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(GradientClippingMode.NONE, threshold=None)
assert grad_sum1 == grad_sum2


class TestModel(NNModel):
def __init__(self):
super().__init__()
self._weights = nn.Linear(2, 3)
self._weights.weight = nn.Parameter(torch.ones_like(self._weights.weight))

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
output = self._weights(**inputs)
return {"output": output}

def get_grad_sum(self) -> float:
return self._weights.weight.grad.sum().item()


def _run_gradient_clipping_experiment(gradient_clipping_mode: GradientClippingMode, threshold: Optional[float] = 0.001):
model = TestModel()
inputs = {"input": torch.ones(2, 2)}
output: torch.Tensor = model(inputs)["output"]
loss = output.sum()
loss.backward()
grad_sum1 = model.get_grad_sum()
clipper = build_gradient_clipper(gradient_clipping_mode, threshold)
clipper(model)
grad_sum2 = model.get_grad_sum()
return grad_sum1, grad_sum2

0 comments on commit 842294a

Please sign in to comment.