diff --git a/podcastfy/aiengines/__init__.py b/podcastfy/aiengines/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/podcastfy/aiengines/llm/base.py b/podcastfy/aiengines/llm/base.py
new file mode 100644
index 0000000..071f79f
--- /dev/null
+++ b/podcastfy/aiengines/llm/base.py
@@ -0,0 +1,23 @@
+from abc import ABC, abstractmethod
+from typing import List, Tuple
+
+from podcastfy.core.character import Character
+from podcastfy.core.content import Content
+
+
+class LLMBackend(ABC):
+ """Abstract base class for Language Model backends."""
+ # TODO a nice mixin/helper could be made to load prompt templates from conf file (both podcast settings and character settings)
+
+ @abstractmethod
+ def generate_transcript(self, content: List[Content], characters: List[Character]) -> List[Tuple[Character, str]]:
+ """
+ Generate text based on a given prompt.
+
+ Args:
+ prompt (str): The input prompt for text generation.
+
+ Returns:
+ List[Tuple[Character, str]]: A list of tuples containing speaker and text.
+ """
+ pass
diff --git a/podcastfy/aiengines/llm/gemini_langchain.py b/podcastfy/aiengines/llm/gemini_langchain.py
new file mode 100644
index 0000000..0b9084e
--- /dev/null
+++ b/podcastfy/aiengines/llm/gemini_langchain.py
@@ -0,0 +1,152 @@
+"""
+Content Generator Module
+
+This module is responsible for generating Q&A content based on input texts using
+LangChain and Google's Generative AI (Gemini). It handles the interaction with the AI model and
+provides methods to generate and save the generated content.
+"""
+
+import os
+import re
+from typing import Optional, Dict, Any, List, Tuple
+
+from langchain_community.llms.llamafile import Llamafile
+from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
+from langchain_google_genai import ChatGoogleGenerativeAI
+from langchain_core.output_parsers import StrOutputParser
+from langchain import hub
+
+from podcastfy.content_generator import ContentGenerator
+from podcastfy.core.character import Character
+from podcastfy.aiengines.llm.base import LLMBackend
+from podcastfy.core.content import Content
+from podcastfy.utils.config_conversation import load_conversation_config
+from podcastfy.utils.config import load_config
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class DefaultPodcastifyTranscriptEngine(LLMBackend):
+ def __init__(self, api_key: str, conversation_config: Optional[Dict[str, Any]] = None, is_local: bool = False):
+ """
+ Initialize the DefaultPodcastifyTranscriptEngine.
+
+ Args:
+ api_key (str): API key for Google's Generative AI.
+ conversation_config (Optional[Dict[str, Any]]): Custom conversation configuration.
+ """
+ self.content_generator = ContentGenerator(api_key, conversation_config)
+ self.is_local = is_local
+
+ def split_qa(self, input_text: str) -> List[Tuple[str, str]]:
+ """
+ Split the input text into question-answer pairs.
+
+ Args:
+ input_text (str): The input text containing Person1 and Person2 dialogues.
+
+ Returns:
+ List[Tuple[str, str]]: A list of tuples containing (Person1, Person2) dialogues.
+ """
+ # Add ending message to the end of input_text
+ input_text += f"{self.content_generator.ending_message}"
+
+ # Regular expression pattern to match Person1 and Person2 dialogues
+ pattern = r'(.*?)\s*(.*?)'
+
+ # Find all matches in the input text
+ matches = re.findall(pattern, input_text, re.DOTALL)
+
+ # Process the matches to remove extra whitespace and newlines
+ processed_matches = [
+ (
+ ' '.join(person1.split()).strip(),
+ ' '.join(person2.split()).strip()
+ )
+ for person1, person2 in matches
+ ]
+ return processed_matches
+
+ def generate_transcript(self, content: List[Content], characters: List[Character]) -> List[Tuple[Character, str]]:
+ image_file_paths = [c.value for c in content if c.type == 'image_path']
+ text_content = "\n\n".join(c.value for c in content if c.type == 'text')
+ content = self.content_generator.generate_qa_content(text_content, image_file_paths, is_local=self.is_local) # ideally in the future we pass characters here
+
+ q_a_pairs = self.split_qa(content)
+ transcript = []
+ for q_a_pair in q_a_pairs:
+ # Assign the speakers based on the order of the characters
+ speaker1, speaker2 = characters
+ speaker_1_text, speaker_2_text = q_a_pair
+ transcript.append((speaker1, speaker_1_text))
+ transcript.append((speaker2, speaker_2_text))
+ return transcript
+
+ # def generate_transcript(self, prompt: str, characters: List[Character]) -> List[Tuple[Character, str]]:
+ # content = self.content_generator.generate_qa_content(prompt, output_filepath=None, characters=characters)
+ #
+ # # Parse the generated content into the required format
+ # transcript = []
+ # for line in content.split('\n'):
+ # if ':' in line:
+ # speaker_name, text = line.split(':', 1)
+ # speaker = next((char for char in characters if char.name == speaker_name.strip()), None)
+ # if speaker:
+ # transcript.append((speaker, text.strip()))
+ #
+ # return transcript
+
+
+
+def main(seed: int = 42) -> None:
+ """
+ Generate Q&A content based on input text from input_text.txt using the Gemini API.
+
+ Args:
+ seed (int): Random seed for reproducibility. Defaults to 42.
+
+ Returns:
+ None
+ """
+ try:
+ # Load configuration
+ config = load_config()
+
+ # Get the Gemini API key from the configuration
+ api_key = config.GEMINI_API_KEY
+ if not api_key:
+ raise ValueError("GEMINI_API_KEY not found in configuration")
+
+ # Initialize ContentGenerator
+ content_generator = DefaultPodcastifyTranscriptEngine(api_key)
+
+ # Read input text from file
+ input_text = ""
+ transcript_dir = config.get('output_directories', {}).get('transcripts', 'data/transcripts')
+ for filename in os.listdir(transcript_dir):
+ if filename.endswith('.txt'):
+ with open(os.path.join(transcript_dir, filename), 'r') as file:
+ input_text += file.read() + "\n\n"
+
+ # Generate Q&A content
+ config_conv = load_conversation_config()
+ characters = [
+ Character(name="Speaker 1", role=config_conv.get('roles_person1')),
+ Character(name="Speaker 2", role=config_conv.get('roles_person2')),
+ ]
+ response = content_generator.generate_transcript(input_text, characters)
+
+ # Print the generated Q&A content
+ print("Generated Q&A Content:")
+ # Output response text to file
+ output_file = os.path.join(config.get('output_directories', {}).get('transcripts', 'data/transcripts'), 'response.txt')
+ with open(output_file, 'w') as file:
+ file.write(response)
+
+ except Exception as e:
+ logger.error(f"An error occurred while generating Q&A content: {str(e)}")
+ raise
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/podcastfy/aiengines/tts/base.py b/podcastfy/aiengines/tts/base.py
new file mode 100644
index 0000000..a776bd0
--- /dev/null
+++ b/podcastfy/aiengines/tts/base.py
@@ -0,0 +1,116 @@
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Dict, Any, List, Union
+
+import yaml
+
+from podcastfy.core.character import Character
+from podcastfy.core.tts_configs import TTSConfig
+
+TTSBackend = Union["SyncTTSBackend", "AsyncTTSBackend"]
+
+
+class SyncTTSBackend(ABC):
+ """Protocol for synchronous Text-to-Speech backends."""
+
+ name: str
+
+ @abstractmethod
+ def text_to_speech(self, text: str, character: Character, output_path: Path) -> None:
+ """
+ Convert text to speech synchronously.
+
+ Args:
+ text (str): The text to convert to speech.
+ character (Character): The character for which to generate speech.
+ output_path (Path): The path to save the generated audio file.
+
+ Returns:
+ Path: The path to the generated audio file.
+ """
+ pass
+
+
+class AsyncTTSBackend(ABC):
+ """Protocol for asynchronous Text-to-Speech backends."""
+
+ name: str
+
+ @abstractmethod
+ async def async_text_to_speech(self, text: str, character: Character, output_path: Path) -> None:
+ """
+ Convert text to speech asynchronously.
+
+ Args:
+ text (str): The text to convert to speech.
+ character (Character): The character for which to generate speech.
+ output_path (Path): The path to save the generated audio file.
+
+ Returns:
+ Path: The path to the generated audio file.
+ """
+ pass
+class TTSConfigMixin:
+ """Mixin class to manage TTS external configurations."""
+
+ def __init__(self, config_file: str = 'podcastfy/conversation_config.yaml', name: str = "") -> None:
+ self.name = name
+ self.config_file = config_file
+ self.default_configs = self._load_default_configs()
+ self.tts_config_call_count = 0
+ self.character_tts_mapping = {}
+
+ def _load_default_configs(self) -> Dict[str, Any]:
+ with open(self.config_file, 'r') as f:
+ config = yaml.safe_load(f)
+ tts_config = config.get('text_to_speech', {})
+ return tts_config.get(self.name, {})
+
+ def get_default_config(self) -> Dict[str, Any]:
+ return self.default_configs
+
+ def update_default_config(self, new_config: Dict[str, Any]) -> None:
+ self.default_configs.update(new_config)
+
+ def tts_config_for_character(self, character: Character) -> TTSConfig:
+ # note: a bit constrained by the fact that the config has just the question and answer fields
+ if character.name in self.character_tts_mapping:
+ return self.character_tts_mapping[character.name]
+
+ # Check if the character has a TTS config for this backend
+ if self.name in character.tts_configs:
+ tts_config = character.tts_configs[self.name]
+ else:
+ # If not, use the default config
+ default_voices = self.default_configs.get('default_voices', {})
+ if self.tts_config_call_count == 0:
+ voice = default_voices['question']
+ else:
+ voice = default_voices['answer']
+ model = self.default_configs.get('model')
+ self.tts_config_call_count += 1
+
+ tts_config = TTSConfig(
+ voice=voice,
+ backend=self.name,
+ extra_args={"model": model} if model else {}
+ )
+
+ # Merge the default config with the character-specific config
+ merged_config = TTSConfig(
+ voice=tts_config.voice or self.default_configs.get('default_voices', {}).get('question' if self.tts_config_call_count == 1 else 'answer', ''),
+ backend=self.name,
+ extra_args={**self.default_configs.get('extra_args', {}), **tts_config.extra_args}
+ )
+
+ self.character_tts_mapping[character.name] = merged_config
+ return merged_config
+
+ # This line is no longer needed as we always return a merged config
+
+ def preload_character_tts_mapping(self, characters: List[Character]) -> None:
+ for character in characters:
+ self.tts_config_for_character(character)
+
+ def get_character_tts_mapping(self) -> Dict[str, TTSConfig]:
+ return self.character_tts_mapping
diff --git a/podcastfy/aiengines/tts/tts_backends.py b/podcastfy/aiengines/tts/tts_backends.py
new file mode 100644
index 0000000..83e59b3
--- /dev/null
+++ b/podcastfy/aiengines/tts/tts_backends.py
@@ -0,0 +1,108 @@
+import os
+import uuid
+from abc import abstractmethod
+from pathlib import Path
+from tempfile import TemporaryFile, TemporaryDirectory
+from typing import Dict, Any, List, ClassVar
+import asyncio
+
+import openai
+
+import edge_tts
+from elevenlabs import client as elevenlabs_client
+
+from podcastfy.aiengines.tts.base import SyncTTSBackend, TTSConfigMixin, AsyncTTSBackend
+from podcastfy.core.character import Character
+
+
+class ElevenLabsTTS(SyncTTSBackend, AsyncTTSBackend, TTSConfigMixin):
+ name: str = "elevenlabs"
+
+ def __init__(self, api_key: str = None, config_file: str = 'podcastfy/conversation_config.yaml'):
+ TTSConfigMixin.__init__(self, config_file, name=self.name)
+ self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY")
+
+ def text_to_speech(self, text: str, character: Character, output_path: Path) -> Path:
+ config = self.tts_config_for_character(character)
+ client = elevenlabs_client.ElevenLabs(api_key=self.api_key) # # client could be reused
+ content = client.generate(
+ text=text,
+ voice=config.voice,
+ model=config.extra_args.get('model', self.get_default_config().get('model', 'default'))
+ )
+ with open(output_path, "wb") as out:
+ for chunk in content:
+ if chunk:
+ out.write(chunk)
+ return output_path
+
+ async def async_text_to_speech(self, text: str, character: Character, output_path: Path) -> Path:
+ config = self.tts_config_for_character(character)
+ client = elevenlabs_client.AsyncElevenLabs(api_key=self.api_key)
+ content = await client.generate(
+ text=text,
+ voice=config.voice,
+ model=config.extra_args.get('model', self.get_default_config().get('model', 'default'))
+ )
+ with open(output_path, "wb") as out:
+ async for chunk in content:
+ if chunk:
+ out.write(chunk)
+
+
+class OpenAITTS(SyncTTSBackend, TTSConfigMixin):
+ name: str = "openai"
+
+ def __init__(self, api_key: str = None, config_file: str = 'podcastfy/conversation_config.yaml'):
+ TTSConfigMixin.__init__(self, config_file, name=self.name)
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY")
+
+ def text_to_speech(self, text: str, character: Character, output_path: Path) -> None:
+ config = self.tts_config_for_character(character)
+
+ print(f"OpenAI TTS: Converting text to speech for character {character.name} with voice {config.voice} \n text: {text}")
+ model = config.extra_args.get('model', self.get_default_config().get('model', 'tts-1'))
+ response = openai.audio.speech.create(
+ model=model,
+ voice=config.voice,
+ input=text
+ )
+ with open(output_path, "wb") as file:
+ file.write(response.content)
+
+
+
+class EdgeTTS(AsyncTTSBackend, TTSConfigMixin):
+ name: str = "edge"
+
+ def __init__(self, config_file: str = 'podcastfy/conversation_config.yaml'):
+ TTSConfigMixin.__init__(self, config_file, name=self.name)
+
+ async def async_text_to_speech(self, text: str, character: Character, output_path: Path) -> None:
+ config = self.tts_config_for_character(character)
+ communicate = edge_tts.Communicate(text, config.voice)
+ await communicate.save(str(output_path))
+
+# register
+SyncTTSBackend.register(ElevenLabsTTS)
+AsyncTTSBackend.register(ElevenLabsTTS)
+SyncTTSBackend.register(OpenAITTS)
+AsyncTTSBackend.register(EdgeTTS)
+
+
+
+# Example usage:
+if __name__ == "__main__":
+ from podcastfy.utils.config import load_config
+
+ config = load_config()
+ elevenlabs_tts = ElevenLabsTTS(config.ELEVENLABS_API_KEY)
+ openai_tts = OpenAITTS(config.OPENAI_API_KEY)
+ edge_tts = EdgeTTS()
+
+ dummy_character1 = Character("character1", "host", {}, "A friendly podcast host")
+ dummy_character2 = Character("character2", "guest", {}, "An expert guest")
+
+ output_dir = Path("output")
+ output_dir.mkdir(exist_ok=True)
+
diff --git a/podcastfy/client.py b/podcastfy/client.py
index 5b2b764..48ff400 100644
--- a/podcastfy/client.py
+++ b/podcastfy/client.py
@@ -5,121 +5,173 @@
from URLs or existing transcript files. It orchestrates the content extraction,
generation, and text-to-speech conversion processes.
"""
+import copy
import os
import uuid
import typer
import yaml
+
+from podcastfy.aiengines.llm.gemini_langchain import DefaultPodcastifyTranscriptEngine
+from podcastfy.aiengines.tts.base import TTSBackend
+from podcastfy.aiengines.tts.tts_backends import OpenAITTS, ElevenLabsTTS, EdgeTTS
+from podcastfy.core.audio import AudioManager
+from podcastfy.core.character import Character
+from podcastfy.core.content import Content
+from podcastfy.core.podcast import Podcast
+from podcastfy.core.transcript import Transcript
from podcastfy.content_parser.content_extractor import ContentExtractor
-from podcastfy.content_generator import ContentGenerator
-from podcastfy.text_to_speech import TextToSpeech
+from podcastfy.core.tts_configs import TTSConfig
from podcastfy.utils.config import Config, load_config
from podcastfy.utils.config_conversation import (
- ConversationConfig,
load_conversation_config,
)
from podcastfy.utils.logger import setup_logger
from typing import List, Optional, Dict, Any
-import copy
-
logger = setup_logger(__name__)
app = typer.Typer()
+def create_characters(config: Dict[str, Any]) -> List[Character]:
+ # in the future, we should load this from the config file
+ host = Character(
+ name="Person1",
+ role="Podcast host",
+ tts_configs={
+ "openai": TTSConfig(
+ voice=config["text_to_speech"]["openai"]["default_voices"]["question"],
+ backend="openai",
+ ),
+ "elevenlabs": TTSConfig(
+ voice=config["text_to_speech"]["elevenlabs"]["default_voices"][
+ "question"
+ ],
+ backend="elevenlabs",
+ ),
+ },
+ default_description_for_llm="{name} is an enthusiastic podcast host. Speaks clearly and engagingly.",
+ )
+
+ guest = Character(
+ name="Person2",
+ role="Expert guest",
+ tts_configs={
+ "openai": TTSConfig(
+ voice=config["text_to_speech"]["openai"]["default_voices"]["answer"],
+ backend="openai",
+ ),
+ "elevenlabs": TTSConfig(
+ voice=config["text_to_speech"]["elevenlabs"]["default_voices"][
+ "answer"
+ ],
+ backend="elevenlabs",
+ ),
+ },
+ default_description_for_llm="{name} is an expert guest. Shares knowledge in a friendly manner.",
+ )
+
+ return [host, guest]
+
+
+def create_tts_backends(config: Config) -> List[TTSBackend]:
+ return [
+ OpenAITTS(api_key=config.OPENAI_API_KEY),
+ ElevenLabsTTS(api_key=config.ELEVENLABS_API_KEY),
+ EdgeTTS(),
+ ]
-def process_content(
- urls=None,
- transcript_file=None,
- tts_model="openai",
- generate_audio=True,
- config=None,
- conversation_config: Optional[Dict[str, Any]] = None,
- image_paths: Optional[List[str]] = None,
- is_local: bool = False,
-):
- """
- Process URLs, a transcript file, or image paths to generate a podcast or transcript.
- Args:
- urls (Optional[List[str]]): A list of URLs to process.
- transcript_file (Optional[str]): Path to a transcript file.
- tts_model (str): The TTS model to use ('openai', 'elevenlabs' or 'edge'). Defaults to 'openai'.
- generate_audio (bool): Whether to generate audio or just a transcript. Defaults to True.
- config (Config): Configuration object to use. If None, default config will be loaded.
- conversation_config (Optional[Dict[str, Any]]): Custom conversation configuration.
- image_paths (Optional[List[str]]): List of image file paths to process.
- is_local (bool): Whether to use a local LLM. Defaults to False.
- Returns:
- Optional[str]: Path to the final podcast audio file, or None if only generating a transcript.
- """
+def process_content(
+ urls: Optional[List[str]] = None,
+ transcript_file: Optional[str] = None,
+ tts_model: str = "openai", # to be fixed, in case of characters, it should be a list of models
+ generate_audio: bool = True,
+ config: Optional[Config] = None,
+ conversation_config: Optional[Dict[str, Any]] = None,
+ image_paths: Optional[List[str]] = None,
+ is_local: bool = False,
+) -> str:
try:
if config is None:
config = load_config()
-
+ if urls is None:
+ urls = []
+ if config is None:
+ config = load_config()
# Load default conversation config
conv_config = load_conversation_config()
-
+
# Update with provided config if any
if conversation_config:
conv_config.configure(conversation_config)
-
+ characters = create_characters(conv_config.config_conversation)
+ tts_backends = obtain_tts_backend(config, tts_model)
+ audio_format = conv_config.config_conversation.get('text_to_speech')['audio_format']
+ temp_dir = conv_config.config_conversation.get('text_to_speech').get('temp_audio_dir')
+ audio_manager = AudioManager(tts_backends, audio_format=audio_format, audio_temp_dir=temp_dir, n_jobs=4)
if transcript_file:
logger.info(f"Using transcript file: {transcript_file}")
- with open(transcript_file, "r") as file:
- qa_content = file.read()
+ transcript = Transcript.load(
+ transcript_file, {char.name: char for char in characters}
+ )
+ podcast = Podcast.from_transcript(transcript, audio_manager, characters)
else:
- content_generator = ContentGenerator(
- api_key=config.GEMINI_API_KEY, conversation_config=conv_config.to_dict()
+ logger.info(f"Processing {len(urls)} links")
+ content_extractor = ContentExtractor()
+ content_generator = DefaultPodcastifyTranscriptEngine(
+ config.GEMINI_API_KEY, conversation_config, is_local=is_local
)
- if urls:
- logger.info(f"Processing {len(urls)} links")
- content_extractor = ContentExtractor()
- # Extract content from links
- contents = [content_extractor.extract_content(link) for link in urls]
- # Combine all extracted content
- combined_content = "\n\n".join(contents)
- else:
- combined_content = "" # Empty string if no URLs provided
-
- # Generate Q&A content
- random_filename = f"transcript_{uuid.uuid4().hex}.txt"
- transcript_filepath = os.path.join(
- config.get("output_directories")["transcripts"], random_filename
- )
- qa_content = content_generator.generate_qa_content(
- combined_content,
- image_file_paths=image_paths or [],
- output_filepath=transcript_filepath,
- is_local=is_local,
+ contents = [content_extractor.extract_content(url) for url in urls]
+ llm_contents = []
+ if contents:
+ llm_contents.append(Content(value="\n\n".join(contents), type="text"))
+ if image_paths:
+ llm_contents.extend(
+ [Content(value=image_path, type="image_path") for image_path in image_paths]
+ )
+ podcast = Podcast(
+ content=llm_contents,
+ llm_backend=content_generator,
+ audio_manager=audio_manager,
+ characters=characters,
)
+ directories = config.get("output_directories")
+ random_filename_no_suffix = f"podcast_{uuid.uuid4().hex}"
+ random_filename_mp3 = f"{random_filename_no_suffix}.mp3"
+ random_filename_transcript = f"{random_filename_no_suffix}.txt"
+ transcript_file_path = os.path.join(directories["transcripts"], random_filename_transcript)
if generate_audio:
- api_key = None
- # edge does not require an API key
- if tts_model != "edge":
- api_key = getattr(config, f"{tts_model.upper()}_API_KEY")
-
- text_to_speech = TextToSpeech(model=tts_model, api_key=api_key)
- # Convert text to speech using the specified model
- random_filename = f"podcast_{uuid.uuid4().hex}.mp3"
+ podcast.finalize()
+
+ # for the sake of the tests currently in place, but in the future, we should remove this and return the podcast object
audio_file = os.path.join(
- config.get("output_directories")["audio"], random_filename
+ directories["audio"], random_filename_mp3
)
- text_to_speech.convert_to_speech(qa_content, audio_file)
- logger.info(f"Podcast generated successfully using {tts_model} TTS model")
- return audio_file
+ podcast.transcript.export(transcript_file_path)
+ podcast.save(filepath=audio_file)
+ return audio_file # note: should return the podcast object instead, but for the sake of the tests, we return the audio file
else:
- logger.info(f"Transcript generated successfully: {transcript_filepath}")
- return transcript_filepath
-
+ podcast.build_transcript()
+ podcast.transcript.export(transcript_file_path)
+ logger.info(f"Transcript generated successfully: {random_filename_transcript}")
+ return transcript_file_path
except Exception as e:
logger.error(f"An error occurred in the process_content function: {str(e)}")
raise
+def obtain_tts_backend(config, tts_model) -> Dict[str, TTSBackend]:
+ # temporary solution
+ tts_backends = create_tts_backends(config)
+ # filter out the tts backends that are not in the tts_model, temporary solution
+ tts_backends = {tts.name: tts for tts in tts_backends if tts.name == tts_model}
+ return tts_backends
+
+
@app.command()
def main(
urls: list[str] = typer.Option(None, "--url", "-u", help="URLs to process"),
diff --git a/podcastfy/content_generator.py b/podcastfy/content_generator.py
index 01502aa..5f3c190 100644
--- a/podcastfy/content_generator.py
+++ b/podcastfy/content_generator.py
@@ -71,6 +71,7 @@ def __init__(
self.content_generator_config = self.config.get("content_generator", {})
self.config_conversation = load_conversation_config(conversation_config)
+ self.ending_message = self.config_conversation.get('text_to_speech').get('ending_message','')
def __compose_prompt(self, num_images: int):
"""
diff --git a/podcastfy/core/__init__.py b/podcastfy/core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/podcastfy/core/audio.py b/podcastfy/core/audio.py
new file mode 100644
index 0000000..2591e5d
--- /dev/null
+++ b/podcastfy/core/audio.py
@@ -0,0 +1,106 @@
+import asyncio
+import atexit
+import os
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Optional, Dict, Union, List, cast, Tuple
+
+from pydub import AudioSegment
+
+from podcastfy.aiengines.tts.base import TTSBackend, SyncTTSBackend, AsyncTTSBackend
+from podcastfy.core.transcript import TranscriptSegment, Transcript
+
+
+class PodcastsAudioSegment:
+ """Represents an audio segment of the podcast."""
+
+ def __init__(self, filepath: Path, transcript_segment: Optional[TranscriptSegment] = None) -> None:
+ self.filepath = filepath
+ self.transcript_segment = transcript_segment
+ self._audio: Optional[AudioSegment] = None
+
+ @property
+ def audio(self) -> AudioSegment:
+ """Lazy-load the audio segment."""
+ if self._audio is None:
+ self._audio = AudioSegment.from_file(self.filepath)
+ return self._audio
+
+
+class AudioManager:
+ def __init__(self, tts_backends: Dict[str, TTSBackend], audio_format, n_jobs: int = 4, file_prefix: str = "", audio_temp_dir: str = None) -> None:
+ self.audio_format = audio_format
+ self.tts_backends = tts_backends
+ self.n_jobs = n_jobs
+ self.has_async_backend = any(isinstance(backend, AsyncTTSBackend) for backend in self.tts_backends.values())
+ self.file_prefix = file_prefix
+ self.final_audio: Optional[AudioSegment] = None
+ if audio_temp_dir:
+ os.makedirs(audio_temp_dir, exist_ok=True)
+ self.temp_dir = Path(audio_temp_dir)
+ else:
+ self._temp_dir = TemporaryDirectory()
+ self.temp_dir = Path(self._temp_dir.name)
+ atexit.register(self._temp_dir.cleanup)
+
+ async def _async_build_audio_segments(self, transcript: Transcript) -> List[PodcastsAudioSegment]:
+ async def process_segment(segment_tuple: Tuple[TranscriptSegment, int]):
+ segment, index = segment_tuple
+ tts_backend = self._get_tts_backend(segment)
+ audio_path = Path(self.temp_dir) / f"{self.file_prefix}{index:04d}.{self.audio_format}"
+ if isinstance(tts_backend, AsyncTTSBackend):
+ await tts_backend.async_text_to_speech(
+ segment.text,
+ segment.speaker,
+ audio_path
+ )
+ else:
+ tts_backend.text_to_speech(
+ segment.text,
+ segment.speaker,
+ audio_path
+ )
+ return PodcastsAudioSegment(audio_path, segment)
+
+ semaphore = asyncio.Semaphore(self.n_jobs)
+
+ async def bounded_process_segment(segment_tuple):
+ async with semaphore:
+ return await process_segment(segment_tuple)
+
+ tasks = [asyncio.create_task(bounded_process_segment((segment, i))) for i, segment in enumerate(transcript.segments)]
+ return list(await asyncio.gather(*tasks))
+
+ def _get_tts_backend(self, segment):
+ tts_backend = self.tts_backends.get(segment.speaker.preferred_tts)
+ if tts_backend is None:
+ # Take the first available TTS backend
+ tts_backend = next(iter(self.tts_backends.values()))
+ return tts_backend
+
+ def _sync_build_audio_segments(self, transcript: Transcript) -> List[PodcastsAudioSegment]:
+ def process_segment(segment_tuple: Tuple[TranscriptSegment, int]):
+ segment, index = segment_tuple
+ tts_backend = self._get_tts_backend(segment)
+ filepath = Path(str(self.temp_dir)) / f"{self.file_prefix}{index:04d}.{self.audio_format}"
+ cast(SyncTTSBackend, tts_backend).text_to_speech(
+ segment.text,
+ segment.speaker,
+ filepath
+ )
+ return PodcastsAudioSegment(filepath, segment)
+
+
+ with ThreadPoolExecutor(max_workers=self.n_jobs) as executor:
+ return list(executor.map(process_segment,
+ ((segment, i) for i, segment in enumerate(transcript.segments))))
+
+ def create_audio_segments(self, transcript: Transcript) -> List[PodcastsAudioSegment]:
+ if self.has_async_backend:
+ return asyncio.run(self._async_build_audio_segments(transcript))
+ else:
+ return self._sync_build_audio_segments(transcript)
+
+ # def stitch_audio_segments(self) -> None:
+ # self.final_audio = sum((segment.audio for segment in self.audio_segments), AudioSegment.empty())
diff --git a/podcastfy/core/character.py b/podcastfy/core/character.py
new file mode 100644
index 0000000..ad6cdc2
--- /dev/null
+++ b/podcastfy/core/character.py
@@ -0,0 +1,30 @@
+from typing import Dict, Optional
+
+from podcastfy.core.tts_configs import TTSConfig
+
+
+class Character:
+ """Represents a character in the podcast."""
+
+ def __init__(self, name: str, role: str, tts_configs: Dict[str, TTSConfig] = {},
+ default_description_for_llm: str = ""):
+ self.name = name
+ self.role = role
+ self.tts_configs = tts_configs
+ self.default_description_for_llm = default_description_for_llm
+ self.preferred_tts = next(iter(tts_configs.keys()), None) # Set first TTS as default, can be None
+
+ def set_preferred_tts(self, tts_name: str):
+ if tts_name not in self.tts_configs:
+ raise ValueError(f"TTS backend '{tts_name}' not configured for this character")
+ self.preferred_tts = tts_name
+
+ def to_prompt(self) -> str:
+ """Convert the character information to a prompt for the LLM."""
+ #TODO: could be improved by adding more information than roles
+ return f"Character: {self.name}\nRole: {self.role}\n{self.default_description_for_llm.format(name=self.name)}"
+
+ def get_tts_args(self, tts_name: Optional[str] = None) -> TTSConfig:
+ """Get the TTS arguments for this character."""
+ tts_name = tts_name or self.preferred_tts
+ return self.tts_configs[tts_name]
diff --git a/podcastfy/core/content.py b/podcastfy/core/content.py
new file mode 100644
index 0000000..3fc6d70
--- /dev/null
+++ b/podcastfy/core/content.py
@@ -0,0 +1,9 @@
+from typing import Any
+from pydantic import BaseModel
+
+
+# we can do much better here, but for now, let's keep it simple
+
+class Content(BaseModel):
+ value: Any
+ type: str
\ No newline at end of file
diff --git a/podcastfy/core/podcast.py b/podcastfy/core/podcast.py
new file mode 100644
index 0000000..2b11267
--- /dev/null
+++ b/podcastfy/core/podcast.py
@@ -0,0 +1,361 @@
+from enum import Enum
+from pathlib import Path
+from typing import List, Optional, Dict, Any, Callable, Tuple, Union, Sequence, cast
+from tempfile import TemporaryDirectory
+import atexit
+from pydub import AudioSegment
+from functools import wraps
+from contextlib import contextmanager
+
+from podcastfy.aiengines.llm.base import LLMBackend
+from podcastfy.aiengines.tts.base import SyncTTSBackend, AsyncTTSBackend, TTSBackend
+from podcastfy.core.audio import PodcastsAudioSegment, AudioManager
+from podcastfy.core.character import Character
+from podcastfy.core.content import Content
+from podcastfy.core.transcript import TranscriptSegment, Transcript
+from podcastfy.core.tts_configs import TTSConfig
+
+
+class PodcastState(Enum):
+ """Enum representing the different states of a podcast during creation."""
+ INITIALIZED = 0 # Initial state when the Podcast object is created
+ TRANSCRIPT_BUILT = 1 # State after the transcript has been generated
+ AUDIO_SEGMENTS_BUILT = 2 # State after individual audio segments have been created
+ STITCHED = 3 # Final state after all audio segments have been combined
+
+
+def podcast_stage(func):
+ """Decorator to manage podcast stage transitions."""
+
+ @wraps(func)
+ def probably_same_func(method, func):
+ return method.__func__.__name__ == func.__name__
+
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ current_method = self._next_stage_methods[self.state]
+ print(f"Current state: {self.state.name}")
+ print(f"Executing: {func.__name__}")
+ if not probably_same_func(current_method, func) and not self._reworking:
+ print(f"Cannot execute {func.__name__} in current state {self.state.name}. Skipping.")
+ raise Exception(f"Cannot execute {func.__name__} in current state {self.state.name}")
+
+ try:
+ result = func(self, *args, **kwargs)
+ next_state = PodcastState(self.state.value + 1)
+ self.state = next_state or self.state
+ print(f"Done!")
+ return result
+ except Exception as e:
+ print(f"Error in {func.__name__}: {str(e)}")
+ raise
+
+ return wrapper
+
+
+class Podcast:
+ """Main class for podcast creation and management."""
+
+ def __init__(self, content: List[Content], llm_backend: LLMBackend,
+ audio_manager: AudioManager,
+ characters: Optional[List[Character]] = None):
+ """
+ Initialize a new Podcast instance.
+
+ Args:
+ content (str): The raw content to be processed into a podcast.
+ llm_backend (LLMBackend): The language model backend for generating the transcript.
+ tts_backends (List[TTSBackend]): List of available TTS backends.
+ audio_temp_dir (Optional[str]): Path to a temporary directory for audio files. If None, a temporary
+ directory will be created.
+ characters (List[Character]): List of characters participating in the podcast.
+ default_tts_n_jobs (int, optional): The default number of concurrent jobs for TTS processing.
+ Defaults to 1.
+
+ Raises:
+ ValueError: If a character's preferred TTS backend is not available.
+ """
+ self.content = content
+ self.llm_backend = llm_backend
+ self.characters: Dict[str, Character] = {char.name: char for char in (characters or [Character("Host", "Podcast host", {}), Character("Guest", "Expert guest", {})])}
+ self.state = PodcastState.INITIALIZED
+ self._reworking = False
+ self.audio_manager = audio_manager
+
+ # Initialize attributes with null values
+ self.transcript: Optional[Transcript] = None
+ self.audio_segments: List[PodcastsAudioSegment] = []
+ self.audio: Optional[AudioSegment] = None
+
+ # Define the sequence of methods to be called for each stage
+ self._next_stage_methods: Dict[PodcastState, Callable[[], None]] = {
+ PodcastState.INITIALIZED: self.build_transcript,
+ PodcastState.TRANSCRIPT_BUILT: self.build_audio_segments,
+ PodcastState.AUDIO_SEGMENTS_BUILT: self.stitch_audio_segments,
+ }
+
+ def __del__(self) -> None:
+ if hasattr(self, '_temp_dir'):
+ self._temp_dir.cleanup()
+
+ @classmethod
+ def from_transcript(cls, transcript: Union[Sequence[Tuple[str, str]], Transcript],
+ audio_manager: AudioManager,
+ characters: List[Character]) -> 'Podcast':
+ """
+ Create a Podcast instance from a pre-existing transcript.
+
+ Args:
+ transcript (Union[Sequence[Tuple[str, str]], Transcript]): Pre-existing transcript.
+ audio_manager (AudioManager): The audio manager instance for creating audio segments.
+ characters (List[Character]): List of characters participating in the podcast.
+ Returns:
+ Podcast: A new Podcast instance with the transcript built and ready for audio generation.
+ """
+ if isinstance(transcript, Transcript):
+ podcast = cls("", cast(LLMBackend, None), audio_manager=audio_manager, characters=characters)
+ podcast.transcript = transcript
+ else:
+ raise ValueError("Transcript must be a Transcript instance") # unimplemented
+ podcast.state = PodcastState.TRANSCRIPT_BUILT
+ return podcast
+
+ def reset_to_state(self, state: PodcastState) -> None:
+ """Reset the podcast to a specific state. """
+ self.state = state
+ self.transcript = None if state.value < PodcastState.TRANSCRIPT_BUILT.value else self.transcript
+ self.audio_segments = [] if state.value < PodcastState.AUDIO_SEGMENTS_BUILT.value else self.audio_segments
+ self.audio = None if state.value < PodcastState.STITCHED.value else self.audio
+
+ @contextmanager
+ def rework(self, target_state: PodcastState, auto_finalize: bool = True):
+ """Context manager for reworking the podcast from a specific state."""
+ original_state = self.state
+ self._reworking = True
+
+ if target_state == PodcastState.INITIALIZED and self.llm_backend is None:
+ raise ValueError("Cannot rewind to INITIALIZED state without an LLM backend.")
+
+ if target_state.value < PodcastState.TRANSCRIPT_BUILT.value and self.llm_backend is None:
+ raise ValueError("Cannot rewind past TRANSCRIPT_BUILT state without an LLM backend.")
+
+ if target_state.value < self.state.value:
+ print(f"Rewinding from {self.state.name} to {target_state.name}")
+ self.reset_to_state(target_state)
+
+ try:
+ yield
+ finally:
+ self._reworking = False
+ if self.state.value < original_state.value:
+ print(
+ f"Warning: Podcast is now in an earlier state ({self.state.name}) than before reworking ({original_state.name}). You may want to call finalize() to rebuild.")
+ if auto_finalize:
+ self.finalize()
+
+ @podcast_stage
+ def build_transcript(self) -> None:
+ """Build the podcast transcript using the LLM backend."""
+ generated_segments = self.llm_backend.generate_transcript(self.content, list(self.characters.values()))
+
+ segments = []
+ for segment in generated_segments:
+ if isinstance(segment, tuple) and len(segment) == 2:
+ speaker, text = segment
+ if speaker.name in self.characters and text.strip():
+ tts_config = cast(Dict[str, Any], self.characters[speaker.name].tts_configs.get(self.characters[speaker.name].preferred_tts, {}))
+ segments.append(TranscriptSegment(text, self.characters[speaker.name], tts_config))
+ else:
+ print(f"Invalid segment: {segment}")
+ continue
+ # If the segment doesn't match the expected format, we'll skip it
+
+ self.transcript = Transcript(segments, {"source": "Generated content"})
+
+ @podcast_stage
+ def build_audio_segments(self) -> None:
+ """Build audio segments from the transcript."""
+ if self.transcript is not None:
+ self.audio_segments = self.audio_manager.create_audio_segments(self.transcript)
+ else:
+ print("Error: Transcript is None")
+ raise ValueError("Transcript must be built before creating audio segments")
+
+ @podcast_stage
+ def stitch_audio_segments(self) -> None:
+ """Stitch all audio segments together to form the final podcast audio."""
+ # order segments by filename
+ segments_to_stitch = sorted(self.audio_segments, key=lambda segment: segment.filepath)
+
+ self.audio = sum((segment.audio for segment in segments_to_stitch), AudioSegment.empty())
+
+ def _build_next_stage(self) -> bool:
+ """Build the next stage of the podcast."""
+ print("state: ", self.state)
+ if self.state == PodcastState.STITCHED:
+ return False
+
+ next_method = self._next_stage_methods[self.state]
+ next_method()
+ return True
+
+ def finalize(self) -> None:
+ """Finalize the podcast by building all remaining stages."""
+ while self._build_next_stage():
+ pass
+
+ def save(self, filepath: str) -> None:
+ """Save the finalized podcast audio to a file."""
+ if self.state != PodcastState.STITCHED:
+ raise ValueError("Podcast can only be saved after audio is stitched")
+
+ if self.audio:
+ self.audio.export(filepath, format="mp3")
+ else:
+ raise ValueError("No stitched audio to save")
+
+ def export_transcript(self, filepath: str, format_: str = "plaintext") -> None:
+ """Save the podcast transcript to a file."""
+ if self.state.value < PodcastState.TRANSCRIPT_BUILT.value:
+ raise ValueError("Transcript can only be saved after it is built")
+
+ if self.transcript:
+ self.transcript.export(filepath, format_)
+ else:
+ raise ValueError("No transcript to save")
+
+ def dump_transcript(self, filepath: str) -> None:
+ """Dump the podcast transcript to a JSON file."""
+ if self.state.value < PodcastState.TRANSCRIPT_BUILT.value:
+ raise ValueError("Transcript can only be dumped after it is built")
+
+ if self.transcript:
+ self.transcript.dump(filepath)
+ else:
+ raise ValueError("No transcript to dump")
+
+ @classmethod
+ def load_transcript(cls, filepath: str, tts_backends: List[Union[SyncTTSBackend, AsyncTTSBackend]],
+ characters: List[Character]) -> 'Podcast':
+ """Load a podcast from a transcript JSON file."""
+ character_dict = {char.name: char for char in characters}
+ transcript = Transcript.load(filepath, character_dict)
+ return cls.from_transcript(transcript, tts_backends, characters)
+
+
+# Usage example: Step-by-step podcast creation
+if __name__ == "__main__":
+ from tempfile import NamedTemporaryFile
+
+
+ class DummyLLMBackend(LLMBackend):
+ def generate_text(self, prompt: str, characters: List[Character]) -> List[Tuple[Character, str]]:
+ return [(characters[0], "Welcome to our podcast!"), (characters[1], "Thanks for having me!")]
+
+
+ class DummyTTSBackend(SyncTTSBackend):
+ def __init__(self, name: str):
+ self.name = name
+
+ def text_to_speech(self, text: str, character: Character, output_path: Path) -> Path:
+ audio = AudioSegment.silent(duration=1000)
+ audio.export(str(output_path), format="mp3")
+ return output_path
+
+
+ # Define TTS backends
+ openai_tts = DummyTTSBackend("openai")
+ elevenlabs_tts = DummyTTSBackend("elevenlabs")
+
+ # Define TTS backends
+ host = Character(
+ name="Host",
+ role="Podcast host",
+ tts_configs={
+ "openai": TTSConfig(voice="en-US-Neural2-F", backend="openai", extra_args={"speaking_rate": 1.0}),
+ "elevenlabs": TTSConfig(voice="Rachel", backend="elevenlabs", extra_args={"stability": 0.5})
+ },
+ default_description_for_llm="{name} is an enthusiastic podcast host. Speaks clearly and engagingly."
+ )
+
+ guest = Character(
+ name="Guest",
+ role="Expert guest",
+ tts_configs={
+ "openai": TTSConfig(voice="en-US-Neural2-D", backend="openai", extra_args={"pitch": -2.0}),
+ "elevenlabs": TTSConfig(voice="Antoni", backend="elevenlabs", extra_args={"stability": 0.8})
+ },
+ default_description_for_llm="{name} is an expert guest. Shares knowledge in a friendly manner."
+ )
+
+ # Initialize the podcast
+ podcast = Podcast(
+ content="""
+ This is a sample content for our podcast.
+ It includes information from multiple sources that have already been parsed.
+ """,
+ llm_backend=DummyLLMBackend(),
+ tts_backends=[openai_tts, elevenlabs_tts],
+ characters=[host, guest],
+ )
+ print(f"Initial state: {podcast.state}")
+
+ # Step 1: Build transcript
+ podcast.build_transcript()
+ print(f"After building transcript: {podcast.state}")
+ print(f"Transcript: {podcast.transcript}")
+
+ # Step 2: Build audio segments
+ podcast.build_audio_segments()
+ print(f"After building audio segments: {podcast.state}")
+ print(f"Number of audio segments: {len(podcast.audio_segments)}")
+
+ # Step 3: Stitch audio segments
+ podcast.stitch_audio_segments()
+ print(f"After stitching audio: {podcast.state}")
+
+ # Rework example: modify the transcript and rebuild (auto_finalize is True by default)
+ with podcast.rework(PodcastState.TRANSCRIPT_BUILT):
+ print(f"Inside rework context, state: {podcast.state}")
+ podcast.transcript.segments.append(
+ TranscriptSegment("This is a new segment", podcast.characters["Host"]))
+ print("Added new segment to transcript")
+
+ # Rebuild audio segments and stitch
+ podcast.build_audio_segments()
+
+ print(f"After rework: {podcast.state}")
+
+ # Add a new audio segment (auto_finalize is True by default)
+ with NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
+ AudioSegment.silent(duration=500).export(temp_file.name, format="mp3")
+
+ with podcast.rework(PodcastState.AUDIO_SEGMENTS_BUILT):
+ new_segment = PodcastsAudioSegment(Path(temp_file.name), 500,
+ TranscriptSegment("New audio segment", podcast.characters["Host"]))
+ podcast.audio_segments.insert(0, new_segment)
+
+ # Save the final podcast
+ podcast.save("./final.mp3")
+ podcast.export_transcript("./final.txt", format_="plaintext")
+ print("Saved podcast and transcript")
+
+ # Example with pre-existing transcript using from_transcript class method
+ pre_existing_transcript = [
+ ("Host", "Welcome to our podcast created from a pre-existing transcript!"),
+ ("Guest", "Thank you for having me. I'm excited to be here.")
+ ]
+
+ podcast_from_transcript = Podcast.from_transcript(
+ transcript=pre_existing_transcript,
+ tts_backends=[openai_tts, elevenlabs_tts],
+ characters=[host, guest]
+ )
+
+ print(f"Podcast created from transcript initial state: {podcast_from_transcript.state}")
+ print(f"Transcript: {podcast_from_transcript.transcript}")
+
+ # Finalize the podcast (this will skip transcript generation and move directly to audio generation)
+ podcast_from_transcript.finalize()
+ podcast_from_transcript.save("./from_transcript.mp3")
+ print("Saved podcast created from transcript")
diff --git a/podcastfy/core/transcript.py b/podcastfy/core/transcript.py
new file mode 100644
index 0000000..785bd55
--- /dev/null
+++ b/podcastfy/core/transcript.py
@@ -0,0 +1,127 @@
+import json
+import re
+from typing import Optional, Dict, Any, List, Tuple
+
+from podcastfy.core.character import Character
+
+
+
+class TranscriptSegment:
+ def __init__(self, text: str, speaker: Character,
+ tts_args: Optional[Dict[str, Any]] = None,
+ auto_clean_markup=True) -> None:
+ self.text = self._clean_markups(text) if auto_clean_markup else text
+ self.speaker = speaker
+ self.tts_args = tts_args or {}
+
+ @staticmethod
+ def _clean_markups(input_text: str) -> str:
+ """
+ Remove unsupported TSS markup tags from the input text while preserving supported SSML tags.
+
+ Args:
+ input_text (str): The input text containing TSS markup tags.
+
+ Returns:
+ str: Cleaned text with unsupported TSS markup tags removed.
+ """
+ # List of SSML tags supported by both OpenAI and ElevenLabs
+ supported_tags = [
+ 'speak', 'speak', 'lang', 'p', 'phoneme',
+ 's', 'say-as', 'sub'
+ ]
+ # Append additional tags to the supported tags list
+ # Create a pattern that matches any tag not in the supported list
+ pattern = r'<(?!(?:/?' + '|'.join(supported_tags) + r')\b)[^>]+>'
+
+ # Remove unsupported tags
+ cleaned_text = re.sub(pattern, '', input_text)
+
+ # Remove any leftover empty lines
+ cleaned_text = re.sub(r'\n\s*\n', '\n', cleaned_text)
+ cleaned_text = cleaned_text.replace('(scratchpad)', '')
+ return cleaned_text
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "text": self.text,
+ "speaker": self.speaker.name,
+ "tts_args": self.tts_args
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any], characters: Dict[str, Character]) -> 'TranscriptSegment':
+ return cls(
+ text=data['text'],
+ speaker=characters[data['speaker']],
+ tts_args=data.get('tts_args', {})
+ )
+
+
+class Transcript:
+ def __init__(self, segments: List[TranscriptSegment], metadata: Dict[str, Any] = {}) -> None:
+ self.segments = segments
+ self.metadata = metadata
+
+ def export(self, filepath: str, format_: str = "plaintext") -> None:
+ """Export the transcript to a file."""
+ with open(filepath, 'w') as f:
+ if format_ == "plaintext":
+ f.write(str(self))
+ elif format_ == "json":
+ json.dump(self.to_dict(), f, indent=2)
+ else:
+ raise ValueError(f"Unsupported format: {format_}")
+
+ def dump(self, filepath: str) -> None:
+ """Dump the transcript to a JSON file."""
+ with open(filepath, 'w') as f:
+ json.dump(self.to_dict(), f, indent=2)
+
+ @staticmethod
+ def _parse_legacy_transcript(content: str) -> List[Tuple[str, str]]:
+ # in the future, Person should be replaced by any character name, but for now, it's Person
+ # this is tricky because we don't want to take a random tag as a character name, but maybe it's ok to assume that the first tag of each line is the character name
+ pattern = r'\s*(.*?)\s*'
+ matches = re.findall(pattern, content, re.DOTALL)
+ return [('Person' + person_num, text) for person_num, text in matches]
+
+ @classmethod
+ def load(cls, filepath: str, characters: Dict[str, Character]) -> 'Transcript':
+ """Load a transcript from a JSON file."""
+ # There are a loss of characters informations when loading a transcript, is it acceptable?
+ with open(filepath, 'r') as f:
+ content = f.read()
+
+ try:
+ data = json.loads(content)
+ segments = [TranscriptSegment.from_dict(seg, characters) for seg in data['segments']]
+ except json.JSONDecodeError:
+ # If JSON parsing fails, assume it's a legacy transcript
+ parsed_content = cls._parse_legacy_transcript(content)
+ segments = []
+ for speaker, text in parsed_content:
+ if speaker in characters:
+ character = characters[speaker]
+ else:
+ # Create a new character if it doesn't exist
+ character = Character(speaker, f"Character {speaker}", {})
+ characters[speaker] = character
+ segments.append(TranscriptSegment(text, character))
+
+ data = {'segments': segments, 'metadata': {}}
+ return cls(segments, data['metadata'])
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "segments": [segment.to_dict() for segment in self.segments],
+ "metadata": self.metadata
+ }
+
+ def __str__(self) -> str:
+ """Convert the transcript to a xml representation."""
+ lines = []
+ for segment in self.segments:
+ lines.append(f'<{segment.speaker.name}>{segment.text}{segment.speaker.name}>')
+ return '\n'.join(lines)
+
diff --git a/podcastfy/core/tts_configs.py b/podcastfy/core/tts_configs.py
new file mode 100644
index 0000000..c46ed25
--- /dev/null
+++ b/podcastfy/core/tts_configs.py
@@ -0,0 +1,12 @@
+from typing import Dict, Any
+
+from pydantic import BaseModel
+
+
+class VoiceConfig(BaseModel):
+ voice: str
+ extra_args: Dict[str, Any] = {}
+
+
+class TTSConfig(VoiceConfig):
+ backend: str
diff --git a/podcastfy/text_to_speech.py b/podcastfy/text_to_speech.py
deleted file mode 100644
index 977272e..0000000
--- a/podcastfy/text_to_speech.py
+++ /dev/null
@@ -1,353 +0,0 @@
-"""
-Text-to-Speech Module
-
-This module provides functionality to convert text into speech using various TTS models.
-It supports both ElevenLabs, OpenAI and Edge TTS services and handles the conversion process,
-including cleaning of input text and merging of audio files.
-"""
-
-import logging
-import asyncio
-import edge_tts
-from elevenlabs import client as elevenlabs_client
-from podcastfy.utils.config import load_config
-from podcastfy.utils.config_conversation import load_conversation_config
-from pydub import AudioSegment
-import os
-import re
-import openai
-from typing import List, Tuple, Optional, Union
-
-logger = logging.getLogger(__name__)
-
-class TextToSpeech:
- def __init__(self, model: str = 'openai', api_key: Optional[str] = None):
- """
- Initialize the TextToSpeech class.
-
- Args:
- model (str): The model to use for text-to-speech conversion.
- Options are 'elevenlabs', 'openai' or 'edge'. Defaults to 'openai'.
- api_key (Optional[str]): API key for the selected text-to-speech service.
- If not provided, it will be loaded from the config.
- """
- self.model = model.lower()
- self.config = load_config()
- self.conversation_config = load_conversation_config()
- self.tts_config = self.conversation_config.get('text_to_speech')
-
- if self.model == 'elevenlabs':
- self.api_key = api_key or self.config.ELEVENLABS_API_KEY
- self.client = elevenlabs_client.ElevenLabs(api_key=self.api_key)
- elif self.model == 'openai':
- self.api_key = api_key or self.config.OPENAI_API_KEY
- openai.api_key = self.api_key
- elif self.model == 'edge':
- pass
- else:
- raise ValueError("Invalid model. Choose 'elevenlabs', 'openai' or 'edge'.")
-
- self.audio_format = self.tts_config['audio_format']
- self.temp_audio_dir = self.tts_config['temp_audio_dir']
- self.ending_message = self.tts_config['ending_message']
-
- # Create temp_audio_dir if it doesn't exist
- if not os.path.exists(self.temp_audio_dir):
- os.makedirs(self.temp_audio_dir)
-
- def __merge_audio_files(self, input_dir: str, output_file: str) -> None:
- """
- Merge all audio files in the input directory sequentially and save the result.
-
- Args:
- input_dir (str): Path to the directory containing audio files.
- output_file (str): Path to save the merged audio file.
- """
- try:
- # Function to sort filenames naturally
- def natural_sort_key(filename: str) -> List[Union[int, str]]:
- return [int(text) if text.isdigit() else text for text in re.split(r'(\d+)', filename)]
-
- combined = AudioSegment.empty()
- audio_files = sorted(
- [f for f in os.listdir(input_dir) if f.endswith(f".{self.audio_format}")],
- key=natural_sort_key
- )
- for file in audio_files:
- if file.endswith(f".{self.audio_format}"):
- file_path = os.path.join(input_dir, file)
- combined += AudioSegment.from_file(file_path, format=self.audio_format)
-
- combined.export(output_file, format=self.audio_format)
- logger.info(f"Merged audio saved to {output_file}")
- except Exception as e:
- logger.error(f"Error merging audio files: {str(e)}")
- raise
-
- def convert_to_speech(self, text: str, output_file: str) -> None:
- """
- Convert input text to speech and save as an audio file.
-
- Args:
- text (str): Input text to convert to speech.
- output_file (str): Path to save the output audio file.
-
- Raises:
- Exception: If there's an error in converting text to speech.
- """
- # Clean TSS markup tags from the input text
- cleaned_text = self.clean_tss_markup(text)
-
- if self.model == 'elevenlabs':
- self.__convert_to_speech_elevenlabs(cleaned_text, output_file)
- elif self.model == 'openai':
- self.__convert_to_speech_openai(cleaned_text, output_file)
- elif self.model == 'edge':
- self.__convert_to_speech_edge(cleaned_text, output_file)
-
- def __convert_to_speech_elevenlabs(self, text: str, output_file: str) -> None:
- try:
- qa_pairs = self.split_qa(text)
- audio_files = []
- counter = 0
- for question, answer in qa_pairs:
- question_audio = self.client.generate(
- text=question,
- voice=self.tts_config['elevenlabs']['default_voices']['question'],
- model=self.tts_config['elevenlabs']['model']
- )
- answer_audio = self.client.generate(
- text=answer,
- voice=self.tts_config['elevenlabs']['default_voices']['answer'],
- model=self.tts_config['elevenlabs']['model']
- )
-
- # Save question and answer audio chunks
- for audio in [question_audio, answer_audio]:
- counter += 1
- file_name = f"{self.temp_audio_dir}{counter}.{self.audio_format}"
- with open(file_name, "wb") as out:
- for chunk in audio:
- if chunk:
- out.write(chunk)
- audio_files.append(file_name)
-
- # Merge all audio files and save the result
- self.__merge_audio_files(self.temp_audio_dir, output_file)
-
- # Clean up individual audio files
- for file in audio_files:
- os.remove(file)
-
- logger.info(f"Audio saved to {output_file}")
-
- except Exception as e:
- logger.error(f"Error converting text to speech with ElevenLabs: {str(e)}")
- raise
-
- def __convert_to_speech_openai(self, text: str, output_file: str) -> None:
- try:
- qa_pairs = self.split_qa(text)
- print(qa_pairs)
- audio_files = []
- counter = 0
- for question, answer in qa_pairs:
- for speaker, content in [
- (self.tts_config['openai']['default_voices']['question'], question),
- (self.tts_config['openai']['default_voices']['answer'], answer)
- ]:
- counter += 1
- file_name = f"{self.temp_audio_dir}{counter}.{self.audio_format}"
- response = openai.audio.speech.create(
- model=self.tts_config['openai']['model'],
- voice=speaker,
- input=content
- )
- with open(file_name, "wb") as file:
- file.write(response.content)
-
- audio_files.append(file_name)
-
- # Merge all audio files and save the result
- self.__merge_audio_files(self.temp_audio_dir, output_file)
-
- # Clean up individual audio files
- for file in audio_files:
- os.remove(file)
-
- logger.info(f"Audio saved to {output_file}")
-
- except Exception as e:
- logger.error(f"Error converting text to speech with OpenAI: {str(e)}")
- raise
-
- def get_or_create_eventloop():
- try:
- return asyncio.get_event_loop()
- except RuntimeError as ex:
- if "There is no current event loop in thread" in str(ex):
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- return asyncio.get_event_loop()
-
- import nest_asyncio # type: ignore
- get_or_create_eventloop()
- nest_asyncio.apply()
-
- def __convert_to_speech_edge(self, text: str, output_file: str) -> None:
- """
- Convert text to speech using Edge TTS.
-
- Args:
- text (str): The input text to convert to speech.
- output_file (str): The path to save the output audio file.
- """
- try:
- qa_pairs = self.split_qa(text)
- audio_files = []
- counter = 0
-
- async def edge_tts_conversion(text_chunk: str, output_path: str, voice: str):
- tts = edge_tts.Communicate(text_chunk, voice)
- await tts.save(output_path)
- return
-
- async def process_qa_pairs(qa_pairs):
- nonlocal counter
- tasks = []
- for question, answer in qa_pairs:
- for speaker, content in [
- (self.tts_config['edge']['default_voices']['question'], question),
- (self.tts_config['edge']['default_voices']['answer'], answer)
- ]:
- counter += 1
- file_name = f"{self.temp_audio_dir}{counter}.{self.audio_format}"
- tasks.append(asyncio.ensure_future(edge_tts_conversion(content, file_name, speaker)))
- audio_files.append(file_name)
-
- await asyncio.gather(*tasks)
-
- asyncio.run(process_qa_pairs(qa_pairs))
-
- # Merge all audio files
- self.__merge_audio_files(self.temp_audio_dir, output_file)
-
- # Clean up individual audio files
- for file in audio_files:
- os.remove(file)
- logger.info(f"Audio saved to {output_file}")
-
- except Exception as e:
- logger.error(f"Error converting text to speech with Edge: {str(e)}")
- raise
-
-
- def split_qa(self, input_text: str) -> List[Tuple[str, str]]:
- """
- Split the input text into question-answer pairs.
-
- Args:
- input_text (str): The input text containing Person1 and Person2 dialogues.
-
- Returns:
- List[Tuple[str, str]]: A list of tuples containing (Person1, Person2) dialogues.
- """
- # Add ending message to the end of input_text
- input_text += f"{self.ending_message}"
-
- # Regular expression pattern to match Person1 and Person2 dialogues
- pattern = r'(.*?)\s*(.*?)'
-
- # Find all matches in the input text
- matches = re.findall(pattern, input_text, re.DOTALL)
-
- # Process the matches to remove extra whitespace and newlines
- processed_matches = [
- (
- ' '.join(person1.split()).strip(),
- ' '.join(person2.split()).strip()
- )
- for person1, person2 in matches
- ]
- return processed_matches
-
- # to be done: Add support for additional tags dynamically given TTS model. Right now it's the intersection of OpenAI/MS Edgeand ElevenLabs supported tags.
- def clean_tss_markup(self, input_text: str, additional_tags: List[str] = ["Person1", "Person2"]) -> str:
- """
- Remove unsupported TSS markup tags from the input text while preserving supported SSML tags.
-
- Args:
- input_text (str): The input text containing TSS markup tags.
- additional_tags (List[str]): Optional list of additional tags to preserve. Defaults to ["Person1", "Person2"].
-
- Returns:
- str: Cleaned text with unsupported TSS markup tags removed.
- """
- # List of SSML tags supported by both OpenAI and ElevenLabs
- supported_tags = [
- 'speak', 'lang', 'p', 'phoneme',
- 's', 'say-as', 'sub'
- ]
-
- # Append additional tags to the supported tags list
- supported_tags.extend(additional_tags)
-
- # Create a pattern that matches any tag not in the supported list
- pattern = r'?(?!(?:' + '|'.join(supported_tags) + r')\b)[^>]+>'
-
- # Remove unsupported tags
- cleaned_text = re.sub(pattern, '', input_text)
-
- # Remove any leftover empty lines
- cleaned_text = re.sub(r'\n\s*\n', '\n', cleaned_text)
-
- # Ensure closing tags for additional tags are preserved
- for tag in additional_tags:
- cleaned_text = re.sub(f'<{tag}>(.*?)(?=<(?:{"|".join(additional_tags)})>|$)',
- f'<{tag}>\\1{tag}>',
- cleaned_text,
- flags=re.DOTALL)
- # Remove '(scratchpad)' from cleaned_text
- cleaned_text = cleaned_text.replace('(scratchpad)', '')
-
- return cleaned_text.strip()
-
-def main(seed: int = 42) -> None:
- """
- Main function to test the TextToSpeech class.
-
- Args:
- seed (int): Random seed for reproducibility. Defaults to 42.
- """
- try:
- # Load configuration
- config = load_config()
-
- # Read input text from file
- with open('tests/data/transcript_336aa9f955cd4019bc1287379a5a2820.txt', 'r') as file:
- input_text = file.read()
-
- # Test ElevenLabs
- tts_elevenlabs = TextToSpeech(model='elevenlabs')
- elevenlabs_output_file = 'tests/data/response_elevenlabs.mp3'
- tts_elevenlabs.convert_to_speech(input_text, elevenlabs_output_file)
- logger.info(f"ElevenLabs TTS completed. Output saved to {elevenlabs_output_file}")
-
- # Test OpenAI
- tts_openai = TextToSpeech(model='openai')
- openai_output_file = 'tests/data/response_openai.mp3'
- tts_openai.convert_to_speech(input_text, openai_output_file)
- logger.info(f"OpenAI TTS completed. Output saved to {openai_output_file}")
-
- # Test OpenAI
- tts_edge = TextToSpeech(model='edge')
- edge_output_file = 'tests/data/response_edge.mp3'
- tts_edge.convert_to_speech(input_text, edge_output_file)
- logger.info(f"Edge TTS completed. Output saved to {edge_output_file}")
-
- except Exception as e:
- logger.error(f"An error occurred during text-to-speech conversion: {str(e)}")
- raise
-
-if __name__ == "__main__":
- main(seed=42)
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 9fb07aa..4758f2e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,10 +44,12 @@ types-pyyaml = "^6.0.12.20240917"
nest-asyncio = "^1.6.0"
ffmpeg = "^1.4"
pytest = "^8.3.3"
+pytest-asyncio = "^0.24.0"
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.3"
+pytest-asyncio = "^0.24.0"
black = "^24.8.0"
sphinx = ">=8.0.2"
nbsphinx = "0.9.5"
diff --git a/requirements.txt b/requirements.txt
index e24bccf..645987c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -108,6 +108,7 @@ pygments==2.18.0 ; python_version >= "3.11" and python_version < "4.0"
pymupdf==1.24.11 ; python_version >= "3.11" and python_version < "4.0"
pyparsing==3.2.0 ; python_version >= "3.11" and python_version < "4.0"
pytest==8.3.3 ; python_version >= "3.11" and python_version < "4.0"
+pytest-asyncio==0.24.0 ; python_version >= "3.11" and python_version < "4.0"
python-dateutil==2.9.0.post0 ; python_version >= "3.11" and python_version < "4.0"
python-dotenv==1.0.1 ; python_version >= "3.11" and python_version < "4.0"
python-levenshtein==0.26.0 ; python_version >= "3.11" and python_version < "4.0"
diff --git a/tests/test_audio.py b/tests/test_audio.py
index 9e72d04..77fe504 100644
--- a/tests/test_audio.py
+++ b/tests/test_audio.py
@@ -1,50 +1,52 @@
-import unittest
import pytest
import os
-from podcastfy.text_to_speech import TextToSpeech
-
-
-class TestAudio(unittest.TestCase):
- def setUp(self):
- self.test_text = "Hello, how are you?I'm doing great, thanks for asking!"
- self.output_dir = "tests/data/audio"
- os.makedirs(self.output_dir, exist_ok=True)
-
- @pytest.mark.skip(reason="Testing edge only on Github Action as it's free")
- def test_text_to_speech_openai(self):
- tts = TextToSpeech(model="openai")
- output_file = os.path.join(self.output_dir, "test_openai.mp3")
- tts.convert_to_speech(self.test_text, output_file)
-
- self.assertTrue(os.path.exists(output_file))
- self.assertGreater(os.path.getsize(output_file), 0)
-
- # Clean up
- os.remove(output_file)
-
- @pytest.mark.skip(reason="Testing edge only on Github Action as it's free")
- def test_text_to_speech_elevenlabs(self):
- tts = TextToSpeech(model="elevenlabs")
- output_file = os.path.join(self.output_dir, "test_elevenlabs.mp3")
- tts.convert_to_speech(self.test_text, output_file)
-
- self.assertTrue(os.path.exists(output_file))
- self.assertGreater(os.path.getsize(output_file), 0)
-
- # Clean up
- os.remove(output_file)
-
- def test_text_to_speech_edge(self):
- tts = TextToSpeech(model="edge")
- output_file = os.path.join(self.output_dir, "test_edge.mp3")
- tts.convert_to_speech(self.test_text, output_file)
-
- self.assertTrue(os.path.exists(output_file))
- self.assertGreater(os.path.getsize(output_file), 0)
-
- # Clean up
- os.remove(output_file)
-
-
-if __name__ == "__main__":
- unittest.main()
+from pathlib import Path
+from podcastfy.core.character import Character
+from podcastfy.aiengines.tts.tts_backends import ElevenLabsTTS, OpenAITTS, EdgeTTS
+
+@pytest.fixture
+def test_setup():
+ test_text = "Hello, how are you?I'm doing great, thanks for asking!"
+ output_dir = Path("tests/data/audio")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ dummy_character = Character("test_character", "host", {}, "A test character")
+ return test_text, output_dir, dummy_character
+
+@pytest.mark.skip(reason="Testing Eleven Labs only on Github Action as it requires API key")
+def test_text_to_speech_elevenlabs(test_setup):
+ test_text, output_dir, dummy_character = test_setup
+ tts = ElevenLabsTTS()
+ output_file = output_dir / "test_elevenlabs.mp3"
+ tts.text_to_speech(test_text, dummy_character, output_file)
+
+ assert output_file.exists()
+ assert output_file.stat().st_size > 0
+
+ # Clean up
+ output_file.unlink()
+
+@pytest.mark.skip(reason="Testing OpenAI only on Github Action as it requires API key")
+def test_text_to_speech_openai(test_setup):
+ test_text, output_dir, dummy_character = test_setup
+ tts = OpenAITTS()
+ output_file = output_dir / "test_openai.mp3"
+ tts.text_to_speech(test_text, dummy_character, output_file)
+
+ assert output_file.exists()
+ assert output_file.stat().st_size > 0
+
+ # Clean up
+ output_file.unlink()
+
+@pytest.mark.asyncio
+async def test_text_to_speech_edge(test_setup):
+ test_text, output_dir, dummy_character = test_setup
+ tts = EdgeTTS()
+ output_file = output_dir / "test_edge.mp3"
+ await tts.async_text_to_speech(test_text, dummy_character, output_file)
+
+ assert output_file.exists()
+ assert output_file.stat().st_size > 0
+
+ # Clean up
+ output_file.unlink()
\ No newline at end of file
diff --git a/tests/test_core_api.py b/tests/test_core_api.py
new file mode 100644
index 0000000..33cf457
--- /dev/null
+++ b/tests/test_core_api.py
@@ -0,0 +1,153 @@
+"""Tests for the core API of the podcastfy package. Not e2e tests as DummyTTSBackend is used to simulate the TTS backend and DummyLLMBackend is used to simulate the LLM backend."""
+import pytest
+from pathlib import Path
+from pydub import AudioSegment
+
+from podcastfy.core.content import Content
+from podcastfy.core.podcast import Podcast, PodcastState
+from podcastfy.aiengines.llm.base import LLMBackend
+from podcastfy.core.character import Character
+from podcastfy.core.tts_configs import TTSConfig
+from podcastfy.core.transcript import TranscriptSegment, Transcript
+from podcastfy.core.audio import AudioManager
+
+class DummyLLMBackend(LLMBackend):
+ def generate_transcript(self, content, characters):
+ return [
+ (characters[0], "Welcome to our podcast!"),
+ (characters[1], "Thanks for having me!")
+ ]
+
+class DummyTTSBackend:
+ def __init__(self, name: str):
+ self.name = name
+
+ def text_to_speech(self, text: str, character: Character, output_path: Path) -> Path:
+ audio = AudioSegment.silent(duration=1000)
+ audio.export(str(output_path), format="mp3")
+ return output_path
+
+@pytest.fixture
+def audio_manager(tmp_path):
+ tts_backends = {"openai": DummyTTSBackend("openai"), "elevenlabs": DummyTTSBackend("elevenlabs")}
+ return AudioManager(tts_backends, audio_format="mp3", audio_temp_dir=tmp_path, n_jobs=1)
+
+@pytest.fixture
+def characters():
+ host = Character(
+ name="Person1",
+ role="Podcast host",
+ tts_configs={
+ "openai": TTSConfig(voice="en-US-Neural2-F", backend="openai", extra_args={"speaking_rate": 1.0}),
+ "elevenlabs": TTSConfig(voice="Rachel", backend="elevenlabs", extra_args={"stability": 0.5})
+ },
+ default_description_for_llm="{name} is an enthusiastic podcast host. Speaks clearly and engagingly."
+ )
+
+ guest = Character(
+ name="Person2",
+ role="Expert guest",
+ tts_configs={
+ "openai": TTSConfig(voice="en-US-Neural2-D", backend="openai", extra_args={"pitch": -2.0}),
+ "elevenlabs": TTSConfig(voice="Antoni", backend="elevenlabs", extra_args={"stability": 0.8})
+ },
+ default_description_for_llm="{name} is an expert guest. Shares knowledge in a friendly manner."
+ )
+
+ return [host, guest]
+
+@pytest.fixture
+def podcast(audio_manager, characters):
+ return Podcast(
+ content=[Content(value="This is a sample content for our podcast.", type="text")],
+ llm_backend=DummyLLMBackend(),
+ audio_manager=audio_manager,
+ characters=characters,
+ )
+
+def test_podcast_initialization(podcast):
+ assert podcast.state == PodcastState.INITIALIZED
+ assert podcast.transcript is None
+ assert podcast.audio is None
+
+def test_build_transcript(podcast):
+ podcast.build_transcript()
+ assert podcast.state == PodcastState.TRANSCRIPT_BUILT
+ assert isinstance(podcast.transcript, Transcript)
+ assert len(podcast.transcript.segments) == 2
+
+def test_build_audio_segments(podcast):
+ podcast.build_transcript()
+ podcast.build_audio_segments()
+ assert podcast.state == PodcastState.AUDIO_SEGMENTS_BUILT
+ assert len(podcast.audio_segments) == 2
+
+def test_stitch_audio_segments(podcast):
+ podcast.build_transcript()
+ podcast.build_audio_segments()
+ podcast.stitch_audio_segments()
+ assert podcast.state == PodcastState.STITCHED
+ assert isinstance(podcast.audio, AudioSegment)
+
+def test_finalize(podcast):
+ podcast.finalize()
+ assert podcast.state == PodcastState.STITCHED
+ assert isinstance(podcast.transcript, Transcript)
+ assert len(podcast.audio_segments) > 0
+ assert isinstance(podcast.audio, AudioSegment)
+
+def test_save(podcast, tmp_path):
+ podcast.finalize()
+ output_file = tmp_path / "test_podcast.mp3"
+ podcast.save(str(output_file))
+ assert output_file.exists()
+
+def test_export_transcript(podcast, tmp_path):
+ podcast.finalize()
+ output_file = tmp_path / "test_transcript.txt"
+ podcast.export_transcript(str(output_file), format_="plaintext")
+ assert output_file.exists()
+
+def test_rework(podcast):
+ podcast.finalize()
+
+ with podcast.rework(PodcastState.TRANSCRIPT_BUILT):
+ assert podcast.state == PodcastState.TRANSCRIPT_BUILT
+ podcast.transcript.segments.append(
+ TranscriptSegment("This is a new segment", podcast.characters["Person1"]))
+
+ assert podcast.state == PodcastState.STITCHED
+ assert len(podcast.transcript.segments) == 3
+
+def test_from_transcript(audio_manager, characters):
+ pre_existing_transcript = [
+ ("Person1", "Welcome to our podcast created from a pre-existing transcript!"),
+ ("Person2", "Thank you for having me. I'm excited to be here.")
+ ]
+
+ podcast = Podcast.from_transcript(
+ transcript=Transcript([
+ TranscriptSegment(text, characters[0] if speaker == "Person1" else characters[1])
+ for speaker, text in pre_existing_transcript
+ ]),
+ audio_manager=audio_manager,
+ characters=characters
+ )
+
+ assert podcast.state == PodcastState.TRANSCRIPT_BUILT
+ assert len(podcast.transcript.segments) == 2
+
+ podcast.finalize()
+ assert podcast.state == PodcastState.STITCHED
+
+def test_load_transcript(audio_manager, characters, tmp_path):
+ # Create a dummy transcript file
+ transcript_file = tmp_path / "test_transcript.json"
+ Transcript([
+ TranscriptSegment("Welcome to our podcast!", characters[0]),
+ TranscriptSegment("Thank you for having me!", characters[1])
+ ]).dump(str(transcript_file))
+
+ podcast = Podcast.load_transcript(str(transcript_file), audio_manager, characters)
+ assert podcast.state == PodcastState.TRANSCRIPT_BUILT
+ assert len(podcast.transcript.segments) == 2
\ No newline at end of file
diff --git a/tests/test_transcript.py b/tests/test_transcript.py
new file mode 100644
index 0000000..c60ac12
--- /dev/null
+++ b/tests/test_transcript.py
@@ -0,0 +1,87 @@
+import pytest
+from podcastfy.core.transcript import TranscriptSegment, Transcript, Character
+from unittest.mock import patch, mock_open
+
+@pytest.fixture
+def characters():
+ character1 = Character("Person1", "John Doe", {})
+ character2 = Character("Person2", "Jane Smith", {})
+ return {"Person1": character1, "Person2": character2}
+
+def test_clean_markups():
+ input_text = "Hello World. This is a test"
+ expected_output = "Hello World. This is a test"
+ assert TranscriptSegment._clean_markups(input_text) == expected_output
+
+def test_clean_markups_with_scratchpad():
+ input_text = "Hello (scratchpad)World"
+ expected_output = "Hello World"
+ assert TranscriptSegment._clean_markups(input_text) == expected_output
+
+def test_transcript_segment_init(characters):
+ segment = TranscriptSegment("Hello World Test", characters["Person1"])
+ assert segment.text == "Hello World Test"
+ assert segment.speaker == characters["Person1"]
+
+def test_transcript_segment_to_dict(characters):
+ segment = TranscriptSegment("Hello World", characters["Person1"], {"voice_id": "test_voice"})
+ expected_dict = {
+ "text": "Hello World",
+ "speaker": "Person1",
+ "tts_args": {"voice_id": "test_voice"}
+ }
+ assert segment.to_dict() == expected_dict
+
+def test_transcript_segment_from_dict(characters):
+ data = {
+ "text": "Hello World",
+ "speaker": "Person1",
+ "tts_args": {"voice_id": "test_voice"}
+ }
+ segment = TranscriptSegment.from_dict(data, characters)
+ assert segment.text == "Hello World"
+ assert segment.speaker == characters["Person1"]
+ assert segment.tts_args == {"voice_id": "test_voice"}
+
+def test_transcript_init(characters):
+ segments = [
+ TranscriptSegment("Hello", characters["Person1"]),
+ TranscriptSegment("Hi there", characters["Person2"])
+ ]
+ transcript = Transcript(segments, {"title": "Test Transcript"})
+ assert len(transcript.segments) == 2
+ assert transcript.metadata == {"title": "Test Transcript"}
+
+def test_transcript_to_dict(characters):
+ segments = [
+ TranscriptSegment("Hello", characters["Person1"]),
+ TranscriptSegment("Hi there", characters["Person2"])
+ ]
+ transcript = Transcript(segments, {"title": "Test Transcript"})
+ expected_dict = {
+ "segments": [
+ {"text": "Hello", "speaker": "Person1", "tts_args": {}},
+ {"text": "Hi there", "speaker": "Person2", "tts_args": {}}
+ ],
+ "metadata": {"title": "Test Transcript"}
+ }
+ assert transcript.to_dict() == expected_dict
+
+@pytest.mark.parametrize("file_content,expected_segments", [
+ ('{"segments": [{"text": "Hello", "speaker": "Person1", "tts_args": {}}], "metadata": {}}', 1),
+ ('Hello\nHi there', 2)
+])
+def test_transcript_load(file_content, expected_segments, characters):
+ with patch('builtins.open', new_callable=mock_open, read_data=file_content):
+ transcript = Transcript.load("fake_path.json", characters)
+ assert len(transcript.segments) == expected_segments
+ assert transcript.segments[0].speaker == characters["Person1"]
+
+def test_transcript_str(characters):
+ segments = [
+ TranscriptSegment("Hello", characters["Person1"]),
+ TranscriptSegment("Hi there", characters["Person2"])
+ ]
+ transcript = Transcript(segments, {"title": "Test Transcript"})
+ expected_str = "Hello\nHi there"
+ assert str(transcript) == expected_str
\ No newline at end of file