Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 29 additions & 32 deletions chex/_src/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,27 +496,28 @@ def _with_pmap(fn,

# Set up a reduce function.
if reduce_fn == "first_device_output":
# Avoid degraded performance under the new jax.pmap. See
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
if jax.config.jax_pmap_shmap_merge:
def reduce_fn(x): # pylint: disable=function-redefined
def _reduce_leaf(leaf):
if (hasattr(leaf, "__getitem__") and
hasattr(leaf, "shape") and
leaf.shape):
if (not isinstance(leaf, jax.core.Tracer) and
hasattr(leaf, "addressable_shards") and
leaf.addressable_shards):
data = leaf.addressable_shards[0].data
return data if not data.shape[0] else data[0]

# Fallback for tracers or other indexable outputs.
return leaf if not leaf.shape[0] else leaf[0]
return leaf

return tree_map(_reduce_leaf, x)
else:
reduce_fn = lambda t: tree_map(lambda x: x[0], t)

def reduce_fn(x): # pylint: disable=function-redefined
def _reduce_leaf(leaf):
if (
hasattr(leaf, "__getitem__")
and hasattr(leaf, "shape")
and leaf.shape
):
if (
not isinstance(leaf, jax.core.Tracer)
and hasattr(leaf, "addressable_shards")
and leaf.addressable_shards
):
data = leaf.addressable_shards[0].data
return data if not data.shape[0] else data[0]

# Fallback for tracers or other indexable outputs.
return leaf if not leaf.shape[0] else leaf[0]
return leaf

return tree_map(_reduce_leaf, x)

elif reduce_fn == "identity" or reduce_fn is None: # Identity.
reduce_fn = lambda t: t

Expand Down Expand Up @@ -545,17 +546,13 @@ def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree):
raise ValueError("Number of available devices is less than required for "
f"test ({len(devices_)} < {n_devices_})")

if jax.config.jax_pmap_shmap_merge:
def bcast_fn(x):
x = jnp.asarray(x)
x = jnp.broadcast_to(x, (n_devices_,) + x.shape)
if not isinstance(x, jax.core.Tracer):
return jax.device_put_sharded(list(x), devices_)
return x
else:
bcast_fn = lambda x: jnp.broadcast_to(
x, (n_devices_,) + jnp.asarray(x).shape
)
def bcast_fn(x):
x = jnp.asarray(x)
x = jnp.broadcast_to(x, (n_devices_,) + x.shape)
if not isinstance(x, jax.core.Tracer):
return jax.device_put_sharded(list(x), devices_)
return x

if broadcast_args_to_devices:
args = [
tree_map(bcast_fn, arg) if idx not in static_argnums else arg
Expand Down
Loading