Skip to content

Commit

Permalink
Improve log message and print only on hosts that broadcast. Also bloc…
Browse files Browse the repository at this point in the history
…k subtree instead of tree.

PiperOrigin-RevId: 658166778
  • Loading branch information
Orbax Authors committed Jul 31, 2024
1 parent 3cc343c commit 71a1fd8
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions checkpoint/orbax/checkpoint/multihost/multislice_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,11 @@ def pre_jit(x, per_replica_sharding):
out_shardings=out_sharding,
)(in_tree_sharded)
out_tree.extend(out_subtree)
jax.block_until_ready(out_tree)
jax.block_until_ready(out_subtree)
start = end
logging.info('Number of broadcasts: %d', num_broadcasts)

if is_source:
logging.info('Total number of broadcasts: %d', num_broadcasts)
return tuple(out_tree), num_broadcasts


Expand Down

0 comments on commit 71a1fd8

Please sign in to comment.