Skip to content

Commit 23de2bc

Browse files
Refactor forward method in ModelnnUNetWrapper for clarity and type consistency
Signed-off-by: simben <[email protected]>
1 parent a2bc247 commit 23de2bc

File tree

2 files changed

+18
-444
lines changed

2 files changed

+18
-444
lines changed

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,14 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
257257
# End Block
258258
self.network_weights = self.predictor.network
259259

260-
def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]):
260+
def forward(self, x: MetaTensor) -> MetaTensor:
261261
"""
262262
Forward pass for the nnUNet model.
263263
264264
:no-index:
265265
266266
Args:
267-
x (Union[MetaTensor, Tuple[MetaTensor]]): Input tensor or a tuple of MetaTensors. If the input is a tuple,
267+
x (MetaTensor): Input tensor. If the input is a tuple,
268268
it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.
269269
270270
Returns:
@@ -280,9 +280,9 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]):
280280
- The predictions are converted to torch tensors, with added batch and channel dimensions.
281281
- The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
282282
"""
283-
if isinstance(x, tuple): # if batch is decollated (list of tensors)
284-
properties_or_list_of_properties = []
285-
image_or_list_of_images = []
283+
#if isinstance(x, tuple): # if batch is decollated (list of tensors)
284+
# properties_or_list_of_properties = []
285+
# image_or_list_of_images = []
286286

287287
# for img in x:
288288
# if isinstance(img, MetaTensor):
@@ -291,15 +291,16 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]):
291291
# else:
292292
# raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
293293

294-
else: # if batch is collated
295-
if isinstance(x, MetaTensor):
296-
if "pixdim" in x.meta:
297-
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
298-
else:
299-
properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]}
294+
#else: # if batch is collated
295+
if isinstance(x, MetaTensor):
296+
if "pixdim" in x.meta:
297+
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
300298
else:
301-
raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
302-
image_or_list_of_images = x.cpu().numpy()[0, :]
299+
properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]}
300+
else:
301+
raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
302+
303+
image_or_list_of_images = x.cpu().numpy()[0, :]
303304

304305
# input_files should be a list of file paths, one per modality
305306
prediction_output = self.predictor.predict_from_list_of_npy_arrays(
@@ -318,10 +319,10 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]):
318319
out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))
319320
out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension
320321

321-
if type(x) is tuple:
322-
return MetaTensor(out_tensor, meta=x[0].meta)
323-
else:
324-
return MetaTensor(out_tensor, meta=x.meta)
322+
#if type(x) is tuple:
323+
# return MetaTensor(out_tensor, meta=x[0].meta)
324+
#else:
325+
return MetaTensor(out_tensor, meta=x.meta)
325326

326327

327328
def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):

0 commit comments

Comments
 (0)