-
Notifications
You must be signed in to change notification settings - Fork 78
Using Laplace approximation for LSTMs as base models #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi Keyvan, can you change the last line of the |
Hi Runa, I change shape of the output as you mentioned and tried the default setting for Laplace. But I got another error:
it would be perfect if you could advise me on this error since I am calling model.train() before fitting Laplace, and the model should be in training mode but it seems that setting the model in train mode does not work because if I remove it I get the same error message. Thanks a lot for your support. |
We call
This might result in a slow-down, but should make it work. |
Hi Runa, |
Glad it works now. Can you double check that the conditions actually hold? This error is only raised when at least one of these conditions does not hold. Regarding applying Laplace to the full network, the issue is that not all modules which are used in your model are supported by our two backends, |
Thanks for your advise. I see your point about the full network. But regarding the subnetwork, I am quite sure that everything should work. I am using this code to check the conditions
and this it the result:
once again thanks for your prompt support. |
Did you install the package via |
Yeah, that's correct, and installing via repo resolves the problem. Yet, unfortunately I am not able to use this option as I guess the batch normalization layer in my model is not supported by Laplace library and its backends. Anyway, thanks a lot for your support. |
@keyvan-amiri could you try this with our new backend? Install the
Then use the diagonal GGN/EF from curvlinops (I don't think curvlinops KFAC supports BatchNorm): from laplace import Laplace
from laplace.curvature import CurvlinopsGGN, CurvlinopsEF
model = ...
la = Laplace(model, likelihood, hessian_structure='diag', backend=CurvlinopsEF) |
Hi every one, and thanks for all the effort you have put on this library. I am Keyvan, and I am working on a regression task in which we use an existing pre-trained model consisting of LSTM (+dropout, batch normalization) layers to get a deterministic point estimate, and then we try to use post-hoc prior precision tuning through laplace library.
Even though I am following what is written in the documentation of Laplace library, I get different errors when Laplace.fit is called. For instance, using default values of the Laplace class I encounter this error: “ValueError: Only 2D inputs are currently supported for MSELoss”, and if I change some of the parameters, this part of the execution self.H += H_batch (within fit method in ParametricLaplace) leads to mismatch shape error. A few days ago I wrote an email to Alexander, and he wrote to me that the best way is to create a new issue here to ask for your kind support. I guess you have not yet tried the library with LSTMs but it might be possible to do it with different backends now. More details about my implementation is as follows:
Loading pretrained model and calling post-hoc laplace method:
The relevant method is:
and the error is as follows:
File "/home/kamiriel/miniconda3/envs/laplace/lib/python3.8/site-packages/laplace/baselaplace.py", line 379, in fit
self.H += H_batch
RuntimeError: The size of tensor a (6) must match the size of tensor b (192) at non-singleton dimension 1
while the pretrained model is defined as per follows:
more additional context is:
The input of the model is in the shape of batch_sizesequence_dimfeature_dim (in this example 3214453), while the output shape of the model is batch_size (1D) since we have a regression task at hand. The last layer of the model is a linear layer with size of linear_hidden_size (in this example 5).
The relevant parameters are set through a config file as follows:
and the get_backend method is this simple method:
Your support is highly appreciated.
Please let me know if I should provide more details.
Regards,
Keyvan
The text was updated successfully, but these errors were encountered: