You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
What is needed is to construct DynamicLossScale as jmp.DynamicLossScale(jnp.float32(2**15)) and to change loss_scale.py:132 to return jax.tree_util.tree_map(lambda x: (x * self.loss_scale).astype(x.dtype), tree)
This way gradients are computed in float16 and loss_scale.loss_scale won't overflow after the first 2000 steps (if it is in float16, jmp will increase it to 2**16, which is outside legal range of float16).
What really puzzles me is that this is the only jax mixed precision package that comes up in searches, and it is evidently not just dead, but it has been broken for months and no one cares. Which raises two possibilities:
Does everyone use jax to train their models strictly in float32 or bf16?
What is needed is to construct DynamicLossScale as jmp.DynamicLossScale(jnp.float32(2**15)) and to change loss_scale.py:132 to return jax.tree_util.tree_map(lambda x: (x * self.loss_scale).astype(x.dtype), tree) This way gradients are computed in float16 and loss_scale.loss_scale won't overflow after the first 2000 steps (if it is in float16, jmp will increase it to 2**16, which is outside legal range of float16).
What really puzzles me is that this is the only jax mixed precision package that comes up in searches, and it is evidently not just dead, but it has been broken for months and no one cares. Which raises two possibilities:
Does everyone use jax to train their models strictly in float32 or bf16?
Does no one use jax any more?
I'll try that with a fork of this repo when I have time, thanks for the suggestion!
I think jax is growing in popularity though haha ;p. Though, these open-source projects might not be a deepmind-priority.
Following the provided example for the DynamicLossScale causes errors if run directly.
This is of course fixed by doing
jmp.DynamicLossScale(jnp.float32(2**15))
, but doesn't this defeat the purpose of this object?The text was updated successfully, but these errors were encountered: