Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add finetune function #68

Closed
wants to merge 14 commits into from
1 change: 1 addition & 0 deletions src/membrain_seg/segmentation/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# These imports are necessary to register CLI commands. Do not remove!
from .cli import cli # noqa: F401
from .fine_tune_cli import finetune # noqa: F401
from .segment_cli import segment # noqa: F401
from .ske_cli import skeletonize # noqa: F401
from .train_cli import data_dir_help, train # noqa: F401
153 changes: 153 additions & 0 deletions src/membrain_seg/segmentation/cli/fine_tune_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import List, Optional

from typer import Option
from typing_extensions import Annotated

from ..finetune import fine_tune as _fine_tune
from .cli import OPTION_PROMPT_KWARGS as PKWARGS
from .cli import cli


@cli.command(name="finetune", no_args_is_help=True)
def finetune(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good!
Can we have a finetune and finetune_advanced similar to training options, though?

I think it would be nice to not have the overwhelming choice of parameters for standard users, but instead have most standard parameters set by default.

pretrained_checkpoint_path: str = Option( # noqa: B008
...,
help="Path to the checkpoint of the pre-trained model.",
**PKWARGS,
),
finetune_data_dir: str = Option( # noqa: B008
...,
help='Path to the directory containing the new data for fine-tuning. \
Following the same required structure as the train function. \
To learn more about the required\
data structure, type "membrain data_structure_help"',
**PKWARGS,
),
log_dir: str = Option( # noqa: B008
"logs_fine_tune/",
help="Log directory path. Finetuning logs will be stored here.",
),
batch_size: int = Option( # noqa: B008
2,
help="Batch size for training.",
),
num_workers: int = Option( # noqa: B008
8,
help="Number of worker threads for data loading.",
),
max_epochs: int = Option( # noqa: B008
100,
help="Maximum number of epochs for fine-tuning.",
),
early_stop_threshold: float = Option( # noqa: B008
0.05,
help="Threshold for early stopping based on validation loss deviation.",
),
aug_prob_to_one: bool = Option( # noqa: B008
True,
help='Whether to augment with a probability of one. This helps with the \
model\'s generalization,\
but also severely increases training time.\
Pass "True" or "False".',
),
use_surface_dice: bool = Option( # noqa: B008
False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".'
),
surface_dice_weight: float = Option( # noqa: B008
1.0, help="Scaling factor for the Surface-Dice loss. "
),
surface_dice_tokens: Annotated[
Optional[List[str]],
Option(
help='List of tokens to \
use for the Surface-Dice loss. \
Pass tokens separately:\
For example, train_advanced --surface_dice_tokens "ds1" \
--surface_dice_tokens "ds2"'
),
] = None,
use_deep_supervision: bool = Option( # noqa: B008
True, help='Whether to use deep supervision. Pass "True" or "False".'
),
project_name: str = Option( # noqa: B008
"membrain-seg_v0_finetune",
help="Project name. This helps to find your model again.",
),
sub_name: str = Option( # noqa: B008
"1",
help="Subproject name. For multiple runs in the same project,\
please specify sub_names.",
),
):
"""
Initiates fine-tuning of a pre-trained model on new datasets
and validation on original datasets.

This function finetunes a pre-trained U-Net model on new data provided by the user.
The `finetune_data_dir` should contain the following directories:
- `imagesTr` and `labelsTr` for the user's own new training data.
- `imagesVal` and `labelsVal` for the old data, which will be used
for validation to ensure that the fine-tuned model's performance
is not significantly worse on the original training data than the
pre-trained model.

Parameters
----------
pretrained_checkpoint_path : str
Path to the checkpoint file of the pre-trained model.
finetune_data_dir : str
Directory containing the new dataset for fine-tuning,
structured as per the MemBrain's requirement.
Use "membrain data_structure_help" for detailed information
on the required data structure.
log_dir : str
Path to the directory where logs will be stored, by default 'logs_fine_tune/'.
batch_size : int
Number of samples per batch, by default 2.
num_workers : int
Number of worker threads for data loading, by default 8.
max_epochs : int
Maximum number of fine-tuning epochs, by default 100.
early_stop_threshold : float
Threshold for early stopping based on validation loss deviation,
by default 0.05.
aug_prob_to_one : bool
Determines whether to apply very strong data augmentation, by default True.
If set to False, data augmentation still happens, but not as frequently.
More data augmentation can lead to better performance, but also increases the
training time substantially.
use_surface_dice : bool
Determines whether to use Surface-Dice loss, by default False.
surface_dice_weight : float
Scaling factor for the Surface-Dice loss, by default 1.0.
surface_dice_tokens : list
List of tokens to use for the Surface-Dice loss.
use_deep_supervision : bool
Determines whether to use deep supervision, by default True.
project_name : str
Name of the project for logging purposes, by default 'membrain-seg_v0_finetune'.
sub_name : str
Sub-name for the project, by default '1'.

Note
----
This command configures and executes a fine-tuning session
using the provided model checkpoint.
The actual fine-tuning logic resides in the function '_fine_tune'.
"""
_fine_tune(
pretrained_checkpoint_path=pretrained_checkpoint_path,
finetune_data_dir=finetune_data_dir,
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
max_epochs=max_epochs,
early_stop_threshold=early_stop_threshold,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
use_surf_dice=use_surface_dice,
surf_dice_weight=surface_dice_weight,
surf_dice_tokens=surface_dice_tokens,
project_name=project_name,
sub_name=sub_name,
)
225 changes: 225 additions & 0 deletions src/membrain_seg/segmentation/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
from typing import Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from membrain_seg.segmentation.dataloading.memseg_pl_datamodule import (
MemBrainSegDataModule,
)
from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
from membrain_seg.segmentation.training.training_param_summary import (
print_training_parameters,
)


def fine_tune(
pretrained_checkpoint_path: str,
finetune_data_dir: str,
finetune_learning_rate: float = 1e-5,
log_dir: str = "logs_finetune/",
batch_size: int = 2,
num_workers: int = 8,
max_epochs: int = 100,
early_stop_threshold: float = 0.05,
aug_prob_to_one: bool = False,
use_deep_supervision: bool = False,
project_name: str = "membrain-seg_finetune",
sub_name: str = "1",
use_surf_dice: bool = False,
surf_dice_weight: float = 1.0,
surf_dice_tokens: list = None,
) -> None:
"""
Fine-tune a pre-trained U-Net model on new datasets.

This function finetunes a pre-trained U-Net model on new data provided by the user.
The `finetune_data_dir` should contain the following directories:
- `imagesTr` and `labelsTr` for the user's own new training data.
- `imagesVal` and `labelsVal` for the old data, which will be used
for validation to ensure that the fine-tuned model's performance
is not significantly worse on the original training data than the
pre-trained model.

Callbacks used during the fine-tuning process
---------
- ModelCheckpoint: Saves the model checkpoints based on training loss
and at regular intervals.
- ToleranceCallback: Stops training if the validation loss deviates significantly
from the baseline value set after the first epoch.
- LearningRateMonitor: Monitors and logs the learning rate during training.
- PrintLearningRate: Prints the current learning rate at the start of each epoch.

Parameters
----------
pretrained_checkpoint_path : str
Path to the checkpoint of the pre-trained model.
finetune_data_dir : str
Path to the directory containing the new data for fine-tuning
and old data for validation.
finetune_learning_rate : float, optional
Learning rate for fine-tuning, by default 1e-5.
log_dir : str, optional
Path to the directory where logs should be stored.
batch_size : int, optional
Number of samples per batch of input data.
num_workers : int, optional
Number of subprocesses to use for data loading.
max_epochs : int, optional
Maximum number of epochs to finetune, by default 100.
early_stop_threshold : float, optional
Threshold for early stopping based on validation loss deviation,
by default 0.05.
aug_prob_to_one : bool, optional
If True, all augmentation probabilities are set to 1.
use_deep_supervision : bool, optional
If True, enables deep supervision in the U-Net model.
project_name : str, optional
Name of the project for logging purposes.
sub_name : str, optional
Sub-name of the project for logging purposes.
use_surf_dice : bool, optional
If True, enables Surface-Dice loss.
surf_dice_weight : float, optional
Weight for the Surface-Dice loss.
surf_dice_tokens : list, optional
List of tokens to use for the Surface-Dice loss.

Returns
-------
None
"""
# Print training parameters for verification
print_training_parameters(
data_dir=finetune_data_dir,
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
max_epochs=max_epochs,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
project_name=project_name,
sub_name=sub_name,
use_surf_dice=use_surf_dice,
surf_dice_weight=surf_dice_weight,
surf_dice_tokens=surf_dice_tokens,
)
print("————————————————————————————————————————————————————————")
print(
f"Pretrained Checkpoint:\n"
f" '{pretrained_checkpoint_path}' \n"
f" Path to the pretrained model checkpoint."
)
print("\n")

