1313
1414import itertools
1515from collections .abc import Sequence
16- from typing import Final
1716
1817import numpy as np
1918import torch
@@ -51,8 +50,6 @@ class SwinUNETR(nn.Module):
5150 <https://arxiv.org/abs/2201.01266>"
5251 """
5352
54- patch_size : Final [int ] = 2
55-
5653 @deprecated_arg (
5754 name = "img_size" ,
5855 since = "1.3" ,
@@ -65,18 +62,24 @@ def __init__(
6562 img_size : Sequence [int ] | int ,
6663 in_channels : int ,
6764 out_channels : int ,
65+ patch_size : int = 2 ,
6866 depths : Sequence [int ] = (2 , 2 , 2 , 2 ),
6967 num_heads : Sequence [int ] = (3 , 6 , 12 , 24 ),
68+ window_size : Sequence [int ] | int = 7 ,
69+ qkv_bias : bool = True ,
70+ mlp_ratio : float = 4.0 ,
7071 feature_size : int = 24 ,
7172 norm_name : tuple | str = "instance" ,
7273 drop_rate : float = 0.0 ,
7374 attn_drop_rate : float = 0.0 ,
7475 dropout_path_rate : float = 0.0 ,
7576 normalize : bool = True ,
77+ norm_layer : type [LayerNorm ] = nn .LayerNorm ,
78+ patch_norm : bool = True ,
7679 use_checkpoint : bool = False ,
7780 spatial_dims : int = 3 ,
78- downsample = "merging" ,
79- use_v2 = False ,
81+ downsample : str | nn . Module = "merging" ,
82+ use_v2 : bool = False ,
8083 ) -> None :
8184 """
8285 Args:
@@ -86,14 +89,20 @@ def __init__(
8689 It will be removed in an upcoming version.
8790 in_channels: dimension of input channels.
8891 out_channels: dimension of output channels.
92+ patch_size: size of the patch token.
8993 feature_size: dimension of network feature size.
9094 depths: number of layers in each stage.
9195 num_heads: number of attention heads.
96+ window_size: local window size.
97+ qkv_bias: add a learnable bias to query, key, value.
98+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
9299 norm_name: feature normalization type and arguments.
93100 drop_rate: dropout rate.
94101 attn_drop_rate: attention dropout rate.
95102 dropout_path_rate: drop path rate.
96103 normalize: normalize output intermediate features in each stage.
104+ norm_layer: normalization layer.
105+ patch_norm: whether to apply normalization to the patch embedding.
97106 use_checkpoint: use gradient checkpointing for reduced memory usage.
98107 spatial_dims: number of spatial dims.
99108 downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
@@ -116,13 +125,15 @@ def __init__(
116125
117126 super ().__init__ ()
118127
119- img_size = ensure_tuple_rep (img_size , spatial_dims )
120- patch_sizes = ensure_tuple_rep (self .patch_size , spatial_dims )
121- window_size = ensure_tuple_rep (7 , spatial_dims )
122-
123128 if spatial_dims not in (2 , 3 ):
124129 raise ValueError ("spatial dimension should be 2 or 3." )
125130
131+ self .patch_size = patch_size
132+
133+ img_size = ensure_tuple_rep (img_size , spatial_dims )
134+ patch_sizes = ensure_tuple_rep (self .patch_size , spatial_dims )
135+ window_size = ensure_tuple_rep (window_size , spatial_dims )
136+
126137 self ._check_input_size (img_size )
127138
128139 if not (0 <= drop_rate <= 1 ):
@@ -146,12 +157,13 @@ def __init__(
146157 patch_size = patch_sizes ,
147158 depths = depths ,
148159 num_heads = num_heads ,
149- mlp_ratio = 4.0 ,
150- qkv_bias = True ,
160+ mlp_ratio = mlp_ratio ,
161+ qkv_bias = qkv_bias ,
151162 drop_rate = drop_rate ,
152163 attn_drop_rate = attn_drop_rate ,
153164 drop_path_rate = dropout_path_rate ,
154- norm_layer = nn .LayerNorm ,
165+ norm_layer = norm_layer ,
166+ patch_norm = patch_norm ,
155167 use_checkpoint = use_checkpoint ,
156168 spatial_dims = spatial_dims ,
157169 downsample = look_up_option (downsample , MERGING_MODE ) if isinstance (downsample , str ) else downsample ,
0 commit comments