Skip to content

Commit 97e4a32

Browse files
committed
build error fix
1 parent cec65f5 commit 97e4a32

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,15 +364,22 @@ def on_train_batch_end(
364364

365365
# For manual optimization, we save the model state that was captured in training_step
366366
# before the optimizer step. The test case saves this state in model.saved_models.
367-
if hasattr(pl_module, "saved_models") and pl_module.saved_models:
367+
if hasattr(pl_module, "saved_models") and pl_module.saved_models and hasattr(pl_module, "layer"):
368368
latest_step = max(pl_module.saved_models.keys())
369369
# Save the checkpoint with the pre-optimization state
370370
with torch.no_grad():
371371
# Save the current state
372+
if not isinstance(pl_module.layer, torch.nn.Module):
373+
raise TypeError("pl_module.layer must be a torch.nn.Module for state dict operations")
374+
372375
original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()}
373376
try:
374377
# Restore the pre-optimization state
375-
pl_module.layer.load_state_dict(pl_module.saved_models[latest_step])
378+
saved_state = pl_module.saved_models[latest_step]
379+
if not isinstance(saved_state, dict):
380+
raise TypeError("Saved model state must be a dictionary")
381+
382+
pl_module.layer.load_state_dict(saved_state)
376383
# Save the checkpoint
377384
self._save_topk_checkpoint(trainer, monitor_candidates)
378385
self._save_last_checkpoint(trainer, monitor_candidates)

tests/tests_pytorch/callbacks/test_model_checkpoint_manual_opt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import torch
66
from torch.utils.data import DataLoader, Dataset
77

8-
from lightning import Trainer
9-
from lightning.pytorch import LightningModule
8+
from lightning.pytorch import LightningModule, Trainer
109
from lightning.pytorch.callbacks import ModelCheckpoint
1110

1211

0 commit comments

Comments
 (0)