-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix LightningCLI
not using ckpt_path
hyperparameters to instantiate classes
#21116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix LightningCLI
not using ckpt_path
hyperparameters to instantiate classes
#21116
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes an issue where LightningCLI was not properly using hyperparameters from checkpoint files when instantiating model classes. The fix adds functionality to automatically parse and apply hyperparameters from checkpoint files when a ckpt_path
is specified.
- Adds automatic parsing of hyperparameters from checkpoint files when
ckpt_path
is provided - Includes proper error handling for failed hyperparameter parsing
- Updates changelog to document the fix
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
src/lightning/pytorch/cli.py | Implements the core fix by adding _parse_ckpt_path() method to extract and apply hyperparameters from checkpoints |
tests/tests_pytorch/test_cli.py | Adds comprehensive test coverage for the checkpoint hyperparameter parsing functionality |
src/lightning/pytorch/CHANGELOG.md | Documents the bug fix in the changelog |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
LightningCLI
not using ckpt_path
hyperparameters to instantiate classes
…te classes (#21116) * Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes * Add message about ckpt_path hyperparameters when parsing fails * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka B <[email protected]> (cherry picked from commit 4824cc1)
…te classes (#21116) * Fix LightningCLI not using ckpt_path hyperparameters to instantiate classes * Add message about ckpt_path hyperparameters when parsing fails * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka B <[email protected]> (cherry picked from commit 4824cc1)
@Borda @SkafteNicki This PR seems to have broken backwards compatibility of checkpoints... |
@isaaccorley dammit, thanks for reporting it. So the problem is that checkpoints saved with <v2.5.5 cannot be loaded using v2.5.5, is that correctly understood? |
Correct, see the references to this PR above from torchgeo and others. |
need to be resolved as P0 |
I am not sure about this. This pull request doesn't change the saving of checkpoints, so I wouldn't see why checkpoints saved with <v2.5.5 and >=v2.5.5 would be any different. This pull request fixes a bug, which was hyperparameters found inside the checkpoint not being considered. Could be that there is a bug in the fix. But would be good to confirm that it is a bug, and not that other code depended on the buggy behavior that this was supposed to fix. The case of mehta-lab/VisCy#291 seems a bit different: "checkpoints trained with one wrapper class is reused for another". If the hyperparameters in the checkpoint are not compatible with the other wrapper class, then makes sense that it doesn't work. Not sure if this is something to officially support, since there are cases in which it wouldn't work. Like when the shapes of the tensors differ, which was what this fix was about. Anyway, for the VisCy case I guess there would need to be a way to disable the loading of hyperparameters found in the checkpoint. |
@mauvilsa the error message introduced in this PR |
For our use case this would be great. Specifically, this needs to be exposed at the CLI/YAML level so the existing dependency injection-based workflow can continue to function. |
@ziw-liu you can already disable the parsing of hyperparameters in the checkpoint by doing as below, the caveat being that it is not an officially public way of do it. class MyCLI(LightningCLI):
def _parse_ckpt_path(self):
pass |
@isaaccorley the error message only happens when self.parser.parse_object(hparams, self.config) fails to parse. Why the hyperparameters are invalid in the case of torchgeo and not in the unit test, I don't know. But it doesn't look like a problem with some if/else. |
@mauvilsa Not sure if this provides more insight but these are the {
'_class_path': 'torchgeo.trainers.ClassificationTask',
'model': 'resnet18',
'in_channels': 13,
'task': 'multiclass',
'num_classes': 10,
'num_labels': None,
'loss': 'ce',
'class_weights': None,
'lr': 0.001,
'patience': 10,
'freeze_backbone': False,
'_instantiator': 'lightning.pytorch.cli.instantiate_module'
} here is an example trained checkpoint |
I think the issue is that I didn't consider the subclass mode case. There is just a pop of |
Got it! That makes sense because have a BaseTask module that has some common logic and then we have several sub tasks for segmentation/classification/detection that inherit from this BaseTask! |
@mauvilsa is this something you will be able to fix? |
I created pull request #21246. @ziw-liu could you please try out the branch of the fix to see if it fixes the problem. At first glance to me it seemed a different issue. But could be that it is just the same bug. @isaaccorley can you please test the fix with torchgeo. |
@mauvilsa I think we have a different issue than what #21246 fixes, where we want to load the state dict in a different lightning module. E.g. in the following we will have import torch
from lightning.pytorch import LightningModule
class TrainingModule(LightningModule):
def __init__(self, lr: float = 1e-3) -> None:
super().__init__()
self.model = torch.nn.Linear(16, 2)
self.lr = lr
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def training_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss)
return loss
def train_dataloader(self) -> torch.utils.data.DataLoader:
dataset = torch.utils.data.TensorDataset(
torch.rand(32, 16), torch.randint(0, 2, (32,))
)
return torch.utils.data.DataLoader(dataset, batch_size=8)
class InferenceModule(LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(16, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def predict_dataloader(self) -> torch.utils.data.DataLoader:
dataset = torch.utils.data.TensorDataset(torch.rand(32, 16))
return torch.utils.data.DataLoader(dataset, batch_size=128)
def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
x = batch
y_hat = self(x)
return y_hat |
@ziw-liu thank you for the additional details of your use case. Using a different class for inference is not the standard case, so it seems to require a new feature request. Maybe better if you create a github issue to have a better place to discuss it. I might have an idea to propose already. |
Now I think there is another issue with this fix. The load of checkpoint hparams works the same for any subcommand. However, |
Regarding this, I described a proposal in #21255. Whoever is interested, please have a look there. |
What does this PR do?
Fixes #20310
Fixes #20801
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--21116.org.readthedocs.build/en/21116/