Skip to content

Commit cb78159

Browse files
authored
Merge pull request #201 from allenai/josh/helios-upath
Use UPath in helios/model.py
2 parents d22e8a9 + 94ea40a commit cb78159

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

rslp/helios/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Helios model wrapper for fine-tuning in rslearn."""
22

33
import json
4-
import os
54
from contextlib import nullcontext
65
from typing import Any
76

@@ -12,6 +11,7 @@
1211
from helios.train.masking import MaskedHeliosSample, MaskValue
1312
from olmo_core.config import Config
1413
from olmo_core.distributed.checkpoint import load_model_and_optim_state
14+
from upath import UPath
1515

1616
from rslp.log_utils import get_logger
1717

@@ -63,6 +63,7 @@ def __init__(
6363
autocast_dtype: which dtype to use for autocasting, or set None to disable.
6464
"""
6565
super().__init__()
66+
_checkpoint_path = UPath(checkpoint_path)
6667
self.forward_kwargs = forward_kwargs
6768
self.embedding_size = embedding_size
6869
self.patch_size = patch_size
@@ -75,17 +76,17 @@ def __init__(
7576
# Load the model config and initialize it.
7677
# We avoid loading the train module here because it depends on running within
7778
# olmo_core.
78-
with open(f"{checkpoint_path}/config.json") as f:
79+
with (_checkpoint_path / "config.json").open() as f:
7980
config_dict = json.load(f)
8081
model_config = Config.from_dict(config_dict["model"])
8182

8283
model = model_config.build()
8384

8485
# Load the checkpoint.
8586
if not random_initialization:
86-
train_module_dir = os.path.join(checkpoint_path, "model_and_optim")
87-
if os.path.exists(train_module_dir):
88-
load_model_and_optim_state(train_module_dir, model)
87+
train_module_dir = _checkpoint_path / "model_and_optim"
88+
if train_module_dir.exists():
89+
load_model_and_optim_state(str(train_module_dir), model)
8990
logger.info(f"loaded helios encoder from {train_module_dir}")
9091
else:
9192
logger.info(f"could not find helios encoder at {train_module_dir}")

0 commit comments

Comments
 (0)