-
Notifications
You must be signed in to change notification settings - Fork 744
Description
Who is responsible for maintenance of deprecated code impact on other parts of your org?
Project https://github.com/google-research/google-research/tree/master/diffusion_distillation, a very nice paper from 2022, used flax.optim
and a custom TrainState
based on @flax.struct.dataclass
which have since been replaced by optax
and flax.training.train_state.TrainState
. Im sure it is great but between google colab
forcing numpy, jax, flax etc.. versions, as well as python 3.12 as well as jaxlib 0.4.1 no longer being in pypi and custom installation from wheel containg no cuda versions... it is impossible to recreate old projects. Especially if one (partly for this reason of deprecations) leans towards torch. * I do prefer jax in principle but I dont use it because "old" projects are broken (3 years old stuff that is very relevatn should not be "old")
I tried to fix the project but I only did it partially, something in either flax.optim -> optax
or Trainstate -> flax.training.train_state.TrainState
breaks my attempted fix.
google-research/google-research#2976
Can someone quickly fix that? Or suggest what in these 2 files i gotta change. to make a pull request ?
https://github.com/xvdp/google-research/blob/fixflaxversion/diffusion_distillation/diffusion_distillation/model.py
https://github.com/xvdp/google-research/blob/fixflaxversion/diffusion_distillation/diffusion_distillation.ipynb