Skip to content

Commit

Permalink
add more pre-trained ckpts
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 7, 2023
1 parent 7c4d93f commit d029bde
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,22 @@ def forward(self, x: Tensor) -> Tensor:
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> MLPMixer:
# Table 1 in https://arxiv.org/pdf/2105.01601.pdf
n_layers, d_model, tokens_mlp_dim, channels_mlp_dim = dict(
S=(8, 512, 256, 2048), B=(12, 768, 384, 3072), L=(24, 1024, 512, 4096), H=(32, 1280, 640, 5120)
S=(8, 512, 256, 2048),
B=(12, 768, 384, 3072),
L=(24, 1024, 512, 4096),
H=(32, 1280, 640, 5120),
)[variant]
m = MLPMixer(n_layers, d_model, patch_size, img_size, tokens_mlp_dim, channels_mlp_dim)
if pretrained:
ckpt = {("B", 16): "Mixer-B_16.npz", ("L", 16): "Mixer-L_16.npz"}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/mixer_models/imagenet21k/"
ckpt = {
("S", 8): "gsam/Mixer-S_8.npz",
("S", 16): "gsam/Mixer-S_16.npz",
("S", 32): "gsam/Mixer-S_32.npz",
("B", 16): "imagenet21k/Mixer-B_16.npz", # also available: gsam, sam
("B", 32): "gsam/Mixer-B_32.npz", # also availale: sam
("L", 16): "imagenet21k/Mixer-L_16.npz",
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/mixer_models/"
m.load_jax_weights(torch_hub_download(base_url + ckpt))
return m

Expand Down

0 comments on commit d029bde

Please sign in to comment.