Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to freeze parameters with nnx and optax? #4167

Open
maxencefaldor opened this issue Sep 1, 2024 · 7 comments
Open

How to freeze parameters with nnx and optax? #4167

maxencefaldor opened this issue Sep 1, 2024 · 7 comments

Comments

@maxencefaldor
Copy link

I would like to know what is the best way to freeze parameters in a model using nnx and optax (https://flax.readthedocs.io/en/latest/guides/training_techniques/transfer_learning.html#optax-multi-transform).

I think it would be useful to add an example to https://flax.readthedocs.io/en/latest/nnx/index.html.

@cgarciae
Copy link
Collaborator

cgarciae commented Sep 2, 2024

Check out 08_save_load_checkpoints.py which contains a very simple example of orbax checkpointing. We will have a proper checkpointing guide in the future as we migrate the Linen docs.

@maxencefaldor
Copy link
Author

Thank you for your response, but I believe there's been a misunderstanding. My question was about freezing parameters for transfer learning, not about checkpointing or saving/loading model states.

@cgarciae
Copy link
Collaborator

cgarciae commented Sep 2, 2024

Oh god I'm sorry! I read orbax.

Transfer Learning and Surgery in general is a lot simpler in the new nnx API. Here's a small working example:

class Classifier(nnx.Module):
  def __init__(self, embed_dim, num_classes, backbone, rngs):
    self.backbone = backbone
    self.head = nnx.Linear(embed_dim, num_classes, rngs=rngs)

  def __call__(self, x):
    x = self.backbone(x)
    x = self.head(x)
    return x

def load_model():
  return nnx.Linear(784, 1024, rngs=nnx.Rngs(0))

backbone = load_model()
classifier = Classifier(1024, 10, backbone, rngs=nnx.Rngs(1))

# filter to select only Params on head path
head_params = nnx.All(nnx.Param, nnx.PathContains('head'))

optimizer = nnx.Optimizer(
  classifier,
  tx=optax.adamw(3e-4),
  wrt=head_params,  # filter head params
)

# simple train step
@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    logits = model(x)
    return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()

  diff_state = nnx.DiffState(0, head_params) # filter head params of the first argument
  grads = nnx.grad(loss_fn, argnums=diff_state)(model)
  optimizer.update(grads)

x = jnp.ones((1, 784))
y = jnp.ones((1,), jnp.int32)
train_step(classifier, optimizer, x, y)

@maxencefaldor
Copy link
Author

Excellent, that's exactly what I was looking for!

@mmorinag127
Copy link

I also have a similar question about Adamw's mask for weight decay.
How can I specify which parameters are to be applied to the weight decay?
Is there any excellent way like the above?

@cgarciae
Copy link
Collaborator

cgarciae commented Sep 17, 2024

@mmorinag127 because the nnx.Optimizer wrapper is very generic there is no support for mask using filters, you can probably generate a compatible mask from the state e.g.

state = nnx.state(classifier, head_params)
mask = create_mask(state) # TODO

optimizer = nnx.Optimizer(
  classifier,
  tx=optax.adamw(3e-4, mask=mask),
  wrt=head_params,  # filter head params
)

@mmorinag127
Copy link

Thanks a lot @cgarciae !!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants