Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiceLoss question in pretrained-backbones-unet/backbones_unet/model /losses.py #7

Open
plun11 opened this issue Dec 7, 2023 · 1 comment

Comments

@plun11
Copy link

plun11 commented Dec 7, 2023

Hello, thank you for the great library.

There is a line in the DiceLoss class:

if not self.from_logits: y_pred = F.sigmoid(y_pred)

I am not sure, but from the description, I think it is meant to apply the sigmoid if the self.from_logits is True, so there should be just "if", instead of "if not".

@physgorg
Copy link

physgorg commented Mar 8, 2024

I agree with this comment, and also note that using the .view(-1) method throws an error for me when my label tensor (y_true) is of shape (batch_size,H,W). I used a surrogate Dice Loss function:
`
def dice_loss(y_pred, y_true):
"""
Compute the Dice loss between predictions and true labels.

Args:
    y_pred (torch.Tensor): The predicted class probabilities with shape (batch_size, num_classes, H, W).
    y_true (torch.Tensor): The true labels with shape (batch_size, H, W).
    
Returns:
    torch.Tensor: Scalar Dice loss.
"""
# Convert y_true to one-hot format to match the shape of y_pred
y_true = y_true.long()
y_true_one_hot = torch.nn.functional.one_hot(y_true, num_classes=y_pred.shape[1]).permute(0, 3, 1, 2).float()

# Apply softmax to y_pred to get class probabilities
y_pred = torch.nn.functional.softmax(y_pred, dim=1)

# Compute intersection and union for Dice score
intersection = torch.sum(y_pred * y_true_one_hot, dim=(2, 3))
union = torch.sum(y_pred, dim=(2, 3)) + torch.sum(y_true_one_hot, dim=(2, 3))

# Compute Dice loss
dice_score = 2.0 * intersection / (union + 1e-6)  # Add a small epsilon to avoid division by zero
dice_loss = 1 - dice_score  # Dice loss is 1 minus the Dice score

# Return the mean Dice loss over all classes and batch
return torch.mean(dice_loss)

`
which avoids the error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants