2323from monai .networks .blocks .spatialattention import SpatialAttentionBlock
2424from monai .networks .nets .autoencoderkl import AEKLResBlock , AutoencoderKL
2525from monai .utils .type_conversion import convert_to_tensor
26+ from monai .utils .deprecate_utils import deprecated_arg
2627
2728# Set up logging configuration
2829logger = logging .getLogger (__name__ )
@@ -33,7 +34,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
3334 torch .cuda .empty_cache ()
3435 return
3536
36-
37+ @ deprecated_arg ( "norm_float16" , since = "1.5.0" , removed = "1.7.0" )
3738class MaisiGroupNorm3D (nn .GroupNorm ):
3839 """
3940 Custom 3D Group Normalization with optional print_info output.
@@ -43,7 +44,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
4344 num_channels: Number of channels for the group norm.
4445 eps: Epsilon value for numerical stability.
4546 affine: Whether to use learnable affine parameters, default to `True`.
46- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
47+ norm_float16: Deprecated argument .
4748 print_info: Whether to print information, default to `False`.
4849 save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
4950 """
@@ -59,14 +60,15 @@ def __init__(
5960 save_mem : bool = True ,
6061 ):
6162 super ().__init__ (num_groups , num_channels , eps , affine )
62- self .norm_float16 = norm_float16
6363 self .print_info = print_info
6464 self .save_mem = save_mem
6565
6666 def forward (self , input : torch .Tensor ) -> torch .Tensor :
6767 if self .print_info :
6868 logger .info (f"MaisiGroupNorm3D with input size: { input .size ()} " )
6969
70+ target_dtype = input .dtype
71+
7072 if len (input .shape ) != 5 :
7173 raise ValueError ("Expected a 5D tensor" )
7274
@@ -75,13 +77,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
7577
7678 inputs = []
7779 for i in range (input .size (1 )):
78- array = input [:, i : i + 1 , ...]. to ( dtype = torch . float32 )
80+ array = input [:, i : i + 1 , ...]
7981 mean = array .mean ([2 , 3 , 4 , 5 ], keepdim = True )
8082 std = array .var ([2 , 3 , 4 , 5 ], unbiased = False , keepdim = True ).add_ (self .eps ).sqrt_ ()
81- if self .norm_float16 :
82- inputs .append (((array - mean ) / std ).to (dtype = torch .float16 ))
83- else :
84- inputs .append ((array - mean ) / std )
83+ inputs .append (((array - mean ) / std ).to (dtype = target_dtype ))
8584
8685 del input
8786 _empty_cuda_cache (self .save_mem )
@@ -375,7 +374,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
375374 x = self .conv (x )
376375 return x
377376
378-
377+ @ deprecated_arg ( "norm_float16" , since = "1.5.0" , removed = "1.7.0" )
379378class MaisiResBlock (nn .Module ):
380379 """
381380 Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
@@ -417,7 +416,6 @@ def __init__(
417416 num_channels = in_channels ,
418417 eps = norm_eps ,
419418 affine = True ,
420- norm_float16 = norm_float16 ,
421419 print_info = print_info ,
422420 save_mem = save_mem ,
423421 )
@@ -439,7 +437,6 @@ def __init__(
439437 num_channels = out_channels ,
440438 eps = norm_eps ,
441439 affine = True ,
442- norm_float16 = norm_float16 ,
443440 print_info = print_info ,
444441 save_mem = save_mem ,
445442 )
@@ -500,7 +497,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
500497 out_tensor : torch .Tensor = convert_to_tensor (out )
501498 return out_tensor
502499
503-
500+ @ deprecated_arg ( "norm_float16" , since = "1.5.0" , removed = "1.7.0" )
504501class MaisiEncoder (nn .Module ):
505502 """
506503 Convolutional cascade that downsamples the image into a spatial latent space.
@@ -520,7 +517,7 @@ class MaisiEncoder(nn.Module):
520517 use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
521518 num_splits: Number of splits for the input tensor.
522519 dim_split: Dimension of splitting for the input tensor.
523- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
520+ norm_float16: Deprecated argument .
524521 print_info: Whether to print information, default to `False`.
525522 save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
526523 """
@@ -591,7 +588,6 @@ def __init__(
591588 out_channels = output_channel ,
592589 num_splits = num_splits ,
593590 dim_split = dim_split ,
594- norm_float16 = norm_float16 ,
595591 print_info = print_info ,
596592 save_mem = save_mem ,
597593 )
@@ -660,7 +656,6 @@ def __init__(
660656 num_channels = num_channels [- 1 ],
661657 eps = norm_eps ,
662658 affine = True ,
663- norm_float16 = norm_float16 ,
664659 print_info = print_info ,
665660 save_mem = save_mem ,
666661 )
@@ -689,7 +684,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
689684 _empty_cuda_cache (self .save_mem )
690685 return x
691686
692-
687+ @ deprecated_arg ( "norm_float16" , since = "1.5.0" , removed = "1.7.0" )
693688class MaisiDecoder (nn .Module ):
694689 """
695690 Convolutional cascade upsampling from a spatial latent space into an image space.
@@ -710,7 +705,7 @@ class MaisiDecoder(nn.Module):
710705 use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
711706 num_splits: Number of splits for the input tensor.
712707 dim_split: Dimension of splitting for the input tensor.
713- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
708+ norm_float16: Deprecated argument .
714709 print_info: Whether to print information, default to `False`.
715710 save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
716711 """
@@ -809,7 +804,6 @@ def __init__(
809804 out_channels = block_out_ch ,
810805 num_splits = num_splits ,
811806 dim_split = dim_split ,
812- norm_float16 = norm_float16 ,
813807 print_info = print_info ,
814808 save_mem = save_mem ,
815809 )
@@ -848,7 +842,6 @@ def __init__(
848842 num_channels = block_in_ch ,
849843 eps = norm_eps ,
850844 affine = True ,
851- norm_float16 = norm_float16 ,
852845 print_info = print_info ,
853846 save_mem = save_mem ,
854847 )
@@ -878,6 +871,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
878871 return x
879872
880873
874+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
881875class AutoencoderKlMaisi (AutoencoderKL ):
882876 """
883877 AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
@@ -901,7 +895,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
901895 use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
902896 num_splits: Number of splits for the input tensor.
903897 dim_split: Dimension of splitting for the input tensor.
904- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
898+ norm_float16: Deprecated argument .
905899 print_info: Whether to print information, default to `False`.
906900 save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
907901 """
@@ -964,7 +958,6 @@ def __init__(
964958 use_flash_attention = use_flash_attention ,
965959 num_splits = num_splits ,
966960 dim_split = dim_split ,
967- norm_float16 = norm_float16 ,
968961 print_info = print_info ,
969962 save_mem = save_mem ,
970963 )
@@ -985,7 +978,6 @@ def __init__(
985978 use_convtranspose = use_convtranspose ,
986979 num_splits = num_splits ,
987980 dim_split = dim_split ,
988- norm_float16 = norm_float16 ,
989981 print_info = print_info ,
990982 save_mem = save_mem ,
991983 )
0 commit comments