1
1
"""Helios model wrapper for fine-tuning in rslearn."""
2
2
3
3
import json
4
- import os
5
4
from contextlib import nullcontext
6
5
from typing import Any
7
6
12
11
from helios .train .masking import MaskedHeliosSample , MaskValue
13
12
from olmo_core .config import Config
14
13
from olmo_core .distributed .checkpoint import load_model_and_optim_state
14
+ from upath import UPath
15
15
16
16
from rslp .log_utils import get_logger
17
17
@@ -63,6 +63,7 @@ def __init__(
63
63
autocast_dtype: which dtype to use for autocasting, or set None to disable.
64
64
"""
65
65
super ().__init__ ()
66
+ _checkpoint_path = UPath (checkpoint_path )
66
67
self .forward_kwargs = forward_kwargs
67
68
self .embedding_size = embedding_size
68
69
self .patch_size = patch_size
@@ -75,17 +76,17 @@ def __init__(
75
76
# Load the model config and initialize it.
76
77
# We avoid loading the train module here because it depends on running within
77
78
# olmo_core.
78
- with open ( f" { checkpoint_path } / config.json" ) as f :
79
+ with ( _checkpoint_path / " config.json"). open ( ) as f :
79
80
config_dict = json .load (f )
80
81
model_config = Config .from_dict (config_dict ["model" ])
81
82
82
83
model = model_config .build ()
83
84
84
85
# Load the checkpoint.
85
86
if not random_initialization :
86
- train_module_dir = os . path . join ( checkpoint_path , "model_and_optim" )
87
- if os . path . exists (train_module_dir ):
88
- load_model_and_optim_state (train_module_dir , model )
87
+ train_module_dir = _checkpoint_path / "model_and_optim"
88
+ if train_module_dir . exists ():
89
+ load_model_and_optim_state (str ( train_module_dir ) , model )
89
90
logger .info (f"loaded helios encoder from { train_module_dir } " )
90
91
else :
91
92
logger .info (f"could not find helios encoder at { train_module_dir } " )
0 commit comments