Skip to content

Commit bdea2e8

Browse files
authored
Add classes/ignore_index to losses (#11)
1 parent 3401b84 commit bdea2e8

File tree

3 files changed

+49
-8
lines changed

3 files changed

+49
-8
lines changed

tests/test_losses.py

+13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torchseg.losses._functional as F
66
from torchseg.losses import (
77
DiceLoss,
8+
FocalLoss,
89
JaccardLoss,
910
MCCLoss,
1011
SoftBCEWithLogitsLoss,
@@ -333,3 +334,15 @@ def test_binary_mcc_loss():
333334

334335
loss = criterion(y_pred, y_true)
335336
assert float(loss) == pytest.approx(0.5, abs=eps)
337+
338+
339+
@torch.no_grad()
340+
@pytest.mark.parametrize("loss_fn", [DiceLoss, JaccardLoss, FocalLoss])
341+
@pytest.mark.parametrize("classes", [None, [1]])
342+
@pytest.mark.parametrize("ignore_index", [None, 0, -255])
343+
def test_classes_arg(loss_fn, classes, ignore_index):
344+
criterion = loss_fn(mode="multiclass", classes=classes, ignore_index=ignore_index)
345+
y_pred = torch.zeros(1, 2, 128, 128, dtype=torch.float)
346+
y_pred[:, 0, ...] = 1.0
347+
y_true = torch.ones(1, 128, 128, dtype=torch.long)
348+
criterion(y_pred, y_true)

torchseg/losses/focal.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
reduction: Optional[str] = "mean",
1919
normalized: bool = False,
2020
reduced_threshold: Optional[float] = None,
21+
classes: Optional[list[int]] = None,
2122
):
2223
"""Compute Focal loss
2324
@@ -30,6 +31,8 @@ def __init__(
3031
normalized: Use normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf)
3132
reduced_threshold: Switch to reduced focal loss.
3233
Note, when using this mode you should use `reduction="sum"`.
34+
classes: List of classes that contribute in loss computation.
35+
By default, all channels are included. Only supported in multiclass mode
3336
3437
Shape
3538
- **y_pred** - torch.Tensor of shape (N, C, H, W)
@@ -44,6 +47,7 @@ def __init__(
4447

4548
self.mode = mode
4649
self.ignore_index = ignore_index
50+
self.classes = classes
4751
self.focal_loss_fn = partial(
4852
focal_loss_with_logits,
4953
alpha=alpha,
@@ -75,13 +79,14 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7579
not_ignored = y_true != self.ignore_index
7680

7781
for cls in range(num_classes):
78-
cls_y_true = (y_true == cls).long()
79-
cls_y_pred = y_pred[:, cls, ...]
82+
if self.classes is None or cls in self.classes:
83+
cls_y_true = (y_true == cls).long()
84+
cls_y_pred = y_pred[:, cls, ...]
8085

81-
if self.ignore_index is not None:
82-
cls_y_true = cls_y_true[not_ignored]
83-
cls_y_pred = cls_y_pred[not_ignored]
86+
if self.ignore_index is not None:
87+
cls_y_true = cls_y_true[not_ignored]
88+
cls_y_pred = cls_y_pred[not_ignored]
8489

85-
loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
90+
loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
8691

8792
return loss

torchseg/losses/jaccard.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
log_loss: bool = False,
1717
from_logits: bool = True,
1818
smooth: float = 0.0,
19+
ignore_index: Optional[int] = None,
1920
eps: float = 1e-7,
2021
):
2122
"""Jaccard loss for image segmentation task.
@@ -29,6 +30,8 @@ def __init__(
2930
otherwise `1 - jaccard_coeff`
3031
from_logits: If True, assumes input is raw logits
3132
smooth: Smoothness constant for dice coefficient
33+
ignore_index: Label that indicates ignored pixels
34+
(does not contribute to loss)
3235
eps: A small epsilon for numerical stability to avoid zero division error
3336
(denominator will be always greater or equal to eps)
3437
@@ -53,6 +56,7 @@ def __init__(
5356
self.from_logits = from_logits
5457
self.smooth = smooth
5558
self.eps = eps
59+
self.ignore_index = ignore_index
5660
self.log_loss = log_loss
5761

5862
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
@@ -76,17 +80,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7680
y_true = y_true.view(bs, 1, -1)
7781
y_pred = y_pred.view(bs, 1, -1)
7882

83+
if self.ignore_index is not None:
84+
mask = y_true != self.ignore_index
85+
y_pred = y_pred * mask
86+
y_true = y_true * mask
87+
7988
if self.mode == MULTICLASS_MODE:
8089
y_true = y_true.view(bs, -1)
8190
y_pred = y_pred.view(bs, num_classes, -1)
8291

83-
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
84-
y_true = y_true.permute(0, 2, 1) # H, C, H*W
92+
if self.ignore_index is not None:
93+
mask = y_true != self.ignore_index
94+
y_pred = y_pred * mask.unsqueeze(1)
95+
96+
y_true = F.one_hot(
97+
(y_true * mask).to(torch.long), num_classes
98+
) # N,H*W -> N,H*W, C
99+
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
100+
else:
101+
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
102+
y_true = y_true.permute(0, 2, 1) # H, C, H*W
85103

86104
if self.mode == MULTILABEL_MODE:
87105
y_true = y_true.view(bs, num_classes, -1)
88106
y_pred = y_pred.view(bs, num_classes, -1)
89107

108+
if self.ignore_index is not None:
109+
mask = y_true != self.ignore_index
110+
y_pred = y_pred * mask
111+
y_true = y_true * mask
112+
90113
scores = soft_jaccard_score(
91114
y_pred,
92115
y_true.type(y_pred.dtype),

0 commit comments

Comments
 (0)