Skip to content

Commit

Permalink
add backbone override to hydra (#997)
Browse files Browse the repository at this point in the history
* add backbone override to hydra

* make it harder to sneak fields into finetune config

* fix

* fix lint

* add a test to make sure override works

* fix override test by adding cleanup after assertion is caught
  • Loading branch information
misko authored Feb 20, 2025
1 parent 680a3a4 commit 56d77c5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ def __init__(
# if finetune_config is provided, then attempt to load the model from the given finetune checkpoint
starting_model = None
if finetune_config is not None:
# Make it hard to sneak more fields into finetuneconfig
assert (
len(set(finetune_config.keys()) - {"starting_checkpoint", "override"})
== 0
)
starting_model: HydraModel = load_model_and_weights_from_checkpoint(
finetune_config["starting_checkpoint"]
)
Expand All @@ -260,6 +265,10 @@ def __init__(
assert isinstance(
starting_model, HydraModel
), "Can only finetune starting from other hydra models!"
# TODO this is a bit hacky to overrride attrs in the backbone
if "override" in finetune_config:
for key, value in finetune_config["override"].items():
setattr(starting_model.backbone, key, value)

if backbone is not None:
backbone = copy.deepcopy(backbone)
Expand Down
34 changes: 34 additions & 0 deletions tests/core/e2e/test_e2e_finetune_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths

from fairchem.core.common import distutils
from fairchem.core.scripts.convert_hydra_to_release import convert_fine_tune_checkpoint


Expand Down Expand Up @@ -283,6 +284,39 @@ def test_finetune_hydra_retain_backbone(tutorial_val_src):
)


def test_finetune_hydra_override(tutorial_val_src):
with tempfile.TemporaryDirectory() as orig_ckpt_dir:
starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)
# now finetune a the model with the checkpoint from the first job
with tempfile.TemporaryDirectory() as ft_temp_dir:
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt")
model_config = {
"name": "hydra",
"finetune_config": {
"starting_checkpoint": starting_ckpt,
"override": {"forward": None},
},
"heads": {
"energy": {"module": "equiformer_v2_energy_head"},
"forces": {"module": "equiformer_v2_force_head"},
},
}

# TODO add a better test for override when we get there
# for now just override .forward() with None
with pytest.raises(TypeError):
run_main_with_ft_hydra(
tempdir=ft_temp_dir,
yaml=ft_yml,
data_src=tutorial_val_src,
run_args={"seed": 1000},
model_config=model_config,
output_checkpoint=ck_ft_path,
)
distutils.cleanup()


def test_finetune_hydra_data_only(tutorial_val_src):
with tempfile.TemporaryDirectory() as orig_ckpt_dir:
starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)
Expand Down

0 comments on commit 56d77c5

Please sign in to comment.