@@ -354,10 +354,46 @@ class BackboneFinetuning(BaseFinetuning):
354354
355355 Example::
356356
357- >>> from lightning.pytorch import Trainer
357+ >>> import torch
358+ >>> import torch.nn as nn
359+ >>> from lightning.pytorch import LightningModule, Trainer
358360 >>> from lightning.pytorch.callbacks import BackboneFinetuning
361+ >>> import torchvision.models as models
362+ >>>
363+ >>> class TransferLearningModel(LightningModule):
364+ ... def __init__(self, num_classes=10):
365+ ... super().__init__()
366+ ... # REQUIRED: Your model must have a 'backbone' attribute
367+ ... self.backbone = models.resnet50(weights="DEFAULT")
368+ ... # Remove the final classification layer from backbone
369+ ... self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
370+ ...
371+ ... # Add your task-specific head
372+ ... self.head = nn.Sequential(
373+ ... nn.Flatten(),
374+ ... nn.Linear(2048, 512),
375+ ... nn.ReLU(),
376+ ... nn.Linear(512, num_classes)
377+ ... )
378+ ...
379+ ... def forward(self, x):
380+ ... # Extract features with backbone
381+ ... features = self.backbone(x)
382+ ... # Classify with head
383+ ... return self.head(features)
384+ ...
385+ ... def configure_optimizers(self):
386+ ... # Initially only optimize the head - backbone will be added by callback
387+ ... return torch.optim.Adam(self.head.parameters(), lr=1e-3)
388+ ...
389+ >>> # Setup the callback
359390 >>> multiplicative = lambda epoch: 1.5
360- >>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
391+ >>> backbone_finetuning = BackboneFinetuning(
392+ ... unfreeze_backbone_at_epoch=10, # Start unfreezing at epoch 10
393+ ... lambda_func=multiplicative, # Gradually increase backbone LR
394+ ... backbone_initial_ratio_lr=0.1, # Start backbone at 10% of head LR
395+ ... )
396+ >>> model = TransferLearningModel()
361397 >>> trainer = Trainer(callbacks=[backbone_finetuning])
362398
363399 """
0 commit comments