@@ -347,9 +347,14 @@ def __call__(
347347 f"The provided inputs type is { type (inputs )} ."
348348 )
349349 patches_locations = inputs
350+ if condition is not None :
351+ condition_locations = condition
350352 else :
351353 # apply splitter
352354 patches_locations = self .splitter (inputs )
355+ if condition is not None :
356+ # apply splitter to condition
357+ condition_locations = self .splitter (condition )
353358
354359 ratios : list [float ] = []
355360 mergers : list [Merger ] = []
@@ -776,6 +781,7 @@ def network_wrapper(
776781 self ,
777782 network : Callable [..., torch .Tensor | Sequence [torch .Tensor ] | dict [Any , torch .Tensor ]],
778783 x : torch .Tensor ,
784+ condition : torch .Tensor | None = None ,
779785 * args : Any ,
780786 ** kwargs : Any ,
781787 ) -> torch .Tensor | tuple [torch .Tensor , ...] | dict [Any , torch .Tensor ]:
@@ -784,7 +790,12 @@ def network_wrapper(
784790 """
785791 # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
786792 x = x .squeeze (dim = self .spatial_dim + 2 )
787- out = network (x , * args , ** kwargs )
793+
794+ if condition is not None :
795+ condition = condition .squeeze (dim = self .spatial_dim + 2 )
796+ out = network (x , condition , * args , ** kwargs )
797+ else :
798+ out = network (x , * args , ** kwargs )
788799
789800 # Unsqueeze the network output so it is [N, C, D, H, W] as expected by
790801 # the default SlidingWindowInferer class
0 commit comments