@@ -32,11 +32,11 @@ def __init__(
3232 num_heads : int ,
3333 ffn_expansion_factor : float ,
3434 bias : bool ,
35- LayerNorm_type : str ,
35+ layer_norm_type : str = "BiasFree" ,
3636 flash_attention : bool = False ,
3737 ):
3838 super ().__init__ ()
39- use_bias = LayerNorm_type != "BiasFree"
39+ use_bias = layer_norm_type != "BiasFree"
4040 self .norm1 = Norm [Norm .INSTANCE , spatial_dims ](dim , affine = use_bias )
4141 self .attn = CABlock (spatial_dims , dim , num_heads , bias , flash_attention )
4242 self .norm2 = Norm [Norm .INSTANCE , spatial_dims ](dim , affine = use_bias )
@@ -88,19 +88,19 @@ class Restormer(nn.Module):
8888
8989 def __init__ (
9090 self ,
91- spatial_dims = 2 ,
92- inp_channels = 3 ,
93- out_channels = 3 ,
94- dim = 48 ,
95- num_blocks = [ 1 , 1 , 1 , 1 ] ,
96- heads = [ 1 , 1 , 1 , 1 ] ,
97- num_refinement_blocks = 4 ,
98- ffn_expansion_factor = 2.66 ,
99- bias = False ,
100- LayerNorm_type = "WithBias" ,
101- dual_pixel_task = False ,
102- flash_attention = False ,
103- ):
91+ spatial_dims : int = 2 ,
92+ inp_channels : int = 3 ,
93+ out_channels : int = 3 ,
94+ dim : int = 48 ,
95+ num_blocks : tuple [ int , ...] = ( 1 , 1 , 1 , 1 ) ,
96+ heads : tuple [ int , ...] = ( 1 , 1 , 1 , 1 ) ,
97+ num_refinement_blocks : int = 4 ,
98+ ffn_expansion_factor : float = 2.66 ,
99+ bias : bool = False ,
100+ layer_norm_type : str = "WithBias" ,
101+ dual_pixel_task : bool = False ,
102+ flash_attention : bool = False ,
103+ ) -> None :
104104 super ().__init__ ()
105105 """Initialize Restormer model.
106106
@@ -113,14 +113,14 @@ def __init__(
113113 heads: Number of attention heads at each scale
114114 ffn_expansion_factor: Expansion factor for feed-forward network
115115 bias: Whether to use bias in convolutions
116- LayerNorm_type : Type of normalization ('WithBias' or 'BiasFree')
116+ layer_norm_type : Type of normalization ('WithBias' or 'BiasFree')
117117 dual_pixel_task: Enable dual-pixel specific processing
118118 flash_attention: Use flash attention if available
119119 """
120120 # Check input parameters
121121 assert len (num_blocks ) > 1 , "Number of blocks must be greater than 1"
122122 assert len (num_blocks ) == len (heads ), "Number of blocks and heads must be equal"
123- assert all ([ n > 0 for n in num_blocks ] ), "Number of blocks must be greater than 0"
123+ assert all (n > 0 for n in num_blocks ), "Number of blocks must be greater than 0"
124124
125125 # Initial feature extraction
126126 self .patch_embed = OverlapPatchEmbed (spatial_dims , inp_channels , dim )
@@ -147,7 +147,7 @@ def __init__(
147147 num_heads = heads [n ],
148148 ffn_expansion_factor = ffn_expansion_factor ,
149149 bias = bias ,
150- LayerNorm_type = LayerNorm_type ,
150+ layer_norm_type = layer_norm_type ,
151151 flash_attention = flash_attention ,
152152 )
153153 for _ in range (num_blocks [n ])
@@ -176,7 +176,7 @@ def __init__(
176176 num_heads = heads [num_steps ],
177177 ffn_expansion_factor = ffn_expansion_factor ,
178178 bias = bias ,
179- LayerNorm_type = LayerNorm_type ,
179+ layer_norm_type = layer_norm_type ,
180180 flash_attention = flash_attention ,
181181 )
182182 for _ in range (num_blocks [num_steps ])
@@ -224,7 +224,7 @@ def __init__(
224224 num_heads = heads [n ],
225225 ffn_expansion_factor = ffn_expansion_factor ,
226226 bias = bias ,
227- LayerNorm_type = LayerNorm_type ,
227+ layer_norm_type = layer_norm_type ,
228228 flash_attention = flash_attention ,
229229 )
230230 for _ in range (num_blocks [n ])
@@ -241,7 +241,7 @@ def __init__(
241241 num_heads = heads [0 ],
242242 ffn_expansion_factor = ffn_expansion_factor ,
243243 bias = bias ,
244- LayerNorm_type = LayerNorm_type ,
244+ layer_norm_type = layer_norm_type ,
245245 flash_attention = flash_attention ,
246246 )
247247 for _ in range (num_refinement_blocks )
@@ -286,7 +286,7 @@ def forward(self, x) -> torch.Tensor:
286286 skip_connections = []
287287
288288 # Encoding path
289- for idx , (encoder , downsample ) in enumerate (zip (self .encoder_levels , self .downsamples )):
289+ for _idx , (encoder , downsample ) in enumerate (zip (self .encoder_levels , self .downsamples )):
290290 x = encoder (x )
291291 skip_connections .append (x )
292292 x = downsample (x )
0 commit comments