Skip to content

Commit

Permalink
mod: root argument for OverlapCalculator
Browse files Browse the repository at this point in the history
- all classes deriving from OverlapCalculator now support the 'root'
keyword
- ORCA sets appropriate iroot, even when keyword is not present in the
input
  • Loading branch information
Johannes Steinmetzer authored and eljost committed Sep 26, 2023
1 parent d755903 commit d08c906
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 52 deletions.
25 changes: 14 additions & 11 deletions pysisyphus/calculators/DFTBp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools as it

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'it' is not used.
from math import ceil
from pathlib import Path
import re
Expand Down Expand Up @@ -35,13 +36,15 @@ def parse_xplusy(text):
for i in range(states):
block = lines[i * block_size : (i + 1) * block_size]
_, *rest = block
xpy = np.array([line.split() for line in rest], dtype=float)
xpy = list()
for line in block[1:]:
xpy.extend(line.split())
xpy = np.array(xpy, dtype=float)
xpys.append(xpy)
return size, states, np.array(xpys)


class DFTBp(OverlapCalculator):

conf_key = "dftbp"
_set_plans = (
"out",
Expand Down Expand Up @@ -95,7 +98,7 @@ class DFTBp(OverlapCalculator):
},
}

def __init__(self, parameter, *args, slakos=None, root=None, **kwargs):
def __init__(self, parameter, *args, slakos=None, **kwargs):
super().__init__(*args, **kwargs)

assert self.mult == 1, "Open-shell not yet supported!"
Expand All @@ -107,8 +110,6 @@ def __init__(self, parameter, *args, slakos=None, root=None, **kwargs):
f"Expected '{self.parameter}' sub-directory in '{self.slakos_prefix}' "
"but could not find it!"
)
self.root = root

self.base_cmd = self.get_cmd()
self.gen_geom_fn = "geometry.gen"
self.inp_fn = "dftb_in.hsd"
Expand Down Expand Up @@ -203,8 +204,8 @@ def get_gen_str(atoms, coords):
return gen_str

@staticmethod
def get_excited_state_str(root, forces=False):
if root is None:
def get_excited_state_str(track, root, nroots, forces=False):
if root is None and (track == False):
return ""

casida_tpl = jinja2.Template(
Expand All @@ -213,7 +214,7 @@ def get_excited_state_str(root, forces=False):
Casida {
NrOfExcitations = {{ nstates }}
Symmetry = Singlet
StateOfInterest = {{ root }}
{% if root %}StateOfInterest = {{ root }}{% endif %}
WriteXplusY = Yes
{{ es_forces }}
}
Expand All @@ -222,7 +223,7 @@ def get_excited_state_str(root, forces=False):
)
es_forces = "ExcitedStateForces = Yes" if forces else ""
es_str = casida_tpl.render(
nstates=root + 5,
nstates=nroots if nroots else root + 5,
root=root,
es_forces=es_forces,
)
Expand All @@ -236,7 +237,7 @@ def prepare_input(self, atoms, coords, calc_type):
analysis = list()
if calc_type == "forces":
analysis.append("CalculateForces = Yes")
if self.root:
if self.track or self.root:
analysis.extend(("WriteEigenvectors = Yes", "EigenvectorsAsText = Yes"))
ang_moms = self.max_ang_moms[self.parameter]
unique_atoms = set(atoms)
Expand Down Expand Up @@ -265,7 +266,9 @@ def prepare_input(self, atoms, coords, calc_type):
parameter=self.parameter,
max_ang_moms=max_ang_moms,
hubbard_derivs=hubbard_derivs,
excited_state_str=self.get_excited_state_str(self.root, es_forces),
excited_state_str=self.get_excited_state_str(
self.track, self.root, self.nroots, es_forces
),
analysis=analysis,
)
return inp, path
Expand Down
25 changes: 11 additions & 14 deletions pysisyphus/calculators/Gaussian16.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import subprocess
import textwrap
import warnings

