Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugs] Fix dispatch attn bug #829

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xtuner/dataset/collate_fns/default_collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def default_collate_fn(instances: Sequence[Dict],
# Some tokenizers have the same eos token and pad token, so input_ids
# cannot be masked directly based on the pad token id.
attention_mask = torch.zeros_like(input_ids).bool()
for i in ori_length:
attention_mask[:i] = True
for i, length in enumerate(ori_length):
attention_mask[i, :length] = True

bs, seq_len = input_ids.shape
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
Expand Down
7 changes: 6 additions & 1 deletion xtuner/model/modules/dispatch/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,22 @@ def cohere_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# self.num_heads is used in self._upad_input method
# num_heads has been changed because of sequence parallel
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
query_states.shape[1],
dropout=dropout_rate)

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down
5 changes: 5 additions & 0 deletions xtuner/model/modules/dispatch/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def deepseek_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# self.num_heads is used in self._upad_input method
# num_heads has been changed because of sequence parallel
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
Expand All @@ -141,6 +145,7 @@ def deepseek_attn_forward(

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, :self.v_head_dim]
Expand Down
5 changes: 5 additions & 0 deletions xtuner/model/modules/dispatch/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def internlm2_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# self.num_heads is used in self._upad_input method
# num_heads has been changed because of sequence parallel
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

dropout_rate = 0.0
attn_output = self._flash_attention_forward(
Expand All @@ -161,6 +165,7 @@ def internlm2_attn_forward(

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.wo(attn_output)
Expand Down
8 changes: 7 additions & 1 deletion xtuner/model/modules/dispatch/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ def mistral_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# num_heads has been changed because of sequence parallel
# `self.num_heads`` is not used in self._flash_attention_forward
# in mistral/mixtral, we are doing this to avoid some unnecessary risk
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
Expand All @@ -227,6 +232,7 @@ def mistral_attn_forward(

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len,
self.hidden_size).contiguous()
Expand Down Expand Up @@ -311,7 +317,7 @@ def mistral_varlen_attn_forward(
value_states = value_states.transpose(1, 2)
# Because the input can be padded, the absolute sequence length
# depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
Expand Down
8 changes: 7 additions & 1 deletion xtuner/model/modules/dispatch/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def phi3_attn_forward(
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states,
scatter_dim=2, gather_dim=1)
# num_heads has been changed because of sequence parallel
# `self.num_heads`` is not used in self._flash_attention_forward
# in mistral/mixtral, we are doing this to avoid some unnecessary risk
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
Expand All @@ -248,6 +253,7 @@ def phi3_attn_forward(
# (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim)
attn_output = post_process_for_sequence_parallel_attn(
attn_output, scatter_dim=1, gather_dim=2)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -333,7 +339,7 @@ def phi3_varlen_attn_forward(
self.layer_idx)

assert position_ids is not None
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(
value_states, position_ids, seq_len=rotary_seq_len)

Expand Down
8 changes: 7 additions & 1 deletion xtuner/model/modules/dispatch/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def qwen2_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# num_heads has been changed because of sequence parallel
# `self.num_heads`` is not used in self._flash_attention_forward
# in mistral/mixtral, we are doing this to avoid some unnecessary risk
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
Expand All @@ -164,6 +169,7 @@ def qwen2_attn_forward(

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -227,7 +233,7 @@ def qwen2_varlen_attn_forward(
self.layer_idx)

assert position_ids is not None
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
Expand Down
Loading