Skip to content

Commit

Permalink
make python code look pretty
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 authored Jan 3, 2021
1 parent 3843c2a commit 4e8abec
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This repository will allow you to use distillation techniques with vision transf

Install with `pip install distillable_vision_transformer` and load a pretrained transformer with:

```
```python
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
```
Expand All @@ -32,14 +32,14 @@ pip install -e .

Load a model architecture:

```
```python
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_name('ViT-B_16')
```

Load a pretrained model:

```
```python
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
```
Expand All @@ -60,15 +60,15 @@ Default hyper parameters:

If you need to modify these hyperparameters, just overwrite them:

```
```python
model = DistillableVisionTransformer.from_name('ViT-B_16', patch_size=64, emb_dim=2048, ...)
```

### Training

Wrap the student (instance of `DistillableVisionTransformer`) and the teacher (any network that you want to use to train the student) with a `DistillationTrainer`:

```
```python
from distillable_vision_transformer import DistillableVisionTransformer, DistillationTrainer

student = DistillableVisionTransformer.from_pretrained('ViT-B_16')
Expand All @@ -78,7 +78,7 @@ trainer = DistillationTrainer(teacher=teacher, student=student) # where teacher
For the loss function, it is recommended that you use the `DistilledLoss` class, which is a kind of hybrid between cross-entropy and KL-divergence loss.
It takes as arguments `teacher_logits`, `student_logits`, and `distill_logits`, which are obtained from the forward pass on `DistillationTrainer`, as well as the true labels `labels`.

```
```python
from distillable_vision_transformer import DistilledLoss

loss_fn = DistilledLoss(alpha=0.5, temperature=3.0)
Expand All @@ -89,7 +89,7 @@ loss = loss_fn(teacher_logits, student_logits, distill_logits, labels)

For inference, we want to use the `DistillableVisionTransformer` instance, not its `DistillationTrainer` wrapper.

```
```python
import torch
from distillable_vision_transformer import DistillableVisionTransformer

Expand Down

0 comments on commit 4e8abec

Please sign in to comment.