77
88from smac .tae import StatusType
99
10- from autoPyTorch .automl_common . common . utils . backend import Backend
11- from autoPyTorch . constants import (
12- CLASSIFICATION_TASKS ,
13- MULTICLASSMULTIOUTPUT ,
10+ from autoPyTorch .datasets . resampling_strategy import (
11+ CrossValTypes ,
12+ NoResamplingStrategyTypes ,
13+ check_resampling_strategy
1414)
15- from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutValTypes
1615from autoPyTorch .evaluation .abstract_evaluator import (
1716 AbstractEvaluator ,
1817 EvaluationResults ,
2120from autoPyTorch .evaluation .abstract_evaluator import EvaluatorParams , FixedPipelineParams
2221from autoPyTorch .utils .common import dict_repr , subsampler
2322
24- __all__ = ['TrainEvaluator' , 'eval_train_function' ]
23+ __all__ = ['Evaluator' , 'eval_fn' ]
24+
2525
2626class _CrossValidationResultsManager :
2727 def __init__ (self , num_folds : int ):
@@ -83,15 +83,13 @@ def get_result_dict(self) -> Dict[str, Any]:
8383 )
8484
8585
86- class TrainEvaluator (AbstractEvaluator ):
86+ class Evaluator (AbstractEvaluator ):
8787 """
8888 This class builds a pipeline using the provided configuration.
8989 A pipeline implementing the provided configuration is fitted
9090 using the datamanager object retrieved from disc, via the backend.
9191 After the pipeline is fitted, it is save to disc and the performance estimate
92- is communicated to the main process via a Queue. It is only compatible
93- with `CrossValTypes`, `HoldoutValTypes`, i.e, when the training data
94- is split and the validation set is used for SMBO optimisation.
92+ is communicated to the main process via a Queue.
9593
9694 Args:
9795 queue (Queue):
@@ -101,54 +99,27 @@ class TrainEvaluator(AbstractEvaluator):
10199 Fixed parameters for a pipeline
102100 evaluator_params (EvaluatorParams):
103101 The parameters for an evaluator.
102+
103+ Attributes:
104+ train (bool):
105+ Whether the training data is split and the validation set is used for SMBO optimisation.
106+ cross_validation (bool):
107+ Whether we use cross validation or not.
104108 """
105- def __init__ (self , backend : Backend , queue : Queue ,
106- metric : autoPyTorchMetric ,
107- budget : float ,
108- configuration : Union [int , str , Configuration ],
109- budget_type : str = None ,
110- pipeline_config : Optional [Dict [str , Any ]] = None ,
111- seed : int = 1 ,
112- output_y_hat_optimization : bool = True ,
113- num_run : Optional [int ] = None ,
114- include : Optional [Dict [str , Any ]] = None ,
115- exclude : Optional [Dict [str , Any ]] = None ,
116- disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
117- init_params : Optional [Dict [str , Any ]] = None ,
118- logger_port : Optional [int ] = None ,
119- keep_models : Optional [bool ] = None ,
120- all_supported_metrics : bool = True ,
121- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ) -> None :
122- super ().__init__ (
123- backend = backend ,
124- queue = queue ,
125- configuration = configuration ,
126- metric = metric ,
127- seed = seed ,
128- output_y_hat_optimization = output_y_hat_optimization ,
129- num_run = num_run ,
130- include = include ,
131- exclude = exclude ,
132- disable_file_output = disable_file_output ,
133- init_params = init_params ,
134- budget = budget ,
135- budget_type = budget_type ,
136- logger_port = logger_port ,
137- all_supported_metrics = all_supported_metrics ,
138- pipeline_config = pipeline_config ,
139- search_space_updates = search_space_updates
140- )
109+ def __init__ (self , queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ):
110+ resampling_strategy = fixed_pipeline_params .backend .load_datamanager ().resampling_strategy
111+ self .train = not isinstance (resampling_strategy , NoResamplingStrategyTypes )
112+ self .cross_validation = isinstance (resampling_strategy , CrossValTypes )
141113
142- if not isinstance ( self .datamanager . resampling_strategy , ( CrossValTypes , HoldoutValTypes )) :
143- resampling_strategy = self . datamanager . resampling_strategy
144- raise ValueError (
145- f'resampling_strategy for TrainEvaluator must be in '
146- f'(CrossValTypes, HoldoutValTypes), but got { resampling_strategy } '
147- )
114+ if not self .train and fixed_pipeline_params . save_y_opt :
115+ # TODO: Add the test to cover here
116+ # No resampling can not be used for building ensembles. save_y_opt=False ensures it
117+ fixed_pipeline_params = fixed_pipeline_params . _replace ( save_y_opt = False )
118+
119+ super (). __init__ ( queue = queue , fixed_pipeline_params = fixed_pipeline_params , evaluator_params = evaluator_params )
148120
149- self .splits = self .datamanager .splits
150- self .num_folds : int = len (self .splits )
151- self .logger .debug ("Search space updates :{}" .format (self .search_space_updates ))
121+ if self .train :
122+ self .logger .debug ("Search space updates :{}" .format (self .fixed_pipeline_params .search_space_updates ))
152123
153124 def _evaluate_on_split (self , split_id : int ) -> EvaluationResults :
154125 """
@@ -177,7 +148,7 @@ def _evaluate_on_split(self, split_id: int) -> EvaluationResults:
177148
178149 return EvaluationResults (
179150 pipeline = pipeline ,
180- opt_loss = self ._loss (labels = self .y_train [opt_split ], preds = opt_pred ),
151+ opt_loss = self ._loss (labels = self .y_train [opt_split ] if self . train else self . y_test , preds = opt_pred ),
181152 train_loss = self ._loss (labels = self .y_train [train_split ], preds = train_pred ),
182153 opt_pred = opt_pred ,
183154 valid_pred = valid_pred ,
@@ -203,6 +174,7 @@ def _cross_validation(self) -> EvaluationResults:
203174 results = self ._evaluate_on_split (split_id )
204175
205176 self .pipelines [split_id ] = results .pipeline
177+ assert opt_split is not None # mypy redefinition
206178 cv_results .update (split_id , results , len (train_split ), len (opt_split ))
207179
208180 self .y_opt = np .concatenate ([y_opt for y_opt in Y_opt if y_opt is not None ])
@@ -214,15 +186,16 @@ def evaluate_loss(self) -> None:
214186 if self .splits is None :
215187 raise ValueError (f"cannot fit pipeline { self .__class__ .__name__ } with datamanager.splits None" )
216188
217- if self .num_folds == 1 :
189+ if self .cross_validation :
190+ results = self ._cross_validation ()
191+ else :
218192 _ , opt_split = self .splits [0 ]
219193 results = self ._evaluate_on_split (split_id = 0 )
220- self .y_opt , self .pipelines [0 ] = self .y_train [opt_split ], results .pipeline
221- else :
222- results = self ._cross_validation ()
194+ self .pipelines [0 ] = results .pipeline
195+ self .y_opt = self .y_train [opt_split ] if self .train else self .y_test
223196
224197 self .logger .debug (
225- f"In train evaluator. evaluate_loss, num_run: { self .num_run } , loss:{ results .opt_loss } ,"
198+ f"In evaluate_loss, num_run: { self .num_run } , loss:{ results .opt_loss } ,"
226199 f" status: { results .status } ,\n additional run info:\n { dict_repr (results .additional_run_info )} "
227200 )
228201 self .record_evaluation (results = results )
@@ -242,41 +215,23 @@ def _fit_and_evaluate_loss(
242215
243216 kwargs = {'pipeline' : pipeline , 'unique_train_labels' : self .unique_train_labels [split_id ]}
244217 train_pred = self .predict (subsampler (self .X_train , train_indices ), ** kwargs )
245- opt_pred = self .predict (subsampler (self .X_train , opt_indices ), ** kwargs )
246- valid_pred = self .predict (self .X_valid , ** kwargs )
247218 test_pred = self .predict (self .X_test , ** kwargs )
219+ valid_pred = self .predict (self .X_valid , ** kwargs )
220+
221+ # No resampling ===> evaluate on test dataset
222+ opt_pred = self .predict (subsampler (self .X_train , opt_indices ), ** kwargs ) if self .train else test_pred
248223
249224 assert train_pred is not None and opt_pred is not None # mypy check
250225 return train_pred , opt_pred , valid_pred , test_pred
251226
252227
253- # create closure for evaluating an algorithm
254- def eval_train_function (
255- backend : Backend ,
256- queue : Queue ,
257- metric : autoPyTorchMetric ,
258- budget : float ,
259- config : Optional [Configuration ],
260- seed : int ,
261- output_y_hat_optimization : bool ,
262- num_run : int ,
263- include : Optional [Dict [str , Any ]],
264- exclude : Optional [Dict [str , Any ]],
265- disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
266- pipeline_config : Optional [Dict [str , Any ]] = None ,
267- budget_type : str = None ,
268- init_params : Optional [Dict [str , Any ]] = None ,
269- logger_port : Optional [int ] = None ,
270- all_supported_metrics : bool = True ,
271- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ,
272- instance : str = None ,
273- ) -> None :
228+ def eval_fn (queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ) -> None :
274229 """
275230 This closure allows the communication between the TargetAlgorithmQuery and the
276- pipeline trainer (TrainEvaluator ).
231+ pipeline trainer (Evaluator ).
277232
278233 Fundamentally, smac calls the TargetAlgorithmQuery.run() method, which internally
279- builds a TrainEvaluator . The TrainEvaluator builds a pipeline, stores the output files
234+ builds an Evaluator . The Evaluator builds a pipeline, stores the output files
280235 to disc via the backend, and puts the performance result of the run in the queue.
281236
282237 Args:
@@ -288,7 +243,11 @@ def eval_train_function(
288243 evaluator_params (EvaluatorParams):
289244 The parameters for an evaluator.
290245 """
291- evaluator = TrainEvaluator (
246+ resampling_strategy = fixed_pipeline_params .backend .load_datamanager ().resampling_strategy
247+ check_resampling_strategy (resampling_strategy )
248+
249+ # NoResamplingStrategyTypes ==> test evaluator, otherwise ==> train evaluator
250+ evaluator = Evaluator (
292251 queue = queue ,
293252 evaluator_params = evaluator_params ,
294253 fixed_pipeline_params = fixed_pipeline_params
0 commit comments