diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..6e4f236 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = server \ No newline at end of file diff --git a/server/audio_processing.py b/server/audio_processing.py index 16c73af..e80b340 100644 --- a/server/audio_processing.py +++ b/server/audio_processing.py @@ -104,7 +104,7 @@ def detect_ads(transcript): raise -def remove_ads(audio, ad_segments): +def remove_ads(audio, ad_segments, flag=None): """Removes ad segments from the audio file with optimized processing.""" if not ad_segments: print("No ads to remove.") @@ -121,7 +121,7 @@ def remove_ads(audio, ad_segments): start, end = segment["start"] * 1000, segment["end"] * 1000 # Convert to milliseconds # Edge Case 1: Adjust ads in the first 5 seconds - if start <= 5000: + if start <= 5000 and flag=="first": start = max(0, start - 1000) # Edge Case 2: Merge close ads @@ -131,7 +131,7 @@ def remove_ads(audio, ad_segments): merged_ads.append({"start": start, "end": end}) # Edge Case 3: Remove everything from the last ad if it's near the end - if merged_ads and merged_ads[-1]["end"] >= total_duration - 10000: + if merged_ads and merged_ads[-1]["end"] >= total_duration - 10000 and flag=="last": merged_ads[-1]["end"] = total_duration # Extract non-ad sections with proper updating of previous_end @@ -189,10 +189,14 @@ def process_audio(audio_segment, url, streaming): logger.info(f"Transcription complete for chunk {i + 1}/{len(chunks)}: {url}") ad_segments = detect_ads(transcription) - logger.info(f"Ad-analysis complete for chunk {i + 1}/{len(chunks)}: {url}") - - processed_chunk = remove_ads(chunk, ad_segments) - logger.info(f"Processing complete for chunk {i + 1}/{len(chunks)}: {url}") + logger.info(f"Ad-analysis complete for chunk {i+1}/{len(chunks)}: {url}") + flag = None + if i == 0: + flag = "first" + elif i == len(chunks) - 1: + flag = "last" + processed_chunk = remove_ads(chunk, ad_segments, flag=flag) + logger.info(f"Processing complete for chunk {i+1}/{len(chunks)}: {url}") if i == 0 and not streaming: processed_chunk = intro + processed_chunk diff --git a/server/tests/__init__.py b/server/enums/__init__.py similarity index 100% rename from server/tests/__init__.py rename to server/enums/__init__.py diff --git a/server/tests/helpers/__init__.py b/server/helpers/__init__.py similarity index 100% rename from server/tests/helpers/__init__.py rename to server/helpers/__init__.py diff --git a/server/pytest.ini b/server/pytest.ini deleted file mode 100644 index 03f586d..0000000 --- a/server/pytest.ini +++ /dev/null @@ -1,2 +0,0 @@ -[pytest] -pythonpath = . \ No newline at end of file diff --git a/server/tests/helpers/test_cache_helpers.py b/server/tests/helpers/test_cache_helpers.py deleted file mode 100644 index 55c2fef..0000000 --- a/server/tests/helpers/test_cache_helpers.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytest -import logging -from unittest.mock import MagicMock, patch -from flask import Flask - -from server.helpers.cache_helpers import ( - setup_cache, initiate_key, cache_audio, - retrieve_audio, cached_rss_url, cached_source_url -) - -@pytest.fixture -def mock_redis(): - """Creates a mock Redis client.""" - mock = MagicMock() - mock.scan_iter.return_value = iter([]) # Default: No cached data - return mock - - -@pytest.fixture -def mock_app(): - """Creates a mock Flask app.""" - app = Flask(__name__) - return app - -def test_setup_cache(mock_app, mock_redis): - """Test that cache is properly set up.""" - with patch("helpers.cache_helpers.redis_client", mock_redis): - mock_app.config["CACHE_TYPE"] = "simple" # ✅ Use in-memory cache for testing - setup_cache(mock_app, mock_redis) - assert mock_redis is not None # Ensure redis_client is set - - -def test_initiate_key(mock_redis, caplog): - """Test that initiate_key sets a value in Redis.""" - with patch("helpers.cache_helpers.redis_client", mock_redis): - with caplog.at_level(logging.ERROR): - initiate_key("test_key") - mock_redis.set.assert_called_with("test_key", "INIT") # Fix: Ensure this call actually happens - assert "Error initializing key in cache" not in caplog.text - - -def test_cache_audio(mock_redis): - """Test that cache_audio stores the file path in Redis.""" - with patch("helpers.cache_helpers.redis_client", mock_redis): - cache_audio("audio_key", "/path/to/audio.mp3") - mock_redis.set.assert_called_with("audio_key", "/path/to/audio.mp3") - - -def test_retrieve_audio_found(mock_redis): - """Test that retrieve_audio returns the correct path if found.""" - with patch("helpers.cache_helpers.redis_client", mock_redis): - mock_redis.scan_iter.side_effect = lambda pattern: iter(["audio_key"]) - mock_redis.get.return_value = b"/path/to/audio.mp3" - - result = retrieve_audio("audio.mp3") - assert result == b"/path/to/audio.mp3" - - -def test_retrieve_audio_not_found(mock_redis): - """Test that retrieve_audio returns None if not found.""" - with patch("helpers.cache_helpers.redis_client", mock_redis): - mock_redis.scan_iter.return_value = iter([]) - - result = retrieve_audio("audio.mp3") - assert result is None - - -def test_cached_rss_url(mock_redis): - """Test that cached_rss_url checks if an RSS URL is cached.""" - with patch("helpers.cache_helpers.redis_client", mock_redis): - mock_redis.scan_iter.return_value = iter(["rss_url::source_url"]) - assert cached_rss_url("rss_url") is True # rss_url should be found - - mock_redis.scan_iter.return_value = iter([]) - assert cached_rss_url("rss_url") is False - - -def test_cached_source_url(mock_redis): - """Test that cached_source_url checks if a source URL is cached.""" - with patch("helpers.cache_helpers.redis_client", mock_redis): - mock_redis.scan_iter.return_value = iter(["rss_url::source_url"]) - assert cached_source_url("source_url") is True # source_url should be found - - mock_redis.scan_iter.return_value = iter([]) - assert cached_source_url("source_url") is False \ No newline at end of file diff --git a/server/tests/helpers/test_file_helpers.py b/server/tests/helpers/test_file_helpers.py deleted file mode 100644 index 9754c1c..0000000 --- a/server/tests/helpers/test_file_helpers.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import pytest -from server.helpers.file_helpers import allowed_file, save_file, sanitize_filename # Bytt ut "your_module" med riktig filnavn - - -def test_allowed_file(): - """Test allowed_file function with valid and invalid file extensions.""" - allowed_extensions = {"wav", "flacc", "mp3"} - - # Valid file-types - assert allowed_file("image.wav", allowed_extensions) is True - assert allowed_file("audio.mp3", allowed_extensions) is True - - # invalid file-types - assert allowed_file("document.pdf", allowed_extensions) is False - assert allowed_file("script.exe", allowed_extensions) is False - - # files with no extension - assert allowed_file("nofileextension", allowed_extensions) is False - - -def test_save_file(): - """Test that save_file returns the correct file path.""" - upload_folder = "/uploads" - - expected_path = os.path.abspath(os.path.normpath("/uploads/testfile.txt")) - result_path = os.path.abspath(save_file("testfile.txt", upload_folder)) - assert result_path == expected_path - - expected_path = os.path.abspath(os.path.normpath("/uploads/audio.mp3")) - result_path = os.path.abspath(save_file("audio.mp3", upload_folder)) - assert result_path == expected_path -def test_sanitize_filename(): - """Test sanitize_filename function to ensure invalid characters are removed.""" - - # Remove invalid char in files - assert sanitize_filename('test<>file.txt') == 'test__file.txt' - assert sanitize_filename('my|file?.mp3') == 'my_file_.mp3' - - # Not change anything if file name is normal - assert sanitize_filename('normal_file.txt') == 'normal_file.txt' - assert sanitize_filename('audio.mp3') == 'audio.mp3' \ No newline at end of file diff --git a/server/tests/helpers/test_url_helpers.py b/server/tests/helpers/test_url_helpers.py deleted file mode 100644 index 9c37197..0000000 --- a/server/tests/helpers/test_url_helpers.py +++ /dev/null @@ -1,35 +0,0 @@ -from server.helpers.url_helpers import ( - normalize_url, generate_cache_url, - extract_name, extract_title, extract_extension -) - - -def test_normalize_url(): - """Ensures that URLs are normalized to end with .mp3.""" - assert normalize_url("https://example.com/audio.mp3?param=value") == "https://example.com/audio.mp3" - - -def test_generate_cache_url(): - """Tests that the cache URL is generated correctly.""" - assert generate_cache_url("rss_url", "source_url") == "rss_url::source_url" - - -def test_extract_name(): - """Tests that the file extension is extracted correctly from the URL.""" - assert extract_name("https://example.com/audio.mp3") == "mp3" - assert extract_name("https://example.com/path/to/song.wav") == "wav" - assert extract_name("https://example.com/no-extension") == "" - - -def test_extract_title(): - """Tests that the title (filename without extension) is correctly extracted.""" - assert extract_title("./uploads/audio.mp3") == "./uploads/audio" - assert extract_title("./uploads/song.wav") == "./uploads/song" - assert extract_title("./uploads/no-extension") == "./uploads/no-extension" - - -def test_extract_extension(): - """Tests that the file extension is extracted correctly.""" - assert extract_extension("https://example.com/audio.mp3") == ".mp3" - assert extract_extension("https://example.com/song.WAV") == ".wav" - assert extract_extension("https://example.com/no-extension") == "" \ No newline at end of file diff --git a/server/tests/whisper/podblock_test.py b/server/tests/whisper/podblock_test.py index 4ef6d31..95c356d 100644 --- a/server/tests/whisper/podblock_test.py +++ b/server/tests/whisper/podblock_test.py @@ -1,241 +1,242 @@ -import os -import re -import logging -import json - -import time -import openai -import pytest -import textwrap - -from pathlib import Path -from dotenv import load_dotenv -from time import perf_counter -from io import BytesIO - -from pydub import AudioSegment -from faster_whisper import WhisperModel, BatchedInferencePipeline - -env_path = Path(__file__).parents[2] / "api.env" -load_dotenv(dotenv_path=env_path) - -client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) - -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") -logger = logging.getLogger(__name__) - -NORWEGIAN_AUDIO_PATH = Path(__file__).parent / "resources" / "norwegian.mp3" -ENGLISH_AUDIO_PATH = Path(__file__).parent / "resources" / "english.mp3" - -WHISPER_MODELS = ["tiny", "base", "small"] -GPT_MODELS = ["gpt-4o-mini", "gpt-4o"] - -RESULTS_DIR = Path(__file__).parent / "results" -RESULTS_DIR.mkdir(exist_ok=True, parents=True) - - -def test_transcribe_and_save(): - if not NORWEGIAN_AUDIO_PATH.exists(): - pytest.skip(f"Audio file not found: {NORWEGIAN_AUDIO_PATH}") - - if not ENGLISH_AUDIO_PATH.exists(): - pytest.skip(f"Audio file not found: {ENGLISH_AUDIO_PATH}") - - audio_segment_norwegian = AudioSegment.from_mp3(NORWEGIAN_AUDIO_PATH) - chunks_norwegian = chunk_audio(audio_segment_norwegian) - - audio_segment_english = AudioSegment.from_mp3(ENGLISH_AUDIO_PATH) - chunks_english = chunk_audio(audio_segment_english) - - if not chunks_norwegian: - pytest.skip(f"Audio could not be chunked: {NORWEGIAN_AUDIO_PATH}") - - if not chunks_english: - pytest.skip(f"Audio could not be chunked: {ENGLISH_AUDIO_PATH}") - - process_chunks(chunks_norwegian, audio_segment_norwegian, NORWEGIAN_AUDIO_PATH, 'norwegian') - process_chunks(chunks_english, audio_segment_english, ENGLISH_AUDIO_PATH, 'english') - - -def process_chunks(chunks, audio_segment, audio_path, language): - for whisper_name in WHISPER_MODELS: - try: - logger.info("Loading model %s", whisper_name) - whisper_model = WhisperModel(whisper_name, device="cpu", compute_type="int8") - batched_model = BatchedInferencePipeline(model=whisper_model) - - times = [] - ads_results = [] - transcripts = [] - - logger.info(f"Transcribing {audio_path} with {whisper_name}!") - for i, chunk in enumerate(chunks): - logger.info(f'Processing chunk {i + 1}/{len(chunks)} with model: {whisper_name}') - duration = chunk.duration_seconds - chunk_start = perf_counter() - transcription = transcribe_audio(chunk, batched_model) - chunk_end = perf_counter() - - transcription_time = chunk_end - chunk_start - times.append({ - "chunk_id": i, - "duration": duration, - "transcription_time": transcription_time - }) - - transcripts.append({ - "chunk_id": i, - "transcription": transcription - }) - - for gpt_name in GPT_MODELS: - logger.info(f'Detecting ads with {gpt_name}') - ad_detection, usage = detect_ads(transcription, gpt_name) - ads_results.append({ - "model": gpt_name, - "usage": usage, - "ads": { - "chunk_id": i, - "ads": ad_detection - } - }) - logger.info( - f'Processed chunk {i + 1}/{len(chunks)} with model: {whisper_name} in {transcription_time:.2f} seconds') - - save_result(language, whisper_name, times, ads_results, transcripts, RESULTS_DIR) - logger.info(f"Completed benchmark for model {whisper_name}") - time_file = RESULTS_DIR / language / whisper_name / "time.json" - assert time_file.exists(), f"{whisper_name} did not produce time.json" - - except Exception as e: - logger.error(f"Error processing model {whisper_name}: {str(e)}") - - -def save_result(language, model_name, times, ads, transcripts, output_dir): - model_dir = output_dir / language /model_name - model_dir.mkdir(exist_ok=True, parents=True) - - # Save timing information - time_file = model_dir / "time.json" - with open(time_file, "w") as f: - time_data = { - "model": model_name, - "total_time": sum(item["transcription_time"] for item in times), - "chunk_times": times - } - json.dump(time_data, f, indent=2) - - # Save ad detection results - ads_file = model_dir / "ads.json" - with open(ads_file, "w") as f: - ads_data = { - "model": model_name, - "ads": ads - } - json.dump(ads_data, f, indent=2) - - # Save plain text transcription - text_file = model_dir / "transcription.txt" - with open(text_file, "w") as f: - for chunk in transcripts: - f.write(f"--- Chunk {chunk['chunk_id']} ---\n") - # Extract all words and join them - if chunk["transcription"]: - if isinstance(chunk["transcription"], str): - # Handle already formatted string - f.write(f"{chunk['transcription']}\n\n") - else: - # Handle structured transcription - full_text = " ".join(item["text"] for item in chunk["transcription"]) - f.write(f"{full_text}\n\n") - - no_ts_file = model_dir / "transcript_plain.txt" - with open(no_ts_file, "w") as f_plain: - for chunk in transcripts: - raw = chunk["transcription"] - cleaned = re.sub(r"\[\d+\.\d+-\d+\.\d+\]\s*", "", raw) - paragraph = " ".join(cleaned.split()) - wrapped = textwrap.fill(paragraph, width=80) - f_plain.write(f"--- Chunk {chunk['chunk_id']} ---\n") - f_plain.write(wrapped + "\n\n") - - logger.info(f"Results for model {model_name} saved to {model_dir}") - - -def chunk_audio(audio, chunk_duration_seconds=240, chunk_duration_ms=240000): - duration_seconds = audio.duration_seconds - if duration_seconds <= chunk_duration_seconds: - chunks = [audio] - else: - chunks = [audio[i:i + chunk_duration_ms] for i in range(0, len(audio), chunk_duration_ms)] - return chunks - - -def detect_ads(transcript, llm_model): - try: - completion = \ - client.chat.completions.create( - model=llm_model, - messages=[ - { - "role": "system", - "content": "You are a system that detects ads in audio transcriptions from podcasts. " - "Based on the word-level timestamps provided, determine the start and end times of any ad segments. " - "For each ad segment, provide a 5-word summary of the ad. " - "Provide ad segments in the format: start: