Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strictly enforce no unused keys in yaml #996

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import hydra
from omegaconf import OmegaConf
from omegaconf.errors import InterpolationKeyError

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand All @@ -35,6 +36,9 @@
logging.basicConfig(level=logging.INFO)


ALLOWED_TOP_LEVEL_KEYS = {"job", "runner"}


class SchedulerType(str, Enum):
LOCAL = "local"
SLURM = "slurm"
Expand Down Expand Up @@ -74,6 +78,7 @@ class JobConfig:
scheduler: SchedulerConfig = field(default_factory=lambda: SchedulerConfig)
log_dir_name: str = "logs"
checkpoint_dir_name: str = "checkpoint"
config_file_name: str = "canonical_config.yaml"
logger: Optional[dict] = None # noqa: UP007 python 3.9 requires Optional still

@property
Expand All @@ -84,6 +89,10 @@ def log_dir(self) -> str:
def checkpoint_dir(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.checkpoint_dir_name)

@property
def config_path(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.config_file_name)


class Submitit(Checkpointable):
def __call__(self, dict_config: DictConfig) -> None:
Expand Down Expand Up @@ -142,6 +151,33 @@ def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict:
}


def get_canonical_config(config: DictConfig) -> DictConfig:
# check that each key other than the allowed top level keys are used in config
# find all top level keys are not in the allowed set
all_keys = set(config.keys()).difference(ALLOWED_TOP_LEVEL_KEYS)
used_keys = set()
for key in all_keys:
# make a copy of all keys except the key in question
copy_cfg = OmegaConf.create({k: v for k, v in config.items() if k != key})
try:
OmegaConf.resolve(copy_cfg)
except InterpolationKeyError:
# if this error is thrown, this means the key was actually required
used_keys.add(key)

unused_keys = all_keys.difference(used_keys)
if unused_keys != set():
raise ValueError(
f"Found unused keys in the config: {unused_keys}, please remove them!, only keys other than {ALLOWED_TOP_LEVEL_KEYS} or ones that are used as variables are allowed."
)

# resolve the config to fully replace the variables and delete all top level keys except for the ALLOWED_TOP_LEVEL_KEYS
OmegaConf.resolve(config)
return OmegaConf.create(
{k: v for k, v in config.items() if k in ALLOWED_TOP_LEVEL_KEYS}
)


def get_hydra_config_from_yaml(
config_yml: str, overrides_args: list[str]
) -> DictConfig:
Expand All @@ -166,20 +202,24 @@ def main(
args, override_args = parser.parse_known_args()

cfg = get_hydra_config_from_yaml(args.config, override_args)
# merge default structured config with job
# merge default structured config with initialized job object
cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg)
log_dir = OmegaConf.to_object(cfg.job).log_dir
os.makedirs(cfg.job.run_dir, exist_ok=True)
# canonicalize config (remove top level keys that just used replacing variables)
cfg = get_canonical_config(cfg)
job_obj = OmegaConf.to_object(cfg.job)
log_dir = job_obj.log_dir
os.makedirs(job_obj.run_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

job_cfg = cfg.job
scheduler_cfg = cfg.job.scheduler
OmegaConf.save(cfg, job_obj.config_path)
logging.info(f"saved canonical config to {job_obj.config_path}")

scheduler_cfg = job_obj.scheduler
logging.info(f"Running fairchemv2 cli with {cfg}")
if scheduler_cfg.mode == SchedulerType.SLURM: # Run on cluster
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3)
executor.update_parameters(
name=job_cfg.run_name,
name=job_obj.run_name,
mem_gb=scheduler_cfg.slurm.mem_gb,
timeout_min=scheduler_cfg.slurm.timeout_hr * 60,
slurm_partition=scheduler_cfg.slurm.partition,
Expand All @@ -192,13 +232,13 @@ def main(
)
job = executor.submit(runner_wrapper, cfg)
logging.info(
f"Submitted job id: {job_cfg.timestamp_id}, slurm id: {job.job_id}, logs: {job_cfg.log_dir}"
f"Submitted job id: {job_obj.timestamp_id}, slurm id: {job.job_id}, logs: {job_obj.log_dir}"
)
else:
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

if scheduler_cfg.ranks_per_node > 1:
logging.info(f"Running in local mode with {job_cfg.ranks_per_node} ranks")
logging.info(f"Running in local mode with {job_obj.ranks_per_node} ranks")
launch_config = LaunchConfig(
min_nodes=1,
max_nodes=1,
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/components/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def load_state(self) -> None:

# Used for testing
class MockRunner(Runner):
def __init__(self, x: int, y: int):
def __init__(self, x: int, y: int, z: int):
self.x = x
self.y = y
self.z = z

def run(self) -> Any:
if self.x * self.y > 1000:
Expand Down
18 changes: 16 additions & 2 deletions tests/core/test_hydra_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hydra
import pytest

from fairchem.core._cli_hydra import main
from fairchem.core._cli_hydra import ALLOWED_TOP_LEVEL_KEYS, main
from fairchem.core.common import distutils


Expand Down Expand Up @@ -39,8 +39,22 @@ def test_hydra_cli_throws_error_on_invalid_inputs():
"-c",
"tests/core/test_hydra_cli.yml",
"runner.x=1000",
"runner.z=5", # z is not a valid input argument to runner
"runner.a=5", # a is not a valid input argument to runner
]
sys.argv[1:] = sys_args
with pytest.raises(hydra.errors.ConfigCompositionException):
main()


def test_hydra_cli_throws_error_on_disallowed_top_level_keys():
distutils.cleanup()
hydra.core.global_hydra.GlobalHydra.instance().clear()
assert "x" not in ALLOWED_TOP_LEVEL_KEYS
sys_args = [
"-c",
"tests/core/test_hydra_cli.yml",
"+x=1000", # this is not allowed because we are adding a key that is not in ALLOWED_TOP_LEVEL_KEYS
]
sys.argv[1:] = sys_args
with pytest.raises(ValueError):
main()
3 changes: 3 additions & 0 deletions tests/core/test_hydra_cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ job:
scheduler:
mode: LOCAL

replacement_var: 5

runner:
_target_: fairchem.core.components.runner.MockRunner
x: 10
y: 23
z: ${replacement_var}
Loading