Skip to content

Commit

Permalink
add descriptions of how to do training and inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 authored Jan 3, 2021
1 parent 2354dd4 commit 3843c2a
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,43 @@ If you need to modify these hyperparameters, just overwrite them:
model = DistillableVisionTransformer.from_name('ViT-B_16', patch_size=64, emb_dim=2048, ...)
```

### Training

Wrap the student (instance of `DistillableVisionTransformer`) and the teacher (any network that you want to use to train the student) with a `DistillationTrainer`:

```
from distillable_vision_transformer import DistillableVisionTransformer, DistillationTrainer
student = DistillableVisionTransformer.from_pretrained('ViT-B_16')
trainer = DistillationTrainer(teacher=teacher, student=student) # where teacher is some pretrained network, e.g. an EfficientNet
```

For the loss function, it is recommended that you use the `DistilledLoss` class, which is a kind of hybrid between cross-entropy and KL-divergence loss.
It takes as arguments `teacher_logits`, `student_logits`, and `distill_logits`, which are obtained from the forward pass on `DistillationTrainer`, as well as the true labels `labels`.

```
from distillable_vision_transformer import DistilledLoss
loss_fn = DistilledLoss(alpha=0.5, temperature=3.0)
loss = loss_fn(teacher_logits, student_logits, distill_logits, labels)
```

### Inference

For inference, we want to use the `DistillableVisionTransformer` instance, not its `DistillationTrainer` wrapper.

```
import torch
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
model.eval()
inputs = torch.rand(1, 3, *model.image_size)
# we can discard the distillation tokens, as they are only needed to calculate loss
outputs, _ = model(inputs)
```

### Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

0 comments on commit 3843c2a

Please sign in to comment.