diff --git a/whisper/timing.py b/whisper/timing.py index befcf464e..b695ead0a 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -299,6 +299,7 @@ def add_word_timestamps( word_durations = np.array([t.end - t.start for t in alignment]) word_durations = word_durations[word_durations.nonzero()] median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + median_duration = min(0.7, float(median_duration)) max_duration = median_duration * 2 # hack: truncate long words at sentence boundaries. diff --git a/whisper/transcribe.py b/whisper/transcribe.py index e80bede1d..1c075a201 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,7 +2,7 @@ import os import traceback import warnings -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import numpy as np import torch @@ -23,6 +23,7 @@ from .utils import ( exact_div, format_timestamp, + get_end, get_writer, make_safe, optional_float, @@ -48,6 +49,8 @@ def transcribe( word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, **decode_options, ): """ @@ -102,6 +105,14 @@ def transcribe( decode_options: dict Keyword arguments to construct `DecodingOptions` instances + clip_timestamps: Union[str, List[float]] + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. + The last end timestamp defaults to the end of the file. + + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold (in seconds) + when a possible hallucination is detected + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -121,6 +132,7 @@ def transcribe( # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES) content_frames = mel.shape[-1] - N_FRAMES + content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) if decode_options.get("language", None) is None: if not model.is_multilingual: @@ -147,6 +159,19 @@ def transcribe( task=task, ) + if isinstance(clip_timestamps, str): + clip_timestamps = [ + float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) + ] + seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + if word_timestamps and task == "translate": warnings.warn("Word-level timestamps on translations may not be reliable.") @@ -190,7 +215,8 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: return decode_result - seek = 0 + clip_idx = 0 + seek = seek_clips[clip_idx][0] input_stride = exact_div( N_FRAMES, model.dims.n_audio_ctx ) # mel frames per output token: 2 @@ -229,10 +255,23 @@ def new_segment( total=content_frames, unit="frames", disable=verbose is not False ) as pbar: last_speech_timestamp = 0.0 - while seek < content_frames: + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while clip_idx < len(seek_clips): + seek_clip_start, seek_clip_end = seek_clips[clip_idx] + if seek < seek_clip_start: + seek = seek_clip_start + if seek >= seek_clip_end: + clip_idx += 1 + if clip_idx < len(seek_clips): + seek = seek_clips[clip_idx][0] + continue time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - mel_segment = mel[:, seek : seek + N_FRAMES] - segment_size = min(N_FRAMES, content_frames - seek) + window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) + mel_segment = mel[:, seek : seek + segment_size] segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) @@ -257,6 +296,30 @@ def new_segment( previous_seek = seek current_segments = [] + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [w for w in segment["words"] if w["word"] not in punctuation] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] @@ -330,17 +393,71 @@ def new_segment( append_punctuations=append_punctuations, last_speech_timestamp=last_speech_timestamp, ) - word_end_timestamps = [ - w["end"] for s in current_segments for w in s["words"] - ] - if len(word_end_timestamps) > 0: - last_speech_timestamp = word_end_timestamps[-1] - if not single_timestamp_ending and len(word_end_timestamps) > 0: - seek_shift = round( - (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND - ) - if seek_shift > 0: - seek = previous_seek + seek_shift + + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + seek = round(last_word_end * FRAMES_PER_SECOND) + + # skip silence before possible hallucinations + if hallucination_silence_threshold is not None: + threshold = hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + remaining_duration = window_end_time - last_word_end + if remaining_duration > threshold: + seek = round(last_word_end * FRAMES_PER_SECOND) + else: + seek = previous_seek + segment_size + + # if first segment might be a hallucination, skip leading silence + first_segment = next_words_segment(current_segments) + if first_segment is not None and is_segment_anomaly(first_segment): + gap = first_segment["start"] - time_offset + if gap > threshold: + seek = previous_seek + round(gap * FRAMES_PER_SECOND) + continue + + # skip silence before any possible hallucination that is surrounded + # by silence or more hallucinations + hal_last_end = last_speech_timestamp + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if is_segment_anomaly(segment): + next_segment = next_words_segment( + current_segments[si + 1 :] + ) + if next_segment is not None: + hal_next_start = next_segment["words"][0]["start"] + else: + hal_next_start = time_offset + segment_duration + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or is_segment_anomaly(next_segment) + or window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + seek = round( + max(time_offset + 1, segment["start"]) + * FRAMES_PER_SECOND + ) + if content_duration - segment["end"] < threshold: + seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] + + last_word_end = get_end(current_segments) + if last_word_end is not None: + last_speech_timestamp = last_word_end if verbose: for segment in current_segments: @@ -427,6 +544,8 @@ def valid_model_name(name): parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") + parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") + parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") # fmt: on args = parser.parse_args().__dict__ diff --git a/whisper/utils.py b/whisper/utils.py index 7a172c401..9b9b13862 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -3,7 +3,7 @@ import re import sys import zlib -from typing import Callable, Optional, TextIO +from typing import Callable, List, Optional, TextIO system_encoding = sys.getdefaultencoding() @@ -68,6 +68,20 @@ def format_timestamp( ) +def get_start(segments: List[dict]) -> Optional[float]: + return next( + (w["start"] for s in segments for w in s["words"]), + segments[0]["start"] if segments else None, + ) + + +def get_end(segments: List[dict]) -> Optional[float]: + return next( + (w["end"] for s in reversed(segments) for w in reversed(s["words"])), + segments[-1]["end"] if segments else None, + ) + + class ResultWriter: extension: str @@ -129,8 +143,8 @@ def iterate_subtitles(): line_len = 0 line_count = 1 # the next subtitle to yield (a list of word timings with whitespace) - subtitle: list[dict] = [] - last = result["segments"][0]["words"][0]["start"] + subtitle: List[dict] = [] + last: float = get_start(result["segments"]) or 0.0 for segment in result["segments"]: chunk_index = 0 words_count = max_words_per_line