@@ -112,12 +112,17 @@ def __init__(
112112 self .include_background = include_background
113113 self .to_onehot_y = to_onehot_y
114114 self .gamma = gamma
115- self .alpha = alpha
116115 self .weight = weight
117116 self .use_softmax = use_softmax
118117 weight = torch .as_tensor (weight ) if weight is not None else None
119118 self .register_buffer ("class_weight" , weight )
120119 self .class_weight : None | torch .Tensor
120+ self .alpha : float | torch .Tensor | None
121+
122+ if isinstance (alpha , (list , tuple )):
123+ self .alpha = torch .tensor (alpha )
124+ else :
125+ self .alpha = alpha
121126
122127 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
123128 """
@@ -159,7 +164,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
159164 input = input .float ()
160165 target = target .float ()
161166
162- alpha_arg = self .alpha
167+ alpha_arg : float | torch .Tensor | None = self .alpha
168+ if isinstance (alpha_arg , torch .Tensor ):
169+ alpha_arg = alpha_arg .to (input .device )
170+
163171 if self .use_softmax :
164172 if not self .include_background and self .alpha is not None :
165173 if isinstance (self .alpha , (float , int )):
@@ -208,7 +216,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
208216
209217
210218def softmax_focal_loss (
211- input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | Sequence [ float ] | None = None
219+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | torch . Tensor | None = None
212220) -> torch .Tensor :
213221 """
214222 FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -241,7 +249,7 @@ def softmax_focal_loss(
241249
242250
243251def sigmoid_focal_loss (
244- input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | Sequence [ float ] | None = None
252+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | torch . Tensor | None = None
245253) -> torch .Tensor :
246254 """
247255 FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
0 commit comments