diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 9cc9f135e..b60206dbe 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -32,10 +32,12 @@ from submitit.helpers import Checkpointable, DelayedSubmission from fairchem.core.common import distutils +from fairchem.core.common.logger import WandBSingletonLogger from fairchem.core.common.utils import ( get_commit_hash, get_timestamp_uid, setup_env_vars, + setup_logging, ) # this effects the cli only since the actual job will be run in subprocesses or remoe @@ -139,22 +141,27 @@ def _set_deterministic_mode() -> None: class Submitit(Checkpointable): + def __init__(self) -> None: + self.config = None + self.runner = None + def __call__(self, dict_config: DictConfig) -> None: self.config = dict_config # TODO also load job config here setup_env_vars() + setup_logging() distutils.setup(map_job_config_to_dist_config(self.config.job)) self._init_logger() _set_seeds(self.config.job.seed) if self.config.job.deterministic: _set_deterministic_mode() - runner: Runner = hydra.utils.instantiate(dict_config.runner) - runner.config = self.config + self.runner: Runner = hydra.utils.instantiate(dict_config.runner) + self.runner.config = self.config # must call resume state AFTER the runner has been initialized if self.config.job.runner_state_path: - runner.load_state(self.config.job.runner_state_path) - runner.run() + self.runner.load_state(self.config.job.runner_state_path) + self.runner.run() distutils.cleanup() def _init_logger(self) -> None: @@ -164,7 +171,7 @@ def _init_logger(self) -> None: and not self.config.job.debug ): # get a partial function from the config and instantiate wandb with it - # currently this assume we use a wandb logger + # currently code assumes that we only use the WandBSingletonLogger logger_initializer = hydra.utils.instantiate(self.config.job.logger) simple_config = OmegaConf.to_container( self.config, resolve=True, throw_on_missing=True @@ -177,13 +184,14 @@ def _init_logger(self) -> None: ) def checkpoint(self, *args, **kwargs) -> DelayedSubmission: - # TODO: this is yet to be tested properly - logging.info("Submitit checkpointing callback is triggered") + logging.error("Submitit checkpointing callback is triggered") save_path = self.config.job.metadata.preemption_checkpoint_dir self.runner.save_state(save_path) - logging.info("Submitit checkpointing callback is completed") cfg_copy = self.config.copy() cfg_copy.job.runner_state_path = save_path + if WandBSingletonLogger.initialized(): + WandBSingletonLogger.get_instance().mark_preempting() + logging.info("Submitit checkpointing callback is completed") return DelayedSubmission(Submitit(), cfg_copy) @@ -240,7 +248,11 @@ def get_hydra_config_from_yaml( config_directory = os.path.dirname(os.path.abspath(config_yml)) config_name = os.path.basename(config_yml) hydra.initialize_config_dir(config_directory, version_base="1.1") - return hydra.compose(config_name=config_name, overrides=overrides_args) + cfg = hydra.compose(config_name=config_name, overrides=overrides_args) + # merge default structured config with initialized job object + cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg) + # canonicalize config (remove top level keys that just used replacing variables) + return get_canonical_config(cfg) def runner_wrapper(config: DictConfig): @@ -256,10 +268,6 @@ def main( args, override_args = parser.parse_known_args() cfg = get_hydra_config_from_yaml(args.config, override_args) - # merge default structured config with initialized job object - cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg) - # canonicalize config (remove top level keys that just used replacing variables) - cfg = get_canonical_config(cfg) log_dir = cfg.job.metadata.log_dir os.makedirs(cfg.job.run_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) @@ -283,7 +291,7 @@ def main( slurm_qos=scheduler_cfg.slurm.qos, slurm_account=scheduler_cfg.slurm.account, ) - job = executor.submit(runner_wrapper, cfg) + job = executor.submit(Submitit(), cfg) logging.info( f"Submitted job id: {cfg.job.timestamp_id}, slurm id: {job.job_id}, logs: {cfg.job.metadata.log_dir}" ) diff --git a/tests/core/test_hydra_cli.py b/tests/core/test_hydra_cli.py index 65de8433a..baac71e47 100644 --- a/tests/core/test_hydra_cli.py +++ b/tests/core/test_hydra_cli.py @@ -5,7 +5,11 @@ import hydra import pytest -from fairchem.core._cli_hydra import ALLOWED_TOP_LEVEL_KEYS, main +from fairchem.core._cli_hydra import ( + ALLOWED_TOP_LEVEL_KEYS, + get_hydra_config_from_yaml, + main, +) from fairchem.core.common import distutils @@ -58,3 +62,12 @@ def test_hydra_cli_throws_error_on_disallowed_top_level_keys(): sys.argv[1:] = sys_args with pytest.raises(ValueError): main() + + +def get_cfg_from_yaml(): + yaml = "tests/core/test_hydra_cli.yml" + cfg = get_hydra_config_from_yaml(yaml) + # assert fields got initialized properly + assert cfg.job.run_name is not None + assert cfg.job.seed is not None + assert cfg.keys() == ALLOWED_TOP_LEVEL_KEYS