@@ -2435,6 +2435,7 @@ def forward(
2435
2435
deep_embeds_and_ids : tuple [nn .Parameter , Tensor ] | None = None ,
2436
2436
self_attn_additional_kv : list [tuple [Tensor , Tensor ]] | None = None ,
2437
2437
additional_kv_mask = None ,
2438
+ route_additional_kv_to_top = True ,
2438
2439
condition = None ,
2439
2440
in_attn_cond = None , # https://arxiv.org/abs/2105.04090
2440
2441
layers_execute_order : tuple [int , ...] | None = None
@@ -2544,10 +2545,6 @@ def forward(
2544
2545
2545
2546
iter_attn_cache = iter (attn_cache )
2546
2547
2547
- # additional self attn key / values
2548
-
2549
- iter_self_attn_kv = iter (default (self_attn_additional_kv , ()))
2550
-
2551
2548
# handle deep embeds if needed
2552
2549
2553
2550
deep_embeds = []
@@ -2582,6 +2579,16 @@ def forward(
2582
2579
layers_execute_order = default (layers_execute_order , self .layers_execute_order )
2583
2580
layer_variables = tuple (tuple (layer_variable [i ] for i in layers_execute_order ) for layer_variable in layer_variables )
2584
2581
2582
+ # additional self attn key / values - say coming from vlm
2583
+
2584
+ if exists (self_attn_additional_kv ) and route_additional_kv_to_top :
2585
+ num_self_attns = sum ([layer_type == 'a' for layer_type in first (layer_variables )])
2586
+
2587
+ self_attn_additional_kv = self_attn_additional_kv [- num_self_attns :]
2588
+ self_attn_additional_kv = [None ] * (num_self_attns - len (self_attn_additional_kv )) + self_attn_additional_kv
2589
+
2590
+ iter_self_attn_kv = iter (default (self_attn_additional_kv , ()))
2591
+
2585
2592
# derived input for reinjection if needed
2586
2593
2587
2594
inp_inject = None
0 commit comments