|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +from scipy.stats import pearsonr, spearmanr |
| 4 | + |
| 5 | +from sklearn.model_selection import KFold, ShuffleSplit |
| 6 | +from sklearn.datasets import load_breast_cancer, load_diabetes |
| 7 | + |
| 8 | +from photonai import Hyperpipe, PipelineElement |
| 9 | +from photonai.helper.photon_base_test import PhotonBaseTest |
| 10 | + |
| 11 | +from photonai.modelwrapper.cpm_feature_selection import CPMFeatureSelection |
| 12 | + |
| 13 | + |
| 14 | +class CPMFeatureSelectionTest(PhotonBaseTest): |
| 15 | + |
| 16 | + @classmethod |
| 17 | + def setUpClass(cls) -> None: |
| 18 | + cls.file = __file__ |
| 19 | + super(CPMFeatureSelectionTest, cls).setUpClass() |
| 20 | + |
| 21 | + def setUp(self): |
| 22 | + super(CPMFeatureSelectionTest, self).setUp() |
| 23 | + self.X_classif, self.y_classif = load_breast_cancer(return_X_y=True) |
| 24 | + self.X_regr, self.y_regr = load_diabetes(return_X_y=True) |
| 25 | + self.pipe_classif = Hyperpipe("cpm_feature_selection_pipe_classif", |
| 26 | + outer_cv=ShuffleSplit(test_size=0.2, n_splits=1, random_state=15), |
| 27 | + inner_cv= KFold(n_splits=3, shuffle=True, random_state=15), |
| 28 | + metrics=["accuracy"], best_config_metric="accuracy", |
| 29 | + project_folder=self.tmp_folder_path) |
| 30 | + self.pipe_regr = Hyperpipe("cpm_feature_selection_pipe_regr", |
| 31 | + outer_cv=ShuffleSplit(test_size=0.2, n_splits=1, random_state=15), |
| 32 | + inner_cv= KFold(n_splits=3, shuffle=True, random_state=15), |
| 33 | + metrics=["mean_absolute_error"], best_config_metric="mean_absolute_error", |
| 34 | + project_folder=self.tmp_folder_path) |
| 35 | + |
| 36 | + def test_cpm_regression(self): |
| 37 | + self.pipe_regr += PipelineElement('CPMFeatureSelection', hyperparameters={}) |
| 38 | + self.pipe_regr += PipelineElement('LinearRegression') |
| 39 | + self.pipe_regr.fit(self.X_regr, self.y_regr) |
| 40 | + |
| 41 | + def test_cpm_classification(self): |
| 42 | + self.pipe_classif += PipelineElement('CPMFeatureSelection', |
| 43 | + hyperparameters={'corr_method': ['pearson', 'spearman']}) |
| 44 | + self.pipe_classif += PipelineElement('LogisticRegression') |
| 45 | + self.pipe_classif.fit(self.X_classif, self.y_classif) |
| 46 | + |
| 47 | + def test_columnwise_correlation(self): |
| 48 | + for cpm_corr_method, scipy_corr_method in [(CPMFeatureSelection._columnwise_pearson, pearsonr), |
| 49 | + (CPMFeatureSelection._columnwise_spearman, spearmanr)]: |
| 50 | + r_values, p_values = cpm_corr_method(self.X_classif, self.y_classif) |
| 51 | + r_scipy_first = scipy_corr_method(self.X_classif[:, 0], self.y_classif) |
| 52 | + r_scipy_last = scipy_corr_method(self.X_classif[:, -1], self.y_classif) |
| 53 | + self.assertAlmostEqual(r_values[0], r_scipy_first.statistic) |
| 54 | + self.assertAlmostEqual(p_values[0], r_scipy_first.pvalue) |
| 55 | + self.assertAlmostEqual(r_values[-1], r_scipy_last.statistic) |
| 56 | + self.assertAlmostEqual(p_values[-1], r_scipy_last.pvalue) |
| 57 | + |
| 58 | + def test_cpm_inverse(self): |
| 59 | + cpm = PipelineElement('CPMFeatureSelection', |
| 60 | + hyperparameters={'corr_method': ['pearson']}) |
| 61 | + |
| 62 | + cpm.fit(self.X_classif, self.y_classif) |
| 63 | + X_transformed, _, _ = cpm.transform(self.X_classif) |
| 64 | + X_back, _, _ = cpm.inverse_transform(np.asarray([3, -3])) |
| 65 | + self.assertEqual(X_transformed.shape[1], 2) |
| 66 | + self.assertEqual(self.X_classif.shape[1], X_back.shape[1]) |
| 67 | + self.assertEqual(np.min(X_back), -3) |
| 68 | + self.assertEqual(np.max(X_back), 3) |
| 69 | + |
| 70 | + with self.assertRaises(ValueError): |
| 71 | + cpm.inverse_transform(X_transformed) |
| 72 | + |
| 73 | + with self.assertRaises(ValueError): |
| 74 | + cpm.inverse_transform(X_transformed.T) |
| 75 | + |
| 76 | + def test_wrong_corr_method(self): |
| 77 | + with self.assertRaises(NotImplementedError): |
| 78 | + PipelineElement('CPMFeatureSelection', corr_method='Pearsons') |
| 79 | + |
| 80 | + def test_cpm_transform(self): |
| 81 | + element = PipelineElement('CPMFeatureSelection', hyperparameters={}) |
| 82 | + element.fit(self.X_classif, self.y_classif) |
| 83 | + X_transformed, _, _ = element.transform(self.X_classif) |
| 84 | + self.assertEqual(X_transformed.shape[0], self.X_classif.shape[0]) |
| 85 | + self.assertEqual(X_transformed.shape[1], 2) |
0 commit comments