Skip to content

Commit 2603db6

Browse files
committed
finalise all tests and time handling.
1 parent 5d72fa7 commit 2603db6

File tree

2 files changed

+108
-17
lines changed

2 files changed

+108
-17
lines changed

src/spikeinterface/widgets/peaks_on_probe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
104104
ax = self.axes[ax_idx]
105105
plot_probe_map(dp.recording, ax=ax)
106106

107-
time_mask = self._get_peaks_time_mask(dp.recording, time_range, fs, peaks_to_plot)
107+
time_mask = self._get_peaks_time_mask(dp.recording, time_range, peaks_to_plot)
108108

109109
if dp.segment_index is not None:
110110
segment_mask = peaks_to_plot["segment_index"] == dp.segment_index
@@ -142,7 +142,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
142142

143143
self.figure.suptitle(f"Peaks on Probe Plot")
144144

145-
def _get_peaks_time_mask(self, recording, time_range, fs, peaks_to_plot):
145+
def _get_peaks_time_mask(self, recording, time_range, peaks_to_plot):
146146
"""
147147
Return a mask of `True` where the peak is within the given time range
148148
and `False` otherwise.
@@ -162,10 +162,12 @@ def _get_peaks_time_mask(self, recording, time_range, fs, peaks_to_plot):
162162

163163
seg_mask = peaks_to_plot["segment_index"] == seg_idx
164164

165-
time_mask[seg_mask] = (t_start_sample <= peaks_to_plot[seg_mask]["sample_index"]) & (
166-
peaks_to_plot[seg_mask]["sample_index"] <= t_stop_sample
165+
time_mask[seg_mask] = (t_start_sample < peaks_to_plot[seg_mask]["sample_index"]) & (
166+
peaks_to_plot[seg_mask]["sample_index"] < t_stop_sample
167167
)
168168

169+
return time_mask
170+
169171
def _get_min_and_max_times_in_recording(self, recording):
170172
"""
171173
Find the maximum and minimum time across all segments in the recording.
@@ -180,7 +182,7 @@ def _get_min_and_max_times_in_recording(self, recording):
180182

181183
t_starts.append(segment.sample_index_to_time(0))
182184

183-
t_stops.append(segment.sample_index_to_time(segment.get_num_samples()))
185+
t_stops.append(segment.sample_index_to_time(segment.get_num_samples() - 1))
184186

185187
time_range = (np.min(t_starts), np.max(t_stops))
186188

src/spikeinterface/widgets/tests/test_peaks_on_probe.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def peak_info(self):
1414
Fixture (created only once per test run) of a small
1515
ground truth recording with peaks and peak locations calculated.
1616
"""
17-
recording, sorting = generate_ground_truth_recording(num_units=5, num_channels=16, durations=[20, 9], seed=0)
17+
recording, _ = generate_ground_truth_recording(num_units=5, num_channels=16, durations=[20, 9], seed=0)
1818
peaks = detect_peaks(recording)
1919

2020
peak_locations = localize_peaks(
@@ -25,7 +25,7 @@ def peak_info(self):
2525
method="center_of_mass",
2626
)
2727

28-
return (recording, sorting, peaks, peak_locations)
28+
return (recording, peaks, peak_locations)
2929

3030
def data_from_widget(self, widget, axes_idx):
3131
"""
@@ -40,7 +40,7 @@ def test_peaks_on_probe_main(self, peak_info):
4040
Plot all peaks, and check every peak is plot.
4141
Check the labels are corect.
4242
"""
43-
recording, sorting, peaks, peak_locations = peak_info
43+
recording, peaks, peak_locations = peak_info
4444

4545
widget = plot_peaks_on_probe(recording, peaks, peak_locations, decimate=1)
4646

@@ -58,7 +58,7 @@ def test_segment_selection(self, peak_info, segment_index):
5858
from a sepecific segment, that only peaks
5959
from that segment are plot.
6060
"""
61-
recording, sorting, peaks, peak_locations = peak_info
61+
recording, peaks, peak_locations = peak_info
6262

