@@ -301,11 +301,10 @@ def __init__(
301301
302302 # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
303303 # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
304- self .head_norm_first = head_norm_first
305304 self .norm_pre = norm_layer (self .num_features ) if head_norm_first else nn .Identity ()
306305 self .head = nn .Sequential (OrderedDict ([
307306 ('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
308- ('norm' , nn .Identity () if head_norm_first or num_classes == 0 else norm_layer (self .num_features )),
307+ ('norm' , nn .Identity () if head_norm_first else norm_layer (self .num_features )),
309308 ('flatten' , nn .Flatten (1 ) if global_pool else nn .Identity ()),
310309 ('drop' , nn .Dropout (self .drop_rate )),
311310 ('fc' , nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ())]))
@@ -336,14 +335,7 @@ def reset_classifier(self, num_classes=0, global_pool=None):
336335 if global_pool is not None :
337336 self .head .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
338337 self .head .flatten = nn .Flatten (1 ) if global_pool else nn .Identity ()
339- if num_classes == 0 :
340- self .head .norm = nn .Identity ()
341- self .head .fc = nn .Identity ()
342- else :
343- if not self .head_norm_first :
344- norm_layer = type (self .stem [- 1 ]) # obtain type from stem norm
345- self .head .norm = norm_layer (self .num_features )
346- self .head .fc = nn .Linear (self .num_features , num_classes )
338+ self .head .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
347339
348340 def forward_features (self , x ):
349341 x = self .stem (x )
@@ -407,6 +399,11 @@ def checkpoint_filter_fn(state_dict, model):
407399
408400
409401def _create_convnext (variant , pretrained = False , ** kwargs ):
402+ if kwargs .get ('pretrained_cfg' , '' ) == 'fcmae' :
403+ # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
404+ # This is workaround loading with num_classes=0 w/o removing norm-layer.
405+ kwargs .setdefault ('pretrained_strict' , False )
406+
410407 model = build_model_with_cfg (
411408 ConvNeXt , variant , pretrained ,
412409 pretrained_filter_fn = checkpoint_filter_fn ,
0 commit comments