Skip to content

Commit

Permalink
⭐Add BitsAndBytesConfig support
Browse files Browse the repository at this point in the history
  • Loading branch information
kadirnar authored May 2, 2024
1 parent e93c997 commit cea1508
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions whisperplus/pipelines/whisper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, BitsAndBytesConfig

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

Expand All @@ -26,13 +26,21 @@ def load_model(self, model_id: str = "openai/whisper-large-v3"):
model_id (str): Identifier of the pre-trained model to be loaded.
"""
logging.info("Loading model...")

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=bnb_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2",
load_in_4bit=True,
device_map="auto")

logging.info("Model loaded successfully.")

processor = AutoProcessor.from_pretrained(model_id)
Expand All @@ -54,8 +62,9 @@ def __call__(self, audio_path: str, language: str = "turkish"):
"automatic-speech-recognition",
model=self.model,
chunk_length_s=30,
stride_length_s=5,
max_new_tokens=128,
batch_size=24,
batch_size=100,
device_map="auto",
return_timestamps=True,
tokenizer=self.processor.tokenizer,
Expand Down

0 comments on commit cea1508

Please sign in to comment.