Skip to content

Commit d3dad8c

Browse files
authored
fix: handle metadata loading and shape calculation in transforms
Signed-off-by: Tristan Kirscher <[email protected]>
1 parent 211cfd8 commit d3dad8c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

modules/dynunet_pipeline/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num, num_sampl
4242
keys = ["image"]
4343

4444
load_transforms = [
45-
LoadImaged(keys=keys),
45+
LoadImaged(keys=keys, image_only=False, ensure_channel_first=True),
4646
EnsureChannelFirstd(keys=keys),
4747
]
4848
# 2. sampling
@@ -284,6 +284,8 @@ def __init__(
284284

285285
def calculate_new_shape(self, spacing, shape):
286286
spacing_ratio = np.array(spacing) / np.array(self.target_spacing)
287+
if len(shape) == 4: # If shape includes channel dimension
288+
shape = shape[1:]
287289
new_shape = (spacing_ratio * np.array(shape)).astype(int).tolist()
288290
return new_shape
289291

0 commit comments

Comments
 (0)