From d38b6139cd07782c101c38f13febf82af4d8d180 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 28 Oct 2023 09:07:39 +0800 Subject: [PATCH] change function signature --- tests/test_vit.py | 2 +- vision_toolbox/backbones/vit.py | 28 ++++++++++++++++------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/test_vit.py b/tests/test_vit.py index 207bfc0..2fb3bdf 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -17,7 +17,7 @@ def test_resize_pe(): def test_from_pretrained(): - m = ViT.from_config("Ti_16", 224, True).eval() + m = ViT.from_config("Ti_16", 224, weights="augreg").eval() x = torch.randn(1, 3, 224, 224) out = m(x) diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 5ded0b0..680c819 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -136,7 +136,7 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: self.pe = nn.Parameter(pe) @staticmethod - def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT: + def from_config(variant: str, img_size: int, *, weights: str | None = None) -> ViT: variant, patch_size = variant.split("_") d_model, depth, n_heads = dict( @@ -150,24 +150,28 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT: patch_size = int(patch_size) m = ViT(d_model, depth, n_heads, patch_size, img_size) - if pretrained: + if weights == "augreg": assert img_size == 224 ckpt = { - ("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", - ("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", - ("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", - ("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", - ("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", - ("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", + ("Ti", 16): "augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", + ("S", 32): "augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", + ("S", 16): "augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", + ("B", 32): "augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", + ("B", 16): "augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", + ("L", 16): "augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", }[(variant, patch_size)] - base_url = "https://storage.googleapis.com/vit_models/augreg/" - m.load_jax_weights(torch_hub_download(base_url + ckpt)) + m.load_vision_transformer_jax_weights(ckpt) + + elif not weights is None: + raise ValueError(f"Unsupported weights={weights}") return m - # weights from https://github.com/google-research/vision_transformer + # https://github.com/google-research/vision_transformer @torch.no_grad() - def load_jax_weights(self, path: str) -> ViT: + def load_vision_transformer_jax_weights(self, ckpt: str) -> ViT: + base_url = "https://storage.googleapis.com/vit_models/" + path = torch_hub_download(base_url + ckpt) jax_weights: Mapping[str, np.ndarray] = np.load(path) def get_w(key: str) -> Tensor: