Skip to content

Commit 88b1182

Browse files
committed
Weights in alpha for FocalLoss
Signed-off-by: ytl0623 <[email protected]>
1 parent 15fd428 commit 88b1182

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

monai/losses/focal_loss.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,21 @@ def __init__(
7070
include_background: bool = True,
7171
to_onehot_y: bool = False,
7272
gamma: float = 2.0,
73-
alpha: float | None = None,
73+
alpha: float | Sequence[float] | None = None,
7474
weight: Sequence[float] | float | int | torch.Tensor | None = None,
7575
reduction: LossReduction | str = LossReduction.MEAN,
7676
use_softmax: bool = False,
7777
) -> None:
7878
"""
7979
Args:
8080
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
81-
If False, `alpha` is invalid when using softmax.
81+
If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights).
8282
to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
8383
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
8484
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
85-
The value should be in [0, 1]. Defaults to None.
85+
The value should be in [0, 1].
86+
If a sequence is provided, it must match the number of classes (after excluding background if set).
87+
Defaults to None.
8688
weight: weights to apply to the voxels of each class. If None no weights are applied.
8789
The input can be a single value (same weight for all classes), a sequence of values (the length
8890
of the sequence should be the same as the number of classes. If not ``include_background``,
@@ -156,13 +158,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
156158
loss: Optional[torch.Tensor] = None
157159
input = input.float()
158160
target = target.float()
161+
162+
alpha_arg = self.alpha
159163
if self.use_softmax:
160164
if not self.include_background and self.alpha is not None:
161-
self.alpha = None
162-
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
163-
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
165+
if isinstance(self.alpha, (float, int)):
166+
alpha_arg = None
167+
warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.")
168+
loss = softmax_focal_loss(input, target, self.gamma, self.alpha_arg)
164169
else:
165-
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
170+
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha_arg)
166171

167172
num_of_classes = target.shape[1]
168173
if self.class_weight is not None and num_of_classes != 1:
@@ -203,7 +208,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
203208

204209

205210
def softmax_focal_loss(
206-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
211+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None
207212
) -> torch.Tensor:
208213
"""
209214
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -215,8 +220,18 @@ def softmax_focal_loss(
215220
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target
216221

217222
if alpha is not None:
218-
# (1-alpha) for the background class and alpha for the other classes
219-
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
223+
alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype)
224+
225+
if alpha_t.ndim == 0: # scalar
226+
# (1-alpha) for the background class and alpha for the other classes
227+
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
228+
else: # sequence
229+
if alpha_t.shape[0] != target.shape[1]:
230+
raise ValueError(
231+
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
232+
)
233+
alpha_fac = alpha_t
234+
220235
broadcast_dims = [-1] + [1] * len(target.shape[2:])
221236
alpha_fac = alpha_fac.view(broadcast_dims)
222237
loss = alpha_fac * loss
@@ -225,7 +240,7 @@ def softmax_focal_loss(
225240

226241

227242
def sigmoid_focal_loss(
228-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
243+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None
229244
) -> torch.Tensor:
230245
"""
231246
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -248,8 +263,21 @@ def sigmoid_focal_loss(
248263
loss = (invprobs * gamma).exp() * loss
249264

250265
if alpha is not None:
251-
# alpha if t==1; (1-alpha) if t==0
252-
alpha_factor = target * alpha + (1 - target) * (1 - alpha)
266+
alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype)
267+
if alpha_t.ndim == 0: # scalar
268+
# alpha if t==1; (1-alpha) if t==0
269+
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
270+
else: # sequence / per-channel alpha
271+
if alpha_t.shape[0] != target.shape[1]:
272+
raise ValueError(
273+
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
274+
)
275+
# Reshape alpha for broadcasting: (1, C, 1, 1...)
276+
broadcast_dims = [-1] + [1] * len(target.shape[2:])
277+
alpha_t = alpha_t.view(broadcast_dims)
278+
# Apply alpha_c if t==1, (1-alpha_c) if t==0 for channel c
279+
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
280+
253281
loss = alpha_factor * loss
254282

255283
return loss

0 commit comments

Comments
 (0)