@@ -174,6 +174,7 @@ def _tpu_flash_attention(
174174 flash_block_sizes : BlockSizes ,
175175 dtype : jnp .dtype = jnp .float32 ,
176176 attention_kernel : str = "flash" ,
177+ is_self_attention : Optional [bool ] = None ,
177178) -> jax .Array :
178179 """TPU Flash Attention"""
179180
@@ -201,8 +202,22 @@ def _tpu_flash_attention(
201202 query = _reshape_data_for_flash (query , heads )
202203 key = _reshape_data_for_flash (key , heads )
203204 value = _reshape_data_for_flash (value , heads )
204- q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
205- kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
205+
206+ # Use different sharding strategy for self-attn vs cross-attn
207+ if is_self_attention is not None :
208+ if is_self_attention :
209+ # Self-attention: Context Parallelism (sharding along num_heads)
210+ q_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
211+ kv_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
212+ else :
213+ # Cross-attention: Sequence Parallelism for Q
214+ # Q's sequence is sharded; K/V are replicated
215+ q_axis_names = PartitionSpec ("data" , None , ("fsdp" , "tensor" ), None )
216+ kv_axis_names = PartitionSpec ("data" , None , None , None )
217+ else :
218+ # Fallback to original maxdiffusion behavior if the flag isn't provided
219+ q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
220+ kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
206221
207222 @functools .partial (
208223 shard_map .shard_map ,
@@ -419,6 +434,7 @@ def _apply_attention(
419434 axis_names_kv : AxisNames ,
420435 flash_block_sizes : BlockSizes ,
421436 dpa_layer : Callable ,
437+ is_self_attention : bool = True ,
422438):
423439 """Routes to different attention kernels."""
424440 _check_attention_inputs (query , key , value )
@@ -439,7 +455,7 @@ def _apply_attention(
439455 )
440456 elif attention_kernel == "flash" :
441457 return _tpu_flash_attention (
442- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
458+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel , is_self_attention ,
443459 )
444460 elif attention_kernel == "ring" :
445461 return _tpu_flash_attention (
@@ -574,6 +590,7 @@ def __init__(
574590 flash_block_sizes : BlockSizes = None ,
575591 dtype : DType = jnp .float32 ,
576592 quant : Quant = None ,
593+ is_self_attention : bool = True ,
577594 ):
578595 self .dpa_layer = None
579596 if attention_kernel == "cudnn_flash_te" :
@@ -593,6 +610,7 @@ def __init__(
593610 self .flash_block_sizes = flash_block_sizes
594611 self .dtype = dtype
595612 self .quant = quant
613+ self .is_self_attention = is_self_attention
596614
597615 def apply_attention (self , query : Array , key : Array , value : Array ):
598616 return _apply_attention (
@@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
613631 axis_names_kv = self .axis_names_kv ,
614632 flash_block_sizes = self .flash_block_sizes ,
615633 dpa_layer = self .dpa_layer ,
634+ is_self_attention = self .is_self_attention ,
616635 )
617636
618637
@@ -701,6 +720,7 @@ def __init__(
701720 precision : jax .lax .Precision = None ,
702721 qkv_bias : bool = False ,
703722 quant : Quant = None ,
723+ is_self_attention : bool = True ,
704724 ):
705725 if attention_kernel == "cudnn_flash_te" :
706726 raise NotImplementedError (f"Wan 2.1 has not been tested with { attention_kernel } " )
@@ -730,6 +750,7 @@ def __init__(
730750 flash_block_sizes = flash_block_sizes ,
731751 dtype = dtype ,
732752 quant = quant ,
753+ is_self_attention = is_self_attention ,
733754 )
734755 # None axes corresponds to the stacked weights across all blocks
735756 # because of the use of nnx.vmap and nnx.scan.
0 commit comments