Skip to content

Commit 2afab14

Browse files
authored
Merge branch 'dev' into signal_type_fix
2 parents 67916d0 + 0968da2 commit 2afab14

File tree

4 files changed

+206
-22
lines changed

4 files changed

+206
-22
lines changed

monai/transforms/spatial/array.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
GridSamplePadMode,
6565
InterpolateMode,
6666
NumpyPadMode,
67+
SpaceKeys,
6768
convert_to_cupy,
6869
convert_to_dst_type,
6970
convert_to_numpy,
@@ -75,6 +76,7 @@
7576
issequenceiterable,
7677
optional_import,
7778
)
79+
from monai.utils.deprecate_utils import deprecated_arg_default
7880
from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends
7981
from monai.utils.misc import ImageMetaKey as Key
8082
from monai.utils.module import look_up_option
@@ -556,11 +558,20 @@ class Orientation(InvertibleTransform, LazyTransform):
556558

557559
backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
558560

561+
@deprecated_arg_default(
562+
name="labels",
563+
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
564+
new_default=None,
565+
msg_suffix=(
566+
"Default value changed to None meaning that the transform now uses the 'space' of a "
567+
"meta-tensor, if applicable, to determine appropriate axis labels."
568+
),
569+
)
559570
def __init__(
560571
self,
561572
axcodes: str | None = None,
562573
as_closest_canonical: bool = False,
563-
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
574+
labels: Sequence[tuple[str, str]] | None = None,
564575
lazy: bool = False,
565576
) -> None:
566577
"""
@@ -573,7 +584,14 @@ def __init__(
573584
as_closest_canonical: if True, load the image as closest to canonical axis format.
574585
labels: optional, None or sequence of (2,) sequences
575586
(2,) sequences are labels for (beginning, end) of output axis.
576-
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
587+
If ``None``, an appropriate value is chosen depending on the
588+
value of the ``"space"`` metadata item of a metatensor: if
589+
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
590+
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
591+
input is not a meta-tensor or has no ``"space"`` item, the
592+
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
593+
``None``, the provided value is always used and the ``"space"``
594+
metadata item (if any) of the input is ignored.
577595
lazy: a flag to indicate whether this transform should execute lazily or not.
578596
Defaults to False
579597
@@ -619,9 +637,19 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
619637
raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.")
620638
affine_: np.ndarray
621639
affine_np: np.ndarray
640+
labels = self.labels
622641
if isinstance(data_array, MetaTensor):
623642
affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray)
624643
affine_ = to_affine_nd(sr, affine_np)
644+
645+
# Set up "labels" such that LPS tensors are handled correctly by default
646+
if (
647+
self.labels is None
648+
and "space" in data_array.meta
649+
and SpaceKeys(data_array.meta["space"]) == SpaceKeys.LPS
650+
):
651+
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS
652+
625653
else:
626654
warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.")
627655
# default to identity
@@ -640,7 +668,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
640668
f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]},"
641669
"please make sure the input is in the channel-first format."
642670
)
643-
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels)
671+
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels)
644672
if len(dst) < sr:
645673
raise ValueError(
646674
f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D"
@@ -653,8 +681,19 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
653681
transform = self.pop_transform(data)
654682
# Create inverse transform
655683
orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"]
656-
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
657-
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels)
684+
labels = self.labels
685+
686+
# Set up "labels" such that LPS tensors are handled correctly by default
687+
if (
688+
isinstance(data, MetaTensor)
689+
and self.labels is None
690+
and "space" in data.meta
691+
and SpaceKeys(data.meta["space"]) == SpaceKeys.LPS
692+
):
693+
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS
694+
695+
orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels)
696+
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels)
658697
# Apply inverse
659698
with inverse_transform.trace_transform(False):
660699
data = inverse_transform(data)

monai/transforms/spatial/dictionary.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
ensure_tuple_rep,
7272
fall_back_tuple,
7373
)
74+
from monai.utils.deprecate_utils import deprecated_arg_default
7475
from monai.utils.enums import TraceKeys
7576
from monai.utils.module import optional_import
7677

@@ -545,12 +546,21 @@ class Orientationd(MapTransform, InvertibleTransform, LazyTransform):
545546

546547
backend = Orientation.backend
547548

549+
@deprecated_arg_default(
550+
name="labels",
551+
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
552+
new_default=None,
553+
msg_suffix=(
554+
"Default value changed to None meaning that the transform now uses the 'space' of a "
555+
"meta-tensor, if applicable, to determine appropriate axis labels."
556+
),
557+
)
548558
def __init__(
549559
self,
550560
keys: KeysCollection,
551561
axcodes: str | None = None,
552562
as_closest_canonical: bool = False,
553-
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
563+
labels: Sequence[tuple[str, str]] | None = None,
554564
allow_missing_keys: bool = False,
555565
lazy: bool = False,
556566
) -> None:
@@ -564,7 +574,14 @@ def __init__(
564574
as_closest_canonical: if True, load the image as closest to canonical axis format.
565575
labels: optional, None or sequence of (2,) sequences
566576
(2,) sequences are labels for (beginning, end) of output axis.
567-
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
577+
If ``None``, an appropriate value is chosen depending on the
578+
value of the ``"space"`` metadata item of a metatensor: if
579+
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
580+
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
581+
input is not a meta-tensor or has no ``"space"`` item, the
582+
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
583+
``None``, the provided value is always used and the ``"space"``
584+
metadata item (if any) of the input is ignored.
568585
allow_missing_keys: don't raise exception if key is missing.
569586
lazy: a flag to indicate whether this transform should execute lazily or not.
570587
Defaults to False

tests/transforms/test_orientation.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import unittest
15+
from typing import cast
1516

1617
import nibabel as nib
1718
import numpy as np
@@ -21,6 +22,7 @@
2122
from monai.data.meta_obj import set_track_meta
2223
from monai.data.meta_tensor import MetaTensor
2324
from monai.transforms import Orientation, create_rotate, create_translate
25+
from monai.utils import SpaceKeys
2426
from tests.lazy_transforms_utils import test_resampler_lazy
2527
from tests.test_utils import TEST_DEVICES, assert_allclose
2628

@@ -33,6 +35,18 @@
3335
torch.eye(4),
3436
torch.arange(12).reshape((2, 1, 2, 3)),
3537
"RAS",
38+
False,
39+
*device,
40+
]
41+
)
42+
TESTS.append(
43+
[
44+
{"axcodes": "LPS"},
45+
torch.arange(12).reshape((2, 1, 2, 3)),
46+
torch.eye(4),
47+
torch.arange(12).reshape((2, 1, 2, 3)),
48+
"LPS",
49+
True,
3650
*device,
3751
]
3852
)
@@ -43,6 +57,18 @@
4357
torch.as_tensor(np.diag([-1, -1, 1, 1])),
4458
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
4559
"ALS",
60+
False,
61+
*device,
62+
]
63+
)
64+
TESTS.append(
65+
[
66+
{"axcodes": "PRS"},
67+
torch.arange(12).reshape((2, 1, 2, 3)),
68+
torch.as_tensor(np.diag([-1, -1, 1, 1])),
69+
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
70+
"PRS",
71+
True,
4672
*device,
4773
]
4874
)
@@ -53,6 +79,18 @@
5379
torch.as_tensor(np.diag([-1, -1, 1, 1])),
5480
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
5581
"RAS",
82+
False,
83+
*device,
84+
]
85+
)
86+
TESTS.append(
87+
[
88+
{"axcodes": "LPS"},
89+
torch.arange(12).reshape((2, 1, 2, 3)),
90+
torch.as_tensor(np.diag([-1, -1, 1, 1])),
91+
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
92+
"LPS",
93+
True,
5694
*device,
5795
]
5896
)
@@ -63,6 +101,18 @@
63101
torch.eye(3),
64102
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
65103
"AL",
104+
False,
105+
*device,
106+
]
107+
)
108+
TESTS.append(
109+
[
110+
{"axcodes": "PR"},
111+
torch.arange(6).reshape((2, 1, 3)),
112+
torch.eye(3),
113+
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
114+
"PR",
115+
True,
66116
*device,
67117
]
68118
)
@@ -73,6 +123,18 @@
73123
torch.eye(2),
74124
torch.tensor([[2, 1, 0], [5, 4, 3]]),
75125
"L",
126+
False,
127+
*device,
128+
]
129+
)
130+
TESTS.append(
131+
[
132+
{"axcodes": "R"},
133+
torch.arange(6).reshape((2, 3)),
134+
torch.eye(2),
135+
torch.tensor([[2, 1, 0], [5, 4, 3]]),
136+
"R",
137+
True,
76138
*device,
77139
]
78140
)
@@ -83,6 +145,7 @@
83145
torch.eye(2),
84146
torch.tensor([[2, 1, 0], [5, 4, 3]]),
85147
"L",
148+
False,
86149
*device,
87150
]
88151
)
@@ -93,6 +156,7 @@
93156
torch.as_tensor(np.diag([-1, 1])),
94157
torch.arange(6).reshape((2, 3)),
95158
"L",
159+
False,
96160
*device,
97161
]
98162
)
@@ -107,6 +171,7 @@
107171
),
108172
torch.tensor([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]),
109173
"LPS",
174+
False,
110175
*device,
111176
]
112177
)
@@ -121,6 +186,7 @@
121186
),
122187
torch.tensor([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]),
123188
"RAS",
189+
False,
124190
*device,
125191
]
126192
)
@@ -131,6 +197,7 @@
131197
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
132198
torch.tensor([[[3, 0], [4, 1], [5, 2]]]),
133199
"RA",
200+
False,
134201
*device,
135202
]
136203
)
@@ -141,6 +208,7 @@
141208
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
142209
torch.tensor([[[2, 5], [1, 4], [0, 3]]]),
143210
"LP",
211+
False,
144212
*device,
145213
]
146214
)
@@ -151,6 +219,7 @@
151219
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
152220
torch.zeros((1, 2, 3, 4, 5)),
153221
"LPID",
222+
False,
154223
*device,
155224
]
156225
)
@@ -161,6 +230,7 @@
161230
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
162231
torch.zeros((1, 2, 3, 4, 5)),
163232
"RASD",
233+
False,
164234
*device,
165235
]
166236
)
@@ -175,6 +245,11 @@
175245
[{"axcodes": "RA"}, torch.arange(12).reshape((2, 1, 2, 3)), torch.eye(4)]
176246
]
177247

248+
TESTS_INVERSE = []
249+
for device in TEST_DEVICES:
250+
TESTS_INVERSE.append([True, *device])
251+
TESTS_INVERSE.append([False, *device])
252+
178253

179254
class TestOrientationCase(unittest.TestCase):
180255
@parameterized.expand(TESTS)
@@ -185,17 +260,20 @@ def test_ornt_meta(
185260
affine: torch.Tensor,
186261
expected_data: torch.Tensor,
187262
expected_code: str,
263+
lps_convention: bool,
188264
device,
189265
):
190-
img = MetaTensor(img, affine=affine).to(device)
266+
meta = {"space": SpaceKeys.LPS} if lps_convention else None
267+
img = MetaTensor(img, affine=affine, meta=meta).to(device)
191268
ornt = Orientation(**init_param)
192269
call_param = {"data_array": img}
193270
res = ornt(**call_param) # type: ignore[arg-type]
194271
if img.ndim in (3, 4):
195272
test_resampler_lazy(ornt, res, init_param, call_param)
196273

197274
assert_allclose(res, expected_data.to(device))
198-
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) # type: ignore
275+
labels = (("R", "L"), ("A", "P"), ("I", "S")) if lps_convention else ornt.labels
276+
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=labels) # type: ignore
199277
self.assertEqual("".join(new_code), expected_code)
200278

201279
@parameterized.expand(TESTS_TORCH)
@@ -224,23 +302,23 @@ def test_bad_params(self, init_param, img: torch.Tensor, affine: torch.Tensor):
224302
with self.assertRaises(ValueError):
225303
Orientation(**init_param)(img)
226304

227-
@parameterized.expand(TEST_DEVICES)
228-
def test_inverse(self, device):
305+
@parameterized.expand(TESTS_INVERSE)
306+
def test_inverse(self, lps_convention: bool, device):
229307
img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
230308
affine = torch.tensor(
231309
[[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
232310
)
233-
meta = {"fname": "somewhere"}
311+
meta = {"fname": "somewhere", "space": SpaceKeys.LPS if lps_convention else SpaceKeys.RAS}
234312
img = MetaTensor(img_t, affine=affine, meta=meta)
235313
tr = Orientation("LPS")
236314
# check that image and affine have changed
237-
img = tr(img)
315+
img = cast(MetaTensor, tr(img))
238316
self.assertNotEqual(img.shape, img_t.shape)
239-
self.assertGreater((affine - img.affine).max(), 0.5)
317+
self.assertGreater(float((affine - img.affine).max()), 0.5)
240318
# check that with inverse, image affine are back to how they were
241-
img = tr.inverse(img)
319+
img = cast(MetaTensor, tr.inverse(img))
242320
self.assertEqual(img.shape, img_t.shape)
243-
self.assertLess((affine - img.affine).max(), 1e-2)
321+
self.assertLess(float((affine - img.affine).max()), 1e-2)
244322

245323

246324
if __name__ == "__main__":

0 commit comments

Comments
 (0)