Skip to content

Commit d3550dc

Browse files
jomayeritjruwase
andauthored
Adagrad support in ZeRO (#3401)
* Adding torch.optim.Adagrad * adding adagrad for zero 1 2 * Adding Adagrad support to zero 3. * Adding documentation and DeepSpeedCPUAdagrad to list. --------- Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 77ebf76 commit d3550dc

File tree

4 files changed

+33
-12
lines changed

4 files changed

+33
-12
lines changed

deepspeed/runtime/engine.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,7 @@ def _configure_basic_optimizer(self, model_parameters):
12001200
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
12011201
)
12021202

1203-
if self.optimizer_name() in [ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
1203+
if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
12041204
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
12051205
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)
12061206

@@ -1214,14 +1214,10 @@ def _configure_basic_optimizer(self, model_parameters):
12141214
optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
12151215
else:
12161216
if self.zero_use_cpu_optimizer():
1217-
if self.optimizer_name() == ADAGRAD_OPTIMIZER:
1218-
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
1219-
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)
1220-
else:
1221-
from deepspeed.ops.adam import DeepSpeedCPUAdam
1222-
optimizer = DeepSpeedCPUAdam(model_parameters,
1223-
**optimizer_parameters,
1224-
adamw_mode=effective_adam_w_mode)
1217+
from deepspeed.ops.adam import DeepSpeedCPUAdam
1218+
optimizer = DeepSpeedCPUAdam(model_parameters,
1219+
**optimizer_parameters,
1220+
adamw_mode=effective_adam_w_mode)
12251221
else:
12261222
from deepspeed.ops.adam import FusedAdam
12271223

@@ -1231,6 +1227,12 @@ def _configure_basic_optimizer(self, model_parameters):
12311227
adam_w_mode=effective_adam_w_mode,
12321228
)
12331229

1230+
elif self.optimizer_name() == ADAGRAD_OPTIMIZER:
1231+
if self.zero_use_cpu_optimizer():
1232+
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
1233+
optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters)
1234+
else:
1235+
optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters)
12341236
elif self.optimizer_name() == LAMB_OPTIMIZER:
12351237
from deepspeed.ops.lamb import FusedLamb
12361238

deepspeed/runtime/zero/stage3.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,10 @@ def initialize_optimizer_states(self):
859859

860860
timer_names = set()
861861

862+
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
863+
# which do lazy initialization of the state at the first call to step.
864+
is_adagrad = isinstance(self.optimizer, torch.optim.Adagrad)
865+
862866
if self.swap_optimizer:
863867
self.optimizer_swapper.init_timers()
864868

@@ -888,7 +892,9 @@ def initialize_optimizer_states(self):
888892
else:
889893
self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements)
890894

891-
self._optimizer_step(i)
895+
# Initialize the optimizer states with the flattended fp32 partition.
896+
if not is_adagrad:
897+
self._optimizer_step(i)
892898

893899
if swappable_param_subgroup:
894900
self._partitioned_params_swap_out(i)
@@ -900,6 +906,10 @@ def initialize_optimizer_states(self):
900906
f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
901907
force=False)
902908

909+
# Initialize the optimizer states with the flattended fp32 partition.
910+
if is_adagrad:
911+
self.optimizer = torch.optim.Adagrad(self.fp32_partitioned_groups_flat, **self.optimizer.defaults)
912+
903913
self.stop_timers([INIT_OPTIMIZER_TIMER])
904914
self.log_timers(timer_names)
905915

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,13 @@ def initialize_optimizer_states(self):
611611
self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory(
612612
single_grad_partition) if self.cpu_offload else single_grad_partition
613613

614-
self.optimizer.step()
614+
# Initialize the optimizer states with the flattended fp32 partition.
615+
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
616+
# which do lazy initialization of the state at the first call to step.
617+
if isinstance(self.optimizer, torch.optim.Adagrad):
618+
self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)
619+
else:
620+
self.optimizer.step()
615621

616622
if not self.cpu_offload:
617623
for group in self.single_partition_of_fp32_groups:

deepspeed/runtime/zero/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from deepspeed import comm as dist
1111
from deepspeed.utils import logger
1212
from deepspeed.ops.adam import DeepSpeedCPUAdam
13+
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
1314
from deepspeed.ops.adam import FusedAdam
1415
from deepspeed.utils.nvtx import instrument_w_nvtx
1516
from deepspeed.accelerator import get_accelerator
@@ -35,7 +36,9 @@ class ZeRORuntimeException(Exception):
3536
pass
3637

3738

38-
ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam]
39+
ZERO_SUPPORTED_OPTIMIZERS = [
40+
torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad
41+
]
3942

4043
# Add apex FusedAdam to supported list if apex is installed
4144
try:

0 commit comments

Comments
 (0)