Skip to content

Releases: google/flax

Version 0.10.1

31 Oct 23:09
Compare
Choose a tag to compare

What's Changed

  • Add Flax NNX GraphDef docstring by @8bitmp3 in #4302
  • Flesh out the Haiku/Flax guide by @IvyZX in #4305
  • [nnx] improve mnist tutorial by @cgarciae in #4316
  • Update Flax Evolution from Linen to NNX guide by @8bitmp3 in #4289
  • [nnx] try casting integers keys in State.replace_by_pure_dict by @cgarciae in #4317
  • Fixed nnx examples bad links in the README.md by @vfdev-5 in #4282
  • Fix philosophy link by @jorisSchaller in #4313
  • [nnx] add gemma notebook by @cgarciae in #4075
  • [nnx] improve init_cache docs by @cgarciae in #4291
  • remove markdown from section titles by @cgarciae in #4322
  • Avoid depending on JAX internals, which are about to change. by @copybara-service in #4326
  • Remove outdated compatibility code. by @jakevdp in #4324
  • fix ruff complaints by @levskaya in #4331
  • Remove GeGLU activation function and golden tests. by @copybara-service in #4303
  • Avoid using float32 in normalization for mean/var and scale/bias parameters when force_float32_reductions=False by @copybara-service in #4314
  • Avoid assert_array_equal on PRNG keys. by @jakevdp in #4332
  • Fix typos in Flax NNX Migrating from Haiku to Flax by @8bitmp3 in #4337
  • Add API reference for flax.nnx.nn and improve landing page by @IvyZX in #4338
  • [nnx] improve transforms guide by @cgarciae in #4333
  • [nnx] cleanup gemma notebook by @cgarciae in #4334
  • Remove non-lazy RNG compat mode and flag from flax. by @copybara-service in #4339
  • [nnx] fix custom_vjp by @cgarciae in #4306
  • Define model surgery in docs by @8bitmp3 in #4349
  • [nnx] update State and variables docstrings by @cgarciae in #4346
  • Add NNX transforms nnx.while_loop and nnx.switch by @IvyZX in #4343
  • update version to v0.10.1 by @cgarciae in #4345

New Contributors

Full Changelog: v0.10.0...v0.10.1

Version 0.10.0

16 Oct 23:25
Compare
Choose a tag to compare

