@@ -79,14 +79,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7979 dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
8080
8181 # Calculate losses separately for each class, enhancing both classes
82- back_dice = 1 - dice_class [:, 0 ]
83- fore_dice = (1 - dice_class [:, 1 ]) * torch .pow (1 - dice_class [:, 1 ], - self .gamma )
82+ back_dice = 1 - dice_class [:, 0 : 1 ]
83+ fore_dice = (1 - dice_class [:, 1 : ]) * torch .pow (1 - dice_class [:, 1 : ], - self .gamma )
8484
8585 if not self .include_background :
8686 back_dice = back_dice * 0.0
8787
88+ all_dice = torch .cat ([back_dice , fore_dice ], dim = 1 )
89+
8890 # Average class scores
89- loss = torch .mean (torch . stack ([ back_dice , fore_dice ], dim = - 1 ) )
91+ loss = torch .mean (all_dice )
9092 return loss
9193
9294
@@ -141,16 +143,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
141143 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
142144 cross_entropy = - y_true * torch .log (y_pred )
143145
144- back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
146+ back_ce = torch .pow (1 - y_pred [:, 0 : 1 ], self .gamma ) * cross_entropy [:, 0 : 1 ]
145147 back_ce = (1 - self .delta ) * back_ce
146148
147- fore_ce = cross_entropy [:, 1 ]
149+ fore_ce = cross_entropy [:, 1 : ]
148150 fore_ce = self .delta * fore_ce
149151
150152 if not self .include_background :
151153 back_ce = back_ce * 0.0
152154
153- loss = torch .mean (torch .sum (torch .stack ([back_ce , fore_ce ], dim = 1 ), dim = 1 ))
155+ all_ce = torch .cat ([back_ce , fore_ce ], dim = 1 )
156+
157+ loss = torch .mean (torch .sum (all_ce , dim = 1 ))
154158 return loss
155159
156160
0 commit comments