@@ -64,6 +64,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6464 else :
6565 y_pred = torch .sigmoid (y_pred )
6666
67+ if y_pred .shape [1 ] == 1 :
68+ y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
69+ y_true = torch .cat ([1 - y_true , y_true ], dim = 1 )
70+
6771 n_pred_ch = y_pred .shape [1 ]
6872
6973 if self .to_onehot_y :
@@ -77,7 +81,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7781
7882 # clip the prediction to avoid NaN
7983 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
80-
8184 axis = list (range (2 , len (y_pred .shape )))
8285
8386 # Calculate true positives (tp), false negatives (fn) and false positives (fp)
@@ -86,18 +89,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8689 fp = torch .sum ((1 - y_true ) * y_pred , dim = axis )
8790 dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
8891
89- # Calculate losses separately for each class, enhancing both classes
92+ # Class 0 is Background
9093 back_dice = 1 - dice_class [:, 0 ]
9194
92- if n_pred_ch > 1 :
93- fore_dice = torch .pow (1 - dice_class [:, 1 :], 1 - self .gamma )
95+ # Class 1+ is Foreground
96+ fore_dice = torch .pow (1 - dice_class [:, 1 :], 1 - self .gamma )
9497
95- if fore_dice .shape [1 ] > 1 :
96- fore_dice = torch .mean (fore_dice , dim = 1 )
97- else :
98- fore_dice = fore_dice .squeeze (1 )
98+ if fore_dice .shape [1 ] > 1 :
99+ fore_dice = torch .mean (fore_dice , dim = 1 )
99100 else :
100- fore_dice = torch . zeros_like ( back_dice )
101+ fore_dice = fore_dice . squeeze ( 1 )
101102
102103 # Average class scores
103104 loss = torch .mean (torch .stack ([back_dice , fore_dice ], dim = - 1 ))
@@ -149,6 +150,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
149150 y_log_pred = F .logsigmoid (y_pred )
150151 y_pred = torch .sigmoid (y_pred )
151152
153+ if y_pred .shape [1 ] == 1 :
154+ y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
155+ y_log_pred = torch .log (torch .clamp (y_pred , 1e-7 , 1.0 ))
156+ y_true = torch .cat ([1 - y_true , y_true ], dim = 1 )
157+
152158 n_pred_ch = y_pred .shape [1 ]
153159
154160 if self .to_onehot_y :
@@ -163,19 +169,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
163169 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
164170 cross_entropy = - y_true * y_log_pred
165171
172+ # Class 0: Background
166173 back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
167174 back_ce = (1 - self .delta ) * back_ce
168175
169- if n_pred_ch > 1 :
170- fore_ce = cross_entropy [:, 1 :]
171- fore_ce = self .delta * fore_ce
176+ # Class 1+: Foreground
177+ fore_ce = cross_entropy [:, 1 :]
178+ fore_ce = self .delta * fore_ce
172179
173- if fore_ce .shape [1 ] > 1 :
174- fore_ce = torch .sum (fore_ce , dim = 1 )
175- else :
176- fore_ce = fore_ce .squeeze (1 )
180+ if fore_ce .shape [1 ] > 1 :
181+ fore_ce = torch .sum (fore_ce , dim = 1 )
177182 else :
178- fore_ce = torch . zeros_like ( back_ce )
183+ fore_ce = fore_ce . squeeze ( 1 )
179184
180185 loss = torch .mean (torch .stack ([back_ce , fore_ce ], dim = - 1 ))
181186 return loss
0 commit comments