What's Changed

  • [nnx] clear nnx basics pip logs by @cgarciae in #4149
  • Support linen <-> nnx metadata box converging in nnx.bridge by @IvyZX in #4145
  • Add nnx bridge API reference to site by @IvyZX in #4158
  • [nnx] use jax-style transforms API in nnx_basics by @cgarciae in #4155
  • [nnx] improve nnx.scan in_axes/out_axes by @cgarciae in #4157
  • Support direct quantization for FP8 matmul by @wenscarl in #3922
  • Upgrade Flax NNX Model Surgery by @8bitmp3 in #4135
  • [nnx] add more Variable proxy methods by @cgarciae in #4170
  • [nnx] disallow Array leaves by @copybara-service in #4172
  • Internal change by @copybara-service in #4176
  • [nnx] improve landing page and nnx_basics messaging by @cgarciae in #4168
  • Fixes a small bug in flax.linen.share_scope, where the scopes of children of the module being merged that were created before setup(),were not being updated to point to the new scope, and so they would end up staying under the original tree. by @copybara-service in #4150
  • Move all NNX content up a level to be equal with Linen, to make python packaging more consistent. by @copybara-service in #4177
  • Add a guide for nnx.bridge by @IvyZX in #4171
  • [nnx] improve Optimizer metadata propagation by @cgarciae in #4180
  • [nnx] enable sharding transformation on integer prefixes by @cgarciae in #4185
  • Support linen.LogicallyPartitioned <-> nnx.Variable by @IvyZX in #4161
  • Clean up axis hooks in nnx.Variable by @IvyZX in #4189
  • Merge nnx.errors to flax.errors by @IvyZX in #4186
  • [nnx] optimize jit by @cgarciae in #4191
  • Split documentation for Linen and NNX by @cgarciae in #4192
  • Partially revert #4192 which sets back a bunch of previous merged pushes. by @copybara-service in #4201
  • Align bridge variable tree structures by @IvyZX in #4194
  • [NNX site] Fix landing page and banner phrasing and add examples page by @IvyZX in #4202
  • shorten banners by @cgarciae in #4206
  • Add trimmed Linen to NNX guide by @IvyZX in #4209
  • Minor documentation fixes for AxisMetadata. by @copybara-service in #4178
  • fix tests for numpy 2.0 compatibility by @copybara-service in #4215
  • Forward all arguments when using nnx.transforms.deprecated.scan as a decorator. by @copybara-service in #4208
  • [nnx] add transforms guide by @cgarciae in #4197
  • [nnx] fix transforms guide by @cgarciae in #4223
  • Flax NNX GSPMD guide by @IvyZX in #4220
  • Update libraries to use JAX's limited (and ill-advised) trace-state-querying APIs rather than depending on JAX's deeper internals, which are about to change. by @copybara-service in #4225
  • [nnx] add Randomness guide by @cgarciae in #4216
  • Add pure dict conversion util functions to nnx.State. by @IvyZX in #4230
  • [nnx] Simplify traversal by @cgarciae in #4205
  • Fix false positive tracer leaks in flax library. by @copybara-service in #4232
  • [nnx] add flaxlib by @copybara-service in #4235
  • [nnx] improve docs by @cgarciae in #4236
  • point nnx banner to flax-linen by @cgarciae in #4237
  • update banners by @cgarciae in #4238
  • Fix scale dtype and refactor q_dot_dq by @wenscarl in #4229
  • update banners by @cgarciae in #4241
  • Add redirects for Linen guide links in the NNX site scope. by @IvyZX in #4242
  • Internal change by @copybara-service in #4243
  • Copybara import of the project: by @copybara-service in #4245
  • Update Flax NNX Scale Up SPMD guide by @8bitmp3 in #4239
  • Upgrade Flax NNX basics doc by @8bitmp3 in #4173
  • Improve landing page, glossary and misc by @IvyZX in #4244
  • Nitting and adding links by @8bitmp3 in #4248
  • enable doctest on notebooks by @cgarciae in #4250
  • Update index.rst by @ariG23498 in #4251
  • Add NNX checkpointing guide by @IvyZX in #4249
  • Add checkpointing guide to website index. by @copybara-service in #4263
  • Update to Flax NNX Transforms doc by @8bitmp3 in #4264
  • Add why nnx by @cgarciae in #4240
  • [nnx] add cloudpickle support by @cgarciae in #4253
  • Fix typo: impost to import by @Vilin97 in #4256
  • [nnx] revive TrainState toy example by @cgarciae in #4226
  • [nnx] add custom_vjp to docs by @cgarciae in #4266
  • remove flax-nnx urls by @cgarciae in #4267
  • Add flatten to nnx.graph autosummary in graph.rst by @8bitmp3 in #4255
  • [nnx] add FSDP toy example with custom optimizer by @cgarciae in #4183
  • Update Flax NNX Landing Page by @8bitmp3 in #4274
  • Update to Flax NNX Model Surgery by @8bitmp3 in #4276
  • Update Why Flax NNX guide by @8bitmp3 in #4262
  • Update to Flax NNX MNIST tutorial by @8bitmp3 in #4277
  • [nnx] improve randomness guide by @cgarciae in #4281
  • Remove notebook exceptions in docs_nnx doctest by @IvyZX in #4285
  • [nnx] add PrefixMapping by @cgarciae in #4278
  • [nnx] state filters by @cgarciae in #4288
  • Fix devcontainer setup by @jorisSchaller in #4299
  • Ugrade Flax NNX Checkpointing guide by @8bitmp3 in #4294
  • Update Flax NNX Scale Up guide by @8bitmp3 in #4296
  • Porting RNN from Linen to NNX by @zinccat in #4272
  • Update Flax NNX Glossary by @8bitmp3 in #4284
  • update version to 0.10.0 by @cgarciae in #4292

New Contributors

Full Changelog: v0.9.0...v0.10.0

v0.9.0

27 Aug 17:51
Compare
Choose a tag to compare

