Replies: 6 comments
-
A quick workaround is to set I'll have to think a bit more about how to fix this nicely. |
Beta Was this translation helpful? Give feedback.
-
Thanks for your quick answer. From your answer, I understood that lazy init is part of the issue. I therefore switched to direct init ( (full updated and working colab here: https://colab.research.google.com/drive/1Cm1MIHkKBmZ-xBq21fEG6byzd1IbrGSz?usp=sharing) It now works, but I wonder : is there any reason why I would prefer For the record : the new create model in the new colab: def create_model(prng_key, use_bn=True, shared=True):
input_shape = (100, 64, 64, 2)
model_dtype = jnp.float32
module = MySiameseNet.partial(train=True, use_bn=use_bn, shared=shared)
with nn.stateful() as init_state:
with flax.nn.stochastic(prng_key):
x = jnp.zeros(input_shape,dtype=model_dtype)
_, initial_params = module.init(prng_key, x)
model = nn.Model(module, initial_params)
return model, init_state Thanks |
Beta Was this translation helpful? Give feedback.
-
You don't need train=True during init. The problem is that with train=True you are trying to gather batch statistics (they don't exist because the init is lazy). Of course you can set train=True again during the actual train steps. |
Beta Was this translation helpful? Give feedback.
-
Thanks, your answers made everything super clear to me. My mind was biased by Keras since in Keras, if you set the trainable = False at batchnorm creation it is not possible to come back by setting it to true since it runs in inference mode forever (at least from what I understand of this post https://keras.io/guides/transfer_learning/#do-a-round-of-finetuning-of-the-entire-model ). Flax behavior, where the operator is the same in both mode is an excellent news. For the record, based on your proposition: https://colab.research.google.com/drive/12Bgq0XSy-Y8G2a3HhHLaF6xKuRRQa5Z9?usp=sharing def create_model(prng_key, use_bn=True, shared=True):
input_shape = (100, 64, 64, 2)
model_dtype = jnp.float32
# workaround for avoiding shared lazy init issues with bn
module_for_init = MySiameseNet.partial(train=False, use_bn=use_bn, shared=shared)
with nn.stateful() as init_state:
_, initial_params = module_for_init.init_by_shape(prng_key, [(input_shape, model_dtype)])
module_for_train = MySiameseNet.partial(train=True, use_bn=use_bn, shared=shared)
model = nn.Model(module_for_train, initial_params)
return model, init_state |
Beta Was this translation helpful? Give feedback.
-
Yes this works correctly although I would prefer to write it as follows for simplicity:
|
Beta Was this translation helpful? Give feedback.
-
IIUC, this is no longer an issue in Linen because we no longer have |
Beta Was this translation helpful? Give feedback.
-
Problem you have encountered:
While trying to learn a siamese network based on imagenet example that replicates model states accross devices, I found an issue with batch norms.
Running on :
Logs, error messages, etc:
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
https://colab.research.google.com/drive/1tRqd7rykyvjxgF6Cqv9OqTdXVffYbqL2?usp=sharing
Beta Was this translation helpful? Give feedback.
All reactions