Skip to content

Commit bb68a8c

Browse files
committed
take care of vlm key / values having different amount of attention heads than main model
1 parent 322f2fe commit bb68a8c

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.6.2"
3+
version = "2.6.3"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_x_transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,8 +1228,8 @@ def test_external_key_values():
12281228
seq = torch.randint(0, 20000, (3, 1024))
12291229

12301230
key_values = [
1231-
(torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1232-
(torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1231+
(torch.randn(3, 2, 32, 16), torch.randn(3, 2, 32, 16)),
1232+
(torch.randn(3, 2, 32, 16), torch.randn(3, 2, 32, 16)),
12331233
]
12341234

12351235
additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()

x_transformers/x_transformers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1795,14 +1795,20 @@ def forward(
17951795
seq_len = k.shape[-2]
17961796

17971797
added_k, added_v = additional_key_values
1798+
added_kv_heads, added_kv_len = added_k.shape[1], added_k.shape[-2]
1799+
1800+
# take care of expanding to query heads if mismatch between key / value heads with the ones coming from vlm
1801+
1802+
if added_kv_heads != kv_h:
1803+
assert divisible_by(h, added_kv_heads)
1804+
k, v, added_k, added_v = tuple(repeat(t, 'b h ... -> b (r h) ...', r = h // t.shape[1]) for t in (k, v, added_k, added_v))
17981805

17991806
k = cat((added_k, k), dim = -2)
18001807
v = cat((added_v, v), dim = -2)
18011808

18021809
if (exists(input_mask) or exists(additional_key_value_mask)):
18031810

18041811
if not exists(additional_key_value_mask):
1805-
added_kv_len = added_k.shape[-2]
18061812
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
18071813
elif not exists(input_mask):
18081814
input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)

0 commit comments

Comments
 (0)