@@ -338,7 +338,7 @@ def get_classifier(self):
338338
339339 def reset_classifier (self , num_classes : int ):
340340 self .num_classes = num_classes
341- self .classifier = nn .Linear (round (self .cfg [self .networks [self .network_idx ]][- 1 ][1 ] * self .scale ), num_classes )
341+ self .classifier = nn .Linear (round (self .cfg [self .networks [self .network_idx ]][- 1 ][0 ] * self .scale ), num_classes )
342342
343343 def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
344344 return self .features (x )
@@ -367,15 +367,18 @@ def _gen_simplenet(
367367) -> SimpleNet :
368368
369369 model_args = dict (
370- num_classes = num_classes ,
371- in_chans = in_chans ,
372- scale = scale ,
373- network_idx = network_idx ,
374- mode = mode ,
375- drop_rates = drop_rates ,
376- ** kwargs ,
370+ in_chans = in_chans , scale = scale , network_idx = network_idx , mode = mode , drop_rates = drop_rates , ** kwargs ,
377371 )
372+ # to allow for seemless finetuning, remove the num_classes
373+ # and load the model intact, we apply the changes afterward!
374+ if "num_classes" in kwargs :
375+ kwargs .pop ("num_classes" )
378376 model = build_model_with_cfg (SimpleNet , model_variant , pretrained , ** model_args )
377+ # if the num_classes is different than imagenet's, it
378+ # means its going to be finetuned, so only create a
379+ # new classifier after the whole model is loaded!
380+ if num_classes != 1000 :
381+ model .reset_classifier (num_classes )
379382 return model
380383
381384
@@ -436,7 +439,7 @@ def remove_network_settings(kwargs: Dict[str, Any]) -> Dict[str, Any]:
436439 Returns:
437440 Dict[str,Any]: cleaned kwargs
438441 """
439- model_args = {k : v for k , v in kwargs .items () if k not in ["scale" , "network_idx" , "mode" ,"drop_rate" ]}
442+ model_args = {k : v for k , v in kwargs .items () if k not in ["scale" , "network_idx" , "mode" , "drop_rate" ]}
440443 return model_args
441444
442445
0 commit comments