From d029bdedbb1770ab5e157ca68589939c30c4ffe5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 7 Aug 2023 23:29:45 +0800 Subject: [PATCH] add more pre-trained ckpts --- vision_toolbox/backbones/mlp_mixer.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index d69689e..f109d7c 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -77,12 +77,22 @@ def forward(self, x: Tensor) -> Tensor: def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> MLPMixer: # Table 1 in https://arxiv.org/pdf/2105.01601.pdf n_layers, d_model, tokens_mlp_dim, channels_mlp_dim = dict( - S=(8, 512, 256, 2048), B=(12, 768, 384, 3072), L=(24, 1024, 512, 4096), H=(32, 1280, 640, 5120) + S=(8, 512, 256, 2048), + B=(12, 768, 384, 3072), + L=(24, 1024, 512, 4096), + H=(32, 1280, 640, 5120), )[variant] m = MLPMixer(n_layers, d_model, patch_size, img_size, tokens_mlp_dim, channels_mlp_dim) if pretrained: - ckpt = {("B", 16): "Mixer-B_16.npz", ("L", 16): "Mixer-L_16.npz"}[(variant, patch_size)] - base_url = "https://storage.googleapis.com/mixer_models/imagenet21k/" + ckpt = { + ("S", 8): "gsam/Mixer-S_8.npz", + ("S", 16): "gsam/Mixer-S_16.npz", + ("S", 32): "gsam/Mixer-S_32.npz", + ("B", 16): "imagenet21k/Mixer-B_16.npz", # also available: gsam, sam + ("B", 32): "gsam/Mixer-B_32.npz", # also availale: sam + ("L", 16): "imagenet21k/Mixer-L_16.npz", + }[(variant, patch_size)] + base_url = "https://storage.googleapis.com/mixer_models/" m.load_jax_weights(torch_hub_download(base_url + ckpt)) return m