Skip to content

Commit

Permalink
Fixed CPU unusable on CUDA machines
Browse files Browse the repository at this point in the history
  • Loading branch information
octimot authored May 4, 2023
1 parent 4885e4b commit 4081e34
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8811,6 +8811,12 @@ def whisper_device_select(self, device):
:return:
'''

allowed_devices = ['cuda', 'CUDA', 'gpu', 'GPU', 'cpu', 'CPU']

# change the whisper device if it was passed as a parameter
if device is not None and device in allowed_devices:
self.whisper_device = device

# if the whisper device is set to cuda
if self.whisper_device in ['cuda', 'CUDA', 'gpu', 'GPU']:
# use CUDA if available
Expand Down Expand Up @@ -10110,6 +10116,10 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,

# what is the name of the audio file
audio_file_name = os.path.basename(audio_file_path)

whisper_device_changed = False
if 'device' in other_whisper_options and self.whisper_device != other_whisper_options['device']:
whisper_device_changed = True

# select the device that was passed (if any)
if 'device' in other_whisper_options:
Expand All @@ -10120,7 +10130,8 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,
# load OpenAI Whisper model
# and hold it loaded for future use (unless another model was passed via other_whisper_options)
if self.whisper_model is None \
or ('model' in other_whisper_options and self.whisper_model_name != other_whisper_options['model']):
or ('model' in other_whisper_options and self.whisper_model_name != other_whisper_options['model'])\
or whisper_device_changed:

# update the status of the item in the transcription log
self.update_transcription_log(unique_id=queue_id, **{'status': 'loading model'})
Expand All @@ -10147,7 +10158,7 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,

logger.info('Loading Whisper {} model.'.format(self.whisper_model_name))
try:
self.whisper_model = whisper.load_model(self.whisper_model_name)
self.whisper_model = whisper.load_model(self.whisper_model_name, device=self.whisper_device)
except Exception as e:
logger.error('Error loading Whisper {} model: {}'.format(self.whisper_model_name, e))

Expand Down

0 comments on commit 4081e34

Please sign in to comment.