Skip to content

Commit 080b55b

Browse files
committed
Add pretrained_path arg for kd
1 parent 743c375 commit 080b55b

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

timm/kd/distillation.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Knowledge Distillation helpers for training with a teacher model."""
22
import logging
3-
from typing import Tuple
3+
from typing import Optional, Tuple
44

55
import torch
66
import torch.nn as nn
@@ -22,7 +22,6 @@ class DistillationTeacher(nn.Module):
2222
model_name: Name of the teacher model to create
2323
num_classes: Number of output classes
2424
in_chans: Number of input channels
25-
pretrained: Whether to load pretrained weights
2625
device: Device to place the model on (default: 'cuda')
2726
dtype: Model dtype (default: None, uses float32)
2827
"""
@@ -32,18 +31,27 @@ def __init__(
3231
model_name: str,
3332
num_classes: int,
3433
in_chans: int = 3,
34+
pretrained_path: Optional[str] = None,
3535
device: torch.device = torch.device('cuda'),
3636
dtype: torch.dtype = None,
3737
):
3838
super().__init__()
3939

4040
_logger.info(f"Creating KD teacher model: '{model_name}'")
4141

42+
pretrained_kwargs = {'pretrained': True}
43+
if pretrained_path:
44+
# specify a local checkpoint path to load pretrained weights from
45+
pretrained_kwargs['pretrained_cfg_overlay'] = dict(
46+
file=pretrained_path,
47+
num_classes=num_classes, # needed to avoid head adaptation?
48+
)
49+
4250
model_kd = create_model(
4351
model_name=model_name,
4452
num_classes=num_classes,
45-
pretrained=True,
4653
in_chans=in_chans,
54+
**pretrained_kwargs,
4755
)
4856

4957
model_kd = model_kd.to(device=device, dtype=dtype)

0 commit comments

Comments
 (0)