@@ -16,6 +16,7 @@ def __init__(
16
16
log_loss : bool = False ,
17
17
from_logits : bool = True ,
18
18
smooth : float = 0.0 ,
19
+ ignore_index : Optional [int ] = None ,
19
20
eps : float = 1e-7 ,
20
21
):
21
22
"""Jaccard loss for image segmentation task.
@@ -29,6 +30,8 @@ def __init__(
29
30
otherwise `1 - jaccard_coeff`
30
31
from_logits: If True, assumes input is raw logits
31
32
smooth: Smoothness constant for dice coefficient
33
+ ignore_index: Label that indicates ignored pixels
34
+ (does not contribute to loss)
32
35
eps: A small epsilon for numerical stability to avoid zero division error
33
36
(denominator will be always greater or equal to eps)
34
37
@@ -53,6 +56,7 @@ def __init__(
53
56
self .from_logits = from_logits
54
57
self .smooth = smooth
55
58
self .eps = eps
59
+ self .ignore_index = ignore_index
56
60
self .log_loss = log_loss
57
61
58
62
def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
@@ -76,17 +80,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
76
80
y_true = y_true .view (bs , 1 , - 1 )
77
81
y_pred = y_pred .view (bs , 1 , - 1 )
78
82
83
+ if self .ignore_index is not None :
84
+ mask = y_true != self .ignore_index
85
+ y_pred = y_pred * mask
86
+ y_true = y_true * mask
87
+
79
88
if self .mode == MULTICLASS_MODE :
80
89
y_true = y_true .view (bs , - 1 )
81
90
y_pred = y_pred .view (bs , num_classes , - 1 )
82
91
83
- y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
84
- y_true = y_true .permute (0 , 2 , 1 ) # H, C, H*W
92
+ if self .ignore_index is not None :
93
+ mask = y_true != self .ignore_index
94
+ y_pred = y_pred * mask .unsqueeze (1 )
95
+
96
+ y_true = F .one_hot (
97
+ (y_true * mask ).to (torch .long ), num_classes
98
+ ) # N,H*W -> N,H*W, C
99
+ y_true = y_true .permute (0 , 2 , 1 ) * mask .unsqueeze (1 ) # N, C, H*W
100
+ else :
101
+ y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
102
+ y_true = y_true .permute (0 , 2 , 1 ) # H, C, H*W
85
103
86
104
if self .mode == MULTILABEL_MODE :
87
105
y_true = y_true .view (bs , num_classes , - 1 )
88
106
y_pred = y_pred .view (bs , num_classes , - 1 )
89
107
108
+ if self .ignore_index is not None :
109
+ mask = y_true != self .ignore_index
110
+ y_pred = y_pred * mask
111
+ y_true = y_true * mask
112
+
90
113
scores = soft_jaccard_score (
91
114
y_pred ,
92
115
y_true .type (y_pred .dtype ),
0 commit comments