diff --git a/README.md b/README.md index 1a661d781..696869c1e 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ audio = whisper.load_audio("audio.mp3") audio = whisper.pad_or_trim(audio) # make log-Mel spectrogram and move to the same device as the model -mel = whisper.log_mel_spectrogram(audio).to(model.device) +mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device) # detect the spoken language _, probs = model.detect_language(mel) diff --git a/whisper/audio.py b/whisper/audio.py index cf6c66ad9..826250f37 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -122,7 +122,7 @@ def log_mel_spectrogram( The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz n_mels: int - The number of Mel-frequency filters, only 80 is supported + The number of Mel-frequency filters, only 80 and 128 are supported padding: int Number of zero samples to pad to the right @@ -132,7 +132,7 @@ def log_mel_spectrogram( Returns ------- - torch.Tensor, shape = (80, n_frames) + torch.Tensor, shape = (n_mels, n_frames) A Tensor that contains the Mel spectrogram """ if not torch.is_tensor(audio):