diff --git a/chex/_src/variants.py b/chex/_src/variants.py index cd90a2e..ea961ef 100644 --- a/chex/_src/variants.py +++ b/chex/_src/variants.py @@ -496,27 +496,28 @@ def _with_pmap(fn, # Set up a reduce function. if reduce_fn == "first_device_output": - # Avoid degraded performance under the new jax.pmap. See - # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. - if jax.config.jax_pmap_shmap_merge: - def reduce_fn(x): # pylint: disable=function-redefined - def _reduce_leaf(leaf): - if (hasattr(leaf, "__getitem__") and - hasattr(leaf, "shape") and - leaf.shape): - if (not isinstance(leaf, jax.core.Tracer) and - hasattr(leaf, "addressable_shards") and - leaf.addressable_shards): - data = leaf.addressable_shards[0].data - return data if not data.shape[0] else data[0] - - # Fallback for tracers or other indexable outputs. - return leaf if not leaf.shape[0] else leaf[0] - return leaf - - return tree_map(_reduce_leaf, x) - else: - reduce_fn = lambda t: tree_map(lambda x: x[0], t) + + def reduce_fn(x): # pylint: disable=function-redefined + def _reduce_leaf(leaf): + if ( + hasattr(leaf, "__getitem__") + and hasattr(leaf, "shape") + and leaf.shape + ): + if ( + not isinstance(leaf, jax.core.Tracer) + and hasattr(leaf, "addressable_shards") + and leaf.addressable_shards + ): + data = leaf.addressable_shards[0].data + return data if not data.shape[0] else data[0] + + # Fallback for tracers or other indexable outputs. + return leaf if not leaf.shape[0] else leaf[0] + return leaf + + return tree_map(_reduce_leaf, x) + elif reduce_fn == "identity" or reduce_fn is None: # Identity. reduce_fn = lambda t: t @@ -545,17 +546,13 @@ def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree): raise ValueError("Number of available devices is less than required for " f"test ({len(devices_)} < {n_devices_})") - if jax.config.jax_pmap_shmap_merge: - def bcast_fn(x): - x = jnp.asarray(x) - x = jnp.broadcast_to(x, (n_devices_,) + x.shape) - if not isinstance(x, jax.core.Tracer): - return jax.device_put_sharded(list(x), devices_) - return x - else: - bcast_fn = lambda x: jnp.broadcast_to( - x, (n_devices_,) + jnp.asarray(x).shape - ) + def bcast_fn(x): + x = jnp.asarray(x) + x = jnp.broadcast_to(x, (n_devices_,) + x.shape) + if not isinstance(x, jax.core.Tracer): + return jax.device_put_sharded(list(x), devices_) + return x + if broadcast_args_to_devices: args = [ tree_map(bcast_fn, arg) if idx not in static_argnums else arg