Skip to content

Commit d62d277

Browse files
[MNT] Make Multi Rec test in deep clustering faster (#2315)
* Update test_clusterer_features.py * precommit * fixed bug * remove test replace by combination * re do test
1 parent 216ac16 commit d62d277

File tree

5 files changed

+20
-21
lines changed

5 files changed

+20
-21
lines changed

aeon/clustering/deep_learning/_ae_fcn.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -489,17 +489,17 @@ def _get_test_params(cls, parameter_set="default"):
489489
Each dict are parameters to construct an "interesting" test instance, i.e.,
490490
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
491491
"""
492-
param1 = {
493-
"n_epochs": 1,
492+
param = {
493+
"n_epochs": 2,
494494
"batch_size": 4,
495495
"use_bias": False,
496-
"n_layers": 1,
497-
"n_filters": 4,
498-
"kernel_size": 2,
496+
"n_layers": 2,
497+
"n_filters": [2, 2],
498+
"kernel_size": [2, 2],
499499
"padding": "same",
500500
"strides": 1,
501-
"latent_space_dim": 4,
501+
"latent_space_dim": 2,
502502
"estimator": DummyClusterer(n_clusters=2),
503503
}
504504

505-
return [param1]
505+
return [param]

aeon/clustering/deep_learning/_ae_resnet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,11 @@ def _get_test_params(cls, parameter_set="default"):
502502
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
503503
"""
504504
param = {
505-
"n_epochs": 1,
505+
"n_epochs": 2,
506506
"batch_size": 4,
507-
"n_residual_blocks": 1,
507+
"n_residual_blocks": 2,
508508
"n_conv_per_residual_block": 1,
509-
"n_filters": 1,
509+
"n_filters": [2, 2],
510510
"kernel_size": 2,
511511
"use_bias": False,
512512
"estimator": DummyClusterer(n_clusters=2),

aeon/clustering/deep_learning/tests/test_clusterer_features.py renamed to aeon/clustering/deep_learning/tests/test_deep_clusterer_features.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414
def test_multi_rec_fcn():
1515
"""Tests whether multi-rec loss works fine or not."""
1616
X = np.random.random((100, 5, 2))
17-
clst = AEFCNClusterer(
18-
n_clusters=2, n_epochs=10, n_filters=[2, 3, 4], loss="multi_rec"
19-
)
17+
clst = AEFCNClusterer(**AEFCNClusterer._get_test_params()[0], loss="multi_rec")
2018
clst.fit(X)
21-
assert (
22-
clst.history["loss"][0] > clst.history["loss"][9]
23-
) # Check if loss is decreasing.
24-
clst = AEResNetClusterer(n_clusters=2, n_epochs=10, loss="multi_rec")
19+
assert isinstance(clst.history["loss"][-1], float)
20+
21+
clst = AEResNetClusterer(
22+
**AEResNetClusterer._get_test_params()[0], loss="multi_rec"
23+
)
2524
clst.fit(X)
26-
assert clst.history["loss"][0] > clst.history["loss"][9]
25+
assert isinstance(clst.history["loss"][-1], float)

aeon/networks/_ae_resnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def build_network(self, input_shape, **kwargs):
225225
conv = tf.keras.layers.Conv1D(
226226
filters=self._n_filters[d],
227227
kernel_size=self._kernel_size[c],
228-
strides=self._strides[d],
228+
strides=self._strides[c],
229229
padding=self._padding[c],
230230
dilation_rate=self._dilation_rate[c],
231231
)(x)
@@ -290,7 +290,7 @@ def build_network(self, input_shape, **kwargs):
290290
conv = tf.keras.layers.Conv1DTranspose(
291291
filters=self._n_filters[d],
292292
kernel_size=self._kernel_size[c],
293-
strides=self._strides[d],
293+
strides=self._strides[c],
294294
padding=self._padding[c],
295295
dilation_rate=self._dilation_rate[c],
296296
)(x)

aeon/networks/_resnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def build_network(self, input_shape, **kwargs):
209209
conv = tf.keras.layers.Conv1D(
210210
filters=self._n_filters[d],
211211
kernel_size=self._kernel_size[c],
212-
strides=self._strides[d],
212+
strides=self._strides[c],
213213
padding=self._padding[c],
214214
dilation_rate=self._dilation_rate[c],
215215
)(x)

0 commit comments

Comments
 (0)