-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
235 lines (189 loc) · 7.22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
Utility functions for this method-merge-tests repository.
"""
import numpy as np
import pandas as pd
import pints
def run_replicates(iterations, n_replicates, test, parallel=False):
"""
Runs ``test(i)`` for all entries ``i`` in ``iterations``, repeating each
test ``n_replicates`` times.
The argument ``test`` is expected to return a dictionary of (scalar valued)
results.
The returned value is a pandas DataFrame with
``len(iterations) * n_replicates`` rows. Each column contains an index, the
number of iterations performed as ``iterations``, the index of the repeat
as ``replicate``, followed by the entries of the corresponding test
result.
Parallel evaluation can be enabled by setting ``parallel`` to ``True`` or
to the number of worker processes to use. However, this can cause issues in
Jupyter notebooks.
"""
df = pd.DataFrame(index=np.arange(len(iterations) * n_replicates))
df['iterations'] = np.repeat(iterations, n_replicates)
df['replicate'] = np.tile(np.arange(n_replicates), len(iterations))
# Evaluate the cases in reverse order:
# - Assuming that the iterations are sorted from low to high, the longest
# running tasks are at the end.
# - If we start with short tasks and end with long ones, the last process
# to start will be the last one to finish.
# - Instead, do the long running tasks first, and then whoever finishes
# first can start on the shorter tasks.
iterations = list(reversed(df['iterations']))
results = pints.evaluate(test, iterations, parallel=parallel)
results.reverse()
assert len(results) > 0, 'Empty result set generated'
for key in results[0].keys():
df[key] = np.array([r[key] for r in results], copy=False)
return df
def ecdf_norm_plotter(draws, normal_sd, x=np.linspace(-5, 5, 100)):
import matplotlib.pyplot as plt
from scipy.stats import norm
from statsmodels.distributions.empirical_distribution import ECDF
ecdf_fun = ECDF(draws)
ecdf = [ecdf_fun(y) for y in x]
cdf = [norm.cdf(y, 0, normal_sd) for y in x]
x1 = np.linspace(0, 1, 100)
y = [y for y in x1]
plt.scatter(ecdf, cdf)
plt.plot(x1, y, 'k-')
plt.xlabel('Estimated cdf')
plt.ylabel('True cdf')
plt.show()
def technicolor_dreamline(ax, x, y, z=None, lw=1):
"""
Draws a multi-coloured line on a set of matplotlib axes ``ax``.
The points to plot should be passed in as ``x`` and ``y``, and optionally
``z`` for a 3d plot.
Line width can be set with ``lw``,
Code adapted from: https://github.com/CardiacModelling/FourWaysOfFitting
"""
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Line3DCollection
colormap = 'jet'
cmap_fix = 1
x = np.asarray(x)
y = np.asarray(y)
if z is not None:
z = np.asarray(z)
# Invisible plot for automatic x & y limits
if z is None:
ax.plot(x, y, alpha=0)
else:
ax.plot(x, y, z, alpha=0)
# Create collection of line segments
stride = max(1, int(len(x) / 1000))
n = 1 + (len(x) - 1) // stride
segments = []
for i in range(n):
lo = i * stride
hi = lo + stride + 1
xs = x[lo:hi]
ys = y[lo:hi]
if z is None:
segments.append(np.vstack((xs, ys)).T)
else:
zs = z[lo:hi]
segments.append(np.vstack((xs, ys, zs)).T)
n = len(segments)
if z is None:
Collection = matplotlib.collections.LineCollection
else:
Collection = Line3DCollection
cmap = plt.cm.get_cmap(colormap)
norm = matplotlib.colors.Normalize(0, cmap_fix)
idxs = np.linspace(0, 1, n)
ax.add_collection(
Collection(segments, cmap=cmap, norm=norm, array=idxs, lw=lw))
def function_between_points(
ax, f, x_true, x_found, padding=0.25, evaluations=20):
"""
Like :meth:`pints.plot.function_between_points`, but takes a matplotlib
axes as first argument.
"""
import matplotlib.pyplot as plt
# Check function and get n_parameters
if not (isinstance(f, pints.LogPDF) or isinstance(f, pints.ErrorMeasure)):
raise ValueError(
'Given function must be pints.LogPDF or pints.ErrorMeasure.')
n_param = f.n_parameters()
# Check points
point_1 = pints.vector(x_true)
point_2 = pints.vector(x_found)
del(x_true, x_found)
if not (len(point_1) == len(point_2) == n_param):
raise ValueError('Both points must have the same number of parameters'
+ ' as the given function.')
# Check padding
padding = float(padding)
if padding < 0:
raise ValueError('Padding cannot be negative.')
# Check evaluation
evaluations = int(evaluations)
if evaluations < 3:
raise ValueError('The number of evaluations must be 3 or greater.')
# Figure setting
#ax.set_xlabel('T')
ax.set_ylabel('Error')
# Generate some x-values near the given parameters
s = np.linspace(-padding, 1 + padding, evaluations)
# Direction
r = point_2 - point_1
# Calculate function with other parameters fixed
x = [point_1 + sj * r for sj in s]
y = pints.evaluate(f, x, parallel=False)
# Plot
ax.plot(s, y, color='green')
ax.axvline(0, color='#1f77b4', label='True parameters')
ax.axvline(1, color='#7f7f7f', label='Estimated parameters')
ax.legend()
def function(axes, f, x, scales=None, evaluations=20):
"""
Like :class:`pints.plot.function`, but takes a set of axes as input, and
uses a list of scales instead of lower and upper bounds.
"""
import matplotlib.pyplot as plt
# Check function and get n_parameters
if not (isinstance(f, pints.LogPDF) or isinstance(f, pints.ErrorMeasure)):
raise ValueError(
'Given function must be pints.LogPDF or pints.ErrorMeasure.')
n_param = f.n_parameters()
# Check axes
if len(axes) != n_param:
raise ValueError('Axes list must have length f.n_parameters().')
# Check point
x = pints.vector(x)
if len(x) != n_param:
raise ValueError('Point x must have length f.n_parameters().')
# Check scales
if scales is None:
# Guess boundaries based on point x
scales = x * 0.05
else:
scales = pints.vector(scales)
if len(scales) != n_param:
raise ValueError('Scales must have length f.n_parameters().')
lower = x - scales
upper = x + scales
# Check number of evaluations
evaluations = int(evaluations)
if evaluations < 1:
raise ValueError('Number of evaluations must be greater than zero.')
# Create points to plot
xs = np.tile(x, (n_param * evaluations, 1))
for j in range(n_param):
i1 = j * evaluations
i2 = i1 + evaluations
xs[i1:i2, j] = np.linspace(lower[j], upper[j], evaluations)
# Evaluate points
fs = pints.evaluate(f, xs, parallel=False)
# Create figure
axes[0].set_xlabel('Function')
for j, p in enumerate(x):
i1 = j * evaluations
i2 = i1 + evaluations
axes[j].plot(xs[i1:i2, j], fs[i1:i2], c='green', label='Function')
axes[j].axvline(p, c='blue', label='Value')
axes[j].set_xlabel('Parameter ' + str(1 + j))
axes[j].legend()