diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a201..d34152832 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -130,7 +130,9 @@ def transcribe( decode_options["fp16"] = False # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) + mel = log_mel_spectrogram( + audio, model.dims.n_mels, padding=N_SAMPLES, device=model.device + ) content_frames = mel.shape[-1] - N_FRAMES content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)