10
10
11
11
12
12
class PySCF (OverlapCalculator ):
13
-
14
13
conf_key = "pyscf"
15
14
drivers = {
16
15
# key: (method, unrestricted?)
@@ -97,6 +96,10 @@ def build_grid(self, mf):
97
96
mf .grids .prune = self .pruning_method [self .pruning ]
98
97
mf .grids .build ()
99
98
99
+ def prepare_mf (self , mf ):
100
+ # Method can be overriden in a subclass to modify the mf object.
101
+ return mf
102
+
100
103
def get_driver (self , step , mol = None , mf = None ):
101
104
def _get_driver ():
102
105
return self .drivers [(step , self .unrestricted )]
@@ -107,10 +110,12 @@ def _get_driver():
107
110
mf .xc = self .xc
108
111
self .set_scf_params (mf )
109
112
self .build_grid (mf )
113
+ mf = self .prepare_mf (mf )
110
114
elif mol and (step == "scf" ):
111
115
driver = _get_driver ()
112
116
mf = driver (mol )
113
117
self .set_scf_params (mf )
118
+ mf = self .prepare_mf (mf )
114
119
elif mf and (step == "mp2" ):
115
120
mp2_mf = _get_driver ()
116
121
mf = mp2_mf (mf )
@@ -124,7 +129,7 @@ def _get_driver():
124
129
raise Exception ("Unknown method '{step}'!" )
125
130
return mf
126
131
127
- def prepare_input (self , atoms , coords ):
132
+ def prepare_mol (self , atoms , coords , build = True ):
128
133
mol = gto .Mole ()
129
134
mol .atom = [(atom , c ) for atom , c in zip (atoms , coords .reshape (- 1 , 3 ))]
130
135
mol .basis = self .basis
@@ -143,8 +148,13 @@ def prepare_input(self, atoms, coords):
143
148
# Search for "Large deviations found" in scf/{uhf,dhf,ghf}.py
144
149
mol .output = self .make_fn (self .out_fn )
145
150
mol .max_memory = self .mem * self .pal
146
- mol .build (parse_arg = False )
151
+ if build :
152
+ mol .build (parse_arg = False )
153
+ return mol
147
154
155
+ def prepare_input (self , atoms , coords , build = True ):
156
+ mol = self .prepare_mol (atoms , coords , build = build )
157
+ assert mol ._built , "Please call mol.build(parse_arg=False)!"
148
158
return mol
149
159
150
160
def store_and_track (self , results , func , atoms , coords , ** prepare_kwargs ):
0 commit comments