Skip to content

Commit

Permalink
avoid yaff imports when using openmm backend
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Jan 30, 2024
1 parent 2018ce6 commit 19c729f
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 204 deletions.
2 changes: 1 addition & 1 deletion psiflow/walkers/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def evaluate_bias(

from psiflow.data import read_dataset
from psiflow.walkers.bias import try_manual_plumed_linking
from psiflow.walkers.utils import ForcePartPlumed
from psiflow.walkers.molecular_dynamics_yaff import ForcePartPlumed

dataset = read_dataset(slice(None), inputs=[inputs[0]])
values = np.zeros((len(dataset), len(variables) + 1)) # column 0 for CV, 1 for bias
Expand Down
201 changes: 194 additions & 7 deletions psiflow/walkers/molecular_dynamics_yaff.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,204 @@
import numpy as np
import torch
import yaff
from ase import Atoms
from ase.geometry import Cell
from ase.geometry.geometry import find_mic
from ase.io import read, write
from ase.units import Bohr, Ha

from psiflow.walkers.bias import try_manual_plumed_linking
from psiflow.walkers.utils import (
DataHook,
ExtXYZHook,
ForcePartPlumed,
ForceThresholdExceededException,
create_forcefield,
)


class ForcePartPlumed(yaff.external.ForcePartPlumed):
"""Remove timer from _internal_compute to avoid pathological errors"""

def _internal_compute(self, gpos, vtens):
self.plumed.cmd("setStep", self.plumedstep)
self.plumed.cmd("setPositions", self.system.pos)
self.plumed.cmd("setMasses", self.system.masses)
if self.system.charges is not None:
self.plumed.cmd("setCharges", self.system.charges)
if self.system.cell.nvec > 0:
rvecs = self.system.cell.rvecs.copy()
self.plumed.cmd("setBox", rvecs)
# PLUMED always needs arrays to write forces and virial to, so
# provide dummy arrays if Yaff does not provide them
# Note that gpos and forces differ by a minus sign, which has to be
# corrected for when interacting with PLUMED
if gpos is None:
my_gpos = np.zeros(self.system.pos.shape)
else:
gpos[:] *= -1.0
my_gpos = gpos
self.plumed.cmd("setForces", my_gpos)
if vtens is None:
my_vtens = np.zeros((3, 3))
else:
my_vtens = vtens
self.plumed.cmd("setVirial", my_vtens)
# Do the actual calculation, without an update; this should
# only be done at the end of a time step
self.plumed.cmd("prepareCalc")
self.plumed.cmd("performCalcNoUpdate")
if gpos is not None:
gpos[:] *= -1.0
# Retrieve biasing energy
energy = np.zeros((1,))
self.plumed.cmd("getBias", energy)
return energy[0]


class ForceThresholdExceededException(Exception):
pass


class ForcePartASE(yaff.pes.ForcePart):
"""YAFF Wrapper around an ASE calculator"""

def __init__(self, system, atoms):
"""Constructor
Parameters
----------
system : yaff.System
system object
atoms : ase.Atoms
atoms object with calculator included.
force_threshold : float [eV/A]
"""
yaff.pes.ForcePart.__init__(self, "ase", system)
self.system = system # store system to obtain current pos and box
self.atoms = atoms

def _internal_compute(self, gpos=None, vtens=None):
self.atoms.set_positions(self.system.pos * Bohr)
if self.atoms.pbc.all():
self.atoms.set_cell(Cell(self.system.cell._get_rvecs() * Bohr))
energy = self.atoms.get_potential_energy() / Ha
if gpos is not None:
forces = self.atoms.get_forces()
gpos[:] = -forces / (Ha / Bohr)
if vtens is not None:
if self.atoms.pbc.all():
stress = self.atoms.get_stress(voigt=False)
volume = np.linalg.det(self.atoms.get_cell())
vtens[:] = volume * stress / Ha
else:
vtens[:] = 0.0
return energy


class ForceField(yaff.pes.ForceField):
"""Implements force threshold check"""

def __init__(self, *args, force_threshold=20, **kwargs):
super().__init__(*args, **kwargs)
self.force_threshold = force_threshold

def _internal_compute(self, gpos, vtens):
if self.needs_nlist_update: # never necessary?
self.nlist.update()
self.needs_nlist_update = False
result = sum([part.compute(gpos, vtens) for part in self.parts])
forces = (-1.0) * gpos * (Ha / Bohr)
self.check_threshold(forces)
return result

def check_threshold(self, forces):
max_force = np.max(np.linalg.norm(forces, axis=1))
index = np.argmax(np.linalg.norm(forces, axis=1))
if max_force > self.force_threshold:
raise ForceThresholdExceededException(
"Max force exceeded: {} eV/A by atom index {}".format(max_force, index),
)


def create_forcefield(atoms, force_threshold):
"""Creates force field from ASE atoms instance"""
if atoms.pbc.all():
rvecs = atoms.get_cell() / Bohr
else:
rvecs = None
system = yaff.System(
numbers=atoms.get_atomic_numbers(),
pos=atoms.get_positions() / Bohr,
rvecs=rvecs,
)
system.set_standard_masses()
part_ase = ForcePartASE(system, atoms)
return ForceField(system, [part_ase], force_threshold=force_threshold)


class DataHook(yaff.VerletHook):
def __init__(self, start=0, step=1):
super().__init__(start, step)
self.atoms = None
self.data = []

def init(self, iterative):
if iterative.ff.system.cell.nvec > 0:
cell = iterative.ff.system.cell._get_rvecs() * Bohr
else:
cell = None
self.atoms = Atoms(
numbers=iterative.ff.system.numbers.copy(),
positions=iterative.ff.system.pos * Bohr,
cell=cell,
pbc=cell is not None,
)

def pre(self, iterative):
pass

def post(self, iterative):
pass

def __call__(self, iterative):
self.atoms.set_positions(iterative.ff.system.pos * Bohr)
if self.atoms.pbc.all():
self.atoms.set_cell(iterative.ff.system.cell._get_rvecs() * Bohr)
self.data.append(self.atoms.copy())


class ExtXYZHook(yaff.VerletHook): # xyz file writer; obsolete
def __init__(self, path_xyz, start=0, step=1):
super().__init__(start, step)
self.path_xyz = path_xyz
self.atoms = None
self.nwrites = 0
self.temperatures = []

def init(self, iterative):
if iterative.ff.system.cell.nvec > 0:
cell = iterative.ff.system.cell._get_rvecs() * Bohr
else:
cell = None
self.atoms = Atoms(
numbers=iterative.ff.system.numbers.copy(),
positions=iterative.ff.system.pos * Bohr,
cell=cell,
pbc=cell is not None,
)

def pre(self, iterative):
pass

def post(self, iterative):
pass

def __call__(self, iterative):
if iterative.counter > 0: # first write is manual
self.atoms.set_positions(iterative.ff.system.pos * Bohr)
if self.atoms.pbc.all():
self.atoms.set_cell(iterative.ff.system.cell._get_rvecs() * Bohr)
write(self.path_xyz, self.atoms, append=True)
self.nwrites += 1
self.temperatures.append(iterative.temp)


def main():
Expand Down
Loading

0 comments on commit 19c729f

Please sign in to comment.