Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ LLM_BASE_URL=http://localhost:11434/v1
LLM_MODEL_NAME=qwen2.5:32b
# Lighter model for development (less VRAM):
# LLM_MODEL_NAME=qwen2.5:14b

DEBUG=True
# ===== Neo4j Configuration =====
# For Docker: use service name "neo4j" instead of "localhost"
NEO4J_URI=bolt://localhost:7687
Expand All @@ -39,3 +39,14 @@ OPENAI_API_BASE_URL=http://localhost:11434/v1
# NEO4J_URI=bolt://neo4j:7687
# EMBEDDING_BASE_URL=http://ollama:11434
# OPENAI_API_BASE_URL=http://ollama:11434/v1


# ===== Parallelization & Performance Tuning =====
# Number of text chunks parsed into Knowledge Graph concurrently
GRAPH_BUILD_BATCH_SIZE=10
# Number of Profiles/Agents to generate concurrently
PROFILE_PARALLEL_COUNT=10
# Number of database search worker threads per profile generation
PROFILE_SEARCH_WORKERS=2
# Number of Report sections generated concurrently by the report agent
REPORT_PARALLEL_SECTIONS=5
2 changes: 1 addition & 1 deletion backend/app/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def add_progress_callback(msg, progress_ratio):
episode_uuids = builder.add_text_batches(
graph_id,
chunks,
batch_size=3,
batch_size=Config.GRAPH_BUILD_BATCH_SIZE,
progress_callback=add_progress_callback
)

Expand Down
6 changes: 6 additions & 0 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class Config:
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))

# Parallelization & Performance tuning
GRAPH_BUILD_BATCH_SIZE = int(os.environ.get('GRAPH_BUILD_BATCH_SIZE', '10'))
PROFILE_PARALLEL_COUNT = int(os.environ.get('PROFILE_PARALLEL_COUNT', '10'))
PROFILE_SEARCH_WORKERS = int(os.environ.get('PROFILE_SEARCH_WORKERS', '2'))
REPORT_PARALLEL_SECTIONS = int(os.environ.get('REPORT_PARALLEL_SECTIONS', '5'))

@classmethod
def validate(cls):
"""Validate required configuration"""
Expand Down
35 changes: 28 additions & 7 deletions backend/app/services/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
import concurrent.futures

