@@ -88,8 +88,8 @@ def __init__(
8888
8989 for layer in self .conv .children ():
9090 if isinstance (layer , conv_type ): # type: ignore
91- torch .nn .init .normal_ (layer .weight , std = 0.01 )
92- torch .nn .init .constant_ (layer .bias , 0 )
91+ torch .nn .init .normal_ (layer .weight , std = 0.01 ) # type: ignore[arg-type]
92+ torch .nn .init .constant_ (layer .bias , 0 ) # type: ignore[arg-type]
9393
9494 self .cls_logits = conv_type (in_channels , num_anchors * num_classes , kernel_size = 3 , stride = 1 , padding = 1 )
9595 torch .nn .init .normal_ (self .cls_logits .weight , std = 0.01 )
@@ -167,8 +167,8 @@ def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int):
167167
168168 for layer in self .conv .children ():
169169 if isinstance (layer , conv_type ): # type: ignore
170- torch .nn .init .normal_ (layer .weight , std = 0.01 )
171- torch .nn .init .zeros_ (layer .bias )
170+ torch .nn .init .normal_ (layer .weight , std = 0.01 ) # type: ignore[arg-type]
171+ torch .nn .init .zeros_ (layer .bias ) # type: ignore[arg-type]
172172
173173 def forward (self , x : list [Tensor ]) -> list [Tensor ]:
174174 """
@@ -297,7 +297,7 @@ def __init__(
297297 )
298298 self .feature_extractor = feature_extractor
299299
300- self .feature_map_channels : int = self .feature_extractor .out_channels
300+ self .feature_map_channels : int = self .feature_extractor .out_channels # type: ignore[assignment]
301301 self .num_anchors = num_anchors
302302 self .classification_head = RetinaNetClassificationHead (
303303 self .feature_map_channels , self .num_anchors , self .num_classes , spatial_dims = self .spatial_dims
0 commit comments