@@ -309,18 +309,22 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
309309 the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
310310 y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
311311 """
312+
313+
312314 # --- Normalize layout to channel-first (N, C, spatial...) ---
313- n_ch = self .num_classes or y_pred .shape [1 ]
315+ # Prefer a strong signal when available.
316+ if self .num_classes is not None :
317+ y_pred , _ = ensure_channel_first (y_pred , channel_hint = self .num_classes )
318+ else :
319+ y_pred , _ = ensure_channel_first (y_pred )
314320
315- # Always normalize y_pred with hint
316- y_pred , _ = ensure_channel_first ( y_pred , channel_hint = n_ch )
321+ # Infer channels after normalization (or use provided).
322+ n_ch = self . num_classes or y_pred . shape [ 1 ]
317323
318- # Normalize y if it looks like channel-last (last dim = 1 or n_ch)
324+ # Normalize y if it plausibly is channel-last.
319325 if y .ndim == y_pred .ndim and y .shape [- 1 ] in (1 , n_ch ):
320326 y , _ = ensure_channel_first (y , channel_hint = n_ch )
321327
322-
323-
324328 _apply_argmax , _threshold = self .apply_argmax , self .threshold
325329 if self .num_classes is None :
326330 n_pred_ch = y_pred .shape [1 ] # y_pred is in one-hot format or multi-channel scores
0 commit comments