From 925bc31fdbdc6cdbed1c2f0a7e720e227d2545a8 Mon Sep 17 00:00:00 2001 From: ben Date: Wed, 27 Nov 2024 17:54:22 +0800 Subject: [PATCH] feat: add optional json_output param to support JSON response in podcast generation --- podcastfy/client.py | 52 +++++++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/podcastfy/client.py b/podcastfy/client.py index e13f54a..4d8a08e 100644 --- a/podcastfy/client.py +++ b/podcastfy/client.py @@ -52,7 +52,8 @@ def process_content( model_name: Optional[str] = None, api_key_label: Optional[str] = None, topic: Optional[str] = None, - longform: bool = False + longform: bool = False, + json_output: bool = False, ): """ Process URLs, a transcript file, image paths, or raw text to generate a podcast or transcript. @@ -71,8 +72,15 @@ def process_content( tts_config = conv_config.get("text_to_speech", {}) output_directories = tts_config.get("output_directories", {}) + # Collect information for the JSON response + response = { + "transcript_file": None, + "audio_file": None, + } + if transcript_file: logger.info(f"Using transcript file: {transcript_file}") + response["transcript_file"] = transcript_file with open(transcript_file, "r") as file: qa_content = file.read() else: @@ -85,11 +93,11 @@ def process_content( is_local=is_local, model_name=model_name, api_key_label=api_key_label, - conversation_config=conv_config.to_dict() + conversation_config=conv_config.to_dict(), ) combined_content = "" - + if urls: logger.info(f"Processing {len(urls)} links") contents = [content_extractor.extract_content(link) for link in urls] @@ -97,7 +105,9 @@ def process_content( if text: if longform and len(text.strip()) < 100: - logger.info("Text too short for direct long-form generation. Extracting context...") + logger.info( + "Text too short for direct long-form generation. Extracting context..." + ) expanded_content = content_extractor.generate_topic_content(text) combined_content += f"\n\n{expanded_content}" else: @@ -113,17 +123,20 @@ def process_content( output_directories.get("transcripts", "data/transcripts"), random_filename, ) + response["transcript_file"] = transcript_filepath qa_content = content_generator.generate_qa_content( combined_content, image_file_paths=image_paths or [], output_filepath=transcript_filepath, - longform=longform + longform=longform, ) if generate_audio: api_key = None if tts_model != "edge": - api_key = getattr(config, f"{tts_model.upper().replace('MULTI', '')}_API_KEY") + api_key = getattr( + config, f"{tts_model.upper().replace('MULTI', '')}_API_KEY" + ) text_to_speech = TextToSpeech( model=tts_model, @@ -137,9 +150,14 @@ def process_content( ) text_to_speech.convert_to_speech(qa_content, audio_file) logger.info(f"Podcast generated successfully using {tts_model} TTS model") + response["audio_file"] = audio_file + if json_output: + return response return audio_file else: logger.info(f"Transcript generated successfully: {transcript_filepath}") + if json_output: + return response return transcript_filepath except Exception as e: @@ -193,10 +211,10 @@ def main( None, "--topic", "-tp", help="Topic to generate podcast about" ), longform: bool = typer.Option( - False, - "--longform", - "-lf", - help="Generate long-form content (only available for text input without images)" + False, + "--longform", + "-lf", + help="Generate long-form content (only available for text input without images)", ), ): """ @@ -231,7 +249,7 @@ def main( model_name=llm_model_name, api_key_label=api_key_label, topic=topic, - longform=longform + longform=longform, ) else: urls_list = urls or [] @@ -255,7 +273,7 @@ def main( model_name=llm_model_name, api_key_label=api_key_label, topic=topic, - longform=longform + longform=longform, ) if transcript_only: @@ -289,6 +307,7 @@ def generate_podcast( api_key_label: Optional[str] = None, topic: Optional[str] = None, longform: bool = False, + json_output: bool = False, ) -> Optional[str]: """ Generate a podcast or transcript from a list of URLs, a file containing URLs, a transcript file, or image files. @@ -307,9 +326,12 @@ def generate_podcast( llm_model_name (Optional[str]): LLM model name for content generation. api_key_label (Optional[str]): Environment variable name for LLM API key. topic (Optional[str]): Topic to generate podcast about. + json_output (bool): Return JSON response with transcript and audio file paths. Defaults to False. Returns: Optional[str]: Path to the final podcast audio file, or None if only generating a transcript. + if json_output is True, returns a dictionary with transcript_file and audio_file paths. + """ try: print("Generating podcast...") @@ -355,7 +377,8 @@ def generate_podcast( model_name=llm_model_name, api_key_label=api_key_label, topic=topic, - longform=longform + longform=longform, + json_output=json_output, ) else: urls_list = urls or [] @@ -381,7 +404,8 @@ def generate_podcast( model_name=llm_model_name, api_key_label=api_key_label, topic=topic, - longform=longform + longform=longform, + json_output=json_output, ) except Exception as e: