5353CROSS_ATTN_KV_LENGTH = common_types .CROSS_ATTN_KV_LENGTH
5454
5555
56-
5756def _maybe_aqt_einsum (quant : Quant ):
5857 return jnp .einsum if quant is None else quant .einsum ()
5958
@@ -448,7 +447,16 @@ def _apply_attention(
448447 )
449448 elif attention_kernel == "flash" :
450449 return _tpu_flash_attention (
451- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
450+ query ,
451+ key * scale ,
452+ value ,
453+ heads ,
454+ mesh ,
455+ axis_names_q ,
456+ axis_names_kv ,
457+ flash_block_sizes ,
458+ dtype ,
459+ attention_kernel ,
452460 )
453461 elif attention_kernel == "ring" :
454462 return _tpu_flash_attention (
@@ -733,7 +741,7 @@ def __init__(
733741 else :
734742 axis_names_q = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_Q_LENGTH , D_KV )
735743 axis_names_kv = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_KV_LENGTH , D_KV )
736-
744+
737745 self .attention_op = NNXAttentionOp (
738746 mesh = mesh ,
739747 attention_kernel = attention_kernel ,
@@ -1542,4 +1550,4 @@ def setup(self):
15421550 def __call__ (self , hidden_states , deterministic = True ):
15431551 hidden_states = self .proj (hidden_states )
15441552 hidden_linear , hidden_gelu = jnp .split (hidden_states , 2 , axis = 2 )
1545- return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
1553+ return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
0 commit comments