Skip to content

Commit

Permalink
[ENH]check for algorithm_type (#2339)
Browse files Browse the repository at this point in the history
* check for algorithm_type

* removed check for missing algorithm_type

* removed check for string

* skip test for none

* removed skip
  • Loading branch information
aryanpola authored Nov 13, 2024
1 parent 8c9edfe commit 5631c01
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions aeon/testing/estimator_checking/_yield_classification_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def _yield_classification_checks(estimator_class, estimator_instances, datatypes
estimator_class=estimator_class,
)

# Algorithm_type check
yield partial(
check_algorithm_type,
estimator_class=estimator_class,
)

# data type irrelevant
if _get_tag(estimator_class, "capability:contractable", raise_error=True):
yield partial(
Expand Down Expand Up @@ -126,6 +132,30 @@ def check_classifier_against_expected_results(estimator_class):
)


def check_algorithm_type(estimator_class):
"""Test the tag algorithm_type is classifier."""
valid_algorithm_types = [
"distance",
"deeplearning",
"convolution",
"dictionary",
"interval",
"feature",
"hybrid",
"shapelet",
]
algorithm_type = estimator_class.get_class_tag("algorithm_type")

# Pass the test
if algorithm_type is None:
return

assert algorithm_type in valid_algorithm_types, (
f"Estimator {estimator_class.__name__} has an invalid 'algorithm_type' tag: "
f"'{algorithm_type}'. Valid types are {valid_algorithm_types}."
)


def check_classifier_tags_consistent(estimator_class):
"""Test the tag X_inner_type is consistent with capability:unequal_length."""
valid_types = {"np-list", "df-list", "pd-multiindex"}
Expand Down

0 comments on commit 5631c01

Please sign in to comment.