Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,20 @@ def pad_nd(
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
try:
_pad = _np_pad
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in {
torch.int16,
torch.int64,
torch.bool,
torch.uint8,
}:
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"}:
# Try PyTorch pad for these modes; fallback to NumPy on error.
_pad = _pt_pad
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
call_kwargs = dict(kwargs)
if mode != "constant":
call_kwargs.pop("value", None)
return _pad(img, pad_width=to_pad, mode=mode, **call_kwargs)
except NotImplementedError:
return _np_pad(img, pad_width=to_pad, mode=mode, **call_kwargs)
except (ValueError, TypeError, RuntimeError) as err:
if isinstance(err, NotImplementedError) or any(
if any(
k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")
):
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
return _np_pad(img, pad_width=to_pad, mode=mode, **call_kwargs)
Comment on lines +106 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file structure and read the relevant section
cd /repo
wc -l monai/transforms/croppad/functional.py

Repository: Project-MONAI/MONAI

Length of output: 178


🏁 Script executed:

# Read the code around lines 106-112 with more context
head -130 monai/transforms/croppad/functional.py | tail -35

Repository: Project-MONAI/MONAI

Length of output: 1953


🏁 Script executed:

# Search for the function that contains this error handling
grep -n "def.*pad" monai/transforms/croppad/functional.py | head -20

Repository: Project-MONAI/MONAI

Length of output: 507


🌐 Web query:

PyTorch torch.nn.functional.pad exception types error handling documentation

💡 Result:

Short answer: PyTorch's torch.nn.functional.pad doesn't have a dedicated "exceptions" section in the docs — it follows normal Python/C++ argument-checking and backend checks, so you should expect primarily TypeError and ValueError from bad arguments (and occasional RuntimeError from lower-level/back-end kernels). See examples and sources below.

Common cases and what raises them

  • TypeError: wrong argument types (e.g., TorchScript expects a List[int] for pad, not a tuple; wrong type for value). [3]
  • ValueError: invalid pad shape/length (e.g., pad length not compatible with input dims or torchvision explicitly checks and raises for unsupported lengths). Also raised when a mode (reflect/replicate) requires constraints (padding must be <= input dim etc.). See docs on pad/modes. [1][2]
  • RuntimeError (or backend errors): lower-level kernel or device-specific errors (e.g., CUDA nondeterminism warnings/errors or kernel failures) may surface as RuntimeError from the op at runtime. [1]

References

  • torch.nn.functional.pad docs (describes pad format and mode constraints). [1]
  • PyTorch GitHub issue showing ValueError from torchvision and differences in pad-length checking. [2]
  • PyTorch forum thread showing a TorchScript TypeError when pad is not a List[int]. [3]

If you want, I can: (a) show example inputs that trigger each exception and how to catch them (try/except), or (b) search the exact source-code checks for the PyTorch version you use — tell me your PyTorch version.


Error-handling approach is reasonable but fragile.

Matching substrings like "supported", "unexpected keyword", "implemented", "value" in exception messages works but is brittle—error text can change across PyTorch versions. The approach tries to distinguish fallback cases (unsupported modes/dtypes) from genuine errors, but PyTorch doesn't provide dedicated exception types for unsupported operations.

Document why substring matching is necessary here, or investigate whether PyTorch's error hierarchy or error codes could provide more stable detection. If documentation is chosen, explain the fallback detection strategy clearly for future maintainers.

🤖 Prompt for AI Agents
In monai/transforms/croppad/functional.py around lines 106 to 112, the current
exception handling falls back to numpy based on fragile substring checks in the
PyTorch error message; either replace brittle string matching with a more stable
detection (e.g., detect specific PyTorch exception classes or inspect error
attributes/errno where available) or, if no stable API exists, add an inline
comment and unit-test-backed docstring that explains why substring matching is
necessary, lists the exact substrings being matched, and notes the PyTorch
versions observed—update the code to centralize the matching logic into a small
helper function with tests and a clear explanatory comment so future maintainers
can safely modify or replace it.

raise ValueError(
f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}"
) from err
Expand Down
69 changes: 69 additions & 0 deletions tests/transforms/croppad/test_pad_nd_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for pad_nd dtype support and backend selection.
Validates PyTorch padding preference and NumPy fallback behavior.
"""
from __future__ import annotations
from unittest.mock import Mock, patch
import pytest
import torch
import monai.transforms.croppad.functional as F
from monai.transforms.croppad.functional import pad_nd

def test_pad_uses_pt_for_bool():
"""Test that pad_nd uses PyTorch backend for bool dtype in constant mode."""
img = torch.ones((1, 4, 4), dtype=torch.bool)
to_pad = [(0, 0), (1, 1), (2, 2)]
with patch.object(F, "_pt_pad", wraps=F._pt_pad) as mock_pt, patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np:
out = pad_nd(img, to_pad, mode="constant", value=0)

assert mock_pt.called
assert not mock_np.called
assert out.dtype == img.dtype

def test_pad_falls_back_to_np_if_pt_raises():
"""Test that pad_nd falls back to NumPy when PyTorch raises NotImplementedError."""
img = torch.ones((1, 4, 4), dtype=torch.bool)
to_pad = [(0, 0), (1, 1), (2, 2)]
with (
patch.object(F, "_pt_pad", new=Mock(side_effect=NotImplementedError("no"))) as mock_pt,
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
):
out = pad_nd(img, to_pad, mode="constant", value=0)

assert mock_pt.called
assert mock_np.called
assert out.dtype == img.dtype

@pytest.mark.parametrize(
"dtype", [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.float32]
)
def test_pad_dtype_no_error_and_dtype_preserved(dtype):
"""Test that pad_nd handles various dtypes without error and preserves dtype."""
img = torch.ones((1, 4, 4), dtype=dtype)
to_pad = [(0, 0), (1, 1), (2, 2)]
out = pad_nd(img, to_pad, mode="constant", value=0)

assert out.shape == (1, 6, 8)
assert out.dtype == img.dtype

@pytest.mark.parametrize("mode", ["constant", "reflect", "replicate"])
@pytest.mark.parametrize("dtype", [torch.bool, torch.int8, torch.float32])
def test_pad_multiple_modes_dtype_preserved(mode, dtype):
"""Test that pad_nd preserves dtype across multiple padding modes."""
img = torch.ones((1, 4, 4), dtype=dtype)
to_pad = [(0, 0), (1, 1), (2, 2)]

out = pad_nd(img, to_pad, mode=mode, value=0)

assert out.shape == (1, 6, 8)
assert out.dtype == img.dtype
Loading