From 22f1f67eea63f2ed2d9424236b04069babfff5c6 Mon Sep 17 00:00:00 2001 From: kadirnar Date: Mon, 13 May 2024 02:38:50 +0300 Subject: [PATCH] Add 4bit and 8bit models for Lightning-Mlx library --- whisperplus/pipelines/lightning_whisper_mlx/lightning.py | 5 ++++- whisperplus/pipelines/mlx_whisper/__init__.py | 2 +- whisperplus/pipelines/mlx_whisper/requirements.txt | 9 --------- whisperplus/test.py | 3 +-- 4 files changed, 6 insertions(+), 13 deletions(-) delete mode 100644 whisperplus/pipelines/mlx_whisper/requirements.txt diff --git a/whisperplus/pipelines/lightning_whisper_mlx/lightning.py b/whisperplus/pipelines/lightning_whisper_mlx/lightning.py index aa6763e..0f0e50a 100644 --- a/whisperplus/pipelines/lightning_whisper_mlx/lightning.py +++ b/whisperplus/pipelines/lightning_whisper_mlx/lightning.py @@ -49,6 +49,8 @@ }, "distil-large-v3": { "base": "mustafaaljadery/distil-whisper-mlx", + "4bit": "mustafaaljadery/distil-whisper-mlx-4bit", + "8bit": "mustafaaljadery/distil-whisper-mlx-8bit", }, } @@ -91,9 +93,10 @@ def __init__(self, model, batch_size=12, quant=None): hf_hub_download(repo_id=repo_id, filename=filename2, local_dir=local_dir) def transcribe(self, audio_path, language=None): + breakpoint() result = transcribe_audio( audio_path, - path_or_hf_repo=f'./mlx_models/{self.name}', + path_or_hf_repo=f'mlx_models/{self.name}', language=language, batch_size=self.batch_size) return result diff --git a/whisperplus/pipelines/mlx_whisper/__init__.py b/whisperplus/pipelines/mlx_whisper/__init__.py index e1f6540..e6de085 100644 --- a/whisperplus/pipelines/mlx_whisper/__init__.py +++ b/whisperplus/pipelines/mlx_whisper/__init__.py @@ -1,5 +1,5 @@ # Copyright © 2023-2024 Apple Inc. from . import audio, decoding, load_models -from .transcribxe import transcribe +from .transcribe import transcribe from .version import __version__ diff --git a/whisperplus/pipelines/mlx_whisper/requirements.txt b/whisperplus/pipelines/mlx_whisper/requirements.txt deleted file mode 100644 index ca9c148..0000000 --- a/whisperplus/pipelines/mlx_whisper/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -mlx>=0.11 -numba -numpy -torch -tqdm -more-itertools -tiktoken -huggingface_hub -scipy diff --git a/whisperplus/test.py b/whisperplus/test.py index 5d42b0f..ea448aa 100644 --- a/whisperplus/test.py +++ b/whisperplus/test.py @@ -1,5 +1,4 @@ import torch -from hqq.utils.patching import prepare_for_inference from pipelines.whisper import SpeechToTextPipeline from transformers import BitsAndBytesConfig, HqqConfig from utils.download_utils import download_youtube_to_mp3 @@ -8,7 +7,7 @@ audio_path = download_youtube_to_mp3(url) hqq_config = HqqConfig( - nbits=1, + nbits=4, group_size=64, quant_zero=False, quant_scale=False,