Skip to content

Commit f8fcb61

Browse files
authored
Merge branch 'dev' into fix-encoder-dim
2 parents e6b440f + b92b2ce commit f8fcb61

File tree

6 files changed

+689
-15
lines changed

6 files changed

+689
-15
lines changed

monai/auto3dseg/analyzer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
217217
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
218218

219219
def __call__(self, data):
220+
# Input Validation Addition
221+
if not isinstance(data, dict):
222+
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
223+
if self.image_key not in data:
224+
raise KeyError(f"Key '{self.image_key}' not found in input data.")
225+
image = data[self.image_key]
226+
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
227+
raise TypeError(
228+
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
229+
f"but got {type(image).__name__}."
230+
)
231+
if image.ndim < 3:
232+
raise ValueError(
233+
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
234+
)
235+
# --- End of validation ---
220236
"""
221237
Callable to execute the pre-defined functions
222238

monai/inferers/inferer.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,36 @@ def __call__(
322322
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
323323
args: optional args to be passed to ``network``.
324324
kwargs: optional keyword args to be passed to ``network``.
325+
condition (torch.Tensor, optional): If provided via `**kwargs`,
326+
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
327+
The resulting segments will be passed to the model together with the corresponding input segments.
325328
326329
"""
330+
# check if there is a conditioning signal
331+
condition = kwargs.pop("condition", None)
332+
# shape check for condition
333+
if condition is not None:
334+
if isinstance(inputs, torch.Tensor) and isinstance(condition, torch.Tensor):
335+
if condition.shape != inputs.shape:
336+
raise ValueError(
337+
f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}"
338+
)
339+
elif isinstance(inputs, list) and isinstance(condition, list):
340+
if len(inputs) != len(condition):
341+
raise ValueError(
342+
f"Length of `condition` must match `inputs`. Got {len(inputs)} and {len(condition)}."
343+
)
344+
for (in_patch, _), (cond_patch, _) in zip(inputs, condition):
345+
if cond_patch.shape != in_patch.shape:
346+
raise ValueError(
347+
"Each `condition` patch must match the shape of the corresponding input patch. "
348+
f"Got {cond_patch.shape} and {in_patch.shape}."
349+
)
350+
else:
351+
raise ValueError(
352+
"`condition` and `inputs` must be of the same type (both Tensor or both list of patches)."
353+
)
354+
327355
patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
328356
if self.splitter is None:
329357
# handle situations where the splitter is not provided
@@ -344,20 +372,39 @@ def __call__(
344372
f"The provided inputs type is {type(inputs)}."
345373
)
346374
patches_locations = inputs
375+
if condition is not None:
376+
condition_locations = condition
347377
else:
348378
# apply splitter
349379
patches_locations = self.splitter(inputs)
380+
if condition is not None:
381+
# apply splitter to condition
382+
condition_locations = self.splitter(condition)
350383

351384
ratios: list[float] = []
352385
mergers: list[Merger] = []
353-
for patches, locations, batch_size in self._batch_sampler(patches_locations):
354-
# run inference
355-
outputs = self._run_inference(network, patches, *args, **kwargs)
356-
# initialize the mergers
357-
if not mergers:
358-
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
359-
# aggregate outputs
360-
self._aggregate(outputs, locations, batch_size, mergers, ratios)
386+
if condition is not None:
387+
for (patches, locations, batch_size), (condition_patches, _, _) in zip(
388+
self._batch_sampler(patches_locations), self._batch_sampler(condition_locations)
389+
):
390+
# add patched condition to kwargs
391+
kwargs["condition"] = condition_patches
392+
# run inference
393+
outputs = self._run_inference(network, patches, *args, **kwargs)
394+
# initialize the mergers
395+
if not mergers:
396+
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
397+
# aggregate outputs
398+
self._aggregate(outputs, locations, batch_size, mergers, ratios)
399+
else:
400+
for patches, locations, batch_size in self._batch_sampler(patches_locations):
401+
# run inference
402+
outputs = self._run_inference(network, patches, *args, **kwargs)
403+
# initialize the mergers
404+
if not mergers:
405+
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
406+
# aggregate outputs
407+
self._aggregate(outputs, locations, batch_size, mergers, ratios)
361408

362409
# finalize the mergers and get the results
363410
merged_outputs = [merger.finalize() for merger in mergers]
@@ -519,8 +566,14 @@ def __call__(
519566
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
520567
args: optional args to be passed to ``network``.
521568
kwargs: optional keyword args to be passed to ``network``.
522-
569+
condition (torch.Tensor, optional): If provided via `**kwargs`,
570+
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
571+
The resulting segments will be passed to the model together with the corresponding input segments.
523572
"""
573+
# shape check for condition
574+
condition = kwargs.get("condition", None)
575+
if condition is not None and condition.shape != inputs.shape:
576+
raise ValueError(f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}")
524577

525578
device = kwargs.pop("device", self.device)
526579
buffer_steps = kwargs.pop("buffer_steps", self.buffer_steps)
@@ -728,7 +781,9 @@ def __call__(
728781
network: 2D model to execute inference on slices in the 3D input
729782
args: optional args to be passed to ``network``.
730783
kwargs: optional keyword args to be passed to ``network``.
731-
"""
784+
condition (torch.Tensor, optional): If provided via `**kwargs`,
785+
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
786+
The resulting segments will be passed to the model together with the corresponding input segments."""
732787
if self.spatial_dim > 2:
733788
raise ValueError("`spatial_dim` can only be `0, 1, 2` with `[H, W, D]` respectively.")
734789

@@ -742,12 +797,28 @@ def __call__(
742797
f"Currently, only 2D `roi_size` ({self.orig_roi_size}) with 3D `inputs` tensor (shape={inputs.shape}) is supported."
743798
)
744799

745-
return super().__call__(inputs=inputs, network=lambda x: self.network_wrapper(network, x, *args, **kwargs))
800+
# shape check for condition
801+
condition = kwargs.get("condition", None)
802+
if condition is not None and condition.shape != inputs.shape:
803+
raise ValueError(f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}")
804+
805+
# check if there is a conditioning signal
806+
if condition is not None:
807+
return super().__call__(
808+
inputs=inputs,
809+
network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs),
810+
condition=condition,
811+
)
812+
else:
813+
return super().__call__(
814+
inputs=inputs, network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs)
815+
)
746816

747817
def network_wrapper(
748818
self,
749819
network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
750820
x: torch.Tensor,
821+
condition: torch.Tensor | None = None,
751822
*args: Any,
752823
**kwargs: Any,
753824
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
@@ -756,7 +827,12 @@ def network_wrapper(
756827
"""
757828
# Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
758829
x = x.squeeze(dim=self.spatial_dim + 2)
759-
out = network(x, *args, **kwargs)
830+
831+
if condition is not None:
832+
condition = condition.squeeze(dim=self.spatial_dim + 2)
833+
out = network(x, condition, *args, **kwargs)
834+
else:
835+
out = network(x, *args, **kwargs)
760836

761837
# Unsqueeze the network output so it is [N, C, D, H, W] as expected by
762838
# the default SlidingWindowInferer class

monai/inferers/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def sliding_window_inference(
153153
device = device or inputs.device
154154
sw_device = sw_device or inputs.device
155155

156+
condition = kwargs.pop("condition", None)
157+
156158
temp_meta = None
157159
if isinstance(inputs, MetaTensor):
158160
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
@@ -168,6 +170,8 @@ def sliding_window_inference(
168170
pad_size.extend([half, diff - half])
169171
if any(pad_size):
170172
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
173+
if condition is not None:
174+
condition = F.pad(condition, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
171175

172176
# Store all slices
173177
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
@@ -220,13 +224,19 @@ def sliding_window_inference(
220224
]
221225
if sw_batch_size > 1:
222226
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
227+
if condition is not None:
228+
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
229+
kwargs["condition"] = win_condition
223230
else:
224231
win_data = inputs[unravel_slice[0]].to(sw_device)
232+
if condition is not None:
233+
win_condition = condition[unravel_slice[0]].to(sw_device)
234+
kwargs["condition"] = win_condition
235+
225236
if with_coord:
226-
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
237+
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)
227238
else:
228-
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch
229-
239+
seg_prob_out = predictor(win_data, *args, **kwargs)
230240
# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
231241
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
232242
if process_fn:

0 commit comments

Comments
 (0)