Skip to content

Commit

Permalink
[REF] Refactor performance_metrics into benchmarking.metrics and …
Browse files Browse the repository at this point in the history
…use a local version of `sphinx-remove-toctrees` (#2353)

* refactor metrics

* doc fixes

* suggested changes

* fix
  • Loading branch information
MatthewMiddlehurst authored Nov 15, 2024
1 parent 344c831 commit 78f025e
Show file tree
Hide file tree
Showing 27 changed files with 224 additions and 115 deletions.
1 change: 1 addition & 0 deletions aeon/benchmarking/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Performance metrics."""
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
"""Metrics for anomaly detection."""

from aeon.performance_metrics.anomaly_detection._binary import (
__all__ = [
"range_precision",
"range_recall",
"range_f_score",
"roc_auc_score",
"pr_auc_score",
"rp_rr_auc_score",
"f_score_at_k_points",
"f_score_at_k_ranges",
"range_pr_roc_auc_support",
"range_roc_auc_score",
"range_pr_auc_score",
"range_pr_vus_score",
"range_roc_vus_score",
]

from aeon.benchmarking.metrics.anomaly_detection._binary import (
range_f_score,
range_precision,
range_recall,
)
from aeon.performance_metrics.anomaly_detection._continuous import (
from aeon.benchmarking.metrics.anomaly_detection._continuous import (
f_score_at_k_points,
f_score_at_k_ranges,
pr_auc_score,
roc_auc_score,
rp_rr_auc_score,
)
from aeon.performance_metrics.anomaly_detection._vus_metrics import (
from aeon.benchmarking.metrics.anomaly_detection._vus_metrics import (
range_pr_auc_score,
range_pr_roc_auc_support,
range_pr_vus_score,
range_roc_auc_score,
range_roc_vus_score,
)

__all__ = [
"range_precision",
"range_recall",
"range_f_score",
"roc_auc_score",
"pr_auc_score",
"rp_rr_auc_score",
"f_score_at_k_points",
"f_score_at_k_ranges",
"range_pr_roc_auc_support",
"range_roc_auc_score",
"range_pr_auc_score",
"range_pr_vus_score",
"range_roc_vus_score",
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from aeon.performance_metrics.anomaly_detection._util import check_y
from aeon.benchmarking.metrics.anomaly_detection._util import check_y
from aeon.utils.validation._dependencies import _check_soft_dependencies


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
"rp_rr_auc_score",
]


import warnings

import numpy as np
from sklearn.metrics import auc, f1_score, precision_recall_curve
from sklearn.metrics import roc_auc_score as _roc_auc_score

from aeon.performance_metrics.anomaly_detection._util import check_y
from aeon.performance_metrics.anomaly_detection.thresholding import (
from aeon.benchmarking.metrics.anomaly_detection._util import check_y
from aeon.benchmarking.metrics.anomaly_detection.thresholding import (
top_k_points_threshold,
top_k_ranges_threshold,
)
Expand Down Expand Up @@ -116,7 +117,7 @@ def f_score_at_k_points(
See Also
--------
aeon.performance_metrics.anomaly_detection.thresholding.top_k_points_threshold
aeon.benchmarking.metrics.anomaly_detection.thresholding.top_k_points_threshold
Function used to find the threshold.
"""
y_true, y_pred = check_y(y_true, y_score, force_y_pred_continuous=True)
Expand Down Expand Up @@ -163,7 +164,7 @@ def f_score_at_k_ranges(
See Also
--------
aeon.performance_metrics.anomaly_detection.thresholding.top_k_ranges_threshold
aeon.benchmarking.metrics.anomaly_detection.thresholding.top_k_ranges_threshold
Function used to find the threshold.
"""
_check_soft_dependencies(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
"range_pr_roc_auc_support",
]


import warnings

import numpy as np

from aeon.performance_metrics.anomaly_detection._util import check_y
from aeon.benchmarking.metrics.anomaly_detection._util import check_y


def _anomaly_bounds(y_true: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from aeon.performance_metrics.anomaly_detection import (
from aeon.benchmarking.metrics.anomaly_detection import (
f_score_at_k_points,
f_score_at_k_ranges,
pr_auc_score,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from aeon.performance_metrics.anomaly_detection.thresholding import (
from aeon.benchmarking.metrics.anomaly_detection.thresholding import (
percentile_threshold,
sigma_threshold,
top_k_points_threshold,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"top_k_ranges_threshold",
]


import numpy as np


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Clustering performance metric functions."""

__maintainer__ = []

__maintainer__ = ["MatthewMiddlehurst", "chrisholder"]
__all__ = ["clustering_accuracy_score"]


Expand Down Expand Up @@ -30,7 +29,7 @@ def clustering_accuracy_score(y_true, y_pred):
Examples
--------
>>> from aeon.performance_metrics.clustering import clustering_accuracy_score
>>> from aeon.benchmarking.metrics.clustering import clustering_accuracy_score
>>> clustering_accuracy_score([0, 0, 1, 1], [1, 1, 0, 0]) # doctest: +SKIP
1.0
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.utils import check_array

__maintainer__ = []
__all__ = ["count_error", "hausdorff_error", "prediction_ratio"]


def count_error(
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Tests for performance metric functions."""

__maintainer__ = []

import numpy as np

from aeon.performance_metrics.clustering import clustering_accuracy_score
from aeon.benchmarking.metrics.clustering import clustering_accuracy_score


def test_clustering_accuracy():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from aeon.performance_metrics.segmentation.metrics import (
from aeon.benchmarking.metrics.segmentation import (
count_error,
hausdorff_error,
prediction_ratio,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Functions to compute stats and get p-values."""

__maintainer__ = []

__all__ = ["check_friedman", "nemenyi_test", "wilcoxon_test"]

import warnings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import aeon
from aeon.benchmarking.results_loaders import get_estimator_results_as_array
from aeon.benchmarking.stats import check_friedman, nemenyi_test, wilcoxon_test
from aeon.datasets.tsc_datasets import univariate_equal_length
from aeon.performance_metrics.stats import check_friedman, nemenyi_test, wilcoxon_test

data_path = os.path.join(
os.path.dirname(aeon.__file__),
Expand Down
1 change: 0 additions & 1 deletion aeon/performance_metrics/__init__.py

This file was deleted.

7 changes: 0 additions & 7 deletions aeon/performance_metrics/segmentation/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion aeon/visualisation/results/_critical_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from scipy.stats import rankdata

from aeon.performance_metrics.stats import check_friedman, nemenyi_test, wilcoxon_test
from aeon.benchmarking.stats import check_friedman, nemenyi_test, wilcoxon_test
from aeon.utils.validation._dependencies import _check_soft_dependencies


Expand Down
2 changes: 1 addition & 1 deletion aeon/visualisation/results/_significance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from scipy.stats import rankdata

from aeon.performance_metrics.stats import check_friedman, nemenyi_test, wilcoxon_test
from aeon.benchmarking.stats import check_friedman, nemenyi_test, wilcoxon_test
from aeon.utils.validation._dependencies import _check_soft_dependencies


Expand Down
56 changes: 56 additions & 0 deletions docs/_sphinxext/sphinx_remove_toctrees.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""A small sphinx extension to remove toctrees.
Original extension:
https://github.com/executablebooks/sphinx-remove-toctrees
This file was adapted by the developers of the MNE-LSL project, this is just
a copy for use in the aeon documentation.
https://github.com/mne-tools/mne-lsl
https://github.com/mne-tools/mne-lsl/blob/main/doc/_sphinxext/sphinx_remove_toctrees.py
"""

from pathlib import Path

from sphinx import addnodes


def remove_toctrees(app, env):
"""Remove toctrees from pages a user provides.
This happens at the end of the build process, so even though the toctrees
are removed, it won't raise sphinx warnings about unreferenced pages.
"""
patterns = app.config.remove_from_toctrees
if isinstance(patterns, str):
patterns = [patterns]

# figure out the list of patterns to remove from all toctrees
to_remove = []
for pattern in patterns:
# inputs should either be a glob pattern or a direct path so just use glob
srcdir = Path(env.srcdir)
for matched in srcdir.glob(pattern):
to_remove.append(
str(matched.relative_to(srcdir).with_suffix("").as_posix())
)

# loop through all tocs and remove the ones that match our pattern
for _, tocs in env.tocs.items():
for toctree in tocs.traverse(addnodes.toctree):
new_entries = []
for entry in toctree.attributes.get("entries", []):
if entry[1] not in to_remove:
new_entries.append(entry)
# if there are no more entries just remove the toctree
if len(new_entries) == 0:
toctree.parent.remove(toctree)
else:
toctree.attributes["entries"] = new_entries


def setup(app): # noqa: D103
app.add_config_value("remove_from_toctrees", [], "html")
app.connect("env-updated", remove_toctrees)
return {"parallel_read_safe": True, "parallel_write_safe": True}
1 change: 0 additions & 1 deletion docs/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ api_reference/data_format
api_reference/datasets
api_reference/distances
api_reference/networks
api_reference/performance_metrics
api_reference/regression
api_reference/segmentation
api_reference/similarity_search
Expand Down
Loading

0 comments on commit 78f025e

Please sign in to comment.