Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ When you initialize the `AudioToTextRecorder` class, you have various options to

- **gpu_device_index** (int, default=0): GPU Device Index to use. The model can also be loaded on multiple GPUs by passing a list of IDs (e.g. [0, 1, 2, 3]).

- **return_segments** (bool, default=False): Return/pass a tuple (text, segments) from any function or callback related to transcibed text (like text(), on_realtime_*, ...), which includes the raw transcribed segments instead of just the text. Useful to observe probabilities or segment timings.

- **on_recording_start**: A callable function triggered when recording starts.

- **on_recording_stop**: A callable function triggered when recording ends.
Expand Down
108 changes: 91 additions & 17 deletions RealtimeSTT/audio_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self,
compute_type: str = "default",
input_device_index: int = 0,
gpu_device_index: Union[int, List[int]] = 0,
return_segments=False,
on_recording_start=None,
on_recording_stop=None,
on_transcription_start=None,
Expand Down Expand Up @@ -173,6 +174,11 @@ def __init__(self,
IDs (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can
run in parallel when transcribe() is called from multiple Python
threads
- return_segments (bool, default=False): Return/pass a tuple (text, segments)
from any function or callback related to transcibed text (like text(),
on_realtime_*, ...), which includes the raw transcribed segments
instead of just the text. Useful to observe probabilities or
segment timings.
- on_recording_start (callable, default=None): Callback function to be
called when recording of audio to be transcripted starts.
- on_recording_stop (callable, default=None): Callback function to be
Expand Down Expand Up @@ -298,6 +304,7 @@ def __init__(self,
self.compute_type = compute_type
self.input_device_index = input_device_index
self.gpu_device_index = gpu_device_index
self.return_segments = return_segments
self.wake_words = wake_words
self.wake_word_activation_delay = wake_word_activation_delay
self.wake_word_timeout = wake_word_timeout
Expand Down Expand Up @@ -354,8 +361,10 @@ def __init__(self,
self.state = "inactive"
self.wakeword_detected = False
self.text_storage = []
self.segment_storage = []
self.realtime_stabilized_text = ""
self.realtime_stabilized_safetext = ""
self.realtime_stabilized_safesegments = []
self.is_webrtc_speech_active = False
self.is_silero_speech_active = False
self.recording_thread = None
Expand Down Expand Up @@ -418,7 +427,8 @@ def __init__(self,
self.interrupt_stop_event,
self.beam_size,
self.initial_prompt,
self.suppress_tokens
self.suppress_tokens,
self.return_segments # word_timestamps on if return_segments is on
)
)
self.transcript_process.start()
Expand Down Expand Up @@ -571,7 +581,8 @@ def _transcription_worker(conn,
interrupt_stop_event,
beam_size,
initial_prompt,
suppress_tokens
suppress_tokens,
word_timestamps
):
"""
Worker method that handles the continuous
Expand Down Expand Up @@ -604,6 +615,7 @@ def _transcription_worker(conn,
to the transcription model.
suppress_tokens (list of int): Tokens to be suppressed from the
transcription output.
word_timestamps (bool): enable world level timestamps
Raises:
Exception: If there is an error while initializing the
transcription model.
Expand Down Expand Up @@ -638,17 +650,18 @@ def _transcription_worker(conn,
if conn.poll(0.5):
audio, language = conn.recv()
try:
segments = model.transcribe(
segments, _ = model.transcribe(
audio,
language=language if language else None,
beam_size=beam_size,
initial_prompt=initial_prompt,
suppress_tokens=suppress_tokens
suppress_tokens=suppress_tokens,
word_timestamps=word_timestamps,
)
segments = segments[0]
segments = list(segments) # Convert generator to list
transcription = " ".join(seg.text for seg in segments)
transcription = transcription.strip()
conn.send(('success', transcription))
conn.send(('success', (transcription, segments)))
except Exception as e:
logging.error(f"General transcription error: {e}")
conn.send(('error', str(e)))
Expand Down Expand Up @@ -846,8 +859,9 @@ def transcribe(self):

self._set_state("inactive")
if status == 'success':
text, segments = result
self.last_transcription_bytes = audio_copy
return self._preprocess_output(result)
return self._preprocess_output(text, segments)
else:
logging.error(result)
raise Exception(result)
Expand All @@ -873,6 +887,7 @@ def text(self,
If omitted, the transcription will be performed synchronously,
and the result will be returned.


Returns (if not callback is set):
str: The transcription of the recorded audio
"""
Expand All @@ -885,7 +900,7 @@ def text(self,
if self.is_shut_down or self.interrupt_stop_event.is_set():
if self.interrupt_stop_event.is_set():
self.was_interrupted.set()
return ""
return ("", []) if self.return_segments else ""

if on_transcription_finished:
threading.Thread(target=on_transcription_finished,
Expand All @@ -910,8 +925,10 @@ def start(self):
logging.info("recording started")
self._set_state("recording")
self.text_storage = []
self.segment_storage = []
self.realtime_stabilized_text = ""
self.realtime_stabilized_safetext = ""
self.realtime_stabilized_safesegments = []
self.wakeword_detected = False
self.wake_word_detect_time = 0
self.frames = []
Expand Down Expand Up @@ -1286,29 +1303,31 @@ def _realtime_worker(self):
INT16_MAX_ABS_VALUE

# Perform transcription and assemble the text
segments = self.realtime_model_type.transcribe(
segments, _ = self.realtime_model_type.transcribe(
audio_array,
language=self.language if self.language else None,
beam_size=self.beam_size_realtime,
initial_prompt=self.initial_prompt,
suppress_tokens=self.suppress_tokens,
word_timestamps=self.return_segments,
)
segments = list(segments) # Convert generator to list

# double check recording state
# because it could have changed mid-transcription
if self.is_recording and time.time() - \
self.recording_start_time > 0.5:

logging.debug('Starting realtime transcription')
self.realtime_transcription_segments = list(segments)
self.realtime_transcription_text = " ".join(
seg.text for seg in segments[0]
seg.text for seg in self.realtime_transcription_segments
)
self.realtime_transcription_text = \
self.realtime_transcription_text.strip()

self.text_storage.append(
self.realtime_transcription_text
)
self.text_storage.append(self.realtime_transcription_text)
self.segment_storage.append(self.realtime_transcription_segments)

# Take the last two texts in storage, if they exist
if len(self.text_storage) >= 2:
Expand All @@ -1330,6 +1349,31 @@ def _realtime_worker(self):
# as additional security
self.realtime_stabilized_safetext = prefix

# Find the corresponding segments for the prefix
# by incremental reconstruction. Modify the previous
# segment list to achieve this.
prefix_segments = []
partial_prefix = ""
if self.return_segments:
# Make a copy (this is relatively cheap, not much to copy)
prefix_segments = copy.deepcopy(self.segment_storage[-1])
for seg in prefix_segments:
# Replace the current word list with an empty one
old_words = list(seg.words)
seg.words.clear()
# Iterate over old words and only append as long as prefix is the same
for word in old_words:
new_partial_prefix = (partial_prefix + word.word)
# As long as the prefix still matches, append the current segment.
if not prefix.startswith(new_partial_prefix.strip()):
# Don't break outer but continue to next segment,
# to ensure all word lists are purged
continue
partial_prefix = new_partial_prefix
seg.words.append(word)

self.realtime_stabilized_safesegments = prefix_segments

# Find parts of the stabilized text
# in the freshly transcripted text
matching_pos = self._find_tail_match_in_text(
Expand All @@ -1342,17 +1386,43 @@ def _realtime_worker(self):
self._on_realtime_transcription_stabilized(
self._preprocess_output(
self.realtime_stabilized_safetext,
True
self.realtime_stabilized_safesegments,
preview=True
)
)
else:
self._on_realtime_transcription_stabilized(
self._preprocess_output(
self.realtime_transcription_text,
True
self.realtime_transcription_segments,
preview=True
)
)
else:
# Get all segments up to the matching_pos.
pre_match_segments = []
partial_prefix = ""
if self.return_segments:
# Make a copy (this is relatively cheap, not much to copy)
prefix_segments = copy.deepcopy(self.realtime_stabilized_safesegments)
for seg in prefix_segments:
# Replace the current word list with an empty one
old_words = list(seg.words)
seg.words.clear()
# Iterate over old words and only append as long as prefix is the same
for word in seg.words:
new_partial_prefix = (partial_prefix + word.word)
# As long as the match pos isn't reached, append the current segment.
if len(new_partial_prefix.strip()) >= matching_pos:
# Don't break outer but continue to next segment,
# to ensure all word lists are purged
continue
partial_prefix = new_partial_prefix
seg.words.append(word)

# Apppend the new stuff
all_segments = pre_match_segments + self.realtime_transcription_segments

# We found parts of the stabilized text
# in the transcripted text
# We now take the stabilized text
Expand All @@ -1365,13 +1435,14 @@ def _realtime_worker(self):
# parts on the first run without the need for
# two transcriptions
self._on_realtime_transcription_stabilized(
self._preprocess_output(output_text, True)
self._preprocess_output(output_text, all_segments, True)
)

# Invoke the callback with the transcribed text
self._on_realtime_transcription_update(
self._preprocess_output(
self.realtime_transcription_text,
self.realtime_transcription_segments,
True
)
)
Expand Down Expand Up @@ -1553,7 +1624,7 @@ def _set_spinner(self, text):
else:
self.halo.text = text

def _preprocess_output(self, text, preview=False):
def _preprocess_output(self, text, segments, preview=False):
"""
Preprocesses the output text by removing any leading or trailing
whitespace, converting all whitespace sequences to a single space
Expand All @@ -1578,8 +1649,11 @@ def _preprocess_output(self, text, preview=False):
if text and text[-1].isalnum():
text += '.'

if self.return_segments:
return (text, segments)
return text


def _find_tail_match_in_text(self, text1, text2, length_of_match=10):
"""
Find the position where the last 'n' characters of text1
Expand Down