File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
tests/tests_pytorch/plugins/precision Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change 14
14
from unittest .mock import Mock
15
15
16
16
import pytest
17
- from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
18
17
from torch .nn import Module
19
18
from torch .optim import Optimizer
20
19
24
23
25
24
def test_clip_gradients ():
26
25
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
27
- module = FSDP ( Mock (spec = Module ) )
26
+ module = Mock (spec = Module )
28
27
optimizer = Mock (spec = Optimizer )
29
28
precision = MixedPrecision (precision = "16-mixed" , device = "cuda:0" , scaler = Mock ())
30
29
precision .clip_grad_by_value = Mock ()
@@ -49,8 +48,9 @@ def test_optimizer_amp_scaling_support_in_step_method():
49
48
"""Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with
50
49
gradient clipping (example: fused Adam)."""
51
50
51
+ module = Mock (spec = Module )
52
52
optimizer = Mock (_step_supports_amp_scaling = True )
53
53
precision = MixedPrecision (precision = "16-mixed" , device = "cuda:0" , scaler = Mock ())
54
54
55
55
with pytest .raises (RuntimeError , match = "The current optimizer.*does not allow for gradient clipping" ):
56
- precision .clip_gradients (optimizer , clip_val = 1.0 )
56
+ precision .clip_gradients (module , optimizer , clip_val = 1.0 )
You can’t perform that action at this time.
0 commit comments