2323
2424from monai .losses .focal_loss import FocalLoss
2525from monai .losses .spatial_mask import MaskedLoss
26+ from monai .losses .utils import compute_tp_fp_fn
2627from monai .networks import one_hot
2728from monai .utils import DiceCEReduction , LossReduction , Weight , look_up_option , pytorch_after
2829
@@ -39,8 +40,16 @@ class DiceLoss(_Loss):
3940 The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
4041 the inter-over-union calculation to smooth results respectively, these values should be small.
4142
42- The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric
43- Medical Image Segmentation, 3DV, 2016.
43+ The original papers:
44+
45+ Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric
46+ Medical Image Segmentation. 3DV 2016.
47+
48+ Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with
49+ Soft Labels. NeurIPS 2023.
50+
51+ Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
52+ Soft Labels. MICCAI 2023.
4453
4554 """
4655
@@ -58,6 +67,7 @@ def __init__(
5867 smooth_dr : float = 1e-5 ,
5968 batch : bool = False ,
6069 weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
70+ soft_label : bool = False ,
6171 ) -> None :
6272 """
6373 Args:
@@ -89,6 +99,8 @@ def __init__(
8999 of the sequence should be the same as the number of classes. If not ``include_background``,
90100 the number of classes should not include the background category class 0).
91101 The value/values should be no less than 0. Defaults to None.
102+ soft_label: whether the target contains non-binary values (soft labels) or not.
103+ If True a soft label formulation of the loss will be used.
92104
93105 Raises:
94106 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -114,6 +126,7 @@ def __init__(
114126 weight = torch .as_tensor (weight ) if weight is not None else None
115127 self .register_buffer ("class_weight" , weight )
116128 self .class_weight : None | torch .Tensor
129+ self .soft_label = soft_label
117130
118131 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
119132 """
@@ -174,21 +187,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
174187 # reducing spatial dimensions and batch
175188 reduce_axis = [0 ] + reduce_axis
176189
177- intersection = torch .sum (target * input , dim = reduce_axis )
178-
179- if self .squared_pred :
180- ground_o = torch .sum (target ** 2 , dim = reduce_axis )
181- pred_o = torch .sum (input ** 2 , dim = reduce_axis )
182- else :
183- ground_o = torch .sum (target , dim = reduce_axis )
184- pred_o = torch .sum (input , dim = reduce_axis )
185-
186- denominator = ground_o + pred_o
187-
188- if self .jaccard :
189- denominator = 2.0 * (denominator - intersection )
190+ ord = 2 if self .squared_pred else 1
191+ tp , fp , fn = compute_tp_fp_fn (input , target , reduce_axis , ord , self .soft_label )
192+ if not self .jaccard :
193+ fp *= 0.5
194+ fn *= 0.5
195+ numerator = 2 * tp + self .smooth_nr
196+ denominator = 2 * (tp + fp + fn ) + self .smooth_dr
190197
191- f : torch .Tensor = 1.0 - ( 2.0 * intersection + self . smooth_nr ) / ( denominator + self . smooth_dr )
198+ f : torch .Tensor = 1 - numerator / denominator
192199
193200 num_of_classes = target .shape [1 ]
194201 if self .class_weight is not None and num_of_classes != 1 :
@@ -272,6 +279,7 @@ def __init__(
272279 smooth_nr : float = 1e-5 ,
273280 smooth_dr : float = 1e-5 ,
274281 batch : bool = False ,
282+ soft_label : bool = False ,
275283 ) -> None :
276284 """
277285 Args:
@@ -295,6 +303,8 @@ def __init__(
295303 batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
296304 Defaults to False, intersection over union is computed from each item in the batch.
297305 If True, the class-weighted intersection and union areas are first summed across the batches.
306+ soft_label: whether the target contains non-binary values (soft labels) or not.
307+ If True a soft label formulation of the loss will be used.
298308
299309 Raises:
300310 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -319,6 +329,7 @@ def __init__(
319329 self .smooth_nr = float (smooth_nr )
320330 self .smooth_dr = float (smooth_dr )
321331 self .batch = batch
332+ self .soft_label = soft_label
322333
323334 def w_func (self , grnd ):
324335 if self .w_type == str (Weight .SIMPLE ):
@@ -370,13 +381,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
370381 reduce_axis : list [int ] = torch .arange (2 , len (input .shape )).tolist ()
371382 if self .batch :
372383 reduce_axis = [0 ] + reduce_axis
373- intersection = torch .sum (target * input , reduce_axis )
374384
375- ground_o = torch . sum ( target , reduce_axis )
376- pred_o = torch . sum ( input , reduce_axis )
377-
378- denominator = ground_o + pred_o
385+ tp , fp , fn = compute_tp_fp_fn ( input , target , reduce_axis , 1 , self . soft_label )
386+ fp *= 0.5
387+ fn *= 0.5
388+ denominator = 2 * ( tp + fp + fn )
379389
390+ ground_o = torch .sum (target , reduce_axis )
380391 w = self .w_func (ground_o .float ())
381392 infs = torch .isinf (w )
382393 if self .batch :
@@ -388,7 +399,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
388399 w = w + infs * max_values
389400
390401 final_reduce_dim = 0 if self .batch else 1
391- numer = 2.0 * (intersection * w ).sum (final_reduce_dim , keepdim = True ) + self .smooth_nr
402+ numer = 2.0 * (tp * w ).sum (final_reduce_dim , keepdim = True ) + self .smooth_nr
392403 denom = (denominator * w ).sum (final_reduce_dim , keepdim = True ) + self .smooth_dr
393404 f : torch .Tensor = 1.0 - (numer / denom )
394405
0 commit comments