diff --git a/yt_whisper/cli.py b/yt_whisper/cli.py index 77763e9..d87a88b 100644 --- a/yt_whisper/cli.py +++ b/yt_whisper/cli.py @@ -1,4 +1,5 @@ import os +from typing import List import whisper from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE import argparse @@ -6,13 +7,13 @@ import yt_dlp from .utils import slugify, str2bool, write_srt, write_vtt import tempfile - +import ffmpeg def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("video", nargs="+", type=str, - help="video URLs to transcribe") + help="video URLs or files to transcribe") parser.add_argument("--model", default="small", choices=whisper.available_models(), help="name of the Whisper model to use") parser.add_argument("--format", default="vtt", @@ -41,31 +42,30 @@ def main(): args["language"] = "en" model = whisper.load_model(model_name) - audios = get_audio(args.pop("video")) - break_lines = args.pop("break_lines") - - for title, audio_path in audios.items(): - warnings.filterwarnings("ignore") - result = model.transcribe(audio_path, **args) - warnings.filterwarnings("default") + with tempfile.TemporaryDirectory() as tmp_dir: + audios = get_audio(args.pop("video"), tmp_dir) + break_lines = args.pop("break_lines") - if (subtitles_format == 'vtt'): - vtt_path = os.path.join(output_dir, f"{slugify(title)}.vtt") - with open(vtt_path, 'w', encoding="utf-8") as vtt: - write_vtt(result["segments"], file=vtt, line_length=break_lines) + for title, audio_path in audios.items(): + warnings.filterwarnings("ignore") + result = model.transcribe(audio_path, **args) + warnings.filterwarnings("default") - print("Saved VTT to", os.path.abspath(vtt_path)) - else: - srt_path = os.path.join(output_dir, f"{slugify(title)}.srt") - with open(srt_path, 'w', encoding="utf-8") as srt: - write_srt(result["segments"], file=srt, line_length=break_lines) + if (subtitles_format == 'vtt'): + vtt_path = os.path.join(output_dir, f"{slugify(title)}.vtt") + with open(vtt_path, 'w', encoding="utf-8") as vtt: + write_vtt(result["segments"], file=vtt, line_length=break_lines) - print("Saved SRT to", os.path.abspath(srt_path)) + print("Saved VTT to", os.path.abspath(vtt_path)) + else: + srt_path = os.path.join(output_dir, f"{slugify(title)}.srt") + with open(srt_path, 'w', encoding="utf-8") as srt: + write_srt(result["segments"], file=srt, line_length=break_lines) + print("Saved SRT to", os.path.abspath(srt_path)) -def get_audio(urls): - temp_dir = tempfile.gettempdir() +def get_audio(video_paths:List[str], temp_dir:tempfile.TemporaryDirectory): ydl = yt_dlp.YoutubeDL({ 'quiet': True, 'verbose': False, @@ -76,12 +76,24 @@ def get_audio(urls): paths = {} - for url in urls: - result = ydl.extract_info(url, download=True) - print( - f"Downloaded video \"{result['title']}\". Generating subtitles..." - ) - paths[result["title"]] = os.path.join(temp_dir, f"{result['id']}.mp3") + for video_path in video_paths: + if os.path.exists(video_path): + title = os.path.basename(video_path).split(".")[0] + audio_path = os.path.join(temp_dir, f"{title}.mp3") + (ffmpeg + .input(video_path) + .audio + .output(audio_path) + .run() + ) + paths[title] = audio_path + else: + result = ydl.extract_info(video_path, download=True) + print( + f"Downloaded video \"{result['title']}\". Generating subtitles..." + ) + audio_path = os.path.join(temp_dir, f"{result['id']}.mp3") + paths[result["title"]] = audio_path return paths