Skip to content

Commit

Permalink
fix optimizer reset
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 16, 2024
1 parent 8df7b88 commit 571cdd4
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/axolotl/monkeypatch/relora.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def reset_optimizer(
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
# pylint:disable=unused-argument
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
n_zeros = 0
n_total = 0
Expand All @@ -56,16 +57,16 @@ def reset_optimizer(
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state

for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
for group in optimizer.param_groups:
for param in group["params"]:
state = optimizer_state[param]
for key, value in state.items():
if key not in optimizer_state_keys:
continue
if torch.is_tensor(value):
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()

_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
Expand Down

0 comments on commit 571cdd4

Please sign in to comment.