Skip to content

Commit cef3f92

Browse files
hummuscienceMuad Abd El HayclaudeMuad Abd El Haythemattinthehatt
authored
Fix checkpoint compatibility for upsampling_layers (#314)
* Fix checkpoint compatibility for upsampling_layers Older checkpoints may have 'upsampling_layers' parameters without the 'head.' prefix, causing warnings when loading models after head refactoring. This fix remaps these keys during checkpoint loading to ensure backwards compatibility. Fixes warning: "Found keys that are not in the model state dict but in the checkpoint: ['upsampling_layers.1.weight', 'upsampling_layers.1.bias', ...]" 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Add fallback mechanism for checkpoint loading with weights_only=False - Implement try/catch blocks around torch.load() calls in three files - First attempts standard loading, falls back to weights_only=False on failure - Provides clear warning messages when fallback is used - Resolves pickle deserialization errors with older checkpoints - Maintains security by attempting safer method first 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * add more transformer tests --------- Co-authored-by: Muad Abd El Hay <[email protected]> Co-authored-by: Claude <[email protected]> Co-authored-by: Muad Abd El Hay <[email protected]> Co-authored-by: themattinthehatt <[email protected]>
1 parent 620459a commit cef3f92

File tree

5 files changed

+115
-47
lines changed

5 files changed

+115
-47
lines changed

lightning_pose/models/backbones/vits.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,13 @@ def build_backbone(backbone_arch: str, image_size: int = 256, **kwargs):
6363

6464
def load_vit_backbone_checkpoint(base, checkpoint: str):
6565
print(f"Loading VIT-MAE weights from {checkpoint}")
66-
ckpt_vit_pretrain = torch.load(checkpoint, map_location="cpu")
66+
# Try loading with default settings first, fallback to weights_only=False if needed
67+
try:
68+
ckpt_vit_pretrain = torch.load(checkpoint, map_location="cpu")
69+
except Exception as e:
70+
print(f"Warning: Failed to load checkpoint with default settings: {e}")
71+
print("Attempting to load with weights_only=False...")
72+
ckpt_vit_pretrain = torch.load(checkpoint, map_location="cpu", weights_only=False)
6773
# extract state dict if checkpoint contains additional info
6874
if "state_dict" in ckpt_vit_pretrain:
6975
ckpt_vit_pretrain = ckpt_vit_pretrain["state_dict"]

lightning_pose/utils/predictions.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -559,49 +559,55 @@ def load_model_from_checkpoint(
559559
map_type=cfg.model.model_type,
560560
semi_supervised=semi_supervised,
561561
)
562-
# initialize a model instance, with weights loaded from .ckpt file
563-
if cfg.model.backbone == "vitb_sam":
564-
# see https://github.com/paninski-lab/lightning-pose/issues/134 for explanation of this block
565-
from lightning_pose.utils.scripts import get_model
566-
567-
# load model first
568-
model = get_model(
569-
cfg,
570-
data_module=data_module,
571-
loss_factories=loss_factories,
562+
563+
# initialize a model instance, load weights from .ckpt file (fix state_dict keys if needed)
564+
try:
565+
checkpoint = torch.load(ckpt_file)
566+
except Exception as e:
567+
print(f"Warning: Failed to load checkpoint with default settings: {e}")
568+
print("Attempting to load with weights_only=False...")
569+
checkpoint = torch.load(ckpt_file, weights_only=False)
570+
state_dict = checkpoint.get("state_dict", checkpoint)
571+
572+
# fix state dict key mismatch for upsampling layers
573+
# old checkpoints may have 'upsampling_layers' without 'head.' prefix
574+
keys_remapped = False
575+
for key in list(state_dict.keys()):
576+
if key.startswith("upsampling_layers."):
577+
# Add 'head.' prefix if missing
578+
new_key = "head." + key
579+
state_dict[new_key] = state_dict.pop(key)
580+
keys_remapped = True
581+
582+
if keys_remapped:
583+
# save the fixed state dict back to checkpoint
584+
checkpoint["state_dict"] = state_dict
585+
# create a temporary file with the fixed checkpoint
586+
import tempfile
587+
with tempfile.NamedTemporaryFile(suffix='.ckpt', delete=False) as tmp_file:
588+
torch.save(checkpoint, tmp_file.name)
589+
fixed_ckpt_file = tmp_file.name
590+
else:
591+
fixed_ckpt_file = ckpt_file
592+
593+
if semi_supervised:
594+
model = ModelClass.load_from_checkpoint(
595+
fixed_ckpt_file,
596+
loss_factory=loss_factories["supervised"],
597+
loss_factory_unsupervised=loss_factories["unsupervised"],
598+
strict=False,
572599
)
573-
# # update model parameter
574-
# if model.backbone.pos_embed is not None:
575-
# # re-initialize absolute positional embedding with *finetune* image size.
576-
# finetune_img_size = cfg.data.image_resize_dims.height
577-
# patch_size = model.backbone.patch_size
578-
# embed_dim = 768 # value from lightning_pose.models.backbones.vits.build_backbone
579-
# model.backbone.pos_embed = torch.nn.Parameter(
580-
# torch.zeros(
581-
# 1,
582-
# finetune_img_size // patch_size,
583-
# finetune_img_size // patch_size,
584-
# embed_dim,
585-
# )
586-
# )
587-
# load weights
588-
state_dict = torch.load(ckpt_file)["state_dict"]
589-
# put weights into model
590-
model.load_state_dict(state_dict, strict=False)
591600
else:
592-
if semi_supervised:
593-
model = ModelClass.load_from_checkpoint(
594-
ckpt_file,
595-
loss_factory=loss_factories["supervised"],
596-
loss_factory_unsupervised=loss_factories["unsupervised"],
597-
strict=False,
598-
)
599-
else:
600-
model = ModelClass.load_from_checkpoint(
601-
ckpt_file,
602-
loss_factory=loss_factories["supervised"],
603-
strict=False,
604-
)
601+
model = ModelClass.load_from_checkpoint(
602+
fixed_ckpt_file,
603+
loss_factory=loss_factories["supervised"],
604+
strict=False,
605+
)
606+
607+
# clean up temporary file if created
608+
if keys_remapped:
609+
import os
610+
os.unlink(fixed_ckpt_file)
605611

606612
if eval:
607613
model.eval()

lightning_pose/utils/scripts.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,13 @@ def get_model(
489489
if not ckpt.endswith(".ckpt"):
490490
import glob
491491
ckpt = glob.glob(os.path.join(ckpt, "**", "*.ckpt"), recursive=True)[0]
492-
state_dict = torch.load(ckpt)["state_dict"]
492+
# Try loading with default settings first, fallback to weights_only=False if needed
493+
try:
494+
state_dict = torch.load(ckpt)["state_dict"]
495+
except Exception as e:
496+
print(f"Warning: Failed to load checkpoint with default settings: {e}")
497+
print("Attempting to load with weights_only=False...")
498+
state_dict = torch.load(ckpt, weights_only=False)["state_dict"]
493499
# try loading all weights
494500
try:
495501
model.load_state_dict(state_dict, strict=False)

tests/conftest.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,14 @@ def video_dataloader(cfg, base_dataset, video_list) -> LitDaliWrapper:
523523
def trainer(cfg) -> pl.Trainer:
524524
"""Create a basic pytorch lightning trainer for testing models."""
525525

526-
cfg.training.unfreezing_epoch = 1 # exercise unfreezing
527-
callbacks = get_callbacks(cfg, early_stopping=False, lr_monitor=False, backbone_unfreeze=True, checkpointing=False)
526+
cfg.training.unfreezing_epoch = 1 # exercise unfreezing
527+
callbacks = get_callbacks(
528+
cfg,
529+
early_stopping=False,
530+
lr_monitor=False,
531+
backbone_unfreeze=True,
532+
checkpointing=False,
533+
)
528534

529535
trainer = pl.Trainer(
530536
accelerator="gpu",
@@ -534,7 +540,7 @@ def trainer(cfg) -> pl.Trainer:
534540
check_val_every_n_epoch=1,
535541
log_every_n_steps=1,
536542
callbacks=callbacks,
537-
enable_checkpointing = False,
543+
enable_checkpointing=False,
538544
limit_train_batches=2,
539545
num_sanity_val_steps=0,
540546
logger=False,

tests/models/test_heatmap_tracker.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_supervised_heatmap(
2424
)
2525

2626

27-
def test_supervised_heatmap_vit_sam(
27+
def test_supervised_heatmap_vitb_sam(
2828
cfg,
2929
heatmap_data_module,
3030
video_dataloader,
@@ -46,6 +46,50 @@ def test_supervised_heatmap_vit_sam(
4646
)
4747

4848

49+
def test_supervised_heatmap_vitb_imagenet(
50+
cfg,
51+
heatmap_data_module,
52+
video_dataloader,
53+
trainer,
54+
run_model_test,
55+
):
56+
"""Test the initialization and training of a supervised heatmap model."""
57+
58+
cfg_tmp = copy.deepcopy(cfg)
59+
cfg_tmp.model.model_type = "heatmap"
60+
cfg_tmp.model.backbone = "vitb_imagenet"
61+
cfg_tmp.model.losses_to_use = []
62+
63+
run_model_test(
64+
cfg=cfg_tmp,
65+
data_module=heatmap_data_module,
66+
video_dataloader=video_dataloader,
67+
trainer=trainer,
68+
)
69+
70+
71+
def test_supervised_heatmap_vits_dino(
72+
cfg,
73+
heatmap_data_module,
74+
video_dataloader,
75+
trainer,
76+
run_model_test,
77+
):
78+
"""Test the initialization and training of a supervised heatmap model."""
79+
80+
cfg_tmp = copy.deepcopy(cfg)
81+
cfg_tmp.model.model_type = "heatmap"
82+
cfg_tmp.model.backbone = "vits_dino"
83+
cfg_tmp.model.losses_to_use = []
84+
85+
run_model_test(
86+
cfg=cfg_tmp,
87+
data_module=heatmap_data_module,
88+
video_dataloader=video_dataloader,
89+
trainer=trainer,
90+
)
91+
92+
4993
def test_supervised_multiview_heatmap(
5094
cfg_multiview,
5195
multiview_heatmap_data_module,

0 commit comments

Comments
 (0)