Skip to content

Add a nnx.identity "activation"? #4635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jlperla opened this issue Mar 18, 2025 · 4 comments · May be fixed by #4652
Open

Add a nnx.identity "activation"? #4635

jlperla opened this issue Mar 18, 2025 · 4 comments · May be fixed by #4652

Comments

@jlperla
Copy link
Contributor

jlperla commented Mar 18, 2025

There are many cases where I want to setup a flexible neural network parameterization where I try out different "activation" functions and final layers from the CLI. A convenient design for this is to use jsonargparse or alternatives and do things like

# Workaround: https://github.com/omni-us/jsonargparse/issues/619#issuecomment-2466451720
nnx.softplus.__deepcopy__ = MethodType(lambda self, memo: self, nnx.softplus)

def my_model(
    width: int = 64,
    depth: int = 4,
    activation: Callable = nnx.relu,
    final_activation: Callable = nnx.softplus,
    ...): 
    # use the settings to create a neural network
if __name__ == "__main__":
    jsonargparse.CLI(my_model)

Then I can do things like python my_model.py --depth=3 --activation=nnx.tanh

However, a common thing I need to do is swap out different settings for an identity function. But this isn't available in jax or nnx. What I would love to do is try things like python my_model.py --depth=3 --final_activation=nnx.identity

So, could we add this to the nnx activation functions? Something trivial like

@jax.jit
def identity(x: ArrayLike) -> Array:
  r"""Identity activation function.

  Returns the argument unmodified.

  Args:
    x : input array

  Returns:
    The argument `x` unmodified.

  Examples:
    >>> nnx.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
    Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)

  """  
  return x

If so, I can create a PR if you tell me which files to add it to. It seemed like https://github.com/google/flax/blob/main/flax/nnx/nn/activations.py is a natural place. If a test is required, tell me which file it makes the most sense in?

@Ruhaan838
Copy link

Okay, you have a nice suggestion that I should allow you to do this. I found that it uses jax for activation, so you need to add the file here.
-> Main src for the functions:
https://github.com/jax-ml/jax/blob/main/jax/_src/nn/functions.py
and also add that here too
-> actual imports:
https://github.com/jax-ml/jax/blob/main/jax/nn/__init__.py

@jlperla
Copy link
Contributor Author

jlperla commented Mar 18, 2025

I see, so you think the best place for this is in the jax activations which end up exported directly by NNX. Makes sense, let me post there.

@jlperla
Copy link
Contributor Author

jlperla commented Mar 18, 2025

Done. See jax-ml/jax#27222

If the JAX maintainers add this (and if anyone knows them they could make a quick push?) then I can add in the re-export in https://github.com/google/flax/blob/main/flax/nnx/nn/activations.py

@Ruhaan838
Copy link

Yes, I hope someone merges this PR Quickly.

@jlperla jlperla linked a pull request Mar 25, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants