Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,8 @@ def pad_nd(
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
try:
_pad = _np_pad
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in {
torch.int16,
torch.int64,
torch.bool,
torch.uint8,
}:
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"}:
# Try PyTorch pad for these modes; fallback to NumPy on error.
_pad = _pt_pad
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
except (ValueError, TypeError, RuntimeError) as err:
Expand Down
58 changes: 58 additions & 0 deletions tests/transforms/croppad/test_pad_nd_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

from unittest.mock import Mock, patch

import pytest
import torch

import monai.transforms.croppad.functional as F
from monai.transforms.croppad.functional import pad_nd


def test_pad_uses_pt_for_bool():
img = torch.ones((1, 4, 4), dtype=torch.bool)
to_pad = [(0, 0), (1, 1), (2, 2)]
with patch.object(F, "_pt_pad", wraps=F._pt_pad) as mock_pt, patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np:
out = pad_nd(img, to_pad, mode="constant", value=0)

assert mock_pt.called
assert not mock_np.called
assert out.dtype == img.dtype


def test_pad_falls_back_to_np_if_pt_raises():
img = torch.ones((1, 4, 4), dtype=torch.bool)
to_pad = [(0, 0), (1, 1), (2, 2)]
with (
patch.object(F, "_pt_pad", new=Mock(side_effect=NotImplementedError("no"))) as mock_pt,
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
):
out = pad_nd(img, to_pad, mode="constant", value=0)

assert mock_pt.called
assert mock_np.called
assert out.dtype == img.dtype


@pytest.mark.parametrize(
"dtype", [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.float32]
)
def test_pad_dtype_no_error_and_dtype_preserved(dtype):
img = torch.ones((1, 4, 4), dtype=dtype)
to_pad = [(0, 0), (1, 1), (2, 2)]
out = pad_nd(img, to_pad, mode="constant", value=0)

assert out.shape == (1, 6, 8)
assert out.dtype == img.dtype
Loading