diff --git a/Dockerfile b/Dockerfile index e65646860..c3592360d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,8 +5,8 @@ RUN apt-get update \ && apt-get install -y --no-install-recommends nodejs npm \ && rm -rf /var/lib/apt/lists/* -# 从 uv 官方镜像复制 uv -COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/ +# 安装 uv +RUN pip install uv==0.9.26 WORKDIR /app @@ -18,7 +18,7 @@ COPY backend/pyproject.toml backend/uv.lock ./backend/ # 安装依赖(Node + Python) RUN npm ci \ && npm ci --prefix frontend \ - && cd backend && uv sync --frozen + && cd backend && uv sync # 复制项目源码 COPY . . diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 12ff1ba2d..b6ed4e94c 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -1,6 +1,6 @@ """ -图谱相关API路由 -采用项目上下文机制,服务端持久化状态 +Knowledge Graph API Routes +Using project context mechanism with server-side persistent state """ import os @@ -18,31 +18,31 @@ from ..models.task import TaskManager, TaskStatus from ..models.project import ProjectManager, ProjectStatus -# 获取日志器 +# Get logger logger = get_logger('mirofish.api') def allowed_file(filename: str) -> bool: - """检查文件扩展名是否允许""" + """Check if file extension is allowed""" if not filename or '.' not in filename: return False ext = os.path.splitext(filename)[1].lower().lstrip('.') return ext in Config.ALLOWED_EXTENSIONS -# ============== 项目管理接口 ============== +# ============== Project Management Endpoints ============== @graph_bp.route('/project/', methods=['GET']) def get_project(project_id: str): """ - 获取项目详情 + Get project details """ project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {project_id}" + "error": f"Project not found: {project_id}" }), 404 return jsonify({ @@ -54,7 +54,7 @@ def get_project(project_id: str): @graph_bp.route('/project/list', methods=['GET']) def list_projects(): """ - 列出所有项目 + List all projects """ limit = request.args.get('limit', 50, type=int) projects = ProjectManager.list_projects(limit=limit) @@ -69,36 +69,36 @@ def list_projects(): @graph_bp.route('/project/', methods=['DELETE']) def delete_project(project_id: str): """ - 删除项目 + Delete project """ success = ProjectManager.delete_project(project_id) if not success: return jsonify({ "success": False, - "error": f"项目不存在或删除失败: {project_id}" + "error": f"Project not found or deletion failed: {project_id}" }), 404 return jsonify({ "success": True, - "message": f"项目已删除: {project_id}" + "message": f"Project deleted: {project_id}" }) @graph_bp.route('/project//reset', methods=['POST']) def reset_project(project_id: str): """ - 重置项目状态(用于重新构建图谱) + Reset project state (for rebuilding knowledge graph) """ project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {project_id}" + "error": f"Project not found: {project_id}" }), 404 - # 重置到本体已生成状态 + # Reset to ontology generated state if project.ontology: project.status = ProjectStatus.ONTOLOGY_GENERATED else: @@ -111,27 +111,27 @@ def reset_project(project_id: str): return jsonify({ "success": True, - "message": f"项目已重置: {project_id}", + "message": f"Project reset: {project_id}", "data": project.to_dict() }) -# ============== 接口1:上传文件并生成本体 ============== +# ============== Endpoint 1: Upload files and generate ontology ============== @graph_bp.route('/ontology/generate', methods=['POST']) def generate_ontology(): """ - 接口1:上传文件,分析生成本体定义 + Endpoint 1: Upload files, analyze and generate ontology definition - 请求方式:multipart/form-data + Request method: multipart/form-data - 参数: - files: 上传的文件(PDF/MD/TXT),可多个 - simulation_requirement: 模拟需求描述(必填) - project_name: 项目名称(可选) - additional_context: 额外说明(可选) + Parameters: + files: Uploaded files (PDF/MD/TXT), multiple allowed + simulation_requirement: Simulation requirement description (required) + project_name: Project name (optional) + additional_context: Additional context (optional) - 返回: + Returns: { "success": true, "data": { @@ -147,42 +147,42 @@ def generate_ontology(): } """ try: - logger.info("=== 开始生成本体定义 ===") + logger.info("=== Starting ontology definition generation ===") - # 获取参数 + # Get parameters simulation_requirement = request.form.get('simulation_requirement', '') project_name = request.form.get('project_name', 'Unnamed Project') additional_context = request.form.get('additional_context', '') - logger.debug(f"项目名称: {project_name}") - logger.debug(f"模拟需求: {simulation_requirement[:100]}...") + logger.debug(f"Project name: {project_name}") + logger.debug(f"Simulation requirement: {simulation_requirement[:100]}...") if not simulation_requirement: return jsonify({ "success": False, - "error": "请提供模拟需求描述 (simulation_requirement)" + "error": "Please provide simulation requirement description (simulation_requirement)" }), 400 - # 获取上传的文件 + # Get uploaded files uploaded_files = request.files.getlist('files') if not uploaded_files or all(not f.filename for f in uploaded_files): return jsonify({ "success": False, - "error": "请至少上传一个文档文件" + "error": "Please upload at least one document file" }), 400 - # 创建项目 + # Create project project = ProjectManager.create_project(name=project_name) project.simulation_requirement = simulation_requirement - logger.info(f"创建项目: {project.project_id}") + logger.info(f"Created project: {project.project_id}") - # 保存文件并提取文本 + # Save files and extract text document_texts = [] all_text = "" for file in uploaded_files: if file and file.filename and allowed_file(file.filename): - # 保存文件到项目目录 + # Save file to project directory file_info = ProjectManager.save_file_to_project( project.project_id, file, @@ -193,7 +193,7 @@ def generate_ontology(): "size": file_info["size"] }) - # 提取文本 + # Extract text text = FileParser.extract_text(file_info["path"]) text = TextProcessor.preprocess_text(text) document_texts.append(text) @@ -203,16 +203,16 @@ def generate_ontology(): ProjectManager.delete_project(project.project_id) return jsonify({ "success": False, - "error": "没有成功处理任何文档,请检查文件格式" + "error": "No documents were successfully processed, please check file format" }), 400 - # 保存提取的文本 + # Save extracted text project.total_text_length = len(all_text) ProjectManager.save_extracted_text(project.project_id, all_text) - logger.info(f"文本提取完成,共 {len(all_text)} 字符") + logger.info(f"Text extraction complete, {len(all_text)} characters") - # 生成本体 - logger.info("调用 LLM 生成本体定义...") + # Generate ontology + logger.info("Calling LLM to generate ontology definition...") generator = OntologyGenerator() ontology = generator.generate( document_texts=document_texts, @@ -220,10 +220,10 @@ def generate_ontology(): additional_context=additional_context if additional_context else None ) - # 保存本体到项目 + # Save ontology to project entity_count = len(ontology.get("entity_types", [])) edge_count = len(ontology.get("edge_types", [])) - logger.info(f"本体生成完成: {entity_count} 个实体类型, {edge_count} 个关系类型") + logger.info(f"Ontology generation complete: {entity_count} entity types, {edge_count} relation types") project.ontology = { "entity_types": ontology.get("entity_types", []), @@ -232,7 +232,7 @@ def generate_ontology(): project.analysis_summary = ontology.get("analysis_summary", "") project.status = ProjectStatus.ONTOLOGY_GENERATED ProjectManager.save_project(project) - logger.info(f"=== 本体生成完成 === 项目ID: {project.project_id}") + logger.info(f"=== Ontology generation complete === Project ID: {project.project_id}") return jsonify({ "success": True, @@ -254,140 +254,140 @@ def generate_ontology(): }), 500 -# ============== 接口2:构建图谱 ============== +# ============== Endpoint 2: Build knowledge graph ============== @graph_bp.route('/build', methods=['POST']) def build_graph(): """ - 接口2:根据project_id构建图谱 + Endpoint 2: Build knowledge graph by project_id - 请求(JSON): + Request (JSON): { - "project_id": "proj_xxxx", // 必填,来自接口1 - "graph_name": "图谱名称", // 可选 - "chunk_size": 500, // 可选,默认500 - "chunk_overlap": 50 // 可选,默认50 + "project_id": "proj_xxxx", // Required, from Endpoint 1 + "graph_name": "Graph name", // Optional + "chunk_size": 500, // Optional,default 500 + "chunk_overlap": 50 // Optional,default 50 } - 返回: + Returns: { "success": true, "data": { "project_id": "proj_xxxx", "task_id": "task_xxxx", - "message": "图谱构建任务已启动" + "message": "Knowledge graph build task started" } } """ try: - logger.info("=== 开始构建图谱 ===") + logger.info("=== Starting knowledge graph build ===") - # 检查配置 + # Check configuration errors = [] - if not Config.ZEP_API_KEY: - errors.append("ZEP_API_KEY未配置") + if not Config.LLM_API_KEY: + errors.append("LLM_API_KEY not configured") if errors: - logger.error(f"配置错误: {errors}") + logger.error(f"Configuration error: {errors}") return jsonify({ "success": False, - "error": "配置错误: " + "; ".join(errors) + "error": "Configuration error: " + "; ".join(errors) }), 500 - # 解析请求 + # Parse request data = request.get_json() or {} project_id = data.get('project_id') - logger.debug(f"请求参数: project_id={project_id}") + logger.debug(f"Request parameters: project_id={project_id}") if not project_id: return jsonify({ "success": False, - "error": "请提供 project_id" + "error": "Please provide project_id" }), 400 - # 获取项目 + # Get project project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {project_id}" + "error": f"Project not found: {project_id}" }), 404 - # 检查项目状态 - force = data.get('force', False) # 强制重新构建 + # Check project status + force = data.get('force', False) # Force rebuild if project.status == ProjectStatus.CREATED: return jsonify({ "success": False, - "error": "项目尚未生成本体,请先调用 /ontology/generate" + "error": "Project has not generated ontology yet, please call /ontology/generate first" }), 400 if project.status == ProjectStatus.GRAPH_BUILDING and not force: return jsonify({ "success": False, - "error": "图谱正在构建中,请勿重复提交。如需强制重建,请添加 force: true", + "error": "Knowledge graph is being built, please do not resubmit. To force rebuild, add force: true", "task_id": project.graph_build_task_id }), 400 - # 如果强制重建,重置状态 + # If force rebuild, reset state if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]: project.status = ProjectStatus.ONTOLOGY_GENERATED project.graph_id = None project.graph_build_task_id = None project.error = None - # 获取配置 + # Get configuration graph_name = data.get('graph_name', project.name or 'MiroFish Graph') chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE) chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP) - # 更新项目配置 + # Update project configuration project.chunk_size = chunk_size project.chunk_overlap = chunk_overlap - # 获取提取的文本 + # Get extracted text text = ProjectManager.get_extracted_text(project_id) if not text: return jsonify({ "success": False, - "error": "未找到提取的文本内容" + "error": "Extracted text content not found" }), 400 - # 获取本体 + # Get ontology ontology = project.ontology if not ontology: return jsonify({ "success": False, - "error": "未找到本体定义" + "error": "Ontology definition not found" }), 400 - # 创建异步任务 + # Create async task task_manager = TaskManager() - task_id = task_manager.create_task(f"构建图谱: {graph_name}") - logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}") + task_id = task_manager.create_task(f"Build knowledge graph: {graph_name}") + logger.info(f"Created graph build task: task_id={task_id}, project_id={project_id}") - # 更新项目状态 + # Update project status project.status = ProjectStatus.GRAPH_BUILDING project.graph_build_task_id = task_id ProjectManager.save_project(project) - # 启动后台任务 + # Start background task def build_task(): build_logger = get_logger('mirofish.build') try: - build_logger.info(f"[{task_id}] 开始构建图谱...") + build_logger.info(f"[{task_id}] Starting knowledge graph build...") task_manager.update_task( task_id, status=TaskStatus.PROCESSING, - message="初始化图谱构建服务..." + message="Initializing graph build service..." ) - # 创建图谱构建服务 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + # Create graph build service + builder = GraphBuilderService() - # 分块 + # Chunking task_manager.update_task( task_id, - message="文本分块中...", + message="Chunking text...", progress=5 ) chunks = TextProcessor.split_text( @@ -397,27 +397,27 @@ def build_task(): ) total_chunks = len(chunks) - # 创建图谱 + # Create graph task_manager.update_task( task_id, - message="创建Zep图谱...", + message="Creating graph...", progress=10 ) graph_id = builder.create_graph(name=graph_name) - # 更新项目的graph_id + # Update project graph_id project.graph_id = graph_id ProjectManager.save_project(project) - # 设置本体 + # Set ontology task_manager.update_task( task_id, - message="设置本体定义...", + message="Setting ontology definition...", progress=15 ) builder.set_ontology(graph_id, ontology) - # 添加文本(progress_callback 签名是 (msg, progress_ratio)) + # Add text (progress_callback signature is (msg, progress_ratio)) def add_progress_callback(msg, progress_ratio): progress = 15 + int(progress_ratio * 40) # 15% - 55% task_manager.update_task( @@ -428,7 +428,7 @@ def add_progress_callback(msg, progress_ratio): task_manager.update_task( task_id, - message=f"开始添加 {total_chunks} 个文本块...", + message=f"Starting to add {total_chunks} text chunks...", progress=15 ) @@ -439,44 +439,29 @@ def add_progress_callback(msg, progress_ratio): progress_callback=add_progress_callback ) - # 等待Zep处理完成(查询每个episode的processed状态) - task_manager.update_task( - task_id, - message="等待Zep处理数据...", - progress=55 - ) - - def wait_progress_callback(msg, progress_ratio): - progress = 55 + int(progress_ratio * 35) # 55% - 90% - task_manager.update_task( - task_id, - message=msg, - progress=progress - ) - - builder._wait_for_episodes(episode_uuids, wait_progress_callback) - - # 获取图谱数据 + # Graphiti processes episodes synchronously, no waiting needed + + # Retrieve graph data task_manager.update_task( task_id, - message="获取图谱数据...", + message="Retrieving graph data...", progress=95 ) graph_data = builder.get_graph_data(graph_id) - # 更新项目状态 + # Update project status project.status = ProjectStatus.GRAPH_COMPLETED ProjectManager.save_project(project) node_count = graph_data.get("node_count", 0) edge_count = graph_data.get("edge_count", 0) - build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}") + build_logger.info(f"[{task_id}] Knowledge graph build complete: graph_id={graph_id}, nodes={node_count}, edges={edge_count}") - # 完成 + # Complete task_manager.update_task( task_id, status=TaskStatus.COMPLETED, - message="图谱构建完成", + message="Knowledge graph build complete", progress=100, result={ "project_id": project_id, @@ -488,8 +473,8 @@ def wait_progress_callback(msg, progress_ratio): ) except Exception as e: - # 更新项目状态为失败 - build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}") + # Update project status to failed + build_logger.error(f"[{task_id}] Graph build failed: {str(e)}") build_logger.debug(traceback.format_exc()) project.status = ProjectStatus.FAILED @@ -499,11 +484,11 @@ def wait_progress_callback(msg, progress_ratio): task_manager.update_task( task_id, status=TaskStatus.FAILED, - message=f"构建失败: {str(e)}", + message=f"Build failed: {str(e)}", error=traceback.format_exc() ) - # 启动后台线程 + # Start background thread thread = threading.Thread(target=build_task, daemon=True) thread.start() @@ -512,7 +497,7 @@ def wait_progress_callback(msg, progress_ratio): "data": { "project_id": project_id, "task_id": task_id, - "message": "图谱构建任务已启动,请通过 /task/{task_id} 查询进度" + "message": "Knowledge graph build task started,query progress via /task/{task_id} " } }) @@ -524,19 +509,19 @@ def wait_progress_callback(msg, progress_ratio): }), 500 -# ============== 任务查询接口 ============== +# ============== Task Query Endpoints ============== @graph_bp.route('/task/', methods=['GET']) def get_task(task_id: str): """ - 查询任务状态 + Query task status """ task = TaskManager().get_task(task_id) if not task: return jsonify({ "success": False, - "error": f"任务不存在: {task_id}" + "error": f"Task not found: {task_id}" }), 404 return jsonify({ @@ -548,7 +533,7 @@ def get_task(task_id: str): @graph_bp.route('/tasks', methods=['GET']) def list_tasks(): """ - 列出所有任务 + List all tasks """ tasks = TaskManager().list_tasks() @@ -559,21 +544,21 @@ def list_tasks(): }) -# ============== 图谱数据接口 ============== +# ============== Graph Data Endpoints ============== @graph_bp.route('/data/', methods=['GET']) def get_graph_data(graph_id: str): """ - 获取图谱数据(节点和边) + Get graph data (nodes and edges) """ try: - if not Config.ZEP_API_KEY: + if not Config.LLM_API_KEY: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "LLM_API_KEY not configured" }), 500 - - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + + builder = GraphBuilderService() graph_data = builder.get_graph_data(graph_id) return jsonify({ @@ -592,21 +577,21 @@ def get_graph_data(graph_id: str): @graph_bp.route('/delete/', methods=['DELETE']) def delete_graph(graph_id: str): """ - 删除Zep图谱 + Delete graph """ try: - if not Config.ZEP_API_KEY: + if not Config.LLM_API_KEY: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "LLM_API_KEY not configured" }), 500 - - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + + builder = GraphBuilderService() builder.delete_graph(graph_id) return jsonify({ "success": True, - "message": f"图谱已删除: {graph_id}" + "message": f"Graph deleted: {graph_id}" }) except Exception as e: diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a0f68168..86c08bd43 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -56,10 +56,10 @@ def get_graph_entities(graph_id: str): enrich: 是否获取相关边信息(默认true) """ try: - if not Config.ZEP_API_KEY: + if not Config.LLM_API_KEY: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "LLM_API_KEY not configured" }), 500 entity_types_str = request.args.get('entity_types', '') @@ -93,10 +93,10 @@ def get_graph_entities(graph_id: str): def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: - if not Config.ZEP_API_KEY: + if not Config.LLM_API_KEY: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "LLM_API_KEY not configured" }), 500 reader = ZepEntityReader() @@ -126,10 +126,10 @@ def get_entity_detail(graph_id: str, entity_uuid: str): def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: - if not Config.ZEP_API_KEY: + if not Config.LLM_API_KEY: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "LLM_API_KEY not configured" }), 500 enrich = request.args.get('enrich', 'true').lower() == 'true' diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50a..ef0c72a51 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -32,17 +32,20 @@ class Config: LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') - # Zep配置 - ZEP_API_KEY = os.environ.get('ZEP_API_KEY') - + # Graphiti / Neo4j config + GRAPHITI_DB_DIR = os.path.join(os.path.dirname(__file__), '../uploads/graphs') + NEO4J_URI = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') + NEO4J_USER = os.environ.get('NEO4J_USER', 'neo4j') + NEO4J_PASSWORD = os.environ.get('NEO4J_PASSWORD', 'mirofish123') + # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads') ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'} # 文本处理配置 - DEFAULT_CHUNK_SIZE = 500 # 默认切块大小 - DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小 + DEFAULT_CHUNK_SIZE = 2000 # default chunk size (larger = fewer LLM calls, faster) + DEFAULT_CHUNK_OVERLAP = 100 # default overlap # OASIS模拟配置 OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10')) @@ -69,7 +72,5 @@ def validate(cls): errors = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") return errors diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 0e0444bf3..f8652aa81 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -1,6 +1,6 @@ """ -图谱构建服务 -接口2:使用Zep API构建Standalone Graph +Graph Builder Service +Interface 2: Build knowledge graphs using GraphitiAdapter """ import os @@ -10,23 +10,20 @@ from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass -from zep_cloud.client import Zep -from zep_cloud import EpisodeData, EntityEdgeSourceTarget - from ..config import Config from ..models.task import TaskManager, TaskStatus -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges from .text_processor import TextProcessor +from .graphiti_adapter import GraphitiAdapter @dataclass class GraphInfo: - """图谱信息""" + """Graph information""" graph_id: str node_count: int edge_count: int entity_types: List[str] - + def to_dict(self) -> Dict[str, Any]: return { "graph_id": self.graph_id, @@ -38,18 +35,13 @@ def to_dict(self) -> Dict[str, Any]: class GraphBuilderService: """ - 图谱构建服务 - 负责调用Zep API构建知识图谱 + Graph Builder Service + Builds knowledge graphs using GraphitiAdapter """ - - def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + + def __init__(self): self.task_manager = TaskManager() - + def build_graph_async( self, text: str, @@ -60,20 +52,20 @@ def build_graph_async( batch_size: int = 3 ) -> str: """ - 异步构建图谱 - + Build graph asynchronously + Args: - text: 输入文本 - ontology: 本体定义(来自接口1的输出) - graph_name: 图谱名称 - chunk_size: 文本块大小 - chunk_overlap: 块重叠大小 - batch_size: 每批发送的块数量 - + text: Input text + ontology: Ontology definition (from interface 1 output) + graph_name: Graph name + chunk_size: Text chunk size + chunk_overlap: Chunk overlap size + batch_size: Number of chunks per batch + Returns: - 任务ID + Task ID """ - # 创建任务 + # Create task task_id = self.task_manager.create_task( task_type="graph_build", metadata={ @@ -82,17 +74,17 @@ def build_graph_async( "text_length": len(text), } ) - - # 在后台线程中执行构建 + + # Execute build in background thread thread = threading.Thread( target=self._build_graph_worker, args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size) ) thread.daemon = True thread.start() - + return task_id - + def _build_graph_worker( self, task_id: str, @@ -103,188 +95,81 @@ def _build_graph_worker( chunk_overlap: int, batch_size: int ): - """图谱构建工作线程""" + """Graph build worker thread""" try: self.task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=5, - message="开始构建图谱..." + message="Starting graph build..." ) - - # 1. 创建图谱 + + # 1. Create graph graph_id = self.create_graph(graph_name) self.task_manager.update_task( task_id, progress=10, - message=f"图谱已创建: {graph_id}" + message=f"Graph created: {graph_id}" ) - - # 2. 设置本体 + + # 2. Set ontology self.set_ontology(graph_id, ontology) self.task_manager.update_task( task_id, progress=15, - message="本体已设置" + message="Ontology set" ) - - # 3. 文本分块 + + # 3. Split text into chunks chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) total_chunks = len(chunks) self.task_manager.update_task( task_id, progress=20, - message=f"文本已分割为 {total_chunks} 个块" + message=f"Text split into {total_chunks} chunks" ) - - # 4. 分批发送数据 - episode_uuids = self.add_text_batches( + + # 4. Send data in batches + self.add_text_batches( graph_id, chunks, batch_size, lambda msg, prog: self.task_manager.update_task( task_id, - progress=20 + int(prog * 0.4), # 20-60% - message=msg - ) - ) - - # 5. 等待Zep处理完成 - self.task_manager.update_task( - task_id, - progress=60, - message="等待Zep处理数据..." - ) - - self._wait_for_episodes( - episode_uuids, - lambda msg, prog: self.task_manager.update_task( - task_id, - progress=60 + int(prog * 0.3), # 60-90% + progress=20 + int(prog * 60), # 20-80% message=msg ) ) - - # 6. 获取图谱信息 + + # 5. Retrieve graph info self.task_manager.update_task( task_id, progress=90, - message="获取图谱信息..." + message="Retrieving graph info..." ) - + graph_info = self._get_graph_info(graph_id) - - # 完成 + + # Complete self.task_manager.complete_task(task_id, { "graph_id": graph_id, "graph_info": graph_info.to_dict(), "chunks_processed": total_chunks, }) - + except Exception as e: import traceback error_msg = f"{str(e)}\n{traceback.format_exc()}" self.task_manager.fail_task(task_id, error_msg) - + def create_graph(self, name: str) -> str: - """创建Zep图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - - self.client.graph.create( - graph_id=graph_id, - name=name, - description="MiroFish Social Simulation Graph" - ) - + adapter = GraphitiAdapter.get_or_create(graph_id) + adapter.create_graph(name) return graph_id - + def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): - """设置图谱本体(公开方法)""" - import warnings - from typing import Optional - from pydantic import Field - from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel - - # 抑制 Pydantic v2 关于 Field(default=None) 的警告 - # 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略 - warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') - - # Zep 保留名称,不能作为属性名 - RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} - - def safe_attr_name(attr_name: str) -> str: - """将保留名称转换为安全名称""" - if attr_name.lower() in RESERVED_NAMES: - return f"entity_{attr_name}" - return attr_name - - # 动态创建实体类型 - entity_types = {} - for entity_def in ontology.get("entity_types", []): - name = entity_def["name"] - description = entity_def.get("description", f"A {name} entity.") - - # 创建属性字典和类型注解(Pydantic v2 需要) - attrs = {"__doc__": description} - annotations = {} - - for attr_def in entity_def.get("attributes", []): - attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 - attr_desc = attr_def.get("description", attr_name) - # Zep API 需要 Field 的 description,这是必需的 - attrs[attr_name] = Field(description=attr_desc, default=None) - annotations[attr_name] = Optional[EntityText] # 类型注解 - - attrs["__annotations__"] = annotations - - # 动态创建类 - entity_class = type(name, (EntityModel,), attrs) - entity_class.__doc__ = description - entity_types[name] = entity_class - - # 动态创建边类型 - edge_definitions = {} - for edge_def in ontology.get("edge_types", []): - name = edge_def["name"] - description = edge_def.get("description", f"A {name} relationship.") - - # 创建属性字典和类型注解 - attrs = {"__doc__": description} - annotations = {} - - for attr_def in edge_def.get("attributes", []): - attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 - attr_desc = attr_def.get("description", attr_name) - # Zep API 需要 Field 的 description,这是必需的 - attrs[attr_name] = Field(description=attr_desc, default=None) - annotations[attr_name] = Optional[str] # 边属性用str类型 - - attrs["__annotations__"] = annotations - - # 动态创建类 - class_name = ''.join(word.capitalize() for word in name.split('_')) - edge_class = type(class_name, (EdgeModel,), attrs) - edge_class.__doc__ = description - - # 构建source_targets - source_targets = [] - for st in edge_def.get("source_targets", []): - source_targets.append( - EntityEdgeSourceTarget( - source=st.get("source", "Entity"), - target=st.get("target", "Entity") - ) - ) - - if source_targets: - edge_definitions[name] = (edge_class, source_targets) - - # 调用Zep API设置本体 - if entity_types or edge_definitions: - self.client.graph.set_ontology( - graph_ids=[graph_id], - entities=entity_types if entity_types else None, - edges=edge_definitions if edge_definitions else None, - ) - + adapter = GraphitiAdapter.get_or_create(graph_id) + adapter.set_ontology(ontology) + def add_text_batches( self, graph_id: str, @@ -292,123 +177,40 @@ def add_text_batches( batch_size: int = 3, progress_callback: Optional[Callable] = None ) -> List[str]: - """分批添加文本到图谱,返回所有 episode 的 uuid 列表""" - episode_uuids = [] + adapter = GraphitiAdapter.get_or_create(graph_id) total_chunks = len(chunks) - + for i in range(0, total_chunks, batch_size): - batch_chunks = chunks[i:i + batch_size] + batch = chunks[i:i + batch_size] batch_num = i // batch_size + 1 total_batches = (total_chunks + batch_size - 1) // batch_size - + if progress_callback: - progress = (i + len(batch_chunks)) / total_chunks + progress = (i + len(batch)) / total_chunks progress_callback( - f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...", + f"Sending batch {batch_num}/{total_batches} ({len(batch)} chunks)...", progress ) - - # 构建episode数据 - episodes = [ - EpisodeData(data=chunk, type="text") - for chunk in batch_chunks - ] - - # 发送到Zep + try: - batch_result = self.client.graph.add_batch( - graph_id=graph_id, - episodes=episodes - ) - - # 收集返回的 episode uuid - if batch_result and isinstance(batch_result, list): - for ep in batch_result: - ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) - if ep_uuid: - episode_uuids.append(ep_uuid) - - # 避免请求过快 - time.sleep(1) - + adapter.add_episodes_bulk(batch) except Exception as e: if progress_callback: - progress_callback(f"批次 {batch_num} 发送失败: {str(e)}", 0) + progress_callback(f"Batch {batch_num} failed: {str(e)}", 0) raise - - return episode_uuids - - def _wait_for_episodes( - self, - episode_uuids: List[str], - progress_callback: Optional[Callable] = None, - timeout: int = 600 - ): - """等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" - if not episode_uuids: - if progress_callback: - progress_callback("无需等待(没有 episode)", 1.0) - return - - start_time = time.time() - pending_episodes = set(episode_uuids) - completed_count = 0 - total_episodes = len(episode_uuids) - - if progress_callback: - progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0) - - while pending_episodes: - if time.time() - start_time > timeout: - if progress_callback: - progress_callback( - f"部分文本块超时,已完成 {completed_count}/{total_episodes}", - completed_count / total_episodes - ) - break - - # 检查每个 episode 的处理状态 - for ep_uuid in list(pending_episodes): - try: - episode = self.client.graph.episode.get(uuid_=ep_uuid) - is_processed = getattr(episode, 'processed', False) - - if is_processed: - pending_episodes.remove(ep_uuid) - completed_count += 1 - - except Exception as e: - # 忽略单个查询错误,继续 - pass - - elapsed = int(time.time() - start_time) - if progress_callback: - progress_callback( - f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)", - completed_count / total_episodes if total_episodes > 0 else 0 - ) - - if pending_episodes: - time.sleep(3) # 每3秒检查一次 - - if progress_callback: - progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0) - - def _get_graph_info(self, graph_id: str) -> GraphInfo: - """获取图谱信息""" - # 获取节点(分页) - nodes = fetch_all_nodes(self.client, graph_id) - # 获取边(分页) - edges = fetch_all_edges(self.client, graph_id) + return [] # No episode UUIDs needed + + def _get_graph_info(self, graph_id: str) -> GraphInfo: + adapter = GraphitiAdapter.get_or_create(graph_id) + nodes = adapter.get_all_nodes() + edges = adapter.get_all_edges() - # 统计实体类型 entity_types = set() for node in nodes: - if node.labels: - for label in node.labels: - if label not in ["Entity", "Node"]: - entity_types.add(label) + for label in node.get("labels", []): + if label not in ["Entity", "Node"]: + entity_types.add(label) return GraphInfo( graph_id=graph_id, @@ -416,85 +218,28 @@ def _get_graph_info(self, graph_id: str) -> GraphInfo: edge_count=len(edges), entity_types=list(entity_types) ) - + def get_graph_data(self, graph_id: str) -> Dict[str, Any]: - """ - 获取完整图谱数据(包含详细信息) - - Args: - graph_id: 图谱ID - - Returns: - 包含nodes和edges的字典,包括时间信息、属性等详细数据 - """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) + adapter = GraphitiAdapter.get_or_create(graph_id) + nodes = adapter.get_all_nodes() + edges = adapter.get_all_edges() + + node_map = {n["uuid"]: n.get("name", "") for n in nodes} - # 创建节点映射用于获取节点名称 - node_map = {} - for node in nodes: - node_map[node.uuid_] = node.name or "" - - nodes_data = [] - for node in nodes: - # 获取创建时间 - created_at = getattr(node, 'created_at', None) - if created_at: - created_at = str(created_at) - - nodes_data.append({ - "uuid": node.uuid_, - "name": node.name, - "labels": node.labels or [], - "summary": node.summary or "", - "attributes": node.attributes or {}, - "created_at": created_at, - }) - - edges_data = [] for edge in edges: - # 获取时间信息 - created_at = getattr(edge, 'created_at', None) - valid_at = getattr(edge, 'valid_at', None) - invalid_at = getattr(edge, 'invalid_at', None) - expired_at = getattr(edge, 'expired_at', None) - - # 获取 episodes - episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) - if episodes and not isinstance(episodes, list): - episodes = [str(episodes)] - elif episodes: - episodes = [str(e) for e in episodes] - - # 获取 fact_type - fact_type = getattr(edge, 'fact_type', None) or edge.name or "" - - edges_data.append({ - "uuid": edge.uuid_, - "name": edge.name or "", - "fact": edge.fact or "", - "fact_type": fact_type, - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "source_node_name": node_map.get(edge.source_node_uuid, ""), - "target_node_name": node_map.get(edge.target_node_uuid, ""), - "attributes": edge.attributes or {}, - "created_at": str(created_at) if created_at else None, - "valid_at": str(valid_at) if valid_at else None, - "invalid_at": str(invalid_at) if invalid_at else None, - "expired_at": str(expired_at) if expired_at else None, - "episodes": episodes or [], - }) - + edge["source_node_name"] = node_map.get(edge.get("source_node_uuid", ""), "") + edge["target_node_name"] = node_map.get(edge.get("target_node_uuid", ""), "") + edge["fact_type"] = edge.get("name", "") + return { "graph_id": graph_id, - "nodes": nodes_data, - "edges": edges_data, - "node_count": len(nodes_data), - "edge_count": len(edges_data), + "nodes": nodes, + "edges": edges, + "node_count": len(nodes), + "edge_count": len(edges), } - - def delete_graph(self, graph_id: str): - """删除图谱""" - self.client.graph.delete(graph_id=graph_id) + def delete_graph(self, graph_id: str): + adapter = GraphitiAdapter.get_or_create(graph_id) + adapter.delete_graph() + GraphitiAdapter.remove_instance(graph_id) diff --git a/backend/app/services/graphiti_adapter.py b/backend/app/services/graphiti_adapter.py new file mode 100644 index 000000000..fc21840e7 --- /dev/null +++ b/backend/app/services/graphiti_adapter.py @@ -0,0 +1,452 @@ +""" +Graphiti sync adapter +Wraps graphiti-core async API for Flask synchronous use. +Replaces all zep_cloud usage, backed by Neo4j graph database. + +Core design: +- Dedicated daemon thread event loop for async-to-sync bridging +- Neo4j as graph database backend +- Native Gemini support for LLM and embeddings +- API compatible with existing Zep-based code +""" + +import asyncio +import threading +import os +import json +import shutil +import uuid as uuid_mod +from datetime import datetime, timezone +from typing import Dict, Any, List, Optional, Callable + +from graphiti_core import Graphiti +from graphiti_core.llm_client.gemini_client import GeminiClient +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig +from graphiti_core.nodes import EpisodeType, EntityNode +from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient +from pydantic import BaseModel, Field + +from ..config import Config +from ..utils.logger import get_logger + +logger = get_logger('mirofish.graphiti_adapter') + +# --------------------------------------------------------------------------- +# Module-level event loop for async operations +# --------------------------------------------------------------------------- + +_loop: Optional[asyncio.AbstractEventLoop] = None +_thread: Optional[threading.Thread] = None +_lock = threading.Lock() + + +def _ensure_loop() -> asyncio.AbstractEventLoop: + global _loop, _thread + with _lock: + if _loop is None or not _loop.is_running(): + _loop = asyncio.new_event_loop() + _thread = threading.Thread(target=_loop.run_forever, daemon=True) + _thread.start() + return _loop + + +def _run_async(coro): + loop = _ensure_loop() + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result(timeout=600) + + +# --------------------------------------------------------------------------- +# Adapter +# --------------------------------------------------------------------------- + +class GraphitiAdapter: + """Synchronous adapter for graphiti-core async API backed by Neo4j.""" + + _instances: Dict[str, 'GraphitiAdapter'] = {} + _instances_lock = threading.Lock() + + def __init__(self, graph_id: str): + self.graph_id = graph_id + self.db_dir = os.path.join(Config.UPLOAD_FOLDER, 'graphs', graph_id) + self._graphiti: Optional[Graphiti] = None + self._gemini_api_key: str = Config.LLM_API_KEY or '' + + @classmethod + def get_or_create(cls, graph_id: str) -> 'GraphitiAdapter': + with cls._instances_lock: + if graph_id not in cls._instances: + cls._instances[graph_id] = cls(graph_id) + return cls._instances[graph_id] + + @classmethod + def remove_instance(cls, graph_id: str): + with cls._instances_lock: + inst = cls._instances.pop(graph_id, None) + if inst is not None: + try: + inst.close() + except Exception: + logger.warning("Failed to close GraphitiAdapter: %s", graph_id, exc_info=True) + + # -------------------------------------------------------------------------- + # Lazy Graphiti initialisation (Neo4j + Gemini) + # -------------------------------------------------------------------------- + + def _get_graphiti(self) -> Graphiti: + if self._graphiti is None: + self._graphiti = _run_async(self._init_graphiti()) + return self._graphiti + + async def _init_graphiti(self) -> Graphiti: + os.makedirs(self.db_dir, exist_ok=True) + + llm_client = GeminiClient( + config=LLMConfig( + api_key=self._gemini_api_key, + model="gemini-2.5-flash-lite", + ) + ) + + embedder = GeminiEmbedder( + config=GeminiEmbedderConfig( + api_key=self._gemini_api_key, + embedding_model="gemini-embedding-001", + ) + ) + + cross_encoder = OpenAIRerankerClient( + config=LLMConfig( + api_key=self._gemini_api_key, + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + model="gemini-2.5-flash-lite", + ) + ) + + graphiti = Graphiti( + uri=Config.NEO4J_URI, + user=Config.NEO4J_USER, + password=Config.NEO4J_PASSWORD, + llm_client=llm_client, + embedder=embedder, + cross_encoder=cross_encoder, + ) + + await graphiti.build_indices_and_constraints() + logger.info("Graphiti initialized: graph_id=%s, neo4j=%s", self.graph_id, Config.NEO4J_URI) + return graphiti + + # -------------------------------------------------------------------------- + # Graph lifecycle + # -------------------------------------------------------------------------- + + def create_graph(self, name: str = "MiroFish Graph") -> str: + os.makedirs(self.db_dir, exist_ok=True) + + meta_path = os.path.join(self.db_dir, '_meta.json') + meta = { + 'graph_id': self.graph_id, + 'name': name, + 'created_at': datetime.now(timezone.utc).isoformat(), + } + with open(meta_path, 'w', encoding='utf-8') as f: + json.dump(meta, f, ensure_ascii=False, indent=2) + + self._get_graphiti() + logger.info("Graph created: graph_id=%s, name=%s", self.graph_id, name) + return self.graph_id + + def delete_graph(self): + self.close() + if os.path.isdir(self.db_dir): + shutil.rmtree(self.db_dir, ignore_errors=True) + logger.info("Graph deleted: graph_id=%s", self.graph_id) + + def graph_exists(self) -> bool: + return os.path.isdir(self.db_dir) + + # -------------------------------------------------------------------------- + # Ontology (stored as JSON metadata) + # -------------------------------------------------------------------------- + + def _ontology_path(self) -> str: + return os.path.join(self.db_dir, '_ontology.json') + + def set_ontology(self, ontology_dict: Dict[str, Any]): + os.makedirs(self.db_dir, exist_ok=True) + with open(self._ontology_path(), 'w', encoding='utf-8') as f: + json.dump(ontology_dict, f, ensure_ascii=False, indent=2) + logger.info("Ontology saved: graph_id=%s", self.graph_id) + + def get_ontology(self) -> Optional[Dict[str, Any]]: + path = self._ontology_path() + if not os.path.isfile(path): + return None + with open(path, 'r', encoding='utf-8') as f: + return json.load(f) + + def _build_entity_types(self) -> Optional[Dict[str, type]]: + """Convert stored ontology into Pydantic models for Graphiti entity typing.""" + ontology = self.get_ontology() + if not ontology: + return None + + reserved = set(EntityNode.model_fields.keys()) + entity_types = {} + + for et in ontology.get("entity_types", []): + name = et.get("name", "") + if not name: + continue + description = et.get("description", f"A {name} entity") + + # Build fields dict for dynamic Pydantic model + fields = {} + annotations = {} + for attr in et.get("attributes", []): + attr_name = attr.get("name", "") + if not attr_name or attr_name.lower() in reserved: + attr_name = f"entity_{attr_name}" + attr_desc = attr.get("description", attr_name) + fields[attr_name] = Field(default=None, description=attr_desc) + annotations[attr_name] = Optional[str] + + # Create dynamic Pydantic model + model_attrs = {"__annotations__": annotations, "__doc__": description} + model_attrs.update(fields) + model_class = type(name, (BaseModel,), model_attrs) + entity_types[name] = model_class + + return entity_types if entity_types else None + + # -------------------------------------------------------------------------- + # Episodes + # -------------------------------------------------------------------------- + + def add_episode(self, text: str, source_description: str = "") -> str: + g = self._get_graphiti() + episode_id = str(uuid_mod.uuid4()) + source = source_description or "mirofish" + entity_types = self._build_entity_types() + + async def _add(): + await g.add_episode( + name=episode_id, + episode_body=text, + source_description=source, + source=EpisodeType.text, + reference_time=datetime.now(timezone.utc), + group_id=self.graph_id, + entity_types=entity_types, + ) + return episode_id + + result = _run_async(_add()) + logger.debug("Episode added: graph_id=%s, len=%d", self.graph_id, len(text)) + return result + + def add_episodes_bulk(self, texts: List[str]): + total = len(texts) + for idx, text in enumerate(texts): + try: + self.add_episode(text, source_description="mirofish_bulk") + except Exception: + logger.error("Episode add failed (%d/%d): graph_id=%s", + idx + 1, total, self.graph_id, exc_info=True) + + # -------------------------------------------------------------------------- + # Node / edge retrieval via Neo4j Cypher + # -------------------------------------------------------------------------- + + def get_all_nodes(self) -> List[Dict[str, Any]]: + g = self._get_graphiti() + + async def _fetch(): + driver = g.driver + records, _, _ = await driver.execute_query( + "MATCH (n:Entity) WHERE n.group_id = $gid " + "RETURN n.uuid AS uuid, n.name AS name, labels(n) AS labels, " + "n.summary AS summary, n.created_at AS created_at", + gid=self.graph_id, + ) + nodes = [] + for r in records: + labels = [l for l in (r['labels'] or []) if l not in ('Entity', '__Entity__')] + nodes.append({ + 'uuid': str(r['uuid'] or ''), + 'name': str(r['name'] or ''), + 'labels': labels, + 'summary': str(r['summary'] or ''), + 'attributes': {}, + 'created_at': str(r['created_at'] or ''), + }) + return nodes + + return _run_async(_fetch()) + + def get_all_edges(self) -> List[Dict[str, Any]]: + g = self._get_graphiti() + + async def _fetch(): + driver = g.driver + records, _, _ = await driver.execute_query( + "MATCH (a:Entity)-[r:RELATES_TO]->(b:Entity) " + "WHERE r.group_id = $gid " + "RETURN r.uuid AS uuid, r.name AS name, r.fact AS fact, " + "a.uuid AS source_uuid, b.uuid AS target_uuid, " + "r.created_at AS created_at, r.valid_at AS valid_at, " + "r.invalid_at AS invalid_at, r.expired_at AS expired_at", + gid=self.graph_id, + ) + edges = [] + for r in records: + edges.append({ + 'uuid': str(r['uuid'] or ''), + 'name': str(r['name'] or ''), + 'fact': str(r['fact'] or ''), + 'source_node_uuid': str(r['source_uuid'] or ''), + 'target_node_uuid': str(r['target_uuid'] or ''), + 'attributes': {}, + 'created_at': str(r['created_at'] or ''), + 'valid_at': str(r['valid_at'] or ''), + 'invalid_at': str(r['invalid_at'] or ''), + 'expired_at': str(r['expired_at'] or ''), + 'episodes': [], + }) + return edges + + return _run_async(_fetch()) + + def get_node(self, node_uuid: str) -> Optional[Dict[str, Any]]: + g = self._get_graphiti() + + async def _fetch(): + driver = g.driver + records, _, _ = await driver.execute_query( + "MATCH (n:Entity) WHERE n.uuid = $uuid " + "RETURN n.uuid AS uuid, n.name AS name, labels(n) AS labels, " + "n.summary AS summary, n.created_at AS created_at", + uuid=node_uuid, + ) + if not records: + return None + r = records[0] + labels = [l for l in (r['labels'] or []) if l not in ('Entity', '__Entity__')] + return { + 'uuid': str(r['uuid'] or ''), + 'name': str(r['name'] or ''), + 'labels': labels, + 'summary': str(r['summary'] or ''), + 'attributes': {}, + 'created_at': str(r['created_at'] or ''), + } + + return _run_async(_fetch()) + + def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: + g = self._get_graphiti() + + async def _fetch(): + driver = g.driver + records, _, _ = await driver.execute_query( + "MATCH (a:Entity)-[r:RELATES_TO]->(b:Entity) " + "WHERE a.uuid = $uuid OR b.uuid = $uuid " + "RETURN r.uuid AS uuid, r.name AS name, r.fact AS fact, " + "a.uuid AS source_uuid, b.uuid AS target_uuid, " + "r.created_at AS created_at, r.valid_at AS valid_at, " + "r.invalid_at AS invalid_at, r.expired_at AS expired_at", + uuid=node_uuid, + ) + edges = [] + for r in records: + edges.append({ + 'uuid': str(r['uuid'] or ''), + 'name': str(r['name'] or ''), + 'fact': str(r['fact'] or ''), + 'source_node_uuid': str(r['source_uuid'] or ''), + 'target_node_uuid': str(r['target_uuid'] or ''), + 'attributes': {}, + 'created_at': str(r['created_at'] or ''), + 'valid_at': str(r['valid_at'] or ''), + 'invalid_at': str(r['invalid_at'] or ''), + 'expired_at': str(r['expired_at'] or ''), + 'episodes': [], + }) + return edges + + return _run_async(_fetch()) + + # -------------------------------------------------------------------------- + # Search + # -------------------------------------------------------------------------- + + def search(self, query: str, limit: int = 10, scope: str = "edges") -> Dict[str, Any]: + g = self._get_graphiti() + + async def _do_search(): + edges: List[Dict[str, Any]] = [] + nodes: List[Dict[str, Any]] = [] + + try: + raw = await g.search( + query, + num_results=limit, + group_ids=[self.graph_id], + ) + for item in (raw if isinstance(raw, list) else []): + edges.append(_entity_edge_to_dict(item)) + except Exception: + logger.error("Search failed: query=%s", query, exc_info=True) + + return {"edges": edges, "nodes": nodes} + + return _run_async(_do_search()) + + # -------------------------------------------------------------------------- + # Cleanup + # -------------------------------------------------------------------------- + + def close(self): + if self._graphiti is not None: + try: + _run_async(self._graphiti.close()) + except Exception: + logger.debug("Graphiti close error", exc_info=True) + self._graphiti = None + logger.info("GraphitiAdapter closed: graph_id=%s", self.graph_id) + + +# --------------------------------------------------------------------------- +# Conversion helpers +# --------------------------------------------------------------------------- + +def _entity_edge_to_dict(edge) -> Dict[str, Any]: + if isinstance(edge, dict): + return { + 'uuid': str(edge.get('uuid', '')), + 'name': str(edge.get('name', '')), + 'fact': str(edge.get('fact', '')), + 'source_node_uuid': str(edge.get('source_node_uuid', '')), + 'target_node_uuid': str(edge.get('target_node_uuid', '')), + 'attributes': edge.get('attributes', {}), + 'created_at': str(edge.get('created_at', '')), + 'valid_at': str(edge.get('valid_at', '')), + 'invalid_at': str(edge.get('invalid_at', '')), + 'expired_at': str(edge.get('expired_at', '')), + 'episodes': list(edge.get('episodes', [])) if edge.get('episodes') else [], + } + + return { + 'uuid': str(getattr(edge, 'uuid', '')), + 'name': str(getattr(edge, 'name', '')), + 'fact': str(getattr(edge, 'fact', '')), + 'source_node_uuid': str(getattr(edge, 'source_node_uuid', '')), + 'target_node_uuid': str(getattr(edge, 'target_node_uuid', '')), + 'attributes': getattr(edge, 'attributes', {}) or {}, + 'created_at': str(getattr(edge, 'created_at', '')), + 'valid_at': str(getattr(edge, 'valid_at', '')), + 'invalid_at': str(getattr(edge, 'invalid_at', '')), + 'expired_at': str(getattr(edge, 'expired_at', '')), + 'episodes': list(getattr(edge, 'episodes', [])) if getattr(edge, 'episodes', None) else [], + } diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 57836c539..8e512b6ac 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -1,11 +1,11 @@ """ -OASIS Agent Profile生成器 -将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 +OASIS Agent Profile Generator +Convert entities from the knowledge graph to OASIS simulation platform's required Agent Profile format -优化改进: -1. 调用Zep检索功能二次丰富节点信息 -2. 优化提示词生成非常详细的人设 -3. 区分个人实体和抽象群体实体 +Optimization improvements: +1. Call knowledge graph retrieval function to enrich node information +2. Optimize prompts to generate very detailed personas +3. Distinguish between individual entities and abstract group entities """ import json @@ -16,34 +16,34 @@ from datetime import datetime from openai import OpenAI -from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader +from .graphiti_adapter import GraphitiAdapter logger = get_logger('mirofish.oasis_profile') @dataclass class OasisAgentProfile: - """OASIS Agent Profile数据结构""" - # 通用字段 + """OASIS Agent Profile data structure""" + # Common fields user_id: int user_name: str name: str bio: str persona: str - # 可选字段 - Reddit风格 + # Optional fields - Reddit style karma: int = 1000 - # 可选字段 - Twitter风格 + # Optional fields - Twitter style friend_count: int = 100 follower_count: int = 150 statuses_count: int = 500 - # 额外人设信息 + # Additional persona information age: Optional[int] = None gender: Optional[str] = None mbti: Optional[str] = None @@ -51,17 +51,17 @@ class OasisAgentProfile: profession: Optional[str] = None interested_topics: List[str] = field(default_factory=list) - # 来源实体信息 + # Source entity information source_entity_uuid: Optional[str] = None source_entity_type: Optional[str] = None created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) def to_reddit_format(self) -> Dict[str, Any]: - """转换为Reddit平台格式""" + """Convert to Reddit platform format""" profile = { "user_id": self.user_id, - "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) + "username": self.user_name, # OASIS library requires field name as username (no underscore) "name": self.name, "bio": self.bio, "persona": self.persona, @@ -69,7 +69,7 @@ def to_reddit_format(self) -> Dict[str, Any]: "created_at": self.created_at, } - # 添加额外人设信息(如果有) + # Add additional persona information (if available) if self.age: profile["age"] = self.age if self.gender: @@ -86,10 +86,10 @@ def to_reddit_format(self) -> Dict[str, Any]: return profile def to_twitter_format(self) -> Dict[str, Any]: - """转换为Twitter平台格式""" + """Convert to Twitter platform format""" profile = { "user_id": self.user_id, - "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) + "username": self.user_name, # OASIS library requires field name as username (no underscore) "name": self.name, "bio": self.bio, "persona": self.persona, @@ -99,7 +99,7 @@ def to_twitter_format(self) -> Dict[str, Any]: "created_at": self.created_at, } - # 添加额外人设信息 + # Add additional persona information if self.age: profile["age"] = self.age if self.gender: @@ -116,7 +116,7 @@ def to_twitter_format(self) -> Dict[str, Any]: return profile def to_dict(self) -> Dict[str, Any]: - """转换为完整字典格式""" + """Convert to complete dictionary format""" return { "user_id": self.user_id, "user_name": self.user_name, @@ -141,17 +141,17 @@ def to_dict(self) -> Dict[str, Any]: class OasisProfileGenerator: """ - OASIS Profile生成器 - - 将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile - - 优化特性: - 1. 调用Zep图谱检索功能获取更丰富的上下文 - 2. 生成非常详细的人设(包括基本信息、职业经历、性格特征、社交媒体行为等) - 3. 区分个人实体和抽象群体实体 + OASIS Profile Generator + + Convert entities from the knowledge graph to Agent Profile required by OASIS simulation + + Optimization features: + 1. Call knowledge graph retrieval function to get richer context + 2. Generate very detailed personas (including basic information, career experience, personality traits, social media behavior, etc.) + 3. Distinguish between individual entities and abstract group entities """ - # MBTI类型列表 + # MBTI types list MBTI_TYPES = [ "INTJ", "INTP", "ENTJ", "ENTP", "INFJ", "INFP", "ENFJ", "ENFP", @@ -159,30 +159,29 @@ class OasisProfileGenerator: "ISTP", "ISFP", "ESTP", "ESFP" ] - # 常见国家列表 + # Common countries list COUNTRIES = [ - "China", "US", "UK", "Japan", "Germany", "France", + "US", "UK", "Japan", "Germany", "France", "Canada", "Australia", "Brazil", "India", "South Korea" ] - # 个人类型实体(需要生成具体人设) + # Individual type entities (need to generate specific personas) INDIVIDUAL_ENTITY_TYPES = [ "student", "alumni", "professor", "person", "publicfigure", "expert", "faculty", "official", "journalist", "activist" ] - # 群体/机构类型实体(需要生成群体代表人设) + # Group/institutional type entities (need to generate group representative personas) GROUP_ENTITY_TYPES = [ "university", "governmentagency", "organization", "ngo", "mediaoutlet", "company", "institution", "group", "community" ] def __init__( - self, + self, api_key: Optional[str] = None, base_url: Optional[str] = None, model_name: Optional[str] = None, - zep_api_key: Optional[str] = None, graph_id: Optional[str] = None ): self.api_key = api_key or Config.LLM_API_KEY @@ -190,23 +189,21 @@ def __init__( self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: - raise ValueError("LLM_API_KEY 未配置") + raise ValueError("LLM_API_KEY not configured") self.client = OpenAI( api_key=self.api_key, base_url=self.base_url ) - # Zep客户端用于检索丰富上下文 - self.zep_api_key = zep_api_key or Config.ZEP_API_KEY - self.zep_client = None + # Graphiti adapter for knowledge graph search self.graph_id = graph_id - - if self.zep_api_key: + self._graphiti_adapter = None + if self.graph_id: try: - self.zep_client = Zep(api_key=self.zep_api_key) + self._graphiti_adapter = GraphitiAdapter.get_or_create(self.graph_id) except Exception as e: - logger.warning(f"Zep客户端初始化失败: {e}") + logger.warning(f"GraphitiAdapter initialization failed: {e}") def generate_profile_from_entity( self, @@ -215,27 +212,27 @@ def generate_profile_from_entity( use_llm: bool = True ) -> OasisAgentProfile: """ - 从Zep实体生成OASIS Agent Profile - + Generate OASIS Agent Profile from knowledge graph entity + Args: - entity: Zep实体节点 - user_id: 用户ID(用于OASIS) - use_llm: 是否使用LLM生成详细人设 - + entity: Knowledge graph entity node + user_id: User ID (for OASIS) + use_llm: Whether to use LLM to generate detailed persona + Returns: OasisAgentProfile """ entity_type = entity.get_entity_type() or "Entity" - # 基础信息 + # Basic information name = entity.name user_name = self._generate_username(name) - # 构建上下文信息 + # Build context information context = self._build_entity_context(entity) if use_llm: - # 使用LLM生成详细人设 + # Use LLM to generate detailed persona profile_data = self._generate_profile_with_llm( entity_name=name, entity_type=entity_type, @@ -244,7 +241,7 @@ def generate_profile_from_entity( context=context ) else: - # 使用规则生成基础人设 + # Use rules to generate basic persona profile_data = self._generate_profile_rule_based( entity_name=name, entity_type=entity_type, @@ -273,168 +270,116 @@ def generate_profile_from_entity( ) def _generate_username(self, name: str) -> str: - """生成用户名""" - # 移除特殊字符,转换为小写 + """Generate username""" + # Remove special characters, convert to lowercase username = name.lower().replace(" ", "_") username = ''.join(c for c in username if c.isalnum() or c == '_') - # 添加随机后缀避免重复 + # Add random suffix to avoid duplicates suffix = random.randint(100, 999) return f"{username}_{suffix}" def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]: """ - 使用Zep图谱混合搜索功能获取实体相关的丰富信息 - - Zep没有内置混合搜索接口,需要分别搜索edges和nodes然后合并结果。 - 使用并行请求同时搜索,提高效率。 - + Use Graphiti graph search to obtain rich information related to entity + Args: - entity: 实体节点对象 - + entity: Entity node object + Returns: - 包含facts, node_summaries, context的字典 + Dictionary containing facts, node_summaries, context """ - import concurrent.futures - - if not self.zep_client: + if not self._graphiti_adapter: return {"facts": [], "node_summaries": [], "context": ""} - + entity_name = entity.name - + results = { "facts": [], "node_summaries": [], "context": "" } - - # 必须有graph_id才能进行搜索 + if not self.graph_id: - logger.debug(f"跳过Zep检索:未设置graph_id") + logger.debug(f"Skip knowledge graph search: graph_id not set") return results - - comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景" - - def search_edges(): - """搜索边(事实/关系)- 带重试机制""" - max_retries = 3 - last_exception = None - delay = 2.0 - - for attempt in range(max_retries): - try: - return self.zep_client.graph.search( - query=comprehensive_query, - graph_id=self.graph_id, - limit=30, - scope="edges", - reranker="rrf" - ) - except Exception as e: - last_exception = e - if attempt < max_retries - 1: - logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") - time.sleep(delay) - delay *= 2 - else: - logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}") - return None - - def search_nodes(): - """搜索节点(实体摘要)- 带重试机制""" - max_retries = 3 - last_exception = None - delay = 2.0 - - for attempt in range(max_retries): - try: - return self.zep_client.graph.search( - query=comprehensive_query, - graph_id=self.graph_id, - limit=20, - scope="nodes", - reranker="rrf" - ) - except Exception as e: - last_exception = e - if attempt < max_retries - 1: - logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") - time.sleep(delay) - delay *= 2 - else: - logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}") - return None - + + comprehensive_query = f"All information, activities, events, relationships and background about {entity_name}" + try: - # 并行执行edges和nodes搜索 - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - edge_future = executor.submit(search_edges) - node_future = executor.submit(search_nodes) - - # 获取结果 - edge_result = edge_future.result(timeout=30) - node_result = node_future.result(timeout=30) - - # 处理边搜索结果 + # Search edges (facts/relations) + edge_results = self._graphiti_adapter.search( + query=comprehensive_query, + limit=30, + scope="edges" + ) + + # Search nodes (entity summaries) + node_results = self._graphiti_adapter.search( + query=comprehensive_query, + limit=20, + scope="nodes" + ) + + # Process edge results all_facts = set() - if edge_result and hasattr(edge_result, 'edges') and edge_result.edges: - for edge in edge_result.edges: - if hasattr(edge, 'fact') and edge.fact: - all_facts.add(edge.fact) + for edge in edge_results: + fact = edge.get('fact', '') + if fact: + all_facts.add(fact) results["facts"] = list(all_facts) - - # 处理节点搜索结果 + + # Process node results all_summaries = set() - if node_result and hasattr(node_result, 'nodes') and node_result.nodes: - for node in node_result.nodes: - if hasattr(node, 'summary') and node.summary: - all_summaries.add(node.summary) - if hasattr(node, 'name') and node.name and node.name != entity_name: - all_summaries.add(f"相关实体: {node.name}") + for node in node_results: + summary = node.get('summary', '') + if summary: + all_summaries.add(summary) + name = node.get('name', '') + if name and name != entity_name: + all_summaries.add(f"Related Entity: {name}") results["node_summaries"] = list(all_summaries) - - # 构建综合上下文 + + # Build combined context context_parts = [] if results["facts"]: - context_parts.append("事实信息:\n" + "\n".join(f"- {f}" for f in results["facts"][:20])) + context_parts.append("Fact Information:\n" + "\n".join(f"- {f}" for f in results["facts"][:20])) if results["node_summaries"]: - context_parts.append("相关实体:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10])) + context_parts.append("Related Entities:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10])) results["context"] = "\n\n".join(context_parts) - - logger.info(f"Zep混合检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, {len(results['node_summaries'])} 个相关节点") - - except concurrent.futures.TimeoutError: - logger.warning(f"Zep检索超时 ({entity_name})") + + logger.info(f"Knowledge graph hybrid search completed: {entity_name}, retrieved {len(results['facts'])} facts, {len(results['node_summaries'])} related nodes") + except Exception as e: - logger.warning(f"Zep检索失败 ({entity_name}): {e}") - + logger.warning(f"Knowledge graph search failed ({entity_name}): {e}") + return results def _build_entity_context(self, entity: EntityNode) -> str: """ - 构建实体的完整上下文信息 - - 包括: - 1. 实体本身的边信息(事实) - 2. 关联节点的详细信息 - 3. Zep混合检索到的丰富信息 + Build complete context information for entity + + Includes: + 1. Edge information of the entity itself (facts) + 2. Detailed information of associated nodes + 3. Rich information retrieved from knowledge graph hybrid search """ context_parts = [] - # 1. 添加实体属性信息 + # 1. Add entity attribute information if entity.attributes: attrs = [] for key, value in entity.attributes.items(): if value and str(value).strip(): attrs.append(f"- {key}: {value}") if attrs: - context_parts.append("### 实体属性\n" + "\n".join(attrs)) + context_parts.append("### Entity Attributes\n" + "\n".join(attrs)) - # 2. 添加相关边信息(事实/关系) + # 2. Add related edge information (facts/relationships) existing_facts = set() if entity.related_edges: relationships = [] - for edge in entity.related_edges: # 不限制数量 + for edge in entity.related_edges: # No limit on quantity fact = edge.get("fact", "") edge_name = edge.get("edge_name", "") direction = edge.get("direction", "") @@ -444,22 +389,22 @@ def _build_entity_context(self, entity: EntityNode) -> str: existing_facts.add(fact) elif edge_name: if direction == "outgoing": - relationships.append(f"- {entity.name} --[{edge_name}]--> (相关实体)") + relationships.append(f"- {entity.name} --[{edge_name}]--> (Related Entity)") else: - relationships.append(f"- (相关实体) --[{edge_name}]--> {entity.name}") + relationships.append(f"- (Related Entity) --[{edge_name}]--> {entity.name}") if relationships: - context_parts.append("### 相关事实和关系\n" + "\n".join(relationships)) + context_parts.append("### Related Facts and Relationships\n" + "\n".join(relationships)) - # 3. 添加关联节点的详细信息 + # 3. Add related node details if entity.related_nodes: related_info = [] - for node in entity.related_nodes: # 不限制数量 + for node in entity.related_nodes: # No limit on quantity node_name = node.get("name", "") node_labels = node.get("labels", []) node_summary = node.get("summary", "") - # 过滤掉默认标签 + # Filter out default labels custom_labels = [l for l in node_labels if l not in ["Entity", "Node"]] label_str = f" ({', '.join(custom_labels)})" if custom_labels else "" @@ -469,28 +414,28 @@ def _build_entity_context(self, entity: EntityNode) -> str: related_info.append(f"- **{node_name}**{label_str}") if related_info: - context_parts.append("### 关联实体信息\n" + "\n".join(related_info)) + context_parts.append("### Related Entity Information\n" + "\n".join(related_info)) - # 4. 使用Zep混合检索获取更丰富的信息 + # 4. Use knowledge graph hybrid search to get richer information zep_results = self._search_zep_for_entity(entity) - + if zep_results.get("facts"): - # 去重:排除已存在的事实 + # Deduplication: exclude existing facts new_facts = [f for f in zep_results["facts"] if f not in existing_facts] if new_facts: - context_parts.append("### Zep检索到的事实信息\n" + "\n".join(f"- {f}" for f in new_facts[:15])) - + context_parts.append("### Facts Retrieved from Knowledge Graph\n" + "\n".join(f"- {f}" for f in new_facts[:15])) + if zep_results.get("node_summaries"): - context_parts.append("### Zep检索到的相关节点\n" + "\n".join(f"- {s}" for s in zep_results["node_summaries"][:10])) + context_parts.append("### Related Nodes Retrieved from Knowledge Graph\n" + "\n".join(f"- {s}" for s in zep_results["node_summaries"][:10])) return "\n\n".join(context_parts) def _is_individual_entity(self, entity_type: str) -> bool: - """判断是否是个人类型实体""" + """Determine if entity is an individual type""" return entity_type.lower() in self.INDIVIDUAL_ENTITY_TYPES def _is_group_entity(self, entity_type: str) -> bool: - """判断是否是群体/机构类型实体""" + """Determine if entity is a group/institutional type""" return entity_type.lower() in self.GROUP_ENTITY_TYPES def _generate_profile_with_llm( @@ -502,11 +447,11 @@ def _generate_profile_with_llm( context: str ) -> Dict[str, Any]: """ - 使用LLM生成非常详细的人设 - - 根据实体类型区分: - - 个人实体:生成具体的人物设定 - - 群体/机构实体:生成代表性账号设定 + Use LLM to generate very detailed persona + + Based on entity type: + - Individual entities: generate specific character profiles + - Group/institutional entities: generate representative account profiles """ is_individual = self._is_individual_entity(entity_type) @@ -520,7 +465,7 @@ def _generate_profile_with_llm( entity_name, entity_type, entity_summary, entity_attributes, context ) - # 尝试多次生成,直到成功或达到最大重试次数 + # Try multiple times until successful or max retry attempts reached max_attempts = 3 last_error = None @@ -533,34 +478,34 @@ def _generate_profile_with_llm( {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 + temperature=0.7 - (attempt * 0.1) # Lower temperature with each retry + # Don't set max_tokens, let LLM generate freely ) content = response.choices[0].message.content - # 检查是否被截断(finish_reason不是'stop') + # Check if output was truncated (finish_reason is not 'stop') finish_reason = response.choices[0].finish_reason if finish_reason == 'length': - logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...") + logger.warning(f"LLM output truncated (attempt {attempt+1}), attempting to fix...") content = self._fix_truncated_json(content) - # 尝试解析JSON + # Try to parse JSON try: result = json.loads(content) - # 验证必需字段 + # Validate required fields if "bio" not in result or not result["bio"]: result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" if "persona" not in result or not result["persona"]: - result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。" + result["persona"] = entity_summary or f"{entity_name} is a {entity_type}." return result except json.JSONDecodeError as je: - logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}") + logger.warning(f"JSON parsing failed (attempt {attempt+1}): {str(je)[:80]}") - # 尝试修复JSON + # Try to fix JSON result = self._try_fix_json(content, entity_name, entity_type, entity_summary) if result.get("_fixed"): del result["_fixed"] @@ -569,75 +514,75 @@ def _generate_profile_with_llm( last_error = je except Exception as e: - logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") + logger.warning(f"LLM call failed (attempt {attempt+1}): {str(e)[:80]}") last_error = e import time - time.sleep(1 * (attempt + 1)) # 指数退避 + time.sleep(1 * (attempt + 1)) # Exponential backoff - logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成") + logger.warning(f"LLM persona generation failed ({max_attempts} attempts): {last_error}, using rule-based generation") return self._generate_profile_rule_based( entity_name, entity_type, entity_summary, entity_attributes ) def _fix_truncated_json(self, content: str) -> str: - """修复被截断的JSON(输出被max_tokens限制截断)""" + """Fix truncated JSON (output truncated by max_tokens limit)""" import re - - # 如果JSON被截断,尝试闭合它 + + # If JSON is truncated, try to close it content = content.strip() - - # 计算未闭合的括号 + + # Count unclosed parentheses open_braces = content.count('{') - content.count('}') open_brackets = content.count('[') - content.count(']') - - # 检查是否有未闭合的字符串 - # 简单检查:如果最后一个引号后没有逗号或闭合括号,可能是字符串被截断 + + # Check for unclosed strings + # Simple check: if last character is not comma or closing bracket, string might be truncated if content and content[-1] not in '",}]': - # 尝试闭合字符串 + # Try to close the string content += '"' - - # 闭合括号 + + # Close parentheses content += ']' * open_brackets content += '}' * open_braces return content def _try_fix_json(self, content: str, entity_name: str, entity_type: str, entity_summary: str = "") -> Dict[str, Any]: - """尝试修复损坏的JSON""" + """Try to fix corrupted JSON""" import re - # 1. 首先尝试修复被截断的情况 + # 1. First try to fix truncated case content = self._fix_truncated_json(content) - - # 2. 尝试提取JSON部分 + + # 2. Try to extract JSON portion json_match = re.search(r'\{[\s\S]*\}', content) if json_match: json_str = json_match.group() - # 3. 处理字符串中的换行符问题 - # 找到所有字符串值并替换其中的换行符 + # 3. Handle newline issues in strings + # Find all string values and replace newlines def fix_string_newlines(match): s = match.group(0) - # 替换字符串内的实际换行符为空格 + # Replace actual newlines in string with spaces s = s.replace('\n', ' ').replace('\r', ' ') - # 替换多余空格 + # Replace excess spaces s = re.sub(r'\s+', ' ', s) return s - # 匹配JSON字符串值 + # Match JSON string values json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str) - # 4. 尝试解析 + # 4. Try to parse try: result = json.loads(json_str) result["_fixed"] = True return result except json.JSONDecodeError as e: - # 5. 如果还是失败,尝试更激进的修复 + # 5. If still failed, try more aggressive fix try: - # 移除所有控制字符 + # Remove all control characters json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) - # 替换所有连续空白 + # Replace all consecutive whitespace json_str = re.sub(r'\s+', ' ', json_str) result = json.loads(json_str) result["_fixed"] = True @@ -645,32 +590,32 @@ def fix_string_newlines(match): except: pass - # 6. 尝试从内容中提取部分信息 + # 6. Try to extract partial information from content bio_match = re.search(r'"bio"\s*:\s*"([^"]*)"', content) - persona_match = re.search(r'"persona"\s*:\s*"([^"]*)', content) # 可能被截断 + persona_match = re.search(r'"persona"\s*:\s*"([^"]*)', content) # May be truncated bio = bio_match.group(1) if bio_match else (entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}") - persona = persona_match.group(1) if persona_match else (entity_summary or f"{entity_name}是一个{entity_type}。") + persona = persona_match.group(1) if persona_match else (entity_summary or f"{entity_name} is a {entity_type}.") - # 如果提取到了有意义的内容,标记为已修复 + # If extracted meaningful content, mark as fixed if bio_match or persona_match: - logger.info(f"从损坏的JSON中提取了部分信息") + logger.info(f"Extracted partial information from corrupted JSON") return { "bio": bio, "persona": persona, "_fixed": True } - # 7. 完全失败,返回基础结构 - logger.warning(f"JSON修复失败,返回基础结构") + # 7. Complete failure, return basic structure + logger.warning(f"JSON fix failed, returning basic structure") return { "bio": entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}", - "persona": entity_summary or f"{entity_name}是一个{entity_type}。" + "persona": entity_summary or f"{entity_name} is a {entity_type}." } def _get_system_prompt(self, is_individual: bool) -> str: - """获取系统提示词""" - base_prompt = "你是社交媒体用户画像生成专家。生成详细、真实的人设用于舆论模拟,最大程度还原已有现实情况。必须返回有效的JSON格式,所有字符串值不能包含未转义的换行符。使用中文。" + """Get system prompt""" + base_prompt = "You are an expert in generating social media user profiles. Generate detailed, realistic personas for opinion simulation that maximize restoration of existing reality. Must return valid JSON format with all string values containing no unescaped newlines. Use English." return base_prompt def _build_individual_persona_prompt( @@ -681,45 +626,45 @@ def _build_individual_persona_prompt( entity_attributes: Dict[str, Any], context: str ) -> str: - """构建个人实体的详细人设提示词""" + """Build detailed persona prompt for individual entities""" - attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "无" - context_str = context[:3000] if context else "无额外上下文" + attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None" + context_str = context[:3000] if context else "No additional context" - return f"""为实体生成详细的社交媒体用户人设,最大程度还原已有现实情况。 + return f"""Generate a detailed social media user persona for the entity, maximizing restoration of existing reality. -实体名称: {entity_name} -实体类型: {entity_type} -实体摘要: {entity_summary} -实体属性: {attrs_str} +Entity Name: {entity_name} +Entity Type: {entity_type} +Entity Summary: {entity_summary} +Entity Attributes: {attrs_str} -上下文信息: +Context Information: {context_str} -请生成JSON,包含以下字段: +Please generate JSON containing the following fields: -1. bio: 社交媒体简介,200字 -2. persona: 详细人设描述(2000字的纯文本),需包含: - - 基本信息(年龄、职业、教育背景、所在地) - - 人物背景(重要经历、与事件的关联、社会关系) - - 性格特征(MBTI类型、核心性格、情绪表达方式) - - 社交媒体行为(发帖频率、内容偏好、互动风格、语言特点) - - 立场观点(对话题的态度、可能被激怒/感动的内容) - - 独特特征(口头禅、特殊经历、个人爱好) - - 个人记忆(人设的重要部分,要介绍这个个体与事件的关联,以及这个个体在事件中的已有动作与反应) -3. age: 年龄数字(必须是整数) -4. gender: 性别,必须是英文: "male" 或 "female" -5. mbti: MBTI类型(如INTJ、ENFP等) -6. country: 国家(使用中文,如"中国") -7. profession: 职业 -8. interested_topics: 感兴趣话题数组 +1. bio: Social media bio, 200 characters +2. persona: Detailed persona description (2000 words of pure text), must include: + - Basic information (age, profession, educational background, location) + - Personal background (important experiences, event associations, social relationships) + - Personality traits (MBTI type, core personality, emotional expression) + - Social media behavior (posting frequency, content preferences, interaction style, language characteristics) + - Positions and views (attitudes toward topics, content that may provoke/touch emotions) + - Unique features (catchphrases, special experiences, personal interests) + - Personal memories (important part of persona, introduce this individual's association with events and their existing actions/reactions in events) +3. age: Age as number (must be integer) +4. gender: Gender, must be in English: "male" or "female" +5. mbti: MBTI type (e.g., INTJ, ENFP) +6. country: Country (use English, e.g., "US") +7. profession: Profession +8. interested_topics: Array of interested topics -重要: -- 所有字段值必须是字符串或数字,不要使用换行符 -- persona必须是一段连贯的文字描述 -- 使用中文(除了gender字段必须用英文male/female) -- 内容要与实体信息保持一致 -- age必须是有效的整数,gender必须是"male"或"female" +Important: +- All field values must be strings or numbers, do not use newlines +- persona must be a coherent text description +- Use English +- Content must be consistent with entity information +- age must be a valid integer, gender must be "male" or "female" """ def _build_group_persona_prompt( @@ -730,45 +675,45 @@ def _build_group_persona_prompt( entity_attributes: Dict[str, Any], context: str ) -> str: - """构建群体/机构实体的详细人设提示词""" + """Build detailed persona prompt for group/institutional entities""" - attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "无" - context_str = context[:3000] if context else "无额外上下文" + attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None" + context_str = context[:3000] if context else "No additional context" - return f"""为机构/群体实体生成详细的社交媒体账号设定,最大程度还原已有现实情况。 + return f"""Generate detailed social media account profile for institutional/group entity, maximizing restoration of existing reality. -实体名称: {entity_name} -实体类型: {entity_type} -实体摘要: {entity_summary} -实体属性: {attrs_str} +Entity Name: {entity_name} +Entity Type: {entity_type} +Entity Summary: {entity_summary} +Entity Attributes: {attrs_str} -上下文信息: +Context Information: {context_str} -请生成JSON,包含以下字段: +Please generate JSON containing the following fields: -1. bio: 官方账号简介,200字,专业得体 -2. persona: 详细账号设定描述(2000字的纯文本),需包含: - - 机构基本信息(正式名称、机构性质、成立背景、主要职能) - - 账号定位(账号类型、目标受众、核心功能) - - 发言风格(语言特点、常用表达、禁忌话题) - - 发布内容特点(内容类型、发布频率、活跃时间段) - - 立场态度(对核心话题的官方立场、面对争议的处理方式) - - 特殊说明(代表的群体画像、运营习惯) - - 机构记忆(机构人设的重要部分,要介绍这个机构与事件的关联,以及这个机构在事件中的已有动作与反应) -3. age: 固定填30(机构账号的虚拟年龄) -4. gender: 固定填"other"(机构账号使用other表示非个人) -5. mbti: MBTI类型,用于描述账号风格,如ISTJ代表严谨保守 -6. country: 国家(使用中文,如"中国") -7. profession: 机构职能描述 -8. interested_topics: 关注领域数组 +1. bio: Official account bio, 200 characters, professional and appropriate +2. persona: Detailed account profile description (2000 words of pure text), must include: + - Basic institutional information (official name, organizational nature, founding background, main functions) + - Account positioning (account type, target audience, core functions) + - Speaking style (language characteristics, common expressions, taboo topics) + - Content publishing characteristics (content types, publishing frequency, active time periods) + - Position and attitude (official stance on core topics, handling of controversies) + - Special notes (group profiles represented, operational habits) + - Institutional memories (important part of institutional persona, introduce this institution's association with events and their existing actions/reactions in events) +3. age: Fixed at 30 (virtual age of institutional account) +4. gender: Fixed at "other" (institutional account uses other to denote non-individual) +5. mbti: MBTI type used to describe account style, e.g., ISTJ represents rigorous conservative +6. country: Country (use English, e.g., "US") +7. profession: Institutional function description +8. interested_topics: Array of focus areas -重要: -- 所有字段值必须是字符串或数字,不允许null值 -- persona必须是一段连贯的文字描述,不要使用换行符 -- 使用中文(除了gender字段必须用英文"other") -- age必须是整数30,gender必须是字符串"other" -- 机构账号发言要符合其身份定位""" +Important: +- All field values must be strings or numbers, no null values allowed +- persona must be a coherent text description, do not use newlines +- Use English +- age must be integer 30, gender must be string "other" +- Institutional account speech must match its identity positioning""" def _generate_profile_rule_based( self, @@ -777,9 +722,9 @@ def _generate_profile_rule_based( entity_summary: str, entity_attributes: Dict[str, Any] ) -> Dict[str, Any]: - """使用规则生成基础人设""" - - # 根据实体类型生成不同的人设 + """Generate basic persona using rules""" + + # Generate different personas based on entity type entity_type_lower = entity_type.lower() if entity_type_lower in ["student", "alumni"]: @@ -810,10 +755,10 @@ def _generate_profile_rule_based( return { "bio": f"Official account for {entity_name}. News and updates.", "persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.", - "age": 30, # 机构虚拟年龄 - "gender": "other", # 机构使用other - "mbti": "ISTJ", # 机构风格:严谨保守 - "country": "中国", + "age": 30, # Institutional virtual age + "gender": "other", # Institutional uses other + "mbti": "ISTJ", # Institutional style: rigorous conservative + "country": "US", "profession": "Media", "interested_topics": ["General News", "Current Events", "Public Affairs"], } @@ -822,16 +767,16 @@ def _generate_profile_rule_based( return { "bio": f"Official account of {entity_name}.", "persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.", - "age": 30, # 机构虚拟年龄 - "gender": "other", # 机构使用other - "mbti": "ISTJ", # 机构风格:严谨保守 - "country": "中国", + "age": 30, # Institutional virtual age + "gender": "other", # Institutional uses other + "mbti": "ISTJ", # Institutional style: rigorous conservative + "country": "US", "profession": entity_type, "interested_topics": ["Public Policy", "Community", "Official Announcements"], } else: - # 默认人设 + # Default persona return { "bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}", "persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.", @@ -844,7 +789,7 @@ def _generate_profile_rule_based( } def set_graph_id(self, graph_id: str): - """设置图谱ID用于Zep检索""" + """Set knowledge graph ID for knowledge graph search""" self.graph_id = graph_id def generate_profiles_from_entities( @@ -858,52 +803,52 @@ def generate_profiles_from_entities( output_platform: str = "reddit" ) -> List[OasisAgentProfile]: """ - 批量从实体生成Agent Profile(支持并行生成) - + Generate Agent Profiles in batch from entities (supports parallel generation) + Args: - entities: 实体列表 - use_llm: 是否使用LLM生成详细人设 - progress_callback: 进度回调函数 (current, total, message) - graph_id: 图谱ID,用于Zep检索获取更丰富上下文 - parallel_count: 并行生成数量,默认5 - realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次) - output_platform: 输出平台格式 ("reddit" 或 "twitter") - + entities: Entity list + use_llm: Whether to use LLM to generate detailed personas + progress_callback: Progress callback function (current, total, message) + graph_id: Knowledge graph ID for knowledge graph search to get richer context + parallel_count: Number of parallel generations, default 5 + realtime_output_path: Real-time output file path (if provided, write after each generation) + output_platform: Output platform format ("reddit" or "twitter") + Returns: - Agent Profile列表 + List of Agent Profiles """ import concurrent.futures from threading import Lock - # 设置graph_id用于Zep检索 + # Set graph_id for knowledge graph search if graph_id: self.graph_id = graph_id total = len(entities) - profiles = [None] * total # 预分配列表保持顺序 - completed_count = [0] # 使用列表以便在闭包中修改 + profiles = [None] * total # Pre-allocate list to maintain order + completed_count = [0] # Use list for modification in closure lock = Lock() - - # 实时写入文件的辅助函数 + + # Helper function for real-time file writing def save_profiles_realtime(): - """实时保存已生成的 profiles 到文件""" + """Real-time save generated profiles to file""" if not realtime_output_path: return with lock: - # 过滤出已生成的 profiles + # Filter generated profiles existing_profiles = [p for p in profiles if p is not None] if not existing_profiles: return try: if output_platform == "reddit": - # Reddit JSON 格式 + # Reddit JSON format profiles_data = [p.to_reddit_format() for p in existing_profiles] with open(realtime_output_path, 'w', encoding='utf-8') as f: json.dump(profiles_data, f, ensure_ascii=False, indent=2) else: - # Twitter CSV 格式 + # Twitter CSV format import csv profiles_data = [p.to_twitter_format() for p in existing_profiles] if profiles_data: @@ -913,10 +858,10 @@ def save_profiles_realtime(): writer.writeheader() writer.writerows(profiles_data) except Exception as e: - logger.warning(f"实时保存 profiles 失败: {e}") + logger.warning(f"Real-time profile save failed: {e}") def generate_single_profile(idx: int, entity: EntityNode) -> tuple: - """生成单个profile的工作函数""" + """Worker function to generate single profile""" entity_type = entity.get_entity_type() or "Entity" try: @@ -926,14 +871,14 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: use_llm=use_llm ) - # 实时输出生成的人设到控制台和日志 + # Real-time output generated persona to console and log self._print_generated_profile(entity.name, entity_type, profile) return idx, profile, None except Exception as e: - logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}") - # 创建一个基础profile + logger.error(f"Failed to generate persona for entity {entity.name}: {str(e)}") + # Create a fallback profile fallback_profile = OasisAgentProfile( user_id=idx, user_name=self._generate_username(entity.name), @@ -945,20 +890,20 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: ) return idx, fallback_profile, str(e) - logger.info(f"开始并行生成 {total} 个Agent人设(并行数: {parallel_count})...") + logger.info(f"Starting parallel generation of {total} agent personas (parallel count: {parallel_count})...") print(f"\n{'='*60}") - print(f"开始生成Agent人设 - 共 {total} 个实体,并行数: {parallel_count}") + print(f"Starting agent persona generation - {total} entities total, parallel count: {parallel_count}") print(f"{'='*60}\n") - # 使用线程池并行执行 + # Use thread pool for parallel execution with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_count) as executor: - # 提交所有任务 + # Submit all tasks future_to_entity = { executor.submit(generate_single_profile, idx, entity): (idx, entity) for idx, entity in enumerate(entities) } - # 收集结果 + # Collect results for future in concurrent.futures.as_completed(future_to_entity): idx, entity = future_to_entity[future] entity_type = entity.get_entity_type() or "Entity" @@ -971,23 +916,23 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: completed_count[0] += 1 current = completed_count[0] - # 实时写入文件 + # Real-time file writing save_profiles_realtime() - + if progress_callback: progress_callback( current, total, - f"已完成 {current}/{total}: {entity.name}({entity_type})" + f"Completed {current}/{total}: {entity.name} ({entity_type})" ) if error: - logger.warning(f"[{current}/{total}] {entity.name} 使用备用人设: {error}") + logger.warning(f"[{current}/{total}] {entity.name} using fallback persona: {error}") else: - logger.info(f"[{current}/{total}] 成功生成人设: {entity.name} ({entity_type})") + logger.info(f"[{current}/{total}] Successfully generated persona: {entity.name} ({entity_type})") except Exception as e: - logger.error(f"处理实体 {entity.name} 时发生异常: {str(e)}") + logger.error(f"Exception occurred while processing entity {entity.name}: {str(e)}") with lock: completed_count[0] += 1 profiles[idx] = OasisAgentProfile( @@ -999,44 +944,44 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) - # 实时写入文件(即使是备用人设) + # Real-time file writing (even for fallback personas) save_profiles_realtime() print(f"\n{'='*60}") - print(f"人设生成完成!共生成 {len([p for p in profiles if p])} 个Agent") + print(f"Persona generation complete! Generated {len([p for p in profiles if p])} agents") print(f"{'='*60}\n") return profiles def _print_generated_profile(self, entity_name: str, entity_type: str, profile: OasisAgentProfile): - """实时输出生成的人设到控制台(完整内容,不截断)""" + """Real-time output generated persona to console (complete content, not truncated)""" separator = "-" * 70 - # 构建完整输出内容(不截断) - topics_str = ', '.join(profile.interested_topics) if profile.interested_topics else '无' + # Build complete output content (not truncated) + topics_str = ', '.join(profile.interested_topics) if profile.interested_topics else 'None' output_lines = [ f"\n{separator}", - f"[已生成] {entity_name} ({entity_type})", + f"[Generated] {entity_name} ({entity_type})", f"{separator}", - f"用户名: {profile.user_name}", + f"Username: {profile.user_name}", f"", - f"【简介】", + f"[Bio]", f"{profile.bio}", f"", - f"【详细人设】", + f"[Detailed Persona]", f"{profile.persona}", f"", - f"【基本属性】", - f"年龄: {profile.age} | 性别: {profile.gender} | MBTI: {profile.mbti}", - f"职业: {profile.profession} | 国家: {profile.country}", - f"兴趣话题: {topics_str}", + f"[Basic Attributes]", + f"Age: {profile.age} | Gender: {profile.gender} | MBTI: {profile.mbti}", + f"Profession: {profile.profession} | Country: {profile.country}", + f"Interested Topics: {topics_str}", separator ] output = "\n".join(output_lines) - # 只输出到控制台(避免重复,logger不再输出完整内容) + # Only output to console (avoid duplication, logger no longer outputs complete content) print(output) def save_profiles( @@ -1046,16 +991,16 @@ def save_profiles( platform: str = "reddit" ): """ - 保存Profile到文件(根据平台选择正确格式) - - OASIS平台格式要求: - - Twitter: CSV格式 - - Reddit: JSON格式 - + Save profiles to file (choose correct format based on platform) + + OASIS platform format requirements: + - Twitter: CSV format + - Reddit: JSON format + Args: - profiles: Profile列表 - file_path: 文件路径 - platform: 平台类型 ("reddit" 或 "twitter") + profiles: Profile list + file_path: File path + platform: Platform type ("reddit" or "twitter") """ if platform == "twitter": self._save_twitter_csv(profiles, file_path) @@ -1064,73 +1009,68 @@ def save_profiles( def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str): """ - 保存Twitter Profile为CSV格式(符合OASIS官方要求) - - OASIS Twitter要求的CSV字段: - - user_id: 用户ID(根据CSV顺序从0开始) - - name: 用户真实姓名 - - username: 系统中的用户名 - - user_char: 详细人设描述(注入到LLM系统提示中,指导Agent行为) - - description: 简短的公开简介(显示在用户资料页面) - - user_char vs description 区别: - - user_char: 内部使用,LLM系统提示,决定Agent如何思考和行动 - - description: 外部显示,其他用户可见的简介 + Save Twitter Profile as CSV format (compliant with OASIS official requirements) + + OASIS Twitter required CSV fields: + - user_id: User ID (starting from 0 based on CSV order) + - name: User real name + - username: Username in the system + - user_char: Detailed persona description (injected into LLM system prompt, guides agent behavior) + - description: Short public bio (displayed on user profile page) + + user_char vs description difference: + - user_char: Internal use, LLM system prompt, determines how agent thinks and acts + - description: External display, visible to other users """ import csv - # 确保文件扩展名是.csv + # Ensure file extension is .csv if not file_path.endswith('.csv'): file_path = file_path.replace('.json', '.csv') with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) - # 写入OASIS要求的表头 + # Write OASIS required header headers = ['user_id', 'name', 'username', 'user_char', 'description'] writer.writerow(headers) - # 写入数据行 + # Write data rows for idx, profile in enumerate(profiles): - # user_char: 完整人设(bio + persona),用于LLM系统提示 + # user_char: Complete persona (bio + persona) for LLM system prompt user_char = profile.bio if profile.persona and profile.persona != profile.bio: user_char = f"{profile.bio} {profile.persona}" - # 处理换行符(CSV中用空格替代) + # Handle newlines (replace with space in CSV) user_char = user_char.replace('\n', ' ').replace('\r', ' ') - # description: 简短简介,用于外部显示 + # description: Short bio for external display description = profile.bio.replace('\n', ' ').replace('\r', ' ') row = [ - idx, # user_id: 从0开始的顺序ID - profile.name, # name: 真实姓名 - profile.user_name, # username: 用户名 - user_char, # user_char: 完整人设(内部LLM使用) - description # description: 简短简介(外部显示) + idx, # user_id: Sequential ID starting from 0 + profile.name, # name: Real name + profile.user_name, # username: Username + user_char, # user_char: Complete persona (internal LLM use) + description # description: Short bio (external display) ] writer.writerow(row) - logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)") + logger.info(f"Saved {len(profiles)} Twitter profiles to {file_path} (OASIS CSV format)") def _normalize_gender(self, gender: Optional[str]) -> str: """ - 标准化gender字段为OASIS要求的英文格式 + Normalize gender field to OASIS required English format - OASIS要求: male, female, other + OASIS requires: male, female, other """ if not gender: return "other" gender_lower = gender.lower().strip() - # 中文映射 + # Gender mapping gender_map = { - "男": "male", - "女": "female", - "机构": "other", - "其他": "other", - # 英文已有 "male": "male", "female": "female", "other": "other", @@ -1140,41 +1080,41 @@ def _normalize_gender(self, gender: Optional[str]) -> str: def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): """ - 保存Reddit Profile为JSON格式 - - 使用与 to_reddit_format() 一致的格式,确保 OASIS 能正确读取。 - 必须包含 user_id 字段,这是 OASIS agent_graph.get_agent() 匹配的关键! - - 必需字段: - - user_id: 用户ID(整数,用于匹配 initial_posts 中的 poster_agent_id) - - username: 用户名 - - name: 显示名称 - - bio: 简介 - - persona: 详细人设 - - age: 年龄(整数) - - gender: "male", "female", 或 "other" - - mbti: MBTI类型 - - country: 国家 + Save Reddit Profile as JSON format + + Use format consistent with to_reddit_format() to ensure OASIS can read correctly. + Must include user_id field, which is the key for OASIS agent_graph.get_agent() matching! + + Required fields: + - user_id: User ID (integer, for matching poster_agent_id in initial_posts) + - username: Username + - name: Display name + - bio: Bio + - persona: Detailed persona + - age: Age (integer) + - gender: "male", "female", or "other" + - mbti: MBTI type + - country: Country """ data = [] for idx, profile in enumerate(profiles): - # 使用与 to_reddit_format() 一致的格式 + # Use format consistent with to_reddit_format() item = { - "user_id": profile.user_id if profile.user_id is not None else idx, # 关键:必须包含 user_id + "user_id": profile.user_id if profile.user_id is not None else idx, # Key: must include user_id "username": profile.user_name, "name": profile.name, "bio": profile.bio[:150] if profile.bio else f"{profile.name}", "persona": profile.persona or f"{profile.name} is a participant in social discussions.", "karma": profile.karma if profile.karma else 1000, "created_at": profile.created_at, - # OASIS必需字段 - 确保都有默认值 + # OASIS required fields - ensure all have defaults "age": profile.age if profile.age else 30, "gender": self._normalize_gender(profile.gender), "mbti": profile.mbti if profile.mbti else "ISTJ", - "country": profile.country if profile.country else "中国", + "country": profile.country if profile.country else "US", } - # 可选字段 + # Optional fields if profile.profession: item["profession"] = profile.profession if profile.interested_topics: @@ -1185,16 +1125,16 @@ def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) - logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON格式,包含user_id字段)") + logger.info(f"Saved {len(profiles)} Reddit profiles to {file_path} (JSON format, includes user_id field)") - # 保留旧方法名作为别名,保持向后兼容 + # Keep old method name as alias for backward compatibility def save_profiles_to_json( self, profiles: List[OasisAgentProfile], file_path: str, platform: str = "reddit" ): - """[已废弃] 请使用 save_profiles() 方法""" - logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") + """[Deprecated] Please use save_profiles() method""" + logger.warning("save_profiles_to_json is deprecated, please use save_profiles method") self.save_profiles(profiles, file_path, platform) diff --git a/backend/app/services/ontology_generator.py b/backend/app/services/ontology_generator.py index 2d3e39bd8..4908539df 100644 --- a/backend/app/services/ontology_generator.py +++ b/backend/app/services/ontology_generator.py @@ -197,7 +197,7 @@ def generate( result = self.llm_client.chat_json( messages=messages, temperature=0.3, - max_tokens=4096 + max_tokens=16384 ) # 验证和后处理 @@ -284,7 +284,7 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: if len(edge.get("description", "")) > 100: edge["description"] = edge["description"][:97] + "..." - # Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型 + # 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型 MAX_ENTITY_TYPES = 10 MAX_EDGE_TYPES = 10 @@ -360,8 +360,12 @@ def generate_python_code(self, ontology: Dict[str, Any]) -> str: '由MiroFish自动生成,用于社会舆论模拟', '"""', '', - 'from pydantic import Field', - 'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel', + 'from pydantic import BaseModel, Field', + '', + '# Base classes for ontology models (Graphiti-compatible)', + 'EntityModel = BaseModel', + 'EntityText = str', + 'EdgeModel = BaseModel', '', '', '# ============== 实体类型定义 ==============', diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index 02ca5bdc2..9f9acb649 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -1,12 +1,12 @@ """ -Report Agent服务 -使用LangChain + Zep实现ReACT模式的模拟报告生成 - -功能: -1. 根据模拟需求和Zep图谱信息生成报告 -2. 先规划目录结构,然后分段生成 -3. 每段采用ReACT多轮思考与反思模式 -4. 支持与用户对话,在对话中自主调用检索工具 +Report Agent Service +Generate simulated reports using ReACT pattern (via GraphStorage / Neo4j) + +Features: +1. Generate reports based on simulation requirements and graph information +2. First plan the outline structure, then generate section by section +3. Each section uses ReACT multi-round thinking and reflection pattern +4. Support conversations with users, autonomously call retrieval tools during conversations """ import os @@ -34,18 +34,18 @@ class ReportLogger: """ - Report Agent 详细日志记录器 - - 在报告文件夹中生成 agent_log.jsonl 文件,记录每一步详细动作。 - 每行是一个完整的 JSON 对象,包含时间戳、动作类型、详细内容等。 + Report Agent Detailed Logger + + Generates agent_log.jsonl file in the report folder, recording detailed actions at each step. + Each line is a complete JSON object containing timestamp, action type, details, etc. """ def __init__(self, report_id: str): """ - 初始化日志记录器 - + Initialize the logger + Args: - report_id: 报告ID,用于确定日志文件路径 + report_id: Report ID, used to determine the log file path """ self.report_id = report_id self.log_file_path = os.path.join( @@ -55,12 +55,12 @@ def __init__(self, report_id: str): self._ensure_log_file() def _ensure_log_file(self): - """确保日志文件所在目录存在""" + """Ensure the log file directory exists""" log_dir = os.path.dirname(self.log_file_path) os.makedirs(log_dir, exist_ok=True) def _get_elapsed_time(self) -> float: - """获取从开始到现在的耗时(秒)""" + """Get elapsed time from start to now (in seconds)""" return (datetime.now() - self.start_time).total_seconds() def log( @@ -72,14 +72,14 @@ def log( section_index: int = None ): """ - 记录一条日志 + Record a log entry Args: - action: 动作类型,如 'start', 'tool_call', 'llm_response', 'section_complete' 等 - stage: 当前阶段,如 'planning', 'generating', 'completed' - details: 详细内容字典,不截断 - section_title: 当前章节标题(可选) - section_index: 当前章节索引(可选) + action: Action type, e.g. 'start', 'tool_call', 'llm_response', 'section_complete' + stage: Current stage, e.g. 'planning', 'generating', 'completed' + details: Details dictionary, not truncated + section_title: Current section title (optional) + section_index: Current section index (optional) """ log_entry = { "timestamp": datetime.now().isoformat(), @@ -92,12 +92,12 @@ def log( "details": details } - # 追加写入 JSONL 文件 + # Append to JSONL file with open(self.log_file_path, 'a', encoding='utf-8') as f: f.write(json.dumps(log_entry, ensure_ascii=False) + '\n') def log_start(self, simulation_id: str, graph_id: str, simulation_requirement: str): - """记录报告生成开始""" + """Log report generation start""" self.log( action="report_start", stage="pending", @@ -105,52 +105,52 @@ def log_start(self, simulation_id: str, graph_id: str, simulation_requirement: s "simulation_id": simulation_id, "graph_id": graph_id, "simulation_requirement": simulation_requirement, - "message": "报告生成任务开始" + "message": "Report generation task started" } ) def log_planning_start(self): - """记录大纲规划开始""" + """Log outline planning start""" self.log( action="planning_start", stage="planning", - details={"message": "开始规划报告大纲"} + details={"message": "Start planning report outline"} ) def log_planning_context(self, context: Dict[str, Any]): - """记录规划时获取的上下文信息""" + """Log context information obtained during planning""" self.log( action="planning_context", stage="planning", details={ - "message": "获取模拟上下文信息", + "message": "Getting simulation context information", "context": context } ) def log_planning_complete(self, outline_dict: Dict[str, Any]): - """记录大纲规划完成""" + """Log outline planning complete""" self.log( action="planning_complete", stage="planning", details={ - "message": "大纲规划完成", + "message": "Outline planning complete", "outline": outline_dict } ) def log_section_start(self, section_title: str, section_index: int): - """记录章节生成开始""" + """Log section generation start""" self.log( action="section_start", stage="generating", section_title=section_title, section_index=section_index, - details={"message": f"开始生成章节: {section_title}"} + details={"message": f"Start generating section: {section_title}"} ) def log_react_thought(self, section_title: str, section_index: int, iteration: int, thought: str): - """记录 ReACT 思考过程""" + """Log ReACT thinking process""" self.log( action="react_thought", stage="generating", @@ -159,7 +159,7 @@ def log_react_thought(self, section_title: str, section_index: int, iteration: i details={ "iteration": iteration, "thought": thought, - "message": f"ReACT 第{iteration}轮思考" + "message": f"ReACT round {iteration} thinking" } ) @@ -171,7 +171,7 @@ def log_tool_call( parameters: Dict[str, Any], iteration: int ): - """记录工具调用""" + """Log tool call""" self.log( action="tool_call", stage="generating", @@ -181,7 +181,7 @@ def log_tool_call( "iteration": iteration, "tool_name": tool_name, "parameters": parameters, - "message": f"调用工具: {tool_name}" + "message": f"Tool call: {tool_name}" } ) @@ -193,7 +193,7 @@ def log_tool_result( result: str, iteration: int ): - """记录工具调用结果(完整内容,不截断)""" + """Log tool call result (full content, not truncated)""" self.log( action="tool_result", stage="generating", @@ -202,9 +202,9 @@ def log_tool_result( details={ "iteration": iteration, "tool_name": tool_name, - "result": result, # 完整结果,不截断 + "result": result, # Full result, not truncated "result_length": len(result), - "message": f"工具 {tool_name} 返回结果" + "message": f"Tool {tool_name} returned result" } ) @@ -217,7 +217,7 @@ def log_llm_response( has_tool_calls: bool, has_final_answer: bool ): - """记录 LLM 响应(完整内容,不截断)""" + """Log LLM response (full content, not truncated)""" self.log( action="llm_response", stage="generating", @@ -225,11 +225,11 @@ def log_llm_response( section_index=section_index, details={ "iteration": iteration, - "response": response, # 完整响应,不截断 + "response": response, # Full response, not truncated "response_length": len(response), "has_tool_calls": has_tool_calls, "has_final_answer": has_final_answer, - "message": f"LLM 响应 (工具调用: {has_tool_calls}, 最终答案: {has_final_answer})" + "message": f"LLM response (tool calls: {has_tool_calls}, final answer: {has_final_answer})" } ) @@ -240,17 +240,17 @@ def log_section_content( content: str, tool_calls_count: int ): - """记录章节内容生成完成(仅记录内容,不代表整个章节完成)""" + """Log section content generation complete (only records content, does not mean entire section is complete)""" self.log( action="section_content", stage="generating", section_title=section_title, section_index=section_index, details={ - "content": content, # 完整内容,不截断 + "content": content, # Full content, not truncated "content_length": len(content), "tool_calls_count": tool_calls_count, - "message": f"章节 {section_title} 内容生成完成" + "message": f"Section {section_title} content generation complete" } ) @@ -261,9 +261,9 @@ def log_section_full_complete( full_content: str ): """ - 记录章节生成完成 + Log section generation complete - 前端应监听此日志来判断一个章节是否真正完成,并获取完整内容 + Frontend should monitor this log to determine if a section is truly complete and get full content """ self.log( action="section_complete", @@ -273,24 +273,24 @@ def log_section_full_complete( details={ "content": full_content, "content_length": len(full_content), - "message": f"章节 {section_title} 生成完成" + "message": f"Section {section_title} generation complete" } ) def log_report_complete(self, total_sections: int, total_time_seconds: float): - """记录报告生成完成""" + """Log report generation complete""" self.log( action="report_complete", stage="completed", details={ "total_sections": total_sections, "total_time_seconds": round(total_time_seconds, 2), - "message": "报告生成完成" + "message": "Report generation complete" } ) def log_error(self, error_message: str, stage: str, section_title: str = None): - """记录错误""" + """Log error""" self.log( action="error", stage=stage, @@ -298,25 +298,25 @@ def log_error(self, error_message: str, stage: str, section_title: str = None): section_index=None, details={ "error": error_message, - "message": f"发生错误: {error_message}" + "message": f"Error occurred: {error_message}" } ) class ReportConsoleLogger: """ - Report Agent 控制台日志记录器 + Report Agent Console Logger - 将控制台风格的日志(INFO、WARNING等)写入报告文件夹中的 console_log.txt 文件。 - 这些日志与 agent_log.jsonl 不同,是纯文本格式的控制台输出。 + Writes console-style logs (INFO, WARNING, etc.) to console_log.txt in the report folder. + These logs differ from agent_log.jsonl as they are plain text console output. """ def __init__(self, report_id: str): """ - 初始化控制台日志记录器 + Initialize console logger Args: - report_id: 报告ID,用于确定日志文件路径 + report_id: Report ID, used to determine log file path """ self.report_id = report_id self.log_file_path = os.path.join( @@ -327,15 +327,15 @@ def __init__(self, report_id: str): self._setup_file_handler() def _ensure_log_file(self): - """确保日志文件所在目录存在""" + """Ensure the log file directory exists""" log_dir = os.path.dirname(self.log_file_path) os.makedirs(log_dir, exist_ok=True) def _setup_file_handler(self): - """设置文件处理器,将日志同时写入文件""" + """Set up file handler to write logs to file""" import logging - # 创建文件处理器 + # Create file handler self._file_handler = logging.FileHandler( self.log_file_path, mode='a', @@ -343,14 +343,14 @@ def _setup_file_handler(self): ) self._file_handler.setLevel(logging.INFO) - # 使用与控制台相同的简洁格式 + # Use same concise format as console formatter = logging.Formatter( '[%(asctime)s] %(levelname)s: %(message)s', datefmt='%H:%M:%S' ) self._file_handler.setFormatter(formatter) - # 添加到 report_agent 相关的 logger + # Add to report_agent related loggers loggers_to_attach = [ 'mirofish.report_agent', 'mirofish.zep_tools', @@ -358,12 +358,12 @@ def _setup_file_handler(self): for logger_name in loggers_to_attach: target_logger = logging.getLogger(logger_name) - # 避免重复添加 + # Avoid duplicate additions if self._file_handler not in target_logger.handlers: target_logger.addHandler(self._file_handler) def close(self): - """关闭文件处理器并从 logger 中移除""" + """Close file handler and remove from loggers""" import logging if self._file_handler: @@ -381,12 +381,12 @@ def close(self): self._file_handler = None def __del__(self): - """析构时确保关闭文件处理器""" + """Ensure file handler is closed on destruction""" self.close() class ReportStatus(str, Enum): - """报告状态""" + """Report status""" PENDING = "pending" PLANNING = "planning" GENERATING = "generating" @@ -396,7 +396,7 @@ class ReportStatus(str, Enum): @dataclass class ReportSection: - """报告章节""" + """Report section""" title: str content: str = "" @@ -407,7 +407,7 @@ def to_dict(self) -> Dict[str, Any]: } def to_markdown(self, level: int = 2) -> str: - """转换为Markdown格式""" + """Convert to Markdown format""" md = f"{'#' * level} {self.title}\n\n" if self.content: md += f"{self.content}\n\n" @@ -416,7 +416,7 @@ def to_markdown(self, level: int = 2) -> str: @dataclass class ReportOutline: - """报告大纲""" + """Report outline""" title: str summary: str sections: List[ReportSection] @@ -429,7 +429,7 @@ def to_dict(self) -> Dict[str, Any]: } def to_markdown(self) -> str: - """转换为Markdown格式""" + """Convert to Markdown format""" md = f"# {self.title}\n\n" md += f"> {self.summary}\n\n" for section in self.sections: @@ -439,7 +439,7 @@ def to_markdown(self) -> str: @dataclass class Report: - """完整报告""" + """Complete report""" report_id: str simulation_id: str graph_id: str @@ -467,417 +467,417 @@ def to_dict(self) -> Dict[str, Any]: # ═══════════════════════════════════════════════════════════════ -# Prompt 模板常量 +# Prompt Template Constants # ═══════════════════════════════════════════════════════════════ -# ── 工具描述 ── +# ── Tool Descriptions ── TOOL_DESC_INSIGHT_FORGE = """\ -【深度洞察检索 - 强大的检索工具】 -这是我们强大的检索函数,专为深度分析设计。它会: -1. 自动将你的问题分解为多个子问题 -2. 从多个维度检索模拟图谱中的信息 -3. 整合语义搜索、实体分析、关系链追踪的结果 -4. 返回最全面、最深度的检索内容 - -【使用场景】 -- 需要深入分析某个话题 -- 需要了解事件的多个方面 -- 需要获取支撑报告章节的丰富素材 - -【返回内容】 -- 相关事实原文(可直接引用) -- 核心实体洞察 -- 关系链分析""" +[Deep Insight Retrieval - Powerful retrieval tool] +This is our powerful retrieval function, designed for deep analysis. It will: +1. Automatically decompose your question into multiple sub-questions +2. Retrieve information from the simulation graph across multiple dimensions +3. Integrate results from semantic search, entity analysis, and relationship chain tracking +4. Return the most comprehensive, deepest retrieval content + +[Use cases] +- Need to deeply analyze a topic +- Need to understand multiple aspects of an event +- Need rich materials to support report sections + +[Returns] +- Related factual text (can be directly quoted) +- Core entity insights +- Relationship chain analysis""" TOOL_DESC_PANORAMA_SEARCH = """\ -【广度搜索 - 获取全貌视图】 -这个工具用于获取模拟结果的完整全貌,特别适合了解事件演变过程。它会: -1. 获取所有相关节点和关系 -2. 区分当前有效的事实和历史/过期的事实 -3. 帮助你了解舆情是如何演变的 - -【使用场景】 -- 需要了解事件的完整发展脉络 -- 需要对比不同阶段的舆情变化 -- 需要获取全面的实体和关系信息 - -【返回内容】 -- 当前有效事实(模拟最新结果) -- 历史/过期事实(演变记录) -- 所有涉及的实体""" +[Breadth Search - Get panoramic view] +This tool is used to get the complete picture of simulation results, especially suitable for understanding event evolution. It will: +1. Get all related nodes and relationships +2. Distinguish between currently valid facts and historical/expired facts +3. Help you understand how public opinion evolved + +[Use cases] +- Need to understand the complete development trajectory of an event +- Need to compare opinion changes across different stages +- Need comprehensive entity and relationship information + +[Returns] +- Currently valid facts (latest simulation results) +- Historical/expired facts (evolution records) +- All involved entities""" TOOL_DESC_QUICK_SEARCH = """\ -【简单搜索 - 快速检索】 -轻量级的快速检索工具,适合简单、直接的信息查询。 +[Simple Search - Quick retrieval] +Lightweight quick retrieval tool, suitable for simple, direct information queries. -【使用场景】 -- 需要快速查找某个具体信息 -- 需要验证某个事实 -- 简单的信息检索 +[Use cases] +- Need to quickly find specific information +- Need to verify a fact +- Simple information retrieval -【返回内容】 -- 与查询最相关的事实列表""" +[Returns] +- List of facts most relevant to the query""" TOOL_DESC_INTERVIEW_AGENTS = """\ -【深度采访 - 真实Agent采访(双平台)】 -调用OASIS模拟环境的采访API,对正在运行的模拟Agent进行真实采访! -这不是LLM模拟,而是调用真实的采访接口获取模拟Agent的原始回答。 -默认在Twitter和Reddit两个平台同时采访,获取更全面的观点。 - -功能流程: -1. 自动读取人设文件,了解所有模拟Agent -2. 智能选择与采访主题最相关的Agent(如学生、媒体、官方等) -3. 自动生成采访问题 -4. 调用 /api/simulation/interview/batch 接口在双平台进行真实采访 -5. 整合所有采访结果,提供多视角分析 - -【使用场景】 -- 需要从不同角色视角了解事件看法(学生怎么看?媒体怎么看?官方怎么说?) -- 需要收集多方意见和立场 -- 需要获取模拟Agent的真实回答(来自OASIS模拟环境) -- 想让报告更生动,包含"采访实录" - -【返回内容】 -- 被采访Agent的身份信息 -- 各Agent在Twitter和Reddit两个平台的采访回答 -- 关键引言(可直接引用) -- 采访摘要和观点对比 - -【重要】需要OASIS模拟环境正在运行才能使用此功能!""" - -# ── 大纲规划 prompt ── +[In-depth Interview - Real Agent Interview (dual platform)] +Calls the OASIS simulation environment interview API to conduct real interviews with running simulation Agents! +This is not LLM simulation, but calling real interview interfaces to get simulation Agents' original responses. +By default interviews on both Twitter and Reddit platforms simultaneously for more comprehensive viewpoints. + +Workflow: +1. Automatically read persona files to understand all simulation Agents +2. Intelligently select Agents most relevant to the interview topic (e.g., students, media, officials) +3. Automatically generate interview questions +4. Call /api/simulation/interview/batch interface for real interviews on both platforms +5. Integrate all interview results, providing multi-perspective analysis + +[Use cases] +- Need to understand event views from different role perspectives (How do students see it? Media? Officials?) +- Need to collect opinions and stances from multiple parties +- Need to get real responses from simulation Agents (from OASIS simulation environment) +- Want to make report more vivid, including "interview transcripts" + +[Returns] +- Identity information of interviewed Agents +- Each Agent's interview responses on both Twitter and Reddit platforms +- Key quotes (can be directly cited) +- Interview summary and viewpoint comparison + +[Important] OASIS simulation environment must be running to use this feature!""" + +# ── Outline Planning Prompt ── PLAN_SYSTEM_PROMPT = """\ -你是一个「未来预测报告」的撰写专家,拥有对模拟世界的「上帝视角」——你可以洞察模拟中每一位Agent的行为、言论和互动。 - -【核心理念】 -我们构建了一个模拟世界,并向其中注入了特定的「模拟需求」作为变量。模拟世界的演化结果,就是对未来可能发生情况的预测。你正在观察的不是"实验数据",而是"未来的预演"。 - -【你的任务】 -撰写一份「未来预测报告」,回答: -1. 在我们设定的条件下,未来发生了什么? -2. 各类Agent(人群)是如何反应和行动? -3. 这个模拟揭示了哪些值得关注的未来趋势和风险? - -【报告定位】 -- ✅ 这是一份基于模拟的未来预测报告,揭示"如果这样,未来会怎样" -- ✅ 聚焦于预测结果:事件走向、群体反应、涌现现象、潜在风险 -- ✅ 模拟世界中的Agent言行就是对未来人群行为的预测 -- ❌ 不是对现实世界现状的分析 -- ❌ 不是泛泛而谈的舆情综述 - -【章节数量限制】 -- 最少2个章节,最多5个章节 -- 不需要子章节,每个章节直接撰写完整内容 -- 内容要精炼,聚焦于核心预测发现 -- 章节结构由你根据预测结果自主设计 - -请输出JSON格式的报告大纲,格式如下: +You are an expert in writing "future prediction reports" with a "god's eye view" of the simulated world - you can gain insights into the behavior, statements, and interactions of every agent in the simulation. + +[Core Concept] +We built a simulated world and injected specific "simulation requirements" as variables into it. The evolution result of the simulated world is a prediction of what might happen in the future. What you're observing is not "experimental data" but a "rehearsal of the future". + +[Your Task] +Write a "future prediction report" that answers: +1. What happened in the future under the conditions we set? +2. How do various agents (groups) react and act? +3. What future trends and risks does this simulation reveal that deserve attention? + +[Report Positioning] +- ✅ This is a future prediction report based on simulation, revealing "if this happens, how will the future unfold" +- ✅ Focus on prediction results: event trajectories, group reactions, emergent phenomena, potential risks +- ✅ Agent statements and behaviors in the simulated world are predictions of future human behavior +- ❌ Not an analysis of the current state of the real world +- ❌ Not a general overview of public sentiment + +[Section Number Limit] +- Minimum 2 sections, maximum 5 sections +- No subsections needed, each section directly writes complete content +- Content should be concise, focused on core prediction findings +- Section structure is designed independently based on prediction results + +Please output the report outline in JSON format as follows: { - "title": "报告标题", - "summary": "报告摘要(一句话概括核心预测发现)", + "title": "Report Title", + "summary": "Report Summary (one sentence summarizing core prediction findings)", "sections": [ { - "title": "章节标题", - "description": "章节内容描述" + "title": "Section Title", + "description": "Section Content Description" } ] } -注意:sections数组最少2个,最多5个元素!""" +Note: sections array must have at least 2 and at most 5 elements!""" PLAN_USER_PROMPT_TEMPLATE = """\ -【预测场景设定】 -我们向模拟世界注入的变量(模拟需求):{simulation_requirement} +[Prediction Scenario Settings] +Variable (simulation requirement) injected into the simulated world: {simulation_requirement} -【模拟世界规模】 -- 参与模拟的实体数量: {total_nodes} -- 实体间产生的关系数量: {total_edges} -- 实体类型分布: {entity_types} -- 活跃Agent数量: {total_entities} +[Simulated World Scale] +- Number of entities participating in simulation: {total_nodes} +- Number of relationships generated between entities: {total_edges} +- Entity type distribution: {entity_types} +- Number of active agents: {total_entities} -【模拟预测到的部分未来事实样本】 +[Sample of Some Future Facts Predicted by Simulation] {related_facts_json} -请以「上帝视角」审视这个未来预演: -1. 在我们设定的条件下,未来呈现出了什么样的状态? -2. 各类人群(Agent)是如何反应和行动的? -3. 这个模拟揭示了哪些值得关注的未来趋势? +Please examine this future rehearsal from a "god's eye view": +1. What state does the future present under the conditions we set? +2. How do various groups (agents) react and act? +3. What future trends does this simulation reveal that deserve attention? -根据预测结果,设计最合适的报告章节结构。 +Based on the prediction results, design the most appropriate report section structure. -【再次提醒】报告章节数量:最少2个,最多5个,内容要精炼聚焦于核心预测发现。""" +[Reminder] Report section count: minimum 2, maximum 5, content should be concise and focused on core prediction findings.""" -# ── 章节生成 prompt ── +# ── Section Generation Prompt ── SECTION_SYSTEM_PROMPT_TEMPLATE = """\ -你是一个「未来预测报告」的撰写专家,正在撰写报告的一个章节。 +You are an expert in writing "future prediction reports" and are writing a section of the report. -报告标题: {report_title} -报告摘要: {report_summary} -预测场景(模拟需求): {simulation_requirement} +Report Title: {report_title} +Report Summary: {report_summary} +Prediction Scenario (Simulation Requirement): {simulation_requirement} -当前要撰写的章节: {section_title} +Current Section to Write: {section_title} ═══════════════════════════════════════════════════════════════ -【核心理念】 +[Core Concept] ═══════════════════════════════════════════════════════════════ -模拟世界是对未来的预演。我们向模拟世界注入了特定条件(模拟需求), -模拟中Agent的行为和互动,就是对未来人群行为的预测。 +The simulated world is a rehearsal of the future. We injected specific conditions (simulation requirements) into the simulated world. +The behavior and interactions of agents in the simulation are predictions of future human behavior. -你的任务是: -- 揭示在设定条件下,未来发生了什么 -- 预测各类人群(Agent)是如何反应和行动的 -- 发现值得关注的未来趋势、风险和机会 +Your task is to: +- Reveal what happens in the future under the set conditions +- Predict how various groups (agents) react and act +- Discover future trends, risks, and opportunities worth paying attention to -❌ 不要写成对现实世界现状的分析 -✅ 要聚焦于"未来会怎样"——模拟结果就是预测的未来 +❌ Don't write it as an analysis of the current state of the real world +✅ Focus on "how the future will unfold" - simulation results are the predicted future ═══════════════════════════════════════════════════════════════ -【最重要的规则 - 必须遵守】 +[Most Important Rules - Must Follow] ═══════════════════════════════════════════════════════════════ -1. 【必须调用工具观察模拟世界】 - - 你正在以「上帝视角」观察未来的预演 - - 所有内容必须来自模拟世界中发生的事件和Agent言行 - - 禁止使用你自己的知识来编写报告内容 - - 每个章节至少调用3次工具(最多5次)来观察模拟的世界,它代表了未来 - -2. 【必须引用Agent的原始言行】 - - Agent的发言和行为是对未来人群行为的预测 - - 在报告中使用引用格式展示这些预测,例如: - > "某类人群会表示:原文内容..." - - 这些引用是模拟预测的核心证据 - -3. 【语言一致性 - 引用内容必须翻译为报告语言】 - - 工具返回的内容可能包含英文或中英文混杂的表述 - - 如果模拟需求和材料原文是中文的,报告必须全部使用中文撰写 - - 当你引用工具返回的英文或中英混杂内容时,必须将其翻译为流畅的中文后再写入报告 - - 翻译时保持原意不变,确保表述自然通顺 - - 这一规则同时适用于正文和引用块(> 格式)中的内容 - -4. 【忠实呈现预测结果】 - - 报告内容必须反映模拟世界中的代表未来的模拟结果 - - 不要添加模拟中不存在的信息 - - 如果某方面信息不足,如实说明 +1. [Must Call Tools to Observe the Simulated World] + - You are observing a rehearsal of the future from a "god's eye view" + - All content must come from events and agent statements/behaviors in the simulated world + - Forbidden to use your own knowledge to write report content + - Each section must call tools at least 3 times (maximum 5 times) to observe the simulated world, which represents the future + +2. [Must Quote Original Agent Statements and Behaviors] + - Agent statements and behaviors are predictions of future human behavior + - Use quote format in the report to display these predictions, for example: + > "Certain groups will state: original content..." + - These quotes are core evidence of simulation predictions + +3. [Language Consistency - Quoted Content Must Be Translated to Report Language] + - Tool returned content may contain English or mixed Chinese-English expressions + - If the simulation requirement and source material are in Chinese, the report must be entirely in Chinese + - When you quote English or mixed Chinese-English content from tools, you must translate it to fluent Chinese before including it in the report + - When translating, preserve the original meaning and ensure natural expression + - This rule applies to both regular text and quoted blocks (> format) + +4. [Faithfully Present Prediction Results] + - Report content must reflect simulation results that represent the future in the simulated world + - Don't add information that doesn't exist in the simulation + - If information is insufficient in some aspects, state it truthfully ═══════════════════════════════════════════════════════════════ -【⚠️ 格式规范 - 极其重要!】 +[⚠️ Format Specification - Extremely Important!] ═══════════════════════════════════════════════════════════════ -【一个章节 = 最小内容单位】 -- 每个章节是报告的最小分块单位 -- ❌ 禁止在章节内使用任何 Markdown 标题(#、##、###、#### 等) -- ❌ 禁止在内容开头添加章节主标题 -- ✅ 章节标题由系统自动添加,你只需撰写纯正文内容 -- ✅ 使用**粗体**、段落分隔、引用、列表来组织内容,但不要用标题 +[One Section = Minimum Content Unit] +- Each section is the minimum content unit of the report +- ❌ Forbidden to use any Markdown titles (#, ##, ###, ####, etc.) within the section +- ❌ Forbidden to add section titles at the beginning of content +- ✅ Section titles are added automatically by the system, just write pure body text +- ✅ Use **bold**, paragraph separation, quotes, and lists to organize content, but don't use titles -【正确示例】 +[Correct Example] ``` -本章节分析了事件的舆论传播态势。通过对模拟数据的深入分析,我们发现... +This section analyzes the public sentiment propagation of the event. Through in-depth analysis of simulation data, we found... -**首发引爆阶段** +**Initial Explosion Phase** -微博作为舆情的第一现场,承担了信息首发的核心功能: +Weibo, as the first scene of public sentiment, undertook the core function of initial information dissemination: -> "微博贡献了68%的首发声量..." +> "Weibo contributed 68% of initial voice..." -**情绪放大阶段** +**Emotion Amplification Phase** -抖音平台进一步放大了事件影响力: +The TikTok platform further amplified the impact of the event: -- 视觉冲击力强 -- 情绪共鸣度高 +- Strong visual impact +- High emotional resonance ``` -【错误示例】 +[Incorrect Example] ``` -## 执行摘要 ← 错误!不要添加任何标题 -### 一、首发阶段 ← 错误!不要用###分小节 -#### 1.1 详细分析 ← 错误!不要用####细分 +## Executive Summary ← Wrong! Don't add any titles +### 1. Initial Phase ← Wrong! Don't use ### for subsections +#### 1.1 Detailed Analysis ← Wrong! Don't use #### for subdivisions -本章节分析了... +This section analyzes... ``` ═══════════════════════════════════════════════════════════════ -【可用检索工具】(每章节调用3-5次) +[Available Retrieval Tools] (call 3-5 times per section) ═══════════════════════════════════════════════════════════════ {tools_description} -【工具使用建议 - 请混合使用不同工具,不要只用一种】 -- insight_forge: 深度洞察分析,自动分解问题并多维度检索事实和关系 -- panorama_search: 广角全景搜索,了解事件全貌、时间线和演变过程 -- quick_search: 快速验证某个具体信息点 -- interview_agents: 采访模拟Agent,获取不同角色的第一人称观点和真实反应 +[Tool Usage Suggestions - Please Mix Different Tools, Don't Use Only One] +- insight_forge: Deep insight analysis, automatically decompose problems and retrieve facts and relationships from multiple dimensions +- panorama_search: Wide-angle panoramic search, understand complete event view, timeline, and evolution process +- quick_search: Quick verification of specific information points +- interview_agents: Interview simulated agents, get first-person perspectives and real reactions from different roles ═══════════════════════════════════════════════════════════════ -【工作流程】 +[Workflow] ═══════════════════════════════════════════════════════════════ -每次回复你只能做以下两件事之一(不可同时做): +Each reply you can only do one of two things (cannot do both): -选项A - 调用工具: -输出你的思考,然后用以下格式调用一个工具: +Option A - Call Tool: +Output your thinking, then call a tool using the following format: -{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} +{{"name": "Tool Name", "parameters": {{"parameter_name": "parameter_value"}}}} -系统会执行工具并把结果返回给你。你不需要也不能自己编写工具返回结果。 +The system will execute the tool and return the result to you. You don't need to and cannot write tool return results yourself. -选项B - 输出最终内容: -当你已通过工具获取了足够信息,以 "Final Answer:" 开头输出章节内容。 +Option B - Output Final Content: +When you have gathered enough information through tools, start with "Final Answer:" and output section content. -⚠️ 严格禁止: -- 禁止在一次回复中同时包含工具调用和 Final Answer -- 禁止自己编造工具返回结果(Observation),所有工具结果由系统注入 -- 每次回复最多调用一个工具 +⚠️ Strictly Forbidden: +- Forbidden to include both tool calls and Final Answer in one reply +- Forbidden to fabricate tool return results (Observation), all tool results are injected by the system +- At most one tool call per reply ═══════════════════════════════════════════════════════════════ -【章节内容要求】 +[Section Content Requirements] ═══════════════════════════════════════════════════════════════ -1. 内容必须基于工具检索到的模拟数据 -2. 大量引用原文来展示模拟效果 -3. 使用Markdown格式(但禁止使用标题): - - 使用 **粗体文字** 标记重点(代替子标题) - - 使用列表(-或1.2.3.)组织要点 - - 使用空行分隔不同段落 - - ❌ 禁止使用 #、##、###、#### 等任何标题语法 -4. 【引用格式规范 - 必须单独成段】 - 引用必须独立成段,前后各有一个空行,不能混在段落中: - - ✅ 正确格式: +1. Content must be based on simulation data retrieved by tools +2. Heavily quote original text to demonstrate simulation effects +3. Use Markdown format (but forbidden to use titles): + - Use **bold text** to mark key points (replacing sub-titles) + - Use lists (- or 1.2.3.) to organize points + - Use blank lines to separate paragraphs + - ❌ Forbidden to use any title syntax like #, ##, ###, #### +4. [Quote Format Specification - Must Be Separate Paragraph] + Quotes must be standalone paragraphs with blank lines before and after, cannot be mixed in paragraphs: + + ✅ Correct Format: ``` - 校方的回应被认为缺乏实质内容。 + School officials' response was considered lacking substantive content. - > "校方的应对模式在瞬息万变的社交媒体环境中显得僵化和迟缓。" + > "School's response pattern appears rigid and slow in the rapidly changing social media environment." - 这一评价反映了公众的普遍不满。 + This assessment reflects widespread public dissatisfaction. ``` - ❌ 错误格式: + ❌ Incorrect Format: ``` - 校方的回应被认为缺乏实质内容。> "校方的应对模式..." 这一评价反映了... + School officials' response was considered lacking substantive content.> "School's response pattern..." This assessment reflects... ``` -5. 保持与其他章节的逻辑连贯性 -6. 【避免重复】仔细阅读下方已完成的章节内容,不要重复描述相同的信息 -7. 【再次强调】不要添加任何标题!用**粗体**代替小节标题""" +5. Maintain logical coherence with other sections +6. [Avoid Duplication] Carefully read the completed section content below, don't repeat describing the same information +7. [Emphasis Again] Don't add any titles! Use **bold** instead of section sub-titles""" SECTION_USER_PROMPT_TEMPLATE = """\ -已完成的章节内容(请仔细阅读,避免重复): +Completed Section Content (Please Read Carefully to Avoid Duplication): {previous_content} ═══════════════════════════════════════════════════════════════ -【当前任务】撰写章节: {section_title} +[Current Task] Write Section: {section_title} ═══════════════════════════════════════════════════════════════ -【重要提醒】 -1. 仔细阅读上方已完成的章节,避免重复相同的内容! -2. 开始前必须先调用工具获取模拟数据 -3. 请混合使用不同工具,不要只用一种 -4. 报告内容必须来自检索结果,不要使用自己的知识 +[Important Reminders] +1. Carefully read the completed sections above to avoid repeating the same content! +2. You must call tools to get simulation data before starting +3. Please mix different tools, don't use only one +4. Report content must come from retrieval results, don't use your own knowledge -【⚠️ 格式警告 - 必须遵守】 -- ❌ 不要写任何标题(#、##、###、####都不行) -- ❌ 不要写"{section_title}"作为开头 -- ✅ 章节标题由系统自动添加 -- ✅ 直接写正文,用**粗体**代替小节标题 +[⚠️ Format Warning - Must Follow] +- ❌ Don't write any titles (#, ##, ###, #### none allowed) +- ❌ Don't write "{section_title}" as the opening +- ✅ Section titles are added automatically by the system +- ✅ Write the body directly, use **bold** instead of sub-section titles -请开始: -1. 首先思考(Thought)这个章节需要什么信息 -2. 然后调用工具(Action)获取模拟数据 -3. 收集足够信息后输出 Final Answer(纯正文,无任何标题)""" +Please start: +1. First think (Thought) what information this section needs +2. Then call tools (Action) to get simulation data +3. After collecting enough information, output Final Answer (pure body text, no titles)""" -# ── ReACT 循环内消息模板 ── +# ── ReACT Loop Message Templates ── REACT_OBSERVATION_TEMPLATE = """\ -Observation(检索结果): +Observation (Retrieval Result): -═══ 工具 {tool_name} 返回 ═══ +═══ Tool {tool_name} Returned ═══ {result} ═══════════════════════════════════════════════════════════════ -已调用工具 {tool_calls_count}/{max_tool_calls} 次(已用: {used_tools_str}){unused_hint} -- 如果信息充分:以 "Final Answer:" 开头输出章节内容(必须引用上述原文) -- 如果需要更多信息:调用一个工具继续检索 +Called tools {tool_calls_count}/{max_tool_calls} times (Used: {used_tools_str}){unused_hint} +- If information is sufficient: Start with "Final Answer:" and output section content (must quote the above original text) +- If more information is needed: Call a tool to continue retrieving ═══════════════════════════════════════════════════════════════""" REACT_INSUFFICIENT_TOOLS_MSG = ( - "【注意】你只调用了{tool_calls_count}次工具,至少需要{min_tool_calls}次。" - "请再调用工具获取更多模拟数据,然后再输出 Final Answer。{unused_hint}" + "[Notice] You have only called {tool_calls_count} tools, need at least {min_tool_calls}. " + "Please call tools again to get more simulation data, then output Final Answer. {unused_hint}" ) REACT_INSUFFICIENT_TOOLS_MSG_ALT = ( - "当前只调用了 {tool_calls_count} 次工具,至少需要 {min_tool_calls} 次。" - "请调用工具获取模拟数据。{unused_hint}" + "Currently called {tool_calls_count} tools, need at least {min_tool_calls}. " + "Please call tools to get simulation data. {unused_hint}" ) REACT_TOOL_LIMIT_MSG = ( - "工具调用次数已达上限({tool_calls_count}/{max_tool_calls}),不能再调用工具。" - '请立即基于已获取的信息,以 "Final Answer:" 开头输出章节内容。' + "Tool call count has reached the limit ({tool_calls_count}/{max_tool_calls}), cannot call tools anymore. " + 'Please immediately start with "Final Answer:" and output section content based on acquired information.' ) -REACT_UNUSED_TOOLS_HINT = "\n💡 你还没有使用过: {unused_list},建议尝试不同工具获取多角度信息" +REACT_UNUSED_TOOLS_HINT = "\n💡 You haven't used yet: {unused_list}, suggest trying different tools to get multi-perspective information" -REACT_FORCE_FINAL_MSG = "已达到工具调用限制,请直接输出 Final Answer: 并生成章节内容。" +REACT_FORCE_FINAL_MSG = "Tool call limit reached, please directly output Final Answer: and generate section content." -# ── Chat prompt ── +# ── Chat Prompt ── CHAT_SYSTEM_PROMPT_TEMPLATE = """\ -你是一个简洁高效的模拟预测助手。 +You are a concise and efficient simulation prediction assistant. -【背景】 -预测条件: {simulation_requirement} +[Background] +Prediction Condition: {simulation_requirement} -【已生成的分析报告】 +[Generated Analysis Report] {report_content} -【规则】 -1. 优先基于上述报告内容回答问题 -2. 直接回答问题,避免冗长的思考论述 -3. 仅在报告内容不足以回答时,才调用工具检索更多数据 -4. 回答要简洁、清晰、有条理 +[Rules] +1. Prioritize answering questions based on the above report content +2. Answer questions directly, avoid lengthy deliberation +3. Only call tools to retrieve more data if the report content is insufficient to answer +4. Answers should be concise, clear, and well-organized -【可用工具】(仅在需要时使用,最多调用1-2次) +[Available Tools] (use only when needed, call at most 1-2 times) {tools_description} -【工具调用格式】 +[Tool Call Format] -{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} +{{"name": "Tool Name", "parameters": {{"parameter_name": "parameter_value"}}}} -【回答风格】 -- 简洁直接,不要长篇大论 -- 使用 > 格式引用关键内容 -- 优先给出结论,再解释原因""" +[Answer Style] +- Concise and direct, don't write lengthy passages +- Use > format to quote key content +- Give conclusions first, then explain reasons""" -CHAT_OBSERVATION_SUFFIX = "\n\n请简洁回答问题。" +CHAT_OBSERVATION_SUFFIX = "\n\nPlease answer the question concisely." # ═══════════════════════════════════════════════════════════════ -# ReportAgent 主类 +# ReportAgent Main Class # ═══════════════════════════════════════════════════════════════ class ReportAgent: """ - Report Agent - 模拟报告生成Agent + Report Agent - Simulation report generation Agent - 采用ReACT(Reasoning + Acting)模式: - 1. 规划阶段:分析模拟需求,规划报告目录结构 - 2. 生成阶段:逐章节生成内容,每章节可多次调用工具获取信息 - 3. 反思阶段:检查内容完整性和准确性 + Uses ReACT (Reasoning + Acting) pattern: + 1. Planning phase: Analyze simulation requirements, plan report outline structure + 2. Generation phase: Generate content section by section, each section can call tools multiple times + 3. Reflection phase: Check content completeness and accuracy """ - # 最大工具调用次数(每个章节) + # Maximum tool calls per section MAX_TOOL_CALLS_PER_SECTION = 5 - # 最大反思轮数 + # Maximum reflection rounds MAX_REFLECTION_ROUNDS = 3 - # 对话中的最大工具调用次数 + # Maximum tool calls per chat MAX_TOOL_CALLS_PER_CHAT = 2 def __init__( @@ -889,14 +889,14 @@ def __init__( zep_tools: Optional[ZepToolsService] = None ): """ - 初始化Report Agent + Initialize Report Agent Args: - graph_id: 图谱ID - simulation_id: 模拟ID - simulation_requirement: 模拟需求描述 - llm_client: LLM客户端(可选) - zep_tools: Zep工具服务(可选) + graph_id: Graph ID + simulation_id: Simulation ID + simulation_requirement: Simulation requirement description + llm_client: LLM client (optional) + zep_tools: Zep tools service (optional) """ self.graph_id = graph_id self.simulation_id = simulation_id @@ -905,66 +905,66 @@ def __init__( self.llm = llm_client or LLMClient() self.zep_tools = zep_tools or ZepToolsService() - # 工具定义 + # Tool definitions self.tools = self._define_tools() - # 日志记录器(在 generate_report 中初始化) + # Logger (initialized in generate_report) self.report_logger: Optional[ReportLogger] = None - # 控制台日志记录器(在 generate_report 中初始化) + # Console logger (initialized in generate_report) self.console_logger: Optional[ReportConsoleLogger] = None - logger.info(f"ReportAgent 初始化完成: graph_id={graph_id}, simulation_id={simulation_id}") + logger.info(f"ReportAgent initialized: graph_id={graph_id}, simulation_id={simulation_id}") def _define_tools(self) -> Dict[str, Dict[str, Any]]: - """定义可用工具""" + """Define available tools""" return { "insight_forge": { "name": "insight_forge", "description": TOOL_DESC_INSIGHT_FORGE, "parameters": { - "query": "你想深入分析的问题或话题", - "report_context": "当前报告章节的上下文(可选,有助于生成更精准的子问题)" + "query": "The question or topic you want to deeply analyze", + "report_context": "Current report section context (optional, helps generate more precise sub-questions)" } }, "panorama_search": { "name": "panorama_search", "description": TOOL_DESC_PANORAMA_SEARCH, "parameters": { - "query": "搜索查询,用于相关性排序", - "include_expired": "是否包含过期/历史内容(默认True)" + "query": "Search query for relevance ranking", + "include_expired": "Whether to include expired/historical content (default True)" } }, "quick_search": { "name": "quick_search", "description": TOOL_DESC_QUICK_SEARCH, "parameters": { - "query": "搜索查询字符串", - "limit": "返回结果数量(可选,默认10)" + "query": "Search query string", + "limit": "Number of results to return (optional, default 10)" } }, "interview_agents": { "name": "interview_agents", "description": TOOL_DESC_INTERVIEW_AGENTS, "parameters": { - "interview_topic": "采访主题或需求描述(如:'了解学生对宿舍甲醛事件的看法')", - "max_agents": "最多采访的Agent数量(可选,默认5,最大10)" + "interview_topic": "Interview topic or requirement description", + "max_agents": "Maximum number of Agents to interview (optional, default 5, max 10)" } } } def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_context: str = "") -> str: """ - 执行工具调用 + Execute tool call Args: - tool_name: 工具名称 - parameters: 工具参数 - report_context: 报告上下文(用于InsightForge) + tool_name: Tool name + parameters: Tool parameters + report_context: Report context (for InsightForge) Returns: - 工具执行结果(文本格式) + Tool execution result (text format) """ - logger.info(f"执行工具: {tool_name}, 参数: {parameters}") + logger.info(f"Execute tool: {tool_name}, parameters: {parameters}") try: if tool_name == "insight_forge": @@ -979,7 +979,7 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return result.to_text() elif tool_name == "panorama_search": - # 广度搜索 - 获取全貌 + # Breadth search - get panoramic view query = parameters.get("query", "") include_expired = parameters.get("include_expired", True) if isinstance(include_expired, str): @@ -992,7 +992,7 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return result.to_text() elif tool_name == "quick_search": - # 简单搜索 - 快速检索 + # Simple search - quick retrieval query = parameters.get("query", "") limit = parameters.get("limit", 10) if isinstance(limit, str): @@ -1005,7 +1005,7 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return result.to_text() elif tool_name == "interview_agents": - # 深度采访 - 调用真实的OASIS采访API获取模拟Agent的回答(双平台) + # In-depth interview - call real OASIS interview API for Agent responses (dual platform) interview_topic = parameters.get("interview_topic", parameters.get("query", "")) max_agents = parameters.get("max_agents", 5) if isinstance(max_agents, str): @@ -1019,11 +1019,11 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte ) return result.to_text() - # ========== 向后兼容的旧工具(内部重定向到新工具) ========== + # ========== Backward compatible old tools (internally redirect to new tools) ========== elif tool_name == "search_graph": - # 重定向到 quick_search - logger.info("search_graph 已重定向到 quick_search") + # Redirect to quick_search + logger.info("search_graph redirected to quick_search") return self._execute_tool("quick_search", parameters, report_context) elif tool_name == "get_graph_statistics": @@ -1039,8 +1039,8 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return json.dumps(result, ensure_ascii=False, indent=2) elif tool_name == "get_simulation_context": - # 重定向到 insight_forge,因为它更强大 - logger.info("get_simulation_context 已重定向到 insight_forge") + # Redirect to insight_forge, as it is more powerful + logger.info("get_simulation_context redirected to insight_forge") query = parameters.get("query", self.simulation_requirement) return self._execute_tool("insight_forge", {"query": query}, report_context) @@ -1054,26 +1054,26 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return json.dumps(result, ensure_ascii=False, indent=2) else: - return f"未知工具: {tool_name}。请使用以下工具之一: insight_forge, panorama_search, quick_search" + return f"Unknown tool: {tool_name}. Please use one of: insight_forge, panorama_search, quick_search" except Exception as e: - logger.error(f"工具执行失败: {tool_name}, 错误: {str(e)}") - return f"工具执行失败: {str(e)}" + logger.error(f"Tool execution failed: {tool_name}, error: {str(e)}") + return f"Tool execution failed: {str(e)}" - # 合法的工具名称集合,用于裸 JSON 兜底解析时校验 + # Valid tool names set, used for validation when parsing raw JSON fallback VALID_TOOL_NAMES = {"insight_forge", "panorama_search", "quick_search", "interview_agents"} def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: """ - 从LLM响应中解析工具调用 + Parse tool calls from LLM response - 支持的格式(按优先级): + Supported formats (by priority): 1. {"name": "tool_name", "parameters": {...}} - 2. 裸 JSON(响应整体或单行就是一个工具调用 JSON) + 2. Bare JSON (response body or single line is a tool call JSON) """ tool_calls = [] - # 格式1: XML风格(标准格式) + # Format 1: XML style (standard format) xml_pattern = r'\s*(\{.*?\})\s*' for match in re.finditer(xml_pattern, response, re.DOTALL): try: @@ -1085,8 +1085,8 @@ def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: if tool_calls: return tool_calls - # 格式2: 兜底 - LLM 直接输出裸 JSON(没包 标签) - # 只在格式1未匹配时尝试,避免误匹配正文中的 JSON + # Format 2: Fallback - LLM outputs bare JSON (no tags) + # Only try when format 1 did not match, to avoid false matches in body text stripped = response.strip() if stripped.startswith('{') and stripped.endswith('}'): try: @@ -1097,7 +1097,7 @@ def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: except json.JSONDecodeError: pass - # 响应可能包含思考文字 + 裸 JSON,尝试提取最后一个 JSON 对象 + # Response may contain thinking text + bare JSON, try to extract last JSON object json_pattern = r'(\{"(?:name|tool)"\s*:.*?\})\s*$' match = re.search(json_pattern, stripped, re.DOTALL) if match: @@ -1111,11 +1111,11 @@ def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: return tool_calls def _is_valid_tool_call(self, data: dict) -> bool: - """校验解析出的 JSON 是否是合法的工具调用""" - # 支持 {"name": ..., "parameters": ...} 和 {"tool": ..., "params": ...} 两种键名 + """Validate whether parsed JSON is a valid tool call""" + # Support both {"name": ..., "parameters": ...} and {"tool": ..., "params": ...} key names tool_name = data.get("name") or data.get("tool") if tool_name and tool_name in self.VALID_TOOL_NAMES: - # 统一键名为 name / parameters + # Normalize key names to name / parameters if "tool" in data: data["name"] = data.pop("tool") if "params" in data and "parameters" not in data: @@ -1124,13 +1124,13 @@ def _is_valid_tool_call(self, data: dict) -> bool: return False def _get_tools_description(self) -> str: - """生成工具描述文本""" - desc_parts = ["可用工具:"] + """Generate tool description text""" + desc_parts = ["Available tools:"] for name, tool in self.tools.items(): params_desc = ", ".join([f"{k}: {v}" for k, v in tool["parameters"].items()]) desc_parts.append(f"- {name}: {tool['description']}") if params_desc: - desc_parts.append(f" 参数: {params_desc}") + desc_parts.append(f" Parameters: {params_desc}") return "\n".join(desc_parts) def plan_outline( @@ -1138,29 +1138,29 @@ def plan_outline( progress_callback: Optional[Callable] = None ) -> ReportOutline: """ - 规划报告大纲 - - 使用LLM分析模拟需求,规划报告的目录结构 - + Plan report outline + + Use LLM to analyze simulation requirements and plan the report structure + Args: - progress_callback: 进度回调函数 - + progress_callback: Progress callback function + Returns: - ReportOutline: 报告大纲 + ReportOutline: Report outline """ - logger.info("开始规划报告大纲...") - + logger.info("Starting to plan report outline...") + if progress_callback: - progress_callback("planning", 0, "正在分析模拟需求...") - - # 首先获取模拟上下文 + progress_callback("planning", 0, "Analyzing simulation requirements...") + + # First get simulation context context = self.zep_tools.get_simulation_context( graph_id=self.graph_id, simulation_requirement=self.simulation_requirement ) if progress_callback: - progress_callback("planning", 30, "正在生成报告大纲...") + progress_callback("planning", 30, "Generating report outline...") system_prompt = PLAN_SYSTEM_PROMPT user_prompt = PLAN_USER_PROMPT_TEMPLATE.format( @@ -1182,9 +1182,9 @@ def plan_outline( ) if progress_callback: - progress_callback("planning", 80, "正在解析大纲结构...") - - # 解析大纲 + progress_callback("planning", 80, "Parsing outline structure...") + + # Parse outline sections = [] for section_data in response.get("sections", []): sections.append(ReportSection( @@ -1193,27 +1193,27 @@ def plan_outline( )) outline = ReportOutline( - title=response.get("title", "模拟分析报告"), + title=response.get("title", "Simulation Analysis Report"), summary=response.get("summary", ""), sections=sections ) if progress_callback: - progress_callback("planning", 100, "大纲规划完成") - - logger.info(f"大纲规划完成: {len(sections)} 个章节") + progress_callback("planning", 100, "Outline planning completed") + + logger.info(f"Outline planning completed: {len(sections)} sections") return outline except Exception as e: - logger.error(f"大纲规划失败: {str(e)}") - # 返回默认大纲(3个章节,作为fallback) + logger.error(f"Outline planning failed: {str(e)}") + # Return default outline (3 sections as fallback) return ReportOutline( - title="未来预测报告", - summary="基于模拟预测的未来趋势与风险分析", + title="Future Prediction Report", + summary="Future trends and risk analysis based on simulation predictions", sections=[ - ReportSection(title="预测场景与核心发现"), - ReportSection(title="人群行为预测分析"), - ReportSection(title="趋势展望与风险提示") + ReportSection(title="Prediction Scenario and Core Findings"), + ReportSection(title="Crowd Behavior Prediction Analysis"), + ReportSection(title="Trend Outlook and Risk Warning") ] ) @@ -1226,28 +1226,28 @@ def _generate_section_react( section_index: int = 0 ) -> str: """ - 使用ReACT模式生成单个章节内容 - - ReACT循环: - 1. Thought(思考)- 分析需要什么信息 - 2. Action(行动)- 调用工具获取信息 - 3. Observation(观察)- 分析工具返回结果 - 4. 重复直到信息足够或达到最大次数 - 5. Final Answer(最终回答)- 生成章节内容 - + Generate individual section content using ReACT pattern + + ReACT loop: + 1. Thought - Analyze what information is needed + 2. Action - Call tool to get information + 3. Observation - Analyze tool return results + 4. Repeat until information is sufficient or maximum iterations reached + 5. Final Answer - Generate section content + Args: - section: 要生成的章节 - outline: 完整大纲 - previous_sections: 之前章节的内容(用于保持连贯性) - progress_callback: 进度回调 - section_index: 章节索引(用于日志记录) - + section: Section to generate + outline: Complete outline + previous_sections: Content of previous sections (for maintaining coherence) + progress_callback: Progress callback + section_index: Section index (for logging) + Returns: - 章节内容(Markdown格式) + Section content (Markdown format) """ - logger.info(f"ReACT生成章节: {section.title}") - - # 记录章节开始日志 + logger.info(f"ReACT generating section: {section.title}") + + # Log section start if self.report_logger: self.report_logger.log_section_start(section.title, section_index) @@ -1259,16 +1259,16 @@ def _generate_section_react( tools_description=self._get_tools_description(), ) - # 构建用户prompt - 每个已完成章节各传入最大4000字 + # Build user prompt - pass maximum 4000 characters for each completed section if previous_sections: previous_parts = [] for sec in previous_sections: - # 每个章节最多4000字 + # Maximum 4000 characters per section truncated = sec[:4000] + "..." if len(sec) > 4000 else sec previous_parts.append(truncated) previous_content = "\n\n---\n\n".join(previous_parts) else: - previous_content = "(这是第一个章节)" + previous_content = "(This is the first section)" user_prompt = SECTION_USER_PROMPT_TEMPLATE.format( previous_content=previous_content, @@ -1280,77 +1280,77 @@ def _generate_section_react( {"role": "user", "content": user_prompt} ] - # ReACT循环 + # ReACT loop tool_calls_count = 0 - max_iterations = 5 # 最大迭代轮数 - min_tool_calls = 3 # 最少工具调用次数 - conflict_retries = 0 # 工具调用与Final Answer同时出现的连续冲突次数 - used_tools = set() # 记录已调用过的工具名 + max_iterations = 5 # Maximum iterations + min_tool_calls = 3 # Minimum tool calls + conflict_retries = 0 # Consecutive conflicts where tool calls and Final Answer appear simultaneously + used_tools = set() # Record tool names already called all_tools = {"insight_forge", "panorama_search", "quick_search", "interview_agents"} - # 报告上下文,用于InsightForge的子问题生成 - report_context = f"章节标题: {section.title}\n模拟需求: {self.simulation_requirement}" + # Report context for InsightForge sub-question generation + report_context = f"Section Title: {section.title}\nSimulation Requirement: {self.simulation_requirement}" for iteration in range(max_iterations): if progress_callback: progress_callback( "generating", int((iteration / max_iterations) * 100), - f"深度检索与撰写中 ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})" + f"Deep retrieval and writing in progress ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})" ) - # 调用LLM + # Call LLM response = self.llm.chat( messages=messages, temperature=0.5, max_tokens=4096 ) - # 检查 LLM 返回是否为 None(API 异常或内容为空) + # Check if LLM return is None (API exception or empty content) if response is None: - logger.warning(f"章节 {section.title} 第 {iteration + 1} 次迭代: LLM 返回 None") - # 如果还有迭代次数,添加消息并重试 + logger.warning(f"Section {section.title} round {iteration + 1} iteration: LLM returned None") + # If there are more iterations, add message and retry if iteration < max_iterations - 1: - messages.append({"role": "assistant", "content": "(响应为空)"}) - messages.append({"role": "user", "content": "请继续生成内容。"}) + messages.append({"role": "assistant", "content": "(Response empty)"}) + messages.append({"role": "user", "content": "Please continue generating content."}) continue - # 最后一次迭代也返回 None,跳出循环进入强制收尾 + # Last iteration also returned None, exit loop and enter forced conclusion break - logger.debug(f"LLM响应: {response[:200]}...") + logger.debug(f"LLM response: {response[:200]}...") - # 解析一次,复用结果 + # Parse once, reuse result tool_calls = self._parse_tool_calls(response) has_tool_calls = bool(tool_calls) has_final_answer = "Final Answer:" in response - # ── 冲突处理:LLM 同时输出了工具调用和 Final Answer ── + # ── Conflict handling: LLM simultaneously output tool calls and Final Answer ── if has_tool_calls and has_final_answer: conflict_retries += 1 logger.warning( - f"章节 {section.title} 第 {iteration+1} 轮: " - f"LLM 同时输出工具调用和 Final Answer(第 {conflict_retries} 次冲突)" + f"Section {section.title} round {iteration+1}: " + f"LLM simultaneously output tool calls and Final Answer (round {conflict_retries} conflicts)" ) if conflict_retries <= 2: - # 前两次:丢弃本次响应,要求 LLM 重新回复 + # First two times: discard this response and request LLM to reply again messages.append({"role": "assistant", "content": response}) messages.append({ "role": "user", "content": ( - "【格式错误】你在一次回复中同时包含了工具调用和 Final Answer,这是不允许的。\n" - "每次回复只能做以下两件事之一:\n" - "- 调用一个工具(输出一个 块,不要写 Final Answer)\n" - "- 输出最终内容(以 'Final Answer:' 开头,不要包含 )\n" - "请重新回复,只做其中一件事。" + "[Format Error] You cannot include both tool calls and Final Answer in one reply.\n" + "Each reply can only do one of the following:\n" + "- Call a tool (output a block, don't write Final Answer)\n" + "- Output final content (starting with 'Final Answer:', don't include )\n" + "Please reply again and only do one of these." ), }) continue else: - # 第三次:降级处理,截断到第一个工具调用,强制执行 + # Third time: downgrade, truncate to first tool call, force execution logger.warning( - f"章节 {section.title}: 连续 {conflict_retries} 次冲突," - "降级为截断执行第一个工具调用" + f"Section {section.title}: consecutive {conflict_retries} conflicts, " + "downgraded to truncate and execute first tool call" ) first_tool_end = response.find('') if first_tool_end != -1: @@ -1360,7 +1360,7 @@ def _generate_section_react( has_final_answer = False conflict_retries = 0 - # 记录 LLM 响应日志 + # Log LLM response if self.report_logger: self.report_logger.log_llm_response( section_title=section.title, @@ -1371,13 +1371,13 @@ def _generate_section_react( has_final_answer=has_final_answer ) - # ── 情况1:LLM 输出了 Final Answer ── + # ── Case 1: LLM output Final Answer ── if has_final_answer: - # 工具调用次数不足,拒绝并要求继续调工具 + # Insufficient tool calls, reject and request to continue calling tools if tool_calls_count < min_tool_calls: messages.append({"role": "assistant", "content": response}) unused_tools = all_tools - used_tools - unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else "" + unused_hint = f"(These tools have not been used, recommend using them: {', '.join(unused_tools)})" if unused_tools else "" messages.append({ "role": "user", "content": REACT_INSUFFICIENT_TOOLS_MSG.format( @@ -1388,9 +1388,9 @@ def _generate_section_react( }) continue - # 正常结束 + # Normal completion final_answer = response.split("Final Answer:")[-1].strip() - logger.info(f"章节 {section.title} 生成完成(工具调用: {tool_calls_count}次)") + logger.info(f"Section {section.title} generation completed (tool calls: {tool_calls_count} times)") if self.report_logger: self.report_logger.log_section_content( @@ -1401,9 +1401,9 @@ def _generate_section_react( ) return final_answer - # ── 情况2:LLM 尝试调用工具 ── + # ── Case 2: LLM attempts to call tools ── if has_tool_calls: - # 工具额度已耗尽 → 明确告知,要求输出 Final Answer + # Tool quota exhausted -> inform clearly, request output Final Answer if tool_calls_count >= self.MAX_TOOL_CALLS_PER_SECTION: messages.append({"role": "assistant", "content": response}) messages.append({ @@ -1415,10 +1415,10 @@ def _generate_section_react( }) continue - # 只执行第一个工具调用 + # Only execute the first tool call call = tool_calls[0] if len(tool_calls) > 1: - logger.info(f"LLM 尝试调用 {len(tool_calls)} 个工具,只执行第一个: {call['name']}") + logger.info(f"LLM attempted to call {len(tool_calls)} tools, only execute the first: {call['name']}") if self.report_logger: self.report_logger.log_tool_call( @@ -1447,7 +1447,7 @@ def _generate_section_react( tool_calls_count += 1 used_tools.add(call['name']) - # 构建未使用工具提示 + # Build unused tools hint unused_tools = all_tools - used_tools unused_hint = "" if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION: @@ -1467,13 +1467,13 @@ def _generate_section_react( }) continue - # ── 情况3:既没有工具调用,也没有 Final Answer ── + # ── Case 3: Neither tool call nor Final Answer ── messages.append({"role": "assistant", "content": response}) if tool_calls_count < min_tool_calls: - # 工具调用次数不足,推荐未用过的工具 + # Tool call count insufficient, recommend unused tools unused_tools = all_tools - used_tools - unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else "" + unused_hint = f"(These tools have not been used, recommend using them: {', '.join(unused_tools)})" if unused_tools else "" messages.append({ "role": "user", @@ -1485,9 +1485,9 @@ def _generate_section_react( }) continue - # 工具调用已足够,LLM 输出了内容但没带 "Final Answer:" 前缀 - # 直接将这段内容作为最终答案,不再空转 - logger.info(f"章节 {section.title} 未检测到 'Final Answer:' 前缀,直接采纳LLM输出作为最终内容(工具调用: {tool_calls_count}次)") + # Tool calls sufficient, LLM output content but without "Final Answer:" prefix + # Directly adopt this content as final answer, no more waiting + logger.info(f"Section {section.title} did not detect 'Final Answer:' prefix, directly adopting LLM output as final content (tool calls: {tool_calls_count} times)") final_answer = response.strip() if self.report_logger: @@ -1499,8 +1499,8 @@ def _generate_section_react( ) return final_answer - # 达到最大迭代次数,强制生成内容 - logger.warning(f"章节 {section.title} 达到最大迭代次数,强制生成") + # Reached maximum iterations, force generate content + logger.warning(f"Section {section.title} reached maximum iterations, force generating") messages.append({"role": "user", "content": REACT_FORCE_FINAL_MSG}) response = self.llm.chat( @@ -1509,16 +1509,16 @@ def _generate_section_react( max_tokens=4096 ) - # 检查强制收尾时 LLM 返回是否为 None + # Check if LLM return is None during forced conclusion if response is None: - logger.error(f"章节 {section.title} 强制收尾时 LLM 返回 None,使用默认错误提示") - final_answer = f"(本章节生成失败:LLM 返回空响应,请稍后重试)" + logger.error(f"Section {section.title} forced conclusion: LLM returned None, using default error message") + final_answer = f"(This section generation failed: LLM returned empty response, please retry later)" elif "Final Answer:" in response: final_answer = response.split("Final Answer:")[-1].strip() else: final_answer = response - # 记录章节内容生成完成日志 + # Log section content generation completion if self.report_logger: self.report_logger.log_section_content( section_title=section.title, @@ -1535,29 +1535,29 @@ def generate_report( report_id: Optional[str] = None ) -> Report: """ - 生成完整报告(分章节实时输出) - - 每个章节生成完成后立即保存到文件夹,不需要等待整个报告完成。 - 文件结构: + Generate complete report (realtime output per section) + + Each section is saved to the folder immediately after generation, no need to wait for the entire report. + File structure: reports/{report_id}/ - meta.json - 报告元信息 - outline.json - 报告大纲 - progress.json - 生成进度 - section_01.md - 第1章节 - section_02.md - 第2章节 + meta.json - Report metadata + outline.json - Report outline + progress.json - Generation progress + section_01.md - Section 1 + section_02.md - Section 2 ... - full_report.md - 完整报告 - + full_report.md - Complete report + Args: - progress_callback: 进度回调函数 (stage, progress, message) - report_id: 报告ID(可选,如果不传则自动生成) - + progress_callback: Progress callback function (stage, progress, message) + report_id: Report ID (optional, auto-generate if not provided) + Returns: - Report: 完整报告 + Report: Complete report """ import uuid - # 如果没有传入 report_id,则自动生成 + # If report_id not provided, auto-generate if not report_id: report_id = f"report_{uuid.uuid4().hex[:12]}" start_time = datetime.now() @@ -1571,14 +1571,14 @@ def generate_report( created_at=datetime.now().isoformat() ) - # 已完成的章节标题列表(用于进度追踪) + # Completed section titles list (for progress tracking) completed_section_titles = [] try: - # 初始化:创建报告文件夹并保存初始状态 + # Initialize: Create report folder and save initial state ReportManager._ensure_report_folder(report_id) - # 初始化日志记录器(结构化日志 agent_log.jsonl) + # Initialize logger (structured logs agent_log.jsonl) self.report_logger = ReportLogger(report_id) self.report_logger.log_start( simulation_id=self.simulation_id, @@ -1586,27 +1586,27 @@ def generate_report( simulation_requirement=self.simulation_requirement ) - # 初始化控制台日志记录器(console_log.txt) + # Initialize console logger (console_log.txt) self.console_logger = ReportConsoleLogger(report_id) ReportManager.update_progress( - report_id, "pending", 0, "初始化报告...", + report_id, "pending", 0, "Initializing report...", completed_sections=[] ) ReportManager.save_report(report) - # 阶段1: 规划大纲 + # Phase 1: Plan outline report.status = ReportStatus.PLANNING ReportManager.update_progress( - report_id, "planning", 5, "开始规划报告大纲...", + report_id, "planning", 5, "Starting to plan report outline...", completed_sections=[] ) - # 记录规划开始日志 + # Log planning start self.report_logger.log_planning_start() if progress_callback: - progress_callback("planning", 0, "开始规划报告大纲...") + progress_callback("planning", 0, "Starting to plan report outline...") outline = self.plan_outline( progress_callback=lambda stage, prog, msg: @@ -1614,33 +1614,33 @@ def generate_report( ) report.outline = outline - # 记录规划完成日志 + # Log planning completion self.report_logger.log_planning_complete(outline.to_dict()) - # 保存大纲到文件 + # Save outline to file ReportManager.save_outline(report_id, outline) ReportManager.update_progress( - report_id, "planning", 15, f"大纲规划完成,共{len(outline.sections)}个章节", + report_id, "planning", 15, f"Outline planning completed, {len(outline.sections)} sections total", completed_sections=[] ) ReportManager.save_report(report) - logger.info(f"大纲已保存到文件: {report_id}/outline.json") + logger.info(f"Outline saved to file: {report_id}/outline.json") - # 阶段2: 逐章节生成(分章节保存) + # Phase 2: Generate sections sequentially (save per section) report.status = ReportStatus.GENERATING total_sections = len(outline.sections) - generated_sections = [] # 保存内容用于上下文 + generated_sections = [] # Save content for context for i, section in enumerate(outline.sections): section_num = i + 1 base_progress = 20 + int((i / total_sections) * 70) - # 更新进度 + # Update progress ReportManager.update_progress( report_id, "generating", base_progress, - f"正在生成章节: {section.title} ({section_num}/{total_sections})", + f"Generating section: {section.title} ({section_num}/{total_sections})", current_section=section.title, completed_sections=completed_section_titles ) @@ -1649,10 +1649,10 @@ def generate_report( progress_callback( "generating", base_progress, - f"正在生成章节: {section.title} ({section_num}/{total_sections})" + f"Generating section: {section.title} ({section_num}/{total_sections})" ) - # 生成主章节内容 + # Generate main section content section_content = self._generate_section_react( section=section, outline=outline, @@ -1669,11 +1669,11 @@ def generate_report( section.content = section_content generated_sections.append(f"## {section.title}\n\n{section_content}") - # 保存章节 + # Save section ReportManager.save_section(report_id, section_num, section) completed_section_titles.append(section.title) - # 记录章节完成日志 + # Log section completion full_section_content = f"## {section.title}\n\n{section_content}" if self.report_logger: @@ -1683,54 +1683,54 @@ def generate_report( full_content=full_section_content.strip() ) - logger.info(f"章节已保存: {report_id}/section_{section_num:02d}.md") - - # 更新进度 + logger.info(f"Section saved: {report_id}/section_{section_num:02d}.md") + + # Update progress ReportManager.update_progress( report_id, "generating", base_progress + int(70 / total_sections), - f"章节 {section.title} 已完成", + f"Section {section.title} completed", current_section=None, completed_sections=completed_section_titles ) - # 阶段3: 组装完整报告 + # Phase 3: Assemble complete report if progress_callback: - progress_callback("generating", 95, "正在组装完整报告...") - + progress_callback("generating", 95, "Assembling complete report...") + ReportManager.update_progress( - report_id, "generating", 95, "正在组装完整报告...", + report_id, "generating", 95, "Assembling complete report...", completed_sections=completed_section_titles ) - # 使用ReportManager组装完整报告 + # Use ReportManager to assemble complete report report.markdown_content = ReportManager.assemble_full_report(report_id, outline) report.status = ReportStatus.COMPLETED report.completed_at = datetime.now().isoformat() - # 计算总耗时 + # Calculate total elapsed time total_time_seconds = (datetime.now() - start_time).total_seconds() - # 记录报告完成日志 + # Log report completion if self.report_logger: self.report_logger.log_report_complete( total_sections=total_sections, total_time_seconds=total_time_seconds ) - # 保存最终报告 + # Save final report ReportManager.save_report(report) ReportManager.update_progress( - report_id, "completed", 100, "报告生成完成", + report_id, "completed", 100, "Report generation completed", completed_sections=completed_section_titles ) if progress_callback: - progress_callback("completed", 100, "报告生成完成") - - logger.info(f"报告生成完成: {report_id}") - - # 关闭控制台日志记录器 + progress_callback("completed", 100, "Report generation completed") + + logger.info(f"Report generation completed: {report_id}") + + # Close console logger if self.console_logger: self.console_logger.close() self.console_logger = None @@ -1738,90 +1738,90 @@ def generate_report( return report except Exception as e: - logger.error(f"报告生成失败: {str(e)}") + logger.error(f"Report generation failed: {str(e)}") report.status = ReportStatus.FAILED report.error = str(e) - # 记录错误日志 + # Log error if self.report_logger: self.report_logger.log_error(str(e), "failed") - # 保存失败状态 + # Save failed status try: ReportManager.save_report(report) ReportManager.update_progress( - report_id, "failed", -1, f"报告生成失败: {str(e)}", + report_id, "failed", -1, f"Report generation failed: {str(e)}", completed_sections=completed_section_titles ) except Exception: - pass # 忽略保存失败的错误 + pass # Ignore save failure errors - # 关闭控制台日志记录器 + # Close console logger if self.console_logger: self.console_logger.close() self.console_logger = None - + return report - + def chat( self, message: str, chat_history: List[Dict[str, str]] = None ) -> Dict[str, Any]: """ - 与Report Agent对话 - - 在对话中Agent可以自主调用检索工具来回答问题 - + Chat with Report Agent + + During conversation, Agent can autonomously call retrieval tools to answer questions + Args: - message: 用户消息 - chat_history: 对话历史 - + message: User message + chat_history: Chat history + Returns: { - "response": "Agent回复", - "tool_calls": [调用的工具列表], - "sources": [信息来源] + "response": "Agent response", + "tool_calls": [list of tool calls], + "sources": [information sources] } """ - logger.info(f"Report Agent对话: {message[:50]}...") + logger.info(f"Report Agent chat: {message[:50]}...") chat_history = chat_history or [] - # 获取已生成的报告内容 + # Get already generated report content report_content = "" try: report = ReportManager.get_report_by_simulation(self.simulation_id) if report and report.markdown_content: - # 限制报告长度,避免上下文过长 + # Limit report length to avoid overly long context report_content = report.markdown_content[:15000] if len(report.markdown_content) > 15000: - report_content += "\n\n... [报告内容已截断] ..." + report_content += "\n\n... [Report content truncated] ..." except Exception as e: - logger.warning(f"获取报告内容失败: {e}") + logger.warning(f"Failed to get report content: {e}") system_prompt = CHAT_SYSTEM_PROMPT_TEMPLATE.format( simulation_requirement=self.simulation_requirement, - report_content=report_content if report_content else "(暂无报告)", + report_content=report_content if report_content else "(No report available)", tools_description=self._get_tools_description(), ) - # 构建消息 + # Build messages messages = [{"role": "system", "content": system_prompt}] - - # 添加历史对话 - for h in chat_history[-10:]: # 限制历史长度 + + # Add chat history + for h in chat_history[-10:]: # Limit history length messages.append(h) - # 添加用户消息 + # Add user message messages.append({ "role": "user", "content": message }) - # ReACT循环(简化版) + # ReACT loop (simplified version) tool_calls_made = [] - max_iterations = 2 # 减少迭代轮数 + max_iterations = 2 # Reduced iterations for iteration in range(max_iterations): response = self.llm.chat( @@ -1829,11 +1829,11 @@ def chat( temperature=0.5 ) - # 解析工具调用 + # Parse tool calls tool_calls = self._parse_tool_calls(response) if not tool_calls: - # 没有工具调用,直接返回响应 + # No tool calls, directly return response clean_response = re.sub(r'.*?', '', response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) @@ -1843,33 +1843,33 @@ def chat( "sources": [tc.get("parameters", {}).get("query", "") for tc in tool_calls_made] } - # 执行工具调用(限制数量) + # Execute tool calls (limit count) tool_results = [] - for call in tool_calls[:1]: # 每轮最多执行1次工具调用 + for call in tool_calls[:1]: # Execute at most 1 tool call per round if len(tool_calls_made) >= self.MAX_TOOL_CALLS_PER_CHAT: break result = self._execute_tool(call["name"], call.get("parameters", {})) tool_results.append({ "tool": call["name"], - "result": result[:1500] # 限制结果长度 + "result": result[:1500] # Limit result length }) tool_calls_made.append(call) - # 将结果添加到消息 + # Add results to messages messages.append({"role": "assistant", "content": response}) - observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results]) + observation = "\n".join([f"[{r['tool']} result]\n{r['result']}" for r in tool_results]) messages.append({ "role": "user", "content": observation + CHAT_OBSERVATION_SUFFIX }) - # 达到最大迭代,获取最终响应 + # Reached maximum iterations, get final response final_response = self.llm.chat( messages=messages, temperature=0.5 ) - # 清理响应 + # Clean response clean_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) @@ -1882,95 +1882,95 @@ def chat( class ReportManager: """ - 报告管理器 - - 负责报告的持久化存储和检索 - - 文件结构(分章节输出): + Report Manager + + Responsible for report persistence storage and retrieval + + File structure (per-section output): reports/ {report_id}/ - meta.json - 报告元信息和状态 - outline.json - 报告大纲 - progress.json - 生成进度 - section_01.md - 第1章节 - section_02.md - 第2章节 + meta.json - Report metadata and status + outline.json - Report outline + progress.json - Generation progress + section_01.md - Section 1 + section_02.md - Section 2 ... - full_report.md - 完整报告 + full_report.md - Complete report """ - # 报告存储目录 + # Report storage directory REPORTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'reports') @classmethod def _ensure_reports_dir(cls): - """确保报告根目录存在""" + """Ensure report root directory exists""" os.makedirs(cls.REPORTS_DIR, exist_ok=True) @classmethod def _get_report_folder(cls, report_id: str) -> str: - """获取报告文件夹路径""" + """Get report folder path""" return os.path.join(cls.REPORTS_DIR, report_id) @classmethod def _ensure_report_folder(cls, report_id: str) -> str: - """确保报告文件夹存在并返回路径""" + """Ensure report folder exists and return path""" folder = cls._get_report_folder(report_id) os.makedirs(folder, exist_ok=True) return folder @classmethod def _get_report_path(cls, report_id: str) -> str: - """获取报告元信息文件路径""" + """Get report metadata file path""" return os.path.join(cls._get_report_folder(report_id), "meta.json") @classmethod def _get_report_markdown_path(cls, report_id: str) -> str: - """获取完整报告Markdown文件路径""" + """Get complete report Markdown file path""" return os.path.join(cls._get_report_folder(report_id), "full_report.md") @classmethod def _get_outline_path(cls, report_id: str) -> str: - """获取大纲文件路径""" + """Get outline file path""" return os.path.join(cls._get_report_folder(report_id), "outline.json") @classmethod def _get_progress_path(cls, report_id: str) -> str: - """获取进度文件路径""" + """Get progress file path""" return os.path.join(cls._get_report_folder(report_id), "progress.json") @classmethod def _get_section_path(cls, report_id: str, section_index: int) -> str: - """获取章节Markdown文件路径""" + """Get section Markdown file path""" return os.path.join(cls._get_report_folder(report_id), f"section_{section_index:02d}.md") @classmethod def _get_agent_log_path(cls, report_id: str) -> str: - """获取 Agent 日志文件路径""" + """Get Agent log file path""" return os.path.join(cls._get_report_folder(report_id), "agent_log.jsonl") @classmethod def _get_console_log_path(cls, report_id: str) -> str: - """获取控制台日志文件路径""" + """Get console log file path""" return os.path.join(cls._get_report_folder(report_id), "console_log.txt") @classmethod def get_console_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: """ - 获取控制台日志内容 - - 这是报告生成过程中的控制台输出日志(INFO、WARNING等), - 与 agent_log.jsonl 的结构化日志不同。 - + Get console log content + + These are console output logs during report generation (INFO, WARNING, etc.), + different from the structured logs in agent_log.jsonl. + Args: - report_id: 报告ID - from_line: 从第几行开始读取(用于增量获取,0 表示从头开始) - + report_id: Report ID + from_line: Start reading from this line number (for incremental retrieval, 0 means from beginning) + Returns: { - "logs": [日志行列表], - "total_lines": 总行数, - "from_line": 起始行号, - "has_more": 是否还有更多日志 + "logs": [log line list], + "total_lines": total line count, + "from_line": start line number, + "has_more": whether there are more logs } """ log_path = cls._get_console_log_path(report_id) @@ -1990,26 +1990,26 @@ def get_console_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: for i, line in enumerate(f): total_lines = i + 1 if i >= from_line: - # 保留原始日志行,去掉末尾换行符 + # Keep original log line, remove trailing newline logs.append(line.rstrip('\n\r')) return { "logs": logs, "total_lines": total_lines, "from_line": from_line, - "has_more": False # 已读取到末尾 + "has_more": False # Already read to end } - + @classmethod def get_console_log_stream(cls, report_id: str) -> List[str]: """ - 获取完整的控制台日志(一次性获取全部) - + Get complete console log (one-time retrieval of all) + Args: - report_id: 报告ID - + report_id: Report ID + Returns: - 日志行列表 + Log line list """ result = cls.get_console_log(report_id, from_line=0) return result["logs"] @@ -2017,18 +2017,18 @@ def get_console_log_stream(cls, report_id: str) -> List[str]: @classmethod def get_agent_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: """ - 获取 Agent 日志内容 - + Get Agent log content + Args: - report_id: 报告ID - from_line: 从第几行开始读取(用于增量获取,0 表示从头开始) - + report_id: Report ID + from_line: Start reading from this line number (for incremental retrieval, 0 means from beginning) + Returns: { - "logs": [日志条目列表], - "total_lines": 总行数, - "from_line": 起始行号, - "has_more": 是否还有更多日志 + "logs": [log entry list], + "total_lines": total line count, + "from_line": start line number, + "has_more": whether there are more logs } """ log_path = cls._get_agent_log_path(report_id) @@ -2052,26 +2052,26 @@ def get_agent_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: log_entry = json.loads(line.strip()) logs.append(log_entry) except json.JSONDecodeError: - # 跳过解析失败的行 + # Skip lines that fail to parse continue return { "logs": logs, "total_lines": total_lines, "from_line": from_line, - "has_more": False # 已读取到末尾 + "has_more": False # Already read to end } - + @classmethod def get_agent_log_stream(cls, report_id: str) -> List[Dict[str, Any]]: """ - 获取完整的 Agent 日志(用于一次性获取全部) - + Get complete Agent log (for one-time retrieval of all) + Args: - report_id: 报告ID - + report_id: Report ID + Returns: - 日志条目列表 + Log entry list """ result = cls.get_agent_log(report_id, from_line=0) return result["logs"] @@ -2079,16 +2079,16 @@ def get_agent_log_stream(cls, report_id: str) -> List[Dict[str, Any]]: @classmethod def save_outline(cls, report_id: str, outline: ReportOutline) -> None: """ - 保存报告大纲 - - 在规划阶段完成后立即调用 + Save report outline + + Called immediately after planning phase completion """ cls._ensure_report_folder(report_id) with open(cls._get_outline_path(report_id), 'w', encoding='utf-8') as f: json.dump(outline.to_dict(), f, ensure_ascii=False, indent=2) - logger.info(f"大纲已保存: {report_id}") + logger.info(f"Outline saved: {report_id}") @classmethod def save_section( @@ -2098,49 +2098,49 @@ def save_section( section: ReportSection ) -> str: """ - 保存单个章节 + Save individual section - 在每个章节生成完成后立即调用,实现分章节输出 + Called immediately after each section generation completes, enabling per-section output Args: - report_id: 报告ID - section_index: 章节索引(从1开始) - section: 章节对象 + report_id: Report ID + section_index: Section index (starting from 1) + section: Section object Returns: - 保存的文件路径 + Saved file path """ cls._ensure_report_folder(report_id) - # 构建章节Markdown内容 - 清理可能存在的重复标题 + # Build section Markdown content - clean possible duplicate titles cleaned_content = cls._clean_section_content(section.content, section.title) md_content = f"## {section.title}\n\n" if cleaned_content: md_content += f"{cleaned_content}\n\n" - # 保存文件 + # Save file file_suffix = f"section_{section_index:02d}.md" file_path = os.path.join(cls._get_report_folder(report_id), file_suffix) with open(file_path, 'w', encoding='utf-8') as f: f.write(md_content) - logger.info(f"章节已保存: {report_id}/{file_suffix}") + logger.info(f"Section saved: {report_id}/{file_suffix}") return file_path @classmethod def _clean_section_content(cls, content: str, section_title: str) -> str: """ - 清理章节内容 - - 1. 移除内容开头与章节标题重复的Markdown标题行 - 2. 将所有 ### 及以下级别的标题转换为粗体文本 - + Clean section content + + 1. Remove Markdown heading lines at the beginning that duplicate the section title + 2. Convert all ### and below level headings to bold text + Args: - content: 原始内容 - section_title: 章节标题 - + content: Original content + section_title: Section title + Returns: - 清理后的内容 + Cleaned content """ import re @@ -2155,26 +2155,26 @@ def _clean_section_content(cls, content: str, section_title: str) -> str: for i, line in enumerate(lines): stripped = line.strip() - # 检查是否是Markdown标题行 + # Check if it's a Markdown heading line heading_match = re.match(r'^(#{1,6})\s+(.+)$', stripped) if heading_match: level = len(heading_match.group(1)) title_text = heading_match.group(2).strip() - # 检查是否是与章节标题重复的标题(跳过前5行内的重复) + # Check if it's a duplicate of the section title (skip duplicates within first 5 lines) if i < 5: if title_text == section_title or title_text.replace(' ', '') == section_title.replace(' ', ''): skip_next_empty = True continue - # 将所有级别的标题(#, ##, ###, ####等)转换为粗体 - # 因为章节标题由系统添加,内容中不应有任何标题 + # Convert all level headings (#, ##, ###, #### etc.) to bold + # Because section titles are added by the system, content should not have any headings cleaned_lines.append(f"**{title_text}**") - cleaned_lines.append("") # 添加空行 + cleaned_lines.append("") # Add empty line continue - # 如果上一行是被跳过的标题,且当前行为空,也跳过 + # If previous line was a skipped heading and current line is empty, also skip if skip_next_empty and stripped == '': skip_next_empty = False continue @@ -2182,14 +2182,14 @@ def _clean_section_content(cls, content: str, section_title: str) -> str: skip_next_empty = False cleaned_lines.append(line) - # 移除开头的空行 + # Remove leading empty lines while cleaned_lines and cleaned_lines[0].strip() == '': cleaned_lines.pop(0) - # 移除开头的分隔线 + # Remove leading separator lines while cleaned_lines and cleaned_lines[0].strip() in ['---', '***', '___']: cleaned_lines.pop(0) - # 同时移除分隔线后的空行 + # Also remove empty lines after separator while cleaned_lines and cleaned_lines[0].strip() == '': cleaned_lines.pop(0) @@ -2206,9 +2206,9 @@ def update_progress( completed_sections: List[str] = None ) -> None: """ - 更新报告生成进度 - - 前端可以通过读取progress.json获取实时进度 + Update report generation progress + + Frontend can read progress.json to get realtime progress """ cls._ensure_report_folder(report_id) @@ -2226,7 +2226,7 @@ def update_progress( @classmethod def get_progress(cls, report_id: str) -> Optional[Dict[str, Any]]: - """获取报告生成进度""" + """Get report generation progress""" path = cls._get_progress_path(report_id) if not os.path.exists(path): @@ -2238,9 +2238,9 @@ def get_progress(cls, report_id: str) -> Optional[Dict[str, Any]]: @classmethod def get_generated_sections(cls, report_id: str) -> List[Dict[str, Any]]: """ - 获取已生成的章节列表 - - 返回所有已保存的章节文件信息 + Get list of generated sections + + Return all saved section file information """ folder = cls._get_report_folder(report_id) @@ -2254,7 +2254,7 @@ def get_generated_sections(cls, report_id: str) -> List[Dict[str, Any]]: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() - # 从文件名解析章节索引 + # Parse section index from filename parts = filename.replace('.md', '').split('_') section_index = int(parts[1]) @@ -2269,48 +2269,48 @@ def get_generated_sections(cls, report_id: str) -> List[Dict[str, Any]]: @classmethod def assemble_full_report(cls, report_id: str, outline: ReportOutline) -> str: """ - 组装完整报告 - - 从已保存的章节文件组装完整报告,并进行标题清理 + Assemble complete report + + Assemble complete report from saved section files, with heading cleanup """ folder = cls._get_report_folder(report_id) - # 构建报告头部 + # Build report header md_content = f"# {outline.title}\n\n" md_content += f"> {outline.summary}\n\n" md_content += f"---\n\n" - # 按顺序读取所有章节文件 + # Read all section files in order sections = cls.get_generated_sections(report_id) for section_info in sections: md_content += section_info["content"] - # 后处理:清理整个报告的标题问题 + # Post-processing: clean heading issues in the entire report md_content = cls._post_process_report(md_content, outline) - # 保存完整报告 + # Save complete report full_path = cls._get_report_markdown_path(report_id) with open(full_path, 'w', encoding='utf-8') as f: f.write(md_content) - logger.info(f"完整报告已组装: {report_id}") + logger.info(f"Complete report assembled: {report_id}") return md_content @classmethod def _post_process_report(cls, content: str, outline: ReportOutline) -> str: """ - 后处理报告内容 - - 1. 移除重复的标题 - 2. 保留报告主标题(#)和章节标题(##),移除其他级别的标题(###, ####等) - 3. 清理多余的空行和分隔线 - + Post-process report content + + 1. Remove duplicate headings + 2. Keep report main title (#) and section titles (##), remove other level headings (###, #### etc.) + 3. Clean redundant empty lines and separator lines + Args: - content: 原始报告内容 - outline: 报告大纲 - + content: Original report content + outline: Report outline + Returns: - 处理后的内容 + Processed content """ import re @@ -2318,7 +2318,7 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: processed_lines = [] prev_was_heading = False - # 收集大纲中的所有章节标题 + # Collect all section titles from outline section_titles = set() for section in outline.sections: section_titles.add(section.title) @@ -2328,14 +2328,14 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: line = lines[i] stripped = line.strip() - # 检查是否是标题行 + # Check if it's a heading line heading_match = re.match(r'^(#{1,6})\s+(.+)$', stripped) if heading_match: level = len(heading_match.group(1)) title = heading_match.group(2).strip() - # 检查是否是重复标题(在连续5行内出现相同内容的标题) + # Check if it's a duplicate heading (same content heading within consecutive 5 lines) is_duplicate = False for j in range(max(0, len(processed_lines) - 5), len(processed_lines)): prev_line = processed_lines[j].strip() @@ -2347,43 +2347,43 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: break if is_duplicate: - # 跳过重复标题及其后的空行 + # Skip duplicate heading and following empty lines i += 1 while i < len(lines) and lines[i].strip() == '': i += 1 continue - # 标题层级处理: - # - # (level=1) 只保留报告主标题 - # - ## (level=2) 保留章节标题 - # - ### 及以下 (level>=3) 转换为粗体文本 + # Heading level handling: + # - # (level=1) only keep report main title + # - ## (level=2) keep section titles + # - ### and below (level>=3) convert to bold text if level == 1: if title == outline.title: - # 保留报告主标题 + # Keep report main title processed_lines.append(line) prev_was_heading = True elif title in section_titles: - # 章节标题错误使用了#,修正为## + # Section title incorrectly used #, correct to ## processed_lines.append(f"## {title}") prev_was_heading = True else: - # 其他一级标题转为粗体 + # Other first-level headings convert to bold processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False elif level == 2: if title in section_titles or title == outline.title: - # 保留章节标题 + # Keep section title processed_lines.append(line) prev_was_heading = True else: - # 非章节的二级标题转为粗体 + # Non-section second-level headings convert to bold processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False else: - # ### 及以下级别的标题转换为粗体文本 + # ### and below level headings convert to bold text processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False @@ -2392,12 +2392,12 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: continue elif stripped == '---' and prev_was_heading: - # 跳过标题后紧跟的分隔线 + # Skip separator lines immediately following headings i += 1 continue elif stripped == '' and prev_was_heading: - # 标题后只保留一个空行 + # Keep only one empty line after heading if processed_lines and processed_lines[-1].strip() != '': processed_lines.append(line) prev_was_heading = False @@ -2408,7 +2408,7 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: i += 1 - # 清理连续的多个空行(保留最多2个) + # Clean consecutive multiple empty lines (keep at most 2) result_lines = [] empty_count = 0 for line in processed_lines: @@ -2424,31 +2424,31 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: @classmethod def save_report(cls, report: Report) -> None: - """保存报告元信息和完整报告""" + """Save report metadata and complete report""" cls._ensure_report_folder(report.report_id) - # 保存元信息JSON + # Save metadata JSON with open(cls._get_report_path(report.report_id), 'w', encoding='utf-8') as f: json.dump(report.to_dict(), f, ensure_ascii=False, indent=2) - # 保存大纲 + # Save outline if report.outline: cls.save_outline(report.report_id, report.outline) - # 保存完整Markdown报告 + # Save complete Markdown report if report.markdown_content: with open(cls._get_report_markdown_path(report.report_id), 'w', encoding='utf-8') as f: f.write(report.markdown_content) - logger.info(f"报告已保存: {report.report_id}") + logger.info(f"Report saved: {report.report_id}") @classmethod def get_report(cls, report_id: str) -> Optional[Report]: - """获取报告""" + """Get report""" path = cls._get_report_path(report_id) if not os.path.exists(path): - # 兼容旧格式:检查直接存储在reports目录下的文件 + # Backward compatibility: check files stored directly in reports directory old_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.json") if os.path.exists(old_path): path = old_path @@ -2458,7 +2458,7 @@ def get_report(cls, report_id: str) -> Optional[Report]: with open(path, 'r', encoding='utf-8') as f: data = json.load(f) - # 重建Report对象 + # Rebuild Report object outline = None if data.get('outline'): outline_data = data['outline'] @@ -2474,7 +2474,7 @@ def get_report(cls, report_id: str) -> Optional[Report]: sections=sections ) - # 如果markdown_content为空,尝试从full_report.md读取 + # If markdown_content is empty, try reading from full_report.md markdown_content = data.get('markdown_content', '') if not markdown_content: full_report_path = cls._get_report_markdown_path(report_id) @@ -2497,17 +2497,17 @@ def get_report(cls, report_id: str) -> Optional[Report]: @classmethod def get_report_by_simulation(cls, simulation_id: str) -> Optional[Report]: - """根据模拟ID获取报告""" + """Get report by simulation ID""" cls._ensure_reports_dir() for item in os.listdir(cls.REPORTS_DIR): item_path = os.path.join(cls.REPORTS_DIR, item) - # 新格式:文件夹 + # New format: folder if os.path.isdir(item_path): report = cls.get_report(item) if report and report.simulation_id == simulation_id: return report - # 兼容旧格式:JSON文件 + # Backward compatible: JSON file elif item.endswith('.json'): report_id = item[:-5] report = cls.get_report(report_id) @@ -2518,19 +2518,19 @@ def get_report_by_simulation(cls, simulation_id: str) -> Optional[Report]: @classmethod def list_reports(cls, simulation_id: Optional[str] = None, limit: int = 50) -> List[Report]: - """列出报告""" + """List reports""" cls._ensure_reports_dir() reports = [] for item in os.listdir(cls.REPORTS_DIR): item_path = os.path.join(cls.REPORTS_DIR, item) - # 新格式:文件夹 + # New format: folder if os.path.isdir(item_path): report = cls.get_report(item) if report: if simulation_id is None or report.simulation_id == simulation_id: reports.append(report) - # 兼容旧格式:JSON文件 + # Backward compatible: JSON file elif item.endswith('.json'): report_id = item[:-5] report = cls.get_report(report_id) @@ -2538,25 +2538,25 @@ def list_reports(cls, simulation_id: Optional[str] = None, limit: int = 50) -> L if simulation_id is None or report.simulation_id == simulation_id: reports.append(report) - # 按创建时间倒序 + # Sort by creation time descending reports.sort(key=lambda r: r.created_at, reverse=True) return reports[:limit] @classmethod def delete_report(cls, report_id: str) -> bool: - """删除报告(整个文件夹)""" + """Delete report (entire folder)""" import shutil folder_path = cls._get_report_folder(report_id) - # 新格式:删除整个文件夹 + # New format: delete entire folder if os.path.exists(folder_path) and os.path.isdir(folder_path): shutil.rmtree(folder_path) - logger.info(f"报告文件夹已删除: {report_id}") + logger.info(f"Report folder deleted: {report_id}") return True - # 兼容旧格式:删除单独的文件 + # Backward compatible: delete individual files deleted = False old_json_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.json") old_md_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.md") diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py index cc362508b..f55b2cb96 100644 --- a/backend/app/services/simulation_config_generator.py +++ b/backend/app/services/simulation_config_generator.py @@ -1,13 +1,13 @@ """ -模拟配置智能生成器 -使用LLM根据模拟需求、文档内容、图谱信息自动生成细致的模拟参数 -实现全程自动化,无需人工设置参数 - -采用分步生成策略,避免一次性生成过长内容导致失败: -1. 生成时间配置 -2. 生成事件配置 -3. 分批生成Agent配置 -4. 生成平台配置 +Simulation Configuration Intelligent Generator +Use LLM to automatically generate detailed simulation parameters based on simulation requirements, document content, and knowledge graph information +Implement full process automation without manual parameter setting + +Adopt step-by-step generation strategy to avoid failures from generating too long content at once: +1. Generate time configuration +2. Generate event configuration +3. Generate agent configurations in batches +4. Generate platform configuration """ import json @@ -24,156 +24,156 @@ logger = get_logger('mirofish.simulation_config') -# 中国作息时间配置(北京时间) +# Time zone configuration for Chinese work schedules (Beijing Time) CHINA_TIMEZONE_CONFIG = { - # 深夜时段(几乎无人活动) + # Dead hours (almost no activity) "dead_hours": [0, 1, 2, 3, 4, 5], - # 早间时段(逐渐醒来) + # Morning hours (gradually waking up) "morning_hours": [6, 7, 8], - # 工作时段 + # Work hours "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], - # 晚间高峰(最活跃) + # Evening peak (most active) "peak_hours": [19, 20, 21, 22], - # 夜间时段(活跃度下降) + # Night hours (activity decreases) "night_hours": [23], - # 活跃度系数 + # Activity multipliers "activity_multipliers": { - "dead": 0.05, # 凌晨几乎无人 - "morning": 0.4, # 早间逐渐活跃 - "work": 0.7, # 工作时段中等 - "peak": 1.5, # 晚间高峰 - "night": 0.5 # 深夜下降 + "dead": 0.05, # Almost no one in early morning + "morning": 0.4, # Gradually active in morning + "work": 0.7, # Medium activity during work hours + "peak": 1.5, # Evening peak + "night": 0.5 # Activity decreases at night } } @dataclass class AgentActivityConfig: - """单个Agent的活动配置""" + """Activity configuration for a single Agent""" agent_id: int entity_uuid: str entity_name: str entity_type: str - # 活跃度配置 (0.0-1.0) - activity_level: float = 0.5 # 整体活跃度 - - # 发言频率(每小时预期发言次数) + # Activity configuration (0.0-1.0) + activity_level: float = 0.5 # Overall activity level + + # Speech frequency (expected posts per hour) posts_per_hour: float = 1.0 comments_per_hour: float = 2.0 - # 活跃时间段(24小时制,0-23) + # Active time periods (24-hour format, 0-23) active_hours: List[int] = field(default_factory=lambda: list(range(8, 23))) - # 响应速度(对热点事件的反应延迟,单位:模拟分钟) + # Response speed (reaction delay to trending events, unit: simulation minutes) response_delay_min: int = 5 response_delay_max: int = 60 - # 情感倾向 (-1.0到1.0,负面到正面) + # Sentiment tendency (-1.0 to 1.0, negative to positive) sentiment_bias: float = 0.0 - # 立场(对特定话题的态度) + # Stance (attitude toward specific topics) stance: str = "neutral" # supportive, opposing, neutral, observer - # 影响力权重(决定其发言被其他Agent看到的概率) + # Influence weight (determines probability of their speech being seen by other agents) influence_weight: float = 1.0 @dataclass class TimeSimulationConfig: - """时间模拟配置(基于中国人作息习惯)""" - # 模拟总时长(模拟小时数) - total_simulation_hours: int = 72 # 默认模拟72小时(3天) - - # 每轮代表的时间(模拟分钟)- 默认60分钟(1小时),加快时间流速 + """Time simulation configuration (based on Chinese work schedule habits)""" + # Total simulation time (simulation hours) + total_simulation_hours: int = 72 # Default 72 hours (3 days) + + # Time represented per round (simulation minutes) - default 60 minutes (1 hour), speed up time minutes_per_round: int = 60 - # 每小时激活的Agent数量范围 + # Range of agents activated per hour agents_per_hour_min: int = 5 agents_per_hour_max: int = 20 - # 高峰时段(晚间19-22点,中国人最活跃的时间) + # Peak hours (evening 19-22, most active time for Chinese people) peak_hours: List[int] = field(default_factory=lambda: [19, 20, 21, 22]) peak_activity_multiplier: float = 1.5 - # 低谷时段(凌晨0-5点,几乎无人活动) + # Off-peak hours (early morning 0-5, almost no activity) off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5]) - off_peak_activity_multiplier: float = 0.05 # 凌晨活跃度极低 - - # 早间时段 + off_peak_activity_multiplier: float = 0.05 # Very low activity in early morning + + # Morning hours morning_hours: List[int] = field(default_factory=lambda: [6, 7, 8]) morning_activity_multiplier: float = 0.4 - # 工作时段 + # Work hours work_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18]) work_activity_multiplier: float = 0.7 @dataclass class EventConfig: - """事件配置""" - # 初始事件(模拟开始时的触发事件) + """Event configuration""" + # Initial posts (triggering events at the start of simulation) initial_posts: List[Dict[str, Any]] = field(default_factory=list) - # 定时事件(在特定时间触发的事件) + # Scheduled events (events triggered at specific times) scheduled_events: List[Dict[str, Any]] = field(default_factory=list) - # 热点话题关键词 + # Hot topic keywords hot_topics: List[str] = field(default_factory=list) - # 舆论引导方向 + # Opinion narrative direction narrative_direction: str = "" @dataclass class PlatformConfig: - """平台特定配置""" + """Platform-specific configuration""" platform: str # twitter or reddit - # 推荐算法权重 - recency_weight: float = 0.4 # 时间新鲜度 - popularity_weight: float = 0.3 # 热度 - relevance_weight: float = 0.3 # 相关性 - - # 病毒传播阈值(达到多少互动后触发扩散) + # Recommendation algorithm weights + recency_weight: float = 0.4 # Time freshness + popularity_weight: float = 0.3 # Popularity + relevance_weight: float = 0.3 # Relevance + + # Viral threshold (number of interactions before triggering spread) viral_threshold: int = 10 - # 回声室效应强度(相似观点聚集程度) + # Echo chamber effect strength (degree of similar opinion clustering) echo_chamber_strength: float = 0.5 @dataclass class SimulationParameters: - """完整的模拟参数配置""" - # 基础信息 + """Complete simulation parameter configuration""" + # Basic information simulation_id: str project_id: str graph_id: str simulation_requirement: str - # 时间配置 + # Time configuration time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig) - - # Agent配置列表 + + # Agent configuration list agent_configs: List[AgentActivityConfig] = field(default_factory=list) - - # 事件配置 + + # Event configuration event_config: EventConfig = field(default_factory=EventConfig) - - # 平台配置 + + # Platform configuration twitter_config: Optional[PlatformConfig] = None reddit_config: Optional[PlatformConfig] = None - # LLM配置 + # LLM configuration llm_model: str = "" llm_base_url: str = "" - # 生成元数据 + # Generation metadata generated_at: str = field(default_factory=lambda: datetime.now().isoformat()) - generation_reasoning: str = "" # LLM的推理说明 + generation_reasoning: str = "" # LLM reasoning explanation def to_dict(self) -> Dict[str, Any]: - """转换为字典""" + """Convert to dictionary""" time_dict = asdict(self.time_config) return { "simulation_id": self.simulation_id, @@ -192,34 +192,34 @@ def to_dict(self) -> Dict[str, Any]: } def to_json(self, indent: int = 2) -> str: - """转换为JSON字符串""" + """Convert to JSON string""" return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent) class SimulationConfigGenerator: """ - 模拟配置智能生成器 - - 使用LLM分析模拟需求、文档内容、图谱实体信息, - 自动生成最佳的模拟参数配置 - - 采用分步生成策略: - 1. 生成时间配置和事件配置(轻量级) - 2. 分批生成Agent配置(每批10-20个) - 3. 生成平台配置 + Simulation Configuration Intelligent Generator + + Use LLM to analyze simulation requirements, document content, knowledge graph entity information, + and automatically generate optimal simulation parameter configuration + + Adopt step-by-step generation strategy: + 1. Generate time configuration and event configuration (lightweight) + 2. Generate agent configurations in batches (10-20 per batch) + 3. Generate platform configuration """ - - # 上下文最大字符数 + + # Maximum context length in characters MAX_CONTEXT_LENGTH = 50000 - # 每批生成的Agent数量 + # Number of agents per batch AGENTS_PER_BATCH = 15 - - # 各步骤的上下文截断长度(字符数) - TIME_CONFIG_CONTEXT_LENGTH = 10000 # 时间配置 - EVENT_CONFIG_CONTEXT_LENGTH = 8000 # 事件配置 - ENTITY_SUMMARY_LENGTH = 300 # 实体摘要 - AGENT_SUMMARY_LENGTH = 300 # Agent配置中的实体摘要 - ENTITIES_PER_TYPE_DISPLAY = 20 # 每类实体显示数量 + + # Context truncation length for each step (characters) + TIME_CONFIG_CONTEXT_LENGTH = 10000 # Time configuration + EVENT_CONFIG_CONTEXT_LENGTH = 8000 # Event configuration + ENTITY_SUMMARY_LENGTH = 300 # Entity summary + AGENT_SUMMARY_LENGTH = 300 # Entity summary in agent configuration + ENTITIES_PER_TYPE_DISPLAY = 20 # Number of entities to display per type def __init__( self, @@ -232,7 +232,7 @@ def __init__( self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: - raise ValueError("LLM_API_KEY 未配置") + raise ValueError("LLM_API_KEY not configured") self.client = OpenAI( api_key=self.api_key, @@ -252,27 +252,27 @@ def generate_config( progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> SimulationParameters: """ - 智能生成完整的模拟配置(分步生成) - + Intelligently generate complete simulation configuration (step-by-step generation) + Args: - simulation_id: 模拟ID - project_id: 项目ID - graph_id: 图谱ID - simulation_requirement: 模拟需求描述 - document_text: 原始文档内容 - entities: 过滤后的实体列表 - enable_twitter: 是否启用Twitter - enable_reddit: 是否启用Reddit - progress_callback: 进度回调函数(current_step, total_steps, message) - + simulation_id: Simulation ID + project_id: Project ID + graph_id: Knowledge graph ID + simulation_requirement: Simulation requirement description + document_text: Original document content + entities: Filtered entity list + enable_twitter: Whether to enable Twitter + enable_reddit: Whether to enable Reddit + progress_callback: Progress callback function(current_step, total_steps, message) + Returns: - SimulationParameters: 完整的模拟参数 + SimulationParameters: Complete simulation parameters """ - logger.info(f"开始智能生成模拟配置: simulation_id={simulation_id}, 实体数={len(entities)}") + logger.info(f"Starting intelligent simulation configuration generation: simulation_id={simulation_id}, entities={len(entities)}") - # 计算总步骤数 + # Calculate total steps num_batches = math.ceil(len(entities) / self.AGENTS_PER_BATCH) - total_steps = 3 + num_batches # 时间配置 + 事件配置 + N批Agent + 平台配置 + total_steps = 3 + num_batches # time config + event config + N batch agents + platform config current_step = 0 def report_progress(step: int, message: str): @@ -282,7 +282,7 @@ def report_progress(step: int, message: str): progress_callback(step, total_steps, message) logger.info(f"[{step}/{total_steps}] {message}") - # 1. 构建基础上下文信息 + # 1. Build basic context information context = self._build_context( simulation_requirement=simulation_requirement, document_text=document_text, @@ -291,20 +291,20 @@ def report_progress(step: int, message: str): reasoning_parts = [] - # ========== 步骤1: 生成时间配置 ========== - report_progress(1, "生成时间配置...") + # ========== Step 1: Generate time configuration ========== + report_progress(1, "Generating time configuration...") num_entities = len(entities) time_config_result = self._generate_time_config(context, num_entities) time_config = self._parse_time_config(time_config_result, num_entities) - reasoning_parts.append(f"时间配置: {time_config_result.get('reasoning', '成功')}") - - # ========== 步骤2: 生成事件配置 ========== - report_progress(2, "生成事件配置和热点话题...") + reasoning_parts.append(f"Time config: {time_config_result.get('reasoning', 'Success')}") + + # ========== Step 2: Generate event configuration ========== + report_progress(2, "Generating event configuration and hot topics...") event_config_result = self._generate_event_config(context, simulation_requirement, entities) event_config = self._parse_event_config(event_config_result) - reasoning_parts.append(f"事件配置: {event_config_result.get('reasoning', '成功')}") - - # ========== 步骤3-N: 分批生成Agent配置 ========== + reasoning_parts.append(f"Event config: {event_config_result.get('reasoning', 'Success')}") + + # ========== Step 3-N: Generate agent configurations in batches ========== all_agent_configs = [] for batch_idx in range(num_batches): start_idx = batch_idx * self.AGENTS_PER_BATCH @@ -313,7 +313,7 @@ def report_progress(step: int, message: str): report_progress( 3 + batch_idx, - f"生成Agent配置 ({start_idx + 1}-{end_idx}/{len(entities)})..." + f"Generating agent configuration ({start_idx + 1}-{end_idx}/{len(entities)})..." ) batch_configs = self._generate_agent_configs_batch( @@ -324,16 +324,16 @@ def report_progress(step: int, message: str): ) all_agent_configs.extend(batch_configs) - reasoning_parts.append(f"Agent配置: 成功生成 {len(all_agent_configs)} 个") - - # ========== 为初始帖子分配发布者 Agent ========== - logger.info("为初始帖子分配合适的发布者 Agent...") + reasoning_parts.append(f"Agent config: Successfully generated {len(all_agent_configs)}") + + # ========== Assign initial post agents ========== + logger.info("Assigning appropriate publisher agents to initial posts...") event_config = self._assign_initial_post_agents(event_config, all_agent_configs) assigned_count = len([p for p in event_config.initial_posts if p.get("poster_agent_id") is not None]) - reasoning_parts.append(f"初始帖子分配: {assigned_count} 个帖子已分配发布者") - - # ========== 最后一步: 生成平台配置 ========== - report_progress(total_steps, "生成平台配置...") + reasoning_parts.append(f"Initial posts assigned: {assigned_count} posts assigned publishers") + + # ========== Final step: Generate platform configuration ========== + report_progress(total_steps, "Generating platform configuration...") twitter_config = None reddit_config = None @@ -357,7 +357,7 @@ def report_progress(step: int, message: str): echo_chamber_strength=0.6 ) - # 构建最终参数 + # Build final parameters params = SimulationParameters( simulation_id=simulation_id, project_id=project_id, @@ -373,7 +373,7 @@ def report_progress(step: int, message: str): generation_reasoning=" | ".join(reasoning_parts) ) - logger.info(f"模拟配置生成完成: {len(params.agent_configs)} 个Agent配置") + logger.info(f"Simulation configuration generation complete: {len(params.agent_configs)} agent configurations") return params @@ -383,33 +383,33 @@ def _build_context( document_text: str, entities: List[EntityNode] ) -> str: - """构建LLM上下文,截断到最大长度""" - - # 实体摘要 + """Build LLM context, truncate to maximum length""" + + # Entity summary entity_summary = self._summarize_entities(entities) - # 构建上下文 + # Build context context_parts = [ - f"## 模拟需求\n{simulation_requirement}", - f"\n## 实体信息 ({len(entities)}个)\n{entity_summary}", + f"## Simulation Requirements\n{simulation_requirement}", + f"\n## Entity Information ({len(entities)})\n{entity_summary}", ] current_length = sum(len(p) for p in context_parts) - remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量 + remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # Reserve 500 characters if remaining_length > 0 and document_text: doc_text = document_text[:remaining_length] if len(document_text) > remaining_length: - doc_text += "\n...(文档已截断)" - context_parts.append(f"\n## 原始文档内容\n{doc_text}") + doc_text += "\n...(document truncated)" + context_parts.append(f"\n## Original Document Content\n{doc_text}") return "\n".join(context_parts) def _summarize_entities(self, entities: List[EntityNode]) -> str: - """生成实体摘要""" + """Generate entity summary""" lines = [] - # 按类型分组 + # Group by type by_type: Dict[str, List[EntityNode]] = {} for e in entities: t = e.get_entity_type() or "Unknown" @@ -418,20 +418,20 @@ def _summarize_entities(self, entities: List[EntityNode]) -> str: by_type[t].append(e) for entity_type, type_entities in by_type.items(): - lines.append(f"\n### {entity_type} ({len(type_entities)}个)") - # 使用配置的显示数量和摘要长度 + lines.append(f"\n### {entity_type} ({len(type_entities)})") + # Use configured display quantity and summary length display_count = self.ENTITIES_PER_TYPE_DISPLAY summary_len = self.ENTITY_SUMMARY_LENGTH for e in type_entities[:display_count]: summary_preview = (e.summary[:summary_len] + "...") if len(e.summary) > summary_len else e.summary lines.append(f"- {e.name}: {summary_preview}") if len(type_entities) > display_count: - lines.append(f" ... 还有 {len(type_entities) - display_count} 个") + lines.append(f" ... and {len(type_entities) - display_count} more") return "\n".join(lines) def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]: - """带重试的LLM调用,包含JSON修复逻辑""" + """LLM call with retry, including JSON repair logic""" import re max_attempts = 3 @@ -446,25 +446,25 @@ def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 + temperature=0.7 - (attempt * 0.1) # Lower temperature with each retry + # Don't set max_tokens, let LLM generate freely ) content = response.choices[0].message.content finish_reason = response.choices[0].finish_reason - # 检查是否被截断 + # Check if output was truncated if finish_reason == 'length': - logger.warning(f"LLM输出被截断 (attempt {attempt+1})") + logger.warning(f"LLM output truncated (attempt {attempt+1})") content = self._fix_truncated_json(content) - # 尝试解析JSON + # Try to parse JSON try: return json.loads(content) except json.JSONDecodeError as e: - logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}") - - # 尝试修复JSON + logger.warning(f"JSON parsing failed (attempt {attempt+1}): {str(e)[:80]}") + + # Try to fix JSON fixed = self._try_fix_config_json(content) if fixed: return fixed @@ -472,44 +472,44 @@ def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any last_error = e except Exception as e: - logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") + logger.warning(f"LLM call failed (attempt {attempt+1}): {str(e)[:80]}") last_error = e import time time.sleep(2 * (attempt + 1)) - raise last_error or Exception("LLM调用失败") + raise last_error or Exception("LLM call failed") def _fix_truncated_json(self, content: str) -> str: - """修复被截断的JSON""" + """Fix truncated JSON""" content = content.strip() - # 计算未闭合的括号 + # Count unclosed parentheses open_braces = content.count('{') - content.count('}') open_brackets = content.count('[') - content.count(']') - # 检查是否有未闭合的字符串 + # Check for unclosed strings if content and content[-1] not in '",}]': content += '"' - # 闭合括号 + # Close parentheses content += ']' * open_brackets content += '}' * open_braces return content def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]: - """尝试修复配置JSON""" + """Try to fix configuration JSON""" import re - # 修复被截断的情况 + # Fix truncated case content = self._fix_truncated_json(content) - # 提取JSON部分 + # Extract JSON portion json_match = re.search(r'\{[\s\S]*\}', content) if json_match: json_str = json_match.group() - # 移除字符串中的换行符 + # Remove newlines in strings def fix_string(match): s = match.group(0) s = s.replace('\n', ' ').replace('\r', ' ') @@ -521,7 +521,7 @@ def fix_string(match): try: return json.loads(json_str) except: - # 尝试移除所有控制字符 + # Try removing all control characters json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) json_str = re.sub(r'\s+', ' ', json_str) try: @@ -532,35 +532,35 @@ def fix_string(match): return None def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]: - """生成时间配置""" - # 使用配置的上下文截断长度 + """Generate time configuration""" + # Use configured context truncation length context_truncated = context[:self.TIME_CONFIG_CONTEXT_LENGTH] - # 计算最大允许值(80%的agent数) + # Calculate maximum allowed value (90% of agents) max_agents_allowed = max(1, int(num_entities * 0.9)) - prompt = f"""基于以下模拟需求,生成时间模拟配置。 + prompt = f"""Based on the following simulation requirements, generate time simulation configuration. {context_truncated} -## 任务 -请生成时间配置JSON。 +## Task +Please generate time configuration JSON. -### 基本原则(仅供参考,需根据具体事件和参与群体灵活调整): -- 用户群体为中国人,需符合北京时间作息习惯 -- 凌晨0-5点几乎无人活动(活跃度系数0.05) -- 早上6-8点逐渐活跃(活跃度系数0.4) -- 工作时间9-18点中等活跃(活跃度系数0.7) -- 晚间19-22点是高峰期(活跃度系数1.5) -- 23点后活跃度下降(活跃度系数0.5) -- 一般规律:凌晨低活跃、早间渐增、工作时段中等、晚间高峰 -- **重要**:以下示例值仅供参考,你需要根据事件性质、参与群体特点来调整具体时段 - - 例如:学生群体高峰可能是21-23点;媒体全天活跃;官方机构只在工作时间 - - 例如:突发热点可能导致深夜也有讨论,off_peak_hours 可适当缩短 +### Basic principles (for reference only, adjust flexibly based on event nature and participant characteristics): +- User base is Chinese people, must follow Beijing Time work schedule habits +- 0-5am almost no activity (activity coefficient 0.05) +- 6-8am gradually active (activity coefficient 0.4) +- 9-18 work time moderately active (activity coefficient 0.7) +- 19-22 evening is peak period (activity coefficient 1.5) +- After 23 activity decreases (activity coefficient 0.5) +- General rule: low activity early morning, gradually increasing morning, moderate work time, evening peak +- **Important**: Example values below are for reference only, adjust specific time periods based on event nature and participant characteristics + - Example: student peak may be 21-23; media active all day; official institutions only during work hours + - Example: breaking news may cause late night discussions, off_peak_hours can be shortened appropriately -### 返回JSON格式(不要markdown) +### Return JSON format (no markdown) -示例: +Example: {{ "total_simulation_hours": 72, "minutes_per_round": 60, @@ -570,70 +570,70 @@ def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, An "off_peak_hours": [0, 1, 2, 3, 4, 5], "morning_hours": [6, 7, 8], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], - "reasoning": "针对该事件的时间配置说明" + "reasoning": "Explanation of time configuration for this event" }} -字段说明: -- total_simulation_hours (int): 模拟总时长,24-168小时,突发事件短、持续话题长 -- minutes_per_round (int): 每轮时长,30-120分钟,建议60分钟 -- agents_per_hour_min (int): 每小时最少激活Agent数(取值范围: 1-{max_agents_allowed}) -- agents_per_hour_max (int): 每小时最多激活Agent数(取值范围: 1-{max_agents_allowed}) -- peak_hours (int数组): 高峰时段,根据事件参与群体调整 -- off_peak_hours (int数组): 低谷时段,通常深夜凌晨 -- morning_hours (int数组): 早间时段 -- work_hours (int数组): 工作时段 -- reasoning (string): 简要说明为什么这样配置""" +Field description: +- total_simulation_hours (int): Total simulation time, 24-168 hours, short for breaking news, long for ongoing topics +- minutes_per_round (int): Time per round, 30-120 minutes, recommend 60 minutes +- agents_per_hour_min (int): Minimum agents activated per hour (range: 1-{max_agents_allowed}) +- agents_per_hour_max (int): Maximum agents activated per hour (range: 1-{max_agents_allowed}) +- peak_hours (int array): Peak hours, adjust based on event participants +- off_peak_hours (int array): Off-peak hours, usually late night/early morning +- morning_hours (int array): Morning hours +- work_hours (int array): Work hours +- reasoning (string): Brief explanation for this configuration""" - system_prompt = "你是社交媒体模拟专家。返回纯JSON格式,时间配置需符合中国人作息习惯。" + system_prompt = "You are a social media simulation expert. Return pure JSON format, time configuration must follow Chinese work schedule habits." try: return self._call_llm_with_retry(prompt, system_prompt) except Exception as e: - logger.warning(f"时间配置LLM生成失败: {e}, 使用默认配置") + logger.warning(f"Time config LLM generation failed: {e}, using default configuration") return self._get_default_time_config(num_entities) def _get_default_time_config(self, num_entities: int) -> Dict[str, Any]: - """获取默认时间配置(中国人作息)""" + """Get default time configuration (Chinese work schedule)""" return { "total_simulation_hours": 72, - "minutes_per_round": 60, # 每轮1小时,加快时间流速 + "minutes_per_round": 60, # 1 hour per round, speed up time "agents_per_hour_min": max(1, num_entities // 15), "agents_per_hour_max": max(5, num_entities // 5), "peak_hours": [19, 20, 21, 22], "off_peak_hours": [0, 1, 2, 3, 4, 5], "morning_hours": [6, 7, 8], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], - "reasoning": "使用默认中国人作息配置(每轮1小时)" + "reasoning": "Using default Chinese work schedule configuration (1 hour per round)" } def _parse_time_config(self, result: Dict[str, Any], num_entities: int) -> TimeSimulationConfig: - """解析时间配置结果,并验证agents_per_hour值不超过总agent数""" - # 获取原始值 + """Parse time configuration result and verify agents_per_hour doesn't exceed total agents""" + # Get original values agents_per_hour_min = result.get("agents_per_hour_min", max(1, num_entities // 15)) agents_per_hour_max = result.get("agents_per_hour_max", max(5, num_entities // 5)) - # 验证并修正:确保不超过总agent数 + # Verify and correct: ensure not exceeding total agents if agents_per_hour_min > num_entities: - logger.warning(f"agents_per_hour_min ({agents_per_hour_min}) 超过总Agent数 ({num_entities}),已修正") + logger.warning(f"agents_per_hour_min ({agents_per_hour_min}) exceeds total agents ({num_entities}), corrected") agents_per_hour_min = max(1, num_entities // 10) if agents_per_hour_max > num_entities: - logger.warning(f"agents_per_hour_max ({agents_per_hour_max}) 超过总Agent数 ({num_entities}),已修正") + logger.warning(f"agents_per_hour_max ({agents_per_hour_max}) exceeds total agents ({num_entities}), corrected") agents_per_hour_max = max(agents_per_hour_min + 1, num_entities // 2) - # 确保 min < max + # Ensure min < max if agents_per_hour_min >= agents_per_hour_max: agents_per_hour_min = max(1, agents_per_hour_max // 2) - logger.warning(f"agents_per_hour_min >= max,已修正为 {agents_per_hour_min}") + logger.warning(f"agents_per_hour_min >= max, corrected to {agents_per_hour_min}") return TimeSimulationConfig( total_simulation_hours=result.get("total_simulation_hours", 72), - minutes_per_round=result.get("minutes_per_round", 60), # 默认每轮1小时 + minutes_per_round=result.get("minutes_per_round", 60), # Default 1 hour per round agents_per_hour_min=agents_per_hour_min, agents_per_hour_max=agents_per_hour_max, peak_hours=result.get("peak_hours", [19, 20, 21, 22]), off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]), - off_peak_activity_multiplier=0.05, # 凌晨几乎无人 + off_peak_activity_multiplier=0.05, # Almost no one in early morning morning_hours=result.get("morning_hours", [6, 7, 8]), morning_activity_multiplier=0.4, work_hours=result.get("work_hours", list(range(9, 19))), @@ -647,14 +647,14 @@ def _generate_event_config( simulation_requirement: str, entities: List[EntityNode] ) -> Dict[str, Any]: - """生成事件配置""" - - # 获取可用的实体类型列表,供 LLM 参考 + """Generate event configuration""" + + # Get available entity types list for LLM reference entity_types_available = list(set( e.get_entity_type() or "Unknown" for e in entities )) - # 为每种类型列出代表性实体名称 + # List representative entity names for each type type_examples = {} for e in entities: etype = e.get_entity_type() or "Unknown" @@ -668,53 +668,53 @@ def _generate_event_config( for t, examples in type_examples.items() ]) - # 使用配置的上下文截断长度 + # Use configured context truncation length context_truncated = context[:self.EVENT_CONFIG_CONTEXT_LENGTH] - - prompt = f"""基于以下模拟需求,生成事件配置。 -模拟需求: {simulation_requirement} + prompt = f"""Based on the following simulation requirements, generate event configuration. + +Simulation Requirements: {simulation_requirement} {context_truncated} -## 可用实体类型及示例 +## Available Entity Types and Examples {type_info} -## 任务 -请生成事件配置JSON: -- 提取热点话题关键词 -- 描述舆论发展方向 -- 设计初始帖子内容,**每个帖子必须指定 poster_type(发布者类型)** +## Task +Please generate event configuration JSON: +- Extract hot topic keywords +- Describe opinion development direction +- Design initial post content, **each post must specify poster_type (publisher type)** -**重要**: poster_type 必须从上面的"可用实体类型"中选择,这样初始帖子才能分配给合适的 Agent 发布。 -例如:官方声明应由 Official/University 类型发布,新闻由 MediaOutlet 发布,学生观点由 Student 发布。 +**Important**: poster_type must be selected from the "Available Entity Types" above so initial posts can be assigned to appropriate agents for publishing. +Example: Official statements should be published by Official/University type, news by MediaOutlet, student opinions by Student type. -返回JSON格式(不要markdown): +Return JSON format (no markdown): {{ - "hot_topics": ["关键词1", "关键词2", ...], - "narrative_direction": "<舆论发展方向描述>", + "hot_topics": ["keyword1", "keyword2", ...], + "narrative_direction": "", "initial_posts": [ - {{"content": "帖子内容", "poster_type": "实体类型(必须从可用类型中选择)"}}, + {{"content": "post content", "poster_type": "entity type (must select from available types)"}}, ... ], - "reasoning": "<简要说明>" + "reasoning": "" }}""" - system_prompt = "你是舆论分析专家。返回纯JSON格式。注意 poster_type 必须精确匹配可用实体类型。" + system_prompt = "You are an opinion analysis expert. Return pure JSON format. Note poster_type must match available entity types precisely." try: return self._call_llm_with_retry(prompt, system_prompt) except Exception as e: - logger.warning(f"事件配置LLM生成失败: {e}, 使用默认配置") + logger.warning(f"Event config LLM generation failed: {e}, using default configuration") return { "hot_topics": [], "narrative_direction": "", "initial_posts": [], - "reasoning": "使用默认配置" + "reasoning": "Using default configuration" } def _parse_event_config(self, result: Dict[str, Any]) -> EventConfig: - """解析事件配置结果""" + """Parse event configuration result""" return EventConfig( initial_posts=result.get("initial_posts", []), scheduled_events=[], @@ -728,14 +728,14 @@ def _assign_initial_post_agents( agent_configs: List[AgentActivityConfig] ) -> EventConfig: """ - 为初始帖子分配合适的发布者 Agent - - 根据每个帖子的 poster_type 匹配最合适的 agent_id + Assign appropriate publisher agents to initial posts + + Match agent_id based on each post's poster_type """ if not event_config.initial_posts: return event_config - # 按实体类型建立 agent 索引 + # Build agent index by entity type agents_by_type: Dict[str, List[AgentActivityConfig]] = {} for agent in agent_configs: etype = agent.entity_type.lower() @@ -743,7 +743,7 @@ def _assign_initial_post_agents( agents_by_type[etype] = [] agents_by_type[etype].append(agent) - # 类型映射表(处理 LLM 可能输出的不同格式) + # Type mapping table (handle different formats LLM might output) type_aliases = { "official": ["official", "university", "governmentagency", "government"], "university": ["university", "official"], @@ -755,7 +755,7 @@ def _assign_initial_post_agents( "person": ["person", "student", "alumni"], } - # 记录每种类型已使用的 agent 索引,避免重复使用同一个 agent + # Track used agent indices for each type to avoid reusing same agent used_indices: Dict[str, int] = {} updated_posts = [] @@ -763,17 +763,17 @@ def _assign_initial_post_agents( poster_type = post.get("poster_type", "").lower() content = post.get("content", "") - # 尝试找到匹配的 agent + # Try to find matching agent matched_agent_id = None - # 1. 直接匹配 + # 1. Direct match if poster_type in agents_by_type: agents = agents_by_type[poster_type] idx = used_indices.get(poster_type, 0) % len(agents) matched_agent_id = agents[idx].agent_id used_indices[poster_type] = idx + 1 else: - # 2. 使用别名匹配 + # 2. Match using aliases for alias_key, aliases in type_aliases.items(): if poster_type in aliases or alias_key == poster_type: for alias in aliases: @@ -786,11 +786,11 @@ def _assign_initial_post_agents( if matched_agent_id is not None: break - # 3. 如果仍未找到,使用影响力最高的 agent + # 3. If still not found, use agent with highest influence if matched_agent_id is None: - logger.warning(f"未找到类型 '{poster_type}' 的匹配 Agent,使用影响力最高的 Agent") + logger.warning(f"No matching agent found for type '{poster_type}', using agent with highest influence") if agent_configs: - # 按影响力排序,选择影响力最高的 + # Sort by influence, select highest sorted_agents = sorted(agent_configs, key=lambda a: a.influence_weight, reverse=True) matched_agent_id = sorted_agents[0].agent_id else: @@ -802,7 +802,7 @@ def _assign_initial_post_agents( "poster_agent_id": matched_agent_id }) - logger.info(f"初始帖子分配: poster_type='{poster_type}' -> agent_id={matched_agent_id}") + logger.info(f"Initial post assigned: poster_type='{poster_type}' -> agent_id={matched_agent_id}") event_config.initial_posts = updated_posts return event_config @@ -814,9 +814,9 @@ def _generate_agent_configs_batch( start_idx: int, simulation_requirement: str ) -> List[AgentActivityConfig]: - """分批生成Agent配置""" - - # 构建实体信息(使用配置的摘要长度) + """Generate agent configurations in batch""" + + # Build entity information (using configured summary length) entity_list = [] summary_len = self.AGENT_SUMMARY_LENGTH for i, e in enumerate(entities): @@ -827,58 +827,58 @@ def _generate_agent_configs_batch( "summary": e.summary[:summary_len] if e.summary else "" }) - prompt = f"""基于以下信息,为每个实体生成社交媒体活动配置。 + prompt = f"""Based on the following information, generate social media activity configuration for each entity. -模拟需求: {simulation_requirement} +Simulation Requirements: {simulation_requirement} -## 实体列表 +## Entity List ```json {json.dumps(entity_list, ensure_ascii=False, indent=2)} ``` -## 任务 -为每个实体生成活动配置,注意: -- **时间符合中国人作息**:凌晨0-5点几乎不活动,晚间19-22点最活跃 -- **官方机构**(University/GovernmentAgency):活跃度低(0.1-0.3),工作时间(9-17)活动,响应慢(60-240分钟),影响力高(2.5-3.0) -- **媒体**(MediaOutlet):活跃度中(0.4-0.6),全天活动(8-23),响应快(5-30分钟),影响力高(2.0-2.5) -- **个人**(Student/Person/Alumni):活跃度高(0.6-0.9),主要晚间活动(18-23),响应快(1-15分钟),影响力低(0.8-1.2) -- **公众人物/专家**:活跃度中(0.4-0.6),影响力中高(1.5-2.0) +## Task +Generate activity configuration for each entity, noting: +- **Time follows Chinese work schedule**: Almost no activity 0-5am, most active 19-22 +- **Official institutions** (University/GovernmentAgency): Low activity (0.1-0.3), active during work hours (9-17), slow response (60-240 min), high influence (2.5-3.0) +- **Media** (MediaOutlet): Medium activity (0.4-0.6), active all day (8-23), fast response (5-30 min), high influence (2.0-2.5) +- **Individuals** (Student/Person/Alumni): High activity (0.6-0.9), mainly evening activity (18-23), fast response (1-15 min), low influence (0.8-1.2) +- **Public figures/Experts**: Medium activity (0.4-0.6), medium-high influence (1.5-2.0) -返回JSON格式(不要markdown): +Return JSON format (no markdown): {{ "agent_configs": [ {{ - "agent_id": <必须与输入一致>, + "agent_id": , "activity_level": <0.0-1.0>, - "posts_per_hour": <发帖频率>, - "comments_per_hour": <评论频率>, - "active_hours": [<活跃小时列表,考虑中国人作息>], - "response_delay_min": <最小响应延迟分钟>, - "response_delay_max": <最大响应延迟分钟>, - "sentiment_bias": <-1.0到1.0>, + "posts_per_hour": , + "comments_per_hour": , + "active_hours": [], + "response_delay_min": , + "response_delay_max": , + "sentiment_bias": <-1.0 to 1.0>, "stance": "", - "influence_weight": <影响力权重> + "influence_weight": }}, ... ] }}""" - system_prompt = "你是社交媒体行为分析专家。返回纯JSON,配置需符合中国人作息习惯。" + system_prompt = "You are a social media behavior analysis expert. Return pure JSON, configuration must follow Chinese work schedule habits." try: result = self._call_llm_with_retry(prompt, system_prompt) llm_configs = {cfg["agent_id"]: cfg for cfg in result.get("agent_configs", [])} except Exception as e: - logger.warning(f"Agent配置批次LLM生成失败: {e}, 使用规则生成") + logger.warning(f"Agent config batch LLM generation failed: {e}, using rule-based generation") llm_configs = {} - # 构建AgentActivityConfig对象 + # Build AgentActivityConfig objects configs = [] for i, entity in enumerate(entities): agent_id = start_idx + i cfg = llm_configs.get(agent_id, {}) - # 如果LLM没有生成,使用规则生成 + # If LLM didn't generate, use rule-based generation if not cfg: cfg = self._generate_agent_config_by_rule(entity) @@ -902,11 +902,11 @@ def _generate_agent_configs_batch( return configs def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: - """基于规则生成单个Agent配置(中国人作息)""" + """Generate single agent configuration based on rules (Chinese work schedule)""" entity_type = (entity.get_entity_type() or "Unknown").lower() if entity_type in ["university", "governmentagency", "ngo"]: - # 官方机构:工作时间活动,低频率,高影响力 + # Official institutions: work hour activity, low frequency, high influence return { "activity_level": 0.2, "posts_per_hour": 0.1, @@ -919,7 +919,7 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 3.0 } elif entity_type in ["mediaoutlet"]: - # 媒体:全天活动,中等频率,高影响力 + # Media: all-day activity, medium frequency, high influence return { "activity_level": 0.5, "posts_per_hour": 0.8, @@ -932,7 +932,7 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 2.5 } elif entity_type in ["professor", "expert", "official"]: - # 专家/教授:工作+晚间活动,中等频率 + # Experts/Professors: work + evening activity, medium frequency return { "activity_level": 0.4, "posts_per_hour": 0.3, @@ -945,12 +945,12 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 2.0 } elif entity_type in ["student"]: - # 学生:晚间为主,高频率 + # Students: mainly evening, high frequency return { "activity_level": 0.8, "posts_per_hour": 0.6, "comments_per_hour": 1.5, - "active_hours": [8, 9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # 上午+晚间 + "active_hours": [8, 9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # Morning + evening "response_delay_min": 1, "response_delay_max": 15, "sentiment_bias": 0.0, @@ -958,12 +958,12 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 0.8 } elif entity_type in ["alumni"]: - # 校友:晚间为主 + # Alumni: mainly evening return { "activity_level": 0.6, "posts_per_hour": 0.4, "comments_per_hour": 0.8, - "active_hours": [12, 13, 19, 20, 21, 22, 23], # 午休+晚间 + "active_hours": [12, 13, 19, 20, 21, 22, 23], # Lunch break + evening "response_delay_min": 5, "response_delay_max": 30, "sentiment_bias": 0.0, @@ -971,12 +971,12 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 1.0 } else: - # 普通人:晚间高峰 + # Ordinary people: evening peak return { "activity_level": 0.7, "posts_per_hour": 0.5, "comments_per_hour": 1.2, - "active_hours": [9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # 白天+晚间 + "active_hours": [9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # Daytime + evening "response_delay_min": 2, "response_delay_max": 20, "sentiment_bias": 0.0, diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index 96c496fd4..569d8c350 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -1,7 +1,7 @@ """ -OASIS模拟管理器 -管理Twitter和Reddit双平台并行模拟 -使用预设脚本 + LLM智能生成配置参数 +OASIS Simulation Manager +Manage Twitter and Reddit dual-platform parallel simulation +Using preset scripts + LLM intelligent configuration parameter generation """ import os @@ -22,60 +22,60 @@ class SimulationStatus(str, Enum): - """模拟状态""" + """Simulation status""" CREATED = "created" PREPARING = "preparing" READY = "ready" RUNNING = "running" PAUSED = "paused" - STOPPED = "stopped" # 模拟被手动停止 - COMPLETED = "completed" # 模拟自然完成 + STOPPED = "stopped" # Simulation manually stopped + COMPLETED = "completed" # Simulation completed naturally FAILED = "failed" class PlatformType(str, Enum): - """平台类型""" + """Platform type""" TWITTER = "twitter" REDDIT = "reddit" @dataclass class SimulationState: - """模拟状态""" + """Simulation state""" simulation_id: str project_id: str graph_id: str - - # 平台启用状态 + + # Platform enable status enable_twitter: bool = True enable_reddit: bool = True - # 状态 + # Status status: SimulationStatus = SimulationStatus.CREATED - - # 准备阶段数据 + + # Preparation stage data entities_count: int = 0 profiles_count: int = 0 entity_types: List[str] = field(default_factory=list) - # 配置生成信息 + # Configuration generation info config_generated: bool = False config_reasoning: str = "" - # 运行时数据 + # Runtime data current_round: int = 0 twitter_status: str = "not_started" reddit_status: str = "not_started" - # 时间戳 + # Timestamps created_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) - # 错误信息 + # Error information error: Optional[str] = None - + def to_dict(self) -> Dict[str, Any]: - """完整状态字典(内部使用)""" + """Full state dictionary (internal use)""" return { "simulation_id": self.simulation_id, "project_id": self.project_id, @@ -97,7 +97,7 @@ def to_dict(self) -> Dict[str, Any]: } def to_simple_dict(self) -> Dict[str, Any]: - """简化状态字典(API返回使用)""" + """Simplified state dictionary (for API response)""" return { "simulation_id": self.simulation_id, "project_id": self.project_id, @@ -113,36 +113,36 @@ def to_simple_dict(self) -> Dict[str, Any]: class SimulationManager: """ - 模拟管理器 - - 核心功能: - 1. 从Zep图谱读取实体并过滤 - 2. 生成OASIS Agent Profile - 3. 使用LLM智能生成模拟配置参数 - 4. 准备预设脚本所需的所有文件 + Simulation Manager + + Core features: + 1. Read and filter entities from Zep knowledge graph + 2. Generate OASIS Agent Profiles + 3. Use LLM to intelligently generate simulation configuration parameters + 4. Prepare all files needed for preset scripts """ - - # 模拟数据存储目录 + + # Simulation data storage directory SIMULATION_DATA_DIR = os.path.join( os.path.dirname(__file__), '../../uploads/simulations' ) def __init__(self): - # 确保目录存在 + # Ensure directory exists os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True) - # 内存中的模拟状态缓存 + # In-memory simulation state cache self._simulations: Dict[str, SimulationState] = {} def _get_simulation_dir(self, simulation_id: str) -> str: - """获取模拟数据目录""" + """Get simulation data directory""" sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id) os.makedirs(sim_dir, exist_ok=True) return sim_dir def _save_simulation_state(self, state: SimulationState): - """保存模拟状态到文件""" + """Save simulation state to file""" sim_dir = self._get_simulation_dir(state.simulation_id) state_file = os.path.join(sim_dir, "state.json") @@ -154,7 +154,7 @@ def _save_simulation_state(self, state: SimulationState): self._simulations[state.simulation_id] = state def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]: - """从文件加载模拟状态""" + """Load simulation state from file""" if simulation_id in self._simulations: return self._simulations[simulation_id] @@ -198,13 +198,13 @@ def create_simulation( enable_reddit: bool = True, ) -> SimulationState: """ - 创建新的模拟 + Create new simulation Args: - project_id: 项目ID - graph_id: Zep图谱ID - enable_twitter: 是否启用Twitter模拟 - enable_reddit: 是否启用Reddit模拟 + project_id: Project ID + graph_id: Zep knowledge graph ID + enable_twitter: Whether to enable Twitter simulation + enable_reddit: Whether to enable Reddit simulation Returns: SimulationState @@ -222,7 +222,7 @@ def create_simulation( ) self._save_simulation_state(state) - logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}") + logger.info(f"Created simulation: {simulation_id}, project={project_id}, graph={graph_id}") return state @@ -237,30 +237,30 @@ def prepare_simulation( parallel_profile_count: int = 3 ) -> SimulationState: """ - 准备模拟环境(全程自动化) + Prepare simulation environment (fully automated) - 步骤: - 1. 从Zep图谱读取并过滤实体 - 2. 为每个实体生成OASIS Agent Profile(可选LLM增强,支持并行) - 3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等) - 4. 保存配置文件和Profile文件 - 5. 复制预设脚本到模拟目录 + Steps: + 1. Read and filter entities from Zep knowledge graph + 2. Generate OASIS Agent Profile for each entity (optional LLM enhancement, parallel support) + 3. Use LLM to intelligently generate simulation config parameters (time, activity, posting frequency, etc.) + 4. Save configuration and Profile files + 5. Copy preset scripts to simulation directory Args: - simulation_id: 模拟ID - simulation_requirement: 模拟需求描述(用于LLM生成配置) - document_text: 原始文档内容(用于LLM理解背景) - defined_entity_types: 预定义的实体类型(可选) - use_llm_for_profiles: 是否使用LLM生成详细人设 - progress_callback: 进度回调函数 (stage, progress, message) - parallel_profile_count: 并行生成人设的数量,默认3 + simulation_id: Simulation ID + simulation_requirement: Simulation requirement description (for LLM configuration generation) + document_text: Original document content (for LLM background understanding) + defined_entity_types: Predefined entity types (optional) + use_llm_for_profiles: Whether to use LLM to generate detailed personas + progress_callback: Progress callback function (stage, progress, message) + parallel_profile_count: Number of parallel persona generations, default 3 Returns: SimulationState """ state = self._load_simulation_state(simulation_id) if not state: - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") try: state.status = SimulationStatus.PREPARING @@ -268,14 +268,14 @@ def prepare_simulation( sim_dir = self._get_simulation_dir(simulation_id) - # ========== 阶段1: 读取并过滤实体 ========== + # ========== Stage 1: Read and filter entities ========== if progress_callback: - progress_callback("reading", 0, "正在连接Zep图谱...") + progress_callback("reading", 0, "Connecting to Zep knowledge graph...") reader = ZepEntityReader() if progress_callback: - progress_callback("reading", 30, "正在读取节点数据...") + progress_callback("reading", 30, "Reading node data...") filtered = reader.filter_defined_entities( graph_id=state.graph_id, @@ -289,29 +289,29 @@ def prepare_simulation( if progress_callback: progress_callback( "reading", 100, - f"完成,共 {filtered.filtered_count} 个实体", + f"Complete, {filtered.filtered_count} entities found", current=filtered.filtered_count, total=filtered.filtered_count ) if filtered.filtered_count == 0: state.status = SimulationStatus.FAILED - state.error = "没有找到符合条件的实体,请检查图谱是否正确构建" + state.error = "No matching entities found, please check if the knowledge graph is correctly built" self._save_simulation_state(state) return state - # ========== 阶段2: 生成Agent Profile ========== + # ========== Stage 2: Generate Agent Profiles ========== total_entities = len(filtered.entities) if progress_callback: progress_callback( "generating_profiles", 0, - "开始生成...", + "Starting generation...", current=0, total=total_entities ) - # 传入graph_id以启用Zep检索功能,获取更丰富的上下文 + # Pass graph_id to enable Zep retrieval for richer context generator = OasisProfileGenerator(graph_id=state.graph_id) def profile_progress(current, total, msg): @@ -325,7 +325,7 @@ def profile_progress(current, total, msg): item_name=msg ) - # 设置实时保存的文件路径(优先使用 Reddit JSON 格式) + # Set real-time save file path (prefer Reddit JSON format) realtime_output_path = None realtime_platform = "reddit" if state.enable_reddit: @@ -339,20 +339,20 @@ def profile_progress(current, total, msg): entities=filtered.entities, use_llm=use_llm_for_profiles, progress_callback=profile_progress, - graph_id=state.graph_id, # 传入graph_id用于Zep检索 - parallel_count=parallel_profile_count, # 并行生成数量 - realtime_output_path=realtime_output_path, # 实时保存路径 - output_platform=realtime_platform # 输出格式 + graph_id=state.graph_id, # Pass graph_id for Zep retrieval + parallel_count=parallel_profile_count, # Parallel generation count + realtime_output_path=realtime_output_path, # Real-time save path + output_platform=realtime_platform # Output format ) state.profiles_count = len(profiles) - # 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式) - # Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性 + # Save Profile files (note: Twitter uses CSV format, Reddit uses JSON format) + # Reddit was already saved in real-time during generation, save again to ensure completeness if progress_callback: progress_callback( "generating_profiles", 95, - "保存Profile文件...", + "Saving Profile files...", current=total_entities, total=total_entities ) @@ -365,7 +365,7 @@ def profile_progress(current, total, msg): ) if state.enable_twitter: - # Twitter使用CSV格式!这是OASIS的要求 + # Twitter uses CSV format! This is an OASIS requirement generator.save_profiles( profiles=profiles, file_path=os.path.join(sim_dir, "twitter_profiles.csv"), @@ -375,16 +375,16 @@ def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_profiles", 100, - f"完成,共 {len(profiles)} 个Profile", + f"Complete, {len(profiles)} Profiles generated", current=len(profiles), total=len(profiles) ) - # ========== 阶段3: LLM智能生成模拟配置 ========== + # ========== Stage 3: LLM intelligent simulation configuration generation ========== if progress_callback: progress_callback( "generating_config", 0, - "正在分析模拟需求...", + "Analyzing simulation requirements...", current=0, total=3 ) @@ -394,7 +394,7 @@ def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_config", 30, - "正在调用LLM生成配置...", + "Calling LLM to generate configuration...", current=1, total=3 ) @@ -413,12 +413,12 @@ def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_config", 70, - "正在保存配置文件...", + "Saving configuration files...", current=2, total=3 ) - # 保存配置文件 + # Save configuration files config_path = os.path.join(sim_dir, "simulation_config.json") with open(config_path, 'w', encoding='utf-8') as f: f.write(sim_params.to_json()) @@ -429,25 +429,25 @@ def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_config", 100, - "配置生成完成", + "Configuration generation complete", current=3, total=3 ) - # 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录 - # 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本 + # Note: Run scripts remain in backend/scripts/ directory, no longer copied to simulation directory + # When starting simulation, simulation_runner will run scripts from scripts/ directory - # 更新状态 + # Update status state.status = SimulationStatus.READY self._save_simulation_state(state) - logger.info(f"模拟准备完成: {simulation_id}, " + logger.info(f"Simulation preparation complete: {simulation_id}, " f"entities={state.entities_count}, profiles={state.profiles_count}") return state except Exception as e: - logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}") + logger.error(f"Simulation preparation failed: {simulation_id}, error={str(e)}") import traceback logger.error(traceback.format_exc()) state.status = SimulationStatus.FAILED @@ -456,16 +456,16 @@ def profile_progress(current, total, msg): raise def get_simulation(self, simulation_id: str) -> Optional[SimulationState]: - """获取模拟状态""" + """Get simulation status""" return self._load_simulation_state(simulation_id) def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]: - """列出所有模拟""" + """List all simulations""" simulations = [] if os.path.exists(self.SIMULATION_DATA_DIR): for sim_id in os.listdir(self.SIMULATION_DATA_DIR): - # 跳过隐藏文件(如 .DS_Store)和非目录文件 + # Skip hidden files (e.g. .DS_Store) and non-directory files sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id) if sim_id.startswith('.') or not os.path.isdir(sim_path): continue @@ -478,10 +478,10 @@ def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationS return simulations def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]: - """获取模拟的Agent Profile""" + """Get simulation Agent Profiles""" state = self._load_simulation_state(simulation_id) if not state: - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") sim_dir = self._get_simulation_dir(simulation_id) profile_path = os.path.join(sim_dir, f"{platform}_profiles.json") @@ -493,7 +493,7 @@ def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dic return json.load(f) def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]: - """获取模拟配置""" + """Get simulation configuration""" sim_dir = self._get_simulation_dir(simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") @@ -504,7 +504,7 @@ def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]: return json.load(f) def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: - """获取运行说明""" + """Get run instructions""" sim_dir = self._get_simulation_dir(simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) @@ -519,10 +519,10 @@ def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: "parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}", }, "instructions": ( - f"1. 激活conda环境: conda activate MiroFish\n" - f"2. 运行模拟 (脚本位于 {scripts_dir}):\n" - f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n" - f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n" - f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}" + f"1. Activate conda environment: conda activate MiroFish\n" + f"2. Run simulation (scripts located at {scripts_dir}):\n" + f" - Run Twitter only: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n" + f" - Run Reddit only: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n" + f" - Run both platforms in parallel: python {scripts_dir}/run_parallel_simulation.py --config {config_path}" ) } diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index 8c35380d1..9794a674e 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -1,6 +1,6 @@ """ -OASIS模拟运行器 -在后台运行模拟并记录每个Agent的动作,支持实时状态监控 +OASIS Simulation Runner +Run simulation in background and record each Agent's actions, with real-time status monitoring """ import os @@ -25,15 +25,15 @@ logger = get_logger('mirofish.simulation_runner') -# 标记是否已注册清理函数 +# Flag whether cleanup function has been registered _cleanup_registered = False -# 平台检测 +# Platform detection IS_WINDOWS = sys.platform == 'win32' class RunnerStatus(str, Enum): - """运行器状态""" + """Runner status""" IDLE = "idle" STARTING = "starting" RUNNING = "running" @@ -46,7 +46,7 @@ class RunnerStatus(str, Enum): @dataclass class AgentAction: - """Agent动作记录""" + """Agent action record""" round_num: int timestamp: str platform: str # twitter / reddit @@ -73,7 +73,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class RoundSummary: - """每轮摘要""" + """Round summary""" round_num: int start_time: str end_time: Optional[str] = None @@ -99,52 +99,52 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class SimulationRunState: - """模拟运行状态(实时)""" + """Simulation run state (real-time)""" simulation_id: str runner_status: RunnerStatus = RunnerStatus.IDLE - # 进度信息 + # Progress information current_round: int = 0 total_rounds: int = 0 simulated_hours: int = 0 total_simulation_hours: int = 0 - # 各平台独立轮次和模拟时间(用于双平台并行显示) + # Per-platform independent rounds and simulation time (for dual-platform parallel display) twitter_current_round: int = 0 reddit_current_round: int = 0 twitter_simulated_hours: int = 0 reddit_simulated_hours: int = 0 - # 平台状态 + # Platform status twitter_running: bool = False reddit_running: bool = False twitter_actions_count: int = 0 reddit_actions_count: int = 0 - # 平台完成状态(通过检测 actions.jsonl 中的 simulation_end 事件) + # Platform completion status (by detecting simulation_end event in actions.jsonl) twitter_completed: bool = False reddit_completed: bool = False - # 每轮摘要 + # Round summaries rounds: List[RoundSummary] = field(default_factory=list) - # 最近动作(用于前端实时展示) + # Recent actions (for frontend real-time display) recent_actions: List[AgentAction] = field(default_factory=list) max_recent_actions: int = 50 - # 时间戳 + # Timestamps started_at: Optional[str] = None updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) completed_at: Optional[str] = None - # 错误信息 + # Error information error: Optional[str] = None - # 进程ID(用于停止) + # Process ID (for stopping) process_pid: Optional[int] = None def add_action(self, action: AgentAction): - """添加动作到最近动作列表""" + """Add action to recent actions list""" self.recent_actions.insert(0, action) if len(self.recent_actions) > self.max_recent_actions: self.recent_actions = self.recent_actions[:self.max_recent_actions] @@ -165,7 +165,7 @@ def to_dict(self) -> Dict[str, Any]: "simulated_hours": self.simulated_hours, "total_simulation_hours": self.total_simulation_hours, "progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1), - # 各平台独立轮次和时间 + # Per-platform independent rounds and time "twitter_current_round": self.twitter_current_round, "reddit_current_round": self.reddit_current_round, "twitter_simulated_hours": self.twitter_simulated_hours, @@ -185,7 +185,7 @@ def to_dict(self) -> Dict[str, Any]: } def to_detail_dict(self) -> Dict[str, Any]: - """包含最近动作的详细信息""" + """Detailed information including recent actions""" result = self.to_dict() result["recent_actions"] = [a.to_dict() for a in self.recent_actions] result["rounds_count"] = len(self.rounds) @@ -194,45 +194,45 @@ def to_detail_dict(self) -> Dict[str, Any]: class SimulationRunner: """ - 模拟运行器 + Simulation Runner - 负责: - 1. 在后台进程中运行OASIS模拟 - 2. 解析运行日志,记录每个Agent的动作 - 3. 提供实时状态查询接口 - 4. 支持暂停/停止/恢复操作 + Responsibilities: + 1. Run OASIS simulation in background process + 2. Parse run logs, record each Agent's actions + 3. Provide real-time status query interface + 4. Support pause/stop/resume operations """ - # 运行状态存储目录 + # Run state storage directory RUN_STATE_DIR = os.path.join( os.path.dirname(__file__), '../../uploads/simulations' ) - # 脚本目录 + # Scripts directory SCRIPTS_DIR = os.path.join( os.path.dirname(__file__), '../../scripts' ) - # 内存中的运行状态 + # In-memory run states _run_states: Dict[str, SimulationRunState] = {} _processes: Dict[str, subprocess.Popen] = {} _action_queues: Dict[str, Queue] = {} _monitor_threads: Dict[str, threading.Thread] = {} - _stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄 - _stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄 + _stdout_files: Dict[str, Any] = {} # Store stdout file handles + _stderr_files: Dict[str, Any] = {} # Store stderr file handles - # 图谱记忆更新配置 + # Graph memory update configuration _graph_memory_enabled: Dict[str, bool] = {} # simulation_id -> enabled @classmethod def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: - """获取运行状态""" + """Get run state""" if simulation_id in cls._run_states: return cls._run_states[simulation_id] - # 尝试从文件加载 + # Try to load from file state = cls._load_run_state(simulation_id) if state: cls._run_states[simulation_id] = state @@ -240,7 +240,7 @@ def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: @classmethod def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: - """从文件加载运行状态""" + """Load run state from file""" state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json") if not os.path.exists(state_file): return None @@ -256,7 +256,7 @@ def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: total_rounds=data.get("total_rounds", 0), simulated_hours=data.get("simulated_hours", 0), total_simulation_hours=data.get("total_simulation_hours", 0), - # 各平台独立轮次和时间 + # Per-platform independent rounds and time twitter_current_round=data.get("twitter_current_round", 0), reddit_current_round=data.get("reddit_current_round", 0), twitter_simulated_hours=data.get("twitter_simulated_hours", 0), @@ -274,7 +274,7 @@ def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: process_pid=data.get("process_pid"), ) - # 加载最近动作 + # Load recent actions actions_data = data.get("recent_actions", []) for a in actions_data: state.recent_actions.append(AgentAction( @@ -291,12 +291,12 @@ def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: return state except Exception as e: - logger.error(f"加载运行状态失败: {str(e)}") + logger.error(f"Failed to load run state: {str(e)}") return None @classmethod def _save_run_state(cls, state: SimulationRunState): - """保存运行状态到文件""" + """Save run state to file""" sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) os.makedirs(sim_dir, exist_ok=True) state_file = os.path.join(sim_dir, "run_state.json") @@ -313,50 +313,50 @@ def start_simulation( cls, simulation_id: str, platform: str = "parallel", # twitter / reddit / parallel - max_rounds: int = None, # 最大模拟轮数(可选,用于截断过长的模拟) - enable_graph_memory_update: bool = False, # 是否将活动更新到Zep图谱 - graph_id: str = None # Zep图谱ID(启用图谱更新时必需) + max_rounds: int = None, # Maximum simulation rounds (optional, to truncate long simulations) + enable_graph_memory_update: bool = False, # Whether to update activities to Zep knowledge graph + graph_id: str = None # Zep graph ID (required when graph update is enabled) ) -> SimulationRunState: """ - 启动模拟 + Start simulation Args: - simulation_id: 模拟ID - platform: 运行平台 (twitter/reddit/parallel) - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) - enable_graph_memory_update: 是否将Agent活动动态更新到Zep图谱 - graph_id: Zep图谱ID(启用图谱更新时必需) + simulation_id: Simulation ID + platform: Running platform (twitter/reddit/parallel) + max_rounds: Maximum simulation rounds (optional, to truncate long simulations) + enable_graph_memory_update: Whether to dynamically update Agent activities to Zep knowledge graph + graph_id: Zep graph ID (required when graph update is enabled) Returns: SimulationRunState """ - # 检查是否已在运行 + # Check if already running existing = cls.get_run_state(simulation_id) if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]: - raise ValueError(f"模拟已在运行中: {simulation_id}") + raise ValueError(f"Simulation already running: {simulation_id}") - # 加载模拟配置 + # Load simulation configuration sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): - raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口") + raise ValueError(f"Simulation configuration not found, please call /prepare endpoint first") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) - # 初始化运行状态 + # Initialize run state time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = int(total_hours * 60 / minutes_per_round) - # 如果指定了最大轮数,则截断 + # If max rounds specified, truncate if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: - logger.info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") + logger.info(f"Rounds truncated: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") state = SimulationRunState( simulation_id=simulation_id, @@ -368,22 +368,22 @@ def start_simulation( cls._save_run_state(state) - # 如果启用图谱记忆更新,创建更新器 + # If graph memory update enabled, create updater if enable_graph_memory_update: if not graph_id: - raise ValueError("启用图谱记忆更新时必须提供 graph_id") + raise ValueError("graph_id must be provided when graph memory update is enabled") try: ZepGraphMemoryManager.create_updater(simulation_id, graph_id) cls._graph_memory_enabled[simulation_id] = True - logger.info(f"已启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}") + logger.info(f"Graph memory update enabled: simulation_id={simulation_id}, graph_id={graph_id}") except Exception as e: - logger.error(f"创建图谱记忆更新器失败: {e}") + logger.error(f"Failed to create graph memory updater: {e}") cls._graph_memory_enabled[simulation_id] = False else: cls._graph_memory_enabled[simulation_id] = False - # 确定运行哪个脚本(脚本位于 backend/scripts/ 目录) + # Determine which script to run (scripts located in backend/scripts/ directory) if platform == "twitter": script_name = "run_twitter_simulation.py" state.twitter_running = True @@ -398,64 +398,64 @@ def start_simulation( script_path = os.path.join(cls.SCRIPTS_DIR, script_name) if not os.path.exists(script_path): - raise ValueError(f"脚本不存在: {script_path}") + raise ValueError(f"Script not found: {script_path}") - # 创建动作队列 + # Create action queue action_queue = Queue() cls._action_queues[simulation_id] = action_queue - # 启动模拟进程 + # Start simulation process try: - # 构建运行命令,使用完整路径 - # 新的日志结构: - # twitter/actions.jsonl - Twitter 动作日志 - # reddit/actions.jsonl - Reddit 动作日志 - # simulation.log - 主进程日志 + # Build run command, using full paths + # New log structure: + # twitter/actions.jsonl - Twitter action log + # reddit/actions.jsonl - Reddit action log + # simulation.log - Main process log cmd = [ - sys.executable, # Python解释器 + sys.executable, # Python interpreter script_path, - "--config", config_path, # 使用完整配置文件路径 + "--config", config_path, # Use full config file path ] - # 如果指定了最大轮数,添加到命令行参数 + # If max rounds specified, add to command line args if max_rounds is not None and max_rounds > 0: cmd.extend(["--max-rounds", str(max_rounds)]) - # 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞 + # Create main log file to avoid stdout/stderr pipe buffer full causing process blocking main_log_path = os.path.join(sim_dir, "simulation.log") main_log_file = open(main_log_path, 'w', encoding='utf-8') - # 设置子进程环境变量,确保 Windows 上使用 UTF-8 编码 - # 这可以修复第三方库(如 OASIS)读取文件时未指定编码的问题 + # Set subprocess env vars to ensure UTF-8 encoding on Windows + # This fixes encoding issues when third-party libraries (e.g. OASIS) read files without specifying encoding env = os.environ.copy() - env['PYTHONUTF8'] = '1' # Python 3.7+ 支持,让所有 open() 默认使用 UTF-8 - env['PYTHONIOENCODING'] = 'utf-8' # 确保 stdout/stderr 使用 UTF-8 + env['PYTHONUTF8'] = '1' # Python 3.7+ support, make all open() use UTF-8 by default + env['PYTHONIOENCODING'] = 'utf-8' # Ensure stdout/stderr use UTF-8 - # 设置工作目录为模拟目录(数据库等文件会生成在此) - # 使用 start_new_session=True 创建新的进程组,确保可以通过 os.killpg 终止所有子进程 + # Set working directory to simulation directory (database files will be generated here) + # Use start_new_session=True to create new process group, ensuring all child processes can be killed via os.killpg process = subprocess.Popen( cmd, cwd=sim_dir, stdout=main_log_file, - stderr=subprocess.STDOUT, # stderr 也写入同一个文件 + stderr=subprocess.STDOUT, # stderr also written to same file text=True, - encoding='utf-8', # 显式指定编码 + encoding='utf-8', # Explicitly specify encoding bufsize=1, - env=env, # 传递带有 UTF-8 设置的环境变量 - start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程 + env=env, # Pass env vars with UTF-8 settings + start_new_session=True, # Create new process group, ensure all related processes can be terminated when server shuts down ) - # 保存文件句柄以便后续关闭 + # Save file handles for later closing cls._stdout_files[simulation_id] = main_log_file - cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr + cls._stderr_files[simulation_id] = None # No longer need separate stderr state.process_pid = process.pid state.runner_status = RunnerStatus.RUNNING cls._processes[simulation_id] = process cls._save_run_state(state) - # 启动监控线程 + # Start monitoring thread monitor_thread = threading.Thread( target=cls._monitor_simulation, args=(simulation_id,), @@ -464,7 +464,7 @@ def start_simulation( monitor_thread.start() cls._monitor_threads[simulation_id] = monitor_thread - logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}") + logger.info(f"Simulation started successfully: {simulation_id}, pid={process.pid}, platform={platform}") except Exception as e: state.runner_status = RunnerStatus.FAILED @@ -476,10 +476,10 @@ def start_simulation( @classmethod def _monitor_simulation(cls, simulation_id: str): - """监控模拟进程,解析动作日志""" + """Monitor simulation process, parse action logs""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) - # 新的日志结构:分平台的动作日志 + # New log structure: per-platform action logs twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") @@ -493,75 +493,75 @@ def _monitor_simulation(cls, simulation_id: str): reddit_position = 0 try: - while process.poll() is None: # 进程仍在运行 - # 读取 Twitter 动作日志 + while process.poll() is None: # Process still running + # Read Twitter action log if os.path.exists(twitter_actions_log): twitter_position = cls._read_action_log( twitter_actions_log, twitter_position, state, "twitter" ) - # 读取 Reddit 动作日志 + # Read Reddit action log if os.path.exists(reddit_actions_log): reddit_position = cls._read_action_log( reddit_actions_log, reddit_position, state, "reddit" ) - # 更新状态 + # Update state cls._save_run_state(state) time.sleep(2) - # 进程结束后,最后读取一次日志 + # After process ends, read logs one final time if os.path.exists(twitter_actions_log): cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter") if os.path.exists(reddit_actions_log): cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit") - # 进程结束 + # Process ended exit_code = process.returncode if exit_code == 0: state.runner_status = RunnerStatus.COMPLETED state.completed_at = datetime.now().isoformat() - logger.info(f"模拟完成: {simulation_id}") + logger.info(f"Simulation completed: {simulation_id}") else: state.runner_status = RunnerStatus.FAILED - # 从主日志文件读取错误信息 + # Read error info from main log file main_log_path = os.path.join(sim_dir, "simulation.log") error_info = "" try: if os.path.exists(main_log_path): with open(main_log_path, 'r', encoding='utf-8') as f: - error_info = f.read()[-2000:] # 取最后2000字符 + error_info = f.read()[-2000:] # Get last 2000 characters except Exception: pass - state.error = f"进程退出码: {exit_code}, 错误: {error_info}" - logger.error(f"模拟失败: {simulation_id}, error={state.error}") + state.error = f"Process exit code: {exit_code}, error: {error_info}" + logger.error(f"Simulation failed: {simulation_id}, error={state.error}") state.twitter_running = False state.reddit_running = False cls._save_run_state(state) except Exception as e: - logger.error(f"监控线程异常: {simulation_id}, error={str(e)}") + logger.error(f"Monitor thread exception: {simulation_id}, error={str(e)}") state.runner_status = RunnerStatus.FAILED state.error = str(e) cls._save_run_state(state) finally: - # 停止图谱记忆更新器 + # Stop graph memory updater if cls._graph_memory_enabled.get(simulation_id, False): try: ZepGraphMemoryManager.stop_updater(simulation_id) - logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}") + logger.info(f"Stopped graph memory update: simulation_id={simulation_id}") except Exception as e: - logger.error(f"停止图谱记忆更新器失败: {e}") + logger.error(f"Failed to stop graph memory updater: {e}") cls._graph_memory_enabled.pop(simulation_id, None) - # 清理进程资源 + # Clean up process resources cls._processes.pop(simulation_id, None) cls._action_queues.pop(simulation_id, None) - # 关闭日志文件句柄 + # Close log file handles if simulation_id in cls._stdout_files: try: cls._stdout_files[simulation_id].close() @@ -584,18 +584,18 @@ def _read_action_log( platform: str ) -> int: """ - 读取动作日志文件 + Read action log file Args: - log_path: 日志文件路径 - position: 上次读取位置 - state: 运行状态对象 - platform: 平台名称 (twitter/reddit) + log_path: Log file path + position: Last read position + state: Run state object + platform: Platform name (twitter/reddit) Returns: - 新的读取位置 + New read position """ - # 检查是否启用了图谱记忆更新 + # Check if graph memory update is enabled graph_memory_enabled = cls._graph_memory_enabled.get(state.simulation_id, False) graph_updater = None if graph_memory_enabled: @@ -610,36 +610,36 @@ def _read_action_log( try: action_data = json.loads(line) - # 处理事件类型的条目 + # Process event type entries if "event_type" in action_data: event_type = action_data.get("event_type") - # 检测 simulation_end 事件,标记平台已完成 + # Detect simulation_end event, mark platform as completed if event_type == "simulation_end": if platform == "twitter": state.twitter_completed = True state.twitter_running = False - logger.info(f"Twitter 模拟已完成: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") + logger.info(f"Twitter simulation completed: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") elif platform == "reddit": state.reddit_completed = True state.reddit_running = False - logger.info(f"Reddit 模拟已完成: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") + logger.info(f"Reddit simulation completed: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") - # 检查是否所有启用的平台都已完成 - # 如果只运行了一个平台,只检查那个平台 - # 如果运行了两个平台,需要两个都完成 + # Check if all enabled platforms have completed + # If only one platform was run, only check that platform + # If two platforms were run, both need to be completed all_completed = cls._check_all_platforms_completed(state) if all_completed: state.runner_status = RunnerStatus.COMPLETED state.completed_at = datetime.now().isoformat() - logger.info(f"所有平台模拟已完成: {state.simulation_id}") + logger.info(f"All platform simulations completed: {state.simulation_id}") - # 更新轮次信息(从 round_end 事件) + # Update round info (from round_end event) elif event_type == "round_end": round_num = action_data.get("round", 0) simulated_hours = action_data.get("simulated_hours", 0) - # 更新各平台独立的轮次和时间 + # Update per-platform independent rounds and time if platform == "twitter": if round_num > state.twitter_current_round: state.twitter_current_round = round_num @@ -649,10 +649,10 @@ def _read_action_log( state.reddit_current_round = round_num state.reddit_simulated_hours = simulated_hours - # 总体轮次取两个平台的最大值 + # Overall round takes maximum of both platforms if round_num > state.current_round: state.current_round = round_num - # 总体时间取两个平台的最大值 + # Overall time takes maximum of both platforms state.simulated_hours = max(state.twitter_simulated_hours, state.reddit_simulated_hours) continue @@ -670,11 +670,11 @@ def _read_action_log( ) state.add_action(action) - # 更新轮次 + # Update rounds if action.round_num and action.round_num > state.current_round: state.current_round = action.round_num - # 如果启用了图谱记忆更新,将活动发送到Zep + # If graph memory update enabled, send activities to Zep if graph_updater: graph_updater.add_activity_from_dict(action_data, platform) @@ -682,52 +682,52 @@ def _read_action_log( pass return f.tell() except Exception as e: - logger.warning(f"读取动作日志失败: {log_path}, error={e}") + logger.warning(f"Failed to read action log: {log_path}, error={e}") return position @classmethod def _check_all_platforms_completed(cls, state: SimulationRunState) -> bool: """ - 检查所有启用的平台是否都已完成模拟 + Check if all enabled platforms have completed simulation - 通过检查对应的 actions.jsonl 文件是否存在来判断平台是否被启用 + Determine if a platform is enabled by checking whether the corresponding actions.jsonl file exists Returns: - True 如果所有启用的平台都已完成 + True if all enabled platforms have completed """ sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) twitter_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_log = os.path.join(sim_dir, "reddit", "actions.jsonl") - # 检查哪些平台被启用(通过文件是否存在判断) + # Check which platforms are enabled (by file existence) twitter_enabled = os.path.exists(twitter_log) reddit_enabled = os.path.exists(reddit_log) - # 如果平台被启用但未完成,则返回 False + # If platform is enabled but not completed, return False if twitter_enabled and not state.twitter_completed: return False if reddit_enabled and not state.reddit_completed: return False - # 至少有一个平台被启用且已完成 + # At least one platform is enabled and completed return twitter_enabled or reddit_enabled @classmethod def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeout: int = 10): """ - 跨平台终止进程及其子进程 + Cross-platform terminate process and its child processes Args: - process: 要终止的进程 - simulation_id: 模拟ID(用于日志) - timeout: 等待进程退出的超时时间(秒) + process: Process to terminate + simulation_id: Simulation ID (for logging) + timeout: Timeout for waiting process to exit (seconds) """ if IS_WINDOWS: - # Windows: 使用 taskkill 命令终止进程树 - # /F = 强制终止, /T = 终止进程树(包括子进程) - logger.info(f"终止进程树 (Windows): simulation={simulation_id}, pid={process.pid}") + # Windows: Use taskkill command to terminate process tree + # /F = Force terminate, /T = Terminate process tree (including child processes) + logger.info(f"Terminating process tree (Windows): simulation={simulation_id}, pid={process.pid}") try: - # 先尝试优雅终止 + # Try graceful termination first subprocess.run( ['taskkill', '/PID', str(process.pid), '/T'], capture_output=True, @@ -736,8 +736,8 @@ def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeo try: process.wait(timeout=timeout) except subprocess.TimeoutExpired: - # 强制终止 - logger.warning(f"进程未响应,强制终止: {simulation_id}") + # Force terminate + logger.warning(f"Process not responding, force terminating: {simulation_id}") subprocess.run( ['taskkill', '/F', '/PID', str(process.pid), '/T'], capture_output=True, @@ -745,53 +745,53 @@ def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeo ) process.wait(timeout=5) except Exception as e: - logger.warning(f"taskkill 失败,尝试 terminate: {e}") + logger.warning(f"taskkill failed, trying terminate: {e}") process.terminate() try: process.wait(timeout=5) except subprocess.TimeoutExpired: process.kill() else: - # Unix: 使用进程组终止 - # 由于使用了 start_new_session=True,进程组 ID 等于主进程 PID + # Unix: Use process group termination + # Since start_new_session=True was used, process group ID equals main process PID pgid = os.getpgid(process.pid) - logger.info(f"终止进程组 (Unix): simulation={simulation_id}, pgid={pgid}") + logger.info(f"Terminating process group (Unix): simulation={simulation_id}, pgid={pgid}") - # 先发送 SIGTERM 给整个进程组 + # Send SIGTERM to entire process group first os.killpg(pgid, signal.SIGTERM) try: process.wait(timeout=timeout) except subprocess.TimeoutExpired: - # 如果超时后还没结束,强制发送 SIGKILL - logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}") + # If not ended after timeout, force send SIGKILL + logger.warning(f"Process group not responding to SIGTERM, force terminating: {simulation_id}") os.killpg(pgid, signal.SIGKILL) process.wait(timeout=5) @classmethod def stop_simulation(cls, simulation_id: str) -> SimulationRunState: - """停止模拟""" + """Stop simulation""" state = cls.get_run_state(simulation_id) if not state: - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]: - raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}") + raise ValueError(f"Simulation not running: {simulation_id}, status={state.runner_status}") state.runner_status = RunnerStatus.STOPPING cls._save_run_state(state) - # 终止进程 + # Terminate process process = cls._processes.get(simulation_id) if process and process.poll() is None: try: cls._terminate_process(process, simulation_id) except ProcessLookupError: - # 进程已经不存在 + # Process no longer exists pass except Exception as e: - logger.error(f"终止进程组失败: {simulation_id}, error={e}") - # 回退到直接终止进程 + logger.error(f"Failed to terminate process group: {simulation_id}, error={e}") + # Fall back to direct process termination try: process.terminate() process.wait(timeout=5) @@ -804,16 +804,16 @@ def stop_simulation(cls, simulation_id: str) -> SimulationRunState: state.completed_at = datetime.now().isoformat() cls._save_run_state(state) - # 停止图谱记忆更新器 + # Stop graph memory updater if cls._graph_memory_enabled.get(simulation_id, False): try: ZepGraphMemoryManager.stop_updater(simulation_id) - logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}") + logger.info(f"Stopped graph memory update: simulation_id={simulation_id}") except Exception as e: - logger.error(f"停止图谱记忆更新器失败: {e}") + logger.error(f"Failed to stop graph memory updater: {e}") cls._graph_memory_enabled.pop(simulation_id, None) - logger.info(f"模拟已停止: {simulation_id}") + logger.info(f"Simulation stopped: {simulation_id}") return state @classmethod @@ -826,14 +826,14 @@ def _read_actions_from_file( round_num: Optional[int] = None ) -> List[AgentAction]: """ - 从单个动作文件中读取动作 + Read actions from a single action file Args: - file_path: 动作日志文件路径 - default_platform: 默认平台(当动作记录中没有 platform 字段时使用) - platform_filter: 过滤平台 - agent_id: 过滤 Agent ID - round_num: 过滤轮次 + file_path: Action log file path + default_platform: Default platform (used when action record has no platform field) + platform_filter: Platform filter + agent_id: Agent ID filter + round_num: Round filter """ if not os.path.exists(file_path): return [] @@ -849,18 +849,18 @@ def _read_actions_from_file( try: data = json.loads(line) - # 跳过非动作记录(如 simulation_start, round_start, round_end 等事件) + # Skip non-action records (e.g. simulation_start, round_start, round_end events) if "event_type" in data: continue - # 跳过没有 agent_id 的记录(非 Agent 动作) + # Skip records without agent_id (non-Agent actions) if "agent_id" not in data: continue - # 获取平台:优先使用记录中的 platform,否则使用默认平台 + # Get platform: prefer platform from record, otherwise use default record_platform = data.get("platform") or default_platform or "" - # 过滤 + # Filter if platform_filter and record_platform != platform_filter: continue if agent_id is not None and data.get("agent_id") != agent_id: @@ -894,54 +894,54 @@ def get_all_actions( round_num: Optional[int] = None ) -> List[AgentAction]: """ - 获取所有平台的完整动作历史(无分页限制) + Get complete action history for all platforms (no pagination limit) Args: - simulation_id: 模拟ID - platform: 过滤平台(twitter/reddit) - agent_id: 过滤Agent - round_num: 过滤轮次 + simulation_id: Simulation ID + platform: Platform filter (twitter/reddit) + agent_id: Agent filter + round_num: Round filter Returns: - 完整的动作列表(按时间戳排序,新的在前) + Complete action list (sorted by timestamp, newest first) """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) actions = [] - # 读取 Twitter 动作文件(根据文件路径自动设置 platform 为 twitter) + # Read Twitter action file (auto-set platform to twitter based on file path) twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") if not platform or platform == "twitter": actions.extend(cls._read_actions_from_file( twitter_actions_log, - default_platform="twitter", # 自动填充 platform 字段 + default_platform="twitter", # Auto-fill platform field platform_filter=platform, agent_id=agent_id, round_num=round_num )) - # 读取 Reddit 动作文件(根据文件路径自动设置 platform 为 reddit) + # Read Reddit action file (auto-set platform to reddit based on file path) reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") if not platform or platform == "reddit": actions.extend(cls._read_actions_from_file( reddit_actions_log, - default_platform="reddit", # 自动填充 platform 字段 + default_platform="reddit", # Auto-fill platform field platform_filter=platform, agent_id=agent_id, round_num=round_num )) - # 如果分平台文件不存在,尝试读取旧的单一文件格式 + # If per-platform files don't exist, try reading legacy single file format if not actions: actions_log = os.path.join(sim_dir, "actions.jsonl") actions = cls._read_actions_from_file( actions_log, - default_platform=None, # 旧格式文件中应该有 platform 字段 + default_platform=None, # Legacy format files should have platform field platform_filter=platform, agent_id=agent_id, round_num=round_num ) - # 按时间戳排序(新的在前) + # Sort by timestamp (newest first) actions.sort(key=lambda x: x.timestamp, reverse=True) return actions @@ -957,18 +957,18 @@ def get_actions( round_num: Optional[int] = None ) -> List[AgentAction]: """ - 获取动作历史(带分页) + Get action history (with pagination) Args: - simulation_id: 模拟ID - limit: 返回数量限制 - offset: 偏移量 - platform: 过滤平台 - agent_id: 过滤Agent - round_num: 过滤轮次 + simulation_id: Simulation ID + limit: Return count limit + offset: Offset + platform: Platform filter + agent_id: Agent filter + round_num: Round filter Returns: - 动作列表 + Action list """ actions = cls.get_all_actions( simulation_id=simulation_id, @@ -977,7 +977,7 @@ def get_actions( round_num=round_num ) - # 分页 + # Pagination return actions[offset:offset + limit] @classmethod @@ -988,19 +988,19 @@ def get_timeline( end_round: Optional[int] = None ) -> List[Dict[str, Any]]: """ - 获取模拟时间线(按轮次汇总) + Get simulation timeline (summarized by round) Args: - simulation_id: 模拟ID - start_round: 起始轮次 - end_round: 结束轮次 + simulation_id: Simulation ID + start_round: Start round + end_round: End round Returns: - 每轮的汇总信息 + Summary information per round """ actions = cls.get_actions(simulation_id, limit=10000) - # 按轮次分组 + # Group by round rounds: Dict[int, Dict[str, Any]] = {} for action in actions: @@ -1033,7 +1033,7 @@ def get_timeline( r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1 r["last_action_time"] = action.timestamp - # 转换为列表 + # Convert to list result = [] for round_num in sorted(rounds.keys()): r = rounds[round_num] @@ -1054,10 +1054,10 @@ def get_timeline( @classmethod def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]: """ - 获取每个Agent的统计信息 + Get statistics for each Agent Returns: - Agent统计列表 + Agent statistics list """ actions = cls.get_actions(simulation_id, limit=10000) @@ -1089,7 +1089,7 @@ def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]: stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1 stats["last_action_time"] = action.timestamp - # 按总动作数排序 + # Sort by total action count result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True) return result @@ -1097,51 +1097,51 @@ def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]: @classmethod def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: """ - 清理模拟的运行日志(用于强制重新开始模拟) + Clean up simulation run logs (for force restarting simulation) - 会删除以下文件: + Will delete the following files: - run_state.json - twitter/actions.jsonl - reddit/actions.jsonl - simulation.log - stdout.log / stderr.log - - twitter_simulation.db(模拟数据库) - - reddit_simulation.db(模拟数据库) - - env_status.json(环境状态) + - twitter_simulation.db (simulation database) + - reddit_simulation.db (simulation database) + - env_status.json (environment status) - 注意:不会删除配置文件(simulation_config.json)和 profile 文件 + Note: Will not delete config files (simulation_config.json) and profile files Args: - simulation_id: 模拟ID + simulation_id: Simulation ID Returns: - 清理结果信息 + Cleanup result information """ import shutil sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - return {"success": True, "message": "模拟目录不存在,无需清理"} + return {"success": True, "message": "Simulation directory does not exist, no cleanup needed"} cleaned_files = [] errors = [] - # 要删除的文件列表(包括数据库文件) + # Files to delete (including database files) files_to_delete = [ "run_state.json", "simulation.log", "stdout.log", "stderr.log", - "twitter_simulation.db", # Twitter 平台数据库 - "reddit_simulation.db", # Reddit 平台数据库 - "env_status.json", # 环境状态文件 + "twitter_simulation.db", # Twitter platform database + "reddit_simulation.db", # Reddit platform database + "env_status.json", # Environment status file ] - # 要删除的目录列表(包含动作日志) + # Directories to clean (containing action logs) dirs_to_clean = ["twitter", "reddit"] - # 删除文件 + # Delete files for filename in files_to_delete: file_path = os.path.join(sim_dir, filename) if os.path.exists(file_path): @@ -1149,9 +1149,9 @@ def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: os.remove(file_path) cleaned_files.append(filename) except Exception as e: - errors.append(f"删除 {filename} 失败: {str(e)}") + errors.append(f"Failed to delete {filename}: {str(e)}") - # 清理平台目录中的动作日志 + # Clean up action logs in platform directories for dir_name in dirs_to_clean: dir_path = os.path.join(sim_dir, dir_name) if os.path.exists(dir_path): @@ -1161,13 +1161,13 @@ def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: os.remove(actions_file) cleaned_files.append(f"{dir_name}/actions.jsonl") except Exception as e: - errors.append(f"删除 {dir_name}/actions.jsonl 失败: {str(e)}") + errors.append(f"Failed to delete {dir_name}/actions.jsonl: {str(e)}") - # 清理内存中的运行状态 + # Clean up in-memory run states if simulation_id in cls._run_states: del cls._run_states[simulation_id] - logger.info(f"清理模拟日志完成: {simulation_id}, 删除文件: {cleaned_files}") + logger.info(f"Simulation log cleanup complete: {simulation_id}, deleted files: {cleaned_files}") return { "success": len(errors) == 0, @@ -1175,71 +1175,71 @@ def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: "errors": errors if errors else None } - # 防止重复清理的标志 + # Flag to prevent duplicate cleanup _cleanup_done = False @classmethod def cleanup_all_simulations(cls): """ - 清理所有运行中的模拟进程 + Clean up all running simulation processes - 在服务器关闭时调用,确保所有子进程被终止 + Called when server shuts down, ensuring all child processes are terminated """ - # 防止重复清理 + # Prevent duplicate cleanup if cls._cleanup_done: return cls._cleanup_done = True - # 检查是否有内容需要清理(避免空进程的进程打印无用日志) + # Check if there is content to clean (avoid printing useless logs for empty processes) has_processes = bool(cls._processes) has_updaters = bool(cls._graph_memory_enabled) if not has_processes and not has_updaters: - return # 没有需要清理的内容,静默返回 + return # Nothing to clean, return silently - logger.info("正在清理所有模拟进程...") + logger.info("Cleaning up all simulation processes...") - # 首先停止所有图谱记忆更新器(stop_all 内部会打印日志) + # First stop all graph memory updaters (stop_all will print logs internally) try: ZepGraphMemoryManager.stop_all() except Exception as e: - logger.error(f"停止图谱记忆更新器失败: {e}") + logger.error(f"Failed to stop graph memory updater: {e}") cls._graph_memory_enabled.clear() - # 复制字典以避免在迭代时修改 + # Copy dict to avoid modification during iteration processes = list(cls._processes.items()) for simulation_id, process in processes: try: - if process.poll() is None: # 进程仍在运行 - logger.info(f"终止模拟进程: {simulation_id}, pid={process.pid}") + if process.poll() is None: # Process still running + logger.info(f"Terminating simulation process: {simulation_id}, pid={process.pid}") try: - # 使用跨平台的进程终止方法 + # Use cross-platform process termination method cls._terminate_process(process, simulation_id, timeout=5) except (ProcessLookupError, OSError): - # 进程可能已经不存在,尝试直接终止 + # Process may no longer exist, try direct termination try: process.terminate() process.wait(timeout=3) except Exception: process.kill() - # 更新 run_state.json + # Update run_state.json state = cls.get_run_state(simulation_id) if state: state.runner_status = RunnerStatus.STOPPED state.twitter_running = False state.reddit_running = False state.completed_at = datetime.now().isoformat() - state.error = "服务器关闭,模拟被终止" + state.error = "Server shutdown, simulation terminated" cls._save_run_state(state) - # 同时更新 state.json,将状态设为 stopped + # Also update state.json, set status to stopped try: sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) state_file = os.path.join(sim_dir, "state.json") - logger.info(f"尝试更新 state.json: {state_file}") + logger.info(f"Trying to update state.json: {state_file}") if os.path.exists(state_file): with open(state_file, 'r', encoding='utf-8') as f: state_data = json.load(f) @@ -1247,16 +1247,16 @@ def cleanup_all_simulations(cls): state_data['updated_at'] = datetime.now().isoformat() with open(state_file, 'w', encoding='utf-8') as f: json.dump(state_data, f, indent=2, ensure_ascii=False) - logger.info(f"已更新 state.json 状态为 stopped: {simulation_id}") + logger.info(f"Updated state.json status to stopped: {simulation_id}") else: - logger.warning(f"state.json 不存在: {state_file}") + logger.warning(f"state.json does not exist: {state_file}") except Exception as state_err: - logger.warning(f"更新 state.json 失败: {simulation_id}, error={state_err}") + logger.warning(f"Failed to update state.json: {simulation_id}, error={state_err}") except Exception as e: - logger.error(f"清理进程失败: {simulation_id}, error={e}") + logger.error(f"Failed to clean up process: {simulation_id}, error={e}") - # 清理文件句柄 + # Clean up file handles for simulation_id, file_handle in list(cls._stdout_files.items()): try: if file_handle: @@ -1273,89 +1273,89 @@ def cleanup_all_simulations(cls): pass cls._stderr_files.clear() - # 清理内存中的状态 + # Clean up in-memory states cls._processes.clear() cls._action_queues.clear() - logger.info("模拟进程清理完成") + logger.info("Simulation process cleanup complete") @classmethod def register_cleanup(cls): """ - 注册清理函数 + Register cleanup functions - 在 Flask 应用启动时调用,确保服务器关闭时清理所有模拟进程 + Called when Flask app starts, ensuring all simulation processes are cleaned up when server shuts down """ global _cleanup_registered if _cleanup_registered: return - # Flask debug 模式下,只在 reloader 子进程中注册清理(实际运行应用的进程) - # WERKZEUG_RUN_MAIN=true 表示是 reloader 子进程 - # 如果不是 debug 模式,则没有这个环境变量,也需要注册 + # In Flask debug mode, only register cleanup in reloader subprocess (the process actually running the app) + # WERKZEUG_RUN_MAIN=true indicates reloader subprocess + # If not debug mode, this env var doesn't exist, also need to register is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' is_debug_mode = os.environ.get('FLASK_DEBUG') == '1' or os.environ.get('WERKZEUG_RUN_MAIN') is not None - # 在 debug 模式下,只在 reloader 子进程中注册;非 debug 模式下始终注册 + # In debug mode, only register in reloader subprocess; in non-debug mode always register if is_debug_mode and not is_reloader_process: - _cleanup_registered = True # 标记已注册,防止子进程再次尝试 + _cleanup_registered = True # Mark as registered, prevent subprocess from trying again return - # 保存原有的信号处理器 + # Save original signal handlers original_sigint = signal.getsignal(signal.SIGINT) original_sigterm = signal.getsignal(signal.SIGTERM) - # SIGHUP 只在 Unix 系统存在(macOS/Linux),Windows 没有 + # SIGHUP only exists on Unix systems (macOS/Linux), not Windows original_sighup = None has_sighup = hasattr(signal, 'SIGHUP') if has_sighup: original_sighup = signal.getsignal(signal.SIGHUP) def cleanup_handler(signum=None, frame=None): - """信号处理器:先清理模拟进程,再调用原处理器""" - # 只有在有进程需要清理时才打印日志 + """Signal handler: clean up simulation processes first, then call original handler""" + # Only print logs when there are processes to clean if cls._processes or cls._graph_memory_enabled: - logger.info(f"收到信号 {signum},开始清理...") + logger.info(f"Received signal {signum},starting cleanup...") cls.cleanup_all_simulations() - # 调用原有的信号处理器,让 Flask 正常退出 + # Call original signal handler to let Flask exit normally if signum == signal.SIGINT and callable(original_sigint): original_sigint(signum, frame) elif signum == signal.SIGTERM and callable(original_sigterm): original_sigterm(signum, frame) elif has_sighup and signum == signal.SIGHUP: - # SIGHUP: 终端关闭时发送 + # SIGHUP: Sent when terminal closes if callable(original_sighup): original_sighup(signum, frame) else: - # 默认行为:正常退出 + # Default behavior: exit normally sys.exit(0) else: - # 如果原处理器不可调用(如 SIG_DFL),则使用默认行为 + # If original handler is not callable (e.g. SIG_DFL), use default behavior raise KeyboardInterrupt - # 注册 atexit 处理器(作为备用) + # Register atexit handler (as fallback) atexit.register(cls.cleanup_all_simulations) - # 注册信号处理器(仅在主线程中) + # Register signal handlers (only in main thread) try: - # SIGTERM: kill 命令默认信号 + # SIGTERM: Default signal for kill command signal.signal(signal.SIGTERM, cleanup_handler) # SIGINT: Ctrl+C signal.signal(signal.SIGINT, cleanup_handler) - # SIGHUP: 终端关闭(仅 Unix 系统) + # SIGHUP: Terminal close (Unix systems only) if has_sighup: signal.signal(signal.SIGHUP, cleanup_handler) except ValueError: - # 不在主线程中,只能使用 atexit - logger.warning("无法注册信号处理器(不在主线程),仅使用 atexit") + # Not in main thread, can only use atexit + logger.warning("Cannot register signal handlers (not in main thread), using atexit only") _cleanup_registered = True @classmethod def get_running_simulations(cls) -> List[str]: """ - 获取所有正在运行的模拟ID列表 + Get list of all running simulation IDs """ running = [] for sim_id, process in cls._processes.items(): @@ -1363,18 +1363,18 @@ def get_running_simulations(cls) -> List[str]: running.append(sim_id) return running - # ============== Interview 功能 ============== + # ============== Interview Features ============== @classmethod def check_env_alive(cls, simulation_id: str) -> bool: """ - 检查模拟环境是否存活(可以接收Interview命令) + Check if simulation environment is alive (can receive Interview commands) Args: - simulation_id: 模拟ID + simulation_id: Simulation ID Returns: - True 表示环境存活,False 表示环境已关闭 + True means environment is alive, False means environment is closed """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): @@ -1386,13 +1386,13 @@ def check_env_alive(cls, simulation_id: str) -> bool: @classmethod def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]: """ - 获取模拟环境的详细状态信息 + Get detailed status info of simulation environment Args: - simulation_id: 模拟ID + simulation_id: Simulation ID Returns: - 状态详情字典,包含 status, twitter_available, reddit_available, timestamp + Status details dict containing status, twitter_available, reddit_available, timestamp """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) status_file = os.path.join(sim_dir, "env_status.json") @@ -1429,35 +1429,35 @@ def interview_agent( timeout: float = 60.0 ) -> Dict[str, Any]: """ - 采访单个Agent + Interview single Agent Args: - simulation_id: 模拟ID + simulation_id: Simulation ID agent_id: Agent ID - prompt: 采访问题 - platform: 指定平台(可选) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None: 双平台模拟时同时采访两个平台,返回整合结果 - timeout: 超时时间(秒) + prompt: Interview question + platform: Specify platform (optional) + - "twitter": Interview Twitter platform only + - "reddit": Interview Reddit platform only + - None: Interview both platforms simultaneously in dual-platform simulation, return integrated results + timeout: Timeout (seconds) Returns: - 采访结果字典 + Interview result dictionary Raises: - ValueError: 模拟不存在或环境未运行 - TimeoutError: 等待响应超时 + ValueError: Simulation not found or environment not running + TimeoutError: Response timeout """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): - raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") + raise ValueError(f"Simulation environment not running or closed, cannot perform Interview: {simulation_id}") - logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}") + logger.info(f"Sending Interview command: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}") response = ipc_client.send_interview( agent_id=agent_id, @@ -1492,34 +1492,34 @@ def interview_agents_batch( timeout: float = 120.0 ) -> Dict[str, Any]: """ - 批量采访多个Agent + Batch interview multiple Agents Args: - simulation_id: 模拟ID - interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} - platform: 默认平台(可选,会被每个采访项的platform覆盖) - - "twitter": 默认只采访Twitter平台 - - "reddit": 默认只采访Reddit平台 - - None: 双平台模拟时每个Agent同时采访两个平台 - timeout: 超时时间(秒) + simulation_id: Simulation ID + interviews: Interview list, each element contains {"agent_id": int, "prompt": str, "platform": str(optional)} + platform: Default platform (optional, overridden by each interview item's platform) + - "twitter": Default interview Twitter platform only + - "reddit": Default interview Reddit platform only + - None: Interview each Agent on both platforms simultaneously in dual-platform simulation + timeout: Timeout (seconds) Returns: - 批量采访结果字典 + Batch Interview result dictionary Raises: - ValueError: 模拟不存在或环境未运行 - TimeoutError: 等待响应超时 + ValueError: Simulation not found or environment not running + TimeoutError: Response timeout """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): - raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") + raise ValueError(f"Simulation environment not running or closed, cannot perform Interview: {simulation_id}") - logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}") + logger.info(f"Sending batch Interview command: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}") response = ipc_client.send_batch_interview( interviews=interviews, @@ -1551,39 +1551,39 @@ def interview_all_agents( timeout: float = 180.0 ) -> Dict[str, Any]: """ - 采访所有Agent(全局采访) + Interview all Agents (global interview) - 使用相同的问题采访模拟中的所有Agent + Interview all Agents in simulation with the same question Args: - simulation_id: 模拟ID - prompt: 采访问题(所有Agent使用相同问题) - platform: 指定平台(可选) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None: 双平台模拟时每个Agent同时采访两个平台 - timeout: 超时时间(秒) + simulation_id: Simulation ID + prompt: Interview question (same question for all Agents) + platform: Specify platform (optional) + - "twitter": Interview Twitter platform only + - "reddit": Interview Reddit platform only + - None: Interview each Agent on both platforms simultaneously in dual-platform simulation + timeout: Timeout (seconds) Returns: - 全局采访结果字典 + Global Interview result dictionary """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") - # 从配置文件获取所有Agent信息 + # Get all Agent info from config file config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): - raise ValueError(f"模拟配置不存在: {simulation_id}") + raise ValueError(f"Simulation configuration not found: {simulation_id}") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) agent_configs = config.get("agent_configs", []) if not agent_configs: - raise ValueError(f"模拟配置中没有Agent: {simulation_id}") + raise ValueError(f"No Agents in simulation configuration: {simulation_id}") - # 构建批量采访列表 + # Build batch interview list interviews = [] for agent_config in agent_configs: agent_id = agent_config.get("agent_id") @@ -1593,7 +1593,7 @@ def interview_all_agents( "prompt": prompt }) - logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}") + logger.info(f"Sending global Interview command: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}") return cls.interview_agents_batch( simulation_id=simulation_id, @@ -1609,45 +1609,45 @@ def close_simulation_env( timeout: float = 30.0 ) -> Dict[str, Any]: """ - 关闭模拟环境(而不是停止模拟进程) + Close simulation environment (without stopping simulation process) - 向模拟发送关闭环境命令,使其优雅退出等待命令模式 + Send close environment command to simulation for graceful exit from command waiting mode Args: - simulation_id: 模拟ID - timeout: 超时时间(秒) + simulation_id: Simulation ID + timeout: Timeout (seconds) Returns: - 操作结果字典 + Operation result dictionary """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): return { "success": True, - "message": "环境已经关闭" + "message": "Environment already closed" } - logger.info(f"发送关闭环境命令: simulation_id={simulation_id}") + logger.info(f"Sending close environment command: simulation_id={simulation_id}") try: response = ipc_client.send_close_env(timeout=timeout) return { "success": response.status.value == "completed", - "message": "环境关闭命令已发送", + "message": "Environment close command sent", "result": response.result, "timestamp": response.timestamp } except TimeoutError: - # 超时可能是因为环境正在关闭 + # Timeout may be because environment is shutting down return { "success": True, - "message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)" + "message": "Environment close command sent (response timeout, environment may be shutting down)" } @classmethod @@ -1658,7 +1658,7 @@ def _get_interview_history_from_db( agent_id: Optional[int] = None, limit: int = 100 ) -> List[Dict[str, Any]]: - """从单个数据库获取Interview历史""" + """Get Interview history from a single database""" import sqlite3 if not os.path.exists(db_path): @@ -1704,7 +1704,7 @@ def _get_interview_history_from_db( conn.close() except Exception as e: - logger.error(f"读取Interview历史失败 ({platform_name}): {e}") + logger.error(f"Failed to read Interview history ({platform_name}): {e}") return results @@ -1717,29 +1717,29 @@ def get_interview_history( limit: int = 100 ) -> List[Dict[str, Any]]: """ - 获取Interview历史记录(从数据库读取) + Get Interview history records (read from database) Args: - simulation_id: 模拟ID - platform: 平台类型(reddit/twitter/None) - - "reddit": 只获取Reddit平台的历史 - - "twitter": 只获取Twitter平台的历史 - - None: 获取两个平台的所有历史 - agent_id: 指定Agent ID(可选,只获取该Agent的历史) - limit: 每个平台返回数量限制 + simulation_id: Simulation ID + platform: Platform type (reddit/twitter/None) + - "reddit": Get Reddit platform history only + - "twitter": Get Twitter platform history only + - None: Get all history from both platforms + agent_id: Specify Agent ID (optional, get history for this Agent only) + limit: Return count limit per platform Returns: - Interview历史记录列表 + Interview history record list """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) results = [] - # 确定要查询的平台 + # Determine platforms to query if platform in ("reddit", "twitter"): platforms = [platform] else: - # 不指定platform时,查询两个平台 + # When platform not specified, query both platforms platforms = ["twitter", "reddit"] for p in platforms: @@ -1752,10 +1752,10 @@ def get_interview_history( ) results.extend(platform_results) - # 按时间降序排序 + # Sort by time descending results.sort(key=lambda x: x.get("timestamp", ""), reverse=True) - # 如果查询了多个平台,限制总数 + # If queried multiple platforms, limit total count if len(platforms) > 1 and len(results) > limit: results = results[:limit] diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be49..2b3163d47 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -7,11 +7,10 @@ from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field -from zep_cloud.client import Zep +from .graphiti_adapter import GraphitiAdapter from ..config import Config from ..utils.logger import get_logger -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_entity_reader') @@ -79,11 +78,7 @@ class ZepEntityReader: """ def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + pass # No initialization needed, adapter is created per-graph def _call_with_retry( self, @@ -126,7 +121,7 @@ def _call_with_retry( def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ - 获取图谱的所有节点(分页获取) + 获取图谱的所有节点 Args: graph_id: 图谱ID @@ -134,26 +129,14 @@ def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: Returns: 节点列表 """ - logger.info(f"获取图谱 {graph_id} 的所有节点...") - - nodes = fetch_all_nodes(self.client, graph_id) - - nodes_data = [] - for node in nodes: - nodes_data.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": node.name or "", - "labels": node.labels or [], - "summary": node.summary or "", - "attributes": node.attributes or {}, - }) - - logger.info(f"共获取 {len(nodes_data)} 个节点") - return nodes_data + adapter = GraphitiAdapter.get_or_create(graph_id) + nodes = adapter.get_all_nodes() + logger.info(f"Got {len(nodes)} nodes") + return nodes def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ - 获取图谱的所有边(分页获取) + 获取图谱的所有边 Args: graph_id: 图谱ID @@ -161,56 +144,31 @@ def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: Returns: 边列表 """ - logger.info(f"获取图谱 {graph_id} 的所有边...") - - edges = fetch_all_edges(self.client, graph_id) - - edges_data = [] - for edge in edges: - edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": edge.name or "", - "fact": edge.fact or "", - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "attributes": edge.attributes or {}, - }) - - logger.info(f"共获取 {len(edges_data)} 条边") - return edges_data + adapter = GraphitiAdapter.get_or_create(graph_id) + edges = adapter.get_all_edges() + logger.info(f"Got {len(edges)} edges") + return edges - def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: + def get_node_edges(self, node_uuid: str, graph_id: str = None) -> List[Dict[str, Any]]: """ - 获取指定节点的所有相关边(带重试机制) - + 获取指定节点的所有相关边 + Args: node_uuid: 节点UUID - + graph_id: 图谱ID(可选,如果不提供则搜索所有已缓存的适配器) + Returns: 边列表 """ - try: - # 使用重试机制调用Zep API - edges = self._call_with_retry( - func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), - operation_name=f"获取节点边(node={node_uuid[:8]}...)" - ) - - edges_data = [] - for edge in edges: - edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": edge.name or "", - "fact": edge.fact or "", - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "attributes": edge.attributes or {}, - }) - - return edges_data - except Exception as e: - logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") - return [] + if graph_id: + adapter = GraphitiAdapter.get_or_create(graph_id) + return adapter.get_node_edges(node_uuid) + # Fallback: try all cached adapters + for gid, adapter in GraphitiAdapter._instances.items(): + edges = adapter.get_node_edges(node_uuid) + if edges: + return edges + return [] def filter_defined_entities( self, @@ -331,83 +289,72 @@ def filter_defined_entities( ) def get_entity_with_context( - self, - graph_id: str, + self, + graph_id: str, entity_uuid: str ) -> Optional[EntityNode]: """ - 获取单个实体及其完整上下文(边和关联节点,带重试机制) - + 获取单个实体及其完整上下文(边和关联节点) + Args: graph_id: 图谱ID entity_uuid: 实体UUID - + Returns: EntityNode或None """ try: - # 使用重试机制获取节点 - node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), - operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" - ) - + adapter = GraphitiAdapter.get_or_create(graph_id) + node = adapter.get_node(entity_uuid) if not node: return None - - # 获取节点的边 - edges = self.get_node_edges(entity_uuid) - - # 获取所有节点用于关联查找 - all_nodes = self.get_all_nodes(graph_id) + + edges = adapter.get_node_edges(entity_uuid) + all_nodes = adapter.get_all_nodes() node_map = {n["uuid"]: n for n in all_nodes} - - # 处理相关边和节点 + related_edges = [] related_node_uuids = set() - for edge in edges: - if edge["source_node_uuid"] == entity_uuid: + if edge.get("source_node_uuid") == entity_uuid: related_edges.append({ "direction": "outgoing", - "edge_name": edge["name"], - "fact": edge["fact"], - "target_node_uuid": edge["target_node_uuid"], + "edge_name": edge.get("name", ""), + "fact": edge.get("fact", ""), + "target_node_uuid": edge.get("target_node_uuid", ""), }) - related_node_uuids.add(edge["target_node_uuid"]) + related_node_uuids.add(edge.get("target_node_uuid", "")) else: related_edges.append({ "direction": "incoming", - "edge_name": edge["name"], - "fact": edge["fact"], - "source_node_uuid": edge["source_node_uuid"], + "edge_name": edge.get("name", ""), + "fact": edge.get("fact", ""), + "source_node_uuid": edge.get("source_node_uuid", ""), }) - related_node_uuids.add(edge["source_node_uuid"]) - - # 获取关联节点信息 + related_node_uuids.add(edge.get("source_node_uuid", "")) + related_nodes = [] - for related_uuid in related_node_uuids: - if related_uuid in node_map: - related_node = node_map[related_uuid] + for ruuid in related_node_uuids: + if ruuid in node_map: + rn = node_map[ruuid] related_nodes.append({ - "uuid": related_node["uuid"], - "name": related_node["name"], - "labels": related_node["labels"], - "summary": related_node.get("summary", ""), + "uuid": rn["uuid"], + "name": rn["name"], + "labels": rn["labels"], + "summary": rn.get("summary", ""), }) - + return EntityNode( - uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - name=node.name or "", - labels=node.labels or [], - summary=node.summary or "", - attributes=node.attributes or {}, + uuid=node["uuid"], + name=node["name"], + labels=node.get("labels", []), + summary=node.get("summary", ""), + attributes=node.get("attributes", {}), related_edges=related_edges, related_nodes=related_nodes, ) - except Exception as e: - logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}") + logger.error(f"Failed to get entity {entity_uuid}: {str(e)}") return None def get_entities_by_type( diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index a8f3cecd9..80ce59171 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -1,6 +1,6 @@ """ -Zep图谱记忆更新服务 -将模拟中的Agent活动动态更新到Zep图谱中 +Zep graph memory update service +Dynamically update agent activities from simulations to the Zep graph """ import os @@ -12,7 +12,7 @@ from datetime import datetime from queue import Queue, Empty -from zep_cloud.client import Zep +from .graphiti_adapter import GraphitiAdapter from ..config import Config from ..utils.logger import get_logger @@ -22,7 +22,7 @@ @dataclass class AgentActivity: - """Agent活动记录""" + """Agent activity record""" platform: str # twitter / reddit agent_id: int agent_name: str @@ -33,12 +33,10 @@ class AgentActivity: def to_episode_text(self) -> str: """ - 将活动转换为可以发送给Zep的文本描述 - - 采用自然语言描述格式,让Zep能够从中提取实体和关系 - 不添加模拟相关的前缀,避免误导图谱更新 + Convert activity to natural language text description + + Use natural language description format so NER extractor can extract entities and relationships """ - # 根据不同的动作类型生成不同的描述 action_descriptions = { "CREATE_POST": self._describe_create_post, "LIKE_POST": self._describe_like_post, @@ -57,222 +55,179 @@ def to_episode_text(self) -> str: describe_func = action_descriptions.get(self.action_type, self._describe_generic) description = describe_func() - # 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀 return f"{self.agent_name}: {description}" def _describe_create_post(self) -> str: content = self.action_args.get("content", "") if content: - return f"发布了一条帖子:「{content}」" - return "发布了一条帖子" - + return f"Posted a post: \"{content}\"" + return "Posted a post" + def _describe_like_post(self) -> str: - """点赞帖子 - 包含帖子原文和作者信息""" post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") - if post_content and post_author: - return f"点赞了{post_author}的帖子:「{post_content}」" + return f"Liked {post_author}'s post: \"{post_content}\"" elif post_content: - return f"点赞了一条帖子:「{post_content}」" + return f"Liked a post: \"{post_content}\"" elif post_author: - return f"点赞了{post_author}的一条帖子" - return "点赞了一条帖子" - + return f"Liked a post by {post_author}" + return "Liked a post" + def _describe_dislike_post(self) -> str: - """踩帖子 - 包含帖子原文和作者信息""" post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") - if post_content and post_author: - return f"踩了{post_author}的帖子:「{post_content}」" + return f"Disliked {post_author}'s post: \"{post_content}\"" elif post_content: - return f"踩了一条帖子:「{post_content}」" + return f"Disliked a post: \"{post_content}\"" elif post_author: - return f"踩了{post_author}的一条帖子" - return "踩了一条帖子" - + return f"Disliked a post by {post_author}" + return "Disliked a post" + def _describe_repost(self) -> str: - """转发帖子 - 包含原帖内容和作者信息""" original_content = self.action_args.get("original_content", "") original_author = self.action_args.get("original_author_name", "") - if original_content and original_author: - return f"转发了{original_author}的帖子:「{original_content}」" + return f"Reposted {original_author}'s post: \"{original_content}\"" elif original_content: - return f"转发了一条帖子:「{original_content}」" + return f"Reposted a post: \"{original_content}\"" elif original_author: - return f"转发了{original_author}的一条帖子" - return "转发了一条帖子" - + return f"Reposted a post by {original_author}" + return "Reposted a post" + def _describe_quote_post(self) -> str: - """引用帖子 - 包含原帖内容、作者信息和引用评论""" original_content = self.action_args.get("original_content", "") original_author = self.action_args.get("original_author_name", "") quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "") - base = "" if original_content and original_author: - base = f"引用了{original_author}的帖子「{original_content}」" + base = f"Quoted {original_author}'s post \"{original_content}\"" elif original_content: - base = f"引用了一条帖子「{original_content}」" + base = f"Quoted a post \"{original_content}\"" elif original_author: - base = f"引用了{original_author}的一条帖子" + base = f"Quoted a post by {original_author}" else: - base = "引用了一条帖子" - + base = "Quoted a post" if quote_content: - base += f",并评论道:「{quote_content}」" + base += f", commented: \"{quote_content}\"" return base - + def _describe_follow(self) -> str: - """关注用户 - 包含被关注用户的名称""" target_user_name = self.action_args.get("target_user_name", "") - if target_user_name: - return f"关注了用户「{target_user_name}」" - return "关注了一个用户" - + return f"Followed user \"{target_user_name}\"" + return "Followed a user" + def _describe_create_comment(self) -> str: - """发表评论 - 包含评论内容和所评论的帖子信息""" content = self.action_args.get("content", "") post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") - if content: if post_content and post_author: - return f"在{post_author}的帖子「{post_content}」下评论道:「{content}」" + return f"Commented on {post_author}'s post \"{post_content}\": \"{content}\"" elif post_content: - return f"在帖子「{post_content}」下评论道:「{content}」" + return f"Commented on post \"{post_content}\": \"{content}\"" elif post_author: - return f"在{post_author}的帖子下评论道:「{content}」" - return f"评论道:「{content}」" - return "发表了评论" - + return f"Commented on {post_author}'s post: \"{content}\"" + return f"Commented: \"{content}\"" + return "Left a comment" + def _describe_like_comment(self) -> str: - """点赞评论 - 包含评论内容和作者信息""" comment_content = self.action_args.get("comment_content", "") comment_author = self.action_args.get("comment_author_name", "") - if comment_content and comment_author: - return f"点赞了{comment_author}的评论:「{comment_content}」" + return f"Liked {comment_author}'s comment: \"{comment_content}\"" elif comment_content: - return f"点赞了一条评论:「{comment_content}」" + return f"Liked a comment: \"{comment_content}\"" elif comment_author: - return f"点赞了{comment_author}的一条评论" - return "点赞了一条评论" - + return f"Liked a comment by {comment_author}" + return "Liked a comment" + def _describe_dislike_comment(self) -> str: - """踩评论 - 包含评论内容和作者信息""" comment_content = self.action_args.get("comment_content", "") comment_author = self.action_args.get("comment_author_name", "") - if comment_content and comment_author: - return f"踩了{comment_author}的评论:「{comment_content}」" + return f"Disliked {comment_author}'s comment: \"{comment_content}\"" elif comment_content: - return f"踩了一条评论:「{comment_content}」" + return f"Disliked a comment: \"{comment_content}\"" elif comment_author: - return f"踩了{comment_author}的一条评论" - return "踩了一条评论" - + return f"Disliked a comment by {comment_author}" + return "Disliked a comment" + def _describe_search(self) -> str: - """搜索帖子 - 包含搜索关键词""" query = self.action_args.get("query", "") or self.action_args.get("keyword", "") - return f"搜索了「{query}」" if query else "进行了搜索" - + return f"Searched for \"{query}\""if query else "Performed a search" + def _describe_search_user(self) -> str: - """搜索用户 - 包含搜索关键词""" query = self.action_args.get("query", "") or self.action_args.get("username", "") - return f"搜索了用户「{query}」" if query else "搜索了用户" - + return f"Searched for user \"{query}\""if query else "Searched for user" + def _describe_mute(self) -> str: - """屏蔽用户 - 包含被屏蔽用户的名称""" target_user_name = self.action_args.get("target_user_name", "") - if target_user_name: - return f"屏蔽了用户「{target_user_name}」" - return "屏蔽了一个用户" - + return f"Muted user \"{target_user_name}\"" + return "Muted a user" + def _describe_generic(self) -> str: - # 对于未知的动作类型,生成通用描述 - return f"执行了{self.action_type}操作" + return f"Executed {self.action_type} action" class ZepGraphMemoryUpdater: """ - Zep图谱记忆更新器 - - 监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。 - 按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。 - - 所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息: - - 点赞/踩的帖子原文 - - 转发/引用的帖子原文 - - 关注/屏蔽的用户名 - - 点赞/踩的评论原文 + Zep graph memory updater + + Monitors simulation action logs and sends agent activities to the Zep graph in real-time. + Batches activities by platform, accumulating BATCH_SIZE activities before sending each batch. """ - - # 批量发送大小(每个平台累积多少条后发送) + BATCH_SIZE = 5 - - # 平台名称映射(用于控制台显示) + PLATFORM_DISPLAY_NAMES = { - 'twitter': '世界1', - 'reddit': '世界2', + 'twitter': 'worldinterface1', + 'reddit': 'worldinterface2', } - - # 发送间隔(秒),避免请求过快 + SEND_INTERVAL = 0.5 - - # 重试配置 MAX_RETRIES = 3 - RETRY_DELAY = 2 # 秒 + RETRY_DELAY = 2 def __init__(self, graph_id: str, api_key: Optional[str] = None): """ - 初始化更新器 - + Initialize updater + Args: - graph_id: Zep图谱ID - api_key: Zep API Key(可选,默认从配置读取) + graph_id: Graph ID + api_key: Reserved parameter, no longer used """ self.graph_id = graph_id - self.api_key = api_key or Config.ZEP_API_KEY - - if not self.api_key: - raise ValueError("ZEP_API_KEY未配置") + self.adapter = GraphitiAdapter.get_or_create(graph_id) - self.client = Zep(api_key=self.api_key) - - # 活动队列 self._activity_queue: Queue = Queue() - - # 按平台分组的活动缓冲区(每个平台各自累积到BATCH_SIZE后批量发送) + self._platform_buffers: Dict[str, List[AgentActivity]] = { 'twitter': [], 'reddit': [], } self._buffer_lock = threading.Lock() - - # 控制标志 + self._running = False self._worker_thread: Optional[threading.Thread] = None - - # 统计 - self._total_activities = 0 # 实际添加到队列的活动数 - self._total_sent = 0 # 成功发送到Zep的批次数 - self._total_items_sent = 0 # 成功发送到Zep的活动条数 - self._failed_count = 0 # 发送失败的批次数 - self._skipped_count = 0 # 被过滤跳过的活动数(DO_NOTHING) - - logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}") + + self._total_activities = 0 + self._total_sent = 0 + self._total_items_sent = 0 + self._failed_count = 0 + self._skipped_count = 0 + + logger.info(f"ZepGraphMemoryUpdater initialized: graph_id={graph_id}, batch_size={self.BATCH_SIZE}") def _get_platform_display_name(self, platform: str) -> str: - """获取平台的显示名称""" + """Get display name for platform""" return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform) def start(self): - """启动后台工作线程""" + """Start background worker thread""" if self._running: return @@ -283,19 +238,18 @@ def start(self): name=f"ZepMemoryUpdater-{self.graph_id[:8]}" ) self._worker_thread.start() - logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}") + logger.info(f"ZepGraphMemoryUpdater started: graph_id={self.graph_id}") def stop(self): - """停止后台工作线程""" + """Stop background worker thread""" self._running = False - - # 发送剩余的活动 + self._flush_remaining() if self._worker_thread and self._worker_thread.is_alive(): self._worker_thread.join(timeout=10) - logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, " + logger.info(f"ZepGraphMemoryUpdater stopped: graph_id={self.graph_id}, " f"total_activities={self._total_activities}, " f"batches_sent={self._total_sent}, " f"items_sent={self._total_items_sent}, " @@ -303,44 +257,19 @@ def stop(self): f"skipped={self._skipped_count}") def add_activity(self, activity: AgentActivity): - """ - 添加一个agent活动到队列 - - 所有有意义的行为都会被添加到队列,包括: - - CREATE_POST(发帖) - - CREATE_COMMENT(评论) - - QUOTE_POST(引用帖子) - - SEARCH_POSTS(搜索帖子) - - SEARCH_USER(搜索用户) - - LIKE_POST/DISLIKE_POST(点赞/踩帖子) - - REPOST(转发) - - FOLLOW(关注) - - MUTE(屏蔽) - - LIKE_COMMENT/DISLIKE_COMMENT(点赞/踩评论) - - action_args中会包含完整的上下文信息(如帖子原文、用户名等)。 - - Args: - activity: Agent活动记录 - """ - # 跳过DO_NOTHING类型的活动 + """Add an agent activity to queue""" + # Skip DO_NOTHING activities if activity.action_type == "DO_NOTHING": self._skipped_count += 1 return self._activity_queue.put(activity) self._total_activities += 1 - logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}") + logger.debug(f"Add activity to queue: {activity.agent_name} - {activity.action_type}") def add_activity_from_dict(self, data: Dict[str, Any], platform: str): - """ - 从字典数据添加活动 - - Args: - data: 从actions.jsonl解析的字典数据 - platform: 平台名称 (twitter/reddit) - """ - # 跳过事件类型的条目 + """Add activity from dict data""" + # Skip event type entries if "event_type" in data: return @@ -357,78 +286,69 @@ def add_activity_from_dict(self, data: Dict[str, Any], platform: str): self.add_activity(activity) def _worker_loop(self): - """后台工作循环 - 按平台批量发送活动到Zep""" + """Background worker loop - batch send activities to Zep by platform""" while self._running or not self._activity_queue.empty(): try: - # 尝试从队列获取活动(超时1秒) + # Try to get activity from queue (1 second timeout) try: activity = self._activity_queue.get(timeout=1) - # 将活动添加到对应平台的缓冲区 + # Add activity to the corresponding platform buffer platform = activity.platform.lower() with self._buffer_lock: if platform not in self._platform_buffers: self._platform_buffers[platform] = [] self._platform_buffers[platform].append(activity) - # 检查该平台是否达到批量大小 + # Check if platform has reached batch size if len(self._platform_buffers[platform]) >= self.BATCH_SIZE: batch = self._platform_buffers[platform][:self.BATCH_SIZE] self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:] - # 释放锁后再发送 + # Send after releasing lock self._send_batch_activities(batch, platform) - # 发送间隔,避免请求过快 + # Send interval to avoid too frequent requests time.sleep(self.SEND_INTERVAL) except Empty: pass except Exception as e: - logger.error(f"工作循环异常: {e}") + logger.error(f"Worker loop exception: {e}") time.sleep(1) def _send_batch_activities(self, activities: List[AgentActivity], platform: str): """ - 批量发送活动到Zep图谱(合并为一条文本) - - Args: - activities: Agent活动列表 - platform: 平台名称 + Send batched activities to the Zep graph (merged as one text) """ if not activities: return - # 将多条活动合并为一条文本,用换行分隔 + # Merge multiple activities into one text, separated by newlines episode_texts = [activity.to_episode_text() for activity in activities] combined_text = "\n".join(episode_texts) - # 带重试的发送 + # Send with retry for attempt in range(self.MAX_RETRIES): try: - self.client.graph.add( - graph_id=self.graph_id, - type="text", - data=combined_text - ) + self.adapter.add_episode(combined_text, source_description="simulation_activity") self._total_sent += 1 self._total_items_sent += len(activities) display_name = self._get_platform_display_name(platform) - logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}") - logger.debug(f"批量内容预览: {combined_text[:200]}...") + logger.info(f"Successfully batch sent {len(activities)} {display_name} activities to graph {self.graph_id}") + logger.debug(f"Batch preview: {combined_text[:200]}...") return except Exception as e: if attempt < self.MAX_RETRIES - 1: - logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}") + logger.warning(f"Batch send to Zep failed (attempt {attempt + 1}/{self.MAX_RETRIES}): {e}") time.sleep(self.RETRY_DELAY * (attempt + 1)) else: - logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}") + logger.error(f"Batch send to Zep failed after {self.MAX_RETRIES} retries: {e}") self._failed_count += 1 def _flush_remaining(self): - """发送队列和缓冲区中剩余的活动""" - # 首先处理队列中剩余的活动,添加到缓冲区 + """Send remaining activities in queue and buffers""" while not self._activity_queue.empty(): try: activity = self._activity_queue.get_nowait() @@ -440,41 +360,41 @@ def _flush_remaining(self): except Empty: break - # 然后发送各平台缓冲区中剩余的活动(即使不足BATCH_SIZE条) + # Send remaining activities from each platform buffer with self._buffer_lock: for platform, buffer in self._platform_buffers.items(): if buffer: display_name = self._get_platform_display_name(platform) - logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动") + logger.info(f"Send remaining {len(buffer)} {display_name} platform activities") self._send_batch_activities(buffer, platform) - # 清空所有缓冲区 + # Clear all buffers for platform in self._platform_buffers: self._platform_buffers[platform] = [] def get_stats(self) -> Dict[str, Any]: - """获取统计信息""" + """Get statistics""" with self._buffer_lock: buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()} return { "graph_id": self.graph_id, "batch_size": self.BATCH_SIZE, - "total_activities": self._total_activities, # 添加到队列的活动总数 - "batches_sent": self._total_sent, # 成功发送的批次数 - "items_sent": self._total_items_sent, # 成功发送的活动条数 - "failed_count": self._failed_count, # 发送失败的批次数 - "skipped_count": self._skipped_count, # 被过滤跳过的活动数(DO_NOTHING) + "total_activities": self._total_activities, + "batches_sent": self._total_sent, + "items_sent": self._total_items_sent, + "failed_count": self._failed_count, + "skipped_count": self._skipped_count, "queue_size": self._activity_queue.qsize(), - "buffer_sizes": buffer_sizes, # 各平台缓冲区大小 + "buffer_sizes": buffer_sizes, "running": self._running, } class ZepGraphMemoryManager: """ - 管理多个模拟的Zep图谱记忆更新器 - - 每个模拟可以有自己的更新器实例 + Manages graph memory updaters for multiple simulations. + + Each simulation can have its own independent updater instance. """ _updaters: Dict[str, ZepGraphMemoryUpdater] = {} @@ -483,17 +403,14 @@ class ZepGraphMemoryManager: @classmethod def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater: """ - 为模拟创建图谱记忆更新器 - + Create a graph memory updater for a simulation. + Args: - simulation_id: 模拟ID - graph_id: Zep图谱ID - - Returns: - ZepGraphMemoryUpdater实例 + simulation_id: Simulation ID + graph_id: Graph ID """ with cls._lock: - # 如果已存在,先停止旧的 + # If already exists, stop the old one if simulation_id in cls._updaters: cls._updaters[simulation_id].stop() @@ -501,30 +418,29 @@ def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpda updater.start() cls._updaters[simulation_id] = updater - logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}") + logger.info(f"Create graph memory updater: simulation_id={simulation_id}, graph_id={graph_id}") return updater @classmethod def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]: - """获取模拟的更新器""" + """Get updater for simulation""" return cls._updaters.get(simulation_id) @classmethod def stop_updater(cls, simulation_id: str): - """停止并移除模拟的更新器""" + """Stop and remove updater for simulation""" with cls._lock: if simulation_id in cls._updaters: cls._updaters[simulation_id].stop() del cls._updaters[simulation_id] - logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}") + logger.info(f"Stopped graph memory updater: simulation_id={simulation_id}") - # 防止 stop_all 重复调用的标志 + # Flag to prevent duplicate stop_all calls _stop_all_done = False @classmethod def stop_all(cls): - """停止所有更新器""" - # 防止重复调用 + """Stop all updaters""" if cls._stop_all_done: return cls._stop_all_done = True @@ -535,13 +451,13 @@ def stop_all(cls): try: updater.stop() except Exception as e: - logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}") + logger.error(f"Failed to stop updater: simulation_id={simulation_id}, error={e}") cls._updaters.clear() - logger.info("已停止所有图谱记忆更新器") + logger.info("Stopped all graph memory updaters") @classmethod def get_all_stats(cls) -> Dict[str, Dict[str, Any]]: - """获取所有更新器的统计信息""" + """Get statistics for all updaters""" return { sim_id: updater.get_stats() for sim_id, updater in cls._updaters.items() diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 384cf540f..e87564909 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -1,11 +1,11 @@ """ -Zep检索工具服务 -封装图谱搜索、节点读取、边查询等工具,供Report Agent使用 +Zep Retrieval Tools Service +Encapsulates graph search, node retrieval, edge queries, and other tools for use by Report Agent. -核心检索工具(优化后): -1. InsightForge(深度洞察检索)- 最强大的混合检索,自动生成子问题并多维度检索 -2. PanoramaSearch(广度搜索)- 获取全貌,包括过期内容 -3. QuickSearch(简单搜索)- 快速检索 +Core Retrieval Tools (Optimized): +1. InsightForge (Deep Insight Retrieval) - Most powerful hybrid search, automatically generates sub-questions and multi-dimensional retrieval +2. PanoramaSearch (Breadth Search) - Get comprehensive view, including expired content +3. QuickSearch (Simple Search) - Quick retrieval """ import time @@ -13,19 +13,17 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.llm_client import LLMClient -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .graphiti_adapter import GraphitiAdapter logger = get_logger('mirofish.zep_tools') @dataclass class SearchResult: - """搜索结果""" + """Search Result""" facts: List[str] edges: List[Dict[str, Any]] nodes: List[Dict[str, Any]] @@ -42,11 +40,11 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为文本格式,供LLM理解""" - text_parts = [f"搜索查询: {self.query}", f"找到 {self.total_count} 条相关信息"] - + """Convert to text format for LLM understanding""" + text_parts = [f"Search Query: {self.query}", f"Found {self.total_count} related results"] + if self.facts: - text_parts.append("\n### 相关事实:") + text_parts.append("\n### Related Facts:") for i, fact in enumerate(self.facts, 1): text_parts.append(f"{i}. {fact}") @@ -55,7 +53,7 @@ def to_text(self) -> str: @dataclass class NodeInfo: - """节点信息""" + """Node Information""" uuid: str name: str labels: List[str] @@ -72,14 +70,14 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为文本格式""" - entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型") - return f"实体: {self.name} (类型: {entity_type})\n摘要: {self.summary}" + """Convert to text format""" + entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "Unknown type") + return f"Entity: {self.name} (Type: {entity_type})\nSummary: {self.summary}" @dataclass class EdgeInfo: - """边信息""" + """Edge Information""" uuid: str name: str fact: str @@ -87,7 +85,7 @@ class EdgeInfo: target_node_uuid: str source_node_name: Optional[str] = None target_node_name: Optional[str] = None - # 时间信息 + # Temporal information created_at: Optional[str] = None valid_at: Optional[str] = None invalid_at: Optional[str] = None @@ -109,47 +107,47 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self, include_temporal: bool = False) -> str: - """转换为文本格式""" + """Convert to text format""" source = self.source_node_name or self.source_node_uuid[:8] target = self.target_node_name or self.target_node_uuid[:8] - base_text = f"关系: {source} --[{self.name}]--> {target}\n事实: {self.fact}" - + base_text = f"Relationship: {source} --[{self.name}]--> {target}\nFact: {self.fact}" + if include_temporal: - valid_at = self.valid_at or "未知" - invalid_at = self.invalid_at or "至今" - base_text += f"\n时效: {valid_at} - {invalid_at}" + valid_at = self.valid_at or "Unknown" + invalid_at = self.invalid_at or "Present" + base_text += f"\nTime Range: {valid_at} - {invalid_at}" if self.expired_at: - base_text += f" (已过期: {self.expired_at})" + base_text += f" (Expired: {self.expired_at})" return base_text @property def is_expired(self) -> bool: - """是否已过期""" + """Whether already expired""" return self.expired_at is not None @property def is_invalid(self) -> bool: - """是否已失效""" + """Whether already invalid""" return self.invalid_at is not None @dataclass class InsightForgeResult: """ - 深度洞察检索结果 (InsightForge) - 包含多个子问题的检索结果,以及综合分析 + Deep Insight Retrieval Result (InsightForge) + Contains retrieval results from multiple sub-questions and integrated analysis """ query: str simulation_requirement: str sub_queries: List[str] - # 各维度检索结果 - semantic_facts: List[str] = field(default_factory=list) # 语义搜索结果 - entity_insights: List[Dict[str, Any]] = field(default_factory=list) # 实体洞察 - relationship_chains: List[str] = field(default_factory=list) # 关系链 - - # 统计信息 + # Retrieval results by dimension + semantic_facts: List[str] = field(default_factory=list) + entity_insights: List[Dict[str, Any]] = field(default_factory=list) + relationship_chains: List[str] = field(default_factory=list) + + # Statistical information total_facts: int = 0 total_entities: int = 0 total_relationships: int = 0 @@ -168,42 +166,38 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为详细的文本格式,供LLM理解""" + """Convert to detailed text format for LLM understanding""" text_parts = [ - f"## 未来预测深度分析", - f"分析问题: {self.query}", - f"预测场景: {self.simulation_requirement}", - f"\n### 预测数据统计", - f"- 相关预测事实: {self.total_facts}条", - f"- 涉及实体: {self.total_entities}个", - f"- 关系链: {self.total_relationships}条" + f"## Future Prediction Deep Analysis", + f"Analysis Query: {self.query}", + f"Prediction Scenario: {self.simulation_requirement}", + f"\n### Prediction Data Statistics", + f"- Related Prediction Facts: {self.total_facts}", + f"- Involved Entities: {self.total_entities}", + f"- Relationship Chains: {self.total_relationships}" ] - - # 子问题 + if self.sub_queries: - text_parts.append(f"\n### 分析的子问题") + text_parts.append(f"\n### Analysis Sub-Questions") for i, sq in enumerate(self.sub_queries, 1): text_parts.append(f"{i}. {sq}") - - # 语义搜索结果 + if self.semantic_facts: - text_parts.append(f"\n### 【关键事实】(请在报告中引用这些原文)") + text_parts.append(f"\n### Key Facts (Please quote these verbatim in the report)") for i, fact in enumerate(self.semantic_facts, 1): - text_parts.append(f"{i}. \"{fact}\"") - - # 实体洞察 + text_parts.append(f'{i}. "{fact}"') + if self.entity_insights: - text_parts.append(f"\n### 【核心实体】") + text_parts.append(f"\n### Core Entities") for entity in self.entity_insights: - text_parts.append(f"- **{entity.get('name', '未知')}** ({entity.get('type', '实体')})") + text_parts.append(f"- **{entity.get('name', 'Unknown')}** ({entity.get('type', 'Entity')})") if entity.get('summary'): - text_parts.append(f" 摘要: \"{entity.get('summary')}\"") + text_parts.append(f" Summary: \"{entity.get('summary')}\"") if entity.get('related_facts'): - text_parts.append(f" 相关事实: {len(entity.get('related_facts', []))}条") - - # 关系链 + text_parts.append(f" Related Facts: {len(entity.get('related_facts', []))} facts") + if self.relationship_chains: - text_parts.append(f"\n### 【关系链】") + text_parts.append(f"\n### Relationship Chains") for chain in self.relationship_chains: text_parts.append(f"- {chain}") @@ -213,21 +207,16 @@ def to_text(self) -> str: @dataclass class PanoramaResult: """ - 广度搜索结果 (Panorama) - 包含所有相关信息,包括过期内容 + Breadth Search Result (Panorama) + Contains all related information, including expired content """ query: str - # 全部节点 all_nodes: List[NodeInfo] = field(default_factory=list) - # 全部边(包括过期的) all_edges: List[EdgeInfo] = field(default_factory=list) - # 当前有效的事实 active_facts: List[str] = field(default_factory=list) - # 已过期/失效的事实(历史记录) historical_facts: List[str] = field(default_factory=list) - - # 统计 + total_nodes: int = 0 total_edges: int = 0 active_count: int = 0 @@ -247,34 +236,31 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为文本格式(完整版本,不截断)""" + """Convert to text format (complete version, no truncation)""" text_parts = [ - f"## 广度搜索结果(未来全景视图)", - f"查询: {self.query}", - f"\n### 统计信息", - f"- 总节点数: {self.total_nodes}", - f"- 总边数: {self.total_edges}", - f"- 当前有效事实: {self.active_count}条", - f"- 历史/过期事实: {self.historical_count}条" + f"## Breadth Search Results (Future Panoramic View)", + f"Query: {self.query}", + f"\n### Statistics", + f"- Total Nodes: {self.total_nodes}", + f"- Total Edges: {self.total_edges}", + f"- Current Valid Facts: {self.active_count}", + f"- Historical/Expired Facts: {self.historical_count}" ] - - # 当前有效的事实(完整输出,不截断) + if self.active_facts: - text_parts.append(f"\n### 【当前有效事实】(模拟结果原文)") + text_parts.append(f"\n### Current Valid Facts (Simulation Results Verbatim)") for i, fact in enumerate(self.active_facts, 1): - text_parts.append(f"{i}. \"{fact}\"") - - # 历史/过期事实(完整输出,不截断) + text_parts.append(f'{i}. "{fact}"') + if self.historical_facts: - text_parts.append(f"\n### 【历史/过期事实】(演变过程记录)") + text_parts.append(f"\n### Historical/Expired Facts (Evolution Record)") for i, fact in enumerate(self.historical_facts, 1): - text_parts.append(f"{i}. \"{fact}\"") - - # 关键实体(完整输出,不截断) + text_parts.append(f'{i}. "{fact}"') + if self.all_nodes: - text_parts.append(f"\n### 【涉及实体】") + text_parts.append(f"\n### Involved Entities") for node in self.all_nodes: - entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") + entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "Entity") text_parts.append(f"- **{node.name}** ({entity_type})") return "\n".join(text_parts) @@ -282,13 +268,13 @@ def to_text(self) -> str: @dataclass class AgentInterview: - """单个Agent的采访结果""" + """Single Agent Interview Result""" agent_name: str - agent_role: str # 角色类型(如:学生、教师、媒体等) - agent_bio: str # 简介 - question: str # 采访问题 - response: str # 采访回答 - key_quotes: List[str] = field(default_factory=list) # 关键引言 + agent_role: str + agent_bio: str + question: str + response: str + key_quotes: List[str] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { @@ -302,21 +288,17 @@ def to_dict(self) -> Dict[str, Any]: def to_text(self) -> str: text = f"**{self.agent_name}** ({self.agent_role})\n" - # 显示完整的agent_bio,不截断 - text += f"_简介: {self.agent_bio}_\n\n" + text += f"_Bio: {self.agent_bio}_\n\n" text += f"**Q:** {self.question}\n\n" text += f"**A:** {self.response}\n" if self.key_quotes: - text += "\n**关键引言:**\n" + text += "\n**Key Quotes:**\n" for quote in self.key_quotes: - # 清理各种引号 clean_quote = quote.replace('\u201c', '').replace('\u201d', '').replace('"', '') clean_quote = clean_quote.replace('\u300c', '').replace('\u300d', '') clean_quote = clean_quote.strip() - # 去掉开头的标点 while clean_quote and clean_quote[0] in ',,;;::、。!?\n\r\t ': clean_quote = clean_quote[1:] - # 过滤包含问题编号的垃圾内容(问题1-9) skip = False for d in '123456789': if f'\u95ee\u9898{d}' in clean_quote: @@ -324,7 +306,7 @@ def to_text(self) -> str: break if skip: continue - # 截断过长内容(按句号截断,而非硬截断) + # Truncate long content (by period, not hard truncation) if len(clean_quote) > 150: dot_pos = clean_quote.find('\u3002', 80) if dot_pos > 0: @@ -339,23 +321,17 @@ def to_text(self) -> str: @dataclass class InterviewResult: """ - 采访结果 (Interview) - 包含多个模拟Agent的采访回答 + Interview Result + Contains interview responses from multiple simulated Agents """ - interview_topic: str # 采访主题 - interview_questions: List[str] # 采访问题列表 - - # 采访选择的Agent + interview_topic: str + interview_questions: List[str] + selected_agents: List[Dict[str, Any]] = field(default_factory=list) - # 各Agent的采访回答 interviews: List[AgentInterview] = field(default_factory=list) - - # 选择Agent的理由 + selection_reasoning: str = "" - # 整合后的采访摘要 summary: str = "" - - # 统计 total_agents: int = 0 interviewed_count: int = 0 @@ -372,74 +348,68 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为详细的文本格式,供LLM理解和报告引用""" + """Convert to detailed text format for LLM understanding and report reference""" text_parts = [ - "## 深度采访报告", - f"**采访主题:** {self.interview_topic}", - f"**采访人数:** {self.interviewed_count} / {self.total_agents} 位模拟Agent", - "\n### 采访对象选择理由", - self.selection_reasoning or "(自动选择)", + "## Deep Interview Report", + f"**Interview Topic:** {self.interview_topic}", + f"**Interviewees:** {self.interviewed_count} / {self.total_agents} Simulated Agents", + "\n### Selection Rationale", + self.selection_reasoning or "(Automatic Selection)", "\n---", - "\n### 采访实录", + "\n### Interview Transcripts", ] if self.interviews: for i, interview in enumerate(self.interviews, 1): - text_parts.append(f"\n#### 采访 #{i}: {interview.agent_name}") + text_parts.append(f"\n#### Interview #{i}: {interview.agent_name}") text_parts.append(interview.to_text()) text_parts.append("\n---") else: - text_parts.append("(无采访记录)\n\n---") + text_parts.append("(No interview records)\n\n---") - text_parts.append("\n### 采访摘要与核心观点") - text_parts.append(self.summary or "(无摘要)") + text_parts.append("\n### Interview Summary & Key Insights") + text_parts.append(self.summary or "(No summary)") return "\n".join(text_parts) class ZepToolsService: """ - Zep检索工具服务 - - 【核心检索工具 - 优化后】 - 1. insight_forge - 深度洞察检索(最强大,自动生成子问题,多维度检索) - 2. panorama_search - 广度搜索(获取全貌,包括过期内容) - 3. quick_search - 简单搜索(快速检索) - 4. interview_agents - 深度采访(采访模拟Agent,获取多视角观点) - - 【基础工具】 - - search_graph - 图谱语义搜索 - - get_all_nodes - 获取图谱所有节点 - - get_all_edges - 获取图谱所有边(含时间信息) - - get_node_detail - 获取节点详细信息 - - get_node_edges - 获取节点相关的边 - - get_entities_by_type - 按类型获取实体 - - get_entity_summary - 获取实体的关系摘要 + Zep Retrieval Tools Service + + [Core Retrieval Tools - Optimized] + 1. insight_forge - Deep Insight Retrieval (Most powerful, auto-generates sub-questions, multi-dimensional retrieval) + 2. panorama_search - Breadth Search (Get comprehensive view, including expired content) + 3. quick_search - Simple Search (Quick retrieval) + 4. interview_agents - Deep Interview (Interview simulated Agents, obtain multi-perspective insights) + + [Basic Tools] + - search_graph - Graph semantic search + - get_all_nodes - Get all nodes in graph + - get_all_edges - Get all edges in graph (with temporal information) + - get_node_detail - Get detailed node information + - get_node_edges - Get edges related to a node + - get_entities_by_type - Get entities by type + - get_entity_summary - Get entity relationship summary """ - # 重试配置 + # Retry configuration MAX_RETRIES = 3 RETRY_DELAY = 2.0 def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) - # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client - logger.info("ZepToolsService 初始化完成") + logger.info("ZepToolsService initialized") @property def llm(self) -> LLMClient: - """延迟初始化LLM客户端""" + """Lazy initialization of LLM client""" if self._llm_client is None: self._llm_client = LLMClient() return self._llm_client def _call_with_retry(self, func, operation_name: str, max_retries: int = None): - """带重试机制的API调用""" + """API call with retry mechanism""" max_retries = max_retries or self.MAX_RETRIES last_exception = None delay = self.RETRY_DELAY @@ -451,13 +421,13 @@ def _call_with_retry(self, func, operation_name: str, max_retries: int = None): last_exception = e if attempt < max_retries - 1: logger.warning( - f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " - f"{delay:.1f}秒后重试..." + f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}, " + f"retrying in {delay:.1f}s..." ) time.sleep(delay) delay *= 2 else: - logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") + logger.error(f"Zep {operation_name} failed after {max_retries} attempts: {str(e)}") raise last_exception @@ -469,78 +439,49 @@ def search_graph( scope: str = "edges" ) -> SearchResult: """ - 图谱语义搜索 - - 使用混合搜索(语义+BM25)在图谱中搜索相关信息。 - 如果Zep Cloud的search API不可用,则降级为本地关键词匹配。 - + Graph semantic search + + Uses hybrid search (semantic + BM25) to search for related information in the graph. + Falls back to local keyword matching if the search API is unavailable. + Args: - graph_id: 图谱ID (Standalone Graph) - query: 搜索查询 - limit: 返回结果数量 - scope: 搜索范围,"edges" 或 "nodes" - + graph_id: Graph ID (Standalone Graph) + query: Search query + limit: Number of results to return + scope: Search scope, "edges" or "nodes" + Returns: - SearchResult: 搜索结果 + SearchResult """ - logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...") - - # 尝试使用Zep Cloud Search API + logger.info(f"Graph search: graph_id={graph_id}, query={query[:50]}...") + try: - search_results = self._call_with_retry( - func=lambda: self.client.graph.search( - graph_id=graph_id, - query=query, - limit=limit, - scope=scope, - reranker="cross_encoder" - ), - operation_name=f"图谱搜索(graph={graph_id})" - ) - + adapter = GraphitiAdapter.get_or_create(graph_id) + result = adapter.search(query, limit=limit, scope=scope) + facts = [] - edges = [] - nodes = [] - - # 解析边搜索结果 - if hasattr(search_results, 'edges') and search_results.edges: - for edge in search_results.edges: - if hasattr(edge, 'fact') and edge.fact: - facts.append(edge.fact) - edges.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": getattr(edge, 'name', ''), - "fact": getattr(edge, 'fact', ''), - "source_node_uuid": getattr(edge, 'source_node_uuid', ''), - "target_node_uuid": getattr(edge, 'target_node_uuid', ''), - }) - - # 解析节点搜索结果 - if hasattr(search_results, 'nodes') and search_results.nodes: - for node in search_results.nodes: - nodes.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": getattr(node, 'name', ''), - "labels": getattr(node, 'labels', []), - "summary": getattr(node, 'summary', ''), - }) - # 节点摘要也算作事实 - if hasattr(node, 'summary') and node.summary: - facts.append(f"[{node.name}]: {node.summary}") - - logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实") - + edges_data = [] + nodes_data = [] + + if scope == "edges": + for edge in result.get("edges", []): + fact = edge.get("fact", "") + if fact: + facts.append(fact) + edges_data.append(edge) + else: + for node in result.get("nodes", []): + nodes_data.append(node) + return SearchResult( facts=facts, - edges=edges, - nodes=nodes, + edges=edges_data, + nodes=nodes_data, query=query, - total_count=len(facts) + total_count=len(facts) + len(nodes_data) ) - except Exception as e: - logger.warning(f"Zep Search API失败,降级为本地搜索: {str(e)}") - # 降级:使用本地关键词匹配搜索 + logger.warning(f"Search failed, falling back to local search: {e}") return self._local_search(graph_id, query, limit, scope) def _local_search( @@ -551,38 +492,38 @@ def _local_search( scope: str = "edges" ) -> SearchResult: """ - 本地关键词匹配搜索(作为Zep Search API的降级方案) - - 获取所有边/节点,然后在本地进行关键词匹配 - + Local keyword matching search (fallback approach) + + Gets all edges/nodes and performs local keyword matching. + Args: - graph_id: 图谱ID - query: 搜索查询 - limit: 返回结果数量 - scope: 搜索范围 - + graph_id: Graph ID + query: Search query + limit: Number of results to return + scope: Search scope + Returns: - SearchResult: 搜索结果 + SearchResult """ - logger.info(f"使用本地搜索: query={query[:30]}...") + logger.info(f"Using local search: query={query[:30]}...") facts = [] edges_result = [] nodes_result = [] - # 提取查询关键词(简单分词) + # Extract query keywords (simple tokenization) query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] def match_score(text: str) -> int: - """计算文本与查询的匹配分数""" + """Calculate match score between text and query""" if not text: return 0 text_lower = text.lower() - # 完全匹配查询 + # Full query match if query_lower in text_lower: return 100 - # 关键词匹配 + # Keyword match score = 0 for keyword in keywords: if keyword in text_lower: @@ -591,7 +532,7 @@ def match_score(text: str) -> int: try: if scope in ["edges", "both"]: - # 获取所有边并匹配 + # Get all edges and match all_edges = self.get_all_edges(graph_id) scored_edges = [] for edge in all_edges: @@ -599,7 +540,7 @@ def match_score(text: str) -> int: if score > 0: scored_edges.append((score, edge)) - # 按分数排序 + # Sort by score scored_edges.sort(key=lambda x: x[0], reverse=True) for score, edge in scored_edges[:limit]: @@ -614,7 +555,7 @@ def match_score(text: str) -> int: }) if scope in ["nodes", "both"]: - # 获取所有节点并匹配 + # Get all nodes and match all_nodes = self.get_all_nodes(graph_id) scored_nodes = [] for node in all_nodes: @@ -634,10 +575,10 @@ def match_score(text: str) -> int: if node.summary: facts.append(f"[{node.name}]: {node.summary}") - logger.info(f"本地搜索完成: 找到 {len(facts)} 条相关事实") + logger.info(f"Local search complete: Found {len(facts)} related facts") except Exception as e: - logger.error(f"本地搜索失败: {str(e)}") + logger.error(f"Local search failed: {str(e)}") return SearchResult( facts=facts, @@ -649,132 +590,126 @@ def match_score(text: str) -> int: def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ - 获取图谱的所有节点(分页获取) + Get all nodes in the graph Args: - graph_id: 图谱ID + graph_id: Graph ID Returns: - 节点列表 + List of nodes """ - logger.info(f"获取图谱 {graph_id} 的所有节点...") + logger.info(f"Getting all nodes in graph {graph_id}...") - nodes = fetch_all_nodes(self.client, graph_id) + adapter = GraphitiAdapter.get_or_create(graph_id) + nodes = adapter.get_all_nodes() - result = [] - for node in nodes: - node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or "" - result.append(NodeInfo( - uuid=str(node_uuid) if node_uuid else "", - name=node.name or "", - labels=node.labels or [], - summary=node.summary or "", - attributes=node.attributes or {} - )) - - logger.info(f"获取到 {len(result)} 个节点") + result = [NodeInfo( + uuid=n.get("uuid", ""), + name=n.get("name", ""), + labels=n.get("labels", []), + summary=n.get("summary", ""), + attributes=n.get("attributes", {}) + ) for n in nodes] + + logger.info(f"Retrieved {len(result)} nodes") return result def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]: """ - 获取图谱的所有边(分页获取,包含时间信息) + Get all edges in the graph (with temporal information) Args: - graph_id: 图谱ID - include_temporal: 是否包含时间信息(默认True) + graph_id: Graph ID + include_temporal: Whether to include temporal information (default True) Returns: - 边列表(包含created_at, valid_at, invalid_at, expired_at) + List of edges (with created_at, valid_at, invalid_at, expired_at) """ - logger.info(f"获取图谱 {graph_id} 的所有边...") - - edges = fetch_all_edges(self.client, graph_id) - - result = [] - for edge in edges: - edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or "" - edge_info = EdgeInfo( - uuid=str(edge_uuid) if edge_uuid else "", - name=edge.name or "", - fact=edge.fact or "", - source_node_uuid=edge.source_node_uuid or "", - target_node_uuid=edge.target_node_uuid or "" - ) - - # 添加时间信息 - if include_temporal: - edge_info.created_at = getattr(edge, 'created_at', None) - edge_info.valid_at = getattr(edge, 'valid_at', None) - edge_info.invalid_at = getattr(edge, 'invalid_at', None) - edge_info.expired_at = getattr(edge, 'expired_at', None) - - result.append(edge_info) - - logger.info(f"获取到 {len(result)} 条边") + logger.info(f"Getting all edges in graph {graph_id}...") + + adapter = GraphitiAdapter.get_or_create(graph_id) + edges = adapter.get_all_edges() + + result = [EdgeInfo( + uuid=e.get("uuid", ""), + name=e.get("name", ""), + fact=e.get("fact", ""), + source_node_uuid=e.get("source_node_uuid", ""), + target_node_uuid=e.get("target_node_uuid", ""), + created_at=e.get("created_at"), + valid_at=e.get("valid_at"), + invalid_at=e.get("invalid_at"), + expired_at=e.get("expired_at"), + ) for e in edges] + + logger.info(f"Retrieved {len(result)} edges") return result - def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: + def get_node_detail(self, graph_id: str, node_uuid: str) -> Optional[NodeInfo]: """ - 获取单个节点的详细信息 - + Get detailed information about a single node + Args: - node_uuid: 节点UUID - + graph_id: Graph ID + node_uuid: Node UUID + Returns: - 节点信息或None + Node info or None """ - logger.info(f"获取节点详情: {node_uuid[:8]}...") - + logger.info(f"Getting node details: {node_uuid[:8]}...") + try: - node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), - operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)" - ) - - if not node: - return None - - return NodeInfo( - uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - name=node.name or "", - labels=node.labels or [], - summary=node.summary or "", - attributes=node.attributes or {} - ) + adapter = GraphitiAdapter.get_or_create(graph_id) + node = adapter.get_node(node_uuid) + if node: + return NodeInfo( + uuid=node.get("uuid", ""), + name=node.get("name", ""), + labels=node.get("labels", []), + summary=node.get("summary", ""), + attributes=node.get("attributes", {}) + ) + return None except Exception as e: - logger.error(f"获取节点详情失败: {str(e)}") + logger.error(f"Failed to get node details: {str(e)}") return None def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]: """ - 获取节点相关的所有边 - - 通过获取图谱所有边,然后过滤出与指定节点相关的边 - + Get all edges related to a node + + Gets all edges in the graph, then filters those related to the specified node. + Args: - graph_id: 图谱ID - node_uuid: 节点UUID - + graph_id: Graph ID + node_uuid: Node UUID + Returns: - 边列表 + List of edges """ - logger.info(f"获取节点 {node_uuid[:8]}... 的相关边") - + logger.info(f"Getting edges related to node {node_uuid[:8]}...") + try: - # 获取图谱所有边,然后过滤 - all_edges = self.get_all_edges(graph_id) - - result = [] - for edge in all_edges: - # 检查边是否与指定节点相关(作为源或目标) - if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid: - result.append(edge) - - logger.info(f"找到 {len(result)} 条与节点相关的边") + adapter = GraphitiAdapter.get_or_create(graph_id) + all_edges = adapter.get_all_edges() + # Filter locally + result = [EdgeInfo( + uuid=e.get("uuid", ""), + name=e.get("name", ""), + fact=e.get("fact", ""), + source_node_uuid=e.get("source_node_uuid", ""), + target_node_uuid=e.get("target_node_uuid", ""), + created_at=e.get("created_at"), + valid_at=e.get("valid_at"), + invalid_at=e.get("invalid_at"), + expired_at=e.get("expired_at"), + ) for e in all_edges if e.get("source_node_uuid") == node_uuid or e.get("target_node_uuid") == node_uuid] + + logger.info(f"Found {len(result)} edges related to the node") return result - + except Exception as e: - logger.warning(f"获取节点边失败: {str(e)}") + logger.warning(f"Failed to get node edges: {str(e)}") return [] def get_entities_by_type( @@ -783,26 +718,26 @@ def get_entities_by_type( entity_type: str ) -> List[NodeInfo]: """ - 按类型获取实体 - + Get entities by type + Args: - graph_id: 图谱ID - entity_type: 实体类型(如 Student, PublicFigure 等) - + graph_id: Graph ID + entity_type: Entity type (e.g. Student, PublicFigure, etc.) + Returns: - 符合类型的实体列表 + List of entities matching the type """ - logger.info(f"获取类型为 {entity_type} 的实体...") + logger.info(f"Getting entities of type {entity_type}...") all_nodes = self.get_all_nodes(graph_id) filtered = [] for node in all_nodes: - # 检查labels是否包含指定类型 + # Check if labels contain the specified type if entity_type in node.labels: filtered.append(node) - logger.info(f"找到 {len(filtered)} 个 {entity_type} 类型的实体") + logger.info(f"Found {len(filtered)} entities of type {entity_type}") return filtered def get_entity_summary( @@ -811,27 +746,27 @@ def get_entity_summary( entity_name: str ) -> Dict[str, Any]: """ - 获取指定实体的关系摘要 - - 搜索与该实体相关的所有信息,并生成摘要 - + Get relationship summary for a specific entity + + Searches all information related to the entity and generates a summary. + Args: - graph_id: 图谱ID - entity_name: 实体名称 - + graph_id: Graph ID + entity_name: Entity name + Returns: - 实体摘要信息 + Entity summary information """ - logger.info(f"获取实体 {entity_name} 的关系摘要...") + logger.info(f"Getting relationship summary for entity {entity_name}...") - # 先搜索该实体相关的信息 + # First search for information related to the entity search_result = self.search_graph( graph_id=graph_id, query=entity_name, limit=20 ) - # 尝试在所有节点中找到该实体 + # Try to find the entity among all nodes all_nodes = self.get_all_nodes(graph_id) entity_node = None for node in all_nodes: @@ -841,7 +776,7 @@ def get_entity_summary( related_edges = [] if entity_node: - # 传入graph_id参数 + # Pass graph_id parameter related_edges = self.get_node_edges(graph_id, entity_node.uuid) return { @@ -854,27 +789,27 @@ def get_entity_summary( def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]: """ - 获取图谱的统计信息 - + Get statistics for the graph + Args: - graph_id: 图谱ID - + graph_id: Graph ID + Returns: - 统计信息 + Statistics information """ - logger.info(f"获取图谱 {graph_id} 的统计信息...") + logger.info(f"Getting statistics for graph {graph_id}...") nodes = self.get_all_nodes(graph_id) edges = self.get_all_edges(graph_id) - # 统计实体类型分布 + # Count entity type distribution entity_types = {} for node in nodes: for label in node.labels: if label not in ["Entity", "Node"]: entity_types[label] = entity_types.get(label, 0) + 1 - # 统计关系类型分布 + # Count relation type distribution relation_types = {} for edge in edges: relation_types[edge.name] = relation_types.get(edge.name, 0) + 1 @@ -894,34 +829,34 @@ def get_simulation_context( limit: int = 30 ) -> Dict[str, Any]: """ - 获取模拟相关的上下文信息 - - 综合搜索与模拟需求相关的所有信息 - + Get simulation-related context information + + Comprehensively searches all information related to the simulation requirement. + Args: - graph_id: 图谱ID - simulation_requirement: 模拟需求描述 - limit: 每类信息的数量限制 - + graph_id: Graph ID + simulation_requirement: Simulation requirement description + limit: Quantity limit per category + Returns: - 模拟上下文信息 + Simulation context information """ - logger.info(f"获取模拟上下文: {simulation_requirement[:50]}...") + logger.info(f"Getting simulation context: {simulation_requirement[:50]}...") - # 搜索与模拟需求相关的信息 + # Search for information related to the simulation requirement search_result = self.search_graph( graph_id=graph_id, query=simulation_requirement, limit=limit ) - # 获取图谱统计 + # Get graph statistics stats = self.get_graph_statistics(graph_id) - # 获取所有实体节点 + # Get all entity nodes all_nodes = self.get_all_nodes(graph_id) - # 筛选有实际类型的实体(非纯Entity节点) + # Filter entities with actual types (not pure Entity nodes) entities = [] for node in all_nodes: custom_labels = [l for l in node.labels if l not in ["Entity", "Node"]] @@ -936,11 +871,11 @@ def get_simulation_context( "simulation_requirement": simulation_requirement, "related_facts": search_result.facts, "graph_statistics": stats, - "entities": entities[:limit], # 限制数量 + "entities": entities[:limit], "total_entities": len(entities) } - # ========== 核心检索工具(优化后) ========== + # ========== Core Retrieval Tools (Optimized) ========== def insight_forge( self, @@ -951,26 +886,26 @@ def insight_forge( max_sub_queries: int = 5 ) -> InsightForgeResult: """ - 【InsightForge - 深度洞察检索】 - - 最强大的混合检索函数,自动分解问题并多维度检索: - 1. 使用LLM将问题分解为多个子问题 - 2. 对每个子问题进行语义搜索 - 3. 提取相关实体并获取其详细信息 - 4. 追踪关系链 - 5. 整合所有结果,生成深度洞察 - + [InsightForge - Deep Insight Retrieval] + + The most powerful hybrid retrieval function, automatically decomposes problems and performs multi-dimensional retrieval: + 1. Use LLM to decompose the problem into multiple sub-questions + 2. Perform semantic search on each sub-question + 3. Extract related entities and get their detailed information + 4. Trace relationship chains + 5. Integrate all results and generate deep insights + Args: - graph_id: 图谱ID - query: 用户问题 - simulation_requirement: 模拟需求描述 - report_context: 报告上下文(可选,用于更精准的子问题生成) - max_sub_queries: 最大子问题数量 - + graph_id: Graph ID + query: User question + simulation_requirement: Simulation requirement description + report_context: Report context (optional, for more precise sub-question generation) + max_sub_queries: Maximum number of sub-questions + Returns: - InsightForgeResult: 深度洞察检索结果 + InsightForgeResult: Deep insight retrieval result """ - logger.info(f"InsightForge 深度洞察检索: {query[:50]}...") + logger.info(f"InsightForge deep insight retrieval: {query[:50]}...") result = InsightForgeResult( query=query, @@ -978,7 +913,7 @@ def insight_forge( sub_queries=[] ) - # Step 1: 使用LLM生成子问题 + # Step 1: Use LLM to generate sub-questions sub_queries = self._generate_sub_queries( query=query, simulation_requirement=simulation_requirement, @@ -986,9 +921,9 @@ def insight_forge( max_queries=max_sub_queries ) result.sub_queries = sub_queries - logger.info(f"生成 {len(sub_queries)} 个子问题") + logger.info(f"Generated {len(sub_queries)} sub-questions") - # Step 2: 对每个子问题进行语义搜索 + # Step 2: Perform semantic search on each sub-question all_facts = [] all_edges = [] seen_facts = set() @@ -1008,7 +943,7 @@ def insight_forge( all_edges.extend(search_result.edges) - # 对原始问题也进行搜索 + # Also search for the original question main_search = self.search_graph( graph_id=graph_id, query=query, @@ -1023,7 +958,7 @@ def insight_forge( result.semantic_facts = all_facts result.total_facts = len(all_facts) - # Step 3: 从边中提取相关实体UUID,只获取这些实体的信息(不获取全部节点) + # Step 3: Extract related entity UUIDs from edges entity_uuids = set() for edge_data in all_edges: if isinstance(edge_data, dict): @@ -1034,21 +969,18 @@ def insight_forge( if target_uuid: entity_uuids.add(target_uuid) - # 获取所有相关实体的详情(不限制数量,完整输出) + # Get related entity details entity_insights = [] - node_map = {} # 用于后续关系链构建 - - for uuid in list(entity_uuids): # 处理所有实体,不截断 + node_map = {} + + for uuid in list(entity_uuids): if not uuid: continue try: - # 单独获取每个相关节点的信息 - node = self.get_node_detail(uuid) + node = self.get_node_detail(graph_id, uuid) if node: node_map[uuid] = node - entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") - - # 获取该实体相关的所有事实(不截断) + entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "Entity") related_facts = [ f for f in all_facts if node.name.lower() in f.lower() @@ -1059,18 +991,18 @@ def insight_forge( "name": node.name, "type": entity_type, "summary": node.summary, - "related_facts": related_facts # 完整输出,不截断 + "related_facts": related_facts }) except Exception as e: - logger.debug(f"获取节点 {uuid} 失败: {e}") + logger.debug(f"Failed to get node {uuid}: {e}") continue result.entity_insights = entity_insights result.total_entities = len(entity_insights) - # Step 4: 构建所有关系链(不限制数量) + # Step 4: Build relationship chains relationship_chains = [] - for edge_data in all_edges: # 处理所有边,不截断 + for edge_data in all_edges: if isinstance(edge_data, dict): source_uuid = edge_data.get('source_node_uuid', '') target_uuid = edge_data.get('target_node_uuid', '') @@ -1086,7 +1018,7 @@ def insight_forge( result.relationship_chains = relationship_chains result.total_relationships = len(relationship_chains) - logger.info(f"InsightForge完成: {result.total_facts}条事实, {result.total_entities}个实体, {result.total_relationships}条关系") + logger.info(f"InsightForge complete: {result.total_facts} facts, {result.total_entities} entities, {result.total_relationships} relationships") return result def _generate_sub_queries( @@ -1096,28 +1028,24 @@ def _generate_sub_queries( report_context: str = "", max_queries: int = 5 ) -> List[str]: - """ - 使用LLM生成子问题 - - 将复杂问题分解为多个可以独立检索的子问题 - """ - system_prompt = """你是一个专业的问题分析专家。你的任务是将一个复杂问题分解为多个可以在模拟世界中独立观察的子问题。 + """Use LLM to generate sub-questions""" + system_prompt = """You are a professional question analysis expert. Your task is to decompose a complex question into multiple sub-questions that can be independently observed in a simulated world. -要求: -1. 每个子问题应该足够具体,可以在模拟世界中找到相关的Agent行为或事件 -2. 子问题应该覆盖原问题的不同维度(如:谁、什么、为什么、怎么样、何时、何地) -3. 子问题应该与模拟场景相关 -4. 返回JSON格式:{"sub_queries": ["子问题1", "子问题2", ...]}""" +Requirements: +1. Each sub-question should be specific enough to find related Agent behavior or events in the simulated world +2. Sub-questions should cover different dimensions of the original question (e.g., who, what, why, how, when, where) +3. Sub-questions should be relevant to the simulation scenario +4. Return in JSON format: {"sub_queries": ["sub-question 1", "sub-question 2", ...]}""" - user_prompt = f"""模拟需求背景: + user_prompt = f"""Simulation requirement background: {simulation_requirement} -{f"报告上下文:{report_context[:500]}" if report_context else ""} +{f"Report context: {report_context[:500]}" if report_context else ""} -请将以下问题分解为{max_queries}个子问题: +Please decompose the following question into {max_queries} sub-questions: {query} -返回JSON格式的子问题列表。""" +Return the sub-questions as a JSON list.""" try: response = self.llm.chat_json( @@ -1129,17 +1057,15 @@ def _generate_sub_queries( ) sub_queries = response.get("sub_queries", []) - # 确保是字符串列表 return [str(sq) for sq in sub_queries[:max_queries]] - + except Exception as e: - logger.warning(f"生成子问题失败: {str(e)},使用默认子问题") - # 降级:返回基于原问题的变体 + logger.warning(f"Failed to generate sub-questions: {str(e)}, using default sub-questions") return [ query, - f"{query} 的主要参与者", - f"{query} 的原因和影响", - f"{query} 的发展过程" + f"Main participants in {query}", + f"Causes and impacts of {query}", + f"Development process of {query}" ][:max_queries] def panorama_search( @@ -1150,40 +1076,35 @@ def panorama_search( limit: int = 50 ) -> PanoramaResult: """ - 【PanoramaSearch - 广度搜索】 - - 获取全貌视图,包括所有相关内容和历史/过期信息: - 1. 获取所有相关节点 - 2. 获取所有边(包括已过期/失效的) - 3. 分类整理当前有效和历史信息 - - 这个工具适用于需要了解事件全貌、追踪演变过程的场景。 - + [PanoramaSearch - Breadth Search] + + Get a comprehensive panoramic view, including all related content and historical/expired information. + Args: - graph_id: 图谱ID - query: 搜索查询(用于相关性排序) - include_expired: 是否包含过期内容(默认True) - limit: 返回结果数量限制 - + graph_id: Graph ID + query: Search query (for relevance sorting) + include_expired: Whether to include expired content (default True) + limit: Result quantity limit + Returns: - PanoramaResult: 广度搜索结果 + PanoramaResult: Breadth search result """ - logger.info(f"PanoramaSearch 广度搜索: {query[:50]}...") + logger.info(f"PanoramaSearch breadth search: {query[:50]}...") result = PanoramaResult(query=query) - # 获取所有节点 + # Get all nodes all_nodes = self.get_all_nodes(graph_id) node_map = {n.uuid: n for n in all_nodes} result.all_nodes = all_nodes result.total_nodes = len(all_nodes) - # 获取所有边(包含时间信息) + # Get all edges (including temporal information) all_edges = self.get_all_edges(graph_id, include_temporal=True) result.all_edges = all_edges result.total_edges = len(all_edges) - # 分类事实 + # Categorize facts active_facts = [] historical_facts = [] @@ -1191,24 +1112,22 @@ def panorama_search( if not edge.fact: continue - # 为事实添加实体名称 + # Add entity names to facts source_name = node_map.get(edge.source_node_uuid, NodeInfo('', '', [], '', {})).name or edge.source_node_uuid[:8] target_name = node_map.get(edge.target_node_uuid, NodeInfo('', '', [], '', {})).name or edge.target_node_uuid[:8] - # 判断是否过期/失效 + # Check if expired/invalid is_historical = edge.is_expired or edge.is_invalid if is_historical: - # 历史/过期事实,添加时间标记 - valid_at = edge.valid_at or "未知" - invalid_at = edge.invalid_at or edge.expired_at or "未知" + valid_at = edge.valid_at or "Unknown" + invalid_at = edge.invalid_at or edge.expired_at or "Unknown" fact_with_time = f"[{valid_at} - {invalid_at}] {edge.fact}" historical_facts.append(fact_with_time) else: - # 当前有效事实 active_facts.append(edge.fact) - # 基于查询进行相关性排序 + # Sort by relevance based on query query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] @@ -1222,7 +1141,7 @@ def relevance_score(fact: str) -> int: score += 10 return score - # 排序并限制数量 + # Sort and limit quantity active_facts.sort(key=relevance_score, reverse=True) historical_facts.sort(key=relevance_score, reverse=True) @@ -1231,7 +1150,7 @@ def relevance_score(fact: str) -> int: result.active_count = len(active_facts) result.historical_count = len(historical_facts) - logger.info(f"PanoramaSearch完成: {result.active_count}条有效, {result.historical_count}条历史") + logger.info(f"PanoramaSearch complete: {result.active_count} valid, {result.historical_count} historical") return result def quick_search( @@ -1241,24 +1160,21 @@ def quick_search( limit: int = 10 ) -> SearchResult: """ - 【QuickSearch - 简单搜索】 - - 快速、轻量级的检索工具: - 1. 直接调用Zep语义搜索 - 2. 返回最相关的结果 - 3. 适用于简单、直接的检索需求 - + [QuickSearch - Simple Search] + + Fast and lightweight retrieval tool. + Args: - graph_id: 图谱ID - query: 搜索查询 - limit: 返回结果数量 - + graph_id: Graph ID + query: Search query + limit: Number of results to return + Returns: - SearchResult: 搜索结果 + SearchResult """ - logger.info(f"QuickSearch 简单搜索: {query[:50]}...") + logger.info(f"QuickSearch simple search: {query[:50]}...") - # 直接调用现有的search_graph方法 + # Directly call existing search_graph method result = self.search_graph( graph_id=graph_id, query=query, @@ -1266,7 +1182,7 @@ def quick_search( scope="edges" ) - logger.info(f"QuickSearch完成: {result.total_count}条结果") + logger.info(f"QuickSearch complete: {result.total_count} results") return result def interview_agents( @@ -1278,53 +1194,48 @@ def interview_agents( custom_questions: List[str] = None ) -> InterviewResult: """ - 【InterviewAgents - 深度采访】 - - 调用真实的OASIS采访API,采访模拟中正在运行的Agent: - 1. 自动读取人设文件,了解所有模拟Agent - 2. 使用LLM分析采访需求,智能选择最相关的Agent - 3. 使用LLM生成采访问题 - 4. 调用 /api/simulation/interview/batch 接口进行真实采访(双平台同时采访) - 5. 整合所有采访结果,生成采访报告 - - 【重要】此功能需要模拟环境处于运行状态(OASIS环境未关闭) - - 【使用场景】 - - 需要从不同角色视角了解事件看法 - - 需要收集多方意见和观点 - - 需要获取模拟Agent的真实回答(非LLM模拟) - + [InterviewAgents - Deep Interview] + + Call the real OASIS interview API to interview Agents running in the simulation: + 1. Automatically read profile files to understand all simulated Agents + 2. Use LLM to analyze interview requirements and intelligently select the most relevant Agents + 3. Use LLM to generate interview questions + 4. Call /api/simulation/interview/batch endpoint for real interviews (dual platform) + 5. Integrate all interview results and generate interview report + + [Important] This feature requires the simulation environment to be running (OASIS environment not closed) + Args: - simulation_id: 模拟ID(用于定位人设文件和调用采访API) - interview_requirement: 采访需求描述(非结构化,如"了解学生对事件的看法") - simulation_requirement: 模拟需求背景(可选) - max_agents: 最多采访的Agent数量 - custom_questions: 自定义采访问题(可选,若不提供则自动生成) - + simulation_id: Simulation ID (for locating profile files and calling interview API) + interview_requirement: Interview requirement description + simulation_requirement: Simulation requirement background (optional) + max_agents: Maximum number of Agents to interview + custom_questions: Custom interview questions (optional, auto-generated if not provided) + Returns: - InterviewResult: 采访结果 + InterviewResult: Interview result """ from .simulation_runner import SimulationRunner - logger.info(f"InterviewAgents 深度采访(真实API): {interview_requirement[:50]}...") + logger.info(f"InterviewAgents deep interview (real API): {interview_requirement[:50]}...") result = InterviewResult( interview_topic=interview_requirement, interview_questions=custom_questions or [] ) - # Step 1: 读取人设文件 + # Step 1: Read agent profile files profiles = self._load_agent_profiles(simulation_id) if not profiles: - logger.warning(f"未找到模拟 {simulation_id} 的人设文件") - result.summary = "未找到可采访的Agent人设文件" + logger.warning(f"No profile files found for simulation {simulation_id}") + result.summary = "No Agent profile files found for interview" return result result.total_agents = len(profiles) - logger.info(f"加载到 {len(profiles)} 个Agent人设") + logger.info(f"Loaded {len(profiles)} Agent profiles") - # Step 2: 使用LLM选择要采访的Agent(返回agent_id列表) + # Step 2: Use LLM to select Agents for interview selected_agents, selected_indices, selection_reasoning = self._select_agents_for_interview( profiles=profiles, interview_requirement=interview_requirement, @@ -1334,156 +1245,142 @@ def interview_agents( result.selected_agents = selected_agents result.selection_reasoning = selection_reasoning - logger.info(f"选择了 {len(selected_agents)} 个Agent进行采访: {selected_indices}") + logger.info(f"Selected {len(selected_agents)} Agents for interview: {selected_indices}") - # Step 3: 生成采访问题(如果没有提供) + # Step 3: Generate interview questions if not result.interview_questions: result.interview_questions = self._generate_interview_questions( interview_requirement=interview_requirement, simulation_requirement=simulation_requirement, selected_agents=selected_agents ) - logger.info(f"生成了 {len(result.interview_questions)} 个采访问题") + logger.info(f"Generated {len(result.interview_questions)} interview questions") - # 将问题合并为一个采访prompt + # Combine questions into a single interview prompt combined_prompt = "\n".join([f"{i+1}. {q}" for i, q in enumerate(result.interview_questions)]) - # 添加优化前缀,约束Agent回复格式 INTERVIEW_PROMPT_PREFIX = ( - "你正在接受一次采访。请结合你的人设、所有的过往记忆与行动," - "以纯文本方式直接回答以下问题。\n" - "回复要求:\n" - "1. 直接用自然语言回答,不要调用任何工具\n" - "2. 不要返回JSON格式或工具调用格式\n" - "3. 不要使用Markdown标题(如#、##、###)\n" - "4. 按问题编号逐一回答,每个回答以「问题X:」开头(X为问题编号)\n" - "5. 每个问题的回答之间用空行分隔\n" - "6. 回答要有实质内容,每个问题至少回答2-3句话\n\n" + "You are being interviewed. Please combine your character profile, all past memories and actions, " + "and directly answer the following questions in plain text.\n" + "Response requirements:\n" + "1. Answer directly in natural language, do not call any tools\n" + "2. Do not return JSON format or tool call format\n" + "3. Do not use Markdown headings (e.g., #, ##, ###)\n" + "4. Answer the questions in order, with each answer starting with 'Question X:' (X is the question number)\n" + "5. Separate each answer with a blank line\n" + "6. Provide substantive answers, at least 2-3 sentences per question\n\n" ) optimized_prompt = f"{INTERVIEW_PROMPT_PREFIX}{combined_prompt}" - - # Step 4: 调用真实的采访API(不指定platform,默认双平台同时采访) + + # Step 4: Call the real interview API try: - # 构建批量采访列表(不指定platform,双平台采访) interviews_request = [] for agent_idx in selected_indices: interviews_request.append({ "agent_id": agent_idx, - "prompt": optimized_prompt # 使用优化后的prompt - # 不指定platform,API会在twitter和reddit两个平台都采访 + "prompt": optimized_prompt }) - - logger.info(f"调用批量采访API(双平台): {len(interviews_request)} 个Agent") - - # 调用 SimulationRunner 的批量采访方法(不传platform,双平台采访) + + logger.info(f"Calling batch interview API (dual platform): {len(interviews_request)} Agents") + api_result = SimulationRunner.interview_agents_batch( simulation_id=simulation_id, interviews=interviews_request, - platform=None, # 不指定platform,双平台采访 - timeout=180.0 # 双平台需要更长超时 + platform=None, + timeout=180.0 ) - - logger.info(f"采访API返回: {api_result.get('interviews_count', 0)} 个结果, success={api_result.get('success')}") - - # 检查API调用是否成功 + + logger.info(f"Interview API returned: {api_result.get('interviews_count', 0)} results, success={api_result.get('success')}") + if not api_result.get("success", False): - error_msg = api_result.get("error", "未知错误") - logger.warning(f"采访API返回失败: {error_msg}") - result.summary = f"采访API调用失败:{error_msg}。请检查OASIS模拟环境状态。" + error_msg = api_result.get("error", "Unknown error") + logger.warning(f"Interview API call failed: {error_msg}") + result.summary = f"Interview API call failed: {error_msg}. Please check the OASIS simulation environment status." return result - - # Step 5: 解析API返回结果,构建AgentInterview对象 - # 双平台模式返回格式: {"twitter_0": {...}, "reddit_0": {...}, "twitter_1": {...}, ...} + + # Step 5: Parse API response api_data = api_result.get("result", {}) results_dict = api_data.get("results", {}) if isinstance(api_data, dict) else {} - + for i, agent_idx in enumerate(selected_indices): agent = selected_agents[i] agent_name = agent.get("realname", agent.get("username", f"Agent_{agent_idx}")) - agent_role = agent.get("profession", "未知") + agent_role = agent.get("profession", "Unknown") agent_bio = agent.get("bio", "") - - # 获取该Agent在两个平台的采访结果 + twitter_result = results_dict.get(f"twitter_{agent_idx}", {}) reddit_result = results_dict.get(f"reddit_{agent_idx}", {}) twitter_response = twitter_result.get("response", "") reddit_response = reddit_result.get("response", "") - # 清理可能的工具调用 JSON 包裹 twitter_response = self._clean_tool_call_response(twitter_response) reddit_response = self._clean_tool_call_response(reddit_response) - # 始终输出双平台标记 - twitter_text = twitter_response if twitter_response else "(该平台未获得回复)" - reddit_text = reddit_response if reddit_response else "(该平台未获得回复)" - response_text = f"【Twitter平台回答】\n{twitter_text}\n\n【Reddit平台回答】\n{reddit_text}" + twitter_text = twitter_response if twitter_response else "(No response from this platform)" + reddit_text = reddit_response if reddit_response else "(No response from this platform)" + response_text = f"[Twitter Platform Response]\n{twitter_text}\n\n[Reddit Platform Response]\n{reddit_text}" - # 提取关键引言(从两个平台的回答中) import re combined_responses = f"{twitter_response} {reddit_response}" - # 清理响应文本:去掉标记、编号、Markdown 等干扰 clean_text = re.sub(r'#{1,6}\s+', '', combined_responses) clean_text = re.sub(r'\{[^}]*tool_name[^}]*\}', '', clean_text) clean_text = re.sub(r'[*_`|>~\-]{2,}', '', clean_text) - clean_text = re.sub(r'问题\d+[::]\s*', '', clean_text) + clean_text = re.sub(r'Question\d+[::]\s*', '', clean_text) clean_text = re.sub(r'【[^】]+】', '', clean_text) - # 策略1(主): 提取完整的有实质内容的句子 sentences = re.split(r'[。!?]', clean_text) meaningful = [ s.strip() for s in sentences if 20 <= len(s.strip()) <= 150 and not re.match(r'^[\s\W,,;;::、]+', s.strip()) - and not s.strip().startswith(('{', '问题')) + and not s.strip().startswith(('{', 'Question')) ] meaningful.sort(key=len, reverse=True) key_quotes = [s + "。" for s in meaningful[:3]] - # 策略2(补充): 正确配对的中文引号「」内长文本 if not key_quotes: paired = re.findall(r'\u201c([^\u201c\u201d]{15,100})\u201d', clean_text) paired += re.findall(r'\u300c([^\u300c\u300d]{15,100})\u300d', clean_text) key_quotes = [q for q in paired if not re.match(r'^[,,;;::、]', q)][:3] - + interview = AgentInterview( agent_name=agent_name, agent_role=agent_role, - agent_bio=agent_bio[:1000], # 扩大bio长度限制 + agent_bio=agent_bio[:1000], question=combined_prompt, response=response_text, key_quotes=key_quotes[:5] ) result.interviews.append(interview) - + result.interviewed_count = len(result.interviews) - + except ValueError as e: - # 模拟环境未运行 - logger.warning(f"采访API调用失败(环境未运行?): {e}") - result.summary = f"采访失败:{str(e)}。模拟环境可能已关闭,请确保OASIS环境正在运行。" + logger.warning(f"Interview API call failed (environment not running?): {e}") + result.summary = f"Interview failed: {str(e)}. The simulation environment may be closed. Please ensure the OASIS environment is running." return result except Exception as e: - logger.error(f"采访API调用异常: {e}") + logger.error(f"Interview API call exception: {e}") import traceback logger.error(traceback.format_exc()) - result.summary = f"采访过程发生错误:{str(e)}" + result.summary = f"An error occurred during the interview process: {str(e)}" return result - # Step 6: 生成采访摘要 + # Step 6: Generate interview summary if result.interviews: result.summary = self._generate_interview_summary( interviews=result.interviews, interview_requirement=interview_requirement ) - logger.info(f"InterviewAgents完成: 采访了 {result.interviewed_count} 个Agent(双平台)") + logger.info(f"InterviewAgents complete: Interviewed {result.interviewed_count} Agents (dual platform)") return result @staticmethod def _clean_tool_call_response(response: str) -> str: - """清理 Agent 回复中的 JSON 工具调用包裹,提取实际内容""" + """Clean JSON tool call wrappers in Agent responses and extract actual content""" if not response or not response.strip().startswith('{'): return response text = response.strip() @@ -1503,11 +1400,11 @@ def _clean_tool_call_response(response: str) -> str: return response def _load_agent_profiles(self, simulation_id: str) -> List[Dict[str, Any]]: - """加载模拟的Agent人设文件""" + """Load Agent profile files for simulation""" import os import csv - # 构建人设文件路径 + # Build profile file path sim_dir = os.path.join( os.path.dirname(__file__), f'../../uploads/simulations/{simulation_id}' @@ -1515,36 +1412,35 @@ def _load_agent_profiles(self, simulation_id: str) -> List[Dict[str, Any]]: profiles = [] - # 优先尝试读取Reddit JSON格式 + # Preferentially try to read Reddit JSON format reddit_profile_path = os.path.join(sim_dir, "reddit_profiles.json") if os.path.exists(reddit_profile_path): try: with open(reddit_profile_path, 'r', encoding='utf-8') as f: profiles = json.load(f) - logger.info(f"从 reddit_profiles.json 加载了 {len(profiles)} 个人设") + logger.info(f"Loaded {len(profiles)} profiles from reddit_profiles.json") return profiles except Exception as e: - logger.warning(f"读取 reddit_profiles.json 失败: {e}") + logger.warning(f"Failed to read reddit_profiles.json: {e}") - # 尝试读取Twitter CSV格式 + # Try to read Twitter CSV format twitter_profile_path = os.path.join(sim_dir, "twitter_profiles.csv") if os.path.exists(twitter_profile_path): try: with open(twitter_profile_path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: - # CSV格式转换为统一格式 - profiles.append({ + profiles.append({ "realname": row.get("name", ""), "username": row.get("username", ""), "bio": row.get("description", ""), "persona": row.get("user_char", ""), - "profession": "未知" + "profession": "Unknown" }) - logger.info(f"从 twitter_profiles.csv 加载了 {len(profiles)} 个人设") + logger.info(f"Loaded {len(profiles)} profiles from twitter_profiles.csv") return profiles except Exception as e: - logger.warning(f"读取 twitter_profiles.csv 失败: {e}") + logger.warning(f"Failed to read twitter_profiles.csv: {e}") return profiles @@ -1555,52 +1451,44 @@ def _select_agents_for_interview( simulation_requirement: str, max_agents: int ) -> tuple: - """ - 使用LLM选择要采访的Agent - - Returns: - tuple: (selected_agents, selected_indices, reasoning) - - selected_agents: 选中Agent的完整信息列表 - - selected_indices: 选中Agent的索引列表(用于API调用) - - reasoning: 选择理由 - """ - - # 构建Agent摘要列表 + """Use LLM to select Agents for interview""" + + # Build agent summaries agent_summaries = [] for i, profile in enumerate(profiles): summary = { "index": i, "name": profile.get("realname", profile.get("username", f"Agent_{i}")), - "profession": profile.get("profession", "未知"), + "profession": profile.get("profession", "Unknown"), "bio": profile.get("bio", "")[:200], "interested_topics": profile.get("interested_topics", []) } agent_summaries.append(summary) - system_prompt = """你是一个专业的采访策划专家。你的任务是根据采访需求,从模拟Agent列表中选择最适合采访的对象。 + system_prompt = """You are a professional interview planning expert. Your task is to select the most suitable Agents for interview from the simulated Agent list based on the interview requirements. -选择标准: -1. Agent的身份/职业与采访主题相关 -2. Agent可能持有独特或有价值的观点 -3. 选择多样化的视角(如:支持方、反对方、中立方、专业人士等) -4. 优先选择与事件直接相关的角色 +Selection Criteria: +1. Agent's identity/profession is relevant to the interview topic +2. Agent may hold unique or valuable perspectives +3. Select diverse perspectives (e.g., supporters, opposers, neutral, experts, etc.) +4. Prioritize roles directly related to the event -返回JSON格式: +Return JSON format: { - "selected_indices": [选中Agent的索引列表], - "reasoning": "选择理由说明" + "selected_indices": [List of indices of selected Agents], + "reasoning": "Explanation of selection rationale" }""" - user_prompt = f"""采访需求: + user_prompt = f"""Interview Requirement: {interview_requirement} -模拟背景: -{simulation_requirement if simulation_requirement else "未提供"} +Simulation Background: +{simulation_requirement if simulation_requirement else "Not provided"} -可选择的Agent列表(共{len(agent_summaries)}个): +Available Agent List ({len(agent_summaries)} total): {json.dumps(agent_summaries, ensure_ascii=False, indent=2)} -请选择最多{max_agents}个最适合采访的Agent,并说明选择理由。""" +Please select up to {max_agents} most suitable Agents for interview and explain your selection rationale.""" try: response = self.llm.chat_json( @@ -1612,9 +1500,9 @@ def _select_agents_for_interview( ) selected_indices = response.get("selected_indices", [])[:max_agents] - reasoning = response.get("reasoning", "基于相关性自动选择") + reasoning = response.get("reasoning", "Automatically selected based on relevance") - # 获取选中的Agent完整信息 + # Get full info for selected agents selected_agents = [] valid_indices = [] for idx in selected_indices: @@ -1625,11 +1513,10 @@ def _select_agents_for_interview( return selected_agents, valid_indices, reasoning except Exception as e: - logger.warning(f"LLM选择Agent失败,使用默认选择: {e}") - # 降级:选择前N个 + logger.warning(f"LLM agent selection failed, using default selection: {e}") selected = profiles[:max_agents] indices = list(range(min(max_agents, len(profiles)))) - return selected, indices, "使用默认选择策略" + return selected, indices, "Using default selection strategy" def _generate_interview_questions( self, @@ -1637,29 +1524,29 @@ def _generate_interview_questions( simulation_requirement: str, selected_agents: List[Dict[str, Any]] ) -> List[str]: - """使用LLM生成采访问题""" - - agent_roles = [a.get("profession", "未知") for a in selected_agents] - - system_prompt = """你是一个专业的记者/采访者。根据采访需求,生成3-5个深度采访问题。 + """Use LLM to generate interview questions""" + + agent_roles = [a.get("profession", "Unknown") for a in selected_agents] + + system_prompt = """You are a professional journalist/interviewer. Based on the interview requirements, generate 3-5 deep interview questions. -问题要求: -1. 开放性问题,鼓励详细回答 -2. 针对不同角色可能有不同答案 -3. 涵盖事实、观点、感受等多个维度 -4. 语言自然,像真实采访一样 -5. 每个问题控制在50字以内,简洁明了 -6. 直接提问,不要包含背景说明或前缀 +Question Requirements: +1. Open-ended questions that encourage detailed answers +2. Questions that may have different answers for different roles +3. Cover multiple dimensions: facts, viewpoints, feelings, etc. +4. Natural language, like real interviews +5. Keep each question under 50 characters, concise and clear +6. Ask directly, do not include background explanation or prefix -返回JSON格式:{"questions": ["问题1", "问题2", ...]}""" +Return JSON format: {"questions": ["question1", "question2", ...]}""" - user_prompt = f"""采访需求:{interview_requirement} + user_prompt = f"""Interview Requirement: {interview_requirement} -模拟背景:{simulation_requirement if simulation_requirement else "未提供"} +Simulation Background: {simulation_requirement if simulation_requirement else "Not provided"} -采访对象角色:{', '.join(agent_roles)} +Interview Subject Roles: {', '.join(agent_roles)} -请生成3-5个采访问题。""" +Please generate 3-5 interview questions.""" try: response = self.llm.chat_json( @@ -1670,14 +1557,14 @@ def _generate_interview_questions( temperature=0.5 ) - return response.get("questions", [f"关于{interview_requirement},您有什么看法?"]) - + return response.get("questions", [f"What is your perspective on {interview_requirement}?"]) + except Exception as e: - logger.warning(f"生成采访问题失败: {e}") + logger.warning(f"Failed to generate interview questions: {e}") return [ - f"关于{interview_requirement},您的观点是什么?", - "这件事对您或您所代表的群体有什么影响?", - "您认为应该如何解决或改进这个问题?" + f"What is your perspective on {interview_requirement}?", + "What impact does this have on you or the group you represent?", + "How do you think this issue should be solved or improved?" ] def _generate_interview_summary( @@ -1685,38 +1572,37 @@ def _generate_interview_summary( interviews: List[AgentInterview], interview_requirement: str ) -> str: - """生成采访摘要""" - + """Generate interview summary""" + if not interviews: - return "未完成任何采访" - - # 收集所有采访内容 + return "No interviews completed" + interview_texts = [] for interview in interviews: - interview_texts.append(f"【{interview.agent_name}({interview.agent_role})】\n{interview.response[:500]}") - - system_prompt = """你是一个专业的新闻编辑。请根据多位受访者的回答,生成一份采访摘要。 + interview_texts.append(f"[{interview.agent_name} ({interview.agent_role})]\n{interview.response[:500]}") + + system_prompt = """You are a professional news editor. Please generate an interview summary based on the responses from multiple interviewees. -摘要要求: -1. 提炼各方主要观点 -2. 指出观点的共识和分歧 -3. 突出有价值的引言 -4. 客观中立,不偏袒任何一方 -5. 控制在1000字内 +Summary Requirements: +1. Extract main viewpoints from all parties +2. Point out consensus and disagreement among viewpoints +3. Highlight valuable quotes +4. Remain objective and neutral, do not favor any side +5. Keep it under 1000 words -格式约束(必须遵守): -- 使用纯文本段落,用空行分隔不同部分 -- 不要使用Markdown标题(如#、##、###) -- 不要使用分割线(如---、***) -- 引用受访者原话时使用中文引号「」 -- 可以使用**加粗**标记关键词,但不要使用其他Markdown语法""" +Format Constraints (Must Follow): +- Use plain text paragraphs, separated by blank lines +- Do not use Markdown headings (e.g., #, ##, ###) +- Do not use dividers (e.g., ---, ***) +- Use appropriate quotes when citing interviewees +- Can use **bold** to mark keywords, but do not use other Markdown syntax""" - user_prompt = f"""采访主题:{interview_requirement} + user_prompt = f"""Interview Topic: {interview_requirement} -采访内容: +Interview Content: {"".join(interview_texts)} -请生成采访摘要。""" +Please generate an interview summary.""" try: summary = self.llm.chat( @@ -1730,6 +1616,5 @@ def _generate_interview_summary( return summary except Exception as e: - logger.warning(f"生成采访摘要失败: {e}") - # 降级:简单拼接 - return f"共采访了{len(interviews)}位受访者,包括:" + "、".join([i.agent_name for i in interviews]) + logger.warning(f"Failed to generate interview summary: {e}") + return f"Interviewed {len(interviews)} interviewees, including: " + ", ".join([i.agent_name for i in interviews]) diff --git a/backend/app/utils/zep_paging.py b/backend/app/utils/zep_paging.py deleted file mode 100644 index 943cd1ae2..000000000 --- a/backend/app/utils/zep_paging.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Zep Graph 分页读取工具。 - -Zep 的 node/edge 列表接口使用 UUID cursor 分页, -本模块封装自动翻页逻辑(含单页重试),对调用方透明地返回完整列表。 -""" - -from __future__ import annotations - -import time -from collections.abc import Callable -from typing import Any - -from zep_cloud import InternalServerError -from zep_cloud.client import Zep - -from .logger import get_logger - -logger = get_logger('mirofish.zep_paging') - -_DEFAULT_PAGE_SIZE = 100 -_MAX_NODES = 2000 -_DEFAULT_MAX_RETRIES = 3 -_DEFAULT_RETRY_DELAY = 2.0 # seconds, doubles each retry - - -def _fetch_page_with_retry( - api_call: Callable[..., list[Any]], - *args: Any, - max_retries: int = _DEFAULT_MAX_RETRIES, - retry_delay: float = _DEFAULT_RETRY_DELAY, - page_description: str = "page", - **kwargs: Any, -) -> list[Any]: - """单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。""" - if max_retries < 1: - raise ValueError("max_retries must be >= 1") - - last_exception: Exception | None = None - delay = retry_delay - - for attempt in range(max_retries): - try: - return api_call(*args, **kwargs) - except (ConnectionError, TimeoutError, OSError, InternalServerError) as e: - last_exception = e - if attempt < max_retries - 1: - logger.warning( - f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {delay:.1f}s..." - ) - time.sleep(delay) - delay *= 2 - else: - logger.error(f"Zep {page_description} failed after {max_retries} attempts: {str(e)}") - - assert last_exception is not None - raise last_exception - - -def fetch_all_nodes( - client: Zep, - graph_id: str, - page_size: int = _DEFAULT_PAGE_SIZE, - max_items: int = _MAX_NODES, - max_retries: int = _DEFAULT_MAX_RETRIES, - retry_delay: float = _DEFAULT_RETRY_DELAY, -) -> list[Any]: - """分页获取图谱节点,最多返回 max_items 条(默认 2000)。每页请求自带重试。""" - all_nodes: list[Any] = [] - cursor: str | None = None - page_num = 0 - - while True: - kwargs: dict[str, Any] = {"limit": page_size} - if cursor is not None: - kwargs["uuid_cursor"] = cursor - - page_num += 1 - batch = _fetch_page_with_retry( - client.graph.node.get_by_graph_id, - graph_id, - max_retries=max_retries, - retry_delay=retry_delay, - page_description=f"fetch nodes page {page_num} (graph={graph_id})", - **kwargs, - ) - if not batch: - break - - all_nodes.extend(batch) - if len(all_nodes) >= max_items: - all_nodes = all_nodes[:max_items] - logger.warning(f"Node count reached limit ({max_items}), stopping pagination for graph {graph_id}") - break - if len(batch) < page_size: - break - - cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None) - if cursor is None: - logger.warning(f"Node missing uuid field, stopping pagination at {len(all_nodes)} nodes") - break - - return all_nodes - - -def fetch_all_edges( - client: Zep, - graph_id: str, - page_size: int = _DEFAULT_PAGE_SIZE, - max_retries: int = _DEFAULT_MAX_RETRIES, - retry_delay: float = _DEFAULT_RETRY_DELAY, -) -> list[Any]: - """分页获取图谱所有边,返回完整列表。每页请求自带重试。""" - all_edges: list[Any] = [] - cursor: str | None = None - page_num = 0 - - while True: - kwargs: dict[str, Any] = {"limit": page_size} - if cursor is not None: - kwargs["uuid_cursor"] = cursor - - page_num += 1 - batch = _fetch_page_with_retry( - client.graph.edge.get_by_graph_id, - graph_id, - max_retries=max_retries, - retry_delay=retry_delay, - page_description=f"fetch edges page {page_num} (graph={graph_id})", - **kwargs, - ) - if not batch: - break - - all_edges.extend(batch) - if len(batch) < page_size: - break - - cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None) - if cursor is None: - logger.warning(f"Edge missing uuid field, stopping pagination at {len(all_edges)} edges") - break - - return all_edges diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f5361d53..26092a83f 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -16,8 +16,8 @@ dependencies = [ # LLM 相关 "openai>=1.0.0", - # Zep Cloud - "zep-cloud==3.13.0", + # Graphiti (local graph DB) + "graphiti-core[google-genai,kuzu]>=0.5.0", # OASIS 社交媒体模拟 "camel-oasis==0.2.5", diff --git a/backend/uv.lock b/backend/uv.lock index f1ce4b60e..5dc2cd1a2 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -475,6 +475,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] +[[package]] +name = "diskcache" +version = "5.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -592,6 +601,68 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl", hash = "sha256:8bf1fe301b7d8acfa6e8571e3b1c3d158f909666642431cc78a1b7b4dbc5ec5b", size = 201422, upload-time = "2025-12-03T15:23:41.434Z" }, ] +[[package]] +name = "google-auth" +version = "2.49.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "pyasn1-modules" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ea/80/6a696a07d3d3b0a92488933532f03dbefa4a24ab80fb231395b9a2a1be77/google_auth-2.49.1.tar.gz", hash = "sha256:16d40da1c3c5a0533f57d268fe72e0ebb0ae1cc3b567024122651c045d879b64", size = 333825, upload-time = "2026-03-12T19:30:58.135Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/eb/c6c2478d8a8d633460be40e2a8a6f8f429171997a35a96f81d3b680dec83/google_auth-2.49.1-py3-none-any.whl", hash = "sha256:195ebe3dca18eddd1b3db5edc5189b76c13e96f29e73043b923ebcf3f1a860f7", size = 240737, upload-time = "2026-03-12T19:30:53.159Z" }, +] + +[package.optional-dependencies] +requests = [ + { name = "requests" }, +] + +[[package]] +name = "google-genai" +version = "1.68.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "google-auth", extra = ["requests"] }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "sniffio" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9c/2c/f059982dbcb658cc535c81bbcbe7e2c040d675f4b563b03cdb01018a4bc3/google_genai-1.68.0.tar.gz", hash = "sha256:ac30c0b8bc630f9372993a97e4a11dae0e36f2e10d7c55eacdca95a9fa14ca96", size = 511285, upload-time = "2026-03-18T01:03:18.243Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/de/7d3ee9c94b74c3578ea4f88d45e8de9405902f857932334d81e89bce3dfa/google_genai-1.68.0-py3-none-any.whl", hash = "sha256:a1bc9919c0e2ea2907d1e319b65471d3d6d58c54822039a249fe1323e4178d15", size = 750912, upload-time = "2026-03-18T01:03:15.983Z" }, +] + +[[package]] +name = "graphiti-core" +version = "0.11.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "diskcache" }, + { name = "neo4j" }, + { name = "numpy" }, + { name = "openai" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "tenacity" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/94/3f84400e5f02ea8e9dc79784202de4173cbc16f4b3ad1bd4302da888e4d8/graphiti_core-0.11.6.tar.gz", hash = "sha256:31d26621834d7d4b8865059ab749feb18af15937b59c69598a640a5dfabea331", size = 71928, upload-time = "2025-05-15T17:58:02.304Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/2e/c8f22f01585bf173d1c82f6d4615511aebc75aeda764c69aa394446fa93c/graphiti_core-0.11.6-py3-none-any.whl", hash = "sha256:6ec4807a884f5ea88b942d0c8b7bcd2e107c7358ab4f98ef2a2092c229929707", size = 111001, upload-time = "2025-05-15T17:58:00.542Z" }, +] + +[package.optional-dependencies] +google-genai = [ + { name = "google-genai" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -1248,11 +1319,11 @@ dependencies = [ { name = "charset-normalizer" }, { name = "flask" }, { name = "flask-cors" }, + { name = "graphiti-core", extra = ["google-genai"] }, { name = "openai" }, { name = "pydantic" }, { name = "pymupdf" }, { name = "python-dotenv" }, - { name = "zep-cloud" }, ] [package.optional-dependencies] @@ -1276,6 +1347,7 @@ requires-dist = [ { name = "charset-normalizer", specifier = ">=3.0.0" }, { name = "flask", specifier = ">=3.0.0" }, { name = "flask-cors", specifier = ">=6.0.0" }, + { name = "graphiti-core", extras = ["google-genai", "kuzu"], specifier = ">=0.5.0" }, { name = "openai", specifier = ">=1.0.0" }, { name = "pipreqs", marker = "extra == 'dev'", specifier = ">=0.5.0" }, { name = "pydantic", specifier = ">=2.0.0" }, @@ -1283,7 +1355,6 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, - { name = "zep-cloud", specifier = "==3.13.0" }, ] provides-extras = ["dev"] @@ -1916,6 +1987,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, ] +[[package]] +name = "pyasn1" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5f/6583902b6f79b399c9c40674ac384fd9cd77805f9e6205075f828ef11fb2/pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf", size = 148685, upload-time = "2026-03-17T01:06:53.382Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/a0/7d793dce3fa811fe047d6ae2431c672364b462850c6235ae306c0efd025f/pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde", size = 83997, upload-time = "2026-03-17T01:06:52.036Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + [[package]] name = "pycparser" version = "2.23" @@ -2987,6 +3079,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, +] + [[package]] name = "texttable" version = "1.7.0" @@ -3488,19 +3589,3 @@ sdist = { url = "https://files.pythonhosted.org/packages/d4/c8/cc640404a0981e6c1 wheels = [ { url = "https://files.pythonhosted.org/packages/8b/90/89a2ff242ccab6a24fbab18dbbabc67c51a6f0ed01f9a0f41689dc177419/yarg-0.1.9-py2.py3-none-any.whl", hash = "sha256:4f9cebdc00fac946c9bf2783d634e538a71c7d280a4d806d45fd4dc0ef441492", size = 19162, upload-time = "2014-08-11T22:01:41.104Z" }, ] - -[[package]] -name = "zep-cloud" -version = "3.13.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "httpx" }, - { name = "pydantic" }, - { name = "pydantic-core" }, - { name = "python-dateutil" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/32/c7/c835debf13302f8aaf8d0561ac6ff5a9bc15cc140cd692a1330fb1900c55/zep_cloud-3.13.0.tar.gz", hash = "sha256:c55d9c511773bb2177ae8e08546141404f87d2099affafabd7ec4b4505763e48", size = 63116, upload-time = "2025-11-20T15:25:40.745Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/e1/bbf03c6c8007c0cb238780e7fc6d8e1a52633893933a41aa09678618985a/zep_cloud-3.13.0-py3-none-any.whl", hash = "sha256:b2fbdeef73e262194c8f67b58f76471de6ee87e1a629541a09d8f7bbf475f12b", size = 110601, upload-time = "2025-11-20T15:25:38.484Z" }, -] diff --git a/docker-compose.yml b/docker-compose.yml index 637f1dfae..53d52709f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,7 @@ services: mirofish: - image: ghcr.io/666ghj/mirofish:latest - # 加速镜像(如拉取缓慢可替换上方地址) - # image: ghcr.nju.edu.cn/666ghj/mirofish:latest + build: . + image: mirofish:local container_name: mirofish env_file: - .env diff --git a/frontend/src/components/GraphPanel.vue b/frontend/src/components/GraphPanel.vue index 314c966e4..83c4ae028 100644 --- a/frontend/src/components/GraphPanel.vue +++ b/frontend/src/components/GraphPanel.vue @@ -2,24 +2,24 @@
Graph Relationship Visualization - +
- -
- +
- +
@@ -27,10 +27,10 @@
- {{ isSimulating ? 'GraphRAG长短期记忆实时更新中' : '实时更新中...' }} + {{ isSimulating ? 'GraphRAG long/short-term memory updating in real-time' : 'Updating in real-time...' }}
- +
@@ -39,8 +39,8 @@
- 还有少量内容处理中,建议稍后手动刷新图谱 -
- +
{{ selectedItem.type === 'node' ? 'Node Details' : 'Relationship' }} @@ -58,7 +58,7 @@
- +
Name: @@ -101,9 +101,9 @@
- +
- + - +