Skip to content

Commit

Permalink
Merge pull request #84 from kadirnar/wer
Browse files Browse the repository at this point in the history
Add optional parameter and wer metric code
  • Loading branch information
kadirnar authored May 5, 2024
2 parents 487bfa0 + 304280a commit e8cbada
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 13 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<div align="center">
<h2>
WhisperPlus: Advancing Speech2Text and Text2Speech Feature 🚀
WhisperPlus: Faster, Smarter, and More Capable 🚀
</h2>
<div>
<img width="500" alt="teaser" src="doc\openai-whisper.jpg">
Expand Down Expand Up @@ -56,8 +56,11 @@ bnb_config = BitsAndBytesConfig(
)

pipeline = SpeechToTextPipeline(
model_id="distil-whisper/distil-large-v3", quant_config=hqq_config
) # or bnb_config
model_id="distil-whisper/distil-large-v3",
quant_config=hqq_config,
hqq=True,
flash_attention_2=True,
)

transcript = pipeline(
audio_path=audio_path,
Expand Down
4 changes: 3 additions & 1 deletion whisperplus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from whisper.pipelines.whisper_autocaption import WhisperAutoCaptionPipeline

from whisperplus.pipelines.autollm_chatbot import AutoLLMChatWithVideo
from whisperplus.pipelines.long_text_summarization import LongTextSummarizationPipeline
from whisperplus.pipelines.summarization import TextSummarizationPipeline
Expand All @@ -7,6 +9,6 @@
from whisperplus.utils.download_utils import download_and_convert_to_mp3
from whisperplus.utils.text_utils import format_speech_to_dialogue

__version__ = '0.3.0'
__version__ = '0.3.1'
__author__ = 'kadirnar'
__license__ = 'Apache License 2.0'
72 changes: 72 additions & 0 deletions whisperplus/audio_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from datasets import load_dataset
from evaluate import load
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.utils.patching import prepare_for_inference
from tqdm import tqdm
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, BitsAndBytesConfig, HqqConfig, pipeline
from transformers.pipelines.pt_utils import KeyDataset

from whisperplus.pipelines.whisper import SpeechToTextPipeline

HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo
HQQLinear.set_backend(HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available)

model_id = "distil-whisper/distil-large-v3"

hqq_config = HqqConfig(
nbits=4,
group_size=64,
quant_zero=False,
quant_scale=False,
axis=0,
offload_meta=False,
) # axis=0 is used by default

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)

model = SpeechToTextPipeline(model_id="distil-whisper/distil-large-v3", quant_config=hqq_config)

model = model.model

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
"automatic-speech-recognition",
model=model,
torch_dtype=torch.bfloat16,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
model_kwargs={"use_flash_attention_2": True},
)

wer_metric = load("wer")

common_voice_test = load_dataset(
"mozilla-foundation/common_voice_17_0", # mozilla-foundation/common_voice_17_0
"dv",
split="test")

all_predictions = []

# run streamed inference
for prediction in tqdm(
pipe(
KeyDataset(common_voice_test, "audio"),
max_new_tokens=128,
generate_kwargs={"task": "transcribe"},
batch_size=32,
),
total=len(common_voice_test),
):
all_predictions.append(prediction["text"])

wer_ortho = 100 * wer_metric.compute(references=common_voice_test["sentence"], predictions=all_predictions)

print(f"WER: {wer_ortho:.2f}%")
33 changes: 24 additions & 9 deletions whisperplus/pipelines/whisper.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,53 @@
import logging
from typing import Optional

import torch
from hqq.core.quantize import HQQBackend, HQQLinear
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo
HQQLinear.set_backend(HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class SpeechToTextPipeline:
"""Class for converting audio to text using a pre-trained speech recognition model."""

def __init__(self, model_id: str = "distil-whisper/distil-large-v3", quant_config=None):
def __init__(
self,
model_id: str = "distil-whisper/distil-large-v3",
quant_config=None,
hqq: Optional[bool] = True,
flash_attention_2: Optional[bool] = True):
self.model = None
self.device = None
self.hqq = hqq
self.flash_attention_2 = flash_attention_2

if self.model is None:
self.load_model(model_id)
else:
logging.info("Model already loaded.")

def load_model(self, model_id: str = "distil-whisper/distil-large-v3", quant_config=None):
if self.hqq:
from hqq.core.quantize import HQQBackend, HQQLinear
HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo
HQQLinear.set_backend(
HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available)

if self.flash_attention_2:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "sdpa"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2",
attn_implementation=attn_implementation,
torch_dtype=torch.bfloat16,
device_map='auto',
max_memory={0: "24GiB"})
)
logging.info("Model loaded successfully.")

processor = AutoProcessor.from_pretrained(model_id)
Expand Down Expand Up @@ -69,7 +84,7 @@ def __call__(
return_timestamps=return_timestamps,
tokenizer=self.processor.tokenizer,
feature_extractor=self.processor.feature_extractor,
model_kwargs={"use_flash_attention_2": True},
model_kwargs={"use_flash_attention_2": self.flash_attention_2},
generate_kwargs={"language": language},
)
logging.info("Transcribing audio...")
Expand Down

0 comments on commit e8cbada

Please sign in to comment.