Skip to content

Conversation

mauvilsa
Copy link
Contributor

@mauvilsa mauvilsa commented Aug 25, 2025

What does this PR do?

Fixes #20310
Fixes #20801

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21116.org.readthedocs.build/en/21116/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Aug 25, 2025
@mauvilsa mauvilsa marked this pull request as ready for review August 25, 2025 03:17
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes an issue where LightningCLI was not properly using hyperparameters from checkpoint files when instantiating model classes. The fix adds functionality to automatically parse and apply hyperparameters from checkpoint files when a ckpt_path is specified.

  • Adds automatic parsing of hyperparameters from checkpoint files when ckpt_path is provided
  • Includes proper error handling for failed hyperparameter parsing
  • Updates changelog to document the fix

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
src/lightning/pytorch/cli.py Implements the core fix by adding _parse_ckpt_path() method to extract and apply hyperparameters from checkpoints
tests/tests_pytorch/test_cli.py Adds comprehensive test coverage for the checkpoint hyperparameter parsing functionality
src/lightning/pytorch/CHANGELOG.md Documents the bug fix in the changelog

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@Borda Borda changed the title Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes Sep 2, 2025
@Borda Borda merged commit 4824cc1 into Lightning-AI:master Sep 2, 2025
116 of 119 checks passed
@mauvilsa mauvilsa deleted the issue-20801-cli-subcommand-ckpt-path-hparams branch September 2, 2025 14:03
Borda pushed a commit that referenced this pull request Sep 3, 2025
…te classes (#21116)

* Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes
* Add message about ckpt_path hyperparameters when parsing fails
* Apply suggestions from code review

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka B <[email protected]>

(cherry picked from commit 4824cc1)
lantiga pushed a commit that referenced this pull request Sep 5, 2025
…te classes (#21116)

* Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes
* Add message about ckpt_path hyperparameters when parsing fails
* Apply suggestions from code review

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka B <[email protected]>

(cherry picked from commit 4824cc1)
@isaaccorley
Copy link

isaaccorley commented Sep 22, 2025

@Borda @SkafteNicki This PR seems to have broken backwards compatibility of checkpoints...

@SkafteNicki
Copy link
Collaborator

@Borda @SkafteNicki This PR seems to have broken backwards compatibility of checkpoints...

@isaaccorley dammit, thanks for reporting it. So the problem is that checkpoints saved with <v2.5.5 cannot be loaded using v2.5.5, is that correctly understood?

@isaaccorley
Copy link

@Borda @SkafteNicki This PR seems to have broken backwards compatibility of checkpoints...

@isaaccorley dammit, thanks for reporting it. So the problem is that checkpoints saved with <v2.5.5 cannot be loaded using v2.5.5, is that correctly understood?

Correct, see the references to this PR above from torchgeo and others.

@Borda
Copy link
Contributor

Borda commented Sep 22, 2025

need to be resolved as P0

@mauvilsa
Copy link
Contributor Author

So the problem is that checkpoints saved with <v2.5.5 cannot be loaded using v2.5.5, is that correctly understood?

I am not sure about this. This pull request doesn't change the saving of checkpoints, so I wouldn't see why checkpoints saved with <v2.5.5 and >=v2.5.5 would be any different. This pull request fixes a bug, which was hyperparameters found inside the checkpoint not being considered. Could be that there is a bug in the fix. But would be good to confirm that it is a bug, and not that other code depended on the buggy behavior that this was supposed to fix.

The case of mehta-lab/VisCy#291 seems a bit different: "checkpoints trained with one wrapper class is reused for another". If the hyperparameters in the checkpoint are not compatible with the other wrapper class, then makes sense that it doesn't work. Not sure if this is something to officially support, since there are cases in which it wouldn't work. Like when the shapes of the tensors differ, which was what this fix was about. Anyway, for the VisCy case I guess there would need to be a way to disable the loading of hyperparameters found in the checkpoint.

@isaaccorley
Copy link

@mauvilsa the error message introduced in this PR Parsing of ckpt_path hyperparameters failed! is what gets triggered so it's possible the if/else logic has a bug.

@ziw-liu
Copy link

ziw-liu commented Sep 22, 2025

for the VisCy case I guess there would need to be a way to disable the loading of hyperparameters found in the checkpoint.

For our use case this would be great. Specifically, this needs to be exposed at the CLI/YAML level so the existing dependency injection-based workflow can continue to function.

@mauvilsa
Copy link
Contributor Author

@ziw-liu you can already disable the parsing of hyperparameters in the checkpoint by doing as below, the caveat being that it is not an officially public way of do it.

class MyCLI(LightningCLI):
    def _parse_ckpt_path(self):
        pass

@mauvilsa
Copy link
Contributor Author

@mauvilsa the error message introduced in this PR Parsing of ckpt_path hyperparameters failed! is what gets triggered so it's possible the if/else logic has a bug.

@isaaccorley the error message only happens when self.parser.parse_object(hparams, self.config) fails to parse. Why the hyperparameters are invalid in the case of torchgeo and not in the unit test, I don't know. But it doesn't look like a problem with some if/else.

@isaaccorley
Copy link

@mauvilsa Not sure if this provides more insight but these are the ckpt["hyper_parameters"] of the checkpoints prior to lightning v2.5.5

{
   '_class_path': 'torchgeo.trainers.ClassificationTask',
   'model': 'resnet18',
   'in_channels': 13,
   'task': 'multiclass',
   'num_classes': 10,
   'num_labels': None,
   'loss': 'ce',
   'class_weights': None,
   'lr': 0.001,
   'patience': 10,
   'freeze_backbone': False,
   '_instantiator': 'lightning.pytorch.cli.instantiate_module'
}

here is an example trained checkpoint

@mauvilsa
Copy link
Contributor Author

I think the issue is that I didn't consider the subclass mode case. There is just a pop of _instantiator but no handling of _class_path. That is, something like cli.py#L885-L888 is missing, and a corresponding unit test.

@isaaccorley
Copy link

I think the issue is that I didn't consider the subclass mode case. There is just a pop of _instantiator but no handling of _class_path. That is, something like cli.py#L885-L888 is missing, and a corresponding unit test.

Got it! That makes sense because have a BaseTask module that has some common logic and then we have several sub tasks for segmentation/classification/detection that inherit from this BaseTask!

@SkafteNicki
Copy link
Collaborator

@mauvilsa is this something you will be able to fix?

@mauvilsa
Copy link
Contributor Author

I created pull request #21246.

@ziw-liu could you please try out the branch of the fix to see if it fixes the problem. At first glance to me it seemed a different issue. But could be that it is just the same bug.

@isaaccorley can you please test the fix with torchgeo.

@ziw-liu
Copy link

ziw-liu commented Sep 25, 2025

@mauvilsa I think we have a different issue than what #21246 fixes, where we want to load the state dict in a different lightning module. E.g. in the following we will have lr saved as a hyperparameter in the checkpoint from TrainingModule, but the InferenceModule does not take it as an argument:

import torch
from lightning.pytorch import LightningModule


class TrainingModule(LightningModule):
    def __init__(self, lr: float = 1e-3) -> None:
        super().__init__()
        self.model = torch.nn.Linear(16, 2)
        self.lr = lr

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def training_step(
        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        dataset = torch.utils.data.TensorDataset(
            torch.rand(32, 16), torch.randint(0, 2, (32,))
        )
        return torch.utils.data.DataLoader(dataset, batch_size=8)


class InferenceModule(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = torch.nn.Linear(16, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def predict_dataloader(self) -> torch.utils.data.DataLoader:
        dataset = torch.utils.data.TensorDataset(torch.rand(32, 16))
        return torch.utils.data.DataLoader(dataset, batch_size=128)

    def predict_step(
        self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = 0
    ) -> torch.Tensor:
        x = batch
        y_hat = self(x)
        return y_hat

@mauvilsa
Copy link
Contributor Author

@mauvilsa I think we have a different issue than what #21246 fixes, where we want to load the state dict in a different lightning module. E.g. in the following we will have lr saved as a hyperparameter in the checkpoint from TrainingModule, but the InferenceModule does not take it as an argument:

@ziw-liu thank you for the additional details of your use case. Using a different class for inference is not the standard case, so it seems to require a new feature request. Maybe better if you create a github issue to have a better place to discuss it. I might have an idea to propose already.

@mauvilsa
Copy link
Contributor Author

Now I think there is another issue with this fix. The load of checkpoint hparams works the same for any subcommand. However, fit also has the ckpt_path parameter. And I guess the only reason why someone would want to provide ckpt_path to fit is to reuse the weights, not the hparams. Could be fixed by not doing this for fit. But maybe it is necessary to print a warning when hparams are loaded from the checkpoint. And describe this behavior in the docs.

@mauvilsa
Copy link
Contributor Author

@mauvilsa I think we have a different issue than what #21246 fixes, where we want to load the state dict in a different lightning module. E.g. in the following we will have lr saved as a hyperparameter in the checkpoint from TrainingModule, but the InferenceModule does not take it as an argument:

@ziw-liu thank you for the additional details of your use case. Using a different class for inference is not the standard case, so it seems to require a new feature request. Maybe better if you create a github issue to have a better place to discuss it. I might have an idea to propose already.

Regarding this, I described a proposal in #21255. Whoever is interested, please have a look there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inconcistency in loading from checkpoint in LightningCLI hparams not loaded when loading checkpoint via LightningCLI
5 participants