What's Changed

  • Add NNX surgery guide by @IvyZX in #4005
  • Port gemma/transformer to NNX by @copybara-service in #4019
  • upgrade python to 3.10 + use pyupgrade by @cgarciae in #4038
  • [nnx] add Using Filters guide by @cgarciae in #4028
  • v0.8.6 by @cgarciae in #4040
  • allow imagenet training profiling to be disabled in config by @copybara-service in #4043
  • [nnx] LoRAParam inherits from Param by @cgarciae in #3988
  • [linen] allows multiple compact methods by @cgarciae in #3808
  • Added support of NANOO fp8. by @wenchenvincent in #3993
  • Add functool.wraps() annotation to flax.nn.jit. by @copybara-service in #4051
  • Fix typo in nnx_basics doc by @rajasekharporeddy in #4047
  • [nnx] fix Variable overloads and add shape/dtype properties by @cgarciae in #4049
  • Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4039
  • [nnx] stabilize unsafe_pytree by @cgarciae in #4030
  • Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4055
  • [NVIDIA] Rename fp8 custom dtype to fp32_max_grad by @kaixih in #3984
  • [nnx] fix mnist_tutorial colab link by @cgarciae in #4063
  • [nnx] fix Accuracy on eager mode by @cgarciae in #4065
  • Update orbax_upgrade_guide.rst for async checkpointing usage examples by @kaushaladiti-2802 in #4036
  • Re-enable some tests after Python 3.9 is dropped by @IvyZX in #4067
  • Rename nnx.compat to nnx.bridge by @IvyZX in #4066
  • [nnx] improve mnist tutorial by @cgarciae in #4070
  • Modify Flax checkpointing in preparation for cl/650338576. by @copybara-service in #4072
  • Remove some outdated backward-compatibility code. by @copybara-service in #4068
  • [NVIDIA] Add a user guide for fp8 by @kaixih in #4076
  • [nnx] add extract APIs by @cgarciae in #4078
  • [example]: remove lm1b useless parallism rules by @knightXun in #4077
  • [nnx] improve filters guide by @cgarciae in #4059
  • [nnx] add call by @cgarciae in #4004
  • Ignore Orbax warning in deprecated flax.training.checkpoints.py to unbreak head doctest by @IvyZX in #4092
  • fix mypy failures due tu numpy update by @cgarciae in #4098
  • [linen] generalize transform caching by @copybara-service in #4057
  • [linen] fold rngs on jit to improve caching by @copybara-service in #4064
  • Add shape-based lazy init to LinenToNNX (prev LinenWrapper) by @IvyZX in #4081
  • [nnx] add reseed by @cgarciae in #4099
  • [nnx] add split/merge_inputs by @cgarciae in #4084
  • Perform shape checks for self.param AFTER unboxing by @danielwatson6 in #4079
  • fix restore_checkpoint example in docstring by @copybara-service in #4101
  • [numpy] Fix users of NumPy APIs that are removed in NumPy 2.0. by @copybara-service in #4104
  • set profile_duration_ms = None as in periodic_actions there's default value for both num_profile_steps and profile_duration_ms, and the profile stopping condition is when both num_profile_steps and profile_duration_ms are satisfied, so setting profile_duration_ms=None so that the passed num_profile_steps value gets used by @copybara-service in #4096
  • [linen] add share_scope by @cgarciae in #4102
  • Allow metadata pass-through in flax.struct.field by @cool-RR in #4056
  • avoid mixing einsum_dot_general and einsum argument by specifying them explicitly in the caller. by @copybara-service in #4115
  • Add logging to track deprecated codepaths. by @copybara-service in #4121
  • [pmap no rank reduce cleanup]: When flipping the by @copybara-service in #4125
  • Add NNXToLinen wrapper to nnx.bridge by @IvyZX in #4126
  • Switch NNX to use Treescope instead of Penzai. by @copybara-service in #4132
  • Add GroupNorm to NNX normalization layers by @treigerm in #4095
  • [nnx] fix initializing propagation by @cgarciae in #4134
  • add JAX-style NNX Transforms FLIP by @cgarciae in #4108
  • Fix _ParentType annotation by @dcharatan in #4120
  • add uv.lock file by @copybara-service in #4139
  • use uv package manager by @cgarciae in #4136
  • More testing and misc fixes on wrappers by @IvyZX in #4137
  • Fix link to orbax documentation by @cool-RR in #4123
  • [nnx] experimental transforms by @cgarciae in #3963
  • [nnx] improve docs by @cgarciae in #4141
  • remove repeated license headers by @cgarciae in #4148
  • update Flax to version 0.9.0 by @copybara-service in #4147

