@@ -14,7 +14,7 @@ def peak_info(self):
1414 Fixture (created only once per test run) of a small
1515 ground truth recording with peaks and peak locations calculated.
1616 """
17- recording , sorting = generate_ground_truth_recording (num_units = 5 , num_channels = 16 , durations = [20 , 9 ], seed = 0 )
17+ recording , _ = generate_ground_truth_recording (num_units = 5 , num_channels = 16 , durations = [20 , 9 ], seed = 0 )
1818 peaks = detect_peaks (recording )
1919
2020 peak_locations = localize_peaks (
@@ -25,7 +25,7 @@ def peak_info(self):
2525 method = "center_of_mass" ,
2626 )
2727
28- return (recording , sorting , peaks , peak_locations )
28+ return (recording , peaks , peak_locations )
2929
3030 def data_from_widget (self , widget , axes_idx ):
3131 """
@@ -40,7 +40,7 @@ def test_peaks_on_probe_main(self, peak_info):
4040 Plot all peaks, and check every peak is plot.
4141 Check the labels are corect.
4242 """
43- recording , sorting , peaks , peak_locations = peak_info
43+ recording , peaks , peak_locations = peak_info
4444
4545 widget = plot_peaks_on_probe (recording , peaks , peak_locations , decimate = 1 )
4646
@@ -58,7 +58,7 @@ def test_segment_selection(self, peak_info, segment_index):
5858 from a sepecific segment, that only peaks
5959 from that segment are plot.
6060 """
61- recording , sorting , peaks , peak_locations = peak_info
61+ recording , peaks , peak_locations = peak_info
6262
6363 widget = plot_peaks_on_probe (
6464 recording ,
@@ -82,7 +82,7 @@ def test_multiple_inputs(self, peak_info):
8282 Check that these separate peaks / peak locations
8383 are plot on different axes.
8484 """
85- recording , sorting , peaks , peak_locations = peak_info
85+ recording , peaks , peak_locations = peak_info
8686
8787 half_num_peaks = int (peaks .shape [0 ] / 2 )
8888
@@ -107,16 +107,14 @@ def test_multiple_inputs(self, peak_info):
107107
108108 assert np .array_equal (np .sort (locs_change ["y" ]), np .sort (ax_1_y_data ))
109109
110- def test_times (self , peak_info ):
110+ def test_times_all (self , peak_info ):
111111 """
112112 Check that when the times of peaks to plot is restricted,
113113 only peaks within the given time range are plot. Set the
114114 limits just before and after the second peak, and check only
115115 that peak is plot.
116116 """
117- recording , sorting , peaks , peak_locations = peak_info
118-
119- peak_times_ms = recording .sample_index_to_time (peaks ["sample_index" ], segment_index = 0 ) * 1000
117+ recording , peaks , peak_locations = peak_info
120118
121119 peak_idx = 1
122120 peak_cutoff_low = peaks ["sample_index" ][peak_idx ] - 1
@@ -137,13 +135,104 @@ def test_times(self, peak_info):
137135
138136 assert np .array_equal ([peak_locations [peak_idx ]["y" ]], ax_y_data )
139137
138+ def test_times_per_segment (self , peak_info ):
139+ """
140+ Test that the time bounds for multi-segment recordings
141+ with different times are handled properly. The time bounds
142+ given must respect the times for each segment. Here, we build
143+ two segments with times 0-100s and 100-200s. We set the
144+ time limits for peaks to plot as 50-150 i.e. all peaks
145+ from the second half of the first segment, and the first half
146+ of the second segment, should be plotted.
147+
148+ Recompute peaks here for completeness even though this does
149+ duplicate the fixture.
150+ """
151+ recording , _ , _ = peak_info
152+
153+ first_seg_times = np .linspace (0 , 100 , recording .get_num_samples (0 ))
154+ second_seg_times = np .linspace (100 , 200 , recording .get_num_samples (1 ))
155+
156+ recording .set_times (first_seg_times , segment_index = 0 )
157+ recording .set_times (second_seg_times , segment_index = 1 )
158+
159+ # After setting the peak times above, re-detect peaks and plot
160+ # with a time range 50-150 s
161+ peaks = detect_peaks (recording )
162+
163+ peak_locations = localize_peaks (
164+ recording ,
165+ peaks ,
166+ ms_before = 0.3 ,
167+ ms_after = 0.6 ,
168+ method = "center_of_mass" ,
169+ )
170+
171+ widget = plot_peaks_on_probe (
172+ recording ,
173+ peaks ,
174+ peak_locations ,
175+ decimate = 1 ,
176+ time_range = (
177+ 50 ,
178+ 150 ,
179+ ),
180+ )
181+
182+ # Find the peaks that are expected to be plot given the time
183+ # restriction (second half of first segment, first half of
184+ # second segment) and check that indeed the expected locations
185+ # are displayed.
186+ seg_one_num_samples = recording .get_num_samples (0 )
187+ seg_two_num_samples = recording .get_num_samples (1 )
188+
189+ okay_peaks_one = np .logical_and (
190+ peaks ["segment_index" ] == 0 , peaks ["sample_index" ] > int (seg_one_num_samples / 2 )
191+ )
192+ okay_peaks_two = np .logical_and (
193+ peaks ["segment_index" ] == 1 , peaks ["sample_index" ] < int (seg_two_num_samples / 2 )
194+ )
195+ okay_peaks = np .logical_or (okay_peaks_one , okay_peaks_two )
196+
197+ ax_y_data = self .data_from_widget (widget , 0 )[:, 1 ]
198+
199+ assert any (okay_peaks ), "someting went wrong in test generation, no peaks within the set time bounds detected"
200+
201+ assert np .array_equal (np .sort (ax_y_data ), np .sort (peak_locations [okay_peaks ]["y" ]))
202+
203+ def test_get_min_and_max_times_in_recording (self , peak_info ):
204+ """
205+ Check that the function which finds the minimum and maximum times
206+ across all segments in the recording returns correctly. First
207+ set times of the segments such that the earliest time is 50s and
208+ latest 200s. Check the function returns (50, 200).
209+ """
210+ recording , peaks , peak_locations = peak_info
211+
212+ first_seg_times = np .linspace (50 , 100 , recording .get_num_samples (0 ))
213+ second_seg_times = np .linspace (100 , 200 , recording .get_num_samples (1 ))
214+
215+ recording .set_times (first_seg_times , segment_index = 0 )
216+ recording .set_times (second_seg_times , segment_index = 1 )
217+
218+ widget = plot_peaks_on_probe (
219+ recording ,
220+ peaks ,
221+ peak_locations ,
222+ decimate = 1 ,
223+ )
224+
225+ min_max_times = widget ._get_min_and_max_times_in_recording (recording )
226+
227+ assert min_max_times == (50 , 200 )
228+
140229 def test_ylim (self , peak_info ):
141230 """
142231 Specify some y-axis limits (which is the probe height
143232 to show) and check that the plot is restricted to
144233 these limits.
145234 """
146- recording , sorting , peaks , peak_locations = peak_info
235+ recording , peaks , peak_locations = peak_info
147236
148237 widget = plot_peaks_on_probe (
149238 recording ,
@@ -163,7 +252,7 @@ def test_decimate(self, peak_info):
163252 checks the decimate argument, to ensure peaks that are
164253 plot are correctly decimated.
165254 """
166- recording , sorting , peaks , peak_locations = peak_info
255+ recording , peaks , peak_locations = peak_info
167256
168257 decimate = 5
169258
@@ -184,7 +273,7 @@ def test_errors(self, peak_info):
184273 Test all validation errors are raised when data in
185274 incorrect form is passed to the plotting function.
186275 """
187- recording , sorting , peaks , peak_locations = peak_info
276+ recording , peaks , peak_locations = peak_info
188277
189278 # All lists must be same length
190279 with pytest .raises (ValueError ) as e :
0 commit comments