Skip to content

Commit baa1eab

Browse files
committed
Revamp distillation by using a lightweight task arch. Encapsulates extra projections, etc that may be needed.
1 parent 080b55b commit baa1eab

File tree

7 files changed

+888
-208
lines changed

7 files changed

+888
-208
lines changed

timm/kd/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

timm/kd/distillation.py

Lines changed: 0 additions & 150 deletions
This file was deleted.

timm/task/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Training task abstractions for timm.
2+
3+
This module provides task-based abstractions for training loops where each task
4+
encapsulates both the forward pass and loss computation, returning a dictionary
5+
with loss components and outputs for logging.
6+
"""
7+
from .task import TrainingTask
8+
from .classification import ClassificationTask
9+
from .distillation import DistillationTeacher, LogitDistillationTask, FeatureDistillationTask
10+
11+
__all__ = [
12+
'TrainingTask',
13+
'ClassificationTask',
14+
'DistillationTeacher',
15+
'LogitDistillationTask',
16+
'FeatureDistillationTask',
17+
]

timm/task/classification.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Classification training task."""
2+
import logging
3+
from typing import Callable, Dict, Optional, Union
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
from .task import TrainingTask
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
13+
class ClassificationTask(TrainingTask):
14+
"""Standard supervised classification task.
15+
16+
Simple task that performs a forward pass through the model and computes
17+
the classification loss.
18+
19+
Args:
20+
model: The model to train
21+
criterion: Loss function (e.g., CrossEntropyLoss)
22+
device: Device for task tensors/buffers
23+
dtype: Dtype for task tensors/buffers
24+
verbose: Enable info logging
25+
26+
Example:
27+
>>> task = ClassificationTask(model, nn.CrossEntropyLoss(), device=torch.device('cuda'))
28+
>>> result = task(input, target)
29+
>>> result['loss'].backward()
30+
"""
31+
32+
def __init__(
33+
self,
34+
model: nn.Module,
35+
criterion: Union[nn.Module, Callable],
36+
device: Optional[torch.device] = None,
37+
dtype: Optional[torch.dtype] = None,
38+
verbose: bool = True,
39+
):
40+
super().__init__(device=device, dtype=dtype, verbose=verbose)
41+
self.model = model
42+
self.criterion = criterion
43+
44+
if self.verbose:
45+
loss_name = getattr(criterion, '__name__', None) or type(criterion).__name__
46+
_logger.info(f"ClassificationTask: criterion={loss_name}")
47+
48+
def prepare_distributed(
49+
self,
50+
device_ids: Optional[list] = None,
51+
**ddp_kwargs
52+
) -> 'ClassificationTask':
53+
"""Prepare task for distributed training.
54+
55+
Wraps the model in DistributedDataParallel (DDP).
56+
57+
Args:
58+
device_ids: List of device IDs for DDP (e.g., [local_rank])
59+
**ddp_kwargs: Additional arguments passed to DistributedDataParallel
60+
61+
Returns:
62+
self (for method chaining)
63+
"""
64+
from torch.nn.parallel import DistributedDataParallel as DDP
65+
self.model = DDP(self.model, device_ids=device_ids, **ddp_kwargs)
66+
return self
67+
68+
def forward(
69+
self,
70+
input: torch.Tensor,
71+
target: torch.Tensor,
72+
) -> Dict[str, torch.Tensor]:
73+
"""Forward pass through model and compute classification loss.
74+
75+
Args:
76+
input: Input tensor [B, C, H, W]
77+
target: Target labels [B]
78+
79+
Returns:
80+
Dictionary containing:
81+
- 'loss': Classification loss
82+
- 'output': Model logits
83+
"""
84+
output = self.model(input)
85+
loss = self.criterion(output, target)
86+
87+
return {
88+
'loss': loss,
89+
'output': output,
90+
}

0 commit comments

Comments
 (0)