Skip to content

Commit

Permalink
Updated test_idk.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ramana-Raja authored Jan 5, 2025
1 parent 1c4262b commit f91793b
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions aeon/anomaly_detection/tests/test_idk.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,23 @@
"""Tests for the IDK Class."""

import numpy as np
import pytest
from sklearn.utils import check_random_state

from aeon.anomaly_detection import IDK
from aeon.utils.validation._dependencies import _check_estimator_deps


@pytest.mark.skipif(
not _check_estimator_deps(
IDK(psi1=8, psi2=2, width=1, random_state=42), severity="none"
),
reason="skip test if required soft dependencies not available",
)
def test_idk_univariate():
"""Test IDK on univariate data."""
rng = check_random_state(seed=2)
rng = np.random.default_rng(seed=2)
series = rng.normal(size=(100,))
series[50:58] -= 10
series[50:58] -= 5

ad = IDK(psi1=8, psi2=2, width=1, random_state=42)
ad = IDK(psi1=8, psi2=2, width=1, random_state=2)
pred = ad.fit_predict(series)
ad_sliding = IDK(psi1=16, psi2=4, width=10, sliding=True, random_state=1)
pred_sliding = ad_sliding.fit_predict(series)

assert pred.shape == (100,)
assert pred.dtype == np.float64
assert 50 <= np.argmax(pred) <= 58
assert pred_sliding.shape == (91,)
assert pred_sliding.dtype == np.float64
assert 50 <= np.argmax(pred_sliding) <= 68
assert 60 <= np.argmax(pred_sliding) <= 80

0 comments on commit f91793b

Please sign in to comment.