forked from dcampora/velopix_tracking
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraph_dfs.py
298 lines (244 loc) · 12.6 KB
/
graph_dfs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from event_model import *
class segment(object):
"""A segment for the graph dfs."""
def __init__(self, h0, h1, seg_number):
self.h0 = h0
self.h1 = h1
self.weight = 0
self.segment_number = seg_number
self.root_segment = False
def __repr__(self):
return "Segment " + str(self.segment_number) + ":\n" + \
" h0: " + str(self.h0) + "\n" + \
" h1: " + str(self.h1) + "\n" + \
" Weight: " + str(self.weight)
class graph_dfs(object):
"""This method creates a directed graph, and traverses
it with a DFS method in order to create tracks.
It grabs inspiration from the "CA" algorithm.
Steps:
0. Preorder all hits in each sensor by x,
and update their hit_number.
1. Fill candidates
index: hit index
contents: [candidate start, candidate end]
2. Create all segments, indexed by outer hit number.
3. Assign weights and get roots.
4. Depth first search.
5. Clone and ghost killing.
"""
def __init__(self, max_slopes=(0.7, 0.7), max_tolerance=(0.4, 0.4), max_scatter=0.4,
minimum_root_weight=1, weight_assignment_iterations=2, allowed_skip_sensors=1,
allow_cross_track=True, clone_ghost_killing=True):
self.__max_slopes = max_slopes
self.__max_tolerance = max_tolerance
self.__max_scatter = max_scatter
self.__minimum_root_weight = minimum_root_weight
self.__weight_assignment_iterations = weight_assignment_iterations
self.__allow_cross_track = allow_cross_track
self.__allowed_skip_sensors = allowed_skip_sensors
self.__clone_ghost_killing = clone_ghost_killing
def are_compatible_in_x(self, hit_0, hit_1):
"""Checks if two hits are compatible according
to the configured max_slope in x.
"""
hit_distance = abs(hit_1[2] - hit_0[2])
dxmax = self.__max_slopes[0] * hit_distance
return abs(hit_1[0] - hit_0[0]) < dxmax
def are_compatible_in_y(self, hit_0, hit_1):
"""Checks if two hits are compatible according
to the configured max_slope in y.
"""
hit_distance = abs(hit_1[2] - hit_0[2])
dymax = self.__max_slopes[1] * hit_distance
return abs(hit_1[1] - hit_0[1]) < dymax
def are_compatible(self, hit_0, hit_1):
"""Checks if two hits are compatible according to
the configured max_slope.
"""
return self.are_compatible_in_x(hit_0, hit_1) and self.are_compatible_in_y(hit_0, hit_1)
def check_tolerance(self, hit_0, hit_1, hit_2):
"""Checks if three hits are compatible by
extrapolating the segment conformed by the
first two hits (hit_0, hit_1) and comparing
it to the third hit.
The parameters that control this tolerance are
max_tolerance and max_scatter.
"""
td = 1.0 / (hit_1.z - hit_0.z)
txn = hit_1.x - hit_0.x
tyn = hit_1.y - hit_0.y
tx = txn * td
ty = tyn * td
dz = hit_2.z - hit_0.z
x_prediction = hit_0.x + tx * dz
dx = abs(x_prediction - hit_2.x)
tolx_condition = dx < self.__max_tolerance[0]
y_prediction = hit_0.y + ty * dz
dy = abs(y_prediction - hit_2.y)
toly_condition = dy < self.__max_tolerance[1]
scatter_num = (dx * dx) + (dy * dy)
scatter_denom = 1.0 / (hit_2.z - hit_1.z)
scatter = scatter_num * scatter_denom * scatter_denom
scatter_condition = scatter < self.__max_scatter
return tolx_condition and toly_condition and scatter_condition
def are_segments_compatible(self, seg0, seg1):
"""Checks whether two segments are compatible, applying
the tolerance check.
seg1 should start where seg0 ends
(ie. seg0.h1 and seg1.h0 should be the same).
"""
if seg0.h1 != seg1.h0:
print("Warning: seg0 h1 and seg1 h0 are not the same")
print(seg0.h1)
print(seg1.h0)
return self.check_tolerance(seg0.h0, seg0.h1, seg1.h1)
def order_hits(self, event):
"""Preorder all hits in each sensor by x,
and update their hit_number.
"""
for hit_start, hit_end in [(s.hit_start_index, s.hit_end_index) for s in event.sensors]:
event.hits[hit_start:hit_end] = sorted(event.hits[hit_start:hit_end], key=lambda h: h.x)
for h in range(0, len(event.hits)):
event.hits[h].hit_number = h
def fill_candidates(self, event):
"""Fill candidates
index: hit index
contents: {sensor_index: [candidate start, candidate end], ...}
"""
candidates = [{} for i in range(0, event.number_of_hits)]
substraction_starting_sensor = 2
if self.__allow_cross_track:
substraction_starting_sensor = 1
for s0, starting_sensor_index in zip(reversed(event.sensors[2:]), reversed(range(0, len(event.sensors) - substraction_starting_sensor))):
for h0 in s0.hits():
for missing_sensors in range(0, self.__allowed_skip_sensors + 1):
sensor_index = starting_sensor_index - missing_sensors * 2
if self.__allow_cross_track:
sensor_index = starting_sensor_index - missing_sensors
if sensor_index >= 0:
s1 = event.sensors[sensor_index]
begin_found = False
end_found = False
candidates[h0.hit_number][sensor_index] = [-1, -1]
for h1 in s1.hits():
if not begin_found and self.are_compatible_in_x(h0, h1):
candidates[h0.hit_number][sensor_index][0] = h1.hit_number
candidates[h0.hit_number][sensor_index][1] = h1.hit_number + 1
begin_found = True
elif begin_found and not self.are_compatible_in_x(h0, h1):
candidates[h0.hit_number][sensor_index][1] = h1.hit_number
end_found = True
break
if begin_found and not end_found:
candidates[h0.hit_number][sensor_index][1] = s1.hits()[-1].hit_number+1
return candidates
def populate_segments(self, event, candidates):
"""Create segments and populate compatible segments.
segments: All segments.
outer_hit_segment_list: Segment indices, indexed by outer hit.
Note: Outer hit number is the one with smaller z.
compatible_segments: Compatible segment indices, indexed by segment index.
populated_compatible_segments: Indices of all compatible_segments.
"""
segments = []
outer_hit_segment_list = [[] for _ in event.hits]
for h0_number in range(0, event.number_of_hits):
for sensor_number, sensor_candidates in iter(candidates[h0_number].items()):
for h1_number in range(sensor_candidates[0], sensor_candidates[1]):
if self.are_compatible_in_y(event.hits[h0_number], event.hits[h1_number]):
segments.append(segment(event.hits[h0_number], event.hits[h1_number], len(segments)))
outer_hit_segment_list[h1_number].append(len(segments) - 1)
compatible_segments = [[] for _ in segments]
for seg1 in segments:
for seg0_index in outer_hit_segment_list[seg1.h0.hit_number]:
seg0 = segments[seg0_index]
if self.are_segments_compatible(seg0, seg1):
compatible_segments[seg0.segment_number].append(seg1.segment_number)
populated_compatible_segments = [seg_index for seg_index in range(0, len(compatible_segments))
if len(compatible_segments[seg_index]) > 0]
return (segments, outer_hit_segment_list, compatible_segments, populated_compatible_segments)
def assign_weights_and_populate_roots(self, segments, compatible_segments, populated_compatible_segments):
"""Assigns weights to the segments according to the configured
number of iterations weight_assignment_iterations.
It also populates the root_segment attribute of segments.
"""
for _ in range(0, self.__weight_assignment_iterations):
for seg0_index in populated_compatible_segments:
segments[seg0_index].weight = max([segments[seg_number].weight for seg_number in compatible_segments[seg0_index]]) + 1
# Find out root segments
# Mark all as root
for seg0_index in populated_compatible_segments:
segments[seg0_index].root_segment = True
# Traverse the populated_compatible_segments and turn the
# found ones to False
for seg0_index in populated_compatible_segments:
for seg1_index in compatible_segments[seg0_index]:
segments[seg1_index].root_segment = False
def dfs(self, segment, segments, compatible_segments):
"""Returns tracks found extrapolating this segment,
by traversing the segments following a depth first search strategy.
"""
if len(compatible_segments[segment.segment_number]) == 0:
return [[segment.h1]]
else:
for segid in compatible_segments[segment.segment_number]:
return [[segment.h1] + dfs_segments for dfs_segments in self.dfs(segments[segid], segments, compatible_segments)]
def prune_short_tracks(self, tracks):
"""Kills clones and weak tracks with
three hits and a shared hit.
"""
used_hits = []
for t in tracks:
if len(t.hits) > 3:
for h in t.hits:
used_hits.append(h.hit_number)
return [t for t in tracks if \
len(t.hits) > 3 or \
all(h.hit_number not in used_hits for h in t.hits)
]
def print_compatible_segments(self, segments, compatible_segments, populated_compatible_segments):
"""Prints all compatible segments."""
for seg0_index in populated_compatible_segments:
seg0 = segments[seg0_index]
print("%s\nis compatible with segments \n%s\n" % (seg0, [segments[seg_index] for seg_index in compatible_segments[seg0_index]]))
def solve(self, event):
"""Solves the event according to the strategy
defined in the class definition.
"""
print("Invoking graph dfs with\n max slopes: %s\n max tolerance: %s\n\
max scatter: %s\n weight assignment iterations: %s\n minimum root weight: %s\n\
allow cross track: %s\n allowed skip sensors: %s (its behaviour depends on allow cross track)\n\
clone ghost killing: %s\n\n" % \
(self.__max_slopes, self.__max_tolerance, self.__max_scatter, self.__weight_assignment_iterations, \
self.__minimum_root_weight, self.__allow_cross_track, self.__allowed_skip_sensors, \
self.__clone_ghost_killing))
# 0. Preorder all hits in each sensor by x,
# and update their hit_number.
# Work with a copy of event
event_copy = event.copy()
self.order_hits(event_copy)
# 1. Fill candidates
# index: hit index
# contents: [candidate start, candidate end]
candidates = self.fill_candidates(event_copy)
# 2. Create all segments, indexed by outer hit number
(segments, outer_hit_segment_list, compatible_segments, populated_compatible_segments) = \
self.populate_segments(event_copy, candidates)
# self.print_compatible_segments(segments, compatible_segments, populated_compatible_segments)
# 3. Assign weights and get roots
self.assign_weights_and_populate_roots(segments, compatible_segments, populated_compatible_segments)
root_segments = [segid for segid in populated_compatible_segments \
if segments[segid].root_segment == True and \
segments[segid].weight >= self.__minimum_root_weight]
# print("Found %d root segments" % (len(root_segments)))
# 4. Depth first search
tracks = []
for segment_id in root_segments:
root_segment = segments[segment_id]
tracks += [track([root_segment.h0] + dfs_segments) for dfs_segments in self.dfs(root_segment, segments, compatible_segments)]
# 5. Clone and ghost killing
# Note: For now, just short track killing
if self.__clone_ghost_killing:
tracks = self.prune_short_tracks(tracks)
return tracks