-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #80 from Modalities/79-add-gradient-clipping
Added gradient clipping to training.
- Loading branch information
Showing
9 changed files
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Callable, Optional | ||
|
||
from torch.nn.utils import clip_grad_norm_, clip_grad_value_ | ||
|
||
from modalities.config.config import GradientClippingMode | ||
from modalities.models.model import NNModel | ||
|
||
|
||
def build_gradient_clipper( | ||
gradient_clipping_mode: GradientClippingMode, gradient_clipping_threshold: Optional[float] | ||
) -> Callable[[NNModel], None]: | ||
"""Returns a function that applies gradient clipping to a given model (in place). | ||
:param gradient_clipping_mode: Selection between different norm based modes, | ||
value based clipping and no clipping | ||
:type gradient_clipping_mode: GradientClippingMode | ||
:param gradient_clipping_threshold: Value at which will be clipped. | ||
:type gradient_clipping_threshold: float | ||
:return: A function taking a model as input and producing no output. | ||
:rtype: Callable[[NNModel], None] | ||
""" | ||
if (gradient_clipping_threshold is None) != (gradient_clipping_mode == GradientClippingMode.NONE): | ||
raise ValueError( | ||
"Either gradient clipping is deactivated and no threshold given or activated and a threshold set." | ||
) | ||
if gradient_clipping_mode == GradientClippingMode.P2_NORM: | ||
# Always return None to satisfy the Callable[[NNModel], None] interface. | ||
return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, 2), None)[-1] | ||
if gradient_clipping_mode == GradientClippingMode.MAX_NORM: | ||
# Always return None to satisfy the Callable[[NNModel], None] interface. | ||
return lambda model: (clip_grad_norm_(model.parameters(), gradient_clipping_threshold, "inf"), None)[-1] | ||
if gradient_clipping_mode == GradientClippingMode.VALUE: | ||
return lambda model: clip_grad_value_(model.parameters(), gradient_clipping_threshold) | ||
return lambda model: None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from typing import Dict, Optional | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
|
||
from modalities.config.config import GradientClippingMode | ||
from modalities.models.model import NNModel | ||
from modalities.utils.gradient_clipping import build_gradient_clipper | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"gradient_clipping_mode", [mode for mode in GradientClippingMode if mode != GradientClippingMode.NONE] | ||
) | ||
def test_clipping_gradients_makes_them_smaller(gradient_clipping_mode: GradientClippingMode): | ||
grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(gradient_clipping_mode) | ||
assert grad_sum1 > grad_sum2 | ||
|
||
|
||
def test_gradient_clipping_mode_none_does_not_change_gradients(): | ||
grad_sum1, grad_sum2 = _run_gradient_clipping_experiment(GradientClippingMode.NONE, threshold=None) | ||
assert grad_sum1 == grad_sum2 | ||
|
||
|
||
class TestModel(NNModel): | ||
def __init__(self): | ||
super().__init__() | ||
self._weights = nn.Linear(2, 3) | ||
self._weights.weight = nn.Parameter(torch.ones_like(self._weights.weight)) | ||
|
||
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
output = self._weights(**inputs) | ||
return {"output": output} | ||
|
||
def get_grad_sum(self) -> float: | ||
return self._weights.weight.grad.sum().item() | ||
|
||
|
||
def _run_gradient_clipping_experiment(gradient_clipping_mode: GradientClippingMode, threshold: Optional[float] = 0.001): | ||
model = TestModel() | ||
inputs = {"input": torch.ones(2, 2)} | ||
output: torch.Tensor = model(inputs)["output"] | ||
loss = output.sum() | ||
loss.backward() | ||
grad_sum1 = model.get_grad_sum() | ||
clipper = build_gradient_clipper(gradient_clipping_mode, threshold) | ||
clipper(model) | ||
grad_sum2 = model.get_grad_sum() | ||
return grad_sum1, grad_sum2 |