Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul of torchseg.losses #15

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft

Conversation

notprime
Copy link

Overhaul of torchseg.losses, starting from DiceLoss.

General changes applied to DiceLoss (to be applied later to all other losses):

  • switched from assert to warnings and raise;
  • eliminated the need of specifying binary, multiclass and multilabels modes when instantiating a loss function, by changing masks shape: now the shape must be [B, C, H, W] for multiclass and multilabels cases, [B, 1, H, W] for binary case. [B, 1, H, W] can be accepted also for multiclass case as long as mask_to_one_hot = True;
  • introduction of proper reduction techniques (MEAN, SUM, NONE) and definition of a LossReduction class to easily save the techniques;

Specific changes applied to DiceLoss:

  • added the argument power to be applied to the denominator of the DiceCoefficient (still need to add this to the documentation of the class)

@notprime
Copy link
Author

notprime commented Mar 11, 2024

@isaaccorley let me know what do you think about the "new" DiceLoss, it seems cleaner and more efficient without the burden of specyfying binary, multilabel and multiclass modes.

I'm not particularly sure about line 135 and 136 in the forward of DiceLoss, it does not make much sense to me for the two reasons I've written in the comment.

I think that we should just rewrite the losses before implementing the new ones

Copy link
Owner

@isaaccorley isaaccorley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really clean. I like it. Just a few comments. We should also add tests. Also the library is using black/isort/flake8 and checking for it in the actions so make sure to run it

smooth: float = 0.0,
eps: float = 1e-7,
dims=None,
dims = None,
) -> torch.Tensor:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a docstring to this?

output_pow = torch.sum(output ** power)
target_pow = torch.sum(target ** power)
cardinality = output_pow + target_pow
dice_score = (2.0 * intersection + smooth) / (cardinality + eps)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need to add smooth to cardinality?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eps and smooth basically do the same thing (smooth avoid a value of 0, eps avoid nan). Supposing that smooth should be very small (same order as eps), we could just use one of the two, but we can also treat them separately, let me know

i think leaving them both gives more customization

y_true = y_true.view(bs, -1)
y_pred = y_pred.view(bs, num_classes, -1)

if self.ignore_index is not None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are removing the ignore_index filtering but I don't see it added below in the new code

# maybe there is a better way to handle this?
permute_dims = tuple(dim - 1 for dim in spatial_dims)
y_true = F.one_hot(y_true, num_classes).squeeze(dim = 1) # N, 1, H, W, ... ---> N, H, W, ..., C
y_true = y_true.permute(0, -1, *permute_dims) # N, 1, H, W, ..., C ---> N, C, H, W, ...
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use einops for this. Its cleaner

elif self.reduction == LossReduction.SUM:
loss = torch.sum(loss)
elif self.reduction == LossReduction.NONE:
broadcast_shape = list(loss.shape[0:2]) + [1] * (len(y_true.shape) - 2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just giving the users the ability to choose a preferred reduction instead of the classic mean. The latter basically returns a DiceScore per channel. So, if y_true has shape [B, C, H, W], selecting LossReduction.NONE returns a loss with shape [B, C, 1, 1].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants