Skip to content

Commit

Permalink
Fix preemption behavior and combine config init (#1016)
Browse files Browse the repository at this point in the history
* combine and test config init

* fix preemption

* update

---------

Co-authored-by: rgao <rgao@meta>
  • Loading branch information
rayg1234 and rgao authored Feb 15, 2025
1 parent a94109f commit 7bc606e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
36 changes: 22 additions & 14 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
from submitit.helpers import Checkpointable, DelayedSubmission

from fairchem.core.common import distutils
from fairchem.core.common.logger import WandBSingletonLogger
from fairchem.core.common.utils import (
get_commit_hash,
get_timestamp_uid,
setup_env_vars,
setup_logging,
)

# this effects the cli only since the actual job will be run in subprocesses or remoe
Expand Down Expand Up @@ -139,22 +141,27 @@ def _set_deterministic_mode() -> None:


class Submitit(Checkpointable):
def __init__(self) -> None:
self.config = None
self.runner = None

def __call__(self, dict_config: DictConfig) -> None:
self.config = dict_config
# TODO also load job config here
setup_env_vars()
setup_logging()
distutils.setup(map_job_config_to_dist_config(self.config.job))
self._init_logger()
_set_seeds(self.config.job.seed)
if self.config.job.deterministic:
_set_deterministic_mode()

runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.config = self.config
self.runner: Runner = hydra.utils.instantiate(dict_config.runner)
self.runner.config = self.config
# must call resume state AFTER the runner has been initialized
if self.config.job.runner_state_path:
runner.load_state(self.config.job.runner_state_path)
runner.run()
self.runner.load_state(self.config.job.runner_state_path)
self.runner.run()
distutils.cleanup()

def _init_logger(self) -> None:
Expand All @@ -164,7 +171,7 @@ def _init_logger(self) -> None:
and not self.config.job.debug
):
# get a partial function from the config and instantiate wandb with it
# currently this assume we use a wandb logger
# currently code assumes that we only use the WandBSingletonLogger
logger_initializer = hydra.utils.instantiate(self.config.job.logger)
simple_config = OmegaConf.to_container(
self.config, resolve=True, throw_on_missing=True
Expand All @@ -177,13 +184,14 @@ def _init_logger(self) -> None:
)

def checkpoint(self, *args, **kwargs) -> DelayedSubmission:
# TODO: this is yet to be tested properly
logging.info("Submitit checkpointing callback is triggered")
logging.error("Submitit checkpointing callback is triggered")
save_path = self.config.job.metadata.preemption_checkpoint_dir
self.runner.save_state(save_path)
logging.info("Submitit checkpointing callback is completed")
cfg_copy = self.config.copy()
cfg_copy.job.runner_state_path = save_path
if WandBSingletonLogger.initialized():
WandBSingletonLogger.get_instance().mark_preempting()
logging.info("Submitit checkpointing callback is completed")
return DelayedSubmission(Submitit(), cfg_copy)


Expand Down Expand Up @@ -240,7 +248,11 @@ def get_hydra_config_from_yaml(
config_directory = os.path.dirname(os.path.abspath(config_yml))
config_name = os.path.basename(config_yml)
hydra.initialize_config_dir(config_directory, version_base="1.1")
return hydra.compose(config_name=config_name, overrides=overrides_args)
cfg = hydra.compose(config_name=config_name, overrides=overrides_args)
# merge default structured config with initialized job object
cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg)
# canonicalize config (remove top level keys that just used replacing variables)
return get_canonical_config(cfg)


def runner_wrapper(config: DictConfig):
Expand All @@ -256,10 +268,6 @@ def main(
args, override_args = parser.parse_known_args()

cfg = get_hydra_config_from_yaml(args.config, override_args)
# merge default structured config with initialized job object
cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg)
# canonicalize config (remove top level keys that just used replacing variables)
cfg = get_canonical_config(cfg)
log_dir = cfg.job.metadata.log_dir
os.makedirs(cfg.job.run_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
Expand All @@ -283,7 +291,7 @@ def main(
slurm_qos=scheduler_cfg.slurm.qos,
slurm_account=scheduler_cfg.slurm.account,
)
job = executor.submit(runner_wrapper, cfg)
job = executor.submit(Submitit(), cfg)
logging.info(
f"Submitted job id: {cfg.job.timestamp_id}, slurm id: {job.job_id}, logs: {cfg.job.metadata.log_dir}"
)
Expand Down
15 changes: 14 additions & 1 deletion tests/core/test_hydra_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import hydra
import pytest

from fairchem.core._cli_hydra import ALLOWED_TOP_LEVEL_KEYS, main
from fairchem.core._cli_hydra import (
ALLOWED_TOP_LEVEL_KEYS,
get_hydra_config_from_yaml,
main,
)
from fairchem.core.common import distutils


Expand Down Expand Up @@ -58,3 +62,12 @@ def test_hydra_cli_throws_error_on_disallowed_top_level_keys():
sys.argv[1:] = sys_args
with pytest.raises(ValueError):
main()


def get_cfg_from_yaml():
yaml = "tests/core/test_hydra_cli.yml"
cfg = get_hydra_config_from_yaml(yaml)
# assert fields got initialized properly
assert cfg.job.run_name is not None
assert cfg.job.seed is not None
assert cfg.keys() == ALLOWED_TOP_LEVEL_KEYS

0 comments on commit 7bc606e

Please sign in to comment.