Skip to content

Commit

Permalink
Use shard_map instead of xmap
Browse files Browse the repository at this point in the history
xmap is soft-deprecated with shard_map being the recommended alternative.

See https://jax.readthedocs.io/en/latest/notebooks/shard_map.html.

PiperOrigin-RevId: 611047495
  • Loading branch information
superbobry authored and copybara-github committed Feb 28, 2024
1 parent 90f8f94 commit 5361422
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions vmoe/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,6 @@ def _main(argv, *, main) -> None:
# Log JAX compilation steps.
jax.config.update('jax_log_compiles', True)
jax.config.update('jax_default_prng_impl', 'unsafe_rbg')
# Enable experimental xmap spmd lowering. Necessary for mixing pjit and xmap.
# Calling xmap from within pjit is necessary because pure_callback is not
# fully supported with pjit.
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
# Log useful information to identify the process running in the logs.
logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())
Expand Down

0 comments on commit 5361422

Please sign in to comment.