From e93c997ad4cdcc6328905f44daf59b1cadc7a270 Mon Sep 17 00:00:00 2001 From: kadirnar Date: Thu, 2 May 2024 03:28:08 +0300 Subject: [PATCH] =?UTF-8?q?=E2=AD=90=20Add=20sdpa=20optimization=20paramet?= =?UTF-8?q?er?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- whisperplus/pipelines/whisper.py | 33 +++++++++++++------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/whisperplus/pipelines/whisper.py b/whisperplus/pipelines/whisper.py index 050915c..0a5612d 100644 --- a/whisperplus/pipelines/whisper.py +++ b/whisperplus/pipelines/whisper.py @@ -18,17 +18,6 @@ def __init__(self, model_id: str = "openai/whisper-large-v3"): else: logging.info("Model already loaded.") - self.set_device() - - def set_device(self): - """Sets the device to be used for inference based on availability.""" - if torch.backends.mps.is_available(): - self.device = "mps" - else: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - logging.info(f"Using device: {self.device}") - def load_model(self, model_id: str = "openai/whisper-large-v3"): """ Loads the pre-trained speech recognition model and moves it to the specified device. @@ -38,35 +27,39 @@ def load_model(self, model_id: str = "openai/whisper-large-v3"): """ logging.info("Loading model...") model = AutoModelForSpeechSeq2Seq.from_pretrained( - model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True) - model.to(self.device) + model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + attn_implementation="flash_attention_2", + load_in_4bit=True, + device_map="auto") logging.info("Model loaded successfully.") + processor = AutoProcessor.from_pretrained(model_id) + + self.processor = processor self.model = model - def __call__(self, audio_path: str, model_id: str = "openai/whisper-large-v3", language: str = "turkish"): + def __call__(self, audio_path: str, language: str = "turkish"): """ Converts audio to text using the pre-trained speech recognition model. Args: audio_path (str): Path to the audio file to be transcribed. - model_id (str): Identifier of the pre-trained model to be used for transcription. Returns: str: Transcribed text from the audio. """ - processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=self.model, - torch_dtype=torch.float16, chunk_length_s=30, max_new_tokens=128, batch_size=24, + device_map="auto", return_timestamps=True, - device=self.device, - tokenizer=processor.tokenizer, - feature_extractor=processor.feature_extractor, + tokenizer=self.processor.tokenizer, + feature_extractor=self.processor.feature_extractor, model_kwargs={"use_flash_attention_2": True}, generate_kwargs={"language": language}, )