@@ -31,7 +31,29 @@ def to_python_float(t):
31
31
return t [0 ]
32
32
33
33
34
- class LossScaler :
34
+ class LossScalerBase :
35
+ """LossScalarBase
36
+ Base class for a loss scaler
37
+ """
38
+ def __init__ (self , cur_scale ):
39
+ self .cur_scale = cur_scale
40
+
41
+ @property
42
+ def loss_scale (self ):
43
+ return self .cur_scale
44
+
45
+ def scale_gradient (self , module , grad_in , grad_out ):
46
+ return tuple (self .loss_scale * g for g in grad_in )
47
+
48
+ def update_scale (self , overflow ):
49
+ pass
50
+
51
+ def backward (self , loss , retain_graph = False ):
52
+ scaled_loss = loss * self .loss_scale
53
+ scaled_loss .backward (retain_graph = retain_graph )
54
+
55
+
56
+ class LossScaler (LossScalerBase ):
35
57
"""
36
58
Class that manages a static loss scale. This class is intended to interact with
37
59
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
@@ -43,7 +65,7 @@ class LossScaler:
43
65
scale (float, optional, default=1.0): The loss scale.
44
66
"""
45
67
def __init__ (self , scale = 1 ):
46
- self . cur_scale = scale
68
+ super ( LossScaler , self ). __init__ ( scale )
47
69
48
70
# `params` is a list / generator of torch.Variable
49
71
def has_overflow (self , params ):
@@ -53,22 +75,8 @@ def has_overflow(self, params):
53
75
def _has_inf_or_nan (x ):
54
76
return False
55
77
56
- def update_scale (self , overflow ):
57
- pass
58
-
59
- @property
60
- def loss_scale (self ):
61
- return self .cur_scale
62
-
63
- def scale_gradient (self , module , grad_in , grad_out ):
64
- return tuple (self .loss_scale * g for g in grad_in )
65
78
66
- def backward (self , loss , retain_graph = False ):
67
- scaled_loss = loss * self .loss_scale
68
- scaled_loss .backward (retain_graph = retain_graph )
69
-
70
-
71
- class DynamicLossScaler :
79
+ class DynamicLossScaler (LossScalerBase ):
72
80
"""
73
81
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
74
82
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
@@ -100,7 +108,7 @@ def __init__(self,
100
108
min_scale = 1 ,
101
109
delayed_shift = 1 ,
102
110
consecutive_hysteresis = False ):
103
- self . cur_scale = init_scale
111
+ super ( DynamicLossScaler , self ). __init__ ( init_scale )
104
112
self .cur_iter = 0
105
113
self .last_overflow_iter = - 1
106
114
self .scale_factor = scale_factor
@@ -113,7 +121,7 @@ def __init__(self,
113
121
# `params` is a list / generator of torch.Variable
114
122
def has_overflow_serial (self , params ):
115
123
for p in params :
116
- if p .grad is not None and DynamicLossScaler ._has_inf_or_nan (p .grad .data ):
124
+ if p .grad is not None and self ._has_inf_or_nan (p .grad .data ):
117
125
return True
118
126
119
127
return False
@@ -135,7 +143,7 @@ def _has_inf_or_nan(x):
135
143
raise
136
144
return True
137
145
else :
138
- if cpu_sum == float ('inf' ) or cpu_sum == - float ('inf' ) or cpu_sum != cpu_sum :
146
+ if cpu_sum in [ float ('inf' ), - float ('inf' )] or cpu_sum != cpu_sum :
139
147
return True
140
148
return False
141
149
@@ -157,17 +165,6 @@ def update_scale(self, overflow):
157
165
self .cur_scale *= self .scale_factor
158
166
self .cur_iter += 1
159
167
160
- @property
161
- def loss_scale (self ):
162
- return self .cur_scale
163
-
164
- def scale_gradient (self , module , grad_in , grad_out ):
165
- return tuple (self .loss_scale * g for g in grad_in )
166
-
167
- def backward (self , loss , retain_graph = False ):
168
- scaled_loss = loss * self .loss_scale
169
- scaled_loss .backward (retain_graph = retain_graph )
170
-
171
168
172
169
##############################################################
173
170
# Example usage below here -- assuming it's in a separate file
0 commit comments