Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"JAX array is set a static" warning is raised unwantedly #863

Open
gautierronan opened this issue Sep 24, 2024 · 5 comments
Open

"JAX array is set a static" warning is raised unwantedly #863

gautierronan opened this issue Sep 24, 2024 · 5 comments

Comments

@gautierronan
Copy link

gautierronan commented Sep 24, 2024

As of Equinox 0.11.6 and #800, the following MWE raises a UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.

import numpy as np
import equinox as eqx

class Foo(eqx.Module):
    x: tuple[int, int] = eqx.field(static=True)

    def add_one(self):
        x_as_np = np.asarray(self.x)
        return Foo(tuple(x_as_np+1))

x = (3, 2)
foo = Foo(x)
foo.add_one()
# UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.

This means that one cannot perform numpy operations (which is often simpler than writing them in plain python) on a static attribute. This is a use-case we have in dynamiqs, see for instance the method __mul__ of this class which represents an array in diagonal (DIA) sparse format. Note that we intentionally use numpy instead of jax.numpy to have "static" logic.

@gautierronan
Copy link
Author

A simple fix would be to replace is_array with is_jax_array (i.e. isinstance(..., jax.Array)) in

is_array, jtu.tree_flatten(getattr(self, field.name))[0]

but I'm not sure this is in line with the intended use.

@lockwo
Copy link
Contributor

lockwo commented Sep 24, 2024

I would say this behavior is expected (whether not it is wanted maybe another question). Since in general numpy arrays are not hashable, and making things static is to set them as aux data in the pytree (https://github.com/patrick-kidger/equinox/blob/main/equinox/_module.py#L946), which expects hashability since "JAX sometimes needs to compare treedef for equality, or compute its hash for use in the JIT cache, and so care must be taken to ensure that the auxiliary data specified in the flattening recipe supports meaningful hashing and equality comparisons." (https://jax.readthedocs.io/en/latest/pytrees.html). This can sometimes cause silent or (speaking from personal experience) confusing errors, which is why I wanted to add the warning.

I agree, at the least, that the warning message is wrong because a JAX array isn't being set static (a numpy array is), so matching the message to the check should be done. As for numpy arrays, I think they were included originally because of their hash problems (to quote @/ jakevdp "Neither np.ndarray nor jax.Array satisfy this, so they should not be included in aux_data. If you do include such values in aux_data, you'll get unsupported, poorly-defined behavior."). That being said, there definitely are cases where using static arrays can be fine and correct (which is why the warning can be ignored as opposed to error), and if these cases are very common then the warning could be a burden. WDYT?

@gautierronan
Copy link
Author

But in our example, the numpy array is just an intermediary for the computation. The actual computation is starting with a static tuple, and returning a static tuple, hence why I don't find that this should be an expected warning.

Also, I don't really see a way around it. The warning is being raised upon the class creation, so I don't see how we could filter this warning. In the case of our library, this will be raised everytime we make an operation on our class, and we really cannot make this class attribute not static.

@gautierronan
Copy link
Author

Actually, investigating more, the following works without warning:

import numpy as np
import equinox as eqx

class Foo(eqx.Module):
    x: tuple[int, int] = eqx.field(static=True)

    def add_one(self):
        x_as_np = np.asarray(self.x)
        x_as_np += 1
        x = tuple([i.item() for i in x_as_np])
        return Foo(x)

x = (3, 2)
foo = Foo(x)
foo.add_one()
# no warning

So what's being detected in the first example is that the tuple elements are of type np.int instead of int.

@lockwo
Copy link
Contributor

lockwo commented Sep 25, 2024

Hmmm I see it yea I misread it, it's a int64 class from numpy. That would be an mis usage of the is_array then, because np.int64's are hashable. I can just add a flag to exclude basic numpy dtypes to the check (maybe just excluding numpy generics from checking, are they all hashable?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants