Skip to content

Commit

Permalink
modify runner and cli for state checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
rgao committed Feb 11, 2025
1 parent 7ef2813 commit 2716f92
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
36 changes: 26 additions & 10 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@

ALLOWED_TOP_LEVEL_KEYS = {"job", "runner"}

LOG_DIR_NAME = "logs"
CHECKPOINT_DIR_NAME = "checkpoints"
CONFIG_FILE_NAME = "canonical_config.yaml"
PREEMPTION_STATE_DIR_NAME = "preemption_state"


class SchedulerType(str, Enum):
LOCAL = "local"
Expand Down Expand Up @@ -79,24 +84,31 @@ class JobConfig:
device_type: DeviceType = DeviceType.CUDA
debug: bool = False
scheduler: SchedulerConfig = field(default_factory=lambda: SchedulerConfig)
log_dir_name: str = "logs"
checkpoint_dir_name: str = "checkpoint"
config_file_name: str = "canonical_config.yaml"
logger: Optional[dict] = None # noqa: UP007 python 3.9 requires Optional still
seed: int = 0
deterministic: bool = False
runner_state: Optional[str] = None # noqa: UP007 python 3.9 requires Optional still

@property
def log_dir(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.log_dir_name)
return os.path.join(self.run_dir, self.timestamp_id, LOG_DIR_NAME)

@property
def checkpoint_dir(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.checkpoint_dir_name)
return os.path.join(self.run_dir, self.timestamp_id, CHECKPOINT_DIR_NAME)

@property
def preemption_checkpoint_dir(self) -> str:
return os.path.join(
self.run_dir,
self.timestamp_id,
self.checkpoint_dir_name,
PREEMPTION_STATE_DIR_NAME,
)

@property
def config_path(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.config_file_name)
return os.path.join(self.run_dir, self.timestamp_id, CONFIG_FILE_NAME)


def _set_seeds(seed: int) -> None:
Expand Down Expand Up @@ -127,8 +139,10 @@ def __call__(self, dict_config: DictConfig) -> None:
_set_deterministic_mode()

runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.job_config = self.job_config
runner.load_state()
runner.config = self.config
# must call resume state AFTER the runner has been initialized
if self.job_config.runner_state:
runner.load_state(self.job_config.runner_state)
runner.run()
distutils.cleanup()

Expand All @@ -155,9 +169,11 @@ def checkpoint(self, *args, **kwargs) -> DelayedSubmission:
# TODO: this is yet to be tested properly
logging.info("Submitit checkpointing callback is triggered")
new_runner = Submitit()
self.runner.save_state()
self.runner.save_state(self.job_config.preemption_checkpoint_dir)
logging.info("Submitit checkpointing callback is completed")
return DelayedSubmission(new_runner, self.config, self.cli_args)
cfg_copy = self.config.copy()
cfg_copy.job_config.runner_state = self.job_config.preemption_checkpoint_dir
return DelayedSubmission(new_runner, cfg_copy, self.cli_args)


def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict:
Expand Down
16 changes: 7 additions & 9 deletions src/fairchem/core/components/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
if TYPE_CHECKING:
from omegaconf import DictConfig

from fairchem.core._cli_hydra import JobConfig


class Runner(metaclass=ABCMeta):
"""
Expand All @@ -17,23 +15,23 @@ class Runner(metaclass=ABCMeta):
"""

@property
def job_config(self) -> JobConfig:
return self._job_config
def config(self) -> DictConfig:
return self._config

@job_config.setter
def job_config(self, cfg: DictConfig):
self._job_config = cfg
@config.setter
def config(self, cfg: DictConfig):
self._config = cfg

@abstractmethod
def run(self) -> Any:
raise NotImplementedError

@abstractmethod
def save_state(self) -> None:
def save_state(self, checkpoint_location: str) -> None:
raise NotImplementedError

@abstractmethod
def load_state(self) -> None:
def load_state(self, checkpoint_location: str) -> None:
raise NotImplementedError


Expand Down

0 comments on commit 2716f92

Please sign in to comment.