From 2716f92abc3ce62fcefd02ce766e806c02a21140 Mon Sep 17 00:00:00 2001 From: rgao Date: Tue, 11 Feb 2025 06:59:12 +0000 Subject: [PATCH] modify runner and cli for state checkpointing --- src/fairchem/core/_cli_hydra.py | 36 +++++++++++++++++++------- src/fairchem/core/components/runner.py | 16 +++++------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 6c6906832..086cce976 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -40,6 +40,11 @@ ALLOWED_TOP_LEVEL_KEYS = {"job", "runner"} +LOG_DIR_NAME = "logs" +CHECKPOINT_DIR_NAME = "checkpoints" +CONFIG_FILE_NAME = "canonical_config.yaml" +PREEMPTION_STATE_DIR_NAME = "preemption_state" + class SchedulerType(str, Enum): LOCAL = "local" @@ -79,24 +84,31 @@ class JobConfig: device_type: DeviceType = DeviceType.CUDA debug: bool = False scheduler: SchedulerConfig = field(default_factory=lambda: SchedulerConfig) - log_dir_name: str = "logs" - checkpoint_dir_name: str = "checkpoint" - config_file_name: str = "canonical_config.yaml" logger: Optional[dict] = None # noqa: UP007 python 3.9 requires Optional still seed: int = 0 deterministic: bool = False + runner_state: Optional[str] = None # noqa: UP007 python 3.9 requires Optional still @property def log_dir(self) -> str: - return os.path.join(self.run_dir, self.timestamp_id, self.log_dir_name) + return os.path.join(self.run_dir, self.timestamp_id, LOG_DIR_NAME) @property def checkpoint_dir(self) -> str: - return os.path.join(self.run_dir, self.timestamp_id, self.checkpoint_dir_name) + return os.path.join(self.run_dir, self.timestamp_id, CHECKPOINT_DIR_NAME) + + @property + def preemption_checkpoint_dir(self) -> str: + return os.path.join( + self.run_dir, + self.timestamp_id, + self.checkpoint_dir_name, + PREEMPTION_STATE_DIR_NAME, + ) @property def config_path(self) -> str: - return os.path.join(self.run_dir, self.timestamp_id, self.config_file_name) + return os.path.join(self.run_dir, self.timestamp_id, CONFIG_FILE_NAME) def _set_seeds(seed: int) -> None: @@ -127,8 +139,10 @@ def __call__(self, dict_config: DictConfig) -> None: _set_deterministic_mode() runner: Runner = hydra.utils.instantiate(dict_config.runner) - runner.job_config = self.job_config - runner.load_state() + runner.config = self.config + # must call resume state AFTER the runner has been initialized + if self.job_config.runner_state: + runner.load_state(self.job_config.runner_state) runner.run() distutils.cleanup() @@ -155,9 +169,11 @@ def checkpoint(self, *args, **kwargs) -> DelayedSubmission: # TODO: this is yet to be tested properly logging.info("Submitit checkpointing callback is triggered") new_runner = Submitit() - self.runner.save_state() + self.runner.save_state(self.job_config.preemption_checkpoint_dir) logging.info("Submitit checkpointing callback is completed") - return DelayedSubmission(new_runner, self.config, self.cli_args) + cfg_copy = self.config.copy() + cfg_copy.job_config.runner_state = self.job_config.preemption_checkpoint_dir + return DelayedSubmission(new_runner, cfg_copy, self.cli_args) def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict: diff --git a/src/fairchem/core/components/runner.py b/src/fairchem/core/components/runner.py index f25f6076e..7bff496c2 100644 --- a/src/fairchem/core/components/runner.py +++ b/src/fairchem/core/components/runner.py @@ -6,8 +6,6 @@ if TYPE_CHECKING: from omegaconf import DictConfig - from fairchem.core._cli_hydra import JobConfig - class Runner(metaclass=ABCMeta): """ @@ -17,23 +15,23 @@ class Runner(metaclass=ABCMeta): """ @property - def job_config(self) -> JobConfig: - return self._job_config + def config(self) -> DictConfig: + return self._config - @job_config.setter - def job_config(self, cfg: DictConfig): - self._job_config = cfg + @config.setter + def config(self, cfg: DictConfig): + self._config = cfg @abstractmethod def run(self) -> Any: raise NotImplementedError @abstractmethod - def save_state(self) -> None: + def save_state(self, checkpoint_location: str) -> None: raise NotImplementedError @abstractmethod - def load_state(self) -> None: + def load_state(self, checkpoint_location: str) -> None: raise NotImplementedError