Skip to content

Commit 641ee88

Browse files
committed
PEP extraction algorithms: updated tests to work with new implementation
1 parent a4214a4 commit 641ee88

11 files changed

+32
-54
lines changed

src/biopsykit/signals/icg/event_extraction/_b_point_forouzanfar2018.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _get_most_prominent_monotonic_increasing_segment(icg_segment: pd.Series, hei
241241
].index
242242

243243
end_index_drop_rule_b = end_index_drop_rule_b.union(end_index_drop_rule_b - 1)
244-
monotony_df = monotony_df.drop(index=end_index_drop_rule_b)
244+
monotony_df = monotony_df.drop(index=monotony_df.iloc[end_index_drop_rule_b].index)
245245

246246
# Select the monotonic segment with the highest amplitude difference
247247
start_sample = 0

tests/test_algorithms/test_b_point_extraction_arbol2017.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_regression_extract_series(self):
8080
@staticmethod
8181
def _get_regression_reference():
8282
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_b_point_reference_arbol2017.csv"), index_col=0)
83-
data = data.convert_dtypes(infer_objects=True)
83+
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
8484
return data
8585

8686
@staticmethod

tests/test_algorithms/test_b_point_extraction_debski1993.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_regression_extract_series(self):
8080
@staticmethod
8181
def _get_regression_reference():
8282
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_b_point_reference_debski1993.csv"), index_col=0)
83-
data = data.convert_dtypes(infer_objects=True)
83+
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
8484
return data
8585

8686
@staticmethod

tests/test_algorithms/test_b_point_extraction_drost2022.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_regression_extract_series(self):
8080
@staticmethod
8181
def _get_regression_reference():
8282
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_b_point_reference_drost2022.csv"), index_col=0)
83-
data = data.convert_dtypes(infer_objects=True)
83+
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
8484
return data
8585

8686
@staticmethod

tests/test_algorithms/test_b_point_extraction_forouzanfar2018.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_regression_extract_series(self):
8080
@staticmethod
8181
def _get_regression_reference():
8282
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_b_point_reference_forouzanfar2018.csv"), index_col=0)
83-
data = data.convert_dtypes(infer_objects=True)
83+
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
8484
return data
8585

8686
@staticmethod

tests/test_algorithms/test_c_point_extraction_scipy_findpeaks.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_regression_extract_series(self):
6666
@staticmethod
6767
def _get_regression_reference():
6868
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_c_point_reference_scipy_findpeaks.csv"), index_col=0)
69-
data = data.convert_dtypes(infer_objects=True)
69+
data = data.astype({"c_point_sample": "Int64", "nan_reason": "object"})
7070
return data
7171

