Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement mixed precision for training and inference #20

Merged
merged 2 commits into from
Jul 20, 2023

Conversation

LorenzLamm
Copy link
Collaborator

@LorenzLamm LorenzLamm commented Jul 10, 2023

Mixed precision for both training and inference.

Using mixed precision during training allows for larger batch sizes and slightly faster training speed (time-wise bottleneck is still data augmentation, which is not affected here).

For segmentation, there is also a slight speed-up, but the main advantage is that inference requires less GPU memory. Thus, it should now be easily possible to perform the prediction using an 8GB GPU -- previously, 8GB was just about the margin causing problems for some users.

@LorenzLamm LorenzLamm marked this pull request as ready for review July 10, 2023 14:16
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to switch to binary_cross_entropy_with_logits because normal binary_cross_entropy was not compatible with the mixed precision.

Copy link
Collaborator

@kevinyamauchi kevinyamauchi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @LorenzLamm ! I have a question below, but it's more of a curiosity.

data = sigmoid(data)
mask = target != self.ignore_label

# Compute the cross entropy loss while ignoring the ignore_label
target_comp = target.clone()
target_comp[target == self.ignore_label] = 0
target_tensor = torch.tensor(target_comp, dtype=data.dtype, device=data.device)
bce_loss = binary_cross_entropy(data, target_tensor, reduction="none")

bce_loss = binary_cross_entropy_with_logits(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this function applies a sigmoid to the prediction before computing the loss. Have you checked how this impacts training performance?

https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't checked in detail how it influences training performance, except that it re-ran training to see if loss curves behave the same: In this sense, there is no difference.

Including the sigmoid in the loss function is supposed to be more numerically stable, as the optimization uses the log sum exp trick to compute the loss: https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/

I guess this helps especially in case of exploding or vanishing gradients, which I didn't observe during training.

Other than that, the loss function is the same as the one I used previously. Note that the binary_cross_entropy_with_logits function uses as input the orig_data variable instead of previously data where sigmoid is applied above.
So both the previous version and the new version first do sigmoid, and then cross entropy. Only the new version computes gradients in a single pass for more stability.

@LorenzLamm LorenzLamm merged commit 0e5d732 into main Jul 20, 2023
11 checks passed
@LorenzLamm LorenzLamm deleted the mixed_precision branch July 20, 2023 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants