Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Laplace Subnetwork with timm library model #128

Open
nilsleh opened this issue Jul 6, 2023 · 6 comments
Open

Laplace Subnetwork with timm library model #128

nilsleh opened this issue Jul 6, 2023 · 6 comments

Comments

@nilsleh
Copy link

nilsleh commented Jul 6, 2023

Hi,

I would like to apply the Laplace Subnetwork approach to a timm library model (standard resnet18). I think the problem I am encountering is not unique to timm models per se, but to inplace operations? I have made a small reproducible example in this google colab. The error

Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

only occurs when trying to do a subnetwork approach, and not with the default Laplace parameters. I have also tried to change the timm resnet activation functions to not be inplace, but maybe it is also related to skip connections? Even though the error does not occur inside the Laplace library immediately, I was wondering if you had any suggestions or pointers to make this approach work.

@nilsleh nilsleh changed the title Laplace Subnetwork with timm libray model Laplace Subnetwork with timm library model Jul 6, 2023
@nilsleh
Copy link
Author

nilsleh commented Jul 10, 2023

@wiseodd
Copy link
Collaborator

wiseodd commented Jul 10, 2023

If it's a BackPACK issue, then maybe switching backend will help. Can you try the following?

from laplace import Laplace
from laplace.curvature import AsdlGGN

la = Laplace(model, ..., backend=AsdlGGN)

@runame
Copy link
Collaborator

runame commented Jul 10, 2023

Was just about to suggest the same thing. However, if you want to use AsdlGGN together with regression you will have to install the Laplace library from source and checkout the branch integrate-latest-asdl. Also, you will have to install ASDL from source and just use the master branch. Let us know if you run into any issues with this!

@nilsleh
Copy link
Author

nilsleh commented Jul 11, 2023

Thanks for the recommendation. I believe I installed as you suggested @runame, however, I get
laplace/curvature/asdl.py", line 132, in diag fisher_maker = get_fisher_maker(self.model, cfg, self.kfac_conv) TypeError: get_fisher_maker() takes 2 positional arguments but 3 were given .

@runame
Copy link
Collaborator

runame commented Jul 11, 2023

Ah right, can you try to remove the argument self.kfac_conv to get_fisher_maker?

Edit: I also just fixed this on the integrate-latest-asdl branch, so you can just pull again.

@nilsleh
Copy link
Author

nilsleh commented Aug 29, 2023

Thank you for the help, for some models I do get the following pytorch warning and I am not sure what it implies:
UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.

From this trace:

lib/python3.9/site-packages/laplace/baselaplace.py:377: in fit
    loss_batch, H_batch = self._curv_closure(X, y, N)
lib/python3.9/site-packages/laplace/baselaplace.py:777: in _curv_closure
    return self.backend.kron(X, y, N=N)
lib/python3.9/site-packages/laplace/curvature/asdl.py:164: in kron
    f, _ = fisher_maker.forward_and_backward()
lib/python3.9/site-packages/asdl/fisher.py:116: in forward_and_backward
    self.call_model()
lib/python3.9/site-packages/asdl/grad_maker.py:258: in call_model
    self._model_output = self._model_fn(*self._model_args, **self._model_kwargs)
lib/python3.9/site-packages/torch/nn/modules/module.py:1571: in _call_impl
    self._maybe_warn_non_full_backward_hook(args, result, grad_fn)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants