@@ -322,8 +322,36 @@ def __call__(
322322 supports callables such as ``lambda x: my_torch_model(x, additional_config)``
323323 args: optional args to be passed to ``network``.
324324 kwargs: optional keyword args to be passed to ``network``.
325+ condition (torch.Tensor, optional): If provided via `**kwargs`,
326+ this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
327+ The resulting segments will be passed to the model together with the corresponding input segments.
325328
326329 """
330+ # check if there is a conditioning signal
331+ condition = kwargs .pop ("condition" , None )
332+ # shape check for condition
333+ if condition is not None :
334+ if isinstance (inputs , torch .Tensor ) and isinstance (condition , torch .Tensor ):
335+ if condition .shape != inputs .shape :
336+ raise ValueError (
337+ f"`condition` must match shape of `inputs` ({ inputs .shape } ), but got { condition .shape } "
338+ )
339+ elif isinstance (inputs , list ) and isinstance (condition , list ):
340+ if len (inputs ) != len (condition ):
341+ raise ValueError (
342+ f"Length of `condition` must match `inputs`. Got { len (inputs )} and { len (condition )} ."
343+ )
344+ for (in_patch , _ ), (cond_patch , _ ) in zip (inputs , condition ):
345+ if cond_patch .shape != in_patch .shape :
346+ raise ValueError (
347+ "Each `condition` patch must match the shape of the corresponding input patch. "
348+ f"Got { cond_patch .shape } and { in_patch .shape } ."
349+ )
350+ else :
351+ raise ValueError (
352+ "`condition` and `inputs` must be of the same type (both Tensor or both list of patches)."
353+ )
354+
327355 patches_locations : Iterable [tuple [torch .Tensor , Sequence [int ]]] | MetaTensor
328356 if self .splitter is None :
329357 # handle situations where the splitter is not provided
@@ -344,20 +372,39 @@ def __call__(
344372 f"The provided inputs type is { type (inputs )} ."
345373 )
346374 patches_locations = inputs
375+ if condition is not None :
376+ condition_locations = condition
347377 else :
348378 # apply splitter
349379 patches_locations = self .splitter (inputs )
380+ if condition is not None :
381+ # apply splitter to condition
382+ condition_locations = self .splitter (condition )
350383
351384 ratios : list [float ] = []
352385 mergers : list [Merger ] = []
353- for patches , locations , batch_size in self ._batch_sampler (patches_locations ):
354- # run inference
355- outputs = self ._run_inference (network , patches , * args , ** kwargs )
356- # initialize the mergers
357- if not mergers :
358- mergers , ratios = self ._initialize_mergers (inputs , outputs , patches , batch_size )
359- # aggregate outputs
360- self ._aggregate (outputs , locations , batch_size , mergers , ratios )
386+ if condition is not None :
387+ for (patches , locations , batch_size ), (condition_patches , _ , _ ) in zip (
388+ self ._batch_sampler (patches_locations ), self ._batch_sampler (condition_locations )
389+ ):
390+ # add patched condition to kwargs
391+ kwargs ["condition" ] = condition_patches
392+ # run inference
393+ outputs = self ._run_inference (network , patches , * args , ** kwargs )
394+ # initialize the mergers
395+ if not mergers :
396+ mergers , ratios = self ._initialize_mergers (inputs , outputs , patches , batch_size )
397+ # aggregate outputs
398+ self ._aggregate (outputs , locations , batch_size , mergers , ratios )
399+ else :
400+ for patches , locations , batch_size in self ._batch_sampler (patches_locations ):
401+ # run inference
402+ outputs = self ._run_inference (network , patches , * args , ** kwargs )
403+ # initialize the mergers
404+ if not mergers :
405+ mergers , ratios = self ._initialize_mergers (inputs , outputs , patches , batch_size )
406+ # aggregate outputs
407+ self ._aggregate (outputs , locations , batch_size , mergers , ratios )
361408
362409 # finalize the mergers and get the results
363410 merged_outputs = [merger .finalize () for merger in mergers ]
@@ -519,8 +566,14 @@ def __call__(
519566 supports callables such as ``lambda x: my_torch_model(x, additional_config)``
520567 args: optional args to be passed to ``network``.
521568 kwargs: optional keyword args to be passed to ``network``.
522-
569+ condition (torch.Tensor, optional): If provided via `**kwargs`,
570+ this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
571+ The resulting segments will be passed to the model together with the corresponding input segments.
523572 """
573+ # shape check for condition
574+ condition = kwargs .get ("condition" , None )
575+ if condition is not None and condition .shape != inputs .shape :
576+ raise ValueError (f"`condition` must match shape of `inputs` ({ inputs .shape } ), but got { condition .shape } " )
524577
525578 device = kwargs .pop ("device" , self .device )
526579 buffer_steps = kwargs .pop ("buffer_steps" , self .buffer_steps )
@@ -728,7 +781,9 @@ def __call__(
728781 network: 2D model to execute inference on slices in the 3D input
729782 args: optional args to be passed to ``network``.
730783 kwargs: optional keyword args to be passed to ``network``.
731- """
784+ condition (torch.Tensor, optional): If provided via `**kwargs`,
785+ this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
786+ The resulting segments will be passed to the model together with the corresponding input segments."""
732787 if self .spatial_dim > 2 :
733788 raise ValueError ("`spatial_dim` can only be `0, 1, 2` with `[H, W, D]` respectively." )
734789
@@ -742,12 +797,28 @@ def __call__(
742797 f"Currently, only 2D `roi_size` ({ self .orig_roi_size } ) with 3D `inputs` tensor (shape={ inputs .shape } ) is supported."
743798 )
744799
745- return super ().__call__ (inputs = inputs , network = lambda x : self .network_wrapper (network , x , * args , ** kwargs ))
800+ # shape check for condition
801+ condition = kwargs .get ("condition" , None )
802+ if condition is not None and condition .shape != inputs .shape :
803+ raise ValueError (f"`condition` must match shape of `inputs` ({ inputs .shape } ), but got { condition .shape } " )
804+
805+ # check if there is a conditioning signal
806+ if condition is not None :
807+ return super ().__call__ (
808+ inputs = inputs ,
809+ network = lambda x , * args , ** kwargs : self .network_wrapper (network , x , * args , ** kwargs ),
810+ condition = condition ,
811+ )
812+ else :
813+ return super ().__call__ (
814+ inputs = inputs , network = lambda x , * args , ** kwargs : self .network_wrapper (network , x , * args , ** kwargs )
815+ )
746816
747817 def network_wrapper (
748818 self ,
749819 network : Callable [..., torch .Tensor | Sequence [torch .Tensor ] | dict [Any , torch .Tensor ]],
750820 x : torch .Tensor ,
821+ condition : torch .Tensor | None = None ,
751822 * args : Any ,
752823 ** kwargs : Any ,
753824 ) -> torch .Tensor | tuple [torch .Tensor , ...] | dict [Any , torch .Tensor ]:
@@ -756,7 +827,12 @@ def network_wrapper(
756827 """
757828 # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
758829 x = x .squeeze (dim = self .spatial_dim + 2 )
759- out = network (x , * args , ** kwargs )
830+
831+ if condition is not None :
832+ condition = condition .squeeze (dim = self .spatial_dim + 2 )
833+ out = network (x , condition , * args , ** kwargs )
834+ else :
835+ out = network (x , * args , ** kwargs )
760836
761837 # Unsqueeze the network output so it is [N, C, D, H, W] as expected by
762838 # the default SlidingWindowInferer class
0 commit comments