Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))


- Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246))


---

## [2.5.5] - 2025-09-05
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,11 @@ def _parse_ckpt_path(self) -> None:
hparams.pop("_instantiator", None)
if not hparams:
return
if "_class_path" in hparams:
hparams = {
"class_path": hparams["_class_path"],
"dict_kwargs": {k: v for k, v in hparams.items() if k != "_class_path"},
}
hparams = {self.config.subcommand: {"model": hparams}}
try:
self.config = self.parser.parse_object(hparams, self.config)
Expand Down
36 changes: 36 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class BoringCkptPathModel(BoringModel):
def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
super().__init__()
self.save_hyperparameters()
self.hidden_dim = hidden_dim
self.layer = torch.nn.Linear(32, out_dim)


Expand Down Expand Up @@ -526,6 +527,41 @@ def add_arguments_to_parser(self, parser):
assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue()


class BoringCkptPathSubclass(BoringCkptPathModel):
def __init__(self, extra: bool = True, **kwargs) -> None:
super().__init__(**kwargs)
self.extra = extra


def test_lightning_cli_ckpt_path_argument_hparams_subclass_mode(cleandir):
class CkptPathCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.init_args.out_dim", "model.init_args.hidden_dim", compute_fn=lambda x: x * 2)

cli_args = ["fit", "--model=BoringCkptPathSubclass", "--model.out_dim=4", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True)

assert cli.config.fit.model.class_path.endswith(".BoringCkptPathSubclass")
assert cli.config.fit.model.init_args == Namespace(out_dim=4, hidden_dim=8, extra=True)
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())
assert hparams["out_dim"] == 4
assert hparams["hidden_dim"] == 8
assert hparams["extra"] is True

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))
cli_args = ["predict", "--model=BoringCkptPathModel", f"--ckpt_path={checkpoint_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = CkptPathCLI(BoringCkptPathModel, subclass_mode_model=True)

assert isinstance(cli.model, BoringCkptPathSubclass)
assert cli.model.hidden_dim == 8
assert cli.model.extra is True
assert cli.model.layer.out_features == 4


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
Expand Down
Loading