# Initialize the data module with fine-tuning datasets
# New data for finetuning and old data for validation
finetune_data_module = MemBrainSegDataModule(
data_dir=finetune_data_dir,
batch_size=batch_size,
num_workers=num_workers,
aug_prob_to_one=aug_prob_to_one,
)

# Load the pre-trained model with updated learning rate
pretrained_model = SemanticSegmentationUnet.load_from_checkpoint(
pretrained_checkpoint_path, learning_rate=finetune_learning_rate
)

checkpointing_name = project_name + "_" + sub_name

# Set up logging
csv_logger = pl_loggers.CSVLogger(log_dir)

# Set up model checkpointing based on training loss
checkpoint_callback_train_loss = ModelCheckpoint(
dirpath="finetuned_checkpoints/",
filename=checkpointing_name + "-{epoch:02d}-{train_loss:.2f}",
monitor="train_loss",
mode="min",
save_top_k=3,
)

# Set up regular checkpointing every 5 epochs
checkpoint_callback_regular = ModelCheckpoint(
save_top_k=-1, # Save all checkpoints
every_n_epochs=5,
dirpath="finetuned_checkpoints/",
filename=checkpointing_name + "-{epoch}-{train_loss:.2f}",
verbose=True, # Print a message when a checkpoint is saved
)

class ToleranceCallback(Callback):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put this class into a separate file? Does it fit to optim_utils?

"""
Callback to stop training if the monitored metric deviates
beyond a certain threshold from the baseline value obtained
after the first epoch.
"""

def __init__(self, metric_name: str, threshold: float):
super().__init__()
self.metric_name = metric_name
self.threshold = threshold
self.baseline_value: Optional[float] = (
None # Baseline value will be set after the first epoch
)

def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
# Access the metric value from the validation metrics
metric_value = trainer.callback_metrics.get(self.metric_name)

# If the metric value is a tensor, convert it to a float
if isinstance(metric_value, torch.Tensor):
metric_value = metric_value.item()

# Set the baseline value after the first validation epoch
if metric_value is not None:
if self.baseline_value is None:
self.baseline_value = metric_value
print(f"Baseline {self.metric_name} set to {self.baseline_value}")
return []

# Check if the metric value deviates beyond the threshold
if abs(metric_value - self.baseline_value) > self.threshold:
print(
f"Stopping training as {self.metric_name} "
f"deviates too far from the baseline value."
)
trainer.should_stop = True

early_stop_metric = "val_loss"

tolerance_callback = ToleranceCallback(early_stop_metric, early_stop_threshold)

# Monitor learning rate changes
lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True)

class PrintLearningRate(Callback):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also put this into another file?

"""Callback to print the current learning rate at the start of each epoch."""

def on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
current_lr = trainer.optimizers[0].param_groups[0]["lr"]
print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}")

print_lr_cb = PrintLearningRate()

# Initialize the trainer with specified precision, logger, and callbacks
trainer = pl.Trainer(
precision="16-mixed",
logger=[csv_logger],
callbacks=[
checkpoint_callback_train_loss,
checkpoint_callback_regular,
lr_monitor,
print_lr_cb,
tolerance_callback,
],
max_epochs=max_epochs,
)

# Start the fine-tuning process
trainer.fit(pretrained_model, finetune_data_module)
2 changes: 2 additions & 0 deletions src/membrain_seg/segmentation/training/optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
combined_loss = combined_loss.mean()
elif self.reduction == "sum":
combined_loss = combined_loss.sum()
elif self.reduction == "none":
return combined_loss
else:
raise ValueError(
f"Invalid reduction type {self.reduction}. "
Expand Down