-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overhaul of torchseg.losses #15
base: main
Are you sure you want to change the base?
Conversation
@isaaccorley let me know what do you think about the "new" DiceLoss, it seems cleaner and more efficient without the burden of specyfying I'm not particularly sure about line 135 and 136 in the I think that we should just rewrite the losses before implementing the new ones |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really clean. I like it. Just a few comments. We should also add tests. Also the library is using black/isort/flake8 and checking for it in the actions so make sure to run it
smooth: float = 0.0, | ||
eps: float = 1e-7, | ||
dims=None, | ||
dims = None, | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a docstring to this?
output_pow = torch.sum(output ** power) | ||
target_pow = torch.sum(target ** power) | ||
cardinality = output_pow + target_pow | ||
dice_score = (2.0 * intersection + smooth) / (cardinality + eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we not need to add smooth to cardinality?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eps and smooth basically do the same thing (smooth avoid a value of 0, eps avoid nan). Supposing that smooth should be very small (same order as eps), we could just use one of the two, but we can also treat them separately, let me know
i think leaving them both gives more customization
y_true = y_true.view(bs, -1) | ||
y_pred = y_pred.view(bs, num_classes, -1) | ||
|
||
if self.ignore_index is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are removing the ignore_index filtering but I don't see it added below in the new code
torchseg/losses/dice.py
Outdated
# maybe there is a better way to handle this? | ||
permute_dims = tuple(dim - 1 for dim in spatial_dims) | ||
y_true = F.one_hot(y_true, num_classes).squeeze(dim = 1) # N, 1, H, W, ... ---> N, H, W, ..., C | ||
y_true = y_true.permute(0, -1, *permute_dims) # N, 1, H, W, ..., C ---> N, C, H, W, ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use einops for this. Its cleaner
elif self.reduction == LossReduction.SUM: | ||
loss = torch.sum(loss) | ||
elif self.reduction == LossReduction.NONE: | ||
broadcast_shape = list(loss.shape[0:2]) + [1] * (len(y_true.shape) - 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just giving the users the ability to choose a preferred reduction instead of the classic mean
. The latter basically returns a DiceScore per channel. So, if y_true has shape [B, C, H, W], selecting LossReduction.NONE
returns a loss with shape [B, C, 1, 1].
Overhaul of torchseg.losses, starting from DiceLoss.
General changes applied to DiceLoss (to be applied later to all other losses):
assert
towarnings
andraise
;binary
,multiclass
andmultilabels
modes when instantiating a loss function, by changing masks shape: now the shape must be [B, C, H, W] for multiclass and multilabels cases, [B, 1, H, W] for binary case. [B, 1, H, W] can be accepted also for multiclass case as long asmask_to_one_hot = True
;MEAN
,SUM
,NONE
) and definition of aLossReduction
class to easily save the techniques;Specific changes applied to DiceLoss:
power
to be applied to the denominator of the DiceCoefficient (still need to add this to the documentation of the class)