Skip to content

Commit

Permalink
Started moving some testing utilities into the actual tpcp package'
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Aug 30, 2023
1 parent 6e098cd commit 83ac999
Show file tree
Hide file tree
Showing 6 changed files with 375 additions and 111 deletions.
94 changes: 0 additions & 94 deletions tests/mixins/test_algorithm_mixin.py

This file was deleted.

21 changes: 11 additions & 10 deletions tests/test_pipelines/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from sklearn.model_selection import ParameterGrid, PredefinedSplit

from tests.mixins.test_algorithm_mixin import TestAlgorithmMixin
from tests.test_pipelines.conftest import (
DummyDataset,
DummyOptimizablePipeline,
Expand All @@ -27,26 +26,26 @@
from tpcp._utils._score import _optimize_and_score
from tpcp.exceptions import OptimizationError, PotentialUserErrorWarning, TestError
from tpcp.optimize import DummyOptimize, GridSearch, GridSearchCV, Optimize
from tpcp.testing import TestAlgorithmMixin
from tpcp.validate import Aggregator, Scorer


class TestMetaFunctionalityGridSearch(TestAlgorithmMixin):
__test__ = True
algorithm_class = GridSearch
ALGORITHM_CLASS = GridSearch
ONLY_DEFAULT_PARAMS = False

@pytest.fixture()
def after_action_instance(self) -> GridSearch:
gs = GridSearch(DummyOptimizablePipeline(), ParameterGrid({"para_1": [1]}), scoring=dummy_single_score_func)
gs.optimize(DummyDataset())
return gs

def test_empty_init(self):
pytest.skip()


class TestMetaFunctionalityGridSearchCV(TestAlgorithmMixin):
__test__ = True
algorithm_class = GridSearchCV
ALGORITHM_CLASS = GridSearchCV
ONLY_DEFAULT_PARAMS = False

@pytest.fixture()
def after_action_instance(self) -> GridSearchCV:
Expand Down Expand Up @@ -94,11 +93,12 @@ def run(self, dataset):

class TestMetaFunctionalityOptimize(TestAlgorithmMixin):
__test__ = True
algorithm_class = Optimize
ALGORITHM_CLASS = Optimize
ONLY_DEFAULT_PARAMS = False

@pytest.fixture()
def after_action_instance(self) -> Optimize:
gs = self.algorithm_class(DummyOptimizablePipelineWithInfo())
gs = self.ALGORITHM_CLASS(DummyOptimizablePipelineWithInfo())
gs.optimize(DummyDataset())
return gs

Expand All @@ -108,11 +108,12 @@ def test_empty_init(self):

class TestMetaFunctionalityDummyOptimize(TestAlgorithmMixin):
__test__ = True
algorithm_class = DummyOptimize
ALGORITHM_CLASS = DummyOptimize
ONLY_DEFAULT_PARAMS = False

@pytest.fixture()
def after_action_instance(self) -> DummyOptimize:
gs = self.algorithm_class(DummyPipeline())
gs = self.ALGORITHM_CLASS(DummyPipeline())
gs.optimize(DummyDataset())
return gs

Expand Down
13 changes: 6 additions & 7 deletions tests/test_pipelines/test_optuna_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from optuna.samplers import GridSampler, RandomSampler, TPESampler
from optuna.trial import FrozenTrial

from tests.mixins.test_algorithm_mixin import TestAlgorithmMixin
from tests.test_pipelines.conftest import (
DummyDataset,
DummyOptimizablePipeline,
Expand All @@ -20,6 +19,7 @@
from tpcp._dataset import DatasetT
from tpcp._pipeline import OptimizablePipeline, PipelineT
from tpcp.optimize.optuna import CustomOptunaOptimize, OptunaSearch, StudyParamsDict
from tpcp.testing import TestAlgorithmMixin
from tpcp.validate import Scorer


Expand Down Expand Up @@ -86,8 +86,9 @@ def _get_study_params(seed):

class TestMetaFunctionalityOptuna(TestAlgorithmMixin):
__test__ = True
algorithm_class = DummyOptunaOptimizer
_ignored_names = ("create_search_space", "scoring", "mock_objective")
ALGORITHM_CLASS = DummyOptunaOptimizer
ONLY_DEFAULT_PARAMS = False
_IGNORED_NAMES = ("create_search_space", "scoring", "mock_objective")

@pytest.fixture()
def after_action_instance(self) -> DummyOptunaOptimizer:
Expand Down Expand Up @@ -254,7 +255,8 @@ def search_space(trial):

class TestMetaFunctionalityOptunaSearch(TestAlgorithmMixin):
__test__ = True
algorithm_class = OptunaSearch
ALGORITHM_CLASS = OptunaSearch
ONLY_DEFAULT_PARAMS = False

@pytest.fixture()
def after_action_instance(self) -> OptunaSearch:
Expand All @@ -268,9 +270,6 @@ def after_action_instance(self) -> OptunaSearch:
gs.optimize(DummyDataset())
return gs

def test_empty_init(self):
pytest.skip()


class TestOptunaSearch:
def test_single_score(self):
Expand Down
4 changes: 4 additions & 0 deletions tpcp/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Helper for testing of algorithms and pipelines implemented in tpcp."""
from tpcp.testing._algorithm_test_mixin import TestAlgorithmMixin

__all__ = ["TestAlgorithmMixin"]
Loading

0 comments on commit 83ac999

Please sign in to comment.