import numpy as np
import pyparsing as pp
Expand Down Expand Up @@ -53,18 +54,16 @@ def __init__(
for kw, option in [self.parse_keyword(kw) for kw in self.route.split()]
}
exc_keyword = [key for key in "td tda cis".split() if key in keywords]
self.root = None
self.nstates = None
self.nstates = self.nroots
if exc_keyword:
self.exc_key = exc_keyword[0]
exc_dict = keywords[self.exc_key]
self.nstates = int(exc_dict["nstates"])
try:
self.root = int(exc_dict["root"])

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute root, which was previously defined in superclass
OverlapCalculator
.
except KeyError:
self.root = 1
self.log("No explicit root was specified! Using root=1 as default!")
# Collect remaining options if specified
warnings.warn("No explicit root was specified!")
# Collect remaining additional options if specified
self.exc_args = {
k: v for k, v in exc_dict.items() if k not in ("nstates", "root")
}
Expand Down Expand Up @@ -120,13 +119,14 @@ def __init__(

def make_exc_str(self):
# Ground state calculation
if not self.root:
if not self.track and (self.root is None):
return ""
root = f"root={self.root}"
root = self.root if self.root is not None else 1
root_str = f"root={root}"
nstates = f"nstates={self.nstates}"
pair2str = lambda k, v: f"{k}" + (f"={v}" if v else "")
arg_str = ",".join([pair2str(k, v) for k, v in self.exc_args.items()])
exc_str = f"{self.exc_key}=({root},{nstates},{arg_str})"
exc_str = f"{self.exc_key}=({root_str},{nstates},{arg_str})"
return exc_str

def reuse_data(self, path):
Expand Down Expand Up @@ -341,11 +341,7 @@ def store_and_track(self, results, func, atoms, coords, **prepare_kwargs):
return results

def get_energy(self, atoms, coords, **prepare_kwargs):
results = self.get_forces(atoms, coords, **prepare_kwargs)
results = self.store_and_track(
results, self.get_energy, atoms, coords, **prepare_kwargs
)
return results
return self.get_forces(atoms, coords, **prepare_kwargs)

def get_forces(self, atoms, coords, **prepare_kwargs):
did_stable = False
Expand Down Expand Up @@ -592,7 +588,8 @@ def parse_force(self, path):
exc_energies = self.parse_tddft(path)
# G16 root input is 1 based, so we substract 1 to get
# the right index here.
root_exc_en = exc_energies[self.root - 1]
root = self.root if self.root is not None else 1
root_exc_en = exc_energies[root - 1]
gs_energy = fchk_dict["SCF Energy"]
# Add excitation energy to ground state energy.
results["energy"] += root_exc_en
Expand Down
26 changes: 22 additions & 4 deletions pysisyphus/calculators/ORCA.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,20 @@ def __init__(
)

self.do_tddft = False
if "tddft" in self.blocks:
self.es_block_header = [key for key in ("tddft", "cis") if key in self.blocks]
if self.es_block_header:
assert len(self.es_block_header) == 1
self.es_block_header = self.es_block_header[0]
if self.es_block_header:
self.do_tddft = True
try:
self.root = int(re.search(r"iroot\s*(\d+)", self.blocks).group(1))

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute root, which was previously defined in superclass
OverlapCalculator
.
warnings.warn(
f"Using root {self.root}, as specified w/ 'iroot' keyword. Please "
"use the designated 'root' keyword in the future, as the 'iroot' route "
"will be deprecated.",
DeprecationWarning,
)
except AttributeError:
self.log("Doing TDA/TDDFT calculation without gradient.")
self.triplets = bool(re.search(r"triplets\s+true", self.blocks))
Expand Down Expand Up @@ -694,9 +704,17 @@ def prepare_input(

def get_block_str(self):
block_str = self.blocks
# Use the correct root if we track it
if self.track:
block_str = re.sub(r"iroot\s+(\d+)", f"iroot {self.root}", self.blocks)
# Use the correct root if we track and a root is supplied
if self.track and (self.root is not None):
if "iroot" in self.blocks:
block_str = re.sub(r"iroot\s+(\d+)", f"iroot {self.root}", self.blocks)
# Insert appropriate iroot keyword if not already present
else:
block_str = re.sub(
f"{self.es_block_header}",
f"{self.es_block_header} iroot {self.root}",
self.blocks,
)
self.log(f"Using iroot '{self.root}' for excited state gradient.")
return block_str

Expand Down
9 changes: 8 additions & 1 deletion pysisyphus/calculators/OverlapCalculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class OverlapCalculator(Calculator):
def __init__(
self,
*args,
root=None,
nroots=None,
track=False,
ovlp_type="tden",
double_mol=False,
Expand All @@ -126,7 +128,13 @@ def __init__(
):
super().__init__(*args, **kwargs)

self.root = root
self.nroots = nroots
self.track = track

# TODO: enable this, when all calculators implement self.root & self.nroots
# if self.track:
# assert self.root <= self.nroots, "'root' must be smaller " "than 'nroots'!"
self.ovlp_type = ovlp_type
assert (
self.ovlp_type in self.OVLP_TYPE_VERBOSE.keys()
Expand Down Expand Up @@ -219,7 +227,6 @@ def __init__(
self.ref_cycle = 0
self.ref_cycles = list()
self.atoms = None
self.root = None

if track:
self.log(
Expand Down
23 changes: 11 additions & 12 deletions pysisyphus/calculators/PySCF.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def __init__(
basis,
xc=None,
method="scf",
root=None,
nstates=None,
auxbasis=None,
keep_chk=True,
verbose=0,
Expand All @@ -58,14 +56,8 @@ def __init__(
self.multisteps[self.method] = ("scf", self.method)
if self.xc and self.method != "tddft":
self.method = "dft"
self.root = root
self.nstates = nstates
if self.method == "tddft":
assert self.nstates, "nstates must be set with method='tddft'!"
if self.track:
assert self.root <= self.nstates, (
"'root' must be smaller " "than 'nstates'!"
)
assert self.nroots, "nroots must be set with method='tddft'!"
self.auxbasis = auxbasis
self.keep_chk = keep_chk
self.verbose = int(verbose)
Expand Down Expand Up @@ -121,10 +113,10 @@ def _get_driver():
mf = mp2_mf(mf)
elif mf and (step == "tddft"):
mf = pyscf.tddft.TDDFT(mf)
mf.nstates = self.nstates
mf.nstates = self.nroots
elif mf and (step == "tda"):
mf = pyscf.tddft.TDA(mf)
mf.nstates = self.nstates
mf.nstates = self.nroots
else:
raise Exception("Unknown method '{step}'!")
return mf
Expand Down Expand Up @@ -171,8 +163,15 @@ def get_energy(self, atoms, coords, **prepare_kwargs):

mol = self.prepare_input(atoms, coords)
mf = self.run(mol, point_charges=point_charges)
energy = mf.e_tot
root = 0 if self.root is None else self.root
try:
energy = energy[root]
except TypeError:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass

results = {
"energy": mf.e_tot,
"energy": energy,
}
results = self.store_and_track(
results, self.get_energy, atoms, coords, **prepare_kwargs
Expand Down
10 changes: 2 additions & 8 deletions pysisyphus/calculators/Turbomole.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def __init__(
self,
control_path=None,
simple_input=None,
root=None,
double_mol_path=None,
cosmo_kwargs=None,
**kwargs,
Expand All @@ -186,9 +185,6 @@ def __init__(

# Handle simple input
if simple_input:
control_path = Path(
"/home/johannes/Code/pysisyphus/tests/test_turbomole/sic"
)
control_path = (self.out_dir / get_random_path("control_path")).absolute()
self.log(
"Set 'control_path' to '{control_path}'. Creating 'control' from simple input in it."
Expand All @@ -209,7 +205,6 @@ def __init__(
# Set provided control_path or use the one generated for simple_input
self.control_path = Path(control_path).absolute()

self.root = root
self.double_mol_path = double_mol_path
if self.double_mol_path:
self.double_mol_path = Path(self.double_mol_path)
Expand Down Expand Up @@ -353,11 +348,9 @@ def get_cmd(cmd):
self.log("\tHessian cmd: " + self.hessian_cmd)

if self.td or self.ricc2 and (self.root is None):
self.root = 1
warnings.warn(
"No root set! Either include '$exopt' for TDA/TDDFT or "
"'geoopt' for ricc2 in the control or supply a value for 'root'! "
f"Continuing with root={self.root}."
)

def set_occ_and_mo_nums(self, text):
Expand Down Expand Up @@ -649,7 +642,8 @@ def parse_energy(self, path):

if self.td:
# Drop ground state energy that is repeated
tot_en = tot_ens[1:][self.root]
root = self.root if self.root is not None else 1
tot_en = tot_ens[1:][root]
elif self.ricc2 and self.ricc2_opt:
results = parse_turbo_gradient(path)
tot_en = results["energy"]
Expand Down
Loading

0 comments on commit d08c906

Please sign in to comment.