From 9061e15154a1be174f23851f787aafcf1126752f Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Mon, 25 Aug 2025 05:10:44 +0200 Subject: [PATCH 1/4] Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes (#20310, #20801). --- src/lightning/pytorch/CHANGELOG.md | 4 ++++ src/lightning/pytorch/cli.py | 15 +++++++++++++ tests/tests_pytorch/test_cli.py | 34 ++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 10d90d68fcd45..0cd3e99c5f476 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -42,6 +42,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `RichProgressBar` crashing when sanity checking using val dataloader with 0 len ([#21108](https://github.com/Lightning-AI/pytorch-lightning/pull/21108)) + +- Fixed `LightningCLI` not using `ckpt_path` hyperparameters to instantiate classes ([#21116](https://github.com/Lightning-AI/pytorch-lightning/pull/21116)) + + --- ## [2.5.3] - 2025-08-13 diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 225296240674a..2893dfcb54b32 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -16,6 +16,7 @@ import sys from collections.abc import Iterable from functools import partial, update_wrapper +from pathlib import Path from types import MethodType from typing import Any, Callable, Optional, TypeVar, Union @@ -397,6 +398,7 @@ def __init__( main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs) self.setup_parser(run, main_kwargs, subparser_kwargs) self.parse_arguments(self.parser, args) + self._parse_ckpt_path() self.subcommand = self.config["subcommand"] if run else None @@ -551,6 +553,19 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def _parse_ckpt_path(self) -> None: + """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" + if not self.config.get("subcommand"): + return + ckpt_path = self.config[self.config.subcommand].get("ckpt_path") + if ckpt_path and Path(ckpt_path).is_file(): + ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu") + hparams = ckpt.get("hyper_parameters", {}) + hparams.pop("_instantiator", None) + if hparams: + hparams = {self.config.subcommand: {"model": hparams}} + self.config = self.parser.parse_object(hparams, self.config) + def _dump_config(self) -> None: if hasattr(self, "config_dump"): return diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 1b883dda0282a..f081731c978d6 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -487,6 +487,40 @@ def test_lightning_cli_print_config(): assert outval["ckpt_path"] is None +class BoringCkptPathModel(BoringModel): + def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None: + super().__init__() + self.save_hyperparameters() + self.layer = torch.nn.Linear(32, out_dim) + + +def test_lightning_cli_ckpt_path_argument_hparams(cleandir): + class CkptPathCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) + + cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = CkptPathCLI(BoringCkptPathModel) + + assert cli.config.fit.model.out_dim == 3 + assert cli.config.fit.model.hidden_dim == 6 + hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" + assert hparams_path.is_file() + hparams = yaml.safe_load(hparams_path.read_text()) + assert hparams["out_dim"] == 3 + assert hparams["hidden_dim"] == 6 + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + cli_args = ["predict", f"--ckpt_path={checkpoint_path}"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = CkptPathCLI(BoringCkptPathModel) + + assert cli.config.predict.model.out_dim == 3 + assert cli.config.predict.model.hidden_dim == 6 + assert cli.config_init.predict.model.layer.out_features == 3 + + def test_lightning_cli_submodules(cleandir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1): From aeb9e29e8403a6b9a923570998d308e680e26e71 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 26 Aug 2025 07:29:16 +0200 Subject: [PATCH 2/4] Add message about ckpt_path hyperparameters when parsing fails --- src/lightning/pytorch/cli.py | 6 +++++- tests/tests_pytorch/test_cli.py | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2893dfcb54b32..ed5a7aa9c9d7e 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -564,7 +564,11 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if hparams: hparams = {self.config.subcommand: {"model": hparams}} - self.config = self.parser.parse_object(hparams, self.config) + try: + self.config = self.parser.parse_object(hparams, self.config) + except SystemExit: + sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n") + raise def _dump_config(self) -> None: if hasattr(self, "config_dump"): diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index f081731c978d6..50e6a6f721ea1 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -17,7 +17,7 @@ import operator import os import sys -from contextlib import ExitStack, contextmanager, redirect_stdout +from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout from io import StringIO from pathlib import Path from typing import Callable, Optional, Union @@ -520,6 +520,11 @@ def add_arguments_to_parser(self, parser): assert cli.config.predict.model.hidden_dim == 6 assert cli.config_init.predict.model.layer.out_features == 3 + err = StringIO() + with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stderr(err), pytest.raises(SystemExit): + cli = LightningCLI(BoringModel) + assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue() + def test_lightning_cli_submodules(cleandir): class MainModule(BoringModel): From cf41d9c10f9a6cd3c079bc32ae03859547901839 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 2 Sep 2025 10:01:40 +0200 Subject: [PATCH 3/4] Apply suggestions from code review --- src/lightning/pytorch/CHANGELOG.md | 3 --- src/lightning/pytorch/cli.py | 24 +++++++++++++----------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index eb0dcfe02011d..524eacdaafa3d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -49,9 +49,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed misalignment column while using rich model summary in `DeepSpeedstrategy` ([#21100](https://github.com/Lightning-AI/pytorch-lightning/pull/21100)) - Fixed `RichProgressBar` crashing when sanity checking using val dataloader with 0 len ([#21108](https://github.com/Lightning-AI/pytorch-lightning/pull/21108)) - ---- - ## [2.5.3] - 2025-08-13 ### Changed diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index ed5a7aa9c9d7e..ef434f44a0b04 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -558,17 +558,19 @@ def _parse_ckpt_path(self) -> None: if not self.config.get("subcommand"): return ckpt_path = self.config[self.config.subcommand].get("ckpt_path") - if ckpt_path and Path(ckpt_path).is_file(): - ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu") - hparams = ckpt.get("hyper_parameters", {}) - hparams.pop("_instantiator", None) - if hparams: - hparams = {self.config.subcommand: {"model": hparams}} - try: - self.config = self.parser.parse_object(hparams, self.config) - except SystemExit: - sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n") - raise + if not ckpt_path or not Path(ckpt_path).is_file(): + return + ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu") + hparams = ckpt.get("hyper_parameters", {}) + hparams.pop("_instantiator", None) + if not hparams: + return + hparams = {self.config.subcommand: {"model": hparams}} + try: + self.config = self.parser.parse_object(hparams, self.config) + except SystemExit: + sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n") + raise def _dump_config(self) -> None: if hasattr(self, "config_dump"): From cd95d50cdc4dfda52790325f6c75151d3c812463 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 2 Sep 2025 13:58:21 +0200 Subject: [PATCH 4/4] revert? --- src/lightning/pytorch/cli.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index ef434f44a0b04..91247127f6c87 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -558,19 +558,18 @@ def _parse_ckpt_path(self) -> None: if not self.config.get("subcommand"): return ckpt_path = self.config[self.config.subcommand].get("ckpt_path") - if not ckpt_path or not Path(ckpt_path).is_file(): - return - ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu") - hparams = ckpt.get("hyper_parameters", {}) - hparams.pop("_instantiator", None) - if not hparams: - return - hparams = {self.config.subcommand: {"model": hparams}} - try: - self.config = self.parser.parse_object(hparams, self.config) - except SystemExit: - sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n") - raise + if ckpt_path and Path(ckpt_path).is_file(): + ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu") + hparams = ckpt.get("hyper_parameters", {}) + hparams.pop("_instantiator", None) + if not hparams: + return + hparams = {self.config.subcommand: {"model": hparams}} + try: + self.config = self.parser.parse_object(hparams, self.config) + except SystemExit: + sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n") + raise def _dump_config(self) -> None: if hasattr(self, "config_dump"):