Skip to content

Commit 2003cf2

Browse files
authored
[ENH] Shapelet visualization tools (#1715)
* Base structure for shapelet visualizers, some docstring corrections * Add docstring examples, correct bugs, WIP * fix boxplots, add customization dict * Fix plt import * Remove test function for RTF, add estimators to docs * Remove examples crashing CI due to softdeps * Add support for other classifiers * Add tests, fix bugs and add distance parameter * Fix pytest paramterize naming * Fix RDST tests * Correct docstring and axes dimensions * Fix precommit * Fix precommit (I'll never work on web git IDE again)
1 parent 7555498 commit 2003cf2

File tree

10 files changed

+1541
-101
lines changed

10 files changed

+1541
-101
lines changed

CODEOWNERS

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ aeon/transformations/theta.py @GuzalBulatova
4444

4545
aeon/utils/numba/ @baraline @MatthewMiddlehurst
4646

47+
aeon/visualisation/ @baraline
48+
4749
.github/ @aeon-toolkit/aeon-infrastructure-workgroup
4850
build_tools/ @aeon-toolkit/aeon-infrastructure-workgroup
4951

aeon/classification/shapelet_based/_rdst.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class RDSTClassifier(BaseClassifier):
6060
If True, restrict the value of the shapelet dilation parameter to be prime
6161
values. This can greatly speed-up the algorithm for long time series and/or
6262
short shapelet length, possibly at the cost of some accuracy.
63+
distance: str="manhattan"
64+
Name of the distance function to be used. By default this is the
65+
manhattan distance. Other distances from the aeon distance modules can be used.
6366
estimator : BaseEstimator or None, default=None
6467
Base estimator for the ensemble, can be supplied a sklearn `BaseEstimator`. If
6568
`None` a default `RidgeClassifierCV` classifier is used with standard scalling.
@@ -134,6 +137,7 @@ def __init__(
134137
use_prime_dilations: bool = False,
135138
estimator=None,
136139
save_transformed_data: bool = False,
140+
distance: str = "manhattan",
137141
n_jobs: int = 1,
138142
random_state: Union[int, Type[np.random.RandomState], None] = None,
139143
) -> None:
@@ -143,7 +147,7 @@ def __init__(
143147
self.threshold_percentiles = threshold_percentiles
144148
self.alpha_similarity = alpha_similarity
145149
self.use_prime_dilations = use_prime_dilations
146-
150+
self.distance = distance
147151
self.estimator = estimator
148152
self.save_transformed_data = save_transformed_data
149153
self.random_state = random_state
@@ -184,6 +188,7 @@ def _fit(self, X, y):
184188
use_prime_dilations=self.use_prime_dilations,
185189
n_jobs=self.n_jobs,
186190
random_state=self.random_state,
191+
distance=self.distance,
187192
)
188193
if self.estimator is None:
189194
self._estimator = make_pipeline(

aeon/transformations/collection/shapelet_based/_dilated_shapelet_transform.py

+39-21
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numba.typed import List
1515
from sklearn.preprocessing import LabelEncoder
1616

17-
from aeon.distances import manhattan_distance
17+
from aeon.distances import get_distance_function
1818
from aeon.transformations.collection import BaseCollectionTransformer
1919
from aeon.utils.numba.general import (
2020
AEON_NUMBA_STD_THRESHOLD,
@@ -83,6 +83,9 @@ class RandomDilatedShapeletTransform(BaseCollectionTransformer):
8383
If True, restrict the value of the shapelet dilation parameter to be prime
8484
values. This can greatly speed up the algorithm for long time series and/or
8585
short shapelet length, possibly at the cost of some accuracy.
86+
distance: str="manhattan"
87+
Name of the distance function to be used. By default this is the
88+
manhattan distance. Other distances from the aeon distance modules can be used.
8689
n_jobs : int, default=1
8790
The number of threads used for both `fit` and `transform`.
8891
random_state : int or None, default=None
@@ -153,6 +156,7 @@ def __init__(
153156
alpha_similarity=0.5,
154157
use_prime_dilations=False,
155158
random_state=None,
159+
distance="manhattan",
156160
n_jobs=1,
157161
):
158162
self.max_shapelets = max_shapelets
@@ -162,6 +166,7 @@ def __init__(
162166
self.alpha_similarity = alpha_similarity
163167
self.use_prime_dilations = use_prime_dilations
164168
self.random_state = random_state
169+
self.distance = distance
165170
self.n_jobs = n_jobs
166171

167172
super().__init__()
@@ -183,7 +188,8 @@ def _fit(self, X, y=None):
183188
self : RandomDilatedShapeletTransform
184189
This estimator.
185190
"""
186-
# Numba does not yet support new random numpy API with generator
191+
self.distance_func = get_distance_function(self.distance)
192+
187193
if isinstance(self.random_state, int):
188194
self._random_state = np.int32(self.random_state)
189195
else:
@@ -218,6 +224,7 @@ def _fit(self, X, y=None):
218224
self.alpha_similarity,
219225
self.use_prime_dilations,
220226
self._random_state,
227+
self.distance_func,
221228
)
222229
if len(self.shapelets_[0]) == 0:
223230
raise RuntimeError(
@@ -259,7 +266,11 @@ def _transform(self, X, y=None):
259266
"calling transform."
260267
)
261268

262-
X_new = dilated_shapelet_transform(X, self.shapelets_)
269+
X_new = dilated_shapelet_transform(
270+
X,
271+
self.shapelets_,
272+
self.distance_func,
273+
)
263274
if np.isinf(X_new).any() or np.isnan(X_new).any():
264275
warnings.warn(
265276
"Some invalid values (inf or nan) where converted from to 0 during the"
@@ -482,6 +493,7 @@ def random_dilated_shapelet_extraction(
482493
alpha_similarity,
483494
use_prime_dilations,
484495
seed,
496+
distance,
485497
):
486498
"""Randomly generate a set of shapelets given the input parameters.
487499
@@ -518,6 +530,10 @@ def random_dilated_shapelet_extraction(
518530
short shapelet length, possibly at the cost of some accuracy.
519531
seed : int
520532
Seed for random number generation.
533+
distance: CPUDispatcher
534+
A Numba function used to compute the distance between two multidimensional
535+
time series of shape (n_channels, length). Used as distance function between
536+
shapelets and candidate subsequences
521537
522538
Returns
523539
-------
@@ -641,7 +657,7 @@ def random_dilated_shapelet_extraction(
641657
X[id_test], length, dilation
642658
)
643659
X_subs = normalize_subsequences(X_subs, X_means, X_stds)
644-
x_dist = compute_shapelet_dist_vector(X_subs, _val, length)
660+
x_dist = compute_shapelet_dist_vector(X_subs, _val, length, distance)
645661

646662
lower_bound = np.percentile(x_dist, threshold_percentiles[0])
647663
upper_bound = np.percentile(x_dist, threshold_percentiles[1])
@@ -669,7 +685,7 @@ def random_dilated_shapelet_extraction(
669685

670686

671687
@njit(fastmath=True, cache=True, parallel=True)
672-
def dilated_shapelet_transform(X, shapelets):
688+
def dilated_shapelet_transform(X, shapelets, distance):
673689
"""Perform the shapelet transform with a set of shapelets and a set of time series.
674690
675691
Parameters
@@ -692,6 +708,10 @@ def dilated_shapelet_transform(X, shapelets):
692708
Means of the shapelets
693709
- stds : array, shape (n_shapelets, n_channels)
694710
Standard deviation of the shapelets
711+
distance: CPUDispatcher
712+
A Numba function used to compute the distance between two multidimensional
713+
time series of shape (n_channels, length).
714+
695715
696716
Returns
697717
-------
@@ -728,7 +748,7 @@ def dilated_shapelet_transform(X, shapelets):
728748
for i_shp in idx_no_norm:
729749
X_new[i_x, (n_ft * i_shp) : (n_ft * i_shp + n_ft)] = (
730750
compute_shapelet_features(
731-
X_subs, values[i_shp], length, threshold[i_shp]
751+
X_subs, values[i_shp], length, threshold[i_shp], distance
732752
)
733753
)
734754

@@ -739,7 +759,7 @@ def dilated_shapelet_transform(X, shapelets):
739759
for i_shp in idx_norm:
740760
X_new[i_x, (n_ft * i_shp) : (n_ft * i_shp + n_ft)] = (
741761
compute_shapelet_features(
742-
X_subs, values[i_shp], length, threshold[i_shp]
762+
X_subs, values[i_shp], length, threshold[i_shp], distance
743763
)
744764
)
745765
return X_new
@@ -808,7 +828,7 @@ def get_all_subsequences(X, length, dilation):
808828

809829

810830
@njit(fastmath=True, cache=True)
811-
def compute_shapelet_features(X_subs, values, length, threshold):
831+
def compute_shapelet_features(X_subs, values, length, threshold, distance):
812832
"""Extract the features from a shapelet distance vector.
813833
814834
Given a shapelet and a time series, extract three features from the resulting
@@ -826,10 +846,11 @@ def compute_shapelet_features(X_subs, values, length, threshold):
826846
The value array of the shapelet
827847
length : int
828848
Length of the shapelet
829-
values : array, shape (n_channels, length)
830-
The resulting subsequence
831849
threshold : float
832850
The threshold parameter of the shapelet
851+
distance: CPUDispatcher
852+
A Numba function used to compute the distance between two multidimensional
853+
time series of shape (n_channels, length).
833854
834855
Returns
835856
-------
@@ -843,7 +864,7 @@ def compute_shapelet_features(X_subs, values, length, threshold):
843864
n_subsequences = X_subs.shape[0]
844865

845866
for i_sub in prange(n_subsequences):
846-
_dist = manhattan_distance(X_subs[i_sub], values[:, :length])
867+
_dist = distance(X_subs[i_sub], values[:, :length])
847868
if _dist < _min:
848869
_min = _dist
849870
_argmin = i_sub
@@ -854,7 +875,7 @@ def compute_shapelet_features(X_subs, values, length, threshold):
854875

855876

856877
@njit(fastmath=True, cache=True)
857-
def compute_shapelet_dist_vector(X_subs, values, length):
878+
def compute_shapelet_dist_vector(X_subs, values, length, distance):
858879
"""Extract the features from a shapelet distance vector.
859880
860881
Given a shapelet and a time series, extract three features from the resulting
@@ -872,20 +893,17 @@ def compute_shapelet_dist_vector(X_subs, values, length):
872893
The value array of the shapelet
873894
length : int
874895
Length of the shapelet
875-
dilation : int
876-
Dilation of the shapelet
877-
values : array, shape (n_channels, length)
878-
The resulting subsequence
879-
threshold : float
880-
The threshold parameter of the shapelet
896+
distance: CPUDispatcher
897+
A Numba function used to compute the distance between two multidimensional
898+
time series of shape (n_channels, length).
881899
882900
Returns
883901
-------
884-
min, argmin, shapelet occurence
885-
The three computed features as float dtypes
902+
dist_vector : array, shape = (n_timestamps-(length-1)*dilation)
903+
The distance vector between the shapelets and candidate subsequences
886904
"""
887905
n_subsequences = X_subs.shape[0]
888906
dist_vector = np.zeros(n_subsequences)
889907
for i_sub in prange(n_subsequences):
890-
dist_vector[i_sub] = manhattan_distance(X_subs[i_sub], values[:, :length])
908+
dist_vector[i_sub] = distance(X_subs[i_sub], values[:, :length])
891909
return dist_vector

aeon/transformations/collection/shapelet_based/_rsast.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def _fit(self, X, y):
179179
# 2--calculate PACF and ACF for each TS chosen in each class
180180

181181
for i, c in enumerate(classes):
182-
183182
X_c = X_[y == c]
184183

185184
cnt = np.min([self.nb_inst_per_class, X_c.shape[0]]).astype(int)
@@ -313,7 +312,7 @@ def _transform(self, X, y=None):
313312
314313
Returns
315314
-------
316-
X_transformed: np.ndarray shape (n_cases, n_timepoints),
315+
X_transformed: np.ndarray shape (n_cases, n_kernels),
317316
The transformed data
318317
"""
319318
X_ = np.reshape(X, (X.shape[0], X.shape[-1]))

aeon/transformations/collection/shapelet_based/tests/test_dilated_shapelet_transform.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def test_compute_shapelet_features(dtype):
144144
dilation = 1
145145
threshold = 0.01
146146
X_subs = get_all_subsequences(X, length, dilation)
147-
_min, _argmin, SO = compute_shapelet_features(X_subs, values, length, threshold)
147+
_min, _argmin, SO = compute_shapelet_features(
148+
X_subs, values, length, threshold, manhattan_distance
149+
)
148150

149151
# On some occasion, float32 precision with fasmath retruns things like
150152
# 2.1835059227370834e-07 instead of 0
@@ -155,7 +157,9 @@ def test_compute_shapelet_features(dtype):
155157
dilation = 2
156158
threshold = 0.1
157159
X_subs = get_all_subsequences(X, length, dilation)
158-
_min, _argmin, SO = compute_shapelet_features(X_subs, values, length, threshold)
160+
_min, _argmin, SO = compute_shapelet_features(
161+
X_subs, values, length, threshold, manhattan_distance
162+
)
159163

160164
assert_almost_equal(_min, 0.0, decimal=4)
161165
assert _argmin == 7.0
@@ -164,7 +168,9 @@ def test_compute_shapelet_features(dtype):
164168
dilation = 4
165169
threshold = 2
166170
X_subs = get_all_subsequences(X, length, dilation)
167-
_min, _argmin, SO = compute_shapelet_features(X_subs, values, length, threshold)
171+
_min, _argmin, SO = compute_shapelet_features(
172+
X_subs, values, length, threshold, manhattan_distance
173+
)
168174

169175
assert_almost_equal(_min, 0.0, decimal=4)
170176
assert _argmin == 3.0
@@ -179,7 +185,9 @@ def test_compute_shapelet_dist_vector(dtype):
179185
for dilation in [1, 3, 5]:
180186
values = np.random.rand(3, length).astype(dtype)
181187
X_subs = get_all_subsequences(X, length, dilation)
182-
d_vect = compute_shapelet_dist_vector(X_subs, values, length)
188+
d_vect = compute_shapelet_dist_vector(
189+
X_subs, values, length, manhattan_distance
190+
)
183191
true_vect = np.zeros(X.shape[1] - (length - 1) * dilation)
184192
for i_sub in range(true_vect.shape[0]):
185193
_idx = [i_sub + j * dilation for j in range(length)]

aeon/visualisation/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,18 @@
2222
"plot_series_with_profiles",
2323
"plot_cluster_algorithm",
2424
"plot_temporal_importance_curves",
25+
"ShapeletVisualizer",
26+
"ShapeletTransformerVisualizer",
27+
"ShapeletClassifierVisualizer",
2528
]
2629

2730
from aeon.visualisation.estimator._clasp import plot_series_with_profiles
2831
from aeon.visualisation.estimator._clustering import plot_cluster_algorithm
32+
from aeon.visualisation.estimator._shapelets import (
33+
ShapeletClassifierVisualizer,
34+
ShapeletTransformerVisualizer,
35+
ShapeletVisualizer,
36+
)
2937
from aeon.visualisation.estimator._temporal_importance_curves import (
3038
plot_temporal_importance_curves,
3139
)

0 commit comments

Comments
 (0)