Skip to content

Commit

Permalink
change function signature
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Oct 28, 2023
1 parent f6d7c4b commit d38b613
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_resize_pe():


def test_from_pretrained():
m = ViT.from_config("Ti_16", 224, True).eval()
m = ViT.from_config("Ti_16", 224, weights="augreg").eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

Expand Down
28 changes: 16 additions & 12 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
self.pe = nn.Parameter(pe)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
def from_config(variant: str, img_size: int, *, weights: str | None = None) -> ViT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Expand All @@ -150,24 +150,28 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
patch_size = int(patch_size)
m = ViT(d_model, depth, n_heads, patch_size, img_size)

if pretrained:
if weights == "augreg":
assert img_size == 224
ckpt = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
("Ti", 16): "augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
("S", 16): "augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
("B", 32): "augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
("B", 16): "augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/vit_models/augreg/"
m.load_jax_weights(torch_hub_download(base_url + ckpt))
m.load_vision_transformer_jax_weights(ckpt)

elif not weights is None:
raise ValueError(f"Unsupported weights={weights}")

return m

# weights from https://github.com/google-research/vision_transformer
# https://github.com/google-research/vision_transformer
@torch.no_grad()
def load_jax_weights(self, path: str) -> ViT:
def load_vision_transformer_jax_weights(self, ckpt: str) -> ViT:
base_url = "https://storage.googleapis.com/vit_models/"
path = torch_hub_download(base_url + ckpt)
jax_weights: Mapping[str, np.ndarray] = np.load(path)

def get_w(key: str) -> Tensor:
Expand Down

0 comments on commit d38b613

Please sign in to comment.