File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change 11"""Knowledge Distillation helpers for training with a teacher model."""
22import logging
3- from typing import Tuple
3+ from typing import Optional , Tuple
44
55import torch
66import torch .nn as nn
@@ -22,7 +22,6 @@ class DistillationTeacher(nn.Module):
2222 model_name: Name of the teacher model to create
2323 num_classes: Number of output classes
2424 in_chans: Number of input channels
25- pretrained: Whether to load pretrained weights
2625 device: Device to place the model on (default: 'cuda')
2726 dtype: Model dtype (default: None, uses float32)
2827 """
@@ -32,18 +31,27 @@ def __init__(
3231 model_name : str ,
3332 num_classes : int ,
3433 in_chans : int = 3 ,
34+ pretrained_path : Optional [str ] = None ,
3535 device : torch .device = torch .device ('cuda' ),
3636 dtype : torch .dtype = None ,
3737 ):
3838 super ().__init__ ()
3939
4040 _logger .info (f"Creating KD teacher model: '{ model_name } '" )
4141
42+ pretrained_kwargs = {'pretrained' : True }
43+ if pretrained_path :
44+ # specify a local checkpoint path to load pretrained weights from
45+ pretrained_kwargs ['pretrained_cfg_overlay' ] = dict (
46+ file = pretrained_path ,
47+ num_classes = num_classes , # needed to avoid head adaptation?
48+ )
49+
4250 model_kd = create_model (
4351 model_name = model_name ,
4452 num_classes = num_classes ,
45- pretrained = True ,
4653 in_chans = in_chans ,
54+ ** pretrained_kwargs ,
4755 )
4856
4957 model_kd = model_kd .to (device = device , dtype = dtype )
You can’t perform that action at this time.
0 commit comments