Skip to content

Commit 1283125

Browse files
Prevent implicit conversion of metatensor to numpy array
Signed-off-by: Davis Vigneault <[email protected]>
1 parent 15fd428 commit 1283125

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

monai/data/image_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def convert_to_channel_last(
324324
data = data[..., 0, :]
325325
# if desired, remove trailing singleton dimensions
326326
while squeeze_end_dims and data.shape[-1] == 1:
327-
data = np.squeeze(data, -1)
327+
data = data.squeeze(-1)
328328
if contiguous:
329329
data = ascontiguousarray(data)
330330
return data

tests/data/test_itk_writer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import torch
2020

21-
from monai.data import ITKWriter
21+
from monai.data import ITKWriter, MetaTensor
2222
from monai.utils import optional_import
2323

2424
itk, has_itk = optional_import("itk")
@@ -64,6 +64,13 @@ def test_no_channel(self):
6464
np.testing.assert_allclose(output.shape, (4, 4, 3))
6565
np.testing.assert_allclose(output[1, 1], (5, 21, 37))
6666

67+
def test_metatensor_preserved(self):
68+
data = MetaTensor(np.arange(48).reshape(3, 4, 4, 1), meta={"test_key": "test_value"})
69+
writer = ITKWriter()
70+
writer.set_data_array(data, channel_dim=-1, squeeze_end_dims=True)
71+
self.assertIsInstance(writer.data_obj, MetaTensor)
72+
self.assertEqual(writer.data_obj.meta.get("test_key"), "test_value")
73+
6774

6875
if __name__ == "__main__":
6976
unittest.main()

0 commit comments

Comments
 (0)