Flax 0.12.0
includes many updates and some important breaking changes to the NNX API.
Breaking Changes
Pytree Strict Attributes
nnx.Pytree
and therefore nnx.Module
are now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:
from flax import nnx
import jax
import jax.numpy as jnp
class Foo(nnx.Module):
def __init__(self, use_bias, rngs):
self.layers = [ # ERROR
nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
]
self.bias = None # status = static
if use_bias:
self.bias = nnx.Param(rngs.params.uniform(3,)) # ERROR
This happens for two reasons:
- JAX pytree structures that contain Arrays now have to be marked with
nnx.data
. Alternatively, if the container pytree is alist
or adict
, you can usennx.List
ornnx.Dict
, which additionally allow mixed "data" and "static" elements. - Attributes will no longer automatically change their status—this now has to be done explicitly using
nnx.data
ornnx.static
. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.
To fix the above you can just create layers
as a List
Module which is automatically recognized as data, and be explicit about bias
being a data attribute on the first assignment by using nnx.data
:
class Foo(nnx.Module):
def __init__(self, use_bias, rngs):
self.layers = nnx.List([ # nnx.data also works but List is recommended
nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
])
self.bias = nnx.data(None)
if use_bias:
self.bias = nnx.Param(rngs.params.uniform(3,))
For more information check the Module & Pytree guide.
Eager Sharding
Variables will now eagerly shard their values when sharding_names
metadata is provided. A mesh is required—it can be provided either via passing a mesh
metadata attribute or setting the global mesh context via jax.set_mesh
. This simplifies the process of sharding a Variable to construction time:
jax.config.update('jax_num_cpu_devices', 8)
mesh = jax.make_mesh((2, 4), ('data', 'model'))
with jax.set_mesh(mesh):
variable = nnx.Param(jnp.ones((16, 32)), sharding_names=(None, 'model'))
print(variable.value.sharding)
Eager sharding will also occur when using the nnx.with_partitioning
initializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersome nnx.get_partition_spec
+ jax.lax.with_sharding_constraint
+ nnx.update
pattern:
with jax.set_mesh(mesh):
linear = nnx.Linear(
in_features=16, out_features=16, use_bias=False,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), (None, 'model')
),
rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)
For projects that currently rely on other means for sharding, eager sharding can be turned off by passing eager_sharding=False
to the Variable constructor, either directly or through initializer decorators like nnx.with_partitioning
:
linear = nnx.Linear(
in_features=16, out_features=16, use_bias=False,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), (None, 'model'), eager_sharding=False
),
rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)
Eager sharding can also be turned off globally via the flax_always_shard_variable
config flag or the FLAX_ALWAYS_SHARD_VARIABLE
environment variable:
import flax
flax.config.update('flax_always_shard_variable', False)
For more information, check out the Variable eager sharding FLIP.
In-Place Operators No Longer Allowed
In-place operators will now raise an error. This is done as part of the push for Variables to be compatible with Tracer
semantics:
w = nnx.Variable(jnp.array(0))
w += 1 # ERROR
The fix is to simply operate on the .value
property instead:
w.value += 1
All Changes
- Doc fix: remove dead link to pre-Orbax checkpointing. by @copybara-service[bot] in #4914
- Fix typo in unflatten docs by @copybara-service[bot] in #4918
- fix RNN by @copybara-service[bot] in #4917
- Update optimizer.py to support masked variable from optax. by @ywrt in #4904
- Added missing functions to graph.rst by @vfdev-5 in #4922
- Update flax/docs_nnx/guides/performance.md and .ipynb by @hanrach9 in #4919
- Added preferred_element_type arg to nnx.Linear*, nnx.Conv*, nnx.Einsum by @vfdev-5 in #4920
- Update README badges and remove invalid ones by @IvyZX in #4905
- static + pytree guide by @cgarciae in #4897
- fix mypy by @copybara-service[bot] in #4931
- Avoid passing non-boolean mask to
where
argument ofjax.numpy
reductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. by @copybara-service[bot] in #4923 - Ported nnx.PReLU from linen by @vfdev-5 in #4934
- Added nnx.scan docs and few minor docs fixes by @vfdev-5 in #4930
- add variables argument to nnx.clone by @cgarciae in #4945
- only copy dicts on State.getitem by @cgarciae in #4946
- always differentiate standalone Variables in nnx.grad by @cgarciae in #4947
- Implement instance norm in NNX by @mattbahr in #4939
- Automatically apply sharding constraints to sharded models by @IvyZX in #4844
- Add reference of flip doc to gspmd guide by @IvyZX in #4949
- Fixed nnx.is_data docstring rendering by @vfdev-5 in #4957
- expose pytree guide by @cgarciae in #4951
- fix toy examples by @cgarciae in #4952
- Explicitly cast attribute names to string before checking for private attributes. by @copybara-service[bot] in #4955
- add flax_hijax_variable flag by @cgarciae in #4953
- mark shard_map as implemented in transforms guide by @cgarciae in #4738
- improve Variable flatten by @cgarciae in #4954
- Minor typo fix in nnx.call docstring by @vfdev-5 in #4959
- allow split tuples in Rngs.fork by @cgarciae in #4958
- Fixed Gemma example using Gemma2 models by @vfdev-5 in #4830
- finish pytree guide by @cgarciae in #4929
- update bridge wrappers from maxtext by @cgarciae in #4937
- fix HashableMapping hash definition for mixed key types by @copybara-service[bot] in #4936
- Flax RNG guide for jax.jit: clarify rng outputs are shared but not inputs. by @copybara-service[bot] in #4956
- fix Variable pytree flatten by @copybara-service[bot] in #4962
- import PathParts from flax.typing by @cgarciae in #4966
- Correctly expose
flax.config.temp_flip_flag
by @IvyZX in #4969 - raise on Variable inplace operators by @cgarciae in #4967
- Copybara import of the project: by @copybara-service[bot] in #4976
- update to version 0.12.0 by @cgarciae in #4982
- Minor typo fixes in flax gspmd guide by @vfdev-5 in #4970
- ignore uv.lock by @copybara-service[bot] in #4974
- [nnx] preserve the function's type information in jit by @cgarciae in #4981
- add Variable.set_metadata by @cgarciae in #4968
- propagate eager sharding by @cgarciae in #4983
New Contributors
Full Changelog: v0.11.2...v0.12.0