From b01f5ef7da2fb86a9f243ed7e858b0082ec0f25b Mon Sep 17 00:00:00 2001 From: Sourabh Medapati Date: Wed, 11 Sep 2024 23:26:54 -0700 Subject: [PATCH] upgrading init2winit from pmap to jit PiperOrigin-RevId: 673695362 --- precondition/distributed_shampoo.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/precondition/distributed_shampoo.py b/precondition/distributed_shampoo.py index 8ad9655..e114ea3 100644 --- a/precondition/distributed_shampoo.py +++ b/precondition/distributed_shampoo.py @@ -2832,10 +2832,7 @@ def _pmap_compute_preconditioners(states, step, statistics, Returns: New optimizer states after computing the preconditioner. """ - if batch_axis_name: - num_devices = lax.psum(1, batch_axis_name) - else: - num_devices = 1 + num_devices = jax.device_count() num_statistics = len(statistics) # Pad statistics and exponents to next multiple of num_devices. packed_statistics = [ @@ -3033,7 +3030,7 @@ def _pmap_quantized_compute_preconditioners(states, step, statistics, Returns: New optimizer states after computing the preconditioner. """ - num_devices = lax.psum(1, batch_axis_name) + num_devices = jax.device_count() num_statistics = len(statistics) quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() # Complexity here is around: shapes needing be statically shaped,