diff --git a/scripts/runpod.sh b/scripts/runpod.sh old mode 100644 new mode 100755 diff --git a/whisperplus/pipelines/whisper.py b/whisperplus/pipelines/whisper.py index cd157d8..46d4abf 100644 --- a/whisperplus/pipelines/whisper.py +++ b/whisperplus/pipelines/whisper.py @@ -1,8 +1,13 @@ import logging import torch +from hqq.core.quantize import HQQBackend, HQQLinear from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline +HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend +HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo +HQQLinear.set_backend(HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available) + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -25,7 +30,7 @@ def load_model(self, model_id: str = "distil-whisper/distil-large-v3", quant_con low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2", - torch_dtype=torch.float16, + torch_dtype=torch.bfloat16, device_map='auto', max_memory={0: "24GiB"}) logging.info("Model loaded successfully.") diff --git a/whisperplus/test.py b/whisperplus/test.py index b7eadbe..f5f3622 100644 --- a/whisperplus/test.py +++ b/whisperplus/test.py @@ -1,15 +1,20 @@ -import time - import torch +from hqq.utils.patching import prepare_for_inference from pipelines.whisper import SpeechToTextPipeline from transformers import BitsAndBytesConfig, HqqConfig from utils.download_utils import download_and_convert_to_mp3 -url = "https://www.youtube.com/watch?v=di3rHkEZuUw" +url = "https://www.youtube.com/watch?v=BpN4hEAvDBg" audio_path = download_and_convert_to_mp3(url) hqq_config = HqqConfig( - nbits=1, group_size=64, quant_zero=False, quant_scale=False, axis=0) # axis=0 is used by default + nbits=1, + group_size=64, + quant_zero=False, + quant_scale=False, + axis=0, + offload_meta=False, +) # axis=0 is used by default bnb_config = BitsAndBytesConfig( load_in_4bit=True, @@ -18,14 +23,14 @@ bnb_4bit_use_double_quant=True, ) model = SpeechToTextPipeline( - model_id="distil-whisper/distil-large-v3", quant_config=bnb_config) # or bnb_config + model_id="distil-whisper/distil-large-v3", quant_config=hqq_config) # or bnb_config start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() transcript = model( - audio_path="testv0.mp3", + audio_path=audio_path, chunk_length_s=30, stride_length_s=5, max_new_tokens=128, @@ -36,4 +41,5 @@ torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) -print(f"Execution time: {elapsed_time_ms}ms") +seconds = elapsed_time_ms / 1000 +print(f"Elapsed time: {seconds} seconds")