|
19 | 19 | import torch |
20 | 20 |
|
21 | 21 | from monai.config import IndexSelection, KeysCollection, NdarrayOrTensor |
| 22 | +from monai.data.meta_obj import get_meta_dict_name |
22 | 23 | from monai.networks.layers import GaussianFilter |
23 | 24 | from monai.transforms import Resize, SpatialCrop |
24 | 25 | from monai.transforms.transform import MapTransform, Randomizable, Transform |
@@ -546,7 +547,12 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num): |
546 | 547 |
|
547 | 548 | def __call__(self, data): |
548 | 549 | d = dict(data) |
549 | | - meta_dict_key = self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}" |
| 550 | + meta_dict_key = self.meta_keys |
| 551 | + if not meta_dict_key: |
| 552 | + candidate_meta_key = f"{self.ref_image}_{self.meta_key_postfix}" |
| 553 | + meta_dict = d.get(candidate_meta_key, None) |
| 554 | + if meta_dict is None: |
| 555 | + meta_dict_key = get_meta_dict_name(self.ref_image, d) |
550 | 556 | if meta_dict_key not in d: |
551 | 557 | raise RuntimeError(f"Missing meta_dict {meta_dict_key} in data!") |
552 | 558 | if "spatial_shape" not in d[meta_dict_key]: |
@@ -742,7 +748,10 @@ def __init__( |
742 | 748 | def __call__(self, data: Any) -> dict: |
743 | 749 | d = dict(data) |
744 | 750 | guidance = d[self.guidance] |
745 | | - meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"] |
| 751 | + # meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"] |
| 752 | + meta_dict: dict = d.get(self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}", None) |
| 753 | + if meta_dict is None: |
| 754 | + meta_dict = d[get_meta_dict_name(self.ref_image, d)] |
746 | 755 | current_shape = d[self.ref_image].shape[1:] |
747 | 756 | cropped_shape = meta_dict[self.cropped_shape_key][1:] |
748 | 757 | factor = np.divide(current_shape, cropped_shape) |
@@ -852,7 +861,8 @@ def __init__( |
852 | 861 |
|
853 | 862 | def __call__(self, data: Any) -> dict: |
854 | 863 | d = dict(data) |
855 | | - meta_dict: dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] |
| 864 | + meta_dict: dict = d.get(f"{self.ref_image}_{self.meta_key_postfix}", |
| 865 | + d[get_meta_dict_name(self.ref_image, d)]) |
856 | 866 |
|
857 | 867 | for key, mode, align_corners, meta_key in self.key_iterator(d, self.mode, self.align_corners, self.meta_keys): |
858 | 868 | image = d[key] |
@@ -969,5 +979,8 @@ def __call__(self, data): |
969 | 979 | for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): |
970 | 980 | img_slice, idx = self._apply(d[key], guidance) |
971 | 981 | d[key] = img_slice |
972 | | - d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx |
| 982 | + # d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx |
| 983 | + if meta_key not in d: |
| 984 | + meta_key = get_meta_dict_name(key, d) |
| 985 | + d[meta_key] = idx |
973 | 986 | return d |
0 commit comments