Skip to content

Commit

Permalink
[nnx] add nnx.pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 5, 2024
1 parent 90715be commit 6a34d62
Show file tree
Hide file tree
Showing 14 changed files with 262 additions and 115 deletions.
5 changes: 4 additions & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .nnx.graph import GraphState as GraphState
from .nnx.graph import PureState as PureState
from .nnx.object import Object as Object
from .nnx.object import pytree as pytree
from .nnx.helpers import Dict as Dict
from .nnx.helpers import List as List
from .nnx.helpers import Sequential as Sequential
Expand Down Expand Up @@ -161,4 +162,6 @@
from .nnx.variables import VariableMetadata as VariableMetadata
from .nnx.variables import with_metadata as with_metadata
from .nnx.visualization import display as display
from .nnx.extract import to_tree, from_tree, TreeNode
from .nnx.extract import to_tree as to_tree
from .nnx.extract import from_tree as from_tree
from .nnx.extract import TreeNode as TreeNode
4 changes: 1 addition & 3 deletions flax/nnx/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ def is_initializing(self) -> bool:

return self._object__state._initializing

def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__(experimental_pytree=experimental_pytree)

def __init_subclass__(cls) -> None:
cls = dataclasses.dataclass(repr=False)(cls)


Expand Down
3 changes: 2 additions & 1 deletion flax/nnx/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from flax import linen
from flax.core import meta
from flax.nnx.nnx import graph
from flax.nnx.nnx.graph import GraphDef
from flax.nnx.nnx.bridge import variables as bv
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.module import Module
from flax.nnx.nnx.rnglib import Rngs
from flax.nnx.nnx.state import State
from flax.nnx.nnx.object import Object
Expand Down
20 changes: 8 additions & 12 deletions flax/nnx/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flax import struct
from flax.nnx.nnx.object import Object
from flax.typing import MISSING, PathParts
from flax.nnx.nnx import graph
from flax.nnx.nnx import graph, safe_tree


A = tp.TypeVar('A')
Expand Down Expand Up @@ -75,7 +75,7 @@ def extract_graph_nodes(
pytree,
prefix_is_leaf=lambda x: x is None,
)
key_leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree)
key_leaves, treedef = safe_tree.flatten_with_path(pytree)

assert len(key_leaves) == len(prefix_leaves)

Expand Down Expand Up @@ -115,7 +115,7 @@ def _maybe_insert(x):
return nodes[x.index]
return x

return jax.tree.map(
return safe_tree.map(
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, ExtractionIndex)
)

Expand Down Expand Up @@ -225,11 +225,9 @@ def broadcast_prefix(
# ValueError; use prefix_errors to find disagreements and raise more precise
# error messages.
result = []
num_leaves = lambda t: jax.tree_util.tree_structure(
t, is_leaf=tree_is_leaf
).num_leaves
num_leaves = lambda t: safe_tree.structure(t, is_leaf=tree_is_leaf).num_leaves
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=prefix_is_leaf)
safe_tree.map(add_leaves, prefix_tree, full_tree, is_leaf=prefix_is_leaf)
return result


Expand Down Expand Up @@ -342,7 +340,7 @@ def to_tree(
tree,
prefix_is_leaf=lambda x: x is None,
)
leaf_keys, treedef = jax.tree_util.tree_flatten_with_path(tree)
leaf_keys, treedef = safe_tree.flatten_with_path(tree)

assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []
Expand Down Expand Up @@ -396,9 +394,7 @@ def from_tree(
prefix_is_leaf=lambda x: x is None or is_leaf(x),
tree_is_leaf=is_leaf,
)
leaf_keys, treedef = jax.tree_util.tree_flatten_with_path(
tree, is_leaf=is_leaf
)
leaf_keys, treedef = safe_tree.flatten_with_path(tree, is_leaf=is_leaf)
assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []

Expand All @@ -416,4 +412,4 @@ def from_tree(
return pytree_out

def clear_non_graph_nodes(tree):
return jax.tree.map(lambda x: x if graph.is_graph_node(x) else None, tree)
return safe_tree.map(lambda x: x if graph.is_graph_node(x) else None, tree)
11 changes: 9 additions & 2 deletions flax/nnx/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
import optax

from flax.nnx.nnx.graph import Key
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.module import Module
from flax.nnx.nnx.graph import GraphDef
from flax.nnx.nnx.object import ObjectStaticMetadata
from flax.nnx.nnx.proxy_caller import ApplyCaller
from flax.nnx.nnx.rnglib import Rngs
from flax.nnx.nnx.state import State
Expand Down Expand Up @@ -111,7 +113,12 @@ def _graph_node_flatten(self):
if key not in ('_object__state', '_length')
)
nodes.append(('_length', self._length))
return nodes, (type(self), self._object__state._initializing)
metadata = ObjectStaticMetadata(
type=type(self),
initializing=self._object__state._initializing,
is_pytree=self._object__state._is_pytree,
)
return nodes, metadata

def _graph_node_set_key(self, key: Key, value: tp.Any):
if isinstance(key, int):
Expand Down
56 changes: 9 additions & 47 deletions flax/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@
from __future__ import annotations

import typing as tp
from functools import partial

import jax.tree_util as jtu

from flax.nnx.nnx import (
filterlib,
graph,
)
from flax.nnx.nnx import variables as variableslib
from flax.nnx.nnx.graph import GraphDef
from flax.nnx.nnx.object import Object, ObjectMeta
from flax.nnx.nnx.graph import GraphState, StateLeaf
from flax.nnx.nnx.state import State
from flax.nnx.nnx.graph import GraphState
from flax.typing import Key, Path, PathParts

A = tp.TypeVar('A')
Expand Down Expand Up @@ -392,58 +388,24 @@ def eval(self, **attributes):
raise_if_not_found=False,
)

def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__()

if experimental_pytree:
jtu.register_pytree_with_keys(
cls,
partial(_module_flatten, with_keys=True),
_module_unflatten, # type: ignore[arg-type]
flatten_func=partial(_module_flatten, with_keys=False),
)

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]

children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
color=treescope.formatting_util.color_from_string(
type(self).__qualname__
)
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
color=treescope.formatting_util.color_from_string(
type(self).__qualname__
),
)

# -------------------------
# Pytree Definition
# -------------------------
def _module_flatten(module: Module, *, with_keys: bool):
graphdef, state = graph.split(module)
key_values = sorted(state.raw_mapping.items())
keys = tuple(key for key, _ in key_values)

children: tuple[tp.Any, ...]
if with_keys:
children = tuple((jtu.DictKey(key), value) for key, value in key_values)
else:
children = tuple(value for _, value in key_values)

return children, (keys, graphdef)


def _module_unflatten(
paths_moduledef: tuple[tuple[Path, ...], GraphDef[M]],
variables: tuple[StateLeaf, ...],
) -> M:
paths, graphdef = paths_moduledef
return graph.merge(graphdef, State(zip(paths, variables)))


def first_from(*args: tp.Optional[A], error_msg: str) -> A:
"""Return the first non-None argument.
Expand Down
Loading

0 comments on commit 6a34d62

Please sign in to comment.