diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d009..c2d07fa75 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -732,6 +732,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]: # repeat text tensors by the group size, for beam search or best-of-n sampling tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) + audio_features = audio_features.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)