Skip to content

Commit

Permalink
Fix mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 30, 2024
1 parent 16b3f69 commit ca46e09
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Classes related to `neuronx-distributed` to perform parallelism."""

import math
import warnings
from typing import TYPE_CHECKING, Callable, Optional, Tuple

import torch
Expand Down Expand Up @@ -724,13 +723,8 @@ def attention_forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and removed since `transformers` v4.37. Please make sure to "
"use `attention_mask` instead.`"
)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
Expand All @@ -754,14 +748,9 @@ def attention_forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
"The cache structure has changed since `transformers` v4.36. If you are using "
f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to "
"initialize the attention class with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
kv_seq_len += cache_position[0]

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
Expand Down

0 comments on commit ca46e09

Please sign in to comment.