Skip to content

Commit

Permalink
use broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent 36e82b7 commit b735bee
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token.expand(out.shape[0], -1, -1), out], 1)
out = torch.cat([self.cls_token, out], 1)
out = self.layers(out + self.pe)
out = self.norm(out)
out = out[:, 0] if self.cls_token is not None else out.mean(1)
Expand Down

0 comments on commit b735bee

Please sign in to comment.