Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 39 additions & 27 deletions yt_whisper/cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import os
from typing import List
import whisper
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
import argparse
import warnings
import yt_dlp
from .utils import slugify, str2bool, write_srt, write_vtt
import tempfile

import ffmpeg

def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("video", nargs="+", type=str,
help="video URLs to transcribe")
help="video URLs or files to transcribe")
parser.add_argument("--model", default="small",
choices=whisper.available_models(), help="name of the Whisper model to use")
parser.add_argument("--format", default="vtt",
Expand Down Expand Up @@ -41,31 +42,30 @@ def main():
args["language"] = "en"

model = whisper.load_model(model_name)
audios = get_audio(args.pop("video"))
break_lines = args.pop("break_lines")

for title, audio_path in audios.items():
warnings.filterwarnings("ignore")
result = model.transcribe(audio_path, **args)
warnings.filterwarnings("default")
with tempfile.TemporaryDirectory() as tmp_dir:
audios = get_audio(args.pop("video"), tmp_dir)
break_lines = args.pop("break_lines")

if (subtitles_format == 'vtt'):
vtt_path = os.path.join(output_dir, f"{slugify(title)}.vtt")
with open(vtt_path, 'w', encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt, line_length=break_lines)
for title, audio_path in audios.items():
warnings.filterwarnings("ignore")
result = model.transcribe(audio_path, **args)
warnings.filterwarnings("default")

print("Saved VTT to", os.path.abspath(vtt_path))
else:
srt_path = os.path.join(output_dir, f"{slugify(title)}.srt")
with open(srt_path, 'w', encoding="utf-8") as srt:
write_srt(result["segments"], file=srt, line_length=break_lines)
if (subtitles_format == 'vtt'):
vtt_path = os.path.join(output_dir, f"{slugify(title)}.vtt")
with open(vtt_path, 'w', encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt, line_length=break_lines)

print("Saved SRT to", os.path.abspath(srt_path))
print("Saved VTT to", os.path.abspath(vtt_path))
else:
srt_path = os.path.join(output_dir, f"{slugify(title)}.srt")
with open(srt_path, 'w', encoding="utf-8") as srt:
write_srt(result["segments"], file=srt, line_length=break_lines)

print("Saved SRT to", os.path.abspath(srt_path))

def get_audio(urls):
temp_dir = tempfile.gettempdir()

def get_audio(video_paths:List[str], temp_dir:tempfile.TemporaryDirectory):
ydl = yt_dlp.YoutubeDL({
'quiet': True,
'verbose': False,
Expand All @@ -76,12 +76,24 @@ def get_audio(urls):

paths = {}

for url in urls:
result = ydl.extract_info(url, download=True)
print(
f"Downloaded video \"{result['title']}\". Generating subtitles..."
)
paths[result["title"]] = os.path.join(temp_dir, f"{result['id']}.mp3")
for video_path in video_paths:
if os.path.exists(video_path):
title = os.path.basename(video_path).split(".")[0]
audio_path = os.path.join(temp_dir, f"{title}.mp3")
(ffmpeg
.input(video_path)
.audio
.output(audio_path)
.run()
)
paths[title] = audio_path
else:
result = ydl.extract_info(video_path, download=True)
print(
f"Downloaded video \"{result['title']}\". Generating subtitles..."
)
audio_path = os.path.join(temp_dir, f"{result['id']}.mp3")
paths[result["title"]] = audio_path

return paths

Expand Down