forked from gwastro/ml-mock-data-challenge-1
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
executable file
·345 lines (304 loc) · 13.8 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
#!/usr/bin/env python
"""A program to calculate the false-alarm rate as well as the sensitive
distance from a search algorithm. (Part of the MLGWSC-1)
"""
import argparse
import numpy as np
import h5py
import os
import logging
from pycbc.conversions import distance_from_chirp_distance_mchirp
def find_injection_times(fgfiles, injfile, padding_start=0, padding_end=0):
"""Determine injections which are contained in the file.
Arguments
---------
fgfiles : list of str
Paths to the files containing the foreground data (noise +
injections).
injfile : str
Path to the file containing information on the injections in the
foreground files.
padding_start : {float, 0}
The amount of time (in seconds) at the start of each segment
where no injections are present.
padding_end : {float, 0}
The amount of time (in seconds) at the end of each segment
where no injections are present.
Returns
-------
duration:
A float representing the total duration (in seconds) of all
foreground files.
bool-indices:
A 1D array containing bools that specify which injections are
contained in the provided foreground files.
"""
duration = 0
times = []
for fpath in fgfiles:
with h5py.File(fpath, 'r') as fp:
det = list(fp.keys())[0]
for key in fp[det].keys():
ds = fp[f'{det}/{key}']
start = ds.attrs['start_time']
end = start + len(ds) * ds.attrs['delta_t']
duration += end - start
start += padding_start
end -= padding_end
if end > start:
times.append([start, end])
with h5py.File(injfile, 'r') as fp:
injtimes = fp['tc'][()]
ret = np.full((len(times), len(injtimes)), False)
for i, (start, end) in enumerate(times):
ret[i] = np.logical_and(start <= injtimes, injtimes <= end)
return duration, np.any(ret, axis=0)
def find_closest_index(array, value, assume_sorted=False):
"""Find the index of the closest element in the array for the given
value(s).
Arguments
---------
array : np.array
1D numpy array.
value : number or np.array
The value(s) of which the closest array element should be found.
assume_sorted : {bool, False}
Assume that the array is sorted. May improve evaluation speed.
Returns
-------
indices:
Array of indices. The length is determined by the length of
value. Each index specifies the element in array that is closest
to the value at the same position.
"""
if len(array) == 0:
raise ValueError('Cannot find closest index for empty input array.')
if not assume_sorted:
array = array.copy()
array.sort()
ridxs = np.searchsorted(array, value, side='right')
lidxs = np.maximum(ridxs - 1, 0)
comp = np.fabs(array[lidxs] - value) < \
np.fabs(array[np.minimum(ridxs, len(array) - 1)] - value)
lisbetter = np.logical_or((ridxs == len(array)), comp)
ridxs[lisbetter] -= 1
return ridxs
def mchirp(mass1, mass2):
return (mass1 * mass2) ** (3. / 5.) / (mass1 + mass2) ** (1. / 5.)
def get_stats(fgevents, bgevents, injparams, duration=None, chirp_distance=False):
"""Calculate the false-alarm rate and sensitivity of a search
algorithm.
Arguments
---------
fgevents : np.array
A numpy array with three rows. The first row has to contain the
times returned by the search algorithm where it believes to have
found a true signal. The second row contains a ranking statistic
like quantity for each time. The third row contains the maxmimum
distance to an injection for the given event to be counted as
true positive. The values have to be determined on the
foreground data, i.e. noise plus additive signals.
bgevents : np.array
A numpy array with three rows. The first row has to contain the
times returned by the search algorithm where it believes to have
found a true signal. The second row contains a ranking statistic
like quantity for each time. The third row contains the maxmimum
distance to an injection for the given event to be counted as
true positive. The values have to be determined on the
background data, i.e. pure noise.
injparams : dict
A dictionary containing at least two entries with keys `tc` and
`distance`. Both entries have to be numpy arrays of the same
length. The entry `tc` contains the times at which injections
were made in the foreground. The entry `distance` contains the
according luminosity distances of these injections.
duration : {None or float, None}
The duration of the analyzed background. If None the injections
are used to infer the duration.
Returns
-------
dict:
Returns a dictionary, where each key-value pair specifies some
statistic. The most important are the keys `far` and
`sensitive-distance`.
"""
ret = {}
injtimes = injparams['tc']
dist = injparams['distance']
if chirp_distance:
massc = mchirp(injparams['mass1'], injparams['mass2'])
if duration is None:
duration = injtime.max() - injtimes.min()
logging.info('Sorting foreground event times')
sidxs = fgevents[0].argsort()
fgevents = fgevents.T[sidxs].T
logging.info('Finding injection times closest to event times')
idxs = find_closest_index(injtimes, fgevents[0])
diff = np.abs(injtimes[idxs] - fgevents[0])
logging.info('Finding true- and false-positives')
tpidxs = np.arange(len(fgevents[0]))[diff <= fgevents[2]]
fpidxs = np.arange(len(fgevents[0]))[diff > fgevents[2]]
tpevents = fgevents.T[tpidxs].T
fpevents = fgevents.T[fpidxs].T
ret['fg-events'] = fgevents
ret['found-indices'] = np.arange(len(injtimes))[idxs]
ret['missed-indices'] = np.setdiff1d(np.arange(len(injtimes)),
ret['found-indices'])
ret['true-positive-event-indices'] = tpidxs
ret['false-positive-event-indices'] = fpidxs
ret['sorting-indices'] = sidxs
ret['true-positive-diffs'] = diff[tpidxs]
ret['false-positive-diffs'] = diff[fpidxs]
ret['true-positives'] = tpevents
ret['false-positives'] = fpevents
#Calculate foreground FAR
logging.info('Calculating foreground FAR')
noise_stats = fpevents[1].copy()
noise_stats.sort()
fgfar = len(noise_stats) - np.arange(len(noise_stats)) - 1
fgfar = fgfar / duration
ret['fg-far'] = fgfar
sfaridxs = fgfar.argsort()
#Calculate background FAR
logging.info('Calculating background FAR')
noise_stats = bgevents[1].copy()
noise_stats.sort()
far = len(noise_stats) - np.arange(len(noise_stats)) - 1
far = far / duration
ret['far'] = far
#Calculate sensitivity
#CARE! THIS APPLIES ONLY WHEN THE DISTRIBUTION IS CHOSEN CORRECTLY
logging.info('Calculating sensitivity')
sidxs = tpevents[1].argsort()
tp_sort = tpevents[1][sidxs]
if chirp_distance:
found_mchirp_total = massc[idxs[tpidxs]]
found_mchirp_total = found_mchirp_total[sidxs]
mchirp_max = massc.max()
max_distance = dist.max()
vtot = (4. / 3.) * np.pi * max_distance**3.
Ninj = len(dist)
if chirp_distance:
mc_norm = mchirp_max ** (5. / 2.) * len(massc)
else:
mc_norm = Ninj
prefactor = vtot / mc_norm
nfound = len(tp_sort) - np.searchsorted(tp_sort, noise_stats,
side='right')
if chirp_distance:
#Get found chirp-mass indices for given threshold
fidxs = np.searchsorted(tp_sort, noise_stats, side='right')
found_mchirp_total = np.flip(found_mchirp_total)
#Calculate sum(found_mchirp ** (5/2))
#with found_mchirp = found_mchirp_total[i:]
#and i looped over fidxs
#Code below is a vectorized form of that
cumsum = np.flip(np.cumsum(found_mchirp_total ** (5./2.)))
cumsum = np.concatenate([cumsum, np.zeros(1)])
mc_sum = cumsum[fidxs]
Ninj = np.sum((mchirp_max / massc) ** (5. / 2.))
cumsumsq = np.flip(np.cumsum(found_mchirp_total ** 5))
cumsumsq = np.concatenate([cumsumsq, np.zeros(1)])
sample_variance_prefactor = cumsumsq[fidxs]
sample_variance = sample_variance_prefactor / Ninj - (mc_sum / Ninj) ** 2
else:
mc_sum = nfound
sample_variance = nfound / Ninj - (nfound / Ninj) ** 2
vol = prefactor * mc_sum
vol_err = prefactor * (Ninj * sample_variance) ** 0.5
rad = (3 * vol / (4 * np.pi))**(1. / 3.)
ret['sensitive-volume'] = vol
ret['sensitive-distance'] = rad
ret['sensitive-volume-error'] = vol_err
ret['sensitive-fraction'] = nfound / Ninj
return ret
def main(doc):
parser = argparse.ArgumentParser(description=doc)
parser.add_argument('--injection-file', type=str, required=True,
help=("Path to the file containing information "
"on the injections. (The file returned by"
"`generate_data.py --output-injection-file`"))
parser.add_argument('--foreground-events', type=str, nargs='+', required=True,
help=("Path to the file containing the events "
"returned by the search on the foreground "
"data set as returned by "
"`generate_data.py --output-foreground-file`."))
parser.add_argument('--foreground-files', type=str, nargs='+', required=True,
help=("Path to the file containing the analyzed "
"foreground data output by"
"`generate_data.py --output-foreground-file`."))
parser.add_argument('--background-events', type=str, required=True, nargs='+',
help=("Path to the file containing the events "
"returned by the search on the background"
"data set as returned by "
"`generate_data.py --output-background-file`."))
parser.add_argument('--output-file', type=str, required=True,
help=("Path at which to store the output HDF5 "
"file. (Path must end in `.hdf`)"))
parser.add_argument('--verbose', action='store_true',
help="Print update messages.")
parser.add_argument('--force', action='store_true',
help="Overwrite existing files.")
args = parser.parse_args()
#Setup logging
log_level = logging.INFO if args.verbose else logging.WARN
logging.basicConfig(format='%(levelname)s | %(asctime)s: %(message)s',
level=log_level, datefmt='%d-%m-%Y %H:%M:%S')
#Sanity check arguments here
if os.path.splitext(args.output_file)[1] != '.hdf':
raise ValueError(f'The output file must have the extension `.hdf`.')
if os.path.isfile(args.output_file) and not args.force:
raise IOError(f'The file {args.output_file} already exists. Set the flag `force` to overwrite it.')
#Find indices contained in foreground file
logging.info(f'Finding injections contained in data')
padding_start, padding_end = 30, 30
dur, idxs = find_injection_times(args.foreground_files,
args.injection_file,
padding_start=padding_start,
padding_end=padding_end)
if np.sum(idxs) == 0:
msg = 'The foreground data contains no injections! '
msg += 'Probably a too small section of data was generated. '
msg += 'Please make sure to generate at least {} seconds of data. '
msg += 'Otherwise a sensitive distance cannot be calculated.'
msg = msg.format(padding_start + padding_end + 24)
raise RuntimeError(msg)
#Read injection parameters
logging.info(f'Reading injections from {args.injection_file}')
injparams = {}
with h5py.File(args.injection_file, 'r') as fp:
injparams['tc'] = fp['tc'][()][idxs]
injparams['distance'] = fp['distance'][()][idxs]
injparams['mass1'] = fp['mass1'][()][idxs]
injparams['mass2'] = fp['mass2'][()][idxs]
use_chirp_distance = 'chirp_distance' in fp.keys()
#Read foreground events
logging.info(f'Reading foreground events from {args.foreground_events}')
fg_events = []
for fpath in args.foreground_events:
with h5py.File(fpath, 'r') as fp:
fg_events.append(np.vstack([fp['time'],
fp['stat'],
fp['var']]))
fg_events = np.concatenate(fg_events, axis=-1)
#Read background events
logging.info(f'Reading background events from {args.background_events}')
bg_events = []
for fpath in args.background_events:
with h5py.File(fpath, 'r') as fp:
bg_events.append(np.vstack([fp['time'],
fp['stat'],
fp['var']]))
bg_events = np.concatenate(bg_events, axis=-1)
stats = get_stats(fg_events, bg_events, injparams,
duration=dur,
chirp_distance=use_chirp_distance)
#Store results
logging.info(f'Writing output to {args.output_file}')
mode = 'w' if args.force else 'x'
with h5py.File(args.output_file, mode) as fp:
for key, val in stats.items():
fp.create_dataset(key, data=np.array(val))
return
if __name__ == "__main__":
main(__doc__)