4040
4141
4242def kda_attention (
43- hidden_states : torch .Tensor ,
44- output : torch .Tensor ,
43+ q_proj_states : torch .Tensor ,
44+ k_proj_states : torch .Tensor ,
45+ v_proj_states : torch .Tensor ,
46+ g1 : torch .Tensor ,
47+ g2 : torch .Tensor ,
48+ beta : torch .Tensor ,
49+ core_attn_out : torch .Tensor ,
4550 layer_name : str ,
4651) -> None :
4752 forward_context : ForwardContext = get_forward_context ()
4853 self = forward_context .no_compile_layers [layer_name ]
49- self ._forward (hidden_states = hidden_states , output = output )
54+ self ._forward (
55+ q_proj_states = q_proj_states ,
56+ k_proj_states = k_proj_states ,
57+ v_proj_states = v_proj_states ,
58+ g1 = g1 ,
59+ g2 = g2 ,
60+ beta = beta ,
61+ core_attn_out = core_attn_out ,
62+ )
5063
5164
5265def kda_attention_fake (
53- hidden_states : torch .Tensor ,
54- output : torch .Tensor ,
66+ q_proj_states : torch .Tensor ,
67+ k_proj_states : torch .Tensor ,
68+ v_proj_states : torch .Tensor ,
69+ g1 : torch .Tensor ,
70+ g2 : torch .Tensor ,
71+ beta : torch .Tensor ,
72+ core_attn_out : torch .Tensor ,
5573 layer_name : str ,
5674) -> None :
5775 return
@@ -60,7 +78,7 @@ def kda_attention_fake(
6078direct_register_custom_op (
6179 op_name = "kda_attention" ,
6280 op_func = kda_attention ,
63- mutates_args = ["output " ],
81+ mutates_args = ["core_attn_out " ],
6482 fake_impl = kda_attention_fake ,
6583)
6684
@@ -241,37 +259,56 @@ def forward(
241259 hidden_states : torch .Tensor ,
242260 positions : torch .Tensor ,
243261 output : torch .Tensor ,
244- ) -> None :
245- return torch .ops .vllm .kda_attention (
246- hidden_states ,
247- output ,
262+ ) -> torch .Tensor :
263+ num_tokens = hidden_states .size (0 )
264+ q = self .q_proj (hidden_states )[0 ]
265+ k = self .k_proj (hidden_states )[0 ]
266+ v = self .v_proj (hidden_states )[0 ]
267+
268+ beta = self .b_proj (hidden_states )[0 ].float ().sigmoid ()
269+ g1 = self .f_b_proj (self .f_a_proj (hidden_states )[0 ])[0 ]
270+ g1 = fused_kda_gate (g1 , self .A_log , self .head_dim , g_bias = self .dt_bias )
271+ beta = beta .unsqueeze (0 )
272+ g1 = g1 .unsqueeze (0 )
273+
274+ g_proj_states = self .g_b_proj (self .g_a_proj (hidden_states )[0 ])[0 ]
275+ g2 = rearrange (g_proj_states , "... (h d) -> ... h d" , d = self .head_dim )
276+
277+ core_attn_out = torch .zeros (
278+ (1 , num_tokens , self .local_num_heads , self .head_dim ),
279+ dtype = hidden_states .dtype ,
280+ device = hidden_states .device ,
281+ )
282+ torch .ops .vllm .kda_attention (
283+ q ,
284+ k ,
285+ v ,
286+ g1 ,
287+ g2 ,
288+ beta ,
289+ core_attn_out ,
248290 self .prefix ,
249291 )
292+ core_attn_out = self .o_norm (core_attn_out , g2 )
293+ core_attn_out = rearrange (core_attn_out , "1 n h d -> n (h d)" )
294+
295+ return self .o_proj (core_attn_out )[0 ]
250296
251297 def _forward (
252298 self ,
253- hidden_states : torch .Tensor ,
254- output : torch .Tensor ,
299+ q_proj_states : torch .Tensor ,
300+ k_proj_states : torch .Tensor ,
301+ v_proj_states : torch .Tensor ,
302+ g1 : torch .Tensor ,
303+ g2 : torch .Tensor ,
304+ beta : torch .Tensor ,
305+ core_attn_out : torch .Tensor ,
255306 ) -> None :
256307 forward_context = get_forward_context ()
257308 attn_metadata : AttentionMetadata = forward_context .attn_metadata
258309
259310 if attn_metadata is None :
260- # V1 profile run
261- # Mimic the memory allocation in the real run
262- q = torch .empty_like (hidden_states )
263- k = torch .empty_like (hidden_states )
264- v = torch .empty_like (hidden_states )
265- g = hidden_states .new_empty (
266- hidden_states .size (0 ),
267- self .local_num_heads ,
268- self .head_dim ,
269- dtype = torch .float32 ,
270- )
271- beta = torch .empty (
272- hidden_states .size (0 ), self .local_num_heads , dtype = torch .float32
273- )
274- core_attn_out = torch .empty_like (hidden_states )
311+ # # V1 profile run
275312 return
276313
277314 assert isinstance (attn_metadata , dict )
@@ -288,10 +325,6 @@ def _forward(
288325 conv_state_k = conv_state_k .transpose (- 1 , - 2 )
289326 conv_state_v = conv_state_v .transpose (- 1 , - 2 )
290327
291- q_proj_states = self .q_proj (hidden_states )[0 ]
292- k_proj_states = self .k_proj (hidden_states )[0 ]
293- v_proj_states = self .v_proj (hidden_states )[0 ]
294-
295328 q_conv_weights = self .q_conv1d .weight .view (
296329 self .q_conv1d .weight .size (0 ), self .q_conv1d .weight .size (2 )
297330 )
@@ -374,14 +407,6 @@ def _forward(
374407 lambda x : rearrange (x , "n (h d) -> 1 n h d" , d = self .head_dim ), (q , k , v )
375408 )
376409
377- beta = self .b_proj (hidden_states )[0 ].float ().sigmoid ()
378-
379- g = self .f_b_proj (self .f_a_proj (hidden_states )[0 ])[0 ]
380- g = fused_kda_gate (g , self .A_log , self .head_dim , g_bias = self .dt_bias )
381-
382- beta = beta .unsqueeze (0 )
383- g = g .unsqueeze (0 )
384-
385410 if attn_metadata .num_prefills > 0 :
386411 zero_idx = non_spec_state_indices_tensor [~ has_initial_state ]
387412 recurrent_state [zero_idx ] = 0
@@ -393,7 +418,7 @@ def _forward(
393418 q = q ,
394419 k = k ,
395420 v = v ,
396- g = g ,
421+ g = g1 ,
397422 beta = beta ,
398423 initial_state = initial_state ,
399424 output_final_state = True ,
@@ -410,17 +435,12 @@ def _forward(
410435 q = q ,
411436 k = k ,
412437 v = v ,
413- g = g ,
438+ g = g1 ,
414439 beta = beta ,
415440 initial_state = recurrent_state ,
416441 use_qk_l2norm_in_kernel = True ,
417442 cu_seqlens = non_spec_query_start_loc ,
418443 ssm_state_indices = non_spec_state_indices_tensor ,
419444 )
420-
421- g_proj_states = self .g_b_proj (self .g_a_proj (hidden_states )[0 ])[0 ]
422- g = rearrange (g_proj_states , "... (h d) -> ... h d" , d = self .head_dim )
423- core_attn_out = self .o_norm (core_attn_out_non_spec , g )
424- core_attn_out = rearrange (core_attn_out , "1 n h d -> n (h d)" )
425-
426- output [:] = self .o_proj (core_attn_out )[0 ]
445+ assert core_attn_out_non_spec .shape == core_attn_out .shape
446+ core_attn_out [:] = core_attn_out_non_spec
0 commit comments