From be9d1744744627e23d7b526946448021ba8114f2 Mon Sep 17 00:00:00 2001 From: kadirnar Date: Sat, 4 May 2024 13:17:53 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=92=AC=20Add=20new=20parameters=20for?= =?UTF-8?q?=20hqq=20optimization=20method?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/runpod.sh | 0 whisperplus/pipelines/whisper.py | 7 ++++++- whisperplus/test.py | 19 ++++++++++++------- 3 files changed, 18 insertions(+), 8 deletions(-) mode change 100644 => 100755 scripts/runpod.sh diff --git a/scripts/runpod.sh b/scripts/runpod.sh old mode 100644 new mode 100755 diff --git a/whisperplus/pipelines/whisper.py b/whisperplus/pipelines/whisper.py index cd157d8..46d4abf 100644 --- a/whisperplus/pipelines/whisper.py +++ b/whisperplus/pipelines/whisper.py @@ -1,8 +1,13 @@ import logging import torch +from hqq.core.quantize import HQQBackend, HQQLinear from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline +HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend +HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo +HQQLinear.set_backend(HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available) + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -25,7 +30,7 @@ def load_model(self, model_id: str = "distil-whisper/distil-large-v3", quant_con low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2", - torch_dtype=torch.float16, + torch_dtype=torch.bfloat16, device_map='auto', max_memory={0: "24GiB"}) logging.info("Model loaded successfully.") diff --git a/whisperplus/test.py b/whisperplus/test.py index b7eadbe..7a5ef79 100644 --- a/whisperplus/test.py +++ b/whisperplus/test.py @@ -1,15 +1,20 @@ -import time - import torch +from hqq.utils.patching import prepare_for_inference from pipelines.whisper import SpeechToTextPipeline from transformers import BitsAndBytesConfig, HqqConfig from utils.download_utils import download_and_convert_to_mp3 -url = "https://www.youtube.com/watch?v=di3rHkEZuUw" +url = "https://www.youtube.com/watch?v=BpN4hEAvDBg" audio_path = download_and_convert_to_mp3(url) hqq_config = HqqConfig( - nbits=1, group_size=64, quant_zero=False, quant_scale=False, axis=0) # axis=0 is used by default + nbits=1, + group_size=64, + quant_zero=False, + quant_scale=False, + axis=0, + offload_meta=False, +) # axis=0 is used by default bnb_config = BitsAndBytesConfig( load_in_4bit=True, @@ -18,14 +23,14 @@ bnb_4bit_use_double_quant=True, ) model = SpeechToTextPipeline( - model_id="distil-whisper/distil-large-v3", quant_config=bnb_config) # or bnb_config + model_id="distil-whisper/distil-large-v3", quant_config=hqq_config) # or bnb_config start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() transcript = model( - audio_path="testv0.mp3", + audio_path=audio_path, chunk_length_s=30, stride_length_s=5, max_new_tokens=128, @@ -36,4 +41,4 @@ torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) -print(f"Execution time: {elapsed_time_ms}ms") +print(f"Execution time: {elapsed_time_ms/1000}s") From e5ceca7cb43196e0c766444f8d6a060250739b48 Mon Sep 17 00:00:00 2001 From: kadirnar Date: Sat, 4 May 2024 13:35:01 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=90=9B=20Fix=20flake8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- whisperplus/test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/whisperplus/test.py b/whisperplus/test.py index 7a5ef79..f5f3622 100644 --- a/whisperplus/test.py +++ b/whisperplus/test.py @@ -41,4 +41,5 @@ torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) -print(f"Execution time: {elapsed_time_ms/1000}s") +seconds = elapsed_time_ms / 1000 +print(f"Elapsed time: {seconds} seconds")