-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add parametrization for normalization #20
base: develop
Are you sure you want to change the base?
Conversation
7ee8f09
to
a3b262f
Compare
a3b262f
to
80a75a0
Compare
It does not depends on input shape anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting PR!
Just a small suggestion for vanilla_export
import
@@ -8,7 +8,7 @@ jobs: | |||
strategy: | |||
max-parallel: 4 | |||
matrix: | |||
python-version: [3.6, 3.7, 3.8] | |||
python-version: [3.7, 3.8, 3.9] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the statuses of Python versions, Python 3.7 is already deprecated. Maybe we can use a more powerful test matrix with more recent Python versions (and also PyTorch versions?).
It's a more important change and it could be postponed to a future PR
Prerequisite: Torch parametrization tutorial
Features
model.eval()
mode (whenmodel.training
is set toFalse
, spectral and bjorck normalizations use cached tensors to perform free normalizations).vanilla_model
. Be careful, this is an in-place conversion!parametrize.cached()
feature is now also usable on Lipschitz layers, allowing to save memory and compute when the same kernel is applied multiple times in an inference step (very useful for RNNs, multi-level convolutions, etc.). Here is how to use it:Models using parametrized modules can only be serialized through
state_dict()
. Sotorch.save(model, PATH)
is not possible anymore and will raise an error. Instead, save and load your models like this:For more information, check this torch tutorial.
TODO
parametrize.cached()
feature