Skip to content

Commit

Permalink
[nnx] disallow Array leaves
Browse files Browse the repository at this point in the history
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.

```python
class Foo(nnx.Module):
  def __init__(self):
    self.a = jnp.array(1) # no longer allowed, instead...
    self.b = nnx.Param(jnp.array(1)) # just use Variables
```

PiperOrigin-RevId: 670949705
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Sep 5, 2024
1 parent aded9ac commit 8bed224
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 80 deletions.
14 changes: 10 additions & 4 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@
Leaf = tp.TypeVar('Leaf')
AuxData = tp.TypeVar('AuxData')

StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array]
StateLeaf = VariableState[tp.Any]
NodeLeaf = VariableState[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
return isinstance(x, (VariableState, np.ndarray, jax.Array))
return isinstance(x, VariableState)


def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
return isinstance(x, (Variable, np.ndarray, jax.Array))
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
return isinstance(x, Variable)


class _HashById(tp.Hashable, tp.Generic[A]):
Expand Down Expand Up @@ -416,6 +417,11 @@ def _graph_flatten(
flat_state[(*path, key)] = value
leaves.append((key, None))
else:
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
)
static_fields.append((key, value))

nodedef = NodeDef.create(
Expand Down
6 changes: 3 additions & 3 deletions flax/nnx/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
self.step = OptState(jnp.array(0, dtype=jnp.uint32))
self.model = model
self.tx = tx
self.opt_state = tx.init(nnx.state(model, wrt))
self.opt_state = OptState(tx.init(nnx.state(model, wrt)))
self.wrt = wrt

def split(self, *filters: filterlib.Filter):
Expand Down Expand Up @@ -198,10 +198,10 @@ def update(self, grads):
"""
state = nnx.state(self.model, self.wrt)

updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
updates, new_opt_state = self.tx.update(grads, self.opt_state.value, state)
new_params = optax.apply_updates(state, updates)
assert isinstance(new_params, nnx.State)

self.step.value += 1
nnx.update(self.model, new_params)
self.opt_state = new_opt_state
self.opt_state.value = new_opt_state
29 changes: 8 additions & 21 deletions flax/nnx/tests/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from copy import deepcopy
import dataclasses
from typing import Any, TypeVar

from absl.testing import absltest
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
import pytest

from flax import nnx

A = TypeVar('A')


class TestModule:
class TestModule(absltest.TestCase):
def test_has_module_state(self):
class Foo(nnx.Module): ...

Expand Down Expand Up @@ -662,26 +662,10 @@ def __init__(self, *, rngs: nnx.Rngs):
assert modules[1][0] == 'linear'
assert isinstance(modules[1][1], nnx.Linear)

def test_array_in_module(self):
class Foo(nnx.Module):
def __init__(self):
self.a = jnp.array(1.0)

foo = Foo()

graphdef, state = nnx.split(foo)

assert isinstance(state, nnx.State)
assert isinstance(state.a, jax.Array)

foo2 = nnx.merge(graphdef, state)

assert isinstance(foo2.a, jax.Array)

def test_state_in_module(self):
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.State({'b': jnp.array(1.0)})
self.a = nnx.State({'b': nnx.Param(jnp.array(1.0))})

foo = Foo()

Expand All @@ -693,3 +677,6 @@ def __init__(self):
foo2 = nnx.merge(graphdef, state)

assert isinstance(foo2.a, nnx.State)

if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 8bed224

Please sign in to comment.