New Contributors

Full Changelog: v0.8.5...v0.9.0

v0.8.5

26 Jun 09:27
Compare
Choose a tag to compare

What's Changed

  • v0.8.5 by @cgarciae in #3941
  • [nnx] improve vmap axis size detection by @cgarciae in #3947
  • Add direct penzai.treescope support for NNX objects. by @copybara-service in #3948
  • [nnx] fix nnx_basics dependencies by @cgarciae in #3942
  • Rename all the NNX tests to internal naming & build conventions. by @copybara-service in #3952
  • updated rng guide by @chiamp in #3912
  • upgraded haiku guide to include NNX by @chiamp in #3923
  • parameterized NNX transforms tests by @chiamp in #3906
  • Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes. by @copybara-service in #3957
  • fix HEAD by @chiamp in #3960
  • Minor grammar fixes to NNX documentation. by @mcsmart76 in #3953
  • Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3928
  • Adding Welford metric. by @copybara-service in #3959
  • Modify Welford metric to return mean value. by @copybara-service in #3970
  • [nnx] make State generic by @cgarciae in #3964
  • updated NNX nn docstrings by @chiamp in #3972
  • make flax work with upcoming JAX change to tree_map (being more careful about by @copybara-service in #3976
  • updated nnx.module docstrings by @chiamp in #3966
  • updated nnx.Conv and nnx.ConvTranspose by @chiamp in #3974
  • updated nnx.graph docstrings by @chiamp in #3958
    • Adds pmap and Pmap. static_broadcasted_argnums, donate_argnums, and global_arg_shapes are not yet supported. by @copybara-service in #3978
  • Fixes for batch norm docs by @jkarwowski in #3982
  • fix deprecation warning by @chiamp in #3981
  • updated NNX rnglib docstring by @chiamp in #3980
  • updated nnx.training by @chiamp in #3975
  • updated nnx.variables docstrings by @chiamp in #3986
  • [nnx] vectorize vmap split counts by @cgarciae in #3989
  • added wrt option to nnx.Optimizer by @chiamp in #3983
  • Added nnx.graph.iter_children by @chiamp in #3991
  • [nnx] fix vmap by @copybara-service in #3995
  • Fix head pytest breakage by @IvyZX in #4006
  • Helper function for loading params from a linen module by @copybara-service in #4012
  • Port gemma/layers to NNX by @copybara-service in #4013
  • [nnx] fix grad by @cgarciae in #4007
  • [nnx] add PathContains Filter by @cgarciae in #4011
  • Support Python 3.9 by @copybara-service in #4018
  • Port gemma/modules to NNX by @copybara-service in #4014
  • Internal change to fix current head CI by @copybara-service in #4017
  • Unpin the Orbax pip version. by @copybara-service in #4024
  • Fix Gemma test to unbreak head by @IvyZX in #4025
  • Fix pickling of exceptions by @sanderland in #4002
  • Call user-defined variable transforms before determining axis size in nn.vmap. by @copybara-service in #4026
  • CI: add test run against oldest supported jax version by @jakevdp in #3996
  • Make force_fp32_for_softmax arg in MultiHeadDotProductAttention useful. by @copybara-service in #4029

New Contributors

Full Changelog: v0.8.4...v0.8.5

v0.8.4

24 May 17:09
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.8.3...v0.8.4

v0.8.3

30 Apr 09:56
Compare
Choose a tag to compare

What's Changed

  • Add git fetch upstream to contributing doc. by @carlosgmartin in #3757
  • removed getattr/setattr unboxing magic from nnx.Pytree by @chiamp in #3743
  • added Einsum layer to NNX by @chiamp in #3741
  • Make TrainState's step possibly jax.Array. This makes replicate valid for type checking. by @copybara-service in #3763
  • v0.8.3 by @cgarciae in #3758
  • [nnx] fix demo notebook by @cgarciae in #3744
  • added nnx api reference by @chiamp in #3762
  • updated rng docstring for init, apply and make_rng by @chiamp in #3765
  • use note box in make_rng docstring by @cgarciae in #3767
  • [nnx] improved graph update mechanism by @cgarciae in #3759
  • use note box in docstrings by @chiamp in #3769
  • Add reset_gate flag to MGUCell. by @carlosgmartin in #3760
  • Access thread_resources via jax.interpreters.pxla instead of jax.experimental.maps by @copybara-service in #3775
  • Minor doc improvements by @canyon289 in #3588
  • added MGU reset_gate test by @chiamp in #3773
  • [nnx] Pytrees are Trees by @cgarciae in #3768
  • Use short-circuiting access to debug_key_reuse by @copybara-service in #3781
  • fix tabulate on norm wrappers by @chiamp in #3772
  • Add kw_only struct.dataclass test by @chiamp in #3651
  • extended PyTreeNode to take dataclass kwargs by @chiamp in #3785
  • [nnx] Arrays are state by @cgarciae in #3791
  • [nnx] add GraphNode base class by @cgarciae in #3790
  • [nnx] jit accepts many Modules by @cgarciae in #3783
  • Exposing the experimental _split_transpose JAX scan parameter in Flax. by @copybara-service in #3795
  • Expose nnx.GraphNode by @chiamp in #3796
  • [nnx] Rngs and RngStream inherit from GraphNode by @cgarciae in #3793
  • [nnx] TrainState uses struct by @cgarciae in #3788
  • [nnx] split returns graphdef first by @cgarciae in #3794
  • Remove the uninitialized field "embedding" in nn.Embed by @copybara-service in #3801
  • Add nnx.training by @chiamp in #3782
  • [nnx] non-str State keys by @cgarciae in #3802
  • [nnx] allow all jit kwargs in nnx.jit by @cgarciae in #3809
  • [nnx] simplify readme by @cgarciae in #3805
  • [nnx] Fix nnx basics by @cgarciae in #3812
  • [nnx] grad accepts argnums by @cgarciae in #3798
  • [nnx] improve toy examples by @cgarciae in #3813
  • [nnx] expose Sequential by @cgarciae in #3814
  • [nnx] Rng Variable tags by @cgarciae in #3807
  • [nnx] remove copy in graph unflatten by @cgarciae in #3804
  • fixed optax guide links and docstring typos by @chiamp in #3789
  • added dropout broadcast test by @chiamp in #3776
  • relaxed grads kwarg for Optimizer.update by @chiamp in #3818
  • added tree_map deprecation warning filter by @chiamp in #3828
  • updated tree_map by @chiamp in #3823
  • added NNX vs JAX transformations guide by @chiamp in #3819
  • Updated NNX MNIST tutorial by @chiamp in #3810
  • [nnx] add Dropout.rngs by @cgarciae in #3815
  • removed autosummary from linen docs by @chiamp in #3792
  • Fix cloudpickle sentinel cloning by @cgarciae in #3825
  • [nnx] remove pytreelib by @cgarciae in #3816
  • [nnx] fix nnx_basics by @cgarciae in #3839
  • [linen] fix DenseGeneral init by @cgarciae in #3834
  • [nnx] jit constrain object state by @cgarciae in #3817
  • Copybara import of the project: by @copybara-service in #3857
  • Add example of unbox() and replace_boxed() to the jit guide by @IvyZX in #3843
  • RNNCellBase refactor FLIP by @cgarciae in #3099
  • [nnx] Some small documentation suggestions. by @gnecula in #3861
  • updated nnx dropout by @chiamp in #3841
  • Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, by @copybara-service in #3877
  • Add option to skip float32 promotion when computing means and variances for normalization. by @copybara-service in #3873
  • added nnx api reference link by @chiamp in #3871
  • option of forcing the input of softmax to be fp32 for better numerical stability in mixed-precision training. by @copybara-service in #3874
  • allow custom dot_general for einsum. by @copybara-service in #3884
  • [NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios by @kaixih in #3827
  • updated robots.txt by @chiamp in #3886
  • fixed autosummary links by @chiamp in #3887
  • Fix jax.tree_util.register_dataclass in older JAX versions. by @copybara-service in #3885
  • [nnx] v0.1 by @cgarciae in #3876

Full Changelog: v0.8.2...v0.8.3

v0.8.2

14 Mar 11:34
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.8.1...v0.8.2

Version 0.8.1

07 Feb 21:52
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.8.0...v0.8.1

v0.8.0

23 Jan 23:16
Compare
Choose a tag to compare

What's Changed

New Contributors

Read more

v0.7.5

28 Oct 02:07
Compare
Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.7.4...v0.7.5