From 22bf6da1fa607d8db8b77ce5b0e7ddcc1bce4307 Mon Sep 17 00:00:00 2001 From: oddlama Date: Thu, 20 Jun 2024 17:13:54 +0200 Subject: [PATCH] feat: add new option "return_segments" to allow accessing word probabilities and other meta information --- README.md | 2 + RealtimeSTT/audio_recorder.py | 108 ++++++++++++++++++++++++++++------ 2 files changed, 93 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index cf2349f..4545e8d 100644 --- a/README.md +++ b/README.md @@ -255,6 +255,8 @@ When you initialize the `AudioToTextRecorder` class, you have various options to - **gpu_device_index** (int, default=0): GPU Device Index to use. The model can also be loaded on multiple GPUs by passing a list of IDs (e.g. [0, 1, 2, 3]). +- **return_segments** (bool, default=False): Return/pass a tuple (text, segments) from any function or callback related to transcibed text (like text(), on_realtime_*, ...), which includes the raw transcribed segments instead of just the text. Useful to observe probabilities or segment timings. + - **on_recording_start**: A callable function triggered when recording starts. - **on_recording_stop**: A callable function triggered when recording ends. diff --git a/RealtimeSTT/audio_recorder.py b/RealtimeSTT/audio_recorder.py index 531d2c6..208bb41 100644 --- a/RealtimeSTT/audio_recorder.py +++ b/RealtimeSTT/audio_recorder.py @@ -92,6 +92,7 @@ def __init__(self, compute_type: str = "default", input_device_index: int = 0, gpu_device_index: Union[int, List[int]] = 0, + return_segments=False, on_recording_start=None, on_recording_stop=None, on_transcription_start=None, @@ -173,6 +174,11 @@ def __init__(self, IDs (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel when transcribe() is called from multiple Python threads + - return_segments (bool, default=False): Return/pass a tuple (text, segments) + from any function or callback related to transcibed text (like text(), + on_realtime_*, ...), which includes the raw transcribed segments + instead of just the text. Useful to observe probabilities or + segment timings. - on_recording_start (callable, default=None): Callback function to be called when recording of audio to be transcripted starts. - on_recording_stop (callable, default=None): Callback function to be @@ -298,6 +304,7 @@ def __init__(self, self.compute_type = compute_type self.input_device_index = input_device_index self.gpu_device_index = gpu_device_index + self.return_segments = return_segments self.wake_words = wake_words self.wake_word_activation_delay = wake_word_activation_delay self.wake_word_timeout = wake_word_timeout @@ -354,8 +361,10 @@ def __init__(self, self.state = "inactive" self.wakeword_detected = False self.text_storage = [] + self.segment_storage = [] self.realtime_stabilized_text = "" self.realtime_stabilized_safetext = "" + self.realtime_stabilized_safesegments = [] self.is_webrtc_speech_active = False self.is_silero_speech_active = False self.recording_thread = None @@ -418,7 +427,8 @@ def __init__(self, self.interrupt_stop_event, self.beam_size, self.initial_prompt, - self.suppress_tokens + self.suppress_tokens, + self.return_segments # word_timestamps on if return_segments is on ) ) self.transcript_process.start() @@ -571,7 +581,8 @@ def _transcription_worker(conn, interrupt_stop_event, beam_size, initial_prompt, - suppress_tokens + suppress_tokens, + word_timestamps ): """ Worker method that handles the continuous @@ -604,6 +615,7 @@ def _transcription_worker(conn, to the transcription model. suppress_tokens (list of int): Tokens to be suppressed from the transcription output. + word_timestamps (bool): enable world level timestamps Raises: Exception: If there is an error while initializing the transcription model. @@ -638,17 +650,18 @@ def _transcription_worker(conn, if conn.poll(0.5): audio, language = conn.recv() try: - segments = model.transcribe( + segments, _ = model.transcribe( audio, language=language if language else None, beam_size=beam_size, initial_prompt=initial_prompt, - suppress_tokens=suppress_tokens + suppress_tokens=suppress_tokens, + word_timestamps=word_timestamps, ) - segments = segments[0] + segments = list(segments) # Convert generator to list transcription = " ".join(seg.text for seg in segments) transcription = transcription.strip() - conn.send(('success', transcription)) + conn.send(('success', (transcription, segments))) except Exception as e: logging.error(f"General transcription error: {e}") conn.send(('error', str(e))) @@ -846,8 +859,9 @@ def transcribe(self): self._set_state("inactive") if status == 'success': + text, segments = result self.last_transcription_bytes = audio_copy - return self._preprocess_output(result) + return self._preprocess_output(text, segments) else: logging.error(result) raise Exception(result) @@ -873,6 +887,7 @@ def text(self, If omitted, the transcription will be performed synchronously, and the result will be returned. + Returns (if not callback is set): str: The transcription of the recorded audio """ @@ -885,7 +900,7 @@ def text(self, if self.is_shut_down or self.interrupt_stop_event.is_set(): if self.interrupt_stop_event.is_set(): self.was_interrupted.set() - return "" + return ("", []) if self.return_segments else "" if on_transcription_finished: threading.Thread(target=on_transcription_finished, @@ -910,8 +925,10 @@ def start(self): logging.info("recording started") self._set_state("recording") self.text_storage = [] + self.segment_storage = [] self.realtime_stabilized_text = "" self.realtime_stabilized_safetext = "" + self.realtime_stabilized_safesegments = [] self.wakeword_detected = False self.wake_word_detect_time = 0 self.frames = [] @@ -1286,13 +1303,15 @@ def _realtime_worker(self): INT16_MAX_ABS_VALUE # Perform transcription and assemble the text - segments = self.realtime_model_type.transcribe( + segments, _ = self.realtime_model_type.transcribe( audio_array, language=self.language if self.language else None, beam_size=self.beam_size_realtime, initial_prompt=self.initial_prompt, suppress_tokens=self.suppress_tokens, + word_timestamps=self.return_segments, ) + segments = list(segments) # Convert generator to list # double check recording state # because it could have changed mid-transcription @@ -1300,15 +1319,15 @@ def _realtime_worker(self): self.recording_start_time > 0.5: logging.debug('Starting realtime transcription') + self.realtime_transcription_segments = list(segments) self.realtime_transcription_text = " ".join( - seg.text for seg in segments[0] + seg.text for seg in self.realtime_transcription_segments ) self.realtime_transcription_text = \ self.realtime_transcription_text.strip() - self.text_storage.append( - self.realtime_transcription_text - ) + self.text_storage.append(self.realtime_transcription_text) + self.segment_storage.append(self.realtime_transcription_segments) # Take the last two texts in storage, if they exist if len(self.text_storage) >= 2: @@ -1330,6 +1349,31 @@ def _realtime_worker(self): # as additional security self.realtime_stabilized_safetext = prefix + # Find the corresponding segments for the prefix + # by incremental reconstruction. Modify the previous + # segment list to achieve this. + prefix_segments = [] + partial_prefix = "" + if self.return_segments: + # Make a copy (this is relatively cheap, not much to copy) + prefix_segments = copy.deepcopy(self.segment_storage[-1]) + for seg in prefix_segments: + # Replace the current word list with an empty one + old_words = list(seg.words) + seg.words.clear() + # Iterate over old words and only append as long as prefix is the same + for word in old_words: + new_partial_prefix = (partial_prefix + word.word) + # As long as the prefix still matches, append the current segment. + if not prefix.startswith(new_partial_prefix.strip()): + # Don't break outer but continue to next segment, + # to ensure all word lists are purged + continue + partial_prefix = new_partial_prefix + seg.words.append(word) + + self.realtime_stabilized_safesegments = prefix_segments + # Find parts of the stabilized text # in the freshly transcripted text matching_pos = self._find_tail_match_in_text( @@ -1342,17 +1386,43 @@ def _realtime_worker(self): self._on_realtime_transcription_stabilized( self._preprocess_output( self.realtime_stabilized_safetext, - True + self.realtime_stabilized_safesegments, + preview=True ) ) else: self._on_realtime_transcription_stabilized( self._preprocess_output( self.realtime_transcription_text, - True + self.realtime_transcription_segments, + preview=True ) ) else: + # Get all segments up to the matching_pos. + pre_match_segments = [] + partial_prefix = "" + if self.return_segments: + # Make a copy (this is relatively cheap, not much to copy) + prefix_segments = copy.deepcopy(self.realtime_stabilized_safesegments) + for seg in prefix_segments: + # Replace the current word list with an empty one + old_words = list(seg.words) + seg.words.clear() + # Iterate over old words and only append as long as prefix is the same + for word in seg.words: + new_partial_prefix = (partial_prefix + word.word) + # As long as the match pos isn't reached, append the current segment. + if len(new_partial_prefix.strip()) >= matching_pos: + # Don't break outer but continue to next segment, + # to ensure all word lists are purged + continue + partial_prefix = new_partial_prefix + seg.words.append(word) + + # Apppend the new stuff + all_segments = pre_match_segments + self.realtime_transcription_segments + # We found parts of the stabilized text # in the transcripted text # We now take the stabilized text @@ -1365,13 +1435,14 @@ def _realtime_worker(self): # parts on the first run without the need for # two transcriptions self._on_realtime_transcription_stabilized( - self._preprocess_output(output_text, True) + self._preprocess_output(output_text, all_segments, True) ) # Invoke the callback with the transcribed text self._on_realtime_transcription_update( self._preprocess_output( self.realtime_transcription_text, + self.realtime_transcription_segments, True ) ) @@ -1553,7 +1624,7 @@ def _set_spinner(self, text): else: self.halo.text = text - def _preprocess_output(self, text, preview=False): + def _preprocess_output(self, text, segments, preview=False): """ Preprocesses the output text by removing any leading or trailing whitespace, converting all whitespace sequences to a single space @@ -1578,8 +1649,11 @@ def _preprocess_output(self, text, preview=False): if text and text[-1].isalnum(): text += '.' + if self.return_segments: + return (text, segments) return text + def _find_tail_match_in_text(self, text1, text2, length_of_match=10): """ Find the position where the last 'n' characters of text1