diff --git a/.env.example b/.env.example index 78a3b72c0..c2890f25e 100644 --- a/.env.example +++ b/.env.example @@ -1,16 +1,47 @@ -# LLM API配置(支持 OpenAI SDK 格式的任意 LLM API) -# 推荐使用阿里百炼平台qwen-plus模型:https://bailian.console.aliyun.com/ -# 注意消耗较大,可先进行小于40轮的模拟尝试 +# ===== LLM Provider ===== +# Supported values: auto | openai | anthropic | github-copilot | ollama | claude +# 'auto' detects the provider from the model name and available env vars. +LLM_PROVIDER=auto + +# ===== Option 1: Anthropic Claude (recommended for quality) ===== +# Get your key at: https://console.anthropic.com/settings/keys +# LLM_API_KEY=sk-ant-api03-... +# LLM_MODEL_NAME=claude-sonnet-4-20250514 +# Notes: +# - LLM_BASE_URL is not needed; the SDK uses the official endpoint automatically. +# - Model names starting with 'claude-' auto-select the anthropic provider. + +# ===== Option 2: Alibaba Bailian / Qwen (recommended for cost) ===== +# Get your key at: https://bailian.console.aliyun.com/ +# Note: high simulation rounds consume significant tokens — start with <40 rounds. LLM_API_KEY=your_api_key_here LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus -# ===== ZEP记忆图谱配置 ===== -# 每月免费额度即可支撑简单使用:https://app.getzep.com/ -ZEP_API_KEY=your_zep_api_key_here +# ===== Option 3: GitHub Copilot (no separate API key needed) ===== +# Uses your existing Copilot subscription. Rate limits are lower than direct APIs. +# LLM_PROVIDER=github-copilot +# GITHUB_TOKEN=ghp_your_github_pat_here # or use GH_TOKEN / COPILOT_GITHUB_TOKEN +# LLM_MODEL_NAME=gpt-4o + +# ===== Option 4: Ollama (fully local, no API key) ===== +# Requires Ollama running locally: https://ollama.ai +# LLM_PROVIDER=ollama +# LLM_BASE_URL=http://localhost:11434/v1 +# LLM_MODEL_NAME=llama3.2 -# ===== 加速 LLM 配置(可选)===== -# 注意如果不使用加速配置,env文件中就不要出现下面的配置项 -LLM_BOOST_API_KEY=your_api_key_here -LLM_BOOST_BASE_URL=your_base_url_here -LLM_BOOST_MODEL_NAME=your_model_name_here \ No newline at end of file +# ===== Option 5: Claude CLI subprocess ===== +# Uses the `claude` CLI — no API key needed beyond your Claude account login. +# LLM_PROVIDER=claude +# CLAUDE_CLI_PATH=claude +# LLM_MODEL_NAME=claude-sonnet-4-20250514 + +# ===== Optional: Boost LLM (faster/cheaper secondary model) ===== +# If unset, the primary LLM is used for all calls. +# LLM_BOOST_API_KEY=your_api_key_here +# LLM_BOOST_BASE_URL=your_base_url_here +# LLM_BOOST_MODEL_NAME=your_model_name_here + +# ===== Zep memory graph ===== +# Free tier is sufficient for basic use: https://app.getzep.com/ +ZEP_API_KEY=your_zep_api_key_here diff --git a/backend/app/__init__.py b/backend/app/__init__.py index aba624bba..4e891fee4 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -1,12 +1,12 @@ """ -MiroFish Backend - Flask应用工厂 +MiroFish Backend — Flask application factory. """ import os import warnings -# 抑制 multiprocessing resource_tracker 的警告(来自第三方库如 transformers) -# 需要在所有其他导入之前设置 +# Suppress multiprocessing resource_tracker warnings from third-party libs (e.g. transformers). +# Must be set before all other imports. warnings.filterwarnings("ignore", message=".*resource_tracker.*") from flask import Flask, request @@ -17,64 +17,62 @@ def create_app(config_class=Config): - """Flask应用工厂函数""" + """Create and configure the Flask application.""" app = Flask(__name__) app.config.from_object(config_class) - - # 设置JSON编码:确保中文直接显示(而不是 \uXXXX 格式) - # Flask >= 2.3 使用 app.json.ensure_ascii,旧版本使用 JSON_AS_ASCII 配置 + + # Ensure non-ASCII characters (e.g. Chinese) render as-is in JSON responses + # Flask >= 2.3 uses app.json.ensure_ascii; older versions use JSON_AS_ASCII config key if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'): app.json.ensure_ascii = False - - # 设置日志 + logger = setup_logger('mirofish') - - # 只在 reloader 子进程中打印启动信息(避免 debug 模式下打印两次) + + # Only log startup info once — avoid duplicate output in debug/reloader mode is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' debug_mode = app.config.get('DEBUG', False) should_log_startup = not debug_mode or is_reloader_process - + if should_log_startup: logger.info("=" * 50) - logger.info("MiroFish Backend 启动中...") + logger.info("MiroFish Backend starting...") logger.info("=" * 50) - - # 启用CORS + + # Enable CORS for all API routes CORS(app, resources={r"/api/*": {"origins": "*"}}) - - # 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程) + + # Register simulation process cleanup on server shutdown from .services.simulation_runner import SimulationRunner SimulationRunner.register_cleanup() if should_log_startup: - logger.info("已注册模拟进程清理函数") - - # 请求日志中间件 + logger.info("Simulation process cleanup registered") + + # Request/response logging middleware @app.before_request def log_request(): - logger = get_logger('mirofish.request') - logger.debug(f"请求: {request.method} {request.path}") + req_logger = get_logger('mirofish.request') + req_logger.debug(f"Request: {request.method} {request.path}") if request.content_type and 'json' in request.content_type: - logger.debug(f"请求体: {request.get_json(silent=True)}") - + req_logger.debug(f"Body: {request.get_json(silent=True)}") + @app.after_request def log_response(response): - logger = get_logger('mirofish.request') - logger.debug(f"响应: {response.status_code}") + req_logger = get_logger('mirofish.request') + req_logger.debug(f"Response: {response.status_code}") return response - - # 注册蓝图 + + # Register blueprints from .api import graph_bp, simulation_bp, report_bp app.register_blueprint(graph_bp, url_prefix='/api/graph') app.register_blueprint(simulation_bp, url_prefix='/api/simulation') app.register_blueprint(report_bp, url_prefix='/api/report') - - # 健康检查 + + # Health check @app.route('/health') def health(): return {'status': 'ok', 'service': 'MiroFish Backend'} - + if should_log_startup: - logger.info("MiroFish Backend 启动完成") - - return app + logger.info("MiroFish Backend ready") + return app diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index ffda743a3..7e8ba0265 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -1,5 +1,5 @@ """ -API路由模块 +API route modules. """ from flask import Blueprint diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50a..c069d2aa8 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,75 +1,159 @@ """ -配置管理 -统一从项目根目录的 .env 文件加载配置 +Configuration management. +Loads settings from the .env file at the project root (MiroFish/.env). """ import os +import shutil from dotenv import load_dotenv -# 加载项目根目录的 .env 文件 -# 路径: MiroFish/.env (相对于 backend/app/config.py) -project_root_env = os.path.join(os.path.dirname(__file__), '../../.env') +# Load .env from project root — path relative to backend/app/config.py +_project_root_env = os.path.join(os.path.dirname(__file__), '../../.env') -if os.path.exists(project_root_env): - load_dotenv(project_root_env, override=True) +if os.path.exists(_project_root_env): + load_dotenv(_project_root_env, override=True) else: - # 如果根目录没有 .env,尝试加载环境变量(用于生产环境) + # No .env at root — fall back to environment variables (production mode) load_dotenv(override=True) class Config: - """Flask配置类""" - - # Flask配置 + """Flask configuration class.""" + + # ── Flask ──────────────────────────────────────────────────────────────── SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key') DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true' - - # JSON配置 - 禁用ASCII转义,让中文直接显示(而不是 \uXXXX 格式) + + # Disable ASCII escaping so non-ASCII characters display correctly in JSON JSON_AS_ASCII = False - - # LLM配置(统一使用OpenAI格式) + + # ── LLM provider selection ──────────────────────────────────────────────── + # Supported values: + # auto (default) — infer from LLM_MODEL_NAME and available env vars + # openai — OpenAI SDK (works with any OpenAI-compatible API) + # anthropic — Anthropic Claude SDK (direct API, requires LLM_API_KEY) + # github-copilot — GitHub Copilot via token exchange (requires GITHUB_TOKEN) + # ollama — Local Ollama via OpenAI-compatible endpoint + # claude — Claude CLI subprocess (requires `claude` CLI installed) + LLM_PROVIDER = os.environ.get('LLM_PROVIDER', 'auto') + + # ── LLM credentials & model ────────────────────────────────────────────── LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') - - # Zep配置 + + # Optional: faster/cheaper secondary LLM for non-critical calls + LLM_BOOST_API_KEY = os.environ.get('LLM_BOOST_API_KEY') + LLM_BOOST_BASE_URL = os.environ.get('LLM_BOOST_BASE_URL') + LLM_BOOST_MODEL_NAME = os.environ.get('LLM_BOOST_MODEL_NAME') + + # ── Provider-specific settings ─────────────────────────────────────────── + # Claude CLI: path to the `claude` executable (used when LLM_PROVIDER=claude) + CLAUDE_CLI_PATH = os.environ.get('CLAUDE_CLI_PATH', 'claude') + + # GitHub Copilot: token env vars (checked in priority order) + # COPILOT_GITHUB_TOKEN > GH_TOKEN > GITHUB_TOKEN + GITHUB_TOKEN = ( + os.environ.get('COPILOT_GITHUB_TOKEN') + or os.environ.get('GH_TOKEN') + or os.environ.get('GITHUB_TOKEN') + ) + + @classmethod + def get_resolved_provider(cls) -> str: + """ + Resolve the effective LLM provider. + + In 'auto' mode the provider is inferred: + 1. Model name starts with 'claude-' → anthropic + 2. GitHub token present, no API key → github-copilot + 3. Otherwise → openai + """ + provider = (cls.LLM_PROVIDER or 'auto').lower() + if provider != 'auto': + return provider + + model = (cls.LLM_MODEL_NAME or '').lower() + if model.startswith('claude-'): + return 'anthropic' + if cls.GITHUB_TOKEN and not cls.LLM_API_KEY: + return 'github-copilot' + return 'openai' + + # ── Zep memory graph ───────────────────────────────────────────────────── ZEP_API_KEY = os.environ.get('ZEP_API_KEY') - - # 文件上传配置 - MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB + + # ── File uploads ───────────────────────────────────────────────────────── + MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50 MB UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads') ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'} - - # 文本处理配置 - DEFAULT_CHUNK_SIZE = 500 # 默认切块大小 - DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小 - - # OASIS模拟配置 + + # ── Text chunking ──────────────────────────────────────────────────────── + DEFAULT_CHUNK_SIZE = 500 + DEFAULT_CHUNK_OVERLAP = 50 + + # ── OASIS simulation ───────────────────────────────────────────────────── OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10')) - OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations') - - # OASIS平台可用动作配置 + OASIS_SIMULATION_DATA_DIR = os.path.join( + os.path.dirname(__file__), '../uploads/simulations' + ) + OASIS_TWITTER_ACTIONS = [ - 'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST' + 'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST', ] OASIS_REDDIT_ACTIONS = [ 'LIKE_POST', 'DISLIKE_POST', 'CREATE_POST', 'CREATE_COMMENT', 'LIKE_COMMENT', 'DISLIKE_COMMENT', 'SEARCH_POSTS', 'SEARCH_USER', - 'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE' + 'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE', ] - - # Report Agent配置 + + # ── Report Agent ───────────────────────────────────────────────────────── REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5')) - REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2')) + REPORT_AGENT_MAX_REFLECTION_ROUNDS = int( + os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2') + ) REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5')) - + @classmethod def validate(cls): - """验证必要配置""" + """Validate required configuration and return a list of error strings.""" + import logging + log = logging.getLogger('mirofish.config') errors = [] - if not cls.LLM_API_KEY: - errors.append("LLM_API_KEY 未配置") + provider = cls.get_resolved_provider() + + if provider == 'anthropic': + if not cls.LLM_API_KEY: + errors.append( + "LLM_API_KEY is required for the anthropic provider. " + "Get a key at https://console.anthropic.com/settings/keys" + ) + elif provider == 'openai': + if not cls.LLM_API_KEY: + errors.append("LLM_API_KEY is not configured (required for openai provider)") + elif provider == 'github-copilot': + if not cls.GITHUB_TOKEN: + errors.append( + "github-copilot provider requires a GitHub token. " + "Set COPILOT_GITHUB_TOKEN, GH_TOKEN, or GITHUB_TOKEN in .env." + ) + elif provider == 'ollama': + if cls.LLM_BASE_URL and '11434' not in cls.LLM_BASE_URL: + log.warning( + "LLM_BASE_URL (%s) doesn't include port 11434; " + "make sure Ollama is reachable at that address.", + cls.LLM_BASE_URL, + ) + elif provider == 'claude': + if shutil.which(cls.CLAUDE_CLI_PATH) is None: + errors.append( + f"Claude CLI not found at '{cls.CLAUDE_CLI_PATH}'. " + "Install it (https://claude.ai/download) or set CLAUDE_CLI_PATH in .env." + ) + if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") + errors.append( + "ZEP_API_KEY is not configured. " + "Get a free key at https://app.getzep.com/" + ) return errors - diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 57836c539..f4e24d978 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -15,10 +15,10 @@ from dataclasses import dataclass, field from datetime import datetime -from openai import OpenAI from zep_cloud.client import Zep from ..config import Config +from ..utils.llm_client import LLMClient from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader @@ -188,13 +188,14 @@ def __init__( self.api_key = api_key or Config.LLM_API_KEY self.base_url = base_url or Config.LLM_BASE_URL self.model_name = model_name or Config.LLM_MODEL_NAME - + if not self.api_key: raise ValueError("LLM_API_KEY 未配置") - - self.client = OpenAI( + + self.client = LLMClient( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + model=self.model_name ) # Zep客户端用于检索丰富上下文 @@ -526,8 +527,7 @@ def _generate_profile_with_llm( for attempt in range(max_attempts): try: - response = self.client.chat.completions.create( - model=self.model_name, + content, finish_reason = self.client.chat_with_finish_reason( messages=[ {"role": "system", "content": self._get_system_prompt(is_individual)}, {"role": "user", "content": prompt} @@ -536,11 +536,8 @@ def _generate_profile_with_llm( temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 # 不设置max_tokens,让LLM自由发挥 ) - - content = response.choices[0].message.content - + # 检查是否被截断(finish_reason不是'stop') - finish_reason = response.choices[0].finish_reason if finish_reason == 'length': logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...") content = self._fix_truncated_json(content) diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py index cc362508b..be8b0bcab 100644 --- a/backend/app/services/simulation_config_generator.py +++ b/backend/app/services/simulation_config_generator.py @@ -16,9 +16,8 @@ from dataclasses import dataclass, field, asdict from datetime import datetime -from openai import OpenAI - from ..config import Config +from ..utils.llm_client import LLMClient from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader @@ -230,13 +229,14 @@ def __init__( self.api_key = api_key or Config.LLM_API_KEY self.base_url = base_url or Config.LLM_BASE_URL self.model_name = model_name or Config.LLM_MODEL_NAME - + if not self.api_key: raise ValueError("LLM_API_KEY 未配置") - - self.client = OpenAI( + + self.client = LLMClient( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + model=self.model_name ) def generate_config( @@ -439,8 +439,7 @@ def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any for attempt in range(max_attempts): try: - response = self.client.chat.completions.create( - model=self.model_name, + content, finish_reason = self.client.chat_with_finish_reason( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} @@ -450,9 +449,6 @@ def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any # 不设置max_tokens,让LLM自由发挥 ) - content = response.choices[0].message.content - finish_reason = response.choices[0].finish_reason - # 检查是否被截断 if finish_reason == 'length': logger.warning(f"LLM输出被截断 (attempt {attempt+1})") diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index 8c35380d1..929087945 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -215,27 +215,30 @@ class SimulationRunner: '../../scripts' ) - # 内存中的运行状态 + # In-memory run state — all shared dicts protected by _state_lock + _state_lock: threading.RLock = threading.RLock() _run_states: Dict[str, SimulationRunState] = {} _processes: Dict[str, subprocess.Popen] = {} _action_queues: Dict[str, Queue] = {} _monitor_threads: Dict[str, threading.Thread] = {} - _stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄 - _stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄 - - # 图谱记忆更新配置 + _stdout_files: Dict[str, Any] = {} + _stderr_files: Dict[str, Any] = {} + + # Graph memory update config _graph_memory_enabled: Dict[str, bool] = {} # simulation_id -> enabled @classmethod def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: - """获取运行状态""" - if simulation_id in cls._run_states: - return cls._run_states[simulation_id] - - # 尝试从文件加载 + """Get run state for a simulation, loading from disk if not in memory.""" + with cls._state_lock: + if simulation_id in cls._run_states: + return cls._run_states[simulation_id] + + # File I/O done outside the lock state = cls._load_run_state(simulation_id) if state: - cls._run_states[simulation_id] = state + with cls._state_lock: + cls._run_states[simulation_id] = state return state @classmethod @@ -305,8 +308,10 @@ def _save_run_state(cls, state: SimulationRunState): with open(state_file, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) - - cls._run_states[state.simulation_id] = state + + # Update in-memory cache after file write + with cls._state_lock: + cls._run_states[state.simulation_id] = state @classmethod def start_simulation( @@ -446,23 +451,23 @@ def start_simulation( start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程 ) - # 保存文件句柄以便后续关闭 - cls._stdout_files[simulation_id] = main_log_file - cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr - state.process_pid = process.pid state.runner_status = RunnerStatus.RUNNING - cls._processes[simulation_id] = process + with cls._state_lock: + cls._stdout_files[simulation_id] = main_log_file + cls._stderr_files[simulation_id] = None + cls._processes[simulation_id] = process cls._save_run_state(state) - - # 启动监控线程 + + # Start monitor thread monitor_thread = threading.Thread( target=cls._monitor_simulation, args=(simulation_id,), daemon=True ) monitor_thread.start() - cls._monitor_threads[simulation_id] = monitor_thread + with cls._state_lock: + cls._monitor_threads[simulation_id] = monitor_thread logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}") @@ -483,9 +488,10 @@ def _monitor_simulation(cls, simulation_id: str): twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") - process = cls._processes.get(simulation_id) + with cls._state_lock: + process = cls._processes.get(simulation_id) state = cls.get_run_state(simulation_id) - + if not process or not state: return @@ -548,32 +554,30 @@ def _monitor_simulation(cls, simulation_id: str): cls._save_run_state(state) finally: - # 停止图谱记忆更新器 - if cls._graph_memory_enabled.get(simulation_id, False): + # Stop graph memory updater (outside lock — may do I/O) + with cls._state_lock: + graph_mem_on = cls._graph_memory_enabled.pop(simulation_id, False) + if graph_mem_on: try: ZepGraphMemoryManager.stop_updater(simulation_id) - logger.info(f"已停止图谱记忆更新: simulation_id={simulation_id}") + logger.info(f"Graph memory updater stopped: simulation_id={simulation_id}") except Exception as e: - logger.error(f"停止图谱记忆更新器失败: {e}") - cls._graph_memory_enabled.pop(simulation_id, None) - - # 清理进程资源 - cls._processes.pop(simulation_id, None) - cls._action_queues.pop(simulation_id, None) - - # 关闭日志文件句柄 - if simulation_id in cls._stdout_files: - try: - cls._stdout_files[simulation_id].close() - except Exception: - pass - cls._stdout_files.pop(simulation_id, None) - if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]: - try: - cls._stderr_files[simulation_id].close() - except Exception: - pass - cls._stderr_files.pop(simulation_id, None) + logger.error(f"Failed to stop graph memory updater: {e}") + + # Release process/queue refs and grab file handles — all under lock + with cls._state_lock: + cls._processes.pop(simulation_id, None) + cls._action_queues.pop(simulation_id, None) + stdout_fh = cls._stdout_files.pop(simulation_id, None) + stderr_fh = cls._stderr_files.pop(simulation_id, None) + + # Close file handles outside the lock + for fh in (stdout_fh, stderr_fh): + if fh: + try: + fh.close() + except Exception: + pass @classmethod def _read_action_log( diff --git a/backend/app/utils/copilot_auth.py b/backend/app/utils/copilot_auth.py new file mode 100644 index 000000000..eaaf22493 --- /dev/null +++ b/backend/app/utils/copilot_auth.py @@ -0,0 +1,191 @@ +""" +GitHub Copilot token exchange. + +Auth flow: + 1. Take a GitHub PAT (from COPILOT_GITHUB_TOKEN / GH_TOKEN / GITHUB_TOKEN) + 2. POST to api.github.com/copilot_internal/v2/token → short-lived Copilot token + 3. Derive the OpenAI-compatible base URL from the token's proxy-ep field + 4. Cache the token in memory and on disk; auto-refresh before expiry + +Note: Uses the same undocumented endpoint as VS Code Copilot. Tokens expire in ~30 min. +""" + +import json +import os +import re +import threading +import time +from typing import Optional + +from .logger import get_logger + +logger = get_logger('mirofish.copilot_auth') + +COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token" +DEFAULT_COPILOT_BASE_URL = "https://api.individual.githubcopilot.com" +TOKEN_REFRESH_MARGIN = 5 * 60 # refresh 5 min before expiry + +# Headers required to pass GitHub's Copilot client check +COPILOT_REQUEST_HEADERS = { + "Accept": "application/json", + "Editor-Version": "vscode/1.96.2", + "User-Agent": "GitHubCopilotChat/0.26.7", + "X-Github-Api-Version": "2025-04-01", +} + + +class CopilotTokenManager: + """ + Manages GitHub Copilot API token lifecycle. + Thread-safe; caches token in memory and optionally on disk. + + Usage: + mgr = CopilotTokenManager(github_token="ghp_xxx") + api_key = mgr.get_api_key() + base_url = mgr.get_base_url() + """ + + def __init__(self, github_token: str, cache_dir: Optional[str] = None): + self._github_token = github_token + self._lock = threading.Lock() + self._cached_token: Optional[str] = None + self._cached_base_url: str = DEFAULT_COPILOT_BASE_URL + self._expires_at: float = 0 + + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + self._cache_path = os.path.join(cache_dir, "copilot-token.json") + else: + self._cache_path = None + + self._load_disk_cache() + + def get_api_key(self) -> str: + """Return a valid Copilot API token, refreshing if necessary. Thread-safe.""" + with self._lock: + if self._is_valid(): + return self._cached_token + return self._refresh() + + def get_base_url(self) -> str: + """Return the Copilot API base URL. Ensures a valid token exists first.""" + self.get_api_key() + with self._lock: + return self._cached_base_url + + def _is_valid(self) -> bool: + return bool(self._cached_token) and time.time() < (self._expires_at - TOKEN_REFRESH_MARGIN) + + def _refresh(self) -> str: + import urllib.request + import urllib.error + + logger.info("Exchanging GitHub token for Copilot API token...") + headers = {**COPILOT_REQUEST_HEADERS, "Authorization": f"Bearer {self._github_token}"} + req = urllib.request.Request(COPILOT_TOKEN_URL, method="GET", headers=headers) + + try: + with urllib.request.urlopen(req, timeout=15) as resp: + data = json.loads(resp.read().decode()) + except urllib.error.HTTPError as e: + body = (e.read() or b"").decode("utf-8", errors="replace") + raise RuntimeError( + f"Copilot token exchange failed (HTTP {e.code}): {body}. " + "Check your GitHub token and Copilot subscription." + ) from e + except Exception as e: + raise RuntimeError(f"Copilot token exchange error: {e}") from e + + token = data.get("token", "") + if not token: + raise RuntimeError( + "Copilot token exchange returned empty token. " + "Check your GitHub Copilot subscription status." + ) + + expires_at = float(data.get("expires_at", time.time() + 1800)) + base_url = self._derive_base_url(token) + + with self._lock: + self._cached_token = token + self._expires_at = expires_at + self._cached_base_url = base_url + + self._save_disk_cache(token, expires_at) + logger.info( + "Copilot token acquired (expires in %d min), base_url=%s", + int((expires_at - time.time()) / 60), + base_url, + ) + return token + + @staticmethod + def _derive_base_url(token: str) -> str: + """Derive API base URL from the proxy-ep field embedded in the token.""" + match = re.search(r"(?:^|;)\s*proxy-ep=([^;\s]+)", token, re.IGNORECASE) + if not match: + return DEFAULT_COPILOT_BASE_URL + host = re.sub(r"^https?://", "", match.group(1).strip()) + host = re.sub(r"^proxy\.", "api.", host, flags=re.IGNORECASE) + return f"https://{host}" if host else DEFAULT_COPILOT_BASE_URL + + def _load_disk_cache(self): + if not self._cache_path or not os.path.exists(self._cache_path): + return + try: + with open(self._cache_path) as f: + data = json.load(f) + token = data.get("token", "") + expires_at = data.get("expires_at", 0) + if token and time.time() < (expires_at - TOKEN_REFRESH_MARGIN): + self._cached_token = token + self._expires_at = expires_at + self._cached_base_url = self._derive_base_url(token) + logger.debug("Loaded Copilot token from disk cache") + except Exception as e: + logger.debug("Could not load Copilot token cache: %s", e) + + def _save_disk_cache(self, token: str, expires_at: float): + if not self._cache_path: + return + try: + with open(self._cache_path, "w") as f: + json.dump({"token": token, "expires_at": expires_at, "updated_at": time.time()}, f) + except Exception as e: + logger.debug("Could not save Copilot token cache: %s", e) + + +def resolve_github_token() -> Optional[str]: + """Check COPILOT_GITHUB_TOKEN, GH_TOKEN, GITHUB_TOKEN in priority order.""" + for var in ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN"): + token = os.environ.get(var, "").strip() + if token: + logger.debug("GitHub token resolved from %s", var) + return token + return None + + +# Module-level singleton +_manager: Optional[CopilotTokenManager] = None +_manager_lock = threading.Lock() + + +def get_copilot_token_manager() -> CopilotTokenManager: + """Return the singleton CopilotTokenManager, creating it on first call.""" + global _manager + if _manager is not None: + return _manager + with _manager_lock: + if _manager is not None: + return _manager + github_token = resolve_github_token() + if not github_token: + raise RuntimeError( + "github-copilot provider requires a GitHub token. " + "Set COPILOT_GITHUB_TOKEN, GH_TOKEN, or GITHUB_TOKEN in .env." + ) + cache_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "uploads", ".copilot-cache" + ) + _manager = CopilotTokenManager(github_token=github_token, cache_dir=cache_dir) + return _manager diff --git a/backend/app/utils/llm_client.py b/backend/app/utils/llm_client.py index 6c1a81f49..9acbf76e7 100644 --- a/backend/app/utils/llm_client.py +++ b/backend/app/utils/llm_client.py @@ -1,103 +1,594 @@ """ -LLM客户端封装 -统一使用OpenAI格式调用 +LLM client wrapper. +Supports five providers: openai, anthropic, github-copilot, ollama, claude (CLI). """ import json import re -from typing import Optional, Dict, Any, List -from openai import OpenAI +import subprocess +from typing import Optional, Dict, Any, List, Tuple from ..config import Config +from .logger import get_logger +logger = get_logger('mirofish.llm_client') + + +# ── Token estimation (no tiktoken dependency) ──────────────────────────────── + +def estimate_tokens(text: str) -> int: + """ + Estimate token count without tiktoken. + + Rules: + - CJK characters: ~1.5 tokens each (UTF-8 Han chars are typically 1-2 tokens) + - ASCII words: ~1.3 tokens each + - +10% safety buffer + """ + if not text: + return 0 + cjk_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + rest = ''.join(c if c < '\u4e00' or c > '\u9fff' else ' ' for c in text) + ascii_words = len(rest.split()) + return int((cjk_chars * 1.5 + ascii_words * 1.3) * 1.1) + + +def estimate_messages_tokens(messages: List[Dict]) -> int: + """Estimate total token count for a message list.""" + total = 0 + for msg in messages: + content = msg.get("content") or "" + if isinstance(content, list): + content = " ".join( + c.get("text", "") for c in content if isinstance(c, dict) + ) + total += estimate_tokens(content) + 4 # role + per-message overhead + return total + 2 # conversation overhead + + +# ── Exceptions ─────────────────────────────────────────────────────────────── + +class LLMError(ValueError): + """LLM call error. Extends ValueError for backward compatibility.""" + pass + + +# ── Main client ────────────────────────────────────────────────────────────── class LLMClient: - """LLM客户端""" - + """ + Unified LLM client. Provider is resolved at construction time. + + Supported providers (set LLM_PROVIDER in .env): + openai — Any OpenAI-compatible API (OpenAI, Qwen, DeepSeek, etc.) + anthropic — Anthropic Claude SDK (direct API) + github-copilot — GitHub Copilot subscription via token exchange + ollama — Local Ollama via OpenAI-compatible endpoint + claude — Claude CLI subprocess (claude -p) + auto (default)— Auto-detect from model name and environment + """ + def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, - model: Optional[str] = None + model: Optional[str] = None, + provider: Optional[str] = None, ): + self.model = model or Config.LLM_MODEL_NAME + self.provider = (provider or Config.get_resolved_provider()).lower() + + if self.provider == 'github-copilot': + self._init_copilot(base_url) + elif self.provider == 'anthropic': + self._init_anthropic(api_key) + elif self.provider == 'ollama': + self._init_ollama(base_url) + elif self.provider == 'claude': + self._init_claude_cli() + else: + self._init_openai(api_key, base_url) + + # ── Initializers ───────────────────────────────────────────────────────── + + def _init_openai(self, api_key: Optional[str], base_url: Optional[str]): + from openai import OpenAI self.api_key = api_key or Config.LLM_API_KEY self.base_url = base_url or Config.LLM_BASE_URL - self.model = model or Config.LLM_MODEL_NAME - if not self.api_key: - raise ValueError("LLM_API_KEY 未配置") - - self.client = OpenAI( + raise LLMError( + "LLM_API_KEY is not configured. " + "Set it in .env or choose a different LLM_PROVIDER." + ) + self._client = OpenAI(api_key=self.api_key, base_url=self.base_url) + + def _init_anthropic(self, api_key: Optional[str]): + from anthropic import Anthropic + self.api_key = api_key or Config.LLM_API_KEY + if not self.api_key: + raise LLMError( + "LLM_API_KEY is not configured for Anthropic provider. " + "Get a key at https://console.anthropic.com/settings/keys" + ) + self._client = Anthropic(api_key=self.api_key) + + def _init_copilot(self, base_url: Optional[str]): + from openai import OpenAI + from .copilot_auth import ( + get_copilot_token_manager, + COPILOT_REQUEST_HEADERS, + ) + self._copilot_mgr = get_copilot_token_manager() + self.api_key = self._copilot_mgr.get_api_key() + self.base_url = base_url or self._copilot_mgr.get_base_url() + self._copilot_headers = { + "Editor-Version": COPILOT_REQUEST_HEADERS["Editor-Version"], + "User-Agent": COPILOT_REQUEST_HEADERS["User-Agent"], + "X-Github-Api-Version": COPILOT_REQUEST_HEADERS["X-Github-Api-Version"], + } + self._client = OpenAI( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + default_headers=self._copilot_headers, ) - + + def _init_ollama(self, base_url: Optional[str]): + from openai import OpenAI + self.api_key = "ollama" # Ollama doesn't require auth + self.base_url = base_url or Config.LLM_BASE_URL or "http://localhost:11434/v1" + self._client = OpenAI(api_key=self.api_key, base_url=self.base_url) + + def _init_claude_cli(self): + self.api_key = None + self.base_url = None + self._client = None # no SDK client needed + + # ── Internal helpers ───────────────────────────────────────────────────── + + @staticmethod + def _clean(content: str) -> str: + """Remove tags and strip whitespace.""" + return re.sub(r'[\s\S]*?', '', content).strip() + + @staticmethod + def _clean_json(content: str) -> str: + """Strip markdown code fences from a JSON response.""" + s = content.strip() + s = re.sub(r'^```(?:json)?\s*\n?', '', s, flags=re.IGNORECASE) + s = re.sub(r'\n?```\s*$', '', s) + return s.strip() + + def _extract_system( + self, messages: List[Dict[str, str]] + ) -> Tuple[Optional[str], List[Dict[str, str]]]: + """Split system messages from the rest (needed for Anthropic API).""" + system_parts, rest = [], [] + for m in messages: + if m.get("role") == "system": + system_parts.append(m["content"]) + else: + rest.append(m) + return ("\n\n".join(system_parts) if system_parts else None), rest + + def _flatten_messages( + self, messages: List[Dict[str, str]] + ) -> Tuple[str, str]: + """Flatten multi-turn messages into (system_prompt, user_prompt) strings for Claude CLI.""" + system_parts, convo_parts = [], [] + for m in messages: + role = m.get("role", "user") + content = m.get("content", "") + if role == "system": + system_parts.append(content) + elif role == "assistant": + convo_parts.append(f"[Assistant]\n{content}") + else: + convo_parts.append(f"[User]\n{content}") + return "\n\n".join(system_parts), "\n\n".join(convo_parts) + + def _refresh_copilot_client(self): + """Refresh the Copilot token if expired and rebuild the OpenAI client.""" + from openai import OpenAI + new_key = self._copilot_mgr.get_api_key() + if new_key != self.api_key: + self.api_key = new_key + self.base_url = self._copilot_mgr.get_base_url() + self._client = OpenAI( + api_key=self.api_key, + base_url=self.base_url, + default_headers=self._copilot_headers, + ) + + # ── Provider-specific call implementations ─────────────────────────────── + + def _call_openai_sdk( + self, + messages: List[Dict[str, str]], + temperature: float, + max_tokens: Optional[int], + response_format: Optional[Dict] = None, + ) -> str: + """Shared call path for openai / ollama / github-copilot.""" + if self.provider == 'github-copilot': + self._refresh_copilot_client() + + kwargs: Dict[str, Any] = { + "model": self.model, + "messages": messages, + "temperature": temperature, + } + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens + if response_format: + kwargs["response_format"] = response_format + + response = self._client.chat.completions.create(**kwargs) + return response.choices[0].message.content + + def _call_openai_sdk_with_finish( + self, + messages: List[Dict[str, str]], + temperature: float, + max_tokens: Optional[int], + response_format: Optional[Dict] = None, + ) -> Tuple[str, str]: + if self.provider == 'github-copilot': + self._refresh_copilot_client() + + kwargs: Dict[str, Any] = { + "model": self.model, + "messages": messages, + "temperature": temperature, + } + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens + if response_format: + kwargs["response_format"] = response_format + + response = self._client.chat.completions.create(**kwargs) + choice = response.choices[0] + return choice.message.content, (choice.finish_reason or 'stop') + + def _call_anthropic_sdk( + self, + messages: List[Dict[str, str]], + temperature: float, + max_tokens: int, + json_mode: bool = False, + ) -> str: + system_text, user_messages = self._extract_system(messages) + if json_mode: + suffix = "\n\nIMPORTANT: Respond with valid JSON only — no markdown, no explanation." + system_text = (system_text + suffix) if system_text else suffix.strip() + + kwargs: Dict[str, Any] = { + "model": self.model, + "messages": user_messages, + "temperature": temperature, + "max_tokens": max_tokens, + } + if system_text: + kwargs["system"] = system_text + + response = self._client.messages.create(**kwargs) + return response.content[0].text + + def _call_anthropic_sdk_with_finish( + self, + messages: List[Dict[str, str]], + temperature: float, + max_tokens: int, + json_mode: bool = False, + ) -> Tuple[str, str]: + content = self._call_anthropic_sdk(messages, temperature, max_tokens, json_mode) + # Anthropic stop_reason: 'end_turn' → 'stop', 'max_tokens' → 'length' + return content, 'stop' # finish reason from _call_anthropic_sdk always ends normally + + def _call_claude_cli( + self, + messages: List[Dict[str, str]], + max_tokens: int = 4096, + ) -> str: + system_prompt, user_prompt = self._flatten_messages(messages) + + cmd = [ + Config.CLAUDE_CLI_PATH, + "-p", "--bare", + "--output-format", "json", + ] + if system_prompt: + cmd.extend(["--system-prompt", system_prompt]) + if self.model: + cmd.extend(["--model", self.model]) + + try: + result = subprocess.run( + cmd, + input=user_prompt, + capture_output=True, + text=True, + timeout=300, + ) + except subprocess.TimeoutExpired: + raise LLMError("Claude CLI timed out after 300 seconds") + except FileNotFoundError: + raise LLMError( + f"Claude CLI not found at '{Config.CLAUDE_CLI_PATH}'. " + "Install claude CLI or set CLAUDE_CLI_PATH in .env." + ) + + if result.returncode != 0: + stderr = (result.stderr or "").strip() or "(no stderr)" + raise LLMError(f"Claude CLI exit code {result.returncode}: {stderr}") + + stdout = result.stdout.strip() + if not stdout: + raise LLMError("Claude CLI returned empty output") + + try: + envelope = json.loads(stdout) + if isinstance(envelope, dict): + if envelope.get("is_error"): + raise LLMError(f"Claude CLI error: {envelope.get('result', 'unknown')}") + return str(envelope.get("result", stdout)) + except json.JSONDecodeError: + pass + + return stdout + + # ── Public API ──────────────────────────────────────────────────────────── + def chat( self, messages: List[Dict[str, str]], temperature: float = 0.7, max_tokens: int = 4096, - response_format: Optional[Dict] = None + response_format: Optional[Dict] = None, ) -> str: """ - 发送聊天请求 - + Send a chat request and return the response text. + Args: - messages: 消息列表 - temperature: 温度参数 - max_tokens: 最大token数 - response_format: 响应格式(如JSON模式) - + messages: Message list (OpenAI format — system/user/assistant roles) + temperature: Sampling temperature + max_tokens: Maximum output tokens + response_format: {"type": "json_object"} for JSON mode (OpenAI providers). + Anthropic and Claude CLI adapt automatically. + + Returns: + Model response as a string. + """ + json_mode = bool(response_format and response_format.get("type") == "json_object") + + if self.provider == 'anthropic': + content = self._call_anthropic_sdk(messages, temperature, max_tokens, json_mode) + elif self.provider == 'claude': + content = self._call_claude_cli(messages, max_tokens) + else: + content = self._call_openai_sdk(messages, temperature, max_tokens, response_format) + + return self._clean(content) + + def chat_with_finish_reason( + self, + messages: List[Dict[str, str]], + temperature: float = 0.7, + max_tokens: Optional[int] = None, + response_format: Optional[Dict] = None, + ) -> Tuple[str, str]: + """ + Like chat(), but also returns the finish reason. + + finish_reason values (normalised to OpenAI format): + 'stop' — completed normally + 'length' — truncated at max_tokens + Returns: - 模型响应文本 + (response_text, finish_reason) + """ + json_mode = bool(response_format and response_format.get("type") == "json_object") + effective_max = max_tokens or 4096 + + if self.provider == 'anthropic': + system_text, user_messages = self._extract_system(messages) + if json_mode: + suffix = "\n\nIMPORTANT: Respond with valid JSON only — no markdown, no explanation." + system_text = (system_text + suffix) if system_text else suffix.strip() + + kwargs: Dict[str, Any] = { + "model": self.model, + "messages": user_messages, + "temperature": temperature, + "max_tokens": effective_max, + } + if system_text: + kwargs["system"] = system_text + + response = self._client.messages.create(**kwargs) + content = response.content[0].text + stop_reason = getattr(response, 'stop_reason', 'end_turn') + finish_reason = 'length' if stop_reason == 'max_tokens' else 'stop' + elif self.provider == 'claude': + content = self._call_claude_cli(messages, effective_max) + finish_reason = 'stop' + else: + if self.provider == 'github-copilot': + self._refresh_copilot_client() + kwargs = { + "model": self.model, + "messages": messages, + "temperature": temperature, + } + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens + if response_format: + kwargs["response_format"] = response_format + response = self._client.chat.completions.create(**kwargs) + choice = response.choices[0] + content = choice.message.content + finish_reason = choice.finish_reason or 'stop' + + return self._clean(content), finish_reason + + def chat_with_tools( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + temperature: float = 0.7, + max_tokens: int = 4096, + ) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]: """ - kwargs = { + Send a chat request with native function/tool calling. + + Args: + messages: Message list (supports role=tool messages for tool results) + tools: Tool definitions in OpenAI format: + [{"type": "function", "function": {"name", "description", "parameters"}}] + temperature: Sampling temperature + max_tokens: Maximum output tokens + + Returns: + (content, tool_calls) + - content: Text content (None or thinking text when tool calls are present) + - tool_calls: List of {"id", "name", "parameters"} dicts, or None if no calls + """ + if self.provider == 'anthropic': + return self._chat_with_tools_anthropic(messages, tools, temperature, max_tokens) + elif self.provider == 'claude': + # Claude CLI doesn't support native tool calling — fall back to text + logger.warning("chat_with_tools not supported for claude CLI provider; falling back to text") + content = self._call_claude_cli(messages, max_tokens) + return self._clean(content), None + else: + return self._chat_with_tools_openai(messages, tools, temperature, max_tokens) + + def _chat_with_tools_openai( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + temperature: float, + max_tokens: int, + ) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]: + if self.provider == 'github-copilot': + self._refresh_copilot_client() + + response = self._client.chat.completions.create( + model=self.model, + messages=messages, + tools=tools, + tool_choice="auto", + temperature=temperature, + max_tokens=max_tokens, + ) + message = response.choices[0].message + content = self._clean(message.content or "") + + if not message.tool_calls: + return content, None + + parsed = [] + for tc in message.tool_calls: + try: + params = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError): + params = {} + parsed.append({"id": tc.id, "name": tc.function.name, "parameters": params}) + return content, parsed + + def _chat_with_tools_anthropic( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + temperature: float, + max_tokens: int, + ) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]: + """Translate OpenAI tool format to Anthropic tool_use format.""" + system_text, user_messages = self._extract_system(messages) + + # Convert OpenAI tools → Anthropic tools + anthropic_tools = [] + for t in tools: + fn = t.get("function", {}) + anthropic_tools.append({ + "name": fn.get("name", ""), + "description": fn.get("description", ""), + "input_schema": fn.get("parameters", {"type": "object", "properties": {}}), + }) + + # Convert tool result messages (role=tool → role=user with tool_result content) + converted = [] + for m in user_messages: + if m.get("role") == "tool": + converted.append({ + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": m.get("tool_call_id", ""), + "content": m.get("content", ""), + }], + }) + else: + converted.append(m) + + kwargs: Dict[str, Any] = { "model": self.model, - "messages": messages, + "messages": converted, + "tools": anthropic_tools, "temperature": temperature, "max_tokens": max_tokens, } - - if response_format: - kwargs["response_format"] = response_format - - response = self.client.chat.completions.create(**kwargs) - content = response.choices[0].message.content - # 部分模型(如MiniMax M2.5)会在content中包含思考内容,需要移除 - content = re.sub(r'[\s\S]*?', '', content).strip() - return content - + if system_text: + kwargs["system"] = system_text + + response = self._client.messages.create(**kwargs) + + text_parts, tool_calls = [], [] + for block in response.content: + if getattr(block, "type", None) == "text": + text_parts.append(block.text) + elif getattr(block, "type", None) == "tool_use": + tool_calls.append({ + "id": block.id, + "name": block.name, + "parameters": block.input or {}, + }) + + content = self._clean(" ".join(text_parts)) or None + return content, (tool_calls if tool_calls else None) + def chat_json( self, messages: List[Dict[str, str]], temperature: float = 0.3, - max_tokens: int = 4096 + max_tokens: int = 4096, ) -> Dict[str, Any]: """ - 发送聊天请求并返回JSON - - Args: - messages: 消息列表 - temperature: 温度参数 - max_tokens: 最大token数 - + Send a chat request and parse the response as JSON. + Returns: - 解析后的JSON对象 + Parsed JSON dict. + + Raises: + LLMError: If the response cannot be parsed as valid JSON. """ - response = self.chat( - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - response_format={"type": "json_object"} - ) - # 清理markdown代码块标记 - cleaned_response = response.strip() - cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE) - cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response) - cleaned_response = cleaned_response.strip() + if self.provider == 'claude': + # Inject JSON instruction into the last non-system message + msgs = [m.copy() for m in messages] + for i in range(len(msgs) - 1, -1, -1): + if msgs[i].get("role") != "system": + msgs[i]["content"] += "\n\nRespond with valid JSON only, no markdown code fences." + break + raw = self._call_claude_cli(msgs, max_tokens) + else: + raw = self.chat( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + response_format={"type": "json_object"}, + ) + cleaned = self._clean_json(raw) try: - return json.loads(cleaned_response) + return json.loads(cleaned) except json.JSONDecodeError: - raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") - + raise LLMError(f"LLM returned invalid JSON: {cleaned[:500]}") diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f5361d53..e9f534d76 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ # LLM 相关 "openai>=1.0.0", + "anthropic>=0.40.0", # Zep Cloud "zep-cloud==3.13.0", diff --git a/backend/requirements.txt b/backend/requirements.txt index 4f146296b..d044d8825 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -12,6 +12,8 @@ flask-cors>=6.0.0 # ============= LLM 相关 ============= # OpenAI SDK(统一使用 OpenAI 格式调用 LLM) openai>=1.0.0 +# Anthropic SDK(支持 Claude 模型直接调用) +anthropic>=0.40.0 # ============= Zep Cloud ============= zep-cloud==3.13.0 diff --git a/frontend/src/components/GraphPanel.vue b/frontend/src/components/GraphPanel.vue index 314c966e4..06f5ce79d 100644 --- a/frontend/src/components/GraphPanel.vue +++ b/frontend/src/components/GraphPanel.vue @@ -4,11 +4,11 @@ Graph Relationship Visualization
- -
@@ -27,7 +27,7 @@ - {{ isSimulating ? 'GraphRAG长短期记忆实时更新中' : '实时更新中...' }} + {{ isSimulating ? t('graph.simulationUpdating') : t('graph.updating') }} @@ -39,8 +39,8 @@ - 还有少量内容处理中,建议稍后手动刷新图谱 -