Skip to content

Commit cdf9ff8

Browse files
committed
Update tutorial examples to thread explicit RNGs
1 parent a9dccc0 commit cdf9ff8

File tree

6 files changed

+547
-402
lines changed

6 files changed

+547
-402
lines changed

docs_nnx/guides/randomness.ipynb

Lines changed: 280 additions & 66 deletions
Large diffs are not rendered by default.

docs_nnx/guides/randomness.md

Lines changed: 101 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -64,35 +64,43 @@ nnx.display(rngs)
6464

6565
Note that the `key` attribute does not change when new PRNG keys are generated.
6666

67-
### Standard PRNG key stream names
67+
+++
6868

69-
There are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:
69+
### Using random state with flax Modules.
7070

71-
| PRNG key stream name | Description |
72-
|----------------------|-----------------------------------------------|
73-
| `params` | Used for parameter initialization |
74-
| `dropout` | Used by `nnx.Dropout` to create dropout masks |
71+
Almost all flax Modules require a random state for initialization. In a `Linear` layer, for example, we need to sample the weights and biases from the appropriate Normal distribution. Random state is provided using the `rngs` keyword argument at initialization.
7572

76-
- `params` is used by most of the standard layers (such as `nnx.Linear`, `nnx.Conv`, `nnx.MultiHeadAttention`, and so on) during the construction to initialize their parameters.
77-
- `dropout` is used by `nnx.Dropout` and `nnx.MultiHeadAttention` to generate dropout masks.
73+
```{code-cell} ipython3
74+
linear = nnx.Linear(20, 10, rngs=rngs)
75+
```
7876

79-
Below is a simple example of a model that uses `params` and `dropout` PRNG key streams:
77+
Specifically, this will use the RngSteam `rngs.params` for weight initialization. The `params` stream is also used for initialization of `nnx.Conv`, `nnx.ConvTranspose`, and `nnx.Embed`.
8078

81-
```{code-cell} ipython3
82-
class Model(nnx.Module):
83-
def __init__(self, rngs: nnx.Rngs):
84-
self.linear = nnx.Linear(20, 10, rngs=rngs)
85-
self.drop = nnx.Dropout(0.1, rngs=rngs)
79+
+++
8680

87-
def __call__(self, x):
88-
return nnx.relu(self.drop(self.linear(x)))
81+
The `nnx.Dropout` module also requires a random state, but it requires this state at *call* time rather than initialization. Once again, we can pass it random state using the `rngs` keyword argument.
8982

90-
model = Model(nnx.Rngs(params=0, dropout=1))
83+
```{code-cell} ipython3
84+
dropout = nnx.Dropout(0.5)
85+
```
9186

92-
y = model(x=jnp.ones((1, 20)))
93-
print(f'{y.shape = }')
87+
```{code-cell} ipython3
88+
dropout(jnp.ones(4), rngs=rngs)
9489
```
9590

91+
The `nnx.Dropout` layer will use the rng's `dropout` stream. This also applies to Modules that use `Dropout` as a sub-Module, like `nnx.MultiHeadAttention`.
92+
93+
+++
94+
95+
To summarize, there are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:
96+
97+
| PRNG key stream name | Description |
98+
|----------------------|-----------------------------------------------|
99+
| `params` | Used for parameter initialization |
100+
| `dropout` | Used by `nnx.Dropout` to create dropout masks |
101+
102+
+++
103+
96104
### Default PRNG key stream
97105

98106
One of the downsides of having named streams is that the user needs to know all the possible names that a model will use when creating the `nnx.Rngs` object. While this could be solved with some documentation, Flax NNX provides a `default` stream that can be
@@ -105,10 +113,6 @@ key1 = rngs.params() # Call params.
105113
key2 = rngs.dropout() # Fallback to the default stream.
106114
key3 = rngs() # Call the default stream directly.
107115
108-
# Test with the `Model` that uses `params` and `dropout`.
109-
model = Model(rngs)
110-
y = model(jnp.ones((1, 20)))
111-
112116
nnx.display(rngs)
113117
```
114118

