diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index 2b1be67..124f7ce 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -99,21 +99,25 @@ def get_w(key: str) -> Tensor: self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1)) self.patch_embed.bias.copy_(get_w("stem/bias")) + for i, layer in enumerate(self.layers): layer: MixerBlock prefix = f"MixerBlock_{i}/" + layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale")) layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias")) layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T) layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias")) layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T) layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias")) + layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale")) layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias")) layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T) layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias")) layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T) layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias")) + self.norm.weight.copy_(get_w("pre_head_layer_norm/scale")) self.norm.bias.copy_(get_w("pre_head_layer_norm/bias")) return self