Skip to content

Commit 7e8f152

Browse files
author
Flax Authors
committed
Merge pull request #4970 from google:fix-typos-flax-gspmd
PiperOrigin-RevId: 811468958
2 parents 8fe45a7 + e1febda commit 7e8f152

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

docs_nnx/guides/flax_gspmd.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@
273273
"\n",
274274
"* **Additional axis annotation**: Transforms like `vmap` and `scan` will add additional dimensions to the JAX arrays. Unfortunately, in auto sharding mode you will need to use `nnx.vmap` and `nnx.scan` instead of raw JAX transforms, so that both JAX and Flax knows how to shard this dimension. You won't need this in [explicit sharding mode](#explicit-sharding).\n",
275275
"\n",
276-
"* [`jax.lax.with_sharding_constraint](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code): They can help you to enforce specific shardings on intermediate activations. Only works under an auto mode mesh context."
276+
"* [`jax.lax.with_sharding_constraint`](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code): They can help you to enforce specific shardings on intermediate activations. Only works under an auto mode mesh context."
277277
]
278278
},
279279
{
@@ -446,7 +446,7 @@
446446
"source": [
447447
"## Logical axis annotation\n",
448448
"\n",
449-
"JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD]((https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.\n",
449+
"JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.\n",
450450
"\n",
451451
"You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below."
452452
]

docs_nnx/guides/flax_gspmd.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ Make note of the following:
119119

120120
* **Additional axis annotation**: Transforms like `vmap` and `scan` will add additional dimensions to the JAX arrays. Unfortunately, in auto sharding mode you will need to use `nnx.vmap` and `nnx.scan` instead of raw JAX transforms, so that both JAX and Flax knows how to shard this dimension. You won't need this in [explicit sharding mode](#explicit-sharding).
121121

122-
* [`jax.lax.with_sharding_constraint](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code): They can help you to enforce specific shardings on intermediate activations. Only works under an auto mode mesh context.
122+
* [`jax.lax.with_sharding_constraint`](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code): They can help you to enforce specific shardings on intermediate activations. Only works under an auto mode mesh context.
123123

124124
```{code-cell} ipython3
125125
class DotReluDot(nnx.Module):
@@ -230,7 +230,7 @@ print(restored_model.layers.dot1.kernel.shape)
230230

231231
## Logical axis annotation
232232

233-
JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD]((https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.
233+
JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.
234234

235235
You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below.
236236

0 commit comments

Comments
 (0)