|
27 | 27 |
|
28 | 28 | import pandas as pd |
29 | 29 |
|
30 | | -from smac.runhistory.runhistory import DataOrigin, RunHistory, RunInfo, RunValue |
| 30 | +from smac.runhistory.runhistory import DataOrigin, RunHistory |
31 | 31 | from smac.stats.stats import Stats |
32 | 32 | from smac.tae import StatusType |
33 | 33 |
|
@@ -593,11 +593,16 @@ def _load_models(self) -> bool: |
593 | 593 | raise ValueError("Resampling strategy is needed to determine what models to load") |
594 | 594 | self.ensemble_ = self._backend.load_ensemble(self.seed) |
595 | 595 |
|
596 | | - if isinstance(self._disable_file_output, List): |
597 | | - disabled_file_outputs = self._disable_file_output |
| 596 | + # TODO: remove this code after `fit_pipeline` is rebased. |
| 597 | + if hasattr(self, '_disable_file_output'): |
| 598 | + if isinstance(self._disable_file_output, List): |
| 599 | + disabled_file_outputs = self._disable_file_output |
| 600 | + disable_file_output = False |
| 601 | + elif isinstance(self._disable_file_output, bool): |
| 602 | + disable_file_output = self._disable_file_output |
| 603 | + disabled_file_outputs = [] |
| 604 | + else: |
598 | 605 | disable_file_output = False |
599 | | - elif isinstance(self._disable_file_output, bool): |
600 | | - disable_file_output = self._disable_file_output |
601 | 606 | disabled_file_outputs = [] |
602 | 607 |
|
603 | 608 | # If no ensemble is loaded, try to get the best performing model |
@@ -901,18 +906,15 @@ def run_traditional_ml( |
901 | 906 | learning algorithm runs over the time limit. |
902 | 907 | """ |
903 | 908 | assert self._logger is not None # for mypy compliancy |
904 | | - if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS: |
905 | | - self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...") |
906 | | - else: |
907 | | - traditional_task_name = 'runTraditional' |
908 | | - self._stopwatch.start_task(traditional_task_name) |
909 | | - elapsed_time = self._stopwatch.wall_elapsed(current_task_name) |
910 | | - time_for_traditional = int(runtime_limit - elapsed_time) |
911 | | - self._do_traditional_prediction( |
912 | | - func_eval_time_limit_secs=func_eval_time_limit_secs, |
913 | | - time_left=time_for_traditional, |
914 | | - ) |
915 | | - self._stopwatch.stop_task(traditional_task_name) |
| 909 | + traditional_task_name = 'runTraditional' |
| 910 | + self._stopwatch.start_task(traditional_task_name) |
| 911 | + elapsed_time = self._stopwatch.wall_elapsed(current_task_name) |
| 912 | + time_for_traditional = int(runtime_limit - elapsed_time) |
| 913 | + self._do_traditional_prediction( |
| 914 | + func_eval_time_limit_secs=func_eval_time_limit_secs, |
| 915 | + time_left=time_for_traditional, |
| 916 | + ) |
| 917 | + self._stopwatch.stop_task(traditional_task_name) |
916 | 918 |
|
917 | 919 | def _search( |
918 | 920 | self, |
@@ -1282,22 +1284,7 @@ def _search( |
1282 | 1284 | self._logger.info("Starting Shutdown") |
1283 | 1285 |
|
1284 | 1286 | if proc_ensemble is not None: |
1285 | | - self._results_manager.ensemble_performance_history = list(proc_ensemble.history) |
1286 | | - |
1287 | | - if len(proc_ensemble.futures) > 0: |
1288 | | - # Also add ensemble runs that did not finish within smac time |
1289 | | - # and add them into the ensemble history |
1290 | | - self._logger.info("Ensemble script still running, waiting for it to finish.") |
1291 | | - result = proc_ensemble.futures.pop().result() |
1292 | | - if result: |
1293 | | - ensemble_history, _, _, _ = result |
1294 | | - self._results_manager.ensemble_performance_history.extend(ensemble_history) |
1295 | | - self._logger.info("Ensemble script finished, continue shutdown.") |
1296 | | - |
1297 | | - # save the ensemble performance history file |
1298 | | - if len(self.ensemble_performance_history) > 0: |
1299 | | - pd.DataFrame(self.ensemble_performance_history).to_json( |
1300 | | - os.path.join(self._backend.internals_directory, 'ensemble_history.json')) |
| 1287 | + self._collect_results_ensemble(proc_ensemble) |
1301 | 1288 |
|
1302 | 1289 | if load_models: |
1303 | 1290 | self._logger.info("Loading models...") |
@@ -1557,7 +1544,7 @@ def fit_pipeline( |
1557 | 1544 | exclude=self.exclude_components, |
1558 | 1545 | search_space_updates=self.search_space_updates) |
1559 | 1546 | dataset_properties = dataset.get_dataset_properties(dataset_requirements) |
1560 | | - self._backend.replace_datamanager(dataset) |
| 1547 | + self._backend.save_datamanager(dataset) |
1561 | 1548 |
|
1562 | 1549 | if self._logger is None: |
1563 | 1550 | self._logger = self._get_logger(dataset.dataset_name) |
@@ -1747,7 +1734,7 @@ def fit_ensemble( |
1747 | 1734 | ensemble_fit_task_name = 'EnsembleFit' |
1748 | 1735 | self._stopwatch.start_task(ensemble_fit_task_name) |
1749 | 1736 | if enable_traditional_pipeline: |
1750 | | - if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_for_task: |
| 1737 | + if func_eval_time_limit_secs > time_for_task: |
1751 | 1738 | self._logger.warning( |
1752 | 1739 | 'Time limit for a single run is higher than total time ' |
1753 | 1740 | 'limit. Capping the limit for a single run to the total ' |
@@ -1788,12 +1775,8 @@ def fit_ensemble( |
1788 | 1775 | ) |
1789 | 1776 |
|
1790 | 1777 | manager.build_ensemble(self._dask_client) |
1791 | | - future = manager.futures.pop() |
1792 | | - result = future.result() |
1793 | | - if result is None: |
1794 | | - raise ValueError("Errors occurred while building the ensemble - please" |
1795 | | - " check the log file and command line output for error messages.") |
1796 | | - self.ensemble_performance_history, _, _, _ = result |
| 1778 | + if manager is not None: |
| 1779 | + self._collect_results_ensemble(manager) |
1797 | 1780 |
|
1798 | 1781 | if load_models: |
1799 | 1782 | self._load_models() |
@@ -1871,6 +1854,31 @@ def _init_ensemble_builder( |
1871 | 1854 |
|
1872 | 1855 | return proc_ensemble |
1873 | 1856 |
|
| 1857 | + def _collect_results_ensemble( |
| 1858 | + self, |
| 1859 | + manager: EnsembleBuilderManager |
| 1860 | + ) -> None: |
| 1861 | + |
| 1862 | + if self._logger is None: |
| 1863 | + raise ValueError("logger should be initialized to fit ensemble") |
| 1864 | + |
| 1865 | + self._results_manager.ensemble_performance_history = list(manager.history) |
| 1866 | + |
| 1867 | + if len(manager.futures) > 0: |
| 1868 | + # Also add ensemble runs that did not finish within smac time |
| 1869 | + # and add them into the ensemble history |
| 1870 | + self._logger.info("Ensemble script still running, waiting for it to finish.") |
| 1871 | + result = manager.futures.pop().result() |
| 1872 | + if result: |
| 1873 | + ensemble_history, _, _, _ = result |
| 1874 | + self._results_manager.ensemble_performance_history.extend(ensemble_history) |
| 1875 | + self._logger.info("Ensemble script finished, continue shutdown.") |
| 1876 | + |
| 1877 | + # save the ensemble performance history file |
| 1878 | + if len(self.ensemble_performance_history) > 0: |
| 1879 | + pd.DataFrame(self.ensemble_performance_history).to_json( |
| 1880 | + os.path.join(self._backend.internals_directory, 'ensemble_history.json')) |
| 1881 | + |
1874 | 1882 | def predict( |
1875 | 1883 | self, |
1876 | 1884 | X_test: np.ndarray, |
|
0 commit comments