Skip to content

Commit

Permalink
fixup a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Oct 25, 2024
1 parent 1a39899 commit 6489df4
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions export/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,27 +682,27 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len = key_states.shape[-2] + past_key_value.shape[2]
# if past_key_value is not None:
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += past_key_value.shape[1]
# kv_seq_len += past_key_value.shape[2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
output_cache = (key_states, value_states)
if past_key_value is not None:
# cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cache_key = past_key_value[
:,
self.layer_idx * 2 * self.num_key_value_heads: (self.layer_idx * 2 + 1) * self.num_key_value_heads
]
cache_value = past_key_value[
:,
(self.layer_idx * 2 + 1) * self.num_key_value_heads: (self.layer_idx * 2 + 2) * self.num_key_value_heads
]
key_states = torch.cat((cache_key, key_states), dim=2)
value_states = torch.cat((cache_value, value_states), dim=2)
# if past_key_value is not None:
# cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cache_key = past_key_value[
:,
self.layer_idx * 2 * self.num_key_value_heads: (self.layer_idx * 2 + 1) * self.num_key_value_heads
]
cache_value = past_key_value[
:,
(self.layer_idx * 2 + 1) * self.num_key_value_heads: (self.layer_idx * 2 + 2) * self.num_key_value_heads
]
key_states = torch.cat((cache_key, key_states), dim=2)
value_states = torch.cat((cache_value, value_states), dim=2)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down

0 comments on commit 6489df4

Please sign in to comment.