diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index de96db2..c2a7957 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -14,27 +14,6 @@ from ..utils import torch_hub_download -__all__ = ["ViT"] - - -configs = dict( - Ti=dict(n_layers=12, d_model=192, n_heads=3), - S=dict(n_layers=12, d_model=384, n_heads=6), - B=dict(n_layers=12, d_model=768, n_heads=12), - L=dict(n_layers=24, d_model=1024, n_heads=16), - H=dict(n_layers=32, d_model=1280, n_heads=16), -) - -checkpoints = { - ("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", - ("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", - ("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", - ("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", - ("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", - ("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", -} - - class MHA(nn.Module): def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: super().__init__() @@ -137,14 +116,29 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: @staticmethod def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT: if pretrained: - if (variant, patch_size) not in checkpoints: - raise ValueError(f"There is no pre-trained checkpoint for ViT-{variant}/{patch_size}") - url = "https://storage.googleapis.com/vit_models/augreg/" + checkpoints[(variant, patch_size)] - m = ViT.from_jax_weights(torch_hub_download(url)) + ckpt = { + ("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", + ("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", + ("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", + ("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", + ("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", + ("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", + }[(variant, patch_size)] + base_url = "https://storage.googleapis.com/vit_models/augreg/" + m = ViT.from_jax_weights(torch_hub_download(base_url + ckpt)) if img_size != 224: m.resize_pe(img_size) + else: - m = ViT(**configs[variant], patch_size=patch_size, img_size=img_size) + n_layers, d_model, n_heads = dict( + Ti=(12, 192, 3), + S=(12, 384, 6), + B=(12, 768, 12), + L=(24, 1024, 16), + H=(32, 1280, 16), + )[variant] + m = ViT(n_layers, d_model, n_heads, patch_size, img_size) + return m # weights from https://github.com/google-research/vision_transformer