Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Z3: optimizations for grad norm calculation and gradient clipping #5504

Merged
merged 25 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
43792bb
z3 scaled_global_grad_norm: repalce get_global_norm with torch.norm
nelyahu May 7, 2024
cac04d9
Merge branch 'master' into zero3_scaled_global
tjruwase May 20, 2024
cbb6b6a
Merge branch 'master' into zero3_scaled_global
loadams May 20, 2024
37b4cb7
Merge branch 'master' into zero3_scaled_global
tjruwase May 23, 2024
5dd50c3
fix grad norm calc in cpu offload and use torch.clip for grad clipping
nelyahu May 27, 2024
f918054
Merge branch 'master' into zero3_scaled_global
lekurile May 30, 2024
ba9fd42
Merge branch 'master' into zero3_scaled_global
nelyahu Jun 7, 2024
238ab34
Merge branch 'master' into zero3_scaled_global
nelyahu Jun 17, 2024
6b6a834
Merge branch 'master' into zero3_scaled_global
loadams Jun 26, 2024
b14f920
Merge branch 'master' into zero3_scaled_global
tjruwase Jun 26, 2024
45e62d2
Merge branch 'master' into zero3_scaled_global
tjruwase Jun 27, 2024
ffdb7f7
Merge branch 'master' into zero3_scaled_global
loadams Jul 9, 2024
d99524b
adding gradient clipping to TestZeroPartialOffloadConfigSweep
nelyahu Jul 18, 2024
e5d5d7c
Merge branch 'master' into zero3_scaled_global
tjruwase Jul 18, 2024
67d2e35
Merge branch 'master' into zero3_scaled_global
loadams Jul 23, 2024
214103d
Merge branch 'master' into zero3_scaled_global
loadams Jul 23, 2024
6134da4
Merge branch 'master' into zero3_scaled_global
loadams Jul 23, 2024
4b16d7c
Merge branch 'master' into zero3_scaled_global
loadams Jul 24, 2024
38bb860
Merge branch 'master' into zero3_scaled_global
tjruwase Aug 4, 2024
ec1aa8c
Merge branch 'master' into zero3_scaled_global
loadams Aug 6, 2024
95cd3a1
Merge branch 'master' into zero3_scaled_global
loadams Aug 9, 2024
26676f2
Merge branch 'master' into zero3_scaled_global
BacharL Aug 14, 2024
e455962
Merge branch 'master' into zero3_scaled_global
loadams Aug 14, 2024
3407e93
Merge branch 'master' into zero3_scaled_global
loadams Aug 14, 2024
cb51597
Merge branch 'master' into zero3_scaled_global
loadams Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down Expand Up @@ -1413,7 +1413,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm

return total_norm
return total_norm.cpu()

@instrument_w_nvtx
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
Expand Down Expand Up @@ -2028,7 +2028,7 @@ def step(self, closure=None):
return

norm_groups = self._get_norm_groups()
scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
Expand Down Expand Up @@ -2112,8 +2112,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale

self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)

Expand Down
1 change: 1 addition & 0 deletions tests/unit/runtime/zero/test_zero_offloadpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test(self, h_dim: int, n_layers: int) -> None:
config_dict = {
"train_batch_size": 256,
"steps_per_print": 1,
"gradient_clipping": 1.0,
"optimizer": {
"type": "Adam",
"params": {
Expand Down
Loading