18
18
import numba
19
19
from aeon .utils .validation ._dependencies import _check_soft_dependencies
20
20
21
- from tsml_eval .experiments import load_and_run_clustering_experiment
22
- from tsml_eval .experiments .set_clusterer import get_clusterer_by_name
21
+ from tsml_eval .experiments import (
22
+ get_clusterer_by_name ,
23
+ get_data_transform_by_name ,
24
+ load_and_run_clustering_experiment ,
25
+ )
23
26
from tsml_eval .experiments .tests import _CLUSTERER_RESULTS_PATH
24
27
from tsml_eval .testing .testing_utils import _TEST_DATA_PATH
25
28
from tsml_eval .utils .arguments import parse_args
@@ -88,10 +91,19 @@ def run_experiment(args):
88
91
row_normalise = args .row_normalise ,
89
92
** args .kwargs ,
90
93
),
91
- row_normalise = args .row_normalise ,
92
94
n_clusters = args .n_clusters ,
93
95
clusterer_name = args .estimator_name ,
94
96
resample_id = args .resample_id ,
97
+ data_transforms = get_data_transform_by_name (
98
+ args .data_transform_name ,
99
+ row_normalise = args .row_normalise ,
100
+ random_state = (
101
+ args .resample_id
102
+ if args .random_seed is None
103
+ else args .random_seed
104
+ ),
105
+ n_jobs = 1 ,
106
+ ),
95
107
build_test_file = args .test_fold ,
96
108
write_attributes = args .write_attributes ,
97
109
att_max_shape = args .att_max_shape ,
@@ -110,6 +122,7 @@ def run_experiment(args):
110
122
estimator_name = "KMeans"
111
123
dataset_name = "MinimalChinatown"
112
124
row_normalise = False
125
+ transform_name = None
113
126
n_clusters = - 1
114
127
resample_id = 0
115
128
test_fold = False
@@ -133,17 +146,22 @@ def run_experiment(args):
133
146
row_normalise = row_normalise ,
134
147
** kwargs ,
135
148
)
149
+ transform = get_data_transform_by_name (
150
+ transform_name ,
151
+ row_normalise = row_normalise ,
152
+ random_state = resample_id ,
153
+ )
136
154
print (f"Local Run of { estimator_name } ({ clusterer .__class__ .__name__ } )." )
137
155
138
156
load_and_run_clustering_experiment (
139
157
data_path ,
140
158
results_path ,
141
159
dataset_name ,
142
160
clusterer ,
143
- row_normalise = row_normalise ,
144
161
n_clusters = n_clusters ,
145
162
clusterer_name = estimator_name ,
146
163
resample_id = resample_id ,
164
+ data_transforms = transform ,
147
165
build_test_file = test_fold ,
148
166
write_attributes = write_attributes ,
149
167
att_max_shape = att_max_shape ,
0 commit comments