You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
|`dropout`| Used by `nnx.Dropout` to create dropout masks |
71
+
Almost all flax Modules require a random state for initialization. In a `Linear` layer, for example, we need to sample the weights and biases from the appropriate Normal distribution. Random state is provided using the `rngs` keyword argument at initialization.
75
72
76
-
-`params` is used by most of the standard layers (such as `nnx.Linear`, `nnx.Conv`, `nnx.MultiHeadAttention`, and so on) during the construction to initialize their parameters.
77
-
-`dropout` is used by `nnx.Dropout` and `nnx.MultiHeadAttention` to generate dropout masks.
73
+
```{code-cell} ipython3
74
+
linear = nnx.Linear(20, 10, rngs=rngs)
75
+
```
78
76
79
-
Below is a simple example of a model that uses `params`and `dropout` PRNG key streams:
77
+
Specifically, this will use the RngSteam `rngs.params` for weight initialization. The `params`stream is also used for initialization of `nnx.Conv`, `nnx.ConvTranspose`, and `nnx.Embed`.
80
78
81
-
```{code-cell} ipython3
82
-
class Model(nnx.Module):
83
-
def __init__(self, rngs: nnx.Rngs):
84
-
self.linear = nnx.Linear(20, 10, rngs=rngs)
85
-
self.drop = nnx.Dropout(0.1, rngs=rngs)
79
+
+++
86
80
87
-
def __call__(self, x):
88
-
return nnx.relu(self.drop(self.linear(x)))
81
+
The `nnx.Dropout` module also requires a random state, but it requires this state at *call* time rather than initialization. Once again, we can pass it random state using the `rngs` keyword argument.
89
82
90
-
model = Model(nnx.Rngs(params=0, dropout=1))
83
+
```{code-cell} ipython3
84
+
dropout = nnx.Dropout(0.5)
85
+
```
91
86
92
-
y = model(x=jnp.ones((1, 20)))
93
-
print(f'{y.shape = }')
87
+
```{code-cell} ipython3
88
+
dropout(jnp.ones(4), rngs=rngs)
94
89
```
95
90
91
+
The `nnx.Dropout` layer will use the rng's `dropout` stream. This also applies to Modules that use `Dropout` as a sub-Module, like `nnx.MultiHeadAttention`.
92
+
93
+
+++
94
+
95
+
To summarize, there are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:
|`dropout`| Used by `nnx.Dropout` to create dropout masks |
101
+
102
+
+++
103
+
96
104
### Default PRNG key stream
97
105
98
106
One of the downsides of having named streams is that the user needs to know all the possible names that a model will use when creating the `nnx.Rngs` object. While this could be solved with some documentation, Flax NNX provides a `default` stream that can be
z2 = rngs.params.bernoulli(0.5, (10,)) # generates key from rngs.params
135
139
```
136
140
141
+
## Forking random state
142
+
143
+
+++
144
+
145
+
Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to fork the random state into separate pieces for each layer. This can be accomplished with the `fork` method, as shown below.
The output of `rng.fork` is another `Rng` with keys and counts that have an expanded shape. In the example above, the `RngKey` and `RngCount` of `dropout_rngs` have shape `()`, but in `forked_rngs` they have shape `(5,)`.
188
+
189
+
+++
190
+
191
+
# Implicit Random State
192
+
193
+
+++
194
+
195
+
So far, we have looked at passing random state directly to each Module when it gets called. But there's another way to handle call-time randomness in flax: we can bundle the random state into the Module itself. This requires passing the `rngs` keyward argument when initializing the module rather than when calling it. For example, here is how we might construct the simple `Module` we defined earlier using an implicit style.
196
+
197
+
```{code-cell} ipython3
198
+
class Model(nnx.Module):
199
+
def __init__(self, rngs: nnx.Rngs):
200
+
self.linear = nnx.Linear(20, 10, rngs=rngs)
201
+
self.drop = nnx.Dropout(0.1, rngs=rngs)
202
+
203
+
def __call__(self, x):
204
+
return nnx.relu(self.drop(self.linear(x)))
205
+
206
+
model = Model(nnx.Rngs(params=0, dropout=1))
207
+
208
+
y = model(x=jnp.ones((1, 20)))
209
+
print(f'{y.shape = }')
210
+
```
211
+
137
212
## Filtering random state
138
213
139
-
Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
214
+
Implicit random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:
140
215
141
216
```{code-cell} ipython3
142
217
model = Model(nnx.Rngs(params=0, dropout=1))
@@ -157,7 +232,7 @@ In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` e
157
232
158
233
In Flax NNX, there are two ways to approach this:
159
234
160
-
1. By passing an `nnx.Rngs` object through the `__call__` stack manually. Standard layers like `nnx.Dropout` and `nnx.MultiHeadAttention` accept the `rngs` argument if you want to have tight control over the random state.
235
+
1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.
161
236
2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.
162
237
163
238
`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.
@@ -180,12 +255,7 @@ assert jnp.allclose(y1, y3) # same
180
255
181
256
## Splitting PRNG keys
182
257
183
-
When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the random state such that each replica has its own unique state. This can be done in two ways:
184
-
185
-
- By manually splitting a key before passing it to one of the `nnx.Rngs` streams; or
186
-
- By using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends.
187
-
188
-
It is more convenient to use `nnx.split_rngs`, since it works nicely with Flax NNX transforms, so here’s one example:
258
+
When interacting with [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`, it is often necessary to split the implicit random state such that each replica has its own unique state. This can be done using the `nnx.split_rngs` decorator which will automatically split the random state of any `nnx.RngStream`s found in the inputs of the function, and automatically "lower" them once the function call ends. Here’s an example:
0 commit comments