From 36e82b7441ef2e97b0c295fb3657d9d61c378685 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 20 Aug 2023 11:08:22 +0800 Subject: [PATCH] separate qkv_proj for simplicity --- vision_toolbox/backbones/vit.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index da3f919..561190f 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -20,15 +20,19 @@ class MHA(nn.Module): def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: super().__init__() - self.in_proj = nn.Linear(d_model, d_model * 3, bias) + self.q_proj = nn.Linear(d_model, d_model, bias) + self.k_proj = nn.Linear(d_model, d_model, bias) + self.v_proj = nn.Linear(d_model, d_model, bias) self.out_proj = nn.Linear(d_model, d_model, bias) self.n_heads = n_heads self.dropout = dropout self.scale = (d_model // n_heads) ** (-0.5) def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: - qkv = self.in_proj(x) - q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) # (B, n_heads, L, head_dim) + q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) + k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + if hasattr(F, "scaled_dot_product_attention"): out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0) else: @@ -190,10 +194,12 @@ def get_w(key: str) -> Tensor: layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale")) layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias")) - w = torch.stack([get_w(mha_prefix + x + "/kernel") for x in ["query", "key", "value"]], 1) - b = torch.stack([get_w(mha_prefix + x + "/bias") for x in ["query", "key", "value"]], 0) - layer.mha[1].in_proj.weight.copy_(w.flatten(1).T) - layer.mha[1].in_proj.bias.copy_(b.flatten()) + layer.mha[1].q_proj.weight.copy_(get_w(mha_prefix + "query/kernel").flatten(1).T) + layer.mha[1].k_proj.weight.copy_(get_w(mha_prefix + "key/kernel").flatten(1).T) + layer.mha[1].v_proj.weight.copy_(get_w(mha_prefix + "value/kernel").flatten(1).T) + layer.mha[1].q_proj.bias.copy_(get_w(mha_prefix + "query/bias").flatten()) + layer.mha[1].k_proj.bias.copy_(get_w(mha_prefix + "key/bias").flatten()) + layer.mha[1].v_proj.bias.copy_(get_w(mha_prefix + "value/bias").flatten()) layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T) layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))