Skip to content

Commit

Permalink
[FEAT][faster whisperx]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 14, 2024
1 parent 66bd097 commit b937a3b
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 10 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ bitsandbytes
peft
accelerate
huggingface_hub[hf_transfer]
huggingface-hub
huggingface-hub
whisperx
4 changes: 3 additions & 1 deletion servers/whisperx/download_mp3.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import os
from pytube import YouTube


def download_youtube_audio(video_url, output_folder):
try:
yt = YouTube(video_url)
audio_stream = yt.streams.filter(only_audio=True).first()
if audio_stream:
output_file = audio_stream.download(output_path=output_folder)
# Rename the file to have the .mp3 extension
mp3_file = output_file.split('.')[0] + '.mp3'
mp3_file = output_file.split(".")[0] + ".mp3"
os.rename(output_file, mp3_file)
print("Audio downloaded successfully:", mp3_file)
else:
print("No audio stream available for the given URL.")
except Exception as e:
print("Error:", e)


# Example usage
video_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" # Example YouTube video URL
output_folder = "youtube_downloads" # Output folder where the MP3 will be saved
Expand Down
4 changes: 1 addition & 3 deletions servers/whisperx/example_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ class WhisperTranscription(BaseModel):


# Construct the request data
request_data = WhisperTranscription(
file = "song.mp3"
)
request_data = WhisperTranscription(file="song.mp3")

# Specify the URL of your FastAPI application
url = "https://localhost:8000/v1/audio/transcriptions"
Expand Down
74 changes: 74 additions & 0 deletions servers/whisperx/whisper_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
from faster_whisper import WhisperModel


class FasterWhisperTranscriber:
def __init__(
self,
model_size="large-v3",
device="cuda",
compute_type="float16",
model_type="faster-whisper",
**kwargs,
):
"""
Initialize the WhisperModel with specified configuration.
:param model_size: Size of the Whisper model (e.g., 'large-v3', 'distil-large-v2')
:param device: Computation device ('cuda' or 'cpu')
:param compute_type: Type of computation ('float16', 'int8_float16', 'int8')
:param model_type: Type of model ('faster-whisper' or 'faster-distil-whisper')
:param kwargs: Additional arguments for WhisperModel transcribe method
"""
self.model_size = model_size
self.device = device
self.compute_type = compute_type
self.model_type = model_type
self.transcribe_options = kwargs
self.model = WhisperModel(
self.model_size, device=self.device, compute_type=self.compute_type
)

def run(self, task: str, *args, **kwargs):
"""
Transcribes the given audio file using the Whisper model.
:param audio_file_path: Path to the audio file to be transcribed
:return: Transcription results
"""
segments, info = self.model.transcribe(
task, **self.transcribe_options
)

# Printing language detection information
print(
f"Detected language '{info.language}' with probability {info.language_probability:.2f}"
)

# Handling transcription based on the model type
if self.model_type == "faster-whisper":
for segment in segments:
print(f"[{segment.start:.2fs} -> {segment.end:.2fs}] {segment.text}")
elif (
self.model_type == "faster-distil-whisper"
and "word_timestamps" in self.transcribe_options
and self.transcribe_options["word_timestamps"]
):
for segment in segments:
for word in segment.words:
print(f"[{word.start:.2fs} -> {word.end:.2fs}] {word.word}")
else:
for segment in segments:
print(f"[{segment.start:.2fs} -> {segment.end:.2fs}] {segment.text}")


# Example usage
if __name__ == "__main__":
logging.basicConfig()
logging.getLogger("faster_whisper").setLevel(logging.DEBUG)

# Example for faster-whisper with GPU and FP16
transcriber = FasterWhisperTranscriber(
model_size="large-v3", device="cuda", compute_type="float16", beam_size=5
)
transcriber.run("song.mp3")
3 changes: 1 addition & 2 deletions servers/whisperx/whisperx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
hf_token = os.getenv("HF_TOKEN")



class WhisperTranscriber:
"""
A class for transcribing audio using the Whisper ASR system.
Expand Down Expand Up @@ -285,7 +284,7 @@ async def create_audio_completion(request: WhisperTranscription):
# Log the entry into supabase

transcriber = WhisperTranscriber(
device = "cuda" if torch.cuda.is_available() else "cpu",
device="cuda" if torch.cuda.is_available() else "cpu",
)

# Run the audio processing pipeline
Expand Down
5 changes: 2 additions & 3 deletions servers/whisperx/whisperx_no_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import whisperx
import os
import os

from dotenv import load_dotenv

Expand Down Expand Up @@ -157,8 +157,7 @@ def run(self, audio_file: str):


# Instantiate the WhisperTranscriber
model = WhisperTranscriber(
)
model = WhisperTranscriber()

# Run the audio processing pipeline
result = model.run("song.mp3")

0 comments on commit b937a3b

Please sign in to comment.