Skip to content

Commit 2da5ca9

Browse files
DCO Remediation Commit for simben <[email protected]>
I, simben <[email protected]>, hereby add my Signed-off-by to this commit: 43c694b I, simben <[email protected]>, hereby add my Signed-off-by to this commit: 569df7b I, simben <[email protected]>, hereby add my Signed-off-by to this commit: fcf5ac0 I, simben <[email protected]>, hereby add my Signed-off-by to this commit: c846b6d Signed-off-by: simben <[email protected]>
1 parent c846b6d commit 2da5ca9

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

monai/bundle/nnunet.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,28 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
259259
self.network_weights = self.predictor.network
260260

261261
def forward(self, x):
262+
"""
263+
Forward pass for the nnUNet model.
264+
265+
:no-index:
266+
267+
Args:
268+
x (Union[torch.Tensor, Tuple[MetaTensor]]): Input tensor or a tuple of MetaTensors. If the input is a tuple,
269+
it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.
270+
271+
Returns:
272+
MetaTensor: The output tensor with the same metadata as the input.
273+
274+
Raises:
275+
TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors.
276+
277+
Notes:
278+
- If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple.
279+
- If the input is a collated batch, the filenames are extracted from the metadata of the input tensor.
280+
- The filenames are used to generate predictions using the nnUNet predictor.
281+
- The predictions are converted to torch tensors, with added batch and channel dimensions.
282+
- The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata as the input.
283+
"""
262284
if type(x) is tuple: # if batch is decollated (list of tensors)
263285
input_files = [img.meta["filename_or_obj"][0] for img in x]
264286
else: # if batch is collated

0 commit comments

Comments
 (0)