Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Oct 29, 2023
1 parent b716d9c commit 80b87d5
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,8 @@ def load_jax_ckpt(self, ckpt: str, big_vision: bool = False, prefix: str = "") -

for i, layer in enumerate(self.layers):
jax_prefix = f"Transformer/encoderblock_{i}"

load_jax_ln(layer.mha[0], jax_weights, f"{jax_prefix}/{mha_norm}")
for x in ("query", "key", "value"):
proj = getattr(layer.mha[1], f"{x[0]}_proj")
proj.weight.copy_(jax_weights[f"{jax_prefix}/{mha}/{x}/kernel"].flatten(1).T)
proj.bias.copy_(jax_weights[f"{jax_prefix}/{mha}/{x}/bias"].flatten())
layer.mha[1].out_proj.weight.copy_(jax_weights[f"{jax_prefix}/{mha}/out/kernel"].flatten(0, 1).T)
layer.mha[1].out_proj.bias.copy_(jax_weights[f"{jax_prefix}/{mha}/out/bias"].flatten())

load_jax_mha(layer.mha[1], jax_weights, f"{jax_prefix}/{mha}")
load_jax_ln(layer.mlp[0], jax_weights, f"{jax_prefix}/{mlp_norm}")
load_jax_linear(layer.mlp[1].linear1, jax_weights, f"{jax_prefix}/{mlp}/Dense_0")
load_jax_linear(layer.mlp[1].linear2, jax_weights, f"{jax_prefix}/{mlp}/Dense_1")
Expand All @@ -225,3 +218,14 @@ def load_jax_linear(linear: nn.Linear, weights: dict[str, Tensor], prefix: str)
def load_jax_conv2d(conv2d: nn.Conv2d, weights: dict[str, Tensor], prefix: str) -> None:
conv2d.weight.copy_(weights[f"{prefix}/kernel"].permute(3, 2, 0, 1))
conv2d.bias.copy_(weights[f"{prefix}/bias"])


def load_jax_mha(mha: MHA, weights: dict[str, Tensor], prefix: str) -> None:
mha.q_proj.weight.copy_(weights[f"{prefix}/query/kernel"].flatten(1).T)
mha.q_proj.bias.copy_(weights[f"{prefix}/query/bias"].flatten())
mha.k_proj.weight.copy_(weights[f"{prefix}/key/kernel"].flatten(1).T)
mha.k_proj.bias.copy_(weights[f"{prefix}/key/bias"].flatten())
mha.v_proj.weight.copy_(weights[f"{prefix}/value/kernel"].flatten(1).T)
mha.v_proj.bias.copy_(weights[f"{prefix}/value/bias"].flatten())
mha.out_proj.weight.copy_(weights[f"{prefix}/out/kernel"].flatten(0, 1).T)
mha.out_proj.bias.copy_(weights[f"{prefix}/out/bias"].flatten())

0 comments on commit 80b87d5

Please sign in to comment.