1616 AbstractEvaluator ,
1717 fit_and_suppress_warnings
1818)
19+ from autoPyTorch .evaluation .utils import DisableFileOutputParameters
1920from autoPyTorch .pipeline .components .training .metrics .base import autoPyTorchMetric
2021from autoPyTorch .utils .common import subsampler
2122from autoPyTorch .utils .hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
@@ -33,7 +34,7 @@ def __init__(self, backend: Backend, queue: Queue,
3334 num_run : Optional [int ] = None ,
3435 include : Optional [Dict [str , Any ]] = None ,
3536 exclude : Optional [Dict [str , Any ]] = None ,
36- disable_file_output : Union [bool , List ] = False ,
37+ disable_file_output : Optional [ List [ Union [str , DisableFileOutputParameters ]]] = None ,
3738 init_params : Optional [Dict [str , Any ]] = None ,
3839 logger_port : Optional [int ] = None ,
3940 keep_models : Optional [bool ] = None ,
@@ -241,14 +242,11 @@ def file_output(
241242 )
242243
243244 # Abort if we don't want to output anything.
244- if hasattr (self , 'disable_file_output' ):
245- if self .disable_file_output :
246- return None , {}
247- else :
248- self .disabled_file_outputs = []
245+ if 'all' in self .disable_file_output :
246+ return None , {}
249247
250- if hasattr (self , 'pipeline' ) and self . pipeline is not None :
251- if 'pipeline' not in self .disabled_file_outputs :
248+ if getattr (self , 'pipeline' , None ) is not None :
249+ if 'pipeline' not in self .disable_file_output :
252250 pipeline = self .pipeline
253251 else :
254252 pipeline = None
@@ -265,11 +263,11 @@ def file_output(
265263 ensemble_predictions = None ,
266264 valid_predictions = (
267265 Y_valid_pred if 'y_valid' not in
268- self .disabled_file_outputs else None
266+ self .disable_file_output else None
269267 ),
270268 test_predictions = (
271269 Y_test_pred if 'y_test' not in
272- self .disabled_file_outputs else None
270+ self .disable_file_output else None
273271 ),
274272 )
275273
@@ -287,8 +285,8 @@ def eval_function(
287285 num_run : int ,
288286 include : Optional [Dict [str , Any ]],
289287 exclude : Optional [Dict [str , Any ]],
290- disable_file_output : Union [bool , List ],
291288 output_y_hat_optimization : bool = False ,
289+ disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
292290 pipeline_config : Optional [Dict [str , Any ]] = None ,
293291 budget_type : str = None ,
294292 init_params : Optional [Dict [str , Any ]] = None ,
@@ -297,14 +295,75 @@ def eval_function(
297295 search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ,
298296 instance : str = None ,
299297) -> None :
298+ """
299+ This closure allows the communication between the ExecuteTaFuncWithQueue and the
300+ pipeline trainer (TrainEvaluator).
301+
302+ Fundamentally, smac calls the ExecuteTaFuncWithQueue.run() method, which internally
303+ builds a TrainEvaluator. The TrainEvaluator builds a pipeline, stores the output files
304+ to disc via the backend, and puts the performance result of the run in the queue.
305+
306+
307+ Attributes:
308+ backend (Backend):
309+ An object to interface with the disk storage. In particular, allows to
310+ access the train and test datasets
311+ queue (Queue):
312+ Each worker available will instantiate an evaluator, and after completion,
313+ it will return the evaluation result via a multiprocessing queue
314+ metric (autoPyTorchMetric):
315+ A scorer object that is able to evaluate how good a pipeline was fit. It
316+ is a wrapper on top of the actual score method (a wrapper on top of scikit
317+ lean accuracy for example) that formats the predictions accordingly.
318+ budget: (float):
319+ The amount of epochs/time a configuration is allowed to run.
320+ budget_type (str):
321+ The budget type, which can be epochs or time
322+ pipeline_config (Optional[Dict[str, Any]]):
323+ Defines the content of the pipeline being evaluated. For example, it
324+ contains pipeline specific settings like logging name, or whether or not
325+ to use tensorboard.
326+ config (Union[int, str, Configuration]):
327+ Determines the pipeline to be constructed.
328+ seed (int):
329+ A integer that allows for reproducibility of results
330+ output_y_hat_optimization (bool):
331+ Whether this worker should output the target predictions, so that they are
332+ stored on disk. Fundamentally, the resampling strategy might shuffle the
333+ Y_train targets, so we store the split in order to re-use them for ensemble
334+ selection.
335+ num_run (Optional[int]):
336+ An identifier of the current configuration being fit. This number is unique per
337+ configuration.
338+ include (Optional[Dict[str, Any]]):
339+ An optional dictionary to include components of the pipeline steps.
340+ exclude (Optional[Dict[str, Any]]):
341+ An optional dictionary to exclude components of the pipeline steps.
342+ disable_file_output (Union[bool, List[str]]):
343+ By default, the model, it's predictions and other metadata is stored on disk
344+ for each finished configuration. This argument allows the user to skip
345+ saving certain file type, for example the model, from being written to disk.
346+ init_params (Optional[Dict[str, Any]]):
347+ Optional argument that is passed to each pipeline step. It is the equivalent of
348+ kwargs for the pipeline steps.
349+ logger_port (Optional[int]):
350+ Logging is performed using a socket-server scheme to be robust against many
351+ parallel entities that want to write to the same file. This integer states the
352+ socket port for the communication channel. If None is provided, a traditional
353+ logger is used.
354+ instance (str):
355+ An instance on which to evaluate the current pipeline. By default we work
356+ with a single instance, being the provided X_train, y_train of a single dataset.
357+ This instance is a compatibility argument for SMAC, that is capable of working
358+ with multiple datasets at the same time.
359+ """
300360 evaluator = FitEvaluator (
301361 backend = backend ,
302362 queue = queue ,
303363 metric = metric ,
304364 configuration = config ,
305365 seed = seed ,
306366 num_run = num_run ,
307- output_y_hat_optimization = output_y_hat_optimization ,
308367 include = include ,
309368 exclude = exclude ,
310369 disable_file_output = disable_file_output ,
0 commit comments