@@ -377,7 +377,7 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch
377377 See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
378378
379379 Args:
380- x: Input tensor
380+ x: Input tensor with shape BCHW[D]
381381 spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
382382 scale_factor: factor to rescale the spatial dimensions by, must be >=1
383383
@@ -423,7 +423,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor
423423 See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
424424
425425 Args:
426- x: Input tensor
426+ x: Input tensor with shape BCHW[D]
427427 spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
428428 scale_factor: factor to reduce the spatial dimensions by, must be >=1
429429
@@ -443,7 +443,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor
443443
444444 if any (d % factor != 0 for d in input_size [2 :]):
445445 raise ValueError (
446- f"All spatial dimensions must be divisible by factor { factor } . " f"Got spatial dimensions : { input_size [2 :]} "
446+ f"All spatial dimensions must be divisible by factor { factor } . " f", spatial shape is : { input_size [2 :]} "
447447 )
448448 output_size = [batch_size , new_channels ] + [d // factor for d in input_size [2 :]]
449449 reshaped_size = [batch_size , channels ] + sum ([[d // factor , factor ] for d in input_size [2 :]], [])
0 commit comments