diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index e4352cbe3..5226b9e03 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -48,6 +48,7 @@ def reset_optimizer( optimizer_state_keys: list[str], prune_ratio: float = 0.9, ): + # pylint:disable=unused-argument pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) n_zeros = 0 n_total = 0 @@ -56,16 +57,16 @@ def reset_optimizer( if isinstance(optimizer, ZeroRedundancyOptimizer): optimizer_state = optimizer.optim.state - for param in reset_params: - param_state = optimizer_state[param] - if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer - continue - for key in optimizer_state_keys: - pruning_fn( - param_state[key] - ) # pruning fn has to be inplace to keep the same keys in the dict - n_total += param_state[key].numel() - n_zeros += torch.sum(param_state[key] == 0).item() + for group in optimizer.param_groups: + for param in group["params"]: + state = optimizer_state[param] + for key, value in state.items(): + if key not in optimizer_state_keys: + continue + if torch.is_tensor(value): + pruning_fn(value) + n_total += value.numel() + n_zeros += torch.sum(value == 0).item() _zeroed = n_zeros / (1e-7 + n_total) * 100 LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")