Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Main #2281

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Main #2281

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 64 additions & 44 deletions whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,20 @@

def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Open an audio file and read as mono waveform, resampling as necessary.

Parameters
----------
file: str
The audio file to open
The audio file to open.

sr: int
The sample rate to resample the audio if necessary
The sample rate to resample the audio if necessary.

Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
Expand All @@ -53,7 +49,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
Expand All @@ -65,6 +60,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.

Parameters
----------
array: Union[np.ndarray, torch.Tensor]
The audio array to pad or trim.

length: int
The desired length of the audio array.

axis: int
The axis along which to pad or trim.

Returns
-------
A padded or trimmed array.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
Expand All @@ -91,14 +101,20 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:

np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
Load the mel filterbank matrix for projecting STFT into a Mel spectrogram.

Parameters
----------
device: torch.device
The device to load the filters on.

n_mels: int
The number of Mel-frequency filters.

Returns
-------
torch.Tensor
The Mel filterbank matrix.
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

Expand All @@ -114,44 +130,48 @@ def log_mel_spectrogram(
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Compute the log-Mel spectrogram of the audio.

Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
audio: Union[str, np.ndarray, torch.Tensor]
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz.

n_mels: int
The number of Mel-frequency filters, only 80 is supported
The number of Mel-frequency filters.

padding: int
Number of zero samples to pad to the right
Number of zero samples to pad to the right.

device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
If given, the audio tensor is moved to this device before STFT.

Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
torch.Tensor
A Tensor that contains the Mel spectrogram.
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
try:
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
except Exception as e:
print(f"Error computing log-mel spectrogram: {e}")
return None
87 changes: 49 additions & 38 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class DecodingResult:
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
tokens_probs: list[float] = field(default_factory=list)


class Inference:
Expand Down Expand Up @@ -218,8 +219,8 @@ def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""

def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list[list[float]]
) -> Tuple[Tensor, list, bool]:
"""Specify how to select the next token, based on the current trace and logits

Parameters
Expand Down Expand Up @@ -275,27 +276,34 @@ def __init__(self, temperature: float, eot: int):
self.eot = eot

def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list
) -> Tuple[Tensor, list, bool]:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / self.temperature).sample()

logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
current_probs = torch.exp(current_logprobs)
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)

tokens_probs = [t_p + [c_p.item()] for t_p, c_p in zip(tokens_probs, current_probs)]
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)

completed = (tokens[:, -1] == self.eot).all()
return tokens, completed

return tokens, tokens_probs, completed

def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
# make sure each sequence has at least one EOT token at the end
def finalize(
self, tokens: Tensor, tokens_probs: list, sum_logprobs: Tensor
) -> Tuple[Tensor, list, list]:
tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist()
tokens_probs = [[ t + [1.0] for t in s] for s in tokens_probs]
return tokens, tokens_probs, sum_logprobs.tolist()




class BeamSearchDecoder(TokenDecoder):
Expand All @@ -321,37 +329,39 @@ def reset(self):
self.finished_sequences = None

def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, tokens_probs: list
) -> Tuple[Tensor, list, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")

n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update
if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)]

logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
scores, sources, finished, probs = {}, {}, {}, {}

# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
prob = torch.exp(logprob).item()
new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob
sources[sequence] = idx

# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
probs[sequence] = tokens_probs[idx] + [prob]

saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
finished[sequence] = (scores[sequence], probs[sequence])
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
tokens_probs[len(next_tokens)] = probs[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])

Expand All @@ -364,44 +374,42 @@ def update(
tokens = torch.tensor(next_tokens, device=tokens.device)
self.inference.rearrange_kv_cache(source_indices)

# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
break
previously_finished[seq] = newly_finished[seq]

# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
)
return tokens, completed

def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
return tokens, tokens_probs, completed

def finalize(
self, preceding_tokens: Tensor, preceding_tokens_prob: list, sum_logprobs: Tensor
) -> Tuple[list, list, list]:
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if (
len(sequences) < self.beam_size
): # when not enough sequences are finished
if len(sequences) < self.beam_size:
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
sequences[tuple(sequence)] = (sum_logprobs[i][j].item(), preceding_tokens_prob[i][j] + [1.0])
if len(sequences) >= self.beam_size:
break

tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()]
for sequences in self.finished_sequences
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
[v[0] for v in sequences.values()] for sequences in self.finished_sequences
]
tokens_probs: list[list[list[float]]] = [
[v[1] for v in sequences.values()] for sequences in self.finished_sequences
]
return tokens, sum_logprobs

return tokens, tokens_probs, sum_logprobs


class LogitFilter:
Expand Down Expand Up @@ -700,7 +708,8 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
logit_filter.apply(logits, tokens)

# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
tokens, tokens_probs, completed = self.decoder.update(tokens, logits, sum_logprobs, tokens_probs)


if completed or tokens.shape[-1] > self.n_ctx:
break
Expand Down Expand Up @@ -734,7 +743,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)

# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
tokens, sum_logprobs, no_speech_probs, tokens_probs = self._main_loop(audio_features, tokens)

# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
Expand All @@ -747,8 +756,10 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
for s in tokens
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
]
tokens_probs: list[list[list[float]]] = [
[probs[:tokens.shape[0]]for probs, tokens in zip(s, t)] for s, t in zip(tokens_probs, tokens)
]

# select the top-ranked sample in each group
Expand Down
Loading