Skip to content

Commit

Permalink
use optimize ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent 8971381 commit 5d5a8f2
Showing 1 changed file with 15 additions and 28 deletions.
43 changes: 15 additions & 28 deletions vision_toolbox/backbones/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

from ..components import LayerScale
from .base import _act, _norm
from .vit import ViTBlock
from .vit import ViT, ViTBlock


class DeiT(nn.Module):
class DeiT(ViT):
def __init__(
self,
d_model: int,
Expand All @@ -31,37 +31,20 @@ def __init__(
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio,
dropout, layer_scale_init, stochastic_depth, norm, act
)
# fmt: on
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pe = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2 + 2, d_model))
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential()
for _ in range(depth):
block = ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act)
self.layers.append(block)

self.norm = norm(d_model)

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = torch.cat([self.cls_token, self.dist_token, out], 1) + self.pe
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C)
out = torch.cat([self.cls_token, self.dist_token, out], 1)
out = self.layers(out)
return self.norm(out[:, :2]).mean(1)

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
pe = self.pe[:, 2:]
old_size = int(pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)
self.pe = nn.Parameter(torch.cat((self.pe[:, :2], pe), 1))

@staticmethod
def from_config(variant: str, img_size: int, version: bool = False, pretrained: bool = False) -> DeiT:
variant, patch_size = variant.split("_")
Expand Down Expand Up @@ -97,12 +80,16 @@ def copy_(m: nn.Linear | nn.LayerNorm, prefix: str):
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
pe = state_dict.pop("pos_embed")
self.pe.copy_(pe[:, -self.pe.shape[1] :])

self.cls_token.copy_(state_dict.pop("cls_token"))
self.cls_token.add_(pe[:, 0])
if self.dist_token is not None:
self.dist_token.copy_(state_dict.pop("dist_token"))
self.dist_token.add_(pe[:, 1])
state_dict.pop("head_dist.weight")
state_dict.pop("head_dist.bias")
self.pe.copy_(state_dict.pop("pos_embed"))

for i, block in enumerate(self.layers):
block: ViTBlock
Expand Down

0 comments on commit 5d5a8f2

Please sign in to comment.