7272
@staticmethod
@@ -78,7 +78,6 @@ class TestCPointExtractionSciPyFindpeaksParameters:
7878
def setup(
7979
self,
8080
window_c_correction: Optional[int] = 3,
81-
save_candidates: Optional[bool] = False,
8281
):
8382
# Sample ECG data
8483
self.ecg_data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_ecg.csv"), index_col=0)
@@ -88,9 +87,7 @@ def setup(
8887
self.heartbeats = self.segmenter.extract(
8988
ecg=self.ecg_data, sampling_rate_hz=self.sampling_rate_hz
9089
).heartbeat_list_
91-
self.extract_algo = CPointExtractionScipyFindPeaks(
92-
window_c_correction=window_c_correction, save_candidates=save_candidates
93-
)
90+
self.extract_algo = CPointExtractionScipyFindPeaks(window_c_correction=window_c_correction)
9491
self.test_case = unittest.TestCase()
9592

9693
@pytest.mark.parametrize(
@@ -102,20 +99,6 @@ def test_extract_window_c_correction(self, window_c_correction):
10299

103100
self.extract_algo.extract(icg=self.icg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
104101

105-
print(self.extract_algo.points_)
106-
107102
assert isinstance(self.extract_algo.points_, pd.DataFrame)
108103
assert "c_point_sample" in self.extract_algo.points_.columns
109104
assert "nan_reason" in self.extract_algo.points_.columns
110-
111-
@pytest.mark.parametrize(
112-
("save_candidates", "expected_columns"),
113-
[(True, ["c_point_sample", "nan_reason", "c_point_candidates"]), (False, ["c_point_sample", "nan_reason"])],
114-
)
115-
def test_extract_window_(self, save_candidates, expected_columns):
116-
self.setup(save_candidates=save_candidates)
117-
118-
self.extract_algo.extract(icg=self.icg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
119-
120-
assert isinstance(self.extract_algo.points_, pd.DataFrame)
121-
self.test_case.assertListEqual(expected_columns, self.extract_algo.points_.columns.tolist())

tests/test_algorithms/test_heatbeat_segmentation.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,6 @@ def test_regression_extract_variable_length_dataframe(self):
6262
_assert_is_dtype(ecg_data, pd.DataFrame)
6363

6464
self.segmenter.extract(ecg=ecg_data, sampling_rate_hz=self.sampling_rate_hz)
65-
66-
# print(self.segmenter.heartbeat_list_["start_time"].dtype)
67-
# print(reference_heartbeats["start_time"].dtype)
68-
6965
# check if the extraction is equal
7066
self._check_heartbeats_equal(reference_heartbeats, self.segmenter.heartbeat_list_)
7167

@@ -136,7 +132,7 @@ def _check_heartbeats_equal(reference_heartbeats, extracted_heartbeats):
136132
("data", "expected"),
137133
[
138134
(None, pytest.raises(ValueError)),
139-
(pd.Series([]), pytest.raises(ValueError)),
135+
(pd.Series([], dtype="Float64"), pytest.raises(ValueError)),
140136
(pd.DataFrame(), pytest.raises(ValidationError)),
141137
],
142138
)

tests/test_algorithms/test_outlier_correction_forouzanfar2018.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _get_b_point_outlier_middle(self):
8686

8787
def _get_regression_reference(self):
8888
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_icg_outlier_correction_forouzanfar2018.csv"), index_col=0)
89-
data = data.convert_dtypes(infer_objects=True)
89+
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
9090
return data
9191

9292
@staticmethod

tests/test_algorithms/test_outlier_correction_interpolation.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def test_regression_correct_outlier(self, outlier_type):
6969
b_points=b_points, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
7070
)
7171

72-
print(self.outlier_algo.points_)
73-
7472
corrected_beats = (self.b_points - self.outlier_algo.points_)["b_point_sample"] != 0
7573
corrected_beats = self.b_points.index[corrected_beats]
7674

@@ -87,7 +85,7 @@ def _get_b_point_outlier_middle(self):
8785

8886
def _get_regression_reference(self):
8987
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_icg_outlier_correction_interpolation.csv"), index_col=0)
90-
data = data.convert_dtypes(infer_objects=True)
88+
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
9189
return data
9290

9391
@staticmethod

tests/test_algorithms/test_q_wave_onset_extraction.py renamed to tests/test_algorithms/test_q_peak_extraction.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from contextlib import contextmanager
33
from pathlib import Path
44

5+
import numpy as np
56
import pandas as pd
67
import pytest
78

@@ -35,44 +36,44 @@ def test_extract(self):
3536
self.extract_algo.extract(ecg=self.ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
3637

3738
assert isinstance(self.extract_algo.points_, pd.DataFrame)
38-
assert "q_wave_onset_sample" in self.extract_algo.points_.columns
39+
assert "q_peak_sample" in self.extract_algo.points_.columns
3940
assert "nan_reason" in self.extract_algo.points_.columns
4041

41-
# add regression test to check if the extracted q-wave onsets match with the saved reference
42+
# add regression test to check if the extracted q-peaks match with the saved reference
4243
def test_regression_extract_dataframe(self):
4344
self.setup()
4445

4546
ecg_data = self.ecg_data
4647
_assert_is_dtype(ecg_data, pd.DataFrame)
4748

48-
reference_q_wave_onsets = self._get_regression_reference()
49+
reference_q_peaks = self._get_regression_reference()
4950
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
5051

51-
self._check_q_wave_onset_equal(reference_q_wave_onsets, self.extract_algo.points_)
52+
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)
5253

5354
def test_regression_extract_series(self):
5455
self.setup()
5556

5657
ecg_data = self.ecg_data.squeeze()
5758
_assert_is_dtype(ecg_data, pd.Series)
5859

59-
reference_q_wave_onsets = self._get_regression_reference()
60+
reference_q_peaks = self._get_regression_reference()
6061
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
6162

62-
self._check_q_wave_onset_equal(reference_q_wave_onsets, self.extract_algo.points_)
63+
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)
6364

6465
@staticmethod
6566
def _get_regression_reference():
66-
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_q_wave_onset_reference_neurokit_dwt.csv"), index_col=0)
67-
data = data.convert_dtypes(infer_objects=True)
67+
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_q_peak_reference_neurokit_dwt.csv"), index_col=0)
68+
data = data.astype({"q_peak_sample": "Int64", "nan_reason": "object"})
6869
return data
6970

7071
@staticmethod
71-
def _check_q_wave_onset_equal(reference_heartbeats, extracted_heartbeats):
72+
def _check_q_peaks_equal(reference_heartbeats, extracted_heartbeats):
7273
pd.testing.assert_frame_equal(reference_heartbeats, extracted_heartbeats)
7374

7475

75-
class TestQWaveOnsetExtractionVanLien2013:
76+
class TestQPeakExtractionVanLien2013:
7677
def setup(self, time_interval_ms: int = 40):
7778
# Sample ECG data
7879
self.ecg_data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_ecg.csv"), index_col=0)
@@ -98,9 +99,9 @@ def test_extract(self):
9899
self.extract_algo.extract(ecg=self.ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
99100

100101
assert isinstance(self.extract_algo.points_, pd.DataFrame)
101-
assert "q_wave_onset_sample" in self.extract_algo.points_.columns
102+
assert "q_peak_sample" in self.extract_algo.points_.columns
102103

103-
# add regression test to check if the extracted q-wave onsets match with the saved reference
104+
# add regression test to check if the extracted q-peaks match with the saved reference
104105
@pytest.mark.parametrize(
105106
("time_interval_ms"),
106107
[34, 36, 38, 40],
@@ -111,9 +112,9 @@ def test_regression_extract_dataframe(self, time_interval_ms):
111112
ecg_data = self.ecg_data
112113
_assert_is_dtype(ecg_data, pd.DataFrame)
113114

114-
reference_q_wave_onsets = self._get_regression_reference(time_interval_ms)
115+
reference_q_peaks = self._get_regression_reference(time_interval_ms)
115116
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
116-
self._check_q_wave_onset_equal(reference_q_wave_onsets, self.extract_algo.points_)
117+
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)
117118

118119
@pytest.mark.parametrize(
119120
("time_interval_ms"),
@@ -125,20 +126,20 @@ def test_regression_extract_series(self, time_interval_ms):
125126
ecg_data = self.ecg_data.squeeze()
126127
_assert_is_dtype(ecg_data, pd.Series)
127128

128-
reference_q_wave_onsets = self._get_regression_reference(time_interval_ms)
129+
reference_q_peaks = self._get_regression_reference(time_interval_ms)
129130
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
130-
self._check_q_wave_onset_equal(reference_q_wave_onsets, self.extract_algo.points_)
131+
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)
131132

132133
def _get_regression_reference(self, time_interval_ms: int = 40):
133134
data = pd.read_csv(
134135
TEST_FILE_PATH.joinpath("pep_test_heartbeat_reference_variable_length.csv"), index_col=0, parse_dates=True
135136
)
136-
data = data.convert_dtypes(infer_objects=True)
137137
data = data[["r_peak_sample"]] - int((time_interval_ms / self.sampling_rate_hz) * 1000)
138-
data.columns = ["q_wave_onset_sample"]
139-
138+
data = data.assign(nan_reason=np.NAN)
139+
data.columns = ["q_peak_sample", "nan_reason"]
140+
data = data.astype({"q_peak_sample": "Int64", "nan_reason": "object"})
140141
return data
141142

142143
@staticmethod
143-
def _check_q_wave_onset_equal(reference_heartbeats, extracted_heartbeats):
144+
def _check_q_peaks_equal(reference_heartbeats, extracted_heartbeats):
144145
pd.testing.assert_frame_equal(reference_heartbeats, extracted_heartbeats)

tests/test_data/pep/pep_test_q_wave_onset_reference_neurokit_dwt.csv renamed to tests/test_data/pep/pep_test_q_peak_reference_neurokit_dwt.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
heartbeat_id,q_wave_onset_sample,nan_reason
1+
heartbeat_id,q_peak_sample,nan_reason
22
0,423,
33
1,1012,
44
2,1614,

0 commit comments

Comments
 (0)