Skip to content

Commit

Permalink
simplify from jax weights
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 23, 2023
1 parent ff7e3af commit e0cfc91
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 43 deletions.
9 changes: 9 additions & 0 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import torch

from vision_toolbox.backbones import ViT
from vision_toolbox.utils import torch_hub_download


def test_resize_pe():
m = ViT.from_config("Ti", 16, 224)
m(torch.randn(1, 3, 224, 224))
m.resize_pe(256)
m(torch.randn(1, 3, 256, 256))


def test_from_jax():
url = (
"https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
)
m = ViT.from_jax_weights(torch_hub_download(url))
m(torch.randn(1, 3, 224, 224))
79 changes: 36 additions & 43 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,22 @@ def from_config(variant: str, patch_size: int, img_size: int) -> ViT:
return ViT(**configs[variant], patch_size=patch_size, img_size=img_size)

# weights from https://github.com/google-research/vision_transformer
@torch.no_grad()
@staticmethod
def from_jax_weights(path: str) -> ViT:
jax_weights: Mapping[str, np.ndarray] = np.load(path)

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

def copy_layernorm(module: nn.LayerNorm, prefix: str) -> None:
module.weight.copy_(get_w(prefix + "scale"))
module.bias.copy_(get_w(prefix + "bias"))

def copy_linear(module: nn.Linear, prefix: str) -> None:
module.weight.copy_(get_w(prefix + "kernel").T)
module.bias.copy_(get_w(prefix + "bias"))

n_layers = 1
while True:
if f"Transformer/encoderblock_{n_layers}/LayerNorm_0/bias" not in jax_weights:
Expand All @@ -96,46 +108,27 @@ def from_jax_weights(path: str) -> ViT:
patch_size = jax_weights["embedding/kernel"].shape[0]
img_size = int((jax_weights["Transformer/posembed_input/pos_embedding"].shape[1] - 1) ** 0.5) * patch_size

def _get(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

torch_weights = dict()

def _convert_layer_norm(jax_prefix: str, torch_prefix: str) -> None:
torch_weights[f"{torch_prefix}.weight"] = _get(f"{jax_prefix}/scale")
torch_weights[f"{torch_prefix}.bias"] = _get(f"{jax_prefix}/bias")

def _convert_linear(jax_prefix: str, torch_prefix: str) -> None:
torch_weights[f"{torch_prefix}.weight"] = _get(f"{jax_prefix}/kernel").T
torch_weights[f"{torch_prefix}.bias"] = _get(f"{jax_prefix}/bias")

def _convert_mha(jax_prefix: str, torch_prefix: str) -> None:
w = torch.stack([_get(f"{jax_prefix}/{x}/kernel") for x in ["query", "key", "value"]], 1)
torch_weights[f"{torch_prefix}.in_proj_weight"] = w.flatten(1).T

b = torch.stack([_get(f"{jax_prefix}/{x}/bias") for x in ["query", "key", "value"]], 0)
torch_weights[f"{torch_prefix}.in_proj_bias"] = b.flatten()

torch_weights[f"{torch_prefix}.out_proj.weight"] = _get(f"{jax_prefix}/out/kernel").flatten(0, 1).T
torch_weights[f"{torch_prefix}.out_proj.bias"] = _get(f"{jax_prefix}/out/bias")

torch_weights["cls_token"] = _get("cls")
torch_weights["patch_embed.weight"] = _get("embedding/kernel").permute(3, 2, 0, 1)
torch_weights["patch_embed.bias"] = _get("embedding/bias")
torch_weights["pe"] = _get("Transformer/posembed_input/pos_embedding")

for idx in range(n_layers):
jax_prefix = f"Transformer/encoderblock_{idx}"
torch_prefix = f"encoder.layers.{idx}"

_convert_layer_norm(f"{jax_prefix}/LayerNorm_0", f"{torch_prefix}.norm1")
_convert_mha(f"{jax_prefix}/MultiHeadDotProductAttention_1", f"{torch_prefix}.self_attn")
_convert_layer_norm(f"{jax_prefix}/LayerNorm_2", f"{torch_prefix}.norm2")
_convert_linear(f"{jax_prefix}/MlpBlock_3/Dense_0", f"{torch_prefix}.linear1")
_convert_linear(f"{jax_prefix}/MlpBlock_3/Dense_1", f"{torch_prefix}.linear2")

_convert_layer_norm("Transformer/encoder_norm", "encoder.norm")

model = ViT(n_layers, d_model, n_heads, patch_size, img_size)
model.load_state_dict(torch_weights)
return model
m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

m.cls_token.copy_(get_w("cls"))
m.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
m.patch_embed.bias.copy_(get_w("embedding/bias"))
m.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))
copy_layernorm(m.encoder.norm, "Transformer/encoder_norm/")

for idx, layer in enumerate(m.encoder.layers):
prefix = f"Transformer/encoderblock_{idx}/"
copy_layernorm(layer.norm1, prefix + "LayerNorm_0/")
copy_layernorm(layer.norm2, prefix + "LayerNorm_2/")
copy_linear(layer.linear1, prefix + "MlpBlock_3/Dense_0/")
copy_linear(layer.linear2, prefix + "MlpBlock_3/Dense_1/")

mha_prefix = prefix + "MultiHeadDotProductAttention_1/"
w = torch.stack([get_w(mha_prefix + x + "/kernel") for x in ["query", "key", "value"]], 1)
b = torch.stack([get_w(mha_prefix + x + "/bias") for x in ["query", "key", "value"]], 0)
layer.self_attn.in_proj_weight.copy_(w.flatten(1).T)
layer.self_attn.in_proj_bias.copy_(b.flatten())
layer.self_attn.out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T)
layer.self_attn.out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))

return m
10 changes: 10 additions & 0 deletions vision_toolbox/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os

import torch


def torch_hub_download(url: str) -> str:
save_path = os.path.join(torch.hub.get_dir(), os.path.basename(url))
if not os.path.exists(save_path):
torch.hub.download_url_to_file(url, save_path)
return save_path

0 comments on commit e0cfc91

Please sign in to comment.