Skip to content

Commit

Permalink
[MNT] Update and consolidate general estimator checks (#2377)
Browse files Browse the repository at this point in the history
* tidy general estimator tests

* fixes

* fixes
  • Loading branch information
MatthewMiddlehurst authored Nov 22, 2024
1 parent e31eddf commit 108948d
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 242 deletions.
4 changes: 2 additions & 2 deletions aeon/testing/estimator_checking/_estimator_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ class is passed.
>>> results = check_estimator(MockClassifier())
Running specific check for MockClassifier
>>> check_estimator(MockClassifier, checks_to_run="check_clone")
{'check_clone(estimator=MockClassifier())': 'PASSED'}
>>> check_estimator(MockClassifier, checks_to_run="check_get_params")
{'check_get_params(estimator=MockClassifier())': 'PASSED'}
"""
# check if estimator has soft dependencies installed
_check_estimator_deps(estimator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,6 @@ def check_classifier_overrides_and_tags(estimator_class):
f"Override _{method} instead."
)

# axis class parameter is for internal use only
assert "axis" not in estimator_class.__dict__

# Test valid tag for X_inner_type
X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type")
if isinstance(X_inner_type, str):
Expand All @@ -172,11 +169,6 @@ def check_classifier_overrides_and_tags(estimator_class):
else: # must be a list
assert any([t in valid_unequal_types for t in X_inner_type])

# Must have at least one set to True
multi = estimator_class.get_class_tag(tag_name="capability:multivariate")
uni = estimator_class.get_class_tag(tag_name="capability:univariate")
assert multi or uni

valid_algorithm_types = [
"distance",
"deeplearning",
Expand Down
Loading

0 comments on commit 108948d

Please sign in to comment.