Multi-Host training checkpointing #21290
Replies: 2 comments
-
I managed to solve it. |
Beta Was this translation helpful? Give feedback.
0 replies
-
For those googling this error, I also ran into this error when I had a threading issue. Specifically, I was inadvertently launching TPU JAX kernels in a background data loading thread. So be sure you always launch the same TPU kernel on all workers at the same time. (Both the original error and mine are that.) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am trying to do multi-host training on a TPU pod. I managed to run back-propagation, but I got stuck at saving checkpoints, mainly in saving the distributed flax train state.
I distributed the initialised state using the following:
After a few updates, I try to save the state but failed to get the parameters on one process. The things I tried:
Do you have any advice on how to get the parameters on one host and save them on disk with orbax?
I also tried directly saving with the legacy api:
but the error that I get is:
Regards,
Rares
Beta Was this translation helpful? Give feedback.
All reactions