77
88from smac .tae import StatusType
99
10- from autoPyTorch .automl_common . common . utils . backend import Backend
11- from autoPyTorch . constants import (
12- CLASSIFICATION_TASKS ,
13- MULTICLASSMULTIOUTPUT ,
10+ from autoPyTorch .datasets . resampling_strategy import (
11+ CrossValTypes ,
12+ NoResamplingStrategyTypes ,
13+ check_resampling_strategy
1414)
15- from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutValTypes
1615from autoPyTorch .evaluation .abstract_evaluator import (
1716 AbstractEvaluator ,
1817 EvaluationResults ,
2120from autoPyTorch .evaluation .abstract_evaluator import EvaluatorParams , FixedPipelineParams
2221from autoPyTorch .utils .common import dict_repr , subsampler
2322
24- __all__ = ['TrainEvaluator' , 'eval_train_function' ]
23+ __all__ = ['Evaluator' , 'eval_fn' ]
24+
2525
2626class _CrossValidationResultsManager :
2727 def __init__ (self , num_folds : int ):
@@ -83,15 +83,13 @@ def get_result_dict(self) -> Dict[str, Any]:
8383 )
8484
8585
86- class TrainEvaluator (AbstractEvaluator ):
86+ class Evaluator (AbstractEvaluator ):
8787 """
8888 This class builds a pipeline using the provided configuration.
8989 A pipeline implementing the provided configuration is fitted
9090 using the datamanager object retrieved from disc, via the backend.
9191 After the pipeline is fitted, it is save to disc and the performance estimate
92- is communicated to the main process via a Queue. It is only compatible
93- with `CrossValTypes`, `HoldoutValTypes`, i.e, when the training data
94- is split and the validation set is used for SMBO optimisation.
92+ is communicated to the main process via a Queue.
9593
9694 Args:
9795 queue (Queue):
@@ -101,43 +99,17 @@ class TrainEvaluator(AbstractEvaluator):
10199 Fixed parameters for a pipeline
102100 evaluator_params (EvaluatorParams):
103101 The parameters for an evaluator.
102+
103+ Attributes:
104+ train (bool):
105+ Whether the training data is split and the validation set is used for SMBO optimisation.
106+ cross_validation (bool):
107+ Whether we use cross validation or not.
104108 """
105- def __init__ (self , backend : Backend , queue : Queue ,
106- metric : autoPyTorchMetric ,
107- budget : float ,
108- configuration : Union [int , str , Configuration ],
109- budget_type : str = None ,
110- pipeline_config : Optional [Dict [str , Any ]] = None ,
111- seed : int = 1 ,
112- output_y_hat_optimization : bool = True ,
113- num_run : Optional [int ] = None ,
114- include : Optional [Dict [str , Any ]] = None ,
115- exclude : Optional [Dict [str , Any ]] = None ,
116- disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
117- init_params : Optional [Dict [str , Any ]] = None ,
118- logger_port : Optional [int ] = None ,
119- keep_models : Optional [bool ] = None ,
120- all_supported_metrics : bool = True ,
121- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ) -> None :
122- super ().__init__ (
123- backend = backend ,
124- queue = queue ,
125- configuration = configuration ,
126- metric = metric ,
127- seed = seed ,
128- output_y_hat_optimization = output_y_hat_optimization ,
129- num_run = num_run ,
130- include = include ,
131- exclude = exclude ,
132- disable_file_output = disable_file_output ,
133- init_params = init_params ,
134- budget = budget ,
135- budget_type = budget_type ,
136- logger_port = logger_port ,
137- all_supported_metrics = all_supported_metrics ,
138- pipeline_config = pipeline_config ,
139- search_space_updates = search_space_updates
140- )
109+ def __init__ (self , queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ):
110+ resampling_strategy = fixed_pipeline_params .backend .load_datamanager ().resampling_strategy
111+ self .train = not isinstance (resampling_strategy , NoResamplingStrategyTypes )
112+ self .cross_validation = isinstance (resampling_strategy , CrossValTypes )
141113
142114 if not isinstance (self .resampling_strategy , (CrossValTypes , HoldoutValTypes )):
143115 raise ValueError (
@@ -175,7 +147,7 @@ def _evaluate_on_split(self, split_id: int) -> EvaluationResults:
175147
176148 return EvaluationResults (
177149 pipeline = pipeline ,
178- opt_loss = self ._loss (labels = self .y_train [opt_split ], preds = opt_pred ),
150+ opt_loss = self ._loss (labels = self .y_train [opt_split ] if self . train else self . y_test , preds = opt_pred ),
179151 train_loss = self ._loss (labels = self .y_train [train_split ], preds = train_pred ),
180152 opt_pred = opt_pred ,
181153 valid_pred = valid_pred ,
@@ -201,6 +173,7 @@ def _cross_validation(self) -> EvaluationResults:
201173 results = self ._evaluate_on_split (split_id )
202174
203175 self .pipelines [split_id ] = results .pipeline
176+ assert opt_split is not None # mypy redefinition
204177 cv_results .update (split_id , results , len (train_split ), len (opt_split ))
205178
206179 self .y_opt = np .concatenate ([y_opt for y_opt in Y_opt if y_opt is not None ])
@@ -212,15 +185,16 @@ def evaluate_loss(self) -> None:
212185 if self .splits is None :
213186 raise ValueError (f"cannot fit pipeline { self .__class__ .__name__ } with datamanager.splits None" )
214187
215- if self .num_folds == 1 :
188+ if self .cross_validation :
189+ results = self ._cross_validation ()
190+ else :
216191 _ , opt_split = self .splits [0 ]
217192 results = self ._evaluate_on_split (split_id = 0 )
218- self .y_opt , self .pipelines [0 ] = self .y_train [opt_split ], results .pipeline
219- else :
220- results = self ._cross_validation ()
193+ self .pipelines [0 ] = results .pipeline
194+ self .y_opt = self .y_train [opt_split ] if self .train else self .y_test
221195
222196 self .logger .debug (
223- f"In train evaluator. evaluate_loss, num_run: { self .num_run } , loss:{ results .opt_loss } ,"
197+ f"In evaluate_loss, num_run: { self .num_run } , loss:{ results .opt_loss } ,"
224198 f" status: { results .status } ,\n additional run info:\n { dict_repr (results .additional_run_info )} "
225199 )
226200 self .record_evaluation (results = results )
@@ -240,41 +214,23 @@ def _fit_and_evaluate_loss(
240214
241215 kwargs = {'pipeline' : pipeline , 'unique_train_labels' : self .unique_train_labels [split_id ]}
242216 train_pred = self .predict (subsampler (self .X_train , train_indices ), ** kwargs )
243- opt_pred = self .predict (subsampler (self .X_train , opt_indices ), ** kwargs )
244- valid_pred = self .predict (self .X_valid , ** kwargs )
245217 test_pred = self .predict (self .X_test , ** kwargs )
218+ valid_pred = self .predict (self .X_valid , ** kwargs )
219+
220+ # No resampling ===> evaluate on test dataset
221+ opt_pred = self .predict (subsampler (self .X_train , opt_indices ), ** kwargs ) if self .train else test_pred
246222
247223 assert train_pred is not None and opt_pred is not None # mypy check
248224 return train_pred , opt_pred , valid_pred , test_pred
249225
250226
251- # create closure for evaluating an algorithm
252- def eval_train_function (
253- backend : Backend ,
254- queue : Queue ,
255- metric : autoPyTorchMetric ,
256- budget : float ,
257- config : Optional [Configuration ],
258- seed : int ,
259- output_y_hat_optimization : bool ,
260- num_run : int ,
261- include : Optional [Dict [str , Any ]],
262- exclude : Optional [Dict [str , Any ]],
263- disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
264- pipeline_config : Optional [Dict [str , Any ]] = None ,
265- budget_type : str = None ,
266- init_params : Optional [Dict [str , Any ]] = None ,
267- logger_port : Optional [int ] = None ,
268- all_supported_metrics : bool = True ,
269- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ,
270- instance : str = None ,
271- ) -> None :
227+ def eval_fn (queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ) -> None :
272228 """
273229 This closure allows the communication between the TargetAlgorithmQuery and the
274- pipeline trainer (TrainEvaluator ).
230+ pipeline trainer (Evaluator ).
275231
276232 Fundamentally, smac calls the TargetAlgorithmQuery.run() method, which internally
277- builds a TrainEvaluator . The TrainEvaluator builds a pipeline, stores the output files
233+ builds an Evaluator . The Evaluator builds a pipeline, stores the output files
278234 to disc via the backend, and puts the performance result of the run in the queue.
279235
280236 Args:
@@ -286,7 +242,11 @@ def eval_train_function(
286242 evaluator_params (EvaluatorParams):
287243 The parameters for an evaluator.
288244 """
289- evaluator = TrainEvaluator (
245+ resampling_strategy = fixed_pipeline_params .backend .load_datamanager ().resampling_strategy
246+ check_resampling_strategy (resampling_strategy )
247+
248+ # NoResamplingStrategyTypes ==> test evaluator, otherwise ==> train evaluator
249+ evaluator = Evaluator (
290250 queue = queue ,
291251 evaluator_params = evaluator_params ,
292252 fixed_pipeline_params = fixed_pipeline_params
0 commit comments