Skip to content

Commit

Permalink
refactor jax weight loading
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Oct 29, 2023
1 parent d38b613 commit b716d9c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 73 deletions.
40 changes: 12 additions & 28 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor, nn

from ..utils import torch_hub_download
from .vit import MLP
from .vit import MLP, load_jax_conv2d, load_jax_linear, load_jax_ln


class MixerBlock(nn.Module):
Expand Down Expand Up @@ -84,33 +84,17 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
return m

@torch.no_grad()
def load_jax_weights(self, path: str) -> MLPMixer:
jax_weights: Mapping[str, np.ndarray] = np.load(path)
def load_jax_weights(self, path: str) -> None:
jax_weights = {k: torch.from_numpy(v) for k, v in np.load(path).items()}

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("stem/bias"))
load_jax_conv2d(self.patch_embed, jax_weights, "stem")
load_jax_ln(self.norm, jax_weights, "pre_head_layer_norm")

for i, layer in enumerate(self.layers):
layer: MixerBlock
prefix = f"MixerBlock_{i}/"

layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T)
layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias"))
layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T)
layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias"))

layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale"))
layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias"))
layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T)
layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias"))
layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T)
layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias"))

self.norm.weight.copy_(get_w("pre_head_layer_norm/scale"))
self.norm.bias.copy_(get_w("pre_head_layer_norm/bias"))
return self
load_jax_ln(layer.norm1, jax_weights, f"MixerBlock_{i}/LayerNorm_0")
load_jax_linear(layer.token_mixing.linear1, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_0")
load_jax_linear(layer.token_mixing.linear2, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_1")

load_jax_ln(layer.norm2, jax_weights, f"MixerBlock_{i}/LayerNorm_1")
load_jax_linear(layer.channel_mixing.linear1, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_0")
load_jax_linear(layer.channel_mixing.linear2, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_1")
105 changes: 60 additions & 45 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
import torch
Expand Down Expand Up @@ -153,14 +152,17 @@ def from_config(variant: str, img_size: int, *, weights: str | None = None) -> V
if weights == "augreg":
assert img_size == 224
ckpt = {
("Ti", 16): "augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
("S", 16): "augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
("B", 32): "augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
("B", 16): "augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}[(variant, patch_size)]
m.load_vision_transformer_jax_weights(ckpt)
m.load_jax_ckpt(f"augreg/{ckpt}")

elif weights == "siglip":
raise NotImplementedError

elif not weights is None:
raise ValueError(f"Unsupported weights={weights}")
Expand All @@ -169,44 +171,57 @@ def from_config(variant: str, img_size: int, *, weights: str | None = None) -> V

# https://github.com/google-research/vision_transformer
@torch.no_grad()
def load_vision_transformer_jax_weights(self, ckpt: str) -> ViT:
base_url = "https://storage.googleapis.com/vit_models/"
path = torch_hub_download(base_url + ckpt)
jax_weights: Mapping[str, np.ndarray] = np.load(path)
def load_jax_ckpt(self, ckpt: str, big_vision: bool = False, prefix: str = "") -> None:
if big_vision:
gcs_bucket = "big_vision"
mha_norm = "LayerNorm_0"
mha = "MultiHeadDotProductAttention_0"
mlp_norm = "LayerNorm_1"
mlp = "MlpBlock_0"

else:
gcs_bucket = "vit_models"
mha_norm = "LayerNorm_0"
mha = "MultiHeadDotProductAttention_1"
mlp_norm = "LayerNorm_2"
mlp = "MlpBlock_3"

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])
path = torch_hub_download(f"https://storage.googleapis.com/{gcs_bucket}/{ckpt}")
jax_weights = {k.lstrip(prefix): torch.from_numpy(v) for k, v in np.load(path).items() if k.startswith(prefix)}

self.cls_token.copy_(get_w("cls"))
pe = get_w("Transformer/posembed_input/pos_embedding")
self.cls_token.copy_(jax_weights["cls"])
pe = jax_weights["Transformer/posembed_input/pos_embedding"]
self.cls_token.add_(pe[:, 0])
self.pe.copy_(pe[:, 1:])
self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("embedding/bias"))

for idx, layer in enumerate(self.layers):
layer: ViTBlock
prefix = f"Transformer/encoderblock_{idx}/"
mha_prefix = prefix + "MultiHeadDotProductAttention_1/"

layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.mha[1].q_proj.weight.copy_(get_w(mha_prefix + "query/kernel").flatten(1).T)
layer.mha[1].k_proj.weight.copy_(get_w(mha_prefix + "key/kernel").flatten(1).T)
layer.mha[1].v_proj.weight.copy_(get_w(mha_prefix + "value/kernel").flatten(1).T)
layer.mha[1].q_proj.bias.copy_(get_w(mha_prefix + "query/bias").flatten())
layer.mha[1].k_proj.bias.copy_(get_w(mha_prefix + "key/bias").flatten())
layer.mha[1].v_proj.bias.copy_(get_w(mha_prefix + "value/bias").flatten())
layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T)
layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))

layer.mlp[0].weight.copy_(get_w(prefix + "LayerNorm_2/scale"))
layer.mlp[0].bias.copy_(get_w(prefix + "LayerNorm_2/bias"))
layer.mlp[1].linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T)
layer.mlp[1].linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias"))
layer.mlp[1].linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp[1].linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))

self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
return self
load_jax_conv2d(self.patch_embed, jax_weights, "embedding")
load_jax_ln(self.norm, jax_weights, "Transformer/encoder_norm")

for i, layer in enumerate(self.layers):
jax_prefix = f"Transformer/encoderblock_{i}"

load_jax_ln(layer.mha[0], jax_weights, f"{jax_prefix}/{mha_norm}")
for x in ("query", "key", "value"):
proj = getattr(layer.mha[1], f"{x[0]}_proj")
proj.weight.copy_(jax_weights[f"{jax_prefix}/{mha}/{x}/kernel"].flatten(1).T)
proj.bias.copy_(jax_weights[f"{jax_prefix}/{mha}/{x}/bias"].flatten())
layer.mha[1].out_proj.weight.copy_(jax_weights[f"{jax_prefix}/{mha}/out/kernel"].flatten(0, 1).T)
layer.mha[1].out_proj.bias.copy_(jax_weights[f"{jax_prefix}/{mha}/out/bias"].flatten())

load_jax_ln(layer.mlp[0], jax_weights, f"{jax_prefix}/{mlp_norm}")
load_jax_linear(layer.mlp[1].linear1, jax_weights, f"{jax_prefix}/{mlp}/Dense_0")
load_jax_linear(layer.mlp[1].linear2, jax_weights, f"{jax_prefix}/{mlp}/Dense_1")


def load_jax_ln(norm: nn.LayerNorm, weights: dict[str, Tensor], prefix: str) -> None:
norm.weight.copy_(weights[f"{prefix}/scale"])
norm.bias.copy_(weights[f"{prefix}/bias"])


def load_jax_linear(linear: nn.Linear, weights: dict[str, Tensor], prefix: str) -> None:
linear.weight.copy_(weights[f"{prefix}/kernel"].T)
linear.bias.copy_(weights[f"{prefix}/bias"])


def load_jax_conv2d(conv2d: nn.Conv2d, weights: dict[str, Tensor], prefix: str) -> None:
conv2d.weight.copy_(weights[f"{prefix}/kernel"].permute(3, 2, 0, 1))
conv2d.bias.copy_(weights[f"{prefix}/bias"])

0 comments on commit b716d9c

Please sign in to comment.