Skip to content

Commit

Permalink
add pre-trained checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 23, 2023
1 parent a46f670 commit 9dec23d
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch.nn.functional as F
from torch import Tensor, nn

from ..utils import torch_hub_download


__all__ = ["ViT"]

Expand All @@ -23,6 +25,15 @@
H=dict(n_layers=32, d_model=1280, n_heads=16),
)

checkpoints = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}


class MHA(nn.Module):
def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None:
Expand Down Expand Up @@ -117,8 +128,16 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
self.pe = nn.Parameter(pe)

@staticmethod
def from_config(variant: str, patch_size: int, img_size: int) -> ViT:
return ViT(**configs[variant], patch_size=patch_size, img_size=img_size)
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT:
if pretrained:
if (variant, patch_size) not in checkpoints:
raise ValueError(f"There is no pre-trained checkpoint for ViT-{variant}/{patch_size}")
m = ViT.from_jax_weights(torch_hub_download(checkpoints[(variant, patch_size)]))
if img_size != 224:
m.resize_pe(img_size)
else:
m = ViT(**configs[variant], patch_size=patch_size, img_size=img_size)
return m

# weights from https://github.com/google-research/vision_transformer
@torch.no_grad()
Expand Down

0 comments on commit 9dec23d

Please sign in to comment.