Skip to content

Commit 6394965

Browse files
author
Johannes Steinmetzer
committed
add: PySCF.prepare_mol/prepare_mf to allow further customization
implements #247 DFTD3-Gradients don't seem to work though, or I'm not clever enough...
1 parent 42f4c6e commit 6394965

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

pysisyphus/calculators/PySCF.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111

1212
class PySCF(OverlapCalculator):
13-
1413
conf_key = "pyscf"
1514
drivers = {
1615
# key: (method, unrestricted?)
@@ -97,6 +96,10 @@ def build_grid(self, mf):
9796
mf.grids.prune = self.pruning_method[self.pruning]
9897
mf.grids.build()
9998

99+
def prepare_mf(self, mf):
100+
# Method can be overriden in a subclass to modify the mf object.
101+
return mf
102+
100103
def get_driver(self, step, mol=None, mf=None):
101104
def _get_driver():
102105
return self.drivers[(step, self.unrestricted)]
@@ -107,10 +110,12 @@ def _get_driver():
107110
mf.xc = self.xc
108111
self.set_scf_params(mf)
109112
self.build_grid(mf)
113+
mf = self.prepare_mf(mf)
110114
elif mol and (step == "scf"):
111115
driver = _get_driver()
112116
mf = driver(mol)
113117
self.set_scf_params(mf)
118+
mf = self.prepare_mf(mf)
114119
elif mf and (step == "mp2"):
115120
mp2_mf = _get_driver()
116121
mf = mp2_mf(mf)
@@ -124,7 +129,7 @@ def _get_driver():
124129
raise Exception("Unknown method '{step}'!")
125130
return mf
126131

127-
def prepare_input(self, atoms, coords):
132+
def prepare_mol(self, atoms, coords, build=True):
128133
mol = gto.Mole()
129134
mol.atom = [(atom, c) for atom, c in zip(atoms, coords.reshape(-1, 3))]
130135
mol.basis = self.basis
@@ -143,8 +148,13 @@ def prepare_input(self, atoms, coords):
143148
# Search for "Large deviations found" in scf/{uhf,dhf,ghf}.py
144149
mol.output = self.make_fn(self.out_fn)
145150
mol.max_memory = self.mem * self.pal
146-
mol.build(parse_arg=False)
151+
if build:
152+
mol.build(parse_arg=False)
153+
return mol
147154

155+
def prepare_input(self, atoms, coords, build=True):
156+
mol = self.prepare_mol(atoms, coords, build=build)
157+
assert mol._built, "Please call mol.build(parse_arg=False)!"
148158
return mol
149159

150160
def store_and_track(self, results, func, atoms, coords, **prepare_kwargs):

0 commit comments

Comments
 (0)