Skip to content

Commit 8faa5da

Browse files
committed
rename args and fix imports
1 parent c7b1af4 commit 8faa5da

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

monai/networks/nets/restormer.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ def __init__(
3232
num_heads: int,
3333
ffn_expansion_factor: float,
3434
bias: bool,
35-
LayerNorm_type: str,
35+
layer_norm_type: str = "BiasFree",
3636
flash_attention: bool = False,
3737
):
3838
super().__init__()
39-
use_bias = LayerNorm_type != "BiasFree"
39+
use_bias = layer_norm_type != "BiasFree"
4040
self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias)
4141
self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention)
4242
self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias)
@@ -88,19 +88,19 @@ class Restormer(nn.Module):
8888

8989
def __init__(
9090
self,
91-
spatial_dims=2,
92-
inp_channels=3,
93-
out_channels=3,
94-
dim=48,
95-
num_blocks=[1, 1, 1, 1],
96-
heads=[1, 1, 1, 1],
97-
num_refinement_blocks=4,
98-
ffn_expansion_factor=2.66,
99-
bias=False,
100-
LayerNorm_type="WithBias",
101-
dual_pixel_task=False,
102-
flash_attention=False,
103-
):
91+
spatial_dims: int = 2,
92+
inp_channels: int = 3,
93+
out_channels: int = 3,
94+
dim: int = 48,
95+
num_blocks: tuple[int, ...] = (1, 1, 1, 1),
96+
heads: tuple[int, ...] = (1, 1, 1, 1),
97+
num_refinement_blocks: int = 4,
98+
ffn_expansion_factor: float = 2.66,
99+
bias: bool = False,
100+
layer_norm_type: str = "WithBias",
101+
dual_pixel_task: bool = False,
102+
flash_attention: bool = False,
103+
) -> None:
104104
super().__init__()
105105
"""Initialize Restormer model.
106106
@@ -113,14 +113,14 @@ def __init__(
113113
heads: Number of attention heads at each scale
114114
ffn_expansion_factor: Expansion factor for feed-forward network
115115
bias: Whether to use bias in convolutions
116-
LayerNorm_type: Type of normalization ('WithBias' or 'BiasFree')
116+
layer_norm_type: Type of normalization ('WithBias' or 'BiasFree')
117117
dual_pixel_task: Enable dual-pixel specific processing
118118
flash_attention: Use flash attention if available
119119
"""
120120
# Check input parameters
121121
assert len(num_blocks) > 1, "Number of blocks must be greater than 1"
122122
assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal"
123-
assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0"
123+
assert all(n > 0 for n in num_blocks), "Number of blocks must be greater than 0"
124124

125125
# Initial feature extraction
126126
self.patch_embed = OverlapPatchEmbed(spatial_dims, inp_channels, dim)
@@ -147,7 +147,7 @@ def __init__(
147147
num_heads=heads[n],
148148
ffn_expansion_factor=ffn_expansion_factor,
149149
bias=bias,
150-
LayerNorm_type=LayerNorm_type,
150+
layer_norm_type=layer_norm_type,
151151
flash_attention=flash_attention,
152152
)
153153
for _ in range(num_blocks[n])
@@ -176,7 +176,7 @@ def __init__(
176176
num_heads=heads[num_steps],
177177
ffn_expansion_factor=ffn_expansion_factor,
178178
bias=bias,
179-
LayerNorm_type=LayerNorm_type,
179+
layer_norm_type=layer_norm_type,
180180
flash_attention=flash_attention,
181181
)
182182
for _ in range(num_blocks[num_steps])
@@ -224,7 +224,7 @@ def __init__(
224224
num_heads=heads[n],
225225
ffn_expansion_factor=ffn_expansion_factor,
226226
bias=bias,
227-
LayerNorm_type=LayerNorm_type,
227+
layer_norm_type=layer_norm_type,
228228
flash_attention=flash_attention,
229229
)
230230
for _ in range(num_blocks[n])
@@ -241,7 +241,7 @@ def __init__(
241241
num_heads=heads[0],
242242
ffn_expansion_factor=ffn_expansion_factor,
243243
bias=bias,
244-
LayerNorm_type=LayerNorm_type,
244+
layer_norm_type=layer_norm_type,
245245
flash_attention=flash_attention,
246246
)
247247
for _ in range(num_refinement_blocks)
@@ -286,7 +286,7 @@ def forward(self, x) -> torch.Tensor:
286286
skip_connections = []
287287

288288
# Encoding path
289-
for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
289+
for _idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
290290
x = encoder(x)
291291
skip_connections.append(x)
292292
x = downsample(x)

0 commit comments

Comments
 (0)