@@ -73,7 +73,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7373
7474 if not self .include_background :
7575 if n_pred_ch == 1 :
76- warnings .warn ("single channel prediction, `include_background=False` ignored." )
76+ warnings .warn ("single channel prediction, `include_background=False` ignored." , stacklevel = 2 )
7777 else :
7878 # if skipping background, removing first channel
7979 y_true = y_true [:, 1 :]
@@ -110,7 +110,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
110110 loss_list .append (1 - dice_class [:, i ])
111111 else :
112112 # Foreground classes: apply focal modulation
113- # Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma)
114113 back_dice = torch .clamp (1 - dice_class [:, i ], min = self .epsilon )
115114 loss_list .append (back_dice * torch .pow (back_dice , - self .gamma ))
116115
@@ -176,8 +175,11 @@ def __init__(
176175 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
177176 """
178177 Args:
179- y_pred: (BNH[WD]) Logits (raw scores).
178+ y_pred: (BNH[WD]) Logits (raw scores, not probabilities).
179+ Do not pass pre-activated inputs; activation is applied internally.
180180 y_true: (BNH[WD]) Ground truth labels.
181+ Returns:
182+ torch.Tensor: Weighted combination of focal loss and asymmetric focal Tversky loss.
181183 """
182184 focal_loss = self .focal_loss (y_pred , y_true )
183185
0 commit comments