From 4e8abecf27e92dffd8d00f3d9b5ad4a21079cd0e Mon Sep 17 00:00:00 2001 From: Graeme Holliday Date: Sat, 2 Jan 2021 22:28:03 -0800 Subject: [PATCH] make python code look pretty --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 66e6f80..441a7e1 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This repository will allow you to use distillation techniques with vision transf Install with `pip install distillable_vision_transformer` and load a pretrained transformer with: -``` +```python from distillable_vision_transformer import DistillableVisionTransformer model = DistillableVisionTransformer.from_pretrained('ViT-B_16') ``` @@ -32,14 +32,14 @@ pip install -e . Load a model architecture: -``` +```python from distillable_vision_transformer import DistillableVisionTransformer model = DistillableVisionTransformer.from_name('ViT-B_16') ``` Load a pretrained model: -``` +```python from distillable_vision_transformer import DistillableVisionTransformer model = DistillableVisionTransformer.from_pretrained('ViT-B_16') ``` @@ -60,7 +60,7 @@ Default hyper parameters: If you need to modify these hyperparameters, just overwrite them: -``` +```python model = DistillableVisionTransformer.from_name('ViT-B_16', patch_size=64, emb_dim=2048, ...) ``` @@ -68,7 +68,7 @@ model = DistillableVisionTransformer.from_name('ViT-B_16', patch_size=64, emb_di Wrap the student (instance of `DistillableVisionTransformer`) and the teacher (any network that you want to use to train the student) with a `DistillationTrainer`: -``` +```python from distillable_vision_transformer import DistillableVisionTransformer, DistillationTrainer student = DistillableVisionTransformer.from_pretrained('ViT-B_16') @@ -78,7 +78,7 @@ trainer = DistillationTrainer(teacher=teacher, student=student) # where teacher For the loss function, it is recommended that you use the `DistilledLoss` class, which is a kind of hybrid between cross-entropy and KL-divergence loss. It takes as arguments `teacher_logits`, `student_logits`, and `distill_logits`, which are obtained from the forward pass on `DistillationTrainer`, as well as the true labels `labels`. -``` +```python from distillable_vision_transformer import DistilledLoss loss_fn = DistilledLoss(alpha=0.5, temperature=3.0) @@ -89,7 +89,7 @@ loss = loss_fn(teacher_logits, student_logits, distill_logits, labels) For inference, we want to use the `DistillableVisionTransformer` instance, not its `DistillationTrainer` wrapper. -``` +```python import torch from distillable_vision_transformer import DistillableVisionTransformer