Skip to content

Commit

Permalink
upgrading init2winit from pmap to jit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673695362
  • Loading branch information
sourabh2k15 authored and The precondition Authors committed Dec 18, 2024
1 parent 6f23374 commit 48a3487
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions precondition/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2832,10 +2832,7 @@ def _pmap_compute_preconditioners(states, step, statistics,
Returns:
New optimizer states after computing the preconditioner.
"""
if batch_axis_name:
num_devices = lax.psum(1, batch_axis_name)
else:
num_devices = 1
num_devices = jax.device_count()
num_statistics = len(statistics)
# Pad statistics and exponents to next multiple of num_devices.
packed_statistics = [
Expand Down Expand Up @@ -3033,7 +3030,7 @@ def _pmap_quantized_compute_preconditioners(states, step, statistics,
Returns:
New optimizer states after computing the preconditioner.
"""
num_devices = lax.psum(1, batch_axis_name)
num_devices = jax.device_count()
num_statistics = len(statistics)
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
# Complexity here is around: shapes needing be statically shaped,
Expand Down

0 comments on commit 48a3487

Please sign in to comment.