diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 69adb75..a84f21e 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -198,15 +198,8 @@ def load_jax_ckpt(self, ckpt: str, big_vision: bool = False, prefix: str = "") - for i, layer in enumerate(self.layers): jax_prefix = f"Transformer/encoderblock_{i}" - load_jax_ln(layer.mha[0], jax_weights, f"{jax_prefix}/{mha_norm}") - for x in ("query", "key", "value"): - proj = getattr(layer.mha[1], f"{x[0]}_proj") - proj.weight.copy_(jax_weights[f"{jax_prefix}/{mha}/{x}/kernel"].flatten(1).T) - proj.bias.copy_(jax_weights[f"{jax_prefix}/{mha}/{x}/bias"].flatten()) - layer.mha[1].out_proj.weight.copy_(jax_weights[f"{jax_prefix}/{mha}/out/kernel"].flatten(0, 1).T) - layer.mha[1].out_proj.bias.copy_(jax_weights[f"{jax_prefix}/{mha}/out/bias"].flatten()) - + load_jax_mha(layer.mha[1], jax_weights, f"{jax_prefix}/{mha}") load_jax_ln(layer.mlp[0], jax_weights, f"{jax_prefix}/{mlp_norm}") load_jax_linear(layer.mlp[1].linear1, jax_weights, f"{jax_prefix}/{mlp}/Dense_0") load_jax_linear(layer.mlp[1].linear2, jax_weights, f"{jax_prefix}/{mlp}/Dense_1") @@ -225,3 +218,14 @@ def load_jax_linear(linear: nn.Linear, weights: dict[str, Tensor], prefix: str) def load_jax_conv2d(conv2d: nn.Conv2d, weights: dict[str, Tensor], prefix: str) -> None: conv2d.weight.copy_(weights[f"{prefix}/kernel"].permute(3, 2, 0, 1)) conv2d.bias.copy_(weights[f"{prefix}/bias"]) + + +def load_jax_mha(mha: MHA, weights: dict[str, Tensor], prefix: str) -> None: + mha.q_proj.weight.copy_(weights[f"{prefix}/query/kernel"].flatten(1).T) + mha.q_proj.bias.copy_(weights[f"{prefix}/query/bias"].flatten()) + mha.k_proj.weight.copy_(weights[f"{prefix}/key/kernel"].flatten(1).T) + mha.k_proj.bias.copy_(weights[f"{prefix}/key/bias"].flatten()) + mha.v_proj.weight.copy_(weights[f"{prefix}/value/kernel"].flatten(1).T) + mha.v_proj.bias.copy_(weights[f"{prefix}/value/bias"].flatten()) + mha.out_proj.weight.copy_(weights[f"{prefix}/out/kernel"].flatten(0, 1).T) + mha.out_proj.bias.copy_(weights[f"{prefix}/out/bias"].flatten())