Skip to content

Commit

Permalink
Merge pull request #78 from kadirnar/add-test-file
Browse files Browse the repository at this point in the history
πŸ„β€β™‚οΈ Add test code and Runpod platform support
  • Loading branch information
kadirnar authored May 4, 2024
2 parents 3b41a84 + d9e6c3e commit 317bdc9
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 25 deletions.
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,9 @@ from transformers import BitsAndBytesConfig, HqqConfig
url = "https://www.youtube.com/watch?v=di3rHkEZuUw"
audio_path = download_and_convert_to_mp3(url)

quant_config = HqqConfig(
nbits=1,
group_size=64,
quant_zero=False,
quant_scale=False, axis=0) #axis=0 is used by default
quant_config = HqqConfig(
nbits=1, group_size=64, quant_zero=False, quant_scale=False, axis=0
) # axis=0 is used by default


bnb_config = BitsAndBytesConfig(
Expand All @@ -53,14 +51,18 @@ bnb_config = BitsAndBytesConfig(
bnb_4bit_use_double_quant=True,
)

pipeline = SpeechToTextPipeline(model_id=distil-whisper/distil-large-v3, quant_config=quant_config) # or bnb_config
pipeline = SpeechToTextPipeline(
model_id="distil-whisper/distil-large-v3", quant_config=quant_config
) # or bnb_config

transcript = pipeline(
audio_path: str = "test.mp3",
chunk_length_s: int = 30,
stride_length_s: int = 5,
max_new_tokens: int = 128,
batch_size: int = 100,
language: str = "english",
audio_path="test.mp3",
chunk_length_s=30,
stride_length_s=5,
max_new_tokens=128,
batch_size=100,
language="english",
return_timestamps=False,
)

print(transcript)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ pre-commit==3.4.0
autollm>=0.1.9
speechbrain>=0.5.16
bitsandbytes
hqq
2 changes: 2 additions & 0 deletions scripts/runpod.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
apt-get update
apt install ffmpeg nvtop -y
21 changes: 8 additions & 13 deletions whisperplus/pipelines/whisper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
Expand All @@ -8,7 +9,7 @@
class SpeechToTextPipeline:
"""Class for converting audio to text using a pre-trained speech recognition model."""

def __init__(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
def __init__(self, model_id: str = "distil-whisper/distil-large-v3", quant_config=None):
self.model = None
self.device = None

Expand All @@ -17,21 +18,16 @@ def __init__(self, model_id: str = "openai/whisper-large-v3", quant_config=None)
else:
logging.info("Model already loaded.")

def load_model(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
"""
Loads the pre-trained speech recognition model and moves it to the specified device.
Args:
model_id (str): Identifier of the pre-trained model to be loaded.
"""
logging.info("Loading model...")
def load_model(self, model_id: str = "distil-whisper/distil-large-v3", quant_config=None):
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2",
)
torch_dtype=torch.float16,
device_map='auto',
max_memory={0: "24GiB"})
logging.info("Model loaded successfully.")

processor = AutoProcessor.from_pretrained(model_id)
Expand All @@ -41,14 +37,13 @@ def load_model(self, model_id: str = "openai/whisper-large-v3", quant_config=Non

def __call__(
self,
audio_path: str = "test.mp3",
chunk_length_s: int = 30,
stride_length_s: int = 5,
audio_path: str = "test.mp3",
max_new_tokens: int = 128,
batch_size: int = 100,
language: str = "turkish",
return_timestamps: bool = False
):
return_timestamps: bool = False):
"""
Converts audio to text using the pre-trained speech recognition model.
Expand Down
39 changes: 39 additions & 0 deletions whisperplus/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import time

import torch
from pipelines.whisper import SpeechToTextPipeline
from transformers import BitsAndBytesConfig, HqqConfig
from utils.download_utils import download_and_convert_to_mp3

url = "https://www.youtube.com/watch?v=di3rHkEZuUw"
audio_path = download_and_convert_to_mp3(url)

hqq_config = HqqConfig(
nbits=1, group_size=64, quant_zero=False, quant_scale=False, axis=0) # axis=0 is used by default

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = SpeechToTextPipeline(
model_id="distil-whisper/distil-large-v3", quant_config=bnb_config) # or bnb_config

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
transcript = model(
audio_path="testv0.mp3",
chunk_length_s=30,
stride_length_s=5,
max_new_tokens=128,
batch_size=100,
language="english",
return_timestamps=False)
end_event.record()

torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Execution time: {elapsed_time_ms}ms")

0 comments on commit 317bdc9

Please sign in to comment.