diff --git a/.env.example b/.env.example index 78a3b72c0..8300e4254 100644 --- a/.env.example +++ b/.env.example @@ -1,16 +1,45 @@ -# LLM API配置(支持 OpenAI SDK 格式的任意 LLM API) -# 推荐使用阿里百炼平台qwen-plus模型:https://bailian.console.aliyun.com/ -# 注意消耗较大,可先进行小于40轮的模拟尝试 +# LLM API configuration (supports any LLM API compatible with the OpenAI SDK format) +# Recommended: use the qwen-plus model on Alibaba Bailian: https://bailian.console.aliyun.com/ +# Note: usage can be expensive, so try simulations with fewer than 40 rounds first LLM_API_KEY=your_api_key_here -LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 -LLM_MODEL_NAME=qwen-plus +LLM_BASE_URL=https://api.openai.com/v1 +LLM_MODEL_NAME=gpt-4o -# ===== ZEP记忆图谱配置 ===== -# 每月免费额度即可支撑简单使用:https://app.getzep.com/ +# ===== Graph backend selection ===== +# Use zep_cloud for hosted Zep, or graphiti_local for local Neo4j + Graphiti +GRAPH_BACKEND=zep_cloud + +# ===== Zep Cloud configuration ===== +# Required only when GRAPH_BACKEND=zep_cloud ZEP_API_KEY=your_zep_api_key_here -# ===== 加速 LLM 配置(可选)===== -# 注意如果不使用加速配置,env文件中就不要出现下面的配置项 +# ===== Local Graphiti + Neo4j configuration ===== +# Required only when GRAPH_BACKEND=graphiti_local +# Note: the local Graphiti backend stores all graphs in one Neo4j database +# and isolates each MiroFish graph by Graphiti `group_id`. +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=your_neo4j_password_here +NEO4J_DATABASE=neo4j +GRAPHITI_AUTO_INIT=true +GRAPHITI_TELEMETRY_ENABLED=false +GRAPHITI_MAX_COROUTINES=10 +GRAPHITI_SEARCH_RERANKER=rrf + +# Optional: override Graphiti model settings +# If omitted, Graphiti falls back to the main LLM settings above +GRAPHITI_LLM_API_KEY= +GRAPHITI_LLM_BASE_URL= +GRAPHITI_LLM_MODEL= +GRAPHITI_EMBEDDER_API_KEY= +GRAPHITI_EMBEDDER_BASE_URL= +GRAPHITI_EMBEDDER_MODEL=text-embedding-3-small +GRAPHITI_RERANKER_API_KEY= +GRAPHITI_RERANKER_BASE_URL= +GRAPHITI_RERANKER_MODEL= + +# ===== Accelerated LLM configuration (optional) ===== +# If you are not using accelerated configuration, do not include the fields below in your env file LLM_BOOST_API_KEY=your_api_key_here LLM_BOOST_BASE_URL=your_base_url_here -LLM_BOOST_MODEL_NAME=your_model_name_here \ No newline at end of file +LLM_BOOST_MODEL_NAME=your_model_name_here diff --git a/Dockerfile b/Dockerfile index e65646860..b635d4795 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,29 +1,30 @@ FROM python:3.11 -# 安装 Node.js (满足 >=18)及必要工具 +# Install Node.js (version 18 or later) and required tools RUN apt-get update \ && apt-get install -y --no-install-recommends nodejs npm \ && rm -rf /var/lib/apt/lists/* -# 从 uv 官方镜像复制 uv +# Copy `uv` from the official uv image COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/ WORKDIR /app -# 先复制依赖描述文件以利用缓存 +# Copy dependency manifests first to take advantage of layer caching COPY package.json package-lock.json ./ COPY frontend/package.json frontend/package-lock.json ./frontend/ COPY backend/pyproject.toml backend/uv.lock ./backend/ -# 安装依赖(Node + Python) +# Install dependencies (Node + Python) RUN npm ci \ && npm ci --prefix frontend \ - && cd backend && uv sync --frozen + && cd backend && uv sync --frozen \ + && uv pip install --python .venv/bin/python --no-deps graphiti-core==0.28.2 -# 复制项目源码 +# Copy the project source COPY . . EXPOSE 3000 5001 -# 同时启动前后端(开发模式) -CMD ["npm", "run", "dev"] \ No newline at end of file +# Start both frontend and backend services (development mode) +CMD ["npm", "run", "dev"] diff --git a/README-EN.md b/README-EN.md index 4b003a63f..fc58b26ef 100644 --- a/README-EN.md +++ b/README-EN.md @@ -4,7 +4,7 @@ 666ghj%2FMiroFish | Trendshift -简洁通用的群体智能引擎,预测万物 +A simple, universal swarm intelligence engine for predicting anything
A Simple and Universal Swarm Intelligence Engine, Predicting Anything @@ -20,7 +20,7 @@ [![X](https://img.shields.io/badge/X-Follow-000000?style=flat-square&logo=x&logoColor=white)](https://x.com/mirofish_ai) [![Instagram](https://img.shields.io/badge/Instagram-Follow-E4405F?style=flat-square&logo=instagram&logoColor=white)](https://www.instagram.com/mirofish_ai/) -[English](./README-EN.md) | [中文文档](./README.md) +[README](./README.md) | [English Copy](./README-EN.md) @@ -49,16 +49,16 @@ Welcome to visit our online demo environment and experience a prediction simulat
- - + + - - + + - - + +
Screenshot 1Screenshot 2Screenshot 1Screenshot 2
Screenshot 3Screenshot 4Screenshot 3Screenshot 4
Screenshot 5Screenshot 6Screenshot 5Screenshot 6
@@ -68,7 +68,7 @@ Welcome to visit our online demo environment and experience a prediction simulat ### 1. Wuhan University Public Opinion Simulation + MiroFish Project Introduction
-MiroFish Demo Video +MiroFish Demo Video Click the image to watch the complete demo video for prediction using BettaFish-generated "Wuhan University Public Opinion Report"
@@ -76,7 +76,7 @@ Click the image to watch the complete demo video for prediction using BettaFish- ### 2. Dream of the Red Chamber Lost Ending Simulation
-MiroFish Demo Video +MiroFish Demo Video Click the image to watch MiroFish's deep prediction of the lost ending based on hundreds of thousands of words from the first 80 chapters of "Dream of the Red Chamber"
@@ -122,9 +122,21 @@ LLM_API_KEY=your_api_key LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus -# Zep Cloud Configuration -# Free monthly quota is sufficient for simple usage: https://app.getzep.com/ +# Graph backend selection +# Use zep_cloud for hosted Zep, or graphiti_local for local Neo4j + Graphiti +GRAPH_BACKEND=zep_cloud + +# Zep Cloud configuration +# Required only when GRAPH_BACKEND=zep_cloud ZEP_API_KEY=your_zep_api_key + +# Local Graphiti + Neo4j configuration +# Required only when GRAPH_BACKEND=graphiti_local +# Note: the local Graphiti backend stores all graphs in one Neo4j database +# and isolates each MiroFish graph by Graphiti `group_id`. +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=your_neo4j_password ``` #### 2. Install Dependencies @@ -151,6 +163,17 @@ npm run setup:backend npm run dev ``` +If you use `GRAPH_BACKEND=graphiti_local`, start Neo4j too: + +```bash +docker compose up -d neo4j +``` + +The bundled `docker-compose.yml` uses `neo4j:5.26.22-enterprise` with +`NEO4J_ACCEPT_LICENSE_AGREEMENT=yes` as the safe local default. +The current local backend still keeps all graphs in the default Neo4j database +and maps each MiroFish `graph_id` directly to a Graphiti `group_id`. + **Service URLs:** - Frontend: `http://localhost:3000` - Backend API: `http://localhost:5001` @@ -175,11 +198,12 @@ docker compose up -d Reads `.env` from root directory by default, maps ports `3000 (frontend) / 5001 (backend)` > Mirror address for faster pulling is provided as comments in `docker-compose.yml`, replace if needed. +> When `GRAPH_BACKEND=graphiti_local`, the bundled compose stack starts a local Neo4j instance for Graphiti storage. The repo keeps the enterprise image as the default compose target because existing local stores may use the block format. ## 📬 Join the Conversation
-QQ Group +QQ Group
  @@ -200,4 +224,4 @@ MiroFish's simulation engine is powered by **[OASIS (Open Agent Social Interacti Star History Chart - \ No newline at end of file + diff --git a/README-SETUP.md b/README-SETUP.md new file mode 100644 index 000000000..fc7d55cfd --- /dev/null +++ b/README-SETUP.md @@ -0,0 +1,266 @@ +# MiroFish Setup Guide + +This file is a practical setup guide for the current state of this fork. +It is based on the main README, but focuses on the startup paths that are +working in this repository today. + +## What Changed + +MiroFish now supports two graph backends: + +- `zep_cloud`: hosted Zep Cloud +- `graphiti_local`: local Graphiti + Neo4j + +The local backend keeps all project graphs inside one Neo4j database and +isolates them with Graphiti `group_id`. + +## Recommended Paths + +Choose one of these: + +- Docker: run frontend, backend, and Neo4j with `docker compose` +- Local development: run frontend/backend locally and Neo4j in Docker + +## Prerequisites + +For Docker: + +- Docker Desktop or Docker Engine with Compose support + +For local development: + +- Node.js 18+ +- Python 3.11 or 3.12 +- `uv` +- Docker, if you want the local Neo4j service + +## Environment File + +Create the env file from the example: + +```bash +cp .env.example .env +``` + +## Option 1: Docker Startup + +This is the easiest way to run the full stack. + +### 1. Configure `.env` + +For the local Graphiti backend, a minimal working config looks like this: + +```env +GRAPH_BACKEND=graphiti_local + +LLM_API_KEY=your_llm_api_key +LLM_BASE_URL=https://api.openai.com/v1 +LLM_MODEL_NAME=gpt-4o-mini + +NEO4J_URI=bolt://neo4j:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=mirofish-local-password +NEO4J_DATABASE=neo4j +``` + +Notes: + +- `GRAPHITI_LLM_*`, `GRAPHITI_EMBEDDER_*`, and `GRAPHITI_RERANKER_*` are optional +- if they are omitted, the backend falls back to the main `LLM_*` settings + +If you want to keep using hosted Zep Cloud instead, use: + +```env +GRAPH_BACKEND=zep_cloud + +LLM_API_KEY=your_llm_api_key +LLM_BASE_URL=https://api.openai.com/v1 +LLM_MODEL_NAME=gpt-4o-mini + +ZEP_API_KEY=your_zep_api_key +``` + +### 2. Build and start + +```bash +docker compose up -d --build +``` + +### 3. Check status + +```bash +docker compose ps +docker compose logs -f +curl http://localhost:5001/health +``` + +When healthy, the backend should answer with a payload that includes: + +```json +{ + "status": "ok", + "service": "MiroFish Backend", + "graph_backend": "graphiti_local" +} +``` + +### 4. Open the app + +- Frontend: `http://localhost:3000` +- Backend: `http://localhost:5001` +- Neo4j Browser: `http://localhost:7474` + +### Useful Docker commands + +Stop the stack: + +```bash +docker compose down +``` + +Stop and remove volumes too: + +```bash +docker compose down -v +``` + +Rebuild after dependency or Dockerfile changes: + +```bash +docker compose up -d --build +``` + +Restart only Neo4j: + +```bash +docker compose up -d neo4j +``` + +## Option 2: Local Development Startup + +Use this when you want hot reload or easier debugging. + +### 1. Configure `.env` + +For local Graphiti, use: + +```env +GRAPH_BACKEND=graphiti_local + +LLM_API_KEY=your_llm_api_key +LLM_BASE_URL=https://api.openai.com/v1 +LLM_MODEL_NAME=gpt-4o-mini + +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=mirofish-local-password +NEO4J_DATABASE=neo4j +``` + +### 2. Install dependencies + +```bash +npm run setup:all +``` + +This does all of the following: + +- installs root Node dependencies +- installs frontend dependencies +- creates and syncs the backend `uv` environment +- installs `graphiti-core==0.28.2` separately into the backend venv + +### 3. Start Neo4j + +```bash +docker compose up -d neo4j +``` + +### 4. Start frontend and backend + +```bash +npm run dev +``` + +Or individually: + +```bash +npm run backend +npm run frontend +``` + +## Current Neo4j Note + +The local compose stack uses: + +- `neo4j:5.26.22-enterprise` + +This repo keeps the enterprise image as the default compose target because +existing local volumes may already use Neo4j block format. The application +logic itself is using a single Neo4j database plus Graphiti `group_id` +isolation, not one database per project. + +## Troubleshooting + +### Backend health is failing + +Check: + +- `LLM_API_KEY` is set +- `GRAPH_BACKEND` is correct +- if `GRAPH_BACKEND=graphiti_local`, `NEO4J_PASSWORD` is set +- Neo4j is running + +### Docker app builds but does not start correctly + +Watch logs: + +```bash +docker compose logs -f mirofish neo4j +``` + +### Neo4j starts but the backend cannot connect + +For Docker: + +- use `NEO4J_URI=bolt://neo4j:7687` + +For local development: + +- use `NEO4J_URI=bolt://localhost:7687` + +### You are on x86_64 and Docker build fails + +The app service currently pins: + +- `platform: linux/arm64` + +in `docker-compose.yml`. + +If your machine is not ARM64, remove or change that line before building. + +## Fast Start + +If you just want the shortest path for local Graphiti in Docker: + +```bash +cp .env.example .env +``` + +Put this in `.env`: + +```env +GRAPH_BACKEND=graphiti_local +LLM_API_KEY=your_llm_api_key +NEO4J_PASSWORD=mirofish-local-password +NEO4J_URI=bolt://neo4j:7687 +NEO4J_USER=neo4j +NEO4J_DATABASE=neo4j +``` + +Then run: + +```bash +docker compose up -d --build +curl http://localhost:5001/health +``` diff --git a/README.md b/README.md index 4f5cffe74..7013265eb 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ 666ghj%2FMiroFish | Trendshift -简洁通用的群体智能引擎,预测万物 +A simple, universal swarm intelligence engine for predicting anything
A Simple and Universal Swarm Intelligence Engine, Predicting Anything @@ -20,179 +20,203 @@ [![X](https://img.shields.io/badge/X-Follow-000000?style=flat-square&logo=x&logoColor=white)](https://x.com/mirofish_ai) [![Instagram](https://img.shields.io/badge/Instagram-Follow-E4405F?style=flat-square&logo=instagram&logoColor=white)](https://www.instagram.com/mirofish_ai/) -[English](./README-EN.md) | [中文文档](./README.md) +[README](./README.md) | [English Copy](./README-EN.md) -## ⚡ 项目概述 +## ⚡ Overview -**MiroFish** 是一款基于多智能体技术的新一代 AI 预测引擎。通过提取现实世界的种子信息(如突发新闻、政策草案、金融信号),自动构建出高保真的平行数字世界。在此空间内,成千上万个具备独立人格、长期记忆与行为逻辑的智能体进行自由交互与社会演化。你可透过「上帝视角」动态注入变量,精准推演未来走向——**让未来在数字沙盘中预演,助决策在百战模拟后胜出**。 +**MiroFish** is a next-generation AI prediction engine powered by multi-agent technology. By extracting seed information from the real world (such as breaking news, policy drafts, or financial signals), it automatically constructs a high-fidelity parallel digital world. Within this space, thousands of intelligent agents with independent personalities, long-term memory, and behavioral logic freely interact and undergo social evolution. You can inject variables dynamically from a "God's-eye view" to precisely deduce future trajectories — **rehearse the future in a digital sandbox, and win decisions after countless simulations**. -> 你只需:上传种子材料(数据分析报告或者有趣的小说故事),并用自然语言描述预测需求
-> MiroFish 将返回:一份详尽的预测报告,以及一个可深度交互的高保真数字世界 +> You only need to: upload seed materials (data analysis reports or interesting novel stories) and describe your prediction requirements in natural language
+> MiroFish will return: a detailed prediction report and a deeply interactive high-fidelity digital world -### 我们的愿景 +### Our Vision -MiroFish 致力于打造映射现实的群体智能镜像,通过捕捉个体互动引发的群体涌现,突破传统预测的局限: +MiroFish is dedicated to creating a swarm intelligence mirror that maps reality. By capturing the collective emergence triggered by individual interactions, we break through the limitations of traditional prediction: -- **于宏观**:我们是决策者的预演实验室,让政策与公关在零风险中试错 -- **于微观**:我们是个人用户的创意沙盘,无论是推演小说结局还是探索脑洞,皆可有趣、好玩、触手可及 +- **At the Macro Level**: We are a rehearsal laboratory for decision-makers, allowing policies and public relations to be tested at zero risk +- **At the Micro Level**: We are a creative sandbox for individual users, whether deducing novel endings or exploring imaginative scenarios, everything can be fun, playful, and accessible -从严肃预测到趣味仿真,我们让每一个如果都能看见结果,让预测万物成为可能。 +From serious predictions to playful simulations, we let every "what if" see its outcome, making it possible to predict anything. -## 🌐 在线体验 +## 🌐 Live Demo -欢迎访问在线 Demo 演示环境,体验我们为你准备的一次关于热点舆情事件的推演预测:[mirofish-live-demo](https://666ghj.github.io/mirofish-demo/) +Visit our online demo environment and experience a prediction simulation around a trending public-opinion event: [mirofish-live-demo](https://666ghj.github.io/mirofish-demo/) -## 📸 系统截图 +## 📸 Screenshots
- - + + - - + + - - + +
截图1截图2Screenshot 1Screenshot 2
截图3截图4Screenshot 3Screenshot 4
截图5截图6Screenshot 5Screenshot 6
-## 🎬 演示视频 +## 🎬 Demo Videos -### 1. 武汉大学舆情推演预测 + MiroFish项目讲解 +### 1. Wuhan University Public Opinion Simulation + MiroFish Project Introduction
-MiroFish Demo Video +MiroFish Demo Video -点击图片查看使用微舆BettaFish生成的《武大舆情报告》进行预测的完整演示视频 +Click the image to watch the complete demo video for prediction using the BettaFish-generated "Wuhan University Public Opinion Report."
-### 2. 《红楼梦》失传结局推演预测 +### 2. Dream of the Red Chamber Lost Ending Simulation
-MiroFish Demo Video +MiroFish Demo Video -点击图片查看基于《红楼梦》前80回数十万字,MiroFish深度预测失传结局 +Click the image to watch MiroFish predict the lost ending based on the first 80 chapters of *Dream of the Red Chamber*.
-> **金融方向推演预测**、**时政要闻推演预测**等示例陆续更新中... +> **Financial prediction**, **current-events forecasting**, and more examples are coming soon. -## 🔄 工作流程 +## 🔄 Workflow -1. **图谱构建**:现实种子提取 & 个体与群体记忆注入 & GraphRAG构建 -2. **环境搭建**:实体关系抽取 & 人设生成 & 环境配置Agent注入仿真参数 -3. **开始模拟**:双平台并行模拟 & 自动解析预测需求 & 动态更新时序记忆 -4. **报告生成**:ReportAgent拥有丰富的工具集与模拟后环境进行深度交互 -5. **深度互动**:与模拟世界中的任意一位进行对话 & 与ReportAgent进行对话 +1. **Graph Building**: Seed extraction, individual and collective memory injection, and GraphRAG construction +2. **Environment Setup**: Entity relationship extraction, persona generation, and agent configuration injection +3. **Simulation**: Dual-platform parallel simulation, automatic prediction-requirement parsing, and dynamic temporal memory updates +4. **Report Generation**: ReportAgent uses a rich toolset to interact deeply with the post-simulation environment +5. **Deep Interaction**: Chat with any agent in the simulated world and continue the conversation with ReportAgent -## 🚀 快速开始 +## 🚀 Quick Start -### 一、源码部署(推荐) +### Option 1: Source Deployment (Recommended) -#### 前置要求 +#### Prerequisites -| 工具 | 版本要求 | 说明 | 安装检查 | -|------|---------|------|---------| -| **Node.js** | 18+ | 前端运行环境,包含 npm | `node -v` | -| **Python** | ≥3.11, ≤3.12 | 后端运行环境 | `python --version` | -| **uv** | 最新版 | Python 包管理器 | `uv --version` | +| Tool | Version | Description | Check Installation | +|------|---------|-------------|-------------------| +| **Node.js** | 18+ | Frontend runtime, includes npm | `node -v` | +| **Python** | ≥3.11, ≤3.12 | Backend runtime | `python --version` | +| **uv** | Latest | Python package manager | `uv --version` | -#### 1. 配置环境变量 +#### 1. Configure Environment Variables ```bash -# 复制示例配置文件 +# Copy the example configuration file cp .env.example .env -# 编辑 .env 文件,填入必要的 API 密钥 +# Edit the .env file and fill in the required API keys ``` -**必需的环境变量:** +**Required Environment Variables:** ```env -# LLM API配置(支持 OpenAI SDK 格式的任意 LLM API) -# 推荐使用阿里百炼平台qwen-plus模型:https://bailian.console.aliyun.com/ -# 注意消耗较大,可先进行小于40轮的模拟尝试 +# LLM API configuration (supports any LLM API compatible with the OpenAI SDK format) +# Recommended: use the qwen-plus model on Alibaba Bailian: https://bailian.console.aliyun.com/ +# Note: usage can be expensive, so try simulations with fewer than 40 rounds first LLM_API_KEY=your_api_key LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus -# Zep Cloud 配置 -# 每月免费额度即可支撑简单使用:https://app.getzep.com/ +# Graph backend selection +# Use zep_cloud for hosted Zep, or graphiti_local for local Neo4j + Graphiti +GRAPH_BACKEND=zep_cloud + +# Zep Cloud configuration +# Required only when GRAPH_BACKEND=zep_cloud ZEP_API_KEY=your_zep_api_key + +# Local Graphiti + Neo4j configuration +# Required only when GRAPH_BACKEND=graphiti_local +# Note: the local Graphiti backend stores all graphs in one Neo4j database +# and isolates each MiroFish graph by Graphiti `group_id`. +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=your_neo4j_password ``` -#### 2. 安装依赖 +#### 2. Install Dependencies ```bash -# 一键安装所有依赖(根目录 + 前端 + 后端) +# One-click installation of all dependencies (root + frontend + backend) npm run setup:all ``` -或者分步安装: +Or install them step by step: ```bash -# 安装 Node 依赖(根目录 + 前端) +# Install Node dependencies (root + frontend) npm run setup -# 安装 Python 依赖(后端,自动创建虚拟环境) +# Install Python dependencies (backend, auto-creates virtual environment) npm run setup:backend ``` -#### 3. 启动服务 +#### 3. Start Services ```bash -# 同时启动前后端(在项目根目录执行) +# Start both frontend and backend (run from the project root) npm run dev ``` -**服务地址:** -- 前端:`http://localhost:3000` -- 后端 API:`http://localhost:5001` +If you use `GRAPH_BACKEND=graphiti_local`, start Neo4j too: + +```bash +docker compose up -d neo4j +``` + +The bundled `docker-compose.yml` uses `neo4j:5.26.22-enterprise` with +`NEO4J_ACCEPT_LICENSE_AGREEMENT=yes` as the safe default for local compatibility. +The current local backend still keeps all graphs in the default Neo4j database +and maps each MiroFish `graph_id` directly to a Graphiti `group_id`. + +**Service URLs:** +- Frontend: `http://localhost:3000` +- Backend API: `http://localhost:5001` -**单独启动:** +**Start Individually:** ```bash -npm run backend # 仅启动后端 -npm run frontend # 仅启动前端 +npm run backend # Start the backend only +npm run frontend # Start the frontend only ``` -### 二、Docker 部署 +### Option 2: Docker Deployment ```bash -# 1. 配置环境变量(同源码部署) +# 1. Configure environment variables (same as source deployment) cp .env.example .env -# 2. 拉取镜像并启动 +# 2. Pull the image and start docker compose up -d ``` -默认会读取根目录下的 `.env`,并映射端口 `3000(前端)/5001(后端)` +Docker reads `.env` from the project root by default and maps ports `3000 (frontend) / 5001 (backend)`. -> 在 `docker-compose.yml` 中已通过注释提供加速镜像地址,可按需替换 +> A mirror image URL is provided as a comment in `docker-compose.yml` if you need a faster pull source. +> When `GRAPH_BACKEND=graphiti_local`, the bundled compose stack starts a local Neo4j instance for Graphiti storage. The repo keeps the enterprise image as the default compose target because existing local stores may use the block format. -## 📬 更多交流 +## 📬 Join the Conversation
-QQ交流群 +QQ Group
  -MiroFish团队长期招募全职/实习,如果你对多Agent应用感兴趣,欢迎投递简历至:**mirofish@shanda.com** +The MiroFish team is recruiting for full-time and internship roles. If you are interested in multi-agent simulation and LLM applications, send your resume to: **mirofish@shanda.com** -## 📄 致谢 +## 📄 Acknowledgments -**MiroFish 得到了盛大集团的战略支持和孵化!** +**MiroFish has received strategic support and incubation from Shanda Group.** -MiroFish 的仿真引擎由 **[OASIS](https://github.com/camel-ai/oasis)** 驱动,我们衷心感谢 CAMEL-AI 团队的开源贡献! +MiroFish's simulation engine is powered by **[OASIS](https://github.com/camel-ai/oasis)**, and we sincerely thank the CAMEL-AI team for their open-source contributions. -## 📈 项目统计 +## 📈 Project Statistics diff --git a/backend/app/__init__.py b/backend/app/__init__.py index aba624bba..f172f5806 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -1,80 +1,87 @@ -""" -MiroFish Backend - Flask应用工厂 -""" +"""MiroFish backend Flask application factory.""" import os import warnings -# 抑制 multiprocessing resource_tracker 的警告(来自第三方库如 transformers) -# 需要在所有其他导入之前设置 + + warnings.filterwarnings("ignore", message=".*resource_tracker.*") from flask import Flask, request from flask_cors import CORS from .config import Config +from .services.graph_provider import initialize_selected_graph_backend from .utils.logger import setup_logger, get_logger def create_app(config_class=Config): - """Flask应用工厂函数""" + """Create app.""" app = Flask(__name__) app.config.from_object(config_class) - # 设置JSON编码:确保中文直接显示(而不是 \uXXXX 格式) - # Flask >= 2.3 使用 app.json.ensure_ascii,旧版本使用 JSON_AS_ASCII 配置 + + if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'): app.json.ensure_ascii = False - # 设置日志 + logger = setup_logger('mirofish') - # 只在 reloader 子进程中打印启动信息(避免 debug 模式下打印两次) + is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' debug_mode = app.config.get('DEBUG', False) should_log_startup = not debug_mode or is_reloader_process if should_log_startup: logger.info("=" * 50) - logger.info("MiroFish Backend 启动中...") + logger.info("Starting MiroFish Backend...") logger.info("=" * 50) - # 启用CORS + CORS(app, resources={r"/api/*": {"origins": "*"}}) + + + initialize_selected_graph_backend() + if should_log_startup: + logger.info(f"Graph backend initialized: {Config.GRAPH_BACKEND}") + - # 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程) from .services.simulation_runner import SimulationRunner SimulationRunner.register_cleanup() if should_log_startup: - logger.info("已注册模拟进程清理函数") + logger.info("Registered simulation process cleanup") + - # 请求日志中间件 @app.before_request def log_request(): logger = get_logger('mirofish.request') - logger.debug(f"请求: {request.method} {request.path}") + logger.debug(f"Request: {request.method} {request.path}") if request.content_type and 'json' in request.content_type: - logger.debug(f"请求体: {request.get_json(silent=True)}") + logger.debug(f"Request body: {request.get_json(silent=True)}") @app.after_request def log_response(response): logger = get_logger('mirofish.request') - logger.debug(f"响应: {response.status_code}") + logger.debug(f"Response: {response.status_code}") return response - # 注册蓝图 + from .api import graph_bp, simulation_bp, report_bp app.register_blueprint(graph_bp, url_prefix='/api/graph') app.register_blueprint(simulation_bp, url_prefix='/api/simulation') app.register_blueprint(report_bp, url_prefix='/api/report') - # 健康检查 + @app.route('/health') def health(): - return {'status': 'ok', 'service': 'MiroFish Backend'} + return { + 'status': 'ok', + 'service': 'MiroFish Backend', + 'graph_backend': Config.GRAPH_BACKEND, + } if should_log_startup: - logger.info("MiroFish Backend 启动完成") + logger.info("MiroFish Backend started") return app - diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index ffda743a3..81c5d2c07 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -1,6 +1,4 @@ -""" -API路由模块 -""" +"""API route modules.""" from flask import Blueprint diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 12ff1ba2d..486f013a0 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -1,7 +1,4 @@ -""" -图谱相关API路由 -采用项目上下文机制,服务端持久化状态 -""" +"""Graph-related API routes.""" import os import traceback @@ -18,31 +15,29 @@ from ..models.task import TaskManager, TaskStatus from ..models.project import ProjectManager, ProjectStatus -# 获取日志器 + logger = get_logger('mirofish.api') def allowed_file(filename: str) -> bool: - """检查文件扩展名是否允许""" + """Check whether the file extension is allowed.""" if not filename or '.' not in filename: return False ext = os.path.splitext(filename)[1].lower().lstrip('.') return ext in Config.ALLOWED_EXTENSIONS -# ============== 项目管理接口 ============== + @graph_bp.route('/project/', methods=['GET']) def get_project(project_id: str): - """ - 获取项目详情 - """ + """Get project.""" project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {project_id}" + "error": f"Project not found: {project_id}" }), 404 return jsonify({ @@ -53,9 +48,7 @@ def get_project(project_id: str): @graph_bp.route('/project/list', methods=['GET']) def list_projects(): - """ - 列出所有项目 - """ + """List projects.""" limit = request.args.get('limit', 50, type=int) projects = ProjectManager.list_projects(limit=limit) @@ -68,37 +61,33 @@ def list_projects(): @graph_bp.route('/project/', methods=['DELETE']) def delete_project(project_id: str): - """ - 删除项目 - """ + """Delete project.""" success = ProjectManager.delete_project(project_id) if not success: return jsonify({ "success": False, - "error": f"项目不存在或删除失败: {project_id}" + "error": f"Project not found or could not be deleted: {project_id}" }), 404 return jsonify({ "success": True, - "message": f"项目已删除: {project_id}" + "message": f"Project deleted: {project_id}" }) @graph_bp.route('/project//reset', methods=['POST']) def reset_project(project_id: str): - """ - 重置项目状态(用于重新构建图谱) - """ + """Reset project.""" project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {project_id}" + "error": f"Project not found: {project_id}" }), 404 - # 重置到本体已生成状态 + if project.ontology: project.status = ProjectStatus.ONTOLOGY_GENERATED else: @@ -111,78 +100,53 @@ def reset_project(project_id: str): return jsonify({ "success": True, - "message": f"项目已重置: {project_id}", + "message": f"Project reset: {project_id}", "data": project.to_dict() }) -# ============== 接口1:上传文件并生成本体 ============== + @graph_bp.route('/ontology/generate', methods=['POST']) def generate_ontology(): - """ - 接口1:上传文件,分析生成本体定义 - - 请求方式:multipart/form-data - - 参数: - files: 上传的文件(PDF/MD/TXT),可多个 - simulation_requirement: 模拟需求描述(必填) - project_name: 项目名称(可选) - additional_context: 额外说明(可选) - - 返回: - { - "success": true, - "data": { - "project_id": "proj_xxxx", - "ontology": { - "entity_types": [...], - "edge_types": [...], - "analysis_summary": "..." - }, - "files": [...], - "total_text_length": 12345 - } - } - """ + """Generate ontology.""" try: - logger.info("=== 开始生成本体定义 ===") + logger.info("=== Starting ontology generation ===") + - # 获取参数 simulation_requirement = request.form.get('simulation_requirement', '') project_name = request.form.get('project_name', 'Unnamed Project') additional_context = request.form.get('additional_context', '') - logger.debug(f"项目名称: {project_name}") - logger.debug(f"模拟需求: {simulation_requirement[:100]}...") + logger.debug(f"Project name: {project_name}") + logger.debug(f"Simulation requirement: {simulation_requirement[:100]}...") if not simulation_requirement: return jsonify({ "success": False, - "error": "请提供模拟需求描述 (simulation_requirement)" + "error": "Please provide a simulation requirement description (simulation_requirement)." }), 400 - # 获取上传的文件 + uploaded_files = request.files.getlist('files') if not uploaded_files or all(not f.filename for f in uploaded_files): return jsonify({ "success": False, - "error": "请至少上传一个文档文件" + "error": "Please upload at least one document file." }), 400 - # 创建项目 + project = ProjectManager.create_project(name=project_name) project.simulation_requirement = simulation_requirement - logger.info(f"创建项目: {project.project_id}") + logger.info(f"Created project: {project.project_id}") + - # 保存文件并提取文本 document_texts = [] all_text = "" for file in uploaded_files: if file and file.filename and allowed_file(file.filename): - # 保存文件到项目目录 + file_info = ProjectManager.save_file_to_project( project.project_id, file, @@ -193,7 +157,7 @@ def generate_ontology(): "size": file_info["size"] }) - # 提取文本 + text = FileParser.extract_text(file_info["path"]) text = TextProcessor.preprocess_text(text) document_texts.append(text) @@ -203,16 +167,16 @@ def generate_ontology(): ProjectManager.delete_project(project.project_id) return jsonify({ "success": False, - "error": "没有成功处理任何文档,请检查文件格式" + "error": "No documents were processed successfully. Please check the file format." }), 400 - # 保存提取的文本 + project.total_text_length = len(all_text) ProjectManager.save_extracted_text(project.project_id, all_text) - logger.info(f"文本提取完成,共 {len(all_text)} 字符") + logger.info(f"Text extraction completed: {len(all_text)} characters") + - # 生成本体 - logger.info("调用 LLM 生成本体定义...") + logger.info("Calling the LLM to generate the ontology...") generator = OntologyGenerator() ontology = generator.generate( document_texts=document_texts, @@ -220,10 +184,10 @@ def generate_ontology(): additional_context=additional_context if additional_context else None ) - # 保存本体到项目 + entity_count = len(ontology.get("entity_types", [])) edge_count = len(ontology.get("edge_types", [])) - logger.info(f"本体生成完成: {entity_count} 个实体类型, {edge_count} 个关系类型") + logger.info(f"Ontology generation completed: {entity_count} entity types, {edge_count} edge types") project.ontology = { "entity_types": ontology.get("entity_types", []), @@ -232,7 +196,7 @@ def generate_ontology(): project.analysis_summary = ontology.get("analysis_summary", "") project.status = ProjectStatus.ONTOLOGY_GENERATED ProjectManager.save_project(project) - logger.info(f"=== 本体生成完成 === 项目ID: {project.project_id}") + logger.info(f"=== Ontology generation completed === project_id: {project.project_id}") return jsonify({ "success": True, @@ -254,140 +218,118 @@ def generate_ontology(): }), 500 -# ============== 接口2:构建图谱 ============== + @graph_bp.route('/build', methods=['POST']) def build_graph(): - """ - 接口2:根据project_id构建图谱 - - 请求(JSON): - { - "project_id": "proj_xxxx", // 必填,来自接口1 - "graph_name": "图谱名称", // 可选 - "chunk_size": 500, // 可选,默认500 - "chunk_overlap": 50 // 可选,默认50 - } - - 返回: - { - "success": true, - "data": { - "project_id": "proj_xxxx", - "task_id": "task_xxxx", - "message": "图谱构建任务已启动" - } - } - """ + """Build graph.""" try: - logger.info("=== 开始构建图谱 ===") + logger.info("=== Starting graph build ===") + - # 检查配置 - errors = [] - if not Config.ZEP_API_KEY: - errors.append("ZEP_API_KEY未配置") + errors = Config.validate_graph_backend() if errors: - logger.error(f"配置错误: {errors}") + logger.error(f"Configuration error: {errors}") return jsonify({ "success": False, - "error": "配置错误: " + "; ".join(errors) + "error": "Configuration error: " + "; ".join(errors) }), 500 - # 解析请求 + data = request.get_json() or {} project_id = data.get('project_id') - logger.debug(f"请求参数: project_id={project_id}") + logger.debug(f"Request parameters: project_id={project_id}") if not project_id: return jsonify({ "success": False, - "error": "请提供 project_id" + "error": "Please provide project_id." }), 400 - # 获取项目 + project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {project_id}" + "error": f"Project not found: {project_id}" }), 404 - # 检查项目状态 - force = data.get('force', False) # 强制重新构建 + + force = data.get('force', False) if project.status == ProjectStatus.CREATED: return jsonify({ "success": False, - "error": "项目尚未生成本体,请先调用 /ontology/generate" + "error": "The project does not have an ontology yet. Call /ontology/generate first." }), 400 if project.status == ProjectStatus.GRAPH_BUILDING and not force: return jsonify({ "success": False, - "error": "图谱正在构建中,请勿重复提交。如需强制重建,请添加 force: true", + "error": "A graph build is already in progress. To force a rebuild, set force: true.", "task_id": project.graph_build_task_id }), 400 - # 如果强制重建,重置状态 + if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]: project.status = ProjectStatus.ONTOLOGY_GENERATED project.graph_id = None project.graph_build_task_id = None project.error = None - # 获取配置 + graph_name = data.get('graph_name', project.name or 'MiroFish Graph') chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE) chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP) - # 更新项目配置 + project.chunk_size = chunk_size project.chunk_overlap = chunk_overlap - # 获取提取的文本 + text = ProjectManager.get_extracted_text(project_id) if not text: return jsonify({ "success": False, - "error": "未找到提取的文本内容" + "error": "Extracted text content was not found." }), 400 - # 获取本体 + ontology = project.ontology if not ontology: return jsonify({ "success": False, - "error": "未找到本体定义" + "error": "Ontology definition was not found." }), 400 - # 创建异步任务 + task_manager = TaskManager() - task_id = task_manager.create_task(f"构建图谱: {graph_name}") - logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}") + task_id = task_manager.create_task(f"Build graph: {graph_name}") + logger.info(f"Created graph build task: task_id={task_id}, project_id={project_id}") + - # 更新项目状态 project.status = ProjectStatus.GRAPH_BUILDING project.graph_build_task_id = task_id ProjectManager.save_project(project) - # 启动后台任务 + def build_task(): build_logger = get_logger('mirofish.build') try: - build_logger.info(f"[{task_id}] 开始构建图谱...") + build_logger.info(f"[{task_id}] Starting graph build...") task_manager.update_task( task_id, status=TaskStatus.PROCESSING, - message="初始化图谱构建服务..." + message="Initializing the graph build service..." ) - # 创建图谱构建服务 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) - # 分块 + builder = GraphBuilderService() + + task_manager.update_task( task_id, - message="文本分块中...", + message="Chunking text...", progress=5 ) chunks = TextProcessor.split_text( @@ -397,27 +339,27 @@ def build_task(): ) total_chunks = len(chunks) - # 创建图谱 + task_manager.update_task( task_id, - message="创建Zep图谱...", + message="Creating graph namespace...", progress=10 ) graph_id = builder.create_graph(name=graph_name) - # 更新项目的graph_id + project.graph_id = graph_id ProjectManager.save_project(project) - # 设置本体 + task_manager.update_task( task_id, - message="设置本体定义...", + message="Applying the ontology definition...", progress=15 ) builder.set_ontology(graph_id, ontology) - # 添加文本(progress_callback 签名是 (msg, progress_ratio)) + def add_progress_callback(msg, progress_ratio): progress = 15 + int(progress_ratio * 40) # 15% - 55% task_manager.update_task( @@ -428,7 +370,7 @@ def add_progress_callback(msg, progress_ratio): task_manager.update_task( task_id, - message=f"开始添加 {total_chunks} 个文本块...", + message=f"Adding {total_chunks} text chunks...", progress=15 ) @@ -439,10 +381,10 @@ def add_progress_callback(msg, progress_ratio): progress_callback=add_progress_callback ) - # 等待Zep处理完成(查询每个episode的processed状态) + task_manager.update_task( task_id, - message="等待Zep处理数据...", + message="Waiting for graph ingestion to complete...", progress=55 ) @@ -454,29 +396,29 @@ def wait_progress_callback(msg, progress_ratio): progress=progress ) - builder._wait_for_episodes(episode_uuids, wait_progress_callback) + builder._wait_for_episodes(graph_id, episode_uuids, wait_progress_callback) + - # 获取图谱数据 task_manager.update_task( task_id, - message="获取图谱数据...", + message="Fetching graph data...", progress=95 ) graph_data = builder.get_graph_data(graph_id) - # 更新项目状态 + project.status = ProjectStatus.GRAPH_COMPLETED ProjectManager.save_project(project) node_count = graph_data.get("node_count", 0) edge_count = graph_data.get("edge_count", 0) - build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}") + build_logger.info(f"[{task_id}] Graph build completed: graph_id={graph_id}, nodes={node_count}, edges={edge_count}") + - # 完成 task_manager.update_task( task_id, status=TaskStatus.COMPLETED, - message="图谱构建完成", + message="Graph build completed", progress=100, result={ "project_id": project_id, @@ -488,8 +430,8 @@ def wait_progress_callback(msg, progress_ratio): ) except Exception as e: - # 更新项目状态为失败 - build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}") + + build_logger.error(f"[{task_id}] Graph build failed: {str(e)}") build_logger.debug(traceback.format_exc()) project.status = ProjectStatus.FAILED @@ -499,11 +441,11 @@ def wait_progress_callback(msg, progress_ratio): task_manager.update_task( task_id, status=TaskStatus.FAILED, - message=f"构建失败: {str(e)}", + message=f"Build failed: {str(e)}", error=traceback.format_exc() ) - # 启动后台线程 + thread = threading.Thread(target=build_task, daemon=True) thread.start() @@ -512,7 +454,7 @@ def wait_progress_callback(msg, progress_ratio): "data": { "project_id": project_id, "task_id": task_id, - "message": "图谱构建任务已启动,请通过 /task/{task_id} 查询进度" + "message": "Graph build task started. Check progress via /task/{task_id}." } }) @@ -524,19 +466,17 @@ def wait_progress_callback(msg, progress_ratio): }), 500 -# ============== 任务查询接口 ============== + @graph_bp.route('/task/', methods=['GET']) def get_task(task_id: str): - """ - 查询任务状态 - """ + """Get task.""" task = TaskManager().get_task(task_id) if not task: return jsonify({ "success": False, - "error": f"任务不存在: {task_id}" + "error": f"Task not found: {task_id}" }), 404 return jsonify({ @@ -547,9 +487,7 @@ def get_task(task_id: str): @graph_bp.route('/tasks', methods=['GET']) def list_tasks(): - """ - 列出所有任务 - """ + """List tasks.""" tasks = TaskManager().list_tasks() return jsonify({ @@ -559,21 +497,19 @@ def list_tasks(): }) -# ============== 图谱数据接口 ============== + @graph_bp.route('/data/', methods=['GET']) def get_graph_data(graph_id: str): - """ - 获取图谱数据(节点和边) - """ + """Get graph data.""" try: - if not Config.ZEP_API_KEY: + if Config.validate_graph_backend(): return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "Graph backend is not configured correctly" }), 500 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + builder = GraphBuilderService() graph_data = builder.get_graph_data(graph_id) return jsonify({ @@ -591,22 +527,20 @@ def get_graph_data(graph_id: str): @graph_bp.route('/delete/', methods=['DELETE']) def delete_graph(graph_id: str): - """ - 删除Zep图谱 - """ + """Delete graph.""" try: - if not Config.ZEP_API_KEY: + if Config.validate_graph_backend(): return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "Graph backend is not configured correctly" }), 500 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + builder = GraphBuilderService() builder.delete_graph(graph_id) return jsonify({ "success": True, - "message": f"图谱已删除: {graph_id}" + "message": f"Graph deleted: {graph_id}" }) except Exception as e: diff --git a/backend/app/api/report.py b/backend/app/api/report.py index e05c73c39..fa7743839 100644 --- a/backend/app/api/report.py +++ b/backend/app/api/report.py @@ -1,7 +1,4 @@ -""" -Report API路由 -提供模拟报告生成、获取、对话等接口 -""" +"""Report API routes.""" import os import traceback @@ -19,33 +16,11 @@ logger = get_logger('mirofish.api.report') -# ============== 报告生成接口 ============== + @report_bp.route('/generate', methods=['POST']) def generate_report(): - """ - 生成模拟分析报告(异步任务) - - 这是一个耗时操作,接口会立即返回task_id, - 使用 GET /api/report/generate/status 查询进度 - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "force_regenerate": false // 可选,强制重新生成 - } - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "task_id": "task_xxxx", - "status": "generating", - "message": "报告生成任务已启动" - } - } - """ + """Generate report.""" try: data = request.get_json() or {} @@ -53,22 +28,22 @@ def generate_report(): if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "simulation_id is required" }), 400 force_regenerate = data.get('force_regenerate', False) - # 获取模拟信息 + manager = SimulationManager() state = manager.get_simulation(simulation_id) if not state: return jsonify({ "success": False, - "error": f"模拟不存在: {simulation_id}" + "error": f"Simulation not found: {simulation_id}" }), 404 - # 检查是否已有报告 + if not force_regenerate: existing_report = ReportManager.get_report_by_simulation(simulation_id) if existing_report and existing_report.status == ReportStatus.COMPLETED: @@ -78,38 +53,38 @@ def generate_report(): "simulation_id": simulation_id, "report_id": existing_report.report_id, "status": "completed", - "message": "报告已存在", + "message": "Report already exists", "already_generated": True } }) - # 获取项目信息 + project = ProjectManager.get_project(state.project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {state.project_id}" + "error": f"Project not found: {state.project_id}" }), 404 graph_id = state.graph_id or project.graph_id if not graph_id: return jsonify({ "success": False, - "error": "缺少图谱ID,请确保已构建图谱" + "error": "Missing graph ID. Please make sure the graph has been built." }), 400 simulation_requirement = project.simulation_requirement if not simulation_requirement: return jsonify({ "success": False, - "error": "缺少模拟需求描述" + "error": "Missing simulation requirement description" }), 400 - # 提前生成 report_id,以便立即返回给前端 + import uuid report_id = f"report_{uuid.uuid4().hex[:12]}" - # 创建异步任务 + task_manager = TaskManager() task_id = task_manager.create_task( task_type="report_generate", @@ -120,24 +95,24 @@ def generate_report(): } ) - # 定义后台任务 + def run_generate(): try: task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=0, - message="初始化Report Agent..." + message="Initializing Report Agent..." ) - # 创建Report Agent + agent = ReportAgent( graph_id=graph_id, simulation_id=simulation_id, simulation_requirement=simulation_requirement ) - # 进度回调 + def progress_callback(stage, progress, message): task_manager.update_task( task_id, @@ -145,13 +120,13 @@ def progress_callback(stage, progress, message): message=f"[{stage}] {message}" ) - # 生成报告(传入预先生成的 report_id) + report = agent.generate_report( progress_callback=progress_callback, report_id=report_id ) - # 保存报告 + ReportManager.save_report(report) if report.status == ReportStatus.COMPLETED: @@ -164,13 +139,13 @@ def progress_callback(stage, progress, message): } ) else: - task_manager.fail_task(task_id, report.error or "报告生成失败") + task_manager.fail_task(task_id, report.error or "Report generation failed") except Exception as e: - logger.error(f"报告生成失败: {str(e)}") + logger.error(f"Report generation failed: {str(e)}") task_manager.fail_task(task_id, str(e)) - # 启动后台线程 + thread = threading.Thread(target=run_generate, daemon=True) thread.start() @@ -181,13 +156,13 @@ def progress_callback(stage, progress, message): "report_id": report_id, "task_id": task_id, "status": "generating", - "message": "报告生成任务已启动,请通过 /api/report/generate/status 查询进度", + "message": "Report generation started. Check progress via /api/report/generate/status", "already_generated": False } }) except Exception as e: - logger.error(f"启动报告生成任务失败: {str(e)}") + logger.error(f"Failed to start report generation: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -197,33 +172,14 @@ def progress_callback(stage, progress, message): @report_bp.route('/generate/status', methods=['POST']) def get_generate_status(): - """ - 查询报告生成任务进度 - - 请求(JSON): - { - "task_id": "task_xxxx", // 可选,generate返回的task_id - "simulation_id": "sim_xxxx" // 可选,模拟ID - } - - 返回: - { - "success": true, - "data": { - "task_id": "task_xxxx", - "status": "processing|completed|failed", - "progress": 45, - "message": "..." - } - } - """ + """Get generate status.""" try: data = request.get_json() or {} task_id = data.get('task_id') simulation_id = data.get('simulation_id') - # 如果提供了simulation_id,先检查是否已有完成的报告 + if simulation_id: existing_report = ReportManager.get_report_by_simulation(simulation_id) if existing_report and existing_report.status == ReportStatus.COMPLETED: @@ -234,7 +190,7 @@ def get_generate_status(): "report_id": existing_report.report_id, "status": "completed", "progress": 100, - "message": "报告已生成", + "message": "Report already generated", "already_completed": True } }) @@ -242,7 +198,7 @@ def get_generate_status(): if not task_id: return jsonify({ "success": False, - "error": "请提供 task_id 或 simulation_id" + "error": "task_id or simulation_id is required" }), 400 task_manager = TaskManager() @@ -251,7 +207,7 @@ def get_generate_status(): if not task: return jsonify({ "success": False, - "error": f"任务不存在: {task_id}" + "error": f"Task not found: {task_id}" }), 404 return jsonify({ @@ -260,41 +216,25 @@ def get_generate_status(): }) except Exception as e: - logger.error(f"查询任务状态失败: {str(e)}") + logger.error(f"Failed to query task status: {str(e)}") return jsonify({ "success": False, "error": str(e) }), 500 -# ============== 报告获取接口 ============== + @report_bp.route('/', methods=['GET']) def get_report(report_id: str): - """ - 获取报告详情 - - 返回: - { - "success": true, - "data": { - "report_id": "report_xxxx", - "simulation_id": "sim_xxxx", - "status": "completed", - "outline": {...}, - "markdown_content": "...", - "created_at": "...", - "completed_at": "..." - } - } - """ + """Get report.""" try: report = ReportManager.get_report(report_id) if not report: return jsonify({ "success": False, - "error": f"报告不存在: {report_id}" + "error": f"Report not found: {report_id}" }), 404 return jsonify({ @@ -303,7 +243,7 @@ def get_report(report_id: str): }) except Exception as e: - logger.error(f"获取报告失败: {str(e)}") + logger.error(f"Failed to get report: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -313,25 +253,14 @@ def get_report(report_id: str): @report_bp.route('/by-simulation/', methods=['GET']) def get_report_by_simulation(simulation_id: str): - """ - 根据模拟ID获取报告 - - 返回: - { - "success": true, - "data": { - "report_id": "report_xxxx", - ... - } - } - """ + """Get report by simulation.""" try: report = ReportManager.get_report_by_simulation(simulation_id) if not report: return jsonify({ "success": False, - "error": f"该模拟暂无报告: {simulation_id}", + "error": f"No report found for simulation: {simulation_id}", "has_report": False }), 404 @@ -342,7 +271,7 @@ def get_report_by_simulation(simulation_id: str): }) except Exception as e: - logger.error(f"获取报告失败: {str(e)}") + logger.error(f"Failed to get report: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -352,20 +281,7 @@ def get_report_by_simulation(simulation_id: str): @report_bp.route('/list', methods=['GET']) def list_reports(): - """ - 列出所有报告 - - Query参数: - simulation_id: 按模拟ID过滤(可选) - limit: 返回数量限制(默认50) - - 返回: - { - "success": true, - "data": [...], - "count": 10 - } - """ + """List reports.""" try: simulation_id = request.args.get('simulation_id') limit = request.args.get('limit', 50, type=int) @@ -382,7 +298,7 @@ def list_reports(): }) except Exception as e: - logger.error(f"列出报告失败: {str(e)}") + logger.error(f"Failed to list reports: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -392,24 +308,20 @@ def list_reports(): @report_bp.route('//download', methods=['GET']) def download_report(report_id: str): - """ - 下载报告(Markdown格式) - - 返回Markdown文件 - """ + """Download report.""" try: report = ReportManager.get_report(report_id) if not report: return jsonify({ "success": False, - "error": f"报告不存在: {report_id}" + "error": f"Report not found: {report_id}" }), 404 md_path = ReportManager._get_report_markdown_path(report_id) if not os.path.exists(md_path): - # 如果MD文件不存在,生成一个临时文件 + import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f: f.write(report.markdown_content) @@ -428,7 +340,7 @@ def download_report(report_id: str): ) except Exception as e: - logger.error(f"下载报告失败: {str(e)}") + logger.error(f"Failed to download report: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -438,23 +350,23 @@ def download_report(report_id: str): @report_bp.route('/', methods=['DELETE']) def delete_report(report_id: str): - """删除报告""" + """Delete report.""" try: success = ReportManager.delete_report(report_id) if not success: return jsonify({ "success": False, - "error": f"报告不存在: {report_id}" + "error": f"Report not found: {report_id}" }), 404 return jsonify({ "success": True, - "message": f"报告已删除: {report_id}" + "message": f"Report deleted: {report_id}" }) except Exception as e: - logger.error(f"删除报告失败: {str(e)}") + logger.error(f"Failed to delete report: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -462,35 +374,11 @@ def delete_report(report_id: str): }), 500 -# ============== Report Agent对话接口 ============== + @report_bp.route('/chat', methods=['POST']) def chat_with_report_agent(): - """ - 与Report Agent对话 - - Report Agent可以在对话中自主调用检索工具来回答问题 - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "message": "请解释一下舆情走向", // 必填,用户消息 - "chat_history": [ // 可选,对话历史 - {"role": "user", "content": "..."}, - {"role": "assistant", "content": "..."} - ] - } - - 返回: - { - "success": true, - "data": { - "response": "Agent回复...", - "tool_calls": [调用的工具列表], - "sources": [信息来源] - } - } - """ + """Chat With Report Agent.""" try: data = request.get_json() or {} @@ -501,42 +389,42 @@ def chat_with_report_agent(): if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "simulation_id is required" }), 400 if not message: return jsonify({ "success": False, - "error": "请提供 message" + "error": "message is required" }), 400 - # 获取模拟和项目信息 + manager = SimulationManager() state = manager.get_simulation(simulation_id) if not state: return jsonify({ "success": False, - "error": f"模拟不存在: {simulation_id}" + "error": f"Simulation not found: {simulation_id}" }), 404 project = ProjectManager.get_project(state.project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {state.project_id}" + "error": f"Project not found: {state.project_id}" }), 404 graph_id = state.graph_id or project.graph_id if not graph_id: return jsonify({ "success": False, - "error": "缺少图谱ID" + "error": "Missing graph ID" }), 400 simulation_requirement = project.simulation_requirement or "" - # 创建Agent并进行对话 + agent = ReportAgent( graph_id=graph_id, simulation_id=simulation_id, @@ -551,7 +439,7 @@ def chat_with_report_agent(): }) except Exception as e: - logger.error(f"对话失败: {str(e)}") + logger.error(f"Chat failed: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -559,33 +447,18 @@ def chat_with_report_agent(): }), 500 -# ============== 报告进度与分章节接口 ============== + @report_bp.route('//progress', methods=['GET']) def get_report_progress(report_id: str): - """ - 获取报告生成进度(实时) - - 返回: - { - "success": true, - "data": { - "status": "generating", - "progress": 45, - "message": "正在生成章节: 关键发现", - "current_section": "关键发现", - "completed_sections": ["执行摘要", "模拟背景"], - "updated_at": "2025-12-09T..." - } - } - """ + """Get report progress.""" try: progress = ReportManager.get_progress(report_id) if not progress: return jsonify({ "success": False, - "error": f"报告不存在或进度信息不可用: {report_id}" + "error": f"Report not found or progress unavailable: {report_id}" }), 404 return jsonify({ @@ -594,7 +467,7 @@ def get_report_progress(report_id: str): }) except Exception as e: - logger.error(f"获取报告进度失败: {str(e)}") + logger.error(f"Failed to get report progress: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -604,33 +477,11 @@ def get_report_progress(report_id: str): @report_bp.route('//sections', methods=['GET']) def get_report_sections(report_id: str): - """ - 获取已生成的章节列表(分章节输出) - - 前端可以轮询此接口获取已生成的章节内容,无需等待整个报告完成 - - 返回: - { - "success": true, - "data": { - "report_id": "report_xxxx", - "sections": [ - { - "filename": "section_01.md", - "section_index": 1, - "content": "## 执行摘要\\n\\n..." - }, - ... - ], - "total_sections": 3, - "is_complete": false - } - } - """ + """Get report sections.""" try: sections = ReportManager.get_generated_sections(report_id) - # 获取报告状态 + report = ReportManager.get_report(report_id) is_complete = report is not None and report.status == ReportStatus.COMPLETED @@ -645,7 +496,7 @@ def get_report_sections(report_id: str): }) except Exception as e: - logger.error(f"获取章节列表失败: {str(e)}") + logger.error(f"Failed to get section list: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -655,25 +506,14 @@ def get_report_sections(report_id: str): @report_bp.route('//section/', methods=['GET']) def get_single_section(report_id: str, section_index: int): - """ - 获取单个章节内容 - - 返回: - { - "success": true, - "data": { - "filename": "section_01.md", - "content": "## 执行摘要\\n\\n..." - } - } - """ + """Get single section.""" try: section_path = ReportManager._get_section_path(report_id, section_index) if not os.path.exists(section_path): return jsonify({ "success": False, - "error": f"章节不存在: section_{section_index:02d}.md" + "error": f"Section not found: section_{section_index:02d}.md" }), 404 with open(section_path, 'r', encoding='utf-8') as f: @@ -689,7 +529,7 @@ def get_single_section(report_id: str, section_index: int): }) except Exception as e: - logger.error(f"获取章节内容失败: {str(e)}") + logger.error(f"Failed to get section content: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -697,27 +537,11 @@ def get_single_section(report_id: str, section_index: int): }), 500 -# ============== 报告状态检查接口 ============== + @report_bp.route('/check/', methods=['GET']) def check_report_status(simulation_id: str): - """ - 检查模拟是否有报告,以及报告状态 - - 用于前端判断是否解锁Interview功能 - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "has_report": true, - "report_status": "completed", - "report_id": "report_xxxx", - "interview_unlocked": true - } - } - """ + """Check report status.""" try: report = ReportManager.get_report_by_simulation(simulation_id) @@ -725,7 +549,7 @@ def check_report_status(simulation_id: str): report_status = report.status.value if report else None report_id = report.report_id if report else None - # 只有报告完成后才解锁interview + interview_unlocked = has_report and report.status == ReportStatus.COMPLETED return jsonify({ @@ -740,7 +564,7 @@ def check_report_status(simulation_id: str): }) except Exception as e: - logger.error(f"检查报告状态失败: {str(e)}") + logger.error(f"Failed to check report status: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -748,48 +572,11 @@ def check_report_status(simulation_id: str): }), 500 -# ============== Agent 日志接口 ============== + @report_bp.route('//agent-log', methods=['GET']) def get_agent_log(report_id: str): - """ - 获取 Report Agent 的详细执行日志 - - 实时获取报告生成过程中的每一步动作,包括: - - 报告开始、规划开始/完成 - - 每个章节的开始、工具调用、LLM响应、完成 - - 报告完成或失败 - - Query参数: - from_line: 从第几行开始读取(可选,默认0,用于增量获取) - - 返回: - { - "success": true, - "data": { - "logs": [ - { - "timestamp": "2025-12-13T...", - "elapsed_seconds": 12.5, - "report_id": "report_xxxx", - "action": "tool_call", - "stage": "generating", - "section_title": "执行摘要", - "section_index": 1, - "details": { - "tool_name": "insight_forge", - "parameters": {...}, - ... - } - }, - ... - ], - "total_lines": 25, - "from_line": 0, - "has_more": false - } - } - """ + """Get agent log.""" try: from_line = request.args.get('from_line', 0, type=int) @@ -801,7 +588,7 @@ def get_agent_log(report_id: str): }) except Exception as e: - logger.error(f"获取Agent日志失败: {str(e)}") + logger.error(f"Failed to get agent log: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -811,18 +598,7 @@ def get_agent_log(report_id: str): @report_bp.route('//agent-log/stream', methods=['GET']) def stream_agent_log(report_id: str): - """ - 获取完整的 Agent 日志(一次性获取全部) - - 返回: - { - "success": true, - "data": { - "logs": [...], - "count": 25 - } - } - """ + """Stream Agent Log.""" try: logs = ReportManager.get_agent_log_stream(report_id) @@ -835,7 +611,7 @@ def stream_agent_log(report_id: str): }) except Exception as e: - logger.error(f"获取Agent日志失败: {str(e)}") + logger.error(f"Failed to get agent log: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -843,35 +619,11 @@ def stream_agent_log(report_id: str): }), 500 -# ============== 控制台日志接口 ============== + @report_bp.route('//console-log', methods=['GET']) def get_console_log(report_id: str): - """ - 获取 Report Agent 的控制台输出日志 - - 实时获取报告生成过程中的控制台输出(INFO、WARNING等), - 这与 agent-log 接口返回的结构化 JSON 日志不同, - 是纯文本格式的控制台风格日志。 - - Query参数: - from_line: 从第几行开始读取(可选,默认0,用于增量获取) - - 返回: - { - "success": true, - "data": { - "logs": [ - "[19:46:14] INFO: 搜索完成: 找到 15 条相关事实", - "[19:46:14] INFO: 图谱搜索: graph_id=xxx, query=...", - ... - ], - "total_lines": 100, - "from_line": 0, - "has_more": false - } - } - """ + """Get console log.""" try: from_line = request.args.get('from_line', 0, type=int) @@ -883,7 +635,7 @@ def get_console_log(report_id: str): }) except Exception as e: - logger.error(f"获取控制台日志失败: {str(e)}") + logger.error(f"Failed to get console log: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -893,18 +645,7 @@ def get_console_log(report_id: str): @report_bp.route('//console-log/stream', methods=['GET']) def stream_console_log(report_id: str): - """ - 获取完整的控制台日志(一次性获取全部) - - 返回: - { - "success": true, - "data": { - "logs": [...], - "count": 100 - } - } - """ + """Stream Console Log.""" try: logs = ReportManager.get_console_log_stream(report_id) @@ -917,7 +658,7 @@ def stream_console_log(report_id: str): }) except Exception as e: - logger.error(f"获取控制台日志失败: {str(e)}") + logger.error(f"Failed to get console log: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -925,20 +666,11 @@ def stream_console_log(report_id: str): }), 500 -# ============== 工具调用接口(供调试使用)============== + @report_bp.route('/tools/search', methods=['POST']) def search_graph_tool(): - """ - 图谱搜索工具接口(供调试使用) - - 请求(JSON): - { - "graph_id": "mirofish_xxxx", - "query": "搜索查询", - "limit": 10 - } - """ + """Search graph tool.""" try: data = request.get_json() or {} @@ -949,7 +681,7 @@ def search_graph_tool(): if not graph_id or not query: return jsonify({ "success": False, - "error": "请提供 graph_id 和 query" + "error": "graph_id and query are required" }), 400 from ..services.zep_tools import ZepToolsService @@ -967,7 +699,7 @@ def search_graph_tool(): }) except Exception as e: - logger.error(f"图谱搜索失败: {str(e)}") + logger.error(f"Graph search failed: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -977,14 +709,7 @@ def search_graph_tool(): @report_bp.route('/tools/statistics', methods=['POST']) def get_graph_statistics_tool(): - """ - 图谱统计工具接口(供调试使用) - - 请求(JSON): - { - "graph_id": "mirofish_xxxx" - } - """ + """Get graph statistics tool.""" try: data = request.get_json() or {} @@ -993,7 +718,7 @@ def get_graph_statistics_tool(): if not graph_id: return jsonify({ "success": False, - "error": "请提供 graph_id" + "error": "graph_id is required" }), 400 from ..services.zep_tools import ZepToolsService @@ -1007,7 +732,7 @@ def get_graph_statistics_tool(): }) except Exception as e: - logger.error(f"获取图谱统计失败: {str(e)}") + logger.error(f"Failed to get graph statistics: {str(e)}") return jsonify({ "success": False, "error": str(e), diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a0f68168..0a08d8b7c 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -1,7 +1,4 @@ -""" -模拟相关API路由 -Step2: Zep实体读取与过滤、OASIS模拟准备与运行(全程自动化) -""" +"""Simulation API routes.""" import os import traceback @@ -19,54 +16,41 @@ logger = get_logger('mirofish.api.simulation') -# Interview prompt 优化前缀 -# 添加此前缀可以避免Agent调用工具,直接用文本回复 -INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:" + + +INTERVIEW_PROMPT_PREFIX = ( + "Based on your persona and all of your past memories and actions, " + "reply directly in plain text without calling any tools: " +) def optimize_interview_prompt(prompt: str) -> str: - """ - 优化Interview提问,添加前缀避免Agent调用工具 - - Args: - prompt: 原始提问 - - Returns: - 优化后的提问 - """ + """Optimize interview prompt.""" if not prompt: return prompt - # 避免重复添加前缀 + if prompt.startswith(INTERVIEW_PROMPT_PREFIX): return prompt return f"{INTERVIEW_PROMPT_PREFIX}{prompt}" -# ============== 实体读取接口 ============== + @simulation_bp.route('/entities/', methods=['GET']) def get_graph_entities(graph_id: str): - """ - 获取图谱中的所有实体(已过滤) - - 只返回符合预定义实体类型的节点(Labels不只是Entity的节点) - - Query参数: - entity_types: 逗号分隔的实体类型列表(可选,用于进一步过滤) - enrich: 是否获取相关边信息(默认true) - """ + """Get graph entities.""" try: - if not Config.ZEP_API_KEY: + if Config.validate_graph_backend(): return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "Graph backend is not configured correctly" }), 500 entity_types_str = request.args.get('entity_types', '') entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None enrich = request.args.get('enrich', 'true').lower() == 'true' - logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}") + logger.info(f"Fetching graph entities: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}") reader = ZepEntityReader() result = reader.filter_defined_entities( @@ -81,7 +65,7 @@ def get_graph_entities(graph_id: str): }) except Exception as e: - logger.error(f"获取图谱实体失败: {str(e)}") + logger.error(f"Failed to fetch graph entities: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -91,12 +75,12 @@ def get_graph_entities(graph_id: str): @simulation_bp.route('/entities//', methods=['GET']) def get_entity_detail(graph_id: str, entity_uuid: str): - """获取单个实体的详细信息""" + """Get entity detail.""" try: - if not Config.ZEP_API_KEY: + if Config.validate_graph_backend(): return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "Graph backend is not configured correctly" }), 500 reader = ZepEntityReader() @@ -105,7 +89,7 @@ def get_entity_detail(graph_id: str, entity_uuid: str): if not entity: return jsonify({ "success": False, - "error": f"实体不存在: {entity_uuid}" + "error": f"Entity does not exist: {entity_uuid}" }), 404 return jsonify({ @@ -114,7 +98,7 @@ def get_entity_detail(graph_id: str, entity_uuid: str): }) except Exception as e: - logger.error(f"获取实体详情失败: {str(e)}") + logger.error(f"Failed to fetch entity details: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -124,12 +108,12 @@ def get_entity_detail(graph_id: str, entity_uuid: str): @simulation_bp.route('/entities//by-type/', methods=['GET']) def get_entities_by_type(graph_id: str, entity_type: str): - """获取指定类型的所有实体""" + """Get entities by type.""" try: - if not Config.ZEP_API_KEY: + if Config.validate_graph_backend(): return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "Graph backend is not configured correctly" }), 500 enrich = request.args.get('enrich', 'true').lower() == 'true' @@ -151,7 +135,7 @@ def get_entities_by_type(graph_id: str, entity_type: str): }) except Exception as e: - logger.error(f"获取实体失败: {str(e)}") + logger.error(f"Failed to fetch entities: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -159,37 +143,11 @@ def get_entities_by_type(graph_id: str, entity_type: str): }), 500 -# ============== 模拟管理接口 ============== + @simulation_bp.route('/create', methods=['POST']) def create_simulation(): - """ - 创建新的模拟 - - 注意:max_rounds等参数由LLM智能生成,无需手动设置 - - 请求(JSON): - { - "project_id": "proj_xxxx", // 必填 - "graph_id": "mirofish_xxxx", // 可选,如不提供则从project获取 - "enable_twitter": true, // 可选,默认true - "enable_reddit": true // 可选,默认true - } - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "project_id": "proj_xxxx", - "graph_id": "mirofish_xxxx", - "status": "created", - "enable_twitter": true, - "enable_reddit": true, - "created_at": "2025-12-01T10:00:00" - } - } - """ + """Create simulation.""" try: data = request.get_json() or {} @@ -197,21 +155,21 @@ def create_simulation(): if not project_id: return jsonify({ "success": False, - "error": "请提供 project_id" + "error": "Please provide project_id" }), 400 project = ProjectManager.get_project(project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {project_id}" + "error": f"Project does not exist: {project_id}" }), 404 graph_id = data.get('graph_id') or project.graph_id if not graph_id: return jsonify({ "success": False, - "error": "项目尚未构建图谱,请先调用 /api/graph/build" + "error": "The project graph has not been built yet. Call /api/graph/build first" }), 400 manager = SimulationManager() @@ -228,7 +186,7 @@ def create_simulation(): }) except Exception as e: - logger.error(f"创建模拟失败: {str(e)}") + logger.error(f"Failed to create simulation: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -237,31 +195,17 @@ def create_simulation(): def _check_simulation_prepared(simulation_id: str) -> tuple: - """ - 检查模拟是否已经准备完成 - - 检查条件: - 1. state.json 存在且 status 为 "ready" - 2. 必要文件存在:reddit_profiles.json, twitter_profiles.csv, simulation_config.json - - 注意:运行脚本(run_*.py)保留在 backend/scripts/ 目录,不再复制到模拟目录 - - Args: - simulation_id: 模拟ID - - Returns: - (is_prepared: bool, info: dict) - """ + """Check simulation prepared.""" import os from ..config import Config simulation_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) - # 检查目录是否存在 + if not os.path.exists(simulation_dir): - return False, {"reason": "模拟目录不存在"} + return False, {"reason": "Simulation directory does not exist"} + - # 必要文件列表(不包括脚本,脚本位于 backend/scripts/) required_files = [ "state.json", "simulation_config.json", @@ -269,7 +213,7 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: "twitter_profiles.csv" ] - # 检查文件是否存在 + existing_files = [] missing_files = [] for f in required_files: @@ -281,12 +225,12 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: if missing_files: return False, { - "reason": "缺少必要文件", + "reason": "Missing required files", "missing_files": missing_files, "existing_files": existing_files } - # 检查state.json中的状态 + state_file = os.path.join(simulation_dir, "state.json") try: import json @@ -296,20 +240,23 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: status = state_data.get("status", "") config_generated = state_data.get("config_generated", False) - # 详细日志 - logger.debug(f"检测模拟准备状态: {simulation_id}, status={status}, config_generated={config_generated}") - - # 如果 config_generated=True 且文件存在,认为准备完成 - # 以下状态都说明准备工作已完成: - # - ready: 准备完成,可以运行 - # - preparing: 如果 config_generated=True 说明已完成 - # - running: 正在运行,说明准备早就完成了 - # - completed: 运行完成,说明准备早就完成了 - # - stopped: 已停止,说明准备早就完成了 - # - failed: 运行失败(但准备是完成的) + + logger.debug( + f"Checking simulation prepared state: {simulation_id}, " + f"status={status}, config_generated={config_generated}" + ) + + + + + + + + + prepared_statuses = ["ready", "preparing", "running", "completed", "stopped", "failed"] if status in prepared_statuses and config_generated: - # 获取文件统计信息 + profiles_file = os.path.join(simulation_dir, "reddit_profiles.json") config_file = os.path.join(simulation_dir, "simulation_config.json") @@ -319,7 +266,7 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: profiles_data = json.load(f) profiles_count = len(profiles_data) if isinstance(profiles_data, list) else 0 - # 如果状态是preparing但文件已完成,自动更新状态为ready + if status == "preparing": try: state_data["status"] = "ready" @@ -327,12 +274,15 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: state_data["updated_at"] = datetime.now().isoformat() with open(state_file, 'w', encoding='utf-8') as f: json.dump(state_data, f, ensure_ascii=False, indent=2) - logger.info(f"自动更新模拟状态: {simulation_id} preparing -> ready") + logger.info(f"Auto-updated simulation status: {simulation_id} preparing -> ready") status = "ready" except Exception as e: - logger.warning(f"自动更新状态失败: {e}") + logger.warning(f"Failed to auto-update status: {e}") - logger.info(f"模拟 {simulation_id} 检测结果: 已准备完成 (status={status}, config_generated={config_generated})") + logger.info( + f"Simulation {simulation_id} check result: prepared " + f"(status={status}, config_generated={config_generated})" + ) return True, { "status": status, "entities_count": state_data.get("entities_count", 0), @@ -344,58 +294,26 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: "existing_files": existing_files } else: - logger.warning(f"模拟 {simulation_id} 检测结果: 未准备完成 (status={status}, config_generated={config_generated})") + logger.warning( + f"Simulation {simulation_id} check result: not prepared " + f"(status={status}, config_generated={config_generated})" + ) return False, { - "reason": f"状态不在已准备列表中或config_generated为false: status={status}, config_generated={config_generated}", + "reason": ( + f"Status is not in the prepared-state list or config_generated is false: " + f"status={status}, config_generated={config_generated}" + ), "status": status, "config_generated": config_generated } except Exception as e: - return False, {"reason": f"读取状态文件失败: {str(e)}"} + return False, {"reason": f"Failed to read state file: {str(e)}"} @simulation_bp.route('/prepare', methods=['POST']) def prepare_simulation(): - """ - 准备模拟环境(异步任务,LLM智能生成所有参数) - - 这是一个耗时操作,接口会立即返回task_id, - 使用 GET /api/simulation/prepare/status 查询进度 - - 特性: - - 自动检测已完成的准备工作,避免重复生成 - - 如果已准备完成,直接返回已有结果 - - 支持强制重新生成(force_regenerate=true) - - 步骤: - 1. 检查是否已有完成的准备工作 - 2. 从Zep图谱读取并过滤实体 - 3. 为每个实体生成OASIS Agent Profile(带重试机制) - 4. LLM智能生成模拟配置(带重试机制) - 5. 保存配置文件和预设脚本 - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "entity_types": ["Student", "PublicFigure"], // 可选,指定实体类型 - "use_llm_for_profiles": true, // 可选,是否用LLM生成人设 - "parallel_profile_count": 5, // 可选,并行生成人设数量,默认5 - "force_regenerate": false // 可选,强制重新生成,默认false - } - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "task_id": "task_xxxx", // 新任务时返回 - "status": "preparing|ready", - "message": "准备任务已启动|已有完成的准备工作", - "already_prepared": true|false // 是否已准备完成 - } - } - """ + """Prepare simulation.""" import threading import os from ..models.task import TaskManager, TaskStatus @@ -408,7 +326,7 @@ def prepare_simulation(): if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 manager = SimulationManager() @@ -417,76 +335,82 @@ def prepare_simulation(): if not state: return jsonify({ "success": False, - "error": f"模拟不存在: {simulation_id}" + "error": f"Simulation does not exist: {simulation_id}" }), 404 - # 检查是否强制重新生成 + force_regenerate = data.get('force_regenerate', False) - logger.info(f"开始处理 /prepare 请求: simulation_id={simulation_id}, force_regenerate={force_regenerate}") + logger.info( + f"Starting /prepare request: simulation_id={simulation_id}, " + f"force_regenerate={force_regenerate}" + ) + - # 检查是否已经准备完成(避免重复生成) if not force_regenerate: - logger.debug(f"检查模拟 {simulation_id} 是否已准备完成...") + logger.debug(f"Checking whether simulation {simulation_id} is already prepared...") is_prepared, prepare_info = _check_simulation_prepared(simulation_id) - logger.debug(f"检查结果: is_prepared={is_prepared}, prepare_info={prepare_info}") + logger.debug(f"Check result: is_prepared={is_prepared}, prepare_info={prepare_info}") if is_prepared: - logger.info(f"模拟 {simulation_id} 已准备完成,跳过重复生成") + logger.info(f"Simulation {simulation_id} is already prepared; skipping duplicate generation") return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "status": "ready", - "message": "已有完成的准备工作,无需重复生成", + "message": "Preparation has already been completed; no regeneration is needed", "already_prepared": True, "prepare_info": prepare_info } }) else: - logger.info(f"模拟 {simulation_id} 未准备完成,将启动准备任务") + logger.info(f"Simulation {simulation_id} is not prepared; starting preparation task") + - # 从项目获取必要信息 project = ProjectManager.get_project(state.project_id) if not project: return jsonify({ "success": False, - "error": f"项目不存在: {state.project_id}" + "error": f"Project does not exist: {state.project_id}" }), 404 - # 获取模拟需求 + simulation_requirement = project.simulation_requirement or "" if not simulation_requirement: return jsonify({ "success": False, - "error": "项目缺少模拟需求描述 (simulation_requirement)" + "error": "Project is missing simulation_requirement" }), 400 - # 获取文档文本 + document_text = ProjectManager.get_extracted_text(state.project_id) or "" entity_types_list = data.get('entity_types') use_llm_for_profiles = data.get('use_llm_for_profiles', True) parallel_profile_count = data.get('parallel_profile_count', 5) - # ========== 同步获取实体数量(在后台任务启动前) ========== - # 这样前端在调用prepare后立即就能获取到预期Agent总数 + + try: - logger.info(f"同步获取实体数量: graph_id={state.graph_id}") + logger.info(f"Synchronously fetching entity count: graph_id={state.graph_id}") reader = ZepEntityReader() - # 快速读取实体(不需要边信息,只统计数量) + filtered_preview = reader.filter_defined_entities( graph_id=state.graph_id, defined_entity_types=entity_types_list, - enrich_with_edges=False # 不获取边信息,加快速度 + enrich_with_edges=False ) - # 保存实体数量到状态(供前端立即获取) + state.entities_count = filtered_preview.filtered_count state.entity_types = list(filtered_preview.entity_types) - logger.info(f"预期实体数量: {filtered_preview.filtered_count}, 类型: {filtered_preview.entity_types}") + logger.info( + f"Expected entity count: {filtered_preview.filtered_count}, " + f"types: {filtered_preview.entity_types}" + ) except Exception as e: - logger.warning(f"同步获取实体数量失败(将在后台任务中重试): {e}") - # 失败不影响后续流程,后台任务会重新获取 + logger.warning(f"Failed to fetch entity count synchronously (will retry in background task): {e}") + + - # 创建异步任务 task_manager = TaskManager() task_id = task_manager.create_task( task_type="simulation_prepare", @@ -496,26 +420,26 @@ def prepare_simulation(): } ) - # 更新模拟状态(包含预先获取的实体数量) + state.status = SimulationStatus.PREPARING manager._save_simulation_state(state) - # 定义后台任务 + def run_prepare(): try: task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=0, - message="开始准备模拟环境..." + message="Starting simulation environment preparation..." ) - # 准备模拟(带进度回调) - # 存储阶段进度详情 + + stage_details = {} def progress_callback(stage, progress, message, **kwargs): - # 计算总进度 + stage_weights = { "reading": (0, 20), # 0-20% "generating_profiles": (20, 70), # 20-70% @@ -526,18 +450,18 @@ def progress_callback(stage, progress, message, **kwargs): start, end = stage_weights.get(stage, (0, 100)) current_progress = int(start + (end - start) * progress / 100) - # 构建详细进度信息 + stage_names = { - "reading": "读取图谱实体", - "generating_profiles": "生成Agent人设", - "generating_config": "生成模拟配置", - "copying_scripts": "准备模拟脚本" + "reading": "Reading graph entities", + "generating_profiles": "Generating agent personas", + "generating_config": "Generating simulation config", + "copying_scripts": "Preparing simulation scripts" } stage_index = list(stage_weights.keys()).index(stage) + 1 if stage in stage_weights else 1 total_stages = len(stage_weights) - # 更新阶段详情 + stage_details[stage] = { "stage_name": stage_names.get(stage, stage), "stage_progress": progress, @@ -546,7 +470,7 @@ def progress_callback(stage, progress, message, **kwargs): "item_name": kwargs.get("item_name", "") } - # 构建详细进度信息 + detail = stage_details[stage] progress_detail_data = { "current_stage": stage, @@ -559,7 +483,7 @@ def progress_callback(stage, progress, message, **kwargs): "item_description": message } - # 构建简洁消息 + if detail["total"] > 0: detailed_message = ( f"[{stage_index}/{total_stages}] {stage_names.get(stage, stage)}: " @@ -585,24 +509,24 @@ def progress_callback(stage, progress, message, **kwargs): parallel_profile_count=parallel_profile_count ) - # 任务完成 + task_manager.complete_task( task_id, result=result_state.to_simple_dict() ) except Exception as e: - logger.error(f"准备模拟失败: {str(e)}") + logger.error(f"Simulation preparation failed: {str(e)}") task_manager.fail_task(task_id, str(e)) - # 更新模拟状态为失败 + state = manager.get_simulation(simulation_id) if state: state.status = SimulationStatus.FAILED state.error = str(e) manager._save_simulation_state(state) - # 启动后台线程 + thread = threading.Thread(target=run_prepare, daemon=True) thread.start() @@ -612,10 +536,10 @@ def progress_callback(stage, progress, message, **kwargs): "simulation_id": simulation_id, "task_id": task_id, "status": "preparing", - "message": "准备任务已启动,请通过 /api/simulation/prepare/status 查询进度", + "message": "Preparation task started. Query progress via /api/simulation/prepare/status", "already_prepared": False, - "expected_entities_count": state.entities_count, # 预期的Agent总数 - "entity_types": state.entity_types # 实体类型列表 + "expected_entities_count": state.entities_count, + "entity_types": state.entity_types } }) @@ -626,7 +550,7 @@ def progress_callback(stage, progress, message, **kwargs): }), 404 except Exception as e: - logger.error(f"启动准备任务失败: {str(e)}") + logger.error(f"Failed to start preparation task: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -636,32 +560,7 @@ def progress_callback(stage, progress, message, **kwargs): @simulation_bp.route('/prepare/status', methods=['POST']) def get_prepare_status(): - """ - 查询准备任务进度 - - 支持两种查询方式: - 1. 通过task_id查询正在进行的任务进度 - 2. 通过simulation_id检查是否已有完成的准备工作 - - 请求(JSON): - { - "task_id": "task_xxxx", // 可选,prepare返回的task_id - "simulation_id": "sim_xxxx" // 可选,模拟ID(用于检查已完成的准备) - } - - 返回: - { - "success": true, - "data": { - "task_id": "task_xxxx", - "status": "processing|completed|ready", - "progress": 45, - "message": "...", - "already_prepared": true|false, // 是否已有完成的准备 - "prepare_info": {...} // 已准备完成时的详细信息 - } - } - """ + """Get prepare status.""" from ..models.task import TaskManager try: @@ -670,7 +569,7 @@ def get_prepare_status(): task_id = data.get('task_id') simulation_id = data.get('simulation_id') - # 如果提供了simulation_id,先检查是否已准备完成 + if simulation_id: is_prepared, prepare_info = _check_simulation_prepared(simulation_id) if is_prepared: @@ -680,36 +579,36 @@ def get_prepare_status(): "simulation_id": simulation_id, "status": "ready", "progress": 100, - "message": "已有完成的准备工作", + "message": "Preparation has already been completed", "already_prepared": True, "prepare_info": prepare_info } }) - # 如果没有task_id,返回错误 + if not task_id: if simulation_id: - # 有simulation_id但未准备完成 + return jsonify({ "success": True, "data": { "simulation_id": simulation_id, "status": "not_started", "progress": 0, - "message": "尚未开始准备,请调用 /api/simulation/prepare 开始", + "message": "Preparation has not started yet. Call /api/simulation/prepare to begin", "already_prepared": False } }) return jsonify({ "success": False, - "error": "请提供 task_id 或 simulation_id" + "error": "Please provide task_id or simulation_id" }), 400 task_manager = TaskManager() task = task_manager.get_task(task_id) if not task: - # 任务不存在,但如果有simulation_id,检查是否已准备完成 + if simulation_id: is_prepared, prepare_info = _check_simulation_prepared(simulation_id) if is_prepared: @@ -720,7 +619,7 @@ def get_prepare_status(): "task_id": task_id, "status": "ready", "progress": 100, - "message": "任务已完成(准备工作已存在)", + "message": "Task completed (preparation already existed)", "already_prepared": True, "prepare_info": prepare_info } @@ -728,7 +627,7 @@ def get_prepare_status(): return jsonify({ "success": False, - "error": f"任务不存在: {task_id}" + "error": f"Task does not exist: {task_id}" }), 404 task_dict = task.to_dict() @@ -740,7 +639,7 @@ def get_prepare_status(): }) except Exception as e: - logger.error(f"查询任务状态失败: {str(e)}") + logger.error(f"Failed to query task status: {str(e)}") return jsonify({ "success": False, "error": str(e) @@ -749,7 +648,7 @@ def get_prepare_status(): @simulation_bp.route('/', methods=['GET']) def get_simulation(simulation_id: str): - """获取模拟状态""" + """Get simulation.""" try: manager = SimulationManager() state = manager.get_simulation(simulation_id) @@ -757,12 +656,12 @@ def get_simulation(simulation_id: str): if not state: return jsonify({ "success": False, - "error": f"模拟不存在: {simulation_id}" + "error": f"Simulation does not exist: {simulation_id}" }), 404 result = state.to_dict() - # 如果模拟已准备好,附加运行说明 + if state.status == SimulationStatus.READY: result["run_instructions"] = manager.get_run_instructions(simulation_id) @@ -772,7 +671,7 @@ def get_simulation(simulation_id: str): }) except Exception as e: - logger.error(f"获取模拟状态失败: {str(e)}") + logger.error(f"Failed to fetch simulation status: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -782,12 +681,7 @@ def get_simulation(simulation_id: str): @simulation_bp.route('/list', methods=['GET']) def list_simulations(): - """ - 列出所有模拟 - - Query参数: - project_id: 按项目ID过滤(可选) - """ + """List simulations.""" try: project_id = request.args.get('project_id') @@ -801,7 +695,7 @@ def list_simulations(): }) except Exception as e: - logger.error(f"列出模拟失败: {str(e)}") + logger.error(f"Failed to list simulations: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -810,23 +704,12 @@ def list_simulations(): def _get_report_id_for_simulation(simulation_id: str) -> str: - """ - 获取 simulation 对应的最新 report_id - - 遍历 reports 目录,找出 simulation_id 匹配的 report, - 如果有多个则返回最新的(按 created_at 排序) - - Args: - simulation_id: 模拟ID - - Returns: - report_id 或 None - """ + """Get report id for simulation.""" import json from datetime import datetime - # reports 目录路径:backend/uploads/reports - # __file__ 是 app/api/simulation.py,需要向上两级到 backend/ + + reports_dir = os.path.join(os.path.dirname(__file__), '../../uploads/reports') if not os.path.exists(reports_dir): return None @@ -859,68 +742,36 @@ def _get_report_id_for_simulation(simulation_id: str) -> str: if not matching_reports: return None - # 按创建时间倒序排序,返回最新的 + matching_reports.sort(key=lambda x: x.get("created_at", ""), reverse=True) return matching_reports[0].get("report_id") except Exception as e: - logger.warning(f"查找 simulation {simulation_id} 的 report 失败: {e}") + logger.warning(f"Failed to find report for simulation {simulation_id}: {e}") return None @simulation_bp.route('/history', methods=['GET']) def get_simulation_history(): - """ - 获取历史模拟列表(带项目详情) - - 用于首页历史项目展示,返回包含项目名称、描述等丰富信息的模拟列表 - - Query参数: - limit: 返回数量限制(默认20) - - 返回: - { - "success": true, - "data": [ - { - "simulation_id": "sim_xxxx", - "project_id": "proj_xxxx", - "project_name": "武大舆情分析", - "simulation_requirement": "如果武汉大学发布...", - "status": "completed", - "entities_count": 68, - "profiles_count": 68, - "entity_types": ["Student", "Professor", ...], - "created_at": "2024-12-10", - "updated_at": "2024-12-10", - "total_rounds": 120, - "current_round": 120, - "report_id": "report_xxxx", - "version": "v1.0.2" - }, - ... - ], - "count": 7 - } - """ + """Get simulation history.""" try: limit = request.args.get('limit', 20, type=int) manager = SimulationManager() simulations = manager.list_simulations()[:limit] - # 增强模拟数据,只从 Simulation 文件读取 + enriched_simulations = [] for sim in simulations: sim_dict = sim.to_dict() - # 获取模拟配置信息(从 simulation_config.json 读取 simulation_requirement) + config = manager.get_simulation_config(sim.simulation_id) if config: sim_dict["simulation_requirement"] = config.get("simulation_requirement", "") time_config = config.get("time_config", {}) sim_dict["total_simulation_hours"] = time_config.get("total_simulation_hours", 0) - # 推荐轮数(后备值) + recommended_rounds = int( time_config.get("total_simulation_hours", 0) * 60 / max(time_config.get("minutes_per_round", 60), 1) @@ -930,35 +781,35 @@ def get_simulation_history(): sim_dict["total_simulation_hours"] = 0 recommended_rounds = 0 - # 获取运行状态(从 run_state.json 读取用户设置的实际轮数) + run_state = SimulationRunner.get_run_state(sim.simulation_id) if run_state: sim_dict["current_round"] = run_state.current_round sim_dict["runner_status"] = run_state.runner_status.value - # 使用用户设置的 total_rounds,若无则使用推荐轮数 + sim_dict["total_rounds"] = run_state.total_rounds if run_state.total_rounds > 0 else recommended_rounds else: sim_dict["current_round"] = 0 sim_dict["runner_status"] = "idle" sim_dict["total_rounds"] = recommended_rounds - # 获取关联项目的文件列表(最多3个) + project = ProjectManager.get_project(sim.project_id) if project and hasattr(project, 'files') and project.files: sim_dict["files"] = [ - {"filename": f.get("filename", "未知文件")} + {"filename": f.get("filename", "Unknown file")} for f in project.files[:3] ] else: sim_dict["files"] = [] - # 获取关联的 report_id(查找该 simulation 最新的 report) + sim_dict["report_id"] = _get_report_id_for_simulation(sim.simulation_id) - # 添加版本号 + sim_dict["version"] = "v1.0.2" - # 格式化日期 + try: created_date = sim_dict.get("created_at", "")[:10] sim_dict["created_date"] = created_date @@ -974,7 +825,7 @@ def get_simulation_history(): }) except Exception as e: - logger.error(f"获取历史模拟失败: {str(e)}") + logger.error(f"Failed to fetch simulation history: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -984,12 +835,7 @@ def get_simulation_history(): @simulation_bp.route('//profiles', methods=['GET']) def get_simulation_profiles(simulation_id: str): - """ - 获取模拟的Agent Profile - - Query参数: - platform: 平台类型(reddit/twitter,默认reddit) - """ + """Get simulation profiles.""" try: platform = request.args.get('platform', 'reddit') @@ -1012,7 +858,7 @@ def get_simulation_profiles(simulation_id: str): }), 404 except Exception as e: - logger.error(f"获取Profile失败: {str(e)}") + logger.error(f"Failed to fetch profiles: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1022,32 +868,7 @@ def get_simulation_profiles(simulation_id: str): @simulation_bp.route('//profiles/realtime', methods=['GET']) def get_simulation_profiles_realtime(simulation_id: str): - """ - 实时获取模拟的Agent Profile(用于在生成过程中实时查看进度) - - 与 /profiles 接口的区别: - - 直接读取文件,不经过 SimulationManager - - 适用于生成过程中的实时查看 - - 返回额外的元数据(如文件修改时间、是否正在生成等) - - Query参数: - platform: 平台类型(reddit/twitter,默认reddit) - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "platform": "reddit", - "count": 15, - "total_expected": 93, // 预期总数(如果有) - "is_generating": true, // 是否正在生成 - "file_exists": true, - "file_modified_at": "2025-12-04T18:20:00", - "profiles": [...] - } - } - """ + """Get simulation profiles realtime.""" import json import csv from datetime import datetime @@ -1055,28 +876,28 @@ def get_simulation_profiles_realtime(simulation_id: str): try: platform = request.args.get('platform', 'reddit') - # 获取模拟目录 + sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) if not os.path.exists(sim_dir): return jsonify({ "success": False, - "error": f"模拟不存在: {simulation_id}" + "error": f"Simulation does not exist: {simulation_id}" }), 404 - # 确定文件路径 + if platform == "reddit": profiles_file = os.path.join(sim_dir, "reddit_profiles.json") else: profiles_file = os.path.join(sim_dir, "twitter_profiles.csv") - # 检查文件是否存在 + file_exists = os.path.exists(profiles_file) profiles = [] file_modified_at = None if file_exists: - # 获取文件修改时间 + file_stat = os.stat(profiles_file) file_modified_at = datetime.fromtimestamp(file_stat.st_mtime).isoformat() @@ -1089,10 +910,10 @@ def get_simulation_profiles_realtime(simulation_id: str): reader = csv.DictReader(f) profiles = list(reader) except (json.JSONDecodeError, Exception) as e: - logger.warning(f"读取 profiles 文件失败(可能正在写入中): {e}") + logger.warning(f"Failed to read profiles file (it may still be being written): {e}") profiles = [] - # 检查是否正在生成(通过 state.json 判断) + is_generating = False total_expected = None @@ -1122,7 +943,7 @@ def get_simulation_profiles_realtime(simulation_id: str): }) except Exception as e: - logger.error(f"实时获取Profile失败: {str(e)}") + logger.error(f"Failed to fetch realtime profiles: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1132,51 +953,30 @@ def get_simulation_profiles_realtime(simulation_id: str): @simulation_bp.route('//config/realtime', methods=['GET']) def get_simulation_config_realtime(simulation_id: str): - """ - 实时获取模拟配置(用于在生成过程中实时查看进度) - - 与 /config 接口的区别: - - 直接读取文件,不经过 SimulationManager - - 适用于生成过程中的实时查看 - - 返回额外的元数据(如文件修改时间、是否正在生成等) - - 即使配置还没生成完也能返回部分信息 - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "file_exists": true, - "file_modified_at": "2025-12-04T18:20:00", - "is_generating": true, // 是否正在生成 - "generation_stage": "generating_config", // 当前生成阶段 - "config": {...} // 配置内容(如果存在) - } - } - """ + """Get simulation config realtime.""" import json from datetime import datetime try: - # 获取模拟目录 + sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) if not os.path.exists(sim_dir): return jsonify({ "success": False, - "error": f"模拟不存在: {simulation_id}" + "error": f"Simulation does not exist: {simulation_id}" }), 404 - # 配置文件路径 + config_file = os.path.join(sim_dir, "simulation_config.json") - # 检查文件是否存在 + file_exists = os.path.exists(config_file) config = None file_modified_at = None if file_exists: - # 获取文件修改时间 + file_stat = os.stat(config_file) file_modified_at = datetime.fromtimestamp(file_stat.st_mtime).isoformat() @@ -1184,10 +984,10 @@ def get_simulation_config_realtime(simulation_id: str): with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) except (json.JSONDecodeError, Exception) as e: - logger.warning(f"读取 config 文件失败(可能正在写入中): {e}") + logger.warning(f"Failed to read config file (it may still be being written): {e}") config = None - # 检查是否正在生成(通过 state.json 判断) + is_generating = False generation_stage = None config_generated = False @@ -1201,7 +1001,7 @@ def get_simulation_config_realtime(simulation_id: str): is_generating = status == "preparing" config_generated = state_data.get("config_generated", False) - # 判断当前阶段 + if is_generating: if state_data.get("profiles_generated", False): generation_stage = "generating_config" @@ -1212,7 +1012,7 @@ def get_simulation_config_realtime(simulation_id: str): except Exception: pass - # 构建返回数据 + response_data = { "simulation_id": simulation_id, "file_exists": file_exists, @@ -1223,7 +1023,7 @@ def get_simulation_config_realtime(simulation_id: str): "config": config } - # 如果配置存在,提取一些关键统计信息 + if config: response_data["summary"] = { "total_agents": len(config.get("agent_configs", [])), @@ -1242,7 +1042,7 @@ def get_simulation_config_realtime(simulation_id: str): }) except Exception as e: - logger.error(f"实时获取Config失败: {str(e)}") + logger.error(f"Failed to fetch realtime config: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1252,16 +1052,7 @@ def get_simulation_config_realtime(simulation_id: str): @simulation_bp.route('//config', methods=['GET']) def get_simulation_config(simulation_id: str): - """ - 获取模拟配置(LLM智能生成的完整配置) - - 返回包含: - - time_config: 时间配置(模拟时长、轮次、高峰/低谷时段) - - agent_configs: 每个Agent的活动配置(活跃度、发言频率、立场等) - - event_config: 事件配置(初始帖子、热点话题) - - platform_configs: 平台配置 - - generation_reasoning: LLM的配置推理说明 - """ + """Get simulation config.""" try: manager = SimulationManager() config = manager.get_simulation_config(simulation_id) @@ -1269,7 +1060,7 @@ def get_simulation_config(simulation_id: str): if not config: return jsonify({ "success": False, - "error": f"模拟配置不存在,请先调用 /prepare 接口" + "error": "Simulation config does not exist. Call /prepare first" }), 404 return jsonify({ @@ -1278,7 +1069,7 @@ def get_simulation_config(simulation_id: str): }) except Exception as e: - logger.error(f"获取配置失败: {str(e)}") + logger.error(f"Failed to fetch config: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1288,7 +1079,7 @@ def get_simulation_config(simulation_id: str): @simulation_bp.route('//config/download', methods=['GET']) def download_simulation_config(simulation_id: str): - """下载模拟配置文件""" + """Download simulation config.""" try: manager = SimulationManager() sim_dir = manager._get_simulation_dir(simulation_id) @@ -1297,7 +1088,7 @@ def download_simulation_config(simulation_id: str): if not os.path.exists(config_path): return jsonify({ "success": False, - "error": "配置文件不存在,请先调用 /prepare 接口" + "error": "Config file does not exist. Call /prepare first" }), 404 return send_file( @@ -1307,7 +1098,7 @@ def download_simulation_config(simulation_id: str): ) except Exception as e: - logger.error(f"下载配置失败: {str(e)}") + logger.error(f"Failed to download config: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1317,20 +1108,12 @@ def download_simulation_config(simulation_id: str): @simulation_bp.route('/script//download', methods=['GET']) def download_simulation_script(script_name: str): - """ - 下载模拟运行脚本文件(通用脚本,位于 backend/scripts/) - - script_name可选值: - - run_twitter_simulation.py - - run_reddit_simulation.py - - run_parallel_simulation.py - - action_logger.py - """ + """Download simulation script.""" try: - # 脚本位于 backend/scripts/ 目录 + scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) - # 验证脚本名称 + allowed_scripts = [ "run_twitter_simulation.py", "run_reddit_simulation.py", @@ -1341,7 +1124,7 @@ def download_simulation_script(script_name: str): if script_name not in allowed_scripts: return jsonify({ "success": False, - "error": f"未知脚本: {script_name},可选: {allowed_scripts}" + "error": f"Unknown script: {script_name}. Allowed: {allowed_scripts}" }), 400 script_path = os.path.join(scripts_dir, script_name) @@ -1349,7 +1132,7 @@ def download_simulation_script(script_name: str): if not os.path.exists(script_path): return jsonify({ "success": False, - "error": f"脚本文件不存在: {script_name}" + "error": f"Script file does not exist: {script_name}" }), 404 return send_file( @@ -1359,7 +1142,7 @@ def download_simulation_script(script_name: str): ) except Exception as e: - logger.error(f"下载脚本失败: {str(e)}") + logger.error(f"Failed to download script: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1367,21 +1150,11 @@ def download_simulation_script(script_name: str): }), 500 -# ============== Profile生成接口(独立使用) ============== + @simulation_bp.route('/generate-profiles', methods=['POST']) def generate_profiles(): - """ - 直接从图谱生成OASIS Agent Profile(不创建模拟) - - 请求(JSON): - { - "graph_id": "mirofish_xxxx", // 必填 - "entity_types": ["Student"], // 可选 - "use_llm": true, // 可选 - "platform": "reddit" // 可选 - } - """ + """Generate profiles.""" try: data = request.get_json() or {} @@ -1389,7 +1162,7 @@ def generate_profiles(): if not graph_id: return jsonify({ "success": False, - "error": "请提供 graph_id" + "error": "Please provide graph_id" }), 400 entity_types = data.get('entity_types') @@ -1406,7 +1179,7 @@ def generate_profiles(): if filtered.filtered_count == 0: return jsonify({ "success": False, - "error": "没有找到符合条件的实体" + "error": "No entities matched the requested filters" }), 400 generator = OasisProfileGenerator() @@ -1433,7 +1206,7 @@ def generate_profiles(): }) except Exception as e: - logger.error(f"生成Profile失败: {str(e)}") + logger.error(f"Failed to generate profiles: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1441,49 +1214,11 @@ def generate_profiles(): }), 500 -# ============== 模拟运行控制接口 ============== + @simulation_bp.route('/start', methods=['POST']) def start_simulation(): - """ - 开始运行模拟 - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "platform": "parallel", // 可选: twitter / reddit / parallel (默认) - "max_rounds": 100, // 可选: 最大模拟轮数,用于截断过长的模拟 - "enable_graph_memory_update": false, // 可选: 是否将Agent活动动态更新到Zep图谱记忆 - "force": false // 可选: 强制重新开始(会停止运行中的模拟并清理日志) - } - - 关于 force 参数: - - 启用后,如果模拟正在运行或已完成,会先停止并清理运行日志 - - 清理的内容包括:run_state.json, actions.jsonl, simulation.log 等 - - 不会清理配置文件(simulation_config.json)和 profile 文件 - - 适用于需要重新运行模拟的场景 - - 关于 enable_graph_memory_update: - - 启用后,模拟中所有Agent的活动(发帖、评论、点赞等)都会实时更新到Zep图谱 - - 这可以让图谱"记住"模拟过程,用于后续分析或AI对话 - - 需要模拟关联的项目有有效的 graph_id - - 采用批量更新机制,减少API调用次数 - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "runner_status": "running", - "process_pid": 12345, - "twitter_running": true, - "reddit_running": true, - "started_at": "2025-12-01T10:00:00", - "graph_memory_update_enabled": true, // 是否启用了图谱记忆更新 - "force_restarted": true // 是否是强制重新开始 - } - } - """ + """Start simulation.""" try: data = request.get_json() or {} @@ -1491,98 +1226,107 @@ def start_simulation(): if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 platform = data.get('platform', 'parallel') - max_rounds = data.get('max_rounds') # 可选:最大模拟轮数 - enable_graph_memory_update = data.get('enable_graph_memory_update', False) # 可选:是否启用图谱记忆更新 - force = data.get('force', False) # 可选:强制重新开始 + max_rounds = data.get('max_rounds') + enable_graph_memory_update = data.get('enable_graph_memory_update', False) + force = data.get('force', False) - # 验证 max_rounds 参数 + if max_rounds is not None: try: max_rounds = int(max_rounds) if max_rounds <= 0: return jsonify({ "success": False, - "error": "max_rounds 必须是正整数" + "error": "max_rounds must be a positive integer" }), 400 except (ValueError, TypeError): return jsonify({ "success": False, - "error": "max_rounds 必须是有效的整数" + "error": "max_rounds must be a valid integer" }), 400 if platform not in ['twitter', 'reddit', 'parallel']: return jsonify({ "success": False, - "error": f"无效的平台类型: {platform},可选: twitter/reddit/parallel" + "error": f"Invalid platform type: {platform}. Allowed: twitter/reddit/parallel" }), 400 - # 检查模拟是否已准备好 + manager = SimulationManager() state = manager.get_simulation(simulation_id) if not state: return jsonify({ "success": False, - "error": f"模拟不存在: {simulation_id}" + "error": f"Simulation does not exist: {simulation_id}" }), 404 force_restarted = False - # 智能处理状态:如果准备工作已完成,允许重新启动 + if state.status != SimulationStatus.READY: - # 检查准备工作是否已完成 + is_prepared, prepare_info = _check_simulation_prepared(simulation_id) if is_prepared: - # 准备工作已完成,检查是否有正在运行的进程 + if state.status == SimulationStatus.RUNNING: - # 检查模拟进程是否真的在运行 + run_state = SimulationRunner.get_run_state(simulation_id) if run_state and run_state.runner_status.value == "running": - # 进程确实在运行 + if force: - # 强制模式:停止运行中的模拟 - logger.info(f"强制模式:停止运行中的模拟 {simulation_id}") + + logger.info(f"Force mode: stopping running simulation {simulation_id}") try: SimulationRunner.stop_simulation(simulation_id) except Exception as e: - logger.warning(f"停止模拟时出现警告: {str(e)}") + logger.warning(f"Warning while stopping simulation: {str(e)}") else: return jsonify({ "success": False, - "error": f"模拟正在运行中,请先调用 /stop 接口停止,或使用 force=true 强制重新开始" + "error": ( + "Simulation is currently running. Call /stop first, " + "or use force=true to restart anyway" + ) }), 400 - # 如果是强制模式,清理运行日志 + if force: - logger.info(f"强制模式:清理模拟日志 {simulation_id}") + logger.info(f"Force mode: cleaning simulation logs for {simulation_id}") cleanup_result = SimulationRunner.cleanup_simulation_logs(simulation_id) if not cleanup_result.get("success"): - logger.warning(f"清理日志时出现警告: {cleanup_result.get('errors')}") + logger.warning(f"Warning while cleaning logs: {cleanup_result.get('errors')}") force_restarted = True - # 进程不存在或已结束,重置状态为 ready - logger.info(f"模拟 {simulation_id} 准备工作已完成,重置状态为 ready(原状态: {state.status.value})") + + logger.info( + f"Simulation {simulation_id} preparation is complete; resetting status to ready " + f"(previous status: {state.status.value})" + ) state.status = SimulationStatus.READY manager._save_simulation_state(state) else: - # 准备工作未完成 + return jsonify({ "success": False, - "error": f"模拟未准备好,当前状态: {state.status.value},请先调用 /prepare 接口" + "error": ( + f"Simulation is not ready. Current status: {state.status.value}. " + "Call /prepare first" + ) }), 400 - # 获取图谱ID(用于图谱记忆更新) + graph_id = None if enable_graph_memory_update: - # 从模拟状态或项目中获取 graph_id + graph_id = state.graph_id if not graph_id: - # 尝试从项目中获取 + project = ProjectManager.get_project(state.project_id) if project: graph_id = project.graph_id @@ -1590,12 +1334,15 @@ def start_simulation(): if not graph_id: return jsonify({ "success": False, - "error": "启用图谱记忆更新需要有效的 graph_id,请确保项目已构建图谱" + "error": ( + "A valid graph_id is required when graph memory updates are enabled. " + "Make sure the project graph has been built" + ) }), 400 - logger.info(f"启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}") + logger.info(f"Graph memory updates enabled: simulation_id={simulation_id}, graph_id={graph_id}") + - # 启动模拟 run_state = SimulationRunner.start_simulation( simulation_id=simulation_id, platform=platform, @@ -1604,7 +1351,7 @@ def start_simulation(): graph_id=graph_id ) - # 更新模拟状态 + state.status = SimulationStatus.RUNNING manager._save_simulation_state(state) @@ -1628,7 +1375,7 @@ def start_simulation(): }), 400 except Exception as e: - logger.error(f"启动模拟失败: {str(e)}") + logger.error(f"Failed to start simulation: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1638,24 +1385,7 @@ def start_simulation(): @simulation_bp.route('/stop', methods=['POST']) def stop_simulation(): - """ - 停止模拟 - - 请求(JSON): - { - "simulation_id": "sim_xxxx" // 必填,模拟ID - } - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "runner_status": "stopped", - "completed_at": "2025-12-01T12:00:00" - } - } - """ + """Stop simulation.""" try: data = request.get_json() or {} @@ -1663,12 +1393,12 @@ def stop_simulation(): if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 run_state = SimulationRunner.stop_simulation(simulation_id) - # 更新模拟状态 + manager = SimulationManager() state = manager.get_simulation(simulation_id) if state: @@ -1687,7 +1417,7 @@ def stop_simulation(): }), 400 except Exception as e: - logger.error(f"停止模拟失败: {str(e)}") + logger.error(f"Failed to stop simulation: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1695,34 +1425,11 @@ def stop_simulation(): }), 500 -# ============== 实时状态监控接口 ============== + @simulation_bp.route('//run-status', methods=['GET']) def get_run_status(simulation_id: str): - """ - 获取模拟运行实时状态(用于前端轮询) - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "runner_status": "running", - "current_round": 5, - "total_rounds": 144, - "progress_percent": 3.5, - "simulated_hours": 2, - "total_simulation_hours": 72, - "twitter_running": true, - "reddit_running": true, - "twitter_actions_count": 150, - "reddit_actions_count": 200, - "total_actions_count": 350, - "started_at": "2025-12-01T10:00:00", - "updated_at": "2025-12-01T10:30:00" - } - } - """ + """Get run status.""" try: run_state = SimulationRunner.get_run_state(simulation_id) @@ -1747,7 +1454,7 @@ def get_run_status(simulation_id: str): }) except Exception as e: - logger.error(f"获取运行状态失败: {str(e)}") + logger.error(f"Failed to fetch run status: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1757,41 +1464,7 @@ def get_run_status(simulation_id: str): @simulation_bp.route('//run-status/detail', methods=['GET']) def get_run_status_detail(simulation_id: str): - """ - 获取模拟运行详细状态(包含所有动作) - - 用于前端展示实时动态 - - Query参数: - platform: 过滤平台(twitter/reddit,可选) - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "runner_status": "running", - "current_round": 5, - ... - "all_actions": [ - { - "round_num": 5, - "timestamp": "2025-12-01T10:30:00", - "platform": "twitter", - "agent_id": 3, - "agent_name": "Agent Name", - "action_type": "CREATE_POST", - "action_args": {"content": "..."}, - "result": null, - "success": true - }, - ... - ], - "twitter_actions": [...], # Twitter 平台的所有动作 - "reddit_actions": [...] # Reddit 平台的所有动作 - } - } - """ + """Get run status detail.""" try: run_state = SimulationRunner.get_run_state(simulation_id) platform_filter = request.args.get('platform') @@ -1808,13 +1481,13 @@ def get_run_status_detail(simulation_id: str): } }) - # 获取完整的动作列表 + all_actions = SimulationRunner.get_all_actions( simulation_id=simulation_id, platform=platform_filter ) - # 分平台获取动作 + twitter_actions = SimulationRunner.get_all_actions( simulation_id=simulation_id, platform="twitter" @@ -1825,7 +1498,7 @@ def get_run_status_detail(simulation_id: str): platform="reddit" ) if not platform_filter or platform_filter == "reddit" else [] - # 获取当前轮次的动作(recent_actions 只展示最新一轮) + current_round = run_state.current_round recent_actions = SimulationRunner.get_all_actions( simulation_id=simulation_id, @@ -1833,13 +1506,13 @@ def get_run_status_detail(simulation_id: str): round_num=current_round ) if current_round > 0 else [] - # 获取基础状态信息 + result = run_state.to_dict() result["all_actions"] = [a.to_dict() for a in all_actions] result["twitter_actions"] = [a.to_dict() for a in twitter_actions] result["reddit_actions"] = [a.to_dict() for a in reddit_actions] result["rounds_count"] = len(run_state.rounds) - # recent_actions 只展示当前最新一轮两个平台的内容 + result["recent_actions"] = [a.to_dict() for a in recent_actions] return jsonify({ @@ -1848,7 +1521,7 @@ def get_run_status_detail(simulation_id: str): }) except Exception as e: - logger.error(f"获取详细状态失败: {str(e)}") + logger.error(f"Failed to fetch detailed status: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1858,25 +1531,7 @@ def get_run_status_detail(simulation_id: str): @simulation_bp.route('//actions', methods=['GET']) def get_simulation_actions(simulation_id: str): - """ - 获取模拟中的Agent动作历史 - - Query参数: - limit: 返回数量(默认100) - offset: 偏移量(默认0) - platform: 过滤平台(twitter/reddit) - agent_id: 过滤Agent ID - round_num: 过滤轮次 - - 返回: - { - "success": true, - "data": { - "count": 100, - "actions": [...] - } - } - """ + """Get simulation actions.""" try: limit = request.args.get('limit', 100, type=int) offset = request.args.get('offset', 0, type=int) @@ -1902,7 +1557,7 @@ def get_simulation_actions(simulation_id: str): }) except Exception as e: - logger.error(f"获取动作历史失败: {str(e)}") + logger.error(f"Failed to fetch action history: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1912,17 +1567,7 @@ def get_simulation_actions(simulation_id: str): @simulation_bp.route('//timeline', methods=['GET']) def get_simulation_timeline(simulation_id: str): - """ - 获取模拟时间线(按轮次汇总) - - 用于前端展示进度条和时间线视图 - - Query参数: - start_round: 起始轮次(默认0) - end_round: 结束轮次(默认全部) - - 返回每轮的汇总信息 - """ + """Get simulation timeline.""" try: start_round = request.args.get('start_round', 0, type=int) end_round = request.args.get('end_round', type=int) @@ -1942,7 +1587,7 @@ def get_simulation_timeline(simulation_id: str): }) except Exception as e: - logger.error(f"获取时间线失败: {str(e)}") + logger.error(f"Failed to fetch timeline: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1952,11 +1597,7 @@ def get_simulation_timeline(simulation_id: str): @simulation_bp.route('//agent-stats', methods=['GET']) def get_agent_stats(simulation_id: str): - """ - 获取每个Agent的统计信息 - - 用于前端展示Agent活跃度排行、动作分布等 - """ + """Get agent stats.""" try: stats = SimulationRunner.get_agent_stats(simulation_id) @@ -1969,7 +1610,7 @@ def get_agent_stats(simulation_id: str): }) except Exception as e: - logger.error(f"获取Agent统计失败: {str(e)}") + logger.error(f"Failed to fetch agent stats: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -1977,20 +1618,11 @@ def get_agent_stats(simulation_id: str): }), 500 -# ============== 数据库查询接口 ============== + @simulation_bp.route('//posts', methods=['GET']) def get_simulation_posts(simulation_id: str): - """ - 获取模拟中的帖子 - - Query参数: - platform: 平台类型(twitter/reddit) - limit: 返回数量(默认50) - offset: 偏移量 - - 返回帖子列表(从SQLite数据库读取) - """ + """Get simulation posts.""" try: platform = request.args.get('platform', 'reddit') limit = request.args.get('limit', 50, type=int) @@ -2011,7 +1643,7 @@ def get_simulation_posts(simulation_id: str): "platform": platform, "count": 0, "posts": [], - "message": "数据库不存在,模拟可能尚未运行" + "message": "Database does not exist yet; the simulation may not have run" } }) @@ -2049,7 +1681,7 @@ def get_simulation_posts(simulation_id: str): }) except Exception as e: - logger.error(f"获取帖子失败: {str(e)}") + logger.error(f"Failed to fetch posts: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -2059,14 +1691,7 @@ def get_simulation_posts(simulation_id: str): @simulation_bp.route('//comments', methods=['GET']) def get_simulation_comments(simulation_id: str): - """ - 获取模拟中的评论(仅Reddit) - - Query参数: - post_id: 过滤帖子ID(可选) - limit: 返回数量 - offset: 偏移量 - """ + """Get simulation comments.""" try: post_id = request.args.get('post_id') limit = request.args.get('limit', 50, type=int) @@ -2124,7 +1749,7 @@ def get_simulation_comments(simulation_id: str): }) except Exception as e: - logger.error(f"获取评论失败: {str(e)}") + logger.error(f"Failed to fetch comments: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -2132,101 +1757,56 @@ def get_simulation_comments(simulation_id: str): }), 500 -# ============== Interview 采访接口 ============== + @simulation_bp.route('/interview', methods=['POST']) def interview_agent(): - """ - 采访单个Agent - - 注意:此功能需要模拟环境处于运行状态(完成模拟循环后进入等待命令模式) - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "agent_id": 0, // 必填,Agent ID - "prompt": "你对这件事有什么看法?", // 必填,采访问题 - "platform": "twitter", // 可选,指定平台(twitter/reddit) - // 不指定时:双平台模拟同时采访两个平台 - "timeout": 60 // 可选,超时时间(秒),默认60 - } - - 返回(不指定platform,双平台模式): - { - "success": true, - "data": { - "agent_id": 0, - "prompt": "你对这件事有什么看法?", - "result": { - "agent_id": 0, - "prompt": "...", - "platforms": { - "twitter": {"agent_id": 0, "response": "...", "platform": "twitter"}, - "reddit": {"agent_id": 0, "response": "...", "platform": "reddit"} - } - }, - "timestamp": "2025-12-08T10:00:01" - } - } - - 返回(指定platform): - { - "success": true, - "data": { - "agent_id": 0, - "prompt": "你对这件事有什么看法?", - "result": { - "agent_id": 0, - "response": "我认为...", - "platform": "twitter", - "timestamp": "2025-12-08T10:00:00" - }, - "timestamp": "2025-12-08T10:00:01" - } - } - """ + """Interview Agent.""" try: data = request.get_json() or {} simulation_id = data.get('simulation_id') agent_id = data.get('agent_id') prompt = data.get('prompt') - platform = data.get('platform') # 可选:twitter/reddit/None + platform = data.get('platform') timeout = data.get('timeout', 60) if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 if agent_id is None: return jsonify({ "success": False, - "error": "请提供 agent_id" + "error": "Please provide agent_id" }), 400 if not prompt: return jsonify({ "success": False, - "error": "请提供 prompt(采访问题)" + "error": "Please provide prompt (the interview question)" }), 400 - # 验证platform参数 + if platform and platform not in ("twitter", "reddit"): return jsonify({ "success": False, - "error": "platform 参数只能是 'twitter' 或 'reddit'" + "error": "platform must be either 'twitter' or 'reddit'" }), 400 - # 检查环境状态 + if not SimulationRunner.check_env_alive(simulation_id): return jsonify({ "success": False, - "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" + "error": ( + "The simulation environment is not running or has been closed. " + "Make sure the simulation has completed and entered command-wait mode." + ) }), 400 - # 优化prompt,添加前缀避免Agent调用工具 + optimized_prompt = optimize_interview_prompt(prompt) result = SimulationRunner.interview_agent( @@ -2251,11 +1831,11 @@ def interview_agent(): except TimeoutError as e: return jsonify({ "success": False, - "error": f"等待Interview响应超时: {str(e)}" + "error": f"Timed out waiting for interview response: {str(e)}" }), 504 except Exception as e: - logger.error(f"Interview失败: {str(e)}") + logger.error(f"Interview failed: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -2265,103 +1845,65 @@ def interview_agent(): @simulation_bp.route('/interview/batch', methods=['POST']) def interview_agents_batch(): - """ - 批量采访多个Agent - - 注意:此功能需要模拟环境处于运行状态 - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "interviews": [ // 必填,采访列表 - { - "agent_id": 0, - "prompt": "你对A有什么看法?", - "platform": "twitter" // 可选,指定该Agent的采访平台 - }, - { - "agent_id": 1, - "prompt": "你对B有什么看法?" // 不指定platform则使用默认值 - } - ], - "platform": "reddit", // 可选,默认平台(被每项的platform覆盖) - // 不指定时:双平台模拟每个Agent同时采访两个平台 - "timeout": 120 // 可选,超时时间(秒),默认120 - } - - 返回: - { - "success": true, - "data": { - "interviews_count": 2, - "result": { - "interviews_count": 4, - "results": { - "twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"}, - "reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"}, - "twitter_1": {"agent_id": 1, "response": "...", "platform": "twitter"}, - "reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"} - } - }, - "timestamp": "2025-12-08T10:00:01" - } - } - """ + """Interview Agents Batch.""" try: data = request.get_json() or {} simulation_id = data.get('simulation_id') interviews = data.get('interviews') - platform = data.get('platform') # 可选:twitter/reddit/None + platform = data.get('platform') timeout = data.get('timeout', 120) if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 if not interviews or not isinstance(interviews, list): return jsonify({ "success": False, - "error": "请提供 interviews(采访列表)" + "error": "Please provide interviews (the interview list)" }), 400 - # 验证platform参数 + if platform and platform not in ("twitter", "reddit"): return jsonify({ "success": False, - "error": "platform 参数只能是 'twitter' 或 'reddit'" + "error": "platform must be either 'twitter' or 'reddit'" }), 400 - # 验证每个采访项 + for i, interview in enumerate(interviews): if 'agent_id' not in interview: return jsonify({ "success": False, - "error": f"采访列表第{i+1}项缺少 agent_id" + "error": f"Interview list item {i+1} is missing agent_id" }), 400 if 'prompt' not in interview: return jsonify({ "success": False, - "error": f"采访列表第{i+1}项缺少 prompt" + "error": f"Interview list item {i+1} is missing prompt" }), 400 - # 验证每项的platform(如果有) + item_platform = interview.get('platform') if item_platform and item_platform not in ("twitter", "reddit"): return jsonify({ "success": False, - "error": f"采访列表第{i+1}项的platform只能是 'twitter' 或 'reddit'" + "error": f"Interview list item {i+1} has an invalid platform; only 'twitter' or 'reddit' are allowed" }), 400 - # 检查环境状态 + if not SimulationRunner.check_env_alive(simulation_id): return jsonify({ "success": False, - "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" + "error": ( + "The simulation environment is not running or has been closed. " + "Make sure the simulation has completed and entered command-wait mode." + ) }), 400 - # 优化每个采访项的prompt,添加前缀避免Agent调用工具 + optimized_interviews = [] for interview in interviews: optimized_interview = interview.copy() @@ -2389,11 +1931,11 @@ def interview_agents_batch(): except TimeoutError as e: return jsonify({ "success": False, - "error": f"等待批量Interview响应超时: {str(e)}" + "error": f"Timed out waiting for batch interview response: {str(e)}" }), 504 except Exception as e: - logger.error(f"批量Interview失败: {str(e)}") + logger.error(f"Batch interview failed: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -2403,72 +1945,45 @@ def interview_agents_batch(): @simulation_bp.route('/interview/all', methods=['POST']) def interview_all_agents(): - """ - 全局采访 - 使用相同问题采访所有Agent - - 注意:此功能需要模拟环境处于运行状态 - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "prompt": "你对这件事整体有什么看法?", // 必填,采访问题(所有Agent使用相同问题) - "platform": "reddit", // 可选,指定平台(twitter/reddit) - // 不指定时:双平台模拟每个Agent同时采访两个平台 - "timeout": 180 // 可选,超时时间(秒),默认180 - } - - 返回: - { - "success": true, - "data": { - "interviews_count": 50, - "result": { - "interviews_count": 100, - "results": { - "twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"}, - "reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"}, - ... - } - }, - "timestamp": "2025-12-08T10:00:01" - } - } - """ + """Interview All Agents.""" try: data = request.get_json() or {} simulation_id = data.get('simulation_id') prompt = data.get('prompt') - platform = data.get('platform') # 可选:twitter/reddit/None + platform = data.get('platform') timeout = data.get('timeout', 180) if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 if not prompt: return jsonify({ "success": False, - "error": "请提供 prompt(采访问题)" + "error": "Please provide prompt (the interview question)" }), 400 - # 验证platform参数 + if platform and platform not in ("twitter", "reddit"): return jsonify({ "success": False, - "error": "platform 参数只能是 'twitter' 或 'reddit'" + "error": "platform must be either 'twitter' or 'reddit'" }), 400 - # 检查环境状态 + if not SimulationRunner.check_env_alive(simulation_id): return jsonify({ "success": False, - "error": "模拟环境未运行或已关闭。请确保模拟已完成并进入等待命令模式。" + "error": ( + "The simulation environment is not running or has been closed. " + "Make sure the simulation has completed and entered command-wait mode." + ) }), 400 - # 优化prompt,添加前缀避免Agent调用工具 + optimized_prompt = optimize_interview_prompt(prompt) result = SimulationRunner.interview_all_agents( @@ -2492,11 +2007,11 @@ def interview_all_agents(): except TimeoutError as e: return jsonify({ "success": False, - "error": f"等待全局Interview响应超时: {str(e)}" + "error": f"Timed out waiting for global interview response: {str(e)}" }), 504 except Exception as e: - logger.error(f"全局Interview失败: {str(e)}") + logger.error(f"Global interview failed: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -2506,50 +2021,19 @@ def interview_all_agents(): @simulation_bp.route('/interview/history', methods=['POST']) def get_interview_history(): - """ - 获取Interview历史记录 - - 从模拟数据库中读取所有Interview记录 - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "platform": "reddit", // 可选,平台类型(reddit/twitter) - // 不指定则返回两个平台的所有历史 - "agent_id": 0, // 可选,只获取该Agent的采访历史 - "limit": 100 // 可选,返回数量,默认100 - } - - 返回: - { - "success": true, - "data": { - "count": 10, - "history": [ - { - "agent_id": 0, - "response": "我认为...", - "prompt": "你对这件事有什么看法?", - "timestamp": "2025-12-08T10:00:00", - "platform": "reddit" - }, - ... - ] - } - } - """ + """Get interview history.""" try: data = request.get_json() or {} simulation_id = data.get('simulation_id') - platform = data.get('platform') # 不指定则返回两个平台的历史 + platform = data.get('platform') agent_id = data.get('agent_id') limit = data.get('limit', 100) if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 history = SimulationRunner.get_interview_history( @@ -2568,7 +2052,7 @@ def get_interview_history(): }) except Exception as e: - logger.error(f"获取Interview历史失败: {str(e)}") + logger.error(f"Failed to fetch interview history: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -2578,28 +2062,7 @@ def get_interview_history(): @simulation_bp.route('/env-status', methods=['POST']) def get_env_status(): - """ - 获取模拟环境状态 - - 检查模拟环境是否存活(可以接收Interview命令) - - 请求(JSON): - { - "simulation_id": "sim_xxxx" // 必填,模拟ID - } - - 返回: - { - "success": true, - "data": { - "simulation_id": "sim_xxxx", - "env_alive": true, - "twitter_available": true, - "reddit_available": true, - "message": "环境正在运行,可以接收Interview命令" - } - } - """ + """Get env status.""" try: data = request.get_json() or {} @@ -2608,18 +2071,18 @@ def get_env_status(): if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 env_alive = SimulationRunner.check_env_alive(simulation_id) - # 获取更详细的状态信息 + env_status = SimulationRunner.get_env_status_detail(simulation_id) if env_alive: - message = "环境正在运行,可以接收Interview命令" + message = "Environment is running and can accept interview commands" else: - message = "环境未运行或已关闭" + message = "Environment is not running or has been closed" return jsonify({ "success": True, @@ -2633,7 +2096,7 @@ def get_env_status(): }) except Exception as e: - logger.error(f"获取环境状态失败: {str(e)}") + logger.error(f"Failed to fetch environment status: {str(e)}") return jsonify({ "success": False, "error": str(e), @@ -2643,30 +2106,7 @@ def get_env_status(): @simulation_bp.route('/close-env', methods=['POST']) def close_simulation_env(): - """ - 关闭模拟环境 - - 向模拟发送关闭环境命令,使其优雅退出等待命令模式。 - - 注意:这不同于 /stop 接口,/stop 会强制终止进程, - 而此接口会让模拟优雅地关闭环境并退出。 - - 请求(JSON): - { - "simulation_id": "sim_xxxx", // 必填,模拟ID - "timeout": 30 // 可选,超时时间(秒),默认30 - } - - 返回: - { - "success": true, - "data": { - "message": "环境关闭命令已发送", - "result": {...}, - "timestamp": "2025-12-08T10:00:01" - } - } - """ + """Close simulation env.""" try: data = request.get_json() or {} @@ -2676,7 +2116,7 @@ def close_simulation_env(): if not simulation_id: return jsonify({ "success": False, - "error": "请提供 simulation_id" + "error": "Please provide simulation_id" }), 400 result = SimulationRunner.close_simulation_env( @@ -2684,7 +2124,7 @@ def close_simulation_env(): timeout=timeout ) - # 更新模拟状态 + manager = SimulationManager() state = manager.get_simulation(simulation_id) if state: @@ -2703,7 +2143,7 @@ def close_simulation_env(): }), 400 except Exception as e: - logger.error(f"关闭环境失败: {str(e)}") + logger.error(f"Failed to close environment: {str(e)}") return jsonify({ "success": False, "error": str(e), diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50a..d778cf076 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,54 +1,77 @@ -""" -配置管理 -统一从项目根目录的 .env 文件加载配置 -""" +"""Configuration management utilities.""" import os from dotenv import load_dotenv -# 加载项目根目录的 .env 文件 -# 路径: MiroFish/.env (相对于 backend/app/config.py) + + project_root_env = os.path.join(os.path.dirname(__file__), '../../.env') if os.path.exists(project_root_env): load_dotenv(project_root_env, override=True) else: - # 如果根目录没有 .env,尝试加载环境变量(用于生产环境) + load_dotenv(override=True) class Config: - """Flask配置类""" + """Config.""" + - # Flask配置 SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key') DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true' - # JSON配置 - 禁用ASCII转义,让中文直接显示(而不是 \uXXXX 格式) + JSON_AS_ASCII = False - # LLM配置(统一使用OpenAI格式) + LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') - # Zep配置 + + GRAPH_BACKEND = os.environ.get('GRAPH_BACKEND', 'zep_cloud').strip().lower() + + ZEP_API_KEY = os.environ.get('ZEP_API_KEY') + + + NEO4J_URI = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') + NEO4J_USER = os.environ.get('NEO4J_USER', 'neo4j') + NEO4J_PASSWORD = os.environ.get('NEO4J_PASSWORD') + NEO4J_DATABASE = os.environ.get('NEO4J_DATABASE', 'neo4j') + + GRAPHITI_AUTO_INIT = os.environ.get('GRAPHITI_AUTO_INIT', 'True').lower() == 'true' + GRAPHITI_TELEMETRY_ENABLED = os.environ.get('GRAPHITI_TELEMETRY_ENABLED', 'False').lower() == 'true' + GRAPHITI_MAX_COROUTINES = int(os.environ.get('GRAPHITI_MAX_COROUTINES', '10')) + GRAPHITI_SEARCH_RERANKER = os.environ.get('GRAPHITI_SEARCH_RERANKER', 'rrf').strip().lower() + + GRAPHITI_LLM_API_KEY = os.environ.get('GRAPHITI_LLM_API_KEY') or LLM_API_KEY + GRAPHITI_LLM_BASE_URL = os.environ.get('GRAPHITI_LLM_BASE_URL') or LLM_BASE_URL + GRAPHITI_LLM_MODEL = os.environ.get('GRAPHITI_LLM_MODEL') or LLM_MODEL_NAME + + GRAPHITI_EMBEDDER_API_KEY = os.environ.get('GRAPHITI_EMBEDDER_API_KEY') or LLM_API_KEY + GRAPHITI_EMBEDDER_BASE_URL = os.environ.get('GRAPHITI_EMBEDDER_BASE_URL') or LLM_BASE_URL + GRAPHITI_EMBEDDER_MODEL = os.environ.get('GRAPHITI_EMBEDDER_MODEL', 'text-embedding-3-small') + + GRAPHITI_RERANKER_API_KEY = os.environ.get('GRAPHITI_RERANKER_API_KEY') or LLM_API_KEY + GRAPHITI_RERANKER_BASE_URL = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or LLM_BASE_URL + GRAPHITI_RERANKER_MODEL = os.environ.get('GRAPHITI_RERANKER_MODEL') or LLM_MODEL_NAME + - # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads') ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'} - # 文本处理配置 - DEFAULT_CHUNK_SIZE = 500 # 默认切块大小 - DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小 - # OASIS模拟配置 + DEFAULT_CHUNK_SIZE = 500 + DEFAULT_CHUNK_OVERLAP = 50 + + OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10')) OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations') - # OASIS平台可用动作配置 + OASIS_TWITTER_ACTIONS = [ 'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST' ] @@ -58,18 +81,40 @@ class Config: 'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE' ] - # Report Agent配置 + REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5')) REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2')) REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5')) + @classmethod + def validate_graph_backend(cls): + """Validate Graph Backend.""" + errors = [] + + if cls.GRAPH_BACKEND == 'zep_cloud': + if not cls.ZEP_API_KEY: + errors.append("ZEP_API_KEY is not configured") + elif cls.GRAPH_BACKEND == 'graphiti_local': + if not cls.NEO4J_URI: + errors.append("NEO4J_URI is not configured") + if not cls.NEO4J_USER: + errors.append("NEO4J_USER is not configured") + if not cls.NEO4J_PASSWORD: + errors.append("NEO4J_PASSWORD is not configured") + if not cls.GRAPHITI_LLM_API_KEY: + errors.append("GRAPHITI_LLM_API_KEY/LLM_API_KEY is not configured") + if not cls.GRAPHITI_EMBEDDER_API_KEY: + errors.append("GRAPHITI_EMBEDDER_API_KEY/LLM_API_KEY is not configured") + else: + errors.append(f"Unsupported GRAPH_BACKEND: {cls.GRAPH_BACKEND}") + + return errors + @classmethod def validate(cls): - """验证必要配置""" + """Validate the configuration.""" errors = [] if not cls.LLM_API_KEY: - errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") + errors.append("LLM_API_KEY is not configured") + errors.extend(cls.validate_graph_backend()) return errors - diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 55bec6195..aebc6f6e4 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,6 +1,4 @@ -""" -数据模型模块 -""" +"""Data model modules.""" from .task import TaskManager, TaskStatus from .project import Project, ProjectStatus, ProjectManager diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 089789374..d04ff0909 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -1,7 +1,4 @@ -""" -项目上下文管理 -用于在服务端持久化项目状态,避免前端在接口间传递大量数据 -""" +"""Project context management.""" import os import json @@ -15,45 +12,45 @@ class ProjectStatus(str, Enum): - """项目状态""" - CREATED = "created" # 刚创建,文件已上传 - ONTOLOGY_GENERATED = "ontology_generated" # 本体已生成 - GRAPH_BUILDING = "graph_building" # 图谱构建中 - GRAPH_COMPLETED = "graph_completed" # 图谱构建完成 - FAILED = "failed" # 失败 + """Project Status.""" + CREATED = "created" + ONTOLOGY_GENERATED = "ontology_generated" + GRAPH_BUILDING = "graph_building" + GRAPH_COMPLETED = "graph_completed" + FAILED = "failed" @dataclass class Project: - """项目数据模型""" + """Project.""" project_id: str name: str status: ProjectStatus created_at: str updated_at: str - # 文件信息 + files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}] total_text_length: int = 0 - # 本体信息(接口1生成后填充) + ontology: Optional[Dict[str, Any]] = None analysis_summary: Optional[str] = None - # 图谱信息(接口2完成后填充) + graph_id: Optional[str] = None graph_build_task_id: Optional[str] = None - # 配置 + simulation_requirement: Optional[str] = None chunk_size: int = 500 chunk_overlap: int = 50 - # 错误信息 + error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: - """转换为字典""" + """Convert the object to a dictionary.""" return { "project_id": self.project_id, "name": self.name, @@ -74,7 +71,7 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Project': - """从字典创建""" + """Create an instance from a dictionary.""" status = data.get('status', 'created') if isinstance(status, str): status = ProjectStatus(status) @@ -99,47 +96,39 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Project': class ProjectManager: - """项目管理器 - 负责项目的持久化存储和检索""" + """Project Manager.""" + - # 项目存储根目录 PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects') @classmethod def _ensure_projects_dir(cls): - """确保项目目录存在""" + """Ensure projects dir.""" os.makedirs(cls.PROJECTS_DIR, exist_ok=True) @classmethod def _get_project_dir(cls, project_id: str) -> str: - """获取项目目录路径""" + """Get project dir.""" return os.path.join(cls.PROJECTS_DIR, project_id) @classmethod def _get_project_meta_path(cls, project_id: str) -> str: - """获取项目元数据文件路径""" + """Get project meta path.""" return os.path.join(cls._get_project_dir(project_id), 'project.json') @classmethod def _get_project_files_dir(cls, project_id: str) -> str: - """获取项目文件存储目录""" + """Get project files dir.""" return os.path.join(cls._get_project_dir(project_id), 'files') @classmethod def _get_project_text_path(cls, project_id: str) -> str: - """获取项目提取文本存储路径""" + """Get project text path.""" return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt') @classmethod def create_project(cls, name: str = "Unnamed Project") -> Project: - """ - 创建新项目 - - Args: - name: 项目名称 - - Returns: - 新创建的Project对象 - """ + """Create project.""" cls._ensure_projects_dir() project_id = f"proj_{uuid.uuid4().hex[:12]}" @@ -153,20 +142,20 @@ def create_project(cls, name: str = "Unnamed Project") -> Project: updated_at=now ) - # 创建项目目录结构 + project_dir = cls._get_project_dir(project_id) files_dir = cls._get_project_files_dir(project_id) os.makedirs(project_dir, exist_ok=True) os.makedirs(files_dir, exist_ok=True) - # 保存项目元数据 + cls.save_project(project) return project @classmethod def save_project(cls, project: Project) -> None: - """保存项目元数据""" + """Save project.""" project.updated_at = datetime.now().isoformat() meta_path = cls._get_project_meta_path(project.project_id) @@ -175,15 +164,7 @@ def save_project(cls, project: Project) -> None: @classmethod def get_project(cls, project_id: str) -> Optional[Project]: - """ - 获取项目 - - Args: - project_id: 项目ID - - Returns: - Project对象,如果不存在返回None - """ + """Get project.""" meta_path = cls._get_project_meta_path(project_id) if not os.path.exists(meta_path): @@ -196,15 +177,7 @@ def get_project(cls, project_id: str) -> Optional[Project]: @classmethod def list_projects(cls, limit: int = 50) -> List[Project]: - """ - 列出所有项目 - - Args: - limit: 返回数量限制 - - Returns: - 项目列表,按创建时间倒序 - """ + """List projects.""" cls._ensure_projects_dir() projects = [] @@ -213,22 +186,14 @@ def list_projects(cls, limit: int = 50) -> List[Project]: if project: projects.append(project) - # 按创建时间倒序排序 + projects.sort(key=lambda p: p.created_at, reverse=True) return projects[:limit] @classmethod def delete_project(cls, project_id: str) -> bool: - """ - 删除项目及其所有文件 - - Args: - project_id: 项目ID - - Returns: - 是否删除成功 - """ + """Delete project.""" project_dir = cls._get_project_dir(project_id) if not os.path.exists(project_dir): @@ -239,29 +204,19 @@ def delete_project(cls, project_id: str) -> bool: @classmethod def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]: - """ - 保存上传的文件到项目目录 - - Args: - project_id: 项目ID - file_storage: Flask的FileStorage对象 - original_filename: 原始文件名 - - Returns: - 文件信息字典 {filename, path, size} - """ + """Save file to project.""" files_dir = cls._get_project_files_dir(project_id) os.makedirs(files_dir, exist_ok=True) - # 生成安全的文件名 + ext = os.path.splitext(original_filename)[1].lower() safe_filename = f"{uuid.uuid4().hex[:8]}{ext}" file_path = os.path.join(files_dir, safe_filename) - # 保存文件 + file_storage.save(file_path) - # 获取文件大小 + file_size = os.path.getsize(file_path) return { @@ -273,14 +228,14 @@ def save_file_to_project(cls, project_id: str, file_storage, original_filename: @classmethod def save_extracted_text(cls, project_id: str, text: str) -> None: - """保存提取的文本""" + """Save extracted text.""" text_path = cls._get_project_text_path(project_id) with open(text_path, 'w', encoding='utf-8') as f: f.write(text) @classmethod def get_extracted_text(cls, project_id: str) -> Optional[str]: - """获取提取的文本""" + """Get extracted text.""" text_path = cls._get_project_text_path(project_id) if not os.path.exists(text_path): @@ -291,7 +246,7 @@ def get_extracted_text(cls, project_id: str) -> Optional[str]: @classmethod def get_project_files(cls, project_id: str) -> List[str]: - """获取项目的所有文件路径""" + """Get project files.""" files_dir = cls._get_project_files_dir(project_id) if not os.path.exists(files_dir): diff --git a/backend/app/models/task.py b/backend/app/models/task.py index e15f35fbd..d1c87b8d7 100644 --- a/backend/app/models/task.py +++ b/backend/app/models/task.py @@ -1,7 +1,4 @@ -""" -任务状态管理 -用于跟踪长时间运行的任务(如图谱构建) -""" +"""Task state management.""" import uuid import threading @@ -12,30 +9,30 @@ class TaskStatus(str, Enum): - """任务状态枚举""" - PENDING = "pending" # 等待中 - PROCESSING = "processing" # 处理中 - COMPLETED = "completed" # 已完成 - FAILED = "failed" # 失败 + """Task Status.""" + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" @dataclass class Task: - """任务数据类""" + """Task.""" task_id: str task_type: str status: TaskStatus created_at: datetime updated_at: datetime - progress: int = 0 # 总进度百分比 0-100 - message: str = "" # 状态消息 - result: Optional[Dict] = None # 任务结果 - error: Optional[str] = None # 错误信息 - metadata: Dict = field(default_factory=dict) # 额外元数据 - progress_detail: Dict = field(default_factory=dict) # 详细进度信息 + progress: int = 0 + message: str = "" + result: Optional[Dict] = None + error: Optional[str] = None + metadata: Dict = field(default_factory=dict) + progress_detail: Dict = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: - """转换为字典""" + """Convert the object to a dictionary.""" return { "task_id": self.task_id, "task_type": self.task_type, @@ -52,16 +49,13 @@ def to_dict(self) -> Dict[str, Any]: class TaskManager: - """ - 任务管理器 - 线程安全的任务状态管理 - """ + """Task Manager.""" _instance = None _lock = threading.Lock() def __new__(cls): - """单例模式""" + """Create the singleton instance.""" if cls._instance is None: with cls._lock: if cls._instance is None: @@ -71,16 +65,7 @@ def __new__(cls): return cls._instance def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str: - """ - 创建新任务 - - Args: - task_type: 任务类型 - metadata: 额外元数据 - - Returns: - 任务ID - """ + """Create task.""" task_id = str(uuid.uuid4()) now = datetime.now() @@ -99,7 +84,7 @@ def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str: return task_id def get_task(self, task_id: str) -> Optional[Task]: - """获取任务""" + """Get task.""" with self._task_lock: return self._tasks.get(task_id) @@ -113,18 +98,7 @@ def update_task( error: Optional[str] = None, progress_detail: Optional[Dict] = None ): - """ - 更新任务状态 - - Args: - task_id: 任务ID - status: 新状态 - progress: 进度 - message: 消息 - result: 结果 - error: 错误信息 - progress_detail: 详细进度信息 - """ + """Update task.""" with self._task_lock: task = self._tasks.get(task_id) if task: @@ -143,26 +117,26 @@ def update_task( task.progress_detail = progress_detail def complete_task(self, task_id: str, result: Dict): - """标记任务完成""" + """Mark task as complete.""" self.update_task( task_id, status=TaskStatus.COMPLETED, progress=100, - message="任务完成", + message="Task completed", result=result ) def fail_task(self, task_id: str, error: str): - """标记任务失败""" + """Mark task as failed.""" self.update_task( task_id, status=TaskStatus.FAILED, - message="任务失败", + message="Task failed", error=error ) def list_tasks(self, task_type: Optional[str] = None) -> list: - """列出任务""" + """List tasks.""" with self._task_lock: tasks = list(self._tasks.values()) if task_type: @@ -170,7 +144,7 @@ def list_tasks(self, task_type: Optional[str] = None) -> list: return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)] def cleanup_old_tasks(self, max_age_hours: int = 24): - """清理旧任务""" + """Cleanup Old Tasks.""" from datetime import timedelta cutoff = datetime.now() - timedelta(hours=max_age_hours) @@ -181,4 +155,3 @@ def cleanup_old_tasks(self, max_age_hours: int = 24): ] for tid in old_ids: del self._tasks[tid] - diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index 8db85d86f..294074318 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -1,6 +1,4 @@ -""" -业务服务模块 -""" +"""Business service modules.""" from .ontology_generator import OntologyGenerator from .graph_builder import GraphBuilderService diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 0e0444bf3..3bef7b36e 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -1,27 +1,24 @@ -""" -图谱构建服务 -接口2:使用Zep API构建Standalone Graph -""" +"""Graph building service.""" import os import uuid -import time import threading +import logging from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass -from zep_cloud.client import Zep -from zep_cloud import EpisodeData, EntityEdgeSourceTarget - from ..config import Config from ..models.task import TaskManager, TaskStatus -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .graph_provider import create_graph_provider from .text_processor import TextProcessor +logger = logging.getLogger(__name__) + + @dataclass class GraphInfo: - """图谱信息""" + """Graph Info.""" graph_id: str node_count: int edge_count: int @@ -37,17 +34,11 @@ def to_dict(self) -> Dict[str, Any]: class GraphBuilderService: - """ - 图谱构建服务 - 负责调用Zep API构建知识图谱 - """ + """Graph Builder Service.""" def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.provider = create_graph_provider() self.task_manager = TaskManager() def build_graph_async( @@ -59,21 +50,8 @@ def build_graph_async( chunk_overlap: int = 50, batch_size: int = 3 ) -> str: - """ - 异步构建图谱 + """Build graph async.""" - Args: - text: 输入文本 - ontology: 本体定义(来自接口1的输出) - graph_name: 图谱名称 - chunk_size: 文本块大小 - chunk_overlap: 块重叠大小 - batch_size: 每批发送的块数量 - - Returns: - 任务ID - """ - # 创建任务 task_id = self.task_manager.create_task( task_type="graph_build", metadata={ @@ -83,7 +61,7 @@ def build_graph_async( } ) - # 在后台线程中执行构建 + thread = threading.Thread( target=self._build_graph_worker, args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size) @@ -103,41 +81,41 @@ def _build_graph_worker( chunk_overlap: int, batch_size: int ): - """图谱构建工作线程""" + """Build graph worker.""" try: self.task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=5, - message="开始构建图谱..." + message="Starting graph build..." ) - # 1. 创建图谱 + graph_id = self.create_graph(graph_name) self.task_manager.update_task( task_id, progress=10, - message=f"图谱已创建: {graph_id}" + message=f"Graph created: {graph_id}" ) - # 2. 设置本体 + self.set_ontology(graph_id, ontology) self.task_manager.update_task( task_id, progress=15, - message="本体已设置" + message="Ontology configured" ) - # 3. 文本分块 + chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) total_chunks = len(chunks) self.task_manager.update_task( task_id, progress=20, - message=f"文本已分割为 {total_chunks} 个块" + message=f"Text split into {total_chunks} chunks" ) - # 4. 分批发送数据 + episode_uuids = self.add_text_batches( graph_id, chunks, batch_size, lambda msg, prog: self.task_manager.update_task( @@ -147,14 +125,15 @@ def _build_graph_worker( ) ) - # 5. 等待Zep处理完成 + self.task_manager.update_task( task_id, progress=60, - message="等待Zep处理数据..." + message="Waiting for Zep to process data..." ) self._wait_for_episodes( + graph_id, episode_uuids, lambda msg, prog: self.task_manager.update_task( task_id, @@ -163,16 +142,16 @@ def _build_graph_worker( ) ) - # 6. 获取图谱信息 + self.task_manager.update_task( task_id, progress=90, - message="获取图谱信息..." + message="Fetching graph information..." ) graph_info = self._get_graph_info(graph_id) - # 完成 + self.task_manager.complete_task(task_id, { "graph_id": graph_id, "graph_info": graph_info.to_dict(), @@ -185,105 +164,12 @@ def _build_graph_worker( self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: - """创建Zep图谱(公开方法)""" - graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - - self.client.graph.create( - graph_id=graph_id, - name=name, - description="MiroFish Social Simulation Graph" - ) - - return graph_id + """Create graph.""" + return self.provider.create_graph(name) def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): - """设置图谱本体(公开方法)""" - import warnings - from typing import Optional - from pydantic import Field - from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel - - # 抑制 Pydantic v2 关于 Field(default=None) 的警告 - # 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略 - warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') - - # Zep 保留名称,不能作为属性名 - RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} - - def safe_attr_name(attr_name: str) -> str: - """将保留名称转换为安全名称""" - if attr_name.lower() in RESERVED_NAMES: - return f"entity_{attr_name}" - return attr_name - - # 动态创建实体类型 - entity_types = {} - for entity_def in ontology.get("entity_types", []): - name = entity_def["name"] - description = entity_def.get("description", f"A {name} entity.") - - # 创建属性字典和类型注解(Pydantic v2 需要) - attrs = {"__doc__": description} - annotations = {} - - for attr_def in entity_def.get("attributes", []): - attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 - attr_desc = attr_def.get("description", attr_name) - # Zep API 需要 Field 的 description,这是必需的 - attrs[attr_name] = Field(description=attr_desc, default=None) - annotations[attr_name] = Optional[EntityText] # 类型注解 - - attrs["__annotations__"] = annotations - - # 动态创建类 - entity_class = type(name, (EntityModel,), attrs) - entity_class.__doc__ = description - entity_types[name] = entity_class - - # 动态创建边类型 - edge_definitions = {} - for edge_def in ontology.get("edge_types", []): - name = edge_def["name"] - description = edge_def.get("description", f"A {name} relationship.") - - # 创建属性字典和类型注解 - attrs = {"__doc__": description} - annotations = {} - - for attr_def in edge_def.get("attributes", []): - attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 - attr_desc = attr_def.get("description", attr_name) - # Zep API 需要 Field 的 description,这是必需的 - attrs[attr_name] = Field(description=attr_desc, default=None) - annotations[attr_name] = Optional[str] # 边属性用str类型 - - attrs["__annotations__"] = annotations - - # 动态创建类 - class_name = ''.join(word.capitalize() for word in name.split('_')) - edge_class = type(class_name, (EdgeModel,), attrs) - edge_class.__doc__ = description - - # 构建source_targets - source_targets = [] - for st in edge_def.get("source_targets", []): - source_targets.append( - EntityEdgeSourceTarget( - source=st.get("source", "Entity"), - target=st.get("target", "Entity") - ) - ) - - if source_targets: - edge_definitions[name] = (edge_class, source_targets) - - # 调用Zep API设置本体 - if entity_types or edge_definitions: - self.client.graph.set_ontology( - graph_ids=[graph_id], - entities=entity_types if entity_types else None, - edges=edge_definitions if edge_definitions else None, - ) + """Set ontology.""" + self.provider.set_ontology(graph_id, ontology) def add_text_batches( self, @@ -292,117 +178,35 @@ def add_text_batches( batch_size: int = 3, progress_callback: Optional[Callable] = None ) -> List[str]: - """分批添加文本到图谱,返回所有 episode 的 uuid 列表""" - episode_uuids = [] - total_chunks = len(chunks) - - for i in range(0, total_chunks, batch_size): - batch_chunks = chunks[i:i + batch_size] - batch_num = i // batch_size + 1 - total_batches = (total_chunks + batch_size - 1) // batch_size - - if progress_callback: - progress = (i + len(batch_chunks)) / total_chunks - progress_callback( - f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...", - progress - ) - - # 构建episode数据 - episodes = [ - EpisodeData(data=chunk, type="text") - for chunk in batch_chunks - ] - - # 发送到Zep - try: - batch_result = self.client.graph.add_batch( - graph_id=graph_id, - episodes=episodes - ) - - # 收集返回的 episode uuid - if batch_result and isinstance(batch_result, list): - for ep in batch_result: - ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) - if ep_uuid: - episode_uuids.append(ep_uuid) - - # 避免请求过快 - time.sleep(1) - - except Exception as e: - if progress_callback: - progress_callback(f"批次 {batch_num} 发送失败: {str(e)}", 0) - raise - - return episode_uuids + """Add text batches.""" + return self.provider.add_text_batches( + graph_id=graph_id, + chunks=chunks, + batch_size=batch_size, + progress_callback=progress_callback, + ) def _wait_for_episodes( self, + graph_id: str, episode_uuids: List[str], progress_callback: Optional[Callable] = None, timeout: int = 600 ): - """等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" - if not episode_uuids: - if progress_callback: - progress_callback("无需等待(没有 episode)", 1.0) - return - - start_time = time.time() - pending_episodes = set(episode_uuids) - completed_count = 0 - total_episodes = len(episode_uuids) - - if progress_callback: - progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0) - - while pending_episodes: - if time.time() - start_time > timeout: - if progress_callback: - progress_callback( - f"部分文本块超时,已完成 {completed_count}/{total_episodes}", - completed_count / total_episodes - ) - break - - # 检查每个 episode 的处理状态 - for ep_uuid in list(pending_episodes): - try: - episode = self.client.graph.episode.get(uuid_=ep_uuid) - is_processed = getattr(episode, 'processed', False) - - if is_processed: - pending_episodes.remove(ep_uuid) - completed_count += 1 - - except Exception as e: - # 忽略单个查询错误,继续 - pass - - elapsed = int(time.time() - start_time) - if progress_callback: - progress_callback( - f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)", - completed_count / total_episodes if total_episodes > 0 else 0 - ) - - if pending_episodes: - time.sleep(3) # 每3秒检查一次 - - if progress_callback: - progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0) + """Wait For Episodes.""" + self.provider.wait_for_episodes( + graph_id=graph_id, + episode_uuids=episode_uuids, + progress_callback=progress_callback, + timeout=timeout, + ) def _get_graph_info(self, graph_id: str) -> GraphInfo: - """获取图谱信息""" - # 获取节点(分页) - nodes = fetch_all_nodes(self.client, graph_id) - - # 获取边(分页) - edges = fetch_all_edges(self.client, graph_id) + """Get graph info.""" + nodes = self.provider.get_all_nodes(graph_id) + edges = self.provider.get_all_edges(graph_id) - # 统计实体类型 + entity_types = set() for node in nodes: if node.labels: @@ -418,72 +222,41 @@ def _get_graph_info(self, graph_id: str) -> GraphInfo: ) def get_graph_data(self, graph_id: str) -> Dict[str, Any]: - """ - 获取完整图谱数据(包含详细信息) - - Args: - graph_id: 图谱ID - - Returns: - 包含nodes和edges的字典,包括时间信息、属性等详细数据 - """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) + """Get graph data.""" + nodes = self.provider.get_all_nodes(graph_id) + edges = self.provider.get_all_edges(graph_id) - # 创建节点映射用于获取节点名称 - node_map = {} - for node in nodes: - node_map[node.uuid_] = node.name or "" + + node_map = {node.uuid: node.name or "" for node in nodes} nodes_data = [] for node in nodes: - # 获取创建时间 - created_at = getattr(node, 'created_at', None) - if created_at: - created_at = str(created_at) - nodes_data.append({ - "uuid": node.uuid_, + "uuid": node.uuid, "name": node.name, "labels": node.labels or [], "summary": node.summary or "", "attributes": node.attributes or {}, - "created_at": created_at, + "created_at": node.created_at, }) edges_data = [] for edge in edges: - # 获取时间信息 - created_at = getattr(edge, 'created_at', None) - valid_at = getattr(edge, 'valid_at', None) - invalid_at = getattr(edge, 'invalid_at', None) - expired_at = getattr(edge, 'expired_at', None) - - # 获取 episodes - episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) - if episodes and not isinstance(episodes, list): - episodes = [str(episodes)] - elif episodes: - episodes = [str(e) for e in episodes] - - # 获取 fact_type - fact_type = getattr(edge, 'fact_type', None) or edge.name or "" - edges_data.append({ - "uuid": edge.uuid_, + "uuid": edge.uuid, "name": edge.name or "", "fact": edge.fact or "", - "fact_type": fact_type, + "fact_type": edge.name or "", "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "source_node_name": node_map.get(edge.source_node_uuid, ""), "target_node_name": node_map.get(edge.target_node_uuid, ""), "attributes": edge.attributes or {}, - "created_at": str(created_at) if created_at else None, - "valid_at": str(valid_at) if valid_at else None, - "invalid_at": str(invalid_at) if invalid_at else None, - "expired_at": str(expired_at) if expired_at else None, - "episodes": episodes or [], + "created_at": edge.created_at, + "valid_at": edge.valid_at, + "invalid_at": edge.invalid_at, + "expired_at": edge.expired_at, + "episodes": edge.episodes or [], }) return { @@ -495,6 +268,5 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]: } def delete_graph(self, graph_id: str): - """删除图谱""" - self.client.graph.delete(graph_id=graph_id) - + """Delete graph.""" + self.provider.delete_graph(graph_id) diff --git a/backend/app/services/graph_provider/__init__.py b/backend/app/services/graph_provider/__init__.py new file mode 100644 index 000000000..673e8fe1b --- /dev/null +++ b/backend/app/services/graph_provider/__init__.py @@ -0,0 +1,16 @@ +""" +Graph provider exports. +""" + +from .base import BaseGraphProvider +from .factory import create_graph_provider, initialize_selected_graph_backend +from .models import GraphEdgeRecord, GraphNodeRecord, GraphSearchResult + +__all__ = [ + 'BaseGraphProvider', + 'GraphEdgeRecord', + 'GraphNodeRecord', + 'GraphSearchResult', + 'create_graph_provider', + 'initialize_selected_graph_backend', +] diff --git a/backend/app/services/graph_provider/base.py b/backend/app/services/graph_provider/base.py new file mode 100644 index 000000000..e12b5ae09 --- /dev/null +++ b/backend/app/services/graph_provider/base.py @@ -0,0 +1,89 @@ +""" +Abstract graph provider interface. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any, Optional + +from .models import GraphEdgeRecord, GraphNodeRecord, GraphSearchResult + +ProgressCallback = Callable[[str, float], None] + + +class BaseGraphProvider(ABC): + """Provider-neutral graph backend interface.""" + + def ensure_initialized(self) -> None: + """Perform one-time backend initialization when needed.""" + + @abstractmethod + def create_graph(self, name: str) -> str: + raise NotImplementedError + + @abstractmethod + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> None: + raise NotImplementedError + + @abstractmethod + def add_text_batches( + self, + graph_id: str, + chunks: list[str], + batch_size: int = 3, + progress_callback: Optional[ProgressCallback] = None, + ) -> list[str]: + raise NotImplementedError + + @abstractmethod + def wait_for_episodes( + self, + graph_id: str, + episode_uuids: list[str], + progress_callback: Optional[ProgressCallback] = None, + timeout: int = 600, + ) -> None: + raise NotImplementedError + + @abstractmethod + def get_all_nodes(self, graph_id: str) -> list[GraphNodeRecord]: + raise NotImplementedError + + @abstractmethod + def get_all_edges(self, graph_id: str) -> list[GraphEdgeRecord]: + raise NotImplementedError + + @abstractmethod + def get_node(self, graph_id: str, node_uuid: str) -> GraphNodeRecord | None: + raise NotImplementedError + + @abstractmethod + def get_node_edges(self, graph_id: str, node_uuid: str) -> list[GraphEdgeRecord]: + raise NotImplementedError + + @abstractmethod + def search( + self, + graph_id: str, + query: str, + limit: int = 10, + scope: str = "edges", + reranker: str = "cross_encoder", + ) -> GraphSearchResult: + raise NotImplementedError + + @abstractmethod + def add_text( + self, + graph_id: str, + data: str, + source_description: str = "MiroFish", + ) -> str | None: + raise NotImplementedError + + @abstractmethod + def delete_graph(self, graph_id: str) -> None: + raise NotImplementedError + diff --git a/backend/app/services/graph_provider/factory.py b/backend/app/services/graph_provider/factory.py new file mode 100644 index 000000000..3b2a5cc1f --- /dev/null +++ b/backend/app/services/graph_provider/factory.py @@ -0,0 +1,33 @@ +""" +Graph provider factory and backend bootstrap helpers. +""" + +from __future__ import annotations + +from functools import lru_cache + +from ...config import Config + + +@lru_cache(maxsize=2) +def _create_graph_provider_for_backend(backend: str): + if backend == "zep_cloud": + from .zep_cloud_provider import ZepCloudGraphProvider + + return ZepCloudGraphProvider() + + if backend == "graphiti_local": + from .graphiti_local_provider import GraphitiLocalGraphProvider + + return GraphitiLocalGraphProvider() + + raise ValueError(f"Unsupported GRAPH_BACKEND: {backend}") + + +def create_graph_provider(): + return _create_graph_provider_for_backend(Config.GRAPH_BACKEND) + + +def initialize_selected_graph_backend() -> None: + provider = create_graph_provider() + provider.ensure_initialized() diff --git a/backend/app/services/graph_provider/graphiti_local_provider.py b/backend/app/services/graph_provider/graphiti_local_provider.py new file mode 100644 index 000000000..a8f1ef4e8 --- /dev/null +++ b/backend/app/services/graph_provider/graphiti_local_provider.py @@ -0,0 +1,640 @@ +""" +Local Graphiti + Neo4j graph provider implementation. +""" + +from __future__ import annotations + +import atexit +import asyncio +import json +import os +import re +import threading +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from ...config import Config +from ...utils.logger import get_logger +from ...utils.ontology_normalizer import normalize_ontology_for_zep +from .base import BaseGraphProvider, ProgressCallback +from .models import GraphEdgeRecord, GraphNodeRecord, GraphSearchResult + +logger = get_logger('mirofish.graph_provider.graphiti_local') + + +class _AsyncRunner: + """Run all Graphiti/Neo4j async work on one dedicated event loop thread.""" + + def __init__(self): + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._run_loop, name="graphiti-local-loop", daemon=True) + self._started = threading.Event() + self._closed = False + self._thread.start() + self._started.wait() + + def _run_loop(self) -> None: + asyncio.set_event_loop(self._loop) + self._started.set() + self._loop.run_forever() + + def run(self, coro): + if self._closed: + raise RuntimeError("Async runner is already closed") + + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result() + + def close(self) -> None: + if self._closed: + return + + self._closed = True + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join(timeout=5) + self._loop.close() + + +_ASYNC_RUNNER = _AsyncRunner() +atexit.register(_ASYNC_RUNNER.close) + + +def _run_async(coro): + return _ASYNC_RUNNER.run(coro) + + +@dataclass +class _OntologyBundle: + entity_types: dict[str, type[BaseModel]] + edge_types: dict[str, type[BaseModel]] + edge_type_map: dict[tuple[str, str], list[str]] + attribute_free_entity_types: dict[str, type[BaseModel]] + attribute_free_edge_types: dict[str, type[BaseModel]] + + +class GraphitiLocalGraphProvider(BaseGraphProvider): + """Graphiti + Neo4j backed graph provider.""" + + _initialized = False + # Startup can flow through ensure_initialized() -> _ensure_client_ready(), so this + # lock must be re-entrant to avoid self-deadlocking during app bootstrap. + _init_lock = threading.RLock() + + def __init__(self): + try: + from graphiti_core import Graphiti + from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient + from graphiti_core.driver.neo4j_driver import Neo4jDriver + from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig + from graphiti_core.errors import GroupsEdgesNotFoundError, GroupsNodesNotFoundError, NodeNotFoundError + from graphiti_core.llm_client.config import LLMConfig + from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient + from neo4j.exceptions import ClientError + except ImportError as exc: # pragma: no cover - depends on installed extras + raise ImportError( + "graphiti-core and neo4j must be installed to use GRAPH_BACKEND=graphiti_local" + ) from exc + + self._Graphiti = Graphiti + self._Neo4jDriver = Neo4jDriver + self._OpenAIEmbedder = OpenAIEmbedder + self._OpenAIEmbedderConfig = OpenAIEmbedderConfig + self._OpenAIRerankerClient = OpenAIRerankerClient + self._OpenAIGenericClient = OpenAIGenericClient + self._LLMConfig = LLMConfig + self._GroupsEdgesNotFoundError = GroupsEdgesNotFoundError + self._GroupsNodesNotFoundError = GroupsNodesNotFoundError + self._NodeNotFoundError = NodeNotFoundError + self._ClientError = ClientError + + # Graphiti reads this env var directly. + os.environ.setdefault('GRAPHITI_TELEMETRY_ENABLED', str(Config.GRAPHITI_TELEMETRY_ENABLED).lower()) + + self._llm_config = self._LLMConfig( + api_key=Config.GRAPHITI_LLM_API_KEY, + base_url=Config.GRAPHITI_LLM_BASE_URL, + model=Config.GRAPHITI_LLM_MODEL, + ) + self._reranker_config = self._LLMConfig( + api_key=Config.GRAPHITI_RERANKER_API_KEY, + base_url=Config.GRAPHITI_RERANKER_BASE_URL, + model=Config.GRAPHITI_RERANKER_MODEL, + ) + self._embedder_config = self._OpenAIEmbedderConfig( + api_key=Config.GRAPHITI_EMBEDDER_API_KEY, + base_url=Config.GRAPHITI_EMBEDDER_BASE_URL, + embedding_model=Config.GRAPHITI_EMBEDDER_MODEL, + ) + + self.driver = self._Neo4jDriver( + uri=Config.NEO4J_URI, + user=Config.NEO4J_USER, + password=Config.NEO4J_PASSWORD, + database=Config.NEO4J_DATABASE, + ) + self.client = self._make_graphiti_client(self.driver) + self._ontology_cache: dict[str, _OntologyBundle] = {} + self._client_ready = False + + def ensure_initialized(self) -> None: + if GraphitiLocalGraphProvider._initialized or not Config.GRAPHITI_AUTO_INIT: + return + + with GraphitiLocalGraphProvider._init_lock: + if GraphitiLocalGraphProvider._initialized: + return + self._ensure_client_ready() + GraphitiLocalGraphProvider._initialized = True + + def _ensure_client_ready(self) -> None: + if self._client_ready: + return + + with GraphitiLocalGraphProvider._init_lock: + if self._client_ready: + return + logger.info("Checking local Neo4j connectivity...") + _run_async(self.driver.health_check()) + logger.info("Local Neo4j connectivity confirmed") + logger.info("Initializing local Graphiti indices and constraints...") + _run_async(self.client.build_indices_and_constraints()) + self._client_ready = True + logger.info("Local Graphiti initialization completed") + + def create_graph(self, name: str) -> str: + self._ensure_client_ready() + graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" + logger.info("Created local Graphiti graph namespace %s (%s)", graph_id, name) + return graph_id + + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> None: + self._ontology_cache[graph_id] = self._build_ontology_bundle(ontology) + + def add_text_batches( + self, + graph_id: str, + chunks: list[str], + batch_size: int = 3, + progress_callback: Optional[ProgressCallback] = None, + ) -> list[str]: + self._ensure_client_ready() + client = self._get_graphiti_client(graph_id) + bundle = self._ontology_cache.get(graph_id) + episode_uuids: list[str] = [] + total_chunks = len(chunks) + + from graphiti_core.nodes import EpisodeType + + base_time = datetime.now(timezone.utc) + + for i in range(0, total_chunks, batch_size): + batch_chunks = chunks[i:i + batch_size] + batch_num = i // batch_size + 1 + total_batches = (total_chunks + batch_size - 1) // batch_size + + if progress_callback: + progress_callback( + f"Sending local batch {batch_num}/{total_batches} ({len(batch_chunks)} chunks)...", + (i + len(batch_chunks)) / total_chunks if total_chunks else 1.0, + ) + + for index, chunk in enumerate(batch_chunks): + result = _run_async( + self._add_episode( + client=client, + graph_id=graph_id, + name=f"{graph_id}_chunk_{i + index + 1}", + episode_body=chunk, + source_description="MiroFish document chunk", + reference_time=base_time + timedelta(seconds=i + index), + source=EpisodeType.text, + bundle=bundle, + ) + ) + self._persist_graph_result(client, result) + + episode = getattr(result, 'episode', None) + episode_uuid = getattr(episode, 'uuid', None) or getattr(episode, 'uuid_', None) + if episode_uuid: + episode_uuids.append(str(episode_uuid)) + + return episode_uuids + + def wait_for_episodes( + self, + graph_id: str, + episode_uuids: list[str], + progress_callback: Optional[ProgressCallback] = None, + timeout: int = 600, + ) -> None: + if progress_callback: + progress_callback( + "Local Graphiti ingestion completed", + 1.0, + ) + + def get_all_nodes(self, graph_id: str) -> list[GraphNodeRecord]: + self._ensure_client_ready() + from graphiti_core.nodes import EntityNode + + return [ + self._normalize_node(node) + for node in self._fetch_group_records(EntityNode.get_by_group_ids, graph_id) + ] + + def get_all_edges(self, graph_id: str) -> list[GraphEdgeRecord]: + self._ensure_client_ready() + from graphiti_core.edges import EntityEdge + + return [ + self._normalize_edge(edge) + for edge in self._fetch_group_records(EntityEdge.get_by_group_ids, graph_id) + ] + + def get_node(self, graph_id: str, node_uuid: str) -> GraphNodeRecord | None: + self._ensure_client_ready() + from graphiti_core.nodes import EntityNode + + graph_driver = self._get_graph_driver(graph_id) + try: + node = _run_async(EntityNode.get_by_uuid(graph_driver, node_uuid)) + except self._NodeNotFoundError: + return None + + if graph_id and getattr(node, 'group_id', None) not in (None, *self._graph_namespaces(graph_id)): + return None + return self._normalize_node(node) + + def get_node_edges(self, graph_id: str, node_uuid: str) -> list[GraphEdgeRecord]: + self._ensure_client_ready() + from graphiti_core.edges import EntityEdge + + graph_driver = self._get_graph_driver(graph_id) + edges = _run_async(EntityEdge.get_by_node_uuid(graph_driver, node_uuid)) + return [ + self._normalize_edge(edge) + for edge in edges + if not graph_id or getattr(edge, 'group_id', None) in (None, *self._graph_namespaces(graph_id)) + ] + + def search( + self, + graph_id: str, + query: str, + limit: int = 10, + scope: str = "edges", + reranker: str = "cross_encoder", + ) -> GraphSearchResult: + self._ensure_client_ready() + client = self._get_graphiti_client(graph_id) + from graphiti_core.search.search_config_recipes import ( + EDGE_HYBRID_SEARCH_CROSS_ENCODER, + EDGE_HYBRID_SEARCH_RRF, + NODE_HYBRID_SEARCH_CROSS_ENCODER, + NODE_HYBRID_SEARCH_RRF, + ) + + effective_reranker = Config.GRAPHITI_SEARCH_RERANKER or reranker or "rrf" + + if scope == "nodes": + config = ( + NODE_HYBRID_SEARCH_CROSS_ENCODER.model_copy(deep=True) + if effective_reranker == "cross_encoder" + else NODE_HYBRID_SEARCH_RRF.model_copy(deep=True) + ) + else: + config = ( + EDGE_HYBRID_SEARCH_CROSS_ENCODER.model_copy(deep=True) + if effective_reranker == "cross_encoder" + else EDGE_HYBRID_SEARCH_RRF.model_copy(deep=True) + ) + config.limit = limit + + results = _run_async( + client.search_( + query=query, + config=config, + group_ids=self._graph_namespaces(graph_id), + ) + ) + + edges = [self._normalize_edge(edge) for edge in results.edges] + nodes = [self._normalize_node(node) for node in results.nodes] + + facts = [edge.fact for edge in edges if edge.fact] + if scope == "nodes": + facts.extend(f"[{node.name}]: {node.summary}" for node in nodes if node.summary) + + return GraphSearchResult(facts=facts, edges=edges, nodes=nodes) + + def add_text( + self, + graph_id: str, + data: str, + source_description: str = "MiroFish", + ) -> str | None: + self._ensure_client_ready() + client = self._get_graphiti_client(graph_id) + from graphiti_core.nodes import EpisodeType + + result = _run_async( + self._add_episode( + client=client, + graph_id=graph_id, + name=f"{graph_id}_activity_{uuid.uuid4().hex[:8]}", + episode_body=data, + source_description=source_description, + reference_time=datetime.now(timezone.utc), + source=EpisodeType.text, + ) + ) + self._persist_graph_result(client, result) + episode = getattr(result, 'episode', None) + episode_uuid = getattr(episode, 'uuid', None) if episode else None + return str(episode_uuid) if episode_uuid else None + + def delete_graph(self, graph_id: str) -> None: + self._ensure_client_ready() + from graphiti_core.edges import EntityEdge, EpisodicEdge + from graphiti_core.nodes import EntityNode, EpisodicNode + + graph_driver = self._get_graph_driver(graph_id) + entity_edges = self._fetch_group_records(EntityEdge.get_by_group_ids, graph_id) + episodic_edges = self._fetch_group_records(EpisodicEdge.get_by_group_ids, graph_id) + episodic_nodes = self._fetch_group_records(EpisodicNode.get_by_group_ids, graph_id) + entity_nodes = self._fetch_group_records(EntityNode.get_by_group_ids, graph_id) + + if episodic_edges: + _run_async(EpisodicEdge.delete_by_uuids(graph_driver, [edge.uuid for edge in episodic_edges])) + if entity_edges: + _run_async(EntityEdge.delete_by_uuids(graph_driver, [edge.uuid for edge in entity_edges])) + if episodic_nodes: + _run_async(EpisodicNode.delete_by_uuids(graph_driver, [node.uuid for node in episodic_nodes])) + if entity_nodes: + _run_async(EntityNode.delete_by_uuids(graph_driver, [node.uuid for node in entity_nodes])) + + self._ontology_cache.pop(graph_id, None) + + def _fetch_group_records(self, fetcher, graph_id: str, page_size: int = 100) -> list[Any]: + graph_driver = self._get_graph_driver(graph_id) + graph_namespaces = self._graph_namespaces(graph_id) + records: list[Any] = [] + cursor: str | None = None + + while True: + try: + batch = _run_async( + fetcher( + graph_driver, + graph_namespaces, + limit=page_size, + uuid_cursor=cursor, + ) + ) + except (self._GroupsEdgesNotFoundError, self._GroupsNodesNotFoundError): + break + if not batch: + break + + records.extend(batch) + if len(batch) < page_size: + break + + cursor = getattr(batch[-1], 'uuid', None) or getattr(batch[-1], 'uuid_', None) + if cursor is None: + break + + return records + + def _graph_namespace(self, graph_id: str) -> str: + if not graph_id or not re.fullmatch(r'[A-Za-z0-9_-]+', graph_id): + raise ValueError(f"Invalid graph_id for local Graphiti backend: {graph_id}") + return graph_id + + def _graph_namespaces(self, graph_id: str) -> list[str]: + primary = self._graph_namespace(graph_id) + namespaces = [primary] + legacy = primary.replace('_', '-') + if legacy != primary: + namespaces.append(legacy) + return namespaces + + def _make_graphiti_client(self, graph_driver) -> Any: + return self._Graphiti( + graph_driver=graph_driver, + llm_client=self._OpenAIGenericClient(config=self._llm_config), + embedder=self._OpenAIEmbedder(config=self._embedder_config), + cross_encoder=self._OpenAIRerankerClient(config=self._reranker_config), + max_coroutines=Config.GRAPHITI_MAX_COROUTINES, + ) + + def _get_graphiti_client(self, graph_id: str): + self._graph_namespace(graph_id) + self._ensure_client_ready() + return self.client + + def _get_graph_driver(self, graph_id: str): + return self._get_graphiti_client(graph_id).driver + + async def _add_episode( + self, + client, + graph_id: str, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source, + bundle: _OntologyBundle | None = None, + ): + episode_kwargs = { + "name": name, + "episode_body": episode_body, + "source_description": source_description, + "reference_time": reference_time, + "source": source, + "group_id": self._graph_namespace(graph_id), + "entity_types": bundle.entity_types if bundle else None, + "edge_types": bundle.edge_types if bundle else None, + "edge_type_map": bundle.edge_type_map if bundle else None, + } + + try: + return await client.add_episode(**episode_kwargs) + except Exception as exc: + if not bundle or not self._is_non_primitive_property_error(exc): + raise + + logger.warning( + "Local Graphiti ontology extraction returned non-primitive Neo4j properties for %s; retrying without ontology attributes. Error: %s", + graph_id, + exc, + ) + fallback_kwargs = dict(episode_kwargs) + fallback_kwargs.update( + entity_types=bundle.attribute_free_entity_types, + edge_types=bundle.attribute_free_edge_types, + edge_type_map=bundle.edge_type_map, + ) + return await client.add_episode(**fallback_kwargs) + + def _persist_graph_result(self, client, result: Any) -> None: + for node in getattr(result, 'nodes', []) or []: + node.attributes = self._sanitize_attributes(getattr(node, 'attributes', {}) or {}) + if getattr(node, 'name_embedding', None) is None: + _run_async(node.generate_name_embedding(client.embedder)) + _run_async(node.save(client.driver)) + + for edge in getattr(result, 'edges', []) or []: + edge.attributes = self._sanitize_attributes(getattr(edge, 'attributes', {}) or {}) + if getattr(edge, 'fact_embedding', None) is None: + _run_async(edge.generate_embedding(client.embedder)) + _run_async(edge.save(client.driver)) + + @staticmethod + def _is_non_primitive_property_error(exc: Exception) -> bool: + return 'Property values can only be of primitive types or arrays thereof' in str(exc) + + def _sanitize_attributes(self, attributes: dict[str, Any]) -> dict[str, Any]: + sanitized: dict[str, Any] = {} + for key, value in attributes.items(): + sanitized[key] = self._sanitize_property_value(key, value) + return sanitized + + def _sanitize_property_value(self, key: str, value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + + if isinstance(value, (list, tuple)): + return [ + item + if isinstance(item, (str, int, float, bool)) or item is None + else json.dumps(item, ensure_ascii=False, default=str) + for item in value + ] + + if isinstance(value, dict): + if key in value: + return self._sanitize_property_value(key, value[key]) + if len(value) == 1: + return self._sanitize_property_value(key, next(iter(value.values()))) + return json.dumps(value, ensure_ascii=False, default=str) + + return str(value) + + def _build_ontology_bundle(self, ontology: dict[str, Any]) -> _OntologyBundle: + ontology, _ = normalize_ontology_for_zep(ontology) + reserved_names = { + 'uuid', + 'name', + 'group_id', + 'labels', + 'created_at', + 'summary', + 'attributes', + 'name_embedding', + } + + def safe_attr_name(attr_name: str) -> str: + if attr_name.lower() in reserved_names: + return f"entity_{attr_name}" + return attr_name + + entity_types: dict[str, type[BaseModel]] = {} + for entity_def in ontology.get("entity_types", []): + name = entity_def["name"] + description = entity_def.get("description", f"A {name} entity.") + attrs: dict[str, Any] = {"__doc__": description} + annotations: dict[str, Any] = {} + for attr_def in entity_def.get("attributes", []): + attr_name = safe_attr_name(attr_def["name"]) + attr_desc = attr_def.get("description", attr_name) + attrs[attr_name] = Field(default=None, description=attr_desc) + annotations[attr_name] = Optional[str] + attrs["__annotations__"] = annotations + entity_class = type(name, (BaseModel,), attrs) + entity_class.__doc__ = description + entity_types[name] = entity_class + + edge_types: dict[str, type[BaseModel]] = {} + edge_type_map: dict[tuple[str, str], list[str]] = {} + for edge_def in ontology.get("edge_types", []): + name = edge_def["name"] + description = edge_def.get("description", f"A {name} relationship.") + attrs = {"__doc__": description} + annotations = {} + for attr_def in edge_def.get("attributes", []): + attr_name = safe_attr_name(attr_def["name"]) + attr_desc = attr_def.get("description", attr_name) + attrs[attr_name] = Field(default=None, description=attr_desc) + annotations[attr_name] = Optional[str] + attrs["__annotations__"] = annotations + edge_class = type(name, (BaseModel,), attrs) + edge_class.__doc__ = description + edge_types[name] = edge_class + + source_targets = edge_def.get("source_targets", []) or [{"source": "Entity", "target": "Entity"}] + for source_target in source_targets: + signature = ( + source_target.get("source", "Entity"), + source_target.get("target", "Entity"), + ) + edge_type_map.setdefault(signature, []).append(name) + + return _OntologyBundle( + entity_types=entity_types, + edge_types=edge_types, + edge_type_map=edge_type_map, + attribute_free_entity_types=self._build_attribute_free_models(entity_types), + attribute_free_edge_types=self._build_attribute_free_models(edge_types), + ) + + @staticmethod + def _build_attribute_free_models( + typed_models: dict[str, type[BaseModel]] + ) -> dict[str, type[BaseModel]]: + stripped_models: dict[str, type[BaseModel]] = {} + for model_name, model_type in typed_models.items(): + attrs: dict[str, Any] = { + "__doc__": model_type.__doc__ or f"A {model_name} type.", + "__annotations__": {}, + } + stripped_model = type(model_name, (BaseModel,), attrs) + stripped_model.__doc__ = model_type.__doc__ + stripped_models[model_name] = stripped_model + return stripped_models + + @staticmethod + def _normalize_node(node: Any) -> GraphNodeRecord: + created_at = getattr(node, 'created_at', None) + return GraphNodeRecord( + uuid=str(getattr(node, 'uuid', None) or getattr(node, 'uuid_', None) or ""), + name=getattr(node, 'name', '') or "", + labels=getattr(node, 'labels', []) or [], + summary=getattr(node, 'summary', '') or "", + attributes=getattr(node, 'attributes', {}) or {}, + created_at=str(created_at) if created_at else None, + ) + + @staticmethod + def _normalize_edge(edge: Any) -> GraphEdgeRecord: + episodes = getattr(edge, 'episodes', None) or [] + if not isinstance(episodes, list): + episodes = [str(episodes)] + return GraphEdgeRecord( + uuid=str(getattr(edge, 'uuid', None) or getattr(edge, 'uuid_', None) or ""), + name=getattr(edge, 'name', '') or "", + fact=getattr(edge, 'fact', '') or "", + source_node_uuid=getattr(edge, 'source_node_uuid', '') or "", + target_node_uuid=getattr(edge, 'target_node_uuid', '') or "", + attributes=getattr(edge, 'attributes', {}) or {}, + created_at=str(getattr(edge, 'created_at', None)) if getattr(edge, 'created_at', None) else None, + valid_at=str(getattr(edge, 'valid_at', None)) if getattr(edge, 'valid_at', None) else None, + invalid_at=str(getattr(edge, 'invalid_at', None)) if getattr(edge, 'invalid_at', None) else None, + expired_at=str(getattr(edge, 'expired_at', None)) if getattr(edge, 'expired_at', None) else None, + episodes=[str(episode) for episode in episodes], + ) diff --git a/backend/app/services/graph_provider/models.py b/backend/app/services/graph_provider/models.py new file mode 100644 index 000000000..7532f3ee8 --- /dev/null +++ b/backend/app/services/graph_provider/models.py @@ -0,0 +1,41 @@ +""" +Provider-neutral graph data models. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class GraphNodeRecord: + uuid: str + name: str + labels: List[str] = field(default_factory=list) + summary: str = "" + attributes: Dict[str, Any] = field(default_factory=dict) + created_at: Optional[str] = None + + +@dataclass +class GraphEdgeRecord: + uuid: str + name: str + fact: str + source_node_uuid: str + target_node_uuid: str + attributes: Dict[str, Any] = field(default_factory=dict) + created_at: Optional[str] = None + valid_at: Optional[str] = None + invalid_at: Optional[str] = None + expired_at: Optional[str] = None + episodes: List[str] = field(default_factory=list) + + +@dataclass +class GraphSearchResult: + facts: List[str] = field(default_factory=list) + edges: List[GraphEdgeRecord] = field(default_factory=list) + nodes: List[GraphNodeRecord] = field(default_factory=list) + diff --git a/backend/app/services/graph_provider/zep_cloud_provider.py b/backend/app/services/graph_provider/zep_cloud_provider.py new file mode 100644 index 000000000..15fbde4e1 --- /dev/null +++ b/backend/app/services/graph_provider/zep_cloud_provider.py @@ -0,0 +1,297 @@ +""" +Zep Cloud graph provider implementation. +""" + +from __future__ import annotations + +import time +import uuid +from typing import Any, Optional + +from zep_cloud import EpisodeData, EntityEdgeSourceTarget +from zep_cloud.client import Zep + +from ...config import Config +from ...utils.logger import get_logger +from ...utils.ontology_normalizer import normalize_ontology_for_zep +from ...utils.zep_paging import fetch_all_edges, fetch_all_nodes +from .base import BaseGraphProvider, ProgressCallback +from .models import GraphEdgeRecord, GraphNodeRecord, GraphSearchResult + +logger = get_logger('mirofish.graph_provider.zep_cloud') + + +class ZepCloudGraphProvider(BaseGraphProvider): + """Zep Cloud backed graph provider.""" + + def __init__(self, api_key: Optional[str] = None): + self.api_key = api_key or Config.ZEP_API_KEY + if not self.api_key: + raise ValueError("ZEP_API_KEY is not configured") + + self.client = Zep(api_key=self.api_key) + + def create_graph(self, name: str) -> str: + graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" + self.client.graph.create( + graph_id=graph_id, + name=name, + description="MiroFish Social Simulation Graph", + ) + return graph_id + + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> None: + import warnings + from pydantic import Field + from zep_cloud.external_clients.ontology import EdgeModel, EntityModel, EntityText + + warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') + + ontology, entity_name_mapping = normalize_ontology_for_zep(ontology) + renamed_entities = { + original: normalized + for original, normalized in entity_name_mapping.items() + if original != normalized + } + if renamed_entities: + logger.info("Normalized ontology entity names for Zep compatibility: %s", renamed_entities) + + reserved_names = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} + + def safe_attr_name(attr_name: str) -> str: + if attr_name.lower() in reserved_names: + return f"entity_{attr_name}" + return attr_name + + entity_types: dict[str, type[EntityModel]] = {} + for entity_def in ontology.get("entity_types", []): + name = entity_def["name"] + description = entity_def.get("description", f"A {name} entity.") + attrs: dict[str, Any] = {"__doc__": description} + annotations: dict[str, Any] = {} + + for attr_def in entity_def.get("attributes", []): + attr_name = safe_attr_name(attr_def["name"]) + attr_desc = attr_def.get("description", attr_name) + attrs[attr_name] = Field(description=attr_desc, default=None) + annotations[attr_name] = Optional[EntityText] + + attrs["__annotations__"] = annotations + entity_class = type(name, (EntityModel,), attrs) + entity_class.__doc__ = description + entity_types[name] = entity_class + + edge_definitions = {} + for edge_def in ontology.get("edge_types", []): + name = edge_def["name"] + description = edge_def.get("description", f"A {name} relationship.") + attrs = {"__doc__": description} + annotations = {} + + for attr_def in edge_def.get("attributes", []): + attr_name = safe_attr_name(attr_def["name"]) + attr_desc = attr_def.get("description", attr_name) + attrs[attr_name] = Field(description=attr_desc, default=None) + annotations[attr_name] = Optional[str] + + attrs["__annotations__"] = annotations + class_name = ''.join(word.capitalize() for word in name.split('_')) + edge_class = type(class_name, (EdgeModel,), attrs) + edge_class.__doc__ = description + + source_targets = [] + for st in edge_def.get("source_targets", []): + source_targets.append( + EntityEdgeSourceTarget( + source=st.get("source", "Entity"), + target=st.get("target", "Entity"), + ) + ) + + if source_targets: + edge_definitions[name] = (edge_class, source_targets) + + if entity_types or edge_definitions: + self.client.graph.set_ontology( + graph_ids=[graph_id], + entities=entity_types if entity_types else None, + edges=edge_definitions if edge_definitions else None, + ) + + def add_text_batches( + self, + graph_id: str, + chunks: list[str], + batch_size: int = 3, + progress_callback: Optional[ProgressCallback] = None, + ) -> list[str]: + episode_uuids: list[str] = [] + total_chunks = len(chunks) + + for i in range(0, total_chunks, batch_size): + batch_chunks = chunks[i:i + batch_size] + batch_num = i // batch_size + 1 + total_batches = (total_chunks + batch_size - 1) // batch_size + + if progress_callback: + progress_callback( + f"Sending batch {batch_num}/{total_batches} ({len(batch_chunks)} chunks)...", + (i + len(batch_chunks)) / total_chunks if total_chunks else 1.0, + ) + + episodes = [EpisodeData(data=chunk, type="text") for chunk in batch_chunks] + batch_result = self.client.graph.add_batch(graph_id=graph_id, episodes=episodes) + + if batch_result and isinstance(batch_result, list): + for episode in batch_result: + episode_uuid = getattr(episode, 'uuid_', None) or getattr(episode, 'uuid', None) + if episode_uuid: + episode_uuids.append(str(episode_uuid)) + + time.sleep(1) + + return episode_uuids + + def wait_for_episodes( + self, + graph_id: str, + episode_uuids: list[str], + progress_callback: Optional[ProgressCallback] = None, + timeout: int = 600, + ) -> None: + if not episode_uuids: + if progress_callback: + progress_callback("No wait needed (no episodes)", 1.0) + return + + start_time = time.time() + pending_episodes = set(episode_uuids) + completed_count = 0 + total_episodes = len(episode_uuids) + + if progress_callback: + progress_callback(f"Waiting for {total_episodes} text chunks to be processed...", 0) + + while pending_episodes: + if time.time() - start_time > timeout: + if progress_callback: + progress_callback( + f"Some text chunks timed out; completed {completed_count}/{total_episodes}", + completed_count / total_episodes if total_episodes else 1.0, + ) + break + + for episode_uuid in list(pending_episodes): + try: + episode = self.client.graph.episode.get(uuid_=episode_uuid) + except Exception: + continue + + if getattr(episode, 'processed', False): + pending_episodes.remove(episode_uuid) + completed_count += 1 + + if progress_callback: + elapsed = int(time.time() - start_time) + progress_callback( + f"Zep processing... {completed_count}/{total_episodes} complete, " + f"{len(pending_episodes)} pending ({elapsed}s)", + completed_count / total_episodes if total_episodes else 1.0, + ) + + if pending_episodes: + time.sleep(3) + + if progress_callback: + progress_callback(f"Processing complete: {completed_count}/{total_episodes}", 1.0) + + def get_all_nodes(self, graph_id: str) -> list[GraphNodeRecord]: + return [self._normalize_node(node) for node in fetch_all_nodes(self.client, graph_id)] + + def get_all_edges(self, graph_id: str) -> list[GraphEdgeRecord]: + return [self._normalize_edge(edge) for edge in fetch_all_edges(self.client, graph_id)] + + def get_node(self, graph_id: str, node_uuid: str) -> GraphNodeRecord | None: + node = self.client.graph.node.get(uuid_=node_uuid) + return self._normalize_node(node) if node else None + + def get_node_edges(self, graph_id: str, node_uuid: str) -> list[GraphEdgeRecord]: + edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid) + return [self._normalize_edge(edge) for edge in edges] + + def search( + self, + graph_id: str, + query: str, + limit: int = 10, + scope: str = "edges", + reranker: str = "cross_encoder", + ) -> GraphSearchResult: + search_results = self.client.graph.search( + graph_id=graph_id, + query=query, + limit=limit, + scope=scope, + reranker=reranker, + ) + + edges = [ + self._normalize_edge(edge) + for edge in getattr(search_results, 'edges', []) or [] + ] + nodes = [ + self._normalize_node(node) + for node in getattr(search_results, 'nodes', []) or [] + ] + + facts = [edge.fact for edge in edges if edge.fact] + if scope == "nodes": + facts.extend(f"[{node.name}]: {node.summary}" for node in nodes if node.summary) + + return GraphSearchResult(facts=facts, edges=edges, nodes=nodes) + + def add_text( + self, + graph_id: str, + data: str, + source_description: str = "MiroFish", + ) -> str | None: + result = self.client.graph.add(graph_id=graph_id, type="text", data=data) + episode_uuid = getattr(result, 'uuid_', None) or getattr(result, 'uuid', None) + return str(episode_uuid) if episode_uuid else None + + def delete_graph(self, graph_id: str) -> None: + self.client.graph.delete(graph_id=graph_id) + + @staticmethod + def _normalize_node(node: Any) -> GraphNodeRecord: + node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or "" + created_at = getattr(node, 'created_at', None) + return GraphNodeRecord( + uuid=str(node_uuid), + name=getattr(node, 'name', '') or "", + labels=getattr(node, 'labels', []) or [], + summary=getattr(node, 'summary', '') or "", + attributes=getattr(node, 'attributes', {}) or {}, + created_at=str(created_at) if created_at else None, + ) + + @staticmethod + def _normalize_edge(edge: Any) -> GraphEdgeRecord: + edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or "" + episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) or [] + if not isinstance(episodes, list): + episodes = [str(episodes)] + return GraphEdgeRecord( + uuid=str(edge_uuid), + name=getattr(edge, 'name', '') or "", + fact=getattr(edge, 'fact', '') or "", + source_node_uuid=getattr(edge, 'source_node_uuid', '') or "", + target_node_uuid=getattr(edge, 'target_node_uuid', '') or "", + attributes=getattr(edge, 'attributes', {}) or {}, + created_at=str(getattr(edge, 'created_at', None)) if getattr(edge, 'created_at', None) else None, + valid_at=str(getattr(edge, 'valid_at', None)) if getattr(edge, 'valid_at', None) else None, + invalid_at=str(getattr(edge, 'invalid_at', None)) if getattr(edge, 'invalid_at', None) else None, + expired_at=str(getattr(edge, 'expired_at', None)) if getattr(edge, 'expired_at', None) else None, + episodes=[str(episode) for episode in episodes], + ) diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 57836c539..57fc46d1b 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -1,25 +1,18 @@ -""" -OASIS Agent Profile生成器 -将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式 - -优化改进: -1. 调用Zep检索功能二次丰富节点信息 -2. 优化提示词生成非常详细的人设 -3. 区分个人实体和抽象群体实体 -""" +"""OASIS agent profile generator.""" import json import random +import re import time from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime from openai import OpenAI -from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger +from .graph_provider import create_graph_provider from .zep_entity_reader import EntityNode, ZepEntityReader logger = get_logger('mirofish.oasis_profile') @@ -27,23 +20,23 @@ @dataclass class OasisAgentProfile: - """OASIS Agent Profile数据结构""" - # 通用字段 + """OASIS Agent Profile.""" + user_id: int user_name: str name: str bio: str persona: str - # 可选字段 - Reddit风格 + karma: int = 1000 - # 可选字段 - Twitter风格 + friend_count: int = 100 follower_count: int = 150 statuses_count: int = 500 - # 额外人设信息 + age: Optional[int] = None gender: Optional[str] = None mbti: Optional[str] = None @@ -51,17 +44,17 @@ class OasisAgentProfile: profession: Optional[str] = None interested_topics: List[str] = field(default_factory=list) - # 来源实体信息 + source_entity_uuid: Optional[str] = None source_entity_type: Optional[str] = None created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) def to_reddit_format(self) -> Dict[str, Any]: - """转换为Reddit平台格式""" + """Convert the object to Reddit Format.""" profile = { "user_id": self.user_id, - "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) + "username": self.user_name, "name": self.name, "bio": self.bio, "persona": self.persona, @@ -69,7 +62,7 @@ def to_reddit_format(self) -> Dict[str, Any]: "created_at": self.created_at, } - # 添加额外人设信息(如果有) + if self.age: profile["age"] = self.age if self.gender: @@ -86,10 +79,10 @@ def to_reddit_format(self) -> Dict[str, Any]: return profile def to_twitter_format(self) -> Dict[str, Any]: - """转换为Twitter平台格式""" + """Convert the object to Twitter Format.""" profile = { "user_id": self.user_id, - "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) + "username": self.user_name, "name": self.name, "bio": self.bio, "persona": self.persona, @@ -99,7 +92,7 @@ def to_twitter_format(self) -> Dict[str, Any]: "created_at": self.created_at, } - # 添加额外人设信息 + if self.age: profile["age"] = self.age if self.gender: @@ -116,7 +109,7 @@ def to_twitter_format(self) -> Dict[str, Any]: return profile def to_dict(self) -> Dict[str, Any]: - """转换为完整字典格式""" + """Convert the object to a dictionary.""" return { "user_id": self.user_id, "user_name": self.user_name, @@ -140,18 +133,9 @@ def to_dict(self) -> Dict[str, Any]: class OasisProfileGenerator: - """ - OASIS Profile生成器 + """OASIS Profile Generator.""" - 将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile - 优化特性: - 1. 调用Zep图谱检索功能获取更丰富的上下文 - 2. 生成非常详细的人设(包括基本信息、职业经历、性格特征、社交媒体行为等) - 3. 区分个人实体和抽象群体实体 - """ - - # MBTI类型列表 MBTI_TYPES = [ "INTJ", "INTP", "ENTJ", "ENTP", "INFJ", "INFP", "ENFJ", "ENFP", @@ -159,19 +143,19 @@ class OasisProfileGenerator: "ISTP", "ISFP", "ESTP", "ESFP" ] - # 常见国家列表 + COUNTRIES = [ "China", "US", "UK", "Japan", "Germany", "France", "Canada", "Australia", "Brazil", "India", "South Korea" ] - # 个人类型实体(需要生成具体人设) + INDIVIDUAL_ENTITY_TYPES = [ "student", "alumni", "professor", "person", "publicfigure", "expert", "faculty", "official", "journalist", "activist" ] - # 群体/机构类型实体(需要生成群体代表人设) + GROUP_ENTITY_TYPES = [ "university", "governmentagency", "organization", "ngo", "mediaoutlet", "company", "institution", "group", "community" @@ -190,23 +174,16 @@ def __init__( self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: - raise ValueError("LLM_API_KEY 未配置") + raise ValueError("LLM_API_KEY is not configured") self.client = OpenAI( api_key=self.api_key, base_url=self.base_url ) - # Zep客户端用于检索丰富上下文 self.zep_api_key = zep_api_key or Config.ZEP_API_KEY - self.zep_client = None + self.graph_provider = create_graph_provider() self.graph_id = graph_id - - if self.zep_api_key: - try: - self.zep_client = Zep(api_key=self.zep_api_key) - except Exception as e: - logger.warning(f"Zep客户端初始化失败: {e}") def generate_profile_from_entity( self, @@ -214,28 +191,18 @@ def generate_profile_from_entity( user_id: int, use_llm: bool = True ) -> OasisAgentProfile: - """ - 从Zep实体生成OASIS Agent Profile - - Args: - entity: Zep实体节点 - user_id: 用户ID(用于OASIS) - use_llm: 是否使用LLM生成详细人设 - - Returns: - OasisAgentProfile - """ + """Generate profile from entity.""" entity_type = entity.get_entity_type() or "Entity" - # 基础信息 + name = entity.name user_name = self._generate_username(name) - # 构建上下文信息 + context = self._build_entity_context(entity) if use_llm: - # 使用LLM生成详细人设 + profile_data = self._generate_profile_with_llm( entity_name=name, entity_type=entity_type, @@ -244,13 +211,15 @@ def generate_profile_from_entity( context=context ) else: - # 使用规则生成基础人设 + profile_data = self._generate_profile_rule_based( entity_name=name, entity_type=entity_type, entity_summary=entity.summary, entity_attributes=entity.attributes ) + + profile_data = self._ensure_profile_english(profile_data) return OasisAgentProfile( user_id=user_id, @@ -273,33 +242,19 @@ def generate_profile_from_entity( ) def _generate_username(self, name: str) -> str: - """生成用户名""" - # 移除特殊字符,转换为小写 + """Generate username.""" + username = name.lower().replace(" ", "_") username = ''.join(c for c in username if c.isalnum() or c == '_') - # 添加随机后缀避免重复 + suffix = random.randint(100, 999) return f"{username}_{suffix}" def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]: - """ - 使用Zep图谱混合搜索功能获取实体相关的丰富信息 - - Zep没有内置混合搜索接口,需要分别搜索edges和nodes然后合并结果。 - 使用并行请求同时搜索,提高效率。 - - Args: - entity: 实体节点对象 - - Returns: - 包含facts, node_summaries, context的字典 - """ + """Search zep for entity.""" import concurrent.futures - if not self.zep_client: - return {"facts": [], "node_summaries": [], "context": ""} - entity_name = entity.name results = { @@ -308,22 +263,25 @@ def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]: "context": "" } - # 必须有graph_id才能进行搜索 + if not self.graph_id: - logger.debug(f"跳过Zep检索:未设置graph_id") + logger.debug("Skipping Zep retrieval: graph_id is not set") return results - comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景" + comprehensive_query = ( + f"All available information, activities, events, relationships, " + f"and background about {entity_name}" + ) def search_edges(): - """搜索边(事实/关系)- 带重试机制""" + """Search edges.""" max_retries = 3 last_exception = None delay = 2.0 for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.graph_provider.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, @@ -333,22 +291,22 @@ def search_edges(): except Exception as e: last_exception = e if attempt < max_retries - 1: - logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") + logger.debug(f"Zep edge search failed on attempt {attempt + 1}: {str(e)[:80]}; retrying...") time.sleep(delay) delay *= 2 else: - logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}") + logger.debug(f"Zep edge search still failed after {max_retries} attempts: {e}") return None def search_nodes(): - """搜索节点(实体摘要)- 带重试机制""" + """Search nodes.""" max_retries = 3 last_exception = None delay = 2.0 for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.graph_provider.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, @@ -358,83 +316,79 @@ def search_nodes(): except Exception as e: last_exception = e if attempt < max_retries - 1: - logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...") + logger.debug(f"Zep node search failed on attempt {attempt + 1}: {str(e)[:80]}; retrying...") time.sleep(delay) delay *= 2 else: - logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}") + logger.debug(f"Zep node search still failed after {max_retries} attempts: {e}") return None try: - # 并行执行edges和nodes搜索 + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: edge_future = executor.submit(search_edges) node_future = executor.submit(search_nodes) - # 获取结果 + edge_result = edge_future.result(timeout=30) node_result = node_future.result(timeout=30) - # 处理边搜索结果 + all_facts = set() - if edge_result and hasattr(edge_result, 'edges') and edge_result.edges: + if edge_result and edge_result.edges: for edge in edge_result.edges: - if hasattr(edge, 'fact') and edge.fact: + if edge.fact: all_facts.add(edge.fact) results["facts"] = list(all_facts) - # 处理节点搜索结果 + all_summaries = set() - if node_result and hasattr(node_result, 'nodes') and node_result.nodes: + if node_result and node_result.nodes: for node in node_result.nodes: - if hasattr(node, 'summary') and node.summary: + if node.summary: all_summaries.add(node.summary) - if hasattr(node, 'name') and node.name and node.name != entity_name: - all_summaries.add(f"相关实体: {node.name}") + if node.name and node.name != entity_name: + all_summaries.add(f"Related entity: {node.name}") results["node_summaries"] = list(all_summaries) - # 构建综合上下文 + context_parts = [] if results["facts"]: - context_parts.append("事实信息:\n" + "\n".join(f"- {f}" for f in results["facts"][:20])) + context_parts.append("Facts:\n" + "\n".join(f"- {f}" for f in results["facts"][:20])) if results["node_summaries"]: - context_parts.append("相关实体:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10])) + context_parts.append("Related entities:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10])) results["context"] = "\n\n".join(context_parts) - logger.info(f"Zep混合检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, {len(results['node_summaries'])} 个相关节点") + logger.info( + f"Zep hybrid retrieval completed: {entity_name}, " + f"retrieved {len(results['facts'])} facts and {len(results['node_summaries'])} related nodes" + ) except concurrent.futures.TimeoutError: - logger.warning(f"Zep检索超时 ({entity_name})") + logger.warning(f"Zep retrieval timed out ({entity_name})") except Exception as e: - logger.warning(f"Zep检索失败 ({entity_name}): {e}") + logger.warning(f"Zep retrieval failed ({entity_name}): {e}") return results def _build_entity_context(self, entity: EntityNode) -> str: - """ - 构建实体的完整上下文信息 - - 包括: - 1. 实体本身的边信息(事实) - 2. 关联节点的详细信息 - 3. Zep混合检索到的丰富信息 - """ + """Build entity context.""" context_parts = [] - # 1. 添加实体属性信息 + if entity.attributes: attrs = [] for key, value in entity.attributes.items(): if value and str(value).strip(): attrs.append(f"- {key}: {value}") if attrs: - context_parts.append("### 实体属性\n" + "\n".join(attrs)) + context_parts.append("### Entity attributes\n" + "\n".join(attrs)) + - # 2. 添加相关边信息(事实/关系) existing_facts = set() if entity.related_edges: relationships = [] - for edge in entity.related_edges: # 不限制数量 + for edge in entity.related_edges: fact = edge.get("fact", "") edge_name = edge.get("edge_name", "") direction = edge.get("direction", "") @@ -444,22 +398,22 @@ def _build_entity_context(self, entity: EntityNode) -> str: existing_facts.add(fact) elif edge_name: if direction == "outgoing": - relationships.append(f"- {entity.name} --[{edge_name}]--> (相关实体)") + relationships.append(f"- {entity.name} --[{edge_name}]--> (related entity)") else: - relationships.append(f"- (相关实体) --[{edge_name}]--> {entity.name}") + relationships.append(f"- (related entity) --[{edge_name}]--> {entity.name}") if relationships: - context_parts.append("### 相关事实和关系\n" + "\n".join(relationships)) + context_parts.append("### Related facts and relationships\n" + "\n".join(relationships)) + - # 3. 添加关联节点的详细信息 if entity.related_nodes: related_info = [] - for node in entity.related_nodes: # 不限制数量 + for node in entity.related_nodes: node_name = node.get("name", "") node_labels = node.get("labels", []) node_summary = node.get("summary", "") - # 过滤掉默认标签 + custom_labels = [l for l in node_labels if l not in ["Entity", "Node"]] label_str = f" ({', '.join(custom_labels)})" if custom_labels else "" @@ -469,28 +423,28 @@ def _build_entity_context(self, entity: EntityNode) -> str: related_info.append(f"- **{node_name}**{label_str}") if related_info: - context_parts.append("### 关联实体信息\n" + "\n".join(related_info)) + context_parts.append("### Related entity details\n" + "\n".join(related_info)) + - # 4. 使用Zep混合检索获取更丰富的信息 zep_results = self._search_zep_for_entity(entity) if zep_results.get("facts"): - # 去重:排除已存在的事实 + new_facts = [f for f in zep_results["facts"] if f not in existing_facts] if new_facts: - context_parts.append("### Zep检索到的事实信息\n" + "\n".join(f"- {f}" for f in new_facts[:15])) + context_parts.append("### Facts retrieved from Zep\n" + "\n".join(f"- {f}" for f in new_facts[:15])) if zep_results.get("node_summaries"): - context_parts.append("### Zep检索到的相关节点\n" + "\n".join(f"- {s}" for s in zep_results["node_summaries"][:10])) + context_parts.append("### Related nodes retrieved from Zep\n" + "\n".join(f"- {s}" for s in zep_results["node_summaries"][:10])) return "\n\n".join(context_parts) def _is_individual_entity(self, entity_type: str) -> bool: - """判断是否是个人类型实体""" + """Return whether individual entity.""" return entity_type.lower() in self.INDIVIDUAL_ENTITY_TYPES def _is_group_entity(self, entity_type: str) -> bool: - """判断是否是群体/机构类型实体""" + """Return whether group entity.""" return entity_type.lower() in self.GROUP_ENTITY_TYPES def _generate_profile_with_llm( @@ -501,13 +455,7 @@ def _generate_profile_with_llm( entity_attributes: Dict[str, Any], context: str ) -> Dict[str, Any]: - """ - 使用LLM生成非常详细的人设 - - 根据实体类型区分: - - 个人实体:生成具体的人物设定 - - 群体/机构实体:生成代表性账号设定 - """ + """Generate profile with llm.""" is_individual = self._is_individual_entity(entity_type) @@ -520,7 +468,7 @@ def _generate_profile_with_llm( entity_name, entity_type, entity_summary, entity_attributes, context ) - # 尝试多次生成,直到成功或达到最大重试次数 + max_attempts = 3 last_error = None @@ -533,34 +481,34 @@ def _generate_profile_with_llm( {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 + temperature=0.7 - (attempt * 0.1) + ) content = response.choices[0].message.content - # 检查是否被截断(finish_reason不是'stop') + finish_reason = response.choices[0].finish_reason if finish_reason == 'length': - logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...") + logger.warning(f"LLM output was truncated (attempt {attempt+1}); attempting recovery...") content = self._fix_truncated_json(content) - # 尝试解析JSON + try: result = json.loads(content) - # 验证必需字段 + if "bio" not in result or not result["bio"]: result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" if "persona" not in result or not result["persona"]: - result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。" + result["persona"] = entity_summary or f"{entity_name} is a {entity_type}." return result except json.JSONDecodeError as je: - logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}") + logger.warning(f"JSON parsing failed (attempt {attempt+1}): {str(je)[:80]}") + - # 尝试修复JSON result = self._try_fix_json(content, entity_name, entity_type, entity_summary) if result.get("_fixed"): del result["_fixed"] @@ -569,75 +517,78 @@ def _generate_profile_with_llm( last_error = je except Exception as e: - logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") + logger.warning(f"LLM call failed (attempt {attempt+1}): {str(e)[:80]}") last_error = e import time - time.sleep(1 * (attempt + 1)) # 指数退避 + time.sleep(1 * (attempt + 1)) - logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成") + logger.warning( + f"LLM persona generation failed after {max_attempts} attempts: " + f"{last_error}. Falling back to rule-based generation" + ) return self._generate_profile_rule_based( entity_name, entity_type, entity_summary, entity_attributes ) def _fix_truncated_json(self, content: str) -> str: - """修复被截断的JSON(输出被max_tokens限制截断)""" + """Fix truncated json.""" import re - # 如果JSON被截断,尝试闭合它 + content = content.strip() - # 计算未闭合的括号 + open_braces = content.count('{') - content.count('}') open_brackets = content.count('[') - content.count(']') - # 检查是否有未闭合的字符串 - # 简单检查:如果最后一个引号后没有逗号或闭合括号,可能是字符串被截断 + + if content and content[-1] not in '",}]': - # 尝试闭合字符串 + content += '"' - # 闭合括号 + content += ']' * open_brackets content += '}' * open_braces return content def _try_fix_json(self, content: str, entity_name: str, entity_type: str, entity_summary: str = "") -> Dict[str, Any]: - """尝试修复损坏的JSON""" + """Try fix json.""" import re - # 1. 首先尝试修复被截断的情况 + content = self._fix_truncated_json(content) - # 2. 尝试提取JSON部分 + json_match = re.search(r'\{[\s\S]*\}', content) if json_match: json_str = json_match.group() - # 3. 处理字符串中的换行符问题 - # 找到所有字符串值并替换其中的换行符 + + def fix_string_newlines(match): s = match.group(0) - # 替换字符串内的实际换行符为空格 + s = s.replace('\n', ' ').replace('\r', ' ') - # 替换多余空格 + s = re.sub(r'\s+', ' ', s) return s - # 匹配JSON字符串值 + json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str) - # 4. 尝试解析 + try: result = json.loads(json_str) result["_fixed"] = True return result except json.JSONDecodeError as e: - # 5. 如果还是失败,尝试更激进的修复 + try: - # 移除所有控制字符 + json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) - # 替换所有连续空白 + json_str = re.sub(r'\s+', ' ', json_str) result = json.loads(json_str) result["_fixed"] = True @@ -645,33 +596,88 @@ def fix_string_newlines(match): except: pass - # 6. 尝试从内容中提取部分信息 + bio_match = re.search(r'"bio"\s*:\s*"([^"]*)"', content) - persona_match = re.search(r'"persona"\s*:\s*"([^"]*)', content) # 可能被截断 + persona_match = re.search(r'"persona"\s*:\s*"([^"]*)', content) bio = bio_match.group(1) if bio_match else (entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}") - persona = persona_match.group(1) if persona_match else (entity_summary or f"{entity_name}是一个{entity_type}。") + persona = persona_match.group(1) if persona_match else (entity_summary or f"{entity_name} is a {entity_type}.") + - # 如果提取到了有意义的内容,标记为已修复 if bio_match or persona_match: - logger.info(f"从损坏的JSON中提取了部分信息") + logger.info("Recovered partial information from malformed JSON") return { "bio": bio, "persona": persona, "_fixed": True } - # 7. 完全失败,返回基础结构 - logger.warning(f"JSON修复失败,返回基础结构") + + logger.warning("JSON repair failed; returning base structure") return { "bio": entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}", - "persona": entity_summary or f"{entity_name}是一个{entity_type}。" + "persona": entity_summary or f"{entity_name} is a {entity_type}." } + + @staticmethod + def _contains_cjk(value: Any) -> bool: + if isinstance(value, str): + return bool(re.search(r'[\u4e00-\u9fff]', value)) + if isinstance(value, list): + return any(OasisProfileGenerator._contains_cjk(item) for item in value) + return False + + def _ensure_profile_english(self, profile_data: Dict[str, Any]) -> Dict[str, Any]: + fields_to_check = ["bio", "persona", "country", "profession", "interested_topics"] + if not any(self._contains_cjk(profile_data.get(field)) for field in fields_to_check): + return profile_data + + payload = { + "bio": profile_data.get("bio", ""), + "persona": profile_data.get("persona", ""), + "country": profile_data.get("country", ""), + "profession": profile_data.get("profession", ""), + "interested_topics": profile_data.get("interested_topics", []), + } + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "system", + "content": ( + "Translate the provided social-media profile fields into natural English. " + "Return valid JSON only. Preserve the original meaning, tone, and structure. " + "Keep list fields as arrays, keep empty values empty, and render country names in English." + ), + }, + { + "role": "user", + "content": json.dumps(payload, ensure_ascii=False), + }, + ], + response_format={"type": "json_object"}, + temperature=0.1, + ) + + translated = json.loads(response.choices[0].message.content) + for field in fields_to_check: + if field in translated and translated[field] not in (None, ""): + profile_data[field] = translated[field] + except Exception as e: + logger.warning(f"Failed to normalize generated profile text to English: {e}") + + return profile_data def _get_system_prompt(self, is_individual: bool) -> str: - """获取系统提示词""" - base_prompt = "你是社交媒体用户画像生成专家。生成详细、真实的人设用于舆论模拟,最大程度还原已有现实情况。必须返回有效的JSON格式,所有字符串值不能包含未转义的换行符。使用中文。" - return base_prompt + """Get system prompt.""" + return ( + "You are an expert at generating realistic social-media personas for simulation. " + "Return valid JSON only. All free-text fields must be written in English, " + "all string values must be single-line strings without unescaped newlines, " + "and the output should stay faithful to the provided entity context." + ) def _build_individual_persona_prompt( self, @@ -681,45 +687,44 @@ def _build_individual_persona_prompt( entity_attributes: Dict[str, Any], context: str ) -> str: - """构建个人实体的详细人设提示词""" + """Build individual persona prompt.""" - attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "无" - context_str = context[:3000] if context else "无额外上下文" + attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None" + context_str = context[:3000] if context else "No additional context" - return f"""为实体生成详细的社交媒体用户人设,最大程度还原已有现实情况。 + return f"""Generate a detailed social-media persona for the following entity and keep it as faithful as possible to the source context. -实体名称: {entity_name} -实体类型: {entity_type} -实体摘要: {entity_summary} -实体属性: {attrs_str} +Entity name: {entity_name} +Entity type: {entity_type} +Entity summary: {entity_summary} +Entity attributes: {attrs_str} -上下文信息: +Context: {context_str} -请生成JSON,包含以下字段: - -1. bio: 社交媒体简介,200字 -2. persona: 详细人设描述(2000字的纯文本),需包含: - - 基本信息(年龄、职业、教育背景、所在地) - - 人物背景(重要经历、与事件的关联、社会关系) - - 性格特征(MBTI类型、核心性格、情绪表达方式) - - 社交媒体行为(发帖频率、内容偏好、互动风格、语言特点) - - 立场观点(对话题的态度、可能被激怒/感动的内容) - - 独特特征(口头禅、特殊经历、个人爱好) - - 个人记忆(人设的重要部分,要介绍这个个体与事件的关联,以及这个个体在事件中的已有动作与反应) -3. age: 年龄数字(必须是整数) -4. gender: 性别,必须是英文: "male" 或 "female" -5. mbti: MBTI类型(如INTJ、ENFP等) -6. country: 国家(使用中文,如"中国") -7. profession: 职业 -8. interested_topics: 感兴趣话题数组 +Return JSON with these fields: +1. bio: an English social-media bio, concise and natural +2. persona: a detailed English plain-text persona profile that covers: + - basic information (age, profession, education, location) + - background and relationship to the event + - personality traits, emotional style, and MBTI + - social-media behavior (posting habits, content preferences, interaction style, language style) + - stance and likely reactions to relevant topics + - distinctive traits, habits, or interests + - personal memory relevant to the event and prior reactions/actions +3. age: an integer +4. gender: "male" or "female" +5. mbti: an MBTI type such as INTJ or ENFP +6. country: country name in English, for example "China" +7. profession: profession in English +8. interested_topics: an array of topics in English -重要: -- 所有字段值必须是字符串或数字,不要使用换行符 -- persona必须是一段连贯的文字描述 -- 使用中文(除了gender字段必须用英文male/female) -- 内容要与实体信息保持一致 -- age必须是有效的整数,gender必须是"male"或"female" +Requirements: +- All free-text fields must be written in English +- All values must be strings, numbers, or arrays, with no unescaped newlines inside strings +- persona must be a single coherent paragraph +- Keep the content aligned with the provided entity information +- age must be a valid integer and gender must be either "male" or "female" """ def _build_group_persona_prompt( @@ -730,45 +735,44 @@ def _build_group_persona_prompt( entity_attributes: Dict[str, Any], context: str ) -> str: - """构建群体/机构实体的详细人设提示词""" + """Build group persona prompt.""" - attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "无" - context_str = context[:3000] if context else "无额外上下文" + attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None" + context_str = context[:3000] if context else "No additional context" - return f"""为机构/群体实体生成详细的社交媒体账号设定,最大程度还原已有现实情况。 + return f"""Generate a detailed social-media account profile for an organization or group entity and keep it faithful to the available context. -实体名称: {entity_name} -实体类型: {entity_type} -实体摘要: {entity_summary} -实体属性: {attrs_str} +Entity name: {entity_name} +Entity type: {entity_type} +Entity summary: {entity_summary} +Entity attributes: {attrs_str} -上下文信息: +Context: {context_str} -请生成JSON,包含以下字段: +Return JSON with these fields: +1. bio: a professional English account bio +2. persona: a detailed English plain-text account profile that covers: + - core organization information and mission + - account positioning, audience, and purpose + - communication style, common phrasing, and sensitive topics + - publishing habits, active windows, and content patterns + - institutional stance on core issues and conflict handling style + - operational habits or representative group traits + - institutional memory relevant to the event and prior reactions/actions +3. age: fixed integer 30 +4. gender: fixed string "other" +5. mbti: an MBTI type that reflects the account style +6. country: country name in English, for example "China" +7. profession: organization role or function in English +8. interested_topics: an array of focus areas in English -1. bio: 官方账号简介,200字,专业得体 -2. persona: 详细账号设定描述(2000字的纯文本),需包含: - - 机构基本信息(正式名称、机构性质、成立背景、主要职能) - - 账号定位(账号类型、目标受众、核心功能) - - 发言风格(语言特点、常用表达、禁忌话题) - - 发布内容特点(内容类型、发布频率、活跃时间段) - - 立场态度(对核心话题的官方立场、面对争议的处理方式) - - 特殊说明(代表的群体画像、运营习惯) - - 机构记忆(机构人设的重要部分,要介绍这个机构与事件的关联,以及这个机构在事件中的已有动作与反应) -3. age: 固定填30(机构账号的虚拟年龄) -4. gender: 固定填"other"(机构账号使用other表示非个人) -5. mbti: MBTI类型,用于描述账号风格,如ISTJ代表严谨保守 -6. country: 国家(使用中文,如"中国") -7. profession: 机构职能描述 -8. interested_topics: 关注领域数组 - -重要: -- 所有字段值必须是字符串或数字,不允许null值 -- persona必须是一段连贯的文字描述,不要使用换行符 -- 使用中文(除了gender字段必须用英文"other") -- age必须是整数30,gender必须是字符串"other" -- 机构账号发言要符合其身份定位""" +Requirements: +- All free-text fields must be written in English +- No null values and no unescaped newlines inside strings +- persona must be a single coherent paragraph +- age must be 30 and gender must be "other" +- The account voice must match the institution's identity and role""" def _generate_profile_rule_based( self, @@ -777,9 +781,9 @@ def _generate_profile_rule_based( entity_summary: str, entity_attributes: Dict[str, Any] ) -> Dict[str, Any]: - """使用规则生成基础人设""" + """Generate profile rule based.""" + - # 根据实体类型生成不同的人设 entity_type_lower = entity_type.lower() if entity_type_lower in ["student", "alumni"]: @@ -810,10 +814,10 @@ def _generate_profile_rule_based( return { "bio": f"Official account for {entity_name}. News and updates.", "persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.", - "age": 30, # 机构虚拟年龄 - "gender": "other", # 机构使用other - "mbti": "ISTJ", # 机构风格:严谨保守 - "country": "中国", + "age": 30, + "gender": "other", + "mbti": "ISTJ", + "country": "China", "profession": "Media", "interested_topics": ["General News", "Current Events", "Public Affairs"], } @@ -822,16 +826,16 @@ def _generate_profile_rule_based( return { "bio": f"Official account of {entity_name}.", "persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.", - "age": 30, # 机构虚拟年龄 - "gender": "other", # 机构使用other - "mbti": "ISTJ", # 机构风格:严谨保守 - "country": "中国", + "age": 30, + "gender": "other", + "mbti": "ISTJ", + "country": "China", "profession": entity_type, "interested_topics": ["Public Policy", "Community", "Official Announcements"], } else: - # 默认人设 + return { "bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}", "persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.", @@ -844,7 +848,7 @@ def _generate_profile_rule_based( } def set_graph_id(self, graph_id: str): - """设置图谱ID用于Zep检索""" + """Set graph id.""" self.graph_id = graph_id def generate_profiles_from_entities( @@ -857,53 +861,39 @@ def generate_profiles_from_entities( realtime_output_path: Optional[str] = None, output_platform: str = "reddit" ) -> List[OasisAgentProfile]: - """ - 批量从实体生成Agent Profile(支持并行生成) - - Args: - entities: 实体列表 - use_llm: 是否使用LLM生成详细人设 - progress_callback: 进度回调函数 (current, total, message) - graph_id: 图谱ID,用于Zep检索获取更丰富上下文 - parallel_count: 并行生成数量,默认5 - realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次) - output_platform: 输出平台格式 ("reddit" 或 "twitter") - - Returns: - Agent Profile列表 - """ + """Generate profiles from entities.""" import concurrent.futures from threading import Lock - # 设置graph_id用于Zep检索 + if graph_id: self.graph_id = graph_id total = len(entities) - profiles = [None] * total # 预分配列表保持顺序 - completed_count = [0] # 使用列表以便在闭包中修改 + profiles = [None] * total + completed_count = [0] lock = Lock() - # 实时写入文件的辅助函数 + def save_profiles_realtime(): - """实时保存已生成的 profiles 到文件""" + """Save profiles realtime.""" if not realtime_output_path: return with lock: - # 过滤出已生成的 profiles + existing_profiles = [p for p in profiles if p is not None] if not existing_profiles: return try: if output_platform == "reddit": - # Reddit JSON 格式 + profiles_data = [p.to_reddit_format() for p in existing_profiles] with open(realtime_output_path, 'w', encoding='utf-8') as f: json.dump(profiles_data, f, ensure_ascii=False, indent=2) else: - # Twitter CSV 格式 + import csv profiles_data = [p.to_twitter_format() for p in existing_profiles] if profiles_data: @@ -913,10 +903,10 @@ def save_profiles_realtime(): writer.writeheader() writer.writerows(profiles_data) except Exception as e: - logger.warning(f"实时保存 profiles 失败: {e}") + logger.warning(f"Realtime profile save failed: {e}") def generate_single_profile(idx: int, entity: EntityNode) -> tuple: - """生成单个profile的工作函数""" + """Generate single profile.""" entity_type = entity.get_entity_type() or "Entity" try: @@ -926,14 +916,14 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: use_llm=use_llm ) - # 实时输出生成的人设到控制台和日志 + self._print_generated_profile(entity.name, entity_type, profile) return idx, profile, None except Exception as e: - logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}") - # 创建一个基础profile + logger.error(f"Failed to generate persona for entity {entity.name}: {str(e)}") + fallback_profile = OasisAgentProfile( user_id=idx, user_name=self._generate_username(entity.name), @@ -945,20 +935,20 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: ) return idx, fallback_profile, str(e) - logger.info(f"开始并行生成 {total} 个Agent人设(并行数: {parallel_count})...") + logger.info(f"Starting parallel generation of {total} agent personas (parallelism: {parallel_count})...") print(f"\n{'='*60}") - print(f"开始生成Agent人设 - 共 {total} 个实体,并行数: {parallel_count}") + print(f"Starting agent persona generation - {total} entities total, parallelism: {parallel_count}") print(f"{'='*60}\n") - # 使用线程池并行执行 + with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_count) as executor: - # 提交所有任务 + future_to_entity = { executor.submit(generate_single_profile, idx, entity): (idx, entity) for idx, entity in enumerate(entities) } - # 收集结果 + for future in concurrent.futures.as_completed(future_to_entity): idx, entity = future_to_entity[future] entity_type = entity.get_entity_type() or "Entity" @@ -971,23 +961,23 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: completed_count[0] += 1 current = completed_count[0] - # 实时写入文件 + save_profiles_realtime() if progress_callback: progress_callback( current, total, - f"已完成 {current}/{total}: {entity.name}({entity_type})" + f"Completed {current}/{total}: {entity.name} ({entity_type})" ) if error: - logger.warning(f"[{current}/{total}] {entity.name} 使用备用人设: {error}") + logger.warning(f"[{current}/{total}] {entity.name} used fallback persona: {error}") else: - logger.info(f"[{current}/{total}] 成功生成人设: {entity.name} ({entity_type})") + logger.info(f"[{current}/{total}] Successfully generated persona: {entity.name} ({entity_type})") except Exception as e: - logger.error(f"处理实体 {entity.name} 时发生异常: {str(e)}") + logger.error(f"Unexpected error while processing entity {entity.name}: {str(e)}") with lock: completed_count[0] += 1 profiles[idx] = OasisAgentProfile( @@ -999,44 +989,44 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) - # 实时写入文件(即使是备用人设) + save_profiles_realtime() print(f"\n{'='*60}") - print(f"人设生成完成!共生成 {len([p for p in profiles if p])} 个Agent") + print(f"Persona generation complete. Generated {len([p for p in profiles if p])} agents in total") print(f"{'='*60}\n") return profiles def _print_generated_profile(self, entity_name: str, entity_type: str, profile: OasisAgentProfile): - """实时输出生成的人设到控制台(完整内容,不截断)""" + """Print Generated Profile.""" separator = "-" * 70 - # 构建完整输出内容(不截断) - topics_str = ', '.join(profile.interested_topics) if profile.interested_topics else '无' + + topics_str = ', '.join(profile.interested_topics) if profile.interested_topics else 'None' output_lines = [ f"\n{separator}", - f"[已生成] {entity_name} ({entity_type})", + f"[Generated] {entity_name} ({entity_type})", f"{separator}", - f"用户名: {profile.user_name}", + f"Username: {profile.user_name}", f"", - f"【简介】", + f"[Bio]", f"{profile.bio}", f"", - f"【详细人设】", + f"[Detailed Persona]", f"{profile.persona}", f"", - f"【基本属性】", - f"年龄: {profile.age} | 性别: {profile.gender} | MBTI: {profile.mbti}", - f"职业: {profile.profession} | 国家: {profile.country}", - f"兴趣话题: {topics_str}", + f"[Basic Attributes]", + f"Age: {profile.age} | Gender: {profile.gender} | MBTI: {profile.mbti}", + f"Profession: {profile.profession} | Country: {profile.country}", + f"Interested topics: {topics_str}", separator ] output = "\n".join(output_lines) - # 只输出到控制台(避免重复,logger不再输出完整内容) + print(output) def save_profiles( @@ -1045,92 +1035,64 @@ def save_profiles( file_path: str, platform: str = "reddit" ): - """ - 保存Profile到文件(根据平台选择正确格式) - - OASIS平台格式要求: - - Twitter: CSV格式 - - Reddit: JSON格式 - - Args: - profiles: Profile列表 - file_path: 文件路径 - platform: 平台类型 ("reddit" 或 "twitter") - """ + """Save profiles.""" if platform == "twitter": self._save_twitter_csv(profiles, file_path) else: self._save_reddit_json(profiles, file_path) def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str): - """ - 保存Twitter Profile为CSV格式(符合OASIS官方要求) - - OASIS Twitter要求的CSV字段: - - user_id: 用户ID(根据CSV顺序从0开始) - - name: 用户真实姓名 - - username: 系统中的用户名 - - user_char: 详细人设描述(注入到LLM系统提示中,指导Agent行为) - - description: 简短的公开简介(显示在用户资料页面) - - user_char vs description 区别: - - user_char: 内部使用,LLM系统提示,决定Agent如何思考和行动 - - description: 外部显示,其他用户可见的简介 - """ + """Save twitter csv.""" import csv - # 确保文件扩展名是.csv + if not file_path.endswith('.csv'): file_path = file_path.replace('.json', '.csv') with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) - # 写入OASIS要求的表头 + headers = ['user_id', 'name', 'username', 'user_char', 'description'] writer.writerow(headers) - # 写入数据行 + for idx, profile in enumerate(profiles): - # user_char: 完整人设(bio + persona),用于LLM系统提示 + user_char = profile.bio if profile.persona and profile.persona != profile.bio: user_char = f"{profile.bio} {profile.persona}" - # 处理换行符(CSV中用空格替代) + user_char = user_char.replace('\n', ' ').replace('\r', ' ') - # description: 简短简介,用于外部显示 + description = profile.bio.replace('\n', ' ').replace('\r', ' ') row = [ - idx, # user_id: 从0开始的顺序ID - profile.name, # name: 真实姓名 - profile.user_name, # username: 用户名 - user_char, # user_char: 完整人设(内部LLM使用) - description # description: 简短简介(外部显示) + idx, + profile.name, + profile.user_name, + user_char, + description ] writer.writerow(row) - logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)") + logger.info(f"Saved {len(profiles)} Twitter profiles to {file_path} (OASIS CSV format)") def _normalize_gender(self, gender: Optional[str]) -> str: - """ - 标准化gender字段为OASIS要求的英文格式 - - OASIS要求: male, female, other - """ + """Normalize Gender.""" if not gender: return "other" gender_lower = gender.lower().strip() - # 中文映射 + gender_map = { - "男": "male", - "女": "female", - "机构": "other", - "其他": "other", - # 英文已有 + "\u7537": "male", + "\u5973": "female", + "\u673a\u6784": "other", + "\u5176\u4ed6": "other", + "male": "male", "female": "female", "other": "other", @@ -1139,42 +1101,26 @@ def _normalize_gender(self, gender: Optional[str]) -> str: return gender_map.get(gender_lower, "other") def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): - """ - 保存Reddit Profile为JSON格式 - - 使用与 to_reddit_format() 一致的格式,确保 OASIS 能正确读取。 - 必须包含 user_id 字段,这是 OASIS agent_graph.get_agent() 匹配的关键! - - 必需字段: - - user_id: 用户ID(整数,用于匹配 initial_posts 中的 poster_agent_id) - - username: 用户名 - - name: 显示名称 - - bio: 简介 - - persona: 详细人设 - - age: 年龄(整数) - - gender: "male", "female", 或 "other" - - mbti: MBTI类型 - - country: 国家 - """ + """Save reddit json.""" data = [] for idx, profile in enumerate(profiles): - # 使用与 to_reddit_format() 一致的格式 + item = { - "user_id": profile.user_id if profile.user_id is not None else idx, # 关键:必须包含 user_id + "user_id": profile.user_id if profile.user_id is not None else idx, "username": profile.user_name, "name": profile.name, "bio": profile.bio[:150] if profile.bio else f"{profile.name}", "persona": profile.persona or f"{profile.name} is a participant in social discussions.", "karma": profile.karma if profile.karma else 1000, "created_at": profile.created_at, - # OASIS必需字段 - 确保都有默认值 + "age": profile.age if profile.age else 30, "gender": self._normalize_gender(profile.gender), "mbti": profile.mbti if profile.mbti else "ISTJ", - "country": profile.country if profile.country else "中国", + "country": profile.country if profile.country else "China", } - # 可选字段 + if profile.profession: item["profession"] = profile.profession if profile.interested_topics: @@ -1185,16 +1131,15 @@ def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) - logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON格式,包含user_id字段)") + logger.info(f"Saved {len(profiles)} Reddit profiles to {file_path} (JSON format with user_id field)") + - # 保留旧方法名作为别名,保持向后兼容 def save_profiles_to_json( self, profiles: List[OasisAgentProfile], file_path: str, platform: str = "reddit" ): - """[已废弃] 请使用 save_profiles() 方法""" - logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") + """Save profiles to json.""" + logger.warning("save_profiles_to_json is deprecated; use save_profiles instead") self.save_profiles(profiles, file_path, platform) - diff --git a/backend/app/services/ontology_generator.py b/backend/app/services/ontology_generator.py index 2d3e39bd8..ef25adf5f 100644 --- a/backend/app/services/ontology_generator.py +++ b/backend/app/services/ontology_generator.py @@ -1,165 +1,160 @@ -""" -本体生成服务 -接口1:分析文本内容,生成适合社会模拟的实体和关系类型定义 -""" +"""Ontology generation service.""" import json from typing import Dict, Any, List, Optional from ..utils.llm_client import LLMClient +from ..utils.ontology_normalizer import normalize_ontology_for_zep -# 本体生成的系统提示词 -ONTOLOGY_SYSTEM_PROMPT = """你是一个专业的知识图谱本体设计专家。你的任务是分析给定的文本内容和模拟需求,设计适合**社交媒体舆论模拟**的实体类型和关系类型。 -**重要:你必须输出有效的JSON格式数据,不要输出任何其他内容。** +ONTOLOGY_SYSTEM_PROMPT = """You are an expert knowledge-graph ontology designer. Your task is to analyze the provided text and simulation requirement, then design entity types and relationship types suitable for a **social media opinion simulation**. -## 核心任务背景 +**Important: you must output valid JSON only. Do not output any additional text.** -我们正在构建一个**社交媒体舆论模拟系统**。在这个系统中: -- 每个实体都是一个可以在社交媒体上发声、互动、传播信息的"账号"或"主体" -- 实体之间会相互影响、转发、评论、回应 -- 我们需要模拟舆论事件中各方的反应和信息传播路径 +## Core task background -因此,**实体必须是现实中真实存在的、可以在社媒上发声和互动的主体**: +We are building a **social media opinion simulation system**. In this system: +- Each entity is an account or actor that can speak, interact, and spread information on social media. +- Entities influence one another, repost, comment, and respond. +- We need to simulate how different parties react during a public-opinion event and how information spreads. -**可以是**: -- 具体的个人(公众人物、当事人、意见领袖、专家学者、普通人) -- 公司、企业(包括其官方账号) -- 组织机构(大学、协会、NGO、工会等) -- 政府部门、监管机构 -- 媒体机构(报纸、电视台、自媒体、网站) -- 社交媒体平台本身 -- 特定群体代表(如校友会、粉丝团、维权群体等) +Therefore, **entities must be real-world actors that can speak and interact on social media**: -**不可以是**: -- 抽象概念(如"舆论"、"情绪"、"趋势") -- 主题/话题(如"学术诚信"、"教育改革") -- 观点/态度(如"支持方"、"反对方") +**Allowed**: +- Specific individuals such as public figures, people directly involved, opinion leaders, scholars, experts, or ordinary people +- Companies and businesses, including their official accounts +- Organizations and institutions, such as universities, associations, NGOs, and unions +- Government departments and regulators +- Media organizations, such as newspapers, TV stations, self-media accounts, and websites +- Social media platforms themselves +- Representatives of specific groups, such as alumni associations, fan groups, or rights-advocacy groups -## 输出格式 +**Not allowed**: +- Abstract concepts such as "public opinion", "emotion", or "trend" +- Topics or themes such as "academic integrity" or "education reform" +- Positions or attitudes such as "supporters" or "opponents" -请输出JSON格式,包含以下结构: +## Output format + +Output JSON in the following structure: ```json { "entity_types": [ { - "name": "实体类型名称(英文,PascalCase)", - "description": "简短描述(英文,不超过100字符)", + "name": "Entity type name (English, PascalCase)", + "description": "Short description (English, under 100 characters)", "attributes": [ { - "name": "属性名(英文,snake_case)", + "name": "Attribute name (English, snake_case)", "type": "text", - "description": "属性描述" + "description": "Attribute description" } ], - "examples": ["示例实体1", "示例实体2"] + "examples": ["Example entity 1", "Example entity 2"] } ], "edge_types": [ { - "name": "关系类型名称(英文,UPPER_SNAKE_CASE)", - "description": "简短描述(英文,不超过100字符)", + "name": "Relationship type name (English, UPPER_SNAKE_CASE)", + "description": "Short description (English, under 100 characters)", "source_targets": [ - {"source": "源实体类型", "target": "目标实体类型"} + {"source": "Source entity type (must exactly match an entity type name)", "target": "Target entity type (must exactly match an entity type name)"} ], "attributes": [] } ], - "analysis_summary": "对文本内容的简要分析说明(中文)" + "analysis_summary": "Brief analysis of the content (English)" } ``` -## 设计指南(极其重要!) +## Design guidelines (very important) -### 1. 实体类型设计 - 必须严格遵守 +### 1. Entity type design - must be followed strictly -**数量要求:必须正好10个实体类型** +**Quantity requirement: exactly 10 entity types** -**层次结构要求(必须同时包含具体类型和兜底类型)**: +**Hierarchy requirement (must include both specific and fallback types)**: -你的10个实体类型必须包含以下层次: +Your 10 entity types must follow this structure: -A. **兜底类型(必须包含,放在列表最后2个)**: - - `Person`: 任何自然人个体的兜底类型。当一个人不属于其他更具体的人物类型时,归入此类。 - - `Organization`: 任何组织机构的兜底类型。当一个组织不属于其他更具体的组织类型时,归入此类。 +A. **Fallback types (must be included and placed as the last 2 items)**: + - `Person`: Fallback type for any individual human being. Use this when a person does not belong to a more specific person category. + - `Organization`: Fallback type for any organization or institution. Use this when an organization does not belong to a more specific organization category. -B. **具体类型(8个,根据文本内容设计)**: - - 针对文本中出现的主要角色,设计更具体的类型 - - 例如:如果文本涉及学术事件,可以有 `Student`, `Professor`, `University` - - 例如:如果文本涉及商业事件,可以有 `Company`, `CEO`, `Employee` +B. **Specific types (8, designed based on the text)**: + - Design more specific types for the main roles that appear in the text. + - Example: for an academic event, types might include `Student`, `Professor`, and `University`. + - Example: for a business event, types might include `Company`, `CEO`, and `Employee`. -**为什么需要兜底类型**: -- 文本中会出现各种人物,如"中小学教师"、"路人甲"、"某位网友" -- 如果没有专门的类型匹配,他们应该被归入 `Person` -- 同理,小型组织、临时团体等应该归入 `Organization` +**Why fallback types are needed**: +- The text may mention many kinds of people, such as school teachers, bystanders, or anonymous netizens. +- If no specific type fits them, they should be classified as `Person`. +- Likewise, small organizations or temporary groups should fall under `Organization`. -**具体类型的设计原则**: -- 从文本中识别出高频出现或关键的角色类型 -- 每个具体类型应该有明确的边界,避免重叠 -- description 必须清晰说明这个类型和兜底类型的区别 +**Specific type design principles**: +- Identify high-frequency or important role categories from the text. +- Each specific type should have a clear boundary and avoid overlap. +- The description must clearly explain how this type differs from the fallback type. -### 2. 关系类型设计 +### 2. Relationship type design -- 数量:6-10个 -- 关系应该反映社媒互动中的真实联系 -- 确保关系的 source_targets 涵盖你定义的实体类型 +- Quantity: 6-10 relationship types +- Relationships should reflect realistic social-media interactions and ties +- Make sure the `source_targets` cover the entity types you define -### 3. 属性设计 +### 3. Attribute design -- 每个实体类型1-3个关键属性 -- **注意**:属性名不能使用 `name`、`uuid`、`group_id`、`created_at`、`summary`(这些是系统保留字) -- 推荐使用:`full_name`, `title`, `role`, `position`, `location`, `description` 等 +- Each entity type should have 1-3 key attributes +- **Important**: attribute names cannot use `name`, `uuid`, `group_id`, `created_at`, or `summary` because these are reserved system fields +- Recommended names include `full_name`, `title`, `role`, `position`, `location`, and `description` -## 实体类型参考 +## Entity type references -**个人类(具体)**: -- Student: 学生 -- Professor: 教授/学者 -- Journalist: 记者 -- Celebrity: 明星/网红 -- Executive: 高管 -- Official: 政府官员 -- Lawyer: 律师 -- Doctor: 医生 +**Person types (specific)**: +- Student: students +- Professor: professors or scholars +- Journalist: journalists +- Celebrity: celebrities or influencers +- Executive: business executives +- Official: government officials +- Lawyer: lawyers +- Doctor: doctors -**个人类(兜底)**: -- Person: 任何自然人(不属于上述具体类型时使用) +**Person type (fallback)**: +- Person: any human individual not covered by a more specific person type -**组织类(具体)**: -- University: 高校 -- Company: 公司企业 -- GovernmentAgency: 政府机构 -- MediaOutlet: 媒体机构 -- Hospital: 医院 -- School: 中小学 -- NGO: 非政府组织 +**Organization types (specific)**: +- University: universities +- Company: companies and businesses +- GovernmentAgency: government agencies +- MediaOutlet: media organizations +- Hospital: hospitals +- School: primary and secondary schools +- NGO: non-governmental organizations -**组织类(兜底)**: -- Organization: 任何组织机构(不属于上述具体类型时使用) +**Organization type (fallback)**: +- Organization: any organization not covered by a more specific organization type -## 关系类型参考 +## Relationship type references -- WORKS_FOR: 工作于 -- STUDIES_AT: 就读于 -- AFFILIATED_WITH: 隶属于 -- REPRESENTS: 代表 -- REGULATES: 监管 -- REPORTS_ON: 报道 -- COMMENTS_ON: 评论 -- RESPONDS_TO: 回应 -- SUPPORTS: 支持 -- OPPOSES: 反对 -- COLLABORATES_WITH: 合作 -- COMPETES_WITH: 竞争 +- WORKS_FOR: works for +- STUDIES_AT: studies at +- AFFILIATED_WITH: is affiliated with +- REPRESENTS: represents +- REGULATES: regulates +- REPORTS_ON: reports on +- COMMENTS_ON: comments on +- RESPONDS_TO: responds to +- SUPPORTS: supports +- OPPOSES: opposes +- COLLABORATES_WITH: collaborates with +- COMPETES_WITH: competes with """ class OntologyGenerator: - """ - 本体生成器 - 分析文本内容,生成实体和关系类型定义 - """ + """Ontology Generator.""" def __init__(self, llm_client: Optional[LLMClient] = None): self.llm_client = llm_client or LLMClient() @@ -170,18 +165,8 @@ def generate( simulation_requirement: str, additional_context: Optional[str] = None ) -> Dict[str, Any]: - """ - 生成本体定义 + """Generate the requested object.""" - Args: - document_texts: 文档文本列表 - simulation_requirement: 模拟需求描述 - additional_context: 额外上下文 - - Returns: - 本体定义(entity_types, edge_types等) - """ - # 构建用户消息 user_message = self._build_user_message( document_texts, simulation_requirement, @@ -193,19 +178,19 @@ def generate( {"role": "user", "content": user_message} ] - # 调用LLM + result = self.llm_client.chat_json( messages=messages, temperature=0.3, max_tokens=4096 ) - # 验证和后处理 + result = self._validate_and_process(result) return result - # 传给 LLM 的文本最大长度(5万字) + MAX_TEXT_LENGTH_FOR_LLM = 50000 def _build_user_message( @@ -214,50 +199,54 @@ def _build_user_message( simulation_requirement: str, additional_context: Optional[str] ) -> str: - """构建用户消息""" + """Build user message.""" + - # 合并文本 combined_text = "\n\n---\n\n".join(document_texts) original_length = len(combined_text) - # 如果文本超过5万字,截断(仅影响传给LLM的内容,不影响图谱构建) + if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM: combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM] - combined_text += f"\n\n...(原文共{original_length}字,已截取前{self.MAX_TEXT_LENGTH_FOR_LLM}字用于本体分析)..." + combined_text += ( + f"\n\n...(Original text length: {original_length} characters. " + f"Only the first {self.MAX_TEXT_LENGTH_FOR_LLM} characters were used for ontology analysis.)..." + ) - message = f"""## 模拟需求 + message = f"""## Simulation Requirement {simulation_requirement} -## 文档内容 +## Document Content {combined_text} """ if additional_context: message += f""" -## 额外说明 +## Additional Notes {additional_context} """ message += """ -请根据以上内容,设计适合社会舆论模拟的实体类型和关系类型。 - -**必须遵守的规则**: -1. 必须正好输出10个实体类型 -2. 最后2个必须是兜底类型:Person(个人兜底)和 Organization(组织兜底) -3. 前8个是根据文本内容设计的具体类型 -4. 所有实体类型必须是现实中可以发声的主体,不能是抽象概念 -5. 属性名不能使用 name、uuid、group_id 等保留字,用 full_name、org_name 等替代 +Based on the content above, design entity types and relationship types suitable for a social-opinion simulation. + +**Rules you must follow**: +1. You must output exactly 10 entity types. +2. The last 2 entity types must be the fallback types: `Person` and `Organization`. +3. The first 8 entity types must be specific types designed from the text. +4. All entity types must be real-world actors that can speak or interact, not abstract concepts. +5. Attribute names cannot use reserved fields such as `name`, `uuid`, or `group_id`; use alternatives like `full_name` or `org_name`. +6. Entity type names must contain only letters and numbers. They cannot contain underscores, spaces, or hyphens. For example, `StudentLeader` is valid but `Student_Leader` is not. """ return message def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: - """验证和后处理结果""" + """Validate And Process.""" + - # 确保必要字段存在 if "entity_types" not in result: result["entity_types"] = [] if "edge_types" not in result: @@ -265,17 +254,17 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: if "analysis_summary" not in result: result["analysis_summary"] = "" - # 验证实体类型 + for entity in result["entity_types"]: if "attributes" not in entity: entity["attributes"] = [] if "examples" not in entity: entity["examples"] = [] - # 确保description不超过100字符 + if len(entity.get("description", "")) > 100: entity["description"] = entity["description"][:97] + "..." - # 验证关系类型 + for edge in result["edge_types"]: if "source_targets" not in edge: edge["source_targets"] = [] @@ -284,11 +273,11 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: if len(edge.get("description", "")) > 100: edge["description"] = edge["description"][:97] + "..." - # Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型 + MAX_ENTITY_TYPES = 10 MAX_EDGE_TYPES = 10 - # 兜底类型定义 + person_fallback = { "name": "Person", "description": "Any individual person not fitting other specific person types.", @@ -309,12 +298,12 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: "examples": ["small business", "community group"] } - # 检查是否已有兜底类型 + entity_names = {e["name"] for e in result["entity_types"]} has_person = "Person" in entity_names has_organization = "Organization" in entity_names - # 需要添加的兜底类型 + fallbacks_to_add = [] if not has_person: fallbacks_to_add.append(person_fallback) @@ -325,50 +314,43 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: current_count = len(result["entity_types"]) needed_slots = len(fallbacks_to_add) - # 如果添加后会超过 10 个,需要移除一些现有类型 + if current_count + needed_slots > MAX_ENTITY_TYPES: - # 计算需要移除多少个 + to_remove = current_count + needed_slots - MAX_ENTITY_TYPES - # 从末尾移除(保留前面更重要的具体类型) + result["entity_types"] = result["entity_types"][:-to_remove] - # 添加兜底类型 + result["entity_types"].extend(fallbacks_to_add) - # 最终确保不超过限制(防御性编程) + if len(result["entity_types"]) > MAX_ENTITY_TYPES: result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES] if len(result["edge_types"]) > MAX_EDGE_TYPES: result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES] - - return result + + normalized_result, _ = normalize_ontology_for_zep(result) + return normalized_result def generate_python_code(self, ontology: Dict[str, Any]) -> str: - """ - 将本体定义转换为Python代码(类似ontology.py) - - Args: - ontology: 本体定义 - - Returns: - Python代码字符串 - """ + """Generate python code.""" code_lines = [ '"""', - '自定义实体类型定义', - '由MiroFish自动生成,用于社会舆论模拟', + 'Custom entity type definitions', + 'Auto-generated by MiroFish for social-opinion simulation', '"""', '', 'from pydantic import Field', 'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel', '', '', - '# ============== 实体类型定义 ==============', + '# ============== Entity Type Definitions ==============', '', ] - # 生成实体类型 + for entity in ontology.get("entity_types", []): name = entity["name"] desc = entity.get("description", f"A {name} entity.") @@ -391,13 +373,13 @@ def generate_python_code(self, ontology: Dict[str, Any]) -> str: code_lines.append('') code_lines.append('') - code_lines.append('# ============== 关系类型定义 ==============') + code_lines.append('# ============== Relationship Type Definitions ==============') code_lines.append('') - # 生成关系类型 + for edge in ontology.get("edge_types", []): name = edge["name"] - # 转换为PascalCase类名 + class_name = ''.join(word.capitalize() for word in name.split('_')) desc = edge.get("description", f"A {name} relationship.") @@ -419,8 +401,8 @@ def generate_python_code(self, ontology: Dict[str, Any]) -> str: code_lines.append('') code_lines.append('') - # 生成类型字典 - code_lines.append('# ============== 类型配置 ==============') + + code_lines.append('# ============== Type Configuration ==============') code_lines.append('') code_lines.append('ENTITY_TYPES = {') for entity in ontology.get("entity_types", []): @@ -436,7 +418,7 @@ def generate_python_code(self, ontology: Dict[str, Any]) -> str: code_lines.append('}') code_lines.append('') - # 生成边的source_targets映射 + code_lines.append('EDGE_SOURCE_TARGETS = {') for edge in ontology.get("edge_types", []): name = edge["name"] @@ -450,4 +432,3 @@ def generate_python_code(self, ontology: Dict[str, Any]) -> str: code_lines.append('}') return '\n'.join(code_lines) - diff --git a/backend/app/services/report_agent.py b/backend/app/services/report_agent.py index 02ca5bdc2..641296c69 100644 --- a/backend/app/services/report_agent.py +++ b/backend/app/services/report_agent.py @@ -1,13 +1,4 @@ -""" -Report Agent服务 -使用LangChain + Zep实现ReACT模式的模拟报告生成 - -功能: -1. 根据模拟需求和Zep图谱信息生成报告 -2. 先规划目录结构,然后分段生成 -3. 每段采用ReACT多轮思考与反思模式 -4. 支持与用户对话,在对话中自主调用检索工具 -""" +"""Report Agent service.""" import os import json @@ -33,20 +24,10 @@ class ReportLogger: - """ - Report Agent 详细日志记录器 - - 在报告文件夹中生成 agent_log.jsonl 文件,记录每一步详细动作。 - 每行是一个完整的 JSON 对象,包含时间戳、动作类型、详细内容等。 - """ + """Report Logger.""" def __init__(self, report_id: str): - """ - 初始化日志记录器 - - Args: - report_id: 报告ID,用于确定日志文件路径 - """ + """Initialize the instance.""" self.report_id = report_id self.log_file_path = os.path.join( Config.UPLOAD_FOLDER, 'reports', report_id, 'agent_log.jsonl' @@ -55,12 +36,12 @@ def __init__(self, report_id: str): self._ensure_log_file() def _ensure_log_file(self): - """确保日志文件所在目录存在""" + """Ensure log file.""" log_dir = os.path.dirname(self.log_file_path) os.makedirs(log_dir, exist_ok=True) def _get_elapsed_time(self) -> float: - """获取从开始到现在的耗时(秒)""" + """Get elapsed time.""" return (datetime.now() - self.start_time).total_seconds() def log( @@ -71,16 +52,7 @@ def log( section_title: str = None, section_index: int = None ): - """ - 记录一条日志 - - Args: - action: 动作类型,如 'start', 'tool_call', 'llm_response', 'section_complete' 等 - stage: 当前阶段,如 'planning', 'generating', 'completed' - details: 详细内容字典,不截断 - section_title: 当前章节标题(可选) - section_index: 当前章节索引(可选) - """ + """Log.""" log_entry = { "timestamp": datetime.now().isoformat(), "elapsed_seconds": round(self._get_elapsed_time(), 2), @@ -92,12 +64,12 @@ def log( "details": details } - # 追加写入 JSONL 文件 + with open(self.log_file_path, 'a', encoding='utf-8') as f: f.write(json.dumps(log_entry, ensure_ascii=False) + '\n') def log_start(self, simulation_id: str, graph_id: str, simulation_requirement: str): - """记录报告生成开始""" + """Log Start.""" self.log( action="report_start", stage="pending", @@ -105,52 +77,52 @@ def log_start(self, simulation_id: str, graph_id: str, simulation_requirement: s "simulation_id": simulation_id, "graph_id": graph_id, "simulation_requirement": simulation_requirement, - "message": "报告生成任务开始" + "message": "Report generation started" } ) def log_planning_start(self): - """记录大纲规划开始""" + """Log Planning Start.""" self.log( action="planning_start", stage="planning", - details={"message": "开始规划报告大纲"} + details={"message": "Starting report outline planning"} ) def log_planning_context(self, context: Dict[str, Any]): - """记录规划时获取的上下文信息""" + """Log Planning Context.""" self.log( action="planning_context", stage="planning", details={ - "message": "获取模拟上下文信息", + "message": "Loaded simulation context", "context": context } ) def log_planning_complete(self, outline_dict: Dict[str, Any]): - """记录大纲规划完成""" + """Log Planning Complete.""" self.log( action="planning_complete", stage="planning", details={ - "message": "大纲规划完成", + "message": "Report outline planning completed", "outline": outline_dict } ) def log_section_start(self, section_title: str, section_index: int): - """记录章节生成开始""" + """Log Section Start.""" self.log( action="section_start", stage="generating", section_title=section_title, section_index=section_index, - details={"message": f"开始生成章节: {section_title}"} + details={"message": f"Starting section: {section_title}"} ) def log_react_thought(self, section_title: str, section_index: int, iteration: int, thought: str): - """记录 ReACT 思考过程""" + """Log React Thought.""" self.log( action="react_thought", stage="generating", @@ -159,7 +131,7 @@ def log_react_thought(self, section_title: str, section_index: int, iteration: i details={ "iteration": iteration, "thought": thought, - "message": f"ReACT 第{iteration}轮思考" + "message": f"ReACT thought round {iteration}" } ) @@ -171,7 +143,7 @@ def log_tool_call( parameters: Dict[str, Any], iteration: int ): - """记录工具调用""" + """Log Tool Call.""" self.log( action="tool_call", stage="generating", @@ -181,7 +153,7 @@ def log_tool_call( "iteration": iteration, "tool_name": tool_name, "parameters": parameters, - "message": f"调用工具: {tool_name}" + "message": f"Calling tool: {tool_name}" } ) @@ -193,7 +165,7 @@ def log_tool_result( result: str, iteration: int ): - """记录工具调用结果(完整内容,不截断)""" + """Log Tool Result.""" self.log( action="tool_result", stage="generating", @@ -202,9 +174,9 @@ def log_tool_result( details={ "iteration": iteration, "tool_name": tool_name, - "result": result, # 完整结果,不截断 + "result": result, "result_length": len(result), - "message": f"工具 {tool_name} 返回结果" + "message": f"Tool result returned: {tool_name}" } ) @@ -217,7 +189,7 @@ def log_llm_response( has_tool_calls: bool, has_final_answer: bool ): - """记录 LLM 响应(完整内容,不截断)""" + """Log LLM Response.""" self.log( action="llm_response", stage="generating", @@ -225,11 +197,11 @@ def log_llm_response( section_index=section_index, details={ "iteration": iteration, - "response": response, # 完整响应,不截断 + "response": response, "response_length": len(response), "has_tool_calls": has_tool_calls, "has_final_answer": has_final_answer, - "message": f"LLM 响应 (工具调用: {has_tool_calls}, 最终答案: {has_final_answer})" + "message": f"LLM response (tool_calls={has_tool_calls}, final_answer={has_final_answer})" } ) @@ -240,17 +212,17 @@ def log_section_content( content: str, tool_calls_count: int ): - """记录章节内容生成完成(仅记录内容,不代表整个章节完成)""" + """Log Section Content.""" self.log( action="section_content", stage="generating", section_title=section_title, section_index=section_index, details={ - "content": content, # 完整内容,不截断 + "content": content, "content_length": len(content), "tool_calls_count": tool_calls_count, - "message": f"章节 {section_title} 内容生成完成" + "message": f"Section content completed: {section_title}" } ) @@ -260,11 +232,7 @@ def log_section_full_complete( section_index: int, full_content: str ): - """ - 记录章节生成完成 - - 前端应监听此日志来判断一个章节是否真正完成,并获取完整内容 - """ + """Log Section Full Complete.""" self.log( action="section_complete", stage="generating", @@ -273,24 +241,24 @@ def log_section_full_complete( details={ "content": full_content, "content_length": len(full_content), - "message": f"章节 {section_title} 生成完成" + "message": f"Section completed: {section_title}" } ) def log_report_complete(self, total_sections: int, total_time_seconds: float): - """记录报告生成完成""" + """Log Report Complete.""" self.log( action="report_complete", stage="completed", details={ "total_sections": total_sections, "total_time_seconds": round(total_time_seconds, 2), - "message": "报告生成完成" + "message": "Report generation completed" } ) def log_error(self, error_message: str, stage: str, section_title: str = None): - """记录错误""" + """Log Error.""" self.log( action="error", stage=stage, @@ -298,26 +266,16 @@ def log_error(self, error_message: str, stage: str, section_title: str = None): section_index=None, details={ "error": error_message, - "message": f"发生错误: {error_message}" + "message": f"Error: {error_message}" } ) class ReportConsoleLogger: - """ - Report Agent 控制台日志记录器 - - 将控制台风格的日志(INFO、WARNING等)写入报告文件夹中的 console_log.txt 文件。 - 这些日志与 agent_log.jsonl 不同,是纯文本格式的控制台输出。 - """ + """Report Console Logger.""" def __init__(self, report_id: str): - """ - 初始化控制台日志记录器 - - Args: - report_id: 报告ID,用于确定日志文件路径 - """ + """Initialize the instance.""" self.report_id = report_id self.log_file_path = os.path.join( Config.UPLOAD_FOLDER, 'reports', report_id, 'console_log.txt' @@ -327,15 +285,15 @@ def __init__(self, report_id: str): self._setup_file_handler() def _ensure_log_file(self): - """确保日志文件所在目录存在""" + """Ensure log file.""" log_dir = os.path.dirname(self.log_file_path) os.makedirs(log_dir, exist_ok=True) def _setup_file_handler(self): - """设置文件处理器,将日志同时写入文件""" + """Setup File Handler.""" import logging - # 创建文件处理器 + self._file_handler = logging.FileHandler( self.log_file_path, mode='a', @@ -343,14 +301,14 @@ def _setup_file_handler(self): ) self._file_handler.setLevel(logging.INFO) - # 使用与控制台相同的简洁格式 + formatter = logging.Formatter( '[%(asctime)s] %(levelname)s: %(message)s', datefmt='%H:%M:%S' ) self._file_handler.setFormatter(formatter) - # 添加到 report_agent 相关的 logger + loggers_to_attach = [ 'mirofish.report_agent', 'mirofish.zep_tools', @@ -358,12 +316,12 @@ def _setup_file_handler(self): for logger_name in loggers_to_attach: target_logger = logging.getLogger(logger_name) - # 避免重复添加 + if self._file_handler not in target_logger.handlers: target_logger.addHandler(self._file_handler) def close(self): - """关闭文件处理器并从 logger 中移除""" + """Close the requested object.""" import logging if self._file_handler: @@ -381,12 +339,12 @@ def close(self): self._file_handler = None def __del__(self): - """析构时确保关闭文件处理器""" + """Clean up resources when the instance is destroyed.""" self.close() class ReportStatus(str, Enum): - """报告状态""" + """Report Status.""" PENDING = "pending" PLANNING = "planning" GENERATING = "generating" @@ -396,7 +354,7 @@ class ReportStatus(str, Enum): @dataclass class ReportSection: - """报告章节""" + """Report Section.""" title: str content: str = "" @@ -407,7 +365,7 @@ def to_dict(self) -> Dict[str, Any]: } def to_markdown(self, level: int = 2) -> str: - """转换为Markdown格式""" + """Convert the object to Markdown.""" md = f"{'#' * level} {self.title}\n\n" if self.content: md += f"{self.content}\n\n" @@ -416,7 +374,7 @@ def to_markdown(self, level: int = 2) -> str: @dataclass class ReportOutline: - """报告大纲""" + """Report Outline.""" title: str summary: str sections: List[ReportSection] @@ -429,7 +387,7 @@ def to_dict(self) -> Dict[str, Any]: } def to_markdown(self) -> str: - """转换为Markdown格式""" + """Convert the object to Markdown.""" md = f"# {self.title}\n\n" md += f"> {self.summary}\n\n" for section in self.sections: @@ -439,7 +397,7 @@ def to_markdown(self) -> str: @dataclass class Report: - """完整报告""" + """Report.""" report_id: str simulation_id: str graph_id: str @@ -467,417 +425,397 @@ def to_dict(self) -> Dict[str, Any]: # ═══════════════════════════════════════════════════════════════ -# Prompt 模板常量 + # ═══════════════════════════════════════════════════════════════ -# ── 工具描述 ── + TOOL_DESC_INSIGHT_FORGE = """\ -【深度洞察检索 - 强大的检索工具】 -这是我们强大的检索函数,专为深度分析设计。它会: -1. 自动将你的问题分解为多个子问题 -2. 从多个维度检索模拟图谱中的信息 -3. 整合语义搜索、实体分析、关系链追踪的结果 -4. 返回最全面、最深度的检索内容 - -【使用场景】 -- 需要深入分析某个话题 -- 需要了解事件的多个方面 -- 需要获取支撑报告章节的丰富素材 - -【返回内容】 -- 相关事实原文(可直接引用) -- 核心实体洞察 -- 关系链分析""" +[Deep Forecast Retrieval] +A strong retrieval tool for section-level analysis. It: +1. Breaks a question into multiple sub-questions +2. Searches the simulation graph from different angles +3. Combines semantic facts, entity insights, and relationship chains +4. Returns dense supporting evidence for report writing + +Best for: +- deep analysis of one topic +- understanding multiple sides of an event +- collecting rich evidence for a report section + +Returns: +- source facts you can quote +- core entity insights +- relationship chains""" TOOL_DESC_PANORAMA_SEARCH = """\ -【广度搜索 - 获取全貌视图】 -这个工具用于获取模拟结果的完整全貌,特别适合了解事件演变过程。它会: -1. 获取所有相关节点和关系 -2. 区分当前有效的事实和历史/过期的事实 -3. 帮助你了解舆情是如何演变的 - -【使用场景】 -- 需要了解事件的完整发展脉络 -- 需要对比不同阶段的舆情变化 -- 需要获取全面的实体和关系信息 - -【返回内容】 -- 当前有效事实(模拟最新结果) -- 历史/过期事实(演变记录) -- 所有涉及的实体""" +[Panorama Search] +Use this tool to understand the full picture of the simulation outcome. It: +1. Retrieves all relevant nodes and relationships +2. Separates active facts from historical or expired facts +3. Helps you understand how the situation evolved over time + +Best for: +- reconstructing the full event arc +- comparing different stages of evolution +- gathering broad entity and relationship coverage + +Returns: +- active facts from the latest simulation state +- historical or expired facts +- involved entities""" TOOL_DESC_QUICK_SEARCH = """\ -【简单搜索 - 快速检索】 -轻量级的快速检索工具,适合简单、直接的信息查询。 +[Quick Search] +A lightweight retrieval tool for simple, direct fact lookup. -【使用场景】 -- 需要快速查找某个具体信息 -- 需要验证某个事实 -- 简单的信息检索 +Best for: +- checking one specific detail +- validating a fact +- fast retrieval with minimal scope -【返回内容】 -- 与查询最相关的事实列表""" +Returns: +- the most relevant facts for the query""" TOOL_DESC_INTERVIEW_AGENTS = """\ -【深度采访 - 真实Agent采访(双平台)】 -调用OASIS模拟环境的采访API,对正在运行的模拟Agent进行真实采访! -这不是LLM模拟,而是调用真实的采访接口获取模拟Agent的原始回答。 -默认在Twitter和Reddit两个平台同时采访,获取更全面的观点。 - -功能流程: -1. 自动读取人设文件,了解所有模拟Agent -2. 智能选择与采访主题最相关的Agent(如学生、媒体、官方等) -3. 自动生成采访问题 -4. 调用 /api/simulation/interview/batch 接口在双平台进行真实采访 -5. 整合所有采访结果,提供多视角分析 - -【使用场景】 -- 需要从不同角色视角了解事件看法(学生怎么看?媒体怎么看?官方怎么说?) -- 需要收集多方意见和立场 -- 需要获取模拟Agent的真实回答(来自OASIS模拟环境) -- 想让报告更生动,包含"采访实录" - -【返回内容】 -- 被采访Agent的身份信息 -- 各Agent在Twitter和Reddit两个平台的采访回答 -- 关键引言(可直接引用) -- 采访摘要和观点对比 - -【重要】需要OASIS模拟环境正在运行才能使用此功能!""" - -# ── 大纲规划 prompt ── +[Live Agent Interviews] +Calls the OASIS interview API to interview live simulated agents. +This is not an LLM roleplay. It retrieves raw answers from the running +simulation, and by default interviews agents across both Twitter and Reddit. + +Workflow: +1. Load available agent profiles +2. Select the most relevant agents for the topic +3. Generate interview questions +4. Call /api/simulation/interview/batch for both platforms +5. Merge the results into a multi-perspective interview report + +Best for: +- understanding how different roles react +- collecting multiple viewpoints and tensions +- obtaining first-person answers from the simulation +- making the report more vivid with interview material + +Returns: +- interviewed agent identities +- Twitter and Reddit answers +- key quotes you can cite +- an interview summary and comparison + +Important: the OASIS simulation environment must be running.""" + PLAN_SYSTEM_PROMPT = """\ -你是一个「未来预测报告」的撰写专家,拥有对模拟世界的「上帝视角」——你可以洞察模拟中每一位Agent的行为、言论和互动。 - -【核心理念】 -我们构建了一个模拟世界,并向其中注入了特定的「模拟需求」作为变量。模拟世界的演化结果,就是对未来可能发生情况的预测。你正在观察的不是"实验数据",而是"未来的预演"。 - -【你的任务】 -撰写一份「未来预测报告」,回答: -1. 在我们设定的条件下,未来发生了什么? -2. 各类Agent(人群)是如何反应和行动? -3. 这个模拟揭示了哪些值得关注的未来趋势和风险? - -【报告定位】 -- ✅ 这是一份基于模拟的未来预测报告,揭示"如果这样,未来会怎样" -- ✅ 聚焦于预测结果:事件走向、群体反应、涌现现象、潜在风险 -- ✅ 模拟世界中的Agent言行就是对未来人群行为的预测 -- ❌ 不是对现实世界现状的分析 -- ❌ 不是泛泛而谈的舆情综述 - -【章节数量限制】 -- 最少2个章节,最多5个章节 -- 不需要子章节,每个章节直接撰写完整内容 -- 内容要精炼,聚焦于核心预测发现 -- 章节结构由你根据预测结果自主设计 - -请输出JSON格式的报告大纲,格式如下: +You are an expert writer of future-facing forecast reports with a God's-eye +view of the simulation world. You can observe every agent's behavior, speech, +and interaction. + +[Core idea] +We injected a specific simulation requirement into a synthetic world. The way +that world evolved is our forecast of what may happen next. You are not looking +at ordinary experiment data. You are looking at a rehearsal of the future. + +[Your task] +Write a forecast report that answers: +1. What happened in the simulated future under the given condition? +2. How did different agents or groups react and act? +3. What future trends, risks, or opportunities does this simulation reveal? + +[Report positioning] +- This is a simulation-based forecast report about what may happen next +- Focus on outcomes, reactions, emergence, trends, and risks +- Agent behavior in the simulation is evidence about likely future behavior +- Do not write a current-state analysis of the real world +- Do not write a generic public-opinion recap + +[Section count] +- Minimum 2 sections, maximum 5 sections +- No subsections are needed +- Keep the structure focused and concise +- Design the section structure based on the forecast itself + +Return the outline as JSON in this format: { - "title": "报告标题", - "summary": "报告摘要(一句话概括核心预测发现)", + "title": "Report title", + "summary": "One-sentence summary of the core forecast finding", "sections": [ { - "title": "章节标题", - "description": "章节内容描述" + "title": "Section title", + "description": "What this section covers" } ] } -注意:sections数组最少2个,最多5个元素!""" +The sections array must contain between 2 and 5 items.""" PLAN_USER_PROMPT_TEMPLATE = """\ -【预测场景设定】 -我们向模拟世界注入的变量(模拟需求):{simulation_requirement} +[Forecast scenario] +Injected simulation requirement: {simulation_requirement} -【模拟世界规模】 -- 参与模拟的实体数量: {total_nodes} -- 实体间产生的关系数量: {total_edges} -- 实体类型分布: {entity_types} -- 活跃Agent数量: {total_entities} +[Simulation scale] +- Total entities in the simulation: {total_nodes} +- Total relationships created: {total_edges} +- Entity type distribution: {entity_types} +- Active agents: {total_entities} -【模拟预测到的部分未来事实样本】 +[Sample future facts from the simulation] {related_facts_json} -请以「上帝视角」审视这个未来预演: -1. 在我们设定的条件下,未来呈现出了什么样的状态? -2. 各类人群(Agent)是如何反应和行动的? -3. 这个模拟揭示了哪些值得关注的未来趋势? +Review this future rehearsal from a God's-eye view: +1. What kind of future state emerged under the given condition? +2. How did different agents or groups react and act? +3. What trends, risks, or opportunities stand out? -根据预测结果,设计最合适的报告章节结构。 +Design the most appropriate report section structure based on the forecast. -【再次提醒】报告章节数量:最少2个,最多5个,内容要精炼聚焦于核心预测发现。""" +Reminder: keep the report between 2 and 5 sections and focus on the strongest +forecast findings.""" -# ── 章节生成 prompt ── SECTION_SYSTEM_PROMPT_TEMPLATE = """\ -你是一个「未来预测报告」的撰写专家,正在撰写报告的一个章节。 - -报告标题: {report_title} -报告摘要: {report_summary} -预测场景(模拟需求): {simulation_requirement} - -当前要撰写的章节: {section_title} - -═══════════════════════════════════════════════════════════════ -【核心理念】 -═══════════════════════════════════════════════════════════════ - -模拟世界是对未来的预演。我们向模拟世界注入了特定条件(模拟需求), -模拟中Agent的行为和互动,就是对未来人群行为的预测。 - -你的任务是: -- 揭示在设定条件下,未来发生了什么 -- 预测各类人群(Agent)是如何反应和行动的 -- 发现值得关注的未来趋势、风险和机会 - -❌ 不要写成对现实世界现状的分析 -✅ 要聚焦于"未来会怎样"——模拟结果就是预测的未来 - -═══════════════════════════════════════════════════════════════ -【最重要的规则 - 必须遵守】 -═══════════════════════════════════════════════════════════════ - -1. 【必须调用工具观察模拟世界】 - - 你正在以「上帝视角」观察未来的预演 - - 所有内容必须来自模拟世界中发生的事件和Agent言行 - - 禁止使用你自己的知识来编写报告内容 - - 每个章节至少调用3次工具(最多5次)来观察模拟的世界,它代表了未来 - -2. 【必须引用Agent的原始言行】 - - Agent的发言和行为是对未来人群行为的预测 - - 在报告中使用引用格式展示这些预测,例如: - > "某类人群会表示:原文内容..." - - 这些引用是模拟预测的核心证据 - -3. 【语言一致性 - 引用内容必须翻译为报告语言】 - - 工具返回的内容可能包含英文或中英文混杂的表述 - - 如果模拟需求和材料原文是中文的,报告必须全部使用中文撰写 - - 当你引用工具返回的英文或中英混杂内容时,必须将其翻译为流畅的中文后再写入报告 - - 翻译时保持原意不变,确保表述自然通顺 - - 这一规则同时适用于正文和引用块(> 格式)中的内容 - -4. 【忠实呈现预测结果】 - - 报告内容必须反映模拟世界中的代表未来的模拟结果 - - 不要添加模拟中不存在的信息 - - 如果某方面信息不足,如实说明 - -═══════════════════════════════════════════════════════════════ -【⚠️ 格式规范 - 极其重要!】 -═══════════════════════════════════════════════════════════════ - -【一个章节 = 最小内容单位】 -- 每个章节是报告的最小分块单位 -- ❌ 禁止在章节内使用任何 Markdown 标题(#、##、###、#### 等) -- ❌ 禁止在内容开头添加章节主标题 -- ✅ 章节标题由系统自动添加,你只需撰写纯正文内容 -- ✅ 使用**粗体**、段落分隔、引用、列表来组织内容,但不要用标题 - -【正确示例】 +You are writing one section of a future forecast report. + +Report title: {report_title} +Report summary: {report_summary} +Simulation requirement: {simulation_requirement} +Current section: {section_title} + +================================================================ +[Core idea] +================================================================ + +The simulation is a rehearsal of the future. The requirement injected into the +simulation acts as the condition, and agent behavior is evidence of how people +may respond in that future. + +Your job is to: +- explain what happened in the simulated future +- forecast how different groups reacted and acted +- surface important trends, risks, and opportunities + +Do not write a present-day real-world analysis. +Focus on what the simulation suggests will happen next. + +================================================================ +[Critical rules] +================================================================ + +1. You must observe the simulation through tools + - All content must come from events, facts, and agent behavior in the simulation + - Do not use your own outside knowledge + - Call tools at least 3 times and at most 5 times for each section + +2. You must quote agent behavior and language + - Agent actions and statements are core forecast evidence + - Use block quotes to show representative evidence, for example: + > "One group may say: translated or paraphrased source content..." + +3. The report must be written entirely in English + - Tool output may contain Chinese, English, or mixed-language content + - Translate all quoted or paraphrased source material into fluent English + - Preserve the original meaning when translating + - This rule applies to both body text and block quotes + +4. Stay faithful to the forecast + - Reflect only what exists in the simulation results + - Do not add unsupported claims + - If evidence is missing, say so plainly + +================================================================ +[Formatting rules] +================================================================ + +[One section is the smallest unit] +- Each section is a single content block +- Do not use Markdown headings inside the section +- Do not repeat the section title at the top +- The system inserts the section title automatically +- Use bold text, paragraphs, quotes, and lists to organize content + +[Correct example] ``` -本章节分析了事件的舆论传播态势。通过对模拟数据的深入分析,我们发现... +This section examines how the issue spread across the simulated platforms. +The simulation shows a sharp shift from curiosity to coordinated criticism. -**首发引爆阶段** +**Early ignition** -微博作为舆情的第一现场,承担了信息首发的核心功能: +The first wave of attention came from a small but highly active cluster. -> "微博贡献了68%的首发声量..." +> "Early participants frame the event as proof of institutional weakness." -**情绪放大阶段** +**Escalation** -抖音平台进一步放大了事件影响力: +As the issue spreads, more agents amplify the risk: -- 视觉冲击力强 -- 情绪共鸣度高 +- reputational pressure rises +- defensive responses draw more scrutiny ``` -【错误示例】 +[Incorrect example] ``` -## 执行摘要 ← 错误!不要添加任何标题 -### 一、首发阶段 ← 错误!不要用###分小节 -#### 1.1 详细分析 ← 错误!不要用####细分 - -本章节分析了... +## Executive Summary +### Phase One +#### 1.1 Detail ``` -═══════════════════════════════════════════════════════════════ -【可用检索工具】(每章节调用3-5次) -═══════════════════════════════════════════════════════════════ +================================================================ +[Available tools] +================================================================ {tools_description} -【工具使用建议 - 请混合使用不同工具,不要只用一种】 -- insight_forge: 深度洞察分析,自动分解问题并多维度检索事实和关系 -- panorama_search: 广角全景搜索,了解事件全貌、时间线和演变过程 -- quick_search: 快速验证某个具体信息点 -- interview_agents: 采访模拟Agent,获取不同角色的第一人称观点和真实反应 +[Tool usage advice] +- insight_forge: deep analysis across facts, entities, and relationships +- panorama_search: full-picture retrieval and evolution over time +- quick_search: fast validation of a specific point +- interview_agents: first-person reactions from simulated agents -═══════════════════════════════════════════════════════════════ -【工作流程】 -═══════════════════════════════════════════════════════════════ +================================================================ +[Workflow] +================================================================ -每次回复你只能做以下两件事之一(不可同时做): +Each reply may do exactly one of the following: -选项A - 调用工具: -输出你的思考,然后用以下格式调用一个工具: +Option A - Call one tool +Write your thought, then call exactly one tool in this format: -{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} +{{"name": "tool_name", "parameters": {{"param": "value"}}}} -系统会执行工具并把结果返回给你。你不需要也不能自己编写工具返回结果。 - -选项B - 输出最终内容: -当你已通过工具获取了足够信息,以 "Final Answer:" 开头输出章节内容。 - -⚠️ 严格禁止: -- 禁止在一次回复中同时包含工具调用和 Final Answer -- 禁止自己编造工具返回结果(Observation),所有工具结果由系统注入 -- 每次回复最多调用一个工具 - -═══════════════════════════════════════════════════════════════ -【章节内容要求】 -═══════════════════════════════════════════════════════════════ - -1. 内容必须基于工具检索到的模拟数据 -2. 大量引用原文来展示模拟效果 -3. 使用Markdown格式(但禁止使用标题): - - 使用 **粗体文字** 标记重点(代替子标题) - - 使用列表(-或1.2.3.)组织要点 - - 使用空行分隔不同段落 - - ❌ 禁止使用 #、##、###、#### 等任何标题语法 -4. 【引用格式规范 - 必须单独成段】 - 引用必须独立成段,前后各有一个空行,不能混在段落中: - - ✅ 正确格式: - ``` - 校方的回应被认为缺乏实质内容。 - - > "校方的应对模式在瞬息万变的社交媒体环境中显得僵化和迟缓。" - - 这一评价反映了公众的普遍不满。 - ``` - - ❌ 错误格式: - ``` - 校方的回应被认为缺乏实质内容。> "校方的应对模式..." 这一评价反映了... - ``` -5. 保持与其他章节的逻辑连贯性 -6. 【避免重复】仔细阅读下方已完成的章节内容,不要重复描述相同的信息 -7. 【再次强调】不要添加任何标题!用**粗体**代替小节标题""" + +Option B - Produce final content +Once you have enough evidence, output the section starting with "Final Answer:" + +Strictly forbidden: +- mixing a tool call and Final Answer in the same reply +- inventing tool observations yourself +- calling more than one tool per reply + +================================================================ +[Section content requirements] +================================================================ + +1. Base the section on retrieved simulation evidence +2. Use direct evidence and quotes generously +3. Use Markdown, but no headings + - use **bold** for emphasis + - use lists where helpful + - use blank lines between paragraphs + - do not use #, ##, ###, or similar heading syntax +4. Quotes must stand alone as separate blocks +5. Keep continuity with earlier sections +6. Read prior sections carefully and avoid repeating the same points +7. Do not add any heading at the top of the section""" SECTION_USER_PROMPT_TEMPLATE = """\ -已完成的章节内容(请仔细阅读,避免重复): +Completed sections so far (read carefully and avoid repetition): {previous_content} -═══════════════════════════════════════════════════════════════ -【当前任务】撰写章节: {section_title} -═══════════════════════════════════════════════════════════════ +================================================================ +[Current task] Write section: {section_title} +================================================================ -【重要提醒】 -1. 仔细阅读上方已完成的章节,避免重复相同的内容! -2. 开始前必须先调用工具获取模拟数据 -3. 请混合使用不同工具,不要只用一种 -4. 报告内容必须来自检索结果,不要使用自己的知识 +[Important reminders] +1. Read the completed sections first so you do not repeat them +2. You must call a tool before writing the section +3. Mix different tools rather than relying on just one +4. All claims must come from retrieved simulation evidence -【⚠️ 格式警告 - 必须遵守】 -- ❌ 不要写任何标题(#、##、###、####都不行) -- ❌ 不要写"{section_title}"作为开头 -- ✅ 章节标题由系统自动添加 -- ✅ 直接写正文,用**粗体**代替小节标题 +[Formatting warning] +- Do not write any heading +- Do not start with "{section_title}" +- The system adds the section title automatically +- Write body text directly and use **bold** instead of headings -请开始: -1. 首先思考(Thought)这个章节需要什么信息 -2. 然后调用工具(Action)获取模拟数据 -3. 收集足够信息后输出 Final Answer(纯正文,无任何标题)""" +Please begin: +1. Think about what information you need +2. Call one tool to retrieve simulation evidence +3. After collecting enough evidence, output Final Answer with body text only""" -# ── ReACT 循环内消息模板 ── REACT_OBSERVATION_TEMPLATE = """\ -Observation(检索结果): +Observation: -═══ 工具 {tool_name} 返回 ═══ +=== Tool {tool_name} returned === {result} -═══════════════════════════════════════════════════════════════ -已调用工具 {tool_calls_count}/{max_tool_calls} 次(已用: {used_tools_str}){unused_hint} -- 如果信息充分:以 "Final Answer:" 开头输出章节内容(必须引用上述原文) -- 如果需要更多信息:调用一个工具继续检索 -═══════════════════════════════════════════════════════════════""" +================================================================ +Tools used: {tool_calls_count}/{max_tool_calls} (used: {used_tools_str}){unused_hint} +- If you have enough evidence, output the section starting with "Final Answer:" +- If you still need more evidence, call one more tool +================================================================""" REACT_INSUFFICIENT_TOOLS_MSG = ( - "【注意】你只调用了{tool_calls_count}次工具,至少需要{min_tool_calls}次。" - "请再调用工具获取更多模拟数据,然后再输出 Final Answer。{unused_hint}" + "You have only called tools {tool_calls_count} time(s), but at least " + "{min_tool_calls} calls are required before Final Answer. " + "Please retrieve more simulation evidence first.{unused_hint}" ) REACT_INSUFFICIENT_TOOLS_MSG_ALT = ( - "当前只调用了 {tool_calls_count} 次工具,至少需要 {min_tool_calls} 次。" - "请调用工具获取模拟数据。{unused_hint}" + "Only {tool_calls_count} tool call(s) have been made so far; at least " + "{min_tool_calls} are required. Please call another tool.{unused_hint}" ) REACT_TOOL_LIMIT_MSG = ( - "工具调用次数已达上限({tool_calls_count}/{max_tool_calls}),不能再调用工具。" - '请立即基于已获取的信息,以 "Final Answer:" 开头输出章节内容。' + "The tool-call limit has been reached ({tool_calls_count}/{max_tool_calls}). " + 'Please output the section immediately starting with "Final Answer:".' ) -REACT_UNUSED_TOOLS_HINT = "\n💡 你还没有使用过: {unused_list},建议尝试不同工具获取多角度信息" +REACT_UNUSED_TOOLS_HINT = "\nHint: you have not used these tools yet: {unused_list}" -REACT_FORCE_FINAL_MSG = "已达到工具调用限制,请直接输出 Final Answer: 并生成章节内容。" +REACT_FORCE_FINAL_MSG = ( + 'The tool-call budget is exhausted. Output the section now starting with ' + '"Final Answer:".' +) -# ── Chat prompt ── +# -- Chat prompt -- CHAT_SYSTEM_PROMPT_TEMPLATE = """\ -你是一个简洁高效的模拟预测助手。 +You are a concise simulation forecast assistant. -【背景】 -预测条件: {simulation_requirement} +[Context] +Simulation requirement: {simulation_requirement} -【已生成的分析报告】 +[Generated report] {report_content} -【规则】 -1. 优先基于上述报告内容回答问题 -2. 直接回答问题,避免冗长的思考论述 -3. 仅在报告内容不足以回答时,才调用工具检索更多数据 -4. 回答要简洁、清晰、有条理 +[Rules] +1. Prefer answering from the existing report content +2. Answer directly and avoid long reasoning monologues +3. Call tools only when the report is not enough +4. Keep the answer concise, clear, and well structured +5. Write in English -【可用工具】(仅在需要时使用,最多调用1-2次) +[Available tools] {tools_description} -【工具调用格式】 +[Tool call format] -{{"name": "工具名称", "parameters": {{"参数名": "参数值"}}}} +{{"name": "tool_name", "parameters": {{"param": "value"}}}} -【回答风格】 -- 简洁直接,不要长篇大论 -- 使用 > 格式引用关键内容 -- 优先给出结论,再解释原因""" +[Answer style] +- lead with the conclusion +- use > quotes for key evidence when useful +- stay brief""" -CHAT_OBSERVATION_SUFFIX = "\n\n请简洁回答问题。" +CHAT_OBSERVATION_SUFFIX = "\n\nPlease answer briefly in English." # ═══════════════════════════════════════════════════════════════ -# ReportAgent 主类 + # ═══════════════════════════════════════════════════════════════ class ReportAgent: - """ - Report Agent - 模拟报告生成Agent - - 采用ReACT(Reasoning + Acting)模式: - 1. 规划阶段:分析模拟需求,规划报告目录结构 - 2. 生成阶段:逐章节生成内容,每章节可多次调用工具获取信息 - 3. 反思阶段:检查内容完整性和准确性 - """ + """Report Agent.""" + - # 最大工具调用次数(每个章节) MAX_TOOL_CALLS_PER_SECTION = 5 - # 最大反思轮数 + MAX_REFLECTION_ROUNDS = 3 - # 对话中的最大工具调用次数 + MAX_TOOL_CALLS_PER_CHAT = 2 def __init__( @@ -888,16 +826,7 @@ def __init__( llm_client: Optional[LLMClient] = None, zep_tools: Optional[ZepToolsService] = None ): - """ - 初始化Report Agent - - Args: - graph_id: 图谱ID - simulation_id: 模拟ID - simulation_requirement: 模拟需求描述 - llm_client: LLM客户端(可选) - zep_tools: Zep工具服务(可选) - """ + """Initialize the instance.""" self.graph_id = graph_id self.simulation_id = simulation_id self.simulation_requirement = simulation_requirement @@ -905,66 +834,56 @@ def __init__( self.llm = llm_client or LLMClient() self.zep_tools = zep_tools or ZepToolsService() - # 工具定义 + self.tools = self._define_tools() - # 日志记录器(在 generate_report 中初始化) + self.report_logger: Optional[ReportLogger] = None - # 控制台日志记录器(在 generate_report 中初始化) + self.console_logger: Optional[ReportConsoleLogger] = None - logger.info(f"ReportAgent 初始化完成: graph_id={graph_id}, simulation_id={simulation_id}") + logger.info(f"ReportAgent initialized: graph_id={graph_id}, simulation_id={simulation_id}") def _define_tools(self) -> Dict[str, Dict[str, Any]]: - """定义可用工具""" + """Define Tools.""" return { "insight_forge": { "name": "insight_forge", "description": TOOL_DESC_INSIGHT_FORGE, "parameters": { - "query": "你想深入分析的问题或话题", - "report_context": "当前报告章节的上下文(可选,有助于生成更精准的子问题)" + "query": "Question or topic to analyze deeply", + "report_context": "Current report-section context (optional)" } }, "panorama_search": { "name": "panorama_search", "description": TOOL_DESC_PANORAMA_SEARCH, "parameters": { - "query": "搜索查询,用于相关性排序", - "include_expired": "是否包含过期/历史内容(默认True)" + "query": "Search query used for relevance sorting", + "include_expired": "Whether to include historical or expired facts (default True)" } }, "quick_search": { "name": "quick_search", "description": TOOL_DESC_QUICK_SEARCH, "parameters": { - "query": "搜索查询字符串", - "limit": "返回结果数量(可选,默认10)" + "query": "Search query string", + "limit": "Maximum number of results to return (optional, default 10)" } }, "interview_agents": { "name": "interview_agents", "description": TOOL_DESC_INTERVIEW_AGENTS, "parameters": { - "interview_topic": "采访主题或需求描述(如:'了解学生对宿舍甲醛事件的看法')", - "max_agents": "最多采访的Agent数量(可选,默认5,最大10)" + "interview_topic": "Interview topic or request description", + "max_agents": "Maximum number of agents to interview (optional, default 5, max 10)" } } } def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_context: str = "") -> str: - """ - 执行工具调用 - - Args: - tool_name: 工具名称 - parameters: 工具参数 - report_context: 报告上下文(用于InsightForge) - - Returns: - 工具执行结果(文本格式) - """ - logger.info(f"执行工具: {tool_name}, 参数: {parameters}") + """Execute Tool.""" + logger.info(f"Executing tool: {tool_name}, parameters: {parameters}") try: if tool_name == "insight_forge": @@ -979,7 +898,7 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return result.to_text() elif tool_name == "panorama_search": - # 广度搜索 - 获取全貌 + query = parameters.get("query", "") include_expired = parameters.get("include_expired", True) if isinstance(include_expired, str): @@ -992,7 +911,7 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return result.to_text() elif tool_name == "quick_search": - # 简单搜索 - 快速检索 + query = parameters.get("query", "") limit = parameters.get("limit", 10) if isinstance(limit, str): @@ -1005,7 +924,7 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return result.to_text() elif tool_name == "interview_agents": - # 深度采访 - 调用真实的OASIS采访API获取模拟Agent的回答(双平台) + interview_topic = parameters.get("interview_topic", parameters.get("query", "")) max_agents = parameters.get("max_agents", 5) if isinstance(max_agents, str): @@ -1019,11 +938,11 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte ) return result.to_text() - # ========== 向后兼容的旧工具(内部重定向到新工具) ========== + elif tool_name == "search_graph": - # 重定向到 quick_search - logger.info("search_graph 已重定向到 quick_search") + + logger.info("search_graph redirected to quick_search") return self._execute_tool("quick_search", parameters, report_context) elif tool_name == "get_graph_statistics": @@ -1039,8 +958,8 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return json.dumps(result, ensure_ascii=False, indent=2) elif tool_name == "get_simulation_context": - # 重定向到 insight_forge,因为它更强大 - logger.info("get_simulation_context 已重定向到 insight_forge") + + logger.info("get_simulation_context redirected to insight_forge") query = parameters.get("query", self.simulation_requirement) return self._execute_tool("insight_forge", {"query": query}, report_context) @@ -1054,26 +973,23 @@ def _execute_tool(self, tool_name: str, parameters: Dict[str, Any], report_conte return json.dumps(result, ensure_ascii=False, indent=2) else: - return f"未知工具: {tool_name}。请使用以下工具之一: insight_forge, panorama_search, quick_search" + return ( + f"Unknown tool: {tool_name}. Use one of: " + "insight_forge, panorama_search, quick_search, interview_agents" + ) except Exception as e: - logger.error(f"工具执行失败: {tool_name}, 错误: {str(e)}") - return f"工具执行失败: {str(e)}" + logger.error(f"Tool execution failed: {tool_name}, error: {str(e)}") + return f"Tool execution failed: {str(e)}" + - # 合法的工具名称集合,用于裸 JSON 兜底解析时校验 VALID_TOOL_NAMES = {"insight_forge", "panorama_search", "quick_search", "interview_agents"} def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: - """ - 从LLM响应中解析工具调用 - - 支持的格式(按优先级): - 1. {"name": "tool_name", "parameters": {...}} - 2. 裸 JSON(响应整体或单行就是一个工具调用 JSON) - """ + """Parse tool calls.""" tool_calls = [] - # 格式1: XML风格(标准格式) + xml_pattern = r'\s*(\{.*?\})\s*' for match in re.finditer(xml_pattern, response, re.DOTALL): try: @@ -1085,8 +1001,8 @@ def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: if tool_calls: return tool_calls - # 格式2: 兜底 - LLM 直接输出裸 JSON(没包 标签) - # 只在格式1未匹配时尝试,避免误匹配正文中的 JSON + + stripped = response.strip() if stripped.startswith('{') and stripped.endswith('}'): try: @@ -1097,7 +1013,7 @@ def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: except json.JSONDecodeError: pass - # 响应可能包含思考文字 + 裸 JSON,尝试提取最后一个 JSON 对象 + json_pattern = r'(\{"(?:name|tool)"\s*:.*?\})\s*$' match = re.search(json_pattern, stripped, re.DOTALL) if match: @@ -1111,11 +1027,11 @@ def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]: return tool_calls def _is_valid_tool_call(self, data: dict) -> bool: - """校验解析出的 JSON 是否是合法的工具调用""" - # 支持 {"name": ..., "parameters": ...} 和 {"tool": ..., "params": ...} 两种键名 + """Return whether valid tool call.""" + tool_name = data.get("name") or data.get("tool") if tool_name and tool_name in self.VALID_TOOL_NAMES: - # 统一键名为 name / parameters + if "tool" in data: data["name"] = data.pop("tool") if "params" in data and "parameters" not in data: @@ -1124,43 +1040,33 @@ def _is_valid_tool_call(self, data: dict) -> bool: return False def _get_tools_description(self) -> str: - """生成工具描述文本""" - desc_parts = ["可用工具:"] + """Get tools description.""" + desc_parts = ["Available tools:"] for name, tool in self.tools.items(): params_desc = ", ".join([f"{k}: {v}" for k, v in tool["parameters"].items()]) desc_parts.append(f"- {name}: {tool['description']}") if params_desc: - desc_parts.append(f" 参数: {params_desc}") + desc_parts.append(f" Parameters: {params_desc}") return "\n".join(desc_parts) def plan_outline( self, progress_callback: Optional[Callable] = None ) -> ReportOutline: - """ - 规划报告大纲 - - 使用LLM分析模拟需求,规划报告的目录结构 - - Args: - progress_callback: 进度回调函数 - - Returns: - ReportOutline: 报告大纲 - """ - logger.info("开始规划报告大纲...") + """Plan Outline.""" + logger.info("Starting report outline planning...") if progress_callback: - progress_callback("planning", 0, "正在分析模拟需求...") + progress_callback("planning", 0, "Analyzing simulation requirement...") + - # 首先获取模拟上下文 context = self.zep_tools.get_simulation_context( graph_id=self.graph_id, simulation_requirement=self.simulation_requirement ) if progress_callback: - progress_callback("planning", 30, "正在生成报告大纲...") + progress_callback("planning", 30, "Generating report outline...") system_prompt = PLAN_SYSTEM_PROMPT user_prompt = PLAN_USER_PROMPT_TEMPLATE.format( @@ -1182,9 +1088,9 @@ def plan_outline( ) if progress_callback: - progress_callback("planning", 80, "正在解析大纲结构...") + progress_callback("planning", 80, "Parsing outline structure...") + - # 解析大纲 sections = [] for section_data in response.get("sections", []): sections.append(ReportSection( @@ -1193,27 +1099,27 @@ def plan_outline( )) outline = ReportOutline( - title=response.get("title", "模拟分析报告"), + title=response.get("title", "Simulation Forecast Report"), summary=response.get("summary", ""), sections=sections ) if progress_callback: - progress_callback("planning", 100, "大纲规划完成") + progress_callback("planning", 100, "Outline planning completed") - logger.info(f"大纲规划完成: {len(sections)} 个章节") + logger.info(f"Outline planning completed: {len(sections)} sections") return outline except Exception as e: - logger.error(f"大纲规划失败: {str(e)}") - # 返回默认大纲(3个章节,作为fallback) + logger.error(f"Outline planning failed: {str(e)}") + return ReportOutline( - title="未来预测报告", - summary="基于模拟预测的未来趋势与风险分析", + title="Future Forecast Report", + summary="Forecast analysis of future trends and risks based on the simulation", sections=[ - ReportSection(title="预测场景与核心发现"), - ReportSection(title="人群行为预测分析"), - ReportSection(title="趋势展望与风险提示") + ReportSection(title="Forecast Scenario and Core Findings"), + ReportSection(title="Predicted Group Reactions"), + ReportSection(title="Outlook, Risks, and Opportunities") ] ) @@ -1225,29 +1131,10 @@ def _generate_section_react( progress_callback: Optional[Callable] = None, section_index: int = 0 ) -> str: - """ - 使用ReACT模式生成单个章节内容 - - ReACT循环: - 1. Thought(思考)- 分析需要什么信息 - 2. Action(行动)- 调用工具获取信息 - 3. Observation(观察)- 分析工具返回结果 - 4. 重复直到信息足够或达到最大次数 - 5. Final Answer(最终回答)- 生成章节内容 - - Args: - section: 要生成的章节 - outline: 完整大纲 - previous_sections: 之前章节的内容(用于保持连贯性) - progress_callback: 进度回调 - section_index: 章节索引(用于日志记录) - - Returns: - 章节内容(Markdown格式) - """ - logger.info(f"ReACT生成章节: {section.title}") - - # 记录章节开始日志 + """Generate section react.""" + logger.info(f"Generating section with ReACT: {section.title}") + + if self.report_logger: self.report_logger.log_section_start(section.title, section_index) @@ -1259,16 +1146,16 @@ def _generate_section_react( tools_description=self._get_tools_description(), ) - # 构建用户prompt - 每个已完成章节各传入最大4000字 + if previous_sections: previous_parts = [] for sec in previous_sections: - # 每个章节最多4000字 + truncated = sec[:4000] + "..." if len(sec) > 4000 else sec previous_parts.append(truncated) previous_content = "\n\n---\n\n".join(previous_parts) else: - previous_content = "(这是第一个章节)" + previous_content = "(This is the first section.)" user_prompt = SECTION_USER_PROMPT_TEMPLATE.format( previous_content=previous_content, @@ -1280,77 +1167,78 @@ def _generate_section_react( {"role": "user", "content": user_prompt} ] - # ReACT循环 + tool_calls_count = 0 - max_iterations = 5 # 最大迭代轮数 - min_tool_calls = 3 # 最少工具调用次数 - conflict_retries = 0 # 工具调用与Final Answer同时出现的连续冲突次数 - used_tools = set() # 记录已调用过的工具名 + max_iterations = 5 + min_tool_calls = 3 + conflict_retries = 0 + used_tools = set() all_tools = {"insight_forge", "panorama_search", "quick_search", "interview_agents"} - # 报告上下文,用于InsightForge的子问题生成 - report_context = f"章节标题: {section.title}\n模拟需求: {self.simulation_requirement}" + + report_context = f"Section title: {section.title}\nSimulation requirement: {self.simulation_requirement}" for iteration in range(max_iterations): if progress_callback: progress_callback( "generating", int((iteration / max_iterations) * 100), - f"深度检索与撰写中 ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})" + f"Retrieving evidence and drafting ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})" ) - # 调用LLM + response = self.llm.chat( messages=messages, temperature=0.5, max_tokens=4096 ) - # 检查 LLM 返回是否为 None(API 异常或内容为空) + if response is None: - logger.warning(f"章节 {section.title} 第 {iteration + 1} 次迭代: LLM 返回 None") - # 如果还有迭代次数,添加消息并重试 + logger.warning(f"Section {section.title} iteration {iteration + 1}: LLM returned None") + if iteration < max_iterations - 1: - messages.append({"role": "assistant", "content": "(响应为空)"}) - messages.append({"role": "user", "content": "请继续生成内容。"}) + messages.append({"role": "assistant", "content": "(Empty response)"}) + messages.append({"role": "user", "content": "Please continue generating content."}) continue - # 最后一次迭代也返回 None,跳出循环进入强制收尾 + break - logger.debug(f"LLM响应: {response[:200]}...") + logger.debug(f"LLM response: {response[:200]}...") - # 解析一次,复用结果 + tool_calls = self._parse_tool_calls(response) has_tool_calls = bool(tool_calls) has_final_answer = "Final Answer:" in response - # ── 冲突处理:LLM 同时输出了工具调用和 Final Answer ── + if has_tool_calls and has_final_answer: conflict_retries += 1 logger.warning( - f"章节 {section.title} 第 {iteration+1} 轮: " - f"LLM 同时输出工具调用和 Final Answer(第 {conflict_retries} 次冲突)" + f"Section {section.title} round {iteration + 1}: " + f"LLM produced both a tool call and Final Answer (conflict {conflict_retries})" ) if conflict_retries <= 2: - # 前两次:丢弃本次响应,要求 LLM 重新回复 + messages.append({"role": "assistant", "content": response}) messages.append({ "role": "user", "content": ( - "【格式错误】你在一次回复中同时包含了工具调用和 Final Answer,这是不允许的。\n" - "每次回复只能做以下两件事之一:\n" - "- 调用一个工具(输出一个 块,不要写 Final Answer)\n" - "- 输出最终内容(以 'Final Answer:' 开头,不要包含 )\n" - "请重新回复,只做其中一件事。" + "[Format error] You included both a tool call and Final Answer in the " + "same reply, which is not allowed.\n" + "Each reply may do only one of the following:\n" + "- Call one tool by outputting a block with no Final Answer\n" + "- Output the final content starting with 'Final Answer:' with no \n" + "Please reply again and do only one of them." ), }) continue else: - # 第三次:降级处理,截断到第一个工具调用,强制执行 + logger.warning( - f"章节 {section.title}: 连续 {conflict_retries} 次冲突," - "降级为截断执行第一个工具调用" + f"Section {section.title}: {conflict_retries} consecutive conflicts; " + "falling back to executing only the first tool call" ) first_tool_end = response.find('') if first_tool_end != -1: @@ -1360,7 +1248,7 @@ def _generate_section_react( has_final_answer = False conflict_retries = 0 - # 记录 LLM 响应日志 + if self.report_logger: self.report_logger.log_llm_response( section_title=section.title, @@ -1371,13 +1259,16 @@ def _generate_section_react( has_final_answer=has_final_answer ) - # ── 情况1:LLM 输出了 Final Answer ── + if has_final_answer: - # 工具调用次数不足,拒绝并要求继续调工具 + if tool_calls_count < min_tool_calls: messages.append({"role": "assistant", "content": response}) unused_tools = all_tools - used_tools - unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else "" + unused_hint = ( + f" (Unused tools you should consider: {', '.join(unused_tools)})" + if unused_tools else "" + ) messages.append({ "role": "user", "content": REACT_INSUFFICIENT_TOOLS_MSG.format( @@ -1388,9 +1279,9 @@ def _generate_section_react( }) continue - # 正常结束 + final_answer = response.split("Final Answer:")[-1].strip() - logger.info(f"章节 {section.title} 生成完成(工具调用: {tool_calls_count}次)") + logger.info(f"Section {section.title} completed (tool calls: {tool_calls_count})") if self.report_logger: self.report_logger.log_section_content( @@ -1401,9 +1292,9 @@ def _generate_section_react( ) return final_answer - # ── 情况2:LLM 尝试调用工具 ── + if has_tool_calls: - # 工具额度已耗尽 → 明确告知,要求输出 Final Answer + if tool_calls_count >= self.MAX_TOOL_CALLS_PER_SECTION: messages.append({"role": "assistant", "content": response}) messages.append({ @@ -1415,10 +1306,12 @@ def _generate_section_react( }) continue - # 只执行第一个工具调用 + call = tool_calls[0] if len(tool_calls) > 1: - logger.info(f"LLM 尝试调用 {len(tool_calls)} 个工具,只执行第一个: {call['name']}") + logger.info( + f"LLM tried to call {len(tool_calls)} tools; executing only the first: {call['name']}" + ) if self.report_logger: self.report_logger.log_tool_call( @@ -1447,11 +1340,11 @@ def _generate_section_react( tool_calls_count += 1 used_tools.add(call['name']) - # 构建未使用工具提示 + unused_tools = all_tools - used_tools unused_hint = "" if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION: - unused_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="、".join(unused_tools)) + unused_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list=", ".join(unused_tools)) messages.append({"role": "assistant", "content": response}) messages.append({ @@ -1467,13 +1360,16 @@ def _generate_section_react( }) continue - # ── 情况3:既没有工具调用,也没有 Final Answer ── + messages.append({"role": "assistant", "content": response}) if tool_calls_count < min_tool_calls: - # 工具调用次数不足,推荐未用过的工具 + unused_tools = all_tools - used_tools - unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else "" + unused_hint = ( + f" (Unused tools you should consider: {', '.join(unused_tools)})" + if unused_tools else "" + ) messages.append({ "role": "user", @@ -1485,9 +1381,12 @@ def _generate_section_react( }) continue - # 工具调用已足够,LLM 输出了内容但没带 "Final Answer:" 前缀 - # 直接将这段内容作为最终答案,不再空转 - logger.info(f"章节 {section.title} 未检测到 'Final Answer:' 前缀,直接采纳LLM输出作为最终内容(工具调用: {tool_calls_count}次)") + + + logger.info( + f"Section {section.title} had no 'Final Answer:' prefix; " + f"accepting raw LLM output as final content (tool calls: {tool_calls_count})" + ) final_answer = response.strip() if self.report_logger: @@ -1499,8 +1398,8 @@ def _generate_section_react( ) return final_answer - # 达到最大迭代次数,强制生成内容 - logger.warning(f"章节 {section.title} 达到最大迭代次数,强制生成") + + logger.warning(f"Section {section.title} hit the max iterations; forcing final output") messages.append({"role": "user", "content": REACT_FORCE_FINAL_MSG}) response = self.llm.chat( @@ -1509,16 +1408,18 @@ def _generate_section_react( max_tokens=4096 ) - # 检查强制收尾时 LLM 返回是否为 None + if response is None: - logger.error(f"章节 {section.title} 强制收尾时 LLM 返回 None,使用默认错误提示") - final_answer = f"(本章节生成失败:LLM 返回空响应,请稍后重试)" + logger.error( + f"Section {section.title} force-finish returned None from the LLM; using fallback text" + ) + final_answer = "(This section failed to generate because the LLM returned an empty response. Please retry.)" elif "Final Answer:" in response: final_answer = response.split("Final Answer:")[-1].strip() else: final_answer = response - # 记录章节内容生成完成日志 + if self.report_logger: self.report_logger.log_section_content( section_title=section.title, @@ -1534,30 +1435,10 @@ def generate_report( progress_callback: Optional[Callable[[str, int, str], None]] = None, report_id: Optional[str] = None ) -> Report: - """ - 生成完整报告(分章节实时输出) - - 每个章节生成完成后立即保存到文件夹,不需要等待整个报告完成。 - 文件结构: - reports/{report_id}/ - meta.json - 报告元信息 - outline.json - 报告大纲 - progress.json - 生成进度 - section_01.md - 第1章节 - section_02.md - 第2章节 - ... - full_report.md - 完整报告 - - Args: - progress_callback: 进度回调函数 (stage, progress, message) - report_id: 报告ID(可选,如果不传则自动生成) - - Returns: - Report: 完整报告 - """ + """Generate report.""" import uuid - # 如果没有传入 report_id,则自动生成 + if not report_id: report_id = f"report_{uuid.uuid4().hex[:12]}" start_time = datetime.now() @@ -1571,14 +1452,14 @@ def generate_report( created_at=datetime.now().isoformat() ) - # 已完成的章节标题列表(用于进度追踪) + completed_section_titles = [] try: - # 初始化:创建报告文件夹并保存初始状态 + ReportManager._ensure_report_folder(report_id) - # 初始化日志记录器(结构化日志 agent_log.jsonl) + self.report_logger = ReportLogger(report_id) self.report_logger.log_start( simulation_id=self.simulation_id, @@ -1586,27 +1467,27 @@ def generate_report( simulation_requirement=self.simulation_requirement ) - # 初始化控制台日志记录器(console_log.txt) + self.console_logger = ReportConsoleLogger(report_id) ReportManager.update_progress( - report_id, "pending", 0, "初始化报告...", + report_id, "pending", 0, "Initializing report...", completed_sections=[] ) ReportManager.save_report(report) - # 阶段1: 规划大纲 + report.status = ReportStatus.PLANNING ReportManager.update_progress( - report_id, "planning", 5, "开始规划报告大纲...", + report_id, "planning", 5, "Starting report outline planning...", completed_sections=[] ) - # 记录规划开始日志 + self.report_logger.log_planning_start() if progress_callback: - progress_callback("planning", 0, "开始规划报告大纲...") + progress_callback("planning", 0, "Starting report outline planning...") outline = self.plan_outline( progress_callback=lambda stage, prog, msg: @@ -1614,33 +1495,33 @@ def generate_report( ) report.outline = outline - # 记录规划完成日志 + self.report_logger.log_planning_complete(outline.to_dict()) - # 保存大纲到文件 + ReportManager.save_outline(report_id, outline) ReportManager.update_progress( - report_id, "planning", 15, f"大纲规划完成,共{len(outline.sections)}个章节", + report_id, "planning", 15, f"Outline planning completed with {len(outline.sections)} sections", completed_sections=[] ) ReportManager.save_report(report) - logger.info(f"大纲已保存到文件: {report_id}/outline.json") + logger.info(f"Outline saved to file: {report_id}/outline.json") + - # 阶段2: 逐章节生成(分章节保存) report.status = ReportStatus.GENERATING total_sections = len(outline.sections) - generated_sections = [] # 保存内容用于上下文 + generated_sections = [] for i, section in enumerate(outline.sections): section_num = i + 1 base_progress = 20 + int((i / total_sections) * 70) - # 更新进度 + ReportManager.update_progress( report_id, "generating", base_progress, - f"正在生成章节: {section.title} ({section_num}/{total_sections})", + f"Generating section: {section.title} ({section_num}/{total_sections})", current_section=section.title, completed_sections=completed_section_titles ) @@ -1649,10 +1530,10 @@ def generate_report( progress_callback( "generating", base_progress, - f"正在生成章节: {section.title} ({section_num}/{total_sections})" + f"Generating section: {section.title} ({section_num}/{total_sections})" ) - # 生成主章节内容 + section_content = self._generate_section_react( section=section, outline=outline, @@ -1669,11 +1550,11 @@ def generate_report( section.content = section_content generated_sections.append(f"## {section.title}\n\n{section_content}") - # 保存章节 + ReportManager.save_section(report_id, section_num, section) completed_section_titles.append(section.title) - # 记录章节完成日志 + full_section_content = f"## {section.title}\n\n{section_content}" if self.report_logger: @@ -1683,54 +1564,54 @@ def generate_report( full_content=full_section_content.strip() ) - logger.info(f"章节已保存: {report_id}/section_{section_num:02d}.md") + logger.info(f"Section saved: {report_id}/section_{section_num:02d}.md") + - # 更新进度 ReportManager.update_progress( report_id, "generating", base_progress + int(70 / total_sections), - f"章节 {section.title} 已完成", + f"Section completed: {section.title}", current_section=None, completed_sections=completed_section_titles ) - # 阶段3: 组装完整报告 + if progress_callback: - progress_callback("generating", 95, "正在组装完整报告...") + progress_callback("generating", 95, "Assembling full report...") ReportManager.update_progress( - report_id, "generating", 95, "正在组装完整报告...", + report_id, "generating", 95, "Assembling full report...", completed_sections=completed_section_titles ) - # 使用ReportManager组装完整报告 + report.markdown_content = ReportManager.assemble_full_report(report_id, outline) report.status = ReportStatus.COMPLETED report.completed_at = datetime.now().isoformat() - # 计算总耗时 + total_time_seconds = (datetime.now() - start_time).total_seconds() - # 记录报告完成日志 + if self.report_logger: self.report_logger.log_report_complete( total_sections=total_sections, total_time_seconds=total_time_seconds ) - # 保存最终报告 + ReportManager.save_report(report) ReportManager.update_progress( - report_id, "completed", 100, "报告生成完成", + report_id, "completed", 100, "Report generation completed", completed_sections=completed_section_titles ) if progress_callback: - progress_callback("completed", 100, "报告生成完成") + progress_callback("completed", 100, "Report generation completed") + + logger.info(f"Report generation completed: {report_id}") - logger.info(f"报告生成完成: {report_id}") - # 关闭控制台日志记录器 if self.console_logger: self.console_logger.close() self.console_logger = None @@ -1738,25 +1619,25 @@ def generate_report( return report except Exception as e: - logger.error(f"报告生成失败: {str(e)}") + logger.error(f"Report generation failed: {str(e)}") report.status = ReportStatus.FAILED report.error = str(e) - # 记录错误日志 + if self.report_logger: self.report_logger.log_error(str(e), "failed") - # 保存失败状态 + try: ReportManager.save_report(report) ReportManager.update_progress( - report_id, "failed", -1, f"报告生成失败: {str(e)}", + report_id, "failed", -1, f"Report generation failed: {str(e)}", completed_sections=completed_section_titles ) except Exception: - pass # 忽略保存失败的错误 + pass + - # 关闭控制台日志记录器 if self.console_logger: self.console_logger.close() self.console_logger = None @@ -1768,60 +1649,45 @@ def chat( message: str, chat_history: List[Dict[str, str]] = None ) -> Dict[str, Any]: - """ - 与Report Agent对话 - - 在对话中Agent可以自主调用检索工具来回答问题 - - Args: - message: 用户消息 - chat_history: 对话历史 - - Returns: - { - "response": "Agent回复", - "tool_calls": [调用的工具列表], - "sources": [信息来源] - } - """ - logger.info(f"Report Agent对话: {message[:50]}...") + """Send a chat request.""" + logger.info(f"Report Agent chat: {message[:50]}...") chat_history = chat_history or [] - # 获取已生成的报告内容 + report_content = "" try: report = ReportManager.get_report_by_simulation(self.simulation_id) if report and report.markdown_content: - # 限制报告长度,避免上下文过长 + report_content = report.markdown_content[:15000] if len(report.markdown_content) > 15000: - report_content += "\n\n... [报告内容已截断] ..." + report_content += "\n\n... [report content truncated] ..." except Exception as e: - logger.warning(f"获取报告内容失败: {e}") + logger.warning(f"Failed to load report content: {e}") system_prompt = CHAT_SYSTEM_PROMPT_TEMPLATE.format( simulation_requirement=self.simulation_requirement, - report_content=report_content if report_content else "(暂无报告)", + report_content=report_content if report_content else "(No report yet)", tools_description=self._get_tools_description(), ) - # 构建消息 + messages = [{"role": "system", "content": system_prompt}] - # 添加历史对话 - for h in chat_history[-10:]: # 限制历史长度 + + for h in chat_history[-10:]: messages.append(h) - # 添加用户消息 + messages.append({ "role": "user", "content": message }) - # ReACT循环(简化版) + tool_calls_made = [] - max_iterations = 2 # 减少迭代轮数 + max_iterations = 2 for iteration in range(max_iterations): response = self.llm.chat( @@ -1829,11 +1695,11 @@ def chat( temperature=0.5 ) - # 解析工具调用 + tool_calls = self._parse_tool_calls(response) if not tool_calls: - # 没有工具调用,直接返回响应 + clean_response = re.sub(r'.*?', '', response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) @@ -1843,33 +1709,33 @@ def chat( "sources": [tc.get("parameters", {}).get("query", "") for tc in tool_calls_made] } - # 执行工具调用(限制数量) + tool_results = [] - for call in tool_calls[:1]: # 每轮最多执行1次工具调用 + for call in tool_calls[:1]: if len(tool_calls_made) >= self.MAX_TOOL_CALLS_PER_CHAT: break result = self._execute_tool(call["name"], call.get("parameters", {})) tool_results.append({ "tool": call["name"], - "result": result[:1500] # 限制结果长度 + "result": result[:1500] }) tool_calls_made.append(call) - # 将结果添加到消息 + messages.append({"role": "assistant", "content": response}) - observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results]) + observation = "\n".join([f"[{r['tool']} result]\n{r['result']}" for r in tool_results]) messages.append({ "role": "user", "content": observation + CHAT_OBSERVATION_SUFFIX }) - # 达到最大迭代,获取最终响应 + final_response = self.llm.chat( messages=messages, temperature=0.5 ) - # 清理响应 + clean_response = re.sub(r'.*?', '', final_response, flags=re.DOTALL) clean_response = re.sub(r'\[TOOL_CALL\].*?\)', '', clean_response) @@ -1881,98 +1747,66 @@ def chat( class ReportManager: - """ - 报告管理器 - - 负责报告的持久化存储和检索 - - 文件结构(分章节输出): - reports/ - {report_id}/ - meta.json - 报告元信息和状态 - outline.json - 报告大纲 - progress.json - 生成进度 - section_01.md - 第1章节 - section_02.md - 第2章节 - ... - full_report.md - 完整报告 - """ - - # 报告存储目录 + """Report Manager.""" + + REPORTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'reports') @classmethod def _ensure_reports_dir(cls): - """确保报告根目录存在""" + """Ensure reports dir.""" os.makedirs(cls.REPORTS_DIR, exist_ok=True) @classmethod def _get_report_folder(cls, report_id: str) -> str: - """获取报告文件夹路径""" + """Get report folder.""" return os.path.join(cls.REPORTS_DIR, report_id) @classmethod def _ensure_report_folder(cls, report_id: str) -> str: - """确保报告文件夹存在并返回路径""" + """Ensure report folder.""" folder = cls._get_report_folder(report_id) os.makedirs(folder, exist_ok=True) return folder @classmethod def _get_report_path(cls, report_id: str) -> str: - """获取报告元信息文件路径""" + """Get report path.""" return os.path.join(cls._get_report_folder(report_id), "meta.json") @classmethod def _get_report_markdown_path(cls, report_id: str) -> str: - """获取完整报告Markdown文件路径""" + """Get report markdown path.""" return os.path.join(cls._get_report_folder(report_id), "full_report.md") @classmethod def _get_outline_path(cls, report_id: str) -> str: - """获取大纲文件路径""" + """Get outline path.""" return os.path.join(cls._get_report_folder(report_id), "outline.json") @classmethod def _get_progress_path(cls, report_id: str) -> str: - """获取进度文件路径""" + """Get progress path.""" return os.path.join(cls._get_report_folder(report_id), "progress.json") @classmethod def _get_section_path(cls, report_id: str, section_index: int) -> str: - """获取章节Markdown文件路径""" + """Get section path.""" return os.path.join(cls._get_report_folder(report_id), f"section_{section_index:02d}.md") @classmethod def _get_agent_log_path(cls, report_id: str) -> str: - """获取 Agent 日志文件路径""" + """Get agent log path.""" return os.path.join(cls._get_report_folder(report_id), "agent_log.jsonl") @classmethod def _get_console_log_path(cls, report_id: str) -> str: - """获取控制台日志文件路径""" + """Get console log path.""" return os.path.join(cls._get_report_folder(report_id), "console_log.txt") @classmethod def get_console_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: - """ - 获取控制台日志内容 - - 这是报告生成过程中的控制台输出日志(INFO、WARNING等), - 与 agent_log.jsonl 的结构化日志不同。 - - Args: - report_id: 报告ID - from_line: 从第几行开始读取(用于增量获取,0 表示从头开始) - - Returns: - { - "logs": [日志行列表], - "total_lines": 总行数, - "from_line": 起始行号, - "has_more": 是否还有更多日志 - } - """ + """Get console log.""" log_path = cls._get_console_log_path(report_id) if not os.path.exists(log_path): @@ -1990,47 +1824,25 @@ def get_console_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: for i, line in enumerate(f): total_lines = i + 1 if i >= from_line: - # 保留原始日志行,去掉末尾换行符 + logs.append(line.rstrip('\n\r')) return { "logs": logs, "total_lines": total_lines, "from_line": from_line, - "has_more": False # 已读取到末尾 + "has_more": False } @classmethod def get_console_log_stream(cls, report_id: str) -> List[str]: - """ - 获取完整的控制台日志(一次性获取全部) - - Args: - report_id: 报告ID - - Returns: - 日志行列表 - """ + """Get console log stream.""" result = cls.get_console_log(report_id, from_line=0) return result["logs"] @classmethod def get_agent_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: - """ - 获取 Agent 日志内容 - - Args: - report_id: 报告ID - from_line: 从第几行开始读取(用于增量获取,0 表示从头开始) - - Returns: - { - "logs": [日志条目列表], - "total_lines": 总行数, - "from_line": 起始行号, - "has_more": 是否还有更多日志 - } - """ + """Get agent log.""" log_path = cls._get_agent_log_path(report_id) if not os.path.exists(log_path): @@ -2052,43 +1864,31 @@ def get_agent_log(cls, report_id: str, from_line: int = 0) -> Dict[str, Any]: log_entry = json.loads(line.strip()) logs.append(log_entry) except json.JSONDecodeError: - # 跳过解析失败的行 + continue return { "logs": logs, "total_lines": total_lines, "from_line": from_line, - "has_more": False # 已读取到末尾 + "has_more": False } @classmethod def get_agent_log_stream(cls, report_id: str) -> List[Dict[str, Any]]: - """ - 获取完整的 Agent 日志(用于一次性获取全部) - - Args: - report_id: 报告ID - - Returns: - 日志条目列表 - """ + """Get agent log stream.""" result = cls.get_agent_log(report_id, from_line=0) return result["logs"] @classmethod def save_outline(cls, report_id: str, outline: ReportOutline) -> None: - """ - 保存报告大纲 - - 在规划阶段完成后立即调用 - """ + """Save outline.""" cls._ensure_report_folder(report_id) with open(cls._get_outline_path(report_id), 'w', encoding='utf-8') as f: json.dump(outline.to_dict(), f, ensure_ascii=False, indent=2) - logger.info(f"大纲已保存: {report_id}") + logger.info(f"Outline saved: {report_id}") @classmethod def save_section( @@ -2097,51 +1897,27 @@ def save_section( section_index: int, section: ReportSection ) -> str: - """ - 保存单个章节 - - 在每个章节生成完成后立即调用,实现分章节输出 - - Args: - report_id: 报告ID - section_index: 章节索引(从1开始) - section: 章节对象 - - Returns: - 保存的文件路径 - """ + """Save section.""" cls._ensure_report_folder(report_id) - # 构建章节Markdown内容 - 清理可能存在的重复标题 + cleaned_content = cls._clean_section_content(section.content, section.title) md_content = f"## {section.title}\n\n" if cleaned_content: md_content += f"{cleaned_content}\n\n" - # 保存文件 + file_suffix = f"section_{section_index:02d}.md" file_path = os.path.join(cls._get_report_folder(report_id), file_suffix) with open(file_path, 'w', encoding='utf-8') as f: f.write(md_content) - logger.info(f"章节已保存: {report_id}/{file_suffix}") + logger.info(f"Section saved: {report_id}/{file_suffix}") return file_path @classmethod def _clean_section_content(cls, content: str, section_title: str) -> str: - """ - 清理章节内容 - - 1. 移除内容开头与章节标题重复的Markdown标题行 - 2. 将所有 ### 及以下级别的标题转换为粗体文本 - - Args: - content: 原始内容 - section_title: 章节标题 - - Returns: - 清理后的内容 - """ + """Clean section content.""" import re if not content: @@ -2155,26 +1931,26 @@ def _clean_section_content(cls, content: str, section_title: str) -> str: for i, line in enumerate(lines): stripped = line.strip() - # 检查是否是Markdown标题行 + heading_match = re.match(r'^(#{1,6})\s+(.+)$', stripped) if heading_match: level = len(heading_match.group(1)) title_text = heading_match.group(2).strip() - # 检查是否是与章节标题重复的标题(跳过前5行内的重复) + if i < 5: if title_text == section_title or title_text.replace(' ', '') == section_title.replace(' ', ''): skip_next_empty = True continue - # 将所有级别的标题(#, ##, ###, ####等)转换为粗体 - # 因为章节标题由系统添加,内容中不应有任何标题 + + cleaned_lines.append(f"**{title_text}**") - cleaned_lines.append("") # 添加空行 + cleaned_lines.append("") continue - # 如果上一行是被跳过的标题,且当前行为空,也跳过 + if skip_next_empty and stripped == '': skip_next_empty = False continue @@ -2182,14 +1958,14 @@ def _clean_section_content(cls, content: str, section_title: str) -> str: skip_next_empty = False cleaned_lines.append(line) - # 移除开头的空行 + while cleaned_lines and cleaned_lines[0].strip() == '': cleaned_lines.pop(0) - # 移除开头的分隔线 + while cleaned_lines and cleaned_lines[0].strip() in ['---', '***', '___']: cleaned_lines.pop(0) - # 同时移除分隔线后的空行 + while cleaned_lines and cleaned_lines[0].strip() == '': cleaned_lines.pop(0) @@ -2205,11 +1981,7 @@ def update_progress( current_section: str = None, completed_sections: List[str] = None ) -> None: - """ - 更新报告生成进度 - - 前端可以通过读取progress.json获取实时进度 - """ + """Update progress.""" cls._ensure_report_folder(report_id) progress_data = { @@ -2226,7 +1998,7 @@ def update_progress( @classmethod def get_progress(cls, report_id: str) -> Optional[Dict[str, Any]]: - """获取报告生成进度""" + """Get progress.""" path = cls._get_progress_path(report_id) if not os.path.exists(path): @@ -2237,11 +2009,7 @@ def get_progress(cls, report_id: str) -> Optional[Dict[str, Any]]: @classmethod def get_generated_sections(cls, report_id: str) -> List[Dict[str, Any]]: - """ - 获取已生成的章节列表 - - 返回所有已保存的章节文件信息 - """ + """Get generated sections.""" folder = cls._get_report_folder(report_id) if not os.path.exists(folder): @@ -2254,7 +2022,7 @@ def get_generated_sections(cls, report_id: str) -> List[Dict[str, Any]]: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() - # 从文件名解析章节索引 + parts = filename.replace('.md', '').split('_') section_index = int(parts[1]) @@ -2268,57 +2036,40 @@ def get_generated_sections(cls, report_id: str) -> List[Dict[str, Any]]: @classmethod def assemble_full_report(cls, report_id: str, outline: ReportOutline) -> str: - """ - 组装完整报告 - - 从已保存的章节文件组装完整报告,并进行标题清理 - """ + """Assemble full report.""" folder = cls._get_report_folder(report_id) - # 构建报告头部 + md_content = f"# {outline.title}\n\n" md_content += f"> {outline.summary}\n\n" md_content += f"---\n\n" - # 按顺序读取所有章节文件 + sections = cls.get_generated_sections(report_id) for section_info in sections: md_content += section_info["content"] - # 后处理:清理整个报告的标题问题 + md_content = cls._post_process_report(md_content, outline) - # 保存完整报告 + full_path = cls._get_report_markdown_path(report_id) with open(full_path, 'w', encoding='utf-8') as f: f.write(md_content) - logger.info(f"完整报告已组装: {report_id}") + logger.info(f"Full report assembled: {report_id}") return md_content @classmethod def _post_process_report(cls, content: str, outline: ReportOutline) -> str: - """ - 后处理报告内容 - - 1. 移除重复的标题 - 2. 保留报告主标题(#)和章节标题(##),移除其他级别的标题(###, ####等) - 3. 清理多余的空行和分隔线 - - Args: - content: 原始报告内容 - outline: 报告大纲 - - Returns: - 处理后的内容 - """ + """Post Process Report.""" import re lines = content.split('\n') processed_lines = [] prev_was_heading = False - # 收集大纲中的所有章节标题 + section_titles = set() for section in outline.sections: section_titles.add(section.title) @@ -2328,14 +2079,14 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: line = lines[i] stripped = line.strip() - # 检查是否是标题行 + heading_match = re.match(r'^(#{1,6})\s+(.+)$', stripped) if heading_match: level = len(heading_match.group(1)) title = heading_match.group(2).strip() - # 检查是否是重复标题(在连续5行内出现相同内容的标题) + is_duplicate = False for j in range(max(0, len(processed_lines) - 5), len(processed_lines)): prev_line = processed_lines[j].strip() @@ -2347,43 +2098,43 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: break if is_duplicate: - # 跳过重复标题及其后的空行 + i += 1 while i < len(lines) and lines[i].strip() == '': i += 1 continue - # 标题层级处理: - # - # (level=1) 只保留报告主标题 - # - ## (level=2) 保留章节标题 - # - ### 及以下 (level>=3) 转换为粗体文本 + + + + if level == 1: if title == outline.title: - # 保留报告主标题 + processed_lines.append(line) prev_was_heading = True elif title in section_titles: - # 章节标题错误使用了#,修正为## + processed_lines.append(f"## {title}") prev_was_heading = True else: - # 其他一级标题转为粗体 + processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False elif level == 2: if title in section_titles or title == outline.title: - # 保留章节标题 + processed_lines.append(line) prev_was_heading = True else: - # 非章节的二级标题转为粗体 + processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False else: - # ### 及以下级别的标题转换为粗体文本 + processed_lines.append(f"**{title}**") processed_lines.append("") prev_was_heading = False @@ -2392,12 +2143,12 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: continue elif stripped == '---' and prev_was_heading: - # 跳过标题后紧跟的分隔线 + i += 1 continue elif stripped == '' and prev_was_heading: - # 标题后只保留一个空行 + if processed_lines and processed_lines[-1].strip() != '': processed_lines.append(line) prev_was_heading = False @@ -2408,7 +2159,7 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: i += 1 - # 清理连续的多个空行(保留最多2个) + result_lines = [] empty_count = 0 for line in processed_lines: @@ -2424,31 +2175,31 @@ def _post_process_report(cls, content: str, outline: ReportOutline) -> str: @classmethod def save_report(cls, report: Report) -> None: - """保存报告元信息和完整报告""" + """Save report.""" cls._ensure_report_folder(report.report_id) - # 保存元信息JSON + with open(cls._get_report_path(report.report_id), 'w', encoding='utf-8') as f: json.dump(report.to_dict(), f, ensure_ascii=False, indent=2) - # 保存大纲 + if report.outline: cls.save_outline(report.report_id, report.outline) - # 保存完整Markdown报告 + if report.markdown_content: with open(cls._get_report_markdown_path(report.report_id), 'w', encoding='utf-8') as f: f.write(report.markdown_content) - logger.info(f"报告已保存: {report.report_id}") + logger.info(f"Report saved: {report.report_id}") @classmethod def get_report(cls, report_id: str) -> Optional[Report]: - """获取报告""" + """Get report.""" path = cls._get_report_path(report_id) if not os.path.exists(path): - # 兼容旧格式:检查直接存储在reports目录下的文件 + old_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.json") if os.path.exists(old_path): path = old_path @@ -2458,7 +2209,7 @@ def get_report(cls, report_id: str) -> Optional[Report]: with open(path, 'r', encoding='utf-8') as f: data = json.load(f) - # 重建Report对象 + outline = None if data.get('outline'): outline_data = data['outline'] @@ -2474,7 +2225,7 @@ def get_report(cls, report_id: str) -> Optional[Report]: sections=sections ) - # 如果markdown_content为空,尝试从full_report.md读取 + markdown_content = data.get('markdown_content', '') if not markdown_content: full_report_path = cls._get_report_markdown_path(report_id) @@ -2497,17 +2248,17 @@ def get_report(cls, report_id: str) -> Optional[Report]: @classmethod def get_report_by_simulation(cls, simulation_id: str) -> Optional[Report]: - """根据模拟ID获取报告""" + """Get report by simulation.""" cls._ensure_reports_dir() for item in os.listdir(cls.REPORTS_DIR): item_path = os.path.join(cls.REPORTS_DIR, item) - # 新格式:文件夹 + if os.path.isdir(item_path): report = cls.get_report(item) if report and report.simulation_id == simulation_id: return report - # 兼容旧格式:JSON文件 + elif item.endswith('.json'): report_id = item[:-5] report = cls.get_report(report_id) @@ -2518,19 +2269,19 @@ def get_report_by_simulation(cls, simulation_id: str) -> Optional[Report]: @classmethod def list_reports(cls, simulation_id: Optional[str] = None, limit: int = 50) -> List[Report]: - """列出报告""" + """List reports.""" cls._ensure_reports_dir() reports = [] for item in os.listdir(cls.REPORTS_DIR): item_path = os.path.join(cls.REPORTS_DIR, item) - # 新格式:文件夹 + if os.path.isdir(item_path): report = cls.get_report(item) if report: if simulation_id is None or report.simulation_id == simulation_id: reports.append(report) - # 兼容旧格式:JSON文件 + elif item.endswith('.json'): report_id = item[:-5] report = cls.get_report(report_id) @@ -2538,25 +2289,25 @@ def list_reports(cls, simulation_id: Optional[str] = None, limit: int = 50) -> L if simulation_id is None or report.simulation_id == simulation_id: reports.append(report) - # 按创建时间倒序 + reports.sort(key=lambda r: r.created_at, reverse=True) return reports[:limit] @classmethod def delete_report(cls, report_id: str) -> bool: - """删除报告(整个文件夹)""" + """Delete report.""" import shutil folder_path = cls._get_report_folder(report_id) - # 新格式:删除整个文件夹 + if os.path.exists(folder_path) and os.path.isdir(folder_path): shutil.rmtree(folder_path) - logger.info(f"报告文件夹已删除: {report_id}") + logger.info(f"Report folder deleted: {report_id}") return True - # 兼容旧格式:删除单独的文件 + deleted = False old_json_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.json") old_md_path = os.path.join(cls.REPORTS_DIR, f"{report_id}.md") diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py index cc362508b..5feddb21c 100644 --- a/backend/app/services/simulation_config_generator.py +++ b/backend/app/services/simulation_config_generator.py @@ -1,17 +1,8 @@ -""" -模拟配置智能生成器 -使用LLM根据模拟需求、文档内容、图谱信息自动生成细致的模拟参数 -实现全程自动化,无需人工设置参数 - -采用分步生成策略,避免一次性生成过长内容导致失败: -1. 生成时间配置 -2. 生成事件配置 -3. 分批生成Agent配置 -4. 生成平台配置 -""" +"""Intelligent simulation configuration generator.""" import json import math +import re from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass, field, asdict from datetime import datetime @@ -24,156 +15,156 @@ logger = get_logger('mirofish.simulation_config') -# 中国作息时间配置(北京时间) + CHINA_TIMEZONE_CONFIG = { - # 深夜时段(几乎无人活动) + "dead_hours": [0, 1, 2, 3, 4, 5], - # 早间时段(逐渐醒来) + "morning_hours": [6, 7, 8], - # 工作时段 + "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], - # 晚间高峰(最活跃) + "peak_hours": [19, 20, 21, 22], - # 夜间时段(活跃度下降) + "night_hours": [23], - # 活跃度系数 + "activity_multipliers": { - "dead": 0.05, # 凌晨几乎无人 - "morning": 0.4, # 早间逐渐活跃 - "work": 0.7, # 工作时段中等 - "peak": 1.5, # 晚间高峰 - "night": 0.5 # 深夜下降 + "dead": 0.05, + "morning": 0.4, + "work": 0.7, + "peak": 1.5, + "night": 0.5 } } @dataclass class AgentActivityConfig: - """单个Agent的活动配置""" + """Agent Activity Config.""" agent_id: int entity_uuid: str entity_name: str entity_type: str - # 活跃度配置 (0.0-1.0) - activity_level: float = 0.5 # 整体活跃度 - # 发言频率(每小时预期发言次数) + activity_level: float = 0.5 + + posts_per_hour: float = 1.0 comments_per_hour: float = 2.0 - # 活跃时间段(24小时制,0-23) + active_hours: List[int] = field(default_factory=lambda: list(range(8, 23))) - # 响应速度(对热点事件的反应延迟,单位:模拟分钟) + response_delay_min: int = 5 response_delay_max: int = 60 - # 情感倾向 (-1.0到1.0,负面到正面) + sentiment_bias: float = 0.0 - # 立场(对特定话题的态度) + stance: str = "neutral" # supportive, opposing, neutral, observer - # 影响力权重(决定其发言被其他Agent看到的概率) + influence_weight: float = 1.0 @dataclass class TimeSimulationConfig: - """时间模拟配置(基于中国人作息习惯)""" - # 模拟总时长(模拟小时数) - total_simulation_hours: int = 72 # 默认模拟72小时(3天) + """Time Simulation Config.""" + + total_simulation_hours: int = 72 + - # 每轮代表的时间(模拟分钟)- 默认60分钟(1小时),加快时间流速 minutes_per_round: int = 60 - # 每小时激活的Agent数量范围 + agents_per_hour_min: int = 5 agents_per_hour_max: int = 20 - # 高峰时段(晚间19-22点,中国人最活跃的时间) + peak_hours: List[int] = field(default_factory=lambda: [19, 20, 21, 22]) peak_activity_multiplier: float = 1.5 - # 低谷时段(凌晨0-5点,几乎无人活动) + off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5]) - off_peak_activity_multiplier: float = 0.05 # 凌晨活跃度极低 + off_peak_activity_multiplier: float = 0.05 + - # 早间时段 morning_hours: List[int] = field(default_factory=lambda: [6, 7, 8]) morning_activity_multiplier: float = 0.4 - # 工作时段 + work_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18]) work_activity_multiplier: float = 0.7 @dataclass class EventConfig: - """事件配置""" - # 初始事件(模拟开始时的触发事件) + """Event Config.""" + initial_posts: List[Dict[str, Any]] = field(default_factory=list) - # 定时事件(在特定时间触发的事件) + scheduled_events: List[Dict[str, Any]] = field(default_factory=list) - # 热点话题关键词 + hot_topics: List[str] = field(default_factory=list) - # 舆论引导方向 + narrative_direction: str = "" @dataclass class PlatformConfig: - """平台特定配置""" + """Platform Config.""" platform: str # twitter or reddit - # 推荐算法权重 - recency_weight: float = 0.4 # 时间新鲜度 - popularity_weight: float = 0.3 # 热度 - relevance_weight: float = 0.3 # 相关性 - # 病毒传播阈值(达到多少互动后触发扩散) + recency_weight: float = 0.4 + popularity_weight: float = 0.3 + relevance_weight: float = 0.3 + + viral_threshold: int = 10 - # 回声室效应强度(相似观点聚集程度) + echo_chamber_strength: float = 0.5 @dataclass class SimulationParameters: - """完整的模拟参数配置""" - # 基础信息 + """Simulation Parameters.""" + simulation_id: str project_id: str graph_id: str simulation_requirement: str - # 时间配置 + time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig) - # Agent配置列表 + agent_configs: List[AgentActivityConfig] = field(default_factory=list) - # 事件配置 + event_config: EventConfig = field(default_factory=EventConfig) - # 平台配置 + twitter_config: Optional[PlatformConfig] = None reddit_config: Optional[PlatformConfig] = None - # LLM配置 + llm_model: str = "" llm_base_url: str = "" - # 生成元数据 + generated_at: str = field(default_factory=lambda: datetime.now().isoformat()) - generation_reasoning: str = "" # LLM的推理说明 + generation_reasoning: str = "" def to_dict(self) -> Dict[str, Any]: - """转换为字典""" + """Convert the object to a dictionary.""" time_dict = asdict(self.time_config) return { "simulation_id": self.simulation_id, @@ -192,34 +183,24 @@ def to_dict(self) -> Dict[str, Any]: } def to_json(self, indent: int = 2) -> str: - """转换为JSON字符串""" + """Convert the object to a JSON string.""" return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent) class SimulationConfigGenerator: - """ - 模拟配置智能生成器 + """Simulation Config Generator.""" - 使用LLM分析模拟需求、文档内容、图谱实体信息, - 自动生成最佳的模拟参数配置 - 采用分步生成策略: - 1. 生成时间配置和事件配置(轻量级) - 2. 分批生成Agent配置(每批10-20个) - 3. 生成平台配置 - """ - - # 上下文最大字符数 MAX_CONTEXT_LENGTH = 50000 - # 每批生成的Agent数量 + AGENTS_PER_BATCH = 15 - # 各步骤的上下文截断长度(字符数) - TIME_CONFIG_CONTEXT_LENGTH = 10000 # 时间配置 - EVENT_CONFIG_CONTEXT_LENGTH = 8000 # 事件配置 - ENTITY_SUMMARY_LENGTH = 300 # 实体摘要 - AGENT_SUMMARY_LENGTH = 300 # Agent配置中的实体摘要 - ENTITIES_PER_TYPE_DISPLAY = 20 # 每类实体显示数量 + + TIME_CONFIG_CONTEXT_LENGTH = 10000 + EVENT_CONFIG_CONTEXT_LENGTH = 8000 + ENTITY_SUMMARY_LENGTH = 300 + AGENT_SUMMARY_LENGTH = 300 + ENTITIES_PER_TYPE_DISPLAY = 20 def __init__( self, @@ -232,7 +213,7 @@ def __init__( self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: - raise ValueError("LLM_API_KEY 未配置") + raise ValueError("LLM_API_KEY is not configured") self.client = OpenAI( api_key=self.api_key, @@ -251,28 +232,15 @@ def generate_config( enable_reddit: bool = True, progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> SimulationParameters: - """ - 智能生成完整的模拟配置(分步生成) - - Args: - simulation_id: 模拟ID - project_id: 项目ID - graph_id: 图谱ID - simulation_requirement: 模拟需求描述 - document_text: 原始文档内容 - entities: 过滤后的实体列表 - enable_twitter: 是否启用Twitter - enable_reddit: 是否启用Reddit - progress_callback: 进度回调函数(current_step, total_steps, message) - - Returns: - SimulationParameters: 完整的模拟参数 - """ - logger.info(f"开始智能生成模拟配置: simulation_id={simulation_id}, 实体数={len(entities)}") + """Generate config.""" + logger.info( + f"Starting intelligent simulation config generation: " + f"simulation_id={simulation_id}, entity_count={len(entities)}" + ) + - # 计算总步骤数 num_batches = math.ceil(len(entities) / self.AGENTS_PER_BATCH) - total_steps = 3 + num_batches # 时间配置 + 事件配置 + N批Agent + 平台配置 + total_steps = 3 + num_batches current_step = 0 def report_progress(step: int, message: str): @@ -282,7 +250,7 @@ def report_progress(step: int, message: str): progress_callback(step, total_steps, message) logger.info(f"[{step}/{total_steps}] {message}") - # 1. 构建基础上下文信息 + context = self._build_context( simulation_requirement=simulation_requirement, document_text=document_text, @@ -291,20 +259,20 @@ def report_progress(step: int, message: str): reasoning_parts = [] - # ========== 步骤1: 生成时间配置 ========== - report_progress(1, "生成时间配置...") + + report_progress(1, "Generating time configuration...") num_entities = len(entities) time_config_result = self._generate_time_config(context, num_entities) time_config = self._parse_time_config(time_config_result, num_entities) - reasoning_parts.append(f"时间配置: {time_config_result.get('reasoning', '成功')}") + reasoning_parts.append(f"Time config: {time_config_result.get('reasoning', 'success')}") + - # ========== 步骤2: 生成事件配置 ========== - report_progress(2, "生成事件配置和热点话题...") + report_progress(2, "Generating event configuration and hot topics...") event_config_result = self._generate_event_config(context, simulation_requirement, entities) event_config = self._parse_event_config(event_config_result) - reasoning_parts.append(f"事件配置: {event_config_result.get('reasoning', '成功')}") + reasoning_parts.append(f"Event config: {event_config_result.get('reasoning', 'success')}") + - # ========== 步骤3-N: 分批生成Agent配置 ========== all_agent_configs = [] for batch_idx in range(num_batches): start_idx = batch_idx * self.AGENTS_PER_BATCH @@ -313,7 +281,7 @@ def report_progress(step: int, message: str): report_progress( 3 + batch_idx, - f"生成Agent配置 ({start_idx + 1}-{end_idx}/{len(entities)})..." + f"Generating agent configs ({start_idx + 1}-{end_idx}/{len(entities)})..." ) batch_configs = self._generate_agent_configs_batch( @@ -324,16 +292,16 @@ def report_progress(step: int, message: str): ) all_agent_configs.extend(batch_configs) - reasoning_parts.append(f"Agent配置: 成功生成 {len(all_agent_configs)} 个") + reasoning_parts.append(f"Agent configs: generated {len(all_agent_configs)} successfully") + - # ========== 为初始帖子分配发布者 Agent ========== - logger.info("为初始帖子分配合适的发布者 Agent...") + logger.info("Assigning appropriate poster agents to initial posts...") event_config = self._assign_initial_post_agents(event_config, all_agent_configs) assigned_count = len([p for p in event_config.initial_posts if p.get("poster_agent_id") is not None]) - reasoning_parts.append(f"初始帖子分配: {assigned_count} 个帖子已分配发布者") + reasoning_parts.append(f"Initial post assignment: {assigned_count} posts assigned to poster agents") - # ========== 最后一步: 生成平台配置 ========== - report_progress(total_steps, "生成平台配置...") + + report_progress(total_steps, "Generating platform configuration...") twitter_config = None reddit_config = None @@ -357,7 +325,7 @@ def report_progress(step: int, message: str): echo_chamber_strength=0.6 ) - # 构建最终参数 + params = SimulationParameters( simulation_id=simulation_id, project_id=project_id, @@ -373,7 +341,10 @@ def report_progress(step: int, message: str): generation_reasoning=" | ".join(reasoning_parts) ) - logger.info(f"模拟配置生成完成: {len(params.agent_configs)} 个Agent配置") + logger.info( + f"Simulation configuration generation completed: " + f"{len(params.agent_configs)} agent configs" + ) return params @@ -383,33 +354,33 @@ def _build_context( document_text: str, entities: List[EntityNode] ) -> str: - """构建LLM上下文,截断到最大长度""" + """Build context.""" + - # 实体摘要 entity_summary = self._summarize_entities(entities) - # 构建上下文 + context_parts = [ - f"## 模拟需求\n{simulation_requirement}", - f"\n## 实体信息 ({len(entities)}个)\n{entity_summary}", + f"## Simulation requirement\n{simulation_requirement}", + f"\n## Entity information ({len(entities)})\n{entity_summary}", ] current_length = sum(len(p) for p in context_parts) - remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量 + remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 if remaining_length > 0 and document_text: doc_text = document_text[:remaining_length] if len(document_text) > remaining_length: - doc_text += "\n...(文档已截断)" - context_parts.append(f"\n## 原始文档内容\n{doc_text}") + doc_text += "\n...(document truncated)" + context_parts.append(f"\n## Source document content\n{doc_text}") return "\n".join(context_parts) def _summarize_entities(self, entities: List[EntityNode]) -> str: - """生成实体摘要""" + """Summarize Entities.""" lines = [] - # 按类型分组 + by_type: Dict[str, List[EntityNode]] = {} for e in entities: t = e.get_entity_type() or "Unknown" @@ -418,20 +389,20 @@ def _summarize_entities(self, entities: List[EntityNode]) -> str: by_type[t].append(e) for entity_type, type_entities in by_type.items(): - lines.append(f"\n### {entity_type} ({len(type_entities)}个)") - # 使用配置的显示数量和摘要长度 + lines.append(f"\n### {entity_type} ({len(type_entities)})") + display_count = self.ENTITIES_PER_TYPE_DISPLAY summary_len = self.ENTITY_SUMMARY_LENGTH for e in type_entities[:display_count]: summary_preview = (e.summary[:summary_len] + "...") if len(e.summary) > summary_len else e.summary lines.append(f"- {e.name}: {summary_preview}") if len(type_entities) > display_count: - lines.append(f" ... 还有 {len(type_entities) - display_count} 个") + lines.append(f" ... {len(type_entities) - display_count} more") return "\n".join(lines) def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]: - """带重试的LLM调用,包含JSON修复逻辑""" + """Call llm with retry.""" import re max_attempts = 3 @@ -446,25 +417,25 @@ def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 + temperature=0.7 - (attempt * 0.1) + ) content = response.choices[0].message.content finish_reason = response.choices[0].finish_reason - # 检查是否被截断 + if finish_reason == 'length': - logger.warning(f"LLM输出被截断 (attempt {attempt+1})") + logger.warning(f"LLM output was truncated (attempt {attempt+1})") content = self._fix_truncated_json(content) - # 尝试解析JSON + try: return json.loads(content) except json.JSONDecodeError as e: - logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}") + logger.warning(f"JSON parsing failed (attempt {attempt+1}): {str(e)[:80]}") + - # 尝试修复JSON fixed = self._try_fix_config_json(content) if fixed: return fixed @@ -472,44 +443,44 @@ def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any last_error = e except Exception as e: - logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") + logger.warning(f"LLM call failed (attempt {attempt+1}): {str(e)[:80]}") last_error = e import time time.sleep(2 * (attempt + 1)) - raise last_error or Exception("LLM调用失败") + raise last_error or Exception("LLM call failed") def _fix_truncated_json(self, content: str) -> str: - """修复被截断的JSON""" + """Fix truncated json.""" content = content.strip() - # 计算未闭合的括号 + open_braces = content.count('{') - content.count('}') open_brackets = content.count('[') - content.count(']') - # 检查是否有未闭合的字符串 + if content and content[-1] not in '",}]': content += '"' - # 闭合括号 + content += ']' * open_brackets content += '}' * open_braces return content def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]: - """尝试修复配置JSON""" + """Try fix config json.""" import re - # 修复被截断的情况 + content = self._fix_truncated_json(content) - # 提取JSON部分 + json_match = re.search(r'\{[\s\S]*\}', content) if json_match: json_str = json_match.group() - # 移除字符串中的换行符 + def fix_string(match): s = match.group(0) s = s.replace('\n', ' ').replace('\r', ' ') @@ -521,7 +492,7 @@ def fix_string(match): try: return json.loads(json_str) except: - # 尝试移除所有控制字符 + json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) json_str = re.sub(r'\s+', ' ', json_str) try: @@ -532,35 +503,34 @@ def fix_string(match): return None def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]: - """生成时间配置""" - # 使用配置的上下文截断长度 + """Generate time config.""" + context_truncated = context[:self.TIME_CONFIG_CONTEXT_LENGTH] - # 计算最大允许值(80%的agent数) + max_agents_allowed = max(1, int(num_entities * 0.9)) - prompt = f"""基于以下模拟需求,生成时间模拟配置。 + prompt = f"""Generate a time-simulation configuration for the following scenario. {context_truncated} -## 任务 -请生成时间配置JSON。 +## Task +Return time-configuration JSON. -### 基本原则(仅供参考,需根据具体事件和参与群体灵活调整): -- 用户群体为中国人,需符合北京时间作息习惯 -- 凌晨0-5点几乎无人活动(活跃度系数0.05) -- 早上6-8点逐渐活跃(活跃度系数0.4) -- 工作时间9-18点中等活跃(活跃度系数0.7) -- 晚间19-22点是高峰期(活跃度系数1.5) -- 23点后活跃度下降(活跃度系数0.5) -- 一般规律:凌晨低活跃、早间渐增、工作时段中等、晚间高峰 -- **重要**:以下示例值仅供参考,你需要根据事件性质、参与群体特点来调整具体时段 - - 例如:学生群体高峰可能是21-23点;媒体全天活跃;官方机构只在工作时间 - - 例如:突发热点可能导致深夜也有讨论,off_peak_hours 可适当缩短 +### Baseline heuristics +- Assume participant behavior broadly follows a China Standard Time daily rhythm unless the scenario strongly suggests otherwise +- Activity is minimal from 00:00-05:00 (activity multiplier about 0.05) +- Activity increases from 06:00-08:00 (about 0.4) +- Work hours from 09:00-18:00 are moderately active (about 0.7) +- Evening from 19:00-22:00 is typically the peak (about 1.5) +- Activity declines after 23:00 (about 0.5) +- These are heuristics only; adapt the exact schedule to the event and participant mix + - Example: students may peak around 21:00-23:00, media may be active all day, official institutions may stay within work hours + - Example: breaking-news events may still attract late-night discussion, so off_peak_hours can be shorter -### 返回JSON格式(不要markdown) +### Return JSON only, no Markdown -示例: +Example: {{ "total_simulation_hours": 72, "minutes_per_round": 60, @@ -570,76 +540,162 @@ def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, An "off_peak_hours": [0, 1, 2, 3, 4, 5], "morning_hours": [6, 7, 8], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], - "reasoning": "针对该事件的时间配置说明" + "reasoning": "Brief explanation for this schedule" }} -字段说明: -- total_simulation_hours (int): 模拟总时长,24-168小时,突发事件短、持续话题长 -- minutes_per_round (int): 每轮时长,30-120分钟,建议60分钟 -- agents_per_hour_min (int): 每小时最少激活Agent数(取值范围: 1-{max_agents_allowed}) -- agents_per_hour_max (int): 每小时最多激活Agent数(取值范围: 1-{max_agents_allowed}) -- peak_hours (int数组): 高峰时段,根据事件参与群体调整 -- off_peak_hours (int数组): 低谷时段,通常深夜凌晨 -- morning_hours (int数组): 早间时段 -- work_hours (int数组): 工作时段 -- reasoning (string): 简要说明为什么这样配置""" +Field notes: +- total_simulation_hours (int): total duration, usually 24-168 hours +- minutes_per_round (int): duration per round, usually 30-120 minutes, commonly 60 +- agents_per_hour_min (int): minimum active agents per hour, range 1-{max_agents_allowed} +- agents_per_hour_max (int): maximum active agents per hour, range 1-{max_agents_allowed} +- peak_hours (int array): peak activity windows +- off_peak_hours (int array): low-activity windows, usually late night and early morning +- morning_hours (int array): early-day activity window +- work_hours (int array): workday activity window +- reasoning (string): brief explanation for the configuration""" - system_prompt = "你是社交媒体模拟专家。返回纯JSON格式,时间配置需符合中国人作息习惯。" + system_prompt = ( + "You are an expert in social-media simulation design. " + "Return pure JSON only. All free-text fields must be written in English." + ) try: return self._call_llm_with_retry(prompt, system_prompt) except Exception as e: - logger.warning(f"时间配置LLM生成失败: {e}, 使用默认配置") + logger.warning(f"Time-config LLM generation failed: {e}, using default config") return self._get_default_time_config(num_entities) def _get_default_time_config(self, num_entities: int) -> Dict[str, Any]: - """获取默认时间配置(中国人作息)""" + """Get default time config.""" return { "total_simulation_hours": 72, - "minutes_per_round": 60, # 每轮1小时,加快时间流速 + "minutes_per_round": 60, "agents_per_hour_min": max(1, num_entities // 15), "agents_per_hour_max": max(5, num_entities // 5), "peak_hours": [19, 20, 21, 22], "off_peak_hours": [0, 1, 2, 3, 4, 5], "morning_hours": [6, 7, 8], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], - "reasoning": "使用默认中国人作息配置(每轮1小时)" + "reasoning": "Using the default China Standard Time activity rhythm with 1-hour rounds" } def _parse_time_config(self, result: Dict[str, Any], num_entities: int) -> TimeSimulationConfig: - """解析时间配置结果,并验证agents_per_hour值不超过总agent数""" - # 获取原始值 + """Parse time config.""" + agents_per_hour_min = result.get("agents_per_hour_min", max(1, num_entities // 15)) agents_per_hour_max = result.get("agents_per_hour_max", max(5, num_entities // 5)) - # 验证并修正:确保不超过总agent数 + if agents_per_hour_min > num_entities: - logger.warning(f"agents_per_hour_min ({agents_per_hour_min}) 超过总Agent数 ({num_entities}),已修正") + logger.warning( + f"agents_per_hour_min ({agents_per_hour_min}) exceeds total agent count " + f"({num_entities}); corrected automatically" + ) agents_per_hour_min = max(1, num_entities // 10) if agents_per_hour_max > num_entities: - logger.warning(f"agents_per_hour_max ({agents_per_hour_max}) 超过总Agent数 ({num_entities}),已修正") + logger.warning( + f"agents_per_hour_max ({agents_per_hour_max}) exceeds total agent count " + f"({num_entities}); corrected automatically" + ) agents_per_hour_max = max(agents_per_hour_min + 1, num_entities // 2) - # 确保 min < max + if agents_per_hour_min >= agents_per_hour_max: agents_per_hour_min = max(1, agents_per_hour_max // 2) - logger.warning(f"agents_per_hour_min >= max,已修正为 {agents_per_hour_min}") + logger.warning(f"agents_per_hour_min >= max; corrected to {agents_per_hour_min}") return TimeSimulationConfig( total_simulation_hours=result.get("total_simulation_hours", 72), - minutes_per_round=result.get("minutes_per_round", 60), # 默认每轮1小时 + minutes_per_round=result.get("minutes_per_round", 60), agents_per_hour_min=agents_per_hour_min, agents_per_hour_max=agents_per_hour_max, peak_hours=result.get("peak_hours", [19, 20, 21, 22]), off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]), - off_peak_activity_multiplier=0.05, # 凌晨几乎无人 + off_peak_activity_multiplier=0.05, morning_hours=result.get("morning_hours", [6, 7, 8]), morning_activity_multiplier=0.4, work_hours=result.get("work_hours", list(range(9, 19))), work_activity_multiplier=0.7, peak_activity_multiplier=1.5 ) + + @staticmethod + def _contains_cjk(value: Any) -> bool: + if isinstance(value, str): + return bool(re.search(r'[\u4e00-\u9fff]', value)) + if isinstance(value, list): + return any(SimulationConfigGenerator._contains_cjk(item) for item in value) + if isinstance(value, dict): + return any(SimulationConfigGenerator._contains_cjk(item) for item in value.values()) + return False + + def _ensure_event_config_english(self, result: Dict[str, Any]) -> Dict[str, Any]: + payload = { + "hot_topics": result.get("hot_topics", []), + "narrative_direction": result.get("narrative_direction", ""), + "initial_posts": [ + { + "content": post.get("content", ""), + "poster_type": post.get("poster_type", ""), + } + for post in result.get("initial_posts", []) + ], + "reasoning": result.get("reasoning", ""), + } + + text_fields = [ + payload["hot_topics"], + payload["narrative_direction"], + [post.get("content", "") for post in payload["initial_posts"]], + payload["reasoning"], + ] + if not any(self._contains_cjk(field) for field in text_fields): + return result + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "system", + "content": ( + "Translate the provided event-configuration text into natural English. " + "Return valid JSON only. Preserve the JSON shape. " + "Do not change any poster_type values. " + "Translate only hot_topics, narrative_direction, initial_posts[].content, and reasoning." + ), + }, + { + "role": "user", + "content": json.dumps(payload, ensure_ascii=False), + }, + ], + response_format={"type": "json_object"}, + temperature=0.1, + ) + + translated = json.loads(response.choices[0].message.content) + result["hot_topics"] = translated.get("hot_topics", result.get("hot_topics", [])) + result["narrative_direction"] = translated.get("narrative_direction", result.get("narrative_direction", "")) + result["reasoning"] = translated.get("reasoning", result.get("reasoning", "")) + + translated_posts = translated.get("initial_posts", []) + if isinstance(translated_posts, list): + merged_posts = [] + original_posts = result.get("initial_posts", []) + for idx, original_post in enumerate(original_posts): + translated_post = translated_posts[idx] if idx < len(translated_posts) and isinstance(translated_posts[idx], dict) else {} + merged_posts.append({ + **original_post, + "content": translated_post.get("content", original_post.get("content", "")), + "poster_type": original_post.get("poster_type", translated_post.get("poster_type", "")), + }) + result["initial_posts"] = merged_posts + except Exception as e: + logger.warning(f"Failed to normalize event config text to English: {e}") + + return result def _generate_event_config( self, @@ -647,14 +703,14 @@ def _generate_event_config( simulation_requirement: str, entities: List[EntityNode] ) -> Dict[str, Any]: - """生成事件配置""" + """Generate event config.""" + - # 获取可用的实体类型列表,供 LLM 参考 entity_types_available = list(set( e.get_entity_type() or "Unknown" for e in entities )) - # 为每种类型列出代表性实体名称 + type_examples = {} for e in entities: etype = e.get_entity_type() or "Unknown" @@ -668,53 +724,60 @@ def _generate_event_config( for t, examples in type_examples.items() ]) - # 使用配置的上下文截断长度 + context_truncated = context[:self.EVENT_CONFIG_CONTEXT_LENGTH] - prompt = f"""基于以下模拟需求,生成事件配置。 + prompt = f"""Generate an event configuration for the following simulation. -模拟需求: {simulation_requirement} +Simulation requirement: {simulation_requirement} {context_truncated} -## 可用实体类型及示例 +## Available entity types and examples {type_info} -## 任务 -请生成事件配置JSON: -- 提取热点话题关键词 -- 描述舆论发展方向 -- 设计初始帖子内容,**每个帖子必须指定 poster_type(发布者类型)** +## Task +Return event configuration JSON that: +- extracts hot-topic keywords +- describes the likely narrative direction +- drafts initial post content, where **every post must include a poster_type** -**重要**: poster_type 必须从上面的"可用实体类型"中选择,这样初始帖子才能分配给合适的 Agent 发布。 -例如:官方声明应由 Official/University 类型发布,新闻由 MediaOutlet 发布,学生观点由 Student 发布。 +Important: +- poster_type must be selected from the available entity types above so each initial post can be assigned to an appropriate agent +- all free-text output must be in English +- poster_type values must remain exact entity-type names -返回JSON格式(不要markdown): +Return JSON only, no Markdown: {{ - "hot_topics": ["关键词1", "关键词2", ...], - "narrative_direction": "<舆论发展方向描述>", + "hot_topics": ["keyword1", "keyword2", "..."], + "narrative_direction": "", "initial_posts": [ - {{"content": "帖子内容", "poster_type": "实体类型(必须从可用类型中选择)"}}, + {{"content": "", "poster_type": ""}}, ... ], - "reasoning": "<简要说明>" + "reasoning": "" }}""" - system_prompt = "你是舆论分析专家。返回纯JSON格式。注意 poster_type 必须精确匹配可用实体类型。" + system_prompt = ( + "You are an expert in social-media narrative analysis. " + "Return pure JSON only. All free-text fields must be in English. " + "poster_type values must exactly match one of the available entity types." + ) try: - return self._call_llm_with_retry(prompt, system_prompt) + result = self._call_llm_with_retry(prompt, system_prompt) + return self._ensure_event_config_english(result) except Exception as e: - logger.warning(f"事件配置LLM生成失败: {e}, 使用默认配置") + logger.warning(f"Event-config LLM generation failed: {e}, using default config") return { "hot_topics": [], "narrative_direction": "", "initial_posts": [], - "reasoning": "使用默认配置" + "reasoning": "Using default configuration" } def _parse_event_config(self, result: Dict[str, Any]) -> EventConfig: - """解析事件配置结果""" + """Parse event config.""" return EventConfig( initial_posts=result.get("initial_posts", []), scheduled_events=[], @@ -727,15 +790,11 @@ def _assign_initial_post_agents( event_config: EventConfig, agent_configs: List[AgentActivityConfig] ) -> EventConfig: - """ - 为初始帖子分配合适的发布者 Agent - - 根据每个帖子的 poster_type 匹配最合适的 agent_id - """ + """Assign Initial Post Agents.""" if not event_config.initial_posts: return event_config - # 按实体类型建立 agent 索引 + agents_by_type: Dict[str, List[AgentActivityConfig]] = {} for agent in agent_configs: etype = agent.entity_type.lower() @@ -743,7 +802,7 @@ def _assign_initial_post_agents( agents_by_type[etype] = [] agents_by_type[etype].append(agent) - # 类型映射表(处理 LLM 可能输出的不同格式) + type_aliases = { "official": ["official", "university", "governmentagency", "government"], "university": ["university", "official"], @@ -755,7 +814,7 @@ def _assign_initial_post_agents( "person": ["person", "student", "alumni"], } - # 记录每种类型已使用的 agent 索引,避免重复使用同一个 agent + used_indices: Dict[str, int] = {} updated_posts = [] @@ -763,17 +822,17 @@ def _assign_initial_post_agents( poster_type = post.get("poster_type", "").lower() content = post.get("content", "") - # 尝试找到匹配的 agent + matched_agent_id = None - # 1. 直接匹配 + if poster_type in agents_by_type: agents = agents_by_type[poster_type] idx = used_indices.get(poster_type, 0) % len(agents) matched_agent_id = agents[idx].agent_id used_indices[poster_type] = idx + 1 else: - # 2. 使用别名匹配 + for alias_key, aliases in type_aliases.items(): if poster_type in aliases or alias_key == poster_type: for alias in aliases: @@ -786,11 +845,14 @@ def _assign_initial_post_agents( if matched_agent_id is not None: break - # 3. 如果仍未找到,使用影响力最高的 agent + if matched_agent_id is None: - logger.warning(f"未找到类型 '{poster_type}' 的匹配 Agent,使用影响力最高的 Agent") + logger.warning( + f"No matching agent found for type '{poster_type}'; " + f"using the highest-influence agent" + ) if agent_configs: - # 按影响力排序,选择影响力最高的 + sorted_agents = sorted(agent_configs, key=lambda a: a.influence_weight, reverse=True) matched_agent_id = sorted_agents[0].agent_id else: @@ -802,7 +864,7 @@ def _assign_initial_post_agents( "poster_agent_id": matched_agent_id }) - logger.info(f"初始帖子分配: poster_type='{poster_type}' -> agent_id={matched_agent_id}") + logger.info(f"Initial post assignment: poster_type='{poster_type}' -> agent_id={matched_agent_id}") event_config.initial_posts = updated_posts return event_config @@ -814,9 +876,9 @@ def _generate_agent_configs_batch( start_idx: int, simulation_requirement: str ) -> List[AgentActivityConfig]: - """分批生成Agent配置""" + """Generate agent configs batch.""" + - # 构建实体信息(使用配置的摘要长度) entity_list = [] summary_len = self.AGENT_SUMMARY_LENGTH for i, e in enumerate(entities): @@ -827,58 +889,63 @@ def _generate_agent_configs_batch( "summary": e.summary[:summary_len] if e.summary else "" }) - prompt = f"""基于以下信息,为每个实体生成社交媒体活动配置。 + prompt = f"""Generate social-media activity settings for each entity below. -模拟需求: {simulation_requirement} +Simulation requirement: {simulation_requirement} -## 实体列表 +## Entity list ```json {json.dumps(entity_list, ensure_ascii=False, indent=2)} ``` -## 任务 -为每个实体生成活动配置,注意: -- **时间符合中国人作息**:凌晨0-5点几乎不活动,晚间19-22点最活跃 -- **官方机构**(University/GovernmentAgency):活跃度低(0.1-0.3),工作时间(9-17)活动,响应慢(60-240分钟),影响力高(2.5-3.0) -- **媒体**(MediaOutlet):活跃度中(0.4-0.6),全天活动(8-23),响应快(5-30分钟),影响力高(2.0-2.5) -- **个人**(Student/Person/Alumni):活跃度高(0.6-0.9),主要晚间活动(18-23),响应快(1-15分钟),影响力低(0.8-1.2) -- **公众人物/专家**:活跃度中(0.4-0.6),影响力中高(1.5-2.0) +## Task +Generate JSON activity settings for every entity. + +Guidance: +- Assume activity generally follows a China Standard Time rhythm unless the scenario suggests otherwise: minimal activity from 00:00-05:00 and peak activity around 19:00-22:00 +- Official institutions (University/GovernmentAgency): lower activity (0.1-0.3), mostly active during work hours (09:00-17:00), slower response (60-240 min), higher influence (2.5-3.0) +- Media outlets (MediaOutlet): medium activity (0.4-0.6), active most of the day (08:00-23:00), fast response (5-30 min), higher influence (2.0-2.5) +- Individuals (Student/Person/Alumni): higher activity (0.6-0.9), mostly active in the evening (18:00-23:00), fast response (1-15 min), lower influence (0.8-1.2) +- Public figures and experts: medium activity (0.4-0.6), medium-to-high influence (1.5-2.0) -返回JSON格式(不要markdown): +Return JSON only, no Markdown: {{ "agent_configs": [ {{ - "agent_id": <必须与输入一致>, + "agent_id": , "activity_level": <0.0-1.0>, - "posts_per_hour": <发帖频率>, - "comments_per_hour": <评论频率>, - "active_hours": [<活跃小时列表,考虑中国人作息>], - "response_delay_min": <最小响应延迟分钟>, - "response_delay_max": <最大响应延迟分钟>, - "sentiment_bias": <-1.0到1.0>, + "posts_per_hour": , + "comments_per_hour": , + "active_hours": [], + "response_delay_min": , + "response_delay_max": , + "sentiment_bias": <-1.0 to 1.0>, "stance": "", - "influence_weight": <影响力权重> + "influence_weight": }}, ... ] }}""" - system_prompt = "你是社交媒体行为分析专家。返回纯JSON,配置需符合中国人作息习惯。" + system_prompt = ( + "You are an expert in social-media behavior modeling. " + "Return pure JSON only. All free-text fields must be written in English." + ) try: result = self._call_llm_with_retry(prompt, system_prompt) llm_configs = {cfg["agent_id"]: cfg for cfg in result.get("agent_configs", [])} except Exception as e: - logger.warning(f"Agent配置批次LLM生成失败: {e}, 使用规则生成") + logger.warning(f"Agent-config batch LLM generation failed: {e}, using rule-based generation") llm_configs = {} - # 构建AgentActivityConfig对象 + configs = [] for i, entity in enumerate(entities): agent_id = start_idx + i cfg = llm_configs.get(agent_id, {}) - # 如果LLM没有生成,使用规则生成 + if not cfg: cfg = self._generate_agent_config_by_rule(entity) @@ -902,11 +969,11 @@ def _generate_agent_configs_batch( return configs def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: - """基于规则生成单个Agent配置(中国人作息)""" + """Generate agent config by rule.""" entity_type = (entity.get_entity_type() or "Unknown").lower() if entity_type in ["university", "governmentagency", "ngo"]: - # 官方机构:工作时间活动,低频率,高影响力 + return { "activity_level": 0.2, "posts_per_hour": 0.1, @@ -919,7 +986,7 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 3.0 } elif entity_type in ["mediaoutlet"]: - # 媒体:全天活动,中等频率,高影响力 + return { "activity_level": 0.5, "posts_per_hour": 0.8, @@ -932,7 +999,7 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 2.5 } elif entity_type in ["professor", "expert", "official"]: - # 专家/教授:工作+晚间活动,中等频率 + return { "activity_level": 0.4, "posts_per_hour": 0.3, @@ -945,12 +1012,12 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 2.0 } elif entity_type in ["student"]: - # 学生:晚间为主,高频率 + return { "activity_level": 0.8, "posts_per_hour": 0.6, "comments_per_hour": 1.5, - "active_hours": [8, 9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # 上午+晚间 + "active_hours": [8, 9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], "response_delay_min": 1, "response_delay_max": 15, "sentiment_bias": 0.0, @@ -958,12 +1025,12 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 0.8 } elif entity_type in ["alumni"]: - # 校友:晚间为主 + return { "activity_level": 0.6, "posts_per_hour": 0.4, "comments_per_hour": 0.8, - "active_hours": [12, 13, 19, 20, 21, 22, 23], # 午休+晚间 + "active_hours": [12, 13, 19, 20, 21, 22, 23], "response_delay_min": 5, "response_delay_max": 30, "sentiment_bias": 0.0, @@ -971,12 +1038,12 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 1.0 } else: - # 普通人:晚间高峰 + return { "activity_level": 0.7, "posts_per_hour": 0.5, "comments_per_hour": 1.2, - "active_hours": [9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # 白天+晚间 + "active_hours": [9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], "response_delay_min": 2, "response_delay_max": 20, "sentiment_bias": 0.0, @@ -984,4 +1051,3 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 1.0 } - diff --git a/backend/app/services/simulation_ipc.py b/backend/app/services/simulation_ipc.py index 9d70d0bea..2eb4d95f9 100644 --- a/backend/app/services/simulation_ipc.py +++ b/backend/app/services/simulation_ipc.py @@ -1,12 +1,4 @@ -""" -模拟IPC通信模块 -用于Flask后端和模拟脚本之间的进程间通信 - -通过文件系统实现简单的命令/响应模式: -1. Flask写入命令到 commands/ 目录 -2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录 -3. Flask轮询响应目录获取结果 -""" +"""Simulation IPC communication module.""" import os import json @@ -23,14 +15,14 @@ class CommandType(str, Enum): - """命令类型""" - INTERVIEW = "interview" # 单个Agent采访 - BATCH_INTERVIEW = "batch_interview" # 批量采访 - CLOSE_ENV = "close_env" # 关闭环境 + """Command Type.""" + INTERVIEW = "interview" + BATCH_INTERVIEW = "batch_interview" + CLOSE_ENV = "close_env" class CommandStatus(str, Enum): - """命令状态""" + """Command Status.""" PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" @@ -39,7 +31,7 @@ class CommandStatus(str, Enum): @dataclass class IPCCommand: - """IPC命令""" + """IPC Command.""" command_id: str command_type: CommandType args: Dict[str, Any] @@ -65,7 +57,7 @@ def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand': @dataclass class IPCResponse: - """IPC响应""" + """IPC Response.""" command_id: str status: CommandStatus result: Optional[Dict[str, Any]] = None @@ -93,24 +85,15 @@ def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse': class SimulationIPCClient: - """ - 模拟IPC客户端(Flask端使用) - - 用于向模拟进程发送命令并等待响应 - """ + """Simulation IPC Client.""" def __init__(self, simulation_dir: str): - """ - 初始化IPC客户端 - - Args: - simulation_dir: 模拟数据目录 - """ + """Initialize the instance.""" self.simulation_dir = simulation_dir self.commands_dir = os.path.join(simulation_dir, "ipc_commands") self.responses_dir = os.path.join(simulation_dir, "ipc_responses") - # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) @@ -121,21 +104,7 @@ def send_command( timeout: float = 60.0, poll_interval: float = 0.5 ) -> IPCResponse: - """ - 发送命令并等待响应 - - Args: - command_type: 命令类型 - args: 命令参数 - timeout: 超时时间(秒) - poll_interval: 轮询间隔(秒) - - Returns: - IPCResponse - - Raises: - TimeoutError: 等待响应超时 - """ + """Send command.""" command_id = str(uuid.uuid4()) command = IPCCommand( command_id=command_id, @@ -143,14 +112,14 @@ def send_command( args=args ) - # 写入命令文件 + command_file = os.path.join(self.commands_dir, f"{command_id}.json") with open(command_file, 'w', encoding='utf-8') as f: json.dump(command.to_dict(), f, ensure_ascii=False, indent=2) - logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}") + logger.info(f"Sending IPC command: {command_type.value}, command_id={command_id}") + - # 等待响应 response_file = os.path.join(self.responses_dir, f"{command_id}.json") start_time = time.time() @@ -161,30 +130,30 @@ def send_command( response_data = json.load(f) response = IPCResponse.from_dict(response_data) - # 清理命令和响应文件 + try: os.remove(command_file) os.remove(response_file) except OSError: pass - logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}") + logger.info(f"Received IPC response: command_id={command_id}, status={response.status.value}") return response except (json.JSONDecodeError, KeyError) as e: - logger.warning(f"解析响应失败: {e}") + logger.warning(f"Failed to parse IPC response: {e}") time.sleep(poll_interval) - # 超时 - logger.error(f"等待IPC响应超时: command_id={command_id}") - # 清理命令文件 + logger.error(f"Timed out waiting for IPC response: command_id={command_id}") + + try: os.remove(command_file) except OSError: pass - raise TimeoutError(f"等待命令响应超时 ({timeout}秒)") + raise TimeoutError(f"Timed out waiting for command response ({timeout}s)") def send_interview( self, @@ -193,21 +162,7 @@ def send_interview( platform: str = None, timeout: float = 60.0 ) -> IPCResponse: - """ - 发送单个Agent采访命令 - - Args: - agent_id: Agent ID - prompt: 采访问题 - platform: 指定平台(可选) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台 - timeout: 超时时间 - - Returns: - IPCResponse,result字段包含采访结果 - """ + """Send interview.""" args = { "agent_id": agent_id, "prompt": prompt @@ -227,20 +182,7 @@ def send_batch_interview( platform: str = None, timeout: float = 120.0 ) -> IPCResponse: - """ - 发送批量采访命令 - - Args: - interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} - platform: 默认平台(可选,会被每个采访项的platform覆盖) - - "twitter": 默认只采访Twitter平台 - - "reddit": 默认只采访Reddit平台 - - None: 双平台模拟时每个Agent同时采访两个平台 - timeout: 超时时间 - - Returns: - IPCResponse,result字段包含所有采访结果 - """ + """Send batch interview.""" args = {"interviews": interviews} if platform: args["platform"] = platform @@ -252,15 +194,7 @@ def send_batch_interview( ) def send_close_env(self, timeout: float = 30.0) -> IPCResponse: - """ - 发送关闭环境命令 - - Args: - timeout: 超时时间 - - Returns: - IPCResponse - """ + """Send close env.""" return self.send_command( command_type=CommandType.CLOSE_ENV, args={}, @@ -268,11 +202,7 @@ def send_close_env(self, timeout: float = 30.0) -> IPCResponse: ) def check_env_alive(self) -> bool: - """ - 检查模拟环境是否存活 - - 通过检查 env_status.json 文件来判断 - """ + """Check env alive.""" status_file = os.path.join(self.simulation_dir, "env_status.json") if not os.path.exists(status_file): return False @@ -286,42 +216,33 @@ def check_env_alive(self) -> bool: class SimulationIPCServer: - """ - 模拟IPC服务器(模拟脚本端使用) - - 轮询命令目录,执行命令并返回响应 - """ + """Simulation IPC Server.""" def __init__(self, simulation_dir: str): - """ - 初始化IPC服务器 - - Args: - simulation_dir: 模拟数据目录 - """ + """Initialize the instance.""" self.simulation_dir = simulation_dir self.commands_dir = os.path.join(simulation_dir, "ipc_commands") self.responses_dir = os.path.join(simulation_dir, "ipc_responses") - # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) - # 环境状态 + self._running = False def start(self): - """标记服务器为运行状态""" + """Start the requested object.""" self._running = True self._update_env_status("alive") def stop(self): - """标记服务器为停止状态""" + """Stop the requested object.""" self._running = False self._update_env_status("stopped") def _update_env_status(self, status: str): - """更新环境状态文件""" + """Update env status.""" status_file = os.path.join(self.simulation_dir, "env_status.json") with open(status_file, 'w', encoding='utf-8') as f: json.dump({ @@ -330,16 +251,11 @@ def _update_env_status(self, status: str): }, f, ensure_ascii=False, indent=2) def poll_commands(self) -> Optional[IPCCommand]: - """ - 轮询命令目录,返回第一个待处理的命令 - - Returns: - IPCCommand 或 None - """ + """Poll commands.""" if not os.path.exists(self.commands_dir): return None - # 按时间排序获取命令文件 + command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): @@ -354,23 +270,18 @@ def poll_commands(self) -> Optional[IPCCommand]: data = json.load(f) return IPCCommand.from_dict(data) except (json.JSONDecodeError, KeyError, OSError) as e: - logger.warning(f"读取命令文件失败: {filepath}, {e}") + logger.warning(f"Failed to read command file: {filepath}, {e}") continue return None def send_response(self, response: IPCResponse): - """ - 发送响应 - - Args: - response: IPC响应 - """ + """Send response.""" response_file = os.path.join(self.responses_dir, f"{response.command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response.to_dict(), f, ensure_ascii=False, indent=2) - # 删除命令文件 + command_file = os.path.join(self.commands_dir, f"{response.command_id}.json") try: os.remove(command_file) @@ -378,7 +289,7 @@ def send_response(self, response: IPCResponse): pass def send_success(self, command_id: str, result: Dict[str, Any]): - """发送成功响应""" + """Send success.""" self.send_response(IPCResponse( command_id=command_id, status=CommandStatus.COMPLETED, @@ -386,7 +297,7 @@ def send_success(self, command_id: str, result: Dict[str, Any]): )) def send_error(self, command_id: str, error: str): - """发送错误响应""" + """Send error.""" self.send_response(IPCResponse( command_id=command_id, status=CommandStatus.FAILED, diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index 96c496fd4..73c53ddb4 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -1,8 +1,4 @@ -""" -OASIS模拟管理器 -管理Twitter和Reddit双平台并行模拟 -使用预设脚本 + LLM智能生成配置参数 -""" +"""OASIS simulation manager.""" import os import json @@ -22,60 +18,60 @@ class SimulationStatus(str, Enum): - """模拟状态""" + """Simulation Status.""" CREATED = "created" PREPARING = "preparing" READY = "ready" RUNNING = "running" PAUSED = "paused" - STOPPED = "stopped" # 模拟被手动停止 - COMPLETED = "completed" # 模拟自然完成 + STOPPED = "stopped" + COMPLETED = "completed" FAILED = "failed" class PlatformType(str, Enum): - """平台类型""" + """Platform Type.""" TWITTER = "twitter" REDDIT = "reddit" @dataclass class SimulationState: - """模拟状态""" + """Simulation State.""" simulation_id: str project_id: str graph_id: str - # 平台启用状态 + enable_twitter: bool = True enable_reddit: bool = True - # 状态 + status: SimulationStatus = SimulationStatus.CREATED - # 准备阶段数据 + entities_count: int = 0 profiles_count: int = 0 entity_types: List[str] = field(default_factory=list) - # 配置生成信息 + config_generated: bool = False config_reasoning: str = "" - # 运行时数据 + current_round: int = 0 twitter_status: str = "not_started" reddit_status: str = "not_started" - # 时间戳 + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) - # 错误信息 + error: Optional[str] = None def to_dict(self) -> Dict[str, Any]: - """完整状态字典(内部使用)""" + """Convert the object to a dictionary.""" return { "simulation_id": self.simulation_id, "project_id": self.project_id, @@ -97,7 +93,7 @@ def to_dict(self) -> Dict[str, Any]: } def to_simple_dict(self) -> Dict[str, Any]: - """简化状态字典(API返回使用)""" + """Convert the object to Simple Dict.""" return { "simulation_id": self.simulation_id, "project_id": self.project_id, @@ -112,37 +108,29 @@ def to_simple_dict(self) -> Dict[str, Any]: class SimulationManager: - """ - 模拟管理器 + """Simulation Manager.""" - 核心功能: - 1. 从Zep图谱读取实体并过滤 - 2. 生成OASIS Agent Profile - 3. 使用LLM智能生成模拟配置参数 - 4. 准备预设脚本所需的所有文件 - """ - # 模拟数据存储目录 SIMULATION_DATA_DIR = os.path.join( os.path.dirname(__file__), '../../uploads/simulations' ) def __init__(self): - # 确保目录存在 + os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True) - # 内存中的模拟状态缓存 + self._simulations: Dict[str, SimulationState] = {} def _get_simulation_dir(self, simulation_id: str) -> str: - """获取模拟数据目录""" + """Get simulation dir.""" sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id) os.makedirs(sim_dir, exist_ok=True) return sim_dir def _save_simulation_state(self, state: SimulationState): - """保存模拟状态到文件""" + """Save simulation state.""" sim_dir = self._get_simulation_dir(state.simulation_id) state_file = os.path.join(sim_dir, "state.json") @@ -154,7 +142,7 @@ def _save_simulation_state(self, state: SimulationState): self._simulations[state.simulation_id] = state def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]: - """从文件加载模拟状态""" + """Load simulation state.""" if simulation_id in self._simulations: return self._simulations[simulation_id] @@ -197,18 +185,7 @@ def create_simulation( enable_twitter: bool = True, enable_reddit: bool = True, ) -> SimulationState: - """ - 创建新的模拟 - - Args: - project_id: 项目ID - graph_id: Zep图谱ID - enable_twitter: 是否启用Twitter模拟 - enable_reddit: 是否启用Reddit模拟 - - Returns: - SimulationState - """ + """Create simulation.""" import uuid simulation_id = f"sim_{uuid.uuid4().hex[:12]}" @@ -222,7 +199,7 @@ def create_simulation( ) self._save_simulation_state(state) - logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}") + logger.info(f"Created simulation: {simulation_id}, project={project_id}, graph={graph_id}") return state @@ -236,31 +213,10 @@ def prepare_simulation( progress_callback: Optional[callable] = None, parallel_profile_count: int = 3 ) -> SimulationState: - """ - 准备模拟环境(全程自动化) - - 步骤: - 1. 从Zep图谱读取并过滤实体 - 2. 为每个实体生成OASIS Agent Profile(可选LLM增强,支持并行) - 3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等) - 4. 保存配置文件和Profile文件 - 5. 复制预设脚本到模拟目录 - - Args: - simulation_id: 模拟ID - simulation_requirement: 模拟需求描述(用于LLM生成配置) - document_text: 原始文档内容(用于LLM理解背景) - defined_entity_types: 预定义的实体类型(可选) - use_llm_for_profiles: 是否使用LLM生成详细人设 - progress_callback: 进度回调函数 (stage, progress, message) - parallel_profile_count: 并行生成人设的数量,默认3 - - Returns: - SimulationState - """ + """Prepare simulation.""" state = self._load_simulation_state(simulation_id) if not state: - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") try: state.status = SimulationStatus.PREPARING @@ -268,14 +224,14 @@ def prepare_simulation( sim_dir = self._get_simulation_dir(simulation_id) - # ========== 阶段1: 读取并过滤实体 ========== + if progress_callback: - progress_callback("reading", 0, "正在连接Zep图谱...") + progress_callback("reading", 0, "Connecting to the Zep graph...") reader = ZepEntityReader() if progress_callback: - progress_callback("reading", 30, "正在读取节点数据...") + progress_callback("reading", 30, "Reading node data...") filtered = reader.filter_defined_entities( graph_id=state.graph_id, @@ -289,29 +245,29 @@ def prepare_simulation( if progress_callback: progress_callback( "reading", 100, - f"完成,共 {filtered.filtered_count} 个实体", + f"Completed. {filtered.filtered_count} entities found.", current=filtered.filtered_count, total=filtered.filtered_count ) if filtered.filtered_count == 0: state.status = SimulationStatus.FAILED - state.error = "没有找到符合条件的实体,请检查图谱是否正确构建" + state.error = "No matching entities were found. Please verify that the graph was built correctly." self._save_simulation_state(state) return state - # ========== 阶段2: 生成Agent Profile ========== + total_entities = len(filtered.entities) if progress_callback: progress_callback( "generating_profiles", 0, - "开始生成...", + "Starting profile generation...", current=0, total=total_entities ) - # 传入graph_id以启用Zep检索功能,获取更丰富的上下文 + generator = OasisProfileGenerator(graph_id=state.graph_id) def profile_progress(current, total, msg): @@ -325,7 +281,7 @@ def profile_progress(current, total, msg): item_name=msg ) - # 设置实时保存的文件路径(优先使用 Reddit JSON 格式) + realtime_output_path = None realtime_platform = "reddit" if state.enable_reddit: @@ -339,20 +295,20 @@ def profile_progress(current, total, msg): entities=filtered.entities, use_llm=use_llm_for_profiles, progress_callback=profile_progress, - graph_id=state.graph_id, # 传入graph_id用于Zep检索 - parallel_count=parallel_profile_count, # 并行生成数量 - realtime_output_path=realtime_output_path, # 实时保存路径 - output_platform=realtime_platform # 输出格式 + graph_id=state.graph_id, + parallel_count=parallel_profile_count, + realtime_output_path=realtime_output_path, + output_platform=realtime_platform ) state.profiles_count = len(profiles) - # 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式) - # Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性 + + if progress_callback: progress_callback( "generating_profiles", 95, - "保存Profile文件...", + "Saving profile files...", current=total_entities, total=total_entities ) @@ -365,7 +321,7 @@ def profile_progress(current, total, msg): ) if state.enable_twitter: - # Twitter使用CSV格式!这是OASIS的要求 + generator.save_profiles( profiles=profiles, file_path=os.path.join(sim_dir, "twitter_profiles.csv"), @@ -375,16 +331,16 @@ def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_profiles", 100, - f"完成,共 {len(profiles)} 个Profile", + f"Completed. {len(profiles)} profiles generated.", current=len(profiles), total=len(profiles) ) - # ========== 阶段3: LLM智能生成模拟配置 ========== + if progress_callback: progress_callback( "generating_config", 0, - "正在分析模拟需求...", + "Analyzing the simulation requirement...", current=0, total=3 ) @@ -394,7 +350,7 @@ def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_config", 30, - "正在调用LLM生成配置...", + "Generating configuration with the LLM...", current=1, total=3 ) @@ -413,12 +369,12 @@ def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_config", 70, - "正在保存配置文件...", + "Saving the configuration file...", current=2, total=3 ) - # 保存配置文件 + config_path = os.path.join(sim_dir, "simulation_config.json") with open(config_path, 'w', encoding='utf-8') as f: f.write(sim_params.to_json()) @@ -429,25 +385,27 @@ def profile_progress(current, total, msg): if progress_callback: progress_callback( "generating_config", 100, - "配置生成完成", + "Configuration generation completed.", current=3, total=3 ) - # 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录 - # 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本 - # 更新状态 + + + state.status = SimulationStatus.READY self._save_simulation_state(state) - logger.info(f"模拟准备完成: {simulation_id}, " - f"entities={state.entities_count}, profiles={state.profiles_count}") + logger.info( + f"Simulation preparation completed: {simulation_id}, " + f"entities={state.entities_count}, profiles={state.profiles_count}" + ) return state except Exception as e: - logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}") + logger.error(f"Simulation preparation failed: {simulation_id}, error={str(e)}") import traceback logger.error(traceback.format_exc()) state.status = SimulationStatus.FAILED @@ -456,16 +414,16 @@ def profile_progress(current, total, msg): raise def get_simulation(self, simulation_id: str) -> Optional[SimulationState]: - """获取模拟状态""" + """Get simulation.""" return self._load_simulation_state(simulation_id) def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]: - """列出所有模拟""" + """List simulations.""" simulations = [] if os.path.exists(self.SIMULATION_DATA_DIR): for sim_id in os.listdir(self.SIMULATION_DATA_DIR): - # 跳过隐藏文件(如 .DS_Store)和非目录文件 + sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id) if sim_id.startswith('.') or not os.path.isdir(sim_path): continue @@ -478,10 +436,10 @@ def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationS return simulations def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]: - """获取模拟的Agent Profile""" + """Get profiles.""" state = self._load_simulation_state(simulation_id) if not state: - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation not found: {simulation_id}") sim_dir = self._get_simulation_dir(simulation_id) profile_path = os.path.join(sim_dir, f"{platform}_profiles.json") @@ -493,7 +451,7 @@ def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dic return json.load(f) def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]: - """获取模拟配置""" + """Get simulation config.""" sim_dir = self._get_simulation_dir(simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") @@ -504,7 +462,7 @@ def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]: return json.load(f) def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: - """获取运行说明""" + """Get run instructions.""" sim_dir = self._get_simulation_dir(simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) @@ -519,10 +477,10 @@ def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: "parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}", }, "instructions": ( - f"1. 激活conda环境: conda activate MiroFish\n" - f"2. 运行模拟 (脚本位于 {scripts_dir}):\n" - f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n" - f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n" - f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}" + f"1. Activate the conda environment: conda activate MiroFish\n" + f"2. Run the simulation (scripts are in {scripts_dir}):\n" + f" - Twitter only: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n" + f" - Reddit only: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n" + f" - Run both platforms in parallel: python {scripts_dir}/run_parallel_simulation.py --config {config_path}" ) } diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index 8c35380d1..30d852fdb 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -1,7 +1,4 @@ -""" -OASIS模拟运行器 -在后台运行模拟并记录每个Agent的动作,支持实时状态监控 -""" +"""OASIS simulation runner.""" import os import sys @@ -25,15 +22,15 @@ logger = get_logger('mirofish.simulation_runner') -# 标记是否已注册清理函数 + _cleanup_registered = False -# 平台检测 + IS_WINDOWS = sys.platform == 'win32' class RunnerStatus(str, Enum): - """运行器状态""" + """Runner Status.""" IDLE = "idle" STARTING = "starting" RUNNING = "running" @@ -46,7 +43,7 @@ class RunnerStatus(str, Enum): @dataclass class AgentAction: - """Agent动作记录""" + """Agent Action.""" round_num: int timestamp: str platform: str # twitter / reddit @@ -73,7 +70,7 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class RoundSummary: - """每轮摘要""" + """Round Summary.""" round_num: int start_time: str end_time: Optional[str] = None @@ -99,52 +96,52 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class SimulationRunState: - """模拟运行状态(实时)""" + """Simulation Run State.""" simulation_id: str runner_status: RunnerStatus = RunnerStatus.IDLE - # 进度信息 + current_round: int = 0 total_rounds: int = 0 simulated_hours: int = 0 total_simulation_hours: int = 0 - # 各平台独立轮次和模拟时间(用于双平台并行显示) + twitter_current_round: int = 0 reddit_current_round: int = 0 twitter_simulated_hours: int = 0 reddit_simulated_hours: int = 0 - # 平台状态 + twitter_running: bool = False reddit_running: bool = False twitter_actions_count: int = 0 reddit_actions_count: int = 0 - # 平台完成状态(通过检测 actions.jsonl 中的 simulation_end 事件) + twitter_completed: bool = False reddit_completed: bool = False - # 每轮摘要 + rounds: List[RoundSummary] = field(default_factory=list) - # 最近动作(用于前端实时展示) + recent_actions: List[AgentAction] = field(default_factory=list) max_recent_actions: int = 50 - # 时间戳 + started_at: Optional[str] = None updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) completed_at: Optional[str] = None - # 错误信息 + error: Optional[str] = None - # 进程ID(用于停止) + process_pid: Optional[int] = None def add_action(self, action: AgentAction): - """添加动作到最近动作列表""" + """Add action.""" self.recent_actions.insert(0, action) if len(self.recent_actions) > self.max_recent_actions: self.recent_actions = self.recent_actions[:self.max_recent_actions] @@ -165,7 +162,7 @@ def to_dict(self) -> Dict[str, Any]: "simulated_hours": self.simulated_hours, "total_simulation_hours": self.total_simulation_hours, "progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1), - # 各平台独立轮次和时间 + "twitter_current_round": self.twitter_current_round, "reddit_current_round": self.reddit_current_round, "twitter_simulated_hours": self.twitter_simulated_hours, @@ -185,7 +182,7 @@ def to_dict(self) -> Dict[str, Any]: } def to_detail_dict(self) -> Dict[str, Any]: - """包含最近动作的详细信息""" + """Convert the object to Detail Dict.""" result = self.to_dict() result["recent_actions"] = [a.to_dict() for a in self.recent_actions] result["rounds_count"] = len(self.rounds) @@ -193,46 +190,38 @@ def to_detail_dict(self) -> Dict[str, Any]: class SimulationRunner: - """ - 模拟运行器 + """Simulation Runner.""" - 负责: - 1. 在后台进程中运行OASIS模拟 - 2. 解析运行日志,记录每个Agent的动作 - 3. 提供实时状态查询接口 - 4. 支持暂停/停止/恢复操作 - """ - # 运行状态存储目录 RUN_STATE_DIR = os.path.join( os.path.dirname(__file__), '../../uploads/simulations' ) - # 脚本目录 + SCRIPTS_DIR = os.path.join( os.path.dirname(__file__), '../../scripts' ) - # 内存中的运行状态 + _run_states: Dict[str, SimulationRunState] = {} _processes: Dict[str, subprocess.Popen] = {} _action_queues: Dict[str, Queue] = {} _monitor_threads: Dict[str, threading.Thread] = {} - _stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄 - _stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄 + _stdout_files: Dict[str, Any] = {} + _stderr_files: Dict[str, Any] = {} + - # 图谱记忆更新配置 _graph_memory_enabled: Dict[str, bool] = {} # simulation_id -> enabled @classmethod def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: - """获取运行状态""" + """Get run state.""" if simulation_id in cls._run_states: return cls._run_states[simulation_id] - # 尝试从文件加载 + state = cls._load_run_state(simulation_id) if state: cls._run_states[simulation_id] = state @@ -240,7 +229,7 @@ def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: @classmethod def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: - """从文件加载运行状态""" + """Load run state.""" state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json") if not os.path.exists(state_file): return None @@ -256,7 +245,7 @@ def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: total_rounds=data.get("total_rounds", 0), simulated_hours=data.get("simulated_hours", 0), total_simulation_hours=data.get("total_simulation_hours", 0), - # 各平台独立轮次和时间 + twitter_current_round=data.get("twitter_current_round", 0), reddit_current_round=data.get("reddit_current_round", 0), twitter_simulated_hours=data.get("twitter_simulated_hours", 0), @@ -274,7 +263,7 @@ def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: process_pid=data.get("process_pid"), ) - # 加载最近动作 + actions_data = data.get("recent_actions", []) for a in actions_data: state.recent_actions.append(AgentAction( @@ -291,12 +280,12 @@ def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: return state except Exception as e: - logger.error(f"加载运行状态失败: {str(e)}") + logger.error(f"Failed to load run state: {str(e)}") return None @classmethod def _save_run_state(cls, state: SimulationRunState): - """保存运行状态到文件""" + """Save run state.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) os.makedirs(sim_dir, exist_ok=True) state_file = os.path.join(sim_dir, "run_state.json") @@ -313,50 +302,38 @@ def start_simulation( cls, simulation_id: str, platform: str = "parallel", # twitter / reddit / parallel - max_rounds: int = None, # 最大模拟轮数(可选,用于截断过长的模拟) - enable_graph_memory_update: bool = False, # 是否将活动更新到Zep图谱 - graph_id: str = None # Zep图谱ID(启用图谱更新时必需) + max_rounds: int = None, + enable_graph_memory_update: bool = False, + graph_id: str = None ) -> SimulationRunState: - """ - 启动模拟 - - Args: - simulation_id: 模拟ID - platform: 运行平台 (twitter/reddit/parallel) - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) - enable_graph_memory_update: 是否将Agent活动动态更新到Zep图谱 - graph_id: Zep图谱ID(启用图谱更新时必需) - - Returns: - SimulationRunState - """ - # 检查是否已在运行 + """Start simulation.""" + existing = cls.get_run_state(simulation_id) if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]: - raise ValueError(f"模拟已在运行中: {simulation_id}") + raise ValueError(f"Simulation is already running: {simulation_id}") + - # 加载模拟配置 sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): - raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口") + raise ValueError("Simulation config does not exist. Call /prepare first") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) - # 初始化运行状态 + time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = int(total_hours * 60 / minutes_per_round) - # 如果指定了最大轮数,则截断 + if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: - logger.info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") + logger.info(f"Round count truncated: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") state = SimulationRunState( simulation_id=simulation_id, @@ -368,22 +345,22 @@ def start_simulation( cls._save_run_state(state) - # 如果启用图谱记忆更新,创建更新器 + if enable_graph_memory_update: if not graph_id: - raise ValueError("启用图谱记忆更新时必须提供 graph_id") + raise ValueError("graph_id is required when graph memory updates are enabled") try: ZepGraphMemoryManager.create_updater(simulation_id, graph_id) cls._graph_memory_enabled[simulation_id] = True - logger.info(f"已启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}") + logger.info(f"Graph memory updates enabled: simulation_id={simulation_id}, graph_id={graph_id}") except Exception as e: - logger.error(f"创建图谱记忆更新器失败: {e}") + logger.error(f"Failed to create graph memory updater: {e}") cls._graph_memory_enabled[simulation_id] = False else: cls._graph_memory_enabled[simulation_id] = False - # 确定运行哪个脚本(脚本位于 backend/scripts/ 目录) + if platform == "twitter": script_name = "run_twitter_simulation.py" state.twitter_running = True @@ -398,64 +375,64 @@ def start_simulation( script_path = os.path.join(cls.SCRIPTS_DIR, script_name) if not os.path.exists(script_path): - raise ValueError(f"脚本不存在: {script_path}") + raise ValueError(f"Script does not exist: {script_path}") + - # 创建动作队列 action_queue = Queue() cls._action_queues[simulation_id] = action_queue - # 启动模拟进程 + try: - # 构建运行命令,使用完整路径 - # 新的日志结构: - # twitter/actions.jsonl - Twitter 动作日志 - # reddit/actions.jsonl - Reddit 动作日志 - # simulation.log - 主进程日志 + + + + + cmd = [ - sys.executable, # Python解释器 + sys.executable, script_path, - "--config", config_path, # 使用完整配置文件路径 + "--config", config_path, ] - # 如果指定了最大轮数,添加到命令行参数 + if max_rounds is not None and max_rounds > 0: cmd.extend(["--max-rounds", str(max_rounds)]) - # 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞 + main_log_path = os.path.join(sim_dir, "simulation.log") main_log_file = open(main_log_path, 'w', encoding='utf-8') - # 设置子进程环境变量,确保 Windows 上使用 UTF-8 编码 - # 这可以修复第三方库(如 OASIS)读取文件时未指定编码的问题 + + env = os.environ.copy() - env['PYTHONUTF8'] = '1' # Python 3.7+ 支持,让所有 open() 默认使用 UTF-8 - env['PYTHONIOENCODING'] = 'utf-8' # 确保 stdout/stderr 使用 UTF-8 + env['PYTHONUTF8'] = '1' + env['PYTHONIOENCODING'] = 'utf-8' + + - # 设置工作目录为模拟目录(数据库等文件会生成在此) - # 使用 start_new_session=True 创建新的进程组,确保可以通过 os.killpg 终止所有子进程 process = subprocess.Popen( cmd, cwd=sim_dir, stdout=main_log_file, - stderr=subprocess.STDOUT, # stderr 也写入同一个文件 + stderr=subprocess.STDOUT, text=True, - encoding='utf-8', # 显式指定编码 + encoding='utf-8', bufsize=1, - env=env, # 传递带有 UTF-8 设置的环境变量 - start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程 + env=env, + start_new_session=True, ) - # 保存文件句柄以便后续关闭 + cls._stdout_files[simulation_id] = main_log_file - cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr + cls._stderr_files[simulation_id] = None state.process_pid = process.pid state.runner_status = RunnerStatus.RUNNING cls._processes[simulation_id] = process cls._save_run_state(state) - # 启动监控线程 + monitor_thread = threading.Thread( target=cls._monitor_simulation, args=(simulation_id,), @@ -464,7 +441,7 @@ def start_simulation( monitor_thread.start() cls._monitor_threads[simulation_id] = monitor_thread - logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}") + logger.info(f"Simulation started successfully: {simulation_id}, pid={process.pid}, platform={platform}") except Exception as e: state.runner_status = RunnerStatus.FAILED @@ -476,10 +453,10 @@ def start_simulation( @classmethod def _monitor_simulation(cls, simulation_id: str): - """监控模拟进程,解析动作日志""" + """Monitor Simulation.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) - # 新的日志结构:分平台的动作日志 + twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") @@ -493,75 +470,75 @@ def _monitor_simulation(cls, simulation_id: str): reddit_position = 0 try: - while process.poll() is None: # 进程仍在运行 - # 读取 Twitter 动作日志 + while process.poll() is None: + if os.path.exists(twitter_actions_log): twitter_position = cls._read_action_log( twitter_actions_log, twitter_position, state, "twitter" ) - # 读取 Reddit 动作日志 + if os.path.exists(reddit_actions_log): reddit_position = cls._read_action_log( reddit_actions_log, reddit_position, state, "reddit" ) - # 更新状态 + cls._save_run_state(state) time.sleep(2) - # 进程结束后,最后读取一次日志 + if os.path.exists(twitter_actions_log): cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter") if os.path.exists(reddit_actions_log): cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit") - # 进程结束 + exit_code = process.returncode if exit_code == 0: state.runner_status = RunnerStatus.COMPLETED state.completed_at = datetime.now().isoformat() - logger.info(f"模拟完成: {simulation_id}") + logger.info(f"Simulation completed: {simulation_id}") else: state.runner_status = RunnerStatus.FAILED - # 从主日志文件读取错误信息 + main_log_path = os.path.join(sim_dir, "simulation.log") error_info = "" try: if os.path.exists(main_log_path): with open(main_log_path, 'r', encoding='utf-8') as f: - error_info = f.read()[-2000:] # 取最后2000字符 + error_info = f.read()[-2000:] except Exception: pass - state.error = f"进程退出码: {exit_code}, 错误: {error_info}" - logger.error(f"模拟失败: {simulation_id}, error={state.error}") + state.error = f"Process exit code: {exit_code}, error: {error_info}" + logger.error(f"Simulation failed: {simulation_id}, error={state.error}") state.twitter_running = False state.reddit_running = False cls._save_run_state(state) except Exception as e: - logger.error(f"监控线程异常: {simulation_id}, error={str(e)}") + logger.error(f"Monitor thread exception: {simulation_id}, error={str(e)}") state.runner_status = RunnerStatus.FAILED state.error = str(e) cls._save_run_state(state) finally: - # 停止图谱记忆更新器 + if cls._graph_memory_enabled.get(simulation_id, False): try: ZepGraphMemoryManager.stop_updater(simulation_id) - logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}") + logger.info(f"Stopped graph memory updates: simulation_id={simulation_id}") except Exception as e: - logger.error(f"停止图谱记忆更新器失败: {e}") + logger.error(f"Failed to stop graph memory updater: {e}") cls._graph_memory_enabled.pop(simulation_id, None) - # 清理进程资源 + cls._processes.pop(simulation_id, None) cls._action_queues.pop(simulation_id, None) - # 关闭日志文件句柄 + if simulation_id in cls._stdout_files: try: cls._stdout_files[simulation_id].close() @@ -583,19 +560,8 @@ def _read_action_log( state: SimulationRunState, platform: str ) -> int: - """ - 读取动作日志文件 - - Args: - log_path: 日志文件路径 - position: 上次读取位置 - state: 运行状态对象 - platform: 平台名称 (twitter/reddit) - - Returns: - 新的读取位置 - """ - # 检查是否启用了图谱记忆更新 + """Read action log.""" + graph_memory_enabled = cls._graph_memory_enabled.get(state.simulation_id, False) graph_updater = None if graph_memory_enabled: @@ -610,36 +576,36 @@ def _read_action_log( try: action_data = json.loads(line) - # 处理事件类型的条目 + if "event_type" in action_data: event_type = action_data.get("event_type") - # 检测 simulation_end 事件,标记平台已完成 + if event_type == "simulation_end": if platform == "twitter": state.twitter_completed = True state.twitter_running = False - logger.info(f"Twitter 模拟已完成: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") + logger.info(f"Twitter simulation completed: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") elif platform == "reddit": state.reddit_completed = True state.reddit_running = False - logger.info(f"Reddit 模拟已完成: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") + logger.info(f"Reddit simulation completed: {state.simulation_id}, total_rounds={action_data.get('total_rounds')}, total_actions={action_data.get('total_actions')}") + + + - # 检查是否所有启用的平台都已完成 - # 如果只运行了一个平台,只检查那个平台 - # 如果运行了两个平台,需要两个都完成 all_completed = cls._check_all_platforms_completed(state) if all_completed: state.runner_status = RunnerStatus.COMPLETED state.completed_at = datetime.now().isoformat() - logger.info(f"所有平台模拟已完成: {state.simulation_id}") + logger.info(f"All platform simulations completed: {state.simulation_id}") + - # 更新轮次信息(从 round_end 事件) elif event_type == "round_end": round_num = action_data.get("round", 0) simulated_hours = action_data.get("simulated_hours", 0) - # 更新各平台独立的轮次和时间 + if platform == "twitter": if round_num > state.twitter_current_round: state.twitter_current_round = round_num @@ -649,10 +615,10 @@ def _read_action_log( state.reddit_current_round = round_num state.reddit_simulated_hours = simulated_hours - # 总体轮次取两个平台的最大值 + if round_num > state.current_round: state.current_round = round_num - # 总体时间取两个平台的最大值 + state.simulated_hours = max(state.twitter_simulated_hours, state.reddit_simulated_hours) continue @@ -670,11 +636,11 @@ def _read_action_log( ) state.add_action(action) - # 更新轮次 + if action.round_num and action.round_num > state.current_round: state.current_round = action.round_num - # 如果启用了图谱记忆更新,将活动发送到Zep + if graph_updater: graph_updater.add_activity_from_dict(action_data, platform) @@ -682,52 +648,38 @@ def _read_action_log( pass return f.tell() except Exception as e: - logger.warning(f"读取动作日志失败: {log_path}, error={e}") + logger.warning(f"Failed to read action log: {log_path}, error={e}") return position @classmethod def _check_all_platforms_completed(cls, state: SimulationRunState) -> bool: - """ - 检查所有启用的平台是否都已完成模拟 - - 通过检查对应的 actions.jsonl 文件是否存在来判断平台是否被启用 - - Returns: - True 如果所有启用的平台都已完成 - """ + """Check all platforms completed.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) twitter_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_log = os.path.join(sim_dir, "reddit", "actions.jsonl") - # 检查哪些平台被启用(通过文件是否存在判断) + twitter_enabled = os.path.exists(twitter_log) reddit_enabled = os.path.exists(reddit_log) - # 如果平台被启用但未完成,则返回 False + if twitter_enabled and not state.twitter_completed: return False if reddit_enabled and not state.reddit_completed: return False - # 至少有一个平台被启用且已完成 + return twitter_enabled or reddit_enabled @classmethod def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeout: int = 10): - """ - 跨平台终止进程及其子进程 - - Args: - process: 要终止的进程 - simulation_id: 模拟ID(用于日志) - timeout: 等待进程退出的超时时间(秒) - """ + """Terminate Process.""" if IS_WINDOWS: - # Windows: 使用 taskkill 命令终止进程树 - # /F = 强制终止, /T = 终止进程树(包括子进程) - logger.info(f"终止进程树 (Windows): simulation={simulation_id}, pid={process.pid}") + + + logger.info(f"Terminating process tree (Windows): simulation={simulation_id}, pid={process.pid}") try: - # 先尝试优雅终止 + subprocess.run( ['taskkill', '/PID', str(process.pid), '/T'], capture_output=True, @@ -736,8 +688,8 @@ def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeo try: process.wait(timeout=timeout) except subprocess.TimeoutExpired: - # 强制终止 - logger.warning(f"进程未响应,强制终止: {simulation_id}") + + logger.warning(f"Process did not respond; force-killing: {simulation_id}") subprocess.run( ['taskkill', '/F', '/PID', str(process.pid), '/T'], capture_output=True, @@ -745,53 +697,53 @@ def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeo ) process.wait(timeout=5) except Exception as e: - logger.warning(f"taskkill 失败,尝试 terminate: {e}") + logger.warning(f"taskkill failed; trying terminate instead: {e}") process.terminate() try: process.wait(timeout=5) except subprocess.TimeoutExpired: process.kill() else: - # Unix: 使用进程组终止 - # 由于使用了 start_new_session=True,进程组 ID 等于主进程 PID + + pgid = os.getpgid(process.pid) - logger.info(f"终止进程组 (Unix): simulation={simulation_id}, pgid={pgid}") + logger.info(f"Terminating process group (Unix): simulation={simulation_id}, pgid={pgid}") + - # 先发送 SIGTERM 给整个进程组 os.killpg(pgid, signal.SIGTERM) try: process.wait(timeout=timeout) except subprocess.TimeoutExpired: - # 如果超时后还没结束,强制发送 SIGKILL - logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}") + + logger.warning(f"Process group did not respond to SIGTERM; force-killing: {simulation_id}") os.killpg(pgid, signal.SIGKILL) process.wait(timeout=5) @classmethod def stop_simulation(cls, simulation_id: str) -> SimulationRunState: - """停止模拟""" + """Stop simulation.""" state = cls.get_run_state(simulation_id) if not state: - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation does not exist: {simulation_id}") if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]: - raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}") + raise ValueError(f"Simulation is not running: {simulation_id}, status={state.runner_status}") state.runner_status = RunnerStatus.STOPPING cls._save_run_state(state) - # 终止进程 + process = cls._processes.get(simulation_id) if process and process.poll() is None: try: cls._terminate_process(process, simulation_id) except ProcessLookupError: - # 进程已经不存在 + pass except Exception as e: - logger.error(f"终止进程组失败: {simulation_id}, error={e}") - # 回退到直接终止进程 + logger.error(f"Failed to terminate process group: {simulation_id}, error={e}") + try: process.terminate() process.wait(timeout=5) @@ -804,16 +756,16 @@ def stop_simulation(cls, simulation_id: str) -> SimulationRunState: state.completed_at = datetime.now().isoformat() cls._save_run_state(state) - # 停止图谱记忆更新器 + if cls._graph_memory_enabled.get(simulation_id, False): try: ZepGraphMemoryManager.stop_updater(simulation_id) - logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}") + logger.info(f"Stopped graph memory updates: simulation_id={simulation_id}") except Exception as e: - logger.error(f"停止图谱记忆更新器失败: {e}") + logger.error(f"Failed to stop graph memory updater: {e}") cls._graph_memory_enabled.pop(simulation_id, None) - logger.info(f"模拟已停止: {simulation_id}") + logger.info(f"Simulation stopped: {simulation_id}") return state @classmethod @@ -825,16 +777,7 @@ def _read_actions_from_file( agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: - """ - 从单个动作文件中读取动作 - - Args: - file_path: 动作日志文件路径 - default_platform: 默认平台(当动作记录中没有 platform 字段时使用) - platform_filter: 过滤平台 - agent_id: 过滤 Agent ID - round_num: 过滤轮次 - """ + """Read actions from file.""" if not os.path.exists(file_path): return [] @@ -849,18 +792,18 @@ def _read_actions_from_file( try: data = json.loads(line) - # 跳过非动作记录(如 simulation_start, round_start, round_end 等事件) + if "event_type" in data: continue - # 跳过没有 agent_id 的记录(非 Agent 动作) + if "agent_id" not in data: continue - # 获取平台:优先使用记录中的 platform,否则使用默认平台 + record_platform = data.get("platform") or default_platform or "" - # 过滤 + if platform_filter and record_platform != platform_filter: continue if agent_id is not None and data.get("agent_id") != agent_id: @@ -893,55 +836,44 @@ def get_all_actions( agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: - """ - 获取所有平台的完整动作历史(无分页限制) - - Args: - simulation_id: 模拟ID - platform: 过滤平台(twitter/reddit) - agent_id: 过滤Agent - round_num: 过滤轮次 - - Returns: - 完整的动作列表(按时间戳排序,新的在前) - """ + """Get all actions.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) actions = [] - # 读取 Twitter 动作文件(根据文件路径自动设置 platform 为 twitter) + twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") if not platform or platform == "twitter": actions.extend(cls._read_actions_from_file( twitter_actions_log, - default_platform="twitter", # 自动填充 platform 字段 + default_platform="twitter", platform_filter=platform, agent_id=agent_id, round_num=round_num )) - # 读取 Reddit 动作文件(根据文件路径自动设置 platform 为 reddit) + reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") if not platform or platform == "reddit": actions.extend(cls._read_actions_from_file( reddit_actions_log, - default_platform="reddit", # 自动填充 platform 字段 + default_platform="reddit", platform_filter=platform, agent_id=agent_id, round_num=round_num )) - # 如果分平台文件不存在,尝试读取旧的单一文件格式 + if not actions: actions_log = os.path.join(sim_dir, "actions.jsonl") actions = cls._read_actions_from_file( actions_log, - default_platform=None, # 旧格式文件中应该有 platform 字段 + default_platform=None, platform_filter=platform, agent_id=agent_id, round_num=round_num ) - # 按时间戳排序(新的在前) + actions.sort(key=lambda x: x.timestamp, reverse=True) return actions @@ -956,20 +888,7 @@ def get_actions( agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: - """ - 获取动作历史(带分页) - - Args: - simulation_id: 模拟ID - limit: 返回数量限制 - offset: 偏移量 - platform: 过滤平台 - agent_id: 过滤Agent - round_num: 过滤轮次 - - Returns: - 动作列表 - """ + """Get actions.""" actions = cls.get_all_actions( simulation_id=simulation_id, platform=platform, @@ -977,7 +896,7 @@ def get_actions( round_num=round_num ) - # 分页 + return actions[offset:offset + limit] @classmethod @@ -987,20 +906,10 @@ def get_timeline( start_round: int = 0, end_round: Optional[int] = None ) -> List[Dict[str, Any]]: - """ - 获取模拟时间线(按轮次汇总) - - Args: - simulation_id: 模拟ID - start_round: 起始轮次 - end_round: 结束轮次 - - Returns: - 每轮的汇总信息 - """ + """Get timeline.""" actions = cls.get_actions(simulation_id, limit=10000) - # 按轮次分组 + rounds: Dict[int, Dict[str, Any]] = {} for action in actions: @@ -1033,7 +942,7 @@ def get_timeline( r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1 r["last_action_time"] = action.timestamp - # 转换为列表 + result = [] for round_num in sorted(rounds.keys()): r = rounds[round_num] @@ -1053,12 +962,7 @@ def get_timeline( @classmethod def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]: - """ - 获取每个Agent的统计信息 - - Returns: - Agent统计列表 - """ + """Get agent stats.""" actions = cls.get_actions(simulation_id, limit=10000) agent_stats: Dict[int, Dict[str, Any]] = {} @@ -1089,59 +993,39 @@ def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]: stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1 stats["last_action_time"] = action.timestamp - # 按总动作数排序 + result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True) return result @classmethod def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: - """ - 清理模拟的运行日志(用于强制重新开始模拟) - - 会删除以下文件: - - run_state.json - - twitter/actions.jsonl - - reddit/actions.jsonl - - simulation.log - - stdout.log / stderr.log - - twitter_simulation.db(模拟数据库) - - reddit_simulation.db(模拟数据库) - - env_status.json(环境状态) - - 注意:不会删除配置文件(simulation_config.json)和 profile 文件 - - Args: - simulation_id: 模拟ID - - Returns: - 清理结果信息 - """ + """Cleanup Simulation Logs.""" import shutil sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - return {"success": True, "message": "模拟目录不存在,无需清理"} + return {"success": True, "message": "Simulation directory does not exist; nothing to clean up"} cleaned_files = [] errors = [] - # 要删除的文件列表(包括数据库文件) + files_to_delete = [ "run_state.json", "simulation.log", "stdout.log", "stderr.log", - "twitter_simulation.db", # Twitter 平台数据库 - "reddit_simulation.db", # Reddit 平台数据库 - "env_status.json", # 环境状态文件 + "twitter_simulation.db", + "reddit_simulation.db", + "env_status.json", ] - # 要删除的目录列表(包含动作日志) + dirs_to_clean = ["twitter", "reddit"] - # 删除文件 + for filename in files_to_delete: file_path = os.path.join(sim_dir, filename) if os.path.exists(file_path): @@ -1149,9 +1033,9 @@ def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: os.remove(file_path) cleaned_files.append(filename) except Exception as e: - errors.append(f"删除 {filename} 失败: {str(e)}") + errors.append(f"Failed to delete {filename}: {str(e)}") + - # 清理平台目录中的动作日志 for dir_name in dirs_to_clean: dir_path = os.path.join(sim_dir, dir_name) if os.path.exists(dir_path): @@ -1161,13 +1045,13 @@ def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: os.remove(actions_file) cleaned_files.append(f"{dir_name}/actions.jsonl") except Exception as e: - errors.append(f"删除 {dir_name}/actions.jsonl 失败: {str(e)}") + errors.append(f"Failed to delete {dir_name}/actions.jsonl: {str(e)}") + - # 清理内存中的运行状态 if simulation_id in cls._run_states: del cls._run_states[simulation_id] - logger.info(f"清理模拟日志完成: {simulation_id}, 删除文件: {cleaned_files}") + logger.info(f"Simulation log cleanup completed: {simulation_id}, deleted files: {cleaned_files}") return { "success": len(errors) == 0, @@ -1175,71 +1059,67 @@ def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: "errors": errors if errors else None } - # 防止重复清理的标志 + _cleanup_done = False @classmethod def cleanup_all_simulations(cls): - """ - 清理所有运行中的模拟进程 + """Cleanup All Simulations.""" - 在服务器关闭时调用,确保所有子进程被终止 - """ - # 防止重复清理 if cls._cleanup_done: return cls._cleanup_done = True - # 检查是否有内容需要清理(避免空进程的进程打印无用日志) + has_processes = bool(cls._processes) has_updaters = bool(cls._graph_memory_enabled) if not has_processes and not has_updaters: - return # 没有需要清理的内容,静默返回 + return + + logger.info("Cleaning up all simulation processes...") - logger.info("正在清理所有模拟进程...") - # 首先停止所有图谱记忆更新器(stop_all 内部会打印日志) try: ZepGraphMemoryManager.stop_all() except Exception as e: - logger.error(f"停止图谱记忆更新器失败: {e}") + logger.error(f"Failed to stop graph memory updater: {e}") cls._graph_memory_enabled.clear() - # 复制字典以避免在迭代时修改 + processes = list(cls._processes.items()) for simulation_id, process in processes: try: - if process.poll() is None: # 进程仍在运行 - logger.info(f"终止模拟进程: {simulation_id}, pid={process.pid}") + if process.poll() is None: + logger.info(f"Terminating simulation process: {simulation_id}, pid={process.pid}") try: - # 使用跨平台的进程终止方法 + cls._terminate_process(process, simulation_id, timeout=5) except (ProcessLookupError, OSError): - # 进程可能已经不存在,尝试直接终止 + try: process.terminate() process.wait(timeout=3) except Exception: process.kill() - # 更新 run_state.json + state = cls.get_run_state(simulation_id) if state: state.runner_status = RunnerStatus.STOPPED state.twitter_running = False state.reddit_running = False state.completed_at = datetime.now().isoformat() - state.error = "服务器关闭,模拟被终止" + state.error = "Server shutdown terminated the simulation" cls._save_run_state(state) - # 同时更新 state.json,将状态设为 stopped + try: sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) state_file = os.path.join(sim_dir, "state.json") - logger.info(f"尝试更新 state.json: {state_file}") + logger.info(f"Attempting to update state.json: {state_file}") if os.path.exists(state_file): with open(state_file, 'r', encoding='utf-8') as f: state_data = json.load(f) @@ -1247,16 +1127,16 @@ def cleanup_all_simulations(cls): state_data['updated_at'] = datetime.now().isoformat() with open(state_file, 'w', encoding='utf-8') as f: json.dump(state_data, f, indent=2, ensure_ascii=False) - logger.info(f"已更新 state.json 状态为 stopped: {simulation_id}") + logger.info(f"Updated state.json status to stopped: {simulation_id}") else: - logger.warning(f"state.json 不存在: {state_file}") + logger.warning(f"state.json does not exist: {state_file}") except Exception as state_err: - logger.warning(f"更新 state.json 失败: {simulation_id}, error={state_err}") + logger.warning(f"Failed to update state.json: {simulation_id}, error={state_err}") except Exception as e: - logger.error(f"清理进程失败: {simulation_id}, error={e}") + logger.error(f"Failed to clean up process: {simulation_id}, error={e}") + - # 清理文件句柄 for simulation_id, file_handle in list(cls._stdout_files.items()): try: if file_handle: @@ -1273,109 +1153,95 @@ def cleanup_all_simulations(cls): pass cls._stderr_files.clear() - # 清理内存中的状态 + cls._processes.clear() cls._action_queues.clear() - logger.info("模拟进程清理完成") + logger.info("Simulation process cleanup completed") @classmethod def register_cleanup(cls): - """ - 注册清理函数 - - 在 Flask 应用启动时调用,确保服务器关闭时清理所有模拟进程 - """ + """Register cleanup.""" global _cleanup_registered if _cleanup_registered: return - # Flask debug 模式下,只在 reloader 子进程中注册清理(实际运行应用的进程) - # WERKZEUG_RUN_MAIN=true 表示是 reloader 子进程 - # 如果不是 debug 模式,则没有这个环境变量,也需要注册 + + + is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' is_debug_mode = os.environ.get('FLASK_DEBUG') == '1' or os.environ.get('WERKZEUG_RUN_MAIN') is not None - # 在 debug 模式下,只在 reloader 子进程中注册;非 debug 模式下始终注册 + if is_debug_mode and not is_reloader_process: - _cleanup_registered = True # 标记已注册,防止子进程再次尝试 + _cleanup_registered = True return - # 保存原有的信号处理器 + original_sigint = signal.getsignal(signal.SIGINT) original_sigterm = signal.getsignal(signal.SIGTERM) - # SIGHUP 只在 Unix 系统存在(macOS/Linux),Windows 没有 + original_sighup = None has_sighup = hasattr(signal, 'SIGHUP') if has_sighup: original_sighup = signal.getsignal(signal.SIGHUP) def cleanup_handler(signum=None, frame=None): - """信号处理器:先清理模拟进程,再调用原处理器""" - # 只有在有进程需要清理时才打印日志 + """Cleanup Handler.""" + if cls._processes or cls._graph_memory_enabled: - logger.info(f"收到信号 {signum},开始清理...") + logger.info(f"Received signal {signum}; starting cleanup...") cls.cleanup_all_simulations() - # 调用原有的信号处理器,让 Flask 正常退出 + if signum == signal.SIGINT and callable(original_sigint): original_sigint(signum, frame) elif signum == signal.SIGTERM and callable(original_sigterm): original_sigterm(signum, frame) elif has_sighup and signum == signal.SIGHUP: - # SIGHUP: 终端关闭时发送 + if callable(original_sighup): original_sighup(signum, frame) else: - # 默认行为:正常退出 + sys.exit(0) else: - # 如果原处理器不可调用(如 SIG_DFL),则使用默认行为 + raise KeyboardInterrupt - # 注册 atexit 处理器(作为备用) + atexit.register(cls.cleanup_all_simulations) - # 注册信号处理器(仅在主线程中) + try: - # SIGTERM: kill 命令默认信号 + signal.signal(signal.SIGTERM, cleanup_handler) # SIGINT: Ctrl+C signal.signal(signal.SIGINT, cleanup_handler) - # SIGHUP: 终端关闭(仅 Unix 系统) + if has_sighup: signal.signal(signal.SIGHUP, cleanup_handler) except ValueError: - # 不在主线程中,只能使用 atexit - logger.warning("无法注册信号处理器(不在主线程),仅使用 atexit") + + logger.warning("Unable to register signal handler (not in main thread); using atexit only") _cleanup_registered = True @classmethod def get_running_simulations(cls) -> List[str]: - """ - 获取所有正在运行的模拟ID列表 - """ + """Get running simulations.""" running = [] for sim_id, process in cls._processes.items(): if process.poll() is None: running.append(sim_id) return running - # ============== Interview 功能 ============== + @classmethod def check_env_alive(cls, simulation_id: str) -> bool: - """ - 检查模拟环境是否存活(可以接收Interview命令) - - Args: - simulation_id: 模拟ID - - Returns: - True 表示环境存活,False 表示环境已关闭 - """ + """Check env alive.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): return False @@ -1385,15 +1251,7 @@ def check_env_alive(cls, simulation_id: str) -> bool: @classmethod def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]: - """ - 获取模拟环境的详细状态信息 - - Args: - simulation_id: 模拟ID - - Returns: - 状态详情字典,包含 status, twitter_available, reddit_available, timestamp - """ + """Get env status detail.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) status_file = os.path.join(sim_dir, "env_status.json") @@ -1428,36 +1286,17 @@ def interview_agent( platform: str = None, timeout: float = 60.0 ) -> Dict[str, Any]: - """ - 采访单个Agent - - Args: - simulation_id: 模拟ID - agent_id: Agent ID - prompt: 采访问题 - platform: 指定平台(可选) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None: 双平台模拟时同时采访两个平台,返回整合结果 - timeout: 超时时间(秒) - - Returns: - 采访结果字典 - - Raises: - ValueError: 模拟不存在或环境未运行 - TimeoutError: 等待响应超时 - """ + """Interview Agent.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation does not exist: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): - raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") + raise ValueError(f"Simulation environment is not running or has been closed; cannot execute interview: {simulation_id}") - logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}") + logger.info(f"Sending interview command: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}") response = ipc_client.send_interview( agent_id=agent_id, @@ -1491,35 +1330,17 @@ def interview_agents_batch( platform: str = None, timeout: float = 120.0 ) -> Dict[str, Any]: - """ - 批量采访多个Agent - - Args: - simulation_id: 模拟ID - interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} - platform: 默认平台(可选,会被每个采访项的platform覆盖) - - "twitter": 默认只采访Twitter平台 - - "reddit": 默认只采访Reddit平台 - - None: 双平台模拟时每个Agent同时采访两个平台 - timeout: 超时时间(秒) - - Returns: - 批量采访结果字典 - - Raises: - ValueError: 模拟不存在或环境未运行 - TimeoutError: 等待响应超时 - """ + """Interview Agents Batch.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation does not exist: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): - raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") + raise ValueError(f"Simulation environment is not running or has been closed; cannot execute interview: {simulation_id}") - logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}") + logger.info(f"Sending batch interview command: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}") response = ipc_client.send_batch_interview( interviews=interviews, @@ -1550,40 +1371,24 @@ def interview_all_agents( platform: str = None, timeout: float = 180.0 ) -> Dict[str, Any]: - """ - 采访所有Agent(全局采访) - - 使用相同的问题采访模拟中的所有Agent - - Args: - simulation_id: 模拟ID - prompt: 采访问题(所有Agent使用相同问题) - platform: 指定平台(可选) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None: 双平台模拟时每个Agent同时采访两个平台 - timeout: 超时时间(秒) - - Returns: - 全局采访结果字典 - """ + """Interview All Agents.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation does not exist: {simulation_id}") - # 从配置文件获取所有Agent信息 + config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): - raise ValueError(f"模拟配置不存在: {simulation_id}") + raise ValueError(f"Simulation config does not exist: {simulation_id}") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) agent_configs = config.get("agent_configs", []) if not agent_configs: - raise ValueError(f"模拟配置中没有Agent: {simulation_id}") + raise ValueError(f"Simulation config contains no agents: {simulation_id}") - # 构建批量采访列表 + interviews = [] for agent_config in agent_configs: agent_id = agent_config.get("agent_id") @@ -1593,7 +1398,7 @@ def interview_all_agents( "prompt": prompt }) - logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}") + logger.info(f"Sending global interview command: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}") return cls.interview_agents_batch( simulation_id=simulation_id, @@ -1608,46 +1413,35 @@ def close_simulation_env( simulation_id: str, timeout: float = 30.0 ) -> Dict[str, Any]: - """ - 关闭模拟环境(而不是停止模拟进程) - - 向模拟发送关闭环境命令,使其优雅退出等待命令模式 - - Args: - simulation_id: 模拟ID - timeout: 超时时间(秒) - - Returns: - 操作结果字典 - """ + """Close simulation env.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): - raise ValueError(f"模拟不存在: {simulation_id}") + raise ValueError(f"Simulation does not exist: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): return { "success": True, - "message": "环境已经关闭" + "message": "Environment is already closed" } - logger.info(f"发送关闭环境命令: simulation_id={simulation_id}") + logger.info(f"Sending close-environment command: simulation_id={simulation_id}") try: response = ipc_client.send_close_env(timeout=timeout) return { "success": response.status.value == "completed", - "message": "环境关闭命令已发送", + "message": "Close-environment command sent", "result": response.result, "timestamp": response.timestamp } except TimeoutError: - # 超时可能是因为环境正在关闭 + return { "success": True, - "message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)" + "message": "Close-environment command sent (response timed out; environment may already be closing)" } @classmethod @@ -1658,7 +1452,7 @@ def _get_interview_history_from_db( agent_id: Optional[int] = None, limit: int = 100 ) -> List[Dict[str, Any]]: - """从单个数据库获取Interview历史""" + """Get interview history from db.""" import sqlite3 if not os.path.exists(db_path): @@ -1704,7 +1498,7 @@ def _get_interview_history_from_db( conn.close() except Exception as e: - logger.error(f"读取Interview历史失败 ({platform_name}): {e}") + logger.error(f"Failed to read interview history ({platform_name}): {e}") return results @@ -1716,30 +1510,16 @@ def get_interview_history( agent_id: Optional[int] = None, limit: int = 100 ) -> List[Dict[str, Any]]: - """ - 获取Interview历史记录(从数据库读取) - - Args: - simulation_id: 模拟ID - platform: 平台类型(reddit/twitter/None) - - "reddit": 只获取Reddit平台的历史 - - "twitter": 只获取Twitter平台的历史 - - None: 获取两个平台的所有历史 - agent_id: 指定Agent ID(可选,只获取该Agent的历史) - limit: 每个平台返回数量限制 - - Returns: - Interview历史记录列表 - """ + """Get interview history.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) results = [] - # 确定要查询的平台 + if platform in ("reddit", "twitter"): platforms = [platform] else: - # 不指定platform时,查询两个平台 + platforms = ["twitter", "reddit"] for p in platforms: @@ -1752,12 +1532,11 @@ def get_interview_history( ) results.extend(platform_results) - # 按时间降序排序 + results.sort(key=lambda x: x.get("timestamp", ""), reverse=True) - # 如果查询了多个平台,限制总数 + if len(platforms) > 1 and len(results) > limit: results = results[:limit] return results - diff --git a/backend/app/services/text_processor.py b/backend/app/services/text_processor.py index 91e32acc5..d3073eb96 100644 --- a/backend/app/services/text_processor.py +++ b/backend/app/services/text_processor.py @@ -1,17 +1,15 @@ -""" -文本处理服务 -""" +"""Text processing service.""" from typing import List, Optional from ..utils.file_parser import FileParser, split_text_into_chunks class TextProcessor: - """文本处理器""" + """Text Processor.""" @staticmethod def extract_from_files(file_paths: List[str]) -> str: - """从多个文件提取文本""" + """Extract from files.""" return FileParser.extract_from_multiple(file_paths) @staticmethod @@ -20,41 +18,21 @@ def split_text( chunk_size: int = 500, overlap: int = 50 ) -> List[str]: - """ - 分割文本 - - Args: - text: 原始文本 - chunk_size: 块大小 - overlap: 重叠大小 - - Returns: - 文本块列表 - """ + """Split text.""" return split_text_into_chunks(text, chunk_size, overlap) @staticmethod def preprocess_text(text: str) -> str: - """ - 预处理文本 - - 移除多余空白 - - 标准化换行 - - Args: - text: 原始文本 - - Returns: - 处理后的文本 - """ + """Preprocess text.""" import re - # 标准化换行 + text = text.replace('\r\n', '\n').replace('\r', '\n') - # 移除连续空行(保留最多两个换行) + text = re.sub(r'\n{3,}', '\n\n', text) - # 移除行首行尾空白 + lines = [line.strip() for line in text.split('\n')] text = '\n'.join(lines) @@ -62,7 +40,7 @@ def preprocess_text(text: str) -> str: @staticmethod def get_text_stats(text: str) -> dict: - """获取文本统计信息""" + """Get text stats.""" return { "total_chars": len(text), "total_lines": text.count('\n') + 1, diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be49..8a15109af 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -1,35 +1,30 @@ -""" -Zep实体读取与过滤服务 -从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 -""" +"""Zep entity reading and filtering service.""" import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config +from .graph_provider import create_graph_provider from ..utils.logger import get_logger -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_entity_reader') -# 用于泛型返回类型 + T = TypeVar('T') @dataclass class EntityNode: - """实体节点数据结构""" + """Entity Node.""" uuid: str name: str labels: List[str] summary: str attributes: Dict[str, Any] - # 相关的边信息 + related_edges: List[Dict[str, Any]] = field(default_factory=list) - # 相关的其他节点信息 + related_nodes: List[Dict[str, Any]] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: @@ -44,7 +39,7 @@ def to_dict(self) -> Dict[str, Any]: } def get_entity_type(self) -> Optional[str]: - """获取实体类型(排除默认的Entity标签)""" + """Get entity type.""" for label in self.labels: if label not in ["Entity", "Node"]: return label @@ -53,7 +48,7 @@ def get_entity_type(self) -> Optional[str]: @dataclass class FilteredEntities: - """过滤后的实体集合""" + """Filtered Entities.""" entities: List[EntityNode] entity_types: Set[str] total_count: int @@ -69,21 +64,11 @@ def to_dict(self) -> Dict[str, Any]: class ZepEntityReader: - """ - Zep实体读取与过滤服务 - - 主要功能: - 1. 从Zep图谱读取所有节点 - 2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点) - 3. 获取每个实体的相关边和关联节点信息 - """ + """Zep Entity Reader.""" def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.provider = create_graph_provider() def _call_with_retry( self, @@ -92,18 +77,7 @@ def _call_with_retry( max_retries: int = 3, initial_delay: float = 2.0 ) -> T: - """ - 带重试机制的Zep API调用 - - Args: - func: 要执行的函数(无参数的lambda或callable) - operation_name: 操作名称,用于日志 - max_retries: 最大重试次数(默认3次,即最多尝试3次) - initial_delay: 初始延迟秒数 - - Returns: - API调用结果 - """ + """Call with retry.""" last_exception = None delay = initial_delay @@ -114,61 +88,45 @@ def _call_with_retry( last_exception = e if attempt < max_retries - 1: logger.warning( - f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " - f"{delay:.1f}秒后重试..." + f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}. " + f"Retrying in {delay:.1f}s..." ) time.sleep(delay) - delay *= 2 # 指数退避 + delay *= 2 else: - logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") + logger.error(f"Zep {operation_name} still failed after {max_retries} attempts: {str(e)}") raise last_exception def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: - """ - 获取图谱的所有节点(分页获取) - - Args: - graph_id: 图谱ID - - Returns: - 节点列表 - """ - logger.info(f"获取图谱 {graph_id} 的所有节点...") + """Get all nodes.""" + logger.info(f"Loading all nodes for graph {graph_id}...") - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.provider.get_all_nodes(graph_id) nodes_data = [] for node in nodes: nodes_data.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + "uuid": node.uuid, "name": node.name or "", "labels": node.labels or [], "summary": node.summary or "", "attributes": node.attributes or {}, }) - logger.info(f"共获取 {len(nodes_data)} 个节点") + logger.info(f"Loaded {len(nodes_data)} nodes") return nodes_data def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: - """ - 获取图谱的所有边(分页获取) + """Get all edges.""" + logger.info(f"Loading all edges for graph {graph_id}...") - Args: - graph_id: 图谱ID - - Returns: - 边列表 - """ - logger.info(f"获取图谱 {graph_id} 的所有边...") - - edges = fetch_all_edges(self.client, graph_id) + edges = self.provider.get_all_edges(graph_id) edges_data = [] for edge in edges: edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "uuid": edge.uuid, "name": edge.name or "", "fact": edge.fact or "", "source_node_uuid": edge.source_node_uuid, @@ -176,30 +134,22 @@ def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: "attributes": edge.attributes or {}, }) - logger.info(f"共获取 {len(edges_data)} 条边") + logger.info(f"Loaded {len(edges_data)} edges") return edges_data - def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: - """ - 获取指定节点的所有相关边(带重试机制) - - Args: - node_uuid: 节点UUID - - Returns: - 边列表 - """ + def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]: + """Get node edges.""" try: - # 使用重试机制调用Zep API + edges = self._call_with_retry( - func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), - operation_name=f"获取节点边(node={node_uuid[:8]}...)" + func=lambda: self.provider.get_node_edges(graph_id, node_uuid), + operation_name=f"get node edges (node={node_uuid[:8]}...)" ) edges_data = [] for edge in edges: edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "uuid": edge.uuid, "name": edge.name or "", "fact": edge.fact or "", "source_node_uuid": edge.source_node_uuid, @@ -209,7 +159,7 @@ def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: return edges_data except Exception as e: - logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") + logger.warning(f"Failed to load edges for node {node_uuid}: {str(e)}") return [] def filter_defined_entities( @@ -218,48 +168,34 @@ def filter_defined_entities( defined_entity_types: Optional[List[str]] = None, enrich_with_edges: bool = True ) -> FilteredEntities: - """ - 筛选出符合预定义实体类型的节点 + """Filter defined entities.""" + logger.info(f"Starting entity filtering for graph {graph_id}...") - 筛选逻辑: - - 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过 - - 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留 - Args: - graph_id: 图谱ID - defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型) - enrich_with_edges: 是否获取每个实体的相关边信息 - - Returns: - FilteredEntities: 过滤后的实体集合 - """ - logger.info(f"开始筛选图谱 {graph_id} 的实体...") - - # 获取所有节点 all_nodes = self.get_all_nodes(graph_id) total_count = len(all_nodes) - # 获取所有边(用于后续关联查找) + all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] - # 构建节点UUID到节点数据的映射 + node_map = {n["uuid"]: n for n in all_nodes} - # 筛选符合条件的实体 + filtered_entities = [] entity_types_found = set() for node in all_nodes: labels = node.get("labels", []) - # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 + custom_labels = [l for l in labels if l not in ["Entity", "Node"]] if not custom_labels: - # 只有默认标签,跳过 + continue - # 如果指定了预定义类型,检查是否匹配 + if defined_entity_types: matching_labels = [l for l in custom_labels if l in defined_entity_types] if not matching_labels: @@ -270,7 +206,7 @@ def filter_defined_entities( entity_types_found.add(entity_type) - # 创建实体节点对象 + entity = EntityNode( uuid=node["uuid"], name=node["name"], @@ -279,7 +215,7 @@ def filter_defined_entities( attributes=node["attributes"], ) - # 获取相关边和节点 + if enrich_with_edges: related_edges = [] related_node_uuids = set() @@ -304,7 +240,7 @@ def filter_defined_entities( entity.related_edges = related_edges - # 获取关联节点的基本信息 + related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: @@ -320,8 +256,10 @@ def filter_defined_entities( filtered_entities.append(entity) - logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, " - f"实体类型: {entity_types_found}") + logger.info( + f"Entity filtering completed: total_nodes={total_count}, " + f"matched={len(filtered_entities)}, entity_types={entity_types_found}" + ) return FilteredEntities( entities=filtered_entities, @@ -335,34 +273,25 @@ def get_entity_with_context( graph_id: str, entity_uuid: str ) -> Optional[EntityNode]: - """ - 获取单个实体及其完整上下文(边和关联节点,带重试机制) - - Args: - graph_id: 图谱ID - entity_uuid: 实体UUID - - Returns: - EntityNode或None - """ + """Get entity with context.""" try: - # 使用重试机制获取节点 + node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), - operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" + func=lambda: self.provider.get_node(graph_id, entity_uuid), + operation_name=f"Fetch node details (uuid={entity_uuid[:8]}...)" ) if not node: return None - # 获取节点的边 - edges = self.get_node_edges(entity_uuid) - # 获取所有节点用于关联查找 + edges = self.get_node_edges(graph_id, entity_uuid) + + all_nodes = self.get_all_nodes(graph_id) node_map = {n["uuid"]: n for n in all_nodes} - # 处理相关边和节点 + related_edges = [] related_node_uuids = set() @@ -384,7 +313,7 @@ def get_entity_with_context( }) related_node_uuids.add(edge["source_node_uuid"]) - # 获取关联节点信息 + related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: @@ -397,7 +326,7 @@ def get_entity_with_context( }) return EntityNode( - uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + uuid=node.uuid, name=node.name or "", labels=node.labels or [], summary=node.summary or "", @@ -407,7 +336,7 @@ def get_entity_with_context( ) except Exception as e: - logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}") + logger.error(f"Failed to fetch entity {entity_uuid}: {str(e)}") return None def get_entities_by_type( @@ -416,22 +345,10 @@ def get_entities_by_type( entity_type: str, enrich_with_edges: bool = True ) -> List[EntityNode]: - """ - 获取指定类型的所有实体 - - Args: - graph_id: 图谱ID - entity_type: 实体类型(如 "Student", "PublicFigure" 等) - enrich_with_edges: 是否获取相关边信息 - - Returns: - 实体列表 - """ + """Get entities by type.""" result = self.filter_defined_entities( graph_id=graph_id, defined_entity_types=[entity_type], enrich_with_edges=enrich_with_edges ) return result.entities - - diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index a8f3cecd9..dc0431a5d 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -1,7 +1,4 @@ -""" -Zep图谱记忆更新服务 -将模拟中的Agent活动动态更新到Zep图谱中 -""" +"""Zep graph memory update service.""" import os import time @@ -12,9 +9,8 @@ from datetime import datetime from queue import Queue, Empty -from zep_cloud.client import Zep - from ..config import Config +from .graph_provider import create_graph_provider from ..utils.logger import get_logger logger = get_logger('mirofish.zep_graph_memory_updater') @@ -22,7 +18,7 @@ @dataclass class AgentActivity: - """Agent活动记录""" + """Agent Activity.""" platform: str # twitter / reddit agent_id: int agent_name: str @@ -32,13 +28,8 @@ class AgentActivity: timestamp: str def to_episode_text(self) -> str: - """ - 将活动转换为可以发送给Zep的文本描述 + """Convert the object to Episode Text.""" - 采用自然语言描述格式,让Zep能够从中提取实体和关系 - 不添加模拟相关的前缀,避免误导图谱更新 - """ - # 根据不同的动作类型生成不同的描述 action_descriptions = { "CREATE_POST": self._describe_create_post, "LIKE_POST": self._describe_like_post, @@ -57,222 +48,201 @@ def to_episode_text(self) -> str: describe_func = action_descriptions.get(self.action_type, self._describe_generic) description = describe_func() - # 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀 + return f"{self.agent_name}: {description}" def _describe_create_post(self) -> str: content = self.action_args.get("content", "") if content: - return f"发布了一条帖子:「{content}」" - return "发布了一条帖子" + return f"published a post: \"{content}\"" + return "published a post" def _describe_like_post(self) -> str: - """点赞帖子 - 包含帖子原文和作者信息""" + """Describe Like Post.""" post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") if post_content and post_author: - return f"点赞了{post_author}的帖子:「{post_content}」" + return f"liked {post_author}'s post: \"{post_content}\"" elif post_content: - return f"点赞了一条帖子:「{post_content}」" + return f"liked a post: \"{post_content}\"" elif post_author: - return f"点赞了{post_author}的一条帖子" - return "点赞了一条帖子" + return f"liked a post from {post_author}" + return "liked a post" def _describe_dislike_post(self) -> str: - """踩帖子 - 包含帖子原文和作者信息""" + """Describe Dislike Post.""" post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") if post_content and post_author: - return f"踩了{post_author}的帖子:「{post_content}」" + return f"downvoted {post_author}'s post: \"{post_content}\"" elif post_content: - return f"踩了一条帖子:「{post_content}」" + return f"downvoted a post: \"{post_content}\"" elif post_author: - return f"踩了{post_author}的一条帖子" - return "踩了一条帖子" + return f"downvoted a post from {post_author}" + return "downvoted a post" def _describe_repost(self) -> str: - """转发帖子 - 包含原帖内容和作者信息""" + """Describe Repost.""" original_content = self.action_args.get("original_content", "") original_author = self.action_args.get("original_author_name", "") if original_content and original_author: - return f"转发了{original_author}的帖子:「{original_content}」" + return f"reposted {original_author}'s post: \"{original_content}\"" elif original_content: - return f"转发了一条帖子:「{original_content}」" + return f"reposted a post: \"{original_content}\"" elif original_author: - return f"转发了{original_author}的一条帖子" - return "转发了一条帖子" + return f"reposted a post from {original_author}" + return "reposted a post" def _describe_quote_post(self) -> str: - """引用帖子 - 包含原帖内容、作者信息和引用评论""" + """Describe Quote Post.""" original_content = self.action_args.get("original_content", "") original_author = self.action_args.get("original_author_name", "") quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "") base = "" if original_content and original_author: - base = f"引用了{original_author}的帖子「{original_content}」" + base = f"quoted {original_author}'s post \"{original_content}\"" elif original_content: - base = f"引用了一条帖子「{original_content}」" + base = f"quoted a post \"{original_content}\"" elif original_author: - base = f"引用了{original_author}的一条帖子" + base = f"quoted a post from {original_author}" else: - base = "引用了一条帖子" + base = "quoted a post" if quote_content: - base += f",并评论道:「{quote_content}」" + base += f", adding the comment: \"{quote_content}\"" return base def _describe_follow(self) -> str: - """关注用户 - 包含被关注用户的名称""" + """Describe Follow.""" target_user_name = self.action_args.get("target_user_name", "") if target_user_name: - return f"关注了用户「{target_user_name}」" - return "关注了一个用户" + return f"followed user \"{target_user_name}\"" + return "followed a user" def _describe_create_comment(self) -> str: - """发表评论 - 包含评论内容和所评论的帖子信息""" + """Describe Create Comment.""" content = self.action_args.get("content", "") post_content = self.action_args.get("post_content", "") post_author = self.action_args.get("post_author_name", "") if content: if post_content and post_author: - return f"在{post_author}的帖子「{post_content}」下评论道:「{content}」" + return f"commented on {post_author}'s post \"{post_content}\": \"{content}\"" elif post_content: - return f"在帖子「{post_content}」下评论道:「{content}」" + return f"commented on a post \"{post_content}\": \"{content}\"" elif post_author: - return f"在{post_author}的帖子下评论道:「{content}」" - return f"评论道:「{content}」" - return "发表了评论" + return f"commented on a post from {post_author}: \"{content}\"" + return f"commented: \"{content}\"" + return "posted a comment" def _describe_like_comment(self) -> str: - """点赞评论 - 包含评论内容和作者信息""" + """Describe Like Comment.""" comment_content = self.action_args.get("comment_content", "") comment_author = self.action_args.get("comment_author_name", "") if comment_content and comment_author: - return f"点赞了{comment_author}的评论:「{comment_content}」" + return f"liked {comment_author}'s comment: \"{comment_content}\"" elif comment_content: - return f"点赞了一条评论:「{comment_content}」" + return f"liked a comment: \"{comment_content}\"" elif comment_author: - return f"点赞了{comment_author}的一条评论" - return "点赞了一条评论" + return f"liked a comment from {comment_author}" + return "liked a comment" def _describe_dislike_comment(self) -> str: - """踩评论 - 包含评论内容和作者信息""" + """Describe Dislike Comment.""" comment_content = self.action_args.get("comment_content", "") comment_author = self.action_args.get("comment_author_name", "") if comment_content and comment_author: - return f"踩了{comment_author}的评论:「{comment_content}」" + return f"downvoted {comment_author}'s comment: \"{comment_content}\"" elif comment_content: - return f"踩了一条评论:「{comment_content}」" + return f"downvoted a comment: \"{comment_content}\"" elif comment_author: - return f"踩了{comment_author}的一条评论" - return "踩了一条评论" + return f"downvoted a comment from {comment_author}" + return "downvoted a comment" def _describe_search(self) -> str: - """搜索帖子 - 包含搜索关键词""" + """Describe Search.""" query = self.action_args.get("query", "") or self.action_args.get("keyword", "") - return f"搜索了「{query}」" if query else "进行了搜索" + return f"searched for \"{query}\"" if query else "performed a search" def _describe_search_user(self) -> str: - """搜索用户 - 包含搜索关键词""" + """Describe Search User.""" query = self.action_args.get("query", "") or self.action_args.get("username", "") - return f"搜索了用户「{query}」" if query else "搜索了用户" + return f"searched for user \"{query}\"" if query else "searched for a user" def _describe_mute(self) -> str: - """屏蔽用户 - 包含被屏蔽用户的名称""" + """Describe Mute.""" target_user_name = self.action_args.get("target_user_name", "") if target_user_name: - return f"屏蔽了用户「{target_user_name}」" - return "屏蔽了一个用户" + return f"muted user \"{target_user_name}\"" + return "muted a user" def _describe_generic(self) -> str: - # 对于未知的动作类型,生成通用描述 - return f"执行了{self.action_type}操作" + + return f"performed the action {self.action_type}" class ZepGraphMemoryUpdater: - """ - Zep图谱记忆更新器 + """Zep Graph Memory Updater.""" - 监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。 - 按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。 - 所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息: - - 点赞/踩的帖子原文 - - 转发/引用的帖子原文 - - 关注/屏蔽的用户名 - - 点赞/踩的评论原文 - """ - - # 批量发送大小(每个平台累积多少条后发送) BATCH_SIZE = 5 - # 平台名称映射(用于控制台显示) + PLATFORM_DISPLAY_NAMES = { - 'twitter': '世界1', - 'reddit': '世界2', + 'twitter': 'Twitter', + 'reddit': 'Reddit', } - # 发送间隔(秒),避免请求过快 + SEND_INTERVAL = 0.5 - # 重试配置 + MAX_RETRIES = 3 - RETRY_DELAY = 2 # 秒 + RETRY_DELAY = 2 def __init__(self, graph_id: str, api_key: Optional[str] = None): - """ - 初始化更新器 - - Args: - graph_id: Zep图谱ID - api_key: Zep API Key(可选,默认从配置读取) - """ + """Initialize the instance.""" self.graph_id = graph_id self.api_key = api_key or Config.ZEP_API_KEY + self.provider = create_graph_provider() - if not self.api_key: - raise ValueError("ZEP_API_KEY未配置") - - self.client = Zep(api_key=self.api_key) - # 活动队列 self._activity_queue: Queue = Queue() - # 按平台分组的活动缓冲区(每个平台各自累积到BATCH_SIZE后批量发送) + self._platform_buffers: Dict[str, List[AgentActivity]] = { 'twitter': [], 'reddit': [], } self._buffer_lock = threading.Lock() - # 控制标志 + self._running = False self._worker_thread: Optional[threading.Thread] = None - # 统计 - self._total_activities = 0 # 实际添加到队列的活动数 - self._total_sent = 0 # 成功发送到Zep的批次数 - self._total_items_sent = 0 # 成功发送到Zep的活动条数 - self._failed_count = 0 # 发送失败的批次数 - self._skipped_count = 0 # 被过滤跳过的活动数(DO_NOTHING) - logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}") + self._total_activities = 0 + self._total_sent = 0 + self._total_items_sent = 0 + self._failed_count = 0 + self._skipped_count = 0 + + logger.info(f"ZepGraphMemoryUpdater initialized: graph_id={graph_id}, batch_size={self.BATCH_SIZE}") def _get_platform_display_name(self, platform: str) -> str: - """获取平台的显示名称""" + """Get platform display name.""" return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform) def start(self): - """启动后台工作线程""" + """Start the requested object.""" if self._running: return @@ -283,64 +253,41 @@ def start(self): name=f"ZepMemoryUpdater-{self.graph_id[:8]}" ) self._worker_thread.start() - logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}") + logger.info(f"ZepGraphMemoryUpdater started: graph_id={self.graph_id}") def stop(self): - """停止后台工作线程""" + """Stop the requested object.""" self._running = False - # 发送剩余的活动 + self._flush_remaining() if self._worker_thread and self._worker_thread.is_alive(): self._worker_thread.join(timeout=10) - logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, " - f"total_activities={self._total_activities}, " - f"batches_sent={self._total_sent}, " - f"items_sent={self._total_items_sent}, " - f"failed={self._failed_count}, " - f"skipped={self._skipped_count}") + logger.info( + f"ZepGraphMemoryUpdater stopped: graph_id={self.graph_id}, " + f"total_activities={self._total_activities}, " + f"batches_sent={self._total_sent}, " + f"items_sent={self._total_items_sent}, " + f"failed={self._failed_count}, " + f"skipped={self._skipped_count}" + ) def add_activity(self, activity: AgentActivity): - """ - 添加一个agent活动到队列 - - 所有有意义的行为都会被添加到队列,包括: - - CREATE_POST(发帖) - - CREATE_COMMENT(评论) - - QUOTE_POST(引用帖子) - - SEARCH_POSTS(搜索帖子) - - SEARCH_USER(搜索用户) - - LIKE_POST/DISLIKE_POST(点赞/踩帖子) - - REPOST(转发) - - FOLLOW(关注) - - MUTE(屏蔽) - - LIKE_COMMENT/DISLIKE_COMMENT(点赞/踩评论) - - action_args中会包含完整的上下文信息(如帖子原文、用户名等)。 - - Args: - activity: Agent活动记录 - """ - # 跳过DO_NOTHING类型的活动 + """Add activity.""" + if activity.action_type == "DO_NOTHING": self._skipped_count += 1 return self._activity_queue.put(activity) self._total_activities += 1 - logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}") + logger.debug(f"Queued activity for Zep: {activity.agent_name} - {activity.action_type}") def add_activity_from_dict(self, data: Dict[str, Any], platform: str): - """ - 从字典数据添加活动 - - Args: - data: 从actions.jsonl解析的字典数据 - platform: 平台名称 (twitter/reddit) - """ - # 跳过事件类型的条目 + """Add activity from dict.""" + if "event_type" in data: return @@ -357,78 +304,72 @@ def add_activity_from_dict(self, data: Dict[str, Any], platform: str): self.add_activity(activity) def _worker_loop(self): - """后台工作循环 - 按平台批量发送活动到Zep""" + """Worker Loop.""" while self._running or not self._activity_queue.empty(): try: - # 尝试从队列获取活动(超时1秒) + try: activity = self._activity_queue.get(timeout=1) - # 将活动添加到对应平台的缓冲区 + platform = activity.platform.lower() with self._buffer_lock: if platform not in self._platform_buffers: self._platform_buffers[platform] = [] self._platform_buffers[platform].append(activity) - # 检查该平台是否达到批量大小 + if len(self._platform_buffers[platform]) >= self.BATCH_SIZE: batch = self._platform_buffers[platform][:self.BATCH_SIZE] self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:] - # 释放锁后再发送 + self._send_batch_activities(batch, platform) - # 发送间隔,避免请求过快 + time.sleep(self.SEND_INTERVAL) except Empty: pass except Exception as e: - logger.error(f"工作循环异常: {e}") + logger.error(f"Worker loop error: {e}") time.sleep(1) def _send_batch_activities(self, activities: List[AgentActivity], platform: str): - """ - 批量发送活动到Zep图谱(合并为一条文本) - - Args: - activities: Agent活动列表 - platform: 平台名称 - """ + """Send batch activities.""" if not activities: return - # 将多条活动合并为一条文本,用换行分隔 + episode_texts = [activity.to_episode_text() for activity in activities] combined_text = "\n".join(episode_texts) - # 带重试的发送 + for attempt in range(self.MAX_RETRIES): try: - self.client.graph.add( + display_name = self._get_platform_display_name(platform) + self.provider.add_text( graph_id=self.graph_id, - type="text", - data=combined_text + data=combined_text, + source_description=f"MiroFish {display_name} activity", ) self._total_sent += 1 self._total_items_sent += len(activities) - display_name = self._get_platform_display_name(platform) - logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}") - logger.debug(f"批量内容预览: {combined_text[:200]}...") + logger.info(f"Sent {len(activities)} {display_name} activities to graph {self.graph_id}") + logger.debug(f"Batch preview: {combined_text[:200]}...") return except Exception as e: if attempt < self.MAX_RETRIES - 1: - logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}") + logger.warning(f"Failed to send batch to Zep (attempt {attempt + 1}/{self.MAX_RETRIES}): {e}") time.sleep(self.RETRY_DELAY * (attempt + 1)) else: - logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}") + logger.error(f"Failed to send batch to Zep after {self.MAX_RETRIES} retries: {e}") self._failed_count += 1 def _flush_remaining(self): - """发送队列和缓冲区中剩余的活动""" - # 首先处理队列中剩余的活动,添加到缓冲区 + """Flush Remaining.""" + while not self._activity_queue.empty(): try: activity = self._activity_queue.get_nowait() @@ -440,60 +381,47 @@ def _flush_remaining(self): except Empty: break - # 然后发送各平台缓冲区中剩余的活动(即使不足BATCH_SIZE条) + with self._buffer_lock: for platform, buffer in self._platform_buffers.items(): if buffer: display_name = self._get_platform_display_name(platform) - logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动") + logger.info(f"Sending remaining {len(buffer)} activities for {display_name}") self._send_batch_activities(buffer, platform) - # 清空所有缓冲区 + for platform in self._platform_buffers: self._platform_buffers[platform] = [] def get_stats(self) -> Dict[str, Any]: - """获取统计信息""" + """Get stats.""" with self._buffer_lock: buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()} return { "graph_id": self.graph_id, "batch_size": self.BATCH_SIZE, - "total_activities": self._total_activities, # 添加到队列的活动总数 - "batches_sent": self._total_sent, # 成功发送的批次数 - "items_sent": self._total_items_sent, # 成功发送的活动条数 - "failed_count": self._failed_count, # 发送失败的批次数 - "skipped_count": self._skipped_count, # 被过滤跳过的活动数(DO_NOTHING) + "total_activities": self._total_activities, + "batches_sent": self._total_sent, + "items_sent": self._total_items_sent, + "failed_count": self._failed_count, + "skipped_count": self._skipped_count, "queue_size": self._activity_queue.qsize(), - "buffer_sizes": buffer_sizes, # 各平台缓冲区大小 + "buffer_sizes": buffer_sizes, "running": self._running, } class ZepGraphMemoryManager: - """ - 管理多个模拟的Zep图谱记忆更新器 - - 每个模拟可以有自己的更新器实例 - """ + """Zep Graph Memory Manager.""" _updaters: Dict[str, ZepGraphMemoryUpdater] = {} _lock = threading.Lock() @classmethod def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater: - """ - 为模拟创建图谱记忆更新器 - - Args: - simulation_id: 模拟ID - graph_id: Zep图谱ID - - Returns: - ZepGraphMemoryUpdater实例 - """ + """Create updater.""" with cls._lock: - # 如果已存在,先停止旧的 + if simulation_id in cls._updaters: cls._updaters[simulation_id].stop() @@ -501,30 +429,30 @@ def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpda updater.start() cls._updaters[simulation_id] = updater - logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}") + logger.info(f"Created graph memory updater: simulation_id={simulation_id}, graph_id={graph_id}") return updater @classmethod def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]: - """获取模拟的更新器""" + """Get updater.""" return cls._updaters.get(simulation_id) @classmethod def stop_updater(cls, simulation_id: str): - """停止并移除模拟的更新器""" + """Stop updater.""" with cls._lock: if simulation_id in cls._updaters: cls._updaters[simulation_id].stop() del cls._updaters[simulation_id] - logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}") + logger.info(f"Stopped graph memory updater: simulation_id={simulation_id}") + - # 防止 stop_all 重复调用的标志 _stop_all_done = False @classmethod def stop_all(cls): - """停止所有更新器""" - # 防止重复调用 + """Stop all.""" + if cls._stop_all_done: return cls._stop_all_done = True @@ -535,13 +463,13 @@ def stop_all(cls): try: updater.stop() except Exception as e: - logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}") + logger.error(f"Failed to stop updater: simulation_id={simulation_id}, error={e}") cls._updaters.clear() - logger.info("已停止所有图谱记忆更新器") + logger.info("Stopped all graph memory updaters") @classmethod def get_all_stats(cls) -> Dict[str, Dict[str, Any]]: - """获取所有更新器的统计信息""" + """Get all stats.""" return { sim_id: updater.get_stats() for sim_id, updater in cls._updaters.items() diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 384cf540f..924653ffc 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -1,31 +1,21 @@ -""" -Zep检索工具服务 -封装图谱搜索、节点读取、边查询等工具,供Report Agent使用 - -核心检索工具(优化后): -1. InsightForge(深度洞察检索)- 最强大的混合检索,自动生成子问题并多维度检索 -2. PanoramaSearch(广度搜索)- 获取全貌,包括过期内容 -3. QuickSearch(简单搜索)- 快速检索 -""" +"""Zep retrieval tools service.""" import time import json from typing import Dict, Any, List, Optional from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config +from .graph_provider import create_graph_provider from ..utils.logger import get_logger from ..utils.llm_client import LLMClient -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_tools') @dataclass class SearchResult: - """搜索结果""" + """Search Result.""" facts: List[str] edges: List[Dict[str, Any]] nodes: List[Dict[str, Any]] @@ -42,11 +32,11 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为文本格式,供LLM理解""" - text_parts = [f"搜索查询: {self.query}", f"找到 {self.total_count} 条相关信息"] + """Convert the object to text.""" + text_parts = [f"Search Query: {self.query}", f"Found {self.total_count} relevant items"] if self.facts: - text_parts.append("\n### 相关事实:") + text_parts.append("\n### Related Facts:") for i, fact in enumerate(self.facts, 1): text_parts.append(f"{i}. {fact}") @@ -55,7 +45,7 @@ def to_text(self) -> str: @dataclass class NodeInfo: - """节点信息""" + """Node Info.""" uuid: str name: str labels: List[str] @@ -72,14 +62,14 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为文本格式""" - entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型") - return f"实体: {self.name} (类型: {entity_type})\n摘要: {self.summary}" + """Convert the object to text.""" + entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "Unknown Type") + return f"Entity: {self.name} (Type: {entity_type})\nSummary: {self.summary}" @dataclass class EdgeInfo: - """边信息""" + """Edge Info.""" uuid: str name: str fact: str @@ -87,7 +77,7 @@ class EdgeInfo: target_node_uuid: str source_node_name: Optional[str] = None target_node_name: Optional[str] = None - # 时间信息 + created_at: Optional[str] = None valid_at: Optional[str] = None invalid_at: Optional[str] = None @@ -109,47 +99,44 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self, include_temporal: bool = False) -> str: - """转换为文本格式""" + """Convert the object to text.""" source = self.source_node_name or self.source_node_uuid[:8] target = self.target_node_name or self.target_node_uuid[:8] - base_text = f"关系: {source} --[{self.name}]--> {target}\n事实: {self.fact}" + base_text = f"Relation: {source} --[{self.name}]--> {target}\nFact: {self.fact}" if include_temporal: - valid_at = self.valid_at or "未知" - invalid_at = self.invalid_at or "至今" - base_text += f"\n时效: {valid_at} - {invalid_at}" + valid_at = self.valid_at or "Unknown" + invalid_at = self.invalid_at or "Present" + base_text += f"\nValidity: {valid_at} - {invalid_at}" if self.expired_at: - base_text += f" (已过期: {self.expired_at})" + base_text += f" (Expired: {self.expired_at})" return base_text @property def is_expired(self) -> bool: - """是否已过期""" + """Return whether expired.""" return self.expired_at is not None @property def is_invalid(self) -> bool: - """是否已失效""" + """Return whether invalid.""" return self.invalid_at is not None @dataclass class InsightForgeResult: - """ - 深度洞察检索结果 (InsightForge) - 包含多个子问题的检索结果,以及综合分析 - """ + """Insight Forge Result.""" query: str simulation_requirement: str sub_queries: List[str] - # 各维度检索结果 - semantic_facts: List[str] = field(default_factory=list) # 语义搜索结果 - entity_insights: List[Dict[str, Any]] = field(default_factory=list) # 实体洞察 - relationship_chains: List[str] = field(default_factory=list) # 关系链 - # 统计信息 + semantic_facts: List[str] = field(default_factory=list) + entity_insights: List[Dict[str, Any]] = field(default_factory=list) + relationship_chains: List[str] = field(default_factory=list) + + total_facts: int = 0 total_entities: int = 0 total_relationships: int = 0 @@ -168,42 +155,42 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为详细的文本格式,供LLM理解""" + """Convert the object to text.""" text_parts = [ - f"## 未来预测深度分析", - f"分析问题: {self.query}", - f"预测场景: {self.simulation_requirement}", - f"\n### 预测数据统计", - f"- 相关预测事实: {self.total_facts}条", - f"- 涉及实体: {self.total_entities}个", - f"- 关系链: {self.total_relationships}条" + "## Deep Forecast Analysis", + f"Analysis Question: {self.query}", + f"Prediction Scenario: {self.simulation_requirement}", + "\n### Forecast Statistics", + f"- Relevant Forecast Facts: {self.total_facts}", + f"- Entities Involved: {self.total_entities}", + f"- Relationship Chains: {self.total_relationships}" ] - # 子问题 + if self.sub_queries: - text_parts.append(f"\n### 分析的子问题") + text_parts.append("\n### Analysis Sub-questions") for i, sq in enumerate(self.sub_queries, 1): text_parts.append(f"{i}. {sq}") - # 语义搜索结果 + if self.semantic_facts: - text_parts.append(f"\n### 【关键事实】(请在报告中引用这些原文)") + text_parts.append("\n### Key Facts") for i, fact in enumerate(self.semantic_facts, 1): text_parts.append(f"{i}. \"{fact}\"") - # 实体洞察 + if self.entity_insights: - text_parts.append(f"\n### 【核心实体】") + text_parts.append("\n### Core Entities") for entity in self.entity_insights: - text_parts.append(f"- **{entity.get('name', '未知')}** ({entity.get('type', '实体')})") + text_parts.append(f"- **{entity.get('name', 'Unknown')}** ({entity.get('type', 'Entity')})") if entity.get('summary'): - text_parts.append(f" 摘要: \"{entity.get('summary')}\"") + text_parts.append(f" Summary: \"{entity.get('summary')}\"") if entity.get('related_facts'): - text_parts.append(f" 相关事实: {len(entity.get('related_facts', []))}条") + text_parts.append(f" Related Facts: {len(entity.get('related_facts', []))}") + - # 关系链 if self.relationship_chains: - text_parts.append(f"\n### 【关系链】") + text_parts.append("\n### Relationship Chains") for chain in self.relationship_chains: text_parts.append(f"- {chain}") @@ -212,22 +199,19 @@ def to_text(self) -> str: @dataclass class PanoramaResult: - """ - 广度搜索结果 (Panorama) - 包含所有相关信息,包括过期内容 - """ + """Panorama Result.""" query: str - # 全部节点 + all_nodes: List[NodeInfo] = field(default_factory=list) - # 全部边(包括过期的) + all_edges: List[EdgeInfo] = field(default_factory=list) - # 当前有效的事实 + active_facts: List[str] = field(default_factory=list) - # 已过期/失效的事实(历史记录) + historical_facts: List[str] = field(default_factory=list) - # 统计 + total_nodes: int = 0 total_edges: int = 0 active_count: int = 0 @@ -247,34 +231,34 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为文本格式(完整版本,不截断)""" + """Convert the object to text.""" text_parts = [ - f"## 广度搜索结果(未来全景视图)", - f"查询: {self.query}", - f"\n### 统计信息", - f"- 总节点数: {self.total_nodes}", - f"- 总边数: {self.total_edges}", - f"- 当前有效事实: {self.active_count}条", - f"- 历史/过期事实: {self.historical_count}条" + "## Panorama Search Results", + f"Query: {self.query}", + "\n### Statistics", + f"- Total Nodes: {self.total_nodes}", + f"- Total Edges: {self.total_edges}", + f"- Active Facts: {self.active_count}", + f"- Historical / Expired Facts: {self.historical_count}" ] - # 当前有效的事实(完整输出,不截断) + if self.active_facts: - text_parts.append(f"\n### 【当前有效事实】(模拟结果原文)") + text_parts.append("\n### Active Facts") for i, fact in enumerate(self.active_facts, 1): text_parts.append(f"{i}. \"{fact}\"") - # 历史/过期事实(完整输出,不截断) + if self.historical_facts: - text_parts.append(f"\n### 【历史/过期事实】(演变过程记录)") + text_parts.append("\n### Historical / Expired Facts") for i, fact in enumerate(self.historical_facts, 1): text_parts.append(f"{i}. \"{fact}\"") - # 关键实体(完整输出,不截断) + if self.all_nodes: - text_parts.append(f"\n### 【涉及实体】") + text_parts.append("\n### Entities Involved") for node in self.all_nodes: - entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") + entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "Entity") text_parts.append(f"- **{node.name}** ({entity_type})") return "\n".join(text_parts) @@ -282,13 +266,13 @@ def to_text(self) -> str: @dataclass class AgentInterview: - """单个Agent的采访结果""" + """Agent Interview.""" agent_name: str - agent_role: str # 角色类型(如:学生、教师、媒体等) - agent_bio: str # 简介 - question: str # 采访问题 - response: str # 采访回答 - key_quotes: List[str] = field(default_factory=list) # 关键引言 + agent_role: str + agent_bio: str + question: str + response: str + key_quotes: List[str] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { @@ -302,29 +286,29 @@ def to_dict(self) -> Dict[str, Any]: def to_text(self) -> str: text = f"**{self.agent_name}** ({self.agent_role})\n" - # 显示完整的agent_bio,不截断 - text += f"_简介: {self.agent_bio}_\n\n" + + text += f"_Bio: {self.agent_bio}_\n\n" text += f"**Q:** {self.question}\n\n" text += f"**A:** {self.response}\n" if self.key_quotes: - text += "\n**关键引言:**\n" + text += "\n**Key Quotes:**\n" for quote in self.key_quotes: - # 清理各种引号 + clean_quote = quote.replace('\u201c', '').replace('\u201d', '').replace('"', '') clean_quote = clean_quote.replace('\u300c', '').replace('\u300d', '') clean_quote = clean_quote.strip() - # 去掉开头的标点 - while clean_quote and clean_quote[0] in ',,;;::、。!?\n\r\t ': + + while clean_quote and clean_quote[0] in '\uff0c,\uff1b;: \u3001\u3002\uff01\uff1f\n\r\t ': clean_quote = clean_quote[1:] - # 过滤包含问题编号的垃圾内容(问题1-9) + skip = False for d in '123456789': - if f'\u95ee\u9898{d}' in clean_quote: + if f'\u95ee\u9898{d}' in clean_quote or f'Question {d}' in clean_quote: skip = True break if skip: continue - # 截断过长内容(按句号截断,而非硬截断) + if len(clean_quote) > 150: dot_pos = clean_quote.find('\u3002', 80) if dot_pos > 0: @@ -338,24 +322,21 @@ def to_text(self) -> str: @dataclass class InterviewResult: - """ - 采访结果 (Interview) - 包含多个模拟Agent的采访回答 - """ - interview_topic: str # 采访主题 - interview_questions: List[str] # 采访问题列表 + """Interview Result.""" + interview_topic: str + interview_questions: List[str] + - # 采访选择的Agent selected_agents: List[Dict[str, Any]] = field(default_factory=list) - # 各Agent的采访回答 + interviews: List[AgentInterview] = field(default_factory=list) - # 选择Agent的理由 + selection_reasoning: str = "" - # 整合后的采访摘要 + summary: str = "" - # 统计 + total_agents: int = 0 interviewed_count: int = 0 @@ -372,74 +353,54 @@ def to_dict(self) -> Dict[str, Any]: } def to_text(self) -> str: - """转换为详细的文本格式,供LLM理解和报告引用""" + """Convert the object to text.""" text_parts = [ - "## 深度采访报告", - f"**采访主题:** {self.interview_topic}", - f"**采访人数:** {self.interviewed_count} / {self.total_agents} 位模拟Agent", - "\n### 采访对象选择理由", - self.selection_reasoning or "(自动选择)", + "## In-Depth Interview Report", + f"**Interview Topic:** {self.interview_topic}", + f"**Interview Count:** {self.interviewed_count} / {self.total_agents} simulated agents", + "\n### Why These Agents Were Selected", + self.selection_reasoning or "(Selected automatically)", "\n---", - "\n### 采访实录", + "\n### Interview Transcript", ] if self.interviews: for i, interview in enumerate(self.interviews, 1): - text_parts.append(f"\n#### 采访 #{i}: {interview.agent_name}") + text_parts.append(f"\n#### Interview #{i}: {interview.agent_name}") text_parts.append(interview.to_text()) text_parts.append("\n---") else: - text_parts.append("(无采访记录)\n\n---") + text_parts.append("(No interview records)\n\n---") - text_parts.append("\n### 采访摘要与核心观点") - text_parts.append(self.summary or "(无摘要)") + text_parts.append("\n### Interview Summary & Key Takeaways") + text_parts.append(self.summary or "(No summary)") return "\n".join(text_parts) class ZepToolsService: - """ - Zep检索工具服务 + """Zep Tools Service.""" - 【核心检索工具 - 优化后】 - 1. insight_forge - 深度洞察检索(最强大,自动生成子问题,多维度检索) - 2. panorama_search - 广度搜索(获取全貌,包括过期内容) - 3. quick_search - 简单搜索(快速检索) - 4. interview_agents - 深度采访(采访模拟Agent,获取多视角观点) - 【基础工具】 - - search_graph - 图谱语义搜索 - - get_all_nodes - 获取图谱所有节点 - - get_all_edges - 获取图谱所有边(含时间信息) - - get_node_detail - 获取节点详细信息 - - get_node_edges - 获取节点相关的边 - - get_entities_by_type - 按类型获取实体 - - get_entity_summary - 获取实体的关系摘要 - """ - - # 重试配置 MAX_RETRIES = 3 RETRY_DELAY = 2.0 def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") + self.provider = create_graph_provider() - self.client = Zep(api_key=self.api_key) - # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client - logger.info("ZepToolsService 初始化完成") + logger.info("ZepToolsService initialized") @property def llm(self) -> LLMClient: - """延迟初始化LLM客户端""" + """LLM.""" if self._llm_client is None: self._llm_client = LLMClient() return self._llm_client def _call_with_retry(self, func, operation_name: str, max_retries: int = None): - """带重试机制的API调用""" + """Call with retry.""" max_retries = max_retries or self.MAX_RETRIES last_exception = None delay = self.RETRY_DELAY @@ -451,13 +412,13 @@ def _call_with_retry(self, func, operation_name: str, max_retries: int = None): last_exception = e if attempt < max_retries - 1: logger.warning( - f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, " - f"{delay:.1f}秒后重试..." + f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}. " + f"Retrying in {delay:.1f}s..." ) time.sleep(delay) delay *= 2 else: - logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}") + logger.error(f"Zep {operation_name} still failed after {max_retries} attempts: {str(e)}") raise last_exception @@ -468,67 +429,43 @@ def search_graph( limit: int = 10, scope: str = "edges" ) -> SearchResult: - """ - 图谱语义搜索 - - 使用混合搜索(语义+BM25)在图谱中搜索相关信息。 - 如果Zep Cloud的search API不可用,则降级为本地关键词匹配。 + """Search graph.""" + logger.info(f"Graph search: graph_id={graph_id}, query={query[:50]}...") - Args: - graph_id: 图谱ID (Standalone Graph) - query: 搜索查询 - limit: 返回结果数量 - scope: 搜索范围,"edges" 或 "nodes" - - Returns: - SearchResult: 搜索结果 - """ - logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...") - - # 尝试使用Zep Cloud Search API try: search_results = self._call_with_retry( - func=lambda: self.client.graph.search( + func=lambda: self.provider.search( graph_id=graph_id, query=query, limit=limit, scope=scope, - reranker="cross_encoder" + reranker="cross_encoder", ), - operation_name=f"图谱搜索(graph={graph_id})" + operation_name=f"graph search (graph={graph_id})" ) - facts = [] - edges = [] - nodes = [] - - # 解析边搜索结果 - if hasattr(search_results, 'edges') and search_results.edges: - for edge in search_results.edges: - if hasattr(edge, 'fact') and edge.fact: - facts.append(edge.fact) - edges.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": getattr(edge, 'name', ''), - "fact": getattr(edge, 'fact', ''), - "source_node_uuid": getattr(edge, 'source_node_uuid', ''), - "target_node_uuid": getattr(edge, 'target_node_uuid', ''), - }) - - # 解析节点搜索结果 - if hasattr(search_results, 'nodes') and search_results.nodes: - for node in search_results.nodes: - nodes.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": getattr(node, 'name', ''), - "labels": getattr(node, 'labels', []), - "summary": getattr(node, 'summary', ''), - }) - # 节点摘要也算作事实 - if hasattr(node, 'summary') and node.summary: - facts.append(f"[{node.name}]: {node.summary}") + facts = list(search_results.facts) + edges = [ + { + "uuid": edge.uuid, + "name": edge.name, + "fact": edge.fact, + "source_node_uuid": edge.source_node_uuid, + "target_node_uuid": edge.target_node_uuid, + } + for edge in search_results.edges + ] + nodes = [ + { + "uuid": node.uuid, + "name": node.name, + "labels": node.labels, + "summary": node.summary, + } + for node in search_results.nodes + ] - logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实") + logger.info(f"Search completed: found {len(facts)} related facts") return SearchResult( facts=facts, @@ -539,8 +476,8 @@ def search_graph( ) except Exception as e: - logger.warning(f"Zep Search API失败,降级为本地搜索: {str(e)}") - # 降级:使用本地关键词匹配搜索 + logger.warning(f"Zep Search API failed, falling back to local search: {str(e)}") + return self._local_search(graph_id, query, limit, scope) def _local_search( @@ -550,39 +487,26 @@ def _local_search( limit: int = 10, scope: str = "edges" ) -> SearchResult: - """ - 本地关键词匹配搜索(作为Zep Search API的降级方案) - - 获取所有边/节点,然后在本地进行关键词匹配 - - Args: - graph_id: 图谱ID - query: 搜索查询 - limit: 返回结果数量 - scope: 搜索范围 - - Returns: - SearchResult: 搜索结果 - """ - logger.info(f"使用本地搜索: query={query[:30]}...") + """Local Search.""" + logger.info(f"Using local search: query={query[:30]}...") facts = [] edges_result = [] nodes_result = [] - # 提取查询关键词(简单分词) + query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] def match_score(text: str) -> int: - """计算文本与查询的匹配分数""" + """Compute the text match score.""" if not text: return 0 text_lower = text.lower() - # 完全匹配查询 + if query_lower in text_lower: return 100 - # 关键词匹配 + score = 0 for keyword in keywords: if keyword in text_lower: @@ -591,7 +515,7 @@ def match_score(text: str) -> int: try: if scope in ["edges", "both"]: - # 获取所有边并匹配 + all_edges = self.get_all_edges(graph_id) scored_edges = [] for edge in all_edges: @@ -599,7 +523,7 @@ def match_score(text: str) -> int: if score > 0: scored_edges.append((score, edge)) - # 按分数排序 + scored_edges.sort(key=lambda x: x[0], reverse=True) for score, edge in scored_edges[:limit]: @@ -614,7 +538,7 @@ def match_score(text: str) -> int: }) if scope in ["nodes", "both"]: - # 获取所有节点并匹配 + all_nodes = self.get_all_nodes(graph_id) scored_nodes = [] for node in all_nodes: @@ -634,10 +558,10 @@ def match_score(text: str) -> int: if node.summary: facts.append(f"[{node.name}]: {node.summary}") - logger.info(f"本地搜索完成: 找到 {len(facts)} 条相关事实") + logger.info(f"Local search completed: found {len(facts)} related facts") except Exception as e: - logger.error(f"本地搜索失败: {str(e)}") + logger.error(f"Local search failed: {str(e)}") return SearchResult( facts=facts, @@ -648,60 +572,41 @@ def match_score(text: str) -> int: ) def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: - """ - 获取图谱的所有节点(分页获取) + """Get all nodes.""" + logger.info(f"Loading all nodes for graph {graph_id}...") - Args: - graph_id: 图谱ID - - Returns: - 节点列表 - """ - logger.info(f"获取图谱 {graph_id} 的所有节点...") - - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.provider.get_all_nodes(graph_id) result = [] for node in nodes: - node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or "" result.append(NodeInfo( - uuid=str(node_uuid) if node_uuid else "", + uuid=node.uuid, name=node.name or "", labels=node.labels or [], summary=node.summary or "", attributes=node.attributes or {} )) - logger.info(f"获取到 {len(result)} 个节点") + logger.info(f"Loaded {len(result)} nodes") return result def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]: - """ - 获取图谱的所有边(分页获取,包含时间信息) - - Args: - graph_id: 图谱ID - include_temporal: 是否包含时间信息(默认True) - - Returns: - 边列表(包含created_at, valid_at, invalid_at, expired_at) - """ - logger.info(f"获取图谱 {graph_id} 的所有边...") + """Get all edges.""" + logger.info(f"Loading all edges for graph {graph_id}...") - edges = fetch_all_edges(self.client, graph_id) + edges = self.provider.get_all_edges(graph_id) result = [] for edge in edges: - edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or "" edge_info = EdgeInfo( - uuid=str(edge_uuid) if edge_uuid else "", + uuid=edge.uuid, name=edge.name or "", fact=edge.fact or "", source_node_uuid=edge.source_node_uuid or "", target_node_uuid=edge.target_node_uuid or "" ) - # 添加时间信息 + if include_temporal: edge_info.created_at = getattr(edge, 'created_at', None) edge_info.valid_at = getattr(edge, 'valid_at', None) @@ -710,25 +615,17 @@ def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[Ed result.append(edge_info) - logger.info(f"获取到 {len(result)} 条边") + logger.info(f"Loaded {len(result)} edges") return result - def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: - """ - 获取单个节点的详细信息 - - Args: - node_uuid: 节点UUID - - Returns: - 节点信息或None - """ - logger.info(f"获取节点详情: {node_uuid[:8]}...") + def get_node_detail(self, graph_id: str, node_uuid: str) -> Optional[NodeInfo]: + """Get node detail.""" + logger.info(f"Loading node details: {node_uuid[:8]}...") try: node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), - operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)" + func=lambda: self.provider.get_node(graph_id, node_uuid), + operation_name=f"get node details (uuid={node_uuid[:8]}...)" ) if not node: @@ -742,39 +639,35 @@ def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: attributes=node.attributes or {} ) except Exception as e: - logger.error(f"获取节点详情失败: {str(e)}") + logger.error(f"Failed to load node details: {str(e)}") return None def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]: - """ - 获取节点相关的所有边 - - 通过获取图谱所有边,然后过滤出与指定节点相关的边 - - Args: - graph_id: 图谱ID - node_uuid: 节点UUID - - Returns: - 边列表 - """ - logger.info(f"获取节点 {node_uuid[:8]}... 的相关边") + """Get node edges.""" + logger.info(f"Loading edges for node {node_uuid[:8]}...") try: - # 获取图谱所有边,然后过滤 - all_edges = self.get_all_edges(graph_id) - + provider_edges = self.provider.get_node_edges(graph_id, node_uuid) result = [] - for edge in all_edges: - # 检查边是否与指定节点相关(作为源或目标) - if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid: - result.append(edge) + for edge in provider_edges: + edge_info = EdgeInfo( + uuid=edge.uuid, + name=edge.name or "", + fact=edge.fact or "", + source_node_uuid=edge.source_node_uuid or "", + target_node_uuid=edge.target_node_uuid or "", + created_at=edge.created_at, + valid_at=edge.valid_at, + invalid_at=edge.invalid_at, + expired_at=edge.expired_at, + ) + result.append(edge_info) - logger.info(f"找到 {len(result)} 条与节点相关的边") + logger.info(f"Found {len(result)} edges related to the node") return result except Exception as e: - logger.warning(f"获取节点边失败: {str(e)}") + logger.warning(f"Failed to load node edges: {str(e)}") return [] def get_entities_by_type( @@ -782,27 +675,18 @@ def get_entities_by_type( graph_id: str, entity_type: str ) -> List[NodeInfo]: - """ - 按类型获取实体 - - Args: - graph_id: 图谱ID - entity_type: 实体类型(如 Student, PublicFigure 等) - - Returns: - 符合类型的实体列表 - """ - logger.info(f"获取类型为 {entity_type} 的实体...") + """Get entities by type.""" + logger.info(f"Loading entities of type {entity_type}...") all_nodes = self.get_all_nodes(graph_id) filtered = [] for node in all_nodes: - # 检查labels是否包含指定类型 + if entity_type in node.labels: filtered.append(node) - logger.info(f"找到 {len(filtered)} 个 {entity_type} 类型的实体") + logger.info(f"Found {len(filtered)} entities of type {entity_type}") return filtered def get_entity_summary( @@ -810,28 +694,17 @@ def get_entity_summary( graph_id: str, entity_name: str ) -> Dict[str, Any]: - """ - 获取指定实体的关系摘要 + """Get entity summary.""" + logger.info(f"Loading relationship summary for entity {entity_name}...") - 搜索与该实体相关的所有信息,并生成摘要 - - Args: - graph_id: 图谱ID - entity_name: 实体名称 - - Returns: - 实体摘要信息 - """ - logger.info(f"获取实体 {entity_name} 的关系摘要...") - # 先搜索该实体相关的信息 search_result = self.search_graph( graph_id=graph_id, query=entity_name, limit=20 ) - # 尝试在所有节点中找到该实体 + all_nodes = self.get_all_nodes(graph_id) entity_node = None for node in all_nodes: @@ -841,7 +714,7 @@ def get_entity_summary( related_edges = [] if entity_node: - # 传入graph_id参数 + related_edges = self.get_node_edges(graph_id, entity_node.uuid) return { @@ -853,28 +726,20 @@ def get_entity_summary( } def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]: - """ - 获取图谱的统计信息 - - Args: - graph_id: 图谱ID - - Returns: - 统计信息 - """ - logger.info(f"获取图谱 {graph_id} 的统计信息...") + """Get graph statistics.""" + logger.info(f"Loading graph statistics for {graph_id}...") nodes = self.get_all_nodes(graph_id) edges = self.get_all_edges(graph_id) - # 统计实体类型分布 + entity_types = {} for node in nodes: for label in node.labels: if label not in ["Entity", "Node"]: entity_types[label] = entity_types.get(label, 0) + 1 - # 统计关系类型分布 + relation_types = {} for edge in edges: relation_types[edge.name] = relation_types.get(edge.name, 0) + 1 @@ -893,35 +758,23 @@ def get_simulation_context( simulation_requirement: str, limit: int = 30 ) -> Dict[str, Any]: - """ - 获取模拟相关的上下文信息 + """Get simulation context.""" + logger.info(f"Loading simulation context: {simulation_requirement[:50]}...") - 综合搜索与模拟需求相关的所有信息 - Args: - graph_id: 图谱ID - simulation_requirement: 模拟需求描述 - limit: 每类信息的数量限制 - - Returns: - 模拟上下文信息 - """ - logger.info(f"获取模拟上下文: {simulation_requirement[:50]}...") - - # 搜索与模拟需求相关的信息 search_result = self.search_graph( graph_id=graph_id, query=simulation_requirement, limit=limit ) - # 获取图谱统计 + stats = self.get_graph_statistics(graph_id) - # 获取所有实体节点 + all_nodes = self.get_all_nodes(graph_id) - # 筛选有实际类型的实体(非纯Entity节点) + entities = [] for node in all_nodes: custom_labels = [l for l in node.labels if l not in ["Entity", "Node"]] @@ -936,11 +789,11 @@ def get_simulation_context( "simulation_requirement": simulation_requirement, "related_facts": search_result.facts, "graph_statistics": stats, - "entities": entities[:limit], # 限制数量 + "entities": entities[:limit], "total_entities": len(entities) } - # ========== 核心检索工具(优化后) ========== + def insight_forge( self, @@ -950,27 +803,8 @@ def insight_forge( report_context: str = "", max_sub_queries: int = 5 ) -> InsightForgeResult: - """ - 【InsightForge - 深度洞察检索】 - - 最强大的混合检索函数,自动分解问题并多维度检索: - 1. 使用LLM将问题分解为多个子问题 - 2. 对每个子问题进行语义搜索 - 3. 提取相关实体并获取其详细信息 - 4. 追踪关系链 - 5. 整合所有结果,生成深度洞察 - - Args: - graph_id: 图谱ID - query: 用户问题 - simulation_requirement: 模拟需求描述 - report_context: 报告上下文(可选,用于更精准的子问题生成) - max_sub_queries: 最大子问题数量 - - Returns: - InsightForgeResult: 深度洞察检索结果 - """ - logger.info(f"InsightForge 深度洞察检索: {query[:50]}...") + """Insight Forge.""" + logger.info(f"InsightForge retrieval: {query[:50]}...") result = InsightForgeResult( query=query, @@ -978,7 +812,7 @@ def insight_forge( sub_queries=[] ) - # Step 1: 使用LLM生成子问题 + sub_queries = self._generate_sub_queries( query=query, simulation_requirement=simulation_requirement, @@ -986,9 +820,9 @@ def insight_forge( max_queries=max_sub_queries ) result.sub_queries = sub_queries - logger.info(f"生成 {len(sub_queries)} 个子问题") + logger.info(f"Generated {len(sub_queries)} sub-queries") + - # Step 2: 对每个子问题进行语义搜索 all_facts = [] all_edges = [] seen_facts = set() @@ -1008,7 +842,7 @@ def insight_forge( all_edges.extend(search_result.edges) - # 对原始问题也进行搜索 + main_search = self.search_graph( graph_id=graph_id, query=query, @@ -1023,7 +857,7 @@ def insight_forge( result.semantic_facts = all_facts result.total_facts = len(all_facts) - # Step 3: 从边中提取相关实体UUID,只获取这些实体的信息(不获取全部节点) + entity_uuids = set() for edge_data in all_edges: if isinstance(edge_data, dict): @@ -1034,21 +868,21 @@ def insight_forge( if target_uuid: entity_uuids.add(target_uuid) - # 获取所有相关实体的详情(不限制数量,完整输出) + entity_insights = [] - node_map = {} # 用于后续关系链构建 + node_map = {} - for uuid in list(entity_uuids): # 处理所有实体,不截断 + for uuid in list(entity_uuids): if not uuid: continue try: - # 单独获取每个相关节点的信息 - node = self.get_node_detail(uuid) + + node = self.get_node_detail(graph_id, uuid) if node: node_map[uuid] = node - entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") + entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "Entity") + - # 获取该实体相关的所有事实(不截断) related_facts = [ f for f in all_facts if node.name.lower() in f.lower() @@ -1059,18 +893,18 @@ def insight_forge( "name": node.name, "type": entity_type, "summary": node.summary, - "related_facts": related_facts # 完整输出,不截断 + "related_facts": related_facts }) except Exception as e: - logger.debug(f"获取节点 {uuid} 失败: {e}") + logger.debug(f"Failed to load node {uuid}: {e}") continue result.entity_insights = entity_insights result.total_entities = len(entity_insights) - # Step 4: 构建所有关系链(不限制数量) + relationship_chains = [] - for edge_data in all_edges: # 处理所有边,不截断 + for edge_data in all_edges: if isinstance(edge_data, dict): source_uuid = edge_data.get('source_node_uuid', '') target_uuid = edge_data.get('target_node_uuid', '') @@ -1086,7 +920,10 @@ def insight_forge( result.relationship_chains = relationship_chains result.total_relationships = len(relationship_chains) - logger.info(f"InsightForge完成: {result.total_facts}条事实, {result.total_entities}个实体, {result.total_relationships}条关系") + logger.info( + f"InsightForge completed: {result.total_facts} facts, " + f"{result.total_entities} entities, {result.total_relationships} relationships" + ) return result def _generate_sub_queries( @@ -1096,28 +933,24 @@ def _generate_sub_queries( report_context: str = "", max_queries: int = 5 ) -> List[str]: - """ - 使用LLM生成子问题 - - 将复杂问题分解为多个可以独立检索的子问题 - """ - system_prompt = """你是一个专业的问题分析专家。你的任务是将一个复杂问题分解为多个可以在模拟世界中独立观察的子问题。 + """Generate sub queries.""" + system_prompt = """You are an expert at decomposing research questions for simulated worlds. -要求: -1. 每个子问题应该足够具体,可以在模拟世界中找到相关的Agent行为或事件 -2. 子问题应该覆盖原问题的不同维度(如:谁、什么、为什么、怎么样、何时、何地) -3. 子问题应该与模拟场景相关 -4. 返回JSON格式:{"sub_queries": ["子问题1", "子问题2", ...]}""" +Requirements: +1. Each sub-question must be specific enough to retrieve evidence about agent behavior or events +2. Cover different angles of the original question +3. Keep the sub-questions relevant to the simulation scenario +4. Return JSON only in this format: {"sub_queries": ["sub-question 1", "sub-question 2"]}""" - user_prompt = f"""模拟需求背景: + user_prompt = f"""Simulation background: {simulation_requirement} -{f"报告上下文:{report_context[:500]}" if report_context else ""} +{f"Report context: {report_context[:500]}" if report_context else ""} -请将以下问题分解为{max_queries}个子问题: +Break the following question into up to {max_queries} sub-questions: {query} -返回JSON格式的子问题列表。""" +Return the sub-questions as JSON.""" try: response = self.llm.chat_json( @@ -1129,17 +962,17 @@ def _generate_sub_queries( ) sub_queries = response.get("sub_queries", []) - # 确保是字符串列表 + return [str(sq) for sq in sub_queries[:max_queries]] except Exception as e: - logger.warning(f"生成子问题失败: {str(e)},使用默认子问题") - # 降级:返回基于原问题的变体 + logger.warning(f"Failed to generate sub-queries: {str(e)}. Using fallback queries.") + return [ query, - f"{query} 的主要参与者", - f"{query} 的原因和影响", - f"{query} 的发展过程" + f"Main actors related to {query}", + f"Causes and effects of {query}", + f"How {query} evolves over time" ][:max_queries] def panorama_search( @@ -1149,41 +982,23 @@ def panorama_search( include_expired: bool = True, limit: int = 50 ) -> PanoramaResult: - """ - 【PanoramaSearch - 广度搜索】 - - 获取全貌视图,包括所有相关内容和历史/过期信息: - 1. 获取所有相关节点 - 2. 获取所有边(包括已过期/失效的) - 3. 分类整理当前有效和历史信息 - - 这个工具适用于需要了解事件全貌、追踪演变过程的场景。 - - Args: - graph_id: 图谱ID - query: 搜索查询(用于相关性排序) - include_expired: 是否包含过期内容(默认True) - limit: 返回结果数量限制 - - Returns: - PanoramaResult: 广度搜索结果 - """ - logger.info(f"PanoramaSearch 广度搜索: {query[:50]}...") + """Panorama Search.""" + logger.info(f"PanoramaSearch retrieval: {query[:50]}...") result = PanoramaResult(query=query) - # 获取所有节点 + all_nodes = self.get_all_nodes(graph_id) node_map = {n.uuid: n for n in all_nodes} result.all_nodes = all_nodes result.total_nodes = len(all_nodes) - # 获取所有边(包含时间信息) + all_edges = self.get_all_edges(graph_id, include_temporal=True) result.all_edges = all_edges result.total_edges = len(all_edges) - # 分类事实 + active_facts = [] historical_facts = [] @@ -1191,24 +1006,24 @@ def panorama_search( if not edge.fact: continue - # 为事实添加实体名称 + source_name = node_map.get(edge.source_node_uuid, NodeInfo('', '', [], '', {})).name or edge.source_node_uuid[:8] target_name = node_map.get(edge.target_node_uuid, NodeInfo('', '', [], '', {})).name or edge.target_node_uuid[:8] - # 判断是否过期/失效 + is_historical = edge.is_expired or edge.is_invalid if is_historical: - # 历史/过期事实,添加时间标记 - valid_at = edge.valid_at or "未知" - invalid_at = edge.invalid_at or edge.expired_at or "未知" + + valid_at = edge.valid_at or "Unknown" + invalid_at = edge.invalid_at or edge.expired_at or "Unknown" fact_with_time = f"[{valid_at} - {invalid_at}] {edge.fact}" historical_facts.append(fact_with_time) else: - # 当前有效事实 + active_facts.append(edge.fact) - # 基于查询进行相关性排序 + query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] @@ -1222,7 +1037,7 @@ def relevance_score(fact: str) -> int: score += 10 return score - # 排序并限制数量 + active_facts.sort(key=relevance_score, reverse=True) historical_facts.sort(key=relevance_score, reverse=True) @@ -1231,7 +1046,10 @@ def relevance_score(fact: str) -> int: result.active_count = len(active_facts) result.historical_count = len(historical_facts) - logger.info(f"PanoramaSearch完成: {result.active_count}条有效, {result.historical_count}条历史") + logger.info( + f"PanoramaSearch completed: {result.active_count} active facts, " + f"{result.historical_count} historical facts" + ) return result def quick_search( @@ -1240,25 +1058,10 @@ def quick_search( query: str, limit: int = 10 ) -> SearchResult: - """ - 【QuickSearch - 简单搜索】 - - 快速、轻量级的检索工具: - 1. 直接调用Zep语义搜索 - 2. 返回最相关的结果 - 3. 适用于简单、直接的检索需求 - - Args: - graph_id: 图谱ID - query: 搜索查询 - limit: 返回结果数量 - - Returns: - SearchResult: 搜索结果 - """ - logger.info(f"QuickSearch 简单搜索: {query[:50]}...") + """Quick Search.""" + logger.info(f"QuickSearch retrieval: {query[:50]}...") + - # 直接调用现有的search_graph方法 result = self.search_graph( graph_id=graph_id, query=query, @@ -1266,7 +1069,7 @@ def quick_search( scope="edges" ) - logger.info(f"QuickSearch完成: {result.total_count}条结果") + logger.info(f"QuickSearch completed: {result.total_count} results") return result def interview_agents( @@ -1277,54 +1080,28 @@ def interview_agents( max_agents: int = 5, custom_questions: List[str] = None ) -> InterviewResult: - """ - 【InterviewAgents - 深度采访】 - - 调用真实的OASIS采访API,采访模拟中正在运行的Agent: - 1. 自动读取人设文件,了解所有模拟Agent - 2. 使用LLM分析采访需求,智能选择最相关的Agent - 3. 使用LLM生成采访问题 - 4. 调用 /api/simulation/interview/batch 接口进行真实采访(双平台同时采访) - 5. 整合所有采访结果,生成采访报告 - - 【重要】此功能需要模拟环境处于运行状态(OASIS环境未关闭) - - 【使用场景】 - - 需要从不同角色视角了解事件看法 - - 需要收集多方意见和观点 - - 需要获取模拟Agent的真实回答(非LLM模拟) - - Args: - simulation_id: 模拟ID(用于定位人设文件和调用采访API) - interview_requirement: 采访需求描述(非结构化,如"了解学生对事件的看法") - simulation_requirement: 模拟需求背景(可选) - max_agents: 最多采访的Agent数量 - custom_questions: 自定义采访问题(可选,若不提供则自动生成) - - Returns: - InterviewResult: 采访结果 - """ + """Interview Agents.""" from .simulation_runner import SimulationRunner - logger.info(f"InterviewAgents 深度采访(真实API): {interview_requirement[:50]}...") + logger.info(f"InterviewAgents live API call: {interview_requirement[:50]}...") result = InterviewResult( interview_topic=interview_requirement, interview_questions=custom_questions or [] ) - # Step 1: 读取人设文件 + profiles = self._load_agent_profiles(simulation_id) if not profiles: - logger.warning(f"未找到模拟 {simulation_id} 的人设文件") - result.summary = "未找到可采访的Agent人设文件" + logger.warning(f"No agent profiles found for simulation {simulation_id}") + result.summary = "No interviewable agent profiles were found." return result result.total_agents = len(profiles) - logger.info(f"加载到 {len(profiles)} 个Agent人设") + logger.info(f"Loaded {len(profiles)} agent profiles") + - # Step 2: 使用LLM选择要采访的Agent(返回agent_id列表) selected_agents, selected_indices, selection_reasoning = self._select_agents_for_interview( profiles=profiles, interview_requirement=interview_requirement, @@ -1334,123 +1111,131 @@ def interview_agents( result.selected_agents = selected_agents result.selection_reasoning = selection_reasoning - logger.info(f"选择了 {len(selected_agents)} 个Agent进行采访: {selected_indices}") + logger.info(f"Selected {len(selected_agents)} agents for interview: {selected_indices}") + - # Step 3: 生成采访问题(如果没有提供) if not result.interview_questions: result.interview_questions = self._generate_interview_questions( interview_requirement=interview_requirement, simulation_requirement=simulation_requirement, selected_agents=selected_agents ) - logger.info(f"生成了 {len(result.interview_questions)} 个采访问题") + logger.info(f"Generated {len(result.interview_questions)} interview questions") + - # 将问题合并为一个采访prompt combined_prompt = "\n".join([f"{i+1}. {q}" for i, q in enumerate(result.interview_questions)]) - # 添加优化前缀,约束Agent回复格式 + INTERVIEW_PROMPT_PREFIX = ( - "你正在接受一次采访。请结合你的人设、所有的过往记忆与行动," - "以纯文本方式直接回答以下问题。\n" - "回复要求:\n" - "1. 直接用自然语言回答,不要调用任何工具\n" - "2. 不要返回JSON格式或工具调用格式\n" - "3. 不要使用Markdown标题(如#、##、###)\n" - "4. 按问题编号逐一回答,每个回答以「问题X:」开头(X为问题编号)\n" - "5. 每个问题的回答之间用空行分隔\n" - "6. 回答要有实质内容,每个问题至少回答2-3句话\n\n" + "You are being interviewed. Use your persona, memories, and prior actions to answer the " + "following questions in English plain text.\n" + "Requirements:\n" + "1. Answer naturally and do not call any tools\n" + "2. Do not return JSON or tool-call formats\n" + "3. Do not use Markdown headings such as #, ##, or ###\n" + "4. Answer each question in order and start each answer with 'Question X:'\n" + "5. Separate answers with blank lines\n" + "6. Each answer should contain substantive content with at least 2-3 sentences\n\n" ) optimized_prompt = f"{INTERVIEW_PROMPT_PREFIX}{combined_prompt}" - # Step 4: 调用真实的采访API(不指定platform,默认双平台同时采访) + try: - # 构建批量采访列表(不指定platform,双平台采访) + interviews_request = [] for agent_idx in selected_indices: interviews_request.append({ "agent_id": agent_idx, - "prompt": optimized_prompt # 使用优化后的prompt - # 不指定platform,API会在twitter和reddit两个平台都采访 + "prompt": optimized_prompt + }) - logger.info(f"调用批量采访API(双平台): {len(interviews_request)} 个Agent") + logger.info(f"Calling batch interview API across both platforms for {len(interviews_request)} agents") + - # 调用 SimulationRunner 的批量采访方法(不传platform,双平台采访) api_result = SimulationRunner.interview_agents_batch( simulation_id=simulation_id, interviews=interviews_request, - platform=None, # 不指定platform,双平台采访 - timeout=180.0 # 双平台需要更长超时 + platform=None, + timeout=180.0 + ) + + logger.info( + f"Interview API returned {api_result.get('interviews_count', 0)} results, " + f"success={api_result.get('success')}" ) - logger.info(f"采访API返回: {api_result.get('interviews_count', 0)} 个结果, success={api_result.get('success')}") - # 检查API调用是否成功 if not api_result.get("success", False): - error_msg = api_result.get("error", "未知错误") - logger.warning(f"采访API返回失败: {error_msg}") - result.summary = f"采访API调用失败:{error_msg}。请检查OASIS模拟环境状态。" + error_msg = api_result.get("error", "Unknown error") + logger.warning(f"Interview API returned failure: {error_msg}") + result.summary = ( + f"Interview API call failed: {error_msg}. " + "Please verify that the OASIS simulation environment is running." + ) return result - # Step 5: 解析API返回结果,构建AgentInterview对象 - # 双平台模式返回格式: {"twitter_0": {...}, "reddit_0": {...}, "twitter_1": {...}, ...} + + api_data = api_result.get("result", {}) results_dict = api_data.get("results", {}) if isinstance(api_data, dict) else {} for i, agent_idx in enumerate(selected_indices): agent = selected_agents[i] agent_name = agent.get("realname", agent.get("username", f"Agent_{agent_idx}")) - agent_role = agent.get("profession", "未知") + agent_role = agent.get("profession", "Unknown") agent_bio = agent.get("bio", "") - # 获取该Agent在两个平台的采访结果 + twitter_result = results_dict.get(f"twitter_{agent_idx}", {}) reddit_result = results_dict.get(f"reddit_{agent_idx}", {}) twitter_response = twitter_result.get("response", "") reddit_response = reddit_result.get("response", "") - # 清理可能的工具调用 JSON 包裹 + twitter_response = self._clean_tool_call_response(twitter_response) reddit_response = self._clean_tool_call_response(reddit_response) - # 始终输出双平台标记 - twitter_text = twitter_response if twitter_response else "(该平台未获得回复)" - reddit_text = reddit_response if reddit_response else "(该平台未获得回复)" - response_text = f"【Twitter平台回答】\n{twitter_text}\n\n【Reddit平台回答】\n{reddit_text}" + + twitter_text = twitter_response if twitter_response else "(No response from Twitter)" + reddit_text = reddit_response if reddit_response else "(No response from Reddit)" + response_text = f"Twitter response:\n{twitter_text}\n\nReddit response:\n{reddit_text}" - # 提取关键引言(从两个平台的回答中) + import re combined_responses = f"{twitter_response} {reddit_response}" - # 清理响应文本:去掉标记、编号、Markdown 等干扰 + clean_text = re.sub(r'#{1,6}\s+', '', combined_responses) clean_text = re.sub(r'\{[^}]*tool_name[^}]*\}', '', clean_text) clean_text = re.sub(r'[*_`|>~\-]{2,}', '', clean_text) - clean_text = re.sub(r'问题\d+[::]\s*', '', clean_text) - clean_text = re.sub(r'【[^】]+】', '', clean_text) + clean_text = re.sub( + r'(?:\u95ee\u9898|Question)\s*\d+[\uff1a:]\s*', '', clean_text, flags=re.IGNORECASE + ) + clean_text = re.sub(r'\u3010[^\u3011]+\u3011', '', clean_text) - # 策略1(主): 提取完整的有实质内容的句子 - sentences = re.split(r'[。!?]', clean_text) + + sentences = re.split(r'[.!?\u3002\uff01\uff1f]', clean_text) meaningful = [ s.strip() for s in sentences if 20 <= len(s.strip()) <= 150 - and not re.match(r'^[\s\W,,;;::、]+', s.strip()) - and not s.strip().startswith(('{', '问题')) + and not re.match(r'^[\s\W\uff0c,\uff1b;\uff1a:\u3001]+', s.strip()) + and not s.strip().startswith(('{', 'Question', 'question', '\u95ee\u9898')) ] meaningful.sort(key=len, reverse=True) - key_quotes = [s + "。" for s in meaningful[:3]] + key_quotes = [s + "." for s in meaningful[:3]] - # 策略2(补充): 正确配对的中文引号「」内长文本 + if not key_quotes: paired = re.findall(r'\u201c([^\u201c\u201d]{15,100})\u201d', clean_text) paired += re.findall(r'\u300c([^\u300c\u300d]{15,100})\u300d', clean_text) - key_quotes = [q for q in paired if not re.match(r'^[,,;;::、]', q)][:3] + key_quotes = [q for q in paired if not re.match(r'^[\uff0c,\uff1b;\uff1a:\u3001]', q)][:3] interview = AgentInterview( agent_name=agent_name, agent_role=agent_role, - agent_bio=agent_bio[:1000], # 扩大bio长度限制 + agent_bio=agent_bio[:1000], question=combined_prompt, response=response_text, key_quotes=key_quotes[:5] @@ -1460,30 +1245,33 @@ def interview_agents( result.interviewed_count = len(result.interviews) except ValueError as e: - # 模拟环境未运行 - logger.warning(f"采访API调用失败(环境未运行?): {e}") - result.summary = f"采访失败:{str(e)}。模拟环境可能已关闭,请确保OASIS环境正在运行。" + + logger.warning(f"Interview API call failed (environment not running?): {e}") + result.summary = ( + f"Interview failed: {str(e)}. The simulation environment may be offline. " + "Please make sure OASIS is running." + ) return result except Exception as e: - logger.error(f"采访API调用异常: {e}") + logger.error(f"Interview API exception: {e}") import traceback logger.error(traceback.format_exc()) - result.summary = f"采访过程发生错误:{str(e)}" + result.summary = f"An error occurred during interviewing: {str(e)}" return result - # Step 6: 生成采访摘要 + if result.interviews: result.summary = self._generate_interview_summary( interviews=result.interviews, interview_requirement=interview_requirement ) - logger.info(f"InterviewAgents完成: 采访了 {result.interviewed_count} 个Agent(双平台)") + logger.info(f"InterviewAgents completed: interviewed {result.interviewed_count} agents across both platforms") return result @staticmethod def _clean_tool_call_response(response: str) -> str: - """清理 Agent 回复中的 JSON 工具调用包裹,提取实际内容""" + """Clean tool call response.""" if not response or not response.strip().startswith('{'): return response text = response.strip() @@ -1503,11 +1291,11 @@ def _clean_tool_call_response(response: str) -> str: return response def _load_agent_profiles(self, simulation_id: str) -> List[Dict[str, Any]]: - """加载模拟的Agent人设文件""" + """Load agent profiles.""" import os import csv - # 构建人设文件路径 + sim_dir = os.path.join( os.path.dirname(__file__), f'../../uploads/simulations/{simulation_id}' @@ -1515,36 +1303,36 @@ def _load_agent_profiles(self, simulation_id: str) -> List[Dict[str, Any]]: profiles = [] - # 优先尝试读取Reddit JSON格式 + reddit_profile_path = os.path.join(sim_dir, "reddit_profiles.json") if os.path.exists(reddit_profile_path): try: with open(reddit_profile_path, 'r', encoding='utf-8') as f: profiles = json.load(f) - logger.info(f"从 reddit_profiles.json 加载了 {len(profiles)} 个人设") + logger.info(f"Loaded {len(profiles)} profiles from reddit_profiles.json") return profiles except Exception as e: - logger.warning(f"读取 reddit_profiles.json 失败: {e}") + logger.warning(f"Failed to read reddit_profiles.json: {e}") + - # 尝试读取Twitter CSV格式 twitter_profile_path = os.path.join(sim_dir, "twitter_profiles.csv") if os.path.exists(twitter_profile_path): try: with open(twitter_profile_path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: - # CSV格式转换为统一格式 + profiles.append({ "realname": row.get("name", ""), "username": row.get("username", ""), "bio": row.get("description", ""), "persona": row.get("user_char", ""), - "profession": "未知" + "profession": "Unknown" }) - logger.info(f"从 twitter_profiles.csv 加载了 {len(profiles)} 个人设") + logger.info(f"Loaded {len(profiles)} profiles from twitter_profiles.csv") return profiles except Exception as e: - logger.warning(f"读取 twitter_profiles.csv 失败: {e}") + logger.warning(f"Failed to read twitter_profiles.csv: {e}") return profiles @@ -1555,52 +1343,45 @@ def _select_agents_for_interview( simulation_requirement: str, max_agents: int ) -> tuple: - """ - 使用LLM选择要采访的Agent + """Select Agents For Interview.""" - Returns: - tuple: (selected_agents, selected_indices, reasoning) - - selected_agents: 选中Agent的完整信息列表 - - selected_indices: 选中Agent的索引列表(用于API调用) - - reasoning: 选择理由 - """ - # 构建Agent摘要列表 agent_summaries = [] for i, profile in enumerate(profiles): summary = { "index": i, "name": profile.get("realname", profile.get("username", f"Agent_{i}")), - "profession": profile.get("profession", "未知"), + "profession": profile.get("profession", "Unknown"), "bio": profile.get("bio", "")[:200], "interested_topics": profile.get("interested_topics", []) } agent_summaries.append(summary) - system_prompt = """你是一个专业的采访策划专家。你的任务是根据采访需求,从模拟Agent列表中选择最适合采访的对象。 + system_prompt = """You are an expert interview planner. Select the best agents to interview +for the given request. -选择标准: -1. Agent的身份/职业与采访主题相关 -2. Agent可能持有独特或有价值的观点 -3. 选择多样化的视角(如:支持方、反对方、中立方、专业人士等) -4. 优先选择与事件直接相关的角色 +Selection criteria: +1. The agent's identity or role should be relevant to the interview topic +2. The agent may hold a distinctive or valuable perspective +3. Prefer a diverse set of viewpoints +4. Prefer agents directly connected to the event when possible -返回JSON格式: +Return JSON: { - "selected_indices": [选中Agent的索引列表], - "reasoning": "选择理由说明" + "selected_indices": [0, 1], + "reasoning": "Why these agents were selected" }""" - user_prompt = f"""采访需求: + user_prompt = f"""Interview request: {interview_requirement} -模拟背景: -{simulation_requirement if simulation_requirement else "未提供"} +Simulation background: +{simulation_requirement if simulation_requirement else "Not provided"} -可选择的Agent列表(共{len(agent_summaries)}个): +Available agents ({len(agent_summaries)} total): {json.dumps(agent_summaries, ensure_ascii=False, indent=2)} -请选择最多{max_agents}个最适合采访的Agent,并说明选择理由。""" +Select up to {max_agents} agents and explain why.""" try: response = self.llm.chat_json( @@ -1612,9 +1393,9 @@ def _select_agents_for_interview( ) selected_indices = response.get("selected_indices", [])[:max_agents] - reasoning = response.get("reasoning", "基于相关性自动选择") + reasoning = response.get("reasoning", "Selected automatically based on relevance") + - # 获取选中的Agent完整信息 selected_agents = [] valid_indices = [] for idx in selected_indices: @@ -1625,11 +1406,11 @@ def _select_agents_for_interview( return selected_agents, valid_indices, reasoning except Exception as e: - logger.warning(f"LLM选择Agent失败,使用默认选择: {e}") - # 降级:选择前N个 + logger.warning(f"LLM failed to select agents, using fallback selection: {e}") + selected = profiles[:max_agents] indices = list(range(min(max_agents, len(profiles)))) - return selected, indices, "使用默认选择策略" + return selected, indices, "Used fallback selection strategy" def _generate_interview_questions( self, @@ -1637,29 +1418,29 @@ def _generate_interview_questions( simulation_requirement: str, selected_agents: List[Dict[str, Any]] ) -> List[str]: - """使用LLM生成采访问题""" + """Generate interview questions.""" - agent_roles = [a.get("profession", "未知") for a in selected_agents] + agent_roles = [a.get("profession", "Unknown") for a in selected_agents] - system_prompt = """你是一个专业的记者/采访者。根据采访需求,生成3-5个深度采访问题。 + system_prompt = """You are an expert interviewer. Generate 3-5 in-depth interview questions. -问题要求: -1. 开放性问题,鼓励详细回答 -2. 针对不同角色可能有不同答案 -3. 涵盖事实、观点、感受等多个维度 -4. 语言自然,像真实采访一样 -5. 每个问题控制在50字以内,简洁明了 -6. 直接提问,不要包含背景说明或前缀 +Requirements: +1. Ask open-ended questions that encourage detailed answers +2. Questions should allow different roles to answer differently +3. Cover facts, opinions, emotions, and implications +4. Keep the phrasing natural and interview-like +5. Keep each question concise +6. Ask directly without extra background framing -返回JSON格式:{"questions": ["问题1", "问题2", ...]}""" +Return JSON only: {"questions": ["Question 1", "Question 2"]}""" - user_prompt = f"""采访需求:{interview_requirement} + user_prompt = f"""Interview request: {interview_requirement} -模拟背景:{simulation_requirement if simulation_requirement else "未提供"} +Simulation background: {simulation_requirement if simulation_requirement else "Not provided"} -采访对象角色:{', '.join(agent_roles)} +Interviewee roles: {', '.join(agent_roles)} -请生成3-5个采访问题。""" +Generate 3-5 interview questions in English.""" try: response = self.llm.chat_json( @@ -1670,14 +1451,14 @@ def _generate_interview_questions( temperature=0.5 ) - return response.get("questions", [f"关于{interview_requirement},您有什么看法?"]) + return response.get("questions", [f"What is your view on {interview_requirement}?"]) except Exception as e: - logger.warning(f"生成采访问题失败: {e}") + logger.warning(f"Failed to generate interview questions: {e}") return [ - f"关于{interview_requirement},您的观点是什么?", - "这件事对您或您所代表的群体有什么影响?", - "您认为应该如何解决或改进这个问题?" + f"What is your perspective on {interview_requirement}?", + "How does this affect you or the group you represent?", + "What response or improvement would you most want to see?" ] def _generate_interview_summary( @@ -1685,38 +1466,38 @@ def _generate_interview_summary( interviews: List[AgentInterview], interview_requirement: str ) -> str: - """生成采访摘要""" + """Generate interview summary.""" if not interviews: - return "未完成任何采访" + return "No interviews were completed." + - # 收集所有采访内容 interview_texts = [] for interview in interviews: - interview_texts.append(f"【{interview.agent_name}({interview.agent_role})】\n{interview.response[:500]}") + interview_texts.append(f"[{interview.agent_name} ({interview.agent_role})]\n{interview.response[:500]}") - system_prompt = """你是一个专业的新闻编辑。请根据多位受访者的回答,生成一份采访摘要。 + system_prompt = """You are a professional editor. Summarize the interview responses in English. -摘要要求: -1. 提炼各方主要观点 -2. 指出观点的共识和分歧 -3. 突出有价值的引言 -4. 客观中立,不偏袒任何一方 -5. 控制在1000字内 +Summary requirements: +1. Capture each side's main viewpoint +2. Highlight areas of consensus and disagreement +3. Surface especially valuable quotes +4. Stay neutral and evidence-based +5. Keep the summary under 1000 words -格式约束(必须遵守): -- 使用纯文本段落,用空行分隔不同部分 -- 不要使用Markdown标题(如#、##、###) -- 不要使用分割线(如---、***) -- 引用受访者原话时使用中文引号「」 -- 可以使用**加粗**标记关键词,但不要使用其他Markdown语法""" +Formatting constraints: +- use plain text paragraphs separated by blank lines +- do not use Markdown headings +- do not use divider lines such as --- or *** +- use standard English quotation marks for quotes +- you may use **bold** for emphasis if helpful""" - user_prompt = f"""采访主题:{interview_requirement} + user_prompt = f"""Interview topic: {interview_requirement} -采访内容: +Interview content: {"".join(interview_texts)} -请生成采访摘要。""" +Generate the interview summary in English.""" try: summary = self.llm.chat( @@ -1730,6 +1511,9 @@ def _generate_interview_summary( return summary except Exception as e: - logger.warning(f"生成采访摘要失败: {e}") - # 降级:简单拼接 - return f"共采访了{len(interviews)}位受访者,包括:" + "、".join([i.agent_name for i in interviews]) + logger.warning(f"Failed to generate interview summary: {e}") + + return ( + f"Interviewed {len(interviews)} participants, including: " + + ", ".join([i.agent_name for i in interviews]) + ) diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py index 5848792b8..11c4d4a80 100644 --- a/backend/app/utils/__init__.py +++ b/backend/app/utils/__init__.py @@ -1,6 +1,4 @@ -""" -工具模块 -""" +"""Utility modules.""" from .file_parser import FileParser from .llm_client import LLMClient diff --git a/backend/app/utils/file_parser.py b/backend/app/utils/file_parser.py index 3f1d8ed2e..326df5acd 100644 --- a/backend/app/utils/file_parser.py +++ b/backend/app/utils/file_parser.py @@ -1,7 +1,4 @@ -""" -文件解析工具 -支持PDF、Markdown、TXT文件的文本提取 -""" +"""File parsing utilities.""" import os from pathlib import Path @@ -9,30 +6,16 @@ def _read_text_with_fallback(file_path: str) -> str: - """ - 读取文本文件,UTF-8失败时自动探测编码。 - - 采用多级回退策略: - 1. 首先尝试 UTF-8 解码 - 2. 使用 charset_normalizer 检测编码 - 3. 回退到 chardet 检测编码 - 4. 最终使用 UTF-8 + errors='replace' 兜底 - - Args: - file_path: 文件路径 - - Returns: - 解码后的文本内容 - """ + """Read text with fallback.""" data = Path(file_path).read_bytes() - # 首先尝试 UTF-8 + try: return data.decode('utf-8') except UnicodeDecodeError: pass - # 尝试使用 charset_normalizer 检测编码 + encoding = None try: from charset_normalizer import from_bytes @@ -42,7 +25,7 @@ def _read_text_with_fallback(file_path: str) -> str: except Exception: pass - # 回退到 chardet + if not encoding: try: import chardet @@ -51,7 +34,7 @@ def _read_text_with_fallback(file_path: str) -> str: except Exception: pass - # 最终兜底:使用 UTF-8 + replace + if not encoding: encoding = 'utf-8' @@ -59,30 +42,22 @@ def _read_text_with_fallback(file_path: str) -> str: class FileParser: - """文件解析器""" + """File Parser.""" SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'} @classmethod def extract_text(cls, file_path: str) -> str: - """ - 从文件中提取文本 - - Args: - file_path: 文件路径 - - Returns: - 提取的文本内容 - """ + """Extract text.""" path = Path(file_path) if not path.exists(): - raise FileNotFoundError(f"文件不存在: {file_path}") + raise FileNotFoundError(f"File not found: {file_path}") suffix = path.suffix.lower() if suffix not in cls.SUPPORTED_EXTENSIONS: - raise ValueError(f"不支持的文件格式: {suffix}") + raise ValueError(f"Unsupported file format: {suffix}") if suffix == '.pdf': return cls._extract_from_pdf(file_path) @@ -91,15 +66,15 @@ def extract_text(cls, file_path: str) -> str: elif suffix == '.txt': return cls._extract_from_txt(file_path) - raise ValueError(f"无法处理的文件格式: {suffix}") + raise ValueError(f"Unable to process file format: {suffix}") @staticmethod def _extract_from_pdf(file_path: str) -> str: - """从PDF提取文本""" + """Extract from pdf.""" try: import fitz # PyMuPDF except ImportError: - raise ImportError("需要安装PyMuPDF: pip install PyMuPDF") + raise ImportError("PyMuPDF is required: pip install PyMuPDF") text_parts = [] with fitz.open(file_path) as doc: @@ -112,34 +87,26 @@ def _extract_from_pdf(file_path: str) -> str: @staticmethod def _extract_from_md(file_path: str) -> str: - """从Markdown提取文本,支持自动编码检测""" + """Extract from markdown.""" return _read_text_with_fallback(file_path) @staticmethod def _extract_from_txt(file_path: str) -> str: - """从TXT提取文本,支持自动编码检测""" + """Extract from txt.""" return _read_text_with_fallback(file_path) @classmethod def extract_from_multiple(cls, file_paths: List[str]) -> str: - """ - 从多个文件提取文本并合并 - - Args: - file_paths: 文件路径列表 - - Returns: - 合并后的文本 - """ + """Extract from multiple.""" all_texts = [] for i, file_path in enumerate(file_paths, 1): try: text = cls.extract_text(file_path) filename = Path(file_path).name - all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}") + all_texts.append(f"=== Document {i}: {filename} ===\n{text}") except Exception as e: - all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===") + all_texts.append(f"=== Document {i}: {file_path} (extraction failed: {str(e)}) ===") return "\n\n".join(all_texts) @@ -149,17 +116,7 @@ def split_text_into_chunks( chunk_size: int = 500, overlap: int = 50 ) -> List[str]: - """ - 将文本分割成小块 - - Args: - text: 原始文本 - chunk_size: 每块的字符数 - overlap: 重叠字符数 - - Returns: - 文本块列表 - """ + """Split text into chunks.""" if len(text) <= chunk_size: return [text] if text.strip() else [] @@ -169,10 +126,10 @@ def split_text_into_chunks( while start < len(text): end = start + chunk_size - # 尝试在句子边界处分割 + if end < len(text): - # 查找最近的句子结束符 - for sep in ['。', '!', '?', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']: + + for sep in ['\u3002', '\uff01', '\uff1f', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']: last_sep = text[start:end].rfind(sep) if last_sep != -1 and last_sep > chunk_size * 0.3: end = start + last_sep + len(sep) @@ -182,8 +139,7 @@ def split_text_into_chunks( if chunk: chunks.append(chunk) - # 下一个块从重叠位置开始 + start = end - overlap if end < len(text) else len(text) return chunks - diff --git a/backend/app/utils/llm_client.py b/backend/app/utils/llm_client.py index 6c1a81f49..7cdfafaca 100644 --- a/backend/app/utils/llm_client.py +++ b/backend/app/utils/llm_client.py @@ -1,7 +1,4 @@ -""" -LLM客户端封装 -统一使用OpenAI格式调用 -""" +"""LLM client wrapper.""" import json import re @@ -12,7 +9,7 @@ class LLMClient: - """LLM客户端""" + """LLM Client.""" def __init__( self, @@ -25,7 +22,7 @@ def __init__( self.model = model or Config.LLM_MODEL_NAME if not self.api_key: - raise ValueError("LLM_API_KEY 未配置") + raise ValueError("LLM_API_KEY is not configured") self.client = OpenAI( api_key=self.api_key, @@ -39,18 +36,7 @@ def chat( max_tokens: int = 4096, response_format: Optional[Dict] = None ) -> str: - """ - 发送聊天请求 - - Args: - messages: 消息列表 - temperature: 温度参数 - max_tokens: 最大token数 - response_format: 响应格式(如JSON模式) - - Returns: - 模型响应文本 - """ + """Send a chat request.""" kwargs = { "model": self.model, "messages": messages, @@ -63,7 +49,7 @@ def chat( response = self.client.chat.completions.create(**kwargs) content = response.choices[0].message.content - # 部分模型(如MiniMax M2.5)会在content中包含思考内容,需要移除 + content = re.sub(r'[\s\S]*?', '', content).strip() return content @@ -73,24 +59,14 @@ def chat_json( temperature: float = 0.3, max_tokens: int = 4096 ) -> Dict[str, Any]: - """ - 发送聊天请求并返回JSON - - Args: - messages: 消息列表 - temperature: 温度参数 - max_tokens: 最大token数 - - Returns: - 解析后的JSON对象 - """ + """Send a chat request and return JSON.""" response = self.chat( messages=messages, temperature=temperature, max_tokens=max_tokens, response_format={"type": "json_object"} ) - # 清理markdown代码块标记 + cleaned_response = response.strip() cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE) cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response) @@ -99,5 +75,4 @@ def chat_json( try: return json.loads(cleaned_response) except json.JSONDecodeError: - raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") - + raise ValueError(f"LLM returned invalid JSON: {cleaned_response}") diff --git a/backend/app/utils/logger.py b/backend/app/utils/logger.py index 1978c0b84..6bbc51063 100644 --- a/backend/app/utils/logger.py +++ b/backend/app/utils/logger.py @@ -1,7 +1,4 @@ -""" -日志配置模块 -提供统一的日志管理,同时输出到控制台和文件 -""" +"""Logging configuration utilities.""" import os import sys @@ -11,48 +8,36 @@ def _ensure_utf8_stdout(): - """ - 确保 stdout/stderr 使用 UTF-8 编码 - 解决 Windows 控制台中文乱码问题 - """ + """Ensure utf8 stdout.""" if sys.platform == 'win32': - # Windows 下重新配置标准输出为 UTF-8 + if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') if hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(encoding='utf-8', errors='replace') -# 日志目录 + LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs') def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger: - """ - 设置日志器 + """Setup Logger.""" - Args: - name: 日志器名称 - level: 日志级别 - - Returns: - 配置好的日志器 - """ - # 确保日志目录存在 os.makedirs(LOG_DIR, exist_ok=True) - # 创建日志器 + logger = logging.getLogger(name) logger.setLevel(level) - # 阻止日志向上传播到根 logger,避免重复输出 + logger.propagate = False - # 如果已经有处理器,不重复添加 + if logger.handlers: return logger - # 日志格式 + detailed_formatter = logging.Formatter( '[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' @@ -63,7 +48,7 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging. datefmt='%H:%M:%S' ) - # 1. 文件处理器 - 详细日志(按日期命名,带轮转) + log_filename = datetime.now().strftime('%Y-%m-%d') + '.log' file_handler = RotatingFileHandler( os.path.join(LOG_DIR, log_filename), @@ -74,14 +59,14 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging. file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(detailed_formatter) - # 2. 控制台处理器 - 简洁日志(INFO及以上) - # 确保 Windows 下使用 UTF-8 编码,避免中文乱码 + + _ensure_utf8_stdout() console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) console_handler.setFormatter(simple_formatter) - # 添加处理器 + logger.addHandler(file_handler) logger.addHandler(console_handler) @@ -89,26 +74,18 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging. def get_logger(name: str = 'mirofish') -> logging.Logger: - """ - 获取日志器(如果不存在则创建) - - Args: - name: 日志器名称 - - Returns: - 日志器实例 - """ + """Get logger.""" logger = logging.getLogger(name) if not logger.handlers: return setup_logger(name) return logger -# 创建默认日志器 + logger = setup_logger() -# 便捷方法 + def debug(msg, *args, **kwargs): logger.debug(msg, *args, **kwargs) diff --git a/backend/app/utils/ontology_normalizer.py b/backend/app/utils/ontology_normalizer.py new file mode 100644 index 000000000..eae0c8b20 --- /dev/null +++ b/backend/app/utils/ontology_normalizer.py @@ -0,0 +1,119 @@ +""" +Utilities for normalizing ontology names before sending them to Zep. +""" + +from __future__ import annotations + +import copy +import re +from typing import Any, Dict, Tuple + + +PASCAL_CASE_PATTERN = re.compile(r"^[A-Z][A-Za-z0-9]*$") + + +def _split_name_parts(raw_name: str) -> list[str]: + text = str(raw_name or "").strip() + if not text: + return [] + + text = re.sub(r"[^A-Za-z0-9]+", " ", text) + text = re.sub(r"(?<=[a-z0-9])(?=[A-Z])", " ", text) + text = re.sub(r"(?<=[A-Z])(?=[A-Z][a-z])", " ", text) + text = re.sub(r"(?<=[A-Za-z])(?=[0-9])", " ", text) + text = re.sub(r"(?<=[0-9])(?=[A-Za-z])", " ", text) + return [part for part in text.split() if part] + + +def normalize_pascal_case_name(raw_name: str, default_prefix: str = "Entity") -> str: + """ + Convert an arbitrary label into Zep-safe PascalCase. + """ + text = str(raw_name or "").strip() + if text and PASCAL_CASE_PATTERN.match(text): + return text + + parts = _split_name_parts(text) + if not parts: + return default_prefix + + normalized_parts = [] + for part in parts: + if part.isdigit(): + normalized_parts.append(part) + elif part.isupper() and len(part) > 1: + normalized_parts.append(part) + else: + normalized_parts.append(part[0].upper() + part[1:].lower()) + + normalized = "".join(normalized_parts) + + if not normalized: + normalized = default_prefix + elif not normalized[0].isalpha(): + normalized = f"{default_prefix}{normalized}" + + return normalized + + +def _ensure_unique_name(base_name: str, used_names: set[str]) -> str: + candidate = base_name + suffix = 2 + + while candidate in used_names: + candidate = f"{base_name}{suffix}" + suffix += 1 + + used_names.add(candidate) + return candidate + + +def normalize_ontology_for_zep(ontology: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, str]]: + """ + Normalize ontology entity names and source/target references for Zep validation. + + Returns: + A tuple of (normalized_ontology, entity_name_mapping) + """ + normalized = copy.deepcopy(ontology or {}) + entity_types = normalized.setdefault("entity_types", []) + edge_types = normalized.setdefault("edge_types", []) + + used_entity_names: set[str] = set() + entity_name_mapping: Dict[str, str] = {} + + for entity in entity_types: + raw_name = str(entity.get("name", "")).strip() + safe_name = normalize_pascal_case_name(raw_name, default_prefix="Entity") + safe_name = _ensure_unique_name(safe_name, used_entity_names) + + entity["name"] = safe_name + + if raw_name: + entity_name_mapping[raw_name] = safe_name + entity_name_mapping[raw_name.strip()] = safe_name + entity_name_mapping[safe_name] = safe_name + + for edge in edge_types: + source_targets = edge.setdefault("source_targets", []) + for source_target in source_targets: + raw_source = str(source_target.get("source", "")).strip() + raw_target = str(source_target.get("target", "")).strip() + + if raw_source: + source_target["source"] = entity_name_mapping.get( + raw_source, + normalize_pascal_case_name(raw_source, default_prefix="Entity"), + ) + else: + source_target["source"] = "Entity" + + if raw_target: + source_target["target"] = entity_name_mapping.get( + raw_target, + normalize_pascal_case_name(raw_target, default_prefix="Entity"), + ) + else: + source_target["target"] = "Entity" + + return normalized, entity_name_mapping diff --git a/backend/app/utils/retry.py b/backend/app/utils/retry.py index 819b1cfcf..b5928a446 100644 --- a/backend/app/utils/retry.py +++ b/backend/app/utils/retry.py @@ -1,7 +1,4 @@ -""" -API调用重试机制 -用于处理LLM等外部API调用的重试逻辑 -""" +"""API retry utilities.""" import time import random @@ -21,23 +18,7 @@ def retry_with_backoff( exceptions: Tuple[Type[Exception], ...] = (Exception,), on_retry: Optional[Callable[[Exception, int], None]] = None ): - """ - 带指数退避的重试装饰器 - - Args: - max_retries: 最大重试次数 - initial_delay: 初始延迟(秒) - max_delay: 最大延迟(秒) - backoff_factor: 退避因子 - jitter: 是否添加随机抖动 - exceptions: 需要重试的异常类型 - on_retry: 重试时的回调函数 (exception, retry_count) - - Usage: - @retry_with_backoff(max_retries=3) - def call_llm_api(): - ... - """ + """Retry With Backoff.""" def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: @@ -52,17 +33,17 @@ def wrapper(*args, **kwargs) -> Any: last_exception = e if attempt == max_retries: - logger.error(f"函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}") + logger.error(f"Function {func.__name__} still failed after {max_retries} retries: {str(e)}") raise - # 计算延迟 + current_delay = min(delay, max_delay) if jitter: current_delay = current_delay * (0.5 + random.random()) logger.warning( - f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, " - f"{current_delay:.1f}秒后重试..." + f"Function {func.__name__} attempt {attempt + 1} failed: {str(e)}. " + f"Retrying in {current_delay:.1f}s..." ) if on_retry: @@ -86,9 +67,7 @@ def retry_with_backoff_async( exceptions: Tuple[Type[Exception], ...] = (Exception,), on_retry: Optional[Callable[[Exception, int], None]] = None ): - """ - 异步版本的重试装饰器 - """ + """Retry With Backoff Async.""" import asyncio def decorator(func: Callable) -> Callable: @@ -105,7 +84,9 @@ async def wrapper(*args, **kwargs) -> Any: last_exception = e if attempt == max_retries: - logger.error(f"异步函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}") + logger.error( + f"Async function {func.__name__} still failed after {max_retries} retries: {str(e)}" + ) raise current_delay = min(delay, max_delay) @@ -113,8 +94,8 @@ async def wrapper(*args, **kwargs) -> Any: current_delay = current_delay * (0.5 + random.random()) logger.warning( - f"异步函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, " - f"{current_delay:.1f}秒后重试..." + f"Async function {func.__name__} attempt {attempt + 1} failed: {str(e)}. " + f"Retrying in {current_delay:.1f}s..." ) if on_retry: @@ -130,9 +111,7 @@ async def wrapper(*args, **kwargs) -> Any: class RetryableAPIClient: - """ - 可重试的API客户端封装 - """ + """Retryable API Client.""" def __init__( self, @@ -153,18 +132,7 @@ def call_with_retry( exceptions: Tuple[Type[Exception], ...] = (Exception,), **kwargs ) -> Any: - """ - 执行函数调用并在失败时重试 - - Args: - func: 要调用的函数 - *args: 函数参数 - exceptions: 需要重试的异常类型 - **kwargs: 函数关键字参数 - - Returns: - 函数返回值 - """ + """Call with retry.""" last_exception = None delay = self.initial_delay @@ -176,15 +144,15 @@ def call_with_retry( last_exception = e if attempt == self.max_retries: - logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}") + logger.error(f"API call still failed after {self.max_retries} retries: {str(e)}") raise current_delay = min(delay, self.max_delay) current_delay = current_delay * (0.5 + random.random()) logger.warning( - f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, " - f"{current_delay:.1f}秒后重试..." + f"API call attempt {attempt + 1} failed: {str(e)}. " + f"Retrying in {current_delay:.1f}s..." ) time.sleep(current_delay) @@ -199,18 +167,7 @@ def call_batch_with_retry( exceptions: Tuple[Type[Exception], ...] = (Exception,), continue_on_failure: bool = True ) -> Tuple[list, list]: - """ - 批量调用并对每个失败项单独重试 - - Args: - items: 要处理的项目列表 - process_func: 处理函数,接收单个item作为参数 - exceptions: 需要重试的异常类型 - continue_on_failure: 单项失败后是否继续处理其他项 - - Returns: - (成功结果列表, 失败项列表) - """ + """Call batch with retry.""" results = [] failures = [] @@ -224,7 +181,7 @@ def call_batch_with_retry( results.append(result) except Exception as e: - logger.error(f"处理第 {idx + 1} 项失败: {str(e)}") + logger.error(f"Failed to process item {idx + 1}: {str(e)}") failures.append({ "index": idx, "item": item, @@ -235,4 +192,3 @@ def call_batch_with_retry( raise return results, failures - diff --git a/backend/app/utils/zep_paging.py b/backend/app/utils/zep_paging.py index 943cd1ae2..bb3b2c7f6 100644 --- a/backend/app/utils/zep_paging.py +++ b/backend/app/utils/zep_paging.py @@ -1,8 +1,4 @@ -"""Zep Graph 分页读取工具。 - -Zep 的 node/edge 列表接口使用 UUID cursor 分页, -本模块封装自动翻页逻辑(含单页重试),对调用方透明地返回完整列表。 -""" +"""Zep graph pagination utilities.""" from __future__ import annotations @@ -31,7 +27,7 @@ def _fetch_page_with_retry( page_description: str = "page", **kwargs: Any, ) -> list[Any]: - """单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。""" + """Fetch page with retry.""" if max_retries < 1: raise ValueError("max_retries must be >= 1") @@ -64,7 +60,7 @@ def fetch_all_nodes( max_retries: int = _DEFAULT_MAX_RETRIES, retry_delay: float = _DEFAULT_RETRY_DELAY, ) -> list[Any]: - """分页获取图谱节点,最多返回 max_items 条(默认 2000)。每页请求自带重试。""" + """Fetch all nodes.""" all_nodes: list[Any] = [] cursor: str | None = None page_num = 0 @@ -109,7 +105,7 @@ def fetch_all_edges( max_retries: int = _DEFAULT_MAX_RETRIES, retry_delay: float = _DEFAULT_RETRY_DELAY, ) -> list[Any]: - """分页获取图谱所有边,返回完整列表。每页请求自带重试。""" + """Fetch all edges.""" all_edges: list[Any] = [] cursor: str | None = None page_num = 0 diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f5361d53..9dc96ac56 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "mirofish-backend" version = "0.1.0" -description = "MiroFish - 简洁通用的群体智能引擎,预测万物" +description = "MiroFish - A simple, universal swarm intelligence engine for predicting anything" requires-python = ">=3.11" license = { text = "AGPL-3.0" } authors = [ @@ -9,27 +9,30 @@ authors = [ ] dependencies = [ - # 核心框架 + # Core framework "flask>=3.0.0", "flask-cors>=6.0.0", - # LLM 相关 - "openai>=1.0.0", + # LLM support + "openai>=1.91.0", - # Zep Cloud + # Graph backends "zep-cloud==3.13.0", + "numpy>=1.0.0", + "posthog>=3.0.0", + "tenacity>=9.0.0", - # OASIS 社交媒体模拟 + # OASIS social media simulation "camel-oasis==0.2.5", "camel-ai==0.2.78", - # 文件处理 + # File processing "PyMuPDF>=1.24.0", - # 编码检测(支持非UTF-8编码的文本文件) + # Encoding detection (supports text files that are not UTF-8) "charset-normalizer>=3.0.0", "chardet>=5.0.0", - # 工具库 + # Utilities "python-dotenv>=1.0.0", "pydantic>=2.0.0", ] diff --git a/backend/requirements.txt b/backend/requirements.txt index 4f146296b..a5a053c42 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,31 +5,37 @@ # Install: pip install -r requirements.txt # =========================================== -# ============= 核心框架 ============= +# ============= Core Framework ============= flask>=3.0.0 flask-cors>=6.0.0 -# ============= LLM 相关 ============= -# OpenAI SDK(统一使用 OpenAI 格式调用 LLM) -openai>=1.0.0 +# ============= LLM Support ============= +# OpenAI SDK (all LLM calls use the OpenAI-compatible format) +openai>=1.91.0 -# ============= Zep Cloud ============= +# ============= Graph Backends ============= zep-cloud==3.13.0 +# Local Graphiti support is installed separately with: +# uv pip install --python .venv/bin/python --no-deps graphiti-core==0.28.2 +# This avoids conflicting Neo4j driver pins between Graphiti and camel-oasis. +numpy>=1.0.0 +posthog>=3.0.0 +tenacity>=9.0.0 -# ============= OASIS 社交媒体模拟 ============= -# OASIS 社交模拟框架 +# ============= OASIS Social Media Simulation ============= +# OASIS social simulation framework camel-oasis==0.2.5 camel-ai==0.2.78 -# ============= 文件处理 ============= +# ============= File Processing ============= PyMuPDF>=1.24.0 -# 编码检测(支持非UTF-8编码的文本文件) +# Encoding detection (supports text files that are not UTF-8) charset-normalizer>=3.0.0 chardet>=5.0.0 -# ============= 工具库 ============= -# 环境变量加载 +# ============= Utilities ============= +# Environment variable loading python-dotenv>=1.0.0 -# 数据验证 +# Data validation pydantic>=2.0.0 diff --git a/backend/run.py b/backend/run.py index 4e3b04fa9..28ec415d5 100644 --- a/backend/run.py +++ b/backend/run.py @@ -1,21 +1,19 @@ -""" -MiroFish Backend 启动入口 -""" +"""MiroFish backend entry point.""" import os import sys -# 解决 Windows 控制台中文乱码问题:在所有导入之前设置 UTF-8 编码 + if sys.platform == 'win32': - # 设置环境变量确保 Python 使用 UTF-8 + os.environ.setdefault('PYTHONIOENCODING', 'utf-8') - # 重新配置标准输出流为 UTF-8 + if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') if hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(encoding='utf-8', errors='replace') -# 添加项目根目录到路径 + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from app import create_app @@ -23,28 +21,27 @@ def main(): - """主函数""" - # 验证配置 + """Run the main entry point.""" + errors = Config.validate() if errors: - print("配置错误:") + print("Configuration errors:") for err in errors: print(f" - {err}") - print("\n请检查 .env 文件中的配置") + print("\nPlease check the configuration in your .env file") sys.exit(1) - # 创建应用 + app = create_app() - # 获取运行配置 + host = os.environ.get('FLASK_HOST', '0.0.0.0') port = int(os.environ.get('FLASK_PORT', 5001)) debug = Config.DEBUG - # 启动服务 + app.run(host=host, port=port, debug=debug, threaded=True) if __name__ == '__main__': main() - diff --git a/backend/scripts/action_logger.py b/backend/scripts/action_logger.py index 38d025a6c..ed40a1cea 100644 --- a/backend/scripts/action_logger.py +++ b/backend/scripts/action_logger.py @@ -1,16 +1,4 @@ -""" -动作日志记录器 -用于记录OASIS模拟中每个Agent的动作,供后端监控使用 - -日志结构: - sim_xxx/ - ├── twitter/ - │ └── actions.jsonl # Twitter 平台动作日志 - ├── reddit/ - │ └── actions.jsonl # Reddit 平台动作日志 - ├── simulation.log # 主模拟进程日志 - └── run_state.json # 运行状态(API 查询用) -""" +"""Action logging utilities.""" import json import os @@ -20,16 +8,10 @@ class PlatformActionLogger: - """单平台动作日志记录器""" + """Platform Action Logger.""" def __init__(self, platform: str, base_dir: str): - """ - 初始化日志记录器 - - Args: - platform: 平台名称 (twitter/reddit) - base_dir: 模拟目录的基础路径 - """ + """Initialize the instance.""" self.platform = platform self.base_dir = base_dir self.log_dir = os.path.join(base_dir, platform) @@ -37,7 +19,7 @@ def __init__(self, platform: str, base_dir: str): self._ensure_dir() def _ensure_dir(self): - """确保目录存在""" + """Ensure dir.""" os.makedirs(self.log_dir, exist_ok=True) def log_action( @@ -50,7 +32,7 @@ def log_action( result: Optional[str] = None, success: bool = True ): - """记录一个动作""" + """Log Action.""" entry = { "round": round_num, "timestamp": datetime.now().isoformat(), @@ -66,7 +48,7 @@ def log_action( f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_round_start(self, round_num: int, simulated_hour: int): - """记录轮次开始""" + """Log Round Start.""" entry = { "round": round_num, "timestamp": datetime.now().isoformat(), @@ -78,7 +60,7 @@ def log_round_start(self, round_num: int, simulated_hour: int): f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_round_end(self, round_num: int, actions_count: int): - """记录轮次结束""" + """Log Round End.""" entry = { "round": round_num, "timestamp": datetime.now().isoformat(), @@ -90,7 +72,7 @@ def log_round_end(self, round_num: int, actions_count: int): f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_simulation_start(self, config: Dict[str, Any]): - """记录模拟开始""" + """Log Simulation Start.""" entry = { "timestamp": datetime.now().isoformat(), "event_type": "simulation_start", @@ -103,7 +85,7 @@ def log_simulation_start(self, config: Dict[str, Any]): f.write(json.dumps(entry, ensure_ascii=False) + '\n') def log_simulation_end(self, total_rounds: int, total_actions: int): - """记录模拟结束""" + """Log Simulation End.""" entry = { "timestamp": datetime.now().isoformat(), "event_type": "simulation_end", @@ -117,36 +99,28 @@ def log_simulation_end(self, total_rounds: int, total_actions: int): class SimulationLogManager: - """ - 模拟日志管理器 - 统一管理所有日志文件,按平台分离 - """ + """Simulation Log Manager.""" def __init__(self, simulation_dir: str): - """ - 初始化日志管理器 - - Args: - simulation_dir: 模拟目录路径 - """ + """Initialize the instance.""" self.simulation_dir = simulation_dir self.twitter_logger: Optional[PlatformActionLogger] = None self.reddit_logger: Optional[PlatformActionLogger] = None self._main_logger: Optional[logging.Logger] = None - # 设置主日志 + self._setup_main_logger() def _setup_main_logger(self): - """设置主模拟日志""" + """Setup Main Logger.""" log_path = os.path.join(self.simulation_dir, "simulation.log") - # 创建 logger + self._main_logger = logging.getLogger(f"simulation.{os.path.basename(self.simulation_dir)}") self._main_logger.setLevel(logging.INFO) self._main_logger.handlers.clear() - # 文件处理器 + file_handler = logging.FileHandler(log_path, encoding='utf-8', mode='w') file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter( @@ -155,7 +129,7 @@ def _setup_main_logger(self): )) self._main_logger.addHandler(file_handler) - # 控制台处理器 + console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(logging.Formatter( @@ -167,19 +141,19 @@ def _setup_main_logger(self): self._main_logger.propagate = False def get_twitter_logger(self) -> PlatformActionLogger: - """获取 Twitter 平台日志记录器""" + """Get twitter logger.""" if self.twitter_logger is None: self.twitter_logger = PlatformActionLogger("twitter", self.simulation_dir) return self.twitter_logger def get_reddit_logger(self) -> PlatformActionLogger: - """获取 Reddit 平台日志记录器""" + """Get reddit logger.""" if self.reddit_logger is None: self.reddit_logger = PlatformActionLogger("reddit", self.simulation_dir) return self.reddit_logger def log(self, message: str, level: str = "info"): - """记录主日志""" + """Log.""" if self._main_logger: getattr(self._main_logger, level.lower(), self._main_logger.info)(message) @@ -196,13 +170,10 @@ def debug(self, message: str): self.log(message, "debug") -# ============ 兼容旧接口 ============ + class ActionLogger: - """ - 动作日志记录器(兼容旧接口) - 建议使用 SimulationLogManager 代替 - """ + """Action Logger.""" def __init__(self, log_path: str): self.log_path = log_path @@ -288,12 +259,12 @@ def log_simulation_end(self, platform: str, total_rounds: int, total_actions: in f.write(json.dumps(entry, ensure_ascii=False) + '\n') -# 全局日志实例(兼容旧接口) + _global_logger: Optional[ActionLogger] = None def get_logger(log_path: Optional[str] = None) -> ActionLogger: - """获取全局日志实例(兼容旧接口)""" + """Get logger.""" global _global_logger if log_path: diff --git a/backend/scripts/run_parallel_simulation.py b/backend/scripts/run_parallel_simulation.py index 2a627ffd0..431f716af 100644 --- a/backend/scripts/run_parallel_simulation.py +++ b/backend/scripts/run_parallel_simulation.py @@ -1,62 +1,34 @@ -""" -OASIS 双平台并行模拟预设脚本 -同时运行Twitter和Reddit模拟,读取相同的配置文件 - -功能特性: -- 双平台(Twitter + Reddit)并行模拟 -- 完成模拟后不立即关闭环境,进入等待命令模式 -- 支持通过IPC接收Interview命令 -- 支持单个Agent采访和批量采访 -- 支持远程关闭环境命令 - -使用方式: - python run_parallel_simulation.py --config simulation_config.json - python run_parallel_simulation.py --config simulation_config.json --no-wait # 完成后立即关闭 - python run_parallel_simulation.py --config simulation_config.json --twitter-only - python run_parallel_simulation.py --config simulation_config.json --reddit-only - -日志结构: - sim_xxx/ - ├── twitter/ - │ └── actions.jsonl # Twitter 平台动作日志 - ├── reddit/ - │ └── actions.jsonl # Reddit 平台动作日志 - ├── simulation.log # 主模拟进程日志 - └── run_state.json # 运行状态(API 查询用) -""" +"""OASIS dual-platform parallel simulation runner.""" # ============================================================ -# 解决 Windows 编码问题:在所有 import 之前设置 UTF-8 编码 -# 这是为了修复 OASIS 第三方库读取文件时未指定编码的问题 + + # ============================================================ import sys import os if sys.platform == 'win32': - # 设置 Python 默认 I/O 编码为 UTF-8 - # 这会影响所有未指定编码的 open() 调用 + + os.environ.setdefault('PYTHONUTF8', '1') os.environ.setdefault('PYTHONIOENCODING', 'utf-8') - # 重新配置标准输出流为 UTF-8(解决控制台中文乱码) + if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') if hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(encoding='utf-8', errors='replace') - # 强制设置默认编码(影响 open() 函数的默认编码) - # 注意:这需要在 Python 启动时就设置,运行时设置可能不生效 - # 所以我们还需要 monkey-patch 内置的 open 函数 + + + import builtins _original_open = builtins.open def _utf8_open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None): - """ - 包装 open() 函数,对于文本模式默认使用 UTF-8 编码 - 这可以修复第三方库(如 OASIS)读取文件时未指定编码的问题 - """ - # 只对文本模式(非二进制)且未指定编码的情况设置默认编码 + """Utf8 Open.""" + if encoding is None and 'b' not in mode: encoding = 'utf-8' return _original_open(file, mode, buffering, encoding, errors, @@ -77,52 +49,49 @@ def _utf8_open(file, mode='r', buffering=-1, encoding=None, errors=None, from typing import Dict, Any, List, Optional, Tuple -# 全局变量:用于信号处理 + _shutdown_event = None _cleanup_done = False -# 添加 backend 目录到路径 -# 脚本固定位于 backend/scripts/ 目录 + + _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) -# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) + from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): load_dotenv(_env_file) - print(f"已加载环境配置: {_env_file}") + print(f"Loaded environment config: {_env_file}") else: - # 尝试加载 backend/.env + _backend_env = os.path.join(_backend_dir, '.env') if os.path.exists(_backend_env): load_dotenv(_backend_env) - print(f"已加载环境配置: {_backend_env}") + print(f"Loaded environment config: {_backend_env}") class MaxTokensWarningFilter(logging.Filter): - """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" + """Max Tokens Warning Filter.""" def filter(self, record): - # 过滤掉包含 max_tokens 警告的日志 + if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True -# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 + logging.getLogger().addFilter(MaxTokensWarningFilter()) def disable_oasis_logging(): - """ - 禁用 OASIS 库的详细日志输出 - OASIS 的日志太冗余(记录每个 agent 的观察和动作),我们使用自己的 action_logger - """ - # 禁用 OASIS 的所有日志器 + """Disable oasis logging.""" + oasis_loggers = [ "social.agent", "social.twitter", @@ -133,22 +102,17 @@ def disable_oasis_logging(): for logger_name in oasis_loggers: logger = logging.getLogger(logger_name) - logger.setLevel(logging.CRITICAL) # 只记录严重错误 + logger.setLevel(logging.CRITICAL) logger.handlers.clear() logger.propagate = False def init_logging_for_simulation(simulation_dir: str): - """ - 初始化模拟的日志配置 + """Init Logging For Simulation.""" - Args: - simulation_dir: 模拟目录路径 - """ - # 禁用 OASIS 的详细日志 disable_oasis_logging() - # 清理旧的 log 目录(如果存在) + old_log_dir = os.path.join(simulation_dir, "log") if os.path.exists(old_log_dir): import shutil @@ -169,12 +133,12 @@ def init_logging_for_simulation(simulation_dir: str): generate_reddit_agent_graph ) except ImportError as e: - print(f"错误: 缺少依赖 {e}") - print("请先安装: pip install oasis-ai camel-ai") + print(f"Error: missing dependency {e}") + print("Install the dependencies first: pip install oasis-ai camel-ai") sys.exit(1) -# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) + TWITTER_ACTIONS = [ ActionType.CREATE_POST, ActionType.LIKE_POST, @@ -184,7 +148,7 @@ def init_logging_for_simulation(simulation_dir: str): ActionType.QUOTE_POST, ] -# Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) + REDDIT_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, @@ -202,24 +166,20 @@ def init_logging_for_simulation(simulation_dir: str): ] -# IPC相关常量 + IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: - """命令类型常量""" + """Command Type.""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class ParallelIPCHandler: - """ - 双平台IPC命令处理器 - - 管理两个平台的环境,处理Interview命令 - """ + """Parallel IPC Handler.""" def __init__( self, @@ -239,12 +199,12 @@ def __init__( self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) - # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): - """更新环境状态""" + """Update status.""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, @@ -254,11 +214,11 @@ def update_status(self, status: str): }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: - """轮询获取待处理命令""" + """Poll command.""" if not os.path.exists(self.commands_dir): return None - # 获取命令文件(按时间排序) + command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): @@ -277,7 +237,7 @@ def poll_command(self) -> Optional[Dict[str, Any]]: return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): - """发送响应""" + """Send response.""" response = { "command_id": command_id, "status": status, @@ -290,7 +250,7 @@ def send_response(self, command_id: str, status: str, result: Dict = None, error with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) - # 删除命令文件 + command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) @@ -298,15 +258,7 @@ def send_response(self, command_id: str, status: str, result: Dict = None, error pass def _get_env_and_graph(self, platform: str): - """ - 获取指定平台的环境和agent_graph - - Args: - platform: 平台名称 ("twitter" 或 "reddit") - - Returns: - (env, agent_graph, platform_name) 或 (None, None, None) - """ + """Get env and graph.""" if platform == "twitter" and self.twitter_env: return self.twitter_env, self.twitter_agent_graph, "twitter" elif platform == "reddit" and self.reddit_env: @@ -315,16 +267,11 @@ def _get_env_and_graph(self, platform: str): return None, None, None async def _interview_single_platform(self, agent_id: int, prompt: str, platform: str) -> Dict[str, Any]: - """ - 在单个平台上执行Interview - - Returns: - 包含结果的字典,或包含error的字典 - """ + """Interview Single Platform.""" env, agent_graph, actual_platform = self._get_env_and_graph(platform) if not env or not agent_graph: - return {"platform": platform, "error": f"{platform}平台不可用"} + return {"platform": platform, "error": f"{platform} platform is unavailable"} try: agent = agent_graph.get_agent(agent_id) @@ -343,37 +290,23 @@ async def _interview_single_platform(self, agent_id: int, prompt: str, platform: return {"platform": platform, "error": str(e)} async def handle_interview(self, command_id: str, agent_id: int, prompt: str, platform: str = None) -> bool: - """ - 处理单个Agent采访命令 - - Args: - command_id: 命令ID - agent_id: Agent ID - prompt: 采访问题 - platform: 指定平台(可选) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None/不指定: 同时采访两个平台,返回整合结果 - - Returns: - True 表示成功,False 表示失败 - """ - # 如果指定了平台,只采访该平台 + """Handle interview.""" + if platform in ("twitter", "reddit"): result = await self._interview_single_platform(agent_id, prompt, platform) if "error" in result: self.send_response(command_id, "failed", error=result["error"]) - print(f" Interview失败: agent_id={agent_id}, platform={platform}, error={result['error']}") + print(f" Interview failed: agent_id={agent_id}, platform={platform}, error={result['error']}") return False else: self.send_response(command_id, "completed", result=result) - print(f" Interview完成: agent_id={agent_id}, platform={platform}") + print(f" Interview completed: agent_id={agent_id}, platform={platform}") return True - # 未指定平台:同时采访两个平台 + if not self.twitter_env and not self.reddit_env: - self.send_response(command_id, "failed", error="没有可用的模拟环境") + self.send_response(command_id, "failed", error="No simulation environments are available") return False results = { @@ -383,7 +316,7 @@ async def handle_interview(self, command_id: str, agent_id: int, prompt: str, pl } success_count = 0 - # 并行采访两个平台 + tasks = [] platforms_to_interview = [] @@ -395,7 +328,7 @@ async def handle_interview(self, command_id: str, agent_id: int, prompt: str, pl tasks.append(self._interview_single_platform(agent_id, prompt, "reddit")) platforms_to_interview.append("reddit") - # 并行执行 + platform_results = await asyncio.gather(*tasks) for platform_name, platform_result in zip(platforms_to_interview, platform_results): @@ -405,30 +338,23 @@ async def handle_interview(self, command_id: str, agent_id: int, prompt: str, pl if success_count > 0: self.send_response(command_id, "completed", result=results) - print(f" Interview完成: agent_id={agent_id}, 成功平台数={success_count}/{len(platforms_to_interview)}") + print( + f" Interview completed: agent_id={agent_id}, " + f"successful_platforms={success_count}/{len(platforms_to_interview)}" + ) return True else: - errors = [f"{p}: {r.get('error', '未知错误')}" for p, r in results["platforms"].items()] + errors = [f"{p}: {r.get('error', 'Unknown error')}" for p, r in results["platforms"].items()] self.send_response(command_id, "failed", error="; ".join(errors)) - print(f" Interview失败: agent_id={agent_id}, 所有平台都失败") + print(f" Interview failed: agent_id={agent_id}, all platforms failed") return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict], platform: str = None) -> bool: - """ - 处理批量采访命令 - - Args: - command_id: 命令ID - interviews: [{"agent_id": int, "prompt": str, "platform": str(optional)}, ...] - platform: 默认平台(可被每个interview项覆盖) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None/不指定: 每个Agent同时采访两个平台 - """ - # 按平台分组 + """Handle batch interview.""" + twitter_interviews = [] reddit_interviews = [] - both_platforms_interviews = [] # 需要同时采访两个平台的 + both_platforms_interviews = [] for interview in interviews: item_platform = interview.get("platform", platform) @@ -437,10 +363,10 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict], elif item_platform == "reddit": reddit_interviews.append(interview) else: - # 未指定平台:两个平台都采访 + both_platforms_interviews.append(interview) - # 把 both_platforms_interviews 拆分到两个平台 + if both_platforms_interviews: if self.twitter_env: twitter_interviews.extend(both_platforms_interviews) @@ -449,7 +375,7 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict], results = {} - # 处理Twitter平台的采访 + if twitter_interviews and self.twitter_env: try: twitter_actions = {} @@ -463,7 +389,7 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict], action_args={"prompt": prompt} ) except Exception as e: - print(f" 警告: 无法获取Twitter Agent {agent_id}: {e}") + print(f" Warning: unable to get Twitter agent {agent_id}: {e}") if twitter_actions: await self.twitter_env.step(twitter_actions) @@ -474,9 +400,9 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict], result["platform"] = "twitter" results[f"twitter_{agent_id}"] = result except Exception as e: - print(f" Twitter批量Interview失败: {e}") + print(f" Twitter batch interview failed: {e}") + - # 处理Reddit平台的采访 if reddit_interviews and self.reddit_env: try: reddit_actions = {} @@ -490,7 +416,7 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict], action_args={"prompt": prompt} ) except Exception as e: - print(f" 警告: 无法获取Reddit Agent {agent_id}: {e}") + print(f" Warning: unable to get Reddit agent {agent_id}: {e}") if reddit_actions: await self.reddit_env.step(reddit_actions) @@ -501,21 +427,21 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict], result["platform"] = "reddit" results[f"reddit_{agent_id}"] = result except Exception as e: - print(f" Reddit批量Interview失败: {e}") + print(f" Reddit batch interview failed: {e}") if results: self.send_response(command_id, "completed", result={ "interviews_count": len(results), "results": results }) - print(f" 批量Interview完成: {len(results)} 个Agent") + print(f" Batch interview completed: {len(results)} agents") return True else: - self.send_response(command_id, "failed", error="没有成功的采访") + self.send_response(command_id, "failed", error="No interviews succeeded") return False def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: - """从数据库获取最新的Interview结果""" + """Get interview result.""" db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db") result = { @@ -531,7 +457,7 @@ def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: conn = sqlite3.connect(db_path) cursor = conn.cursor() - # 查询最新的Interview记录 + cursor.execute(""" SELECT user_id, info, created_at FROM trace @@ -553,17 +479,12 @@ def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: conn.close() except Exception as e: - print(f" 读取Interview结果失败: {e}") + print(f" Failed to read interview result: {e}") return result async def process_commands(self) -> bool: - """ - 处理所有待处理命令 - - Returns: - True 表示继续运行,False 表示应该退出 - """ + """Process commands.""" command = self.poll_command() if not command: return True @@ -572,7 +493,7 @@ async def process_commands(self) -> bool: command_type = command.get("command_type") args = command.get("args", {}) - print(f"\n收到IPC命令: {command_type}, id={command_id}") + print(f"\nReceived IPC command: {command_type}, id={command_id}") if command_type == CommandType.INTERVIEW: await self.handle_interview( @@ -592,25 +513,25 @@ async def process_commands(self) -> bool: return True elif command_type == CommandType.CLOSE_ENV: - print("收到关闭环境命令") - self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) + print("Received close-environment command") + self.send_response(command_id, "completed", result={"message": "Environment will close shortly"}) return False else: - self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") + self.send_response(command_id, "failed", error=f"Unknown command type: {command_type}") return True def load_config(config_path: str) -> Dict[str, Any]: - """加载配置文件""" + """Load config.""" with open(config_path, 'r', encoding='utf-8') as f: return json.load(f) -# 需要过滤掉的非核心动作类型(这些动作对分析价值较低) + FILTERED_ACTIONS = {'refresh', 'sign_up'} -# 动作类型映射表(数据库中的名称 -> 标准名称) + ACTION_TYPE_MAP = { 'create_post': 'CREATE_POST', 'like_post': 'LIKE_POST', @@ -631,17 +552,7 @@ def load_config(config_path: str) -> Dict[str, Any]: def get_agent_names_from_config(config: Dict[str, Any]) -> Dict[int, str]: - """ - 从 simulation_config 中获取 agent_id -> entity_name 的映射 - - 这样可以在 actions.jsonl 中显示真实的实体名称,而不是 "Agent_0" 这样的代号 - - Args: - config: simulation_config.json 的内容 - - Returns: - agent_id -> entity_name 的映射字典 - """ + """Get agent names from config.""" agent_names = {} agent_configs = config.get("agent_configs", []) @@ -659,19 +570,7 @@ def fetch_new_actions_from_db( last_rowid: int, agent_names: Dict[int, str] ) -> Tuple[List[Dict[str, Any]], int]: - """ - 从数据库中获取新的动作记录,并补充完整的上下文信息 - - Args: - db_path: 数据库文件路径 - last_rowid: 上次读取的最大 rowid 值(使用 rowid 而不是 created_at,因为不同平台的 created_at 格式不同) - agent_names: agent_id -> agent_name 映射 - - Returns: - (actions_list, new_last_rowid) - - actions_list: 动作列表,每个元素包含 agent_id, agent_name, action_type, action_args(含上下文信息) - - new_last_rowid: 新的最大 rowid 值 - """ + """Fetch new actions from db.""" actions = [] new_last_rowid = last_rowid @@ -682,8 +581,8 @@ def fetch_new_actions_from_db( conn = sqlite3.connect(db_path) cursor = conn.cursor() - # 使用 rowid 来追踪已处理的记录(rowid 是 SQLite 的内置自增字段) - # 这样可以避免 created_at 格式差异问题(Twitter 用整数,Reddit 用日期时间字符串) + + cursor.execute(""" SELECT rowid, user_id, action, info FROM trace @@ -692,20 +591,20 @@ def fetch_new_actions_from_db( """, (last_rowid,)) for rowid, user_id, action, info_json in cursor.fetchall(): - # 更新最大 rowid + new_last_rowid = rowid - # 过滤非核心动作 + if action in FILTERED_ACTIONS: continue - # 解析动作参数 + try: action_args = json.loads(info_json) if info_json else {} except json.JSONDecodeError: action_args = {} - # 精简 action_args,只保留关键字段(保留完整内容,不截断) + simplified_args = {} if 'content' in action_args: simplified_args['content'] = action_args['content'] @@ -726,10 +625,10 @@ def fetch_new_actions_from_db( if 'dislike_id' in action_args: simplified_args['dislike_id'] = action_args['dislike_id'] - # 转换动作类型名称 + action_type = ACTION_TYPE_MAP.get(action, action.upper()) - # 补充上下文信息(帖子内容、用户名等) + _enrich_action_context(cursor, action_type, simplified_args, agent_names) actions.append({ @@ -741,7 +640,7 @@ def fetch_new_actions_from_db( conn.close() except Exception as e: - print(f"读取数据库动作失败: {e}") + print(f"Failed to read database actions: {e}") return actions, new_last_rowid @@ -752,17 +651,9 @@ def _enrich_action_context( action_args: Dict[str, Any], agent_names: Dict[int, str] ) -> None: - """ - 为动作补充上下文信息(帖子内容、用户名等) - - Args: - cursor: 数据库游标 - action_type: 动作类型 - action_args: 动作参数(会被修改) - agent_names: agent_id -> agent_name 映射 - """ + """Enrich Action Context.""" try: - # 点赞/踩帖子:补充帖子内容和作者 + if action_type in ('LIKE_POST', 'DISLIKE_POST'): post_id = action_args.get('post_id') if post_id: @@ -771,11 +662,11 @@ def _enrich_action_context( action_args['post_content'] = post_info.get('content', '') action_args['post_author_name'] = post_info.get('author_name', '') - # 转发帖子:补充原帖内容和作者 + elif action_type == 'REPOST': new_post_id = action_args.get('new_post_id') if new_post_id: - # 转发帖子的 original_post_id 指向原帖 + cursor.execute(""" SELECT original_post_id FROM post WHERE post_id = ? """, (new_post_id,)) @@ -787,7 +678,7 @@ def _enrich_action_context( action_args['original_content'] = original_info.get('content', '') action_args['original_author_name'] = original_info.get('author_name', '') - # 引用帖子:补充原帖内容、作者和引用评论 + elif action_type == 'QUOTE_POST': quoted_id = action_args.get('quoted_id') new_post_id = action_args.get('new_post_id') @@ -798,7 +689,7 @@ def _enrich_action_context( action_args['original_content'] = original_info.get('content', '') action_args['original_author_name'] = original_info.get('author_name', '') - # 获取引用帖子的评论内容(quote_content) + if new_post_id: cursor.execute(""" SELECT quote_content FROM post WHERE post_id = ? @@ -807,11 +698,11 @@ def _enrich_action_context( if row and row[0]: action_args['quote_content'] = row[0] - # 关注用户:补充被关注用户的名称 + elif action_type == 'FOLLOW': follow_id = action_args.get('follow_id') if follow_id: - # 从 follow 表获取 followee_id + cursor.execute(""" SELECT followee_id FROM follow WHERE follow_id = ? """, (follow_id,)) @@ -822,16 +713,16 @@ def _enrich_action_context( if target_name: action_args['target_user_name'] = target_name - # 屏蔽用户:补充被屏蔽用户的名称 + elif action_type == 'MUTE': - # 从 action_args 中获取 user_id 或 target_id + target_id = action_args.get('user_id') or action_args.get('target_id') if target_id: target_name = _get_user_name(cursor, target_id, agent_names) if target_name: action_args['target_user_name'] = target_name - # 点赞/踩评论:补充评论内容和作者 + elif action_type in ('LIKE_COMMENT', 'DISLIKE_COMMENT'): comment_id = action_args.get('comment_id') if comment_id: @@ -840,7 +731,7 @@ def _enrich_action_context( action_args['comment_content'] = comment_info.get('content', '') action_args['comment_author_name'] = comment_info.get('author_name', '') - # 发表评论:补充所评论的帖子信息 + elif action_type == 'CREATE_COMMENT': post_id = action_args.get('post_id') if post_id: @@ -850,8 +741,8 @@ def _enrich_action_context( action_args['post_author_name'] = post_info.get('author_name', '') except Exception as e: - # 补充上下文失败不影响主流程 - print(f"补充动作上下文失败: {e}") + + print(f"Failed to enrich action context: {e}") def _get_post_info( @@ -859,17 +750,7 @@ def _get_post_info( post_id: int, agent_names: Dict[int, str] ) -> Optional[Dict[str, str]]: - """ - 获取帖子信息 - - Args: - cursor: 数据库游标 - post_id: 帖子ID - agent_names: agent_id -> agent_name 映射 - - Returns: - 包含 content 和 author_name 的字典,或 None - """ + """Get post info.""" try: cursor.execute(""" SELECT p.content, p.user_id, u.agent_id @@ -883,12 +764,12 @@ def _get_post_info( user_id = row[1] agent_id = row[2] - # 优先使用 agent_names 中的名称 + author_name = '' if agent_id is not None and agent_id in agent_names: author_name = agent_names[agent_id] elif user_id: - # 从 user 表获取名称 + cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,)) user_row = cursor.fetchone() if user_row: @@ -905,17 +786,7 @@ def _get_user_name( user_id: int, agent_names: Dict[int, str] ) -> Optional[str]: - """ - 获取用户名称 - - Args: - cursor: 数据库游标 - user_id: 用户ID - agent_names: agent_id -> agent_name 映射 - - Returns: - 用户名称,或 None - """ + """Get user name.""" try: cursor.execute(""" SELECT agent_id, name, user_name FROM user WHERE user_id = ? @@ -926,7 +797,7 @@ def _get_user_name( name = row[1] user_name = row[2] - # 优先使用 agent_names 中的名称 + if agent_id is not None and agent_id in agent_names: return agent_names[agent_id] return name or user_name or '' @@ -940,17 +811,7 @@ def _get_comment_info( comment_id: int, agent_names: Dict[int, str] ) -> Optional[Dict[str, str]]: - """ - 获取评论信息 - - Args: - cursor: 数据库游标 - comment_id: 评论ID - agent_names: agent_id -> agent_name 映射 - - Returns: - 包含 content 和 author_name 的字典,或 None - """ + """Get comment info.""" try: cursor.execute(""" SELECT c.content, c.user_id, u.agent_id @@ -964,12 +825,12 @@ def _get_comment_info( user_id = row[1] agent_id = row[2] - # 优先使用 agent_names 中的名称 + author_name = '' if agent_id is not None and agent_id in agent_names: author_name = agent_names[agent_id] elif user_id: - # 从 user 表获取名称 + cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,)) user_row = cursor.fetchone() if user_row: @@ -982,54 +843,42 @@ def _get_comment_info( def create_model(config: Dict[str, Any], use_boost: bool = False): - """ - 创建LLM模型 - - 支持双 LLM 配置,用于并行模拟时提速: - - 通用配置:LLM_API_KEY, LLM_BASE_URL, LLM_MODEL_NAME - - 加速配置(可选):LLM_BOOST_API_KEY, LLM_BOOST_BASE_URL, LLM_BOOST_MODEL_NAME - - 如果配置了加速 LLM,并行模拟时可以让不同平台使用不同的 API 服务商,提高并发能力。 + """Create model.""" - Args: - config: 模拟配置字典 - use_boost: 是否使用加速 LLM 配置(如果可用) - """ - # 检查是否有加速配置 boost_api_key = os.environ.get("LLM_BOOST_API_KEY", "") boost_base_url = os.environ.get("LLM_BOOST_BASE_URL", "") boost_model = os.environ.get("LLM_BOOST_MODEL_NAME", "") has_boost_config = bool(boost_api_key) - # 根据参数和配置情况选择使用哪个 LLM + if use_boost and has_boost_config: - # 使用加速配置 + llm_api_key = boost_api_key llm_base_url = boost_base_url llm_model = boost_model or os.environ.get("LLM_MODEL_NAME", "") - config_label = "[加速LLM]" + config_label = "[Boost LLM]" else: - # 使用通用配置 + llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") - config_label = "[通用LLM]" + config_label = "[Standard LLM]" + - # 如果 .env 中没有模型名,则使用 config 作为备用 if not llm_model: llm_model = config.get("llm_model", "gpt-4o-mini") - # 设置 camel-ai 所需的环境变量 + if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key if not os.environ.get("OPENAI_API_KEY"): - raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") + raise ValueError("Missing API key configuration. Set LLM_API_KEY in the project root .env file") if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url - print(f"{config_label} model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") + print(f"{config_label} model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else 'default'}...") return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, @@ -1043,7 +892,7 @@ def get_active_agents_for_round( current_hour: int, round_num: int ) -> List: - """根据时间和配置决定本轮激活哪些Agent""" + """Get active agents for round.""" time_config = config.get("time_config", {}) agent_configs = config.get("agent_configs", []) @@ -1091,7 +940,7 @@ def get_active_agents_for_round( class PlatformSimulation: - """平台模拟结果容器""" + """Platform Simulation.""" def __init__(self): self.env = None self.agent_graph = None @@ -1105,18 +954,7 @@ async def run_twitter_simulation( main_logger: Optional[SimulationLogManager] = None, max_rounds: Optional[int] = None ) -> PlatformSimulation: - """运行Twitter模拟 - - Args: - config: 模拟配置 - simulation_dir: 模拟目录 - action_logger: 动作日志记录器 - main_logger: 主日志管理器 - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) - - Returns: - PlatformSimulation: 包含env和agent_graph的结果对象 - """ + """Run twitter simulation.""" result = PlatformSimulation() def log_info(msg): @@ -1124,15 +962,15 @@ def log_info(msg): main_logger.info(f"[Twitter] {msg}") print(f"[Twitter] {msg}") - log_info("初始化...") + log_info("Initializing...") + - # Twitter 使用通用 LLM 配置 model = create_model(config, use_boost=False) - # OASIS Twitter使用CSV格式 + profile_path = os.path.join(simulation_dir, "twitter_profiles.csv") if not os.path.exists(profile_path): - log_info(f"错误: Profile文件不存在: {profile_path}") + log_info(f"Error: profile file does not exist: {profile_path}") return result result.agent_graph = await generate_twitter_agent_graph( @@ -1141,9 +979,9 @@ def log_info(msg): available_actions=TWITTER_ACTIONS, ) - # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X) + agent_names = get_agent_names_from_config(config) - # 如果配置中没有某个 agent,则使用 OASIS 的默认名称 + for agent_id, agent in result.agent_graph.get_agents(): if agent_id not in agent_names: agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') @@ -1156,23 +994,23 @@ def log_info(msg): agent_graph=result.agent_graph, platform=oasis.DefaultPlatformType.TWITTER, database_path=db_path, - semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 + semaphore=30, ) await result.env.reset() - log_info("环境已启动") + log_info("Environment started") if action_logger: action_logger.log_simulation_start(config) total_actions = 0 - last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异) + last_rowid = 0 + - # 执行初始事件 event_config = config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) - # 记录 round 0 开始(初始事件阶段) + if action_logger: action_logger.log_round_start(0, 0) # round 0, simulated_hour 0 @@ -1204,32 +1042,32 @@ def log_info(msg): if initial_actions: await result.env.step(initial_actions) - log_info(f"已发布 {len(initial_actions)} 条初始帖子") + log_info(f"Published {len(initial_actions)} initial posts") + - # 记录 round 0 结束 if action_logger: action_logger.log_round_end(0, initial_action_count) - # 主模拟循环 + time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round - # 如果指定了最大轮数,则截断 + if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: - log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") + log_info(f"Round count truncated: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") start_time = datetime.now() for round_num in range(total_rounds): - # 检查是否收到退出信号 + if _shutdown_event and _shutdown_event.is_set(): if main_logger: - main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟") + main_logger.info(f"Received shutdown signal; stopping simulation at round {round_num + 1}") break simulated_minutes = round_num * minutes_per_round @@ -1240,12 +1078,12 @@ def log_info(msg): result.env, config, simulated_hour, round_num ) - # 无论是否有活跃agent,都记录round开始 + if action_logger: action_logger.log_round_start(round_num + 1, simulated_hour) if not active_agents: - # 没有活跃agent时也记录round结束(actions_count=0) + if action_logger: action_logger.log_round_end(round_num + 1, 0) continue @@ -1253,7 +1091,7 @@ def log_info(msg): actions = {agent: LLMAction() for _, agent in active_agents} await result.env.step(actions) - # 从数据库获取实际执行的动作并记录 + actual_actions, last_rowid = fetch_new_actions_from_db( db_path, last_rowid, agent_names ) @@ -1278,14 +1116,14 @@ def log_info(msg): progress = (round_num + 1) / total_rounds * 100 log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") - # 注意:不关闭环境,保留给Interview使用 + if action_logger: action_logger.log_simulation_end(total_rounds, total_actions) result.total_actions = total_actions elapsed = (datetime.now() - start_time).total_seconds() - log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") + log_info(f"Simulation loop completed! Elapsed: {elapsed:.1f}s, total actions: {total_actions}") return result @@ -1297,18 +1135,7 @@ async def run_reddit_simulation( main_logger: Optional[SimulationLogManager] = None, max_rounds: Optional[int] = None ) -> PlatformSimulation: - """运行Reddit模拟 - - Args: - config: 模拟配置 - simulation_dir: 模拟目录 - action_logger: 动作日志记录器 - main_logger: 主日志管理器 - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) - - Returns: - PlatformSimulation: 包含env和agent_graph的结果对象 - """ + """Run reddit simulation.""" result = PlatformSimulation() def log_info(msg): @@ -1316,14 +1143,14 @@ def log_info(msg): main_logger.info(f"[Reddit] {msg}") print(f"[Reddit] {msg}") - log_info("初始化...") + log_info("Initializing...") + - # Reddit 使用加速 LLM 配置(如果有的话,否则回退到通用配置) model = create_model(config, use_boost=True) profile_path = os.path.join(simulation_dir, "reddit_profiles.json") if not os.path.exists(profile_path): - log_info(f"错误: Profile文件不存在: {profile_path}") + log_info(f"Error: profile file does not exist: {profile_path}") return result result.agent_graph = await generate_reddit_agent_graph( @@ -1332,9 +1159,9 @@ def log_info(msg): available_actions=REDDIT_ACTIONS, ) - # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X) + agent_names = get_agent_names_from_config(config) - # 如果配置中没有某个 agent,则使用 OASIS 的默认名称 + for agent_id, agent in result.agent_graph.get_agents(): if agent_id not in agent_names: agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') @@ -1347,23 +1174,23 @@ def log_info(msg): agent_graph=result.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, - semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 + semaphore=30, ) await result.env.reset() - log_info("环境已启动") + log_info("Environment started") if action_logger: action_logger.log_simulation_start(config) total_actions = 0 - last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异) + last_rowid = 0 + - # 执行初始事件 event_config = config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) - # 记录 round 0 开始(初始事件阶段) + if action_logger: action_logger.log_round_start(0, 0) # round 0, simulated_hour 0 @@ -1403,32 +1230,32 @@ def log_info(msg): if initial_actions: await result.env.step(initial_actions) - log_info(f"已发布 {len(initial_actions)} 条初始帖子") + log_info(f"Published {len(initial_actions)} initial posts") + - # 记录 round 0 结束 if action_logger: action_logger.log_round_end(0, initial_action_count) - # 主模拟循环 + time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round - # 如果指定了最大轮数,则截断 + if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: - log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") + log_info(f"Round count truncated: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") start_time = datetime.now() for round_num in range(total_rounds): - # 检查是否收到退出信号 + if _shutdown_event and _shutdown_event.is_set(): if main_logger: - main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟") + main_logger.info(f"Received shutdown signal; stopping simulation at round {round_num + 1}") break simulated_minutes = round_num * minutes_per_round @@ -1439,12 +1266,12 @@ def log_info(msg): result.env, config, simulated_hour, round_num ) - # 无论是否有活跃agent,都记录round开始 + if action_logger: action_logger.log_round_start(round_num + 1, simulated_hour) if not active_agents: - # 没有活跃agent时也记录round结束(actions_count=0) + if action_logger: action_logger.log_round_end(round_num + 1, 0) continue @@ -1452,7 +1279,7 @@ def log_info(msg): actions = {agent: LLMAction() for _, agent in active_agents} await result.env.step(actions) - # 从数据库获取实际执行的动作并记录 + actual_actions, last_rowid = fetch_new_actions_from_db( db_path, last_rowid, agent_names ) @@ -1477,76 +1304,76 @@ def log_info(msg): progress = (round_num + 1) / total_rounds * 100 log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") - # 注意:不关闭环境,保留给Interview使用 + if action_logger: action_logger.log_simulation_end(total_rounds, total_actions) result.total_actions = total_actions elapsed = (datetime.now() - start_time).total_seconds() - log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}") + log_info(f"Simulation loop completed! Elapsed: {elapsed:.1f}s, total actions: {total_actions}") return result async def main(): - parser = argparse.ArgumentParser(description='OASIS双平台并行模拟') + parser = argparse.ArgumentParser(description='OASIS dual-platform parallel simulation') parser.add_argument( '--config', type=str, required=True, - help='配置文件路径 (simulation_config.json)' + help='Path to the config file (simulation_config.json)' ) parser.add_argument( '--twitter-only', action='store_true', - help='只运行Twitter模拟' + help='Run only the Twitter simulation' ) parser.add_argument( '--reddit-only', action='store_true', - help='只运行Reddit模拟' + help='Run only the Reddit simulation' ) parser.add_argument( '--max-rounds', type=int, default=None, - help='最大模拟轮数(可选,用于截断过长的模拟)' + help='Maximum number of simulation rounds (optional, used to truncate long simulations)' ) parser.add_argument( '--no-wait', action='store_true', default=False, - help='模拟完成后立即关闭环境,不进入等待命令模式' + help='Close the environment immediately after the simulation instead of entering command-wait mode' ) args = parser.parse_args() - # 在 main 函数开始时创建 shutdown 事件,确保整个程序都能响应退出信号 + global _shutdown_event _shutdown_event = asyncio.Event() if not os.path.exists(args.config): - print(f"错误: 配置文件不存在: {args.config}") + print(f"Error: config file does not exist: {args.config}") sys.exit(1) config = load_config(args.config) simulation_dir = os.path.dirname(args.config) or "." wait_for_commands = not args.no_wait - # 初始化日志配置(禁用 OASIS 日志,清理旧文件) + init_logging_for_simulation(simulation_dir) - # 创建日志管理器 + log_manager = SimulationLogManager(simulation_dir) twitter_logger = log_manager.get_twitter_logger() reddit_logger = log_manager.get_reddit_logger() log_manager.info("=" * 60) - log_manager.info("OASIS 双平台并行模拟") - log_manager.info(f"配置文件: {args.config}") - log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}") - log_manager.info(f"等待命令模式: {'启用' if wait_for_commands else '禁用'}") + log_manager.info("OASIS Dual-Platform Parallel Simulation") + log_manager.info(f"Config file: {args.config}") + log_manager.info(f"Simulation ID: {config.get('simulation_id', 'unknown')}") + log_manager.info(f"Command wait mode: {'enabled' if wait_for_commands else 'disabled'}") log_manager.info("=" * 60) time_config = config.get("time_config", {}) @@ -1554,25 +1381,25 @@ async def main(): minutes_per_round = time_config.get('minutes_per_round', 30) config_total_rounds = (total_hours * 60) // minutes_per_round - log_manager.info(f"模拟参数:") - log_manager.info(f" - 总模拟时长: {total_hours}小时") - log_manager.info(f" - 每轮时间: {minutes_per_round}分钟") - log_manager.info(f" - 配置总轮数: {config_total_rounds}") + log_manager.info("Simulation parameters:") + log_manager.info(f" - Total simulation duration: {total_hours} hours") + log_manager.info(f" - Minutes per round: {minutes_per_round}") + log_manager.info(f" - Configured total rounds: {config_total_rounds}") if args.max_rounds: - log_manager.info(f" - 最大轮数限制: {args.max_rounds}") + log_manager.info(f" - Max rounds limit: {args.max_rounds}") if args.max_rounds < config_total_rounds: - log_manager.info(f" - 实际执行轮数: {args.max_rounds} (已截断)") - log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}") + log_manager.info(f" - Actual executed rounds: {args.max_rounds} (truncated)") + log_manager.info(f" - Agent count: {len(config.get('agent_configs', []))}") - log_manager.info("日志结构:") - log_manager.info(f" - 主日志: simulation.log") - log_manager.info(f" - Twitter动作: twitter/actions.jsonl") - log_manager.info(f" - Reddit动作: reddit/actions.jsonl") + log_manager.info("Log structure:") + log_manager.info(" - Main log: simulation.log") + log_manager.info(" - Twitter actions: twitter/actions.jsonl") + log_manager.info(" - Reddit actions: reddit/actions.jsonl") log_manager.info("=" * 60) start_time = datetime.now() - # 存储两个平台的模拟结果 + twitter_result: Optional[PlatformSimulation] = None reddit_result: Optional[PlatformSimulation] = None @@ -1581,7 +1408,7 @@ async def main(): elif args.reddit_only: reddit_result = await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds) else: - # 并行运行(每个平台使用独立的日志记录器) + results = await asyncio.gather( run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds), run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds), @@ -1590,17 +1417,17 @@ async def main(): total_elapsed = (datetime.now() - start_time).total_seconds() log_manager.info("=" * 60) - log_manager.info(f"模拟循环完成! 总耗时: {total_elapsed:.1f}秒") + log_manager.info(f"Simulation loop completed! Total elapsed time: {total_elapsed:.1f}s") + - # 是否进入等待命令模式 if wait_for_commands: log_manager.info("") log_manager.info("=" * 60) - log_manager.info("进入等待命令模式 - 环境保持运行") - log_manager.info("支持的命令: interview, batch_interview, close_env") + log_manager.info("Entering command-wait mode - environment stays running") + log_manager.info("Supported commands: interview, batch_interview, close_env") log_manager.info("=" * 60) - # 创建IPC处理器 + ipc_handler = ParallelIPCHandler( simulation_dir=simulation_dir, twitter_env=twitter_result.env if twitter_result else None, @@ -1610,40 +1437,40 @@ async def main(): ) ipc_handler.update_status("alive") - # 等待命令循环(使用全局 _shutdown_event) + try: while not _shutdown_event.is_set(): should_continue = await ipc_handler.process_commands() if not should_continue: break - # 使用 wait_for 替代 sleep,这样可以响应 shutdown_event + try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) - break # 收到退出信号 + break except asyncio.TimeoutError: - pass # 超时继续循环 + pass except KeyboardInterrupt: - print("\n收到中断信号") + print("\nReceived interrupt signal") except asyncio.CancelledError: - print("\n任务被取消") + print("\nTask was cancelled") except Exception as e: - print(f"\n命令处理出错: {e}") + print(f"\nCommand processing error: {e}") - log_manager.info("\n关闭环境...") + log_manager.info("\nClosing environment...") ipc_handler.update_status("stopped") - # 关闭环境 + if twitter_result and twitter_result.env: await twitter_result.env.close() - log_manager.info("[Twitter] 环境已关闭") + log_manager.info("[Twitter] Environment closed") if reddit_result and reddit_result.env: await reddit_result.env.close() - log_manager.info("[Reddit] 环境已关闭") + log_manager.info("[Reddit] Environment closed") log_manager.info("=" * 60) - log_manager.info(f"全部完成!") - log_manager.info(f"日志文件:") + log_manager.info("All done!") + log_manager.info("Log files:") log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}") log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}") log_manager.info(f" - {os.path.join(simulation_dir, 'reddit', 'actions.jsonl')}") @@ -1651,30 +1478,22 @@ async def main(): def setup_signal_handlers(loop=None): - """ - 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 - - 持久化模拟场景:模拟完成后不退出,等待 interview 命令 - 当收到终止信号时,需要: - 1. 通知 asyncio 循环退出等待 - 2. 让程序有机会正常清理资源(关闭数据库、环境等) - 3. 然后才退出 - """ + """Setup Signal Handlers.""" def signal_handler(signum, frame): global _cleanup_done sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" - print(f"\n收到 {sig_name} 信号,正在退出...") + print(f"\nReceived {sig_name}; exiting...") if not _cleanup_done: _cleanup_done = True - # 设置事件通知 asyncio 循环退出(让循环有机会清理资源) + if _shutdown_event: _shutdown_event.set() - # 不要直接 sys.exit(),让 asyncio 循环正常退出并清理资源 - # 如果是重复收到信号,才强制退出 + + else: - print("强制退出...") + print("Force exiting...") sys.exit(1) signal.signal(signal.SIGTERM, signal_handler) @@ -1686,14 +1505,14 @@ def signal_handler(signum, frame): try: asyncio.run(main()) except KeyboardInterrupt: - print("\n程序被中断") + print("\nProgram interrupted") except SystemExit: pass finally: - # 清理 multiprocessing 资源跟踪器(防止退出时的警告) + try: from multiprocessing import resource_tracker resource_tracker._resource_tracker._stop() except Exception: pass - print("模拟进程已退出") + print("Simulation process exited") diff --git a/backend/scripts/run_reddit_simulation.py b/backend/scripts/run_reddit_simulation.py index 14907cbda..52eec1f14 100644 --- a/backend/scripts/run_reddit_simulation.py +++ b/backend/scripts/run_reddit_simulation.py @@ -1,17 +1,4 @@ -""" -OASIS Reddit模拟预设脚本 -此脚本读取配置文件中的参数来执行模拟,实现全程自动化 - -功能特性: -- 完成模拟后不立即关闭环境,进入等待命令模式 -- 支持通过IPC接收Interview命令 -- 支持单个Agent采访和批量采访 -- 支持远程关闭环境命令 - -使用方式: - python run_reddit_simulation.py --config /path/to/simulation_config.json - python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 -""" +"""OASIS Reddit simulation runner.""" import argparse import asyncio @@ -25,18 +12,18 @@ from datetime import datetime from typing import Dict, Any, List, Optional -# 全局变量:用于信号处理 + _shutdown_event = None _cleanup_done = False -# 添加项目路径 + _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) -# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) + from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): @@ -51,7 +38,7 @@ class UnicodeFormatter(logging.Formatter): - """自定义格式化器,将 Unicode 转义序列转换为可读字符""" + """Unicode Formatter.""" UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') @@ -68,24 +55,24 @@ def replace_unicode(match): class MaxTokensWarningFilter(logging.Filter): - """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" + """Max Tokens Warning Filter.""" def filter(self, record): - # 过滤掉包含 max_tokens 警告的日志 + if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True -# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 + logging.getLogger().addFilter(MaxTokensWarningFilter()) def setup_oasis_logging(log_dir: str): - """配置 OASIS 的日志,使用固定名称的日志文件""" + """Setup OASIS Logging.""" os.makedirs(log_dir, exist_ok=True) - # 清理旧的日志文件 + for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) if os.path.isfile(old_log) and f.endswith('.log'): @@ -126,25 +113,25 @@ def setup_oasis_logging(log_dir: str): generate_reddit_agent_graph ) except ImportError as e: - print(f"错误: 缺少依赖 {e}") - print("请先安装: pip install oasis-ai camel-ai") + print(f"Error: missing dependency {e}") + print("Install the dependencies first: pip install oasis-ai camel-ai") sys.exit(1) -# IPC相关常量 + IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: - """命令类型常量""" + """Command Type.""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class IPCHandler: - """IPC命令处理器""" + """IPC Handler.""" def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir @@ -155,12 +142,12 @@ def __init__(self, simulation_dir: str, env, agent_graph): self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True - # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): - """更新环境状态""" + """Update status.""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, @@ -168,11 +155,11 @@ def update_status(self, status: str): }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: - """轮询获取待处理命令""" + """Poll command.""" if not os.path.exists(self.commands_dir): return None - # 获取命令文件(按时间排序) + command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): @@ -191,7 +178,7 @@ def poll_command(self) -> Optional[Dict[str, Any]]: return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): - """发送响应""" + """Send response.""" response = { "command_id": command_id, "status": status, @@ -204,7 +191,7 @@ def send_response(self, command_id: str, status: str, result: Dict = None, error with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) - # 删除命令文件 + command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) @@ -212,50 +199,40 @@ def send_response(self, command_id: str, status: str, result: Dict = None, error pass async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: - """ - 处理单个Agent采访命令 - - Returns: - True 表示成功,False 表示失败 - """ + """Handle interview.""" try: - # 获取Agent + agent = self.agent_graph.get_agent(agent_id) - # 创建Interview动作 + interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) - # 执行Interview + actions = {agent: interview_action} await self.env.step(actions) - # 从数据库获取结果 + result = self._get_interview_result(agent_id) self.send_response(command_id, "completed", result=result) - print(f" Interview完成: agent_id={agent_id}") + print(f" Interview completed: agent_id={agent_id}") return True except Exception as e: error_msg = str(e) - print(f" Interview失败: agent_id={agent_id}, error={error_msg}") + print(f" Interview failed: agent_id={agent_id}, error={error_msg}") self.send_response(command_id, "failed", error=error_msg) return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: - """ - 处理批量采访命令 - - Args: - interviews: [{"agent_id": int, "prompt": str}, ...] - """ + """Handle batch interview.""" try: - # 构建动作字典 + actions = {} - agent_prompts = {} # 记录每个agent的prompt + agent_prompts = {} for interview in interviews: agent_id = interview.get("agent_id") @@ -269,16 +246,16 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) ) agent_prompts[agent_id] = prompt except Exception as e: - print(f" 警告: 无法获取Agent {agent_id}: {e}") + print(f" Warning: unable to get agent {agent_id}: {e}") if not actions: - self.send_response(command_id, "failed", error="没有有效的Agent") + self.send_response(command_id, "failed", error="No valid agents found") return False - # 执行批量Interview + await self.env.step(actions) - # 获取所有结果 + results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) @@ -288,17 +265,17 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) "interviews_count": len(results), "results": results }) - print(f" 批量Interview完成: {len(results)} 个Agent") + print(f" Batch interview completed: {len(results)} agents") return True except Exception as e: error_msg = str(e) - print(f" 批量Interview失败: {error_msg}") + print(f" Batch interview failed: {error_msg}") self.send_response(command_id, "failed", error=error_msg) return False def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: - """从数据库获取最新的Interview结果""" + """Get interview result.""" db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") result = { @@ -314,7 +291,7 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: conn = sqlite3.connect(db_path) cursor = conn.cursor() - # 查询最新的Interview记录 + cursor.execute(""" SELECT user_id, info, created_at FROM trace @@ -336,17 +313,12 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: conn.close() except Exception as e: - print(f" 读取Interview结果失败: {e}") + print(f" Failed to read interview result: {e}") return result async def process_commands(self) -> bool: - """ - 处理所有待处理命令 - - Returns: - True 表示继续运行,False 表示应该退出 - """ + """Process commands.""" command = self.poll_command() if not command: return True @@ -355,7 +327,7 @@ async def process_commands(self) -> bool: command_type = command.get("command_type") args = command.get("args", {}) - print(f"\n收到IPC命令: {command_type}, id={command_id}") + print(f"\nReceived IPC command: {command_type}, id={command_id}") if command_type == CommandType.INTERVIEW: await self.handle_interview( @@ -373,19 +345,19 @@ async def process_commands(self) -> bool: return True elif command_type == CommandType.CLOSE_ENV: - print("收到关闭环境命令") - self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) + print("Received close-environment command") + self.send_response(command_id, "completed", result={"message": "Environment will close shortly"}) return False else: - self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") + self.send_response(command_id, "failed", error=f"Unknown command type: {command_type}") return True class RedditSimulationRunner: - """Reddit模拟运行器""" + """Reddit Simulation Runner.""" + - # Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) AVAILABLE_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, @@ -403,13 +375,7 @@ class RedditSimulationRunner: ] def __init__(self, config_path: str, wait_for_commands: bool = True): - """ - 初始化模拟运行器 - - Args: - config_path: 配置文件路径 (simulation_config.json) - wait_for_commands: 模拟完成后是否等待命令(默认True) - """ + """Initialize the instance.""" self.config_path = config_path self.config = self._load_config() self.simulation_dir = os.path.dirname(config_path) @@ -419,47 +385,40 @@ def __init__(self, config_path: str, wait_for_commands: bool = True): self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: - """加载配置文件""" + """Load config.""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) def _get_profile_path(self) -> str: - """获取Profile文件路径""" + """Get profile path.""" return os.path.join(self.simulation_dir, "reddit_profiles.json") def _get_db_path(self) -> str: - """获取数据库路径""" + """Get db path.""" return os.path.join(self.simulation_dir, "reddit_simulation.db") def _create_model(self): - """ - 创建LLM模型 - - 统一使用项目根目录 .env 文件中的配置(优先级最高): - - LLM_API_KEY: API密钥 - - LLM_BASE_URL: API基础URL - - LLM_MODEL_NAME: 模型名称 - """ - # 优先从 .env 读取配置 + """Create model.""" + llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") - # 如果 .env 中没有,则使用 config 作为备用 + if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") - # 设置 camel-ai 所需的环境变量 + if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key if not os.environ.get("OPENAI_API_KEY"): - raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") + raise ValueError("Missing API key configuration. Set LLM_API_KEY in the project root .env file") if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url - print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") + print(f"LLM config: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else 'default'}...") return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, @@ -472,9 +431,7 @@ def _get_active_agents_for_round( current_hour: int, round_num: int ) -> List: - """ - 根据时间和配置决定本轮激活哪些Agent - """ + """Get active agents for round.""" time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) @@ -521,16 +478,12 @@ def _get_active_agents_for_round( return active_agents async def run(self, max_rounds: int = None): - """运行Reddit模拟 - - Args: - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) - """ + """Run the requested object.""" print("=" * 60) - print("OASIS Reddit模拟") - print(f"配置文件: {self.config_path}") - print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") - print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") + print("OASIS Reddit Simulation") + print(f"Config file: {self.config_path}") + print(f"Simulation ID: {self.config.get('simulation_id', 'unknown')}") + print(f"Command wait mode: {'enabled' if self.wait_for_commands else 'disabled'}") print("=" * 60) time_config = self.config.get("time_config", {}) @@ -538,28 +491,28 @@ async def run(self, max_rounds: int = None): minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round - # 如果指定了最大轮数,则截断 + if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: - print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") + print(f"\nRound count truncated: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") - print(f"\n模拟参数:") - print(f" - 总模拟时长: {total_hours}小时") - print(f" - 每轮时间: {minutes_per_round}分钟") - print(f" - 总轮数: {total_rounds}") + print(f"\nSimulation parameters:") + print(f" - Total simulation duration: {total_hours} hours") + print(f" - Minutes per round: {minutes_per_round}") + print(f" - Total rounds: {total_rounds}") if max_rounds: - print(f" - 最大轮数限制: {max_rounds}") - print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") + print(f" - Max rounds limit: {max_rounds}") + print(f" - Agent count: {len(self.config.get('agent_configs', []))}") - print("\n初始化LLM模型...") + print("\nInitializing LLM model...") model = self._create_model() - print("加载Agent Profile...") + print("Loading agent profiles...") profile_path = self._get_profile_path() if not os.path.exists(profile_path): - print(f"错误: Profile文件不存在: {profile_path}") + print(f"Error: profile file does not exist: {profile_path}") return self.agent_graph = await generate_reddit_agent_graph( @@ -571,29 +524,29 @@ async def run(self, max_rounds: int = None): db_path = self._get_db_path() if os.path.exists(db_path): os.remove(db_path) - print(f"已删除旧数据库: {db_path}") + print(f"Deleted old database: {db_path}") - print("创建OASIS环境...") + print("Creating OASIS environment...") self.env = oasis.make( agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, - semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 + semaphore=30, ) await self.env.reset() - print("环境初始化完成\n") + print("Environment initialization completed\n") + - # 初始化IPC处理器 self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") - # 执行初始事件 + event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) if initial_posts: - print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") + print(f"Executing initial events ({len(initial_posts)} initial posts)...") initial_actions = {} for post in initial_posts: agent_id = post.get("poster_agent_id", 0) @@ -613,14 +566,14 @@ async def run(self, max_rounds: int = None): action_args={"content": content} ) except Exception as e: - print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") + print(f" Warning: unable to create initial post for agent {agent_id}: {e}") if initial_actions: await self.env.step(initial_actions) - print(f" 已发布 {len(initial_actions)} 条初始帖子") + print(f" Published {len(initial_actions)} initial posts") + - # 主模拟循环 - print("\n开始模拟循环...") + print("\nStarting simulation loop...") start_time = datetime.now() for round_num in range(total_rounds): @@ -651,20 +604,20 @@ async def run(self, max_rounds: int = None): f"- elapsed: {elapsed:.1f}s") total_elapsed = (datetime.now() - start_time).total_seconds() - print(f"\n模拟循环完成!") - print(f" - 总耗时: {total_elapsed:.1f}秒") - print(f" - 数据库: {db_path}") + print(f"\nSimulation loop completed!") + print(f" - Total elapsed time: {total_elapsed:.1f}s") + print(f" - Database: {db_path}") + - # 是否进入等待命令模式 if self.wait_for_commands: print("\n" + "=" * 60) - print("进入等待命令模式 - 环境保持运行") - print("支持的命令: interview, batch_interview, close_env") + print("Entering command-wait mode - environment stays running") + print("Supported commands: interview, batch_interview, close_env") print("=" * 60) self.ipc_handler.update_status("alive") - # 等待命令循环(使用全局 _shutdown_event) + try: while not _shutdown_event.is_set(): should_continue = await self.ipc_handler.process_commands() @@ -672,58 +625,58 @@ async def run(self, max_rounds: int = None): break try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) - break # 收到退出信号 + break except asyncio.TimeoutError: pass except KeyboardInterrupt: - print("\n收到中断信号") + print("\nReceived interrupt signal") except asyncio.CancelledError: - print("\n任务被取消") + print("\nTask was cancelled") except Exception as e: - print(f"\n命令处理出错: {e}") + print(f"\nCommand processing error: {e}") - print("\n关闭环境...") + print("\nClosing environment...") + - # 关闭环境 self.ipc_handler.update_status("stopped") await self.env.close() - print("环境已关闭") + print("Environment closed") print("=" * 60) async def main(): - parser = argparse.ArgumentParser(description='OASIS Reddit模拟') + parser = argparse.ArgumentParser(description='OASIS Reddit simulation') parser.add_argument( '--config', type=str, required=True, - help='配置文件路径 (simulation_config.json)' + help='Path to the config file (simulation_config.json)' ) parser.add_argument( '--max-rounds', type=int, default=None, - help='最大模拟轮数(可选,用于截断过长的模拟)' + help='Maximum number of simulation rounds (optional, used to truncate long simulations)' ) parser.add_argument( '--no-wait', action='store_true', default=False, - help='模拟完成后立即关闭环境,不进入等待命令模式' + help='Close the environment immediately after the simulation instead of entering command-wait mode' ) args = parser.parse_args() - # 在 main 函数开始时创建 shutdown 事件 + global _shutdown_event _shutdown_event = asyncio.Event() if not os.path.exists(args.config): - print(f"错误: 配置文件不存在: {args.config}") + print(f"Error: config file does not exist: {args.config}") sys.exit(1) - # 初始化日志配置(使用固定文件名,清理旧日志) + simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) @@ -735,21 +688,18 @@ async def main(): def setup_signal_handlers(): - """ - 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 - 让程序有机会正常清理资源(关闭数据库、环境等) - """ + """Setup Signal Handlers.""" def signal_handler(signum, frame): global _cleanup_done sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" - print(f"\n收到 {sig_name} 信号,正在退出...") + print(f"\nReceived {sig_name}; exiting...") if not _cleanup_done: _cleanup_done = True if _shutdown_event: _shutdown_event.set() else: - # 重复收到信号才强制退出 - print("强制退出...") + + print("Force exiting...") sys.exit(1) signal.signal(signal.SIGTERM, signal_handler) @@ -761,9 +711,8 @@ def signal_handler(signum, frame): try: asyncio.run(main()) except KeyboardInterrupt: - print("\n程序被中断") + print("\nProgram interrupted") except SystemExit: pass finally: - print("模拟进程已退出") - + print("Simulation process exited") diff --git a/backend/scripts/run_twitter_simulation.py b/backend/scripts/run_twitter_simulation.py index caab9e9d3..77a86b22c 100644 --- a/backend/scripts/run_twitter_simulation.py +++ b/backend/scripts/run_twitter_simulation.py @@ -1,17 +1,4 @@ -""" -OASIS Twitter模拟预设脚本 -此脚本读取配置文件中的参数来执行模拟,实现全程自动化 - -功能特性: -- 完成模拟后不立即关闭环境,进入等待命令模式 -- 支持通过IPC接收Interview命令 -- 支持单个Agent采访和批量采访 -- 支持远程关闭环境命令 - -使用方式: - python run_twitter_simulation.py --config /path/to/simulation_config.json - python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 -""" +"""OASIS Twitter simulation runner.""" import argparse import asyncio @@ -25,18 +12,18 @@ from datetime import datetime from typing import Dict, Any, List, Optional -# 全局变量:用于信号处理 + _shutdown_event = None _cleanup_done = False -# 添加项目路径 + _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) -# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) + from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): @@ -51,7 +38,7 @@ class UnicodeFormatter(logging.Formatter): - """自定义格式化器,将 Unicode 转义序列转换为可读字符""" + """Unicode Formatter.""" UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') @@ -68,24 +55,24 @@ def replace_unicode(match): class MaxTokensWarningFilter(logging.Filter): - """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" + """Max Tokens Warning Filter.""" def filter(self, record): - # 过滤掉包含 max_tokens 警告的日志 + if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True -# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 + logging.getLogger().addFilter(MaxTokensWarningFilter()) def setup_oasis_logging(log_dir: str): - """配置 OASIS 的日志,使用固定名称的日志文件""" + """Setup OASIS Logging.""" os.makedirs(log_dir, exist_ok=True) - # 清理旧的日志文件 + for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) if os.path.isfile(old_log) and f.endswith('.log'): @@ -126,25 +113,25 @@ def setup_oasis_logging(log_dir: str): generate_twitter_agent_graph ) except ImportError as e: - print(f"错误: 缺少依赖 {e}") - print("请先安装: pip install oasis-ai camel-ai") + print(f"Error: missing dependency {e}") + print("Install the dependencies first: pip install oasis-ai camel-ai") sys.exit(1) -# IPC相关常量 + IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: - """命令类型常量""" + """Command Type.""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class IPCHandler: - """IPC命令处理器""" + """IPC Handler.""" def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir @@ -155,12 +142,12 @@ def __init__(self, simulation_dir: str, env, agent_graph): self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True - # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): - """更新环境状态""" + """Update status.""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, @@ -168,11 +155,11 @@ def update_status(self, status: str): }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: - """轮询获取待处理命令""" + """Poll command.""" if not os.path.exists(self.commands_dir): return None - # 获取命令文件(按时间排序) + command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): @@ -191,7 +178,7 @@ def poll_command(self) -> Optional[Dict[str, Any]]: return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): - """发送响应""" + """Send response.""" response = { "command_id": command_id, "status": status, @@ -204,7 +191,7 @@ def send_response(self, command_id: str, status: str, result: Dict = None, error with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) - # 删除命令文件 + command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) @@ -212,50 +199,40 @@ def send_response(self, command_id: str, status: str, result: Dict = None, error pass async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: - """ - 处理单个Agent采访命令 - - Returns: - True 表示成功,False 表示失败 - """ + """Handle interview.""" try: - # 获取Agent + agent = self.agent_graph.get_agent(agent_id) - # 创建Interview动作 + interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) - # 执行Interview + actions = {agent: interview_action} await self.env.step(actions) - # 从数据库获取结果 + result = self._get_interview_result(agent_id) self.send_response(command_id, "completed", result=result) - print(f" Interview完成: agent_id={agent_id}") + print(f" Interview completed: agent_id={agent_id}") return True except Exception as e: error_msg = str(e) - print(f" Interview失败: agent_id={agent_id}, error={error_msg}") + print(f" Interview failed: agent_id={agent_id}, error={error_msg}") self.send_response(command_id, "failed", error=error_msg) return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: - """ - 处理批量采访命令 - - Args: - interviews: [{"agent_id": int, "prompt": str}, ...] - """ + """Handle batch interview.""" try: - # 构建动作字典 + actions = {} - agent_prompts = {} # 记录每个agent的prompt + agent_prompts = {} for interview in interviews: agent_id = interview.get("agent_id") @@ -269,16 +246,16 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) ) agent_prompts[agent_id] = prompt except Exception as e: - print(f" 警告: 无法获取Agent {agent_id}: {e}") + print(f" Warning: unable to get agent {agent_id}: {e}") if not actions: - self.send_response(command_id, "failed", error="没有有效的Agent") + self.send_response(command_id, "failed", error="No valid agents found") return False - # 执行批量Interview + await self.env.step(actions) - # 获取所有结果 + results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) @@ -288,17 +265,17 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) "interviews_count": len(results), "results": results }) - print(f" 批量Interview完成: {len(results)} 个Agent") + print(f" Batch interview completed: {len(results)} agents") return True except Exception as e: error_msg = str(e) - print(f" 批量Interview失败: {error_msg}") + print(f" Batch interview failed: {error_msg}") self.send_response(command_id, "failed", error=error_msg) return False def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: - """从数据库获取最新的Interview结果""" + """Get interview result.""" db_path = os.path.join(self.simulation_dir, "twitter_simulation.db") result = { @@ -314,7 +291,7 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: conn = sqlite3.connect(db_path) cursor = conn.cursor() - # 查询最新的Interview记录 + cursor.execute(""" SELECT user_id, info, created_at FROM trace @@ -336,17 +313,12 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: conn.close() except Exception as e: - print(f" 读取Interview结果失败: {e}") + print(f" Failed to read interview result: {e}") return result async def process_commands(self) -> bool: - """ - 处理所有待处理命令 - - Returns: - True 表示继续运行,False 表示应该退出 - """ + """Process commands.""" command = self.poll_command() if not command: return True @@ -355,7 +327,7 @@ async def process_commands(self) -> bool: command_type = command.get("command_type") args = command.get("args", {}) - print(f"\n收到IPC命令: {command_type}, id={command_id}") + print(f"\nReceived IPC command: {command_type}, id={command_id}") if command_type == CommandType.INTERVIEW: await self.handle_interview( @@ -373,19 +345,19 @@ async def process_commands(self) -> bool: return True elif command_type == CommandType.CLOSE_ENV: - print("收到关闭环境命令") - self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) + print("Received close-environment command") + self.send_response(command_id, "completed", result={"message": "Environment will close shortly"}) return False else: - self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") + self.send_response(command_id, "failed", error=f"Unknown command type: {command_type}") return True class TwitterSimulationRunner: - """Twitter模拟运行器""" + """Twitter Simulation Runner.""" + - # Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) AVAILABLE_ACTIONS = [ ActionType.CREATE_POST, ActionType.LIKE_POST, @@ -396,13 +368,7 @@ class TwitterSimulationRunner: ] def __init__(self, config_path: str, wait_for_commands: bool = True): - """ - 初始化模拟运行器 - - Args: - config_path: 配置文件路径 (simulation_config.json) - wait_for_commands: 模拟完成后是否等待命令(默认True) - """ + """Initialize the instance.""" self.config_path = config_path self.config = self._load_config() self.simulation_dir = os.path.dirname(config_path) @@ -412,47 +378,40 @@ def __init__(self, config_path: str, wait_for_commands: bool = True): self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: - """加载配置文件""" + """Load config.""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) def _get_profile_path(self) -> str: - """获取Profile文件路径(OASIS Twitter使用CSV格式)""" + """Get profile path.""" return os.path.join(self.simulation_dir, "twitter_profiles.csv") def _get_db_path(self) -> str: - """获取数据库路径""" + """Get db path.""" return os.path.join(self.simulation_dir, "twitter_simulation.db") def _create_model(self): - """ - 创建LLM模型 - - 统一使用项目根目录 .env 文件中的配置(优先级最高): - - LLM_API_KEY: API密钥 - - LLM_BASE_URL: API基础URL - - LLM_MODEL_NAME: 模型名称 - """ - # 优先从 .env 读取配置 + """Create model.""" + llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") - # 如果 .env 中没有,则使用 config 作为备用 + if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") - # 设置 camel-ai 所需的环境变量 + if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key if not os.environ.get("OPENAI_API_KEY"): - raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") + raise ValueError("Missing API key configuration. Set LLM_API_KEY in the project root .env file") if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url - print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") + print(f"LLM config: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else 'default'}...") return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, @@ -465,25 +424,15 @@ def _get_active_agents_for_round( current_hour: int, round_num: int ) -> List: - """ - 根据时间和配置决定本轮激活哪些Agent - - Args: - env: OASIS环境 - current_hour: 当前模拟小时(0-23) - round_num: 当前轮数 - - Returns: - 激活的Agent列表 - """ + """Get active agents for round.""" time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) - # 基础激活数量 + base_min = time_config.get("agents_per_hour_min", 5) base_max = time_config.get("agents_per_hour_max", 20) - # 根据时段调整 + peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) @@ -496,28 +445,28 @@ def _get_active_agents_for_round( target_count = int(random.uniform(base_min, base_max) * multiplier) - # 根据每个Agent的配置计算激活概率 + candidates = [] for cfg in agent_configs: agent_id = cfg.get("agent_id", 0) active_hours = cfg.get("active_hours", list(range(8, 23))) activity_level = cfg.get("activity_level", 0.5) - # 检查是否在活跃时间 + if current_hour not in active_hours: continue - # 根据活跃度计算概率 + if random.random() < activity_level: candidates.append(agent_id) - # 随机选择 + selected_ids = random.sample( candidates, min(target_count, len(candidates)) ) if candidates else [] - # 转换为Agent对象 + active_agents = [] for agent_id in selected_ids: try: @@ -529,50 +478,46 @@ def _get_active_agents_for_round( return active_agents async def run(self, max_rounds: int = None): - """运行Twitter模拟 - - Args: - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) - """ + """Run the requested object.""" print("=" * 60) - print("OASIS Twitter模拟") - print(f"配置文件: {self.config_path}") - print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") - print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") + print("OASIS Twitter Simulation") + print(f"Config file: {self.config_path}") + print(f"Simulation ID: {self.config.get('simulation_id', 'unknown')}") + print(f"Command wait mode: {'enabled' if self.wait_for_commands else 'disabled'}") print("=" * 60) - # 加载时间配置 + time_config = self.config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) - # 计算总轮数 + total_rounds = (total_hours * 60) // minutes_per_round - # 如果指定了最大轮数,则截断 + if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: - print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") + print(f"\nRound count truncated: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") - print(f"\n模拟参数:") - print(f" - 总模拟时长: {total_hours}小时") - print(f" - 每轮时间: {minutes_per_round}分钟") - print(f" - 总轮数: {total_rounds}") + print(f"\nSimulation parameters:") + print(f" - Total simulation duration: {total_hours} hours") + print(f" - Minutes per round: {minutes_per_round}") + print(f" - Total rounds: {total_rounds}") if max_rounds: - print(f" - 最大轮数限制: {max_rounds}") - print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") + print(f" - Max rounds limit: {max_rounds}") + print(f" - Agent count: {len(self.config.get('agent_configs', []))}") + - # 创建模型 - print("\n初始化LLM模型...") + print("\nInitializing LLM model...") model = self._create_model() - # 加载Agent图 - print("加载Agent Profile...") + + print("Loading agent profiles...") profile_path = self._get_profile_path() if not os.path.exists(profile_path): - print(f"错误: Profile文件不存在: {profile_path}") + print(f"Error: profile file does not exist: {profile_path}") return self.agent_graph = await generate_twitter_agent_graph( @@ -581,34 +526,34 @@ async def run(self, max_rounds: int = None): available_actions=self.AVAILABLE_ACTIONS, ) - # 数据库路径 + db_path = self._get_db_path() if os.path.exists(db_path): os.remove(db_path) - print(f"已删除旧数据库: {db_path}") + print(f"Deleted old database: {db_path}") + - # 创建环境 - print("创建OASIS环境...") + print("Creating OASIS environment...") self.env = oasis.make( agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.TWITTER, database_path=db_path, - semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 + semaphore=30, ) await self.env.reset() - print("环境初始化完成\n") + print("Environment initialization completed\n") + - # 初始化IPC处理器 self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") - # 执行初始事件 + event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) if initial_posts: - print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") + print(f"Executing initial events ({len(initial_posts)} initial posts)...") initial_actions = {} for post in initial_posts: agent_id = post.get("poster_agent_id", 0) @@ -620,23 +565,23 @@ async def run(self, max_rounds: int = None): action_args={"content": content} ) except Exception as e: - print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") + print(f" Warning: unable to create initial post for agent {agent_id}: {e}") if initial_actions: await self.env.step(initial_actions) - print(f" 已发布 {len(initial_actions)} 条初始帖子") + print(f" Published {len(initial_actions)} initial posts") - # 主模拟循环 - print("\n开始模拟循环...") + + print("\nStarting simulation loop...") start_time = datetime.now() for round_num in range(total_rounds): - # 计算当前模拟时间 + simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 - # 获取本轮激活的Agent + active_agents = self._get_active_agents_for_round( self.env, simulated_hour, round_num ) @@ -644,16 +589,16 @@ async def run(self, max_rounds: int = None): if not active_agents: continue - # 构建动作 + actions = { agent: LLMAction() for _, agent in active_agents } - # 执行动作 + await self.env.step(actions) - # 打印进度 + if (round_num + 1) % 10 == 0 or round_num == 0: elapsed = (datetime.now() - start_time).total_seconds() progress = (round_num + 1) / total_rounds * 100 @@ -663,20 +608,20 @@ async def run(self, max_rounds: int = None): f"- elapsed: {elapsed:.1f}s") total_elapsed = (datetime.now() - start_time).total_seconds() - print(f"\n模拟循环完成!") - print(f" - 总耗时: {total_elapsed:.1f}秒") - print(f" - 数据库: {db_path}") + print(f"\nSimulation loop completed!") + print(f" - Total elapsed time: {total_elapsed:.1f}s") + print(f" - Database: {db_path}") + - # 是否进入等待命令模式 if self.wait_for_commands: print("\n" + "=" * 60) - print("进入等待命令模式 - 环境保持运行") - print("支持的命令: interview, batch_interview, close_env") + print("Entering command-wait mode - environment stays running") + print("Supported commands: interview, batch_interview, close_env") print("=" * 60) self.ipc_handler.update_status("alive") - # 等待命令循环(使用全局 _shutdown_event) + try: while not _shutdown_event.is_set(): should_continue = await self.ipc_handler.process_commands() @@ -684,58 +629,58 @@ async def run(self, max_rounds: int = None): break try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) - break # 收到退出信号 + break except asyncio.TimeoutError: pass except KeyboardInterrupt: - print("\n收到中断信号") + print("\nReceived interrupt signal") except asyncio.CancelledError: - print("\n任务被取消") + print("\nTask was cancelled") except Exception as e: - print(f"\n命令处理出错: {e}") + print(f"\nCommand processing error: {e}") - print("\n关闭环境...") + print("\nClosing environment...") + - # 关闭环境 self.ipc_handler.update_status("stopped") await self.env.close() - print("环境已关闭") + print("Environment closed") print("=" * 60) async def main(): - parser = argparse.ArgumentParser(description='OASIS Twitter模拟') + parser = argparse.ArgumentParser(description='OASIS Twitter simulation') parser.add_argument( '--config', type=str, required=True, - help='配置文件路径 (simulation_config.json)' + help='Path to the config file (simulation_config.json)' ) parser.add_argument( '--max-rounds', type=int, default=None, - help='最大模拟轮数(可选,用于截断过长的模拟)' + help='Maximum number of simulation rounds (optional, used to truncate long simulations)' ) parser.add_argument( '--no-wait', action='store_true', default=False, - help='模拟完成后立即关闭环境,不进入等待命令模式' + help='Close the environment immediately after the simulation instead of entering command-wait mode' ) args = parser.parse_args() - # 在 main 函数开始时创建 shutdown 事件 + global _shutdown_event _shutdown_event = asyncio.Event() if not os.path.exists(args.config): - print(f"错误: 配置文件不存在: {args.config}") + print(f"Error: config file does not exist: {args.config}") sys.exit(1) - # 初始化日志配置(使用固定文件名,清理旧日志) + simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) @@ -747,21 +692,18 @@ async def main(): def setup_signal_handlers(): - """ - 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 - 让程序有机会正常清理资源(关闭数据库、环境等) - """ + """Setup Signal Handlers.""" def signal_handler(signum, frame): global _cleanup_done sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" - print(f"\n收到 {sig_name} 信号,正在退出...") + print(f"\nReceived {sig_name}; exiting...") if not _cleanup_done: _cleanup_done = True if _shutdown_event: _shutdown_event.set() else: - # 重复收到信号才强制退出 - print("强制退出...") + + print("Force exiting...") sys.exit(1) signal.signal(signal.SIGTERM, signal_handler) @@ -773,8 +715,8 @@ def signal_handler(signum, frame): try: asyncio.run(main()) except KeyboardInterrupt: - print("\n程序被中断") + print("\nProgram interrupted") except SystemExit: pass finally: - print("模拟进程已退出") + print("Simulation process exited") diff --git a/backend/scripts/test_profile_format.py b/backend/scripts/test_profile_format.py index 354e8b5ca..11093ccd9 100644 --- a/backend/scripts/test_profile_format.py +++ b/backend/scripts/test_profile_format.py @@ -1,9 +1,4 @@ -""" -测试Profile格式生成是否符合OASIS要求 -验证: -1. Twitter Profile生成CSV格式 -2. Reddit Profile生成JSON详细格式 -""" +"""Validate that generated profile formats match OASIS expectations.""" import os import sys @@ -11,19 +6,19 @@ import csv import tempfile -# 添加项目路径 + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.services.oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile def test_profile_formats(): - """测试Profile格式""" + """Test Profile Formats.""" print("=" * 60) - print("OASIS Profile格式测试") + print("OASIS Profile Format Test") print("=" * 60) - # 创建测试Profile数据 + test_profiles = [ OasisAgentProfile( user_id=0, @@ -63,84 +58,84 @@ def test_profile_formats(): generator = OasisProfileGenerator.__new__(OasisProfileGenerator) - # 使用临时目录 + with tempfile.TemporaryDirectory() as temp_dir: twitter_path = os.path.join(temp_dir, "twitter_profiles.csv") reddit_path = os.path.join(temp_dir, "reddit_profiles.json") - # 测试Twitter CSV格式 - print("\n1. 测试Twitter Profile (CSV格式)") + + print("\n1. Test Twitter Profile (CSV format)") print("-" * 40) generator._save_twitter_csv(test_profiles, twitter_path) - # 读取并验证CSV + with open(twitter_path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) rows = list(reader) - print(f" 文件: {twitter_path}") - print(f" 行数: {len(rows)}") - print(f" 表头: {list(rows[0].keys())}") - print(f"\n 示例数据 (第1行):") + print(f" File: {twitter_path}") + print(f" Row count: {len(rows)}") + print(f" Headers: {list(rows[0].keys())}") + print(f"\n Sample data (first row):") for key, value in rows[0].items(): print(f" {key}: {value}") - # 验证必需字段 + required_twitter_fields = ['user_id', 'user_name', 'name', 'bio', 'friend_count', 'follower_count', 'statuses_count', 'created_at'] missing = set(required_twitter_fields) - set(rows[0].keys()) if missing: - print(f"\n [错误] 缺少字段: {missing}") + print(f"\n [Error] Missing fields: {missing}") else: - print(f"\n [通过] 所有必需字段都存在") + print(f"\n [Pass] All required fields are present") + - # 测试Reddit JSON格式 - print("\n2. 测试Reddit Profile (JSON详细格式)") + print("\n2. Test Reddit Profile (detailed JSON format)") print("-" * 40) generator._save_reddit_json(test_profiles, reddit_path) - # 读取并验证JSON + with open(reddit_path, 'r', encoding='utf-8') as f: reddit_data = json.load(f) - print(f" 文件: {reddit_path}") - print(f" 条目数: {len(reddit_data)}") - print(f" 字段: {list(reddit_data[0].keys())}") - print(f"\n 示例数据 (第1条):") + print(f" File: {reddit_path}") + print(f" Item count: {len(reddit_data)}") + print(f" Fields: {list(reddit_data[0].keys())}") + print(f"\n Sample data (first item):") print(json.dumps(reddit_data[0], ensure_ascii=False, indent=4)) - # 验证详细格式字段 + required_reddit_fields = ['realname', 'username', 'bio', 'persona'] optional_reddit_fields = ['age', 'gender', 'mbti', 'country', 'profession', 'interested_topics'] missing = set(required_reddit_fields) - set(reddit_data[0].keys()) if missing: - print(f"\n [错误] 缺少必需字段: {missing}") + print(f"\n [Error] Missing required fields: {missing}") else: - print(f"\n [通过] 所有必需字段都存在") + print(f"\n [Pass] All required fields are present") present_optional = set(optional_reddit_fields) & set(reddit_data[0].keys()) - print(f" [信息] 可选字段: {present_optional}") + print(f" [Info] Optional fields present: {present_optional}") print("\n" + "=" * 60) - print("测试完成!") + print("Test complete!") print("=" * 60) def show_expected_formats(): - """显示OASIS期望的格式""" + """Show Expected Formats.""" print("\n" + "=" * 60) - print("OASIS 期望的Profile格式参考") + print("Reference: expected OASIS profile formats") print("=" * 60) - print("\n1. Twitter Profile (CSV格式)") + print("\n1. Twitter Profile (CSV format)") print("-" * 40) twitter_example = """user_id,user_name,name,bio,friend_count,follower_count,statuses_count,created_at 0,user0,User Zero,I am user zero with interests in technology.,100,150,500,2023-01-01 1,user1,User One,Tech enthusiast and coffee lover.,200,250,1000,2023-01-02""" print(twitter_example) - print("\n2. Reddit Profile (JSON详细格式)") + print("\n2. Reddit Profile (detailed JSON format)") print("-" * 40) reddit_example = [ { @@ -163,4 +158,3 @@ def show_expected_formats(): test_profile_formats() show_expected_formats() - diff --git a/backend/tests/test_ontology_normalizer.py b/backend/tests/test_ontology_normalizer.py new file mode 100644 index 000000000..6e6402e74 --- /dev/null +++ b/backend/tests/test_ontology_normalizer.py @@ -0,0 +1,38 @@ +from app.utils.ontology_normalizer import normalize_ontology_for_zep + + +def test_normalize_ontology_entity_names_and_source_targets(): + ontology = { + "entity_types": [ + { + "name": "IH_Team", + "description": "Escalation team", + "attributes": [], + }, + { + "name": "billing department", + "description": "Billing org", + "attributes": [], + }, + ], + "edge_types": [ + { + "name": "LEADS", + "description": "Leadership relation", + "source_targets": [ + {"source": "IH_Team", "target": "billing department"}, + ], + "attributes": [], + } + ], + } + + normalized, entity_name_mapping = normalize_ontology_for_zep(ontology) + + assert entity_name_mapping["IH_Team"] == "IHTeam" + assert entity_name_mapping["billing department"] == "BillingDepartment" + assert normalized["entity_types"][0]["name"] == "IHTeam" + assert normalized["entity_types"][1]["name"] == "BillingDepartment" + assert normalized["edge_types"][0]["source_targets"] == [ + {"source": "IHTeam", "target": "BillingDepartment"} + ] diff --git a/backend/uv.lock b/backend/uv.lock index f1ce4b60e..303dbb21c 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -1248,10 +1248,13 @@ dependencies = [ { name = "charset-normalizer" }, { name = "flask" }, { name = "flask-cors" }, + { name = "numpy" }, { name = "openai" }, + { name = "posthog" }, { name = "pydantic" }, { name = "pymupdf" }, { name = "python-dotenv" }, + { name = "tenacity" }, { name = "zep-cloud" }, ] @@ -1276,13 +1279,16 @@ requires-dist = [ { name = "charset-normalizer", specifier = ">=3.0.0" }, { name = "flask", specifier = ">=3.0.0" }, { name = "flask-cors", specifier = ">=6.0.0" }, - { name = "openai", specifier = ">=1.0.0" }, + { name = "numpy", specifier = ">=1.0.0" }, + { name = "openai", specifier = ">=1.91.0" }, { name = "pipreqs", marker = "extra == 'dev'", specifier = ">=0.5.0" }, + { name = "posthog", specifier = ">=3.0.0" }, { name = "pydantic", specifier = ">=2.0.0" }, { name = "pymupdf", specifier = ">=1.24.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "tenacity", specifier = ">=9.0.0" }, { name = "zep-cloud", specifier = "==3.13.0" }, ] provides-extras = ["dev"] @@ -1840,6 +1846,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "posthog" +version = "7.9.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backoff" }, + { name = "distro" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "six" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1c/a7/2865487853061fbd62383492237b546d2d8f7c1846272350d2b9e14138cd/posthog-7.9.12.tar.gz", hash = "sha256:ebabf2eb2e1c1fbf22b0759df4644623fa43cc6c9dcbe9fd429b7937d14251ec", size = 176828, upload-time = "2026-03-12T09:01:15.184Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/a9/7a803aed5a5649cf78ea7b31e90d0080181ba21f739243e1741a1e607f1f/posthog-7.9.12-py3-none-any.whl", hash = "sha256:7175bd1698a566bfea98a016c64e3456399f8046aeeca8f1d04ae5bf6c5a38d0", size = 202469, upload-time = "2026-03-12T09:01:13.38Z" }, +] + [[package]] name = "prance" version = "23.6.21.0" @@ -2987,6 +3010,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, +] + [[package]] name = "texttable" version = "1.7.0" diff --git a/docker-compose.yml b/docker-compose.yml index 637f1dfae..e8035489d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,10 @@ services: mirofish: - image: ghcr.io/666ghj/mirofish:latest - # 加速镜像(如拉取缓慢可替换上方地址) + image: mirofish-local + build: + context: . + platform: linux/arm64 + # Mirror image for faster pulls if needed # image: ghcr.nju.edu.cn/666ghj/mirofish:latest container_name: mirofish env_file: @@ -9,6 +12,25 @@ services: ports: - "3000:3000" - "5001:5001" + depends_on: + - neo4j restart: unless-stopped volumes: - - ./backend/uploads:/app/backend/uploads \ No newline at end of file + - ./backend/uploads:/app/backend/uploads + neo4j: + image: neo4j:5.26.22-enterprise + container_name: mirofish-neo4j + environment: + NEO4J_AUTH: ${NEO4J_USER:-neo4j}/${NEO4J_PASSWORD:-mirofish-local-password} + NEO4J_ACCEPT_LICENSE_AGREEMENT: "yes" + ports: + - "7474:7474" + - "7687:7687" + restart: unless-stopped + volumes: + - neo4j_data:/data + - neo4j_logs:/logs + +volumes: + neo4j_data: + neo4j_logs: diff --git a/docs/zep-cloud-to-local-migration-plan.md b/docs/zep-cloud-to-local-migration-plan.md new file mode 100644 index 000000000..51cbf08d5 --- /dev/null +++ b/docs/zep-cloud-to-local-migration-plan.md @@ -0,0 +1,663 @@ +# Zep Cloud to Local Migration Plan + +Date: 2026-03-28 + +Status: Planning document only. No runtime changes are included in this document. + +## Executive summary + +MiroFish is not using Zep Cloud in just one place. It depends on Zep Cloud for: + +- Graph creation and deletion +- Ontology registration +- Text ingestion and episode processing +- Full graph reads for visualization and simulation prep +- Semantic and hybrid search for the report agent +- Live simulation memory updates back into the graph + +Because of that, migrating from Zep Cloud to a local setup is not a simple API-key swap. + +As of 2026-03-28, Zep’s current docs say that Zep Community Edition is deprecated and no longer supported, and point self-hosted users to Graphiti or BYOC instead. That means the safest “Zep local” target for this repo is: + +- `graphiti-core` running inside the backend +- A local Neo4j instance for graph storage +- A provider abstraction layer so MiroFish can run `zep_cloud` and `graphiti_local` side by side during rollout + +This plan assumes that target. + +## Official references + +- Zep FAQ: https://help.getzep.com/faq +- Zep open-source direction announcement, published April 2, 2025: https://blog.getzep.com/announcing-a-new-direction-for-zeps-open-source-strategy/ +- Graphiti quick start: https://help.getzep.com/graphiti/getting-started/quick-start +- Graphiti Neo4j configuration: https://help.getzep.com/graphiti/configuration/neo-4-j-configuration +- Graphiti custom entity and edge types: https://help.getzep.com/graphiti/core-concepts/custom-entity-and-edge-types +- Graphiti graph namespacing: https://help.getzep.com/graphiti/core-concepts/graph-namespacing +- Graphiti search: https://help.getzep.com/graphiti/working-with-data/searching +- Graphiti CRUD operations: https://help.getzep.com/graphiti/working-with-data/crud-operations +- Graphiti fact triples: https://help.getzep.com/graphiti/working-with-data/adding-fact-triples + +## Current MiroFish dependency map + +These files are the main Zep Cloud touch points in the repo today: + +- `backend/app/config.py` + - Requires `ZEP_API_KEY` +- `backend/requirements.txt` + - Pins `zep-cloud==3.13.0` +- `backend/pyproject.toml` + - Pins `zep-cloud==3.13.0` +- `backend/app/services/graph_builder.py` + - Creates graphs, sets ontology, ingests batches, waits for episode processing, reads graph data, deletes graphs +- `backend/app/services/zep_entity_reader.py` + - Reads all nodes and edges, enriches entities with edge context +- `backend/app/services/zep_tools.py` + - Runs graph search and powers report-agent retrieval +- `backend/app/services/zep_graph_memory_updater.py` + - Pushes live simulation activity back into the graph +- `backend/app/services/oasis_profile_generator.py` + - Uses graph search to enrich agent profiles +- `backend/app/utils/zep_paging.py` + - Implements graph-wide pagination on Zep node and edge list APIs +- `backend/app/api/graph.py` + - Exposes graph build, read, and delete flows +- `backend/app/api/simulation.py` + - Exposes graph entity read flows used by simulation setup +- `backend/app/api/report.py` + - Exposes report search/statistics endpoints +- `.env.example`, `README.md`, `README-EN.md` + - Document Zep Cloud setup +- `docker-compose.yml` + - Does not currently start any local graph database + +## What must keep working after migration + +The migration is successful only if these user-visible flows still work: + +- Build a graph from uploaded source material +- Store and reuse a project-level `graph_id` +- Load graph nodes and edges in the UI +- Read entities from the graph to prepare simulation agents +- Search the graph for report generation +- Keep live simulation memory updates enabled +- Delete graph data when a project is removed or rebuilt + +## Recommended target architecture + +### 1. Keep the existing product-level `graph_id` + +Do not change the frontend contract if we can avoid it. + +- Keep storing `graph_id` in project and simulation JSON +- In the new local provider, keep `graph_id` as the app-level identifier +- Reuse `graph_id` as the primary Graphiti `group_id` +- This preserves almost all frontend behavior and avoids a database migration in the UI layer + +Reason: + +- Graphiti uses `group_id` for isolated graph namespaces +- In the currently targeted `graphiti-core` release, MiroFish can stay on a single Neo4j database and isolate graphs by `group_id` +- MiroFish already thinks in terms of one graph per project +- `graph_id -> group_id` is the cleanest compatibility bridge for this repo + +### 2. Add a graph provider abstraction in the backend + +Create a small internal interface, for example: + +- `create_graph(name) -> graph_id` +- `delete_graph(graph_id)` +- `set_ontology(graph_id, ontology)` +- `add_text_batch(graph_id, chunks)` +- `wait_for_ingestion(graph_id, job_ref)` +- `get_all_nodes(graph_id)` +- `get_all_edges(graph_id)` +- `get_node(graph_id, node_uuid)` +- `get_node_edges(graph_id, node_uuid)` +- `search_graph(graph_id, query, scope, limit)` +- `append_activity(graph_id, text)` +- `get_graph_statistics(graph_id)` + +Implement two providers: + +- `ZepCloudGraphProvider` +- `GraphitiLocalGraphProvider` + +Add a factory selected by an env var such as: + +- `GRAPH_BACKEND=zep_cloud` +- `GRAPH_BACKEND=graphiti_local` + +### 3. Embed Graphiti in the backend first + +For the first migration, do not introduce a second application service unless it becomes necessary. + +Recommended first version: + +- Flask backend remains the API server +- Graphiti runs as a Python library inside the backend process +- Neo4j runs as a local service in Docker Compose + +Why this is the safest first move: + +- Fewer moving parts +- Easier local debugging +- Lower operational complexity +- Faster path to dual-run and rollback + +### 4. Treat async Graphiti calls as a real design task + +This is an important implementation detail. + +Current MiroFish backend code is mostly synchronous Flask code. Graphiti’s documented API is async-first. That means we need one of these approaches: + +- Wrap Graphiti calls in a sync adapter using a controlled event-loop helper +- Move graph operations into a worker or async service + +Recommendation: + +- Phase 1 and Phase 2 should use a thin sync adapter around Graphiti +- Only move to a separate async service if latency or concurrency becomes a real problem + +## API and data-model mapping + +### Current Zep Cloud behavior to replace + +| Current behavior | Current MiroFish usage | Local replacement strategy | +| --- | --- | --- | +| `graph.create` | Creates one graph per project | No separate graph object in Graphiti; create and reserve a `group_id` namespace | +| `graph.set_ontology` | Registers project ontology before ingestion | Pass custom entity types, edge types, and edge maps during episode ingestion | +| `graph.add_batch` | Sends document chunks for extraction | Loop over `graphiti.add_episode(...)` or bulk helper if adopted | +| `episode.get(...processed)` | Polls until ingestion finishes | Track ingestion at app level; treat Graphiti call completion as local completion, or store background job state | +| `graph.node.get_by_graph_id` | Reads all nodes | Query by `group_id`, using Graphiti CRUD utilities or direct Neo4j Cypher | +| `graph.edge.get_by_graph_id` | Reads all edges | Query by `group_id`, using Graphiti CRUD utilities or direct Neo4j Cypher | +| `graph.search(scope="edges" / "nodes")` | Report and simulation retrieval | Use `graphiti.search()` or `graphiti._search()` recipes for edge-only and node-only search | +| `graph.add(type="text", data=...)` | Live simulation memory writeback | Convert activity batches into Graphiti text episodes | +| `graph.delete` | Removes a graph | Delete all nodes, edges, and episodes for the namespace | + +### Ontology compatibility notes + +Graphiti supports custom entity and edge types using Pydantic models, which matches MiroFish’s current ontology-generation approach well. It also has similar protected attribute names, including: + +- `uuid` +- `name` +- `group_id` +- `labels` +- `created_at` +- `summary` +- `attributes` +- `name_embedding` + +This is good news for MiroFish because `backend/app/services/graph_builder.py` already normalizes ontology names and protected attributes for Zep-style constraints. That logic should be reused, not rewritten from scratch. + +### Search compatibility notes + +Graphiti supports: + +- Hybrid search +- Node distance reranking +- Configurable `_search()` recipes for node-only and edge-only search + +The main migration work is not “can Graphiti search,” but: + +- matching the result shape expected by `zep_tools.py` +- preserving edge facts, node summaries, and score ordering closely enough for report quality + +### Full graph reads + +The current UI and simulation setup rely on full-graph reads, not just search. + +That means the local provider must expose: + +- list all nodes for a namespace +- list all edges for a namespace +- get one node by UUID +- get related edges for a node + +Recommendation: + +- Use direct Neo4j reads by `group_id` for graph-wide list endpoints +- Keep Graphiti itself focused on ingestion and search + +This is simpler than trying to force every UI read through search APIs. + +## Migration phases + +## Phase 0: Lock the target and scope + +Goal: + +- Avoid starting implementation against the wrong “local Zep” target + +Tasks: + +- Confirm the target is `Graphiti + Neo4j`, not deprecated Zep Community Edition +- Decide whether the migration must cover: + - existing historical graphs + - only newly created graphs + - both +- Decide whether local graph storage is the only goal, or whether local LLM/embedding providers are also required +- Define feature parity as: + - required for launch + - acceptable with minor quality drift + - allowed to defer + +Exit criteria: + +- One agreed target architecture +- One agreed feature-parity list +- One agreed data-migration scope + +## Phase 1: Prepare infrastructure + +Goal: + +- Make the repo able to run a local graph backend + +Tasks: + +- Add Neo4j service to `docker-compose.yml` +- Add persistent Neo4j volume +- Add backend env vars: + - `GRAPH_BACKEND` + - `NEO4J_URI` + - `NEO4J_USER` + - `NEO4J_PASSWORD` + - optional `GRAPHITI_TELEMETRY_ENABLED=false` + - optional `SEMAPHORE_LIMIT` +- Add Graphiti-compatible runtime deps to `backend/requirements.txt` and `backend/pyproject.toml` +- Install `graphiti-core` separately with `uv pip install --no-deps ...` to avoid the `neo4j` version conflict with `camel-oasis` +- Keep `zep-cloud` installed during dual-run +- Add backend startup initialization: + - connect to Graphiti + - call `build_indices_and_constraints()` once +- Update `.env.example`, `README.md`, and `README-EN.md` + +Recommended first Docker addition: + +- Neo4j 5.26+ image +- For this repo, keeping the enterprise variant in `docker-compose.yml` is the safer default if you may reuse volumes that were created with Neo4j block format +- Ports `7474` and `7687` +- Persistent `/data` and `/logs` volumes +- One shared Neo4j database with Graphiti `group_id` isolation per MiroFish graph + +Exit criteria: + +- Local `docker compose up` starts Neo4j and MiroFish +- Backend can connect to Neo4j +- Graphiti indices are created successfully + +## Phase 2: Add the provider abstraction without changing behavior + +Goal: + +- Isolate Zep-specific code before swapping implementations + +Tasks: + +- Create `graph_provider` interface and factory +- Move Zep client construction behind the provider +- Keep existing Zep behavior as the default implementation +- Refactor these services to depend on the provider instead of importing `zep_cloud` directly: + - `graph_builder.py` + - `zep_entity_reader.py` + - `zep_tools.py` + - `zep_graph_memory_updater.py` + - `oasis_profile_generator.py` +- Refactor `zep_paging.py` into provider-neutral graph read helpers or retire it in favor of provider methods + +Important rule: + +- Do not change API response shapes in this phase + +Exit criteria: + +- App still works exactly as before with `GRAPH_BACKEND=zep_cloud` +- Zep-specific imports are limited to the Zep provider module + +## Phase 3: Implement the local Graphiti provider + +Goal: + +- Support the same MiroFish workflows on a local graph backend + +Tasks: + +- Map `graph_id` to Graphiti `group_id` +- Implement `create_graph` as namespace bootstrap +- Reuse ontology normalization logic from current graph builder +- Convert MiroFish ontology into: + - Graphiti custom entity types + - Graphiti custom edge types + - Graphiti edge type map +- Implement chunk ingestion with `add_episode` +- Implement live simulation memory writes with `add_episode` +- Implement search using: + - edge-focused `_search()` recipe for fact retrieval + - node-focused `_search()` recipe for entity retrieval +- Implement graph-wide reads by `group_id` +- Implement delete-by-namespace logic + +Important local-behavior differences to handle: + +- Graphiti is namespace-based, not graph-object-based +- Ingestion lifecycle is different from Zep Cloud polling +- Search result objects will not be identical to Zep Cloud result objects + +Exit criteria: + +- A new graph can be built locally +- The UI can render graph nodes and edges +- Simulation prep can read filtered entities +- Report agent search returns usable facts + +## Phase 4: Data migration and backfill + +Goal: + +- Move old project graphs, not just new ones + +Recommended order of preference: + +### Option A: Re-ingest from original source documents + +This is the best option when the original uploaded material still exists. + +Why it is preferred: + +- It preserves the intended extraction pipeline +- It preserves ontology-guided classification +- It avoids lossy conversion from already-extracted node and edge summaries back into raw text + +Use this when: + +- the original uploaded text or PDF is still available +- the project ontology is still stored + +### Option B: Rebuild from exported facts and nodes + +Use this only for projects where the original source text is missing. + +Approach: + +- Export Zep Cloud nodes and edges with existing MiroFish read code +- Convert important edges into fact triples or synthetic episodes +- Rebuild the namespace in Graphiti + +Tradeoff: + +- Faster for stranded data +- Lower fidelity than re-ingesting original source material + +### Required migration script + +Create a script such as: + +- `backend/scripts/migrate_zep_cloud_to_graphiti.py` + +Suggested responsibilities: + +- list existing projects with `graph_id` +- detect whether original source text is available +- choose migration mode per project +- create local namespace +- ingest data +- validate node and edge counts +- write migration report JSON + +Suggested validation fields per project: + +- old graph id +- new group id +- migration mode used +- old node count +- new node count +- old edge count +- new edge count +- top 10 search comparison queries +- status +- error details if failed + +Exit criteria: + +- A representative sample of old projects has been migrated and validated + +## Phase 5: Dual-run and comparison + +Goal: + +- Prove that local results are good enough before cutover + +Tasks: + +- Add a temporary comparison mode +- For selected projects: + - build or migrate graph in both backends + - run the same search queries against both + - compare: + - returned facts + - node summaries + - graph statistics + - report quality +- Log mismatches for manual review + +Recommended comparison set: + +- graph build from a medium-size document +- entity list for simulation setup +- 10 report-agent queries from real historical runs +- live memory update during a short simulation + +Exit criteria: + +- No blocker regressions in core flows +- Search quality is acceptable for report generation + +## Phase 6: Cutover + +Goal: + +- Move production behavior to the local backend with low risk + +Tasks: + +- Keep both backends available behind `GRAPH_BACKEND` +- Start with local backend in dev only +- Then test on a small set of staging projects +- Then switch default backend for new graphs only +- After confidence is high, migrate old projects and switch all reads to local + +Safe cutover order: + +1. New graph builds go to local +2. New simulation live updates go to local +3. Report/search reads go to local +4. Historical projects are backfilled +5. Zep Cloud becomes fallback only + +Rollback: + +- Flip `GRAPH_BACKEND` back to `zep_cloud` +- Leave dual-write or dual-read disabled unless specifically needed + +## Phase 7: Cleanup + +Goal: + +- Remove cloud-only assumptions after the local backend is stable + +Tasks: + +- Remove `ZEP_API_KEY` from required config if no longer needed +- Remove `zep-cloud` dependency +- Remove Zep-specific code paths and helpers +- Rename files and classes so they are provider-neutral + - example: `zep_tools.py` -> `graph_tools.py` + - example: `ZepEntityReader` -> `GraphEntityReader` +- Update docs to describe local-first setup + +Exit criteria: + +- No runtime path depends on Zep Cloud +- Local setup is the documented default + +## File-level implementation plan + +### Config and infra + +- `backend/app/config.py` + - add `GRAPH_BACKEND` + - add `NEO4J_URI` + - add `NEO4J_USER` + - add `NEO4J_PASSWORD` + - stop hard-failing on missing `ZEP_API_KEY` when local backend is selected +- `.env.example` + - document both cloud and local modes +- `docker-compose.yml` + - add Neo4j service and volume +- `backend/requirements.txt` + - add Graphiti dependency +- `backend/pyproject.toml` + - add Graphiti dependency + +### New provider layer + +Recommended new files: + +- `backend/app/services/graph_provider/base.py` +- `backend/app/services/graph_provider/factory.py` +- `backend/app/services/graph_provider/zep_cloud_provider.py` +- `backend/app/services/graph_provider/graphiti_local_provider.py` +- `backend/app/services/graph_provider/models.py` + +### Existing service refactors + +- `backend/app/services/graph_builder.py` + - use provider for graph lifecycle and ingestion +- `backend/app/services/zep_entity_reader.py` + - make provider-neutral +- `backend/app/services/zep_tools.py` + - make provider-neutral +- `backend/app/services/zep_graph_memory_updater.py` + - write activities through provider +- `backend/app/services/oasis_profile_generator.py` + - search via provider +- `backend/app/api/graph.py` + - leave API shape stable +- `backend/app/api/simulation.py` + - leave API shape stable +- `backend/app/api/report.py` + - leave API shape stable + +### Frontend impact + +Frontend changes should be minimal if backend response shapes stay stable. + +Likely no required frontend changes beyond wording updates in docs or setup screens. + +## Testing plan + +### Unit tests + +- provider factory selection +- ontology normalization compatibility +- graph-id to group-id mapping +- node and edge shape normalization +- search result shape normalization + +### Integration tests + +- create graph locally +- ingest text chunks +- read all nodes and edges +- retrieve filtered entities +- run report search +- push live simulation activity and confirm graph updates +- delete namespace and confirm cleanup + +### Regression tests + +Use at least one real project fixture and compare: + +- node count difference stays within an agreed threshold +- edge count difference stays within an agreed threshold +- top search results are semantically comparable +- report output remains acceptable to a human reviewer + +## Risks and mitigation + +### Risk 1: Search quality differs from Zep Cloud + +Mitigation: + +- dual-run search comparisons +- tune Graphiti search recipes +- add fallback local keyword search only as backup, not primary behavior + +### Risk 2: Full graph reads are harder than search + +Mitigation: + +- use direct Neo4j namespace queries for UI graph rendering and entity listing + +### Risk 3: Async Graphiti calls complicate Flask integration + +Mitigation: + +- start with a sync adapter +- isolate async logic inside the provider + +### Risk 4: Old graphs cannot be migrated losslessly from summaries alone + +Mitigation: + +- prefer original source-document re-ingestion +- use fact-triple fallback only where necessary + +### Risk 5: Config sprawl during rollout + +Mitigation: + +- one `GRAPH_BACKEND` switch +- one local env block +- keep cloud env vars optional once local mode is supported + +## Acceptance criteria + +The migration can be called complete when all of the following are true: + +- MiroFish can build a project graph locally without Zep Cloud +- The graph viewer loads local nodes and edges correctly +- Simulation setup reads local graph entities correctly +- Report generation can retrieve relevant facts from the local graph +- Live simulation memory updates work against the local graph +- Existing important projects are migrated or rebuildable +- Zep Cloud is no longer required for normal operation + +## Recommended execution order + +If this work is implemented as an engineering project, the lowest-risk order is: + +1. Add infra and config +2. Add provider abstraction +3. Keep Zep Cloud as the default provider +4. Implement local Graphiti provider +5. Validate new graph creation locally +6. Validate report search locally +7. Validate simulation entity loading and live updates locally +8. Add migration/backfill script +9. Dual-run and compare +10. Cut over new graphs +11. Migrate old graphs +12. Remove Zep Cloud dependency + +## Bottom line + +This migration is very feasible, but it should be treated as a backend provider replacement project, not a config tweak. + +The key decisions that make it safe are: + +- use Graphiti plus local Neo4j as the supported local target +- keep `graph_id` as the app-level identifier and reuse it as Graphiti `group_id` +- add a provider abstraction before changing behavior +- prefer re-ingesting original source documents for old projects +- dual-run before cutover diff --git a/frontend/index.html b/frontend/index.html index 009c924a4..72f28baec 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -1,5 +1,5 @@ - + @@ -7,8 +7,8 @@ - - MiroFish - 预测万物 + + MiroFish - Predict Anything
diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 8c4fa710d..fee02cad8 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1331,7 +1331,6 @@ "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", "license": "ISC", - "peer": true, "engines": { "node": ">=12" } @@ -1809,7 +1808,6 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -1943,7 +1941,6 @@ "integrity": "sha512-ITcnkFeR3+fI8P1wMgItjGrR10170d8auB4EpMLPqmx6uxElH3a/hHGQabSHKdqd4FXWO1nFIp9rRn7JQ34ACQ==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", @@ -2018,7 +2015,6 @@ "resolved": "https://registry.npmjs.org/vue/-/vue-3.5.25.tgz", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "license": "MIT", - "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.25", "@vue/compiler-sfc": "3.5.25", diff --git a/frontend/src/App.vue b/frontend/src/App.vue index b7cd71ca6..a76fb0f97 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -3,11 +3,11 @@ \ No newline at end of file + diff --git a/frontend/src/components/Step4Report.vue b/frontend/src/components/Step4Report.vue index 22f2bdcfd..d4586fecd 100644 --- a/frontend/src/components/Step4Report.vue +++ b/frontend/src/components/Step4Report.vue @@ -58,7 +58,7 @@ - 正在生成{{ section.title }}... + Generating {{ section.title }}... @@ -127,9 +127,9 @@ - +