From 264a07fc76e0e4120c3d6dc2d8d36ca3fe4b7120 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 8 Aug 2023 21:50:34 +0800 Subject: [PATCH] add spacing --- vision_toolbox/backbones/mlp_mixer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index 2b1be67..124f7ce 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -99,21 +99,25 @@ def get_w(key: str) -> Tensor: self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1)) self.patch_embed.bias.copy_(get_w("stem/bias")) + for i, layer in enumerate(self.layers): layer: MixerBlock prefix = f"MixerBlock_{i}/" + layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale")) layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias")) layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T) layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias")) layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T) layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias")) + layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale")) layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias")) layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T) layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias")) layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T) layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias")) + self.norm.weight.copy_(get_w("pre_head_layer_norm/scale")) self.norm.bias.copy_(get_w("pre_head_layer_norm/bias")) return self