diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index bcd2dbd..850450c 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -131,9 +131,7 @@ def forward(self, imgs: Tensor) -> Tensor: if self.cls_token is not None: out = torch.cat([self.cls_token, out], 1) out = self.layers(out + self.pe) - out = self.norm(out) - out = out[:, 0] if self.cls_token is not None else out.mean(1) - return out + return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1) @torch.no_grad() def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: