Skip to content

Commit

Permalink
add resize PE
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 23, 2023
1 parent 68894b3 commit ff7e3af
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
10 changes: 10 additions & 0 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

from vision_toolbox.backbones import ViT


def test_resize_pe():
m = ViT.from_config("Ti", 16, 224)
m(torch.randn(1, 3, 224, 224))
m.resize_pe(256)
m(torch.randn(1, 3, 256, 256))
24 changes: 21 additions & 3 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# https://arxiv.org/abs/2106.10270
# https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py

from __future__ import annotations

from typing import Mapping

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn


Expand Down Expand Up @@ -33,7 +36,7 @@ def __init__(
cls_token: bool = True,
dropout: float = 0.0,
norm_eps: float = 1e-6,
):
) -> None:
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None
Expand All @@ -58,13 +61,28 @@ def forward(self, imgs: Tensor) -> Tensor:
out = out[:, 0] if self.cls_token is not None else out.mean(1)
return out

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
pe = self.pe if self.cls_token is None else self.pe[:, 1:]

old_size = int(pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)

if self.cls_token is not None:
pe = torch.cat((self.pe[:, 0:1], pe), 1)

self.pe = nn.Parameter(pe)

@staticmethod
def from_config(variant: str, patch_size: int, img_size: int) -> "ViT":
def from_config(variant: str, patch_size: int, img_size: int) -> ViT:
return ViT(**configs[variant], patch_size=patch_size, img_size=img_size)

# weights from https://github.com/google-research/vision_transformer
@staticmethod
def from_jax_weights(path: str) -> "ViT":
def from_jax_weights(path: str) -> ViT:
jax_weights: Mapping[str, np.ndarray] = np.load(path)

n_layers = 1
Expand Down

0 comments on commit ff7e3af

Please sign in to comment.