Skip to content

Commit adcbfdb

Browse files
林旻佑林旻佑
authored andcommitted
WIP: save local changes before rebase
Signed-off-by: 林旻佑 <[email protected]>
1 parent 12a34b7 commit adcbfdb

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

monai/inferers/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@
3636
tqdm, _ = optional_import("tqdm", name="tqdm")
3737
_nearest_mode = "nearest-exact"
3838

39-
__all__ = ["sliding_window_inference"]
39+
__all__ = ["ensure_channel_first","sliding_window_inference"]
4040

4141
def ensure_channel_first(
4242
x: torch.Tensor,
4343
spatial_ndim: Optional[int] = None,
4444
channel_hint: Optional[int] = None,
4545
threshold: int = 32,
46-
) -> Tuple[torch.Tensor, int]:
46+
) -> tuple[torch.Tensor, int]:
4747
"""
4848
Normalize a tensor to channel-first layout (N, C, spatial...).
4949

monai/metrics/meandice.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,18 +309,22 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
309309
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
310310
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
311311
"""
312+
313+
312314
# --- Normalize layout to channel-first (N, C, spatial...) ---
313-
n_ch = self.num_classes or y_pred.shape[1]
315+
# Prefer a strong signal when available.
316+
if self.num_classes is not None:
317+
y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes)
318+
else:
319+
y_pred, _ = ensure_channel_first(y_pred)
314320

315-
# Always normalize y_pred with hint
316-
y_pred, _ = ensure_channel_first(y_pred, channel_hint=n_ch)
321+
# Infer channels after normalization (or use provided).
322+
n_ch = self.num_classes or y_pred.shape[1]
317323

318-
# Normalize y if it looks like channel-last (last dim = 1 or n_ch)
324+
# Normalize y if it plausibly is channel-last.
319325
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
320326
y, _ = ensure_channel_first(y, channel_hint=n_ch)
321327

322-
323-
324328
_apply_argmax, _threshold = self.apply_argmax, self.threshold
325329
if self.num_classes is None:
326330
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores

0 commit comments

Comments
 (0)