Skip to content

Commit

Permalink
Merge branch 'main' into fix_openneb_load
Browse files Browse the repository at this point in the history
  • Loading branch information
misko authored Feb 5, 2025
2 parents 9085f48 + f4c510c commit b6e4404
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 12 deletions.
2 changes: 1 addition & 1 deletion packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [

[project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev]
dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "umap-learn", "vdict"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "astroid<4", "umap-learn", "vdict"]
adsorbml = ["dscribe","x3dase","scikit-image"]

[project.scripts]
Expand Down
56 changes: 48 additions & 8 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import hydra
from omegaconf import OmegaConf
from omegaconf.errors import InterpolationKeyError

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand All @@ -35,6 +36,9 @@
logging.basicConfig(level=logging.INFO)


ALLOWED_TOP_LEVEL_KEYS = {"job", "runner"}


class SchedulerType(str, Enum):
LOCAL = "local"
SLURM = "slurm"
Expand Down Expand Up @@ -74,6 +78,7 @@ class JobConfig:
scheduler: SchedulerConfig = field(default_factory=lambda: SchedulerConfig)
log_dir_name: str = "logs"
checkpoint_dir_name: str = "checkpoint"
config_file_name: str = "canonical_config.yaml"
logger: Optional[dict] = None # noqa: UP007 python 3.9 requires Optional still

@property
Expand All @@ -84,6 +89,10 @@ def log_dir(self) -> str:
def checkpoint_dir(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.checkpoint_dir_name)

@property
def config_path(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.config_file_name)


class Submitit(Checkpointable):
def __call__(self, dict_config: DictConfig) -> None:
Expand Down Expand Up @@ -142,6 +151,33 @@ def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict:
}


def get_canonical_config(config: DictConfig) -> DictConfig:
# check that each key other than the allowed top level keys are used in config
# find all top level keys are not in the allowed set
all_keys = set(config.keys()).difference(ALLOWED_TOP_LEVEL_KEYS)
used_keys = set()
for key in all_keys:
# make a copy of all keys except the key in question
copy_cfg = OmegaConf.create({k: v for k, v in config.items() if k != key})
try:
OmegaConf.resolve(copy_cfg)
except InterpolationKeyError:
# if this error is thrown, this means the key was actually required
used_keys.add(key)

unused_keys = all_keys.difference(used_keys)
if unused_keys != set():
raise ValueError(
f"Found unused keys in the config: {unused_keys}, please remove them!, only keys other than {ALLOWED_TOP_LEVEL_KEYS} or ones that are used as variables are allowed."
)

# resolve the config to fully replace the variables and delete all top level keys except for the ALLOWED_TOP_LEVEL_KEYS
OmegaConf.resolve(config)
return OmegaConf.create(
{k: v for k, v in config.items() if k in ALLOWED_TOP_LEVEL_KEYS}
)


def get_hydra_config_from_yaml(
config_yml: str, overrides_args: list[str]
) -> DictConfig:
Expand All @@ -166,20 +202,24 @@ def main(
args, override_args = parser.parse_known_args()

cfg = get_hydra_config_from_yaml(args.config, override_args)
# merge default structured config with job
# merge default structured config with initialized job object
cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg)
log_dir = OmegaConf.to_object(cfg.job).log_dir
os.makedirs(cfg.job.run_dir, exist_ok=True)
# canonicalize config (remove top level keys that just used replacing variables)
cfg = get_canonical_config(cfg)
job_obj = OmegaConf.to_object(cfg.job)
log_dir = job_obj.log_dir
os.makedirs(job_obj.run_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

job_cfg = cfg.job
scheduler_cfg = cfg.job.scheduler
OmegaConf.save(cfg, job_obj.config_path)
logging.info(f"saved canonical config to {job_obj.config_path}")

scheduler_cfg = job_obj.scheduler
logging.info(f"Running fairchemv2 cli with {cfg}")
if scheduler_cfg.mode == SchedulerType.SLURM: # Run on cluster
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3)
executor.update_parameters(
name=job_cfg.run_name,
name=job_obj.run_name,
mem_gb=scheduler_cfg.slurm.mem_gb,
timeout_min=scheduler_cfg.slurm.timeout_hr * 60,
slurm_partition=scheduler_cfg.slurm.partition,
Expand All @@ -192,13 +232,13 @@ def main(
)
job = executor.submit(runner_wrapper, cfg)
logging.info(
f"Submitted job id: {job_cfg.timestamp_id}, slurm id: {job.job_id}, logs: {job_cfg.log_dir}"
f"Submitted job id: {job_obj.timestamp_id}, slurm id: {job.job_id}, logs: {job_obj.log_dir}"
)
else:
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

if scheduler_cfg.ranks_per_node > 1:
logging.info(f"Running in local mode with {job_cfg.ranks_per_node} ranks")
logging.info(f"Running in local mode with {job_obj.ranks_per_node} ranks")
launch_config = LaunchConfig(
min_nodes=1,
max_nodes=1,
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/components/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def load_state(self) -> None:

# Used for testing
class MockRunner(Runner):
def __init__(self, x: int, y: int):
def __init__(self, x: int, y: int, z: int):
self.x = x
self.y = y
self.z = z

def run(self) -> Any:
if self.x * self.y > 1000:
Expand Down
18 changes: 16 additions & 2 deletions tests/core/test_hydra_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hydra
import pytest

from fairchem.core._cli_hydra import main
from fairchem.core._cli_hydra import ALLOWED_TOP_LEVEL_KEYS, main
from fairchem.core.common import distutils


Expand Down Expand Up @@ -39,8 +39,22 @@ def test_hydra_cli_throws_error_on_invalid_inputs():
"-c",
"tests/core/test_hydra_cli.yml",
"runner.x=1000",
"runner.z=5", # z is not a valid input argument to runner
"runner.a=5", # a is not a valid input argument to runner
]
sys.argv[1:] = sys_args
with pytest.raises(hydra.errors.ConfigCompositionException):
main()


def test_hydra_cli_throws_error_on_disallowed_top_level_keys():
distutils.cleanup()
hydra.core.global_hydra.GlobalHydra.instance().clear()
assert "x" not in ALLOWED_TOP_LEVEL_KEYS
sys_args = [
"-c",
"tests/core/test_hydra_cli.yml",
"+x=1000", # this is not allowed because we are adding a key that is not in ALLOWED_TOP_LEVEL_KEYS
]
sys.argv[1:] = sys_args
with pytest.raises(ValueError):
main()
3 changes: 3 additions & 0 deletions tests/core/test_hydra_cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ job:
scheduler:
mode: LOCAL

replacement_var: 5

runner:
_target_: fairchem.core.components.runner.MockRunner
x: 10
y: 23
z: ${replacement_var}

0 comments on commit b6e4404

Please sign in to comment.