From 10274545a2c35248e1f3659ff8c6da517aad90e9 Mon Sep 17 00:00:00 2001 From: vivekh2000 <90518409+vivekh2000@users.noreply.github.com> Date: Thu, 25 Jul 2024 22:01:20 +0530 Subject: [PATCH] Update distill.py to include device agnostic code for `distill_mlp` head and `distillation_token` Since in your code, `distillation_token` and `distill_mlp` head are defined in the DistillWrapper class, sending the model instance of the DistillableViT class to GPU. do not send them to GPU. While training a model using this code, I got a device mismatch error, which made it hard to figure out the source of the error. Finally, the `distillation_token` and `distill_mlp` turned out to be the culprits as they are not defined in the model class but in the DistillWrapper class. Therefore, I have suggested the following changes, when training a model on GPU, the training code should set the device="cude" if torch.cuda.is_available() else "cpu". or the same can be incorporated in the constructor of the DistillWrapper class. --- vit_pytorch/distill.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vit_pytorch/distill.py b/vit_pytorch/distill.py index b480e23..d814119 100644 --- a/vit_pytorch/distill.py +++ b/vit_pytorch/distill.py @@ -116,6 +116,7 @@ def __init__( super().__init__() assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer' + device = 'cuda' if torch.cuda.is_available() else 'cpu' self.teacher = teacher self.student = student @@ -125,12 +126,12 @@ def __init__( self.alpha = alpha self.hard = hard - self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) + self.distillation_token = nn.Parameter(torch.randn(1, 1, dim,device=device)) self.distill_mlp = nn.Sequential( nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(), nn.Linear(dim, num_classes) - ) + ).to(device) def forward(self, img, labels, temperature = None, alpha = None, **kwargs):