diff --git a/tests/test_aev_hook.py b/tests/test_aev_hook.py new file mode 100644 index 000000000..7703532e1 --- /dev/null +++ b/tests/test_aev_hook.py @@ -0,0 +1,49 @@ +import unittest + +import torch +import torchani + +from torch import Tensor +from typing import Optional, Tuple + +from torchani.testing import TestCase + + +class TestAEVHook(TestCase): + def test_aev_hook(self): + # Create a test module. + + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + # Create a ANI2x model. + self._model = torchani.models.ANI2x() + + # Define a dummy hook. This doesn't do anything with the output + # but triggers the exception when kwargs are used when calling + # AEVComputer.forward. + def hook( + module, + input: Tuple[ + Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor] + ], + output: torchani.aev.SpeciesAEV, + ): + pass + + # Register the hook. + self._aev_hook = self._model.aev_computer.register_forward_hook(hook) + + def forward(self, species, coordinates): + return self._model((species, coordinates)) + + # Create a test module. + model = TestModule() + + # Convert the model to TorchScript. + model = torch.jit.script(model) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchani/models.py b/torchani/models.py index 117cb4a9e..522e6d0d0 100644 --- a/torchani/models.py +++ b/torchani/models.py @@ -103,7 +103,7 @@ def forward(self, species_coordinates: Tuple[Tensor, Tensor], if species_coordinates[0].ge(self.aev_computer.num_species).any(): raise ValueError(f'Unknown species found in {species_coordinates[0]}') - species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) + species_aevs = self.aev_computer(species_coordinates, cell, pbc) species_energies = self.neural_networks(species_aevs) return self.energy_shifter(species_energies) @@ -135,7 +135,7 @@ def atomic_energies(self, species_coordinates: Tuple[Tensor, Tensor], """ if self.periodic_table_index: species_coordinates = self.species_converter(species_coordinates) - species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) + species, aevs = self.aev_computer(species_coordinates, cell, pbc) atomic_energies = self.neural_networks._atomic_energies((species, aevs)) self_energies = self.energy_shifter.self_energies.clone().to(species.device) self_energies = self_energies[species] @@ -236,7 +236,7 @@ def atomic_energies(self, species_coordinates: Tuple[Tensor, Tensor], """ if self.periodic_table_index: species_coordinates = self.species_converter(species_coordinates) - species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) + species, aevs = self.aev_computer(species_coordinates, cell, pbc) members_list = [] for nnp in self.neural_networks: members_list.append(nnp._atomic_energies((species, aevs)).unsqueeze(0)) @@ -322,7 +322,7 @@ def members_energies(self, species_coordinates: Tuple[Tensor, Tensor], """ if self.periodic_table_index: species_coordinates = self.species_converter(species_coordinates) - species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) + species, aevs = self.aev_computer(species_coordinates, cell, pbc) member_outputs = [] for nnp in self.neural_networks: unshifted_energies = nnp((species, aevs)).energies