diff --git a/precondition/distributed_shampoo.py b/precondition/distributed_shampoo.py index 8ad9655..e114ea3 100644 --- a/precondition/distributed_shampoo.py +++ b/precondition/distributed_shampoo.py @@ -2832,10 +2832,7 @@ def _pmap_compute_preconditioners(states, step, statistics, Returns: New optimizer states after computing the preconditioner. """ - if batch_axis_name: - num_devices = lax.psum(1, batch_axis_name) - else: - num_devices = 1 + num_devices = jax.device_count() num_statistics = len(statistics) # Pad statistics and exponents to next multiple of num_devices. packed_statistics = [ @@ -3033,7 +3030,7 @@ def _pmap_quantized_compute_preconditioners(states, step, statistics, Returns: New optimizer states after computing the preconditioner. """ - num_devices = lax.psum(1, batch_axis_name) + num_devices = jax.device_count() num_statistics = len(statistics) quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() # Complexity here is around: shapes needing be statically shaped,