Skip to content

Commit

Permalink
separate qkv_proj for simplicity
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent d7560ca commit 36e82b7
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@
class MHA(nn.Module):
def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None:
super().__init__()
self.in_proj = nn.Linear(d_model, d_model * 3, bias)
self.q_proj = nn.Linear(d_model, d_model, bias)
self.k_proj = nn.Linear(d_model, d_model, bias)
self.v_proj = nn.Linear(d_model, d_model, bias)
self.out_proj = nn.Linear(d_model, d_model, bias)
self.n_heads = n_heads
self.dropout = dropout
self.scale = (d_model // n_heads) ** (-0.5)

def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor:
qkv = self.in_proj(x)
q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) # (B, n_heads, L, head_dim)
q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim)
k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)
v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)

if hasattr(F, "scaled_dot_product_attention"):
out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0)
else:
Expand Down Expand Up @@ -190,10 +194,12 @@ def get_w(key: str) -> Tensor:

layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
w = torch.stack([get_w(mha_prefix + x + "/kernel") for x in ["query", "key", "value"]], 1)
b = torch.stack([get_w(mha_prefix + x + "/bias") for x in ["query", "key", "value"]], 0)
layer.mha[1].in_proj.weight.copy_(w.flatten(1).T)
layer.mha[1].in_proj.bias.copy_(b.flatten())
layer.mha[1].q_proj.weight.copy_(get_w(mha_prefix + "query/kernel").flatten(1).T)
layer.mha[1].k_proj.weight.copy_(get_w(mha_prefix + "key/kernel").flatten(1).T)
layer.mha[1].v_proj.weight.copy_(get_w(mha_prefix + "value/kernel").flatten(1).T)
layer.mha[1].q_proj.bias.copy_(get_w(mha_prefix + "query/bias").flatten())
layer.mha[1].k_proj.bias.copy_(get_w(mha_prefix + "key/bias").flatten())
layer.mha[1].v_proj.bias.copy_(get_w(mha_prefix + "value/bias").flatten())
layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T)
layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))

Expand Down

0 comments on commit 36e82b7

Please sign in to comment.