From 7b737a71d79bdab27769e7f81accd033c14c9a15 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 9 Jul 2024 10:09:08 -0700 Subject: [PATCH] make format --- netam/framework.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/netam/framework.py b/netam/framework.py index 26c0e0fb..a9fdce6a 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -422,7 +422,9 @@ def reset_optimization(self, learning_rate=None): learning_rate = self.learning_rate # copied from # https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L264 - param_dict = {pn: p for pn, p in self.model.named_parameters() if p.requires_grad} + param_dict = { + pn: p for pn, p in self.model.named_parameters() if p.requires_grad + } # Do not apply weight decay to 1D parameters (biases and layernorm weights). decay_params = [p for p in param_dict.values() if p.dim() >= 2] nodecay_params = [p for p in param_dict.values() if p.dim() < 2]