Skip to content

Commit

Permalink
Improve log message and print only on hosts that broadcast. Also bloc…
Browse files Browse the repository at this point in the history
…k subtree instead of tree.

PiperOrigin-RevId: 658166778
  • Loading branch information
Orbax Authors committed Aug 1, 2024
1 parent 6a7fdf6 commit b37e71a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions checkpoint/orbax/checkpoint/multihost/multislice_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,11 @@ def pre_jit(x, per_replica_sharding):
out_shardings=out_sharding,
)(in_tree_sharded)
out_tree.extend(out_subtree)
jax.block_until_ready(out_tree)
jax.block_until_ready(out_subtree)
start = end
logging.info('Number of broadcasts: %d', num_broadcasts)

if is_source:
logging.info('Total number of broadcasts: %d', num_broadcasts)
return tuple(out_tree), num_broadcasts


Expand Down

0 comments on commit b37e71a

Please sign in to comment.