Skip to content

Commit df7711d

Browse files
committed
ENH: support additional dtypes in pad_nd
Prefer the PyTorch padding backend when supported and safely fall back to NumPy on error. Add unit tests to validate backend selection and ensure output dtype is preserved. Signed-off-by: Shubham Chandravanshi <[email protected]>
1 parent 15fd428 commit df7711d

File tree

2 files changed

+60
-6
lines changed

2 files changed

+60
-6
lines changed

monai/transforms/croppad/functional.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,8 @@ def pad_nd(
9696
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
9797
try:
9898
_pad = _np_pad
99-
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in {
100-
torch.int16,
101-
torch.int64,
102-
torch.bool,
103-
torch.uint8,
104-
}:
99+
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"}:
100+
# Try PyTorch pad for these modes; fallback to NumPy on error.
105101
_pad = _pt_pad
106102
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
107103
except (ValueError, TypeError, RuntimeError) as err:
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
from __future__ import annotations
14+
15+
from unittest.mock import Mock, patch
16+
17+
import pytest
18+
import torch
19+
20+
import monai.transforms.croppad.functional as F
21+
from monai.transforms.croppad.functional import pad_nd
22+
23+
24+
def test_pad_uses_pt_for_bool():
25+
img = torch.ones((1, 4, 4), dtype=torch.bool)
26+
to_pad = [(0, 0), (1, 1), (2, 2)]
27+
with patch.object(F, "_pt_pad", wraps=F._pt_pad) as mock_pt, patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np:
28+
out = pad_nd(img, to_pad, mode="constant", value=0)
29+
30+
assert mock_pt.called
31+
assert not mock_np.called
32+
assert out.dtype == img.dtype
33+
34+
35+
def test_pad_falls_back_to_np_if_pt_raises():
36+
img = torch.ones((1, 4, 4), dtype=torch.bool)
37+
to_pad = [(0, 0), (1, 1), (2, 2)]
38+
with (
39+
patch.object(F, "_pt_pad", new=Mock(side_effect=NotImplementedError("no"))) as mock_pt,
40+
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
41+
):
42+
out = pad_nd(img, to_pad, mode="constant", value=0)
43+
44+
assert mock_pt.called
45+
assert mock_np.called
46+
assert out.dtype == img.dtype
47+
48+
49+
@pytest.mark.parametrize(
50+
"dtype", [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.float32]
51+
)
52+
def test_pad_dtype_no_error_and_dtype_preserved(dtype):
53+
img = torch.ones((1, 4, 4), dtype=dtype)
54+
to_pad = [(0, 0), (1, 1), (2, 2)]
55+
out = pad_nd(img, to_pad, mode="constant", value=0)
56+
57+
assert out.shape == (1, 6, 8)
58+
assert out.dtype == img.dtype

0 commit comments

Comments
 (0)