Replies: 1 comment
-
I think this is a mistake made when porting the old BatchNom to linen. It should be an argument to We have so far not included a global trainings=False/True switch to the Module like many other NN apis have. One nice pattern to avoid errors is the following:
#683 should allow for a |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
The training state for BatchNorm is set via
self.use_running_average
attr. For Dropout, it is passed viadeterministic
arg in__call__
.(I realize those modes are not specific to training/not training).
Is there any reason for this difference? I was planning to use a
training=False/True
arg in my__call__
chain to pass training state as opposed to binding layer creation args. I believe it still works fine with jit if it's marked as static?Having training/not training state passed through in some cases as an arg for the layer init and in others as an arg in the
__call__
is a bit jarring and seems error prone (already messed it up once).Beta Was this translation helpful? Give feedback.
All reactions