Skip to content

Commit

Permalink
finetuning with hydra config
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu user committed Feb 15, 2025
1 parent a94109f commit d015afe
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
32 changes: 26 additions & 6 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import TYPE_CHECKING, Any
from uuid import uuid4

import hydra
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -1433,18 +1434,37 @@ def update_config(base_config):
return config


def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module:
def load_model_and_weights_from_checkpoint(
checkpoint_path: str, overrides : dict|None = None) -> nn.Module:

if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
# this assumes the checkpont also contains the config with the full model in it
# TODO: need to schematize how we save and load the config from checkpoint
config = checkpoint["config"]["model"]
name = config.pop("name")
model = registry.get_model_class(name)(**config)

if "model_config" in checkpoint:
model_config = checkpoint["model_config"]
elif "config" in checkpoint:
model_config = checkpoint["config"]["model"]
else:
raise ValueError("Model config not found in checkpoint")

if overrides is not None:
for k, v in overrides.items():
if k not in model_config["backbone"]:
raise ValueError(f"Override key <{k}> not found in model config")
orig_v = model_config["backbone"][k]
logging.info(f"Overriding {k}: <{orig_v}> to <{v}>")
model_config["backbone"][k] = v

if "model_config" in checkpoint:
model = hydra.utils.instantiate(model_config)
elif "config" in checkpoint:
name = model_config.pop("name")
model = registry.get_model_class(name)(**model_config)

matched_dict = match_state_dict(model.state_dict(), checkpoint["state_dict"])
load_state_dict(model, matched_dict, strict=True)
return model
Expand Down
2 changes: 1 addition & 1 deletion src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(
starting_model = None
if finetune_config is not None:
starting_model: HydraModel = load_model_and_weights_from_checkpoint(
finetune_config["starting_checkpoint"]
finetune_config["starting_checkpoint"], finetune_config.get("overrides", None)
)
logging.info(
f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)"
Expand Down

0 comments on commit d015afe

Please sign in to comment.