diff --git a/netam/framework.py b/netam/framework.py index 26c0e0fb..a9fdce6a 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -422,7 +422,9 @@ def reset_optimization(self, learning_rate=None): learning_rate = self.learning_rate # copied from # https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L264 - param_dict = {pn: p for pn, p in self.model.named_parameters() if p.requires_grad} + param_dict = { + pn: p for pn, p in self.model.named_parameters() if p.requires_grad + } # Do not apply weight decay to 1D parameters (biases and layernorm weights). decay_params = [p for p in param_dict.values() if p.dim() >= 2] nodecay_params = [p for p in param_dict.values() if p.dim() < 2]