@@ -134,9 +138,80 @@ z1 = rngs.normal((2, 3)) # generates key from rngs.default
134138
z2 = rngs.params.bernoulli(0.5, (10,)) # generates key from rngs.params
135139
```
136140

141+
## Forking random state
142+
143+
+++
144+
145+
Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to fork the random state into separate pieces for each layer. This can be accomplished with the `fork` method, as shown below.
146+
147+
```{code-cell} ipython3
148+
class Model(nnx.Module):
149+
def __init__(self, rngs: nnx.Rngs):
150+
self.linear = nnx.Linear(20, 10, rngs=rngs)
151+
self.drop = nnx.Dropout(0.1)
152+
153+
def __call__(self, x, rngs):
154+
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
155+
```
156+
157+
```{code-cell} ipython3
158+
model = Model(rngs=nnx.Rngs(0))
159+
```
160+
161+
```{code-cell} ipython3
162+
@nnx.vmap(in_axes=(None, 0, 0), out_axes=0)
163+
def train_step(model, x, rngs):
164+
out = model(x, rngs=rngs)
165+
```
166+
167+
```{code-cell} ipython3
168+
dropout_rngs = nnx.Rngs(1)
169+
```
170+
171+
```{code-cell} ipython3
172+
dropout_rngs
173+
```
174+
175+
```{code-cell} ipython3
176+
forked_rngs = dropout_rngs.fork(split=5)
177+
```
178+
179+
```{code-cell} ipython3
180+
forked_rngs
181+
```
182+
183+
```{code-cell} ipython3
184+
train_step(model, jnp.ones((5, 20)), forked_rngs)
185+
```
186+
187+
The output of `rng.fork` is another `Rng` with keys and counts that have an expanded shape. In the example above, the `RngKey` and `RngCount` of `dropout_rngs` have shape `()`, but in `forked_rngs` they have shape `(5,)`.
188+
189+
+++
190+
191+
# Implicit Random State
192+
193+
+++
194+
195+
So far, we have looked at passing random state directly to each Module when it gets called. But there's another way to handle call-time randomness in flax: we can bundle the random state into the Module itself. This requires passing the `rngs` keyward argument when initializing the module rather than when calling it. For example, here is how we might construct the simple `Module` we defined earlier using an implicit style.
196+
197+
```{code-cell} ipython3
198+
class Model(nnx.Module):
199+
def __init__(self, rngs: nnx.Rngs):
200+
self.linear = nnx.Linear(20, 10, rngs=rngs)
201+
self.drop = nnx.Dropout(0.1, rngs=rngs)
202+
203+
def __call__(self, x):
204+
return nnx.relu(self.drop(self.linear(x)))
205+
206+
model = Model(nnx.Rngs(params=0, dropout=1))
207+
208+
y = model(x=jnp.ones((1, 20)))
209+
print(f'{y.shape = }')
210+
```
211+
137212
## Filtering random state
138213

139-
Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
214+
Implicit random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
140215

141216
```{code-cell} ipython3
142217
model = Model(nnx.Rngs(params=0, dropout=1))
@@ -157,7 +232,7 @@ In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` e
157232

158233
In Flax NNX, there are two ways to approach this:
159234

160-
1. By passing an `nnx.Rngs` object through the `__call__` stack manually. Standard layers like `nnx.Dropout` and `nnx.MultiHeadAttention` accept the `rngs` argument if you want to have tight control over the random state.
235+
1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.
161236
2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.
162237

163238
`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.
@@ -180,12 +255,7 @@ assert jnp.allclose(y1, y3) # same
180255

181256
## Splitting PRNG keys
182257

183-
When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the random state such that each replica has its own unique state. This can be done in two ways:
184-
185-
- By manually splitting a key before passing it to one of the `nnx.Rngs` streams; or
186-
- By using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends.
187-
188-
It is more convenient to use `nnx.split_rngs`, since it works nicely with Flax NNX transforms, so here’s one example:
258+
When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the implicit random state such that each replica has its own unique state. This can be done using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends. Here’s an example:
189259

190260
```{code-cell} ipython3
191261
rngs = nnx.Rngs(params=0, dropout=1)

0 commit comments

Comments
 (0)