Skip to content

Commit 3f58ca2

Browse files
committed
Modifying argparse and checking nstates validity
1 parent 497446b commit 3f58ca2

1 file changed

Lines changed: 45 additions & 43 deletions

File tree

analysis/tsh_traj_anal.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,15 @@ def parse_cmd():
3232
description="Analyze Surface Hopping trajectories",
3333
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
3434
)
35-
parser.add_argument("-p", "--print", action="store_true", help="Save png figure")
36-
parser.add_argument("-c", "--convert", action="store_false", help="Convert a.u. to eV")
35+
parser.add_argument("-s", "--save-fig", action="store_true", help="Save png figure")
36+
parser.add_argument("-eu", "--energy-units", choices=["eV", "au"], default="eV", help="Energy units")
3737
parser.add_argument(
38-
"-s",
39-
"--shift",
38+
"-es",
39+
"--energy-shift",
4040
action="store_false",
4141
help="Shift curves so that the lowest state minimum has 0 energy",
4242
)
4343
parser.add_argument("-n", "--nstates", type=int, help="Number of states to plot (default: all)")
44-
parser.add_argument(
45-
"-pop",
46-
"--population",
47-
default="pop.dat",
48-
help="File with populations from simulations",
49-
)
50-
parser.add_argument("-pes", "--pes", default="PES.dat", help="File with PES from simulations")
51-
parser.add_argument(
52-
"-en",
53-
"--energy",
54-
default="energies.dat",
55-
help="Energy fileFile with energies from simulations - used to get running PES",
56-
)
5744
args = parser.parse_args()
5845
return vars(args)
5946

@@ -66,43 +53,70 @@ def file_exists(fname: str):
6653

6754

6855
config = parse_cmd()
69-
file_exists(config["population"])
70-
file_exists(config["pes"])
71-
file_exists(config["energy"])
56+
57+
# input files
58+
popfile = "pop.dat"
59+
pesfile = "PES.dat"
60+
enfile = "energies.dat"
61+
62+
# check if files exist
63+
file_exists(popfile)
64+
file_exists(pesfile)
65+
file_exists(enfile)
7266

7367
# Lazy imports to speed up help printing
7468
import numpy as np # noqa: E402
7569
import matplotlib.pyplot as plt # noqa: E402
7670

7771
# reading data from files
78-
data = np.genfromtxt(config["pes"])
79-
energies = np.genfromtxt(config["energy"])
80-
pop = np.genfromtxt(config["population"])
72+
pop = np.genfromtxt(popfile)
73+
data = np.genfromtxt(pesfile)
74+
energies = np.genfromtxt(enfile)
8175

82-
# Set nstates to all unless specified otherwise
76+
# Set nstates to all unless specified otherwise, or check if nstates is valid
8377
nstates = config["nstates"]
8478
if nstates is None:
8579
nstates = len(data.T) - 1
80+
elif nstates > len(data.T) - 1:
81+
exit(f"ERROR: nstates ({nstates}) is larger than the number of states in the data ({len(data.T) - 1})")
82+
elif nstates < 1:
83+
exit(f"ERROR: nstates ({nstates}) must be at least 1")
8684

8785
# converting data to eV
88-
if config["convert"]:
86+
if config["energy_units"] == "eV":
8987
data.T[1:, :] = data.T[1:, :] * 27.2114
9088
energies.T[1:, :] = energies.T[1:, :] * 27.2114
9189
enunits = "eV"
9290
else:
9391
enunits = "a.u."
9492

9593
# shifting data
96-
if config["shift"]:
94+
if config["energy_shift"]:
9795
minE = np.min(data.T[1:, :])
9896
data.T[1:, :] = data.T[1:, :] - minE
9997
energies.T[1, :] = energies.T[1, :] - minE
10098
energies.T[-1, :] = energies.T[-1, :] - minE
10199

102-
# plotting
100+
### plotting
103101
colors = plt.cm.viridis(np.linspace(0, 0.8, nstates))
104102
fig, axs = plt.subplots(4, 1, figsize=(8, 7.5), gridspec_kw={"height_ratios": [1, 2, 1, 1]}, sharex=True)
105103

104+
# plotting populations
105+
for i in range(nstates):
106+
axs[0].plot(pop.T[0, :], pop.T[i + 2, :], color=colors[i], label=rf"$S_{i}$")
107+
108+
axs[0].plot(
109+
pop.T[0, :],
110+
pop.T[1, :] - 1,
111+
color="black",
112+
linestyle="dashed",
113+
label="act. state\nindex",
114+
)
115+
116+
axs[0].set_ylabel("El. populations")
117+
axs[0].legend(labelspacing=0)
118+
119+
# plotting PES
106120
for i in range(0, nstates):
107121
axs[1].plot(data.T[0, :], data.T[i + 1, :], color=colors[i], label=rf"$S_{i}$")
108122

@@ -127,24 +141,11 @@ def file_exists(fname: str):
127141
axs[1].set_ylabel(f"Energy ({enunits})")
128142
axs[1].legend(labelspacing=0)
129143

130-
131-
for i in range(nstates):
132-
axs[0].plot(pop.T[0, :], pop.T[i + 2, :], color=colors[i], label=rf"$S_{i}$")
133-
134-
axs[0].plot(
135-
pop.T[0, :],
136-
pop.T[1, :] - 1,
137-
color="black",
138-
linestyle="dashed",
139-
label="act. state\nindex",
140-
)
141-
142-
axs[0].set_ylabel("El. populations")
143-
axs[0].legend(labelspacing=0)
144-
144+
# plotting kinetic energy
145145
axs[2].plot(energies.T[0, :], energies.T[2, :], color="black", alpha=1, label=r"$E_k$")
146146
axs[2].set_ylabel(f"Kin. energy ({enunits})")
147147

148+
# plotting total energy
148149
for i in range(0, nstates):
149150
axs[3].scatter(
150151
energies.T[0, pop.T[1, :] == (i + 1)],
@@ -165,7 +166,8 @@ def file_exists(fname: str):
165166

166167
plt.tight_layout()
167168
plt.subplots_adjust(hspace=0)
169+
168170
# save figure
169-
if config["print"]:
171+
if config["save_fig"]:
170172
plt.savefig("PES_pop", dpi=300)
171173
plt.show()

0 commit comments

Comments
 (0)