-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Experimental-pytree flag causes crash #4142
Comments
Interesting counter examples for Module's pytree definition. The issue is that we decide what is static based on the type. |
Okay, I understand. Feel free to close if you like. |
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally. PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally. PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally. PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally. PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally. PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally. PiperOrigin-RevId: 670949705
copybara-service bot
pushed a commit
that referenced
this issue
Sep 5, 2024
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally. PiperOrigin-RevId: 671372717
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
gives
Assigning a
float
toepsilon
makes the problem disappear.(Tested on main and latest.)
The text was updated successfully, but these errors were encountered: