-
Notifications
You must be signed in to change notification settings - Fork 691
Add a nnx.identity "activation"? #4635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Okay, you have a nice suggestion that I should allow you to do this. I found that it uses |
I see, so you think the best place for this is in the jax activations which end up exported directly by NNX. Makes sense, let me post there. |
Done. See jax-ml/jax#27222 If the JAX maintainers add this (and if anyone knows them they could make a quick push?) then I can add in the re-export in https://github.com/google/flax/blob/main/flax/nnx/nn/activations.py |
Yes, I hope someone merges this PR Quickly. |
There are many cases where I want to setup a flexible neural network parameterization where I try out different "activation" functions and final layers from the CLI. A convenient design for this is to use jsonargparse or alternatives and do things like
Then I can do things like
python my_model.py --depth=3 --activation=nnx.tanh
However, a common thing I need to do is swap out different settings for an identity function. But this isn't available in jax or nnx. What I would love to do is try things like
python my_model.py --depth=3 --final_activation=nnx.identity
So, could we add this to the nnx activation functions? Something trivial like
If so, I can create a PR if you tell me which files to add it to. It seemed like https://github.com/google/flax/blob/main/flax/nnx/nn/activations.py is a natural place. If a test is required, tell me which file it makes the most sense in?
The text was updated successfully, but these errors were encountered: