Skip to content

Commit

Permalink
unit tests for pdb loader
Browse files Browse the repository at this point in the history
  • Loading branch information
SCMusson committed Jul 5, 2023
1 parent feaf543 commit aac0a6a
Showing 1 changed file with 119 additions and 2 deletions.
121 changes: 119 additions & 2 deletions test/test_openmm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,122 @@

import torch

class Test_PDBData_Basics(unittest.TestCase):
def setUp(self):
self.data = molearn.PDBData()
self.data.import_pdb('MurD_test.pdb')
self.data.prepare_dataset()

def test_dataset_is_tensor(self,):
self.assertIsInstance(self.data.dataset, torch.Tensor)

def test_dataset_dimensions(self,):
self.assertTrue(self.data.dataset.shape[0]==16)
self.assertTrue(self.data.dataset.shape[1]==3)
self.assertTrue(self.data.dataset.shape[2]==3286)

def test_mean(self):
self.assertIsInstance(self.data.mean, float)
self.assertTrue(torch.isclose(torch.zeros(1), self.data.dataset.mean()))

def test_std(self):
self.assertIsInstance(self.data.std, float)
self.assertTrue(torch.isclose(torch.ones(1), self.data.dataset.std()))

def test_atominfo(self):
atominfo = self.data.get_atominfo()
self.assertTrue(atominfo.shape[0] == self.data.dataset.shape[2])
self.assertTrue(atominfo.shape[1]==3)
# Are the types correct (str, str, int)
self.assertTrue(all([all([isinstance(a, str), isinstance(b, str), isinstance(c, int)]) for a,b,c in atominfo]))

def test_frame(self):
frame = self.data.frame()
self.assertTrue(frame.coordinates.shape[0] == 1)
self.assertTrue(frame.coordinates.shape[1] == self.data.dataset.shape[2])
self.assertTrue(frame.coordinates.shape[2] == 3)

def test_get_dataloader(self):
bs = 4
train, valid = self.data.get_dataloader(batch_size=bs, validation_split=0.5)
self.assertIsInstance(train, torch.utils.data.DataLoader)
self.assertIsInstance(valid, torch.utils.data.DataLoader)
train_batch = next(iter(train))
valid_batch = next(iter(valid))

self.assertTrue(train_batch[0].shape[0]==bs)
self.assertTrue(valid_batch[0].shape[0]==bs)
self.assertTrue(train_batch[0].shape[1]==3)
self.assertTrue(valid_batch[0].shape[1]==3)
self.assertTrue(train_batch[0].shape[2]==self.data.dataset.shape[2])
self.assertTrue(valid_batch[0].shape[2]==self.data.dataset.shape[2])

def test_split(self):
train, valid = self.data.split(validation_split = 0.5)
self.assertIsInstance(train, molearn.PDBData)
self.assertIsInstance(valid, molearn.PDBData)
self.assertTrue(train.dataset.shape[0]==int(self.data.dataset.shape[0]/2))

def test_get_datasets(self):
train, valid = self.data.get_datasets(validation_split = 0.5)
cdist = torch.cdist(self.data.dataset.reshape(16,-1).unsqueeze(0), train.reshape(8,-1).unsqueeze(0))[0]

mask = torch.isclose(cdist, torch.zeros_like(cdist))
axis_0 = mask.sum(axis=0)
self.assertTrue(torch.allclose(axis_0, torch.ones_like(axis_0)))
axis_1 = mask.sum(axis=1)
self.assertTrue(axis_1.max()==1)
self.assertTrue(axis_1.sum() == 8)

def test_atoms(self):
atoms = self.data.atoms
self.assertTrue(len(atoms)==37)
self.assertTrue(all([isinstance(atom, str) for atom in atoms]))

class Test_PDBData_atomselect_bb(Test_PDBData_Basics):
def setUp(self):
self.data = molearn.PDBData()
self.data.import_pdb('MurD_test.pdb')
self.data.atomselect(atoms = ['CA', 'C', 'CB', 'O', 'N'])
self.data.prepare_dataset()

def test_dataset_dimensions(self,):
self.assertTrue(self.data.dataset.shape[0]==16)
self.assertTrue(self.data.dataset.shape[1]==3)
self.assertTrue(self.data.dataset.shape[2]==2145)

def test_atoms(self):
atoms = self.data.atoms
self.assertTrue(len(atoms)==5)
self.assertTrue(all([isinstance(atom, str) for atom in atoms]))

class Test_PDBData_atomselect_no_hydrogen(Test_PDBData_Basics):
def setUp(self):
self.data = molearn.PDBData()
self.data.import_pdb('MurD_test.pdb')
self.data.atomselect(atoms = 'no_hydrogen')
self.data.prepare_dataset()

def test_dataset_dimensions(self,):
self.assertTrue(self.data.dataset.shape[0]==16)
self.assertTrue(self.data.dataset.shape[1]==3)
self.assertTrue(self.data.dataset.shape[2]==3286)

def test_atoms(self):
atoms = self.data.atoms
self.assertTrue(len(atoms)==37)
self.assertTrue(all([isinstance(atom, str) for atom in atoms]))


class Test_Trainers(unittest.TestCase):
def setUp(self):
self.data = molearn.PDBData()
self.data.import_pdb('MurD_test.pdb')
self.data.atomselect(atoms = ['CA', 'C', 'CB', 'O', 'N'])
self.data.prepare_dataset()
def test_init(self):
pass

class Test_openmm_plugin(unittest.TestCase):
def test_openmm_energy(self,):

Expand All @@ -27,7 +143,8 @@ def forward(self, x):
_d = data.dataset.to(device).float()
#d = next(iter(dataloader))[0].to(device).float()
para = mymodule(_d.clone())
openmmscore = molearn.openmm_energy(data.mol, data.std, clamp=None, platform = 'CUDA') #xml_file = ['modified_amber_protein.xml',])
from molearn.loss_functions import openmm_energy
openmmscore = openmm_energy(data.mol, data.std, clamp=None, platform = 'CUDA') #xml_file = ['modified_amber_protein.xml',])
opt = torch.optim.SGD(para.parameters(), lr=0.0001)
scores = []
for i in range(1000):
Expand All @@ -42,7 +159,7 @@ def forward(self, x):
if i > 20:
ratio = (torch.tensor(scores)>loss.item()).sum()/i
self.assertGreater(ratio, 0.9)
if i%10==0:
if i%100==0:
print(ratio)
print(loss)
scores.append(loss.item())
Expand Down

0 comments on commit aac0a6a

Please sign in to comment.