Skip to content

Commit 306d338

Browse files
committed
add docstring
Signed-off-by: ytl0623 <[email protected]>
1 parent c9002e0 commit 306d338

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7373

7474
if not self.include_background:
7575
if n_pred_ch == 1:
76-
warnings.warn("single channel prediction, `include_background=False` ignored.")
76+
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
7777
else:
7878
# if skipping background, removing first channel
7979
y_true = y_true[:, 1:]
@@ -110,7 +110,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
110110
loss_list.append(1 - dice_class[:, i])
111111
else:
112112
# Foreground classes: apply focal modulation
113-
# Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma)
114113
back_dice = torch.clamp(1 - dice_class[:, i], min=self.epsilon)
115114
loss_list.append(back_dice * torch.pow(back_dice, -self.gamma))
116115

@@ -176,8 +175,11 @@ def __init__(
176175
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
177176
"""
178177
Args:
179-
y_pred: (BNH[WD]) Logits (raw scores).
178+
y_pred: (BNH[WD]) Logits (raw scores, not probabilities).
179+
Do not pass pre-activated inputs; activation is applied internally.
180180
y_true: (BNH[WD]) Ground truth labels.
181+
Returns:
182+
torch.Tensor: Weighted combination of focal loss and asymmetric focal Tversky loss.
181183
"""
182184
focal_loss = self.focal_loss(y_pred, y_true)
183185

tests/losses/test_unified_focal_loss.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):
7272

7373
@parameterized.expand(TEST_CASES)
7474
def test_result(self, input_param, input_data, expected_val):
75+
"""
76+
Test AsymmetricUnifiedFocalLoss with various configurations.
77+
78+
Args:
79+
input_param: Dict of loss constructor parameters (use_softmax, include_background, etc.).
80+
input_data: Dict containing y_pred (logits) and y_true (ground truth) tensors.
81+
expected_val: Expected loss value.
82+
"""
7583
loss = AsymmetricUnifiedFocalLoss(**input_param)
7684
result = loss(**input_data)
7785
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

0 commit comments

Comments
 (0)