diff --git a/export/modeling_qwen2.py b/export/modeling_qwen2.py index 53cdd8d..fcb6de9 100644 --- a/export/modeling_qwen2.py +++ b/export/modeling_qwen2.py @@ -682,27 +682,27 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: + kv_seq_len = key_states.shape[-2] + past_key_value.shape[2] + # if past_key_value is not None: # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value.shape[1] + # kv_seq_len += past_key_value.shape[2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) output_cache = (key_states, value_states) - if past_key_value is not None: - # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - cache_key = past_key_value[ - :, - self.layer_idx * 2 * self.num_key_value_heads: (self.layer_idx * 2 + 1) * self.num_key_value_heads - ] - cache_value = past_key_value[ - :, - (self.layer_idx * 2 + 1) * self.num_key_value_heads: (self.layer_idx * 2 + 2) * self.num_key_value_heads - ] - key_states = torch.cat((cache_key, key_states), dim=2) - value_states = torch.cat((cache_value, value_states), dim=2) + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_key = past_key_value[ + :, + self.layer_idx * 2 * self.num_key_value_heads: (self.layer_idx * 2 + 1) * self.num_key_value_heads + ] + cache_value = past_key_value[ + :, + (self.layer_idx * 2 + 1) * self.num_key_value_heads: (self.layer_idx * 2 + 2) * self.num_key_value_heads + ] + key_states = torch.cat((cache_key, key_states), dim=2) + value_states = torch.cat((cache_value, value_states), dim=2) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)