You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Flax 0.12.0 includes many updates and some important breaking changes to the NNX API.
Breaking Changes
Pytree Strict Attributes
nnx.Pytree and therefore nnx.Module are now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:
JAX pytree structures that contain Arrays now have to be marked with nnx.data. Alternatively, if the container pytree is a list or a dict, you can use nnx.List or nnx.Dict, which additionally allow mixed "data" and "static" elements.
Attributes will no longer automatically change their status—this now has to be done explicitly using nnx.data or nnx.static. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.
To fix the above you can just create layers as a List Module which is automatically recognized as data, and be explicit about bias being a data attribute on the first assignment by using nnx.data:
classFoo(nnx.Module):
def__init__(self, use_bias, rngs):
self.layers=nnx.List([ # nnx.data also works but List is recommendednnx.Linear(3, 3, rngs=rngs) for_inrange(5)
])
self.bias=nnx.data(None)
ifuse_bias:
self.bias=nnx.Param(rngs.params.uniform(3,))
Variables will now eagerly shard their values when sharding_names metadata is provided. A mesh is required—it can be provided either via passing a mesh metadata attribute or setting the global mesh context via jax.set_mesh. This simplifies the process of sharding a Variable to construction time:
Eager sharding will also occur when using the nnx.with_partitioning initializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersome nnx.get_partition_spec + jax.lax.with_sharding_constraint + nnx.update pattern:
For projects that currently rely on other means for sharding, eager sharding can be turned off by passing eager_sharding=False to the Variable constructor, either directly or through initializer decorators like nnx.with_partitioning:
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Flax
0.12.0
includes many updates and some important breaking changes to the NNX API.Breaking Changes
Pytree Strict Attributes
nnx.Pytree
and thereforennx.Module
are now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:This happens for two reasons:
nnx.data
. Alternatively, if the container pytree is alist
or adict
, you can usennx.List
ornnx.Dict
, which additionally allow mixed "data" and "static" elements.nnx.data
ornnx.static
. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.To fix the above you can just create
layers
as aList
Module which is automatically recognized as data, and be explicit aboutbias
being a data attribute on the first assignment by usingnnx.data
:For more information check the Module & Pytree guide.
Eager Sharding
Variables will now eagerly shard their values when
sharding_names
metadata is provided. A mesh is required—it can be provided either via passing amesh
metadata attribute or setting the global mesh context viajax.set_mesh
. This simplifies the process of sharding a Variable to construction time:Eager sharding will also occur when using the
nnx.with_partitioning
initializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersomennx.get_partition_spec
+jax.lax.with_sharding_constraint
+nnx.update
pattern:For projects that currently rely on other means for sharding, eager sharding can be turned off by passing
eager_sharding=False
to the Variable constructor, either directly or through initializer decorators likennx.with_partitioning
:Eager sharding can also be turned off globally via the
flax_always_shard_variable
config flag or theFLAX_ALWAYS_SHARD_VARIABLE
environment variable:For more information, check out the Variable eager sharding FLIP.
In-Place Operators No Longer Allowed
In-place operators will now raise an error. This is done as part of the push for Variables to be compatible with
Tracer
semantics:The fix is to simply operate on the
.value
property instead:All Changes
where
argument ofjax.numpy
reductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. by @copybara-service[bot] in Avoid passing non-boolean mask towhere
argument ofjax.numpy
reductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. #4923flax.config.temp_flip_flag
by @IvyZX in Correctly exposeflax.config.temp_flip_flag
#4969New Contributors
Full Changelog: v0.11.2...v0.12.0
This discussion was created from the release 0.12.0.
Beta Was this translation helpful? Give feedback.
All reactions