from ..config import Config
from ..models.task import TaskManager, TaskStatus
Expand Down Expand Up @@ -51,7 +52,7 @@ def build_graph_async(
graph_name: str = "MiroFish Graph",
chunk_size: int = 500,
chunk_overlap: int = 50,
batch_size: int = 3
batch_size: int = Config.GRAPH_BUILD_BATCH_SIZE
) -> str:
"""
Build graph asynchronously
Expand Down Expand Up @@ -186,7 +187,7 @@ def add_text_batches(
self,
graph_id: str,
chunks: List[str],
batch_size: int = 3,
batch_size: int = Config.GRAPH_BUILD_BATCH_SIZE,
progress_callback: Optional[Callable] = None
) -> List[str]:
"""Add text in batches to graph, return uuid list of all episodes"""
Expand All @@ -207,7 +208,7 @@ def add_text_batches(
progress
)

for j, chunk in enumerate(batch_chunks):
def process_single_chunk(j: int, chunk: str) -> str:
chunk_idx = i + j + 1
chunk_preview = chunk[:80].replace('\n', ' ')
logger.info(
Expand All @@ -217,20 +218,40 @@ def add_text_batches(
t0 = time.time()
try:
episode_id = self.storage.add_text(graph_id, chunk)
episode_uuids.append(episode_id)
elapsed = time.time() - t0
logger.info(
f"[graph_build] Chunk {chunk_idx}/{total_chunks} done in {elapsed:.1f}s"
)
return episode_id
except Exception as e:
elapsed = time.time() - t0
logger.error(
f"[graph_build] Chunk {chunk_idx}/{total_chunks} FAILED "
f"after {elapsed:.1f}s: {e}"
)
if progress_callback:
progress_callback(f"Batch {batch_num} processing failed: {str(e)}", 0)
raise
raise e

# Run chunks in the current batch concurrently
with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
future_to_j = {
executor.submit(process_single_chunk, j, chunk): j
for j, chunk in enumerate(batch_chunks)
}

batch_results = [None] * len(batch_chunks)
for future in concurrent.futures.as_completed(future_to_j):
j = future_to_j[future]
try:
batch_results[j] = future.result()
except Exception as e:
if progress_callback:
progress_callback(f"Batch {batch_num} processing failed: {str(e)}", 0)
raise

# Append successful episode_ids in order
for eid in batch_results:
if eid:
episode_uuids.append(eid)

logger.info(f"[graph_build] All {total_chunks} chunks processed successfully")
return episode_uuids
Expand Down
72 changes: 41 additions & 31 deletions backend/app/services/oasis_profile_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
import concurrent.futures
import threading

from openai import OpenAI

Expand Down Expand Up @@ -305,29 +307,35 @@ def _search_graph_for_entity(self, entity: EntityNode) -> Dict[str, Any]:
comprehensive_query = f"All information, activities, events, relationships and background about {entity_name}"

try:
# Search edges (facts)
edge_results = self.storage.search(
graph_id=self.graph_id,
query=comprehensive_query,
limit=30,
scope="edges"
)
# Execute searches concurrently since both hit the database/embedding service
with concurrent.futures.ThreadPoolExecutor(max_workers=Config.PROFILE_SEARCH_WORKERS) as io_executor:
future_edges = io_executor.submit(
self.storage.search,
graph_id=self.graph_id,
query=comprehensive_query,
limit=30,
scope="edges"
)
future_nodes = io_executor.submit(
self.storage.search,
graph_id=self.graph_id,
query=comprehensive_query,
limit=20,
scope="nodes"
)

all_facts = set()
if isinstance(edge_results, dict) and 'edges' in edge_results:
for edge in edge_results['edges']:
fact = edge.get('fact', '')
if fact:
all_facts.add(fact)
results["facts"] = list(all_facts)

# Search nodes (entity summaries)
node_results = self.storage.search(
graph_id=self.graph_id,
query=comprehensive_query,
limit=20,
scope="nodes"
)
# Collect edge results
edge_results = future_edges.result()
all_facts = set()
if isinstance(edge_results, dict) and 'edges' in edge_results:
for edge in edge_results['edges']:
fact = edge.get('fact', '')
if fact:
all_facts.add(fact)
results["facts"] = list(all_facts)

# Collect node results
node_results = future_nodes.result()

all_summaries = set()
if isinstance(node_results, dict) and 'nodes' in node_results:
Expand Down Expand Up @@ -478,8 +486,8 @@ def _generate_profile_with_llm(
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # Lower temperature with each retry
# Don't set max_tokens, let LLM generate freely
temperature=0.7 - (attempt * 0.1), # Lower temperature with each retry
max_tokens=4000 # Enforce max_tokens to prevent vLLM infinite generation loops
)

content = response.choices[0].message.content
Expand Down Expand Up @@ -798,7 +806,7 @@ def generate_profiles_from_entities(
use_llm: bool = True,
progress_callback: Optional[callable] = None,
graph_id: Optional[str] = None,
parallel_count: int = 5,
parallel_count: int = Config.PROFILE_PARALLEL_COUNT,
realtime_output_path: Optional[str] = None,
output_platform: str = "reddit"
) -> List[OasisAgentProfile]:
Expand Down Expand Up @@ -831,26 +839,25 @@ def generate_profiles_from_entities(

# Helper function for real-time file writing
def save_profiles_realtime():
"""Real-time save generated profiles to file"""
"""Real-time save generated profiles to file without blocking the main event loop"""
if not realtime_output_path:
return

# Capture snapshot quickly
with lock:
# Filter generated profiles
existing_profiles = [p for p in profiles if p is not None]
if not existing_profiles:
return

def _background_write(snapshot):
try:
if output_platform == "reddit":
# Reddit JSON format
profiles_data = [p.to_reddit_format() for p in existing_profiles]
profiles_data = [p.to_reddit_format() for p in snapshot]
with open(realtime_output_path, 'w', encoding='utf-8') as f:
json.dump(profiles_data, f, ensure_ascii=False, indent=2)
else:
# Twitter CSV format
import csv
profiles_data = [p.to_twitter_format() for p in existing_profiles]
profiles_data = [p.to_twitter_format() for p in snapshot]
if profiles_data:
fieldnames = list(profiles_data[0].keys())
with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f:
Expand All @@ -859,6 +866,9 @@ def save_profiles_realtime():
writer.writerows(profiles_data)
except Exception as e:
logger.warning(f"Real-time profile save failed: {e}")

# Offload heavy disk I/O out of the ThreadPool completion loop
threading.Thread(target=_background_write, args=(existing_profiles,), daemon=True).start()

def generate_single_profile(idx: int, entity: EntityNode) -> tuple:
"""Worker function to generate single profile"""
Expand Down
Loading