Reminder
System Info
the sp branch code,when transformers version > 4.51.0, sequence_parallel_attention will be registered.but the forward func,_update_causal_mask has a judge like "if self.config._attn_implementation == "flash_attention_2":"。finally,attention mask changes from 2d to 4d,i think it is a bug,can you help me ?
Reproduction
just look look code,it is ok
Expected behavior
No response
Others
No response