Skip to content

Commit 0edbe87

Browse files
author
Ben Murray
committed
Initial commit to make codebase robust to use of non-standard meta_dict names
Signed-off-by: Ben Murray <[email protected]>
1 parent bfcb318 commit 0edbe87

File tree

8 files changed

+45
-7
lines changed

8 files changed

+45
-7
lines changed

monai/apps/deepgrow/transforms.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020

2121
from monai.config import IndexSelection, KeysCollection, NdarrayOrTensor
22+
from monai.data.meta_obj import get_meta_dict_name
2223
from monai.networks.layers import GaussianFilter
2324
from monai.transforms import Resize, SpatialCrop
2425
from monai.transforms.transform import MapTransform, Randomizable, Transform
@@ -546,7 +547,12 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num):
546547

547548
def __call__(self, data):
548549
d = dict(data)
549-
meta_dict_key = self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"
550+
meta_dict_key = self.meta_keys
551+
if not meta_dict_key:
552+
candidate_meta_key = f"{self.ref_image}_{self.meta_key_postfix}"
553+
meta_dict = d.get(candidate_meta_key, None)
554+
if meta_dict is None:
555+
meta_dict_key = get_meta_dict_name(self.ref_image, d)
550556
if meta_dict_key not in d:
551557
raise RuntimeError(f"Missing meta_dict {meta_dict_key} in data!")
552558
if "spatial_shape" not in d[meta_dict_key]:
@@ -742,7 +748,10 @@ def __init__(
742748
def __call__(self, data: Any) -> dict:
743749
d = dict(data)
744750
guidance = d[self.guidance]
745-
meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"]
751+
# meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"]
752+
meta_dict: dict = d.get(self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}", None)
753+
if meta_dict is None:
754+
meta_dict = d[get_meta_dict_name(self.ref_image, d)]
746755
current_shape = d[self.ref_image].shape[1:]
747756
cropped_shape = meta_dict[self.cropped_shape_key][1:]
748757
factor = np.divide(current_shape, cropped_shape)
@@ -852,7 +861,8 @@ def __init__(
852861

853862
def __call__(self, data: Any) -> dict:
854863
d = dict(data)
855-
meta_dict: dict = d[f"{self.ref_image}_{self.meta_key_postfix}"]
864+
meta_dict: dict = d.get(f"{self.ref_image}_{self.meta_key_postfix}",
865+
d[get_meta_dict_name(self.ref_image, d)])
856866

857867
for key, mode, align_corners, meta_key in self.key_iterator(d, self.mode, self.align_corners, self.meta_keys):
858868
image = d[key]
@@ -969,5 +979,8 @@ def __call__(self, data):
969979
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
970980
img_slice, idx = self._apply(d[key], guidance)
971981
d[key] = img_slice
972-
d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx
982+
# d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx
983+
if meta_key not in d:
984+
meta_key = get_meta_dict_name(key, d)
985+
d[meta_key] = idx
973986
return d

monai/apps/detection/transforms/dictionary.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from monai.config import KeysCollection, SequenceStr
4242
from monai.config.type_definitions import DtypeLike, NdarrayOrTensor
4343
from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image
44+
from monai.data.meta_obj import get_meta_dict_name
4445
from monai.data.meta_tensor import MetaTensor, get_track_meta
4546
from monai.data.utils import orientation_ras_lps
4647
from monai.transforms import Flip, RandFlip, RandZoom, Rotate90, SpatialCrop, Zoom
@@ -308,7 +309,9 @@ def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> tuple[Ndarray
308309
elif meta_key in d:
309310
meta_dict = d[meta_key]
310311
else:
311-
raise ValueError(f"{meta_key} is not found. Please check whether it is the correct the image meta key.")
312+
meta_key = get_meta_dict_name(self.box_ref_image_keys, d)
313+
if meta_key not in d:
314+
raise ValueError(f"{self.image_meta_key} is not found. Please check whether it is the correct the image meta key.")
312315
if "affine" not in meta_dict:
313316
raise ValueError(
314317
f"'affine' is not found in {meta_key}. \

monai/data/meta_obj.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,11 @@ def is_batch(self) -> bool:
242242
def is_batch(self, val: bool) -> None:
243243
"""Set whether object is part of batch or not."""
244244
self._is_batch = val
245+
246+
247+
def get_meta_dict_name(key, dictionary):
248+
for kv, kd in dictionary.items():
249+
if isinstance(kd, dict):
250+
if kd.get("tensor_name", None) == key:
251+
return kv
252+
return None

monai/transforms/intensity/dictionary.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from monai.config import DtypeLike, KeysCollection
2626
from monai.config.type_definitions import NdarrayOrTensor
27-
from monai.data.meta_obj import get_track_meta
27+
from monai.data.meta_obj import get_meta_dict_name, get_track_meta
2828
from monai.transforms.intensity.array import (
2929
AdjustContrast,
3030
ClipIntensityPercentiles,
@@ -358,6 +358,8 @@ def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]:
358358
d, self.factor_key, self.meta_keys, self.meta_key_postfix
359359
):
360360
meta_key = meta_key or f"{key}_{meta_key_postfix}"
361+
if meta_key not in d:
362+
meta_key = get_meta_dict_name(key, d)
361363
factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None
362364
offset = None if factor is None else self.shifter.offset * factor
363365
d[key] = self.shifter(d[key], offset=offset)

monai/transforms/io/dictionary.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __call__(self, data, reader: ImageReader | None = None):
169169
f"loader must return a tuple or list (because image_only=False was used), got {type(data)}."
170170
)
171171
d[key] = data[0]
172+
data[1]['tensor_name'] = key
172173
if not isinstance(data[1], dict):
173174
raise ValueError(f"metadata must be a dict, got {type(data[1])}.")
174175
meta_key = meta_key or f"{key}_{meta_key_postfix}"

monai/transforms/meta_utility/dictionary.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424

2525
from monai.config.type_definitions import KeysCollection, NdarrayOrTensor
26+
from monai.data.meta_obj import get_meta_dict_name
2627
from monai.data.meta_tensor import MetaTensor
2728
from monai.transforms.inverse import InvertibleTransform
2829
from monai.transforms.transform import MapTransform
@@ -77,7 +78,10 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Nd
7778
_ = self.get_most_recent_transform(d, key)
7879
# do the inverse
7980
im = d[key]
80-
meta = d.pop(PostFix.meta(key), None)
81+
if PostFix.meta(key) in d:
82+
meta = d.pop(PostFix.meta(key), None)
83+
else:
84+
meta = d.pop(get_meta_dict_name(key, d))
8185
transforms = d.pop(PostFix.transforms(key), None)
8286
im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore
8387
d[key] = im
@@ -101,6 +105,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
101105
for key in self.key_iterator(d):
102106
self.push_transform(d, key)
103107
im = d[key]
108+
104109
meta = d.pop(PostFix.meta(key), None)
105110
transforms = d.pop(PostFix.transforms(key), None)
106111
im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore

monai/transforms/post/dictionary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from monai import config
2929
from monai.config.type_definitions import KeysCollection, NdarrayOrTensor, PathLike
3030
from monai.data.csv_saver import CSVSaver
31+
from monai.data.meta_obj import get_meta_dict_name
3132
from monai.data.meta_tensor import MetaTensor
3233
from monai.transforms.inverse import InvertibleTransform
3334
from monai.transforms.post.array import (
@@ -797,6 +798,8 @@ def __call__(self, data):
797798
if meta_key is None and meta_key_postfix is not None:
798799
meta_key = f"{key}_{meta_key_postfix}"
799800
meta_data = d[meta_key] if meta_key is not None else None
801+
if meta_data is None:
802+
meta_data = d.get(get_meta_dict_name(key, d), None)
800803
self.saver.save(data=d[key], meta_data=meta_data)
801804
if self.flush:
802805
self.saver.finalize()

monai/transforms/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import monai
2828
from monai.config import DtypeLike, IndexSelection
2929
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
30+
from monai.data.meta_obj import get_meta_dict_name
3031
from monai.data.utils import to_affine_nd
3132
from monai.networks.layers import GaussianFilter
3233
from monai.networks.utils import meshgrid_ij
@@ -2144,6 +2145,8 @@ def sync_meta_info(key, data_dict, t: bool = True):
21442145

21452146
# update meta dicts
21462147
meta_dict_key = PostFix.meta(key)
2148+
if meta_dict_key not in d:
2149+
meta_dict_key = get_meta_dict_name(key, d)
21472150
if meta_dict_key not in d:
21482151
d[meta_dict_key] = monai.data.MetaTensor.get_default_meta()
21492152
if not isinstance(d[key], monai.data.MetaTensor):

0 commit comments

Comments
 (0)