1212from __future__ import annotations
1313
1414import unittest
15+ from typing import cast
1516
1617import nibabel as nib
1718import numpy as np
2122from monai .data .meta_obj import set_track_meta
2223from monai .data .meta_tensor import MetaTensor
2324from monai .transforms import Orientation , create_rotate , create_translate
25+ from monai .utils import SpaceKeys
2426from tests .lazy_transforms_utils import test_resampler_lazy
2527from tests .test_utils import TEST_DEVICES , assert_allclose
2628
3335 torch .eye (4 ),
3436 torch .arange (12 ).reshape ((2 , 1 , 2 , 3 )),
3537 "RAS" ,
38+ False ,
39+ * device ,
40+ ]
41+ )
42+ TESTS .append (
43+ [
44+ {"axcodes" : "LPS" },
45+ torch .arange (12 ).reshape ((2 , 1 , 2 , 3 )),
46+ torch .eye (4 ),
47+ torch .arange (12 ).reshape ((2 , 1 , 2 , 3 )),
48+ "LPS" ,
49+ True ,
3650 * device ,
3751 ]
3852 )
4357 torch .as_tensor (np .diag ([- 1 , - 1 , 1 , 1 ])),
4458 torch .tensor ([[[[3 , 4 , 5 ]], [[0 , 1 , 2 ]]], [[[9 , 10 , 11 ]], [[6 , 7 , 8 ]]]]),
4559 "ALS" ,
60+ False ,
61+ * device ,
62+ ]
63+ )
64+ TESTS .append (
65+ [
66+ {"axcodes" : "PRS" },
67+ torch .arange (12 ).reshape ((2 , 1 , 2 , 3 )),
68+ torch .as_tensor (np .diag ([- 1 , - 1 , 1 , 1 ])),
69+ torch .tensor ([[[[3 , 4 , 5 ]], [[0 , 1 , 2 ]]], [[[9 , 10 , 11 ]], [[6 , 7 , 8 ]]]]),
70+ "PRS" ,
71+ True ,
4672 * device ,
4773 ]
4874 )
5379 torch .as_tensor (np .diag ([- 1 , - 1 , 1 , 1 ])),
5480 torch .tensor ([[[[3 , 4 , 5 ], [0 , 1 , 2 ]]], [[[9 , 10 , 11 ], [6 , 7 , 8 ]]]]),
5581 "RAS" ,
82+ False ,
83+ * device ,
84+ ]
85+ )
86+ TESTS .append (
87+ [
88+ {"axcodes" : "LPS" },
89+ torch .arange (12 ).reshape ((2 , 1 , 2 , 3 )),
90+ torch .as_tensor (np .diag ([- 1 , - 1 , 1 , 1 ])),
91+ torch .tensor ([[[[3 , 4 , 5 ], [0 , 1 , 2 ]]], [[[9 , 10 , 11 ], [6 , 7 , 8 ]]]]),
92+ "LPS" ,
93+ True ,
5694 * device ,
5795 ]
5896 )
63101 torch .eye (3 ),
64102 torch .tensor ([[[0 ], [1 ], [2 ]], [[3 ], [4 ], [5 ]]]),
65103 "AL" ,
104+ False ,
105+ * device ,
106+ ]
107+ )
108+ TESTS .append (
109+ [
110+ {"axcodes" : "PR" },
111+ torch .arange (6 ).reshape ((2 , 1 , 3 )),
112+ torch .eye (3 ),
113+ torch .tensor ([[[0 ], [1 ], [2 ]], [[3 ], [4 ], [5 ]]]),
114+ "PR" ,
115+ True ,
66116 * device ,
67117 ]
68118 )
73123 torch .eye (2 ),
74124 torch .tensor ([[2 , 1 , 0 ], [5 , 4 , 3 ]]),
75125 "L" ,
126+ False ,
127+ * device ,
128+ ]
129+ )
130+ TESTS .append (
131+ [
132+ {"axcodes" : "R" },
133+ torch .arange (6 ).reshape ((2 , 3 )),
134+ torch .eye (2 ),
135+ torch .tensor ([[2 , 1 , 0 ], [5 , 4 , 3 ]]),
136+ "R" ,
137+ True ,
76138 * device ,
77139 ]
78140 )
83145 torch .eye (2 ),
84146 torch .tensor ([[2 , 1 , 0 ], [5 , 4 , 3 ]]),
85147 "L" ,
148+ False ,
86149 * device ,
87150 ]
88151 )
93156 torch .as_tensor (np .diag ([- 1 , 1 ])),
94157 torch .arange (6 ).reshape ((2 , 3 )),
95158 "L" ,
159+ False ,
96160 * device ,
97161 ]
98162 )
107171 ),
108172 torch .tensor ([[[[2 , 5 ]], [[1 , 4 ]], [[0 , 3 ]]], [[[8 , 11 ]], [[7 , 10 ]], [[6 , 9 ]]]]),
109173 "LPS" ,
174+ False ,
110175 * device ,
111176 ]
112177 )
121186 ),
122187 torch .tensor ([[[[0 , 3 ]], [[1 , 4 ]], [[2 , 5 ]]], [[[6 , 9 ]], [[7 , 10 ]], [[8 , 11 ]]]]),
123188 "RAS" ,
189+ False ,
124190 * device ,
125191 ]
126192 )
131197 torch .as_tensor (create_translate (2 , (10 , 20 )) @ create_rotate (2 , (np .pi / 3 )) @ np .diag ([- 1 , - 0.2 , 1 ])),
132198 torch .tensor ([[[3 , 0 ], [4 , 1 ], [5 , 2 ]]]),
133199 "RA" ,
200+ False ,
134201 * device ,
135202 ]
136203 )
141208 torch .as_tensor (create_translate (2 , (10 , 20 )) @ create_rotate (2 , (np .pi / 3 )) @ np .diag ([- 1 , - 0.2 , 1 ])),
142209 torch .tensor ([[[2 , 5 ], [1 , 4 ], [0 , 3 ]]]),
143210 "LP" ,
211+ False ,
144212 * device ,
145213 ]
146214 )
151219 torch .as_tensor (np .diag ([- 1 , - 0.2 , - 1 , 1 , 1 ])),
152220 torch .zeros ((1 , 2 , 3 , 4 , 5 )),
153221 "LPID" ,
222+ False ,
154223 * device ,
155224 ]
156225 )
161230 torch .as_tensor (np .diag ([- 1 , - 0.2 , - 1 , 1 , 1 ])),
162231 torch .zeros ((1 , 2 , 3 , 4 , 5 )),
163232 "RASD" ,
233+ False ,
164234 * device ,
165235 ]
166236 )
175245 [{"axcodes" : "RA" }, torch .arange (12 ).reshape ((2 , 1 , 2 , 3 )), torch .eye (4 )]
176246]
177247
248+ TESTS_INVERSE = []
249+ for device in TEST_DEVICES :
250+ TESTS_INVERSE .append ([True , * device ])
251+ TESTS_INVERSE .append ([False , * device ])
252+
178253
179254class TestOrientationCase (unittest .TestCase ):
180255 @parameterized .expand (TESTS )
@@ -185,17 +260,20 @@ def test_ornt_meta(
185260 affine : torch .Tensor ,
186261 expected_data : torch .Tensor ,
187262 expected_code : str ,
263+ lps_convention : bool ,
188264 device ,
189265 ):
190- img = MetaTensor (img , affine = affine ).to (device )
266+ meta = {"space" : SpaceKeys .LPS } if lps_convention else None
267+ img = MetaTensor (img , affine = affine , meta = meta ).to (device )
191268 ornt = Orientation (** init_param )
192269 call_param = {"data_array" : img }
193270 res = ornt (** call_param ) # type: ignore[arg-type]
194271 if img .ndim in (3 , 4 ):
195272 test_resampler_lazy (ornt , res , init_param , call_param )
196273
197274 assert_allclose (res , expected_data .to (device ))
198- new_code = nib .orientations .aff2axcodes (res .affine .cpu (), labels = ornt .labels ) # type: ignore
275+ labels = (("R" , "L" ), ("A" , "P" ), ("I" , "S" )) if lps_convention else ornt .labels
276+ new_code = nib .orientations .aff2axcodes (res .affine .cpu (), labels = labels ) # type: ignore
199277 self .assertEqual ("" .join (new_code ), expected_code )
200278
201279 @parameterized .expand (TESTS_TORCH )
@@ -224,23 +302,23 @@ def test_bad_params(self, init_param, img: torch.Tensor, affine: torch.Tensor):
224302 with self .assertRaises (ValueError ):
225303 Orientation (** init_param )(img )
226304
227- @parameterized .expand (TEST_DEVICES )
228- def test_inverse (self , device ):
305+ @parameterized .expand (TESTS_INVERSE )
306+ def test_inverse (self , lps_convention : bool , device ):
229307 img_t = torch .rand ((1 , 10 , 9 , 8 ), dtype = torch .float32 , device = device )
230308 affine = torch .tensor (
231309 [[0 , 0 , - 1 , 0 ], [1 , 0 , 0 , 0 ], [0 , 1 , 0 , 0 ], [0 , 0 , 0 , 1 ]], dtype = torch .float32 , device = "cpu"
232310 )
233- meta = {"fname" : "somewhere" }
311+ meta = {"fname" : "somewhere" , "space" : SpaceKeys . LPS if lps_convention else SpaceKeys . RAS }
234312 img = MetaTensor (img_t , affine = affine , meta = meta )
235313 tr = Orientation ("LPS" )
236314 # check that image and affine have changed
237- img = tr (img )
315+ img = cast ( MetaTensor , tr (img ) )
238316 self .assertNotEqual (img .shape , img_t .shape )
239- self .assertGreater (( affine - img .affine ).max (), 0.5 )
317+ self .assertGreater (float (( affine - img .affine ).max () ), 0.5 )
240318 # check that with inverse, image affine are back to how they were
241- img = tr .inverse (img )
319+ img = cast ( MetaTensor , tr .inverse (img ) )
242320 self .assertEqual (img .shape , img_t .shape )
243- self .assertLess (( affine - img .affine ).max (), 1e-2 )
321+ self .assertLess (float (( affine - img .affine ).max () ), 1e-2 )
244322
245323
246324if __name__ == "__main__" :
0 commit comments