Skip to content

Commit 1b24834

Browse files
committed
fix undefined type error
Signed-off-by: ytl0623 <[email protected]>
1 parent 50cc7e9 commit 1b24834

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

monai/losses/focal_loss.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,17 @@ def __init__(
112112
self.include_background = include_background
113113
self.to_onehot_y = to_onehot_y
114114
self.gamma = gamma
115-
self.alpha = alpha
116115
self.weight = weight
117116
self.use_softmax = use_softmax
118117
weight = torch.as_tensor(weight) if weight is not None else None
119118
self.register_buffer("class_weight", weight)
120119
self.class_weight: None | torch.Tensor
120+
self.alpha: float | torch.Tensor | None
121+
122+
if isinstance(alpha, (list, tuple)):
123+
self.alpha = torch.tensor(alpha)
124+
else:
125+
self.alpha = alpha
121126

122127
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
123128
"""
@@ -159,7 +164,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
159164
input = input.float()
160165
target = target.float()
161166

162-
alpha_arg = self.alpha
167+
alpha_arg: float | torch.Tensor | None = self.alpha
168+
if isinstance(alpha_arg, torch.Tensor):
169+
alpha_arg = alpha_arg.to(input.device)
170+
163171
if self.use_softmax:
164172
if not self.include_background and self.alpha is not None:
165173
if isinstance(self.alpha, (float, int)):
@@ -208,7 +216,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
208216

209217

210218
def softmax_focal_loss(
211-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None
219+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
212220
) -> torch.Tensor:
213221
"""
214222
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -241,7 +249,7 @@ def softmax_focal_loss(
241249

242250

243251
def sigmoid_focal_loss(
244-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None
252+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
245253
) -> torch.Tensor:
246254
"""
247255
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

0 commit comments

Comments
 (0)