Skip to content

Commit 322f2fe

Browse files
committed
always route vlm key values to highest self attention layers
1 parent 4dd085e commit 322f2fe

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.6.1"
3+
version = "2.6.2"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_x_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,7 @@ def test_external_key_values():
12191219
max_seq_len = 1024,
12201220
attn_layers = Decoder(
12211221
dim = 512,
1222-
depth = 2,
1222+
depth = 3,
12231223
heads = 8,
12241224
attn_dim_head = 16
12251225
)

x_transformers/x_transformers.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2435,6 +2435,7 @@ def forward(
24352435
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
24362436
self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
24372437
additional_kv_mask = None,
2438+
route_additional_kv_to_top = True,
24382439
condition = None,
24392440
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
24402441
layers_execute_order: tuple[int, ...] | None = None
@@ -2544,10 +2545,6 @@ def forward(
25442545

25452546
iter_attn_cache = iter(attn_cache)
25462547

2547-
# additional self attn key / values
2548-
2549-
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2550-
25512548
# handle deep embeds if needed
25522549

25532550
deep_embeds = []
@@ -2582,6 +2579,16 @@ def forward(
25822579
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
25832580
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
25842581

2582+
# additional self attn key / values - say coming from vlm
2583+
2584+
if exists(self_attn_additional_kv) and route_additional_kv_to_top:
2585+
num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
2586+
2587+
self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
2588+
self_attn_additional_kv = [None] * (num_self_attns - len(self_attn_additional_kv)) + self_attn_additional_kv
2589+
2590+
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2591+
25852592
# derived input for reinjection if needed
25862593

25872594
inp_inject = None

0 commit comments

Comments
 (0)