From 2999335cc96bfc615e826ac44a02b3548883eace Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Mon, 10 Jul 2023 16:00:17 +0200 Subject: [PATCH] Implement mixed precision for training and inference --- src/membrain_seg/segment.py | 20 +++++++++++--------- src/membrain_seg/train.py | 4 ++-- src/membrain_seg/training/optim_utils.py | 19 +++++++++++++------ 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/membrain_seg/segment.py b/src/membrain_seg/segment.py index 3fef8c1..edf34be 100644 --- a/src/membrain_seg/segment.py +++ b/src/membrain_seg/segment.py @@ -91,16 +91,18 @@ def segment( print("Performing 8-fold test-time augmentation.") for m in range(8): with torch.no_grad(): - predictions += ( - get_mirrored_img( - inferer(get_mirrored_img(new_data.clone(), m).to(device), pl_model)[ - 0 - ], - m, + with torch.cuda.amp.autocast(): + predictions += ( + get_mirrored_img( + inferer( + get_mirrored_img(new_data.clone(), m).to(device), pl_model + )[0], + m, + ) + .detach() + .cpu() ) - .detach() - .cpu() - ) + predictions /= 8.0 # Extract segmentations and store them in an output file. diff --git a/src/membrain_seg/train.py b/src/membrain_seg/train.py index 16b9fa5..52986e1 100644 --- a/src/membrain_seg/train.py +++ b/src/membrain_seg/train.py @@ -89,8 +89,7 @@ def train( save_top_k=-1, # Save all checkpoints every_n_epochs=100, dirpath="checkpoints/", - filename=checkpointing_name - + "-{epoch}-{val_loss:.2f}", # Customize the filename of saved checkpoints + filename=checkpointing_name + "-{epoch}-{val_loss:.2f}", verbose=True, # Print a message when a checkpoint is saved ) @@ -104,6 +103,7 @@ def on_epoch_start(self, trainer, pl_module): print_lr_cb = PrintLearningRate() # Set up the trainer trainer = pl.Trainer( + precision="16-mixed", logger=[csv_logger, wandb_logger], callbacks=[ checkpoint_callback_val_loss, diff --git a/src/membrain_seg/training/optim_utils.py b/src/membrain_seg/training/optim_utils.py index 09c9be8..8dbdc6b 100644 --- a/src/membrain_seg/training/optim_utils.py +++ b/src/membrain_seg/training/optim_utils.py @@ -2,7 +2,10 @@ from monai.losses import DiceLoss, MaskedLoss from monai.networks.nets import DynUNet from monai.utils import LossReduction -from torch.nn.functional import binary_cross_entropy, sigmoid +from torch.nn.functional import ( + binary_cross_entropy_with_logits, + sigmoid, +) from torch.nn.modules.loss import _Loss @@ -79,6 +82,7 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: The calculated loss. """ # Create a mask to ignore the specified label in the target + orig_data = data.clone() data = sigmoid(data) mask = target != self.ignore_label @@ -86,7 +90,10 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target_comp = target.clone() target_comp[target == self.ignore_label] = 0 target_tensor = torch.tensor(target_comp, dtype=data.dtype, device=data.device) - bce_loss = binary_cross_entropy(data, target_tensor, reduction="none") + + bce_loss = binary_cross_entropy_with_logits( + orig_data, target_tensor, reduction="none" + ) bce_loss[~mask] = 0.0 bce_loss = torch.sum(bce_loss) / torch.sum(mask) dice_loss = self.dice_loss(data, target, mask) @@ -101,10 +108,10 @@ class DeepSuperVisionLoss(_Loss): Deep Supervision loss using downsampled GT and low-res outputs. Implementation based on nnU-Net's implementation with downsampled images. - Reference: Zeng, Guodong, et al. "3D U-net with multi-level deep supervision: - fully automatic segmentation of proximal femur in 3D MR images." Machine Learning - in Medical Imaging: 8th International Workshop, MLMI 2017, Held in Conjunction with - MICCAI 2017, Quebec City, QC, Canada, September 10, 2017, Proceedings 8. Springer + Reference: Zeng, Guodong, et al. "3D U-net with multi-level deep supervision: + fully automatic segmentation of proximal femur in 3D MR images." Machine Learning + in Medical Imaging: 8th International Workshop, MLMI 2017, Held in Conjunction with + MICCAI 2017, Quebec City, QC, Canada, September 10, 2017, Proceedings 8. Springer International Publishing, 2017. Parameters