-
Notifications
You must be signed in to change notification settings - Fork 631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to freeze parameters with nnx
and optax
?
#4167
Comments
Check out 08_save_load_checkpoints.py which contains a very simple example of orbax checkpointing. We will have a proper checkpointing guide in the future as we migrate the Linen docs. |
Thank you for your response, but I believe there's been a misunderstanding. My question was about freezing parameters for transfer learning, not about checkpointing or saving/loading model states. |
Oh god I'm sorry! I read orbax. Transfer Learning and Surgery in general is a lot simpler in the new class Classifier(nnx.Module):
def __init__(self, embed_dim, num_classes, backbone, rngs):
self.backbone = backbone
self.head = nnx.Linear(embed_dim, num_classes, rngs=rngs)
def __call__(self, x):
x = self.backbone(x)
x = self.head(x)
return x
def load_model():
return nnx.Linear(784, 1024, rngs=nnx.Rngs(0))
backbone = load_model()
classifier = Classifier(1024, 10, backbone, rngs=nnx.Rngs(1))
# filter to select only Params on head path
head_params = nnx.All(nnx.Param, nnx.PathContains('head'))
optimizer = nnx.Optimizer(
classifier,
tx=optax.adamw(3e-4),
wrt=head_params, # filter head params
)
# simple train step
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
logits = model(x)
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
diff_state = nnx.DiffState(0, head_params) # filter head params of the first argument
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
optimizer.update(grads)
x = jnp.ones((1, 784))
y = jnp.ones((1,), jnp.int32)
train_step(classifier, optimizer, x, y) |
Excellent, that's exactly what I was looking for! |
I also have a similar question about |
@mmorinag127 because the state = nnx.state(classifier, head_params)
mask = create_mask(state) # TODO
optimizer = nnx.Optimizer(
classifier,
tx=optax.adamw(3e-4, mask=mask),
wrt=head_params, # filter head params
) |
Thanks a lot @cgarciae !! |
I would like to know what is the best way to freeze parameters in a model using
nnx
andoptax
(https://flax.readthedocs.io/en/latest/guides/training_techniques/transfer_learning.html#optax-multi-transform).I think it would be useful to add an example to https://flax.readthedocs.io/en/latest/nnx/index.html.
The text was updated successfully, but these errors were encountered: