Skip to content

Commit

Permalink
slightly optimize ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent 7d5698e commit 6fca040
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ def forward(self, imgs: Tensor) -> Tensor:
if self.cls_token is not None:
out = torch.cat([self.cls_token, out], 1)
out = self.layers(out + self.pe)
out = self.norm(out)
out = out[:, 0] if self.cls_token is not None else out.mean(1)
return out
return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1)

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
Expand Down

0 comments on commit 6fca040

Please sign in to comment.