Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions rslp/helios/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Helios model wrapper for fine-tuning in rslearn."""

import json
import os
from contextlib import nullcontext
from typing import Any

Expand All @@ -12,6 +11,7 @@
from helios.train.masking import MaskedHeliosSample, MaskValue
from olmo_core.config import Config
from olmo_core.distributed.checkpoint import load_model_and_optim_state
from upath import UPath

from rslp.log_utils import get_logger

Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(
autocast_dtype: which dtype to use for autocasting, or set None to disable.
"""
super().__init__()
_checkpoint_path = UPath(checkpoint_path)
self.forward_kwargs = forward_kwargs
self.embedding_size = embedding_size
self.patch_size = patch_size
Expand All @@ -75,16 +76,16 @@ def __init__(
# Load the model config and initialize it.
# We avoid loading the train module here because it depends on running within
# olmo_core.
with open(f"{checkpoint_path}/config.json") as f:
with (_checkpoint_path / "config.json").open() as f:
config_dict = json.load(f)
model_config = Config.from_dict(config_dict["model"])

model = model_config.build()

# Load the checkpoint.
if not random_initialization:
train_module_dir = os.path.join(checkpoint_path, "model_and_optim")
if os.path.exists(train_module_dir):
train_module_dir = _checkpoint_path / "model_and_optim"
if train_module_dir.exists():
load_model_and_optim_state(train_module_dir, model)
logger.info(f"loaded helios encoder from {train_module_dir}")
else:
Expand Down
Loading