You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs_nnx/guides/randomness.md
+94-32Lines changed: 94 additions & 32 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -62,36 +62,40 @@ dropout_key = rngs.dropout()
62
62
nnx.display(rngs)
63
63
```
64
64
65
-
Note that the `key` attribute does not change when new PRNG keys are generated.
65
+
### Using random state with flax Modules.
66
66
67
-
### Standard PRNG key stream names
67
+
Almost all flax Modules require a random state for initialization. In a `Linear` layer, for example, we need to sample the weights and biases from the appropriate Normal distribution. Random state is generally provided using the `rngs` keyword argument at initialization.
68
68
69
-
There are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:
|`dropout`| Used by `nnx.Dropout` to create dropout masks |
73
+
Specifically, this will use the RngSteam `rngs.params` for weight initialization. The `params` stream is also used for initialization of `nnx.Conv`, `nnx.ConvTranspose`, and `nnx.Embed`.
75
74
76
-
-`params` is used by most of the standard layers (such as `nnx.Linear`, `nnx.Conv`, `nnx.MultiHeadAttention`, and so on) during the construction to initialize their parameters.
77
-
-`dropout` is used by `nnx.Dropout` and `nnx.MultiHeadAttention` to generate dropout masks.
75
+
+++
78
76
79
-
Below is a simple example of a model that uses `params` and `dropout` PRNG key streams:
77
+
The `nnx.Dropout` module also requires a random state when called. Once again, we do this by passing the `rngs` keyword argument.
80
78
81
79
```{code-cell} ipython3
82
-
class Model(nnx.Module):
83
-
def __init__(self, rngs: nnx.Rngs):
84
-
self.linear = nnx.Linear(20, 10, rngs=rngs)
85
-
self.drop = nnx.Dropout(0.1, rngs=rngs)
80
+
dropout = nnx.Dropout(0.5)
81
+
```
86
82
87
-
def __call__(self, x):
88
-
return nnx.relu(self.drop(self.linear(x)))
83
+
```{code-cell} ipython3
84
+
dropout(jnp.ones(4), rngs=rngs)
85
+
```
89
86
90
-
model = Model(nnx.Rngs(params=0, dropout=1))
87
+
This will use the rng's `dropout` stream. This also applies to Modules that use `Dropout` as a sub-Module, like `nnx.MultiHeadAttention`.
91
88
92
-
y = model(x=jnp.ones((1, 20)))
93
-
print(f'{y.shape = }')
94
-
```
89
+
+++
90
+
91
+
To summarize, there are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:
z2 = rngs.params.bernoulli(0.5, (10,)) # generates key from rngs.params
135
135
```
136
136
137
+
## Forking random state
138
+
139
+
+++
140
+
141
+
Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to fork the random state into separate pieces for each layer. This can be accomplished with the `fork` method, as shown below.
So far, we have looked at passing random state directly to each Module when it gets called. But there's another way to handle call-time randomness in flax: we can bundle the random state into the Module itself. This requires passing the `rngs` keyward argument when initializing the module rather than when calling it. For example, here is how we might construct the simple `Module` we defined earlier using an implicit style.
188
+
189
+
```{code-cell} ipython3
190
+
class Model(nnx.Module):
191
+
def __init__(self, rngs: nnx.Rngs):
192
+
self.linear = nnx.Linear(20, 10, rngs=rngs)
193
+
self.drop = nnx.Dropout(0.1, rngs=rngs)
194
+
195
+
def __call__(self, x):
196
+
return nnx.relu(self.drop(self.linear(x)))
197
+
198
+
model = Model(nnx.Rngs(params=0, dropout=1))
199
+
200
+
y = model(x=jnp.ones((1, 20)))
201
+
print(f'{y.shape = }')
202
+
```
203
+
137
204
## Filtering random state
138
205
139
-
Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
206
+
Implicit random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
140
207
141
208
```{code-cell} ipython3
142
209
model = Model(nnx.Rngs(params=0, dropout=1))
@@ -157,7 +224,7 @@ In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` e
157
224
158
225
In Flax NNX, there are two ways to approach this:
159
226
160
-
1. By passing an `nnx.Rngs` object through the `__call__` stack manually. Standard layers like `nnx.Dropout` and `nnx.MultiHeadAttention` accept the `rngs` argument if you want to have tight control over the random state.
227
+
1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.
161
228
2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.
162
229
163
230
`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.
@@ -180,12 +247,7 @@ assert jnp.allclose(y1, y3) # same
180
247
181
248
## Splitting PRNG keys
182
249
183
-
When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the random state such that each replica has its own unique state. This can be done in two ways:
184
-
185
-
- By manually splitting a key before passing it to one of the `nnx.Rngs` streams; or
186
-
- By using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends.
187
-
188
-
It is more convenient to use `nnx.split_rngs`, since it works nicely with Flax NNX transforms, so here’s one example:
250
+
When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the implicit random state such that each replica has its own unique state. This can be done using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends. Here’s an example:
0 commit comments