Skip to content

Commit 4296895

Browse files
arnoegwFlax Authors
authored andcommitted
Clarify *Norm layer docstrings: axis_index_groups is unused under SPMD jit.
PiperOrigin-RevId: 806240620
1 parent 4887f7d commit 4296895

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

flax/linen/normalization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _compute_stats(
9090
this is only used for pmap and shard map. For SPMD jit, you do not need to
9191
manually synchronize. Just make sure that the axes are correctly annotated
9292
and XLA:SPMD will insert the necessary collectives.
93-
axis_index_groups: Optional axis indices.
93+
axis_index_groups: Optional groups of indices within that named axis.
9494
use_mean: If true, calculate the mean from the input and use it when
9595
computing the variance. If false, set the mean to zero and compute the
9696
variance without subtracting the mean.
@@ -300,7 +300,7 @@ class BatchNorm(Module):
300300
representing subsets of devices to reduce over (default: None). For
301301
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
302302
examples on the first two and last two devices. See ``jax.lax.psum`` for
303-
more details.
303+
more details. This argument is currently not supported for SPMD jit.
304304
use_fast_variance: If true, use a faster, but less numerically stable,
305305
calculation for the variance.
306306
"""
@@ -478,7 +478,7 @@ class LayerNorm(Module):
478478
representing subsets of devices to reduce over (default: None). For
479479
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
480480
examples on the first two and last two devices. See ``jax.lax.psum`` for
481-
more details.
481+
more details. This argument is currently not supported for SPMD jit.
482482
use_fast_variance: If true, use a faster, but less numerically stable,
483483
calculation for the variance.
484484
"""
@@ -580,7 +580,7 @@ class RMSNorm(Module):
580580
representing subsets of devices to reduce over (default: None). For
581581
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
582582
examples on the first two and last two devices. See ``jax.lax.psum`` for
583-
more details.
583+
more details. This argument is currently not supported for SPMD jit.
584584
use_fast_variance: If true, use a faster, but less numerically stable,
585585
calculation for the variance.
586586
"""
@@ -703,7 +703,7 @@ class GroupNorm(Module):
703703
representing subsets of devices to reduce over (default: None). For
704704
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
705705
examples on the first two and last two devices. See ``jax.lax.psum`` for
706-
more details.
706+
more details. This argument is currently not supported for SPMD jit.
707707
use_fast_variance: If true, use a faster, but less numerically stable,
708708
calculation for the variance.
709709
"""
@@ -879,7 +879,7 @@ class InstanceNorm(Module):
879879
representing subsets of devices to reduce over (default: None). For
880880
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
881881
examples on the first two and last two devices. See ``jax.lax.psum`` for
882-
more details.
882+
more details. This argument is currently not supported for SPMD jit.
883883
use_fast_variance: If true, use a faster, but less numerically stable,
884884
calculation for the variance.
885885
"""

flax/nnx/nn/normalization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _compute_stats(
7676
this is only used for pmap and shard map. For SPMD jit, you do not need to
7777
manually synchronize. Just make sure that the axes are correctly annotated
7878
and XLA:SPMD will insert the necessary collectives.
79-
axis_index_groups: Optional axis indices.
79+
axis_index_groups: Optional groups of indices within that named axis.
8080
use_mean: If true, calculate the mean from the input and use it when
8181
computing the variance. If false, set the mean to zero and compute the
8282
variance without subtracting the mean.
@@ -254,7 +254,7 @@ class BatchNorm(Module):
254254
representing subsets of devices to reduce over (default: None). For
255255
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
256256
the examples on the first two and last two devices. See ``jax.lax.psum``
257-
for more details.
257+
for more details. This argument is currently not supported for SPMD jit.
258258
use_fast_variance: If true, use a faster, but less numerically stable,
259259
calculation for the variance.
260260
rngs: rng key.
@@ -429,7 +429,7 @@ class LayerNorm(Module):
429429
representing subsets of devices to reduce over (default: None). For
430430
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
431431
the examples on the first two and last two devices. See ``jax.lax.psum``
432-
for more details.
432+
for more details. This argument is currently not supported for SPMD jit.
433433
use_fast_variance: If true, use a faster, but less numerically stable,
434434
calculation for the variance.
435435
rngs: rng key.
@@ -560,7 +560,7 @@ class RMSNorm(Module):
560560
representing subsets of devices to reduce over (default: None). For
561561
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
562562
the examples on the first two and last two devices. See ``jax.lax.psum``
563-
for more details.
563+
for more details. This argument is currently not supported for SPMD jit.
564564
use_fast_variance: If true, use a faster, but less numerically stable,
565565
calculation for the variance.
566566
rngs: rng key.
@@ -702,7 +702,7 @@ class GroupNorm(Module):
702702
representing subsets of devices to reduce over (default: None). For
703703
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
704704
examples on the first two and last two devices. See ``jax.lax.psum`` for
705-
more details.
705+
more details. This argument is currently not supported for SPMD jit.
706706
use_fast_variance: If true, use a faster, but less numerically stable,
707707
calculation for the variance.
708708
rngs: rng key.

0 commit comments

Comments
 (0)