diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index b35c912bee..91e2cbf408 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -42,6 +42,7 @@ from .nnx.graph import GraphState as GraphState from .nnx.graph import PureState as PureState from .nnx.object import Object as Object +from .nnx.object import pytree as pytree from .nnx.helpers import Dict as Dict from .nnx.helpers import List as List from .nnx.helpers import Sequential as Sequential @@ -161,4 +162,6 @@ from .nnx.variables import VariableMetadata as VariableMetadata from .nnx.variables import with_metadata as with_metadata from .nnx.visualization import display as display -from .nnx.extract import to_tree, from_tree, TreeNode +from .nnx.extract import to_tree as to_tree +from .nnx.extract import from_tree as from_tree +from .nnx.extract import TreeNode as TreeNode diff --git a/flax/nnx/nnx/bridge/module.py b/flax/nnx/nnx/bridge/module.py index 69c964a9ee..c8e92153fb 100644 --- a/flax/nnx/nnx/bridge/module.py +++ b/flax/nnx/nnx/bridge/module.py @@ -199,9 +199,7 @@ def is_initializing(self) -> bool: return self._object__state._initializing - def __init_subclass__(cls, experimental_pytree: bool = False) -> None: - super().__init_subclass__(experimental_pytree=experimental_pytree) - + def __init_subclass__(cls) -> None: cls = dataclasses.dataclass(repr=False)(cls) diff --git a/flax/nnx/nnx/bridge/wrappers.py b/flax/nnx/nnx/bridge/wrappers.py index 6878be0a19..c206671a88 100644 --- a/flax/nnx/nnx/bridge/wrappers.py +++ b/flax/nnx/nnx/bridge/wrappers.py @@ -20,8 +20,9 @@ from flax import linen from flax.core import meta from flax.nnx.nnx import graph +from flax.nnx.nnx.graph import GraphDef from flax.nnx.nnx.bridge import variables as bv -from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.module import Module from flax.nnx.nnx.rnglib import Rngs from flax.nnx.nnx.state import State from flax.nnx.nnx.object import Object diff --git a/flax/nnx/nnx/extract.py b/flax/nnx/nnx/extract.py index cad3b4113d..da9a75a610 100644 --- a/flax/nnx/nnx/extract.py +++ b/flax/nnx/nnx/extract.py @@ -23,7 +23,7 @@ from flax import struct from flax.nnx.nnx.object import Object from flax.typing import MISSING, PathParts -from flax.nnx.nnx import graph +from flax.nnx.nnx import graph, safe_tree A = tp.TypeVar('A') @@ -75,7 +75,7 @@ def extract_graph_nodes( pytree, prefix_is_leaf=lambda x: x is None, ) - key_leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree) + key_leaves, treedef = safe_tree.flatten_with_path(pytree) assert len(key_leaves) == len(prefix_leaves) @@ -115,7 +115,7 @@ def _maybe_insert(x): return nodes[x.index] return x - return jax.tree.map( + return safe_tree.map( _maybe_insert, pytree, is_leaf=lambda x: isinstance(x, ExtractionIndex) ) @@ -225,11 +225,9 @@ def broadcast_prefix( # ValueError; use prefix_errors to find disagreements and raise more precise # error messages. result = [] - num_leaves = lambda t: jax.tree_util.tree_structure( - t, is_leaf=tree_is_leaf - ).num_leaves + num_leaves = lambda t: safe_tree.structure(t, is_leaf=tree_is_leaf).num_leaves add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree)) - jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=prefix_is_leaf) + safe_tree.map(add_leaves, prefix_tree, full_tree, is_leaf=prefix_is_leaf) return result @@ -342,7 +340,7 @@ def to_tree( tree, prefix_is_leaf=lambda x: x is None, ) - leaf_keys, treedef = jax.tree_util.tree_flatten_with_path(tree) + leaf_keys, treedef = safe_tree.flatten_with_path(tree) assert len(leaf_keys) == len(leaf_prefixes) leaves_out = [] @@ -396,9 +394,7 @@ def from_tree( prefix_is_leaf=lambda x: x is None or is_leaf(x), tree_is_leaf=is_leaf, ) - leaf_keys, treedef = jax.tree_util.tree_flatten_with_path( - tree, is_leaf=is_leaf - ) + leaf_keys, treedef = safe_tree.flatten_with_path(tree, is_leaf=is_leaf) assert len(leaf_keys) == len(leaf_prefixes) leaves_out = [] @@ -416,4 +412,4 @@ def from_tree( return pytree_out def clear_non_graph_nodes(tree): - return jax.tree.map(lambda x: x if graph.is_graph_node(x) else None, tree) \ No newline at end of file + return safe_tree.map(lambda x: x if graph.is_graph_node(x) else None, tree) diff --git a/flax/nnx/nnx/helpers.py b/flax/nnx/nnx/helpers.py index fe2ad8daf3..1d758bf711 100644 --- a/flax/nnx/nnx/helpers.py +++ b/flax/nnx/nnx/helpers.py @@ -35,7 +35,9 @@ import optax from flax.nnx.nnx.graph import Key -from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.module import Module +from flax.nnx.nnx.graph import GraphDef +from flax.nnx.nnx.object import ObjectStaticMetadata from flax.nnx.nnx.proxy_caller import ApplyCaller from flax.nnx.nnx.rnglib import Rngs from flax.nnx.nnx.state import State @@ -111,7 +113,12 @@ def _graph_node_flatten(self): if key not in ('_object__state', '_length') ) nodes.append(('_length', self._length)) - return nodes, (type(self), self._object__state._initializing) + metadata = ObjectStaticMetadata( + type=type(self), + initializing=self._object__state._initializing, + is_pytree=self._object__state._is_pytree, + ) + return nodes, metadata def _graph_node_set_key(self, key: Key, value: tp.Any): if isinstance(key, int): diff --git a/flax/nnx/nnx/module.py b/flax/nnx/nnx/module.py index 48e8987a05..208ebecf82 100644 --- a/flax/nnx/nnx/module.py +++ b/flax/nnx/nnx/module.py @@ -15,19 +15,15 @@ from __future__ import annotations import typing as tp -from functools import partial -import jax.tree_util as jtu from flax.nnx.nnx import ( filterlib, graph, ) from flax.nnx.nnx import variables as variableslib -from flax.nnx.nnx.graph import GraphDef from flax.nnx.nnx.object import Object, ObjectMeta -from flax.nnx.nnx.graph import GraphState, StateLeaf -from flax.nnx.nnx.state import State +from flax.nnx.nnx.graph import GraphState from flax.typing import Key, Path, PathParts A = tp.TypeVar('A') @@ -392,58 +388,24 @@ def eval(self, **attributes): raise_if_not_found=False, ) - def __init_subclass__(cls, experimental_pytree: bool = False) -> None: - super().__init_subclass__() - - if experimental_pytree: - jtu.register_pytree_with_keys( - cls, - partial(_module_flatten, with_keys=True), - _module_unflatten, # type: ignore[arg-type] - flatten_func=partial(_module_flatten, with_keys=False), - ) - def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] + children = {} for name, value in vars(self).items(): if name.startswith('_'): continue children[name] = value return treescope.repr_lib.render_object_constructor( - object_type=type(self), - attributes=children, - path=path, - subtree_renderer=subtree_renderer, - color=treescope.formatting_util.color_from_string( - type(self).__qualname__ - ) + object_type=type(self), + attributes=children, + path=path, + subtree_renderer=subtree_renderer, + color=treescope.formatting_util.color_from_string( + type(self).__qualname__ + ), ) -# ------------------------- -# Pytree Definition -# ------------------------- -def _module_flatten(module: Module, *, with_keys: bool): - graphdef, state = graph.split(module) - key_values = sorted(state.raw_mapping.items()) - keys = tuple(key for key, _ in key_values) - - children: tuple[tp.Any, ...] - if with_keys: - children = tuple((jtu.DictKey(key), value) for key, value in key_values) - else: - children = tuple(value for _, value in key_values) - - return children, (keys, graphdef) - - -def _module_unflatten( - paths_moduledef: tuple[tuple[Path, ...], GraphDef[M]], - variables: tuple[StateLeaf, ...], -) -> M: - paths, graphdef = paths_moduledef - return graph.merge(graphdef, State(zip(paths, variables))) - def first_from(*args: tp.Optional[A], error_msg: str) -> A: """Return the first non-None argument. diff --git a/flax/nnx/nnx/object.py b/flax/nnx/nnx/object.py index 78fed06133..c404927964 100644 --- a/flax/nnx/nnx/object.py +++ b/flax/nnx/nnx/object.py @@ -15,6 +15,7 @@ from __future__ import annotations import dataclasses +from functools import partial import threading import typing as tp from abc import ABCMeta @@ -30,10 +31,11 @@ tracers, ) from flax.nnx.nnx import graph +from flax.nnx.nnx.state import State from flax.nnx.nnx.variables import Variable, VariableState from flax.typing import Key -G = tp.TypeVar('G', bound='Object') +O = tp.TypeVar('O', bound='Object') @dataclasses.dataclass @@ -43,13 +45,31 @@ class GraphUtilsContext(threading.local): CONTEXT = GraphUtilsContext() +@dataclasses.dataclass(frozen=True, repr=False) +class Leaf(tp.Generic[O], reprlib.Representable): + obj: O + + def __nnx_repr__(self): + yield reprlib.Object(type(self)) + yield reprlib.Attr('obj', self.obj) + + def __treescope_repr__(self, path, subtree_renderer): + import treescope # type: ignore[import-not-found,import-untyped] + + return treescope.repr_lib.render_object_constructor( + object_type=type(self), + attributes={'obj': self.obj}, + path=path, + subtree_renderer=subtree_renderer, + ) class ObjectState(reprlib.Representable): - __slots__ = ('_trace_state', '_initializing') + __slots__ = ('_trace_state', '_initializing', '_is_pytree') - def __init__(self, initializing: bool = False): + def __init__(self, initializing: bool, is_pytree: bool): self._trace_state = tracers.TraceState() self._initializing = initializing + self._is_pytree = is_pytree @property def trace_state(self) -> tracers.TraceState: @@ -59,6 +79,10 @@ def trace_state(self) -> tracers.TraceState: def initializing(self) -> bool: return self._initializing + @property + def is_pytree(self) -> bool: + return self._is_pytree + def __nnx_repr__(self): yield reprlib.Object(type(self)) yield reprlib.Attr('trace_state', self._trace_state) @@ -82,11 +106,12 @@ def _object_meta_construct(cls, self, *args, **kwargs): self.__init__(*args, **kwargs) -def _graph_node_meta_call(cls: tp.Type[G], *args, **kwargs) -> G: +def _graph_node_meta_call(cls: tp.Type[O], *args, **kwargs) -> O: node = cls.__new__(cls, *args, **kwargs) - vars(node)['_object__state'] = ObjectState() + vars(node)['_object__state'] = ObjectState( + initializing=False, is_pytree=cls._object__is_pytree + ) cls._object_meta_construct(node, *args, **kwargs) - return node @@ -100,6 +125,8 @@ def __repr__(self): class Object(reprlib.Representable, metaclass=ObjectMeta): + _object__is_pytree: bool = False + if tp.TYPE_CHECKING: _object__state: ObjectState @@ -115,6 +142,13 @@ def __init_subclass__(cls) -> None: clear=cls._graph_node_clear, ) + jax.tree_util.register_pytree_with_keys( + cls, + partial(_flatten_object, with_keys=True), # type: ignore + _unflatten_object, # type: ignore + flatten_func=partial(_flatten_object, with_keys=False), # type: ignore + ) + if not tp.TYPE_CHECKING: def __setattr__(self, name: str, value: Any) -> None: @@ -130,7 +164,7 @@ def check_valid_context(self, error_msg: tp.Callable[[], str]) -> None: if not self._object__state.trace_state.is_valid(): raise errors.TraceContextError(error_msg()) - def __deepcopy__(self: G, memo=None) -> G: + def __deepcopy__(self: O, memo=None) -> O: graphdef, state = graph.split(self) graphdef = deepcopy(graphdef) state = deepcopy(state) @@ -194,7 +228,12 @@ def _graph_node_flatten(self): for key, value in vars(self).items() if key != '_object__state' ) - return nodes, (type(self), self._object__state._initializing) + metadata = ObjectStaticMetadata( + type=type(self), + initializing=self._object__state._initializing, + is_pytree=self._object__state._is_pytree, + ) + return nodes, metadata def _graph_node_set_key(self, key: Key, value: tp.Any): if not isinstance(key, str): @@ -214,10 +253,13 @@ def _graph_node_pop_key(self, key: Key): return vars(self).pop(key) @staticmethod - def _graph_node_create_empty(static: tuple[tp.Type[G], bool]) -> G: - node_type, initializing = static - node = object.__new__(node_type) - vars(node).update(_object__state=ObjectState(initializing)) + def _graph_node_create_empty(metadata: ObjectStaticMetadata[O]) -> O: + node = object.__new__(metadata.type) + vars(node).update( + _object__state=ObjectState( + initializing=metadata.initializing, is_pytree=metadata.is_pytree + ) + ) return node def _graph_node_clear(self): @@ -225,3 +267,66 @@ def _graph_node_clear(self): module_vars = vars(self) module_vars.clear() module_vars['_object__state'] = module_state + +@dataclasses.dataclass(frozen=True) +class ObjectStaticMetadata(tp.Generic[O]): + type: tp.Type[O] + initializing: bool + is_pytree: bool + + +# ------------------------- +# Pytree Definition +# ------------------------- +def _flatten_object(obj: Object, *, with_keys: bool): + is_pytree = obj._object__state._is_pytree + if is_pytree: + graphdef, state = graph.split(obj) + key_values = sorted(state.raw_mapping.items()) + keys = tuple(key for key, _ in key_values) + + nodes: tuple[tp.Any, ...] + if with_keys: + nodes = tuple( + (jax.tree_util.GetAttrKey(str(key)), value) for key, value in key_values + ) + else: + nodes = tuple(value for _, value in key_values) + + return nodes, (keys, graphdef) + else: + if with_keys: + nodes = ((jax.tree_util.GetAttrKey('leaf'), Leaf(obj)),) + return nodes, None + + +def _unflatten_object( + metadata: tuple[tuple[Key, ...], graph.GraphDef[O]] | None, + children: tuple[tp.Any, ...] | tuple[Leaf[O]], +) -> O: + if metadata is None: + if len(children) != 1: + raise ValueError(f'Expected 1 child, got {len(children)}') + elif not isinstance(children[0], Leaf): + raise ValueError(f'Expected Leaf, got {type(children[0])}') + return children[0].obj + else: + _children = tp.cast(tuple[tp.Any, ...], children) + paths, graphdef = metadata + return graph.merge(graphdef, State(zip(paths, _children))) + +# ------------------------- +# pytree API +# ------------------------- +@tp.overload +def pytree(node_or_class: tp.Type[O]) -> tp.Type[O]: ... +@tp.overload +def pytree(node_or_class: O) -> O: ... +def pytree(node_or_class: Object | type[Object]): + if isinstance(node_or_class, type): + node_or_class._object__is_pytree = True + return node_or_class + else: + obj = graph.clone(node_or_class) + obj._object__state._is_pytree = True + return obj \ No newline at end of file diff --git a/flax/nnx/nnx/safe_tree.py b/flax/nnx/nnx/safe_tree.py new file mode 100644 index 0000000000..b0d1772dc9 --- /dev/null +++ b/flax/nnx/nnx/safe_tree.py @@ -0,0 +1,71 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp +from flax.nnx.nnx import graph + +import jax + + +def _get_is_leaf(is_leaf): + def _is_leaf_or_graph_node(x): + if graph.is_graph_node(x): + return True + elif is_leaf is None: + return False + else: + return is_leaf(x) + + return _is_leaf_or_graph_node + + +if tp.TYPE_CHECKING: + map = jax.tree_util.tree_map + leaves = jax.tree_util.tree_leaves + flatten = jax.tree_util.tree_flatten + structure = jax.tree_util.tree_structure + flatten_with_path = jax.tree_util.tree_flatten_with_path + leaves_with_path = jax.tree_util.tree_leaves_with_path + map_with_path = jax.tree_util.tree_map_with_path +else: + + def map(*args, is_leaf=None, **kwargs): + is_leaf = _get_is_leaf(is_leaf) + return jax.tree_util.tree_map(*args, is_leaf=is_leaf, **kwargs) + + def leaves(*args, is_leaf=None, **kwargs): + is_leaf = _get_is_leaf(is_leaf) + return jax.tree_util.tree_leaves(*args, is_leaf=is_leaf, **kwargs) + + def flatten(*args, is_leaf=None, **kwargs): + is_leaf = _get_is_leaf(is_leaf) + return jax.tree_util.tree_flatten(*args, is_leaf=is_leaf, **kwargs) + + def structure(*args, is_leaf=None, **kwargs): + is_leaf = _get_is_leaf(is_leaf) + return jax.tree_util.tree_structure(*args, is_leaf=is_leaf, **kwargs) + + def flatten_with_path(*args, is_leaf=None, **kwargs): + is_leaf = _get_is_leaf(is_leaf) + return jax.tree_util.tree_flatten_with_path( + *args, is_leaf=is_leaf, **kwargs + ) + + def leaves_with_path(*args, is_leaf=None, **kwargs): + is_leaf = _get_is_leaf(is_leaf) + return jax.tree_util.tree_leaves_with_path(*args, is_leaf=is_leaf, **kwargs) + + def map_with_path(*args, is_leaf=None, **kwargs): + is_leaf = _get_is_leaf(is_leaf) + return jax.tree_util.tree_map_with_path(*args, is_leaf=is_leaf, **kwargs) diff --git a/flax/nnx/nnx/transforms/autodiff.py b/flax/nnx/nnx/transforms/autodiff.py index 5795d51ca8..171d76b18b 100644 --- a/flax/nnx/nnx/transforms/autodiff.py +++ b/flax/nnx/nnx/transforms/autodiff.py @@ -23,6 +23,7 @@ extract, filterlib, graph, + safe_tree, variables, ) from flax.nnx.nnx.state import State @@ -164,7 +165,7 @@ def _grad_split_fn( fn_out = gradded_fn(*pure_args) def process_grads(grads): - return jax.tree.map( + return safe_tree.map( lambda x: x.state if isinstance(x, extract.TreeNode) else x, grads, is_leaf=lambda x: isinstance(x, extract.TreeNode), @@ -490,7 +491,7 @@ def __call__(self, *args): metadata, pure_residual = res nondiff = extract.from_tree(nondiff) residual = extract.from_tree(pure_residual) - pure_g = jax.tree.map( + pure_g = safe_tree.map( lambda x: x.state if isinstance(x, extract.TreeNode) else x, pure_g, is_leaf=lambda x: isinstance(x, extract.TreeNode), diff --git a/flax/nnx/nnx/transforms/compilation.py b/flax/nnx/nnx/transforms/compilation.py index 766f25ec81..d42868dcea 100644 --- a/flax/nnx/nnx/transforms/compilation.py +++ b/flax/nnx/nnx/transforms/compilation.py @@ -35,6 +35,7 @@ extract, filterlib, graph, + safe_tree, ) import jax import jax.core @@ -303,13 +304,13 @@ def jit( abstracted_axes=abstracted_axes, ) # type: ignore[return-value] kwarg_shardings = None - jax_in_shardings = jax.tree.map( + jax_in_shardings = safe_tree.map( lambda x: extract.TreeNode.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, in_shardings, ) - jax_out_shardings = jax.tree.map( + jax_out_shardings = safe_tree.map( lambda x: extract.TreeNode.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, diff --git a/flax/nnx/nnx/transforms/deprecated.py b/flax/nnx/nnx/transforms/deprecated.py index 3dd826b254..ea63893317 100644 --- a/flax/nnx/nnx/transforms/deprecated.py +++ b/flax/nnx/nnx/transforms/deprecated.py @@ -35,7 +35,8 @@ from flax import struct from flax.core.frozen_dict import FrozenDict from flax.nnx.nnx import extract, filterlib, graph, rnglib, spmd, variables -from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.module import Module +from flax.nnx.nnx.graph import GraphDef from flax.nnx.nnx.proxy_caller import DelayedAccessor from flax.nnx.nnx.state import State from flax.nnx.nnx.transforms.transforms import LiftedModule diff --git a/flax/nnx/nnx/transforms/iteration.py b/flax/nnx/nnx/transforms/iteration.py index 4315f1a379..9059603fa6 100644 --- a/flax/nnx/nnx/transforms/iteration.py +++ b/flax/nnx/nnx/transforms/iteration.py @@ -34,7 +34,7 @@ from flax import struct from flax.core.frozen_dict import FrozenDict -from flax.nnx.nnx import extract, filterlib, graph, spmd +from flax.nnx.nnx import extract, filterlib, graph, safe_tree, spmd from flax.nnx.nnx.module import Module from flax.nnx.nnx.state import State from flax.nnx.nnx.transforms.transforms import resolve_kwargs @@ -134,8 +134,8 @@ def _update_axes_fn(tree_node): return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) return tree_node - return jax.tree.map( - _update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.TreeNode) + return safe_tree.map( + _update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.TreeNode) ) @@ -311,17 +311,17 @@ def vmap( transform_metadata=transform_metadata, ) # type: ignore[return-value] - jax_in_axes = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) - if isinstance(x, StateAxes) - else x, - in_axes, + jax_in_axes = safe_tree.map( + lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + in_axes, ) - jax_out_axes = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) - if isinstance(x, StateAxes) - else x, - out_axes, + jax_out_axes = safe_tree.map( + lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + out_axes, ) vmapped_fn = jax.vmap( VmapFn(f, transform_metadata, in_axes, out_axes), @@ -531,17 +531,17 @@ def pmap( transform_metadata=transform_metadata, ) # type: ignore[return-value] - jax_in_axes = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) - if isinstance(x, StateAxes) - else x, - in_axes, + jax_in_axes = safe_tree.map( + lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + in_axes, ) - jax_out_axes = jax.tree.map( - lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) - if isinstance(x, StateAxes) - else x, - out_axes, + jax_out_axes = safe_tree.map( + lambda x: extract.TreeNode.from_prefixes(x.axes, metadata=x) + if isinstance(x, StateAxes) + else x, + out_axes, ) pmapped_fn = jax.pmap( PmapFn(f, transform_metadata, in_axes, out_axes), @@ -586,7 +586,7 @@ def _get_carry_argnum(axes, is_in_axes: bool): obj_repr = 'in_axes' if is_in_axes else 'out_axes' carry_argnum: int | None = None prev_key: tp.Any = None - for key, x in jax.tree_util.tree_leaves_with_path(axes): + for key, x in safe_tree.leaves_with_path(axes): if x is not Carry: continue assert isinstance(key[0], jax.tree_util.SequenceKey) @@ -641,9 +641,7 @@ def check_carry_same_references(key_path, arg, out): f'at carry{jax.tree_util.keystr(key_path)}' ) - jax.tree_util.tree_map_with_path( - check_carry_same_references, carry_arg, carry_arg_out - ) + safe_tree.map_with_path(check_carry_same_references, carry_arg, carry_arg_out) def _extract_index_mappings( pure_carry_arg_out, @@ -662,7 +660,7 @@ def extract_index_mappings(x): ) return x - pure_carry_arg_out = jax.tree.map( + pure_carry_arg_out = safe_tree.map( extract_index_mappings, pure_carry_arg_out, is_leaf=lambda x: isinstance(x, extract.GraphDefState), @@ -685,7 +683,7 @@ def insert_index_mappings(x): ) return x - pure_carry_arg_out = jax.tree.map( + pure_carry_arg_out = safe_tree.map( insert_index_mappings, pure_carry_arg_out, is_leaf=lambda x: isinstance(x, extract.GraphDefState), diff --git a/flax/nnx/tests/graph_utils_test.py b/flax/nnx/tests/graph_utils_test.py index 57b0f2e3c1..d6b392af14 100644 --- a/flax/nnx/tests/graph_utils_test.py +++ b/flax/nnx/tests/graph_utils_test.py @@ -794,7 +794,8 @@ class SimpleModule(nnx.Module): pass -class SimplePyTreeModule(nnx.Module, experimental_pytree=True): +@nnx.pytree +class SimplePyTreeModule(nnx.Module): pass diff --git a/flax/nnx/tests/module_test.py b/flax/nnx/tests/module_test.py index d5aeae08cd..ca6835f733 100644 --- a/flax/nnx/tests/module_test.py +++ b/flax/nnx/tests/module_test.py @@ -486,7 +486,8 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): class TestModulePytree: def test_tree_map(self): - class Foo(nnx.Module, experimental_pytree=True): + @nnx.pytree + class Foo(nnx.Module): def __init__(self): self.node = nnx.Param(1) self.graphdef = 1 @@ -499,7 +500,8 @@ def __init__(self): assert m.graphdef == 1 def test_static(self): - class C(nnx.Module, experimental_pytree=True): + @nnx.pytree + class C(nnx.Module): def __init__(self, x): self.x = x