-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement mixed precision for training and inference #20
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to switch to binary_cross_entropy_with_logits
because normal binary_cross_entropy
was not compatible with the mixed precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM @LorenzLamm ! I have a question below, but it's more of a curiosity.
data = sigmoid(data) | ||
mask = target != self.ignore_label | ||
|
||
# Compute the cross entropy loss while ignoring the ignore_label | ||
target_comp = target.clone() | ||
target_comp[target == self.ignore_label] = 0 | ||
target_tensor = torch.tensor(target_comp, dtype=data.dtype, device=data.device) | ||
bce_loss = binary_cross_entropy(data, target_tensor, reduction="none") | ||
|
||
bce_loss = binary_cross_entropy_with_logits( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, this function applies a sigmoid to the prediction before computing the loss. Have you checked how this impacts training performance?
https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't checked in detail how it influences training performance, except that it re-ran training to see if loss curves behave the same: In this sense, there is no difference.
Including the sigmoid in the loss function is supposed to be more numerically stable, as the optimization uses the log sum exp trick to compute the loss: https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
I guess this helps especially in case of exploding or vanishing gradients, which I didn't observe during training.
Other than that, the loss function is the same as the one I used previously. Note that the binary_cross_entropy_with_logits
function uses as input the orig_data
variable instead of previously data
where sigmoid is applied above.
So both the previous version and the new version first do sigmoid, and then cross entropy. Only the new version computes gradients in a single pass for more stability.
Mixed precision for both training and inference.
Using mixed precision during training allows for larger batch sizes and slightly faster training speed (time-wise bottleneck is still data augmentation, which is not affected here).
For segmentation, there is also a slight speed-up, but the main advantage is that inference requires less GPU memory. Thus, it should now be easily possible to perform the prediction using an 8GB GPU -- previously, 8GB was just about the margin causing problems for some users.