-
Notifications
You must be signed in to change notification settings - Fork 465
Adding ignore_index to segmentation mean_iou
#3266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
2119739
e1aa57a
b1adbc8
59744c1
caff134
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,8 +72,13 @@ def _mean_iou_update( | |
| num_classes: Optional[int] = None, | ||
| include_background: bool = False, | ||
| input_format: Literal["one-hot", "index", "mixed"] = "one-hot", | ||
| ignore_index: Optional[int] = None, | ||
| ) -> tuple[Tensor, Tensor]: | ||
| """Update the intersection and union counts for the mean IoU computation.""" | ||
| if ignore_index is not None and input_format == "index": | ||
| idx = target == ignore_index | ||
| target, preds = target[~idx], preds[~idx] | ||
|
||
|
|
||
| preds, target = _mean_iou_reshape_args(preds, target, input_format) | ||
|
|
||
| preds, target = _segmentation_inputs_format(preds, target, include_background, num_classes, input_format) | ||
|
|
@@ -102,6 +107,7 @@ def mean_iou( | |
| include_background: bool = True, | ||
| per_class: bool = False, | ||
| input_format: Literal["one-hot", "index", "mixed"] = "one-hot", | ||
| ignore_index: Optional[int] = None, | ||
| ) -> Tensor: | ||
| """Calculates the mean Intersection over Union (mIoU) for semantic segmentation. | ||
|
|
||
|
|
@@ -117,6 +123,8 @@ def mean_iou( | |
| input_format: What kind of input the function receives. | ||
| Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors | ||
| or ``"mixed"`` for one one-hot encoded and one index tensor | ||
| ignore_index: Class index to ignore in the target. This class will be ignored | ||
| in both the intersection and union computation. Only used when ``input_format="index"``. | ||
|
|
||
| Returns: | ||
| The mean IoU score | ||
|
|
@@ -151,7 +159,7 @@ def mean_iou( | |
|
|
||
| """ | ||
| _mean_iou_validate_args(num_classes, include_background, per_class, input_format) | ||
| intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format) | ||
| intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format, ignore_index) | ||
| scores = _mean_iou_compute(intersection, union, zero_division="nan") | ||
| valid_classes = union > 0 | ||
| return scores.nan_to_num(-1.0) if per_class else scores.nansum(dim=-1) / valid_classes.sum(dim=-1) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |||||||||||
| from torchmetrics.functional.segmentation.mean_iou import mean_iou | ||||||||||||
| from torchmetrics.segmentation.mean_iou import MeanIoU | ||||||||||||
| from unittests import NUM_CLASSES | ||||||||||||
| from unittests._helpers.testers import MetricTester | ||||||||||||
| from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index | ||||||||||||
| from unittests.segmentation.inputs import ( | ||||||||||||
| _index_input_1, | ||||||||||||
| _mixed_input_1, | ||||||||||||
|
|
@@ -41,27 +41,30 @@ def _reference_mean_iou( | |||||||||||
| include_background: bool = True, | ||||||||||||
| per_class: bool = True, | ||||||||||||
| reduce: bool = True, | ||||||||||||
| ignore_index: Optional[int] = None, | ||||||||||||
| ): | ||||||||||||
| """Calculate reference metric for `MeanIoU`.""" | ||||||||||||
| if input_format == "index": | ||||||||||||
| target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) | ||||||||||||
| preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) | ||||||||||||
| target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) | ||||||||||||
| elif input_format == "mixed": | ||||||||||||
| if preds.dim() == (target.dim() + 1): | ||||||||||||
| if torch.is_floating_point(preds): | ||||||||||||
| preds = preds.argmax(dim=1) | ||||||||||||
| preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) | ||||||||||||
| target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) | ||||||||||||
| preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) | ||||||||||||
| target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) | ||||||||||||
| elif (preds.dim() + 1) == target.dim(): | ||||||||||||
| if torch.is_floating_point(target): | ||||||||||||
| target = target.argmax(dim=1) | ||||||||||||
| target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) | ||||||||||||
| preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) | ||||||||||||
| target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) | ||||||||||||
| preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) | ||||||||||||
|
|
||||||||||||
| val = compute_iou(preds, target, include_background=include_background) | ||||||||||||
| val[torch.isnan(val)] = 0.0 | ||||||||||||
| if reduce: | ||||||||||||
| return torch.mean(val, 0) if per_class else torch.mean(val) | ||||||||||||
|
|
||||||||||||
| return val | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
@@ -83,11 +86,14 @@ def _reference_mean_iou( | |||||||||||
| class TestMeanIoU(MetricTester): | ||||||||||||
| """Test class for `MeanIoU` metric.""" | ||||||||||||
|
|
||||||||||||
|
||||||||||||
| # The tolerance has been relaxed from 1e-4 to 1e-2 due to minor numerical differences | |
| # between the reference and implementation, likely caused by floating point precision | |
| # or differences in third-party library computations (e.g., MONAI, PyTorch). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No additional comment needed here since the rationale is self-understood.
Copilot
AI
Sep 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expected IoU values are hardcoded without explanation of how they were calculated. Consider adding a comment explaining the calculation or using a more descriptive variable name to make the test more maintainable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No additional comment needed here since the rationale is self-understood.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we support that also in onehot by basically just removing the row in that dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could do that. However, in most cases the ignore_index value is 255 which will blow up the one-hot tensor, right? So, for one-hot tensors can we assume that the user would handle the masking?