Skip to content

Commit ad97133

Browse files
committed
some tests
1 parent 46d1d27 commit ad97133

File tree

2 files changed

+116
-4
lines changed

2 files changed

+116
-4
lines changed

WrightSim/hamiltonian/default.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ def __init__(
7171
Order matters, and meaning is dependent on the individual Hamiltonian.
7272
Default is two values, both initially 1.0.
7373
omega : 1-D array <float64> (optional)
74-
The energies of various transitions.
74+
The energies of various transitions (wavenumbers).
7575
The default uses w_central and coupling parameters to compute the appropriate
7676
values for a TRIVE Hamiltonian
77-
w_central : float (optional)
78-
The cetral frequency of a resonance for a TRIVE Hamiltonian.
77+
w_central : float (optional)
78+
The central frequency (wavenumbers) of a resonance for a TRIVE Hamiltonian.
7979
Used only when ``omega`` is ``None``.
80-
coupling : float (optional)
80+
coupling : float (optional) (wavenumbers)
8181
The copuling of states for a TRIVE Hamiltonian.
8282
Used only when ``omega`` is ``None``.
8383
propagator : function (optional)

tests/mixed/default.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""smokescreen checks of mixed domain default hamiltonian integration"""
2+
import WrightSim as ws
3+
import WrightTools as wt
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
7+
8+
dt = 20
9+
nt = 21
10+
wn_to_omega = 2 * np.pi * 3e-5 # cm / fs
11+
w_central = 3000 # wn
12+
coupling = 0 # wn
13+
14+
ham = ws.hamiltonian.Hamiltonian(
15+
w_central=w_central,
16+
coupling=coupling,
17+
tau=100,
18+
)
19+
ham.recorded_elements = [7, 8]
20+
21+
22+
# @pytest.mark.skip("this test currently fails; bugfix needed")
23+
def test_windowed():
24+
exp = ws.experiment.builtin('trive')
25+
exp.w1.points = w_central # wn
26+
exp.w2.points = w_central # wn
27+
exp.d2.points = 50 # np.zeros((1,)) # fs
28+
exp.d1.points = 0 # fs
29+
exp.s1.points = exp.s2.points = dt # fs
30+
31+
exp.d1.active = exp.d2.active = False
32+
33+
# 400 time points
34+
exp.timestep = 1
35+
exp.early_buffer = 100.
36+
exp.late_buffer = 300.
37+
38+
scan = exp.run(ham, mp=False)
39+
data = scan.sig
40+
41+
# shift delay so emission is timed differently
42+
exp2 = ws.experiment.builtin('trive')
43+
exp2.w1.points = w_central # wn
44+
exp2.w2.points = w_central # wn
45+
exp2.d2.points = 50 # np.zeros((1,)) # fs
46+
exp2.d1.points = 0 # fs
47+
exp2.s1.points = exp2.s2.points = dt # fs
48+
49+
exp2.d1.active = exp2.d2.active = False
50+
51+
exp2.timestep = 1
52+
exp2.early_buffer = 100.
53+
exp2.late_buffer = 300.
54+
55+
scan2 = exp2.run(ham, mp=False, windowed=True)
56+
data2 = scan2.sig
57+
58+
if True:
59+
fig, (ax1, ax2) = plt.subplots(nrows=2)
60+
ax1.plot(data.time[:], data.channels[0][:].real)
61+
62+
wn = np.fft.fftfreq(n=data.time.size, d=exp.timestep) / 3e-5
63+
sig_fft = np.abs(np.fft.fft(data.channels[0][:]))
64+
ax2.plot(wn, sig_fft)
65+
66+
ax1.plot(data2.time[:], data2.channels[0][:].real)
67+
# ax1.plot(data2.time[:], data2.channels[0][:].imag)
68+
69+
wn2 = np.fft.fftfreq(n=data.time.size, d=exp.timestep) / 3e-5
70+
sig_fft2 = np.abs(np.fft.fft(data.channels[0][:]))
71+
ax2.plot(wn2, sig_fft2)
72+
73+
ax2.set_xlim(-4000, -2000)
74+
75+
plt.show()
76+
77+
assert data2.time.size == data.time.size
78+
assert np.all(np.isclose(data2.channels[0][:], data.channels[0][:]))
79+
80+
81+
def test_frequency():
82+
83+
exp = ws.experiment.builtin('trive')
84+
exp.w1.points = w_central # wn
85+
exp.w2.points = w_central # wn
86+
exp.d2.points = 0 # np.zeros((1,)) # fs
87+
exp.d1.points = 0 # fs
88+
exp.s1.points = exp.s2.points = dt # fs
89+
90+
exp.d1.active = exp.d2.active = False
91+
92+
# 400 time points
93+
exp.timestep = 1
94+
exp.early_buffer = 100.
95+
exp.late_buffer = 300.
96+
97+
scan = exp.run(ham, mp=False)
98+
data = scan.sig
99+
wn = np.fft.fftfreq(n=data.time.size, d=exp.timestep) / 3e-5
100+
sig_fft = np.abs(np.fft.fft(data.channels[0][:]))
101+
102+
if False:
103+
fig, (ax1, ax2) = plt.subplots(nrows=2)
104+
ax2.plot(wn, sig_fft)
105+
plt.show()
106+
107+
assert np.abs(wn[np.argmax(sig_fft)] + w_central) < np.abs(wn[1] - wn[0])
108+
109+
110+
if __name__ == "__main__":
111+
test_windowed() # fails atm
112+
test_frequency()

0 commit comments

Comments
 (0)