diff --git a/whisperplus/pipelines/whisper.py b/whisperplus/pipelines/whisper.py index 050915c..0a5612d 100644 --- a/whisperplus/pipelines/whisper.py +++ b/whisperplus/pipelines/whisper.py @@ -18,17 +18,6 @@ def __init__(self, model_id: str = "openai/whisper-large-v3"): else: logging.info("Model already loaded.") - self.set_device() - - def set_device(self): - """Sets the device to be used for inference based on availability.""" - if torch.backends.mps.is_available(): - self.device = "mps" - else: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - logging.info(f"Using device: {self.device}") - def load_model(self, model_id: str = "openai/whisper-large-v3"): """ Loads the pre-trained speech recognition model and moves it to the specified device. @@ -38,35 +27,39 @@ def load_model(self, model_id: str = "openai/whisper-large-v3"): """ logging.info("Loading model...") model = AutoModelForSpeechSeq2Seq.from_pretrained( - model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True) - model.to(self.device) + model_id, + low_cpu_mem_usage=True, + use_safetensors=True, + attn_implementation="flash_attention_2", + load_in_4bit=True, + device_map="auto") logging.info("Model loaded successfully.") + processor = AutoProcessor.from_pretrained(model_id) + + self.processor = processor self.model = model - def __call__(self, audio_path: str, model_id: str = "openai/whisper-large-v3", language: str = "turkish"): + def __call__(self, audio_path: str, language: str = "turkish"): """ Converts audio to text using the pre-trained speech recognition model. Args: audio_path (str): Path to the audio file to be transcribed. - model_id (str): Identifier of the pre-trained model to be used for transcription. Returns: str: Transcribed text from the audio. """ - processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=self.model, - torch_dtype=torch.float16, chunk_length_s=30, max_new_tokens=128, batch_size=24, + device_map="auto", return_timestamps=True, - device=self.device, - tokenizer=processor.tokenizer, - feature_extractor=processor.feature_extractor, + tokenizer=self.processor.tokenizer, + feature_extractor=self.processor.feature_extractor, model_kwargs={"use_flash_attention_2": True}, generate_kwargs={"language": language}, )