diff --git a/.env.example b/.env.example index 5eb47a8..6d4b032 100644 --- a/.env.example +++ b/.env.example @@ -13,7 +13,7 @@ LLM_BASE_URL=http://localhost:11434/v1 LLM_MODEL_NAME=qwen2.5:32b # Lighter model for development (less VRAM): # LLM_MODEL_NAME=qwen2.5:14b - +DEBUG=True # ===== Neo4j Configuration ===== # For Docker: use service name "neo4j" instead of "localhost" NEO4J_URI=bolt://localhost:7687 @@ -39,3 +39,14 @@ OPENAI_API_BASE_URL=http://localhost:11434/v1 # NEO4J_URI=bolt://neo4j:7687 # EMBEDDING_BASE_URL=http://ollama:11434 # OPENAI_API_BASE_URL=http://ollama:11434/v1 + + +# ===== Parallelization & Performance Tuning ===== +# Number of text chunks parsed into Knowledge Graph concurrently +GRAPH_BUILD_BATCH_SIZE=10 +# Number of Profiles/Agents to generate concurrently +PROFILE_PARALLEL_COUNT=10 +# Number of database search worker threads per profile generation +PROFILE_SEARCH_WORKERS=2 +# Number of Report sections generated concurrently by the report agent +REPORT_PARALLEL_SECTIONS=5 diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 59f915e..7d61fd7 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -435,7 +435,7 @@ def add_progress_callback(msg, progress_ratio): episode_uuids = builder.add_text_batches( graph_id, chunks, - batch_size=3, + batch_size=Config.GRAPH_BUILD_BATCH_SIZE, progress_callback=add_progress_callback ) diff --git a/backend/app/config.py b/backend/app/config.py index de706ca..087e412 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -69,6 +69,12 @@ class Config: REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2')) REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5')) + # Parallelization & Performance tuning + GRAPH_BUILD_BATCH_SIZE = int(os.environ.get('GRAPH_BUILD_BATCH_SIZE', '10')) + PROFILE_PARALLEL_COUNT = int(os.environ.get('PROFILE_PARALLEL_COUNT', '10')) + PROFILE_SEARCH_WORKERS = int(os.environ.get('PROFILE_SEARCH_WORKERS', '2')) + REPORT_PARALLEL_SECTIONS = int(os.environ.get('REPORT_PARALLEL_SECTIONS', '5')) + @classmethod def validate(cls): """Validate required configuration""" diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 2518e3b..e2a218b 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -8,6 +8,7 @@ import threading from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass +import concurrent.futures from ..config import Config from ..models.task import TaskManager, TaskStatus @@ -51,7 +52,7 @@ def build_graph_async( graph_name: str = "MiroFish Graph", chunk_size: int = 500, chunk_overlap: int = 50, - batch_size: int = 3 + batch_size: int = Config.GRAPH_BUILD_BATCH_SIZE ) -> str: """ Build graph asynchronously @@ -186,7 +187,7 @@ def add_text_batches( self, graph_id: str, chunks: List[str], - batch_size: int = 3, + batch_size: int = Config.GRAPH_BUILD_BATCH_SIZE, progress_callback: Optional[Callable] = None ) -> List[str]: """Add text in batches to graph, return uuid list of all episodes""" @@ -207,7 +208,7 @@ def add_text_batches( progress ) - for j, chunk in enumerate(batch_chunks): + def process_single_chunk(j: int, chunk: str) -> str: chunk_idx = i + j + 1 chunk_preview = chunk[:80].replace('\n', ' ') logger.info( @@ -217,20 +218,40 @@ def add_text_batches( t0 = time.time() try: episode_id = self.storage.add_text(graph_id, chunk) - episode_uuids.append(episode_id) elapsed = time.time() - t0 logger.info( f"[graph_build] Chunk {chunk_idx}/{total_chunks} done in {elapsed:.1f}s" ) + return episode_id except Exception as e: elapsed = time.time() - t0 logger.error( f"[graph_build] Chunk {chunk_idx}/{total_chunks} FAILED " f"after {elapsed:.1f}s: {e}" ) - if progress_callback: - progress_callback(f"Batch {batch_num} processing failed: {str(e)}", 0) - raise + raise e + + # Run chunks in the current batch concurrently + with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor: + future_to_j = { + executor.submit(process_single_chunk, j, chunk): j + for j, chunk in enumerate(batch_chunks) + } + + batch_results = [None] * len(batch_chunks) + for future in concurrent.futures.as_completed(future_to_j): + j = future_to_j[future] + try: + batch_results[j] = future.result() + except Exception as e: + if progress_callback: + progress_callback(f"Batch {batch_num} processing failed: {str(e)}", 0) + raise + + # Append successful episode_ids in order + for eid in batch_results: + if eid: + episode_uuids.append(eid) logger.info(f"[graph_build] All {total_chunks} chunks processed successfully") return episode_uuids diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 2555997..253704f 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -14,6 +14,8 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime +import concurrent.futures +import threading from openai import OpenAI @@ -305,29 +307,35 @@ def _search_graph_for_entity(self, entity: EntityNode) -> Dict[str, Any]: comprehensive_query = f"All information, activities, events, relationships and background about {entity_name}" try: - # Search edges (facts) - edge_results = self.storage.search( - graph_id=self.graph_id, - query=comprehensive_query, - limit=30, - scope="edges" - ) + # Execute searches concurrently since both hit the database/embedding service + with concurrent.futures.ThreadPoolExecutor(max_workers=Config.PROFILE_SEARCH_WORKERS) as io_executor: + future_edges = io_executor.submit( + self.storage.search, + graph_id=self.graph_id, + query=comprehensive_query, + limit=30, + scope="edges" + ) + future_nodes = io_executor.submit( + self.storage.search, + graph_id=self.graph_id, + query=comprehensive_query, + limit=20, + scope="nodes" + ) - all_facts = set() - if isinstance(edge_results, dict) and 'edges' in edge_results: - for edge in edge_results['edges']: - fact = edge.get('fact', '') - if fact: - all_facts.add(fact) - results["facts"] = list(all_facts) - - # Search nodes (entity summaries) - node_results = self.storage.search( - graph_id=self.graph_id, - query=comprehensive_query, - limit=20, - scope="nodes" - ) + # Collect edge results + edge_results = future_edges.result() + all_facts = set() + if isinstance(edge_results, dict) and 'edges' in edge_results: + for edge in edge_results['edges']: + fact = edge.get('fact', '') + if fact: + all_facts.add(fact) + results["facts"] = list(all_facts) + + # Collect node results + node_results = future_nodes.result() all_summaries = set() if isinstance(node_results, dict) and 'nodes' in node_results: @@ -478,8 +486,8 @@ def _generate_profile_with_llm( {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # Lower temperature with each retry - # Don't set max_tokens, let LLM generate freely + temperature=0.7 - (attempt * 0.1), # Lower temperature with each retry + max_tokens=4000 # Enforce max_tokens to prevent vLLM infinite generation loops ) content = response.choices[0].message.content @@ -798,7 +806,7 @@ def generate_profiles_from_entities( use_llm: bool = True, progress_callback: Optional[callable] = None, graph_id: Optional[str] = None, - parallel_count: int = 5, + parallel_count: int = Config.PROFILE_PARALLEL_COUNT, realtime_output_path: Optional[str] = None, output_platform: str = "reddit" ) -> List[OasisAgentProfile]: @@ -831,26 +839,25 @@ def generate_profiles_from_entities( # Helper function for real-time file writing def save_profiles_realtime(): - """Real-time save generated profiles to file""" + """Real-time save generated profiles to file without blocking the main event loop""" if not realtime_output_path: return + # Capture snapshot quickly with lock: - # Filter generated profiles existing_profiles = [p for p in profiles if p is not None] if not existing_profiles: return + def _background_write(snapshot): try: if output_platform == "reddit": - # Reddit JSON format - profiles_data = [p.to_reddit_format() for p in existing_profiles] + profiles_data = [p.to_reddit_format() for p in snapshot] with open(realtime_output_path, 'w', encoding='utf-8') as f: json.dump(profiles_data, f, ensure_ascii=False, indent=2) else: - # Twitter CSV format import csv - profiles_data = [p.to_twitter_format() for p in existing_profiles] + profiles_data = [p.to_twitter_format() for p in snapshot] if profiles_data: fieldnames = list(profiles_data[0].keys()) with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f: @@ -859,6 +866,9 @@ def save_profiles_realtime(): writer.writerows(profiles_data) except Exception as e: logger.warning(f"Real-time profile save failed: {e}") + + # Offload heavy disk I/O out of the ThreadPool completion loop + threading.Thread(target=_background_write, args=(existing_profiles,), daemon=True).start() def generate_single_profile(idx: int, entity: EntityNode) -> tuple: """Worker function to generate single profile""" diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index 0f8a4d1..8f3ae66 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -17,6 +17,7 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum +import concurrent.futures from ..config import Config from ..utils.llm_client import LLMClient @@ -1287,8 +1288,8 @@ def _generate_section_react( # ReACT loop tool_calls_count = 0 - max_iterations = 5 # Maximum iterations - min_tool_calls = 3 # Minimum tool calls + max_iterations = 3 # Maximum iterations (reduced from 5 for speed) + min_tool_calls = 0 # Minimum tool calls (reduced from 3 so the agent isn't forced to use tools unnecessarily) conflict_retries = 0 # Consecutive conflicts where tool calls and Final Answer appear simultaneously used_tools = set() # Record tool names already called all_tools = {"insight_forge", "panorama_search", "quick_search", "interview_agents"} @@ -1632,72 +1633,74 @@ def generate_report( logger.info(f"outlinesavedtofile: {report_id}/outline.json") - # Phase 2: Sequentially generate sectionsgeneration (per sectionsave) + # Phase 2: Concurrent section generation report.status = ReportStatus.GENERATING total_sections = len(outline.sections) - generated_sections = [] # savecontentfor context + generated_sections = [] # To mimic previous sections structure, although parallel will build out of order or individually. - for i, section in enumerate(outline.sections): - section_num = i + 1 - base_progress = 20 + int((i / total_sections) * 70) - - # Update progress - ReportManager.update_progress( - report_id, "generating", base_progress, - f"generatinggenerateSection: {section.title} ({section_num}/{total_sections})", - current_section=section.title, - completed_sections=completed_section_titles - ) - - if progress_callback: - progress_callback( - "generating", - base_progress, - f"generatinggenerateSection: {section.title} ({section_num}/{total_sections})" + # Using ThreadPoolExecutor for concurrent section generation + # Let's use max_workers reasonable parallel speed without overloading the LLM API limit. + with concurrent.futures.ThreadPoolExecutor(max_workers=Config.REPORT_PARALLEL_SECTIONS) as executor: + future_to_index = {} + for i, section in enumerate(outline.sections): + section_num = i + 1 + base_progress = 20 + int((i / total_sections) * 70) + + # Update initial progress + ReportManager.update_progress( + report_id, "generating", base_progress, + f"Queued Section: {section.title} ({section_num}/{total_sections})", + current_section=section.title, + completed_sections=completed_section_titles ) - - # Generate main sectioncontent - section_content = self._generate_section_react( - section=section, - outline=outline, - previous_sections=generated_sections, - progress_callback=lambda stage, prog, msg: - progress_callback( - stage, - base_progress + int(prog * 0.7 / total_sections), - msg - ) if progress_callback else None, - section_index=section_num - ) - - section.content = section_content - generated_sections.append(f"## {section.title}\n\n{section_content}") - - # saveSection - ReportManager.save_section(report_id, section_num, section) - completed_section_titles.append(section.title) - - # Log sectioncompletion log - full_section_content = f"## {section.title}\n\n{section_content}" - - if self.report_logger: - self.report_logger.log_section_full_complete( - section_title=section.title, - section_index=section_num, - full_content=full_section_content.strip() + + # Submit task + future = executor.submit( + self._generate_section_react, + section=section, + outline=outline, + previous_sections=[], # Empty because we are generating concurrently + progress_callback=None, # Avoid callback conflict from multiple threads, or pass custom + section_index=section_num ) - - logger.info(f"Sectionsaved: {report_id}/section_{section_num:02d}.md") + future_to_index[future] = (i, section, section_num, base_progress) - # Update progress - ReportManager.update_progress( - report_id, "generating", - base_progress + int(70 / total_sections), - f"Section {section.title} completed", - current_section=None, - completed_sections=completed_section_titles - ) + # As tasks finish + for future in concurrent.futures.as_completed(future_to_index): + i, section, section_num, base_progress = future_to_index[future] + try: + section_content = future.result() + except Exception as e: + logger.error(f"Failed to generate section {section.title}: {e}") + section_content = f"Error generating section: {str(e)}" + + section.content = section_content + generated_sections.append(f"## {section.title}\n\n{section_content}") + + # save section + ReportManager.save_section(report_id, section_num, section) + completed_section_titles.append(section.title) + + # Log completion + full_section_content = f"## {section.title}\n\n{section_content}" + if self.report_logger: + self.report_logger.log_section_full_complete( + section_title=section.title, + section_index=section_num, + full_content=full_section_content.strip() + ) + + logger.info(f"Section saved: {report_id}/section_{section_num:02d}.md") + + # Update progress + ReportManager.update_progress( + report_id, "generating", + 20 + int((len(completed_section_titles) / total_sections) * 70), + f"Section {section.title} completed", + current_section=None, + completed_sections=completed_section_titles + ) # phase3: assembleComplete report if progress_callback: diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index 44d4571..243b210 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -234,7 +234,7 @@ def prepare_simulation( defined_entity_types: Optional[List[str]] = None, use_llm_for_profiles: bool = True, progress_callback: Optional[callable] = None, - parallel_profile_count: int = 3, + parallel_profile_count: int = Config.PROFILE_PARALLEL_COUNT, storage: 'GraphStorage' = None, ) -> SimulationState: """ diff --git a/backend/app/storage/embedding_service.py b/backend/app/storage/embedding_service.py index 3d27f65..d0723f3 100644 --- a/backend/app/storage/embedding_service.py +++ b/backend/app/storage/embedding_service.py @@ -31,7 +31,7 @@ def __init__( self.base_url = (base_url or Config.EMBEDDING_BASE_URL).rstrip('/') self.max_retries = max_retries self.timeout = timeout - self._embed_url = f"{self.base_url}/api/embed" + self._embed_url = f"{self.base_url}" ##/api/embed" # Simple in-memory cache (text -> embedding vector) # Using dict instead of lru_cache because lists aren't hashable @@ -141,7 +141,7 @@ def _request_embeddings(self, texts: List[str]) -> List[List[float]]: response.raise_for_status() data = response.json() - embeddings = data.get("embeddings", []) + embeddings = data.get("embeddings", []) or [item.get("embedding", []) for item in data.get("data", [])] if len(embeddings) != len(texts): raise EmbeddingError( f"Expected {len(texts)} embeddings, got {len(embeddings)}" diff --git a/backend/app/storage/neo4j_storage.py b/backend/app/storage/neo4j_storage.py index 23d226a..34c1866 100644 --- a/backend/app/storage/neo4j_storage.py +++ b/backend/app/storage/neo4j_storage.py @@ -9,6 +9,8 @@ import time import uuid import logging +import concurrent.futures +import threading from datetime import datetime, timezone from typing import Dict, Any, List, Optional, Callable @@ -353,26 +355,45 @@ def add_text_batch( self, graph_id: str, chunks: List[str], - batch_size: int = 3, + batch_size: int = Config.GRAPH_BUILD_BATCH_SIZE, progress_callback: Optional[Callable] = None, ) -> List[str]: - """Batch-add text chunks with progress reporting.""" - episode_ids = [] + """Batch-add text chunks concurrently using a thread pool.""" total = len(chunks) + episode_ids = [None] * total + lock = threading.Lock() + completed = [0] - for i, chunk in enumerate(chunks): - if not chunk or not chunk.strip(): - continue - episode_id = self.add_text(graph_id, chunk) - episode_ids.append(episode_id) + def process_chunk(idx: int, text: str) -> Optional[str]: + if not text or not text.strip(): + return None + try: + return self.add_text(graph_id, text) + except Exception as e: + logger.error(f"Failed to process chunk {idx}: {e}") + return None + + with concurrent.futures.ThreadPoolExecutor(max_workers=max(1, batch_size)) as executor: + future_to_idx = { + executor.submit(process_chunk, i, chunk): i + for i, chunk in enumerate(chunks) + } + + for future in concurrent.futures.as_completed(future_to_idx): + i = future_to_idx[future] + ep_id = future.result() + episode_ids[i] = ep_id + + with lock: + completed[0] += 1 + c = completed[0] - if progress_callback: - progress = (i + 1) / total - progress_callback(progress) + if progress_callback: + progress_callback(c / total) - logger.info(f"Processed chunk {i + 1}/{total}") + logger.info(f"Processed chunk {c}/{total}") - return episode_ids + return [eid for eid in episode_ids if eid] def wait_for_processing( self, diff --git a/backend/run.py b/backend/run.py index 98ffe60..3cf3514 100644 --- a/backend/run.py +++ b/backend/run.py @@ -40,7 +40,7 @@ def main(): host = os.environ.get('FLASK_HOST', '0.0.0.0') port = int(os.environ.get('FLASK_PORT', 5001)) debug = Config.DEBUG - + debug = True # Start service app.run(host=host, port=port, debug=debug, threaded=True)