From 6489df456bc7e22fbaa0f735497deb0aefe4f7a6 Mon Sep 17 00:00:00 2001 From: Tlntin <371043382@qq.com> Date: Fri, 25 Oct 2024 11:01:07 +0800 Subject: [PATCH] fixup a bug --- export/modeling_qwen2.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/export/modeling_qwen2.py b/export/modeling_qwen2.py index 53cdd8d..fcb6de9 100644 --- a/export/modeling_qwen2.py +++ b/export/modeling_qwen2.py @@ -682,27 +682,27 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: + kv_seq_len = key_states.shape[-2] + past_key_value.shape[2] + # if past_key_value is not None: # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value.shape[1] + # kv_seq_len += past_key_value.shape[2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) output_cache = (key_states, value_states) - if past_key_value is not None: - # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - cache_key = past_key_value[ - :, - self.layer_idx * 2 * self.num_key_value_heads: (self.layer_idx * 2 + 1) * self.num_key_value_heads - ] - cache_value = past_key_value[ - :, - (self.layer_idx * 2 + 1) * self.num_key_value_heads: (self.layer_idx * 2 + 2) * self.num_key_value_heads - ] - key_states = torch.cat((cache_key, key_states), dim=2) - value_states = torch.cat((cache_value, value_states), dim=2) + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_key = past_key_value[ + :, + self.layer_idx * 2 * self.num_key_value_heads: (self.layer_idx * 2 + 1) * self.num_key_value_heads + ] + cache_value = past_key_value[ + :, + (self.layer_idx * 2 + 1) * self.num_key_value_heads: (self.layer_idx * 2 + 2) * self.num_key_value_heads + ] + key_states = torch.cat((cache_key, key_states), dim=2) + value_states = torch.cat((cache_value, value_states), dim=2) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)