diff --git a/config_files/config_example_hf_meditron_7B_instruction.yaml b/config_files/config_example_hf_meditron_7B_instruction.yaml index 590525dc..e37305a6 100644 --- a/config_files/config_example_hf_meditron_7B_instruction.yaml +++ b/config_files/config_example_hf_meditron_7B_instruction.yaml @@ -12,6 +12,8 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 1 sequence_length: 4096 + gradient_clipping: + mode: none cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index 1175a731..9ee4743a 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -12,6 +12,8 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 16 sequence_length: 4096 + gradient_clipping: + mode: none cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/config_files/config_lorem_ipsum.yaml b/config_files/config_lorem_ipsum.yaml index 7bcaa1fe..b9332f03 100644 --- a/config_files/config_lorem_ipsum.yaml +++ b/config_files/config_lorem_ipsum.yaml @@ -11,6 +11,9 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 3 sequence_length: 256 + gradient_clipping: + mode: p2_norm + threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 1b320d5c..bcd4e630 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -28,6 +28,7 @@ from modalities.trainer import Trainer from modalities.util import compute_number_of_trainable_parameters, get_callback_interval_in_batches_per_rank from modalities.utils.generate_text import main as generate_text_main +from modalities.utils.gradient_clipping import build_gradient_clipper @click.group() @@ -236,6 +237,10 @@ def run(self): batch_progress_publisher=batch_processed_publisher, evaluation_result_publisher=evaluation_result_publisher, gradient_acc_steps=components.settings.training.gradient_acc_steps, + gradient_clipper=build_gradient_clipper( + gradient_clipping_mode=components.settings.training.gradient_clipping.mode, + gradient_clipping_threshold=components.settings.training.gradient_clipping.threshold, + ), ) # Evaluator diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 4765fb56..593e590d 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -85,6 +85,16 @@ class WandbMode(LookupEnum): DISABLED = "DISABLED" +class GradientClippingMode(LookupEnum): + NONE = "NONE" # Do not apply gradient clipping. + VALUE = "value" # Clip all gradient values independently. + # For norm based clipping modes, the norm is computed over + # all gradients together, as if they were concatenated + # into a single vector. + P2_NORM = "p2_norm" # Euclidean norm based clipping. + MAX_NORM = "max_norm" # Maximum norm based clipping. + + class ReferenceConfig(BaseModel): instance_key: str pass_type: PassType @@ -349,6 +359,18 @@ class CudaEnv(BaseModel): class Settings(BaseModel): class Training(BaseModel): + class GradientClipping(BaseModel): + mode: GradientClippingMode + threshold: Optional[Annotated[float, Field(strict=True, gt=0.0)]] = None + + @model_validator(mode="after") + def check_mode_none_iff_threshold_none(self) -> BaseModel: + if self.mode == GradientClippingMode.NONE and self.threshold is not None: + raise ValueError("If gradient clipping is deactivated, no threshold should be set.") + if self.mode != GradientClippingMode.NONE and self.threshold is None: + raise ValueError("A threshold value is required when gradient clipping is used.") + return self + callback_interval_in_samples: Annotated[int, Field(strict=True, ge=1)] global_num_training_samples: Annotated[int, Field(strict=True, ge=1)] global_num_seen_samples: Annotated[int, Field(strict=True, ge=0)] @@ -356,6 +378,7 @@ class Training(BaseModel): gradient_acc_steps: Annotated[int, Field(strict=True, ge=1)] local_train_micro_batch_size: Annotated[int, Field(strict=True, ge=1)] sequence_length: Annotated[int, Field(strict=True, ge=1)] + gradient_clipping: GradientClipping class Paths(BaseModel): checkpointing_path: Path diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 9a719d23..d4fc9679 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -28,11 +28,13 @@ def __init__( batch_progress_publisher: MessagePublisher[BatchProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], gradient_acc_steps: int, + gradient_clipper: Callable[[NNModel], None], ) -> None: self.local_rank = local_rank self.batch_progress_publisher = batch_progress_publisher self.evaluation_result_publisher = evaluation_result_publisher self.gradient_acc_steps = gradient_acc_steps + self.gradient_clipper = gradient_clipper def _train_batch( self, @@ -47,6 +49,7 @@ def _train_batch( result_batch = model_predict_batch(model=model, batch=batch) loss = loss_fun(result_batch) / self.gradient_acc_steps loss.backward() + self.gradient_clipper(model) if (batch_id + 1) % self.gradient_acc_steps == 0 or (batch_id + 1) == len(data_loader): optimizer.step() diff --git a/src/modalities/utils/gradient_clipping.py b/src/modalities/utils/gradient_clipping.py new file mode 100644 index 00000000..39a65a0d --- /dev/null +++ b/src/modalities/utils/gradient_clipping.py @@ -0,0 +1,34 @@ +from typing import Callable, Optional + +from torch.nn.utils import clip_grad_norm_, clip_grad_value_ + +from modalities.config.config import GradientClippingMode +from modalities.models.model import NNModel + + +def build_gradient_clipper( + gradient_clipping_mode: GradientClippingMode, gradient_clipping_threshold: Optional[float] +) -> Callable[[NNModel], None]: + """Returns a function that applies gradient clipping to a given model (in place). + + :param gradient_clipping_mode: Selection between different norm based modes, + value based clipping and no clipping + :type gradient_clipping_mode: GradientClippingMode + :param gradient_clipping_threshold: Value at which will be clipped. + :type gradient_clipping_threshold: float + :return: A function taking a model as input and producing no output. + :rtype: Callable[[NNModel], None] + """ + if (gradient_clipping_threshold is None) != (gradient_clipping_mode == GradientClippingMode.NONE): + raise ValueError( + "Either gradient clipping is deactivated and no threshold given or activated and a threshold set." + ) + if gradient_clipping_mode == GradientClippingMode.P2_NORM: + # Always return None to satisfy the Callable[[NNModel], None] interface. + return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, 2), None)[-1] + if gradient_clipping_mode == GradientClippingMode.MAX_NORM: + # Always return None to satisfy the Callable[[NNModel], None] interface. + return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, "inf"), None)[-1] + if gradient_clipping_mode == GradientClippingMode.VALUE: + return lambda model: clip_grad_value_(model.parameters(), gradient_clipping_threshold) + return lambda model: None diff --git a/tests/conftest.py b/tests/conftest.py index 7f52a76a..b8459817 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -130,6 +130,7 @@ def trainer(progress_publisher_mock): batch_progress_publisher=progress_publisher_mock, evaluation_result_publisher=progress_publisher_mock, gradient_acc_steps=1, + gradient_clipper=lambda model: None, ) diff --git a/tests/utils/test_gradient_clipping.py b/tests/utils/test_gradient_clipping.py new file mode 100644 index 00000000..52f7f36d --- /dev/null +++ b/tests/utils/test_gradient_clipping.py @@ -0,0 +1,49 @@ +from typing import Dict, Optional + +import pytest +import torch +import torch.nn as nn + +from modalities.config.config import GradientClippingMode +from modalities.models.model import NNModel +from modalities.utils.gradient_clipping import build_gradient_clipper + + +@pytest.mark.parametrize( + "gradient_clipping_mode", [mode for mode in GradientClippingMode if mode != GradientClippingMode.NONE] +) +def test_clipping_gradients_makes_them_smaller(gradient_clipping_mode: GradientClippingMode): + grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(gradient_clipping_mode) + assert grad_sum1 > grad_sum2 + + +def test_gradient_clipping_mode_none_does_not_change_gradients(): + grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(GradientClippingMode.NONE, threshold=None) + assert grad_sum1 == grad_sum2 + + +class TestModel(NNModel): + def __init__(self): + super().__init__() + self._weights = nn.Linear(2, 3) + self._weights.weight = nn.Parameter(torch.ones_like(self._weights.weight)) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + output = self._weights(**inputs) + return {"output": output} + + def get_grad_sum(self) -> float: + return self._weights.weight.grad.sum().item() + + +def _run_gradient_clipping_experiment(gradient_clipping_mode: GradientClippingMode, threshold: Optional[float] = 0.001): + model = TestModel() + inputs = {"input": torch.ones(2, 2)} + output: torch.Tensor = model(inputs)["output"] + loss = output.sum() + loss.backward() + grad_sum1 = model.get_grad_sum() + clipper = build_gradient_clipper(gradient_clipping_mode, threshold) + clipper(model) + grad_sum2 = model.get_grad_sum() + return grad_sum1, grad_sum2