From 6fca040b7dd8ddb7934ab99b0967965b852ca918 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 20 Aug 2023 13:43:35 +0800 Subject: [PATCH] slightly optimize ViT --- vision_toolbox/backbones/vit.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index bcd2dbd..850450c 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -131,9 +131,7 @@ def forward(self, imgs: Tensor) -> Tensor: if self.cls_token is not None: out = torch.cat([self.cls_token, out], 1) out = self.layers(out + self.pe) - out = self.norm(out) - out = out[:, 0] if self.cls_token is not None else out.mean(1) - return out + return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1) @torch.no_grad() def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: