Skip to content

Commit 96c4daa

Browse files
authored
minor refactor loss scaler (#261)
1 parent f502550 commit 96c4daa

File tree

1 file changed

+28
-31
lines changed

1 file changed

+28
-31
lines changed

deepspeed/pt/loss_scaler.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,29 @@ def to_python_float(t):
3131
return t[0]
3232

3333

34-
class LossScaler:
34+
class LossScalerBase:
35+
"""LossScalarBase
36+
Base class for a loss scaler
37+
"""
38+
def __init__(self, cur_scale):
39+
self.cur_scale = cur_scale
40+
41+
@property
42+
def loss_scale(self):
43+
return self.cur_scale
44+
45+
def scale_gradient(self, module, grad_in, grad_out):
46+
return tuple(self.loss_scale * g for g in grad_in)
47+
48+
def update_scale(self, overflow):
49+
pass
50+
51+
def backward(self, loss, retain_graph=False):
52+
scaled_loss = loss * self.loss_scale
53+
scaled_loss.backward(retain_graph=retain_graph)
54+
55+
56+
class LossScaler(LossScalerBase):
3557
"""
3658
Class that manages a static loss scale. This class is intended to interact with
3759
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
@@ -43,7 +65,7 @@ class LossScaler:
4365
scale (float, optional, default=1.0): The loss scale.
4466
"""
4567
def __init__(self, scale=1):
46-
self.cur_scale = scale
68+
super(LossScaler, self).__init__(scale)
4769

4870
# `params` is a list / generator of torch.Variable
4971
def has_overflow(self, params):
@@ -53,22 +75,8 @@ def has_overflow(self, params):
5375
def _has_inf_or_nan(x):
5476
return False
5577

56-
def update_scale(self, overflow):
57-
pass
58-
59-
@property
60-
def loss_scale(self):
61-
return self.cur_scale
62-
63-
def scale_gradient(self, module, grad_in, grad_out):
64-
return tuple(self.loss_scale * g for g in grad_in)
6578

66-
def backward(self, loss, retain_graph=False):
67-
scaled_loss = loss * self.loss_scale
68-
scaled_loss.backward(retain_graph=retain_graph)
69-
70-
71-
class DynamicLossScaler:
79+
class DynamicLossScaler(LossScalerBase):
7280
"""
7381
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
7482
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
@@ -100,7 +108,7 @@ def __init__(self,
100108
min_scale=1,
101109
delayed_shift=1,
102110
consecutive_hysteresis=False):
103-
self.cur_scale = init_scale
111+
super(DynamicLossScaler, self).__init__(init_scale)
104112
self.cur_iter = 0
105113
self.last_overflow_iter = -1
106114
self.scale_factor = scale_factor
@@ -113,7 +121,7 @@ def __init__(self,
113121
# `params` is a list / generator of torch.Variable
114122
def has_overflow_serial(self, params):
115123
for p in params:
116-
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
124+
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
117125
return True
118126

119127
return False
@@ -135,7 +143,7 @@ def _has_inf_or_nan(x):
135143
raise
136144
return True
137145
else:
138-
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
146+
if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
139147
return True
140148
return False
141149

@@ -157,17 +165,6 @@ def update_scale(self, overflow):
157165
self.cur_scale *= self.scale_factor
158166
self.cur_iter += 1
159167

160-
@property
161-
def loss_scale(self):
162-
return self.cur_scale
163-
164-
def scale_gradient(self, module, grad_in, grad_out):
165-
return tuple(self.loss_scale * g for g in grad_in)
166-
167-
def backward(self, loss, retain_graph=False):
168-
scaled_loss = loss * self.loss_scale
169-
scaled_loss.backward(retain_graph=retain_graph)
170-
171168

172169
##############################################################
173170
# Example usage below here -- assuming it's in a separate file

0 commit comments

Comments
 (0)