-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement mixed precision for training and inference #20
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,10 @@ | |
from monai.losses import DiceLoss, MaskedLoss | ||
from monai.networks.nets import DynUNet | ||
from monai.utils import LossReduction | ||
from torch.nn.functional import binary_cross_entropy, sigmoid | ||
from torch.nn.functional import ( | ||
binary_cross_entropy_with_logits, | ||
sigmoid, | ||
) | ||
from torch.nn.modules.loss import _Loss | ||
|
||
|
||
|
@@ -79,14 +82,18 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
The calculated loss. | ||
""" | ||
# Create a mask to ignore the specified label in the target | ||
orig_data = data.clone() | ||
data = sigmoid(data) | ||
mask = target != self.ignore_label | ||
|
||
# Compute the cross entropy loss while ignoring the ignore_label | ||
target_comp = target.clone() | ||
target_comp[target == self.ignore_label] = 0 | ||
target_tensor = torch.tensor(target_comp, dtype=data.dtype, device=data.device) | ||
bce_loss = binary_cross_entropy(data, target_tensor, reduction="none") | ||
|
||
bce_loss = binary_cross_entropy_with_logits( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, this function applies a sigmoid to the prediction before computing the loss. Have you checked how this impacts training performance? https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haven't checked in detail how it influences training performance, except that it re-ran training to see if loss curves behave the same: In this sense, there is no difference. Including the sigmoid in the loss function is supposed to be more numerically stable, as the optimization uses the log sum exp trick to compute the loss: https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ I guess this helps especially in case of exploding or vanishing gradients, which I didn't observe during training. Other than that, the loss function is the same as the one I used previously. Note that the |
||
orig_data, target_tensor, reduction="none" | ||
) | ||
bce_loss[~mask] = 0.0 | ||
bce_loss = torch.sum(bce_loss) / torch.sum(mask) | ||
dice_loss = self.dice_loss(data, target, mask) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to switch to
binary_cross_entropy_with_logits
because normalbinary_cross_entropy
was not compatible with the mixed precision.