Skip to content

Commit 3a7b077

Browse files
committed
Update tutorial examples to thread explicit RNGs
1 parent f400015 commit 3a7b077

File tree

6 files changed

+547
-384
lines changed

6 files changed

+547
-384
lines changed

docs_nnx/guides/randomness.ipynb

Lines changed: 265 additions & 65 deletions
Large diffs are not rendered by default.

docs_nnx/guides/randomness.md

Lines changed: 94 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -62,36 +62,40 @@ dropout_key = rngs.dropout()
6262
nnx.display(rngs)
6363
```
6464

65-
Note that the `key` attribute does not change when new PRNG keys are generated.
65+
### Using random state with flax Modules.
6666

67-
### Standard PRNG key stream names
67+
Almost all flax Modules require a random state for initialization. In a `Linear` layer, for example, we need to sample the weights and biases from the appropriate Normal distribution. Random state is generally provided using the `rngs` keyword argument at initialization.
6868

69-
There are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:
69+
```{code-cell} ipython3
70+
linear = nnx.Linear(20, 10, rngs=rngs)
71+
```
7072

71-
| PRNG key stream name | Description |
72-
|----------------------|-----------------------------------------------|
73-
| `params` | Used for parameter initialization |
74-
| `dropout` | Used by `nnx.Dropout` to create dropout masks |
73+
Specifically, this will use the RngSteam `rngs.params` for weight initialization. The `params` stream is also used for initialization of `nnx.Conv`, `nnx.ConvTranspose`, and `nnx.Embed`.
7574

76-
- `params` is used by most of the standard layers (such as `nnx.Linear`, `nnx.Conv`, `nnx.MultiHeadAttention`, and so on) during the construction to initialize their parameters.
77-
- `dropout` is used by `nnx.Dropout` and `nnx.MultiHeadAttention` to generate dropout masks.
75+
+++
7876

79-
Below is a simple example of a model that uses `params` and `dropout` PRNG key streams:
77+
The `nnx.Dropout` module also requires a random state when called. Once again, we do this by passing the `rngs` keyword argument.
8078

8179
```{code-cell} ipython3
82-
class Model(nnx.Module):
83-
def __init__(self, rngs: nnx.Rngs):
84-
self.linear = nnx.Linear(20, 10, rngs=rngs)
85-
self.drop = nnx.Dropout(0.1, rngs=rngs)
80+
dropout = nnx.Dropout(0.5)
81+
```
8682

87-
def __call__(self, x):
88-
return nnx.relu(self.drop(self.linear(x)))
83+
```{code-cell} ipython3
84+
dropout(jnp.ones(4), rngs=rngs)
85+
```
8986

90-
model = Model(nnx.Rngs(params=0, dropout=1))
87+
This will use the rng's `dropout` stream. This also applies to Modules that use `Dropout` as a sub-Module, like `nnx.MultiHeadAttention`.
9188

92-
y = model(x=jnp.ones((1, 20)))
93-
print(f'{y.shape = }')
94-
```
89+
+++
90+
91+
To summarize, there are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:
92+
93+
| PRNG key stream name | Description |
94+
|----------------------|-----------------------------------------------|
95+
| `params` | Used for parameter initialization |
96+
| `dropout` | Used by `nnx.Dropout` to create dropout masks |
97+
98+
+++
9599

96100
### Default PRNG key stream
97101

@@ -105,10 +109,6 @@ key1 = rngs.params() # Call params.
105109
key2 = rngs.dropout() # Fallback to the default stream.
106110
key3 = rngs() # Call the default stream directly.
107111
108-
# Test with the `Model` that uses `params` and `dropout`.
109-
model = Model(rngs)
110-
y = model(jnp.ones((1, 20)))
111-
112112
nnx.display(rngs)
113113
```
114114

@@ -134,9 +134,76 @@ z1 = rngs.normal((2, 3)) # generates key from rngs.default
134134
z2 = rngs.params.bernoulli(0.5, (10,)) # generates key from rngs.params
135135
```
136136

137+
## Forking random state
138+
139+
+++
140+
141+
Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to fork the random state into separate pieces for each layer. This can be accomplished with the `fork` method, as shown below.
142+
143+
```{code-cell} ipython3
144+
class Model(nnx.Module):
145+
def __init__(self, rngs: nnx.Rngs):
146+
self.linear = nnx.Linear(20, 10, rngs=rngs)
147+
self.drop = nnx.Dropout(0.1)
148+
149+
def __call__(self, x, rngs):
150+
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
151+
```
152+
153+
```{code-cell} ipython3
154+
model = Model(rngs=nnx.Rngs(0))
155+
```
156+
157+
```{code-cell} ipython3
158+
@nnx.vmap(in_axes=(None, 0, 0), out_axes=0)
159+
def train_step(model, x, rngs):
160+
out = model(x, rngs=rngs)
161+
```
162+
163+
```{code-cell} ipython3
164+
param_rngs = nnx.Rngs(1)
165+
```
166+
167+
```{code-cell} ipython3
168+
param_rngs
169+
```
170+
171+
```{code-cell} ipython3
172+
forked_rngs = param_rngs.fork(split=5)
173+
```
174+
175+
```{code-cell} ipython3
176+
forked_rngs
177+
```
178+
179+
```{code-cell} ipython3
180+
train_step(model, jnp.ones((5, 20)), forked_rngs)
181+
```
182+
183+
# Implicit Random State
184+
185+
+++
186+
187+
So far, we have looked at passing random state directly to each Module when it gets called. But there's another way to handle call-time randomness in flax: we can bundle the random state into the Module itself. This requires passing the `rngs` keyward argument when initializing the module rather than when calling it. For example, here is how we might construct the simple `Module` we defined earlier using an implicit style.
188+
189+
```{code-cell} ipython3
190+
class Model(nnx.Module):
191+
def __init__(self, rngs: nnx.Rngs):
192+
self.linear = nnx.Linear(20, 10, rngs=rngs)
193+
self.drop = nnx.Dropout(0.1, rngs=rngs)
194+
195+
def __call__(self, x):
196+
return nnx.relu(self.drop(self.linear(x)))
197+
198+
model = Model(nnx.Rngs(params=0, dropout=1))
199+
200+
y = model(x=jnp.ones((1, 20)))
201+
print(f'{y.shape = }')
202+
```
203+
137204
## Filtering random state
138205

139-
Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
206+
Implicit random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
140207

141208
```{code-cell} ipython3
142209
model = Model(nnx.Rngs(params=0, dropout=1))
@@ -157,7 +224,7 @@ In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` e
157224

158225
In Flax NNX, there are two ways to approach this:
159226

160-
1. By passing an `nnx.Rngs` object through the `__call__` stack manually. Standard layers like `nnx.Dropout` and `nnx.MultiHeadAttention` accept the `rngs` argument if you want to have tight control over the random state.
227+
1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.
161228
2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.
162229

163230
`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.
@@ -180,12 +247,7 @@ assert jnp.allclose(y1, y3) # same
180247

181248
## Splitting PRNG keys
182249

183-
When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the random state such that each replica has its own unique state. This can be done in two ways:
184-
185-
- By manually splitting a key before passing it to one of the `nnx.Rngs` streams; or
186-
- By using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends.
187-
188-
It is more convenient to use `nnx.split_rngs`, since it works nicely with Flax NNX transforms, so here’s one example:
250+
When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the implicit random state such that each replica has its own unique state. This can be done using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends. Here’s an example:
189251

190252
```{code-cell} ipython3
191253
rngs = nnx.Rngs(params=0, dropout=1)

0 commit comments

Comments
 (0)