Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,10 +1060,13 @@ def _assert_fn(path, leaf):
# This is for backwards compatibility.
def _check_sharding(x):
if hasattr(jax, "Array") and isinstance(x, jax.Array):
if not jax.typeof(x).sharding.is_fully_replicated:
return True
else:
return len(x.sharding.device_set) > 1
# Use x.sharding directly for concrete arrays.
sharding = getattr(x, 'sharding', None)
if sharding is not None:
if not sharding.is_fully_replicated:
return True
else:
return len(sharding.device_set) > 1
# pytype: disable=attribute-error
return (
hasattr(jax, "pxla")
Expand Down Expand Up @@ -1184,8 +1187,8 @@ def _assert_fn(path, leaf):
if isinstance(leaf, jax.Array):
if _check_sharding(leaf):
errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' is a "
f"ShardedDeviceArray which are disallowed. "
f" (type={type(leaf)})."))
f"sharded JAX array (historically ShardedDeviceArray) "
f"which are disallowed. (type={type(leaf)})."))
else: # DeviceArray and not ShardedDeviceArray
# Check the platform.
leaf_device = list(leaf.devices())[0]
Expand Down
Loading