diff --git a/aeon/clustering/_elastic_som.py b/aeon/clustering/_elastic_som.py index da88bb541c..08180c2811 100644 --- a/aeon/clustering/_elastic_som.py +++ b/aeon/clustering/_elastic_som.py @@ -6,6 +6,7 @@ from aeon.clustering.base import BaseClusterer from aeon.distances import get_alignment_path_function, pairwise_distance +from aeon.utils.tags.enum_tags import AlgorithmType VALID_ELASTIC_SOM_METRICS = [ "dtw", @@ -148,7 +149,7 @@ class ElasticSOM(BaseClusterer): _tags = { "capability:multivariate": True, - "algorithm_type": "distance", + "algorithm_type": AlgorithmType.DISTANCE.value, } def __init__( diff --git a/aeon/clustering/_k_means.py b/aeon/clustering/_k_means.py index e4e459a5cf..f2c6bb3a7c 100644 --- a/aeon/clustering/_k_means.py +++ b/aeon/clustering/_k_means.py @@ -2,6 +2,8 @@ from typing import Optional +from aeon.utils.tags.enum_tags import AlgorithmType + __maintainer__ = [] from typing import Callable, Union @@ -153,7 +155,7 @@ class TimeSeriesKMeans(BaseClusterer): _tags = { "capability:multivariate": True, - "algorithm_type": "distance", + "algorithm_type": AlgorithmType.DISTANCE.value, } def __init__( diff --git a/aeon/clustering/_k_medoids.py b/aeon/clustering/_k_medoids.py index 12d0f2819d..c6003c4251 100644 --- a/aeon/clustering/_k_medoids.py +++ b/aeon/clustering/_k_medoids.py @@ -2,6 +2,8 @@ from typing import Optional +from aeon.utils.tags.enum_tags import AlgorithmType + __maintainer__ = [] import warnings @@ -146,7 +148,7 @@ class TimeSeriesKMedoids(BaseClusterer): _tags = { "capability:multivariate": True, - "algorithm_type": "distance", + "algorithm_type": AlgorithmType.DISTANCE.value, } def __init__( diff --git a/aeon/clustering/_k_sc.py b/aeon/clustering/_k_sc.py index 1ace94b245..f243b6a9b9 100644 --- a/aeon/clustering/_k_sc.py +++ b/aeon/clustering/_k_sc.py @@ -6,6 +6,7 @@ from numpy.random import RandomState from aeon.clustering import TimeSeriesKMeans +from aeon.utils.tags.enum_tags import AlgorithmType class KSpectralCentroid(TimeSeriesKMeans): @@ -92,7 +93,7 @@ class KSpectralCentroid(TimeSeriesKMeans): _tags = { "capability:multivariate": True, - "algorithm_type": "distance", + "algorithm_type": AlgorithmType.DISTANCE.value, } def __init__( diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index aa8d8a3b64..b1b53084b1 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -6,6 +6,7 @@ from numpy.random import RandomState from aeon.clustering.base import BaseClusterer +from aeon.utils.tags.enum_tags import AlgorithmType class TimeSeriesKShape(BaseClusterer): @@ -70,7 +71,7 @@ class TimeSeriesKShape(BaseClusterer): _tags = { "capability:multivariate": True, "python_dependencies": "tslearn", - "algorithm_type": "distance", + "algorithm_type": AlgorithmType.DISTANCE.value, } def __init__( diff --git a/aeon/clustering/compose/_pipeline.py b/aeon/clustering/compose/_pipeline.py index fef3f87e0b..be76dc61a5 100644 --- a/aeon/clustering/compose/_pipeline.py +++ b/aeon/clustering/compose/_pipeline.py @@ -6,6 +6,7 @@ from aeon.base._estimators.compose.collection_pipeline import BaseCollectionPipeline from aeon.clustering import BaseClusterer +from aeon.utils.tags.enum_tags import AlgorithmType class ClustererPipeline(BaseCollectionPipeline, BaseClusterer): @@ -75,7 +76,10 @@ class ClustererPipeline(BaseCollectionPipeline, BaseClusterer): """ _tags = { - "X_inner_type": ["np-list", "numpy3D"], + "X_inner_type": [ + AlgorithmType.NP_LIST.value, + AlgorithmType.NUMPY3D.value, + ], } def __init__(self, transformers, clusterer, random_state=None): diff --git a/aeon/clustering/deep_learning/base.py b/aeon/clustering/deep_learning/base.py index 4a05e8c662..f946ed9f4d 100644 --- a/aeon/clustering/deep_learning/base.py +++ b/aeon/clustering/deep_learning/base.py @@ -8,6 +8,7 @@ from aeon.base._base import _clone_estimator from aeon.clustering._k_means import TimeSeriesKMeans from aeon.clustering.base import BaseClusterer +from aeon.utils.tags.enum_tags import AlgorithmType class BaseDeepClusterer(BaseClusterer): @@ -29,9 +30,9 @@ class BaseDeepClusterer(BaseClusterer): """ _tags = { - "X_inner_type": "numpy3D", + "X_inner_type": AlgorithmType.NUMPY3D.value, "capability:multivariate": True, - "algorithm_type": "deeplearning", + "algorithm_type": AlgorithmType.DEEPLEARNING.value, "non_deterministic": True, "cant_pickle": True, "python_dependencies": "tensorflow", diff --git a/aeon/clustering/feature_based/_catch22.py b/aeon/clustering/feature_based/_catch22.py index 33f0b79bc5..03d142fbb1 100644 --- a/aeon/clustering/feature_based/_catch22.py +++ b/aeon/clustering/feature_based/_catch22.py @@ -12,6 +12,7 @@ from aeon.base._base import _clone_estimator from aeon.clustering import BaseClusterer from aeon.transformations.collection.feature_based import Catch22 +from aeon.utils.tags.enum_tags import AlgorithmType class Catch22Clusterer(BaseClusterer): @@ -92,11 +93,14 @@ class Catch22Clusterer(BaseClusterer): """ _tags = { - "X_inner_type": ["np-list", "numpy3D"], + "X_inner_type": [ + AlgorithmType.NP_LIST.value, + AlgorithmType.NUMPY3D.value, + ], "capability:multivariate": True, "capability:unequal_length": True, "capability:multithreading": True, - "algorithm_type": "feature", + "algorithm_type": AlgorithmType.FEATURE.value, } def __init__( diff --git a/aeon/clustering/feature_based/_summary.py b/aeon/clustering/feature_based/_summary.py index 309d3ac92f..7299edccbd 100644 --- a/aeon/clustering/feature_based/_summary.py +++ b/aeon/clustering/feature_based/_summary.py @@ -12,6 +12,7 @@ from aeon.base._base import _clone_estimator from aeon.clustering import BaseClusterer from aeon.transformations.collection.feature_based import SevenNumberSummary +from aeon.utils.tags.enum_tags import AlgorithmType class SummaryClusterer(BaseClusterer): @@ -64,7 +65,7 @@ class SummaryClusterer(BaseClusterer): _tags = { "capability:multivariate": True, "capability:multithreading": True, - "algorithm_type": "feature", + "algorithm_type": AlgorithmType.FEATURE.value, } def __init__( diff --git a/aeon/clustering/feature_based/_tsfresh.py b/aeon/clustering/feature_based/_tsfresh.py index ed14e90a47..ad61efe0ce 100644 --- a/aeon/clustering/feature_based/_tsfresh.py +++ b/aeon/clustering/feature_based/_tsfresh.py @@ -15,6 +15,7 @@ from aeon.base._base import _clone_estimator from aeon.clustering import BaseClusterer from aeon.transformations.collection.feature_based import TSFresh +from aeon.utils.tags.enum_tags import AlgorithmType class TSFreshClusterer(BaseClusterer): @@ -74,7 +75,7 @@ class TSFreshClusterer(BaseClusterer): _tags = { "capability:multivariate": True, "capability:multithreading": True, - "algorithm_type": "feature", + "algorithm_type": AlgorithmType.FEATURE.value, "python_dependencies": "tsfresh", } diff --git a/aeon/utils/tags/__init__.py b/aeon/utils/tags/__init__.py index cd506c4460..17fc313f81 100644 --- a/aeon/utils/tags/__init__.py +++ b/aeon/utils/tags/__init__.py @@ -1,6 +1,7 @@ """Estimator tags and tag utilities.""" __all__ = [ + "AlgorithmType", "ESTIMATOR_TAGS", "check_valid_tags", "all_tags_for_estimator", @@ -9,3 +10,4 @@ from aeon.utils.tags._discovery import all_tags_for_estimator from aeon.utils.tags._tags import ESTIMATOR_TAGS from aeon.utils.tags._validate import check_valid_tags +from aeon.utils.tags.enum_tags import AlgorithmType diff --git a/aeon/utils/tags/enum_tags.py b/aeon/utils/tags/enum_tags.py new file mode 100644 index 0000000000..eabb943c45 --- /dev/null +++ b/aeon/utils/tags/enum_tags.py @@ -0,0 +1,31 @@ +"""Apply Enumeration in module.""" + +__all__ = ["AlgorithmType"] + +from enum import Enum + + +class AlgorithmType(Enum): + """ + An enumeration of algorithm types and data structures. + + Attributes + ---------- + Algorithm Types: + - DISTANCE: Clustering based on distance metrics + - DEEPLEARNING: Clustering using deep learning techniques + - FEATURE: Clustering driven by feature extraction + + Data Structure Types: + - NP_LIST: Numpy list-based data structure + - NUMPY3D: Three-dimensional Numpy array + """ + + # Algorithm types for clustering strategies + DISTANCE = "distance" + DEEPLEARNING = "deeplearning" + FEATURE = "feature" + + # Data structure types for clustering input + NP_LIST = "np-list" + NUMPY3D = "numpy3D"