diff --git a/whisper/audio.py b/whisper/audio.py index cf6c66ad9..f28e0fd70 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -24,24 +24,20 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): """ - Open an audio file and read as mono waveform, resampling as necessary + Open an audio file and read as mono waveform, resampling as necessary. Parameters ---------- file: str - The audio file to open + The audio file to open. sr: int - The sample rate to resample the audio if necessary + The sample rate to resample the audio if necessary. Returns ------- A NumPy array containing the audio waveform, in float32 dtype. """ - - # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. - # fmt: off cmd = [ "ffmpeg", "-nostdin", @@ -53,7 +49,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): "-ar", str(sr), "-" ] - # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout except CalledProcessError as e: @@ -65,6 +60,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): """ Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + + Parameters + ---------- + array: Union[np.ndarray, torch.Tensor] + The audio array to pad or trim. + + length: int + The desired length of the audio array. + + axis: int + The axis along which to pad or trim. + + Returns + ------- + A padded or trimmed array. """ if torch.is_tensor(array): if array.shape[axis] > length: @@ -91,14 +101,20 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @lru_cache(maxsize=None) def mel_filters(device, n_mels: int) -> torch.Tensor: """ - load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - Allows decoupling librosa dependency; saved using: - - np.savez_compressed( - "mel_filters.npz", - mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), - mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), - ) + Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + + Parameters + ---------- + device: torch.device + The device to load the filters on. + + n_mels: int + The number of Mel-frequency filters. + + Returns + ------- + torch.Tensor + The Mel filterbank matrix. """ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" @@ -114,44 +130,48 @@ def log_mel_spectrogram( device: Optional[Union[str, torch.device]] = None, ): """ - Compute the log-Mel spectrogram of + Compute the log-Mel spectrogram of the audio. Parameters ---------- - audio: Union[str, np.ndarray, torch.Tensor], shape = (*) - The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + audio: Union[str, np.ndarray, torch.Tensor] + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz. n_mels: int - The number of Mel-frequency filters, only 80 is supported + The number of Mel-frequency filters. padding: int - Number of zero samples to pad to the right + Number of zero samples to pad to the right. device: Optional[Union[str, torch.device]] - If given, the audio tensor is moved to this device before STFT + If given, the audio tensor is moved to this device before STFT. Returns ------- - torch.Tensor, shape = (80, n_frames) - A Tensor that contains the Mel spectrogram + torch.Tensor + A Tensor that contains the Mel spectrogram. """ - if not torch.is_tensor(audio): - if isinstance(audio, str): - audio = load_audio(audio) - audio = torch.from_numpy(audio) - - if device is not None: - audio = audio.to(device) - if padding > 0: - audio = F.pad(audio, (0, padding)) - window = torch.hann_window(N_FFT).to(audio.device) - stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) - magnitudes = stft[..., :-1].abs() ** 2 - - filters = mel_filters(audio.device, n_mels) - mel_spec = filters @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - return log_spec + try: + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + except Exception as e: + print(f"Error computing log-mel spectrogram: {e}") + return None diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d009..1e870a3ce 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -125,6 +125,7 @@ class DecodingResult: no_speech_prob: float = np.nan temperature: float = np.nan compression_ratio: float = np.nan + tokens_probs: list[float] = field(default_factory=list) class Inference: @@ -218,8 +219,8 @@ def reset(self): """Initialize any stateful variables for decoding a new sequence""" def update( - self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor - ) -> Tuple[Tensor, bool]: + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list[list[float]] + ) -> Tuple[Tensor, list, bool]: """Specify how to select the next token, based on the current trace and logits Parameters @@ -275,8 +276,8 @@ def __init__(self, temperature: float, eot: int): self.eot = eot def update( - self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor - ) -> Tuple[Tensor, bool]: + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list + ) -> Tuple[Tensor, list, bool]: if self.temperature == 0: next_tokens = logits.argmax(dim=-1) else: @@ -284,18 +285,25 @@ def update( logprobs = F.log_softmax(logits.float(), dim=-1) current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] + current_probs = torch.exp(current_logprobs) sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) + tokens_probs = [t_p + [c_p.item()] for t_p, c_p in zip(tokens_probs, current_probs)] next_tokens[tokens[:, -1] == self.eot] = self.eot tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) completed = (tokens[:, -1] == self.eot).all() - return tokens, completed + + return tokens, tokens_probs, completed - def finalize(self, tokens: Tensor, sum_logprobs: Tensor): - # make sure each sequence has at least one EOT token at the end + def finalize( + self, tokens: Tensor, tokens_probs: list, sum_logprobs: Tensor + ) -> Tuple[Tensor, list, list]: tokens = F.pad(tokens, (0, 1), value=self.eot) - return tokens, sum_logprobs.tolist() + tokens_probs = [[ t + [1.0] for t in s] for s in tokens_probs] + return tokens, tokens_probs, sum_logprobs.tolist() + + class BeamSearchDecoder(TokenDecoder): @@ -321,37 +329,39 @@ def reset(self): self.finished_sequences = None def update( - self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor - ) -> Tuple[Tensor, bool]: + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list + ) -> Tuple[Tensor, list, bool]: if tokens.shape[0] % self.beam_size != 0: raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") n_audio = tokens.shape[0] // self.beam_size - if self.finished_sequences is None: # for the first update + if self.finished_sequences is None: self.finished_sequences = [{} for _ in range(n_audio)] logprobs = F.log_softmax(logits.float(), dim=-1) next_tokens, source_indices, finished_sequences = [], [], [] for i in range(n_audio): - scores, sources, finished = {}, {}, {} + scores, sources, finished, probs = {}, {}, {}, {} - # STEP 1: calculate the cumulative log probabilities for possible candidates for j in range(self.beam_size): idx = i * self.beam_size + j prefix = tokens[idx].tolist() for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): + prob = torch.exp(logprob).item() new_logprob = (sum_logprobs[idx] + logprob).item() sequence = tuple(prefix + [token.item()]) scores[sequence] = new_logprob sources[sequence] = idx - # STEP 2: rank the candidates and keep the top beam_size sequences for each audio + probs[sequence] = tokens_probs[idx] + [prob] + saved = 0 for sequence in sorted(scores, key=scores.get, reverse=True): if sequence[-1] == self.eot: - finished[sequence] = scores[sequence] + finished[sequence] = (scores[sequence], probs[sequence]) else: sum_logprobs[len(next_tokens)] = scores[sequence] + tokens_probs[len(next_tokens)] = probs[sequence] next_tokens.append(sequence) source_indices.append(sources[sequence]) @@ -364,44 +374,42 @@ def update( tokens = torch.tensor(next_tokens, device=tokens.device) self.inference.rearrange_kv_cache(source_indices) - # add newly finished sequences to self.finished_sequences assert len(self.finished_sequences) == len(finished_sequences) - for previously_finished, newly_finished in zip( - self.finished_sequences, finished_sequences - ): + for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): if len(previously_finished) >= self.max_candidates: - break # the candidate list is full + break previously_finished[seq] = newly_finished[seq] - # mark as completed if all audio has enough number of samples completed = all( - len(sequences) >= self.max_candidates - for sequences in self.finished_sequences + len(sequences) >= self.max_candidates for sequences in self.finished_sequences ) - return tokens, completed - def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): - # collect all finished sequences, including patience, and add unfinished ones if not enough + return tokens, tokens_probs, completed + + def finalize( + self, preceding_tokens: Tensor, preceding_tokens_prob: list, sum_logprobs: Tensor + ) -> Tuple[list, list, list]: sum_logprobs = sum_logprobs.cpu() for i, sequences in enumerate(self.finished_sequences): - if ( - len(sequences) < self.beam_size - ): # when not enough sequences are finished + if len(sequences) < self.beam_size: for j in list(np.argsort(sum_logprobs[i]))[::-1]: sequence = preceding_tokens[i, j].tolist() + [self.eot] - sequences[tuple(sequence)] = sum_logprobs[i][j].item() + sequences[tuple(sequence)] = (sum_logprobs[i][j].item(), preceding_tokens_prob[i][j] + [1.0]) if len(sequences) >= self.beam_size: break tokens: List[List[Tensor]] = [ - [torch.tensor(seq) for seq in sequences.keys()] - for sequences in self.finished_sequences + [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences ] sum_logprobs: List[List[float]] = [ - list(sequences.values()) for sequences in self.finished_sequences + [v[0] for v in sequences.values()] for sequences in self.finished_sequences + ] + tokens_probs: list[list[list[float]]] = [ + [v[1] for v in sequences.values()] for sequences in self.finished_sequences ] - return tokens, sum_logprobs + + return tokens, tokens_probs, sum_logprobs class LogitFilter: @@ -700,7 +708,8 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor): logit_filter.apply(logits, tokens) # expand the tokens tensor with the selected next tokens - tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + tokens, tokens_probs, completed = self.decoder.update(tokens, logits, sum_logprobs, tokens_probs) + if completed or tokens.shape[-1] > self.n_ctx: break @@ -734,7 +743,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]: tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop - tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) + tokens, sum_logprobs, no_speech_probs, tokens_probs = self._main_loop(audio_features, tokens) # reshape the tensors to have (n_audio, n_group) as the first two dimensions audio_features = audio_features[:: self.n_group] @@ -747,8 +756,10 @@ def run(self, mel: Tensor) -> List[DecodingResult]: # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) tokens: List[List[Tensor]] = [ - [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] - for s in tokens + [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens + ] + tokens_probs: list[list[list[float]]] = [ + [probs[:tokens.shape[0]]for probs, tokens in zip(s, t)] for s, t in zip(tokens_probs, tokens) ] # select the top-ranked sample in each group