Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions Video-Description-Generation-Query-Retrieval/st_video_rag_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import streamlit as st
from sentence_transformers import SentenceTransformer

# Define a safe dataset root directory
SAFE_DATASET_ROOT = os.path.abspath("./datasets")
os.makedirs(SAFE_DATASET_ROOT, exist_ok=True)

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -149,7 +153,7 @@

# Dataset configuration
st.subheader("📁 Dataset")
dataset_folder = st.text_input("Video Folder", ".")
dataset_folder = st.text_input(f"Video Folder (relative to {SAFE_DATASET_ROOT})", ".")
max_videos = st.slider("Max Videos", 1, 128, 20)

st.markdown("---")
Expand Down Expand Up @@ -241,14 +245,24 @@ def generate_video_description_ollama(video_path, model, max_tokens=100, tempera


def get_video_paths(folder, max_count):
"""Get video file paths from folder."""
"""Get video file paths from folder. Validates that 'folder' is within SAFE_DATASET_ROOT."""
try:
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv']
video_files = []

folder_path = os.path.abspath(folder)

for root, dirs, files in os.walk(folder_path):
# Normalize and validate to ensure folder stays inside SAFE_DATASET_ROOT
user_folder = folder.strip()
# If the user gives an absolute path, remove its "/"
user_folder = os.path.relpath(user_folder, "/") if os.path.isabs(user_folder) else user_folder
safe_target_path = os.path.normpath(os.path.abspath(os.path.join(SAFE_DATASET_ROOT, user_folder)))
if not safe_target_path.startswith(SAFE_DATASET_ROOT):
logging.error(f"Attempted access to forbidden folder: {safe_target_path}")
return []
if not os.path.isdir(safe_target_path):
logging.error(f"Folder does not exist: {safe_target_path}")
return []

for root, dirs, files in os.walk(safe_target_path):
video_files.extend([
os.path.join(root, f) for f in files
if any(f.lower().endswith(ext) for ext in video_extensions)
Expand Down
Loading