Skip to content

SoftDiceclDiceLoss is not well documented and produces zero loss as it is #8239

@chezhia

Description

@chezhia

Describe the bug
The SoftDiceclDiceLoss implementation is different from Dice loss and in its current form could not be switched with Dice or other popular losses offered. There is no option for excluding background, applying activation functions etc. Also there is no description of the expected input (y_pred)- can this be probabilities or logits or should it be a binary mask similar to the ground truth (y_true).

To Reproduce
Use this enhanced class that mimics other MONAI losses:

from torch.nn.modules.loss import _Loss
from monai.networks import one_hot
from monai.losses import SoftDiceclDiceLoss
import warnings

class EnhancedSoftDiceClDiceLoss(_Loss):
"""
Enhanced version of SoftDiceClDiceLoss with support for:
- Excluding background channel
- Applying activations (sigmoid, softmax, or custom)
- Handling one-hot encoded targets
- Flexible reduction (mean, sum, none)
"""

def __init__(
    self,
    iter_: int = 3,
    alpha: float = 0.5,
    smooth: float = 1.0,
    include_background: bool = True,
    to_onehot_y: bool = False,
    sigmoid: bool = False,
    softmax: bool = False,
    other_act: callable | None = None,
    reduction: str = "mean",
) -> None:
    """
    Args:
        iter_: Number of iterations for skeletonization
        smooth: Smoothing parameter
        alpha: Weighing factor for cldice
        include_background: If False, excludes the background channel from the loss computation.
        to_onehot_y: If True, converts `y` into one-hot format. Defaults to False.
        sigmoid: If True, applies sigmoid activation to predictions.
        softmax: If True, applies softmax activation to predictions.
        other_act: Callable function for custom activation (e.g., torch.tanh).
        threshold: Threshold value for discretization (applies after sigmoid).
        argmax: If True, applies argmax for discretization (applies after softmax).
        reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction applied to the output.
    """
    super().__init__()
    self.include_background = include_background
    self.to_onehot_y = to_onehot_y
    self.sigmoid = sigmoid
    self.softmax = softmax
    self.other_act = other_act
    self.reduction = reduction.lower()
    
    # Validate activation settings
    if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
        raise ValueError("Only one of [sigmoid=True, softmax=True, other_act] can be set.")
    if self.reduction not in ["mean", "sum", "none"]:
        raise ValueError(f"Unsupported reduction mode: {self.reduction}")

    # Create an instance of the original SoftDiceclDiceLoss
    self.base_loss = SoftDiceclDiceLoss(iter_=iter_, alpha=alpha, smooth=smooth)

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    """
    Args:
        y_pred: Predicted tensor with shape [B, C, H, W, ...].
        y_true: Ground truth tensor with shape [B, C, H, W, ...].

    Returns:
        Computed loss value.
    """
    n_pred_ch = y_pred.shape[1]
    
            # Convert ground truth to one-hot if necessary
    if self.to_onehot_y:
        if n_pred_ch == 1:
            warnings.warn("Single channel prediction, to_onehot_y=True ignored.")
        else:
            y_true = one_hot(y_true, num_classes=n_pred_ch)
    
    # Exclude background channel if specified
    if not self.include_background:
        if n_pred_ch == 1:
            warnings.warn("Single channel prediction, include_background=False ignored.")
        else:
            y_pred = y_pred[:, 1:]
            y_true = y_true[:, 1:]

    # Apply activation if specified
    if self.sigmoid:
        y_pred = torch.sigmoid(y_pred)
        y_pred = torch.sigmoid((y_pred - 0.5) * 10)  # Differentiable approximation
    elif self.softmax:
        if y_pred.shape[1] == 1:
            warnings.warn("Single channel prediction, softmax=True ignored and sigmoid applied")
            y_pred = torch.sigmoid(y_pred)
        else:
            y_pred = torch.softmax(y_pred, dim=1)
    elif self.other_act is not None:
        y_pred = self.other_act(y_pred)

    # Ensure shapes match
    if y_pred.shape != y_true.shape:
        raise AssertionError(f"Shape mismatch: y_pred {y_pred.shape}, y_true {y_true.shape}")

    # Delegate loss computation to the original SoftDiceclDiceLoss
    loss = self.base_loss(y_true, y_pred)

    # Apply reduction if necessary
    if self.reduction == "mean":
        return loss.mean()
    elif self.reduction == "sum":
        return loss.sum()
    return loss

Expected behavior
The loss calculated is zero, even after using a custom class with enhancements, need input on how to avoid zero losses.

Screenshots
image

Environment

Ensuring you use the relevant python executable, please paste the output of:

/root/.cache/pypoetry/virtualenvs/segmentation-codebase-os60uNmW-py3.10/lib/python3.10/site-packages/_distutils_hack/__init__.py:55: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml
  warnings.warn(
/root/.cache/pypoetry/virtualenvs/segmentation-codebase-os60uNmW-py3.10/lib/python3.10/site-packages/ignite/handlers/checkpoint.py:17: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import ZeroRedundancyOptimizer
================================
Printing MONAI config...
================================
MONAI version: 1.3.0
Numpy version: 1.26.4
Pytorch version: 2.4.0+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /<username>/.cache/pypoetry/virtualenvs/segmentation-codebase-os60uNmW-py3.10/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.0
Nibabel version: 5.2.1
scikit-image version: 0.24.0
scipy version: 1.14.0
Pillow version: 10.4.0
Tensorboard version: 2.17.0
gdown version: 5.2.0
TorchVision version: 0.19.0+cu121
tqdm version: 4.66.5
lmdb version: 1.5.1
psutil version: 6.0.0
pandas version: 2.2.2
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.15.1
pynrrd version: 1.0.0
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Linux
Linux version: Ubuntu 22.04.2 LTS
Platform: Linux-5.4.0-169-generic-x86_64-with-glibc2.35
Processor: x86_64
Machine: x86_64
Python version: 3.10.12
Process name: pt_main_thread
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 128
Num logical CPUs: 256
Num usable CPUs: 256
CPU usage (%): [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 2.8, 1.7, 0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.8, 2.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 3.3, 0.0, 0.0, 3.3, 3.1, 3.3, 2.8, 0.0, 0.0, 0.0, 3.1, 3.1, 1.1, 3.3, 0.0, 2.8, 3.1, 3.1, 0.0, 3.1, 0.0, 0.0, 0.0, 0.6, 0.0, 0.3, 0.0, 5.0, 0.0, 0.0, 3.3, 0.0, 0.0, 2.8, 0.0, 3.1, 2.8, 2.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 3.3, 2.8, 3.3, 3.1, 3.9, 3.3, 3.9, 3.4, 3.3, 3.6, 3.6, 0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 0.0, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.1, 0.0, 0.3, 0.0, 0.0, 0.0, 3.1, 0.0, 3.3, 3.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.1, 0.0, 3.1, 0.0, 0.0, 0.0, 3.3, 0.0, 0.0, 3.3, 3.3, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0, 3.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.9, 0.0, 0.0, 3.3, 3.6, 3.3, 3.3, 3.1, 3.3, 3.6, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 0.0, 1.9, 99.2]
CPU freq. (MHz): 2924
Load avg. in last 1, 5, 15 mins (%): [0.1, 0.1, 0.2]
Disk usage (%): 95.3
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 1007.7
Available memory (GB): 965.3
Used memory (GB): 34.6

================================
Printing GPU config...
================================
Num GPUs: 8
Has CUDA: True
CUDA version: 12.1
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: 1
cuDNN version: 90100
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A100-SXM4-40GB
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 108
GPU 0 Total memory (GB): 39.4
GPU 0 CUDA capability (maj.min): 8.0
GPU 1 Name: NVIDIA A100-SXM4-40GB
GPU 1 Is integrated: False
GPU 1 Is multi GPU board: False
GPU 1 Multi processor count: 108
GPU 1 Total memory (GB): 39.4
GPU 1 CUDA capability (maj.min): 8.0
GPU 2 Name: NVIDIA A100-SXM4-40GB
GPU 2 Is integrated: False
GPU 2 Is multi GPU board: False
GPU 2 Multi processor count: 108
GPU 2 Total memory (GB): 39.4
GPU 2 CUDA capability (maj.min): 8.0
GPU 3 Name: NVIDIA A100-SXM4-40GB
GPU 3 Is integrated: False
GPU 3 Is multi GPU board: False
GPU 3 Multi processor count: 108
GPU 3 Total memory (GB): 39.4
GPU 3 CUDA capability (maj.min): 8.0
GPU 4 Name: NVIDIA A100-SXM4-40GB
GPU 4 Is integrated: False
GPU 4 Is multi GPU board: False
GPU 4 Multi processor count: 108
GPU 4 Total memory (GB): 39.4
GPU 4 CUDA capability (maj.min): 8.0
GPU 5 Name: NVIDIA A100-SXM4-40GB
GPU 5 Is integrated: False
GPU 5 Is multi GPU board: False
GPU 5 Multi processor count: 108
GPU 5 Total memory (GB): 39.4
GPU 5 CUDA capability (maj.min): 8.0
GPU 6 Name: NVIDIA A100-SXM4-40GB
GPU 6 Is integrated: False
GPU 6 Is multi GPU board: False
GPU 6 Multi processor count: 108
GPU 6 Total memory (GB): 39.4
GPU 6 CUDA capability (maj.min): 8.0
GPU 7 Name: NVIDIA A100-SXM4-40GB
GPU 7 Is integrated: False
GPU 7 Is multi GPU board: False
GPU 7 Multi processor count: 108
GPU 7 Total memory (GB): 39.4
GPU 7 CUDA capability (maj.min): 8.0

Additional context
Trying to use this loss for airway segmentations

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions