-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
When shard_map
ping a function with relatively long (and variable) runtimes for different inputs across multiple devices, JAX times out with the following error:
2025-09-15 14:00:55.996594: F external/xla/xla/service/rendezvous.cc:127] [id=6] Termination timeout for `all reduce RendezvousKey{run_id=RunId: 1416913065, global_devices=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27], num_local_participants=28, collective_op_kind=cross_module, op_id=1}` of 10 seconds exceeded. Exiting to ensure a consistent program state. Expected 28 threads to join the rendezvous, but only 11 of them arrived on time.
Aborted (core dumped)
I understand the problem seems related to uneven work distribution (i.e., some inputs take longer to evaluate than others, which causes a timeout). For no input I expect my function to take 10 seconds to evaluate, but I guess it might have to do with compilation or something else (not sure here).
In any case, I couldn't find a way to disable the timeout. Is there a chance you may expose a way to override the timeout to JAX users? (Or, is there a different suggestion for executing this workflow without running into the timeouts?)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.6.2
jaxlib: 0.6.2
numpy: 2.3.3
python: 3.13.7 (main, Sep 2 2025, 14:21:46) [Clang 20.1.4 ]
device info: cpu-28, 28 local devices"
process_count: 1
platform: uname_result(system='Linux', node='mucajai', release='6.14.0-29-generic', version='#29~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Aug 14 16:52:50 UTC 2', machine='x86_64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working