Skip to content

Updates on flax, jax and optax caused regressions in google-research projects #4944

@xvdp

Description

@xvdp

Who is responsible for maintenance of deprecated code impact on other parts of your org?
Project https://github.com/google-research/google-research/tree/master/diffusion_distillation, a very nice paper from 2022, used flax.optim and a custom TrainState based on @flax.struct.dataclass which have since been replaced by optax and flax.training.train_state.TrainState. Im sure it is great but between google colab forcing numpy, jax, flax etc.. versions, as well as python 3.12 as well as jaxlib 0.4.1 no longer being in pypi and custom installation from wheel containg no cuda versions... it is impossible to recreate old projects. Especially if one (partly for this reason of deprecations) leans towards torch. * I do prefer jax in principle but I dont use it because "old" projects are broken (3 years old stuff that is very relevatn should not be "old")

I tried to fix the project but I only did it partially, something in either flax.optim -> optax or Trainstate -> flax.training.train_state.TrainState breaks my attempted fix.

google-research/google-research#2976

Can someone quickly fix that? Or suggest what in these 2 files i gotta change. to make a pull request ?
https://github.com/xvdp/google-research/blob/fixflaxversion/diffusion_distillation/diffusion_distillation/model.py
https://github.com/xvdp/google-research/blob/fixflaxversion/diffusion_distillation/diffusion_distillation.ipynb

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions