Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-



- Fixed `GeneralizedDiceScore` to yield `NaN` if there are missing classes ([#2846](https://github.com/Lightning-AI/torchmetrics/issues/2846))

---

## [1.8.2] - 2025-09-03
Expand Down
15 changes: 6 additions & 9 deletions src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class:
if not per_class:
numerator = torch.sum(numerator, 1)
denominator = torch.sum(denominator, 1)
return _safe_divide(numerator, denominator)
else:
numerator = torch.sum(numerator, 0, keepdim=True)
denominator = torch.sum(denominator, 0, keepdim=True)
return _safe_divide(numerator, denominator, "nan")


def generalized_dice_score(
Expand Down Expand Up @@ -126,10 +129,7 @@ def generalized_dice_score(
>>> generalized_dice_score(preds, target, num_classes=5)
tensor([0.4830, 0.4935, 0.5044, 0.4880])
>>> generalized_dice_score(preds, target, num_classes=5, per_class=True)
tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500],
[0.4571, 0.4980, 0.5191, 0.4380, 0.5649],
[0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])
tensor([[0.4845, 0.4997, 0.4993, 0.4864, 0.4912]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tensorshape shouldn't change here!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made the changes.


Example (with index tensors):
>>> from torch import randint
Expand All @@ -139,10 +139,7 @@ def generalized_dice_score(
>>> generalized_dice_score(preds, target, num_classes=5, input_format="index")
tensor([0.1991, 0.1971, 0.2350, 0.2216])
>>> generalized_dice_score(preds, target, num_classes=5, per_class=True, input_format="index")
tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069],
[0.1837, 0.2162, 0.0962, 0.2692, 0.1895],
[0.3866, 0.1348, 0.2526, 0.2301, 0.2083],
[0.1978, 0.2804, 0.1714, 0.1915, 0.2783]])
tensor([[0.2234, 0.2170, 0.1597, 0.2399, 0.2204]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here the shape should not change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


"""
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)
Expand Down
24 changes: 13 additions & 11 deletions src/torchmetrics/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

Expand All @@ -24,6 +23,7 @@
_generalized_dice_validate_args,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

Expand Down Expand Up @@ -100,15 +100,16 @@ class GeneralizedDiceScore(Metric):
tensor(0.4992)
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True)
>>> gds(preds, target)
tensor([0.5001, 0.4993, 0.4982])
tensor([0.5000, 0.4993, 0.4983])
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False)
>>> gds(preds, target)
tensor([0.4993, 0.4982])
tensor([0.4993, 0.4983])

"""

score: Tensor
samples: Tensor
class_present: Tensor
numerator: List[Tensor]
denominator: List[Tensor]
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
Expand All @@ -133,20 +134,21 @@ def __init__(
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("numerator", default=[], dist_reduce_fx="cat")
self.add_state("denominator", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with new data."""
numerator, denominator = _generalized_dice_update(
preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format
)
self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0)
self.samples += preds.shape[0]
self.numerator.append(numerator)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this still changes the memory format if you add to a list rather than adding to a tensor

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate this a bit more?

self.denominator.append(denominator)

def compute(self) -> Tensor:
"""Compute the final generalized dice score."""
return self.score / self.samples
score = _generalized_dice_compute(dim_zero_cat(self.numerator), dim_zero_cat(self.denominator), self.per_class)
return score.mean(dim=0)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def _reference_generalized_dice(
class TestGeneralizedDiceScore(MetricTester):
"""Test class for `GeneralizedDiceScore` metric."""

atol = 2e-3

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp):
"""Test class implementation of metric."""
Expand Down Expand Up @@ -122,3 +124,52 @@ def test_generalized_dice_functional(self, preds, target, input_format, include_
"input_format": input_format,
},
)


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_samples_with_missing_classes(per_class, include_background):
"""Test GeneralizedDiceScore with missing classes in some samples."""
target = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)
preds = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)

target[0, 0, 0, 0] = 1
preds[0, 0, 0, 0] = 1

target[2, 1, 0, 0] = 1
preds[2, 1, 0, 0] = 1

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)
score = metric(preds, target)

target_slice = target if include_background else target[:, 1:]
output_classes = NUM_CLASSES if include_background else NUM_CLASSES - 1

if per_class:
assert len(score) == output_classes
for c in range(output_classes):
assert score[c] == 1.0 if target_slice[:, c].sum() > 0 else torch.isnan(score[c])
else:
assert score.isnan()


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_generalized_dice_zero_denominator(per_class, include_background):
"""Check that GeneralizedDiceScore returns NaN when the denominator is all zero (no class present)."""
target = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)
preds = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)

score = metric(preds, target)

if per_class and include_background:
assert len(score) == NUM_CLASSES
assert all(t.isnan() for t in score)
elif per_class and not include_background:
assert len(score) == NUM_CLASSES - 1
assert all(t.isnan() for t in score)
else:
# Expect scalar NaN
assert score.isnan()
Loading