Skip to content

Commit

Permalink
Merge pull request #77 from kadirnar/add-hqq
Browse files Browse the repository at this point in the history
🥇 Add hqq optimization algorithm
  • Loading branch information
kadirnar authored May 3, 2024
2 parents c0d67f1 + cef7d20 commit 24a7998
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 30 deletions.
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
## 🛠️ Installation

```bash
pip install whisperplus
pip install whisperplus git+https://github.com/huggingface/transformers
pip install flash-attn --no-build-isolation
```

Expand All @@ -33,12 +33,35 @@ To use the whisperplus library, follow the steps below for different tasks:

```python
from whisperplus import SpeechToTextPipeline, download_and_convert_to_mp3
from transformers import BitsAndBytesConfig, HqqConfig

url = "https://www.youtube.com/watch?v=di3rHkEZuUw"

url = "https://www.youtube.com/watch?v=di3rHkEZuUw"
audio_path = download_and_convert_to_mp3(url)
pipeline = SpeechToTextPipeline(model_id="openai/whisper-large-v3")
transcript = pipeline(audio_path, "openai/whisper-large-v3", "english")

quant_config = HqqConfig(
nbits=1,
group_size=64,
quant_zero=False,
quant_scale=False, axis=0) #axis=0 is used by default


bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)

pipeline = SpeechToTextPipeline(model_id=distil-whisper/distil-large-v3, quant_config=quant_config) # or bnb_config
transcript = pipeline(
audio_path: str = "test.mp3",
chunk_length_s: int = 30,
stride_length_s: int = 5,
max_new_tokens: int = 128,
batch_size: int = 100,
language: str = "english",
)

print(transcript)
```
Expand Down
27 changes: 13 additions & 14 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
torch>=2.0.0
torchvision>=0.15.0
torchaudio>=2.0.0
gradio==4.14.0
langchain==0.1.0
moviepy==1.0.3
numpy==1.24.1
pyannote.audio==3.1.0
pyannote.core==5.0.0
pyannote.database==5.0.1
pyannote.metrics==3.2.1
pyannote.pipeline==3.0.1
pytube==15.0.0
Requests==2.31.0
transformers==4.35.2
gradio>=4.14.0
langchain>=0.1.0
moviepy>=1.0.3
numpy>=1.24.1
pyannote.audio>=3.1.0
pyannote.core>=5.0.0
pyannote.database>=5.0.1
pyannote.metrics>=3.2.1
pyannote.pipeline>=3.0.1
pytube>=15.0.0
Requests>=2.31.0
sentence-transformers
ctransformers
accelerate
pre-commit==3.4.0
autollm==0.1.9
speechbrain==0.5.16
autollm>=0.1.9
speechbrain>=0.5.16
bitsandbytes
2 changes: 1 addition & 1 deletion whisperplus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
from whisperplus.utils.download_utils import download_and_convert_to_mp3
from whisperplus.utils.text_utils import format_speech_to_dialogue

__version__ = '0.2.7'
__version__ = '0.2.7.2.dev1'
__author__ = 'kadirnar'
__license__ = 'Apache License 2.0'
14 changes: 3 additions & 11 deletions whisperplus/pipelines/whisper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, BitsAndBytesConfig, pipeline
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

Expand All @@ -18,24 +17,17 @@ def __init__(self, model_id: str = "openai/whisper-large-v3"):
else:
logging.info("Model already loaded.")

def load_model(self, model_id: str = "openai/whisper-large-v3"):
def load_model(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
"""
Loads the pre-trained speech recognition model and moves it to the specified device.
Args:
model_id (str): Identifier of the pre-trained model to be loaded.
"""
logging.info("Loading model...")

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=bnb_config,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2",
Expand Down

0 comments on commit 24a7998

Please sign in to comment.