6363
widget = plot_peaks_on_probe(
6464
recording,
@@ -82,7 +82,7 @@ def test_multiple_inputs(self, peak_info):
8282
Check that these separate peaks / peak locations
8383
are plot on different axes.
8484
"""
85-
recording, sorting, peaks, peak_locations = peak_info
85+
recording, peaks, peak_locations = peak_info
8686

8787
half_num_peaks = int(peaks.shape[0] / 2)
8888

@@ -107,16 +107,14 @@ def test_multiple_inputs(self, peak_info):
107107

108108
assert np.array_equal(np.sort(locs_change["y"]), np.sort(ax_1_y_data))
109109

110-
def test_times(self, peak_info):
110+
def test_times_all(self, peak_info):
111111
"""
112112
Check that when the times of peaks to plot is restricted,
113113
only peaks within the given time range are plot. Set the
114114
limits just before and after the second peak, and check only
115115
that peak is plot.
116116
"""
117-
recording, sorting, peaks, peak_locations = peak_info
118-
119-
peak_times_ms = recording.sample_index_to_time(peaks["sample_index"], segment_index=0) * 1000
117+
recording, peaks, peak_locations = peak_info
120118

121119
peak_idx = 1
122120
peak_cutoff_low = peaks["sample_index"][peak_idx] - 1
@@ -137,13 +135,104 @@ def test_times(self, peak_info):
137135

138136
assert np.array_equal([peak_locations[peak_idx]["y"]], ax_y_data)
139137

138+
def test_times_per_segment(self, peak_info):
139+
"""
140+
Test that the time bounds for multi-segment recordings
141+
with different times are handled properly. The time bounds
142+
given must respect the times for each segment. Here, we build
143+
two segments with times 0-100s and 100-200s. We set the
144+
time limits for peaks to plot as 50-150 i.e. all peaks
145+
from the second half of the first segment, and the first half
146+
of the second segment, should be plotted.
147+
148+
Recompute peaks here for completeness even though this does
149+
duplicate the fixture.
150+
"""
151+
recording, _, _ = peak_info
152+
153+
first_seg_times = np.linspace(0, 100, recording.get_num_samples(0))
154+
second_seg_times = np.linspace(100, 200, recording.get_num_samples(1))
155+
156+
recording.set_times(first_seg_times, segment_index=0)
157+
recording.set_times(second_seg_times, segment_index=1)
158+
159+
# After setting the peak times above, re-detect peaks and plot
160+
# with a time range 50-150 s
161+
peaks = detect_peaks(recording)
162+
163+
peak_locations = localize_peaks(
164+
recording,
165+
peaks,
166+
ms_before=0.3,
167+
ms_after=0.6,
168+
method="center_of_mass",
169+
)
170+
171+
widget = plot_peaks_on_probe(
172+
recording,
173+
peaks,
174+
peak_locations,
175+
decimate=1,
176+
time_range=(
177+
50,
178+
150,
179+
),
180+
)
181+
182+
# Find the peaks that are expected to be plot given the time
183+
# restriction (second half of first segment, first half of
184+
# second segment) and check that indeed the expected locations
185+
# are displayed.
186+
seg_one_num_samples = recording.get_num_samples(0)
187+
seg_two_num_samples = recording.get_num_samples(1)
188+
189+
okay_peaks_one = np.logical_and(
190+
peaks["segment_index"] == 0, peaks["sample_index"] > int(seg_one_num_samples / 2)
191+
)
192+
okay_peaks_two = np.logical_and(
193+
peaks["segment_index"] == 1, peaks["sample_index"] < int(seg_two_num_samples / 2)
194+
)
195+
okay_peaks = np.logical_or(okay_peaks_one, okay_peaks_two)
196+
197+
ax_y_data = self.data_from_widget(widget, 0)[:, 1]
198+
199+
assert any(okay_peaks), "someting went wrong in test generation, no peaks within the set time bounds detected"
200+
201+
assert np.array_equal(np.sort(ax_y_data), np.sort(peak_locations[okay_peaks]["y"]))
202+
203+
def test_get_min_and_max_times_in_recording(self, peak_info):
204+
"""
205+
Check that the function which finds the minimum and maximum times
206+
across all segments in the recording returns correctly. First
207+
set times of the segments such that the earliest time is 50s and
208+
latest 200s. Check the function returns (50, 200).
209+
"""
210+
recording, peaks, peak_locations = peak_info
211+
212+
first_seg_times = np.linspace(50, 100, recording.get_num_samples(0))
213+
second_seg_times = np.linspace(100, 200, recording.get_num_samples(1))
214+
215+
recording.set_times(first_seg_times, segment_index=0)
216+
recording.set_times(second_seg_times, segment_index=1)
217+
218+
widget = plot_peaks_on_probe(
219+
recording,
220+
peaks,
221+
peak_locations,
222+
decimate=1,
223+
)
224+
225+
min_max_times = widget._get_min_and_max_times_in_recording(recording)
226+
227+
assert min_max_times == (50, 200)
228+
140229
def test_ylim(self, peak_info):
141230
"""
142231
Specify some y-axis limits (which is the probe height
143232
to show) and check that the plot is restricted to
144233
these limits.
145234
"""
146-
recording, sorting, peaks, peak_locations = peak_info
235+
recording, peaks, peak_locations = peak_info
147236

148237
widget = plot_peaks_on_probe(
149238
recording,
@@ -163,7 +252,7 @@ def test_decimate(self, peak_info):
163252
checks the decimate argument, to ensure peaks that are
164253
plot are correctly decimated.
165254
"""
166-
recording, sorting, peaks, peak_locations = peak_info
255+
recording, peaks, peak_locations = peak_info
167256

168257
decimate = 5
169258

@@ -184,7 +273,7 @@ def test_errors(self, peak_info):
184273
Test all validation errors are raised when data in
185274
incorrect form is passed to the plotting function.
186275
"""
187-
recording, sorting, peaks, peak_locations = peak_info
276+
recording, peaks, peak_locations = peak_info
188277

189278
# All lists must be same length
190279
with pytest.raises(ValueError) as e:

0 commit comments

Comments
 (0)