@@ -90,7 +90,7 @@ def _compute_stats(
90
90
this is only used for pmap and shard map. For SPMD jit, you do not need to
91
91
manually synchronize. Just make sure that the axes are correctly annotated
92
92
and XLA:SPMD will insert the necessary collectives.
93
- axis_index_groups: Optional axis indices.
93
+ axis_index_groups: Optional groups of indices within that named axis .
94
94
use_mean: If true, calculate the mean from the input and use it when
95
95
computing the variance. If false, set the mean to zero and compute the
96
96
variance without subtracting the mean.
@@ -300,7 +300,7 @@ class BatchNorm(Module):
300
300
representing subsets of devices to reduce over (default: None). For
301
301
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
302
302
examples on the first two and last two devices. See ``jax.lax.psum`` for
303
- more details.
303
+ more details. This argument is currently not supported for SPMD jit.
304
304
use_fast_variance: If true, use a faster, but less numerically stable,
305
305
calculation for the variance.
306
306
"""
@@ -478,7 +478,7 @@ class LayerNorm(Module):
478
478
representing subsets of devices to reduce over (default: None). For
479
479
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
480
480
examples on the first two and last two devices. See ``jax.lax.psum`` for
481
- more details.
481
+ more details. This argument is currently not supported for SPMD jit.
482
482
use_fast_variance: If true, use a faster, but less numerically stable,
483
483
calculation for the variance.
484
484
"""
@@ -580,7 +580,7 @@ class RMSNorm(Module):
580
580
representing subsets of devices to reduce over (default: None). For
581
581
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
582
582
examples on the first two and last two devices. See ``jax.lax.psum`` for
583
- more details.
583
+ more details. This argument is currently not supported for SPMD jit.
584
584
use_fast_variance: If true, use a faster, but less numerically stable,
585
585
calculation for the variance.
586
586
"""
@@ -703,7 +703,7 @@ class GroupNorm(Module):
703
703
representing subsets of devices to reduce over (default: None). For
704
704
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
705
705
examples on the first two and last two devices. See ``jax.lax.psum`` for
706
- more details.
706
+ more details. This argument is currently not supported for SPMD jit.
707
707
use_fast_variance: If true, use a faster, but less numerically stable,
708
708
calculation for the variance.
709
709
"""
@@ -879,7 +879,7 @@ class InstanceNorm(Module):
879
879
representing subsets of devices to reduce over (default: None). For
880
880
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
881
881
examples on the first two and last two devices. See ``jax.lax.psum`` for
882
- more details.
882
+ more details. This argument is currently not supported for SPMD jit.
883
883
use_fast_variance: If true, use a faster, but less numerically stable,
884
884
calculation for the variance.
885
885
"""
0 commit comments