Skip to content

Commit

Permalink
bugfix for scan_error_from_target
Browse files Browse the repository at this point in the history
  • Loading branch information
degiacom committed Jul 5, 2023
1 parent aac0a6a commit c0f345c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/molearn/analysis/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,23 @@ def scan_error_from_target(self, key, index=None, align=False):
Calculate landscape of RMSD vs single target structure. Target should be previously loaded datset containing a single conformation.
:param str key: key pointing to a dataset previously loaded with :func:`set_dataset <molearn.analysis.MolearnAnalysis.set_dataset>`
:param int index: index of conformation to be selected from dataset containing multiple conformations.
:param bool align: if True, structures generated from the grid are aligned to target prior RMSD calculation.
:return: RMSD latent space NxN surface
:return: x-axis values
:return: y-axis values
'''
s_key = f'RMSD_from_{key}' if index is None else f'RMSD_from_{key}_index_{index}'
if s_key not in self.surfaces:
assert 'grid' in self._encoded, 'make sure to call MolearnAnalysis.setup_grid first'
target = self.get_dataset(key) if index is None else self.get_dataset(key)[index]
assert target.shape[0] == 1, f'The key {key} points to more than one structure, '
+'either pass a key that points to a single structure or pass the index of the '
+'structure you want i.e. analyser.scan_error_from_target(key, index=0)'
target = self.get_dataset(key) if index is None else self.get_dataset(key)[index].unsqueeze(0)
if target.shape[0] != 1:
msg = f'dataset {key} shape is {target.shape}. \
A dataset with a single conformation is expected.\
Either pass a key that points to a single structure or pass the index of the \
structure you want, e.g., analyser.scan_error_from_target(key, index=0)'
raise Exception(msg)

decoded = self.get_decoded('grid')
if align:
crd_ref = as_numpy(target.permute(0,2,1))*self.stdval
Expand Down

0 comments on commit c0f345c

Please sign in to comment.