-
Notifications
You must be signed in to change notification settings - Fork 465
Bugfix for GeneralizedDiceScore to yield NaN for missing classes
#3251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 17 commits
9505cec
495413c
1963092
981ca9c
4317bc1
44175de
6b4ac90
5dacd29
1410b15
b142344
b820b7e
d2a89e1
4ae909a
8e26b68
ce55fea
59c6150
ef5dade
3524609
c47cb4a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,7 +90,10 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: | |
| if not per_class: | ||
| numerator = torch.sum(numerator, 1) | ||
| denominator = torch.sum(denominator, 1) | ||
| return _safe_divide(numerator, denominator) | ||
| else: | ||
| numerator = torch.sum(numerator, 0, keepdim=True) | ||
| denominator = torch.sum(denominator, 0, keepdim=True) | ||
| return _safe_divide(numerator, denominator, "nan") | ||
|
|
||
|
|
||
| def generalized_dice_score( | ||
|
|
@@ -126,10 +129,7 @@ def generalized_dice_score( | |
| >>> generalized_dice_score(preds, target, num_classes=5) | ||
| tensor([0.4830, 0.4935, 0.5044, 0.4880]) | ||
| >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) | ||
| tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], | ||
| [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], | ||
| [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], | ||
| [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) | ||
| tensor([[0.4845, 0.4997, 0.4993, 0.4864, 0.4912]]) | ||
|
|
||
| Example (with index tensors): | ||
| >>> from torch import randint | ||
|
|
@@ -139,10 +139,7 @@ def generalized_dice_score( | |
| >>> generalized_dice_score(preds, target, num_classes=5, input_format="index") | ||
| tensor([0.1991, 0.1971, 0.2350, 0.2216]) | ||
| >>> generalized_dice_score(preds, target, num_classes=5, per_class=True, input_format="index") | ||
| tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069], | ||
| [0.1837, 0.2162, 0.0962, 0.2692, 0.1895], | ||
| [0.3866, 0.1348, 0.2526, 0.2301, 0.2083], | ||
| [0.1978, 0.2804, 0.1714, 0.1915, 0.2783]]) | ||
| tensor([[0.2234, 0.2170, 0.1597, 0.2399, 0.2204]]) | ||
|
||
|
|
||
| """ | ||
| _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,9 +12,8 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from collections.abc import Sequence | ||
| from typing import Any, Optional, Union | ||
| from typing import Any, List, Optional, Union | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
| from typing_extensions import Literal | ||
|
|
||
|
|
@@ -24,6 +23,7 @@ | |
| _generalized_dice_validate_args, | ||
| ) | ||
| from torchmetrics.metric import Metric | ||
| from torchmetrics.utilities.data import dim_zero_cat | ||
| from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE | ||
| from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE | ||
|
|
||
|
|
@@ -100,15 +100,16 @@ class GeneralizedDiceScore(Metric): | |
| tensor(0.4992) | ||
| >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True) | ||
| >>> gds(preds, target) | ||
| tensor([0.5001, 0.4993, 0.4982]) | ||
| tensor([0.5000, 0.4993, 0.4983]) | ||
| >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) | ||
| >>> gds(preds, target) | ||
| tensor([0.4993, 0.4982]) | ||
| tensor([0.4993, 0.4983]) | ||
|
|
||
| """ | ||
|
|
||
| score: Tensor | ||
| samples: Tensor | ||
| class_present: Tensor | ||
| numerator: List[Tensor] | ||
| denominator: List[Tensor] | ||
| full_state_update: bool = False | ||
| is_differentiable: bool = False | ||
| higher_is_better: bool = True | ||
|
|
@@ -133,20 +134,21 @@ def __init__( | |
| self.input_format = input_format | ||
|
|
||
| num_classes = num_classes - 1 if not include_background else num_classes | ||
| self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") | ||
| self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") | ||
| self.add_state("numerator", default=[], dist_reduce_fx="cat") | ||
| self.add_state("denominator", default=[], dist_reduce_fx="cat") | ||
|
|
||
| def update(self, preds: Tensor, target: Tensor) -> None: | ||
| """Update the state with new data.""" | ||
| numerator, denominator = _generalized_dice_update( | ||
| preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format | ||
| ) | ||
| self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0) | ||
| self.samples += preds.shape[0] | ||
| self.numerator.append(numerator) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this still changes the memory format if you add to a list rather than adding to a tensor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please elaborate this a bit more? |
||
| self.denominator.append(denominator) | ||
|
|
||
| def compute(self) -> Tensor: | ||
| """Compute the final generalized dice score.""" | ||
| return self.score / self.samples | ||
| score = _generalized_dice_compute(dim_zero_cat(self.numerator), dim_zero_cat(self.denominator), self.per_class) | ||
| return score.mean(dim=0) | ||
|
|
||
| def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: | ||
| """Plot a single or multiple values from the metric. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tensorshape shouldn't change here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made the changes.