diff --git a/vmoe/app.py b/vmoe/app.py index ed9c380..7915b77 100644 --- a/vmoe/app.py +++ b/vmoe/app.py @@ -66,11 +66,6 @@ def _main(argv, *, main) -> None: # Log JAX compilation steps. jax.config.update('jax_log_compiles', True) jax.config.update('jax_default_prng_impl', 'unsafe_rbg') - # Enable experimental xmap spmd lowering. Necessary for mixing pjit and xmap. - # Calling xmap from within pjit is necessary because pure_callback is not - # fully supported with pjit. - jax.config.update('experimental_xmap_spmd_lowering', True) - jax.config.update('experimental_xmap_spmd_lowering_manual', True) # Log useful information to identify the process running in the logs. logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices())