@@ -51,6 +51,7 @@ def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
5151class LocalNormalizedCrossCorrelationLoss (_Loss ):
5252 """
5353 Local squared zero-normalized cross-correlation.
54+
5455 The loss is based on a moving kernel/window over the y_true/y_pred,
5556 within the window the square of zncc is calculated.
5657 The kernel can be a rectangular / triangular / gaussian window.
@@ -59,6 +60,35 @@ class LocalNormalizedCrossCorrelationLoss(_Loss):
5960 Adapted from:
6061 https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
6162 DeepReg (https://github.com/DeepRegNet/DeepReg)
63+
64+ Args:
65+ spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.
66+ kernel_size: kernel spatial size, must be odd.
67+ kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``.
68+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
69+ Specifies the reduction to apply to the output. Defaults to ``"mean"``.
70+
71+ - ``"none"``: no reduction will be applied.
72+ - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
73+ - ``"sum"``: the output will be summed.
74+ smooth_nr: a small constant added to the numerator to avoid nan.
75+ smooth_dr: a small constant added to the denominator to avoid nan.
76+
77+ Returns:
78+ torch.Tensor: The computed loss value. The output range is approximately [-1, 0], where:
79+ - Values closer to -1 indicate higher correlation (better match)
80+ - Values closer to 0 indicate lower correlation (worse match)
81+ - This loss should be **minimized** during optimization
82+
83+ Note:
84+ The implementation computes the squared normalized cross-correlation coefficient
85+ and then negates it, transforming the correlation maximization problem into a
86+ loss minimization problem suitable for standard PyTorch optimizers.
87+
88+ Interpretation:
89+ - Loss ≈ -1: Perfect correlation between images
90+ - Loss ≈ 0: No correlation between images
91+ - Lower (more negative) values indicate better alignment
6292 """
6393
6494 def __init__ (
@@ -70,21 +100,6 @@ def __init__(
70100 smooth_nr : float = 0.0 ,
71101 smooth_dr : float = 1e-5 ,
72102 ) -> None :
73- """
74- Args:
75- spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.
76- kernel_size: kernel spatial size, must be odd.
77- kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``.
78- reduction: {``"none"``, ``"mean"``, ``"sum"``}
79- Specifies the reduction to apply to the output. Defaults to ``"mean"``.
80-
81- - ``"none"``: no reduction will be applied.
82- - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
83- - ``"sum"``: the output will be summed.
84- smooth_nr: a small constant added to the numerator to avoid nan.
85- smooth_dr: a small constant added to the denominator to avoid nan.
86-
87- """
88103 super ().__init__ (reduction = LossReduction (reduction ).value )
89104
90105 self .ndim = spatial_dims
0 commit comments