From 59cc7fca066e12210d2e85701cb34f0a63e48196 Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 18 Mar 2024 11:12:06 +0100 Subject: [PATCH 1/7] feat(trainer): Added gradient clipping to training. --- ...ig_example_hf_meditron_7B_instruction.yaml | 2 + .../config_example_mem_map_dataset.yaml | 2 + config_files/config_lorem_ipsum.yaml | 2 + src/modalities/__main__.py | 5 ++ src/modalities/config/config.py | 9 ++++ src/modalities/trainer.py | 3 ++ src/modalities/utils/gradient_clipping.py | 18 +++++++ tests/utils/test_gradient_clipping.py | 49 +++++++++++++++++++ 8 files changed, 90 insertions(+) create mode 100644 src/modalities/utils/gradient_clipping.py create mode 100644 tests/utils/test_gradient_clipping.py diff --git a/config_files/config_example_hf_meditron_7B_instruction.yaml b/config_files/config_example_hf_meditron_7B_instruction.yaml index 590525dc..6bafbe3d 100644 --- a/config_files/config_example_hf_meditron_7B_instruction.yaml +++ b/config_files/config_example_hf_meditron_7B_instruction.yaml @@ -12,6 +12,8 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 1 sequence_length: 4096 + gradient_clipping_mode: none + gradient_clipping_threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index 1175a731..de236646 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -12,6 +12,8 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 16 sequence_length: 4096 + gradient_clipping_mode: none + gradient_clipping_threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/config_files/config_lorem_ipsum.yaml b/config_files/config_lorem_ipsum.yaml index 94e485dc..75431f27 100644 --- a/config_files/config_lorem_ipsum.yaml +++ b/config_files/config_lorem_ipsum.yaml @@ -11,6 +11,8 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 3 sequence_length: 256 + gradient_clipping_mode: p2_norm + gradient_clipping_threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 1b712250..da91870f 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -28,6 +28,7 @@ from modalities.trainer import Trainer from modalities.util import compute_number_of_trainable_parameters, get_callback_interval_in_batches_per_rank from modalities.utils.generate_text import main as generate_text_main +from modalities.utils.gradient_clipping import build_gradient_clipper @click.group() @@ -236,6 +237,10 @@ def run(self): batch_progress_publisher=batch_processed_publisher, evaluation_result_publisher=evaluation_result_publisher, gradient_acc_steps=components.settings.training.gradient_acc_steps, + gradient_clipper=build_gradient_clipper( + gradient_clipping_mode=components.settings.training.gradient_clipping_mode, + gradient_clipping_threshold=components.settings.training.gradient_clipping_threshold, + ), ) # Evaluator diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 55d4b9af..87a96be5 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -83,6 +83,13 @@ class WandbMode(LookupEnum): DISABLED = "DISABLED" +class GradientClippingMode(LookupEnum): + NONE = "NONE" + VALUE = "value" + P2_NORM = "p2_norm" + MAX_NORM = "max_norm" + + class ReferenceConfig(BaseModel): instance_key: str pass_type: PassType @@ -291,6 +298,8 @@ class Training(BaseModel): gradient_acc_steps: Annotated[int, Field(strict=True, ge=1)] local_train_micro_batch_size: Annotated[int, Field(strict=True, ge=1)] sequence_length: Annotated[int, Field(strict=True, ge=1)] + gradient_clipping_mode: GradientClippingMode + gradient_clipping_threshold: Annotated[float, Field(strict=True, gt=0.0)] = 1.0 class Paths(BaseModel): checkpointing_path: Path diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 6994af98..a1055a2d 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -27,11 +27,13 @@ def __init__( batch_progress_publisher: MessagePublisher[BatchProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], gradient_acc_steps: int, + gradient_clipper: Callable[[NNModel], None], ) -> None: self.local_rank = local_rank self.batch_progress_publisher = batch_progress_publisher self.evaluation_result_publisher = evaluation_result_publisher self.gradient_acc_steps = gradient_acc_steps + self.gradient_clipper = gradient_clipper def _train_batch( self, @@ -45,6 +47,7 @@ def _train_batch( result_batch = model_predict_batch(model=model, batch=batch) loss = loss_fun(result_batch) / self.gradient_acc_steps loss.backward() + self.gradient_clipper(model) if (batch_id + 1) % self.gradient_acc_steps == 0 or (batch_id + 1) == len(data_loader): optimizer.step() diff --git a/src/modalities/utils/gradient_clipping.py b/src/modalities/utils/gradient_clipping.py new file mode 100644 index 00000000..30bd6c6f --- /dev/null +++ b/src/modalities/utils/gradient_clipping.py @@ -0,0 +1,18 @@ +from typing import Callable + +from torch.nn.utils import clip_grad_norm_, clip_grad_value_ + +from modalities.config.config import GradientClippingMode +from modalities.models.model import NNModel + + +def build_gradient_clipper( + gradient_clipping_mode: GradientClippingMode, gradient_clipping_threshold: float +) -> Callable[[NNModel], None]: + if gradient_clipping_mode == GradientClippingMode.P2_NORM: + return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, 2), None)[-1] + if gradient_clipping_mode == GradientClippingMode.MAX_NORM: + return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, "inf"), None)[-1] + if gradient_clipping_mode == GradientClippingMode.VALUE: + return lambda model: clip_grad_value_(model.parameters(), gradient_clipping_threshold) + return lambda model: None diff --git a/tests/utils/test_gradient_clipping.py b/tests/utils/test_gradient_clipping.py new file mode 100644 index 00000000..60635d40 --- /dev/null +++ b/tests/utils/test_gradient_clipping.py @@ -0,0 +1,49 @@ +from typing import Dict + +import pytest +import torch +import torch.nn as nn + +from modalities.config.config import GradientClippingMode +from modalities.models.model import NNModel +from modalities.utils.gradient_clipping import build_gradient_clipper + + +@pytest.mark.parametrize( + "gradient_clipping_mode", [mode for mode in GradientClippingMode if mode != GradientClippingMode.NONE] +) +def test_clipping_gradients_makes_them_smaller(gradient_clipping_mode: GradientClippingMode): + grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(gradient_clipping_mode) + assert grad_sum1 > grad_sum2 + + +def test_gradient_clipping_mode_none_does_not_change_gradients(): + grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(GradientClippingMode.NONE) + assert grad_sum1 == grad_sum2 + + +class TestModel(NNModel): + def __init__(self): + super().__init__() + self._weights = nn.Linear(2, 3) + self._weights.weight = nn.Parameter(torch.ones_like(self._weights.weight)) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + output = self._weights(**inputs) + return {"output": output} + + def get_grad_sum(self) -> float: + return self._weights.weight.grad.sum().item() + + +def _run_gradient_clipping_experiment(gradient_clipping_mode): + model = TestModel() + inputs = {"input": torch.rand(2, 2)} + output: torch.Tensor = model(inputs)["output"] + loss = output.sum() + loss.backward() + grad_sum1 = model.get_grad_sum() + clipper = build_gradient_clipper(gradient_clipping_mode, 0.001) + clipper(model) + grad_sum2 = model.get_grad_sum() + return grad_sum1, grad_sum2 From efb8e551484b695a9534790f52e29ea107ed421b Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 18 Mar 2024 14:17:53 +0100 Subject: [PATCH 2/7] test(trainer): Added missing parameter in trainer fixture. --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index 1ee049ca..f75701c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,6 +124,7 @@ def trainer(progress_publisher_mock): batch_progress_publisher=progress_publisher_mock, evaluation_result_publisher=progress_publisher_mock, gradient_acc_steps=1, + gradient_clipper=lambda model: None, ) From 35bf777bc06efb252ebbf75cdcee576bf69cf63e Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 18 Mar 2024 14:18:43 +0100 Subject: [PATCH 3/7] docs(utilities): Added explanation comments in build_gradient_clipper(). --- src/modalities/utils/gradient_clipping.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/modalities/utils/gradient_clipping.py b/src/modalities/utils/gradient_clipping.py index 30bd6c6f..4302a7f4 100644 --- a/src/modalities/utils/gradient_clipping.py +++ b/src/modalities/utils/gradient_clipping.py @@ -10,8 +10,10 @@ def build_gradient_clipper( gradient_clipping_mode: GradientClippingMode, gradient_clipping_threshold: float ) -> Callable[[NNModel], None]: if gradient_clipping_mode == GradientClippingMode.P2_NORM: + # Always return None to satisfy the Callable[[NNModel], None] interface. return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, 2), None)[-1] if gradient_clipping_mode == GradientClippingMode.MAX_NORM: + # Always return None to satisfy the Callable[[NNModel], None] interface. return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, "inf"), None)[-1] if gradient_clipping_mode == GradientClippingMode.VALUE: return lambda model: clip_grad_value_(model.parameters(), gradient_clipping_threshold) From 101e46064453ade24384771825189ffa2b540f3b Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 18 Mar 2024 14:20:41 +0100 Subject: [PATCH 4/7] feat(config): Made gradient clipping config more hierarchical. --- .../config_example_hf_meditron_7B_instruction.yaml | 5 +++-- config_files/config_example_mem_map_dataset.yaml | 5 +++-- config_files/config_lorem_ipsum.yaml | 5 +++-- src/modalities/__main__.py | 4 ++-- src/modalities/config/config.py | 7 +++++-- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/config_files/config_example_hf_meditron_7B_instruction.yaml b/config_files/config_example_hf_meditron_7B_instruction.yaml index 6bafbe3d..f6b11821 100644 --- a/config_files/config_example_hf_meditron_7B_instruction.yaml +++ b/config_files/config_example_hf_meditron_7B_instruction.yaml @@ -12,8 +12,9 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 1 sequence_length: 4096 - gradient_clipping_mode: none - gradient_clipping_threshold: 1.0 + gradient_clipping: + mode: none + threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index de236646..3e0ce753 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -12,8 +12,9 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 16 sequence_length: 4096 - gradient_clipping_mode: none - gradient_clipping_threshold: 1.0 + gradient_clipping: + mode: none + threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/config_files/config_lorem_ipsum.yaml b/config_files/config_lorem_ipsum.yaml index 75431f27..ba2b252c 100644 --- a/config_files/config_lorem_ipsum.yaml +++ b/config_files/config_lorem_ipsum.yaml @@ -11,8 +11,9 @@ settings: gradient_acc_steps: 1 local_train_micro_batch_size: 3 sequence_length: 256 - gradient_clipping_mode: p2_norm - gradient_clipping_threshold: 1.0 + gradient_clipping: + mode: p2_norm + threshold: 1.0 cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index da91870f..ebfdef82 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -238,8 +238,8 @@ def run(self): evaluation_result_publisher=evaluation_result_publisher, gradient_acc_steps=components.settings.training.gradient_acc_steps, gradient_clipper=build_gradient_clipper( - gradient_clipping_mode=components.settings.training.gradient_clipping_mode, - gradient_clipping_threshold=components.settings.training.gradient_clipping_threshold, + gradient_clipping_mode=components.settings.training.gradient_clipping.mode, + gradient_clipping_threshold=components.settings.training.gradient_clipping.threshold, ), ) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 87a96be5..8c719379 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -291,6 +291,10 @@ class CudaEnv(BaseModel): class Settings(BaseModel): class Training(BaseModel): + class GradientClipping(BaseModel): + mode: GradientClippingMode + threshold: Annotated[float, Field(strict=True, gt=0.0)] = 1.0 + callback_interval_in_samples: Annotated[int, Field(strict=True, ge=1)] global_num_training_samples: Annotated[int, Field(strict=True, ge=1)] global_num_seen_samples: Annotated[int, Field(strict=True, ge=0)] @@ -298,8 +302,7 @@ class Training(BaseModel): gradient_acc_steps: Annotated[int, Field(strict=True, ge=1)] local_train_micro_batch_size: Annotated[int, Field(strict=True, ge=1)] sequence_length: Annotated[int, Field(strict=True, ge=1)] - gradient_clipping_mode: GradientClippingMode - gradient_clipping_threshold: Annotated[float, Field(strict=True, gt=0.0)] = 1.0 + gradient_clipping: GradientClipping class Paths(BaseModel): checkpointing_path: Path From 3fe94b8a3f832a717dc46e67a3ec60483b28bc68 Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 18 Mar 2024 14:29:07 +0100 Subject: [PATCH 5/7] docs(gradient_clipping): Added some additional documentation. --- src/modalities/config/config.py | 11 +++++++---- src/modalities/utils/gradient_clipping.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 8c719379..b867a8d8 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -84,10 +84,13 @@ class WandbMode(LookupEnum): class GradientClippingMode(LookupEnum): - NONE = "NONE" - VALUE = "value" - P2_NORM = "p2_norm" - MAX_NORM = "max_norm" + NONE = "NONE" # Do not apply gradient clipping. + VALUE = "value" # Clip all gradient values independently. + # For norm based clipping modes, the norm is computed over + # all gradients together, as if they were concatenated + # into a single vector. + P2_NORM = "p2_norm" # Euclidean norm based clipping. + MAX_NORM = "max_norm" # Maximum norm based clipping. class ReferenceConfig(BaseModel): diff --git a/src/modalities/utils/gradient_clipping.py b/src/modalities/utils/gradient_clipping.py index 4302a7f4..0492d0f8 100644 --- a/src/modalities/utils/gradient_clipping.py +++ b/src/modalities/utils/gradient_clipping.py @@ -9,6 +9,16 @@ def build_gradient_clipper( gradient_clipping_mode: GradientClippingMode, gradient_clipping_threshold: float ) -> Callable[[NNModel], None]: + """Returns a function that applies gradient clipping to a given model (in place). + + :param gradient_clipping_mode: Selection between different norm based modes, + value based clipping and no clipping + :type gradient_clipping_mode: GradientClippingMode + :param gradient_clipping_threshold: Value at which will be clipped. + :type gradient_clipping_threshold: float + :return: A function taking a model as input and producing no output. + :rtype: Callable[[NNModel], None] + """ if gradient_clipping_mode == GradientClippingMode.P2_NORM: # Always return None to satisfy the Callable[[NNModel], None] interface. return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, 2), None)[-1] From 4a400e5883d6adb34cbc928499f423e3ddd0a3b1 Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 18 Mar 2024 15:50:27 +0100 Subject: [PATCH 6/7] feat(config): Added model validation for gradient clipping. It checks that a threshold can be set only when the clipping is active. --- .../config_example_hf_meditron_7B_instruction.yaml | 2 +- config_files/config_example_mem_map_dataset.yaml | 2 +- src/modalities/config/config.py | 12 ++++++++++-- src/modalities/utils/gradient_clipping.py | 8 ++++++-- tests/utils/test_gradient_clipping.py | 10 +++++----- 5 files changed, 23 insertions(+), 11 deletions(-) diff --git a/config_files/config_example_hf_meditron_7B_instruction.yaml b/config_files/config_example_hf_meditron_7B_instruction.yaml index f6b11821..efdf9ed7 100644 --- a/config_files/config_example_hf_meditron_7B_instruction.yaml +++ b/config_files/config_example_hf_meditron_7B_instruction.yaml @@ -14,7 +14,7 @@ settings: sequence_length: 4096 gradient_clipping: mode: none - threshold: 1.0 + threshold: cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index 3e0ce753..c63da2ac 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -14,7 +14,7 @@ settings: sequence_length: 4096 gradient_clipping: mode: none - threshold: 1.0 + threshold: cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index b867a8d8..1210446b 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -4,7 +4,7 @@ import torch.nn as nn from omegaconf import OmegaConf -from pydantic import BaseModel, Field, FilePath, GetCoreSchemaHandler, PositiveInt, field_validator +from pydantic import BaseModel, Field, FilePath, GetCoreSchemaHandler, PositiveInt, field_validator, model_validator from pydantic_core import core_schema from torch.distributed.fsdp import ShardingStrategy from torch.optim import Optimizer @@ -296,7 +296,15 @@ class Settings(BaseModel): class Training(BaseModel): class GradientClipping(BaseModel): mode: GradientClippingMode - threshold: Annotated[float, Field(strict=True, gt=0.0)] = 1.0 + threshold: Optional[Annotated[float, Field(strict=True, gt=0.0)]] = None + + @model_validator(mode="after") + def check_mode_none_iff_threshold_none(self) -> BaseModel: + if self.mode == GradientClippingMode.NONE and self.threshold is not None: + raise ValueError("If gradient clipping is deactivated, no threshold should be set.") + if self.mode != GradientClippingMode.NONE and self.threshold is None: + raise ValueError("A threshold value is required when gradient clipping is used.") + return self callback_interval_in_samples: Annotated[int, Field(strict=True, ge=1)] global_num_training_samples: Annotated[int, Field(strict=True, ge=1)] diff --git a/src/modalities/utils/gradient_clipping.py b/src/modalities/utils/gradient_clipping.py index 0492d0f8..39a65a0d 100644 --- a/src/modalities/utils/gradient_clipping.py +++ b/src/modalities/utils/gradient_clipping.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Optional from torch.nn.utils import clip_grad_norm_, clip_grad_value_ @@ -7,7 +7,7 @@ def build_gradient_clipper( - gradient_clipping_mode: GradientClippingMode, gradient_clipping_threshold: float + gradient_clipping_mode: GradientClippingMode, gradient_clipping_threshold: Optional[float] ) -> Callable[[NNModel], None]: """Returns a function that applies gradient clipping to a given model (in place). @@ -19,6 +19,10 @@ def build_gradient_clipper( :return: A function taking a model as input and producing no output. :rtype: Callable[[NNModel], None] """ + if (gradient_clipping_threshold is None) != (gradient_clipping_mode == GradientClippingMode.NONE): + raise ValueError( + "Either gradient clipping is deactivated and no threshold given or activated and a threshold set." + ) if gradient_clipping_mode == GradientClippingMode.P2_NORM: # Always return None to satisfy the Callable[[NNModel], None] interface. return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, 2), None)[-1] diff --git a/tests/utils/test_gradient_clipping.py b/tests/utils/test_gradient_clipping.py index 60635d40..52f7f36d 100644 --- a/tests/utils/test_gradient_clipping.py +++ b/tests/utils/test_gradient_clipping.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional import pytest import torch @@ -18,7 +18,7 @@ def test_clipping_gradients_makes_them_smaller(gradient_clipping_mode: GradientC def test_gradient_clipping_mode_none_does_not_change_gradients(): - grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(GradientClippingMode.NONE) + grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(GradientClippingMode.NONE, threshold=None) assert grad_sum1 == grad_sum2 @@ -36,14 +36,14 @@ def get_grad_sum(self) -> float: return self._weights.weight.grad.sum().item() -def _run_gradient_clipping_experiment(gradient_clipping_mode): +def _run_gradient_clipping_experiment(gradient_clipping_mode: GradientClippingMode, threshold: Optional[float] = 0.001): model = TestModel() - inputs = {"input": torch.rand(2, 2)} + inputs = {"input": torch.ones(2, 2)} output: torch.Tensor = model(inputs)["output"] loss = output.sum() loss.backward() grad_sum1 = model.get_grad_sum() - clipper = build_gradient_clipper(gradient_clipping_mode, 0.001) + clipper = build_gradient_clipper(gradient_clipping_mode, threshold) clipper(model) grad_sum2 = model.get_grad_sum() return grad_sum1, grad_sum2 From d2d61aaff08d73e2b9ef5130a238b6f77b62c1bf Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Mon, 18 Mar 2024 16:11:06 +0100 Subject: [PATCH 7/7] fix(config): In gradient clipping configs: Removed threshold field when mode is none. --- config_files/config_example_hf_meditron_7B_instruction.yaml | 1 - config_files/config_example_mem_map_dataset.yaml | 1 - 2 files changed, 2 deletions(-) diff --git a/config_files/config_example_hf_meditron_7B_instruction.yaml b/config_files/config_example_hf_meditron_7B_instruction.yaml index efdf9ed7..e37305a6 100644 --- a/config_files/config_example_hf_meditron_7B_instruction.yaml +++ b/config_files/config_example_hf_meditron_7B_instruction.yaml @@ -14,7 +14,6 @@ settings: sequence_length: 4096 gradient_clipping: mode: none - threshold: cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK} diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index c63da2ac..9ee4743a 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -14,7 +14,6 @@ settings: sequence_length: 4096 gradient_clipping: mode: none - threshold: cuda_env: local_rank: ${cuda_env:LOCAL_RANK} global_rank: ${cuda_env:RANK}