Skip to content

Commit

Permalink
merge configs to from_config
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 24, 2023
1 parent 51b1b9a commit 8bcf593
Showing 1 changed file with 20 additions and 26 deletions.
46 changes: 20 additions & 26 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,6 @@
from ..utils import torch_hub_download


__all__ = ["ViT"]


configs = dict(
Ti=dict(n_layers=12, d_model=192, n_heads=3),
S=dict(n_layers=12, d_model=384, n_heads=6),
B=dict(n_layers=12, d_model=768, n_heads=12),
L=dict(n_layers=24, d_model=1024, n_heads=16),
H=dict(n_layers=32, d_model=1280, n_heads=16),
)

checkpoints = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}


class MHA(nn.Module):
def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None:
super().__init__()
Expand Down Expand Up @@ -137,14 +116,29 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
@staticmethod
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT:
if pretrained:
if (variant, patch_size) not in checkpoints:
raise ValueError(f"There is no pre-trained checkpoint for ViT-{variant}/{patch_size}")
url = "https://storage.googleapis.com/vit_models/augreg/" + checkpoints[(variant, patch_size)]
m = ViT.from_jax_weights(torch_hub_download(url))
ckpt = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/vit_models/augreg/"
m = ViT.from_jax_weights(torch_hub_download(base_url + ckpt))
if img_size != 224:
m.resize_pe(img_size)

else:
m = ViT(**configs[variant], patch_size=patch_size, img_size=img_size)
n_layers, d_model, n_heads = dict(
Ti=(12, 192, 3),
S=(12, 384, 6),
B=(12, 768, 12),
L=(24, 1024, 16),
H=(32, 1280, 16),
)[variant]
m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

return m

# weights from https://github.com/google-research/vision_transformer
Expand Down

0 comments on commit 8bcf593

Please sign in to comment.