Training worked locally with 4 chips fsdp=4: export PJRT_DEVICE=TPU; export TORCHPRIME_TPU_TYPE=v6e-4 && python torchprime/torch_xla_models/train.py model=flex-qwen-1b
MFU: 0.21
On a v5p-128 cluster with command tp run --name jialei-0812-qwen-fsdp32tensor2 torchprime/torch_xla_models/train.py model=flex-qwen-1b task.global_batch_size=64 ici_mesh.fsdp=x ici_mesh.tensor=y
- fsdp64 tp 1: hang >.<
- fsdp 32 tp 2: finished MFU 0.22
- fsdp 16 tp 4: finished: MFU 0.19
- fsdp 8 tp 8: finished, MFU 0.11
Also see cluster log here https://b.corp.google.com/issues/436664633#comment40.
Runs for llama model: https://b.corp.google.com/issues/436664633#comment33