-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass pbc and cell as args to allow use of forward hook with TorchScript #648
Comments
Hi Lester, could you pass empty tensor (cell, pbc) for your hook? |
No, it doesn't seem to be possible, unless I'm misunderstanding things. The hook signature needs to be: def hook_wrapper():
def hook(
module,
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: torchani.aev.SpeciesAEV,
):
# Do something with the AEVComputer.forward output here. This is then registered with something like: self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook_wrapper()) (Here The For reference, your signature is: def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesAEV: It seems that TorchScript can't resolve that Here's a minimal example that reproduces the issue: import torch
import torchani
from torch import Tensor
from typing import Optional, Tuple
class Test(torch.nn.Module):
def __init__(self):
super(Test, self).__init__()
self._ani2x = torchani.models.ANI2x(periodic_table_index=True)
# Assign a tensor attribute that can be used for assigning the AEVs.
self._ani2x.aev_computer._aev = torch.empty(0)
# Hook the forward pass of the ANI2x model to get the AEV features.
def hook_wrapper():
def hook(
module,
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: torchani.aev.SpeciesAEV,
):
module._aev = output[1][0]
return hook
# Register the hook.
self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook_wrapper())
def forward(self, species: Tensor, coordinates: Tensor) -> Tensor:
# Forward pass of the ANI2x model.
energy = self._ani2x((species, coordinates))[0]
# Do something with the AEV features.
aevs = self._ani2x.aev_computer._aev
return energy
# Create an instance of the model.
model = Test()
# Convert to TorchScript.
script_model = torch.jit.script(model) |
There's the possibility of using import torch
import torchani
def hook(module, args, kwargs, output):
print(module)
print(len(args))
print(args)
print(len(kwargs))
print(kwargs)
class Test(torch.nn.Module):
def __init__(self):
super(Test, self).__init__()
self._ani2x = torchani.models.ANI2x(periodic_table_index=True)
self._aev_hook = self._ani2x.aev_computer.register_forward_hook(
hook, with_kwargs=True
)
def forward(self, coordinates, species):
return self._ani2x((species, coordinates), cell=None, pbc=None).energies
coordinates = torch.tensor(
[
[
[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834],
]
],
requires_grad=True,
dtype=torch.float32,
)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])
# Create model.
model = Test()
# Convert to TorchScript.
#model = torch.jit.script(model)
# Compute energies.
energies = model(coordinates, species) Gives: AEVComputer()
1
(SpeciesCoordinates(species=tensor([[1, 0, 0, 0, 0]]), coordinates=tensor([[[ 0.0319, 0.0064, 0.0130],
[-0.8314, 0.3937, -0.2640],
[-0.6652, -0.8446, 0.2076],
[ 0.4555, 0.5429, 0.8117],
[ 0.6609, -0.1680, -0.9104]]], requires_grad=True)),)
2
{'cell': None, 'pbc': None} However, this doesn't work with TorchScript, i.e. uncommenting the
So it looks like TorchScript can only work with forward hooks when using args, so expects the hooked function to be called with args too, hence why I needed to convert kwargs to args in your models.py file. |
Hi, thanks for the update, sounds good. Could you use your local fork to make the changes? We plan to have a major update later this year, and the current code base is freezed. |
No problem. I'll provide a patch for our users since it's just a modification to a single file. Do you want me to raise a PR with a fix and test anyway? I appreciate that it won't be merged, but at least you'll have a record if you want to apply it to whatever update comes later in the year. Cheers. |
Yes, please open a PR, thank you! |
Hello there,
I am writing a module based on ANI2x that requires AEVs and have been trying to use a forward hook on the ANI2x AEVComputer to avoid duplicating the calculation. While this works perfectly fine in PyTorch, I appear to be unable to serialise my model using TorchScript since you are passing args within the hook input as kwargs. In particular,
cell
andpbc
are always passed ascell=cell
andpbc=pbc
. The exception that I get is:Applying the following patch to
models.py
gets things to work:(I've confirmed that identical results as before.)
Is there a reason why kwargs need to be used? If not, would the proposed patch be acceptable? Using a forward hook gives an appreciable speed gain for my model by removing the need to compute AEVs twice.
Just to note that I am pretty new to TorchScript, so perhaps there is a way to get it to correctly detect the use of kwargs in hooks. Alternatively, perhaps there is another way to get the AEVs from the module. (Currently you use them as an intermediate in the calculation, but they could be stored as a module attribute.) I certainly could derive from ANI2x and overload forward, but the way I have things now is quite a bit more flexible for my use case.
Many thanks.
The text was updated successfully, but these errors were encountered: