Train with JAX on Multiple Heterogeneous GPUs (RTX 4060Ti and RTX 5060Ti) #32027
Unanswered
animikhaich
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have two devices:
With a background in PyTorch, I have used PyTorch to train on both GPUs together using DDP.
I was playing around with Jax and Flax, and hence started with the official MNIST Example on Flax. It ran fine on either GPUs individually, and wanted test out distributed training across both my GPUs.
With Google Gemini's help, I ended up with this code:
I got the following error:
Did some digging and realized that Data Parallelization is not supported across heterogenous hardware with Jax. Is that true? Is there any way around it?
Beta Was this translation helpful? Give feedback.
All reactions