@@ -928,7 +928,7 @@ def _allreduce_impl(prim, pos_reducer, *args, axes, axis_index_groups):
928
928
return [pos_reducer (arg , axes ) for arg in args ]
929
929
930
930
def _allreduce_effectful_abstract_eval (* args , axes , axis_index_groups ):
931
- _check_axis_names (axes )
931
+ _check_axis_names (axes , 'psum' )
932
932
named_axes = tuple (axis for axis in axes if not isinstance (axis , int ))
933
933
pos_axes = tuple (axis for axis in axes if isinstance (axis , int ))
934
934
if axis_index_groups is not None :
@@ -949,7 +949,7 @@ def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups):
949
949
* args , axes = axes , axis_index_groups = axis_index_groups )
950
950
951
951
assert isinstance (axes , tuple )
952
- _check_axis_names (axes )
952
+ _check_axis_names (axes , 'psum' )
953
953
arg_vma = [a .vma for a in args ]
954
954
# If intersection between arg_vma and axes is empty, error
955
955
if any (not set (axes ) & a for a in arg_vma ):
@@ -985,12 +985,14 @@ def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
985
985
return _psum_invariant_abstract_eval (
986
986
name , * args , axes = axes , axis_index_groups = axis_index_groups )
987
987
988
- def _check_axis_names (axes ):
988
+ def _check_axis_names (axes , api_name ):
989
989
named_axes = tuple (axis for axis in axes if not isinstance (axis , int ))
990
990
axis_env = core .get_axis_env ()
991
991
for name in named_axes :
992
992
if not axis_env .axis_exists (name ):
993
- raise NameError (f"unbound axis name: { name } " )
993
+ raise NameError (
994
+ f"Found an unbound axis name: { name } . To fix this, please call"
995
+ f" { api_name } under `jax.shard_map`." )
994
996
995
997
def _allreduce_lowering (prim , pos_fn , ctx , * args , axes , axis_index_groups ):
996
998
if axis_index_groups is not None and ("tpu" in ctx .module_context .platforms ):
@@ -1166,7 +1168,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
1166
1168
return v .take (perm_indices , d ), d
1167
1169
1168
1170
def _raise_to_shaped_abstract_eval (x , * , axis_name , ** params ):
1169
- _check_axis_names (axis_name )
1171
+ _check_axis_names (axis_name , 'ppermute' )
1170
1172
collective_vma_rule ('ppermute' , axis_name , x )
1171
1173
return x
1172
1174
@@ -1218,7 +1220,7 @@ def _psend_lowering_gpu(ctx, x, *, axis_name, perm):
1218
1220
1219
1221
1220
1222
def _psend_abstract_eval (x , * , axis_name , ** params ):
1221
- _check_axis_names (axis_name )
1223
+ _check_axis_names (axis_name , 'psend' )
1222
1224
return abstract_token , {
1223
1225
* map (core .NamedAxisEffect , axis_name ),
1224
1226
single_side_collective_effect ,
@@ -1492,7 +1494,7 @@ def _all_to_all_effectful_abstract_eval(
1492
1494
del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
1493
1495
if not isinstance (axis_name , (list , tuple )):
1494
1496
axis_name = (axis_name ,)
1495
- _check_axis_names (axis_name )
1497
+ _check_axis_names (axis_name , 'all_to_all' )
1496
1498
shape = list (input_aval .shape )
1497
1499
axis_size = (
1498
1500
_axis_size (axis_name )
@@ -1581,7 +1583,7 @@ def _ragged_all_to_all_effectful_abstract_eval(
1581
1583
" size, but got shape {}" .format (recv_sizes .shape )
1582
1584
)
1583
1585
1584
- _check_axis_names (axis_name )
1586
+ _check_axis_names (axis_name , 'ragged_all_to_all' )
1585
1587
out_aval = output .update (shape = output .shape , weak_type = False )
1586
1588
effects = {* map (core .NamedAxisEffect , axis_name )}
1587
1589
return out_aval , effects
@@ -1802,7 +1804,7 @@ def _all_gather_effectful_abstract_eval(
1802
1804
):
1803
1805
if not isinstance (axis_name , (list , tuple )):
1804
1806
axis_name = (axis_name ,)
1805
- _check_axis_names (axis_name )
1807
+ _check_axis_names (axis_name , 'all_gather' )
1806
1808
new_shape = list (x_aval .shape )
1807
1809
if tiled :
1808
1810
new_shape [all_gather_dimension ] *= axis_size
@@ -1920,7 +1922,7 @@ def bind(leaf):
1920
1922
def _all_gather_invariant_effectful_abstract_eval (
1921
1923
x_aval , * , all_gather_dimension , axis_name , axis_size , tiled
1922
1924
):
1923
- _check_axis_names (axis_name )
1925
+ _check_axis_names (axis_name , 'all_gather_invariant' )
1924
1926
new_shape = list (x_aval .shape )
1925
1927
if tiled :
1926
1928
new_shape [all_gather_dimension ] *= axis_size
@@ -2026,7 +2028,7 @@ def _reduce_scatter_effectful_abstract_eval(
2026
2028
):
2027
2029
if not isinstance (axis_name , (list , tuple )):
2028
2030
axis_name = (axis_name ,)
2029
- _check_axis_names (axis_name )
2031
+ _check_axis_names (axis_name , 'reduce_scatter' )
2030
2032
new_shape = list (x_aval .shape )
2031
2033
scatter_dim_input_size = x_aval .shape [scatter_dimension ]
2032
2034
if tiled :
@@ -2244,7 +2246,7 @@ def _axis_index_lowering(ctx, *, axis_name):
2244
2246
def _axis_index_effectful_abstract_eval (* , axis_name ):
2245
2247
effect = {core .NamedAxisEffect (axis_name )}
2246
2248
axis_name = (axis_name ,) if not isinstance (axis_name , tuple ) else axis_name
2247
- _check_axis_names (axis_name )
2249
+ _check_axis_names (axis_name , 'axis_index' )
2248
2250
mesh = get_abstract_mesh ()
2249
2251
sharding = NamedSharding (mesh , P ())
2250
2252
vma = ((frozenset (axis_name ) if mesh ._any_axis_manual else frozenset ())
@@ -2280,7 +2282,7 @@ def _pgather_impl(src, idx, *, axes):
2280
2282
def _pgather_abstract_eval (src , idx , * , axes ):
2281
2283
# TODO: Avals with names rule: remove all axes from src, insert those from idx
2282
2284
# The order is important, because it is ok to re-insert one of the deleted axes!
2283
- _check_axis_names (axes )
2285
+ _check_axis_names (axes , 'pgather' )
2284
2286
shape = list (src .shape )
2285
2287
for axis in sorted ((a for a in axes if isinstance (a , int )), reverse = True ):
2286
2288
del shape [axis ]
0 commit comments