From b716d9cfcf5961f193696d5ef0b46426b1265711 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 29 Oct 2023 10:17:19 +0800 Subject: [PATCH] refactor jax weight loading --- vision_toolbox/backbones/mlp_mixer.py | 40 +++------- vision_toolbox/backbones/vit.py | 105 +++++++++++++++----------- 2 files changed, 72 insertions(+), 73 deletions(-) diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index 85d0081..3308a44 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -10,7 +10,7 @@ from torch import Tensor, nn from ..utils import torch_hub_download -from .vit import MLP +from .vit import MLP, load_jax_conv2d, load_jax_linear, load_jax_ln class MixerBlock(nn.Module): @@ -84,33 +84,17 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = return m @torch.no_grad() - def load_jax_weights(self, path: str) -> MLPMixer: - jax_weights: Mapping[str, np.ndarray] = np.load(path) + def load_jax_weights(self, path: str) -> None: + jax_weights = {k: torch.from_numpy(v) for k, v in np.load(path).items()} - def get_w(key: str) -> Tensor: - return torch.from_numpy(jax_weights[key]) - - self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1)) - self.patch_embed.bias.copy_(get_w("stem/bias")) + load_jax_conv2d(self.patch_embed, jax_weights, "stem") + load_jax_ln(self.norm, jax_weights, "pre_head_layer_norm") for i, layer in enumerate(self.layers): - layer: MixerBlock - prefix = f"MixerBlock_{i}/" - - layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale")) - layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias")) - layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T) - layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias")) - layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T) - layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias")) - - layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale")) - layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias")) - layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T) - layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias")) - layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T) - layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias")) - - self.norm.weight.copy_(get_w("pre_head_layer_norm/scale")) - self.norm.bias.copy_(get_w("pre_head_layer_norm/bias")) - return self + load_jax_ln(layer.norm1, jax_weights, f"MixerBlock_{i}/LayerNorm_0") + load_jax_linear(layer.token_mixing.linear1, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_0") + load_jax_linear(layer.token_mixing.linear2, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_1") + + load_jax_ln(layer.norm2, jax_weights, f"MixerBlock_{i}/LayerNorm_1") + load_jax_linear(layer.channel_mixing.linear1, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_0") + load_jax_linear(layer.channel_mixing.linear2, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_1") diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 680c819..69adb75 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -5,7 +5,6 @@ from __future__ import annotations from functools import partial -from typing import Mapping import numpy as np import torch @@ -153,14 +152,17 @@ def from_config(variant: str, img_size: int, *, weights: str | None = None) -> V if weights == "augreg": assert img_size == 224 ckpt = { - ("Ti", 16): "augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", - ("S", 32): "augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", - ("S", 16): "augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", - ("B", 32): "augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", - ("B", 16): "augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", - ("L", 16): "augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", + ("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", + ("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz", + ("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", + ("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz", + ("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", + ("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", }[(variant, patch_size)] - m.load_vision_transformer_jax_weights(ckpt) + m.load_jax_ckpt(f"augreg/{ckpt}") + + elif weights == "siglip": + raise NotImplementedError elif not weights is None: raise ValueError(f"Unsupported weights={weights}") @@ -169,44 +171,57 @@ def from_config(variant: str, img_size: int, *, weights: str | None = None) -> V # https://github.com/google-research/vision_transformer @torch.no_grad() - def load_vision_transformer_jax_weights(self, ckpt: str) -> ViT: - base_url = "https://storage.googleapis.com/vit_models/" - path = torch_hub_download(base_url + ckpt) - jax_weights: Mapping[str, np.ndarray] = np.load(path) + def load_jax_ckpt(self, ckpt: str, big_vision: bool = False, prefix: str = "") -> None: + if big_vision: + gcs_bucket = "big_vision" + mha_norm = "LayerNorm_0" + mha = "MultiHeadDotProductAttention_0" + mlp_norm = "LayerNorm_1" + mlp = "MlpBlock_0" + + else: + gcs_bucket = "vit_models" + mha_norm = "LayerNorm_0" + mha = "MultiHeadDotProductAttention_1" + mlp_norm = "LayerNorm_2" + mlp = "MlpBlock_3" - def get_w(key: str) -> Tensor: - return torch.from_numpy(jax_weights[key]) + path = torch_hub_download(f"https://storage.googleapis.com/{gcs_bucket}/{ckpt}") + jax_weights = {k.lstrip(prefix): torch.from_numpy(v) for k, v in np.load(path).items() if k.startswith(prefix)} - self.cls_token.copy_(get_w("cls")) - pe = get_w("Transformer/posembed_input/pos_embedding") + self.cls_token.copy_(jax_weights["cls"]) + pe = jax_weights["Transformer/posembed_input/pos_embedding"] self.cls_token.add_(pe[:, 0]) self.pe.copy_(pe[:, 1:]) - self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1)) - self.patch_embed.bias.copy_(get_w("embedding/bias")) - - for idx, layer in enumerate(self.layers): - layer: ViTBlock - prefix = f"Transformer/encoderblock_{idx}/" - mha_prefix = prefix + "MultiHeadDotProductAttention_1/" - - layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale")) - layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias")) - layer.mha[1].q_proj.weight.copy_(get_w(mha_prefix + "query/kernel").flatten(1).T) - layer.mha[1].k_proj.weight.copy_(get_w(mha_prefix + "key/kernel").flatten(1).T) - layer.mha[1].v_proj.weight.copy_(get_w(mha_prefix + "value/kernel").flatten(1).T) - layer.mha[1].q_proj.bias.copy_(get_w(mha_prefix + "query/bias").flatten()) - layer.mha[1].k_proj.bias.copy_(get_w(mha_prefix + "key/bias").flatten()) - layer.mha[1].v_proj.bias.copy_(get_w(mha_prefix + "value/bias").flatten()) - layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T) - layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias")) - - layer.mlp[0].weight.copy_(get_w(prefix + "LayerNorm_2/scale")) - layer.mlp[0].bias.copy_(get_w(prefix + "LayerNorm_2/bias")) - layer.mlp[1].linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) - layer.mlp[1].linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) - layer.mlp[1].linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) - layer.mlp[1].linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) - - self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale")) - self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias")) - return self + load_jax_conv2d(self.patch_embed, jax_weights, "embedding") + load_jax_ln(self.norm, jax_weights, "Transformer/encoder_norm") + + for i, layer in enumerate(self.layers): + jax_prefix = f"Transformer/encoderblock_{i}" + + load_jax_ln(layer.mha[0], jax_weights, f"{jax_prefix}/{mha_norm}") + for x in ("query", "key", "value"): + proj = getattr(layer.mha[1], f"{x[0]}_proj") + proj.weight.copy_(jax_weights[f"{jax_prefix}/{mha}/{x}/kernel"].flatten(1).T) + proj.bias.copy_(jax_weights[f"{jax_prefix}/{mha}/{x}/bias"].flatten()) + layer.mha[1].out_proj.weight.copy_(jax_weights[f"{jax_prefix}/{mha}/out/kernel"].flatten(0, 1).T) + layer.mha[1].out_proj.bias.copy_(jax_weights[f"{jax_prefix}/{mha}/out/bias"].flatten()) + + load_jax_ln(layer.mlp[0], jax_weights, f"{jax_prefix}/{mlp_norm}") + load_jax_linear(layer.mlp[1].linear1, jax_weights, f"{jax_prefix}/{mlp}/Dense_0") + load_jax_linear(layer.mlp[1].linear2, jax_weights, f"{jax_prefix}/{mlp}/Dense_1") + + +def load_jax_ln(norm: nn.LayerNorm, weights: dict[str, Tensor], prefix: str) -> None: + norm.weight.copy_(weights[f"{prefix}/scale"]) + norm.bias.copy_(weights[f"{prefix}/bias"]) + + +def load_jax_linear(linear: nn.Linear, weights: dict[str, Tensor], prefix: str) -> None: + linear.weight.copy_(weights[f"{prefix}/kernel"].T) + linear.bias.copy_(weights[f"{prefix}/bias"]) + + +def load_jax_conv2d(conv2d: nn.Conv2d, weights: dict[str, Tensor], prefix: str) -> None: + conv2d.weight.copy_(weights[f"{prefix}/kernel"].permute(3, 2, 0, 1)) + conv2d.bias.copy_(weights[f"{prefix}/bias"])