Skip to content

Commit ae79249

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Improve error message when collective APIs are called without a shard_map
Before: `unbound axis name: x` After: `Found an unbound axis name: x. To fix this, please call psum under jax.shard_map` PiperOrigin-RevId: 778632500
1 parent 14b2c90 commit ae79249

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

jax/_src/lax/parallel.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,7 @@ def _allreduce_impl(prim, pos_reducer, *args, axes, axis_index_groups):
928928
return [pos_reducer(arg, axes) for arg in args]
929929

930930
def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
931-
_check_axis_names(axes)
931+
_check_axis_names(axes, 'psum')
932932
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
933933
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
934934
if axis_index_groups is not None:
@@ -949,7 +949,7 @@ def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups):
949949
*args, axes=axes, axis_index_groups=axis_index_groups)
950950

951951
assert isinstance(axes, tuple)
952-
_check_axis_names(axes)
952+
_check_axis_names(axes, 'psum')
953953
arg_vma = [a.vma for a in args]
954954
# If intersection between arg_vma and axes is empty, error
955955
if any(not set(axes) & a for a in arg_vma):
@@ -985,12 +985,14 @@ def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
985985
return _psum_invariant_abstract_eval(
986986
name, *args, axes=axes, axis_index_groups=axis_index_groups)
987987

988-
def _check_axis_names(axes):
988+
def _check_axis_names(axes, api_name):
989989
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
990990
axis_env = core.get_axis_env()
991991
for name in named_axes:
992992
if not axis_env.axis_exists(name):
993-
raise NameError(f"unbound axis name: {name}")
993+
raise NameError(
994+
f"Found an unbound axis name: {name}. To fix this, please call"
995+
f" {api_name} under `jax.shard_map`.")
994996

995997
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
996998
if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms):
@@ -1166,7 +1168,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
11661168
return v.take(perm_indices, d), d
11671169

11681170
def _raise_to_shaped_abstract_eval(x, *, axis_name, **params):
1169-
_check_axis_names(axis_name)
1171+
_check_axis_names(axis_name, 'ppermute')
11701172
collective_vma_rule('ppermute', axis_name, x)
11711173
return x
11721174

@@ -1218,7 +1220,7 @@ def _psend_lowering_gpu(ctx, x, *, axis_name, perm):
12181220

12191221

12201222
def _psend_abstract_eval(x, *, axis_name, **params):
1221-
_check_axis_names(axis_name)
1223+
_check_axis_names(axis_name, 'psend')
12221224
return abstract_token, {
12231225
*map(core.NamedAxisEffect, axis_name),
12241226
single_side_collective_effect,
@@ -1492,7 +1494,7 @@ def _all_to_all_effectful_abstract_eval(
14921494
del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
14931495
if not isinstance(axis_name, (list, tuple)):
14941496
axis_name = (axis_name,)
1495-
_check_axis_names(axis_name)
1497+
_check_axis_names(axis_name, 'all_to_all')
14961498
shape = list(input_aval.shape)
14971499
axis_size = (
14981500
_axis_size(axis_name)
@@ -1581,7 +1583,7 @@ def _ragged_all_to_all_effectful_abstract_eval(
15811583
" size, but got shape {}".format(recv_sizes.shape)
15821584
)
15831585

1584-
_check_axis_names(axis_name)
1586+
_check_axis_names(axis_name, 'ragged_all_to_all')
15851587
out_aval = output.update(shape=output.shape, weak_type=False)
15861588
effects = {*map(core.NamedAxisEffect, axis_name)}
15871589
return out_aval, effects
@@ -1802,7 +1804,7 @@ def _all_gather_effectful_abstract_eval(
18021804
):
18031805
if not isinstance(axis_name, (list, tuple)):
18041806
axis_name = (axis_name,)
1805-
_check_axis_names(axis_name)
1807+
_check_axis_names(axis_name, 'all_gather')
18061808
new_shape = list(x_aval.shape)
18071809
if tiled:
18081810
new_shape[all_gather_dimension] *= axis_size
@@ -1920,7 +1922,7 @@ def bind(leaf):
19201922
def _all_gather_invariant_effectful_abstract_eval(
19211923
x_aval, *, all_gather_dimension, axis_name, axis_size, tiled
19221924
):
1923-
_check_axis_names(axis_name)
1925+
_check_axis_names(axis_name, 'all_gather_invariant')
19241926
new_shape = list(x_aval.shape)
19251927
if tiled:
19261928
new_shape[all_gather_dimension] *= axis_size
@@ -2026,7 +2028,7 @@ def _reduce_scatter_effectful_abstract_eval(
20262028
):
20272029
if not isinstance(axis_name, (list, tuple)):
20282030
axis_name = (axis_name,)
2029-
_check_axis_names(axis_name)
2031+
_check_axis_names(axis_name, 'reduce_scatter')
20302032
new_shape = list(x_aval.shape)
20312033
scatter_dim_input_size = x_aval.shape[scatter_dimension]
20322034
if tiled:
@@ -2244,7 +2246,7 @@ def _axis_index_lowering(ctx, *, axis_name):
22442246
def _axis_index_effectful_abstract_eval(*, axis_name):
22452247
effect = {core.NamedAxisEffect(axis_name)}
22462248
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
2247-
_check_axis_names(axis_name)
2249+
_check_axis_names(axis_name, 'axis_index')
22482250
mesh = get_abstract_mesh()
22492251
sharding = NamedSharding(mesh, P())
22502252
vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset())
@@ -2280,7 +2282,7 @@ def _pgather_impl(src, idx, *, axes):
22802282
def _pgather_abstract_eval(src, idx, *, axes):
22812283
# TODO: Avals with names rule: remove all axes from src, insert those from idx
22822284
# The order is important, because it is ok to re-insert one of the deleted axes!
2283-
_check_axis_names(axes)
2285+
_check_axis_names(axes, 'pgather')
22842286
shape = list(src.shape)
22852287
for axis in sorted((a for a in axes if isinstance(a, int)), reverse=True):
22862288
del shape[axis]

tests/shard_map_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,6 +2402,22 @@ def ar(a):
24022402
self.assertArraysAllClose(ex_out1, out1, rtol=2e-4)
24032403
self.assertArraysAllClose(ex_out2, out2, rtol=2e-4)
24042404

2405+
def test_psum_not_under_shmap_error(self):
2406+
mesh = jtu.create_mesh((2,), 'x')
2407+
2408+
@jax.jit
2409+
def f(x):
2410+
return jax.lax.psum(x, 'x')
2411+
2412+
with self.assertRaisesRegex(
2413+
NameError,
2414+
'Found an unbound axis name: x. To fix this, please call psum under'
2415+
' `jax.shard_map`'):
2416+
f(jnp.arange(8.))
2417+
2418+
# fixes the above error
2419+
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P()) # doesn't crash
2420+
24052421
def test_shmap_auto_unreduced_error(self):
24062422
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
24072423
with self.assertRaisesRegex(

0 commit comments

Comments
 (0)