Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions rslp/helios/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Helios model wrapper for fine-tuning in rslearn."""

import json
import os
from contextlib import nullcontext
from typing import Any

Expand All @@ -12,6 +11,7 @@
from helios.train.masking import MaskedHeliosSample, MaskValue
from olmo_core.config import Config
from olmo_core.distributed.checkpoint import load_model_and_optim_state
from upath import UPath

from rslp.log_utils import get_logger

Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(
autocast_dtype: which dtype to use for autocasting, or set None to disable.
"""
super().__init__()
_checkpoint_path = UPath(checkpoint_path)
self.forward_kwargs = forward_kwargs
self.embedding_size = embedding_size
self.patch_size = patch_size
Expand All @@ -75,17 +76,17 @@ def __init__(
# Load the model config and initialize it.
# We avoid loading the train module here because it depends on running within
# olmo_core.
with open(f"{checkpoint_path}/config.json") as f:
with (_checkpoint_path / "config.json").open() as f:
config_dict = json.load(f)
model_config = Config.from_dict(config_dict["model"])

model = model_config.build()

# Load the checkpoint.
if not random_initialization:
train_module_dir = os.path.join(checkpoint_path, "model_and_optim")
if os.path.exists(train_module_dir):
load_model_and_optim_state(train_module_dir, model)
train_module_dir = _checkpoint_path / "model_and_optim"
if train_module_dir.exists():
load_model_and_optim_state(str(train_module_dir), model)
logger.info(f"loaded helios encoder from {train_module_dir}")
else:
logger.info(f"could not find helios encoder at {train_module_dir}")
Expand Down
Loading