Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Oct 29, 2023
1 parent 80b87d5 commit 13dbac7
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,11 @@ def load_jax_ckpt(self, ckpt: str, big_vision: bool = False, prefix: str = "") -
load_jax_ln(self.norm, jax_weights, "Transformer/encoder_norm")

for i, layer in enumerate(self.layers):
jax_prefix = f"Transformer/encoderblock_{i}"
load_jax_ln(layer.mha[0], jax_weights, f"{jax_prefix}/{mha_norm}")
load_jax_mha(layer.mha[1], jax_weights, f"{jax_prefix}/{mha}")
load_jax_ln(layer.mlp[0], jax_weights, f"{jax_prefix}/{mlp_norm}")
load_jax_linear(layer.mlp[1].linear1, jax_weights, f"{jax_prefix}/{mlp}/Dense_0")
load_jax_linear(layer.mlp[1].linear2, jax_weights, f"{jax_prefix}/{mlp}/Dense_1")
load_jax_ln(layer.mha[0], jax_weights, f"Transformer/encoderblock_{i}/{mha_norm}")
load_jax_mha(layer.mha[1], jax_weights, f"Transformer/encoderblock_{i}/{mha}")
load_jax_ln(layer.mlp[0], jax_weights, f"Transformer/encoderblock_{i}/{mlp_norm}")
load_jax_linear(layer.mlp[1].linear1, jax_weights, f"Transformer/encoderblock_{i}/{mlp}/Dense_0")
load_jax_linear(layer.mlp[1].linear2, jax_weights, f"Transformer/encoderblock_{i}/{mlp}/Dense_1")


def load_jax_ln(norm: nn.LayerNorm, weights: dict[str, Tensor], prefix: str) -> None:
Expand Down

0 comments on commit 13dbac7

Please sign in to comment.