diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 82a705649..8f34b0ef3 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -10,7 +10,7 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram from .decoding import DecodingOptions, DecodingResult from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_vtt, write_srt +from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt if TYPE_CHECKING: from .model import Whisper @@ -295,7 +295,7 @@ def cli(): # save TXT with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: - print(result["text"], file=txt) + write_txt(result["segments"], file=txt) # save VTT with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: diff --git a/whisper/utils.py b/whisper/utils.py index c8f9c0dbe..b63ade7f7 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -44,6 +44,11 @@ def format_timestamp(seconds: float, always_include_hours: bool = False): return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}" +def write_txt(transcript: Iterator[dict], file: TextIO): + for segment in transcript: + print(segment['text'].strip(), file=file, flush=True) + + def write_vtt(transcript: Iterator[dict], file: TextIO): print("WEBVTT\n", file=file) for segment in transcript: