From 2448c6f749670963c473656e339c63175a6662ba Mon Sep 17 00:00:00 2001 From: take0x <89313929+take0x@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:38:39 +0900 Subject: [PATCH 1/2] Transcribe on GPU --- whisper/transcribe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a201..d34152832 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -130,7 +130,9 @@ def transcribe( decode_options["fp16"] = False # Pad 30-seconds of silence to the input audio, for slicing - mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) + mel = log_mel_spectrogram( + audio, model.dims.n_mels, padding=N_SAMPLES, device=model.device + ) content_frames = mel.shape[-1] - N_FRAMES content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) From c1031a5787e7f21b789e9b84309d443d2fc7188a Mon Sep 17 00:00:00 2001 From: take0x <89313929+take0x@users.noreply.github.com> Date: Mon, 23 Sep 2024 08:06:27 +0900 Subject: [PATCH 2/2] Add mel_spectrogram_device parameter --- whisper/transcribe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d34152832..d3a6283ab 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -51,6 +51,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + mel_spectrogram_device: Optional[Union[str, torch.device]] = None, **decode_options, ): """ @@ -113,6 +114,9 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + mel_spectrogram_device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -131,7 +135,7 @@ def transcribe( # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram( - audio, model.dims.n_mels, padding=N_SAMPLES, device=model.device + audio, model.dims.n_mels, padding=N_SAMPLES, device=mel_spectrogram_device ) content_frames = mel.shape[-1] - N_FRAMES content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)