From c69e7473f714787aa1ea410c3dfd47da8a3525c4 Mon Sep 17 00:00:00 2001 From: futuremeng Date: Mon, 16 Mar 2026 10:18:01 +0800 Subject: [PATCH 01/68] fix(ollama): use 127.0.0.1 as default host and improve error diagnostics (#1480) align default host with Ollama server and SDK default (127.0.0.1) --- src/copaw/cli/providers_cmd.py | 2 +- src/copaw/providers/ollama_provider.py | 10 +++++----- tests/unit/providers/test_ollama_provider.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/copaw/cli/providers_cmd.py b/src/copaw/cli/providers_cmd.py index ef72d0ed9..85ee03710 100644 --- a/src/copaw/cli/providers_cmd.py +++ b/src/copaw/cli/providers_cmd.py @@ -60,7 +60,7 @@ def _get_ollama_host() -> str: manager = _manager() provider = manager.get_provider("ollama") if provider is None or not provider.base_url: - return "http://localhost:11434" + return "http://127.0.0.1:11434" return provider.base_url diff --git a/src/copaw/providers/ollama_provider.py b/src/copaw/providers/ollama_provider.py index 159960671..893f46ad1 100644 --- a/src/copaw/providers/ollama_provider.py +++ b/src/copaw/providers/ollama_provider.py @@ -22,7 +22,7 @@ class OllamaProvider(Provider): def model_post_init(self, __context: Any) -> None: if not self.base_url: # type: ignore self.base_url = ( - os.environ.get("OLLAMA_HOST") or "http://localhost:11434" + os.environ.get("OLLAMA_HOST") or "http://127.0.0.1:11434" ) if self.base_url.endswith("/v1"): # For backwards compatibility, if the URL ends with /v1, @@ -76,10 +76,10 @@ async def check_connection(self, timeout: float = 5) -> tuple[bool, str]: return False, "Ollama Python SDK is not installed" except ConnectionError: return False, f"Failed to connect to Ollama at `{self.base_url}`" - except Exception: + except Exception as exc: return ( False, - f"Unknown exception when connecting to `{self.base_url}`", + f"Failed to connect to Ollama at `{self.base_url}`: {exc}", ) async def fetch_models(self, timeout: float = 5) -> List[ModelInfo]: @@ -113,8 +113,8 @@ async def check_model_connection( return False, "Ollama Python SDK is not installed" except ConnectionError: return False, f"Failed to connect to Ollama at `{self.base_url}`" - except Exception: - return False, f"Unknown exception when connecting to `{target}`" + except Exception as exc: + return False, f"Model connection failed for `{target}`: {exc}" async def add_model( self, diff --git a/tests/unit/providers/test_ollama_provider.py b/tests/unit/providers/test_ollama_provider.py index 3b9c2919a..89c40b9f6 100644 --- a/tests/unit/providers/test_ollama_provider.py +++ b/tests/unit/providers/test_ollama_provider.py @@ -79,7 +79,7 @@ async def list(self): ok, msg = await provider.check_connection(timeout=1.0) assert ok is False - assert msg == f"Unknown exception when connecting to `{provider.base_url}`" + assert msg == f"Failed to connect to Ollama at `{provider.base_url}`: boom" async def test_fetch_models_normalizes_and_deduplicates(monkeypatch) -> None: @@ -171,7 +171,7 @@ async def chat(self, **kwargs): ok, msg = await provider.check_model_connection("qwen2:7b", timeout=4.0) assert ok is False - assert msg == "Unknown exception when connecting to `qwen2:7b`" + assert msg == "Model connection failed for `qwen2:7b`: failed" async def test_update_config_updates_non_none_values_and_get_info( From 84e4fdc3d29bc5cbb5c816116c7fb8c188344a74 Mon Sep 17 00:00:00 2001 From: toby <38551968+toby1123yjh@users.noreply.github.com> Date: Mon, 16 Mar 2026 10:56:23 +0800 Subject: [PATCH 02/68] fix(skills): add missing channels to cron SKILL.md --channel options (#1541) --- src/copaw/agents/skills/cron/SKILL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/copaw/agents/skills/cron/SKILL.md b/src/copaw/agents/skills/cron/SKILL.md index 7f0a93d98..ee6464708 100644 --- a/src/copaw/agents/skills/cron/SKILL.md +++ b/src/copaw/agents/skills/cron/SKILL.md @@ -67,7 +67,7 @@ copaw cron create \ - `--type`:任务类型(text 或 agent) - `--name`:任务名称 - `--cron`:cron 表达式(**UTC 时间**,如用户在 UTC+8 希望每天 9:00 执行,需填 `"0 1 * * *"`) -- `--channel`:目标频道(imessage / discord / dingtalk / qq / console) +- `--channel`:目标频道(console / feishu / dingtalk / discord / qq / telegram / imessage / matrix / mattermost 等)。用户未指定时,使用"当前的channel"的值 - `--target-user`:用户标识 - `--target-session`:会话标识 - `--text`:消息内容(text 类型)或提问内容(agent 类型) From 6f8e02fdb0b8cf9b2eee981190d63c8a78b3a1db Mon Sep 17 00:00:00 2001 From: Yuchang Sun <52027540+hiyuchang@users.noreply.github.com> Date: Mon, 16 Mar 2026 10:58:17 +0800 Subject: [PATCH 03/68] feat(skills): add a guidance skill for documentation of copaw (#1522) --- src/copaw/agents/skills/guidance/SKILL.md | 140 ++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 src/copaw/agents/skills/guidance/SKILL.md diff --git a/src/copaw/agents/skills/guidance/SKILL.md b/src/copaw/agents/skills/guidance/SKILL.md new file mode 100644 index 000000000..bbe2f7530 --- /dev/null +++ b/src/copaw/agents/skills/guidance/SKILL.md @@ -0,0 +1,140 @@ +--- +name: guidance +description: "回答用户关于 CoPaw 安装与配置的问题:优先定位并阅读本地文档,再提炼答案;若本地信息不足,兜底访问官网文档。" +metadata: + { + "copaw": + { + "emoji": "🧭", + "requires": {} + } + } +--- + +# CoPaw 安装与配置问答指南 + +当用户询问 **CoPaw 的安装、初始化、环境配置、依赖要求、常见配置项** 时,使用本 skill。 + +核心原则: + +- 先查本地文档,再回答 +- 回答要基于已读到的内容,不臆测 +- 回答语言与用户提问语言保持一致 + +## 标准流程 + + +### 第一步:定位文档位置 + +**查找记忆中的文档目录** + +首先你可以查看memory中是否有文档目录,如果有则直接使用,如果没有则继续执行下一步。 + +```bash +# 获取memory中的文档目录 +DOC_DIR=$(find ~/.copaw/memory/ -type d -name "docs") +``` + +如果 memory 中没有文档目录,则继续执行下面的逻辑。 + +**检查项目源码中的文档目录** + +执行以下脚本逻辑来获取变量 $COPAW_ROOT: + +```bash +# 获取二进制绝对路径 +COP_PATH=$(which copaw 2>/dev/null || whereis copaw | awk '{print $2}') + +# 逻辑推导:如果路径包含 .copaw/bin/copaw,则根目录在其上三层 +# 例如:/path/to/CoPaw/.copaw/bin/copaw -> /path/to/CoPaw +if [[ "$COP_PATH" == *".copaw/bin/copaw" ]]; then + COPAW_ROOT=$(echo "$COP_PATH" | sed 's/\/\.copaw\/bin\/copaw//') +else + # 兜底:尝试获取所在目录的父目录 + COPAW_ROOT=$(dirname $(dirname "$COP_PATH") 2>/dev/null || echo ".") +fi + +echo "Detected CoPaw Root: $COPAW_ROOT" +``` + +验证并列出文档目录: +使用推导出的 $COPAW_ROOT 定位文档: + +```bash +# 组合标准文档路径 +="$COPAW_ROOT/website/public/docs/" + +# 检查路径是否存在并列出文件 +if [ -d "$DOC_DIR" ]; then + find "$DOC_DIR" -type f -name "*.md" | head -n 100 +else + # 如果推导路径不对,执行全局模糊搜索 + find "$COPAW_ROOT" -type d -name "docs" | grep "website/public/docs" +fi +``` +**如果项目文档不存在,搜索工作目录** + +如果还是找不到文档,搜索 copaw 安装路径下的可用文档内容: + +```bash +# 寻找 faq.en.md 或 config.zh.md 等特征文件 +FILE_PATH=$(find . -type f -name "faq.en.md" -o -name "config.zh.md" | head -n 1) +if [ -n "$FILE_PATH" ]; then + # 使用 dirname 获取该文件所在的目录 + DOC_DIR=$(dirname "$FILE_PATH") +fi +``` +如果找到了文档目录,请你记录在 memory 中,格式为: + +```markdown +# 文档目录 +$DOC_DIR = +``` + +### 第二步:文档检索与匹配 + +文档文件命名格式为 `..md`(如 `config.zh.md`、`config.en.md`、`quickstart.zh.md`)。 + +使用 find 命令在目标目录中列出所有符合后缀的文档,并根据文件名关键字(如 install, env, setup)锁定目标作为 。 + +```bash +# 列出所有符合后缀的文档 +find $DOC_DIR -type f -name "*.md" +``` + +如果没有合适的文档,则在下一步阅读所有文档内容。 + + +### 第三步:阅读文档内容 + +找到候选文档后,读取并确认与问题相关的段落。可使用: + +- `cat ` +- `file_reader` skill(推荐用于更长文档或分段读取) + +如果文档很长,优先读取和问题最相关的章节(安装步骤、配置项、示例命令、注意事项、版本要求)。 + +### 第四步:提取信息并作答 + +从文档中提取关键信息,组织成可执行答案: + +- 先给直接结论 +- 再给步骤/命令/配置示例 +- 补充必要前置条件与常见坑 + +语言要求:回答语言必须与用户提问语言一致(中文问就中文答,英文问就英文答)。 + +### 第五步(可选):官网检索 + +若前面步骤无法完成(本地无文档、文档缺失、信息不足),使用官网作为兜底: + +- http://copaw.agentscope.io/ + +基于官网可获得内容继续回答,并在答案中明确说明该结论来自官网文档。 + +## 输出质量要求 + +- 不编造不存在的配置项或命令 +- 遇到版本差异时,明确标注“需以当前文档版本为准” +- 涉及路径、命令、配置键时,尽量给可复制的原文片段 +- 若信息仍不足,明确缺口并告诉用户还需要哪类信息(例如操作系统、安装方式、报错日志) From dd10136fda918473fd796f57eb9350cc84827fca Mon Sep 17 00:00:00 2001 From: wwx814 <13960486856@163.com> Date: Mon, 16 Mar 2026 11:18:50 +0800 Subject: [PATCH 04/68] feat(channels): add XiaoYi channel support (#1213) --- console/src/api/types/channel.ts | 12 +- console/src/locales/en.json | 1 + console/src/locales/zh.json | 1 + .../Channels/components/ChannelDrawer.tsx | 38 + .../Control/Channels/components/constants.ts | 1 + .../src/pages/Control/Channels/useChannels.ts | 1 + src/copaw/app/channels/registry.py | 1 + src/copaw/app/channels/schema.py | 1 + src/copaw/app/channels/xiaoyi/__init__.py | 17 + src/copaw/app/channels/xiaoyi/auth.py | 67 + src/copaw/app/channels/xiaoyi/channel.py | 1456 +++++++++++++++++ src/copaw/app/channels/xiaoyi/constants.py | 22 + src/copaw/config/config.py | 12 + website/public/docs/channels.en.md | 23 + website/public/docs/channels.zh.md | 23 + 15 files changed, 1675 insertions(+), 1 deletion(-) create mode 100644 src/copaw/app/channels/xiaoyi/__init__.py create mode 100644 src/copaw/app/channels/xiaoyi/auth.py create mode 100644 src/copaw/app/channels/xiaoyi/channel.py create mode 100644 src/copaw/app/channels/xiaoyi/constants.py diff --git a/console/src/api/types/channel.ts b/console/src/api/types/channel.ts index 8c96659cd..f6b1393fe 100644 --- a/console/src/api/types/channel.ts +++ b/console/src/api/types/channel.ts @@ -81,6 +81,14 @@ export interface VoiceChannelConfig extends BaseChannelConfig { welcome_greeting: string; } +export interface XiaoYiConfig extends BaseChannelConfig { + ak: string; + sk: string; + agent_id: string; + ws_url: string; + task_timeout_ms?: number; +} + export interface ChannelConfig { imessage: IMessageChannelConfig; discord: DiscordConfig; @@ -92,6 +100,7 @@ export interface ChannelConfig { matrix: MatrixConfig; console: ConsoleConfig; voice: VoiceChannelConfig; + xiaoyi: XiaoYiConfig; } export type SingleChannelConfig = @@ -104,4 +113,5 @@ export type SingleChannelConfig = | TelegramConfig | MQTTConfig | MatrixConfig - | VoiceChannelConfig; + | VoiceChannelConfig + | XiaoYiConfig; diff --git a/console/src/locales/en.json b/console/src/locales/en.json index c235a5164..0c4bfc218 100644 --- a/console/src/locales/en.json +++ b/console/src/locales/en.json @@ -280,6 +280,7 @@ "welcomeGreeting": "Welcome Greeting", "voiceSetupGuide": "Set up a Twilio account, purchase a phone number, then enter your credentials below. Find your Account SID and Auth Token on the Twilio Console dashboard. The Phone Number SID is listed under Phone Numbers → Active Numbers.", "voiceSetupLink": "Open Twilio Console", + "xiaoyiSetupGuide": "Please create an agent on Huawei Developer Platform and get AK/SK and Agent ID. AK/SK can be found in the credential management page.", "filterAll": "All", "builtin": "Built-in", "custom": "Custom" diff --git a/console/src/locales/zh.json b/console/src/locales/zh.json index 208191bbf..168aadff4 100644 --- a/console/src/locales/zh.json +++ b/console/src/locales/zh.json @@ -280,6 +280,7 @@ "welcomeGreeting": "欢迎语", "voiceSetupGuide": "请先注册 Twilio 账户并购买电话号码,然后在下方填写凭据。Account SID 和 Auth Token 可在 Twilio 控制台首页找到。Phone Number SID 在 Phone Numbers → Active Numbers 中查看。", "voiceSetupLink": "打开 Twilio 控制台", + "xiaoyiSetupGuide": "请在华为开发者平台创建智能体并获取 AK/SK 和 Agent ID。AK/SK 可在凭证管理页面找到。", "filterAll": "全部", "builtin": "内置", "custom": "自定义" diff --git a/console/src/pages/Control/Channels/components/ChannelDrawer.tsx b/console/src/pages/Control/Channels/components/ChannelDrawer.tsx index 193ea0666..1b0d7e511 100644 --- a/console/src/pages/Control/Channels/components/ChannelDrawer.tsx +++ b/console/src/pages/Control/Channels/components/ChannelDrawer.tsx @@ -21,6 +21,7 @@ const CHANNELS_WITH_ACCESS_CONTROL: ChannelKey[] = [ "feishu", "mattermost", "matrix", + "xiaoyi", ]; interface ChannelDrawerProps { @@ -48,6 +49,8 @@ const CHANNEL_DOC_URLS: Partial> = { mqtt: "https://copaw.agentscope.io/docs/channels/#MQTT", mattermost: "https://copaw.agentscope.io/docs/channels/#Mattermost", matrix: "https://copaw.agentscope.io/docs/channels/#Matrix", + xiaoyi: + "https://developer.huawei.com/consumer/cn/doc/service/openclaw-0000002518410344", }; const twilioConsoleUrl = "https://console.twilio.com"; @@ -434,6 +437,41 @@ export function ChannelDrawer({ ); + case "xiaoyi": + return ( + <> + + + + + + + + + + + + + + + ); default: return null; } diff --git a/console/src/pages/Control/Channels/components/constants.ts b/console/src/pages/Control/Channels/components/constants.ts index 06b6f42f9..0acf90b72 100644 --- a/console/src/pages/Control/Channels/components/constants.ts +++ b/console/src/pages/Control/Channels/components/constants.ts @@ -14,6 +14,7 @@ export const CHANNEL_LABELS: Record = { matrix: "Matrix", console: "Console", voice: "Twilio", + xiaoyi: "XiaoYi", }; // Get channel label - returns built-in label or formatted custom name diff --git a/console/src/pages/Control/Channels/useChannels.ts b/console/src/pages/Control/Channels/useChannels.ts index 323248e80..812a2a42d 100644 --- a/console/src/pages/Control/Channels/useChannels.ts +++ b/console/src/pages/Control/Channels/useChannels.ts @@ -40,6 +40,7 @@ export function useChannels() { "telegram", "qq", "matrix", + "xiaoyi", ], [], ); diff --git a/src/copaw/app/channels/registry.py b/src/copaw/app/channels/registry.py index 8cbd9d1b6..98d42e965 100644 --- a/src/copaw/app/channels/registry.py +++ b/src/copaw/app/channels/registry.py @@ -28,6 +28,7 @@ "console": (".console", "ConsoleChannel"), "matrix": (".matrix", "MatrixChannel"), "voice": (".voice", "VoiceChannel"), + "xiaoyi": (".xiaoyi", "XiaoYiChannel"), } # Required channels must load; failures are raised, not skipped. diff --git a/src/copaw/app/channels/schema.py b/src/copaw/app/channels/schema.py index 11dd48157..f7dedd59e 100644 --- a/src/copaw/app/channels/schema.py +++ b/src/copaw/app/channels/schema.py @@ -38,6 +38,7 @@ def to_handle(self) -> str: "mqtt", "console", "voice", + "xiaoyi", ) # ChannelType is str to allow plugin channels; built-in set above. diff --git a/src/copaw/app/channels/xiaoyi/__init__.py b/src/copaw/app/channels/xiaoyi/__init__.py new file mode 100644 index 000000000..77a4fc983 --- /dev/null +++ b/src/copaw/app/channels/xiaoyi/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +"""XiaoYi channel module. + +XiaoYi (小艺) is Huawei's voice assistant platform. +This module implements A2A (Agent-to-Agent) protocol support. +""" + +from .channel import XiaoYiChannel +from .auth import generate_auth_headers, XiaoYiAuth +from .constants import DEFAULT_WS_URL + +__all__ = [ + "XiaoYiChannel", + "generate_auth_headers", + "XiaoYiAuth", + "DEFAULT_WS_URL", +] diff --git a/src/copaw/app/channels/xiaoyi/auth.py b/src/copaw/app/channels/xiaoyi/auth.py new file mode 100644 index 000000000..64e65babb --- /dev/null +++ b/src/copaw/app/channels/xiaoyi/auth.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +""" +XiaoYi authentication using AK/SK mechanism. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import time +from typing import Dict + + +def generate_signature(sk: str, timestamp: str) -> str: + """Generate HMAC-SHA256 signature. + + Format: Base64(HMAC-SHA256(secretKey, timestamp)) + + Args: + sk: Secret Key + timestamp: Timestamp as string (milliseconds) + + Returns: + Base64 encoded signature + """ + hmac_obj = hmac.new(sk.encode(), timestamp.encode(), hashlib.sha256) + return base64.b64encode(hmac_obj.digest()).decode() + + +def generate_auth_headers(ak: str, sk: str, agent_id: str) -> Dict[str, str]: + """Generate WebSocket authentication headers. + + Args: + ak: Access Key + sk: Secret Key + agent_id: Agent ID + + Returns: + Dict of headers for WebSocket connection + """ + timestamp = str(int(time.time() * 1000)) + signature = generate_signature(sk, timestamp) + + return { + "x-access-key": ak, + "x-sign": signature, + "x-ts": timestamp, + "x-agent-id": agent_id, + } + + +class XiaoYiAuth: + """XiaoYi authentication helper class.""" + + def __init__(self, ak: str, sk: str, agent_id: str): + self.ak = ak + self.sk = sk + self.agent_id = agent_id + + def get_auth_headers(self) -> Dict[str, str]: + """Get authentication headers for WebSocket connection.""" + return generate_auth_headers(self.ak, self.sk, self.agent_id) + + def generate_signature(self, timestamp: str) -> str: + """Generate signature for given timestamp.""" + return generate_signature(self.sk, timestamp) diff --git a/src/copaw/app/channels/xiaoyi/channel.py b/src/copaw/app/channels/xiaoyi/channel.py new file mode 100644 index 000000000..8c39e980e --- /dev/null +++ b/src/copaw/app/channels/xiaoyi/channel.py @@ -0,0 +1,1456 @@ +# -*- coding: utf-8 -*- +"""XiaoYi Channel implementation. + +XiaoYi uses A2A (Agent-to-Agent) protocol over WebSocket. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import uuid +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +import aiohttp + +from agentscope_runtime.engine.schemas.agent_schemas import ( + ContentType, + TextContent, +) + +from ....config.config import XiaoYiConfig as XiaoYiChannelConfig +from ..base import ( + BaseChannel, + OnReplySent, + OutgoingContentPart, + ProcessHandler, +) +from ..renderer import MessageRenderer, RenderStyle +from .auth import generate_auth_headers +from .constants import ( + CONNECTION_TIMEOUT, + DEFAULT_TASK_TIMEOUT_MS, + HEARTBEAT_INTERVAL, + MAX_RECONNECT_ATTEMPTS, + RECONNECT_DELAYS, + TEXT_CHUNK_LIMIT, +) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from agentscope_runtime.engine.schemas.agent_schemas import AgentRequest + + +# Class-level registry to track active connections per agent_id +# This prevents multiple channel instances with same agent_id from conflicting +_active_connections: Dict[str, "XiaoYiChannel"] = {} +_active_connections_lock = asyncio.Lock() + + +class XiaoYiChannel(BaseChannel): + """XiaoYi channel using A2A protocol over WebSocket. + + This channel connects to XiaoYi server as a WebSocket client + and handles A2A (Agent-to-Agent) protocol messages. + """ + + channel = "xiaoyi" + uses_manager_queue = True + + def __init__( + self, + process: ProcessHandler, + enabled: bool, + ak: str, + sk: str, + agent_id: str, + ws_url: str, + task_timeout_ms: int = DEFAULT_TASK_TIMEOUT_MS, + on_reply_sent: OnReplySent = None, + show_tool_details: bool = True, + filter_tool_messages: bool = False, + filter_thinking: bool = False, + bot_prefix: str = "", + dm_policy: str = "open", + group_policy: str = "open", + allow_from: Optional[List[str]] = None, + deny_message: str = "", + ): + super().__init__( + process, + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + dm_policy=dm_policy, + group_policy=group_policy, + allow_from=allow_from, + deny_message=deny_message, + ) + + # XiaoYi platform supports markdown and code fences + # Tool call arguments should be in code blocks for better readability + self._render_style = RenderStyle( + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + supports_markdown=True, + supports_code_fence=True, + use_emoji=True, + ) + self._renderer = MessageRenderer(self._render_style) + + self.enabled = enabled + self.ak = ak + self.sk = sk + self.agent_id = agent_id + self.ws_url = ws_url + self.task_timeout_ms = task_timeout_ms + self.bot_prefix = bot_prefix + + # WebSocket state + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._session: Optional[aiohttp.ClientSession] = None + self._connected = False + self._reconnect_attempts = 0 + self._stopping = False # Flag to prevent reconnect during stop + + # Session -> task_id mapping + self._session_task_map: Dict[str, str] = {} + + # Heartbeat task + self._heartbeat_task: Optional[asyncio.Task] = None + + # Receive loop task + self._receive_task: Optional[asyncio.Task] = None + + @classmethod + def from_env( + cls, + process: ProcessHandler, + on_reply_sent: OnReplySent = None, + ) -> "XiaoYiChannel": + """Create channel from environment variables.""" + import os + + return cls( + process=process, + enabled=os.getenv("XIAOYI_CHANNEL_ENABLED", "0") == "1", + ak=os.getenv("XIAOYI_AK", ""), + sk=os.getenv("XIAOYI_SK", ""), + agent_id=os.getenv("XIAOYI_AGENT_ID", ""), + ws_url=os.getenv( + "XIAOYI_WS_URL", + "wss://hag.cloud.huawei.com/openclaw/v1/ws/link", + ), + on_reply_sent=on_reply_sent, + ) + + @classmethod + def from_config( + cls, + process: ProcessHandler, + config: XiaoYiChannelConfig, + on_reply_sent: OnReplySent = None, + show_tool_details: bool = True, + filter_tool_messages: bool = False, + filter_thinking: bool = False, + ) -> "XiaoYiChannel": + """Create channel from config object.""" + if isinstance(config, dict): + return cls( + process=process, + enabled=config.get("enabled", False), + ak=config.get("ak", ""), + sk=config.get("sk", ""), + agent_id=config.get("agent_id", ""), + ws_url=config.get( + "ws_url", + "wss://hag.cloud.huawei.com/openclaw/v1/ws/link", + ), + task_timeout_ms=config.get( + "task_timeout_ms", + DEFAULT_TASK_TIMEOUT_MS, + ), + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + bot_prefix=config.get("bot_prefix", ""), + dm_policy=config.get("dm_policy", "open"), + group_policy=config.get("group_policy", "open"), + allow_from=config.get("allow_from"), + deny_message=config.get("deny_message", ""), + ) + + return cls( + process=process, + enabled=config.enabled, + ak=config.ak, + sk=config.sk, + agent_id=config.agent_id, + ws_url=config.ws_url, + task_timeout_ms=config.task_timeout_ms, + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + bot_prefix=config.bot_prefix, + dm_policy=config.dm_policy, + group_policy=config.group_policy, + allow_from=list(config.allow_from) if config.allow_from else None, + deny_message=config.deny_message, + ) + + def _validate_config(self) -> None: + """Validate required configuration.""" + if not self.ak: + raise ValueError("XiaoYi AK (Access Key) is required") + if not self.sk: + raise ValueError("XiaoYi SK (Secret Key) is required") + if not self.agent_id: + raise ValueError("XiaoYi Agent ID is required") + + async def start(self) -> None: + """Start WebSocket connection.""" + if not self.enabled: + logger.debug("XiaoYi: start() skipped (enabled=false)") + return + + try: + self._validate_config() + except ValueError as e: + logger.error(f"XiaoYi config validation failed: {e}") + return + + # Check if there's already an active connection for this agent_id + # and reuse it if only filter settings changed + global _active_connections + should_connect = True + async with _active_connections_lock: + existing = _active_connections.get(self.agent_id) + if ( + existing is not None + and existing is not self + and existing._connected # pylint: disable=protected-access + ): + # pylint: disable=protected-access + # Found active connection - update settings + logger.info( + "XiaoYi: Updating settings for existing " + f"connection agent_id={self.agent_id}", + ) + # Update render style settings on the existing channel + existing._render_style.filter_tool_messages = ( + self._render_style.filter_tool_messages + ) + existing._render_style.filter_thinking = ( + self._render_style.filter_thinking + ) + existing._render_style.show_tool_details = ( + self._render_style.show_tool_details + ) + # Re-register this instance + # (so the new instance becomes the active one) + _active_connections[self.agent_id] = self + # Copy the WebSocket state to this instance + self._ws = existing._ws + self._session = existing._session + self._connected = existing._connected + self._heartbeat_task = existing._heartbeat_task + self._receive_task = existing._receive_task + self._session_task_map = existing._session_task_map + # Mark old instance as not owning the connection anymore + existing._ws = None + existing._session = None + existing._connected = False + existing._heartbeat_task = None + existing._receive_task = None + should_connect = False + + if not should_connect: + logger.info( + "XiaoYi: Reused existing connection with updated settings", + ) + return + + # No existing connection or can't reuse - start new connection + await self._wait_and_register_connection() + + logger.info(f"XiaoYi: Connecting to {self.ws_url}...") + + try: + await self._connect() + except Exception as e: + logger.error(f"XiaoYi connection failed: {e}") + # Unregister on failure + await self._unregister_connection() + self._schedule_reconnect() + + async def _wait_and_register_connection(self) -> None: + """Stop any existing connection with same agent_id, then register.""" + global _active_connections + + # First, get existing connection and remove it from registry + existing = None + async with _active_connections_lock: + existing = _active_connections.get(self.agent_id) + if existing is not None and existing is not self: + # Remove from registry immediately + _active_connections.pop(self.agent_id, None) + # Register this instance + _active_connections[self.agent_id] = self + + # Now stop the old connection outside the lock + if existing is not None and existing is not self: + # pylint: disable=protected-access + logger.info( + "XiaoYi: Stopping old connection for " + f"agent_id={self.agent_id}", + ) + try: + # Set stopping flag FIRST to prevent any reconnect + existing._stopping = True + existing._connected = False + # Cancel tasks and wait for them to finish + if existing._heartbeat_task: + existing._heartbeat_task.cancel() + try: + await existing._heartbeat_task + except asyncio.CancelledError: + pass + if existing._receive_task: + existing._receive_task.cancel() + try: + await existing._receive_task + except asyncio.CancelledError: + pass + # Close WebSocket + if existing._ws: + await existing._ws.close() + if existing._session: + await existing._session.close() + logger.debug("XiaoYi: Old connection stopped") + except Exception as e: + logger.debug(f"XiaoYi: Error stopping old connection: {e}") + + logger.debug( + f"XiaoYi: Registered connection for agent_id={self.agent_id}", + ) + + async def _unregister_connection(self) -> None: + """Unregister this connection from active connections.""" + global _active_connections + async with _active_connections_lock: + if _active_connections.get(self.agent_id) is self: + _active_connections.pop(self.agent_id, None) + logger.debug( + "XiaoYi: Unregistered connection for " + f"agent_id={self.agent_id}", + ) + + async def _connect(self) -> None: + """Establish WebSocket connection.""" + headers = generate_auth_headers(self.ak, self.sk, self.agent_id) + + # Clean up any existing session first + await self._cleanup_session() + + self._session = aiohttp.ClientSession() + ws_timeout = aiohttp.ClientWSTimeout(ws_close=CONNECTION_TIMEOUT) + + try: + self._ws = await self._session.ws_connect( + self.ws_url, + headers=headers, + timeout=ws_timeout, + ) + + self._connected = True + self._reconnect_attempts = 0 + logger.info("XiaoYi: WebSocket connected") + + # Send init message + await self._send_init_message() + + # Start heartbeat + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + # Start receive loop + self._receive_task = asyncio.create_task(self._receive_loop()) + + except Exception as e: + logger.error(f"XiaoYi: WebSocket connection error: {e}") + self._connected = False + raise + + async def _send_init_message(self) -> None: + """Send init message to server.""" + if not self._ws: + return + + init_msg = { + "msgType": "clawd_bot_init", + "agentId": self.agent_id, + } + + try: + await self._ws.send_json(init_msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send init message: {e}") + + async def _heartbeat_loop(self) -> None: + """Send heartbeat messages periodically.""" + while self._connected and self._ws: + try: + await asyncio.sleep(HEARTBEAT_INTERVAL) + + if not self._connected or not self._ws: + break + + heartbeat_msg = { + "msgType": "heartbeat", + "agentId": self.agent_id, + } + + await self._ws.send_json(heartbeat_msg) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"XiaoYi: Heartbeat error: {e}") + break + + async def _receive_loop(self) -> None: + """Receive and process messages from WebSocket.""" + if not self._ws: + return + + try: + async for msg in self._ws: + if msg.type == aiohttp.WSMsgType.TEXT: + await self._handle_message(msg.data) + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error( + f"XiaoYi: WebSocket error: {self._ws.exception()}", + ) + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + logger.info("XiaoYi: WebSocket closed") + break + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"XiaoYi: Receive loop error: {e}") + finally: + self._connected = False + # Only reconnect if not stopping + if not self._stopping: + self._schedule_reconnect() + + async def _handle_message(self, data: str) -> None: + """Handle incoming WebSocket message.""" + try: + message = json.loads(data) + logger.debug( + "XiaoYi: Received message: " + f"{json.dumps(message, indent=2)}", + ) + + # Validate agent_id + if message.get("agentId") and message["agentId"] != self.agent_id: + logger.warning( + "XiaoYi: Mismatched agentId " + f"{message['agentId']}, expected {self.agent_id}", + ) + return + + # Handle clear context + if ( + message.get("method") == "clearContext" + or message.get("action") == "clear" + ): + await self._handle_clear_context(message) + return + + # Handle tasks cancel + if ( + message.get("method") == "tasks/cancel" + or message.get("action") == "tasks/cancel" + ): + await self._handle_tasks_cancel(message) + return + + # Handle A2A request + if message.get("method") == "message/stream": + await self._handle_a2a_request(message) + + except json.JSONDecodeError as e: + logger.error(f"XiaoYi: Failed to parse message: {e}") + except Exception as e: + logger.error(f"XiaoYi: Error handling message: {e}", exc_info=True) + + async def _handle_a2a_request(self, message: Dict[str, Any]) -> None: + """Handle A2A request message.""" + try: + # Extract session ID + # (prefer params.sessionId, fallback to top-level) + session_id = message.get("params", {}).get( + "sessionId", + ) or message.get("sessionId") + task_id = message.get("params", {}).get("id") or message.get("id") + + if not session_id: + logger.warning("XiaoYi: No sessionId in message") + return + + # Store session -> task mapping + self._session_task_map[session_id] = task_id + + # Extract text from message parts + text_parts = [] + params = message.get("params", {}) + msg = params.get("message", {}) + parts = msg.get("parts", []) + + for part in parts: + if part.get("kind") == "text" and part.get("text"): + text_parts.append(part["text"]) + + content = " ".join(text_parts) + if not content.strip(): + logger.debug("XiaoYi: Empty message content, skipping") + return + + # Build native payload + content_parts = [TextContent(type=ContentType.TEXT, text=content)] + native = { + "channel_id": self.channel, + "sender_id": session_id, + "content_parts": content_parts, + "meta": { + "session_id": session_id, + "task_id": task_id, + "message_id": message.get("id"), + }, + } + + if self._enqueue: + self._enqueue(native) + else: + logger.warning("XiaoYi: _enqueue not set, message dropped") + + except Exception as e: + logger.error( + f"XiaoYi: Error handling A2A request: {e}", + exc_info=True, + ) + + async def _handle_clear_context(self, message: Dict[str, Any]) -> None: + """Handle clear context message.""" + session_id = message.get("sessionId") or "" + request_id = message.get("id") or "" + + logger.info(f"XiaoYi: Clear context for session {session_id}") + + # Send clear response + await self._send_clear_context_response(request_id, session_id) + + # Clean up session + if session_id: + self._session_task_map.pop(session_id, None) + + async def _handle_tasks_cancel(self, message: Dict[str, Any]) -> None: + """Handle tasks cancel message.""" + session_id = message.get("sessionId") or "" + request_id = message.get("id") or "" + task_id = message.get("taskId") or request_id + + logger.info(f"XiaoYi: Cancel task {task_id} for session {session_id}") + + # Send cancel response + await self._send_tasks_cancel_response(request_id, session_id) + + async def _send_clear_context_response( + self, + request_id: str, + session_id: str, + success: bool = True, + ) -> None: + """Send clear context response.""" + if not self._ws or not self._connected: + return + + json_rpc_response = { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "status": {"state": "cleared" if success else "failed"}, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": request_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send clear context response: {e}") + + async def _send_tasks_cancel_response( + self, + request_id: str, + session_id: str, + success: bool = True, + ) -> None: + """Send tasks cancel response.""" + if not self._ws or not self._connected: + return + + json_rpc_response = { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "id": request_id, + "status": {"state": "canceled" if success else "failed"}, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": request_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send cancel response: {e}") + + def _schedule_reconnect(self) -> None: + """Schedule reconnection attempt.""" + if self._stopping: + return + + if self._reconnect_attempts >= MAX_RECONNECT_ATTEMPTS: + logger.error("XiaoYi: Max reconnect attempts reached") + return + + delay_idx = min(self._reconnect_attempts, len(RECONNECT_DELAYS) - 1) + delay = RECONNECT_DELAYS[delay_idx] + self._reconnect_attempts += 1 + + logger.info( + "XiaoYi: Reconnecting in " + f"{delay}s (attempt {self._reconnect_attempts})", + ) + + async def reconnect(): + await asyncio.sleep(delay) + if self._stopping or self._connected: + return + + # Clean up old session before reconnecting + await self._cleanup_session() + + try: + await self._connect() + logger.info("XiaoYi: Reconnected successfully") + except Exception as e: + logger.error(f"XiaoYi: Reconnect failed: {e}") + self._schedule_reconnect() + + asyncio.create_task(reconnect()) + + async def _cleanup_session(self) -> None: + """Clean up WebSocket and session.""" + if self._ws: + try: + await self._ws.close() + except Exception: + pass + self._ws = None + + if self._session: + try: + await self._session.close() + except Exception: + pass + self._session = None + + async def stop(self) -> None: + """Stop WebSocket connection.""" + logger.info("XiaoYi: Stopping channel...") + + self._stopping = True # Prevent reconnect during stop + self._connected = False + + # Cancel tasks + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + + if self._receive_task: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + self._receive_task = None + + # Close WebSocket + if self._ws: + await self._ws.close() + self._ws = None + + # Close session + if self._session: + await self._session.close() + self._session = None + + # Unregister from active connections + await self._unregister_connection() + + # Keep _stopping = True to prevent any reconnection attempts + # This channel instance will not be reused after stop + logger.info("XiaoYi: Channel stopped") + + async def send( + self, + to_handle: str, + text: str, + meta: Optional[Dict[str, Any]] = None, + ) -> None: + """Send text message via WebSocket. + + For A2A protocol with append=true, messages are chunked + at TEXT_CHUNK_LIMIT characters to avoid WebSocket disconnection + on large messages. + """ + if not self.enabled or not self._ws or not self._connected: + logger.warning("XiaoYi: Cannot send - not connected") + return + + meta = meta or {} + session_id = meta.get("session_id") or to_handle + task_id = meta.get("task_id") or self._session_task_map.get(session_id) + + if not task_id: + logger.warning(f"XiaoYi: No task_id for session {session_id}") + return + + # Don't send empty text + if not text or not text.strip(): + return + + # Get or create message ID for this session + message_id = meta.get("message_id", str(uuid.uuid4())) + + # Chunk text if too large + chunks = self._chunk_text(text) + + for chunk in chunks: + await self._send_chunk(session_id, task_id, message_id, chunk) + + def _chunk_text(self, text: str) -> List[str]: + """Split text into chunks of TEXT_CHUNK_LIMIT size.""" + if len(text) <= TEXT_CHUNK_LIMIT: + return [text] + + chunks = [] + # Try to split at newlines for better readability + lines = text.split("\n") + current_chunk = "" + + for line in lines: + # If single line is too long, split it + if len(line) > TEXT_CHUNK_LIMIT: + # First add any accumulated chunk + if current_chunk: + chunks.append(current_chunk.rstrip("\n")) + current_chunk = "" + + # Split long line into chunks + for i in range(0, len(line), TEXT_CHUNK_LIMIT): + chunks.append(line[i : i + TEXT_CHUNK_LIMIT]) + else: + # Check if adding this line would exceed limit + test_chunk = ( + current_chunk + "\n" + line if current_chunk else line + ) + if len(test_chunk) > TEXT_CHUNK_LIMIT: + if current_chunk: + chunks.append(current_chunk) + current_chunk = line + else: + current_chunk = test_chunk + + if current_chunk: + chunks.append(current_chunk) + + return chunks + + async def _send_chunk( + self, + session_id: str, + task_id: str, + message_id: str, + text: str, + ) -> None: + """Send a single text chunk via WebSocket.""" + if not self._ws or not self._connected: + logger.warning("XiaoYi: Cannot send chunk - not connected") + return + + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": message_id, + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, # Append to previous messages + "lastChunk": False, # Not the last chunk + "final": False, # Not final, more content may come + "artifact": { + "artifactId": artifact_id, + "parts": [{"kind": "text", "text": text}], + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send message: {e}") + + async def _send_reasoning_chunk( + self, + session_id: str, + task_id: str, + message_id: str, + reasoning_text: str, + ) -> None: + """Send a single reasoning/thinking chunk via WebSocket.""" + if not self._ws or not self._connected: + logger.warning( + "XiaoYi: Cannot send reasoning chunk - not connected", + ) + return + + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": message_id, + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, + "lastChunk": False, + "final": False, + "artifact": { + "artifactId": artifact_id, + "parts": [ + { + "kind": "reasoningText", + "reasoningText": reasoning_text, + }, + ], + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send reasoning chunk: {e}") + + async def send_final_message( + self, + session_id: str, + task_id: str, + message_id: str, + ) -> None: + """Send final empty message to end the stream.""" + if not self.enabled or not self._ws or not self._connected: + return + + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": message_id, + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, + "lastChunk": True, + "final": True, + "artifact": { + "artifactId": artifact_id, + "parts": [ + {"kind": "text", "text": ""}, + ], # Empty text for final + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send final message: {e}") + + async def send_media( + self, + to_handle: str, + part: OutgoingContentPart, + meta: Optional[Dict[str, Any]] = None, + ) -> None: + """Send media message via WebSocket.""" + if not self.enabled or not self._ws or not self._connected: + return + + meta = meta or {} + session_id = meta.get("session_id") or to_handle + task_id = meta.get("task_id") or self._session_task_map.get(session_id) + + if not task_id: + return + + part_type = getattr(part, "type", None) + + # Build artifact part based on content type + artifact_part: Dict[str, Any] = {"kind": "text"} + + if part_type == ContentType.IMAGE: + img_url = getattr(part, "image_url", "") + artifact_part = { + "kind": "file", + "file": { + "name": "image", + "mimeType": "image/png", + "uri": img_url, + }, + } + elif part_type == ContentType.VIDEO: + vid_url = getattr(part, "video_url", "") + artifact_part = { + "kind": "file", + "file": { + "name": "video", + "mimeType": "video/mp4", + "uri": vid_url, + }, + } + elif part_type == ContentType.FILE: + file_url = getattr(part, "file_url", "") + artifact_part = { + "kind": "file", + "file": { + "name": getattr(part, "file_name", "file"), + "mimeType": "application/octet-stream", + "uri": file_url, + }, + } + + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, + "lastChunk": True, + "final": True, + "artifact": { + "artifactId": artifact_id, + "parts": [artifact_part], + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send media: {e}") + + def _extract_xiaoyi_parts( + self, + message: Any, + ) -> List[Dict[str, Any]]: + # pylint: disable=too-many-branches,too-many-statements + # pylint: disable=too-many-nested-blocks + """Extract parts from message with proper XiaoYi kinds. + + XiaoYi supports: + - kind="reasoningText": For thinking/reasoning content + - kind="text": For regular text content + """ + from agentscope_runtime.engine.schemas.agent_schemas import ( + MessageType, + ) + + msg_type = getattr(message, "type", None) + content = getattr(message, "content", None) or [] + parts = [] + + # Check if this is a reasoning/thinking message type + if msg_type == MessageType.REASONING: + # Check if thinking is filtered + if self._render_style.filter_thinking: + return [] + for c in content: + text = getattr(c, "text", None) + if text: + # Add newline separator for each thinking content + parts.append( + { + "kind": "reasoningText", + "reasoningText": text + "\n", + }, + ) + return parts + + # Process each content item + for c in content: + ctype = getattr(c, "type", None) + + # Handle thinking blocks (inside DATA content as dict) + if ctype == ContentType.DATA: + data = getattr(c, "data", None) + if isinstance(data, dict): + # Check for thinking content in blocks + blocks = data.get("blocks", []) + if ( + isinstance(blocks, list) + and not self._render_style.filter_thinking + ): + for block in blocks: + if ( + isinstance(block, dict) + and block.get("type") == "thinking" + ): + thinking_text = block.get("thinking", "") + if thinking_text: + # Add newline separator + parts.append( + { + "kind": "reasoningText", + "reasoningText": thinking_text + + "\n", + }, + ) + + # Handle TEXT type (regular message content) + # Add leading newline to separate from previous content + if ctype == ContentType.TEXT and getattr(c, "text", None): + text = c.text + # Add leading newlines if not already present + if not text.startswith("\n"): + text = "\n\n" + text + parts.append({"kind": "text", "text": text}) + + # Handle REFUSAL type + elif ctype == ContentType.REFUSAL and getattr(c, "refusal", None): + parts.append({"kind": "text", "text": c.refusal}) + + # Handle tool call/output messages + # with complete, independent formatting + # Check if tool messages should be filtered + if self._render_style.filter_tool_messages: + if msg_type in ( + MessageType.FUNCTION_CALL, + MessageType.PLUGIN_CALL, + MessageType.MCP_TOOL_CALL, + MessageType.FUNCTION_CALL_OUTPUT, + MessageType.PLUGIN_CALL_OUTPUT, + MessageType.MCP_TOOL_CALL_OUTPUT, + ): + return [] + + if msg_type in ( + MessageType.FUNCTION_CALL, + MessageType.PLUGIN_CALL, + MessageType.MCP_TOOL_CALL, + ): + # Tool call: format as "🔧 **name**" + code block with args + for c in content: + if getattr(c, "type", None) != ContentType.DATA: + continue + data = getattr(c, "data", None) + if not isinstance(data, dict): + continue + name = data.get("name") or "tool" + args = data.get("arguments") or "{}" + # Complete, independent formatting for each tool call + formatted = f"\n\n🔧 **{name}**\n```\n{args}\n```\n" + parts.append({"kind": "text", "text": formatted}) + return parts + + if msg_type in ( + MessageType.FUNCTION_CALL_OUTPUT, + MessageType.PLUGIN_CALL_OUTPUT, + MessageType.MCP_TOOL_CALL_OUTPUT, + ): + # Tool output: format as "✅ **name**" + code block with result + for c in content: + if getattr(c, "type", None) != ContentType.DATA: + continue + data = getattr(c, "data", None) + if not isinstance(data, dict): + continue + name = data.get("name") or "tool" + output = data.get("output", "") + + # Parse output and format as JSON + try: + if isinstance(output, str): + parsed = json.loads(output) + else: + parsed = output + + # Handle list format like [{'type': 'text', 'text': '...'}] + if isinstance(parsed, list): + texts = [] + for item in parsed: + if ( + isinstance(item, dict) + and item.get("type") == "text" + ): + texts.append(item.get("text", "")) + output_str = "\n".join(texts) if texts else str(parsed) + elif isinstance(parsed, dict): + output_str = json.dumps( + parsed, + ensure_ascii=False, + indent=2, + ) + else: + output_str = str(parsed) + except (json.JSONDecodeError, TypeError): + output_str = str(output) if output else "" + + # Truncate if too long + if len(output_str) > 500: + output_str = output_str[:500] + "..." + + # Escape backticks in output + # to avoid breaking code blocks + output_str = output_str.replace("```", "\\`\\`\\`") + + # Complete, independent formatting + # for each tool output + # Ensure code block is properly closed + formatted = f"\n\n✅ **{name}**\n```\n{output_str}\n```\n" + parts.append({"kind": "text", "text": formatted}) + return parts + + # If no parts extracted, use renderer as fallback + if not parts: + rendered_parts = self._renderer.message_to_parts(message) + for rp in rendered_parts: + if getattr(rp, "type", None) == ContentType.TEXT: + text = getattr(rp, "text", "") + if text: + parts.append({"kind": "text", "text": text}) + + return parts + + async def send_xiaoyi_parts( + self, + to_handle: str, + parts: List[Dict[str, Any]], + meta: Optional[Dict[str, Any]] = None, + ) -> None: + # pylint: disable=too-many-branches,too-many-nested-blocks + """Send parts with XiaoYi-specific format. + + Each part is a dict with: + - kind: "text" or "reasoningText" + - text/reasoningText: the content string + """ + if not self.enabled or not self._ws or not self._connected: + logger.warning("XiaoYi: Cannot send - not connected") + return + + meta = meta or {} + session_id = meta.get("session_id") or to_handle + task_id = meta.get("task_id") or self._session_task_map.get(session_id) + + if not task_id: + logger.warning(f"XiaoYi: No task_id for session {session_id}") + return + + message_id = meta.get("message_id", str(uuid.uuid4())) + + # Build artifact parts for XiaoYi + artifact_parts = [] + for part in parts: + kind = part.get("kind", "text") + if kind == "reasoningText": + artifact_parts.append( + { + "kind": "reasoningText", + "reasoningText": part.get("reasoningText", ""), + }, + ) + elif kind == "text": + artifact_parts.append( + { + "kind": "text", + "text": part.get("text", ""), + }, + ) + + if not artifact_parts: + return + + # Check if any part exceeds chunk limit + max_part_len = max( + len(p.get("text", "") or p.get("reasoningText", "")) + for p in artifact_parts + ) + + if max_part_len > TEXT_CHUNK_LIMIT: + # Chunk each part separately, preserving kind + for part in artifact_parts: + kind = part.get("kind", "text") + content = part.get("text", "") or part.get("reasoningText", "") + if len(content) > TEXT_CHUNK_LIMIT: + chunks = self._chunk_text(content) + for chunk in chunks: + if kind == "reasoningText": + await self._send_reasoning_chunk( + session_id, + task_id, + message_id, + chunk, + ) + else: + await self._send_chunk( + session_id, + task_id, + message_id, + chunk, + ) + else: + # Send small parts as-is + if kind == "reasoningText": + await self._send_reasoning_chunk( + session_id, + task_id, + message_id, + content, + ) + else: + await self._send_chunk( + session_id, + task_id, + message_id, + content, + ) + return + + # Send as single message with proper parts + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": message_id, + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, + "lastChunk": False, + "final": False, + "artifact": { + "artifactId": artifact_id, + "parts": artifact_parts, + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send parts: {e}") + + async def on_event_message_completed( + self, + request: "AgentRequest", + to_handle: str, + event: Any, + send_meta: Dict[str, Any], + ) -> None: + """Override to handle XiaoYi-specific message formatting. + + Separates thinking/reasoning content from regular text. + """ + # Extract parts with proper kinds + parts = self._extract_xiaoyi_parts(event) + + if not parts: + logger.debug("XiaoYi: No parts to send for message") + return + + # Send with XiaoYi format + await self.send_xiaoyi_parts(to_handle, parts, send_meta) + + def resolve_session_id( + self, + sender_id: str, + channel_meta: Optional[Dict[str, Any]] = None, + ) -> str: + """Resolve session ID from sender and meta.""" + if channel_meta and channel_meta.get("session_id"): + return f"xiaoyi:{channel_meta['session_id']}" + return f"xiaoyi:{sender_id}" + + def get_to_handle_from_request(self, request: "AgentRequest") -> str: + """Get send target from request.""" + meta = getattr(request, "channel_meta", None) or {} + if meta.get("session_id"): + return meta["session_id"] + return getattr(request, "user_id", "") or "" + + def build_agent_request_from_native( + self, + native_payload: Any, + ) -> "AgentRequest": + """Build AgentRequest from native payload.""" + payload = native_payload if isinstance(native_payload, dict) else {} + + channel_id = payload.get("channel_id") or self.channel + sender_id = payload.get("sender_id") or "" + content_parts = payload.get("content_parts") or [] + meta = payload.get("meta") or {} + + session_id = self.resolve_session_id(sender_id, meta) + + request = self.build_agent_request_from_user_content( + channel_id=channel_id, + sender_id=sender_id, + session_id=session_id, + content_parts=content_parts, + channel_meta=meta, + ) + request.user_id = sender_id + request.channel_meta = meta + return request + + def to_handle_from_target(self, *, user_id: str, session_id: str) -> str: + """Map dispatch target to channel-specific to_handle.""" + if session_id.startswith("xiaoyi:"): + return session_id.split(":", 1)[-1] + return user_id + + async def _run_process_loop( + self, + request: "AgentRequest", + to_handle: str, + send_meta: Dict[str, Any], + ) -> None: + """Run process and send events. Override to send final message.""" + from agentscope_runtime.engine.schemas.agent_schemas import RunStatus + + last_response = None + session_id = send_meta.get("session_id") or to_handle + + try: + async for event in self._process(request): + obj = getattr(event, "object", None) + status = getattr(event, "status", None) + if obj == "message" and status == RunStatus.Completed: + await self.on_event_message_completed( + request, + to_handle, + event, + send_meta, + ) + elif obj == "response": + last_response = event + await self.on_event_response(request, event) + + # Send final message to end the stream + task_id = send_meta.get("task_id") or self._session_task_map.get( + session_id, + ) + message_id = str(uuid.uuid4()) + + if task_id and session_id: + await self.send_final_message(session_id, task_id, message_id) + + err_msg = self._get_response_error_message(last_response) + if err_msg: + await self._on_consume_error( + request, + to_handle, + f"Error: {err_msg}", + ) + if self._on_reply_sent: + args = self.get_on_reply_sent_args(request, to_handle) + self._on_reply_sent(self.channel, *args) + except Exception: + logger.exception("XiaoYi channel consume_one failed") + await self._on_consume_error( + request, + to_handle, + "An error occurred while processing your request.", + ) diff --git a/src/copaw/app/channels/xiaoyi/constants.py b/src/copaw/app/channels/xiaoyi/constants.py new file mode 100644 index 000000000..3f6649a8d --- /dev/null +++ b/src/copaw/app/channels/xiaoyi/constants.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +"""XiaoYi channel constants.""" + +# Default WebSocket URL +DEFAULT_WS_URL = "wss://hag.cloud.huawei.com/openclaw/v1/ws/link" + +# Heartbeat interval (seconds) +HEARTBEAT_INTERVAL = 30 + +# Reconnect delays (seconds) +RECONNECT_DELAYS = [1, 2, 5, 10, 30, 60] +MAX_RECONNECT_ATTEMPTS = 50 + +# Connection timeout (seconds) +CONNECTION_TIMEOUT = 30 + +# Task timeout (milliseconds) +DEFAULT_TASK_TIMEOUT_MS = 3600000 # 1 hour + +# Maximum text chunk size (characters) +# Larger messages will be split to avoid WebSocket disconnection +TEXT_CHUNK_LIMIT = 4000 diff --git a/src/copaw/config/config.py b/src/copaw/config/config.py index 2e8757464..be6595131 100644 --- a/src/copaw/config/config.py +++ b/src/copaw/config/config.py @@ -124,6 +124,16 @@ class VoiceChannelConfig(BaseChannelConfig): welcome_greeting: str = "Hi! This is CoPaw. How can I help you?" +class XiaoYiConfig(BaseChannelConfig): + """XiaoYi channel: Huawei A2A protocol via WebSocket.""" + + ak: str = "" # Access Key + sk: str = "" # Secret Key + agent_id: str = "" # Agent ID from XiaoYi platform + ws_url: str = "wss://hag.cloud.huawei.com/openclaw/v1/ws/link" + task_timeout_ms: int = 3600000 # 1 hour task timeout + + class ChannelConfig(BaseModel): """Built-in channel configs; extra keys allowed for plugin channels.""" @@ -140,6 +150,7 @@ class ChannelConfig(BaseModel): console: ConsoleConfig = ConsoleConfig() matrix: MatrixConfig = MatrixConfig() voice: VoiceChannelConfig = VoiceChannelConfig() + xiaoyi: XiaoYiConfig = XiaoYiConfig() class LastApiConfig(BaseModel): @@ -501,4 +512,5 @@ class Config(BaseModel): ConsoleConfig, MatrixConfig, VoiceChannelConfig, + XiaoYiConfig, ] diff --git a/website/public/docs/channels.en.md b/website/public/docs/channels.en.md index c2681bdc5..f32c98a4c 100644 --- a/website/public/docs/channels.en.md +++ b/website/public/docs/channels.en.md @@ -641,6 +641,26 @@ Invite the bot to a room or send it a direct message from any Matrix client (e.g --- +## XiaoYi + +The XiaoYi channel connects CoPaw via **A2A (Agent-to-Agent) protocol** over WebSocket to Huawei's AI assistant platform. + +### Get credentials + +1. Create an agent in the XiaoYi Open Platform. +2. Obtain **AK** (Access Key), **SK** (Secret Key), and **Agent ID**. + +### Core Config + +| Field | Description | Default | +| ------------ | ----------------------- | ------------------------------------------------ | +| **ak** | Access Key | - | +| **sk** | Secret Key | - | +| **agent_id** | Agent unique identifier | - | +| **ws_url** | WebSocket URL | `wss://hag.cloud.huawei.com/openclaw/v1/ws/link` | + +--- + ## Appendix ### Config overview @@ -655,6 +675,7 @@ Invite the bot to a room or send it a direct message from any Matrix client (e.g | Telegram | telegram | bot_token; optional http_proxy, http_proxy_auth | | Mattermost | mattermost | url, bot_token; optional show_typing, dm_policy, allow_from | | Matrix | matrix | homeserver, user_id, access_token | +| XiaoYi | xiaoyi | ak, sk, agent_id; optional ws_url | Field details and structure are in the tables above and [Config & working dir](./config). @@ -675,6 +696,7 @@ done). **✗** = not supported (not possible on this channel). | Telegram | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | Mattermost | ✓ | ✓ | 🚧 | 🚧 | ✓ | ✓ | ✓ | 🚧 | 🚧 | ✓ | | Matrix | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | +| XiaoYi | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | Notes: @@ -691,6 +713,7 @@ Notes: currently text + link-only. - **Telegram**: Attachments are parsed as files on receive and can be opened in the corresponding format (image / voice / video / file) within the Telegram chat interface. - **Matrix**: Receives image, video, audio, and file attachments via `mxc://` media URLs. Sends media by uploading to the homeserver and sending native Matrix media messages (`m.image`, `m.video`, `m.audio`, `m.file`). +- **XiaoYi**: Text only; media support is 🚧. ### Changing config via HTTP diff --git a/website/public/docs/channels.zh.md b/website/public/docs/channels.zh.md index f970e9cc6..600839c81 100644 --- a/website/public/docs/channels.zh.md +++ b/website/public/docs/channels.zh.md @@ -637,6 +637,26 @@ Matrix 频道通过 [matrix-nio](https://github.com/poljar/matrix-nio) 库将 Co --- +## 小艺(XiaoYi) + +小艺通道通过 **A2A (Agent-to-Agent) 协议** 基于 WebSocket 连接华为小艺平台。 + +### 获取凭证 + +1. 在小艺开放平台创建Agent。 +2. 获取 **AK** (Access Key)、**SK** (Secret Key) 和 **Agent ID**。 + +### 核心配置 + +| 字段 | 说明 | 默认值 | +| ------------ | -------------- | ------------------------------------------------ | +| **ak** | 访问密钥 | - | +| **sk** | 密钥 | - | +| **agent_id** | 代理唯一标识 | - | +| **ws_url** | WebSocket 地址 | `wss://hag.cloud.huawei.com/openclaw/v1/ws/link` | + +--- + ## 附录 ### 配置总览 @@ -651,6 +671,7 @@ Matrix 频道通过 [matrix-nio](https://github.com/poljar/matrix-nio) 库将 Co | Telegram | telegram | bot_token;可选 http_proxy, http_proxy_auth | | Mattermost | mattermost | url, bot_token; 可选 show_typing, dm_policy, allow_from | | Matrix | matrix | homeserver, user_id, access_token | +| 小艺 | xiaoyi | ak, sk, agent_id;可选 ws_url | 各频道字段与完整结构见上文表格及 [配置与工作目录](./config)。 @@ -669,6 +690,7 @@ Matrix 频道通过 [matrix-nio](https://github.com/poljar/matrix-nio) 库将 Co | Telegram | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | Mattermost | ✓ | ✓ | 🚧 | 🚧 | ✓ | ✓ | ✓ | 🚧 | 🚧 | ✓ | | Matrix | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | +| 小艺 | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | 说明: @@ -679,6 +701,7 @@ Matrix 频道通过 [matrix-nio](https://github.com/poljar/matrix-nio) 库将 Co - **QQ**:接收侧附件解析为多模态、发送侧真实媒体均为 🚧 施工中,当前仅文本 + 链接形式。 - **Telegram**:接收时附件会解析为文件并传入,可在telegram对话界面以对应格式打开(图片 / 语音 / 视频 / 文件) - **Matrix**:接收图片 / 视频 / 音频 / 文件(通过 `mxc://` 媒体 URL);发送时将文件上传至服务器后以原生 Matrix 媒体消息(`m.image`、`m.video`、`m.audio`、`m.file`)发出。 +- **小艺**:当前仅支持文本。 ### 通过 HTTP 修改配置 From f8aef03926f6513c3dcc188c1abde665dea4cdcd Mon Sep 17 00:00:00 2001 From: hongxicheng <1003394729@qq.com> Date: Mon, 16 Mar 2026 11:37:16 +0800 Subject: [PATCH 05/68] fest(channel): add WeCom channel (#1407) --- console/src/locales/en.json | 3 + console/src/locales/zh.json | 3 + .../Channels/components/ChannelDrawer.tsx | 30 + .../Control/Channels/components/constants.ts | 1 + pyproject.toml | 1 + src/copaw/app/channels/registry.py | 1 + src/copaw/app/channels/wecom/__init__.py | 6 + src/copaw/app/channels/wecom/channel.py | 812 ++++++++++++++++++ src/copaw/app/channels/wecom/utils.py | 108 +++ src/copaw/config/config.py | 12 + website/public/docs/channels.en.md | 55 ++ website/public/docs/channels.zh.md | 48 ++ 12 files changed, 1080 insertions(+) create mode 100644 src/copaw/app/channels/wecom/__init__.py create mode 100644 src/copaw/app/channels/wecom/channel.py create mode 100644 src/copaw/app/channels/wecom/utils.py diff --git a/console/src/locales/en.json b/console/src/locales/en.json index 0c4bfc218..0fffc7006 100644 --- a/console/src/locales/en.json +++ b/console/src/locales/en.json @@ -278,6 +278,9 @@ "sttProvider": "STT Provider", "language": "Language", "welcomeGreeting": "Welcome Greeting", + "welcomeText": "Welcome Message", + "welcomeTextTooltip": "Message automatically sent when a user enters a single chat session with the bot for the first time that day", + "welcomeTextPlaceholder": "e.g. Hello! I'm CoPaw. How can I help you?", "voiceSetupGuide": "Set up a Twilio account, purchase a phone number, then enter your credentials below. Find your Account SID and Auth Token on the Twilio Console dashboard. The Phone Number SID is listed under Phone Numbers → Active Numbers.", "voiceSetupLink": "Open Twilio Console", "xiaoyiSetupGuide": "Please create an agent on Huawei Developer Platform and get AK/SK and Agent ID. AK/SK can be found in the credential management page.", diff --git a/console/src/locales/zh.json b/console/src/locales/zh.json index 168aadff4..c10b3e48a 100644 --- a/console/src/locales/zh.json +++ b/console/src/locales/zh.json @@ -278,6 +278,9 @@ "sttProvider": "STT 提供商", "language": "语言", "welcomeGreeting": "欢迎语", + "welcomeText": "欢迎消息", + "welcomeTextTooltip": "用户当天首次进入机器人单聊会话时机器人自动发送的消息", + "welcomeTextPlaceholder": "例如:你好!我是 CoPaw,有什么可以帮你的?", "voiceSetupGuide": "请先注册 Twilio 账户并购买电话号码,然后在下方填写凭据。Account SID 和 Auth Token 可在 Twilio 控制台首页找到。Phone Number SID 在 Phone Numbers → Active Numbers 中查看。", "voiceSetupLink": "打开 Twilio 控制台", "xiaoyiSetupGuide": "请在华为开发者平台创建智能体并获取 AK/SK 和 Agent ID。AK/SK 可在凭证管理页面找到。", diff --git a/console/src/pages/Control/Channels/components/ChannelDrawer.tsx b/console/src/pages/Control/Channels/components/ChannelDrawer.tsx index 1b0d7e511..03eb1e2e0 100644 --- a/console/src/pages/Control/Channels/components/ChannelDrawer.tsx +++ b/console/src/pages/Control/Channels/components/ChannelDrawer.tsx @@ -19,6 +19,7 @@ const CHANNELS_WITH_ACCESS_CONTROL: ChannelKey[] = [ "dingtalk", "discord", "feishu", + "wecom", "mattermost", "matrix", "xiaoyi", @@ -437,6 +438,35 @@ export function ChannelDrawer({ ); + case "wecom": + return ( + <> + + + + + + + + + + + + + + ); case "xiaoyi": return ( <> diff --git a/console/src/pages/Control/Channels/components/constants.ts b/console/src/pages/Control/Channels/components/constants.ts index 0acf90b72..0dae560e9 100644 --- a/console/src/pages/Control/Channels/components/constants.ts +++ b/console/src/pages/Control/Channels/components/constants.ts @@ -14,6 +14,7 @@ export const CHANNEL_LABELS: Record = { matrix: "Matrix", console: "Console", voice: "Twilio", + wecom: "WeCom", xiaoyi: "XiaoYi", }; diff --git a/pyproject.toml b/pyproject.toml index 2ab5bf624..2fc13f59c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "pywebview>=4.0", "aiofiles>=24.1.0", "paho-mqtt>=2.0.0", + "wecom-aibot-sdk @ https://agentscope.oss-cn-zhangjiakou.aliyuncs.com/pre_whl/wecom_aibot_sdk-1.0.0-py3-none-any.whl", "matrix-nio>=0.24.0", ] diff --git a/src/copaw/app/channels/registry.py b/src/copaw/app/channels/registry.py index 98d42e965..cddae6ce9 100644 --- a/src/copaw/app/channels/registry.py +++ b/src/copaw/app/channels/registry.py @@ -28,6 +28,7 @@ "console": (".console", "ConsoleChannel"), "matrix": (".matrix", "MatrixChannel"), "voice": (".voice", "VoiceChannel"), + "wecom": (".wecom", "WecomChannel"), "xiaoyi": (".xiaoyi", "XiaoYiChannel"), } diff --git a/src/copaw/app/channels/wecom/__init__.py b/src/copaw/app/channels/wecom/__init__.py new file mode 100644 index 000000000..32a770d63 --- /dev/null +++ b/src/copaw/app/channels/wecom/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""WeCom (Enterprise WeChat) channel package.""" + +from .channel import WecomChannel + +__all__ = ["WecomChannel"] diff --git a/src/copaw/app/channels/wecom/channel.py b/src/copaw/app/channels/wecom/channel.py new file mode 100644 index 000000000..0aaf48792 --- /dev/null +++ b/src/copaw/app/channels/wecom/channel.py @@ -0,0 +1,812 @@ +# -*- coding: utf-8 -*- +# pylint: disable=too-many-statements,too-many-branches +# pylint: disable=too-many-return-statements,too-many-instance-attributes +# pylint: disable=too-many-nested-blocks +"""WeCom (Enterprise WeChat) Channel. + +Uses the aibot WebSocket SDK to receive messages from WeCom AI Bot. +Sends replies via the same WebSocket channel using stream mode +(reply_stream). Supports text, image, voice, file, and mixed messages. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging +import os +import sys +import threading +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, List, Optional + +from agentscope_runtime.engine.schemas.agent_schemas import ( + AgentRequest, + FileContent, + ImageContent, + TextContent, +) +from aibot import WSClient, WSClientOptions, generate_req_id + +from ..base import ( + BaseChannel, + ContentType, + OnReplySent, + OutgoingContentPart, + ProcessHandler, +) +from .utils import format_markdown_tables + +logger = logging.getLogger(__name__) + +# Max number of processed message_ids to keep for dedup. +_WECOM_PROCESSED_IDS_MAX = 2000 + + +class WecomChannel(BaseChannel): + """WeCom AI Bot channel: WebSocket receive and send. + + Session: for single-chat session_id = wecom:, for group-chat + wecom:group:. The frame from the SDK is stored in meta so + we can call reply_stream back through the same connection. + """ + + channel = "wecom" + + def __init__( + self, + process: ProcessHandler, + enabled: bool, + bot_id: str, + secret: str, + bot_prefix: str = "[BOT] ", + media_dir: str = "~/.copaw/media", + welcome_text: str = "", + on_reply_sent: OnReplySent = None, + show_tool_details: bool = True, + filter_tool_messages: bool = False, + filter_thinking: bool = False, + dm_policy: str = "open", + group_policy: str = "open", + allow_from: Optional[List[str]] = None, + deny_message: str = "", + max_reconnect_attempts: int = -1, + ): + super().__init__( + process, + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + dm_policy=dm_policy, + group_policy=group_policy, + allow_from=allow_from, + deny_message=deny_message, + ) + self.enabled = enabled + self.bot_id = bot_id + self.secret = secret + self.bot_prefix = bot_prefix + self.welcome_text = welcome_text + self._media_dir = Path(media_dir).expanduser() + self._max_reconnect_attempts = max_reconnect_attempts + + self._client: Any = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._ws_thread: Optional[threading.Thread] = None + + # message_id dedup (ordered dict, trimmed when over limit) + self._processed_message_ids: OrderedDict[str, None] = OrderedDict() + self._processed_ids_lock = threading.Lock() + + @classmethod + def from_env( + cls, + process: ProcessHandler, + on_reply_sent: OnReplySent = None, + ) -> "WecomChannel": + allow_from_env = os.getenv("WECOM_ALLOW_FROM", "") + allow_from = ( + [s.strip() for s in allow_from_env.split(",") if s.strip()] + if allow_from_env + else [] + ) + return cls( + process=process, + enabled=os.getenv("WECOM_CHANNEL_ENABLED", "0") == "1", + bot_id=os.getenv("WECOM_BOT_ID", ""), + secret=os.getenv("WECOM_SECRET", ""), + bot_prefix=os.getenv("WECOM_BOT_PREFIX", "[BOT] "), + media_dir=os.getenv("WECOM_MEDIA_DIR", "~/.copaw/media"), + on_reply_sent=on_reply_sent, + dm_policy=os.getenv("WECOM_DM_POLICY", "open"), + group_policy=os.getenv("WECOM_GROUP_POLICY", "open"), + allow_from=allow_from, + deny_message=os.getenv("WECOM_DENY_MESSAGE", ""), + max_reconnect_attempts=int( + os.getenv("WECOM_MAX_RECONNECT_ATTEMPTS", "-1"), + ), + ) + + @classmethod + def from_config( + cls, + process: ProcessHandler, + config: Any, + on_reply_sent: OnReplySent = None, + show_tool_details: bool = True, + filter_tool_messages: bool = False, + filter_thinking: bool = False, + ) -> "WecomChannel": + return cls( + process=process, + enabled=getattr(config, "enabled", False), + bot_id=getattr(config, "bot_id", "") or "", + secret=getattr(config, "secret", "") or "", + bot_prefix=getattr(config, "bot_prefix", "[BOT] ") or "[BOT] ", + media_dir=( + getattr(config, "media_dir", "~/.copaw/media") + or "~/.copaw/media" + ), + welcome_text=getattr(config, "welcome_text", "") or "", + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + dm_policy=getattr(config, "dm_policy", "open") or "open", + group_policy=getattr(config, "group_policy", "open") or "open", + allow_from=getattr(config, "allow_from", []) or [], + deny_message=getattr(config, "deny_message", "") or "", + max_reconnect_attempts=int( + -1 + if getattr(config, "max_reconnect_attempts", None) is None + else getattr(config, "max_reconnect_attempts"), + ), + ) + + # ------------------------------------------------------------------ + # Session / handle helpers + # ------------------------------------------------------------------ + + def resolve_session_id( + self, + sender_id: str, + channel_meta: Optional[Dict[str, Any]] = None, + ) -> str: + """Build session_id from meta or sender_id.""" + meta = channel_meta or {} + chatid = (meta.get("wecom_chatid") or "").strip() + chat_type = (meta.get("wecom_chat_type") or "single").strip() + if chat_type == "group" and chatid: + return f"wecom:group:{chatid}" + if sender_id: + return f"wecom:{sender_id}" + return f"wecom:{chatid or 'unknown'}" + + @staticmethod + def _parse_chatid_from_handle(to_handle: str) -> str: + """Extract chatid/userid from a to_handle string. + + - ``wecom:group:`` → ```` + - ``wecom:`` → ```` + """ + h = (to_handle or "").strip() + if h.startswith("wecom:group:"): + return h[len("wecom:group:") :] + if h.startswith("wecom:"): + return h[len("wecom:") :] + return h + + def to_handle_from_target(self, *, user_id: str, session_id: str) -> str: + """Return send handle; session_id takes priority.""" + return session_id or f"wecom:{user_id}" + + def get_to_handle_from_request(self, request: Any) -> str: + session_id = getattr(request, "session_id", "") or "" + user_id = getattr(request, "user_id", "") or "" + return session_id or f"wecom:{user_id}" + + def get_on_reply_sent_args( + self, + request: Any, + to_handle: str, + ) -> tuple: + return ( + getattr(request, "user_id", "") or "", + getattr(request, "session_id", "") or "", + ) + + def build_agent_request_from_native( + self, + native_payload: Any, + ) -> "AgentRequest": + """Build AgentRequest from a wecom native dict.""" + payload = native_payload if isinstance(native_payload, dict) else {} + channel_id = payload.get("channel_id") or self.channel + sender_id = payload.get("sender_id") or "" + content_parts = payload.get("content_parts") or [] + meta = payload.get("meta") or {} + session_id = payload.get("session_id") or self.resolve_session_id( + sender_id, + meta, + ) + user_id = payload["user_id"] if "user_id" in payload else sender_id + request = self.build_agent_request_from_user_content( + channel_id=channel_id, + sender_id=user_id, + session_id=session_id, + content_parts=content_parts, + channel_meta=meta, + ) + setattr(request, "channel_meta", meta) + return request + + def merge_native_items(self, items: List[Any]) -> Any: + """Merge same-session native payloads: concat content_parts.""" + if not items: + return None + first = items[0] if isinstance(items[0], dict) else {} + merged_parts: List[Any] = [] + for it in items: + p = it if isinstance(it, dict) else {} + merged_parts.extend(p.get("content_parts") or []) + last = items[-1] if isinstance(items[-1], dict) else {} + return { + "channel_id": first.get("channel_id") or self.channel, + "sender_id": last.get( + "sender_id", + first.get("sender_id", ""), + ), + "user_id": last.get("user_id", first.get("user_id", "")), + "session_id": last.get( + "session_id", + first.get("session_id", ""), + ), + "content_parts": merged_parts, + "meta": dict(last.get("meta") or {}), + } + + # ------------------------------------------------------------------ + # Message dedup helper + # ------------------------------------------------------------------ + + def _is_duplicate(self, msg_id: str) -> bool: + """Return True if msg_id was already seen; record it.""" + with self._processed_ids_lock: + if msg_id in self._processed_message_ids: + return True + self._processed_message_ids[msg_id] = None + while len(self._processed_message_ids) > _WECOM_PROCESSED_IDS_MAX: + self._processed_message_ids.popitem(last=False) + return False + + # ------------------------------------------------------------------ + # Incoming message handlers (called from WS thread, dispatch to loop) + # ------------------------------------------------------------------ + + def _on_message_sync(self, frame: Any) -> None: + """Sync handler called from SDK event; dispatches to async loop.""" + if not self._loop or not self._loop.is_running(): + logger.warning("wecom: main loop not set/running, drop message") + return + asyncio.run_coroutine_threadsafe( + self._on_message(frame), + self._loop, + ) + + async def _on_message(self, frame: Any) -> None: + """Parse and enqueue one incoming message.""" + try: + body = frame.get("body") or {} + msgtype = body.get("msgtype") or "" + sender_id = (body.get("from") or {}).get("userid", "") + chatid = body.get("chatid", "") + chat_type = body.get("chattype", "single") + + # Build unique message id for dedup + msg_id = ( + body.get("msgid") or "" + ) or f"{sender_id}_{body.get('send_time', '')}" + if msg_id and self._is_duplicate(msg_id): + return + + content_parts: List[Any] = [] + text_parts: List[str] = [] + + if msgtype == "text": + text = (body.get("text") or {}).get("content", "").strip() + if text: + text_parts.append(text) + + elif msgtype == "image": + img_info = body.get("image") or {} + url = img_info.get("url") or "" + aes_key = img_info.get("aeskey") or "" + if url: + path = await self._download_media( + url, + aes_key=aes_key, + filename_hint="image.jpg", + ) + if path: + content_parts.append( + ImageContent( + type=ContentType.IMAGE, + image_url=path, + ), + ) + else: + text_parts.append("[image: download failed]") + else: + text_parts.append("[image: no url]") + + elif msgtype == "voice": + voice_info = body.get("voice") or {} + # Use ASR text from WeCom; no need to download audio + asr_text = voice_info.get("content", "").strip() + if asr_text: + text_parts.append(asr_text) + else: + text_parts.append("[voice: no text]") + + elif msgtype == "file": + file_info = body.get("file") or {} + url = file_info.get("url") or "" + aes_key = file_info.get("aeskey") or "" + filename = file_info.get("filename") or "file.bin" + if url: + path = await self._download_media( + url, + aes_key=aes_key, + filename_hint=filename, + ) + if path: + content_parts.append( + FileContent( + type=ContentType.FILE, + file_url=path, + ), + ) + else: + text_parts.append("[file: download failed]") + else: + text_parts.append("[file: no url]") + + elif msgtype == "mixed": + # Mixed: list of items, each has msgtype, text or image + mixed_items = body.get("mixed", {}).get("msg_item", []) + for item in mixed_items: + itype = item.get("msgtype") or "" + if itype == "text": + t = item.get("text", {}).get("content", "").strip() + if t: + text_parts.append(t) + elif itype == "image": + img = item.get("image") or {} + url = img.get("url") or "" + aes_key = img.get("aeskey") or "" + if url: + path = await self._download_media( + url, + aes_key=aes_key, + filename_hint="image.jpg", + ) + if path: + content_parts.append( + ImageContent( + type=ContentType.IMAGE, + image_url=path, + ), + ) + else: + text_parts.append("[image: download failed]") + else: + text_parts.append(f"[{msgtype}]") + + text = "\n".join(text_parts).strip() + if text: + content_parts.insert( + 0, + TextContent(type=ContentType.TEXT, text=text), + ) + if not content_parts: + return + + is_group = chat_type == "group" + meta: Dict[str, Any] = { + "wecom_sender_id": sender_id, + "wecom_chatid": chatid, + "wecom_chat_type": chat_type, + "wecom_frame": frame, + "is_group": is_group, + } + + allowed, error_msg = self._check_allowlist(sender_id, is_group) + if not allowed: + logger.info( + "wecom allowlist blocked: sender=%s is_group=%s", + sender_id, + is_group, + ) + await self._send_text_via_frame( + frame, + error_msg or "Access denied.", + ) + return + + # Send "processing" indicator only if message has text content + processing_stream_id = "" + if text_parts and self._client: + processing_stream_id = generate_req_id("stream") + try: + await self._client.reply_stream( + frame, + stream_id=processing_stream_id, + content="🤔 思考中...", + finish=False, + ) + except Exception: + logger.debug("wecom failed to send processing indicator") + + session_id = self.resolve_session_id(sender_id, meta) + if processing_stream_id: + meta["wecom_processing_stream_id"] = processing_stream_id + native = { + "channel_id": self.channel, + "sender_id": sender_id, + # Group chats share one session; omit user_id so the + # session file is keyed by session_id only. + "user_id": "" if is_group else sender_id, + "session_id": session_id, + "content_parts": content_parts, + "meta": meta, + } + logger.info( + "wecom recv: sender=%s chatid=%s msgtype=%s text_len=%s", + sender_id[:20], + (chatid or "")[:20], + msgtype, + len(text), + ) + if self._enqueue is not None: + self._enqueue(native) + except Exception: + logger.exception("wecom _on_message failed") + + def _on_enter_chat_sync(self, frame: Any) -> None: + """Sync handler called from SDK event; dispatches to async loop.""" + if not self._loop or not self._loop.is_running(): + logger.warning("wecom: main loop not set/running, drop enter_chat") + return + asyncio.run_coroutine_threadsafe( + self._on_enter_chat(frame), + self._loop, + ) + + async def _on_enter_chat(self, frame: Any) -> None: + """Handle enter_chat event; send welcome reply if configured.""" + logger.info("wecom enter_chat event") + if not self.welcome_text or not self._client: + return + await self._client.reply_welcome( + frame, + {"msgtype": "text", "text": {"content": self.welcome_text}}, + ) + + # ------------------------------------------------------------------ + # File download helper + # ------------------------------------------------------------------ + + async def _download_media( + self, + url: str, + aes_key: str = "", + filename_hint: str = "file.bin", + ) -> Optional[str]: + """Download (and optionally decrypt) media; return local path.""" + if not self._client: + return None + try: + data, filename = await self._client.download_file( + url, + aes_key or None, + ) + fn = filename or filename_hint + # Determine extension from hint if file has none + hint_ext = Path(filename_hint).suffix + if hint_ext and Path(fn).suffix in ("", ".bin", ".file"): + fn = (Path(fn).stem or "file") + hint_ext + self._media_dir.mkdir(parents=True, exist_ok=True) + safe_name = ( + "".join(c for c in fn if c.isalnum() or c in "-_.") or "media" + ) + url_hash = hashlib.md5(url.encode()).hexdigest()[:8] + path = self._media_dir / f"wecom_{url_hash}_{safe_name}" + path.write_bytes(data) + return str(path) + except Exception: + logger.exception("wecom _download_media failed url=%s", url[:60]) + return None + + # ------------------------------------------------------------------ + # Send helpers + # ------------------------------------------------------------------ + + async def _send_text_via_frame( + self, + frame: Any, + text: str, + stream_id: str = "", + ) -> None: + """Send a text reply using the SDK reply method (stream finish). + + Args: + frame: WebSocket frame from the incoming message. + text: Content to send. + stream_id: Optional stream ID to overwrite existing message. + If empty, a new UUID is generated. + """ + if not self._client or not text: + return + try: + sid = stream_id or generate_req_id("stream") + await self._client.reply_stream( + frame, + stream_id=sid, + content=text, + finish=True, + ) + except Exception: + logger.exception("wecom _send_text_via_frame failed") + + async def _send_image_via_send_message( + self, + chatid: str, + part: OutgoingContentPart, + ) -> None: + """Send image as markdown inline (best-effort via send_message).""" + if not self._client or not chatid: + return + image_url = getattr(part, "image_url", "") or "" + if not image_url: + return + # WeCom does not support uploading images via WS; use markdown link + try: + await self._client.send_message( + chatid, + { + "msgtype": "markdown", + "markdown": {"content": f"![image]({image_url})"}, + }, + ) + except Exception: + logger.exception("wecom _send_image_via_send_message failed") + + async def send_content_parts( + self, + to_handle: str, + parts: List[OutgoingContentPart], + meta: Optional[Dict[str, Any]] = None, + ) -> None: + """Send text (stream) and media parts back to WeCom.""" + if not self.enabled: + return + m = meta or {} + frame = m.get("wecom_frame") + chatid = ( + m.get("wecom_chatid") + or self._parse_chatid_from_handle(to_handle) + or "" + ) + + prefix = m.get("bot_prefix", "") or self.bot_prefix or "" + text_parts: List[str] = [] + media_parts: List[OutgoingContentPart] = [] + + for p in parts: + t = getattr(p, "type", None) or ( + p.get("type") if isinstance(p, dict) else None + ) + text_val = getattr(p, "text", None) or ( + p.get("text") if isinstance(p, dict) else None + ) + refusal_val = getattr(p, "refusal", None) or ( + p.get("refusal") if isinstance(p, dict) else None + ) + if t == ContentType.TEXT and text_val: + text_parts.append(text_val) + elif t == ContentType.REFUSAL and refusal_val: + text_parts.append(refusal_val) + elif t in ( + ContentType.IMAGE, + ContentType.FILE, + ContentType.VIDEO, + ContentType.AUDIO, + ): + media_parts.append(p) + + body = "\n".join(text_parts).strip() + if prefix and body: + body = prefix + body + + # Format markdown tables for WeCom compatibility + body = format_markdown_tables(body) + + # Use processing stream_id to overwrite "thinking..." indicator + # Only first reply uses it; subsequent replies get new stream_id + processing_sid = m.pop("wecom_processing_stream_id", "") + + if body and frame: + await self._send_text_via_frame(frame, body, processing_sid) + elif body and chatid: + # Proactive send without an inbound frame + try: + await self._client.send_message( + chatid, + { + "msgtype": "markdown", + "markdown": {"content": body}, + }, + ) + except Exception: + logger.exception("wecom send_content_parts proactive failed") + + # # the SDK does not support sending media files. + # for part in media_parts: + # pt = getattr(part, "type", None) + # if pt == ContentType.IMAGE and chatid: + # await self._send_image_via_send_message(chatid, part) + # elif pt in ( + # ContentType.FILE, ContentType.AUDIO, ContentType.VIDEO + # ): + # # Send file path/url as markdown link (WS channel limitation) + # file_url = ( + # getattr(part, "file_url", "") + # or getattr(part, "video_url", "") + # or "" + # ) + # if file_url and chatid: + # filename = Path(file_url).name or "file" + # try: + # await self._client.send_message( + # chatid, + # { + # "msgtype": "markdown", + # "markdown": { + # "content": f"[{filename}]({file_url})" + # }, + # }, + # ) + # except Exception: + # logger.exception( + # "wecom send_content_parts file link failed" + # ) + + async def send( + self, + to_handle: str, + text: str, + meta: Optional[Dict[str, Any]] = None, + ) -> None: + """Proactive send: use send_message with markdown body.""" + if not self.enabled: + return + m = meta or {} + chatid = ( + m.get("wecom_chatid") + or self._parse_chatid_from_handle(to_handle) + or "" + ) + frame = m.get("wecom_frame") + prefix = m.get("bot_prefix", "") or self.bot_prefix or "" + body = (prefix + text) if text else prefix + + if not body: + return + + if frame: + await self._send_text_via_frame(frame, body) + elif chatid and self._client: + try: + await self._client.send_message( + chatid, + { + "msgtype": "markdown", + "markdown": {"content": body}, + }, + ) + except Exception: + logger.exception( + "wecom send proactive failed chatid=%s", + chatid, + ) + else: + logger.warning( + "wecom send: no frame/chatid for to_handle=%s", + (to_handle or "")[:40], + ) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _run_ws_forever(self) -> None: + """Background thread: run SDK event loop forever.""" + # macOS/Python 3.12+ fix: use SelectorEventLoop explicitly + if sys.platform == "darwin": + ws_loop = asyncio.SelectorEventLoop() + else: + ws_loop = asyncio.new_event_loop() + asyncio.set_event_loop(ws_loop) + + # Set thread name for debugging + threading.current_thread().name = "wecom-ws" + + try: + # Run connection in the new loop + ws_loop.run_until_complete(self._client.connect()) + ws_loop.run_forever() + except Exception: + logger.exception("wecom WebSocket thread failed") + finally: + try: + # Cancel all pending tasks + pending = asyncio.all_tasks(ws_loop) + for task in pending: + task.cancel() + if pending: + ws_loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True), + ) + ws_loop.run_until_complete(ws_loop.shutdown_asyncgens()) + ws_loop.close() + except Exception: + pass + + async def start(self) -> None: + if not self.enabled: + logger.debug("wecom channel disabled") + return + + if not self.bot_id or not self.secret: + raise RuntimeError( + "WECOM_BOT_ID and WECOM_SECRET are required when " + "the wecom channel is enabled.", + ) + + self._loop = asyncio.get_running_loop() + options = WSClientOptions( + bot_id=self.bot_id, + secret=self.secret, + max_reconnect_attempts=self._max_reconnect_attempts, + ) + self._client = WSClient(options) + + # Register event handlers + self._client.on("message", self._on_message_sync) + self._client.on("event.enter_chat", self._on_enter_chat_sync) + + self._ws_thread = threading.Thread( + target=self._run_ws_forever, + daemon=True, + name="wecom-ws", + ) + self._ws_thread.start() + logger.info( + "wecom channel started (bot_id=%s)", + (self.bot_id or "")[:12], + ) + + async def stop(self) -> None: + if not self.enabled: + return + if self._client: + try: + self._client.disconnect() + except Exception: + pass + if self._ws_thread: + self._ws_thread.join(timeout=5) + self._client = None + logger.info("wecom channel stopped") diff --git a/src/copaw/app/channels/wecom/utils.py b/src/copaw/app/channels/wecom/utils.py new file mode 100644 index 000000000..b32ed5eb2 --- /dev/null +++ b/src/copaw/app/channels/wecom/utils.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +"""WeCom channel utilities.""" +from __future__ import annotations + +import re +from typing import List + + +def format_markdown_tables(text: str) -> str: + """Format GFM markdown tables for WeCom compatibility. + + WeCom requires table columns to be properly aligned. + This function normalizes table formatting. + + Args: + text: Input markdown text possibly containing tables. + + Returns: + Text with formatted tables. + """ + lines = text.split("\n") + result: List[str] = [] + i = 0 + in_code_fence = False + while i < len(lines): + line = lines[i] + stripped = line.strip() + # Track fenced code blocks (```), pass through inside lines unchanged. + if stripped.startswith("```"): + in_code_fence = not in_code_fence + result.append(line) + i += 1 + continue + if in_code_fence: + result.append(line) + i += 1 + continue + # Detect table start (line with |) when not inside a code fence + if "|" in line: + # Collect table lines + table_lines: List[str] = [] + while ( + i < len(lines) + and "|" in lines[i] + and not lines[i].strip().startswith("```") + ): + table_lines.append(lines[i]) + i += 1 + # Format and add table + if table_lines: + result.extend(_format_table(table_lines)) + continue + result.append(line) + i += 1 + return "\n".join(result) + + +def _format_table(lines: List[str]) -> List[str]: + """Format a single markdown table.""" + if not lines: + return lines + + # Check if second row is separator (contains only -, :, |, spaces) + sep_pattern = re.compile(r"^[\s\-:|]+$") + has_separator = len(lines) >= 2 and sep_pattern.match(lines[1]) is not None + + # Parse cells, skipping the separator row (it will be rebuilt) + rows: List[List[str]] = [] + for idx, line in enumerate(lines): + if has_separator and idx == 1: + continue # Skip separator row; rebuild it from column widths + cells = [c.strip() for c in line.split("|")] + # Remove empty first/last cells from leading/trailing | + if cells and not cells[0]: + cells = cells[1:] + if cells and not cells[-1]: + cells = cells[:-1] + if cells: + rows.append(cells) + + if not rows: + return lines + + # Calculate column widths + col_count = max(len(r) for r in rows) + widths: List[int] = [0] * col_count + for row in rows: + for j in range(col_count): + cell = row[j] if j < len(row) else "" + widths[j] = max(widths[j], len(cell)) + + # Format rows with proper padding, inserting separator after header + formatted: List[str] = [] + for idx, row in enumerate(rows): + padded = [ + (row[j] if j < len(row) else "").ljust(widths[j]) + for j in range(col_count) + ] + formatted.append("| " + " | ".join(padded) + " |") + if idx == 0: + sep = ( + "| " + + " | ".join("-" * max(3, widths[j]) for j in range(col_count)) + + " |" + ) + formatted.append(sep) + + return formatted diff --git a/src/copaw/config/config.py b/src/copaw/config/config.py index be6595131..4fb7ede74 100644 --- a/src/copaw/config/config.py +++ b/src/copaw/config/config.py @@ -102,6 +102,16 @@ class ConsoleConfig(BaseChannelConfig): enabled: bool = True +class WecomConfig(BaseChannelConfig): + """WeCom (Enterprise WeChat) AI Bot channel config.""" + + bot_id: str = "" + secret: str = "" + media_dir: str = "~/.copaw/media" + welcome_text: str = "" + max_reconnect_attempts: int = -1 + + class MatrixConfig(BaseChannelConfig): """Matrix channel configuration.""" @@ -150,6 +160,7 @@ class ChannelConfig(BaseModel): console: ConsoleConfig = ConsoleConfig() matrix: MatrixConfig = MatrixConfig() voice: VoiceChannelConfig = VoiceChannelConfig() + wecom: WecomConfig = WecomConfig() xiaoyi: XiaoYiConfig = XiaoYiConfig() @@ -512,5 +523,6 @@ class Config(BaseModel): ConsoleConfig, MatrixConfig, VoiceChannelConfig, + WecomConfig, XiaoYiConfig, ] diff --git a/website/public/docs/channels.en.md b/website/public/docs/channels.en.md index f32c98a4c..8cc2524e4 100644 --- a/website/public/docs/channels.en.md +++ b/website/public/docs/channels.en.md @@ -431,6 +431,58 @@ You can also fill them in the Console UI. --- +## WeCom (WeChat Work) + +### Create a new enterprise + +Individual users can first register an account, create a new enterprise, and become an enterprise administrator. + +![Create enterprise](https://img.alicdn.com/imgextra/i2/O1CN01Xg8B3i1EQWAKt5xj0_!!6000000000346-2-tps-2938-1588.png) + +![New account](https://img.alicdn.com/imgextra/i2/O1CN01QzuScv26w6je9Yypg_!!6000000007725-2-tps-2938-1592.png) + +If you already have a WeCom account or are a regular employee of an enterprise, you can directly create an API-mode robot in your current enterprise. + +### Create a bot + +You can create a bot in the admin console by clicking Management Tools → Smart Robot → Create Robot, and select API Mode → Configure via Long Connection. + +![Create robot 1](https://img.alicdn.com/imgextra/i2/O1CN01n4qAEI1deajLveo2B_!!6000000003761-2-tps-2938-1590.png) + +![Create robot 2](https://img.alicdn.com/imgextra/i4/O1CN01kZDNVk1ugHf73ybs2_!!6000000006066-2-tps-2938-1594.png) + +![Create robot 3](https://img.alicdn.com/imgextra/i1/O1CN01Znm7aQ1Tfpe5Ha9WL_!!6000000002410-2-tps-1482-992.png) + +### Bind the bot + +You can bind the bot by filling in the Bot ID and Secret in the Console or `config.json`. + +**Method 1:** Fill in the Console + +![Bind robot](https://img.alicdn.com/imgextra/i2/O1CN01X8NcEj1NrqL0e3AMS_!!6000000001624-2-tps-2732-1390.png) + +**Method 2:** Fill in `config.json` (default file path is `~/.copaw/config.json`) + +Find `wecom` and fill in the corresponding information, for example: + +```json +"wecom": { + "enabled": true, + "dm_policy": "open", + "group_policy": "open", + "bot_id": "your bot_id", + "secret": "your secret", + "media_dir": "~/.copaw/media", + "max_reconnect_attempts": -1 +} +``` + +### Start chatting with the bot in WeCom + +![Start using](https://img.alicdn.com/imgextra/i3/O1CN01ZsmpYr1tq4ViIbO80_!!6000000005952-2-tps-1308-1130.png) + +--- + ## Telegram ### Get Telegram bot credentials @@ -672,6 +724,7 @@ The XiaoYi channel connects CoPaw via **A2A (Agent-to-Agent) protocol** over Web | iMessage | imessage | db_path, poll_sec (macOS only) | | Discord | discord | bot_token; optional http_proxy, http_proxy_auth | | QQ | qq | app_id, client_secret | +| WeCom | wecom | bot_id, secret; optional media_dir, max_reconnect_attempts | | Telegram | telegram | bot_token; optional http_proxy, http_proxy_auth | | Mattermost | mattermost | url, bot_token; optional show_typing, dm_policy, allow_from | | Matrix | matrix | homeserver, user_id, access_token | @@ -693,6 +746,7 @@ done). **✗** = not supported (not possible on this channel). | Discord | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | | iMessage | ✓ | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | | QQ | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | +| WeCom | ✓ | ✓ | 🚧 | ✓ | ✓ | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | | Telegram | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | Mattermost | ✓ | ✓ | 🚧 | 🚧 | ✓ | ✓ | ✓ | 🚧 | 🚧 | ✓ | | Matrix | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | @@ -712,6 +766,7 @@ Notes: - **QQ**: Receiving attachments as multimodal and sending real media are 🚧; currently text + link-only. - **Telegram**: Attachments are parsed as files on receive and can be opened in the corresponding format (image / voice / video / file) within the Telegram chat interface. +- **WeCom**: WebSocket long connection for receiving; markdown/template_card for sending. Supports text, image, voice, and file receiving; sending media is not supported by the SDK (only text via markdown). - **Matrix**: Receives image, video, audio, and file attachments via `mxc://` media URLs. Sends media by uploading to the homeserver and sending native Matrix media messages (`m.image`, `m.video`, `m.audio`, `m.file`). - **XiaoYi**: Text only; media support is 🚧. diff --git a/website/public/docs/channels.zh.md b/website/public/docs/channels.zh.md index 600839c81..fe995623c 100644 --- a/website/public/docs/channels.zh.md +++ b/website/public/docs/channels.zh.md @@ -427,6 +427,51 @@ --- +## 企业微信 + +### 创建新企业 + +个人使用者可先注册账号,创建新企业,成为企业管理员。 +![创建企业](https://img.alicdn.com/imgextra/i2/O1CN01Xg8B3i1EQWAKt5xj0_!!6000000000346-2-tps-2938-1588.png) +![新建账号](https://img.alicdn.com/imgextra/i2/O1CN01QzuScv26w6je9Yypg_!!6000000007725-2-tps-2938-1592.png) + +若已经有企业微信账号或是企业普通员工,可以直接在当前企业创建API模式机器人。 + +### 创建机器人 + +可在管理后台点击管理工具-智能机器人-创建机器人,选择API模式创建-通过长链接配置 +![创建机器人1](https://img.alicdn.com/imgextra/i2/O1CN01n4qAEI1deajLveo2B_!!6000000003761-2-tps-2938-1590.png) +![新建机器人2](https://img.alicdn.com/imgextra/i4/O1CN01kZDNVk1ugHf73ybs2_!!6000000006066-2-tps-2938-1594.png) +![新建机器人3](https://img.alicdn.com/imgextra/i1/O1CN01Znm7aQ1Tfpe5Ha9WL_!!6000000002410-2-tps-1482-992.png) + +### 绑定bot + +可以在Console或是`config.json`填写Bot ID和Secret绑定bot + +**方法一**在console填写 +![绑定机器人](https://img.alicdn.com/imgextra/i2/O1CN01X8NcEj1NrqL0e3AMS_!!6000000001624-2-tps-2732-1390.png) + +**方法二**在`config.json`填写(默认文件路径为`~/.copaw/config.json`) +找到`wecom`,填写对应信息,例如: + +```json +"wecom": { + "enabled": true, + "dm_policy": "open", + "group_policy": "open", + "bot_id": "your bot_id", + "secret": "your secret", + "media_dir": "~/.copaw/media", + "max_reconnect_attempts": -1 + } +``` + +### 在企业微信开始与机器人聊天 + +![开始使用](https://img.alicdn.com/imgextra/i3/O1CN01ZsmpYr1tq4ViIbO80_!!6000000005952-2-tps-1308-1130.png) + +--- + ## Telegram ### 获取 Telegram 机器人凭证 @@ -668,6 +713,7 @@ Matrix 频道通过 [matrix-nio](https://github.com/poljar/matrix-nio) 库将 Co | iMessage | imessage | db_path, poll_sec(仅 macOS) | | Discord | discord | bot_token;可选 http_proxy, http_proxy_auth | | QQ | qq | app_id, client_secret | +| 企业微信 | wecom | bot_id, secret;可选 media_dir, max_reconnect_attempts | | Telegram | telegram | bot_token;可选 http_proxy, http_proxy_auth | | Mattermost | mattermost | url, bot_token; 可选 show_typing, dm_policy, allow_from | | Matrix | matrix | homeserver, user_id, access_token | @@ -687,6 +733,7 @@ Matrix 频道通过 [matrix-nio](https://github.com/poljar/matrix-nio) 库将 Co | Discord | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | | iMessage | ✓ | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | | QQ | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | +| 企业微信 | ✓ | ✓ | 🚧 | ✓ | ✓ | ✓ | 🚧 | 🚧 | 🚧 | 🚧 | | Telegram | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | Mattermost | ✓ | ✓ | 🚧 | 🚧 | ✓ | ✓ | ✓ | 🚧 | 🚧 | ✓ | | Matrix | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | @@ -700,6 +747,7 @@ Matrix 频道通过 [matrix-nio](https://github.com/poljar/matrix-nio) 库将 Co - **iMessage**:基于本地 imsg + 数据库轮询,仅支持文本收发;平台/实现限制,无法支持附件(✗)。 - **QQ**:接收侧附件解析为多模态、发送侧真实媒体均为 🚧 施工中,当前仅文本 + 链接形式。 - **Telegram**:接收时附件会解析为文件并传入,可在telegram对话界面以对应格式打开(图片 / 语音 / 视频 / 文件) +- **企业微信**:WebSocket 长连接接收,markdown/template_card 发送;支持接收文本、图片、语音和文件;发送媒体暂不支持(SDK 限制,仅支持通过 markdown 发送文本)。 - **Matrix**:接收图片 / 视频 / 音频 / 文件(通过 `mxc://` 媒体 URL);发送时将文件上传至服务器后以原生 Matrix 媒体消息(`m.image`、`m.video`、`m.audio`、`m.file`)发出。 - **小艺**:当前仅支持文本。 From fdd000bb5b03254607de84db4ce420cb9945be3f Mon Sep 17 00:00:00 2001 From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com> Date: Mon, 16 Mar 2026 13:58:52 +0800 Subject: [PATCH 06/68] feat(CLI): Add `copaw update` command to update CoPaw automatically (#1278) --- pyproject.toml | 1 + src/copaw/cli/main.py | 12 + src/copaw/cli/process_utils.py | 236 ++++++ src/copaw/cli/shutdown_cmd.py | 382 ++++++++++ src/copaw/cli/update_cmd.py | 729 ++++++++++++++++++ tests/unit/cli/test_cli_shutdown.py | 238 ++++++ tests/unit/cli/test_cli_update.py | 910 +++++++++++++++++++++++ tests/{ => unit/cli}/test_cli_version.py | 0 8 files changed, 2508 insertions(+) create mode 100644 src/copaw/cli/process_utils.py create mode 100644 src/copaw/cli/shutdown_cmd.py create mode 100644 src/copaw/cli/update_cmd.py create mode 100644 tests/unit/cli/test_cli_shutdown.py create mode 100644 tests/unit/cli/test_cli_update.py rename tests/{ => unit/cli}/test_cli_version.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 2fc13f59c..5afaf9452 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "agentscope==1.0.16.dev0", "agentscope-runtime==1.1.0", "httpx>=0.27.0", + "packaging>=24.0", "discord-py>=2.3", "dingtalk-stream>=0.24.3", "uvicorn>=0.40.0", diff --git a/src/copaw/cli/main.py b/src/copaw/cli/main.py index 04fcb9c8e..9491f5c13 100644 --- a/src/copaw/cli/main.py +++ b/src/copaw/cli/main.py @@ -101,6 +101,16 @@ def _record(label: str, elapsed: float) -> None: _record(".desktop_cmd", time.perf_counter() - _t) +_t = time.perf_counter() +from .update_cmd import update_cmd # noqa: E402 + +_record(".update_cmd", time.perf_counter() - _t) + +_t = time.perf_counter() +from .shutdown_cmd import shutdown_cmd # noqa: E402 + +_record(".shutdown_cmd", time.perf_counter() - _t) + _total = time.perf_counter() - _t0_main _init_timings.append(("(total imports)", _total)) logger.debug("%.3fs (total imports)", _total) @@ -152,3 +162,5 @@ def cli(ctx: click.Context, host: str | None, port: int | None) -> None: cli.add_command(skills_group) cli.add_command(uninstall_cmd) cli.add_command(desktop_cmd) +cli.add_command(update_cmd) +cli.add_command(shutdown_cmd) diff --git a/src/copaw/cli/process_utils.py b/src/copaw/cli/process_utils.py new file mode 100644 index 000000000..fabbdc100 --- /dev/null +++ b/src/copaw/cli/process_utils.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import csv +import io +import json +import re +import subprocess +import sys +from typing import Optional + + +_PORT_ARG_PATTERN = re.compile(r"(?:^|\s)--port(?:=|\s+)(\d+)(?=\s|$)") + + +def _coerce_optional_int(value: object) -> Optional[int]: + """Best-effort conversion of JSON-decoded values to integers.""" + if value is None: + return None + if isinstance(value, int): + return value + if isinstance(value, str): + try: + return int(value) + except ValueError: + return None + return None + + +def _parse_windows_process_snapshot_json( + payload: str, +) -> dict[int, tuple[Optional[int], str, str]]: + """Parse PowerShell JSON process snapshot output.""" + if not payload.strip(): + return {} + + try: + data = json.loads(payload) + except json.JSONDecodeError: + return {} + + rows = data if isinstance(data, list) else [data] + snapshot: dict[int, tuple[Optional[int], str, str]] = {} + for row in rows: + if not isinstance(row, dict): + continue + + pid_value = row.get("ProcessId") + parent_value = row.get("ParentProcessId") + pid = _coerce_optional_int(pid_value) + if pid is None: + continue + + parent_pid = _coerce_optional_int(parent_value) + + name = str(row.get("Name") or "") + command = str(row.get("CommandLine") or "") + snapshot[pid] = (parent_pid, name, command) + return snapshot + + +def _parse_windows_process_snapshot_csv( + payload: str, +) -> dict[int, tuple[Optional[int], str, str]]: + """Parse WMIC CSV process snapshot output.""" + if not payload.strip(): + return {} + + snapshot: dict[int, tuple[Optional[int], str, str]] = {} + reader = csv.DictReader(io.StringIO(payload)) + for row in reader: + pid_value = (row.get("ProcessId") or "").strip() + if not pid_value.isdigit(): + continue + + parent_value = (row.get("ParentProcessId") or "").strip() + parent_pid = int(parent_value) if parent_value.isdigit() else None + pid = int(pid_value) + name = (row.get("Name") or "").strip() + command = (row.get("CommandLine") or "").strip() + snapshot[pid] = (parent_pid, name, command) + return snapshot + + +def _windows_process_snapshot() -> dict[int, tuple[Optional[int], str, str]]: + """Return Windows process info as pid -> (parent_pid, name, cmdline).""" + commands = ( + ( + [ + "powershell", + "-NoProfile", + "-Command", + ( + "Get-CimInstance Win32_Process | " + "Select-Object ProcessId,ParentProcessId,Name," + "CommandLine | ConvertTo-Json -Compress" + ), + ], + _parse_windows_process_snapshot_json, + ), + ( + [ + "wmic", + "process", + "get", + "ProcessId,ParentProcessId,Name,CommandLine", + "/FORMAT:CSV", + ], + _parse_windows_process_snapshot_csv, + ), + ) + + for command, parser in commands: + try: + result = subprocess.run( + command, + capture_output=True, + text=True, + timeout=15, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + continue + + snapshot = parser(result.stdout or "") + if snapshot: + return snapshot + return {} + + +def _process_table() -> list[tuple[int, str]]: + """Return a best-effort process table as (pid, command line).""" + if sys.platform == "win32": + return [ + (pid, command or name or "") + for pid, ( + _parent_pid, + name, + command, + ) in _windows_process_snapshot().items() + ] + + try: + result = subprocess.run( + ["ps", "-ax", "-o", "pid=", "-o", "command="], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + return [] + + rows: list[tuple[int, str]] = [] + for line in (result.stdout or "").splitlines(): + stripped = line.strip() + if not stripped: + continue + parts = stripped.split(None, 1) + if not parts or not parts[0].isdigit(): + continue + command = parts[1] if len(parts) > 1 else "" + rows.append((int(parts[0]), command)) + return rows + + +def _matches_copaw_cli_command(command: str, *subcommands: str) -> bool: + """Return whether command line looks like a CoPaw CLI invocation.""" + lowered = f" {command.lower()}" + return any( + pattern in lowered + for subcommand in subcommands + for pattern in ( + f" -m copaw {subcommand}", + f" copaw {subcommand}", + f"__main__.py {subcommand}", + f'copaw.exe" {subcommand}', + f"copaw.exe {subcommand}", + ) + ) + + +def _is_copaw_service_command(command: str) -> bool: + """Return whether the command line looks like a local CoPaw app.""" + return _matches_copaw_cli_command(command, "app") + + +def _is_copaw_wrapper_process(name: str, command: str) -> bool: + """Return whether the process looks like a CoPaw CLI wrapper.""" + lowered_name = name.lower().removesuffix(".exe") + return lowered_name == "copaw" or _matches_copaw_cli_command( + command, + "app", + "desktop", + ) + + +def _extract_port_from_command(command: str, default: int = 8088) -> int: + """Extract `--port` from a command line when present.""" + match = _PORT_ARG_PATTERN.search(command) + return int(match.group(1)) if match else default + + +def _base_url(host: str, port: int) -> str: + """Build a base URL from host and port.""" + normalized_host = host.strip() + if ":" in normalized_host and not normalized_host.startswith("["): + normalized_host = f"[{normalized_host}]" + return f"http://{normalized_host}:{port}" + + +def _candidate_hosts(host: str | None) -> list[str]: + """Return host variants that can reach a local CoPaw service.""" + if not host: + return [] + + normalized = host.strip() + lowered = normalized.lower().strip("[]") + candidates: list[str] = [] + + def _add(value: str) -> None: + if value and value not in candidates: + candidates.append(value) + + if lowered in {"0.0.0.0", "::"}: + _add("127.0.0.1") + _add("localhost") + if lowered == "::": + _add("::1") + elif lowered == "localhost": + _add("localhost") + _add("127.0.0.1") + _add("::1") + + _add(normalized) + return candidates diff --git a/src/copaw/cli/shutdown_cmd.py b/src/copaw/cli/shutdown_cmd.py new file mode 100644 index 000000000..127cbdda3 --- /dev/null +++ b/src/copaw/cli/shutdown_cmd.py @@ -0,0 +1,382 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import os +import signal +import subprocess +import sys +import time +from pathlib import Path +from typing import Optional + +import click + +from .process_utils import ( + _is_copaw_wrapper_process, + _process_table, + _windows_process_snapshot, +) + + +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +_CONSOLE_DIR = (_PROJECT_ROOT / "console").resolve() +_SIGTERM = signal.SIGTERM +_SIGKILL = getattr(signal, "SIGKILL", _SIGTERM) + + +def _backend_port(ctx: click.Context, port: Optional[int]) -> int: + """Resolve backend port from explicit option or global CLI context.""" + if port is not None: + return port + return int((ctx.obj or {}).get("port", 8088)) + + +def _listening_pids_for_port(port: int) -> set[int]: + """Return PIDs currently listening on the given TCP port.""" + if sys.platform == "win32": + try: + result = subprocess.run( + ["netstat", "-ano", "-p", "tcp"], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + return set() + + pids: set[int] = set() + suffix = f":{port}" + for line in (result.stdout or "").splitlines(): + parts = line.split() + if len(parts) < 5: + continue + local_addr = parts[1] + state = parts[3].upper() + if not local_addr.endswith(suffix) or state != "LISTENING": + continue + try: + pids.add(int(parts[4])) + except ValueError: + continue + return pids + + commands = ( + ["lsof", "-nP", f"-iTCP:{port}", "-sTCP:LISTEN", "-t"], + ["fuser", f"{port}/tcp"], + ) + for command in commands: + try: + result = subprocess.run( + command, + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + continue + + pids = { + int(token) + for token in (result.stdout or "").split() + if token.isdigit() + } + if pids: + return pids + return set() + + +def _find_frontend_dev_pids() -> set[int]: + """Find Vite dev-server processes for this repository's console app.""" + console_dir = str(_CONSOLE_DIR).lower() + matches: set[int] = set() + for pid, command in _process_table(): + lowered = command.lower() + if "vite" in lowered and console_dir in lowered: + matches.add(pid) + continue + if "copaw-console" in lowered and ( + "npm" in lowered + or "pnpm" in lowered + or "yarn" in lowered + or "node" in lowered + ): + matches.add(pid) + return matches + + +def _find_desktop_wrapper_pids() -> set[int]: + """Find `copaw desktop` wrapper processes for this project.""" + matches: set[int] = set() + patterns = ( + " -m copaw desktop", + " copaw desktop", + "__main__.py desktop", + ) + for pid, command in _process_table(): + lowered = f" {command.lower()}" + if any(pattern in lowered for pattern in patterns): + matches.add(pid) + return matches + + +def _find_windows_wrapper_ancestor_pids(pids: set[int]) -> set[int]: + """Find CoPaw wrapper/supervisor ancestors for Windows backend PIDs.""" + if sys.platform != "win32" or not pids: + return set() + + snapshot = _windows_process_snapshot() + matches: set[int] = set() + for pid in pids: + visited: set[int] = set() + current_pid = pid + while True: + info = snapshot.get(current_pid) + if info is None: + break + + parent_pid = info[0] + if parent_pid in (None, 0) or parent_pid in visited: + break + visited.add(parent_pid) + + parent_info = snapshot.get(parent_pid) + if parent_info is None: + break + + if _is_copaw_wrapper_process(parent_info[1], parent_info[2]): + matches.add(parent_pid) + + current_pid = parent_pid + return matches + + +def _child_pids_unix(pid: int) -> set[int]: + """Recursively collect child PIDs for Unix-like systems.""" + children: set[int] = set() + stack = [pid] + while stack: + current = stack.pop() + try: + result = subprocess.run( + ["pgrep", "-P", str(current)], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + continue + for token in (result.stdout or "").split(): + if not token.isdigit(): + continue + child = int(token) + if child in children: + continue + children.add(child) + stack.append(child) + return children + + +def _pid_exists(pid: int) -> bool: + """Return whether the PID still exists.""" + if pid <= 0: + return False + if sys.platform == "win32": + return pid in _windows_process_snapshot() + try: + os.kill(pid, 0) + except OSError: + return False + return True + + +def _wait_for_pid_exit( + pid: int, + timeout_sec: float, + interval_sec: float, +) -> bool: + """Wait until a PID exits within the given timeout.""" + deadline = time.monotonic() + timeout_sec + while time.monotonic() < deadline: + if not _pid_exists(pid): + return True + time.sleep(interval_sec) + return not _pid_exists(pid) + + +def _signal_process_tree_unix(pid: int, sig: signal.Signals) -> None: + """Send a signal to a Unix process and its descendants.""" + descendants = sorted(_child_pids_unix(pid), reverse=True) + for child_pid in descendants: + try: + os.kill(child_pid, sig) + except OSError: + continue + try: + os.kill(pid, sig) + except OSError: + pass + + +def _terminate_process_tree_windows(pid: int, force: bool = False) -> None: + """Terminate a Windows process tree.""" + command = ["taskkill", "/T", "/PID", str(pid)] + if force: + command.insert(1, "/F") + try: + subprocess.run( + command, + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + pass + + +def _force_terminate_windows_process(pid: int) -> None: + """Force terminate a Windows process as a fallback.""" + commands = ( + [ + "powershell", + "-NoProfile", + "-Command", + ( + "$ErrorActionPreference='SilentlyContinue'; " + f"Stop-Process -Id {pid} -Force" + ), + ], + ["taskkill", "/F", "/PID", str(pid)], + ) + for command in commands: + try: + subprocess.run( + command, + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + continue + + +def _terminate_pid(pid: int, timeout_sec: float = 5.0) -> bool: + """Terminate a process tree gracefully, then force kill if needed.""" + if not _pid_exists(pid): + return True + + if sys.platform == "win32": + _terminate_process_tree_windows(pid) + else: + _signal_process_tree_unix(pid, _SIGTERM) + + if _wait_for_pid_exit(pid, timeout_sec, 0.2): + return True + + if sys.platform == "win32": + _terminate_process_tree_windows(pid, force=True) + if _wait_for_pid_exit(pid, 2.0, 0.1): + return True + _force_terminate_windows_process(pid) + else: + _signal_process_tree_unix(pid, _SIGKILL) + + return _wait_for_pid_exit(pid, 2.0, 0.1) + + +def _stop_pid_set(pids: set[int]) -> tuple[list[int], list[int]]: + """Stop a set of PIDs and return (stopped, failed).""" + stopped: list[int] = [] + failed: list[int] = [] + for pid in sorted(pids): + if _terminate_pid(pid): + stopped.append(pid) + else: + failed.append(pid) + return stopped, failed + + +@click.command("shutdown", help="Force stop the running CoPaw app processes.") +@click.option( + "--port", + default=None, + type=int, + help="Backend port to stop. Defaults to global --port from config.", +) +@click.pass_context +def shutdown_cmd(ctx: click.Context, port: Optional[int]) -> None: + """Stop the running CoPaw app processes. + + `copaw app` only starts the backend process. The web console is normally + static files served by that backend. During frontend development, a + separate Vite process may also be running from the repository's + `console/` directory, and this command will stop that as well. + """ + backend_port = _backend_port(ctx, port) + backend_pids = _listening_pids_for_port(backend_port) + frontend_pids = _find_frontend_dev_pids() + desktop_pids = _find_desktop_wrapper_pids() + wrapper_pids = _find_windows_wrapper_ancestor_pids(backend_pids) + + # Build a process table for logging. + proc_table = dict(_process_table()) + + def log_pid_set(title, pids): + if not pids: + click.echo(f"{title}: nothing to stop") + return + click.echo(f"{title} ({len(pids)} total):") + for pid in sorted(pids): + cmd = proc_table.get(pid, "") + click.echo(f" PID {pid}: {cmd}") + + log_pid_set("Backend listener processes", backend_pids) + log_pid_set("Frontend development processes", frontend_pids) + log_pid_set("Desktop wrapper processes", desktop_pids) + log_pid_set("Related wrapper processes", wrapper_pids) + + all_targets = backend_pids | frontend_pids | desktop_pids | wrapper_pids + if not all_targets: + raise click.ClickException( + "No running CoPaw backend/frontend process was found.", + ) + + wrapper_stopped, wrapper_failed = _stop_pid_set(wrapper_pids) + frontend_stopped, frontend_failed = _stop_pid_set(frontend_pids) + desktop_stopped, desktop_failed = _stop_pid_set( + desktop_pids - set(wrapper_stopped) - set(frontend_stopped), + ) + backend_stopped, backend_failed = _stop_pid_set( + backend_pids + - set(wrapper_stopped) + - set(frontend_stopped) + - set(desktop_stopped), + ) + + stopped = ( + wrapper_stopped + frontend_stopped + desktop_stopped + backend_stopped + ) + failed = list( + set( + wrapper_failed + frontend_failed + desktop_failed + backend_failed, + ), + ) + + if stopped: + click.echo( + "Stopped CoPaw processes: " + + ", ".join(str(pid) for pid in sorted(stopped)), + ) + if failed: + click.echo("Failed to stop the following processes:") + for pid in sorted(failed): + cmd = proc_table.get(pid, "") + click.echo(f" PID {pid}: {cmd}") + raise click.ClickException( + "Failed to shutdown process(es): " + + ", ".join(str(pid) for pid in sorted(failed)), + ) diff --git a/src/copaw/cli/update_cmd.py b/src/copaw/cli/update_cmd.py new file mode 100644 index 000000000..c53d2881b --- /dev/null +++ b/src/copaw/cli/update_cmd.py @@ -0,0 +1,729 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import json +import os +import signal +import shutil +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from importlib import metadata +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import click +import httpx +from packaging.version import InvalidVersion, Version + +from ..__version__ import __version__ +from ..constant import WORKING_DIR +from ..config.utils import read_last_api +from .process_utils import ( + _base_url, + _candidate_hosts, + _extract_port_from_command, + _is_copaw_service_command, + _process_table, +) + +_PYPI_JSON_URL = "https://pypi.org/pypi/copaw/json" + + +def _subprocess_text_kwargs() -> dict[str, Any]: + """Return robust text-decoding settings for subprocess output. + + Package installers may emit UTF-8 regardless of the active Windows code + page. Using replacement for undecodable bytes prevents the update worker + from crashing while streaming output. + """ + return { + "text": True, + "encoding": "utf-8", + "errors": "replace", + } + + +@dataclass(frozen=True) +class InstallInfo: + """Information about the current CoPaw installation.""" + + package_dir: str + python_executable: str + environment_root: str + environment_kind: str + installer: str + source_type: str + source_url: str | None = None + + +@dataclass(frozen=True) +class RunningServiceInfo: + """Detected CoPaw service endpoint state.""" + + is_running: bool + base_url: str | None = None + version: str | None = None + + +def _version_obj(version: str) -> Any: + """Parse version when possible; otherwise keep the raw string.""" + try: + return Version(version) + except InvalidVersion: + return version + + +def _is_newer_version(latest: str, current: str) -> bool | None: + """Return whether latest is newer than current. + + Returns `None` when either version cannot be compared reliably. + """ + parsed_latest = _version_obj(latest) + parsed_current = _version_obj(current) + if isinstance(parsed_latest, str) or isinstance(parsed_current, str): + if latest == current: + return False + return None + return parsed_latest > parsed_current + + +def _fetch_latest_version() -> str: + """Fetch the latest published CoPaw version from PyPI.""" + try: + resp = httpx.get( + _PYPI_JSON_URL, + timeout=10.0, + headers={"Accept": "application/json"}, + ) + resp.raise_for_status() + data = resp.json() + except httpx.HTTPError as exc: + raise click.ClickException( + f"Failed to fetch the latest CoPaw version from PyPI: {exc}", + ) from exc + except json.JSONDecodeError as exc: + raise click.ClickException( + "Received an invalid response from PyPI when checking for the " + f"latest CoPaw version: {exc}", + ) from exc + version = str(data.get("info", {}).get("version", "")).strip() + if not version: + raise click.ClickException( + "Unable to determine the latest CoPaw version.", + ) + return version + + +def _detect_source_type( + direct_url: dict[str, Any] | None, +) -> tuple[str, str | None]: + """Classify the current installation origin.""" + if not direct_url: + return ("pypi", None) + + url = direct_url.get("url") + dir_info = direct_url.get("dir_info") or {} + if dir_info.get("editable"): + return ("editable", url) + if direct_url.get("vcs_info"): + return ("vcs", url) + if isinstance(url, str) and url.startswith("file://"): + return ("local", url) + return ("direct-url", url if isinstance(url, str) else None) + + +def _detect_installation() -> InstallInfo: + """Inspect the current Python environment and installation style.""" + dist = metadata.distribution("copaw") + # if installed through uv, installer will be `uv` + installer = (dist.read_text("INSTALLER") or "pip").strip() or "pip" + + direct_url: dict[str, Any] | None = None + direct_url_text = dist.read_text("direct_url.json") + if direct_url_text: + try: + direct_url = json.loads(direct_url_text) + except json.JSONDecodeError: + direct_url = None + + source_type, source_url = _detect_source_type(direct_url) + package_dir = Path(__file__).resolve().parent.parent + python_executable = sys.executable + environment_root = Path(sys.prefix).resolve() + environment_kind = ( + "virtualenv" if sys.prefix != sys.base_prefix else "system" + ) + + return InstallInfo( + package_dir=str(package_dir), + python_executable=str(python_executable), + environment_root=str(environment_root), + environment_kind=environment_kind, + installer=installer, + source_type=source_type, + source_url=source_url, + ) + + +def _probe_service(base_url: str) -> RunningServiceInfo: + """Probe a possible running CoPaw HTTP service.""" + try: + resp = httpx.get( + f"{base_url.rstrip('/')}/api/version", + timeout=2.0, + headers={"Accept": "application/json"}, + trust_env=False, + ) + resp.raise_for_status() + payload = resp.json() + except (httpx.HTTPError, ValueError): + return RunningServiceInfo(is_running=False) + + version = payload.get("version") if isinstance(payload, dict) else None + return RunningServiceInfo( + is_running=True, + base_url=base_url.rstrip("/"), + version=str(version) if version else None, + ) + + +def _process_candidate_ports() -> list[int]: + """Infer candidate local CoPaw service ports from running processes.""" + ports: list[int] = [] + for _pid, command in _process_table(): + if not _is_copaw_service_command(command): + continue + + port = _extract_port_from_command(command) + if port not in ports: + ports.append(port) + return ports + + +def _detect_running_service_from_processes( + preferred_hosts: list[str], +) -> RunningServiceInfo: + """Best-effort local process fallback for service detection.""" + for port in _process_candidate_ports(): + hosts = preferred_hosts or ["127.0.0.1", "localhost"] + for host in hosts: + result = _probe_service(_base_url(host, port)) + if result.is_running: + return result + + fallback_host = next(iter(hosts), "127.0.0.1") + return RunningServiceInfo( + is_running=True, + base_url=_base_url(fallback_host, port), + ) + + return RunningServiceInfo(is_running=False) + + +def _detect_running_service( + host: str | None, + port: int | None, +) -> RunningServiceInfo: + """Detect whether a CoPaw HTTP service is currently running.""" + candidates: list[str] = [] + seen: set[str] = set() + preferred_hosts: list[str] = [] + + def _remember_hosts(candidate_host: str | None) -> None: + for item in _candidate_hosts(candidate_host): + if item not in preferred_hosts: + preferred_hosts.append(item) + + def _add_candidate( + candidate_host: str | None, + candidate_port: int | None, + ) -> None: + if not candidate_host or candidate_port is None: + return + _remember_hosts(candidate_host) + for resolved_host in _candidate_hosts(candidate_host): + base_url = _base_url(resolved_host, candidate_port) + if base_url in seen: + continue + seen.add(base_url) + candidates.append(base_url) + + _add_candidate(host, port) + last = read_last_api() + if last: + _add_candidate(last[0], last[1]) + _add_candidate("127.0.0.1", 8088) + + for base_url in candidates: + result = _probe_service(base_url) + if result.is_running: + return result + + return _detect_running_service_from_processes(preferred_hosts) + + +def _running_service_display(running: RunningServiceInfo) -> str: + """Build a concise running-service description for user prompts.""" + if not running.base_url: + return "a running CoPaw service" + version_suffix = f" (version {running.version})" if running.version else "" + return f"CoPaw service at {running.base_url}{version_suffix}" + + +def _confirm_force_shutdown(running: RunningServiceInfo) -> bool: + """Ask whether `copaw shutdown` should be used before updating.""" + click.echo("") + click.secho("!" * 72, fg="yellow", bold=True) + click.secho( + "WARNING: RUNNING COPAW SERVICE DETECTED", + fg="yellow", + bold=True, + ) + click.secho("!" * 72, fg="yellow", bold=True) + click.secho( + f"Detected {_running_service_display(running)}.", + fg="yellow", + bold=True, + ) + click.secho( + "Running `copaw shutdown` will forcibly terminate the current " + "CoPaw backend/frontend processes.", + fg="red", + bold=True, + ) + click.secho( + "Active requests, background tasks, or unsaved work may be " + "interrupted immediately.", + fg="red", + bold=True, + ) + click.echo("") + return click.confirm( + "Run `copaw shutdown` now and continue with the update?", + default=False, + ) + + +def _run_shutdown_for_update( + info: InstallInfo, + running: RunningServiceInfo, +) -> None: + """Run `copaw shutdown` in the current environment before updating.""" + command = [info.python_executable, "-m", "copaw"] + parsed = urlparse(running.base_url or "") + if parsed.port is not None: + command.extend(["--port", str(parsed.port)]) + command.append("shutdown") + + click.echo("") + click.echo("Running `copaw shutdown` before updating...") + + try: + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + **_subprocess_text_kwargs(), + check=False, + ) + except OSError as exc: + raise click.ClickException( + "Failed to run `copaw shutdown`: " f"{exc}", + ) from exc + + output = (result.stdout or "").strip() + if output: + click.echo(output) + + if result.returncode != 0: + raise click.ClickException( + "`copaw shutdown` failed. Please stop the running CoPaw " + "service manually before running `copaw update`.", + ) + + +def _build_upgrade_command( + info: InstallInfo, + latest_version: str, +) -> tuple[list[str], str]: + """Build the installer command used by the detached update worker.""" + package_spec = f"copaw=={latest_version}" + installer = info.installer.lower() + if installer.startswith("uv") and shutil.which("uv"): + return ( + [ + "uv", + "pip", + "install", + "--python", + info.python_executable, + "--upgrade", + package_spec, + "--prerelease=allow", + ], + "uv pip", + ) + return ( + [ + info.python_executable, + "-m", + "pip", + "install", + "--upgrade", + package_spec, + "--disable-pip-version-check", + ], + "pip", + ) + + +def _plan_dir() -> Path: + """Directory used to persist short-lived update worker plans.""" + return WORKING_DIR / "updates" + + +def _write_worker_plan(plan: dict[str, Any]) -> Path: + """Persist a worker plan for the detached process.""" + plan_dir = _plan_dir() + plan_dir.mkdir(parents=True, exist_ok=True) + plan_path = plan_dir / f"update-{int(time.time() * 1000)}.json" + plan_path.write_text( + json.dumps(plan, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + return plan_path + + +def _spawn_update_worker( + plan_path: Path, + *, + capture_output: bool = True, +) -> subprocess.Popen[str]: + """Spawn the worker that performs the actual package upgrade.""" + worker_code = ( + "from copaw.cli.update_cmd import run_update_worker; " + "import sys; " + "sys.exit(run_update_worker(sys.argv[1]))" + ) + kwargs: dict[str, Any] = {"stdin": subprocess.DEVNULL} + if capture_output: + kwargs.update( + { + "stdout": subprocess.PIPE, + "stderr": subprocess.STDOUT, + **_subprocess_text_kwargs(), + "bufsize": 1, + }, + ) + if sys.platform == "win32": + kwargs["creationflags"] = getattr( + subprocess, + "CREATE_NEW_PROCESS_GROUP", + 0, + ) + else: + kwargs["start_new_session"] = True + + return subprocess.Popen( # pylint: disable=consider-using-with + [sys.executable, "-u", "-c", worker_code, str(plan_path)], + **kwargs, + ) + + +def _terminate_update_worker(proc: subprocess.Popen[str]) -> None: + """Best-effort termination for the worker and its installer child.""" + if proc.poll() is not None: + return + + try: + if sys.platform == "win32": + ctrl_break = getattr(signal, "CTRL_BREAK_EVENT", None) + if ctrl_break is not None: + proc.send_signal(ctrl_break) + try: + proc.wait(timeout=5) + return + except subprocess.TimeoutExpired: + pass + proc.terminate() + else: + os.killpg(proc.pid, signal.SIGTERM) + except (OSError, ProcessLookupError, ValueError): + return + + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + try: + proc.kill() + except OSError: + return + + +def _wait_for_process_exit(pid: int | None, timeout: float = 15.0) -> None: + """Wait briefly for another process to exit before updating files.""" + if pid is None or pid <= 0: + return + + if sys.platform == "win32": + try: + import ctypes + + kernel32 = ctypes.windll.kernel32 + synchronize = 0x00100000 + wait_timeout = 0x00000102 + handle = kernel32.OpenProcess(synchronize, False, pid) + if not handle: + return + try: + result = kernel32.WaitForSingleObject( + handle, + max(0, int(timeout * 1000)), + ) + if result == wait_timeout: + time.sleep(1.0) + finally: + kernel32.CloseHandle(handle) + except (AttributeError, ImportError, OSError): + time.sleep(min(timeout, 2.0)) + return + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + os.kill(pid, 0) + except OSError: + return + time.sleep(0.1) + + +def _run_update_worker_foreground(plan_path: Path) -> int: + """Run the update worker in a child process and wait for completion.""" + try: + proc = _spawn_update_worker(plan_path) + except OSError as exc: + raise click.ClickException( + "Failed to start update worker: " f"{exc}", + ) from exc + + try: + with proc: + if proc.stdout is not None: + for line in proc.stdout: + click.echo(line.rstrip()) + return proc.wait() + except KeyboardInterrupt: + click.echo("") + click.echo("[copaw] Update interrupted. Stopping installer...") + _terminate_update_worker(proc) + return 130 + + +def _run_update_worker_detached(plan_path: Path) -> None: + """Launch the update worker and return immediately.""" + try: + _spawn_update_worker(plan_path, capture_output=False) + except OSError as exc: + raise click.ClickException( + "Failed to start update worker: " f"{exc}", + ) from exc + + +def _load_worker_plan(plan_path: str | Path) -> dict[str, Any]: + """Load a persisted worker plan.""" + return json.loads(Path(plan_path).read_text(encoding="utf-8")) + + +def run_update_worker(plan_path: str | Path) -> int: + """Run the update worker and stream installer output.""" + path = Path(plan_path) + plan = _load_worker_plan(path) + command = [str(part) for part in plan["command"]] + + _wait_for_process_exit(plan.get("launcher_pid")) + + click.echo("") + click.echo( + "[copaw] Updating CoPaw " + f"{plan['current_version']} -> {plan['latest_version']}...", + ) + click.echo(f"[copaw] Using installer: {plan['installer_label']}") + + try: + with subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + **_subprocess_text_kwargs(), + bufsize=1, + ) as proc: + if proc.stdout is not None: + for line in proc.stdout: + click.echo(line.rstrip()) + return_code = proc.wait() + except FileNotFoundError as exc: + click.echo(f"[copaw] Update failed: {exc}") + return_code = 1 + finally: + try: + path.unlink(missing_ok=True) + except OSError: + pass + + if return_code == 0: + click.echo("[copaw] Update completed successfully.") + click.echo( + "[copaw] Please restart any running CoPaw service " + "to use the new version.", + ) + else: + click.echo(f"[copaw] Update failed with exit code {return_code}.") + click.echo( + "[copaw] Please fix the error above and run " + "`copaw update` again.", + ) + + return return_code + + +def _echo_install_summary(info: InstallInfo, latest_version: str) -> None: + """Print the update summary shown before launching the worker.""" + click.echo(f"Current version: {__version__}") + click.echo(f"Latest version: {latest_version}") + click.echo(f"Python: {info.python_executable}") + click.echo( + f"Environment: {info.environment_kind} " + f"({info.environment_root})", + ) + click.echo(f"Install path: {info.package_dir}") + click.echo(f"Installer: {info.installer}") + + +def _confirm_source_override(info: InstallInfo, yes: bool) -> bool: + """Confirm whether a non-PyPI installation should be overwritten.""" + if info.source_type == "pypi": + return True + + detail = f" ({info.source_url})" if info.source_url else "" + message = ( + "Detected a non-PyPI installation source: " + f"{info.source_type}{detail}. Updating will overwrite the current " + "installation with the PyPI release for this environment." + ) + + if yes: + click.echo( + f"Warning: {message} Proceeding because `--yes` was provided.", + ) + return True + + click.echo(f"Warning: {message}") + return click.confirm( + "Continue and replace the current installation with the PyPI " + "version?", + default=False, + ) + + +@click.command("update") +@click.option( + "--yes", + is_flag=True, + help="Do not prompt before starting the update", +) +@click.pass_context +def update_cmd(ctx: click.Context, yes: bool) -> None: + """Upgrade CoPaw in the current Python environment.""" + info = _detect_installation() + latest_version = _fetch_latest_version() + + _echo_install_summary(info, latest_version) + + version_check = _is_newer_version(latest_version, __version__) + if version_check is False: + click.echo("CoPaw is already up to date.") + return + + if not _confirm_source_override(info, yes): + click.echo("Cancelled.") + return + + if version_check is None: + if yes: + click.echo( + "Warning: unable to compare the current version" + f"({__version__}) with the latest version ({latest_version})" + " automatically. Proceeding because `--yes` was provided.", + ) + elif not click.confirm( + f"Unable to compare the current version ({__version__}) with the " + f"latest version ({latest_version}) automatically. Continue with " + "update anyway?", + default=False, + ): + click.echo("Cancelled.") + return + + running = _detect_running_service( + ctx.obj.get("host") if ctx.obj else None, + ctx.obj.get("port") if ctx.obj else None, + ) + if running.is_running: + if yes: + raise click.ClickException( + "Detected " + f"{_running_service_display(running)}. " + "Please stop it before running `copaw update`, or rerun " + "without `--yes` to confirm a forced `copaw shutdown`.", + ) + if not _confirm_force_shutdown(running): + click.echo("Cancelled.") + return + _run_shutdown_for_update(info, running) + running = _detect_running_service( + ctx.obj.get("host") if ctx.obj else None, + ctx.obj.get("port") if ctx.obj else None, + ) + if running.is_running: + raise click.ClickException( + "Detected " + f"{_running_service_display(running)} after `copaw shutdown`. " + "Please stop it manually before running `copaw update`.", + ) + + if not yes and not click.confirm( + f"Update CoPaw to {latest_version} in the current environment?", + default=True, + ): + click.echo("Cancelled.") + return + + command, installer_label = _build_upgrade_command(info, latest_version) + plan = { + "current_version": __version__, + "latest_version": latest_version, + "installer_label": installer_label, + "command": command, + "install": asdict(info), + "launcher_pid": os.getpid() if sys.platform == "win32" else None, + } + plan_path = _write_worker_plan(plan) + click.echo("") + click.echo("Starting CoPaw update...") + + if sys.platform == "win32": + _run_update_worker_detached(plan_path) + click.echo( + "On Windows, the update will continue after this command exits " + "to avoid locking `copaw.exe`.", + ) + click.echo("Keep this terminal open until the update completes.") + return + + return_code = _run_update_worker_foreground(plan_path) + + if return_code != 0: + ctx.exit(return_code) diff --git a/tests/unit/cli/test_cli_shutdown.py b/tests/unit/cli/test_cli_shutdown.py new file mode 100644 index 000000000..a99c36cfc --- /dev/null +++ b/tests/unit/cli/test_cli_shutdown.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from click.testing import CliRunner + +from copaw.cli.main import cli +from copaw.cli import shutdown_cmd as shutdown_cmd_module +from copaw.cli.shutdown_cmd import ( + _find_windows_wrapper_ancestor_pids, + _terminate_pid, +) + + +def test_shutdown_command_stops_backend_and_frontend(monkeypatch) -> None: + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._listening_pids_for_port", + lambda _port: {1001}, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_frontend_dev_pids", + lambda: {2002}, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_desktop_wrapper_pids", + set, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_windows_wrapper_ancestor_pids", + lambda _pids: set(), + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._terminate_pid", + lambda _pid: True, + ) + + result = CliRunner().invoke(cli, ["shutdown"]) + + assert result.exit_code == 0 + assert "1001" in result.output + assert "2002" in result.output + + +def test_shutdown_command_reports_failure(monkeypatch) -> None: + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._listening_pids_for_port", + lambda _port: {1001}, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_frontend_dev_pids", + set, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_desktop_wrapper_pids", + set, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_windows_wrapper_ancestor_pids", + lambda _pids: set(), + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._terminate_pid", + lambda _pid: False, + ) + + result = CliRunner().invoke(cli, ["shutdown"]) + + assert result.exit_code != 0 + assert "Failed to shutdown process" in result.output + + +def test_shutdown_command_reports_nothing_found(monkeypatch) -> None: + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._listening_pids_for_port", + lambda _port: set(), + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_frontend_dev_pids", + set, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_desktop_wrapper_pids", + set, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_windows_wrapper_ancestor_pids", + lambda _pids: set(), + ) + + result = CliRunner().invoke(cli, ["shutdown"]) + + assert result.exit_code != 0 + assert "No running CoPaw" in result.output + + +def test_shutdown_command_stops_windows_wrapper_ancestors(monkeypatch) -> None: + monkeypatch.setattr("copaw.cli.shutdown_cmd.sys.platform", "win32") + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._listening_pids_for_port", + lambda _port: {24692}, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_frontend_dev_pids", + set, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_desktop_wrapper_pids", + set, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._find_windows_wrapper_ancestor_pids", + lambda _pids: {1052}, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._terminate_pid", + lambda _pid: True, + ) + + result = CliRunner().invoke(cli, ["shutdown"]) + + assert result.exit_code == 0 + assert "1052" in result.output + assert "24692" in result.output + + +def test_terminate_pid_force_kills_on_windows(monkeypatch) -> None: + calls: list[tuple[int, bool]] = [] + waits = iter([False, True]) + + monkeypatch.setattr("copaw.cli.shutdown_cmd.sys.platform", "win32") + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._pid_exists", + lambda _pid: True, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._terminate_process_tree_windows", + lambda pid, force=False: calls.append((pid, force)), + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._wait_for_pid_exit", + lambda _pid, _timeout, _interval: next(waits), + ) + + assert _terminate_pid(17944) is True + assert calls == [(17944, False), (17944, True)] + + +def test_terminate_pid_uses_windows_fallback(monkeypatch) -> None: + calls: list[tuple[int, bool]] = [] + waits = iter([False, False, True]) + fallback_calls: list[int] = [] + + monkeypatch.setattr("copaw.cli.shutdown_cmd.sys.platform", "win32") + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._pid_exists", + lambda _pid: True, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._terminate_process_tree_windows", + lambda pid, force=False: calls.append((pid, force)), + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._force_terminate_windows_process", + fallback_calls.append, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._wait_for_pid_exit", + lambda _pid, _timeout, _interval: next(waits), + ) + + assert _terminate_pid(17944) is True + assert calls == [(17944, False), (17944, True)] + assert fallback_calls == [17944] + + +def test_pid_exists_uses_windows_snapshot(monkeypatch) -> None: + monkeypatch.setattr("copaw.cli.shutdown_cmd.sys.platform", "win32") + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._windows_process_snapshot", + lambda: {29104: (1, "copaw.exe", "copaw app")}, + ) + + assert ( + shutdown_cmd_module._pid_exists( # pylint: disable=protected-access + 29104, + ) + is True + ) + assert ( + shutdown_cmd_module._pid_exists( # pylint: disable=protected-access + 99999, + ) + is False + ) + + +def test_find_windows_wrapper_ancestor_pids(monkeypatch) -> None: + monkeypatch.setattr("copaw.cli.shutdown_cmd.sys.platform", "win32") + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._windows_process_snapshot", + lambda: { + 24692: (1052, "python.exe", "python -m uvicorn copaw.app"), + 1052: (900, "copaw.exe", ""), + 900: (4, "powershell.exe", "powershell"), + }, + ) + + assert _find_windows_wrapper_ancestor_pids({24692}) == {1052} + + +def test_terminate_pid_force_kills_on_unix(monkeypatch) -> None: + calls: list[tuple[int, object]] = [] + waits = iter([False, True]) + + monkeypatch.setattr("copaw.cli.shutdown_cmd.sys.platform", "darwin") + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._pid_exists", + lambda _pid: True, + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._signal_process_tree_unix", + lambda pid, sig: calls.append((pid, sig)), + ) + monkeypatch.setattr( + "copaw.cli.shutdown_cmd._wait_for_pid_exit", + lambda _pid, _timeout, _interval: next(waits), + ) + + assert _terminate_pid(4242) is True + assert calls == [ + ( + 4242, + shutdown_cmd_module._SIGTERM, # pylint: disable=protected-access + ), + ( + 4242, + shutdown_cmd_module._SIGKILL, # pylint: disable=protected-access + ), + ] diff --git a/tests/unit/cli/test_cli_update.py b/tests/unit/cli/test_cli_update.py new file mode 100644 index 000000000..181719624 --- /dev/null +++ b/tests/unit/cli/test_cli_update.py @@ -0,0 +1,910 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import sys +import json +from pathlib import Path + +import httpx +import pytest +from click.testing import CliRunner + +from copaw.__version__ import __version__ +from copaw.cli.main import cli +from copaw.cli.update_cmd import ( + InstallInfo, + RunningServiceInfo, + _detect_running_service, + _detect_installation, + _is_newer_version, + _probe_service, + _detect_source_type, + _run_update_worker_detached, + _run_update_worker_foreground, + run_update_worker, +) + + +def _install_info( + *, + source_type: str = "pypi", + installer: str = "pip", +) -> InstallInfo: + return InstallInfo( + package_dir="/tmp/site-packages/copaw", + python_executable="/tmp/venv/bin/python", + environment_root="/tmp/venv", + environment_kind="virtualenv", + installer=installer, + source_type=source_type, + source_url=None, + ) + + +@pytest.mark.parametrize( + ("latest", "current", "expected"), + [ + ("1.2.4", "1.2.3", True), + ("1.2.3", "1.2.3", False), + ("1.2.2", "1.2.3", False), + ("1.2.3", "1.2.3rc1", True), + ("1.2.3rc1", "1.2.3", False), + ("main", "main", False), + ("main", "feature", None), + ("main", "1.2.3", None), + ], +) +def test_is_newer_version( + latest: str, + current: str, + expected: bool | None, +) -> None: + assert _is_newer_version(latest, current) is expected + + +@pytest.mark.parametrize( + ("direct_url", "expected"), + [ + (None, ("pypi", None)), + ( + { + "url": "file:///Users/test/CoPaw", + "dir_info": {"editable": True}, + }, + ("editable", "file:///Users/test/CoPaw"), + ), + ( + { + "url": "https://github.com/agentscope-ai/CoPaw.git", + "vcs_info": {"vcs": "git", "commit_id": "abc123"}, + }, + ("vcs", "https://github.com/agentscope-ai/CoPaw.git"), + ), + ( + {"url": "file:///tmp/copaw.whl"}, + ("local", "file:///tmp/copaw.whl"), + ), + ( + {"url": "https://example.com/copaw.whl"}, + ("direct-url", "https://example.com/copaw.whl"), + ), + ], +) +def test_detect_source_type( + direct_url: dict[str, object] | None, + expected: tuple[str, str | None], +) -> None: + assert _detect_source_type(direct_url) == expected + + +@pytest.mark.parametrize( + ( + "installer_text", + "direct_url_text", + "expected_installer", + "expected_source_type", + "expected_source_url", + ), + [ + (None, None, "pip", "pypi", None), + ( + "uv\n", + json.dumps( + { + "url": "file:///Users/test/CoPaw", + "dir_info": {"editable": True}, + }, + ), + "uv", + "editable", + "file:///Users/test/CoPaw", + ), + ], +) +def test_detect_installation( + monkeypatch, + installer_text: str | None, + direct_url_text: str | None, + expected_installer: str, + expected_source_type: str, + expected_source_url: str | None, +) -> None: + from copaw.cli import update_cmd as update_cmd_module + + class _FakeDistribution: + def read_text(self, name: str) -> str | None: + mapping = { + "INSTALLER": installer_text, + "direct_url.json": direct_url_text, + } + return mapping.get(name) + + expected_python_executable = "/tmp/test-venv/bin/python" + expected_environment_root = str(Path("/tmp/test-venv").resolve()) + expected_package_dir = str( + Path(update_cmd_module.__file__).resolve().parent.parent, + ) + + monkeypatch.setattr( + update_cmd_module.metadata, + "distribution", + lambda name: _FakeDistribution(), + ) + monkeypatch.setattr( + update_cmd_module.sys, + "executable", + "/tmp/test-venv/bin/python", + ) + monkeypatch.setattr(update_cmd_module.sys, "prefix", "/tmp/test-venv") + monkeypatch.setattr(update_cmd_module.sys, "base_prefix", "/usr/local") + + result = _detect_installation() + + assert result.installer == expected_installer + assert result.source_type == expected_source_type + assert result.source_url == expected_source_url + assert result.python_executable == expected_python_executable + assert result.environment_root == expected_environment_root + assert result.environment_kind == "virtualenv" + assert result.package_dir == expected_package_dir + + +def test_update_reports_up_to_date(monkeypatch) -> None: + from copaw.cli import update_cmd as update_cmd_module + + install_info = _install_info() + + def _detect_installation() -> InstallInfo: + return install_info + + def _fetch_latest_version() -> str: + return __version__ + + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + _detect_installation, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + _fetch_latest_version, + ) + + result = CliRunner().invoke(cli, ["update", "--yes"]) + + assert result.exit_code == 0 + assert "CoPaw is already up to date." in result.output + + +def test_probe_service_ignores_proxy_env(monkeypatch) -> None: + captured: dict[str, object] = {} + + class _Response: + def raise_for_status(self) -> None: + return None + + def json(self) -> dict[str, str]: + return {"version": "1.2.3"} + + def _fake_get(url: str, **kwargs): + captured["url"] = url + captured.update(kwargs) + return _Response() + + monkeypatch.setattr("copaw.cli.update_cmd.httpx.get", _fake_get) + + result = _probe_service("http://127.0.0.1:8088") + + assert result.is_running is True + assert result.base_url == "http://127.0.0.1:8088" + assert result.version == "1.2.3" + assert captured["trust_env"] is False + + +def test_probe_service_returns_not_running_on_http_error(monkeypatch) -> None: + def _fake_get(_url: str, **_kwargs): + raise httpx.HTTPError("bad gateway") + + monkeypatch.setattr("copaw.cli.update_cmd.httpx.get", _fake_get) + + result = _probe_service("http://127.0.0.1:8088") + + assert result.is_running is False + + +def test_detect_running_service_handles_wildcard_host(monkeypatch) -> None: + from copaw.cli import update_cmd as update_cmd_module + + monkeypatch.setattr(update_cmd_module, "read_last_api", lambda: None) + monkeypatch.setattr( + update_cmd_module, + "_probe_service", + lambda base_url: RunningServiceInfo( + is_running=base_url == "http://127.0.0.1:9090", + base_url=base_url if base_url == "http://127.0.0.1:9090" else None, + ), + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service_from_processes", + lambda _hosts: RunningServiceInfo(is_running=False), + ) + + result = _detect_running_service("0.0.0.0", 9090) + + assert result.is_running is True + assert result.base_url == "http://127.0.0.1:9090" + + +def test_detect_running_service_falls_back_to_process_ports( + monkeypatch, +) -> None: + from copaw.cli import update_cmd as update_cmd_module + + monkeypatch.setattr(update_cmd_module, "read_last_api", lambda: None) + monkeypatch.setattr( + update_cmd_module, + "_probe_service", + lambda _base_url: RunningServiceInfo(is_running=False), + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service_from_processes", + lambda hosts: RunningServiceInfo( + is_running=True, + base_url=f"http://{hosts[0]}:8088", + ), + ) + + result = _detect_running_service(None, None) + + assert result.is_running is True + assert result.base_url == "http://127.0.0.1:8088" + + +def test_update_blocks_running_service(monkeypatch) -> None: + from copaw.cli import update_cmd as update_cmd_module + + install_info = _install_info() + + def _detect_installation() -> InstallInfo: + return install_info + + def _fetch_latest_version() -> str: + return "9.9.9" + + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + _detect_installation, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + _fetch_latest_version, + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service", + lambda host, port: RunningServiceInfo( + is_running=True, + base_url="http://127.0.0.1:8088", + version=__version__, + ), + ) + + result = CliRunner().invoke(cli, ["update", "--yes"]) + + assert result.exit_code != 0 + assert "Please stop it before running `copaw update`" in result.output + assert ( + "without `--yes` to confirm a forced `copaw shutdown`" in result.output + ) + + +def test_update_can_cancel_forced_shutdown(monkeypatch) -> None: + from copaw.cli import update_cmd as update_cmd_module + + install_info = _install_info() + + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + lambda: install_info, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + lambda: "9.9.9", + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service", + lambda host, port: RunningServiceInfo( + is_running=True, + base_url="http://127.0.0.1:8088", + version=__version__, + ), + ) + + result = CliRunner().invoke(cli, ["update"], input="n\n") + + assert result.exit_code == 0 + assert ( + "forcibly terminate the current CoPaw backend/frontend " + "processes" in result.output + ) + assert ( + "Run `copaw shutdown` now and continue with the update?" + in result.output + ) + assert "Cancelled." in result.output + + +def test_update_can_force_shutdown_running_service( + monkeypatch, + tmp_path: Path, +) -> None: + from copaw.cli import update_cmd as update_cmd_module + + install_info = _install_info() + spawned: dict[str, object] = {} + service_checks = iter( + [ + RunningServiceInfo( + is_running=True, + base_url="http://127.0.0.1:8088", + version=__version__, + ), + RunningServiceInfo(is_running=False), + ], + ) + + monkeypatch.setattr(update_cmd_module, "WORKING_DIR", tmp_path) + monkeypatch.setattr(update_cmd_module.sys, "platform", "darwin") + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + lambda: install_info, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + lambda: "9.9.9", + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service", + lambda host, port: next(service_checks), + ) + + def _fake_shutdown( + command, + stdout, + stderr, + text, + encoding, + errors, + check, + ): + del stdout, stderr, text, encoding, errors, check + assert command == [ + "/tmp/venv/bin/python", + "-m", + "copaw", + "--port", + "8088", + "shutdown", + ] + + class _Result: + returncode = 0 + stdout = "Stopped CoPaw processes: 1234\n" + + return _Result() + + def _fake_run_worker(plan_path: Path) -> int: + spawned["path"] = plan_path + spawned["plan"] = json.loads(plan_path.read_text(encoding="utf-8")) + return 0 + + monkeypatch.setattr(update_cmd_module.subprocess, "run", _fake_shutdown) + monkeypatch.setattr( + update_cmd_module, + "_run_update_worker_foreground", + _fake_run_worker, + ) + + result = CliRunner().invoke(cli, ["update"], input="y\ny\n") + + assert result.exit_code == 0 + assert "Running `copaw shutdown` before updating..." in result.output + assert "Stopped CoPaw processes: 1234" in result.output + assert "Starting CoPaw update..." in result.output + assert isinstance(spawned["path"], Path) + + +def test_update_can_cancel_non_pypi_override(monkeypatch) -> None: + from copaw.cli import update_cmd as update_cmd_module + + install_info = _install_info(source_type="editable") + + def _detect_installation() -> InstallInfo: + return install_info + + def _fetch_latest_version() -> str: + return "9.9.9" + + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + _detect_installation, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + _fetch_latest_version, + ) + + result = CliRunner().invoke(cli, ["update"], input="n\n") + + assert result.exit_code == 0 + assert "Detected a non-PyPI installation source: editable" in result.output + assert "Continue and replace the current installation" in result.output + assert "Cancelled." in result.output + + +def test_update_can_override_non_pypi_install_with_yes( + monkeypatch, + tmp_path: Path, +) -> None: + from copaw.cli import update_cmd as update_cmd_module + + spawned: dict[str, object] = {} + install_info = _install_info(source_type="editable") + + def _detect_installation() -> InstallInfo: + return install_info + + def _fetch_latest_version() -> str: + return "9.9.9" + + def _detect_running_service( + host: str | None, + port: int | None, + ) -> RunningServiceInfo: + del host, port + return RunningServiceInfo(is_running=False) + + monkeypatch.setattr(update_cmd_module, "WORKING_DIR", tmp_path) + monkeypatch.setattr(update_cmd_module.sys, "platform", "darwin") + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + _detect_installation, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + _fetch_latest_version, + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service", + _detect_running_service, + ) + + def _fake_run_worker(plan_path: Path) -> int: + spawned["path"] = plan_path + spawned["plan"] = json.loads(plan_path.read_text(encoding="utf-8")) + return 0 + + monkeypatch.setattr( + update_cmd_module, + "_run_update_worker_foreground", + _fake_run_worker, + ) + + result = CliRunner().invoke(cli, ["update", "--yes"]) + + assert result.exit_code == 0 + assert "Proceeding because `--yes` was provided." in result.output + assert "Starting CoPaw update..." in result.output + assert isinstance(spawned["path"], Path) + + +def test_update_spawns_worker(monkeypatch, tmp_path: Path) -> None: + from copaw.cli import update_cmd as update_cmd_module + + spawned: dict[str, object] = {} + install_info = _install_info() + + def _detect_installation() -> InstallInfo: + return install_info + + def _fetch_latest_version() -> str: + return "9.9.9" + + def _detect_running_service( + host: str | None, + port: int | None, + ) -> RunningServiceInfo: + del host, port + return RunningServiceInfo(is_running=False) + + monkeypatch.setattr(update_cmd_module, "WORKING_DIR", tmp_path) + monkeypatch.setattr(update_cmd_module.sys, "platform", "darwin") + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + _detect_installation, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + _fetch_latest_version, + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service", + _detect_running_service, + ) + + def _fake_run_worker(plan_path: Path) -> int: + spawned["path"] = plan_path + spawned["plan"] = json.loads(plan_path.read_text(encoding="utf-8")) + return 0 + + monkeypatch.setattr( + update_cmd_module, + "_run_update_worker_foreground", + _fake_run_worker, + ) + + result = CliRunner().invoke(cli, ["update", "--yes"]) + + assert result.exit_code == 0 + assert "Starting CoPaw update..." in result.output + assert isinstance(spawned["path"], Path) + plan = spawned["plan"] + assert plan["latest_version"] == "9.9.9" # type: ignore [index] + assert plan["installer_label"] == "pip" # type: ignore [index] + assert plan["command"][:5] == [ # type: ignore [index] + "/tmp/venv/bin/python", + "-m", + "pip", + "install", + "--upgrade", + ] + + +def test_update_prompts_when_version_is_not_comparable( + monkeypatch, +) -> None: + from copaw.cli import update_cmd as update_cmd_module + + install_info = _install_info() + + def _detect_installation() -> InstallInfo: + return install_info + + def _fetch_latest_version() -> str: + return "main" + + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + _detect_installation, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + _fetch_latest_version, + ) + + result = CliRunner().invoke(cli, ["update"], input="n\n") + + assert result.exit_code == 0 + assert "Unable to compare the current version" in result.output + assert "Cancelled." in result.output + + +def test_update_can_continue_when_version_is_not_comparable( + monkeypatch, + tmp_path: Path, +) -> None: + from copaw.cli import update_cmd as update_cmd_module + + spawned: dict[str, object] = {} + install_info = _install_info() + + def _detect_installation() -> InstallInfo: + return install_info + + def _fetch_latest_version() -> str: + return "main" + + def _detect_running_service( + host: str | None, + port: int | None, + ) -> RunningServiceInfo: + del host, port + return RunningServiceInfo(is_running=False) + + monkeypatch.setattr(update_cmd_module, "WORKING_DIR", tmp_path) + monkeypatch.setattr(update_cmd_module.sys, "platform", "darwin") + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + _detect_installation, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + _fetch_latest_version, + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service", + _detect_running_service, + ) + + def _fake_run_worker(plan_path: Path) -> int: + spawned["path"] = plan_path + spawned["plan"] = json.loads(plan_path.read_text(encoding="utf-8")) + return 0 + + monkeypatch.setattr( + update_cmd_module, + "_run_update_worker_foreground", + _fake_run_worker, + ) + + result = CliRunner().invoke(cli, ["update"], input="y\ny\n") + + assert result.exit_code == 0 + assert isinstance(spawned["path"], Path) + assert "Starting CoPaw update..." in result.output + + +def test_update_returns_worker_exit_code(monkeypatch, tmp_path: Path) -> None: + from copaw.cli import update_cmd as update_cmd_module + + install_info = _install_info() + + monkeypatch.setattr(update_cmd_module, "WORKING_DIR", tmp_path) + monkeypatch.setattr(update_cmd_module.sys, "platform", "darwin") + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + lambda: install_info, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + lambda: "9.9.9", + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service", + lambda host, port: RunningServiceInfo(is_running=False), + ) + monkeypatch.setattr( + update_cmd_module, + "_run_update_worker_foreground", + lambda plan_path: 2, + ) + + result = CliRunner().invoke(cli, ["update", "--yes"]) + + assert result.exit_code == 2 + assert "Starting CoPaw update..." in result.output + + +def test_update_detaches_worker_on_windows( + monkeypatch, + tmp_path: Path, +) -> None: + from copaw.cli import update_cmd as update_cmd_module + + install_info = _install_info() + spawned: dict[str, object] = {} + + monkeypatch.setattr(update_cmd_module, "WORKING_DIR", tmp_path) + monkeypatch.setattr(update_cmd_module.sys, "platform", "win32") + monkeypatch.setattr( + update_cmd_module, + "_detect_installation", + lambda: install_info, + ) + monkeypatch.setattr( + update_cmd_module, + "_fetch_latest_version", + lambda: "9.9.9", + ) + monkeypatch.setattr( + update_cmd_module, + "_detect_running_service", + lambda host, port: RunningServiceInfo(is_running=False), + ) + + def _fake_run_detached(plan_path: Path) -> None: + spawned["path"] = plan_path + spawned["plan"] = json.loads(plan_path.read_text(encoding="utf-8")) + + monkeypatch.setattr( + update_cmd_module, + "_run_update_worker_detached", + _fake_run_detached, + ) + monkeypatch.setattr( + update_cmd_module, + "_run_update_worker_foreground", + lambda plan_path: pytest.fail("foreground worker should not run"), + ) + + result = CliRunner().invoke(cli, ["update", "--yes"]) + + assert result.exit_code == 0 + assert "Starting CoPaw update..." in result.output + assert "continue after this command exits" in result.output + assert isinstance(spawned["path"], Path) + assert spawned["plan"]["launcher_pid"] is not None # type: ignore[index] + + +def test_update_worker_waits_for_launcher_exit( + monkeypatch, + tmp_path: Path, + capsys, +) -> None: + plan_path = tmp_path / "update-plan-wait.json" + _write_plan( + plan_path, + command=[ + sys.executable, + "-u", + "-c", + "print('installer: done', flush=True)", + ], + ) + plan = json.loads(plan_path.read_text(encoding="utf-8")) + plan["launcher_pid"] = 4321 + plan_path.write_text(json.dumps(plan), encoding="utf-8") + + waited: list[tuple[int | None, float]] = [] + monkeypatch.setattr( + "copaw.cli.update_cmd._wait_for_process_exit", + lambda pid, timeout=15.0: waited.append((pid, timeout)), + ) + + return_code = run_update_worker(plan_path) + captured = capsys.readouterr() + + assert return_code == 0 + assert waited == [(4321, 15.0)] + assert "installer: done" in captured.out + + +def test_run_update_worker_detached_spawns_without_capture( + monkeypatch, + tmp_path: Path, +) -> None: + captured: dict[str, object] = {} + + def _fake_spawn(plan_path: Path, *, capture_output: bool = True): + captured["path"] = plan_path + captured["capture_output"] = capture_output + return object() + + monkeypatch.setattr( + "copaw.cli.update_cmd._spawn_update_worker", + _fake_spawn, + ) + + plan_path = tmp_path / "update-plan.json" + plan_path.write_text("{}", encoding="utf-8") + + _run_update_worker_detached(plan_path) + + assert captured == { + "path": plan_path, + "capture_output": False, + } + + +def _write_plan( + plan_path: Path, + *, + command: list[str], + current_version: str = "1.0.0", + latest_version: str = "9.9.9", +) -> None: + plan = { + "current_version": current_version, + "latest_version": latest_version, + "installer_label": "integration-test", + "command": command, + "install": {}, + } + plan_path.write_text(json.dumps(plan), encoding="utf-8") + + +def test_update_worker_foreground_streams_output_and_cleans_plan( + tmp_path: Path, + capsys, +) -> None: + """Test that the foreground worker streams child output and cleans up.""" + plan_path = tmp_path / "update-plan.json" + _write_plan( + plan_path, + command=[ + sys.executable, + "-u", + "-c", + ( + "print('installer: preparing', flush=True);" + "print('installer: done', flush=True)" + ), + ], + ) + + return_code = _run_update_worker_foreground(plan_path) + captured = capsys.readouterr() + + assert return_code == 0 + assert "[copaw] Updating CoPaw 1.0.0 -> 9.9.9..." in captured.out + assert "[copaw] Using installer: integration-test" in captured.out + assert "installer: preparing" in captured.out + assert "installer: done" in captured.out + assert "[copaw] Update completed successfully." in captured.out + assert not plan_path.exists() + + +def test_update_worker_foreground_propagates_failure_exit_code( + tmp_path: Path, + capsys, +) -> None: + """Test that the foreground worker returns the installer exit code.""" + plan_path = tmp_path / "update-plan-fail.json" + _write_plan( + plan_path, + command=[ + sys.executable, + "-u", + "-c", + ( + "import sys;" + "print('installer: failing', flush=True);" + "sys.exit(7)" + ), + ], + ) + + return_code = _run_update_worker_foreground(plan_path) + captured = capsys.readouterr() + + assert return_code == 7 + assert "installer: failing" in captured.out + assert "[copaw] Update failed with exit code 7." in captured.out + assert not plan_path.exists() diff --git a/tests/test_cli_version.py b/tests/unit/cli/test_cli_version.py similarity index 100% rename from tests/test_cli_version.py rename to tests/unit/cli/test_cli_version.py From c9baa183b8743780da4e6d246c6254dc8842ced9 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:58:34 +0800 Subject: [PATCH 07/68] feat(channels): normalize prefixed discord ids for cron sends Co-authored-by: Runlin Lei --- src/copaw/app/channels/discord_/channel.py | 24 +++++++++------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/copaw/app/channels/discord_/channel.py b/src/copaw/app/channels/discord_/channel.py index dbbd533df..6400ff2fd 100644 --- a/src/copaw/app/channels/discord_/channel.py +++ b/src/copaw/app/channels/discord_/channel.py @@ -283,26 +283,22 @@ def from_config( require_mention=config.require_mention, ) - async def _resolve_target(self, to_handle, meta): + async def _resolve_target(self, to_handle, _meta): """Resolve a Discord Messageable from meta or to_handle.""" - meta = meta or {} - if not meta.get("channel_id") and not meta.get("user_id"): - meta.update(self._route_from_handle(to_handle)) - channel_id = meta.get("channel_id") - user_id = meta.get("user_id") + route = self._route_from_handle(to_handle) + channel_id = route.get("channel_id") + user_id = route.get("user_id") if channel_id: - ch = self._client.get_channel(int(channel_id)) + cid = int(channel_id) + ch = self._client.get_channel(cid) if ch is None: - ch = await self._client.fetch_channel( - int(channel_id), - ) + ch = await self._client.fetch_channel(cid) return ch if user_id: - user = self._client.get_user(int(user_id)) + uid = int(user_id) + user = self._client.get_user(uid) if user is None: - user = await self._client.fetch_user( - int(user_id), - ) + user = await self._client.fetch_user(uid) return user.dm_channel or await user.create_dm() return None From 7b8d1e07cc4494da2e91a981f8c13262e1a7d4d7 Mon Sep 17 00:00:00 2001 From: Atletico1999 <53865833+Atletico1999@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:59:34 +0800 Subject: [PATCH 08/68] feat(discord): prevent cross-channel message merging due to sender_id debounce key (#1002) Co-authored-by: Antigravity --- src/copaw/app/channels/discord_/channel.py | 56 +++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/src/copaw/app/channels/discord_/channel.py b/src/copaw/app/channels/discord_/channel.py index 6400ff2fd..bafe41eec 100644 --- a/src/copaw/app/channels/discord_/channel.py +++ b/src/copaw/app/channels/discord_/channel.py @@ -9,7 +9,7 @@ import tempfile from pathlib import Path from urllib.parse import urlparse -from typing import Any, Optional +from typing import Any, Dict, List, Optional import aiohttp from agentscope_runtime.engine.schemas.agent_schemas import ( @@ -515,6 +515,60 @@ async def stop(self) -> None: if self._client: await self._client.close() + # ------------------------------------------------------------------ + # Debounce: use per-channel keys so concurrent messages from the same + # user in different channels/threads are NOT merged together. + # ------------------------------------------------------------------ + + def get_debounce_key(self, payload: Any) -> str: + """Return a debounce key scoped to the Discord channel or DM. + + The base class falls back to ``sender_id``, which causes + ``ChannelManager._drain_same_key()`` to incorrectly merge + messages when the same user sends to multiple channels at the + same time. This override uses ``resolve_session_id`` so each + channel/thread gets its own isolated debounce bucket. + """ + if isinstance(payload, dict): + meta = payload.get("meta") or {} + sender_id = payload.get("sender_id") or "" + return self.resolve_session_id(sender_id, meta) + return getattr(payload, "session_id", "") or "" + + def merge_native_items(self, items: List[Any]) -> Any: + """Merge native payloads while preserving Discord metadata. + + Extends the base implementation to also carry over + Discord-specific meta keys (``channel_id``, ``message_id``, + ``guild_id``, ``is_dm``, ``is_group``) from the first item. + """ + if not items: + return None + first = items[0] if isinstance(items[0], dict) else {} + merged_parts: List[Any] = [] + merged_meta: Dict[str, Any] = dict(first.get("meta") or {}) + for it in items: + p = it if isinstance(it, dict) else {} + merged_parts.extend(p.get("content_parts") or []) + m = p.get("meta") or {} + for k in ( + "reply_future", + "reply_loop", + "incoming_message", + "conversation_id", + "message_id", + ): + if k in m: + merged_meta[k] = m[k] + return { + "channel_id": first.get("channel_id") or self.channel, + "sender_id": first.get("sender_id") or "", + "content_parts": merged_parts, + "meta": merged_meta, + } + + # ------------------------------------------------------------------ + def resolve_session_id( self, sender_id: str, From 699c6bfeabe051ba989263afada448bc534bb6be Mon Sep 17 00:00:00 2001 From: Yuexiang XIE Date: Mon, 16 Mar 2026 16:01:33 +0800 Subject: [PATCH 09/68] feat(timezone): add user timezone configuration (#1535) --- console/src/api/index.ts | 4 + console/src/api/modules/userTimezone.ts | 15 ++++ console/src/constants/timezone.ts | 21 +++++ console/src/locales/en.json | 6 ++ console/src/locales/ja.json | 6 ++ console/src/locales/ru.json | 6 ++ console/src/locales/zh.json | 6 ++ .../Config/components/ReactAgentCard.tsx | 28 +++++++ console/src/pages/Agent/Config/index.tsx | 6 ++ .../src/pages/Agent/Config/useAgentConfig.tsx | 30 ++++++- .../Control/CronJobs/components/constants.ts | 22 +----- console/src/pages/Control/CronJobs/index.tsx | 21 ++++- pyproject.toml | 1 + src/copaw/agents/md_files/en/BOOTSTRAP.md | 2 +- src/copaw/agents/md_files/en/PROFILE.md | 1 - src/copaw/agents/md_files/ru/BOOTSTRAP.md | 2 +- src/copaw/agents/md_files/ru/PROFILE.md | 1 - src/copaw/agents/md_files/zh/BOOTSTRAP.md | 2 +- src/copaw/agents/md_files/zh/PROFILE.md | 1 - src/copaw/agents/react_agent.py | 4 +- src/copaw/agents/skills/cron/SKILL.md | 17 ++-- src/copaw/agents/tools/__init__.py | 3 +- src/copaw/agents/tools/get_current_time.py | 79 ++++++++++++++++--- src/copaw/app/crons/heartbeat.py | 18 ++++- src/copaw/app/crons/manager.py | 6 +- src/copaw/app/routers/config.py | 30 +++++++ src/copaw/app/runner/query_error_dump.py | 6 +- src/copaw/app/runner/utils.py | 22 +++++- src/copaw/cli/cron_cmd.py | 13 ++- src/copaw/config/config.py | 11 +++ src/copaw/config/timezone.py | 47 +++++++++++ website/public/docs/cli.en.md | 12 +-- website/public/docs/cli.zh.md | 12 +-- website/public/docs/config.en.md | 20 ++++- website/public/docs/config.zh.md | 20 ++++- website/public/docs/console.en.md | 2 +- website/public/docs/console.zh.md | 2 +- 37 files changed, 425 insertions(+), 80 deletions(-) create mode 100644 console/src/api/modules/userTimezone.ts create mode 100644 console/src/constants/timezone.ts create mode 100644 src/copaw/config/timezone.py diff --git a/console/src/api/index.ts b/console/src/api/index.ts index fc3515551..d137b9423 100644 --- a/console/src/api/index.ts +++ b/console/src/api/index.ts @@ -20,6 +20,7 @@ import { mcpApi } from "./modules/mcp"; import { tokenUsageApi } from "./modules/tokenUsage"; import { toolsApi } from "./modules/tools"; import { securityApi } from "./modules/security"; +import { userTimezoneApi } from "./modules/userTimezone"; export const api = { // Root @@ -71,6 +72,9 @@ export const api = { // Security ...securityApi, + + // User Timezone + ...userTimezoneApi, }; export default api; diff --git a/console/src/api/modules/userTimezone.ts b/console/src/api/modules/userTimezone.ts new file mode 100644 index 000000000..dc56d055f --- /dev/null +++ b/console/src/api/modules/userTimezone.ts @@ -0,0 +1,15 @@ +import { request } from "../request"; + +export interface UserTimezoneConfig { + timezone: string; +} + +export const userTimezoneApi = { + getUserTimezone: () => request("/config/user-timezone"), + + updateUserTimezone: (timezone: string) => + request("/config/user-timezone", { + method: "PUT", + body: JSON.stringify({ timezone }), + }), +}; diff --git a/console/src/constants/timezone.ts b/console/src/constants/timezone.ts new file mode 100644 index 000000000..d88d5096f --- /dev/null +++ b/console/src/constants/timezone.ts @@ -0,0 +1,21 @@ +export const TIMEZONE_OPTIONS = [ + { value: "America/Los_Angeles", label: "America/Los_Angeles (UTC-8)" }, + { value: "America/Denver", label: "America/Denver (UTC-7)" }, + { value: "America/Chicago", label: "America/Chicago (UTC-6)" }, + { value: "America/New_York", label: "America/New_York (UTC-5)" }, + { value: "America/Toronto", label: "America/Toronto (UTC-5)" }, + { value: "UTC", label: "UTC" }, + { value: "Europe/London", label: "Europe/London (UTC+0)" }, + { value: "Europe/Paris", label: "Europe/Paris (UTC+1)" }, + { value: "Europe/Berlin", label: "Europe/Berlin (UTC+1)" }, + { value: "Europe/Moscow", label: "Europe/Moscow (UTC+3)" }, + { value: "Asia/Dubai", label: "Asia/Dubai (UTC+4)" }, + { value: "Asia/Shanghai", label: "Asia/Shanghai (UTC+8)" }, + { value: "Asia/Hong_Kong", label: "Asia/Hong_Kong (UTC+8)" }, + { value: "Asia/Singapore", label: "Asia/Singapore (UTC+8)" }, + { value: "Asia/Tokyo", label: "Asia/Tokyo (UTC+9)" }, + { value: "Asia/Seoul", label: "Asia/Seoul (UTC+9)" }, + { value: "Australia/Sydney", label: "Australia/Sydney (UTC+10)" }, + { value: "Australia/Melbourne", label: "Australia/Melbourne (UTC+10)" }, + { value: "Pacific/Auckland", label: "Pacific/Auckland (UTC+12)" }, +]; diff --git a/console/src/locales/en.json b/console/src/locales/en.json index 0fffc7006..94d1e4cc0 100644 --- a/console/src/locales/en.json +++ b/console/src/locales/en.json @@ -540,6 +540,12 @@ "languageSaveFailed": "Failed to update language", "saveSuccess": "Configuration saved successfully", "saveFailed": "Failed to save configuration", + "timezone": "User Timezone", + "timezoneTooltip": "Used for scheduled tasks, time display, and agent context. Defaults to your system timezone.", + "timezoneRequired": "Please select a timezone", + "selectTimezone": "Select timezone", + "timezoneSaveSuccess": "Timezone updated successfully", + "timezoneSaveFailed": "Failed to update timezone", "loadFailed": "Failed to load configuration" }, "tools": { diff --git a/console/src/locales/ja.json b/console/src/locales/ja.json index 46cbfbbd7..72b64bf1d 100644 --- a/console/src/locales/ja.json +++ b/console/src/locales/ja.json @@ -528,6 +528,12 @@ "languageSaveSuccess": "言語を更新しました", "languageSaveSuccessWithFiles": "言語を更新し、{{count}} 個のMDファイルをコピーしました", "languageSaveFailed": "言語の更新に失敗しました", + "timezone": "ユーザータイムゾーン", + "timezoneTooltip": "スケジュールタスク、時刻表示、エージェントコンテキストに使用されます。デフォルトはシステムのタイムゾーンです。", + "timezoneRequired": "タイムゾーンを選択してください", + "selectTimezone": "タイムゾーンを選択", + "timezoneSaveSuccess": "タイムゾーンを更新しました", + "timezoneSaveFailed": "タイムゾーンの更新に失敗しました", "saveSuccess": "設定を保存しました", "saveFailed": "設定の保存に失敗しました", "loadFailed": "設定の読み込みに失敗しました" diff --git a/console/src/locales/ru.json b/console/src/locales/ru.json index dbafd670f..0c252d5fb 100644 --- a/console/src/locales/ru.json +++ b/console/src/locales/ru.json @@ -533,6 +533,12 @@ "languageSaveSuccess": "Язык успешно обновлён", "languageSaveSuccessWithFiles": "Язык обновлён, скопировано файлов: {{count}}", "languageSaveFailed": "Не удалось обновить язык", + "timezone": "Часовой пояс пользователя", + "timezoneTooltip": "Используется для запланированных задач, отображения времени и контекста агента. По умолчанию — системный часовой пояс.", + "timezoneRequired": "Пожалуйста, выберите часовой пояс", + "selectTimezone": "Выберите часовой пояс", + "timezoneSaveSuccess": "Часовой пояс обновлён", + "timezoneSaveFailed": "Не удалось обновить часовой пояс", "saveSuccess": "Конфигурация успешно сохранена", "saveFailed": "Не удалось сохранить конфигурацию", "loadFailed": "Не удалось загрузить конфигурацию" diff --git a/console/src/locales/zh.json b/console/src/locales/zh.json index c10b3e48a..7eea3584c 100644 --- a/console/src/locales/zh.json +++ b/console/src/locales/zh.json @@ -538,6 +538,12 @@ "languageSaveSuccess": "语言更新成功", "languageSaveSuccessWithFiles": "语言已更新,已复制 {{count}} 个 MD 文件", "languageSaveFailed": "语言更新失败", + "timezone": "用户时区", + "timezoneTooltip": "用于定时任务、时间显示和 Agent 上下文,默认使用系统时区。", + "timezoneRequired": "请选择时区", + "selectTimezone": "选择时区", + "timezoneSaveSuccess": "时区更新成功", + "timezoneSaveFailed": "时区更新失败", "saveSuccess": "配置保存成功", "saveFailed": "配置保存失败", "loadFailed": "配置加载失败" diff --git a/console/src/pages/Agent/Config/components/ReactAgentCard.tsx b/console/src/pages/Agent/Config/components/ReactAgentCard.tsx index 0dbd6d474..4e650c0d8 100644 --- a/console/src/pages/Agent/Config/components/ReactAgentCard.tsx +++ b/console/src/pages/Agent/Config/components/ReactAgentCard.tsx @@ -1,5 +1,6 @@ import { Form, InputNumber, Select, Card } from "@agentscope-ai/design"; import { useTranslation } from "react-i18next"; +import { TIMEZONE_OPTIONS } from "../../../../constants/timezone"; import styles from "../index.module.less"; const LANGUAGE_OPTIONS = [ @@ -12,12 +13,18 @@ interface ReactAgentCardProps { language: string; savingLang: boolean; onLanguageChange: (value: string) => void; + timezone: string; + savingTimezone: boolean; + onTimezoneChange: (value: string) => void; } export function ReactAgentCard({ language, savingLang, onLanguageChange, + timezone, + savingTimezone, + onTimezoneChange, }: ReactAgentCardProps) { const { t } = useTranslation(); return ( @@ -36,6 +43,27 @@ export function ReactAgentCard({ /> + + + + + } + > + {agents.map((agent) => ( + + + {agent.name} + + } + > +
+
+
+ +
+
+
+ {agent.name} + {agent.id === selectedAgent && ( + + )} +
+ {agent.description && ( +
+ {agent.description} +
+ )} +
+
+
ID: {agent.id}
+
+
+ ))} + + + ); +} diff --git a/console/src/layouts/Header.tsx b/console/src/layouts/Header.tsx index 94975d310..e0ea80e81 100644 --- a/console/src/layouts/Header.tsx +++ b/console/src/layouts/Header.tsx @@ -1,5 +1,6 @@ import { Layout, Space } from "antd"; import LanguageSwitcher from "../components/LanguageSwitcher"; +import AgentSelector from "../components/AgentSelector"; import { useTranslation } from "react-i18next"; import { FileTextOutlined, @@ -64,6 +65,7 @@ export default function Header({ selectedKey }: HeaderProps) { {t(keyToLabel[selectedKey] || "nav.chat")} + + handleDelete(record.id)} + disabled={record.id === "default"} + okText={t("common.confirm")} + cancelText={t("common.cancel")} + > + + + + ), + }, + ]; + + return ( +
+ + + {t("agent.management")} + + } + extra={ + + } + > + t("common.total", { count: total }), + }} + /> + + + setModalVisible(false)} + width={600} + okText={t("common.save")} + cancelText={t("common.cancel")} + > +
+ {editingAgent && ( + + + + )} + + + + + + + + + + +
+ + ); +} diff --git a/console/src/pages/Settings/Models/index.tsx b/console/src/pages/Settings/Models/index.tsx index 90b507c0c..8eda32bda 100644 --- a/console/src/pages/Settings/Models/index.tsx +++ b/console/src/pages/Settings/Models/index.tsx @@ -6,7 +6,6 @@ import { PageHeader, LoadingState, ProviderCard, - ModelsSection, CustomProviderModal, } from "./components"; import { useTranslation } from "react-i18next"; @@ -64,18 +63,7 @@ function ModelsPage() { ) : ( <> - {/* ---- LLM Section (top) ---- */} - - - - {/* ---- Providers Section (below) ---- */} + {/* ---- Providers Section ---- */}
([]); @@ -9,6 +10,7 @@ export function useProviders() { ); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); + const { selectedAgent } = useAgentStore(); const fetchAll = useCallback(async (showLoading = true) => { if (showLoading) { @@ -41,7 +43,7 @@ export function useProviders() { useEffect(() => { fetchAll(); - }, [fetchAll]); + }, [fetchAll, selectedAgent]); return { providers, diff --git a/console/src/stores/agentStore.ts b/console/src/stores/agentStore.ts new file mode 100644 index 000000000..bfb800e61 --- /dev/null +++ b/console/src/stores/agentStore.ts @@ -0,0 +1,46 @@ +import { create } from "zustand"; +import { persist } from "zustand/middleware"; +import type { AgentSummary } from "../api/types/agents"; + +interface AgentStore { + selectedAgent: string; + agents: AgentSummary[]; + setSelectedAgent: (agentId: string) => void; + setAgents: (agents: AgentSummary[]) => void; + addAgent: (agent: AgentSummary) => void; + removeAgent: (agentId: string) => void; + updateAgent: (agentId: string, updates: Partial) => void; +} + +export const useAgentStore = create()( + persist( + (set) => ({ + selectedAgent: "default", + agents: [], + + setSelectedAgent: (agentId) => set({ selectedAgent: agentId }), + + setAgents: (agents) => set({ agents }), + + addAgent: (agent) => + set((state) => ({ + agents: [...state.agents, agent], + })), + + removeAgent: (agentId) => + set((state) => ({ + agents: state.agents.filter((a) => a.id !== agentId), + })), + + updateAgent: (agentId, updates) => + set((state) => ({ + agents: state.agents.map((a) => + a.id === agentId ? { ...a, ...updates } : a, + ), + })), + }), + { + name: "copaw-agent-storage", + }, + ), +); diff --git a/pyproject.toml b/pyproject.toml index 6933ca12c..8e124d0ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "paho-mqtt>=2.0.0", "wecom-aibot-sdk @ https://agentscope.oss-cn-zhangjiakou.aliyuncs.com/pre_whl/wecom_aibot_sdk-1.0.0-py3-none-any.whl", "matrix-nio>=0.24.0", + "shortuuid>=1.0.0", "google-genai>=1.67.0", "tzdata>=2024.1", ] diff --git a/src/copaw/agents/command_handler.py b/src/copaw/agents/command_handler.py index a039cccfb..0aa0402f7 100644 --- a/src/copaw/agents/command_handler.py +++ b/src/copaw/agents/command_handler.py @@ -9,8 +9,6 @@ from agentscope.agent._react_agent import _MemoryMark from agentscope.message import Msg, TextBlock -from copaw.config import load_config - if TYPE_CHECKING: from .memory import MemoryManager from reme.memory.file_based import ReMeInMemoryMemory @@ -61,6 +59,7 @@ def __init__( memory: "ReMeInMemoryMemory", memory_manager: "MemoryManager | None" = None, enable_memory_manager: bool = True, + max_input_length: int = 128 * 1024, ): """Initialize command handler. @@ -69,9 +68,12 @@ def __init__( memory: Agent's ReMeInMemoryMemory instance memory_manager: Optional memory manager instance enable_memory_manager: Whether memory manager is enabled + max_input_length: Maximum input length in tokens for context + window (default: 128K = 131072) """ self.agent_name = agent_name self.memory = memory + self.max_input_length = max_input_length self.memory_manager = memory_manager self._enable_memory_manager = enable_memory_manager @@ -199,10 +201,8 @@ async def _process_history( _args: str = "", ) -> Msg: """Process /history command.""" - config = load_config() - max_input_length = config.agents.running.max_input_length history_str = await self.memory.get_history_str( - max_input_length=max_input_length, + max_input_length=self.max_input_length, ) return await self._make_system_msg(history_str) diff --git a/src/copaw/agents/hooks/memory_compaction.py b/src/copaw/agents/hooks/memory_compaction.py index 6f145cb36..f91b555f2 100644 --- a/src/copaw/agents/hooks/memory_compaction.py +++ b/src/copaw/agents/hooks/memory_compaction.py @@ -10,7 +10,6 @@ from agentscope.agent._react_agent import _MemoryMark, ReActAgent -from copaw.config import load_config from copaw.constant import MEMORY_COMPACT_KEEP_RECENT from ..utils import ( check_valid_messages, @@ -32,13 +31,28 @@ class MemoryCompactionHook: messages while summarizing older conversation history. """ - def __init__(self, memory_manager: "MemoryManager"): + def __init__( + self, + memory_manager: "MemoryManager", + memory_compact_threshold: int | None = None, + memory_compact_reserve: int | None = None, + enable_tool_result_compact: bool = False, + tool_result_compact_keep_n: int = 5, + ): """Initialize memory compaction hook. Args: memory_manager: Memory manager instance for compaction + memory_compact_threshold: Token threshold for compaction + memory_compact_reserve: Reserve tokens for recent messages + enable_tool_result_compact: Enable tool result compaction + tool_result_compact_keep_n: Number of tool results to keep """ self.memory_manager = memory_manager + self.memory_compact_threshold = memory_compact_threshold + self.memory_compact_reserve = memory_compact_reserve + self.enable_tool_result_compact = enable_tool_result_compact + self.tool_result_compact_keep_n = tool_result_compact_keep_n async def __call__( self, @@ -72,10 +86,14 @@ async def __call__( system_prompt + compressed_summary, ) - config = load_config() - memory_compact_threshold = ( - config.agents.running.memory_compact_threshold - ) + # memory_compact_threshold must be provided + if self.memory_compact_threshold is None: + raise ValueError( + "memory_compact_threshold is required but not provided " + "to MemoryCompactionHook", + ) + memory_compact_threshold = self.memory_compact_threshold + left_compact_threshold = memory_compact_threshold - str_token_count if left_compact_threshold <= 0: @@ -91,19 +109,20 @@ async def __call__( messages = await memory.get_memory(prepend_summary=False) - enable_tool_result_compact = ( - config.agents.running.enable_tool_result_compact - ) - tool_result_compact_keep_n = ( - config.agents.running.tool_result_compact_keep_n - ) + # Use configured values + enable_tool_result_compact = self.enable_tool_result_compact + tool_result_compact_keep_n = self.tool_result_compact_keep_n if enable_tool_result_compact and tool_result_compact_keep_n > 0: compact_msgs = messages[:-tool_result_compact_keep_n] await self.memory_manager.compact_tool_result(compact_msgs) - memory_compact_reserve = ( - config.agents.running.memory_compact_reserve - ) + # memory_compact_reserve must be provided + if self.memory_compact_reserve is None: + raise ValueError( + "memory_compact_reserve is required but not provided " + "to MemoryCompactionHook", + ) + memory_compact_reserve = self.memory_compact_reserve ( messages_to_compact, _, diff --git a/src/copaw/agents/memory/memory_manager.py b/src/copaw/agents/memory/memory_manager.py index 6e30c2ff1..37fc33402 100644 --- a/src/copaw/agents/memory/memory_manager.py +++ b/src/copaw/agents/memory/memory_manager.py @@ -19,7 +19,6 @@ from copaw.agents.model_factory import create_model_and_formatter from copaw.agents.tools import read_file, write_file, edit_file from copaw.agents.utils import _get_token_counter -from copaw.config import load_config logger = logging.getLogger(__name__) @@ -36,6 +35,9 @@ class ReMeLight: # type: ignore """Placeholder when reme is not available.""" + async def start(self) -> None: + """No-op start when reme is unavailable.""" + class MemoryManager(ReMeLight): """Memory manager that extends ReMeLight for CoPaw agents. @@ -47,11 +49,25 @@ class MemoryManager(ReMeLight): - Configurable vector search and full-text search backends """ - def __init__(self, working_dir: str): + def __init__( + self, + working_dir: str, + max_input_length: int = 128 * 1024, + memory_compact_ratio: float = 0.7, + memory_reserve_ratio: float = 0.1, + language: str = "zh", + ): """Initialize MemoryManager with ReMeLight configuration. Args: working_dir: Working directory path for memory storage + max_input_length: Maximum input length in tokens for context + window (default: 128K = 131072) + memory_compact_ratio: Ratio for memory compaction + (default: 0.7) + memory_reserve_ratio: Ratio for memory reserve + (default: 0.1) + language: Language for memory operations (default: "zh") Environment Variables: EMBEDDING_API_KEY: API key for embedding service @@ -72,8 +88,17 @@ def __init__(self, working_dir: str): Vector search is enabled only when both EMBEDDING_API_KEY and EMBEDDING_MODEL_NAME are configured. """ + # Store configuration parameters + self._max_input_length = max_input_length + self._memory_compact_ratio = memory_compact_ratio + self._memory_reserve_ratio = memory_reserve_ratio + self._language = language + if not _REME_AVAILABLE: - raise RuntimeError("reme package not installed.") + logger.warning( + "reme package not available, memory features will be limited", + ) + return embedding_api_key = self._safe_str("EMBEDDING_API_KEY", "") embedding_base_url = self._safe_str( @@ -234,19 +259,14 @@ async def compact_memory( """ self.prepare_model_formatter() - config = load_config() - max_input_length = config.agents.running.max_input_length - memory_compact_ratio = config.agents.running.memory_compact_ratio - language = config.agents.language - return await super().compact_memory( messages=messages, as_llm=self.chat_model, as_llm_formatter=self.formatter, token_counter=self.token_counter, - language=language, - max_input_length=max_input_length, - compact_ratio=memory_compact_ratio, + language=self._language, + max_input_length=self._max_input_length, + compact_ratio=self._memory_compact_ratio, previous_summary=previous_summary, ) @@ -263,20 +283,15 @@ async def summary_memory(self, messages: list[Msg], **_kwargs) -> str: Returns: str: Comprehensive summary of the messages """ - config = load_config() - max_input_length = config.agents.running.max_input_length - memory_compact_ratio = config.agents.running.memory_compact_ratio - language = config.agents.language - return await super().summary_memory( messages=messages, as_llm=self.chat_model, as_llm_formatter=self.formatter, token_counter=self.token_counter, toolkit=self.summary_toolkit, - language=language, - max_input_length=max_input_length, - compact_ratio=memory_compact_ratio, + language=self._language, + max_input_length=self._max_input_length, + compact_ratio=self._memory_compact_ratio, ) def get_in_memory_memory(self, **_kwargs): diff --git a/src/copaw/agents/model_factory.py b/src/copaw/agents/model_factory.py index 4edbe42d1..3feb4a822 100644 --- a/src/copaw/agents/model_factory.py +++ b/src/copaw/agents/model_factory.py @@ -11,7 +11,7 @@ import logging -from typing import Sequence, Tuple, Type, Any +from typing import Any, Optional, Sequence, Tuple, Type from functools import wraps from agentscope.formatter import FormatterBase, OpenAIChatFormatter @@ -282,15 +282,17 @@ def _strip_top_level_message_name( return messages -def create_model_and_formatter() -> Tuple[ChatModelBase, FormatterBase]: +def create_model_and_formatter( + agent_id: Optional[str] = None, +) -> Tuple[ChatModelBase, FormatterBase]: """Factory method to create model and formatter instances. This method handles both local and remote models, selecting the appropriate chat model class and formatter based on configuration. Args: - llm_cfg: Resolved model configuration. If None, will call - get_active_llm_config() to fetch the active configuration. + agent_id: Optional agent ID to load agent-specific model config. + If None, tries to get from context, then falls back to global. Returns: Tuple of (model_instance, formatter_instance) @@ -298,14 +300,56 @@ def create_model_and_formatter() -> Tuple[ChatModelBase, FormatterBase]: Example: >>> model, formatter = create_model_and_formatter() """ - # Fetch config if not provided - model = ProviderManager.get_active_chat_model() + from ..app.agent_context import get_current_agent_id + from ..config.config import load_agent_config + + # Determine agent_id (parameter > context > None) + if agent_id is None: + try: + agent_id = get_current_agent_id() + except Exception: + pass + + # Try to get agent-specific model first + model_slot = None + if agent_id: + try: + agent_config = load_agent_config(agent_id) + model_slot = agent_config.active_model + except Exception: + pass + + # Create chat model from agent-specific or global config + if model_slot and model_slot.provider_id and model_slot.model: + # Use agent-specific model + manager = ProviderManager.get_instance() + provider = manager.get_provider(model_slot.provider_id) + if provider is None: + raise ValueError( + f"Provider '{model_slot.provider_id}' not found.", + ) + if provider.is_local: + from agentscope.model import create_local_chat_model + + model = create_local_chat_model( + model_id=model_slot.model, + stream=True, + generate_kwargs={"max_tokens": None}, + ) + else: + model = provider.get_chat_model_instance(model_slot.model) + provider_id = model_slot.provider_id + else: + # Fallback to global active model + model = ProviderManager.get_active_chat_model() + provider_id = ( + ProviderManager.get_instance().get_active_model().provider_id + ) # Create the formatter based on the real model class formatter = _create_formatter_instance(model.__class__) # Wrap with retry logic for transient LLM API errors - provider_id = ProviderManager.get_instance().get_active_model().provider_id wrapped_model = TokenRecordingModelWrapper(provider_id, model) wrapped_model = RetryChatModel(wrapped_model) diff --git a/src/copaw/agents/prompt.py b/src/copaw/agents/prompt.py index fa17aa674..df049ad7c 100644 --- a/src/copaw/agents/prompt.py +++ b/src/copaw/agents/prompt.py @@ -128,16 +128,21 @@ def build(self) -> str: return final_prompt -def build_system_prompt_from_working_dir() -> str: +def build_system_prompt_from_working_dir( + working_dir: Path | None = None, + enabled_files: list[str] | None = None, + agent_id: str | None = None, +) -> str: """ Build system prompt by reading markdown files from working directory. This function constructs the system prompt by loading markdown files from - WORKING_DIR (~/.copaw by default). These files define the agent's behavior, - personality, and operational guidelines. + the specified working directory (workspace_dir for multi-agent setup). + These files define the agent's behavior, personality, and operational guidelines. - The files to load are determined by the agents.system_prompt_files configuration. - If not configured, falls back to default files: + The files to load are determined by the enabled_files parameter or + agents.system_prompt_files configuration. If not configured, falls back to + default files: - AGENTS.md - Detailed workflows, rules, and guidelines - SOUL.md - Core identity and behavioral principles - PROFILE.md - Agent identity and user profile @@ -145,6 +150,12 @@ def build_system_prompt_from_working_dir() -> str: All files are optional. If a file doesn't exist or can't be read, it will be skipped. If no files can be loaded, returns the default prompt. + Args: + working_dir: Directory to read markdown files from (if None, uses + global WORKING_DIR for backward compatibility) + enabled_files: List of filenames to load (if None, uses config or defaults) + agent_id: Agent identifier to include in system prompt (optional) + Returns: str: Constructed system prompt from markdown files. If no files exist, returns the default prompt. @@ -156,19 +167,44 @@ def build_system_prompt_from_working_dir() -> str: from ..constant import WORKING_DIR from ..config import load_config - # Load enabled files from config - config = load_config() - enabled_files = ( - config.agents.system_prompt_files - if config.agents.system_prompt_files is not None - else None - ) + # Use provided working_dir or fallback to global WORKING_DIR + if working_dir is None: + working_dir = Path(WORKING_DIR) + + # Load enabled files from parameter or config + if enabled_files is None: + # Use agent-specific config if agent_id provided + if agent_id: + from ..config.config import load_agent_config + + try: + agent_config = load_agent_config(agent_id) + enabled_files = agent_config.system_prompt_files + except (ValueError, FileNotFoundError): + # Agent not found in config, fallback to global config + config = load_config() + enabled_files = config.agents.system_prompt_files + else: + # Fallback to global config for backward compatibility + config = load_config() + enabled_files = config.agents.system_prompt_files builder = PromptBuilder( - working_dir=Path(WORKING_DIR), + working_dir=working_dir, enabled_files=enabled_files, ) - return builder.build() + prompt = builder.build() + + # Add agent identity information at the beginning of the prompt + if agent_id and agent_id != "default": + identity_header = ( + f"# Agent Identity\n\n" + f"Your agent id is `{agent_id}`. " + f"This is your unique identifier in the multi-agent system.\n\n" + ) + prompt = identity_header + prompt + + return prompt def build_bootstrap_guidance( diff --git a/src/copaw/agents/react_agent.py b/src/copaw/agents/react_agent.py index 93fb95efc..c1be58ea3 100644 --- a/src/copaw/agents/react_agent.py +++ b/src/copaw/agents/react_agent.py @@ -7,6 +7,7 @@ import asyncio import logging import os +from pathlib import Path from typing import Any, List, Literal, Optional, Type from agentscope.agent import ReActAgent @@ -43,7 +44,7 @@ ) from .utils import process_file_and_media_blocks_in_message from ..agents.memory import MemoryManager -from ..config import load_config +from ..config.config import load_agent_config from ..constant import ( MEMORY_COMPACT_RATIO, WORKING_DIR, @@ -85,6 +86,12 @@ def __init__( max_iters: int = 50, max_input_length: int = 128 * 1024, # 128K = 131072 tokens namesake_strategy: NamesakeStrategy = "skip", + memory_compact_threshold: int | None = None, + memory_compact_reserve: int | None = None, + enable_tool_result_compact: bool = False, + tool_result_compact_keep_n: int = 5, + language: str = "zh", + workspace_dir: Path | None = None, ): """Initialize CoPawAgent. @@ -102,17 +109,37 @@ def __init__( namesake_strategy: Strategy to handle namesake tool functions. Options: "override", "skip", "raise", "rename" (default: "skip") + memory_compact_threshold: Token threshold for memory + compaction (optional, uses default ratio if not set) + memory_compact_reserve: Reserve tokens for recent messages + enable_tool_result_compact: Enable tool result compaction + tool_result_compact_keep_n: Number of tool results to keep + language: Language setting for agent (default: "zh") + workspace_dir: Workspace directory for reading prompt files + (if None, uses global WORKING_DIR) """ self._env_context = env_context self._request_context = dict(request_context or {}) self._max_input_length = max_input_length self._mcp_clients = mcp_clients or [] self._namesake_strategy = namesake_strategy - - # Memory compaction threshold: configurable ratio of max_input_length - self._memory_compact_threshold = int( - max_input_length * MEMORY_COMPACT_RATIO, + self._language = language + self._workspace_dir = workspace_dir + + # Memory compaction settings: use provided or calculate defaults + self._memory_compact_threshold = ( + memory_compact_threshold + if memory_compact_threshold is not None + else int(max_input_length * MEMORY_COMPACT_RATIO) + ) + # Calculate reserve as 40% of max_input_length if not provided + self._memory_compact_reserve = ( + memory_compact_reserve + if memory_compact_reserve is not None + else int(max_input_length * 0.4) ) + self._enable_tool_result_compact = enable_tool_result_compact + self._tool_result_compact_keep_n = tool_result_compact_keep_n # Initialize toolkit with built-in tools toolkit = self._create_toolkit(namesake_strategy=namesake_strategy) @@ -150,6 +177,7 @@ def __init__( memory=self.memory, memory_manager=self.memory_manager, enable_memory_manager=self._enable_memory_manager, + max_input_length=max_input_length, ) # Register hooks @@ -171,14 +199,29 @@ def _create_toolkit( """ toolkit = Toolkit() - # Load config to check which tools are enabled - config = load_config() + # Check which tools are enabled (tools config is agent-specific) enabled_tools = {} - if hasattr(config, "tools") and hasattr(config.tools, "builtin_tools"): - enabled_tools = { - name: tool_config.enabled - for name, tool_config in config.tools.builtin_tools.items() - } + try: + # Get agent_id from request_context + agent_id = ( + self._request_context.get("agent_id", "default") + if self._request_context + else "default" + ) + agent_config = load_agent_config(agent_id) + if hasattr(agent_config, "tools") and hasattr( + agent_config.tools, + "builtin_tools", + ): + enabled_tools = { + name: tool.enabled + for name, tool in agent_config.tools.builtin_tools.items() + } + except Exception as e: + logger.warning( + f"Failed to load agent tools config: {e}, " + "all tools will be disabled", + ) # Map of tool functions tool_functions = { @@ -210,16 +253,18 @@ def _create_toolkit( return toolkit def _register_skills(self, toolkit: Toolkit) -> None: - """Load and register skills from working directory. + """Load and register skills from workspace directory. Args: toolkit: Toolkit to register skills to """ + workspace_dir = self._workspace_dir or WORKING_DIR + # Check skills initialization - ensure_skills_initialized() + ensure_skills_initialized(workspace_dir) - working_skills_dir = get_working_skills_dir() - available_skills = list_available_skills() + working_skills_dir = get_working_skills_dir(workspace_dir) + available_skills = list_available_skills(workspace_dir) for skill_name in available_skills: skill_dir = working_skills_dir / skill_name @@ -240,7 +285,17 @@ def _build_sys_prompt(self) -> str: Returns: Complete system prompt string """ - sys_prompt = build_system_prompt_from_working_dir() + # Get agent_id from request_context + agent_id = ( + self._request_context.get("agent_id") + if self._request_context + else None + ) + + sys_prompt = build_system_prompt_from_working_dir( + working_dir=self._workspace_dir, + agent_id=agent_id, + ) if self._env_context is not None: sys_prompt = sys_prompt + "\n\n" + self._env_context return sys_prompt @@ -283,10 +338,13 @@ def _setup_memory_manager( def _register_hooks(self) -> None: """Register pre-reasoning and pre-acting hooks.""" # Bootstrap hook - checks BOOTSTRAP.md on first interaction - config = load_config() + # Use workspace_dir if available, else fallback to WORKING_DIR + working_dir = ( + self._workspace_dir if self._workspace_dir else WORKING_DIR + ) bootstrap_hook = BootstrapHook( - working_dir=WORKING_DIR, - language=config.agents.language, + working_dir=working_dir, + language=self._language, ) self.register_instance_hook( hook_type="pre_reasoning", @@ -299,6 +357,10 @@ def _register_hooks(self) -> None: if self._enable_memory_manager and self.memory_manager is not None: memory_compact_hook = MemoryCompactionHook( memory_manager=self.memory_manager, + memory_compact_threshold=self._memory_compact_threshold, + memory_compact_reserve=self._memory_compact_reserve, + enable_tool_result_compact=self._enable_tool_result_compact, + tool_result_compact_keep_n=self._tool_result_compact_keep_n, ) self.register_instance_hook( hook_type="pre_reasoning", @@ -513,6 +575,11 @@ async def reply( Returns: Response message """ + # Set workspace_dir in context for tool functions + from ..config.context import set_current_workspace_dir + + set_current_workspace_dir(self._workspace_dir) + # Process file and media blocks in messages if msg is not None: await process_file_and_media_blocks_in_message(msg) diff --git a/src/copaw/agents/skills/cron/SKILL.md b/src/copaw/agents/skills/cron/SKILL.md index 628303e0c..ca9995ab1 100644 --- a/src/copaw/agents/skills/cron/SKILL.md +++ b/src/copaw/agents/skills/cron/SKILL.md @@ -11,9 +11,12 @@ metadata: { "copaw": { "emoji": "⏰" } } ## 常用命令 ```bash -# 列出所有任务 +# 列出所有任务(默认操作 default agent) copaw cron list +# 为特定 agent 列出任务 +copaw cron list --agent-id abc123 + # 查看任务详情 copaw cron get @@ -31,6 +34,8 @@ copaw cron resume copaw cron run ``` +**注意**:所有命令都支持 `--agent-id` 参数,默认为 `default`。如果需要操作特定 agent 的任务,请指定对应的 agent ID。 + ## 创建任务 支持两种任务类型: @@ -40,7 +45,7 @@ copaw cron run ### 快速创建 ```bash -# 每天 9:00 发送文本消息 +# 每天 9:00 发送文本消息(默认 agent) copaw cron create \ --type text \ --name "每日早安" \ @@ -50,8 +55,9 @@ copaw cron create \ --target-session "CHANGEME" \ --text "早上好!" -# 每 2 小时向 Agent 提问 +# 为特定 agent 创建任务 copaw cron create \ + --agent-id abc123 \ --type agent \ --name "检查待办" \ --cron "0 */2 * * *" \ @@ -72,6 +78,10 @@ copaw cron create \ - `--target-session`:会话标识 - `--text`:消息内容(text 类型)或提问内容(agent 类型) +### 可选参数 + +- `--agent-id`:指定 agent ID(默认:default)。用于多 agent 场景。 + ### 从 JSON 创建(复杂配置) ```bash @@ -93,4 +103,5 @@ copaw cron create -f job_spec.json - 缺少参数时,询问用户补充后再创建 - 暂停/删除/恢复前,用 `copaw cron list` 查找 job_id - 排查问题时,用 `copaw cron state ` 查看状态 -- 给用户的命令要完整、可直接复制执行 \ No newline at end of file +- 给用户的命令要完整、可直接复制执行 +- 记得指定 `--agent-id` 参数 diff --git a/src/copaw/agents/skills_hub.py b/src/copaw/agents/skills_hub.py index 39e15888f..031e8c828 100644 --- a/src/copaw/agents/skills_hub.py +++ b/src/copaw/agents/skills_hub.py @@ -11,6 +11,7 @@ import io import zipfile from dataclasses import dataclass +from pathlib import Path from typing import Any from urllib.parse import urlencode, urlparse, unquote from urllib.error import HTTPError, URLError @@ -1352,6 +1353,7 @@ def search_hub_skills(query: str, limit: int = 20) -> list[HubSkillResult]: # pylint: disable-next=too-many-branches def install_skill_from_hub( *, + workspace_dir: Path, bundle_url: str, version: str = "", enable: bool = True, @@ -1409,7 +1411,8 @@ def install_skill_from_hub( # Sanitize: "Excel / XLSX" etc. must not be used as dir name name = _sanitize_skill_dir_name(name) - created = SkillService.create_skill( + skill_service = SkillService(workspace_dir) + created = skill_service.create_skill( name=name, content=content, overwrite=overwrite, @@ -1425,7 +1428,7 @@ def install_skill_from_hub( enabled = False if enable: - enabled = SkillService.enable_skill(name, force=True) + enabled = skill_service.enable_skill(name, force=True) if not enabled: logger.warning("Skill '%s' imported but enable failed", name) diff --git a/src/copaw/agents/skills_manager.py b/src/copaw/agents/skills_manager.py index 76ba47d81..db8b9b67b 100644 --- a/src/copaw/agents/skills_manager.py +++ b/src/copaw/agents/skills_manager.py @@ -11,7 +11,6 @@ from pydantic import BaseModel import frontmatter -from ..constant import ACTIVE_SKILLS_DIR, CUSTOMIZED_SKILLS_DIR logger = logging.getLogger(__name__) @@ -143,23 +142,23 @@ def get_builtin_skills_dir() -> Path: return Path(__file__).parent / "skills" -def get_customized_skills_dir() -> Path: - """Get the path to customized skills directory in working_dir.""" - return CUSTOMIZED_SKILLS_DIR +def get_customized_skills_dir(workspace_dir: Path) -> Path: + """Get the path to customized skills directory in workspace_dir.""" + return workspace_dir / "customized_skills" -def get_active_skills_dir() -> Path: - """Get the path to active skills directory in working_dir.""" - return ACTIVE_SKILLS_DIR +def get_active_skills_dir(workspace_dir: Path) -> Path: + """Get the path to active skills directory in workspace_dir.""" + return workspace_dir / "active_skills" -def get_working_skills_dir() -> Path: +def get_working_skills_dir(workspace_dir: Path) -> Path: """ - Get the path to skills directory in working_dir. + Get the path to skills directory in workspace_dir. Deprecated: Use get_active_skills_dir() instead. """ - return get_active_skills_dir() + return get_active_skills_dir(workspace_dir) def _build_directory_tree(directory: Path) -> dict[str, Any]: @@ -218,6 +217,7 @@ def _collect_skills_from_dir(directory: Path) -> dict[str, Path]: def sync_skills_to_working_dir( + workspace_dir: Path, skill_names: list[str] | None = None, force: bool = False, ) -> tuple[int, int]: @@ -225,6 +225,7 @@ def sync_skills_to_working_dir( Sync skills from builtin and customized to active_skills directory. Args: + workspace_dir: Workspace directory path. skill_names: List of skill names to sync. If None, sync all skills. force: If True, overwrite existing skills in active_skills. @@ -232,8 +233,8 @@ def sync_skills_to_working_dir( Tuple of (synced_count, skipped_count). """ builtin_skills = get_builtin_skills_dir() - customized_skills = get_customized_skills_dir() - active_skills = get_active_skills_dir() + customized_skills = get_customized_skills_dir(workspace_dir) + active_skills = get_active_skills_dir(workspace_dir) # Ensure active skills directory exists active_skills.mkdir(parents=True, exist_ok=True) @@ -296,19 +297,21 @@ def sync_skills_to_working_dir( def sync_skills_from_active_to_customized( + workspace_dir: Path, skill_names: list[str] | None = None, ) -> tuple[int, int]: """ Sync skills from active_skills to customized_skills directory. Args: + workspace_dir: Workspace directory path. skill_names: List of skill names to sync. If None, sync all skills. Returns: Tuple of (synced_count, skipped_count). """ - active_skills = get_active_skills_dir() - customized_skills = get_customized_skills_dir() + active_skills = get_active_skills_dir(workspace_dir) + customized_skills = get_customized_skills_dir(workspace_dir) builtin_skills = get_builtin_skills_dir() customized_skills.mkdir(parents=True, exist_ok=True) @@ -357,14 +360,17 @@ def sync_skills_from_active_to_customized( return synced_count, skipped_count -def list_available_skills() -> list[str]: +def list_available_skills(workspace_dir: Path) -> list[str]: """ List all available skills in active_skills directory. + Args: + workspace_dir: Workspace directory path. + Returns: List of skill names. """ - active_skills = get_active_skills_dir() + active_skills = get_active_skills_dir(workspace_dir) if not active_skills.exists(): return [] @@ -376,16 +382,19 @@ def list_available_skills() -> list[str]: ] -def ensure_skills_initialized() -> None: +def ensure_skills_initialized(workspace_dir: Path) -> None: """ Check if skills are initialized in active_skills directory. + Args: + workspace_dir: Workspace directory path. + Logs a warning if no skills are found, or info about loaded skills. Skills should be configured via `copaw init` or `copaw skills config`. """ - active_skills = get_active_skills_dir() - available = list_available_skills() + active_skills = get_active_skills_dir(workspace_dir) + available = list_available_skills(workspace_dir) if not active_skills.exists() or not available: logger.warning( @@ -532,11 +541,20 @@ class SkillService: """ Service for managing skills. - Manages skills across builtin, customized, and active directories. + Manages skills across builtin, customized, and active directories + for a specific workspace. """ - @staticmethod - def list_all_skills() -> list[SkillInfo]: + def __init__(self, workspace_dir: Path): + """ + Initialize SkillService for a specific workspace. + + Args: + workspace_dir: Path to the workspace directory. + """ + self.workspace_dir = workspace_dir + + def list_all_skills(self) -> list[SkillInfo]: """ List all skills from builtin and customized directories. @@ -544,7 +562,9 @@ def list_all_skills() -> list[SkillInfo]: List of SkillInfo with name, content, source, and path. """ try: - synced, _ = sync_skills_from_active_to_customized() + synced, _ = sync_skills_from_active_to_customized( + self.workspace_dir, + ) if synced > 0: logger.debug( "Synced %d skill(s) from active_skills", @@ -564,23 +584,28 @@ def list_all_skills() -> list[SkillInfo]: _read_skills_from_dir(get_builtin_skills_dir(), "builtin"), ) skills.extend( - _read_skills_from_dir(get_customized_skills_dir(), "customized"), + _read_skills_from_dir( + get_customized_skills_dir(self.workspace_dir), + "customized", + ), ) return _dedupe_skills_by_name(skills) - @staticmethod - def list_available_skills() -> list[SkillInfo]: + def list_available_skills(self) -> list[SkillInfo]: """ List all available (active) skills in active_skills directory. Returns: List of SkillInfo with name, content, source, and path. """ - return _read_skills_from_dir(get_active_skills_dir(), "active") + return _read_skills_from_dir( + get_active_skills_dir(self.workspace_dir), + "active", + ) - @staticmethod def create_skill( + self, name: str, content: str, overwrite: bool = False, @@ -657,7 +682,7 @@ def create_skill( ) return False - customized_dir = get_customized_skills_dir() + customized_dir = get_customized_skills_dir(self.workspace_dir) customized_dir.mkdir(parents=True, exist_ok=True) skill_dir = customized_dir / name @@ -719,8 +744,7 @@ def create_skill( ) return False - @staticmethod - def disable_skill(name: str) -> bool: + def disable_skill(self, name: str) -> bool: """ Disable a skill by removing it from active_skills directory. @@ -730,7 +754,7 @@ def disable_skill(name: str) -> bool: Returns: True if skill was disabled successfully, False otherwise. """ - active_dir = get_active_skills_dir() + active_dir = get_active_skills_dir(self.workspace_dir) skill_dir = active_dir / name if not skill_dir.exists(): @@ -752,8 +776,7 @@ def disable_skill(name: str) -> bool: ) return False - @staticmethod - def enable_skill(name: str, force: bool = False) -> bool: + def enable_skill(self, name: str, force: bool = False) -> bool: """ Enable a skill by syncing it to active_skills directory. @@ -764,13 +787,16 @@ def enable_skill(name: str, force: bool = False) -> bool: Returns: True if skill was enabled successfully, False otherwise. """ - sync_skills_to_working_dir(skill_names=[name], force=force) + sync_skills_to_working_dir( + self.workspace_dir, + skill_names=[name], + force=force, + ) # Check if skill was actually synced - active_dir = get_active_skills_dir() + active_dir = get_active_skills_dir(self.workspace_dir) return (active_dir / name).exists() - @staticmethod - def delete_skill(name: str) -> bool: + def delete_skill(self, name: str) -> bool: """ Delete a skill from customized_skills directory permanently. @@ -785,7 +811,7 @@ def delete_skill(name: str) -> bool: Returns: True if skill was deleted successfully, False otherwise. """ - customized_dir = get_customized_skills_dir() + customized_dir = get_customized_skills_dir(self.workspace_dir) skill_dir = customized_dir / name if not skill_dir.exists(): @@ -810,8 +836,8 @@ def delete_skill(name: str) -> bool: ) return False - @staticmethod def sync_from_active_to_customized( + self, skill_names: list[str] | None = None, ) -> tuple[int, int]: """ @@ -824,11 +850,12 @@ def sync_from_active_to_customized( Tuple of (synced_count, skipped_count). """ return sync_skills_from_active_to_customized( + self.workspace_dir, skill_names=skill_names, ) - @staticmethod def load_skill_file( # pylint: disable=too-many-return-statements + self, skill_name: str, file_path: str, source: str, @@ -892,7 +919,7 @@ def load_skill_file( # pylint: disable=too-many-return-statements # Get source directory if source == "customized": - base_dir = get_customized_skills_dir() + base_dir = get_customized_skills_dir(self.workspace_dir) else: # builtin base_dir = get_builtin_skills_dir() diff --git a/src/copaw/agents/tools/file_io.py b/src/copaw/agents/tools/file_io.py index 1e8040412..237d38ad3 100644 --- a/src/copaw/agents/tools/file_io.py +++ b/src/copaw/agents/tools/file_io.py @@ -9,12 +9,13 @@ from agentscope.tool import ToolResponse from ...constant import WORKING_DIR +from ...config.context import get_current_workspace_dir from .utils import truncate_file_output, read_file_safe def _resolve_file_path(file_path: str) -> str: """Resolve file path: use absolute path as-is, - resolve relative path from WORKING_DIR. + resolve relative path from current workspace or WORKING_DIR. Args: file_path: The input file path (absolute or relative). @@ -26,7 +27,9 @@ def _resolve_file_path(file_path: str) -> str: if path.is_absolute(): return str(path) else: - return str(WORKING_DIR / file_path) + # Use current workspace_dir from context, fallback to WORKING_DIR + workspace_dir = get_current_workspace_dir() or WORKING_DIR + return str(workspace_dir / file_path) async def read_file( # pylint: disable=too-many-return-statements diff --git a/src/copaw/agents/tools/file_search.py b/src/copaw/agents/tools/file_search.py index 12bcdae0d..330c5a722 100644 --- a/src/copaw/agents/tools/file_search.py +++ b/src/copaw/agents/tools/file_search.py @@ -11,6 +11,7 @@ from agentscope.tool import ToolResponse from ...constant import WORKING_DIR +from ...config.context import get_current_workspace_dir from .file_io import _resolve_file_path # Skip binary / large files @@ -112,7 +113,11 @@ async def grep_search( # pylint: disable=too-many-branches ], ) - search_root = Path(_resolve_file_path(path)) if path else WORKING_DIR + search_root = ( + Path(_resolve_file_path(path)) + if path + else (get_current_workspace_dir() or WORKING_DIR) + ) if not search_root.exists(): return ToolResponse( @@ -236,7 +241,11 @@ async def glob_search( ], ) - search_root = Path(_resolve_file_path(path)) if path else WORKING_DIR + search_root = ( + Path(_resolve_file_path(path)) + if path + else (get_current_workspace_dir() or WORKING_DIR) + ) if not search_root.exists(): return ToolResponse( diff --git a/src/copaw/agents/tools/send_file.py b/src/copaw/agents/tools/send_file.py index c6fb13773..e48a8b52a 100644 --- a/src/copaw/agents/tools/send_file.py +++ b/src/copaw/agents/tools/send_file.py @@ -4,7 +4,6 @@ import os import mimetypes import unicodedata -from pathlib import Path from agentscope.tool import ToolResponse from agentscope.message import ( @@ -15,24 +14,6 @@ ) from ..schema import FileBlock -from ...constant import WORKING_DIR - -# Only allow files under this directory, mirroring message_processing.py -_ALLOWED_MEDIA_ROOT = WORKING_DIR / "media" - - -def _is_allowed_media_path(path: str) -> bool: - """True if path is a file under _ALLOWED_MEDIA_ROOT. - - Returns False when the path is invalid, cannot be resolved, - is not a regular file, or lies outside the allowed media root. - """ - try: - resolved = Path(path).expanduser().resolve() - root = _ALLOWED_MEDIA_ROOT.resolve() - return resolved.is_file() and resolved.is_relative_to(root) - except Exception: - return False def _auto_as_type(mt: str) -> str: @@ -50,13 +31,9 @@ async def send_file_to_user( ) -> ToolResponse: """Send a file to the user. - The file must be located inside the allowed media directory - (``/media``). Attempting to send a file from outside that - directory will return an error so that the agent is aware of the failure. - Args: file_path (`str`): - Path to the file to send. Must be inside the media directory. + Path to the file to send. Returns: `ToolResponse`: @@ -98,17 +75,6 @@ async def send_file_to_user( try: # Use local file URL instead of base64 absolute_path = os.path.abspath(file_path) - - if not _is_allowed_media_path(absolute_path): - return ToolResponse( - content=[ - TextBlock( - type="text", - text=f"Error: Media file outside allowed directory: {os.path.basename(file_path)}", - ), - ], - ) - file_url = f"file://{absolute_path}" source = {"type": "url", "url": file_url} diff --git a/src/copaw/agents/tools/shell.py b/src/copaw/agents/tools/shell.py index babcaa3b4..5c3ae9f06 100644 --- a/src/copaw/agents/tools/shell.py +++ b/src/copaw/agents/tools/shell.py @@ -14,7 +14,8 @@ from agentscope.message import TextBlock from agentscope.tool import ToolResponse -from copaw.constant import WORKING_DIR +from ...constant import WORKING_DIR +from ...config.context import get_current_workspace_dir from .utils import truncate_shell_output @@ -156,7 +157,11 @@ async def execute_shell_command( cmd = (command or "").strip() # Set working directory - working_dir = cwd if cwd is not None else WORKING_DIR + # Use current workspace_dir from context, fallback to WORKING_DIR + if cwd is not None: + working_dir = cwd + else: + working_dir = get_current_workspace_dir() or WORKING_DIR # Ensure the venv Python is on PATH for subprocesses env = os.environ.copy() diff --git a/src/copaw/agents/utils/setup_utils.py b/src/copaw/agents/utils/setup_utils.py index f9b584c28..4caf82b0d 100644 --- a/src/copaw/agents/utils/setup_utils.py +++ b/src/copaw/agents/utils/setup_utils.py @@ -14,18 +14,23 @@ def copy_md_files( language: str, skip_existing: bool = False, + workspace_dir: Path | None = None, ) -> list[str]: """Copy md files from agents/md_files to working directory. Args: language: Language code (e.g. 'en', 'zh') skip_existing: If True, skip files that already exist in working dir. + workspace_dir: Target workspace directory. If None, uses WORKING_DIR. Returns: List of copied file names. """ from ...constant import WORKING_DIR + # Use provided workspace_dir or default to WORKING_DIR + target_dir = workspace_dir if workspace_dir is not None else WORKING_DIR + # Get md_files directory path with language subdirectory md_files_dir = Path(__file__).parent.parent / "md_files" / language @@ -40,13 +45,13 @@ def copy_md_files( logger.error("Default 'en' md files not found either") return [] - # Ensure working directory exists - WORKING_DIR.mkdir(parents=True, exist_ok=True) + # Ensure target directory exists + target_dir.mkdir(parents=True, exist_ok=True) - # Copy all .md files to working directory + # Copy all .md files to target directory copied_files: list[str] = [] for md_file in md_files_dir.glob("*.md"): - target_file = WORKING_DIR / md_file.name + target_file = target_dir / md_file.name if skip_existing and target_file.exists(): logger.debug("Skipped existing md file: %s", md_file.name) continue @@ -66,7 +71,7 @@ def copy_md_files( "Copied %d md file(s) [%s] to %s", len(copied_files), language, - WORKING_DIR, + target_dir, ) return copied_files diff --git a/src/copaw/app/_app.py b/src/copaw/app/_app.py index 8c908272b..25ad975cb 100644 --- a/src/copaw/app/_app.py +++ b/src/copaw/app/_app.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- # pylint: disable=redefined-outer-name,unused-argument -import asyncio import mimetypes import os import time @@ -13,27 +12,21 @@ from fastapi.responses import FileResponse from agentscope_runtime.engine.app import AgentApp -from .runner import AgentRunner -from ..config import ( # pylint: disable=no-name-in-module - load_config, - update_last_dispatch, - ConfigWatcher, -) -from ..config.utils import get_jobs_path, get_chats_path, get_config_path +from ..config import load_config # pylint: disable=no-name-in-module +from ..config.utils import get_config_path from ..constant import DOCS_ENABLED, LOG_LEVEL_ENV, CORS_ORIGINS, WORKING_DIR from ..__version__ import __version__ from ..utils.logging import setup_logger, add_copaw_file_handler -from .channels import ChannelManager # pylint: disable=no-name-in-module -from .channels.utils import make_process_from_runner -from .mcp import MCPClientManager, MCPConfigWatcher # MCP hot-reload support -from .runner.repo.json_repo import JsonChatRepository -from .crons.repo.json_repo import JsonJobRepository -from .crons.manager import CronManager -from .runner.manager import ChatManager -from .routers import router as api_router +from .routers import router as api_router, create_agent_scoped_router +from .routers.agent_scoped import AgentContextMiddleware from .routers.voice import voice_router from ..envs import load_envs_into_environ from ..providers.provider_manager import ProviderManager +from .multi_agent_manager import MultiAgentManager +from .migration import ( + migrate_legacy_workspace_to_default_agent, + ensure_default_agent_exists, +) # Apply log level on load so reload child process gets same level as CLI. logger = setup_logger(os.environ.get(LOG_LEVEL_ENV, "info")) @@ -50,7 +43,98 @@ # so they are available before the lifespan starts. load_envs_into_environ() -runner = AgentRunner() + +# Dynamic runner that selects the correct workspace runner based on request +class DynamicMultiAgentRunner: + """Runner wrapper that dynamically routes to the correct workspace runner. + + This allows AgentApp to work with multiple agents by inspecting + the X-Agent-Id header on each request. + """ + + def __init__(self): + self.framework_type = "agentscope" + self._multi_agent_manager = None + + def set_multi_agent_manager(self, manager): + """Set the MultiAgentManager instance after initialization.""" + self._multi_agent_manager = manager + + async def _get_workspace_runner(self, request): + """Get the correct workspace runner based on request.""" + from .agent_context import get_current_agent_id + + # Get agent_id from context (set by middleware or header) + agent_id = get_current_agent_id() + + logger.debug(f"_get_workspace_runner: agent_id={agent_id}") + + # Get the correct workspace runner + if not self._multi_agent_manager: + raise RuntimeError("MultiAgentManager not initialized") + + try: + workspace = await self._multi_agent_manager.get_agent(agent_id) + logger.debug( + f"Got workspace: {workspace.agent_id}, " + f"runner: {workspace.runner}", + ) + return workspace.runner + except ValueError as e: + logger.error(f"Agent not found: {e}") + raise + except Exception as e: + logger.error( + f"Error getting workspace runner: {e}", + exc_info=True, + ) + raise + + async def stream_query(self, request, *args, **kwargs): + """Dynamically route to the correct workspace runner.""" + logger.debug("DynamicMultiAgentRunner.stream_query called") + try: + runner = await self._get_workspace_runner(request) + logger.debug(f"Got runner: {runner}, type: {type(runner)}") + # Delegate to the actual runner's stream_query generator + count = 0 + async for item in runner.stream_query(request, *args, **kwargs): + count += 1 + logger.debug(f"Yielding item #{count}: {type(item)}") + yield item + logger.debug(f"stream_query completed, yielded {count} items") + except Exception as e: + logger.error( + f"Error in stream_query: {e}", + exc_info=True, + ) + # Yield error message to client + yield { + "error": str(e), + "type": "error", + } + + async def query_handler(self, request, *args, **kwargs): + """Dynamically route to the correct workspace runner.""" + runner = await self._get_workspace_runner(request) + # Delegate to the actual runner's query_handler generator + async for item in runner.query_handler(request, *args, **kwargs): + yield item + + # Async context manager support for AgentApp lifecycle + async def __aenter__(self): + """ + No-op context manager entry (workspaces manage their own runners). + """ + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """No-op context manager exit (workspaces manage their own runners).""" + return None + + +# Use dynamic runner for AgentApp +runner = DynamicMultiAgentRunner() agent_app = AgentApp( app_name="Friday", @@ -65,343 +149,50 @@ async def lifespan( ): # pylint: disable=too-many-statements,too-many-branches startup_start_time = time.time() add_copaw_file_handler(WORKING_DIR / "copaw.log") - await runner.start() - - # --- MCP client manager init (independent module, hot-reloadable) --- - config = load_config() - mcp_manager = MCPClientManager() - if hasattr(config, "mcp"): - try: - await mcp_manager.init_from_config(config.mcp) - logger.debug("MCP client manager initialized") - except BaseException as e: - if isinstance(e, (KeyboardInterrupt, SystemExit)): - raise - logger.exception("Failed to initialize MCP manager") - runner.set_mcp_manager(mcp_manager) - - # --- channel connector init/start (from config.json) --- - channel_manager = ChannelManager.from_config( - process=make_process_from_runner(runner), - config=config, - on_last_dispatch=update_last_dispatch, - ) - await channel_manager.start_all() - - # --- cron init/start --- - repo = JsonJobRepository(get_jobs_path()) - cron_manager = CronManager( - repo=repo, - runner=runner, - channel_manager=channel_manager, - timezone="UTC", - ) - await cron_manager.start() - # --- chat manager init and connect to runner.session --- - chat_repo = JsonChatRepository(get_chats_path()) - chat_manager = ChatManager( - repo=chat_repo, - ) + # --- Multi-agent migration and initialization --- + logger.info("Checking for legacy config migration...") + migrate_legacy_workspace_to_default_agent() + ensure_default_agent_exists() - runner.set_chat_manager(chat_manager) + # --- Multi-agent manager initialization --- + logger.info("Initializing MultiAgentManager...") + multi_agent_manager = MultiAgentManager() - # --- config file watcher (channels + heartbeat hot-reload on change) --- - config_watcher = ConfigWatcher( - channel_manager=channel_manager, - cron_manager=cron_manager, - ) - await config_watcher.start() - - # --- MCP config watcher (auto-reload MCP clients on change) --- - mcp_watcher = None - if hasattr(config, "mcp"): - try: - mcp_watcher = MCPConfigWatcher( - mcp_manager=mcp_manager, - config_loader=load_config, - config_path=get_config_path(), - ) - await mcp_watcher.start() - logger.debug("MCP config watcher started") - except BaseException as e: - if isinstance(e, (KeyboardInterrupt, SystemExit)): - raise - logger.exception("Failed to start MCP watcher") - - # Inject channel_manager into approval service so it can - # proactively push approval messages to channels like DingTalk. - from .approvals import get_approval_service - - get_approval_service().set_channel_manager(channel_manager) + # Start all configured agents (handled by manager) + await multi_agent_manager.start_all_configured_agents() # --- Model provider manager (non-reloadable, in-memory) --- provider_manager = ProviderManager.get_instance() - # expose to endpoints - app.state.runner = runner - app.state.channel_manager = channel_manager - app.state.cron_manager = cron_manager - app.state.chat_manager = chat_manager - app.state.config_watcher = config_watcher - app.state.mcp_manager = mcp_manager - app.state.mcp_watcher = mcp_watcher - app.state.provider_manager = provider_manager + # Expose to endpoints - multi-agent manager + app.state.multi_agent_manager = multi_agent_manager - _restart_task: asyncio.Task | None = None + # Connect DynamicMultiAgentRunner to MultiAgentManager + if isinstance(runner, DynamicMultiAgentRunner): + runner.set_multi_agent_manager(multi_agent_manager) - async def _restart_services() -> None: - """Stop all managers, then rebuild from config (no exit). + # Helper function to get agent instance by ID (async) + async def _get_agent_by_id(agent_id: str = None): + """Get agent instance by ID, or active agent if not specified.""" + if agent_id is None: + config = load_config(get_config_path()) + agent_id = config.agents.active_agent or "default" + return await multi_agent_manager.get_agent(agent_id) - Single-flight: only one restart runs at a time. Concurrent or - duplicate callers wait for the in-progress restart and return - successfully. Uses asyncio.shield() so that when the caller - (e.g. channel request) is cancelled, the restart task keeps - running and does not propagate cancellation into deep task - trees (avoids RecursionError on cancel). - """ - # pylint: disable=too-many-statements - nonlocal _restart_task - # Caller task (in _local_tasks) must not be cancelled so it can - # yield the final "Restart completed" message. - restart_requester_task = asyncio.current_task() + app.state.get_agent_by_id = _get_agent_by_id - async def _run_then_clear() -> None: - try: - await _do_restart_services( - restart_requester_task=restart_requester_task, - ) - finally: - nonlocal _restart_task - _restart_task = None - - if _restart_task is not None and not _restart_task.done(): - logger.info( - "_restart_services: waiting for in-progress restart to finish", - ) - await asyncio.shield(_restart_task) - return - if _restart_task is not None and _restart_task.done(): - _restart_task = None - logger.info("_restart_services: starting restart") - _restart_task = asyncio.create_task(_run_then_clear()) - await asyncio.shield(_restart_task) - - async def _teardown_new_stack( - mcp_watcher=None, - config_watcher=None, - cron_mgr=None, - ch_mgr=None, - mcp_mgr=None, - ) -> None: - """Stop new stack in reverse start order (for rollback on failure).""" - if mcp_watcher is not None: - try: - await mcp_watcher.stop() - except Exception: - logger.debug( - "rollback: mcp_watcher.stop failed", - exc_info=True, - ) - if config_watcher is not None: - try: - await config_watcher.stop() - except Exception: - logger.debug( - "rollback: config_watcher.stop failed", - exc_info=True, - ) - if cron_mgr is not None: - try: - await cron_mgr.stop() - except Exception: - logger.debug( - "rollback: cron_manager.stop failed", - exc_info=True, - ) - if ch_mgr is not None: - try: - await ch_mgr.stop_all() - except Exception: - logger.debug( - "rollback: channel_manager.stop_all failed", - exc_info=True, - ) - if mcp_mgr is not None: - try: - await mcp_mgr.close_all() - except Exception: - logger.debug( - "rollback: mcp_manager.close_all failed", - exc_info=True, - ) - - async def _do_restart_services( - restart_requester_task: asyncio.Task | None = None, - ) -> None: - """Cancel in-flight agent requests first (so they can send error to - channel), then stop old stack, then start new stack and swap. - """ - # pylint: disable=too-many-statements - try: - config = load_config(get_config_path()) - except Exception: - logger.exception("restart_services: load_config failed") - return - - # 1) Cancel in-flight agent requests. Do not wait for them so the - # console restart task never blocks (avoid deadlock when cancelled - # task is slow to exit). - local_tasks = getattr(agent_app, "_local_tasks", None) - if local_tasks: - to_cancel = [ - t - for t in list(local_tasks.values()) - if t is not restart_requester_task and not t.done() - ] - for t in to_cancel: - t.cancel() - if to_cancel: - logger.info( - "restart: cancelled %s in-flight task(s), not waiting", - len(to_cancel), - ) - - # 2) Stop old stack - cfg_w = app.state.config_watcher - mcp_w = getattr(app.state, "mcp_watcher", None) - cron_mgr = app.state.cron_manager - ch_mgr = app.state.channel_manager - mcp_mgr = app.state.mcp_manager - try: - await cfg_w.stop() - except Exception: - logger.exception( - "restart_services: old config_watcher.stop failed", - ) - if mcp_w is not None: - try: - await mcp_w.stop() - except Exception: - logger.exception( - "restart_services: old mcp_watcher.stop failed", - ) - try: - await cron_mgr.stop() - except Exception: - logger.exception( - "restart_services: old cron_manager.stop failed", - ) - try: - await ch_mgr.stop_all() - except Exception: - logger.exception( - "restart_services: old channel_manager.stop_all failed", - ) - if mcp_mgr is not None: - try: - await mcp_mgr.close_all() - except Exception: - logger.exception( - "restart_services: old mcp_manager.close_all failed", - ) - - # 3) Build and start new stack - new_mcp_manager = MCPClientManager() - if hasattr(config, "mcp"): - try: - await new_mcp_manager.init_from_config(config.mcp) - except Exception: - logger.exception( - "restart_services: mcp init_from_config failed", - ) - return - - new_channel_manager = ChannelManager.from_config( - process=make_process_from_runner(runner), - config=config, - on_last_dispatch=update_last_dispatch, - ) - try: - await new_channel_manager.start_all() - except Exception: - logger.exception( - "restart_services: channel_manager.start_all failed", - ) - await _teardown_new_stack(mcp_mgr=new_mcp_manager) - return - - job_repo = JsonJobRepository(get_jobs_path()) - new_cron_manager = CronManager( - repo=job_repo, - runner=runner, - channel_manager=new_channel_manager, - timezone="UTC", - ) - try: - await new_cron_manager.start() - except Exception: - logger.exception( - "restart_services: cron_manager.start failed", - ) - await _teardown_new_stack( - ch_mgr=new_channel_manager, - mcp_mgr=new_mcp_manager, - ) - return + # Global managers (shared across all agents) + app.state.provider_manager = provider_manager - new_config_watcher = ConfigWatcher( - channel_manager=new_channel_manager, - cron_manager=new_cron_manager, - ) - try: - await new_config_watcher.start() - except Exception: - logger.exception( - "restart_services: config_watcher.start failed", - ) - await _teardown_new_stack( - cron_mgr=new_cron_manager, - ch_mgr=new_channel_manager, - mcp_mgr=new_mcp_manager, - ) - return + # Setup approval service with default agent's channel_manager + default_agent = await multi_agent_manager.get_agent("default") + if default_agent.channel_manager: + from .approvals import get_approval_service - new_mcp_watcher = None - if hasattr(config, "mcp"): - try: - new_mcp_watcher = MCPConfigWatcher( - mcp_manager=new_mcp_manager, - config_loader=load_config, - config_path=get_config_path(), - ) - await new_mcp_watcher.start() - except Exception: - logger.exception( - "restart_services: mcp_watcher.start failed", - ) - await _teardown_new_stack( - config_watcher=new_config_watcher, - cron_mgr=new_cron_manager, - ch_mgr=new_channel_manager, - mcp_mgr=new_mcp_manager, - ) - return - - if hasattr(config, "mcp"): - runner.set_mcp_manager(new_mcp_manager) - app.state.mcp_manager = new_mcp_manager - app.state.mcp_watcher = new_mcp_watcher - else: - runner.set_mcp_manager(None) - app.state.mcp_manager = None - app.state.mcp_watcher = None - app.state.channel_manager = new_channel_manager - app.state.cron_manager = new_cron_manager - app.state.config_watcher = new_config_watcher - logger.info("Daemon restart (in-process) completed: managers rebuilt") - - setattr(runner, "_restart_callback", _restart_services) + get_approval_service().set_channel_manager( + default_agent.channel_manager, + ) startup_elapsed = time.time() - startup_start_time logger.debug( @@ -411,39 +202,16 @@ async def _do_restart_services( try: yield finally: - # Stop current app.state refs (post-restart instances if any) - cfg_w = getattr(app.state, "config_watcher", None) - mcp_w = getattr(app.state, "mcp_watcher", None) - cron_mgr = getattr(app.state, "cron_manager", None) - ch_mgr = getattr(app.state, "channel_manager", None) - mcp_mgr = getattr(app.state, "mcp_manager", None) - # stop order: watchers -> cron -> channels -> mcp -> runner - if cfg_w is not None: - try: - await cfg_w.stop() - except Exception: - pass - if mcp_w is not None: - try: - await mcp_w.stop() - except Exception: - pass - if cron_mgr is not None: - try: - await cron_mgr.stop() - except Exception: - pass - if ch_mgr is not None: - try: - await ch_mgr.stop_all() - except Exception: - pass - if mcp_mgr is not None: + # Stop multi-agent manager (stops all agents and their components) + multi_agent_mgr = getattr(app.state, "multi_agent_manager", None) + if multi_agent_mgr is not None: + logger.info("Stopping MultiAgentManager...") try: - await mcp_mgr.close_all() - except Exception: - pass - await runner.stop() + await multi_agent_mgr.stop_all() + except Exception as e: + logger.error(f"Error stopping MultiAgentManager: {e}") + + logger.info("Application shutdown complete") app = FastAPI( @@ -453,6 +221,9 @@ async def _do_restart_services( openapi_url="/openapi.json" if DOCS_ENABLED else None, ) +# Add agent context middleware for agent-scoped routes +app.add_middleware(AgentContextMiddleware) + # Apply CORS middleware if CORS_ORIGINS is set if CORS_ORIGINS: origins = [o.strip() for o in CORS_ORIGINS.split(",") if o.strip()] @@ -504,7 +275,8 @@ def read_root(): "CoPaw Web Console is not available. " "If you installed CoPaw from source code, please run " "`npm ci && npm run build` in CoPaw's `console/` " - "directory, and restart CoPaw to enable the web console." + "directory, and restart CoPaw to enable the " + "web console." ), } @@ -517,6 +289,11 @@ def get_version(): app.include_router(api_router, prefix="/api") +# Agent-scoped router: /api/agents/{agentId}/chats, etc. +agent_scoped_router = create_agent_scoped_router() +app.include_router(agent_scoped_router, prefix="/api") + + app.include_router( agent_app.router, prefix="/api/agent", diff --git a/src/copaw/config/watcher.py b/src/copaw/app/agent_config_watcher.py similarity index 54% rename from src/copaw/config/watcher.py rename to src/copaw/app/agent_config_watcher.py index 8c5dded31..8ccb87463 100644 --- a/src/copaw/config/watcher.py +++ b/src/copaw/app/agent_config_watcher.py @@ -1,16 +1,23 @@ # -*- coding: utf-8 -*- -"""Watch config.json for changes and auto-reload channels and heartbeat.""" +"""Watch agent.json for changes and auto-reload agent components. + +This watcher monitors an agent's workspace/agent.json file for changes +and automatically reloads channels, heartbeat, and other configurations +without requiring manual restart. +""" from __future__ import annotations import asyncio import logging from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING + +from ..config.config import load_agent_config +from ..config.utils import get_available_channels -from .utils import load_config, get_config_path, get_available_channels -from .config import ChannelConfig, HeartbeatConfig -from ..app.channels import ChannelManager # pylint: disable=no-name-in-module +if TYPE_CHECKING: + from ..config.config import ChannelConfig, HeartbeatConfig logger = logging.getLogger(__name__) @@ -25,27 +32,43 @@ def _heartbeat_hash(hb: Optional[HeartbeatConfig]) -> int: return hash(str(hb.model_dump(mode="json"))) -class ConfigWatcher: - """Poll config.json mtime; reload only changed channels automatically.""" +class AgentConfigWatcher: + """Poll agent.json mtime and reload changed configs automatically. + + This watcher is agent-scoped and monitors a specific agent's + workspace/agent.json file for configuration changes. + """ def __init__( self, - channel_manager: ChannelManager, - poll_interval: float = DEFAULT_POLL_INTERVAL, - config_path: Optional[Path] = None, + agent_id: str, + workspace_dir: Path, + channel_manager: Any, cron_manager: Any = None, + poll_interval: float = DEFAULT_POLL_INTERVAL, ): + """Initialize agent config watcher. + + Args: + agent_id: Agent ID to monitor + workspace_dir: Path to agent's workspace directory + channel_manager: ChannelManager instance for this agent + cron_manager: CronManager instance for this agent (optional) + poll_interval: How often to check for changes (seconds) + """ + self._agent_id = agent_id + self._workspace_dir = workspace_dir + self._config_path = workspace_dir / "agent.json" self._channel_manager = channel_manager - self._poll_interval = poll_interval - self._config_path = config_path or get_config_path() self._cron_manager = cron_manager + self._poll_interval = poll_interval self._task: Optional[asyncio.Task] = None - # Snapshot of the last known channel config (for diffing) + # Snapshot of the last known config (for diffing) self._last_channels: Optional[ChannelConfig] = None self._last_channels_hash: Optional[int] = None self._last_heartbeat_hash: Optional[int] = None - # mtime of config.json at last check + # mtime of agent.json at last check self._last_mtime: float = 0.0 async def start(self) -> None: @@ -53,15 +76,15 @@ async def start(self) -> None: self._snapshot() self._task = asyncio.create_task( self._poll_loop(), - name="config_watcher", + name=f"agent_config_watcher_{self._agent_id}", ) logger.info( - "ConfigWatcher started (poll=%.1fs, path=%s)", - self._poll_interval, - self._config_path, + f"AgentConfigWatcher started for agent {self._agent_id} " + f"(poll={self._poll_interval}s, path={self._config_path})", ) async def stop(self) -> None: + """Stop the polling task.""" if self._task: self._task.cancel() try: @@ -69,28 +92,40 @@ async def stop(self) -> None: except asyncio.CancelledError: pass self._task = None - logger.info("ConfigWatcher stopped") + logger.info(f"AgentConfigWatcher stopped for agent {self._agent_id}") # ------------------------------------------------------------------ + # Internal methods + # ------------------------------------------------------------------ def _snapshot(self) -> None: - """Load current config; record mtime, channels hash, heartbeat hash.""" + """Load current agent config; record mtime and hashes.""" try: self._last_mtime = self._config_path.stat().st_mtime except FileNotFoundError: self._last_mtime = 0.0 + try: - config = load_config(self._config_path) - self._last_channels = config.channels.model_copy(deep=True) - self._last_channels_hash = self._channels_hash(config.channels) - hb = getattr( - config.agents.defaults, - "heartbeat", - None, + agent_config = load_agent_config(self._agent_id) + if agent_config.channels: + self._last_channels = agent_config.channels.model_copy( + deep=True, + ) + self._last_channels_hash = self._channels_hash( + agent_config.channels, + ) + else: + self._last_channels = None + self._last_channels_hash = None + + self._last_heartbeat_hash = _heartbeat_hash( + agent_config.heartbeat, ) - self._last_heartbeat_hash = _heartbeat_hash(hb) except Exception: - logger.exception("ConfigWatcher: failed to load initial config") + logger.exception( + f"AgentConfigWatcher: failed to load initial config " + f"for agent {self._agent_id}", + ) self._last_channels = None self._last_channels_hash = None self._last_heartbeat_hash = None @@ -123,26 +158,33 @@ async def _reload_one_channel( old_channel = await self._channel_manager.get_channel(name) if old_channel is None: logger.warning( - "ConfigWatcher: channel '%s' not found, skip", - name, + f"AgentConfigWatcher ({self._agent_id}): " + f"channel '{name}' not found, skip", ) return new_channel = old_channel.clone(new_ch) await self._channel_manager.replace_channel(new_channel) - logger.info("ConfigWatcher: channel '%s' reloaded", name) + logger.info( + f"AgentConfigWatcher ({self._agent_id}): " + f"channel '{name}' reloaded", + ) except Exception: logger.exception( - "ConfigWatcher: failed to reload channel '%s'", - name, + f"AgentConfigWatcher ({self._agent_id}): " + f"failed to reload channel '{name}'", ) setattr(new_channels, name, old_ch if old_ch else new_ch) - async def _apply_channel_changes(self, loaded_config: Any) -> None: + async def _apply_channel_changes(self, agent_config: Any) -> None: """Diff channels and reload changed ones; update snapshot.""" - new_hash = self._channels_hash(loaded_config.channels) + if not agent_config.channels: + return + + new_hash = self._channels_hash(agent_config.channels) if new_hash == self._last_channels_hash: return - new_channels = loaded_config.channels + + new_channels = agent_config.channels old_channels = self._last_channels extra_new = getattr(new_channels, "__pydantic_extra__", None) or {} extra_old = ( @@ -150,6 +192,7 @@ async def _apply_channel_changes(self, loaded_config: Any) -> None: if old_channels else {} ) + for name in get_available_channels(): new_ch = getattr(new_channels, name, None) or extra_new.get(name) old_ch = ( @@ -164,17 +207,17 @@ async def _apply_channel_changes(self, loaded_config: Any) -> None: if new_dump is not None and new_dump == old_dump: continue logger.info( - "ConfigWatcher: channel '%s' config changed, reloading", - name, + f"AgentConfigWatcher ({self._agent_id}): " + f"channel '{name}' config changed, reloading", ) await self._reload_one_channel(name, new_ch, new_channels, old_ch) + self._last_channels = new_channels.model_copy(deep=True) self._last_channels_hash = self._channels_hash(new_channels) - async def _apply_heartbeat_change(self, loaded_config: Any) -> None: + async def _apply_heartbeat_change(self, agent_config: Any) -> None: """Update heartbeat hash and reschedule if changed.""" - hb = getattr(loaded_config.agents.defaults, "heartbeat", None) - new_hb_hash = _heartbeat_hash(hb) + new_hb_hash = _heartbeat_hash(agent_config.heartbeat) if ( self._cron_manager is not None and new_hb_hash != self._last_heartbeat_hash @@ -182,34 +225,53 @@ async def _apply_heartbeat_change(self, loaded_config: Any) -> None: self._last_heartbeat_hash = new_hb_hash try: await self._cron_manager.reschedule_heartbeat() - logger.info("ConfigWatcher: heartbeat rescheduled") + logger.info( + f"AgentConfigWatcher ({self._agent_id}): " + f"heartbeat rescheduled", + ) except Exception: logger.exception( - "ConfigWatcher: failed to reschedule heartbeat", + f"AgentConfigWatcher ({self._agent_id}): " + f"failed to reschedule heartbeat", ) else: self._last_heartbeat_hash = new_hb_hash async def _poll_loop(self) -> None: + """Main polling loop.""" while True: try: await asyncio.sleep(self._poll_interval) await self._check() except Exception: - logger.exception("ConfigWatcher: poll iteration failed") + logger.exception( + f"AgentConfigWatcher ({self._agent_id}): " + f"poll iteration failed", + ) async def _check(self) -> None: + """Check for config changes and reload if needed.""" try: mtime = self._config_path.stat().st_mtime except FileNotFoundError: return + if mtime == self._last_mtime: return + self._last_mtime = mtime + try: - loaded = load_config(self._config_path) + agent_config = load_agent_config(self._agent_id) except Exception: - logger.exception("ConfigWatcher: failed to parse config.json") + logger.exception( + f"AgentConfigWatcher ({self._agent_id}): " + f"failed to parse agent.json", + ) return - await self._apply_channel_changes(loaded) - await self._apply_heartbeat_change(loaded) + + # Apply changes + if self._channel_manager: + await self._apply_channel_changes(agent_config) + if self._cron_manager: + await self._apply_heartbeat_change(agent_config) diff --git a/src/copaw/app/agent_context.py b/src/copaw/app/agent_context.py new file mode 100644 index 000000000..d263ee5bf --- /dev/null +++ b/src/copaw/app/agent_context.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +"""Agent context utilities for multi-agent support. + +Provides utilities to get the correct agent instance for each request. +""" +from contextvars import ContextVar +from typing import Optional, TYPE_CHECKING +from fastapi import Request +from .multi_agent_manager import MultiAgentManager +from ..config.utils import load_config + +if TYPE_CHECKING: + from .workspace import Workspace + +# Context variable to store current agent ID across async calls +_current_agent_id: ContextVar[Optional[str]] = ContextVar( + "current_agent_id", + default=None, +) + + +async def get_agent_for_request( + request: Request, + agent_id: Optional[str] = None, +) -> "Workspace": + """Get agent workspace for current request. + + Priority: + 1. agent_id parameter (explicit override) + 2. request.state.agent_id (from agent-scoped router) + 3. X-Agent-Id header (from frontend) + 4. Active agent from config + + Args: + request: FastAPI request object + agent_id: Agent ID override (highest priority) + + Returns: + Workspace for the specified or active agent + + Raises: + HTTPException: If agent not found + """ + from fastapi import HTTPException + + # Determine which agent to use + target_agent_id = agent_id + + # Check request.state.agent_id (set by agent-scoped router) + if not target_agent_id and hasattr(request.state, "agent_id"): + target_agent_id = request.state.agent_id + + # Check X-Agent-Id header + if not target_agent_id: + target_agent_id = request.headers.get("X-Agent-Id") + + if not target_agent_id: + # Fallback to active agent from config + config = load_config() + target_agent_id = config.agents.active_agent or "default" + + # Get MultiAgentManager + if not hasattr(request.app.state, "multi_agent_manager"): + raise HTTPException( + status_code=500, + detail="MultiAgentManager not initialized", + ) + + manager: MultiAgentManager = request.app.state.multi_agent_manager + + try: + workspace = await manager.get_agent(target_agent_id) + if not workspace: + raise HTTPException( + status_code=404, + detail=f"Agent '{target_agent_id}' not found", + ) + return workspace + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to get agent: {str(e)}", + ) from e + + +def get_active_agent_id() -> str: + """Get current active agent ID from config. + + Returns: + Active agent ID, defaults to "default" + """ + try: + config = load_config() + return config.agents.active_agent or "default" + except Exception: + return "default" + + +def set_current_agent_id(agent_id: str) -> None: + """Set current agent ID in context. + + Args: + agent_id: Agent ID to set + """ + _current_agent_id.set(agent_id) + + +def get_current_agent_id() -> str: + """Get current agent ID from context or config fallback. + + Returns: + Current agent ID, defaults to active agent or "default" + """ + agent_id = _current_agent_id.get() + if agent_id: + return agent_id + return get_active_agent_id() diff --git a/src/copaw/app/channels/dingtalk/channel.py b/src/copaw/app/channels/dingtalk/channel.py index 8491df400..7964c9224 100644 --- a/src/copaw/app/channels/dingtalk/channel.py +++ b/src/copaw/app/channels/dingtalk/channel.py @@ -34,6 +34,7 @@ from ..utils import file_url_to_local_path from ....config.config import DingTalkConfig as DingTalkChannelConfig from ....config.utils import get_config_path +from ....constant import DEFAULT_MEDIA_DIR from ..base import ( BaseChannel, @@ -85,7 +86,8 @@ def __init__( client_id: str, client_secret: str, bot_prefix: str, - media_dir: str = "~/.copaw/media", + media_dir: str = "", + workspace_dir: Path | None = None, on_reply_sent: OnReplySent = None, show_tool_details: bool = True, filter_tool_messages: bool = False, @@ -112,7 +114,17 @@ def __init__( self.client_id = client_id self.client_secret = client_secret self.bot_prefix = bot_prefix - self._media_dir = Path(media_dir).expanduser() + self._workspace_dir = ( + Path(workspace_dir).expanduser() if workspace_dir else None + ) + # Use workspace-specific media dir if workspace_dir is provided + if not media_dir and self._workspace_dir: + self._media_dir = self._workspace_dir / "media" + elif media_dir: + self._media_dir = Path(media_dir).expanduser() + else: + self._media_dir = DEFAULT_MEDIA_DIR + self._media_dir.mkdir(parents=True, exist_ok=True) self._client: Optional[dingtalk_stream.DingTalkStreamClient] = None self._loop: Optional[asyncio.AbstractEventLoop] = None @@ -156,7 +168,7 @@ def from_env( client_id=os.getenv("DINGTALK_CLIENT_ID", ""), client_secret=os.getenv("DINGTALK_CLIENT_SECRET", ""), bot_prefix=os.getenv("DINGTALK_BOT_PREFIX", "[BOT] "), - media_dir=os.getenv("DINGTALK_MEDIA_DIR", "~/.copaw/media"), + media_dir=os.getenv("DINGTALK_MEDIA_DIR", ""), on_reply_sent=on_reply_sent, dm_policy=os.getenv("DINGTALK_DM_POLICY", "open"), group_policy=os.getenv("DINGTALK_GROUP_POLICY", "open"), @@ -174,6 +186,7 @@ def from_config( show_tool_details: bool = True, filter_tool_messages: bool = False, filter_thinking: bool = False, + workspace_dir: Path | None = None, ) -> "DingTalkChannel": return cls( process=process, @@ -181,7 +194,8 @@ def from_config( client_id=config.client_id or "", client_secret=config.client_secret or "", bot_prefix=config.bot_prefix or "[BOT] ", - media_dir=config.media_dir or "~/.copaw/media", + media_dir=config.media_dir or "", + workspace_dir=workspace_dir, on_reply_sent=on_reply_sent, show_tool_details=show_tool_details, filter_tool_messages=filter_tool_messages, @@ -257,7 +271,13 @@ def _route_from_handle(self, to_handle: str) -> dict: return {"webhook_key": s} if s else {} def _session_webhook_store_path(self) -> Path: - """Path to persist session webhook mapping (for cron after restart).""" + """Path to persist session webhook mapping (for cron after restart). + + Uses agent workspace directory if available, otherwise falls back + to global config directory for backward compatibility. + """ + if self._workspace_dir: + return self._workspace_dir / "dingtalk_session_webhooks.json" return get_config_path().parent / "dingtalk_session_webhooks.json" def _load_session_webhook_store_from_disk(self) -> None: diff --git a/src/copaw/app/channels/feishu/channel.py b/src/copaw/app/channels/feishu/channel.py index 106287757..54051a077 100644 --- a/src/copaw/app/channels/feishu/channel.py +++ b/src/copaw/app/channels/feishu/channel.py @@ -37,6 +37,7 @@ from ....config.config import FeishuConfig as FeishuChannelConfig from ....config.utils import get_config_path +from ....constant import DEFAULT_MEDIA_DIR from ..base import ( BaseChannel, ContentType, @@ -161,7 +162,8 @@ def __init__( bot_prefix: str, encrypt_key: str = "", verification_token: str = "", - media_dir: str = "~/.copaw/media", + media_dir: str = "", + workspace_dir: Path | None = None, on_reply_sent: OnReplySent = None, show_tool_details: bool = True, filter_tool_messages: bool = False, @@ -190,7 +192,17 @@ def __init__( self.bot_prefix = bot_prefix self.encrypt_key = encrypt_key or "" self.verification_token = verification_token or "" - self._media_dir = Path(media_dir).expanduser() + self._workspace_dir = ( + Path(workspace_dir).expanduser() if workspace_dir else None + ) + # Use workspace-specific media dir if workspace_dir is provided + if not media_dir and self._workspace_dir: + self._media_dir = self._workspace_dir / "media" + elif media_dir: + self._media_dir = Path(media_dir).expanduser() + else: + self._media_dir = DEFAULT_MEDIA_DIR + self._media_dir.mkdir(parents=True, exist_ok=True) self._client: Any = None self._ws_client: Any = None @@ -235,7 +247,7 @@ def from_env( bot_prefix=os.getenv("FEISHU_BOT_PREFIX", "[BOT] "), encrypt_key=os.getenv("FEISHU_ENCRYPT_KEY", ""), verification_token=os.getenv("FEISHU_VERIFICATION_TOKEN", ""), - media_dir=os.getenv("FEISHU_MEDIA_DIR", "~/.copaw/media"), + media_dir=os.getenv("FEISHU_MEDIA_DIR", ""), on_reply_sent=on_reply_sent, dm_policy=os.getenv("FEISHU_DM_POLICY", "open"), group_policy=os.getenv("FEISHU_GROUP_POLICY", "open"), @@ -253,6 +265,7 @@ def from_config( show_tool_details: bool = True, filter_tool_messages: bool = False, filter_thinking: bool = False, + workspace_dir: Path | None = None, ) -> "FeishuChannel": return cls( process=process, @@ -262,7 +275,8 @@ def from_config( bot_prefix=config.bot_prefix or "[BOT] ", encrypt_key=config.encrypt_key or "", verification_token=config.verification_token or "", - media_dir=config.media_dir or "~/.copaw/media", + media_dir=config.media_dir or "", + workspace_dir=workspace_dir, on_reply_sent=on_reply_sent, show_tool_details=show_tool_details, filter_tool_messages=filter_tool_messages, @@ -998,7 +1012,12 @@ async def _download_file_resource( def _receive_id_store_path(self) -> Path: """ Path to persist receive_id mapping (for cron to resolve after restart). + + Uses agent workspace directory if available, otherwise falls back + to global config directory for backward compatibility. """ + if self._workspace_dir: + return self._workspace_dir / "feishu_receive_ids.json" return get_config_path().parent / "feishu_receive_ids.json" def _load_receive_id_store_from_disk(self) -> None: diff --git a/src/copaw/app/channels/imessage/channel.py b/src/copaw/app/channels/imessage/channel.py index bf4d4f0fd..e4f312208 100644 --- a/src/copaw/app/channels/imessage/channel.py +++ b/src/copaw/app/channels/imessage/channel.py @@ -21,6 +21,7 @@ ) from ....config.config import IMessageChannelConfig +from ....constant import DEFAULT_MEDIA_DIR from ..utils import file_url_to_local_path from ....agents.utils.file_handling import download_file_from_url @@ -44,7 +45,7 @@ def __init__( db_path: str, poll_sec: float, bot_prefix: str, - media_dir: str = "~/.copaw/media", + media_dir: str = "", max_decoded_size: int = 10 * 1024 * 1024, # 10MB default on_reply_sent: OnReplySent = None, show_tool_details: bool = True, @@ -64,7 +65,9 @@ def __init__( self.bot_prefix = bot_prefix # Create media directory for downloaded files - self._media_dir = Path(media_dir).expanduser() + self._media_dir = ( + Path(media_dir).expanduser() if media_dir else DEFAULT_MEDIA_DIR + ) self._media_dir.mkdir(parents=True, exist_ok=True) # Base64 data size limit @@ -89,7 +92,7 @@ def from_env( ), poll_sec=float(os.getenv("IMESSAGE_POLL_SEC", "1.0")), bot_prefix=os.getenv("IMESSAGE_BOT_PREFIX", "[BOT] "), - media_dir=os.getenv("IMESSAGE_MEDIA_DIR", "~/.copaw/media"), + media_dir=os.getenv("IMESSAGE_MEDIA_DIR", ""), max_decoded_size=int( os.getenv("IMESSAGE_MAX_DECODED_SIZE", "10485760"), ), # 10MB @@ -112,7 +115,7 @@ def from_config( db_path=config.db_path or "~/Library/Messages/chat.db", poll_sec=config.poll_sec, bot_prefix=config.bot_prefix or "[BOT] ", - media_dir=config.media_dir or "~/.copaw/media", + media_dir=config.media_dir if config.media_dir else "", max_decoded_size=config.max_decoded_size, on_reply_sent=on_reply_sent, show_tool_details=show_tool_details, diff --git a/src/copaw/app/channels/manager.py b/src/copaw/app/channels/manager.py index 41184b08b..02d4b6306 100644 --- a/src/copaw/app/channels/manager.py +++ b/src/copaw/app/channels/manager.py @@ -7,6 +7,7 @@ import asyncio import logging +from pathlib import Path from typing import ( Callable, @@ -159,8 +160,16 @@ def from_config( process: ProcessHandler, config: "Config", on_last_dispatch: OnLastDispatch = None, + workspace_dir: Path | None = None, ) -> "ChannelManager": - """Create channels from config (config.json).""" + """Create channels from config (config.json or agent.json). + + Args: + process: Process handler for agent communication + config: Configuration object with channels + on_last_dispatch: Callback for dispatch events + workspace_dir: Agent workspace directory for channel state files + """ available = get_available_channels() ch = config.channels show_tool_details = getattr(config, "show_tool_details", True) @@ -180,6 +189,12 @@ def from_config( ) if ch_cfg is None: continue + + # Check if channel is enabled + enabled = getattr(ch_cfg, "enabled", False) + if not enabled: + continue + filter_tool_messages = getattr( ch_cfg, "filter_tool_messages", @@ -190,16 +205,26 @@ def from_config( "filter_thinking", False, ) - channels.append( - ch_cls.from_config( - process, - ch_cfg, - on_reply_sent=on_last_dispatch, - show_tool_details=show_tool_details, - filter_tool_messages=filter_tool_messages, - filter_thinking=filter_thinking, - ), - ) + + # Pass workspace_dir to channel if supported + from_config_kwargs = { + "process": process, + "config": ch_cfg, + "on_reply_sent": on_last_dispatch, + "show_tool_details": show_tool_details, + "filter_tool_messages": filter_tool_messages, + "filter_thinking": filter_thinking, + } + + # Only pass workspace_dir to channels that support it + import inspect + + sig = inspect.signature(ch_cls.from_config) + if "workspace_dir" in sig.parameters: + from_config_kwargs["workspace_dir"] = workspace_dir + + channels.append(ch_cls.from_config(**from_config_kwargs)) + return cls(channels) def _make_enqueue_cb(self, channel_id: str) -> Callable[[Any], None]: diff --git a/src/copaw/app/channels/mattermost/channel.py b/src/copaw/app/channels/mattermost/channel.py index 3d9f4954b..a77ef88ed 100644 --- a/src/copaw/app/channels/mattermost/channel.py +++ b/src/copaw/app/channels/mattermost/channel.py @@ -20,6 +20,7 @@ ) from ....config.config import MattermostConfig as MattermostChannelConfig +from ....constant import WORKING_DIR from ..base import ( BaseChannel, OnReplySent, @@ -31,7 +32,7 @@ MATTERMOST_POST_CHUNK_SIZE = 4000 # chars per post (hard limit ~16383) -_DEFAULT_MEDIA_DIR = Path("~/.copaw/media/mattermost").expanduser() +_DEFAULT_MEDIA_DIR = WORKING_DIR / "media" / "mattermost" _TYPING_TIMEOUT_S = 180 _IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff"} diff --git a/src/copaw/app/channels/qq/channel.py b/src/copaw/app/channels/qq/channel.py index 3314139bd..9f2c426d7 100644 --- a/src/copaw/app/channels/qq/channel.py +++ b/src/copaw/app/channels/qq/channel.py @@ -35,6 +35,7 @@ ) from ....config.config import QQConfig as QQChannelConfig +from ....constant import WORKING_DIR from ..base import ( BaseChannel, @@ -73,7 +74,7 @@ _IMAGE_TAG_PATTERN = re.compile(r"\[Image: (https?://[^\]]+)\]", re.IGNORECASE) # Rich media paths -_DEFAULT_MEDIA_DIR = Path("~/.copaw/media/qq").expanduser() +_DEFAULT_MEDIA_DIR = WORKING_DIR / "media" / "qq" class QQApiError(RuntimeError): diff --git a/src/copaw/app/channels/telegram/channel.py b/src/copaw/app/channels/telegram/channel.py index de79b91f4..042f665b2 100644 --- a/src/copaw/app/channels/telegram/channel.py +++ b/src/copaw/app/channels/telegram/channel.py @@ -51,7 +51,7 @@ 50 * 1024 * 1024 ) # 50 MB – Telegram bot upload limit -_DEFAULT_MEDIA_DIR = Path("~/.copaw/media/telegram").expanduser() +_DEFAULT_MEDIA_DIR = WORKING_DIR / "media" / "telegram" _TYPING_TIMEOUT_S = 180 _RECONNECT_INITIAL_S = 2.0 @@ -278,6 +278,7 @@ def __init__( on_reply_sent: OnReplySent = None, show_tool_details: bool = True, media_dir: str = "", + workspace_dir: Path | None = None, show_typing: bool = True, filter_tool_messages: bool = False, filter_thinking: bool = False, @@ -307,6 +308,9 @@ def __init__( self._media_dir = ( Path(media_dir).expanduser() if media_dir else _DEFAULT_MEDIA_DIR ) + self._workspace_dir = ( + Path(workspace_dir).expanduser() if workspace_dir else None + ) self._show_typing = show_typing self._typing_tasks: dict[str, asyncio.Task] = {} self._task: Optional[asyncio.Task] = None @@ -485,6 +489,7 @@ def from_config( show_tool_details: bool = True, filter_tool_messages: bool = False, filter_thinking: bool = False, + workspace_dir: Path | None = None, ) -> "TelegramChannel": if isinstance(config, dict): c = config @@ -509,6 +514,7 @@ def _get_str(key: str) -> str: show_tool_details=show_tool_details, filter_tool_messages=filter_tool_messages, filter_thinking=filter_thinking, + workspace_dir=workspace_dir, show_typing=show_typing, dm_policy=c.get("dm_policy") or "open", group_policy=c.get("group_policy") or "open", @@ -784,7 +790,11 @@ async def _send_media_value( "Could not resolve media file from URL.", ) local_path = Path(raw_path).resolve() - allowed_root = (WORKING_DIR / "media").resolve() + allowed_root = ( + (self._workspace_dir / "media").resolve() + if self._workspace_dir + else (WORKING_DIR / "media").resolve() + ) if not local_path.is_relative_to(allowed_root): logger.error( "telegram: blocked media outside allowed directory: %s", diff --git a/src/copaw/app/channels/wecom/channel.py b/src/copaw/app/channels/wecom/channel.py index 0aaf48792..87923f14a 100644 --- a/src/copaw/app/channels/wecom/channel.py +++ b/src/copaw/app/channels/wecom/channel.py @@ -29,6 +29,7 @@ ) from aibot import WSClient, WSClientOptions, generate_req_id +from ....constant import DEFAULT_MEDIA_DIR from ..base import ( BaseChannel, ContentType, @@ -61,7 +62,7 @@ def __init__( bot_id: str, secret: str, bot_prefix: str = "[BOT] ", - media_dir: str = "~/.copaw/media", + media_dir: str = "", welcome_text: str = "", on_reply_sent: OnReplySent = None, show_tool_details: bool = True, @@ -89,7 +90,9 @@ def __init__( self.secret = secret self.bot_prefix = bot_prefix self.welcome_text = welcome_text - self._media_dir = Path(media_dir).expanduser() + self._media_dir = ( + Path(media_dir).expanduser() if media_dir else DEFAULT_MEDIA_DIR + ) self._max_reconnect_attempts = max_reconnect_attempts self._client: Any = None @@ -118,7 +121,7 @@ def from_env( bot_id=os.getenv("WECOM_BOT_ID", ""), secret=os.getenv("WECOM_SECRET", ""), bot_prefix=os.getenv("WECOM_BOT_PREFIX", "[BOT] "), - media_dir=os.getenv("WECOM_MEDIA_DIR", "~/.copaw/media"), + media_dir=os.getenv("WECOM_MEDIA_DIR", ""), on_reply_sent=on_reply_sent, dm_policy=os.getenv("WECOM_DM_POLICY", "open"), group_policy=os.getenv("WECOM_GROUP_POLICY", "open"), @@ -145,10 +148,7 @@ def from_config( bot_id=getattr(config, "bot_id", "") or "", secret=getattr(config, "secret", "") or "", bot_prefix=getattr(config, "bot_prefix", "[BOT] ") or "[BOT] ", - media_dir=( - getattr(config, "media_dir", "~/.copaw/media") - or "~/.copaw/media" - ), + media_dir=getattr(config, "media_dir", None) or "", welcome_text=getattr(config, "welcome_text", "") or "", on_reply_sent=on_reply_sent, show_tool_details=show_tool_details, diff --git a/src/copaw/app/crons/api.py b/src/copaw/app/crons/api.py index a52b44399..21ca65138 100644 --- a/src/copaw/app/crons/api.py +++ b/src/copaw/app/crons/api.py @@ -10,14 +10,19 @@ router = APIRouter(prefix="/cron", tags=["cron"]) -def get_cron_manager(request: Request) -> CronManager: - mgr = getattr(request.app.state, "cron_manager", None) - if mgr is None: +async def get_cron_manager( + request: Request, +) -> CronManager: + """Get cron manager for the active agent.""" + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + if workspace.cron_manager is None: raise HTTPException( - status_code=503, - detail="cron manager not initialized", + status_code=500, + detail="CronManager not initialized", ) - return mgr + return workspace.cron_manager @router.get("/jobs", response_model=list[CronJobSpec]) diff --git a/src/copaw/app/crons/repo/json_repo.py b/src/copaw/app/crons/repo/json_repo.py index fc1c47025..ecc40de11 100644 --- a/src/copaw/app/crons/repo/json_repo.py +++ b/src/copaw/app/crons/repo/json_repo.py @@ -17,7 +17,9 @@ class JsonJobRepository(BaseJobRepository): - Atomic write: write tmp then replace. """ - def __init__(self, path: Path): + def __init__(self, path: Path | str): + if isinstance(path, str): + path = Path(path) self._path = path.expanduser() @property diff --git a/src/copaw/app/migration.py b/src/copaw/app/migration.py new file mode 100644 index 000000000..10653c38d --- /dev/null +++ b/src/copaw/app/migration.py @@ -0,0 +1,319 @@ +# -*- coding: utf-8 -*- +"""Configuration migration utilities for multi-agent support. + +Handles migration from legacy single-agent config to new multi-agent structure. +""" +import json +import logging +import shutil +from pathlib import Path + +from ..config.config import ( + AgentProfileConfig, + AgentProfileRef, + AgentsConfig, + AgentsRunningConfig, + AgentsLLMRoutingConfig, +) +from ..config.utils import load_config, save_config + +logger = logging.getLogger(__name__) + + +def migrate_legacy_workspace_to_default_agent() -> bool: + """Migrate legacy single-agent workspace to default agent workspace. + + This function: + 1. Checks if migration is needed + 2. Creates default agent workspace + 3. Migrates sessions, memory, and markdown files + 4. Creates agent.json with legacy configuration + 5. Updates root config.json to new structure + + Returns: + bool: True if migration was performed, False if already migrated + """ + try: + config = load_config() + except Exception as e: + logger.error(f"Failed to load config: {e}") + return False + + # Check if already migrated + # Skip if: + # 1. Multiple agents already exist (multi-agent config), OR + # 2. Default agent has agent.json (already migrated) + if len(config.agents.profiles) > 1: + logger.debug( + f"Multi-agent config already exists " + f"({len(config.agents.profiles)} agents), skipping migration", + ) + return False + + if "default" in config.agents.profiles: + agent_ref = config.agents.profiles["default"] + if isinstance(agent_ref, AgentProfileRef): + workspace_dir = Path(agent_ref.workspace_dir).expanduser() + agent_config_path = workspace_dir / "agent.json" + if agent_config_path.exists(): + logger.debug( + "Default agent already migrated, skipping migration", + ) + return False + + logger.info("=" * 60) + logger.info("Migrating legacy config to multi-agent structure...") + logger.info("=" * 60) + + # Extract legacy agent configuration + legacy_agents = config.agents + + # Create default agent workspace + default_workspace = Path("~/.copaw/workspaces/default").expanduser() + default_workspace.mkdir(parents=True, exist_ok=True) + logger.info(f"Created default agent workspace: {default_workspace}") + + # Build default agent configuration from legacy settings + default_agent_config = AgentProfileConfig( + id="default", + name="Default Agent", + description="Default CoPaw agent (migrated from legacy config)", + workspace_dir=str(default_workspace), + channels=config.channels if hasattr(config, "channels") else None, + mcp=config.mcp if hasattr(config, "mcp") else None, + heartbeat=( + legacy_agents.defaults.heartbeat + if hasattr(legacy_agents, "defaults") and legacy_agents.defaults + else None + ), + running=( + legacy_agents.running + if hasattr(legacy_agents, "running") and legacy_agents.running + else AgentsRunningConfig() + ), + llm_routing=( + legacy_agents.llm_routing + if hasattr(legacy_agents, "llm_routing") + and legacy_agents.llm_routing + else AgentsLLMRoutingConfig() + ), + system_prompt_files=( + legacy_agents.system_prompt_files + if hasattr(legacy_agents, "system_prompt_files") + and legacy_agents.system_prompt_files + else ["AGENTS.md", "SOUL.md", "PROFILE.md"] + ), + tools=config.tools if hasattr(config, "tools") else None, + security=config.security if hasattr(config, "security") else None, + ) + + # Save default agent configuration to workspace/agent.json + agent_config_path = default_workspace / "agent.json" + with open(agent_config_path, "w", encoding="utf-8") as f: + json.dump( + default_agent_config.model_dump(exclude_none=True), + f, + ensure_ascii=False, + indent=2, + ) + logger.info(f"Created agent config: {agent_config_path}") + + # Migrate existing workspace files to default agent workspace + old_workspace = Path("~/.copaw").expanduser() + + migrated_items = [] + + # Migrate sessions directory + _migrate_workspace_item( + old_workspace / "sessions", + default_workspace / "sessions", + "sessions", + migrated_items, + ) + + # Migrate memory directory + _migrate_workspace_item( + old_workspace / "memory", + default_workspace / "memory", + "memory", + migrated_items, + ) + + # Migrate chats.json + _migrate_workspace_item( + old_workspace / "chats.json", + default_workspace / "chats.json", + "chats.json", + migrated_items, + ) + + # Migrate jobs.json + _migrate_workspace_item( + old_workspace / "jobs.json", + default_workspace / "jobs.json", + "jobs.json", + migrated_items, + ) + + # Migrate markdown files + for md_file in ["AGENTS.md", "SOUL.md", "PROFILE.md", "HEARTBEAT.md"]: + _migrate_workspace_item( + old_workspace / md_file, + default_workspace / md_file, + md_file, + migrated_items, + ) + + # Migrate channel-specific configuration files + _migrate_workspace_item( + old_workspace / "feishu_receive_ids.json", + default_workspace / "feishu_receive_ids.json", + "feishu_receive_ids.json", + migrated_items, + ) + + _migrate_workspace_item( + old_workspace / "dingtalk_session_webhooks.json", + default_workspace / "dingtalk_session_webhooks.json", + "dingtalk_session_webhooks.json", + migrated_items, + ) + + if migrated_items: + logger.info(f"Migrated workspace items: {', '.join(migrated_items)}") + + # Update root config.json to new structure + # CRITICAL: Preserve legacy agent fields in root config for downgrade + # compatibility. Old versions expect these fields to have valid values. + config.agents = AgentsConfig( + active_agent="default", + profiles={ + "default": AgentProfileRef( + id="default", + workspace_dir=str(default_workspace), + ), + }, + # Preserve legacy fields with values from migrated agent config + running=default_agent_config.running, + llm_routing=default_agent_config.llm_routing, + language=default_agent_config.language, + system_prompt_files=default_agent_config.system_prompt_files, + ) + + # IMPORTANT: Keep original config fields in root config.json for + # backward compatibility. If user downgrades, old version can still + # use these fields. New version will prioritize agent.json. + # DO NOT clear: channels, mcp, tools, security fields + + save_config(config) + logger.info( + "Updated root config.json to multi-agent structure " + "(kept original fields for backward compatibility)", + ) + + logger.info("=" * 60) + logger.info("Migration completed successfully!") + logger.info(f" Default agent workspace: {default_workspace}") + logger.info(f" Default agent config: {agent_config_path}") + logger.info("=" * 60) + + return True + + +def _migrate_workspace_item( + old_path: Path, + new_path: Path, + item_name: str, + migrated_items: list, +) -> None: + """Migrate a single workspace item (file or directory). + + Args: + old_path: Source path + new_path: Destination path + item_name: Name for logging + migrated_items: List to append migrated item names + """ + if not old_path.exists(): + return + + if new_path.exists(): + logger.debug(f"Skipping {item_name} (already exists in new location)") + return + + try: + if old_path.is_dir(): + shutil.copytree(old_path, new_path) + else: + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(old_path, new_path) + + migrated_items.append(item_name) + logger.debug(f"Migrated {item_name}") + except Exception as e: + logger.warning(f"Failed to migrate {item_name}: {e}") + + +def ensure_default_agent_exists() -> None: + """Ensure that the default agent exists in config. + + This function is called on startup to verify the default agent + is properly configured. If not, it will be created. + Also ensures necessary workspace files exist (chats.json, jobs.json). + """ + config = load_config() + + # Get or determine default workspace path + if "default" in config.agents.profiles: + agent_ref = config.agents.profiles["default"] + default_workspace = Path(agent_ref.workspace_dir).expanduser() + agent_existed = True + else: + default_workspace = Path("~/.copaw/workspaces/default").expanduser() + agent_existed = False + + # Ensure workspace directory exists + default_workspace.mkdir(parents=True, exist_ok=True) + + # Always ensure chats.json exists (even if agent already registered) + chats_file = default_workspace / "chats.json" + if not chats_file.exists(): + with open(chats_file, "w", encoding="utf-8") as f: + json.dump( + {"version": 1, "chats": []}, + f, + ensure_ascii=False, + indent=2, + ) + logger.debug("Created chats.json for default agent") + + # Always ensure jobs.json exists (even if agent already registered) + jobs_file = default_workspace / "jobs.json" + if not jobs_file.exists(): + with open(jobs_file, "w", encoding="utf-8") as f: + json.dump( + {"version": 1, "jobs": []}, + f, + ensure_ascii=False, + indent=2, + ) + logger.debug("Created jobs.json for default agent") + + # Only update config if agent didn't exist + if not agent_existed: + logger.info("Creating default agent...") + + # Add default agent reference to config + config.agents.profiles["default"] = AgentProfileRef( + id="default", + workspace_dir=str(default_workspace), + ) + + # Set as active if no active agent + if not config.agents.active_agent: + config.agents.active_agent = "default" + + save_config(config) + logger.info( + f"Created default agent with workspace: {default_workspace}", + ) diff --git a/src/copaw/app/multi_agent_manager.py b/src/copaw/app/multi_agent_manager.py new file mode 100644 index 000000000..1ef6d0b00 --- /dev/null +++ b/src/copaw/app/multi_agent_manager.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- +"""MultiAgentManager: Manages multiple agent workspaces with lazy loading. + +Provides centralized management for multiple Workspace objects, +including lazy loading, lifecycle management, and hot reloading. +""" +import asyncio +import logging +from typing import Dict + +from .workspace import Workspace +from ..config.utils import load_config + +logger = logging.getLogger(__name__) + + +class MultiAgentManager: + """Manages multiple agent workspaces. + + Features: + - Lazy loading: Workspaces are created only when first requested + - Lifecycle management: Start, stop, reload workspaces + - Thread-safe: Uses async lock for concurrent access + - Hot reload: Reload individual workspaces without affecting others + """ + + def __init__(self): + """Initialize multi-agent manager.""" + self.agents: Dict[str, Workspace] = {} + self._lock = asyncio.Lock() + logger.debug("MultiAgentManager initialized") + + async def get_agent(self, agent_id: str) -> Workspace: + """Get agent workspace by ID (lazy loading). + + If workspace doesn't exist in memory, it will be created and started. + Thread-safe using async lock. + + Args: + agent_id: Agent ID to retrieve + + Returns: + Workspace: The requested workspace instance + + Raises: + ValueError: If agent ID not found in configuration + """ + async with self._lock: + # Return existing agent if already loaded + if agent_id in self.agents: + logger.debug(f"Returning cached agent: {agent_id}") + return self.agents[agent_id] + + # Load configuration to get agent reference + config = load_config() + + if agent_id not in config.agents.profiles: + raise ValueError( + f"Agent '{agent_id}' not found in configuration. " + f"Available agents: {list(config.agents.profiles.keys())}", + ) + + agent_ref = config.agents.profiles[agent_id] + + # Create and start new workspace + logger.info(f"Creating new workspace: {agent_id}") + instance = Workspace( + agent_id=agent_id, + workspace_dir=agent_ref.workspace_dir, + ) + + try: + await instance.start() + self.agents[agent_id] = instance + logger.info(f"Workspace created and started: {agent_id}") + return instance + except Exception as e: + logger.error(f"Failed to start workspace {agent_id}: {e}") + raise + + async def stop_agent(self, agent_id: str) -> bool: + """Stop a specific agent instance. + + Args: + agent_id: Agent ID to stop + + Returns: + bool: True if agent was stopped, False if not running + """ + async with self._lock: + if agent_id not in self.agents: + logger.warning(f"Agent not running: {agent_id}") + return False + + instance = self.agents[agent_id] + await instance.stop() + del self.agents[agent_id] + logger.info(f"Agent stopped and removed: {agent_id}") + return True + + async def reload_agent(self, agent_id: str) -> bool: + """Reload a specific agent instance. + + This stops the agent, removes it from cache, so it will be + recreated with fresh configuration on next access. + + Args: + agent_id: Agent ID to reload + + Returns: + bool: True if agent was reloaded, False if not running + """ + async with self._lock: + if agent_id not in self.agents: + logger.debug( + f"Agent not running, will be loaded on next " + f"request: {agent_id}", + ) + return False + + logger.info(f"Reloading agent: {agent_id}") + instance = self.agents[agent_id] + await instance.stop() + del self.agents[agent_id] + logger.info( + f"Agent stopped and removed from cache " + f"(will be reloaded on next request): {agent_id}", + ) + return True + + async def stop_all(self): + """Stop all agent instances. + + Called during application shutdown to clean up resources. + """ + logger.info(f"Stopping all agents ({len(self.agents)} running)...") + + # Create list of agent IDs to avoid modifying dict during iteration + agent_ids = list(self.agents.keys()) + + for agent_id in agent_ids: + try: + instance = self.agents[agent_id] + await instance.stop() + logger.debug(f"Agent stopped: {agent_id}") + except Exception as e: + logger.error(f"Error stopping agent {agent_id}: {e}") + + self.agents.clear() + logger.info("All agents stopped") + + def list_loaded_agents(self) -> list[str]: + """List currently loaded agent IDs. + + Returns: + list[str]: List of loaded agent IDs + """ + return list(self.agents.keys()) + + def is_agent_loaded(self, agent_id: str) -> bool: + """Check if agent is currently loaded. + + Args: + agent_id: Agent ID to check + + Returns: + bool: True if agent is loaded and running + """ + return agent_id in self.agents + + async def preload_agent(self, agent_id: str) -> bool: + """Preload an agent instance during startup. + + Args: + agent_id: Agent ID to preload + + Returns: + bool: True if successfully preloaded, False if failed + """ + try: + await self.get_agent(agent_id) + logger.info(f"Successfully preloaded agent: {agent_id}") + return True + except Exception as e: + logger.error(f"Failed to preload agent {agent_id}: {e}") + return False + + async def start_all_configured_agents(self) -> dict[str, bool]: + """Start all agents defined in configuration concurrently. + + This method loads the current configuration and starts all + configured agents in parallel for optimal performance. + + Returns: + dict[str, bool]: Mapping of agent_id to success status + """ + config = load_config() + agent_ids = list(config.agents.profiles.keys()) + + if not agent_ids: + logger.warning("No agents configured in config") + return {} + + logger.info(f"Starting {len(agent_ids)} configured agent(s)") + + async def start_single_agent(agent_id: str) -> tuple[str, bool]: + """Start a single agent with error handling.""" + try: + logger.info(f"Starting agent: {agent_id}") + await self.preload_agent(agent_id) + logger.info(f"Agent started successfully: {agent_id}") + return (agent_id, True) + except Exception as e: + logger.error( + f"Failed to start agent {agent_id}: {e}. " + f"Continuing with other agents...", + ) + return (agent_id, False) + + # Start all agents concurrently + results = await asyncio.gather( + *[start_single_agent(agent_id) for agent_id in agent_ids], + return_exceptions=False, + ) + + # Build result mapping + result_map = dict(results) + success_count = sum(1 for success in result_map.values() if success) + logger.info( + f"Agent startup complete: {success_count}/{len(agent_ids)} " + f"agents started successfully", + ) + + return result_map + + def __repr__(self) -> str: + """String representation of manager.""" + loaded = list(self.agents.keys()) + return f"MultiAgentManager(loaded_agents={loaded})" diff --git a/src/copaw/app/routers/__init__.py b/src/copaw/app/routers/__init__.py index 3035b6481..c94d416f0 100644 --- a/src/copaw/app/routers/__init__.py +++ b/src/copaw/app/routers/__init__.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- +"""API routers.""" from fastapi import APIRouter from .agent import router as agent_router +from .agents import router as agents_router from .config import router as config_router from .local_models import router as local_models_router from .providers import router as providers_router @@ -17,9 +19,9 @@ from .console import router as console_router from .token_usage import router as token_usage_router - router = APIRouter() +router.include_router(agents_router) router.include_router(agent_router) router.include_router(config_router) router.include_router(console_router) @@ -36,4 +38,16 @@ router.include_router(envs_router) router.include_router(token_usage_router) -__all__ = ["router"] + +def create_agent_scoped_router() -> APIRouter: + """Create agent-scoped router that wraps existing routers. + + Returns: + APIRouter with all routers mounted under /agents/{agentId}/ + """ + from .agent_scoped import create_agent_scoped_router as _create + + return _create() + + +__all__ = ["router", "create_agent_scoped_router"] diff --git a/src/copaw/app/routers/agent.py b/src/copaw/app/routers/agent.py index 2afc9997b..0b35aa5fb 100644 --- a/src/copaw/app/routers/agent.py +++ b/src/copaw/app/routers/agent.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """Agent file management API.""" -from fastapi import APIRouter, Body, HTTPException +from fastapi import APIRouter, Body, HTTPException, Request from pydantic import BaseModel, Field from ...config import ( @@ -10,7 +10,8 @@ AgentsRunningConfig, ) -from ...agents.memory.agent_md_manager import AGENT_MD_MANAGER +from ...agents.memory.agent_md_manager import AgentMdManager +from ..agent_context import get_agent_for_request router = APIRouter(prefix="/agent", tags=["agent"]) @@ -35,14 +36,20 @@ class MdFileContent(BaseModel): "/files", response_model=list[MdFileInfo], summary="List working files", - description="List all working files", + description="List all working files (uses active agent)", ) -async def list_working_files() -> list[MdFileInfo]: +async def list_working_files( + request: Request, +) -> list[MdFileInfo]: """List working directory markdown files.""" try: + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) files = [ MdFileInfo.model_validate(file) - for file in AGENT_MD_MANAGER.list_working_mds() + for file in workspace_manager.list_working_mds() ] return files except Exception as exc: @@ -53,14 +60,19 @@ async def list_working_files() -> list[MdFileInfo]: "/files/{md_name}", response_model=MdFileContent, summary="Read a working file", - description="Read a working markdown file", + description="Read a working markdown file (uses active agent)", ) async def read_working_file( md_name: str, + request: Request, ) -> MdFileContent: """Read a working directory markdown file.""" try: - content = AGENT_MD_MANAGER.read_working_md(md_name) + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) + content = workspace_manager.read_working_md(md_name) return MdFileContent(content=content) except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc @@ -72,15 +84,20 @@ async def read_working_file( "/files/{md_name}", response_model=dict, summary="Write a working file", - description="Create or update a working file", + description="Create or update a working file (uses active agent)", ) async def write_working_file( md_name: str, - request: MdFileContent, + body: MdFileContent, + request: Request, ) -> dict: """Write a working directory markdown file.""" try: - AGENT_MD_MANAGER.write_working_md(md_name, request.content) + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) + workspace_manager.write_working_md(md_name, body.content) return {"written": True} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc @@ -90,14 +107,20 @@ async def write_working_file( "/memory", response_model=list[MdFileInfo], summary="List memory files", - description="List all memory files", + description="List all memory files (uses active agent)", ) -async def list_memory_files() -> list[MdFileInfo]: +async def list_memory_files( + request: Request, +) -> list[MdFileInfo]: """List memory directory markdown files.""" try: + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) files = [ MdFileInfo.model_validate(file) - for file in AGENT_MD_MANAGER.list_memory_mds() + for file in workspace_manager.list_memory_mds() ] return files except Exception as exc: @@ -108,14 +131,19 @@ async def list_memory_files() -> list[MdFileInfo]: "/memory/{md_name}", response_model=MdFileContent, summary="Read a memory file", - description="Read a memory markdown file", + description="Read a memory markdown file (uses active agent)", ) async def read_memory_file( md_name: str, + request: Request, ) -> MdFileContent: """Read a memory directory markdown file.""" try: - content = AGENT_MD_MANAGER.read_memory_md(md_name) + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) + content = workspace_manager.read_memory_md(md_name) return MdFileContent(content=content) except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc @@ -127,15 +155,20 @@ async def read_memory_file( "/memory/{md_name}", response_model=dict, summary="Write a memory file", - description="Create or update a memory file", + description="Create or update a memory file (uses active agent)", ) async def write_memory_file( md_name: str, - request: MdFileContent, + body: MdFileContent, + request: Request, ) -> dict: """Write a memory directory markdown file.""" try: - AGENT_MD_MANAGER.write_memory_md(md_name, request.content) + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) + workspace_manager.write_memory_md(md_name, body.content) return {"written": True} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc @@ -202,30 +235,56 @@ async def put_agent_language( "/running-config", response_model=AgentsRunningConfig, summary="Get agent running config", - description="Retrieve agent runtime behavior configuration", + description="Get running configuration for active agent", ) -async def get_agents_running_config() -> AgentsRunningConfig: +async def get_agents_running_config( + request: Request, +) -> AgentsRunningConfig: """Get agent running configuration.""" - config = load_config() - return config.agents.running + workspace = await get_agent_for_request(request) + from ...config.config import load_agent_config + + agent_config = load_agent_config(workspace.agent_id) + return agent_config.running or AgentsRunningConfig() @router.put( "/running-config", response_model=AgentsRunningConfig, summary="Update agent running config", - description="Update agent runtime behavior configuration", + description="Update running configuration for active agent", ) async def put_agents_running_config( running_config: AgentsRunningConfig = Body( ..., description="Updated agent running configuration", ), + request: Request = None, ) -> AgentsRunningConfig: """Update agent running configuration.""" - config = load_config() - config.agents.running = running_config - save_config(config) + workspace = await get_agent_for_request(request) + from ...config.config import load_agent_config, save_agent_config + + agent_config = load_agent_config(workspace.agent_id) + agent_config.running = running_config + save_agent_config(workspace.agent_id, agent_config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + manager = request.app.state.multi_agent_manager + await manager.reload_agent(workspace.agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + return running_config @@ -233,28 +292,54 @@ async def put_agents_running_config( "/system-prompt-files", response_model=list[str], summary="Get system prompt files", - description="Get list of markdown files enabled for system prompt", + description="Get system prompt files for active agent", ) -async def get_system_prompt_files() -> list[str]: +async def get_system_prompt_files( + request: Request, +) -> list[str]: """Get list of enabled system prompt files.""" - config = load_config() - return config.agents.system_prompt_files + workspace = await get_agent_for_request(request) + from ...config.config import load_agent_config + + agent_config = load_agent_config(workspace.agent_id) + return agent_config.system_prompt_files or [] @router.put( "/system-prompt-files", response_model=list[str], summary="Update system prompt files", - description="Update list of markdown files enabled for system prompt", + description="Update system prompt files for active agent", ) async def put_system_prompt_files( files: list[str] = Body( ..., - description="List of markdown filenames to load into system prompt", + description="Markdown filenames to load into system prompt", ), + request: Request = None, ) -> list[str]: """Update list of enabled system prompt files.""" - config = load_config() - config.agents.system_prompt_files = files - save_config(config) + workspace = await get_agent_for_request(request) + from ...config.config import load_agent_config, save_agent_config + + agent_config = load_agent_config(workspace.agent_id) + agent_config.system_prompt_files = files + save_agent_config(workspace.agent_id, agent_config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + manager = request.app.state.multi_agent_manager + await manager.reload_agent(workspace.agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + return files diff --git a/src/copaw/app/routers/agent_scoped.py b/src/copaw/app/routers/agent_scoped.py new file mode 100644 index 000000000..4b127775d --- /dev/null +++ b/src/copaw/app/routers/agent_scoped.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +"""Agent-scoped router that wraps existing routers under /agents/{agentId}/ + +This provides agent isolation by injecting agentId into request.state, +allowing downstream APIs to access the correct agent context. +""" +from fastapi import APIRouter, Request +from starlette.middleware.base import ( + BaseHTTPMiddleware, + RequestResponseEndpoint, +) +from starlette.responses import Response + + +class AgentContextMiddleware(BaseHTTPMiddleware): + """Middleware to inject agentId into request.state.""" + + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + """Extract agentId from path/header and inject into context.""" + import logging + from ..agent_context import set_current_agent_id + + logger = logging.getLogger(__name__) + agent_id = None + + # Priority 1: Extract agentId from path: /api/agents/{agentId}/... + path_parts = request.url.path.split("/") + if len(path_parts) >= 4 and path_parts[1] == "api": + if path_parts[2] == "agents": + agent_id = path_parts[3] + request.state.agent_id = agent_id + logger.debug( + f"AgentContextMiddleware: agent_id={agent_id} " + f"from path={request.url.path}", + ) + + # Priority 2: Check X-Agent-Id header + if not agent_id: + agent_id = request.headers.get("X-Agent-Id") + + # Set agent_id in context variable for use by runners + if agent_id: + set_current_agent_id(agent_id) + + response = await call_next(request) + return response + + +def create_agent_scoped_router() -> APIRouter: + """Create router that wraps all existing routers under /{agentId}/ + + Returns: + APIRouter with all sub-routers mounted under /{agentId}/ + """ + from .agent import router as agent_router + from .skills import router as skills_router + from .tools import router as tools_router + from .config import router as config_router + from .mcp import router as mcp_router + from .workspace import router as workspace_router + from ..crons.api import router as cron_router + from ..runner.api import router as chats_router + + # Create parent router with agentId parameter + router = APIRouter(prefix="/agents/{agentId}", tags=["agent-scoped"]) + + # Include all agent-specific sub-routers (they keep their own prefixes) + # /agents/{agentId}/agent/* -> agent_router + # /agents/{agentId}/chats/* -> chats_router + # /agents/{agentId}/config/* -> config_router (channels, heartbeat) + # /agents/{agentId}/cron/* -> cron_router + # /agents/{agentId}/mcp/* -> mcp_router + # /agents/{agentId}/skills/* -> skills_router + # /agents/{agentId}/tools/* -> tools_router + # /agents/{agentId}/workspace/* -> workspace_router + router.include_router(agent_router) + router.include_router(chats_router) + router.include_router(config_router) + router.include_router(cron_router) + router.include_router(mcp_router) + router.include_router(skills_router) + router.include_router(tools_router) + router.include_router(workspace_router) + + return router diff --git a/src/copaw/app/routers/agents.py b/src/copaw/app/routers/agents.py new file mode 100644 index 000000000..fa300254f --- /dev/null +++ b/src/copaw/app/routers/agents.py @@ -0,0 +1,515 @@ +# -*- coding: utf-8 -*- +"""Multi-agent management API. + +Provides RESTful API for managing multiple agent instances. +""" +import json +import logging +from pathlib import Path +from fastapi import APIRouter, Body, HTTPException, Request +from fastapi import Path as PathParam +from pydantic import BaseModel + +from ...config.config import ( + AgentProfileConfig, + AgentProfileRef, + load_agent_config, + save_agent_config, + generate_short_agent_id, +) +from ...config.utils import load_config, save_config +from ...agents.memory.agent_md_manager import AgentMdManager +from ..multi_agent_manager import MultiAgentManager + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/agents", tags=["agents"]) + + +class AgentSummary(BaseModel): + """Agent summary information.""" + + id: str + name: str + description: str + workspace_dir: str + + +class AgentListResponse(BaseModel): + """Response for listing agents.""" + + agents: list[AgentSummary] + + +class CreateAgentRequest(BaseModel): + """Request model for creating a new agent (id is auto-generated).""" + + name: str + description: str = "" + workspace_dir: str | None = None + language: str = "en" + + +class MdFileInfo(BaseModel): + """Markdown file metadata.""" + + filename: str + path: str + size: int + created_time: str + modified_time: str + + +class MdFileContent(BaseModel): + """Markdown file content.""" + + content: str + + +def _get_multi_agent_manager(request: Request) -> MultiAgentManager: + """Get MultiAgentManager from app state.""" + if not hasattr(request.app.state, "multi_agent_manager"): + raise HTTPException( + status_code=500, + detail="MultiAgentManager not initialized", + ) + return request.app.state.multi_agent_manager + + +@router.get( + "", + response_model=AgentListResponse, + summary="List all agents", + description="Get list of all configured agents", +) +async def list_agents() -> AgentListResponse: + """List all configured agents.""" + config = load_config() + + agents = [] + for agent_id, agent_ref in config.agents.profiles.items(): + # Load agent config to get name and description + try: + agent_config = load_agent_config(agent_id) + agents.append( + AgentSummary( + id=agent_id, + name=agent_config.name, + description=agent_config.description, + workspace_dir=agent_ref.workspace_dir, + ), + ) + except Exception: # noqa: E722 + # If agent config load fails, use basic info + agents.append( + AgentSummary( + id=agent_id, + name=agent_id.title(), + description="", + workspace_dir=agent_ref.workspace_dir, + ), + ) + + return AgentListResponse( + agents=agents, + ) + + +@router.get( + "/{agentId}", + response_model=AgentProfileConfig, + summary="Get agent details", + description="Get complete configuration for a specific agent", +) +async def get_agent(agentId: str = PathParam(...)) -> AgentProfileConfig: + """Get agent configuration.""" + try: + agent_config = load_agent_config(agentId) + return agent_config + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post( + "", + response_model=AgentProfileRef, + status_code=201, + summary="Create new agent", + description="Create a new agent (ID is auto-generated by server)", +) +async def create_agent( + request: CreateAgentRequest = Body(...), +) -> AgentProfileRef: + """Create a new agent with auto-generated ID.""" + config = load_config() + + # Always generate a unique short UUID (6 characters) + max_attempts = 10 + new_id = None + for _ in range(max_attempts): + candidate_id = generate_short_agent_id() + if candidate_id not in config.agents.profiles: + new_id = candidate_id + break + + if new_id is None: + raise HTTPException( + status_code=500, + detail="Failed to generate unique agent ID after 10 attempts", + ) + + # Create workspace directory + workspace_dir = Path( + request.workspace_dir or f"~/.copaw/workspaces/{new_id}", + ).expanduser() + workspace_dir.mkdir(parents=True, exist_ok=True) + + # Build complete agent config with generated ID + from ...config.config import ( + ChannelConfig, + MCPConfig, + HeartbeatConfig, + ToolsConfig, + ) + + agent_config = AgentProfileConfig( + id=new_id, + name=request.name, + description=request.description, + workspace_dir=str(workspace_dir), + language=request.language, + channels=ChannelConfig(), + mcp=MCPConfig(clients={}), + heartbeat=HeartbeatConfig(), + tools=ToolsConfig(), + ) + + # Initialize workspace with default files + _initialize_agent_workspace(workspace_dir, agent_config) + + # Save agent configuration to workspace/agent.json + agent_ref = AgentProfileRef( + id=new_id, + workspace_dir=str(workspace_dir), + ) + + # Add to root config + config.agents.profiles[new_id] = agent_ref + save_config(config) + + # Save agent config to workspace + save_agent_config(new_id, agent_config) + + logger.info(f"Created new agent: {new_id} (name={request.name})") + + return agent_ref + + +@router.put( + "/{agentId}", + response_model=AgentProfileConfig, + summary="Update agent", + description="Update agent configuration and trigger reload", +) +async def update_agent( + agentId: str = PathParam(...), + agent_config: AgentProfileConfig = Body(...), + request: Request = None, +) -> AgentProfileConfig: + """Update agent configuration.""" + config = load_config() + + if agentId not in config.agents.profiles: + raise HTTPException( + status_code=404, + detail=f"Agent '{agentId}' not found", + ) + + # Ensure ID doesn't change + agent_config.id = agentId + + # Save agent configuration + save_agent_config(agentId, agent_config) + + # Trigger hot reload if agent is running + manager = _get_multi_agent_manager(request) + await manager.reload_agent(agentId) + + return agent_config + + +@router.delete( + "/{agentId}", + summary="Delete agent", + description="Delete agent and workspace (cannot delete default agent)", +) +async def delete_agent( + agentId: str = PathParam(...), + request: Request = None, +) -> dict: + """Delete an agent.""" + config = load_config() + + if agentId not in config.agents.profiles: + raise HTTPException( + status_code=404, + detail=f"Agent '{agentId}' not found", + ) + + if agentId == "default": + raise HTTPException( + status_code=400, + detail="Cannot delete the default agent", + ) + + # Stop agent instance if running + manager = _get_multi_agent_manager(request) + await manager.stop_agent(agentId) + + # Remove from config + del config.agents.profiles[agentId] + save_config(config) + + # Note: We don't delete the workspace directory for safety + # Users can manually delete it if needed + + return {"success": True, "agent_id": agentId} + + +@router.get( + "/{agentId}/files", + response_model=list[MdFileInfo], + summary="List agent workspace files", + description="List all markdown files in agent's workspace", +) +async def list_agent_files( + agentId: str = PathParam(...), + request: Request = None, +) -> list[MdFileInfo]: + """List agent workspace files.""" + manager = _get_multi_agent_manager(request) + + try: + workspace = await manager.get_agent(agentId) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + workspace_manager = AgentMdManager(str(workspace.workspace_dir)) + + try: + files = [ + MdFileInfo.model_validate(file) + for file in workspace_manager.list_working_mds() + ] + return files + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/{agentId}/files/{filename}", + response_model=MdFileContent, + summary="Read agent workspace file", + description="Read a markdown file from agent's workspace", +) +async def read_agent_file( + agentId: str = PathParam(...), + filename: str = PathParam(...), + request: Request = None, +) -> MdFileContent: + """Read agent workspace file.""" + manager = _get_multi_agent_manager(request) + + try: + workspace = await manager.get_agent(agentId) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + workspace_manager = AgentMdManager(str(workspace.workspace_dir)) + + try: + content = workspace_manager.read_working_md(filename) + return MdFileContent(content=content) + except FileNotFoundError as exc: + raise HTTPException( + status_code=404, + detail=f"File '{filename}' not found", + ) from exc + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.put( + "/{agentId}/files/{filename}", + response_model=dict, + summary="Write agent workspace file", + description="Create or update a markdown file in agent's workspace", +) +async def write_agent_file( + agentId: str = PathParam(...), + filename: str = PathParam(...), + file_content: MdFileContent = Body(...), + request: Request = None, +) -> dict: + """Write agent workspace file.""" + manager = _get_multi_agent_manager(request) + + try: + workspace = await manager.get_agent(agentId) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + workspace_manager = AgentMdManager(str(workspace.workspace_dir)) + + try: + workspace_manager.write_working_md(filename, file_content.content) + return {"written": True, "filename": filename} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/{agentId}/memory", + response_model=list[MdFileInfo], + summary="List agent memory files", + description="List all memory files for an agent", +) +async def list_agent_memory( + agentId: str = PathParam(...), + request: Request = None, +) -> list[MdFileInfo]: + """List agent memory files.""" + manager = _get_multi_agent_manager(request) + + try: + workspace = await manager.get_agent(agentId) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + workspace_manager = AgentMdManager(str(workspace.workspace_dir)) + + try: + files = [ + MdFileInfo.model_validate(file) + for file in workspace_manager.list_memory_mds() + ] + return files + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +def _initialize_agent_workspace( # pylint: disable=too-many-branches + workspace_dir: Path, + agent_config: AgentProfileConfig, # pylint: disable=unused-argument +) -> None: + """Initialize agent workspace (similar to copaw init --defaults). + + Args: + workspace_dir: Path to agent workspace + agent_config: Agent configuration (reserved for future use) + """ + import shutil + from ...config import load_config as load_global_config + + # Create essential subdirectories + (workspace_dir / "sessions").mkdir(exist_ok=True) + (workspace_dir / "memory").mkdir(exist_ok=True) + (workspace_dir / "active_skills").mkdir(exist_ok=True) + (workspace_dir / "customized_skills").mkdir(exist_ok=True) + + # Get language from global config + config = load_global_config() + language = config.agents.language or "zh" + + # Copy MD files from agents/md_files/{language}/ to workspace + md_files_dir = ( + Path(__file__).parent.parent.parent / "agents" / "md_files" / language + ) + if md_files_dir.exists(): + for md_file in md_files_dir.glob("*.md"): + target_file = workspace_dir / md_file.name + if not target_file.exists(): + try: + shutil.copy2(md_file, target_file) + except Exception as e: + logger.warning( + f"Failed to copy {md_file.name}: {e}", + ) + + # Create HEARTBEAT.md if not exists + heartbeat_file = workspace_dir / "HEARTBEAT.md" + if not heartbeat_file.exists(): + DEFAULT_HEARTBEAT_MDS = { + "zh": """# Heartbeat checklist +- 扫描收件箱紧急邮件 +- 查看未来 2h 的日历 +- 检查待办是否卡住 +- 若安静超过 8h,轻量 check-in +""", + "en": """# Heartbeat checklist +- Scan inbox for urgent email +- Check calendar for next 2h +- Check tasks for blockers +- Light check-in if quiet for 8h +""", + "ru": """# Heartbeat checklist +- Проверить входящие на срочные письма +- Просмотреть календарь на ближайшие 2 часа +- Проверить задачи на наличие блокировок +- Лёгкая проверка при отсутствии активности более 8 часов +""", + } + heartbeat_content = DEFAULT_HEARTBEAT_MDS.get( + language, + DEFAULT_HEARTBEAT_MDS["en"], + ) + with open(heartbeat_file, "w", encoding="utf-8") as f: + f.write(heartbeat_content.strip()) + + # Copy builtin skills to agent's active_skills directory + builtin_skills_dir = ( + Path(__file__).parent.parent.parent / "agents" / "skills" + ) + if builtin_skills_dir.exists(): + for skill_dir in builtin_skills_dir.iterdir(): + if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists(): + target_skill_dir = ( + workspace_dir / "active_skills" / skill_dir.name + ) + if not target_skill_dir.exists(): + try: + shutil.copytree(skill_dir, target_skill_dir) + except Exception as e: + logger.warning( + f"Failed to copy skill {skill_dir.name}: {e}", + ) + + # Create empty jobs.json for cron jobs + jobs_file = workspace_dir / "jobs.json" + if not jobs_file.exists(): + with open(jobs_file, "w", encoding="utf-8") as f: + json.dump( + {"version": 1, "jobs": []}, + f, + ensure_ascii=False, + indent=2, + ) + + # Create empty chats.json for chat history + chats_file = workspace_dir / "chats.json" + if not chats_file.exists(): + with open(chats_file, "w", encoding="utf-8") as f: + json.dump( + {"version": 1, "chats": []}, + f, + ensure_ascii=False, + indent=2, + ) + + # Create empty token_usage.json + token_usage_file = workspace_dir / "token_usage.json" + if not token_usage_file.exists(): + with open(token_usage_file, "w", encoding="utf-8") as f: + f.write("[]") diff --git a/src/copaw/app/routers/config.py b/src/copaw/app/routers/config.py index efb43a612..cf86f460f 100644 --- a/src/copaw/app/routers/config.py +++ b/src/copaw/app/routers/config.py @@ -7,7 +7,6 @@ from ...config import ( load_config, save_config, - get_heartbeat_config, ChannelConfig, ChannelConfigUnion, get_available_channels, @@ -56,15 +55,23 @@ summary="List all channels", description="Retrieve configuration for all available channels", ) -async def list_channels() -> dict: +async def list_channels(request: Request) -> dict: """List all channel configs (filtered by available channels).""" - config = load_config() + from ..agent_context import get_agent_for_request + + agent = await get_agent_for_request(request) + agent_config = agent.config available = get_available_channels() - # Get all channel configs from model_dump and __pydantic_extra__ - all_configs = config.channels.model_dump() - extra = getattr(config.channels, "__pydantic_extra__", None) or {} - all_configs.update(extra) + # Get channel configs from agent's config (with fallback to empty) + channels_config = agent_config.channels + if channels_config is None: + # No channels config yet, use empty defaults + all_configs = {} + else: + all_configs = channels_config.model_dump() + extra = getattr(channels_config, "__pydantic_extra__", None) or {} + all_configs.update(extra) # Return all available channels (use default config if not saved) result = {} @@ -102,15 +109,35 @@ async def list_channel_types() -> List[str]: description="Update configuration for all channels at once", ) async def put_channels( + request: Request, channels_config: ChannelConfig = Body( ..., description="Complete channel configuration", ), ) -> ChannelConfig: """Update all channel configs.""" - config = load_config() - config.channels = channels_config - save_config(config) + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config + + agent = await get_agent_for_request(request) + agent.config.channels = channels_config + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + await agent.reload() + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + return channels_config @@ -121,6 +148,7 @@ async def put_channels( description="Retrieve configuration for a specific channel by name", ) async def get_channel( + request: Request, channel_name: str = Path( ..., description="Name of the channel to retrieve", @@ -128,16 +156,26 @@ async def get_channel( ), ) -> ChannelConfigUnion: """Get a specific channel config by name.""" + from ..agent_context import get_agent_for_request + available = get_available_channels() if channel_name not in available: raise HTTPException( status_code=404, detail=f"Channel '{channel_name}' not found", ) - config = load_config() - single_channel_config = getattr(config.channels, channel_name, None) + + agent = await get_agent_for_request(request) + channels = agent.config.channels + if channels is None: + raise HTTPException( + status_code=404, + detail=f"Channel '{channel_name}' not configured", + ) + + single_channel_config = getattr(channels, channel_name, None) if single_channel_config is None: - extra = getattr(config.channels, "__pydantic_extra__", None) or {} + extra = getattr(channels, "__pydantic_extra__", None) or {} single_channel_config = extra.get(channel_name) if single_channel_config is None: raise HTTPException( @@ -154,6 +192,7 @@ async def get_channel( description="Update configuration for a specific channel by name", ) async def put_channel( + request: Request, channel_name: str = Path( ..., description="Name of the channel to update", @@ -165,13 +204,21 @@ async def put_channel( ), ) -> ChannelConfigUnion: """Update a specific channel config by name.""" + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config + available = get_available_channels() if channel_name not in available: raise HTTPException( status_code=404, detail=f"Channel '{channel_name}' not found", ) - config = load_config() + + agent = await get_agent_for_request(request) + + # Initialize channels if not exists + if agent.config.channels is None: + agent.config.channels = ChannelConfig() config_class = _CHANNEL_CONFIG_CLASS_MAP.get(channel_name) if config_class is not None: @@ -180,9 +227,25 @@ async def put_channel( # For custom channels, just use the dict channel_config = single_channel_config - # Allow setting extra (plugin) channel config - setattr(config.channels, channel_name, channel_config) - save_config(config) + # Set channel config in agent's config + setattr(agent.config.channels, channel_name, channel_config) + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + await agent.reload() + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + return channel_config @@ -191,9 +254,16 @@ async def put_channel( summary="Get heartbeat config", description="Return current heartbeat config (interval, target, etc.)", ) -async def get_heartbeat() -> Any: +async def get_heartbeat(request: Request) -> Any: """Return effective heartbeat config (from file or default).""" - hb = get_heartbeat_config() + from ..agent_context import get_agent_for_request + from ...config.config import HeartbeatConfig as HeartbeatConfigModel + + agent = await get_agent_for_request(request) + hb = agent.config.heartbeat + if hb is None: + # Use default if not configured + hb = HeartbeatConfigModel() return hb.model_dump(mode="json", by_alias=True) @@ -207,19 +277,34 @@ async def put_heartbeat( body: HeartbeatBody = Body(..., description="Heartbeat configuration"), ) -> Any: """Update heartbeat config and reschedule the heartbeat job.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config + + agent = await get_agent_for_request(request) hb = HeartbeatConfig( enabled=body.enabled, every=body.every, target=body.target, active_hours=body.active_hours, ) - config.agents.defaults.heartbeat = hb - save_config(config) + agent.config.heartbeat = hb + save_agent_config(agent.agent_id, agent.config) + + # Reschedule heartbeat (async, non-blocking) + import asyncio + + async def reschedule_in_background(): + try: + if agent.cron_manager is not None: + await agent.cron_manager.reschedule_heartbeat() + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reschedule failed: {e}", + ) - cron_manager = getattr(request.app.state, "cron_manager", None) - if cron_manager is not None: - await cron_manager.reschedule_heartbeat() + asyncio.create_task(reschedule_in_background()) return hb.model_dump(mode="json", by_alias=True) diff --git a/src/copaw/app/routers/mcp.py b/src/copaw/app/routers/mcp.py index e647519be..05d42af7e 100644 --- a/src/copaw/app/routers/mcp.py +++ b/src/copaw/app/routers/mcp.py @@ -5,10 +5,9 @@ from typing import Dict, List, Optional, Literal -from fastapi import APIRouter, Body, HTTPException, Path +from fastapi import APIRouter, Body, HTTPException, Path, Request from pydantic import BaseModel, Field -from ...config import load_config, save_config from ...config.config import MCPClientConfig router = APIRouter(prefix="/mcp", tags=["mcp"]) @@ -194,12 +193,18 @@ def _build_client_info(key: str, client: MCPClientConfig) -> MCPClientInfo: response_model=List[MCPClientInfo], summary="List all MCP clients", ) -async def list_mcp_clients() -> List[MCPClientInfo]: +async def list_mcp_clients(request: Request) -> List[MCPClientInfo]: """Get list of all configured MCP clients.""" - config = load_config() + from ..agent_context import get_agent_for_request + + agent = await get_agent_for_request(request) + mcp_config = agent.config.mcp + if mcp_config is None or not mcp_config.clients: + return [] + return [ _build_client_info(key, client) - for key, client in config.mcp.clients.items() + for key, client in mcp_config.clients.items() ] @@ -208,10 +213,19 @@ async def list_mcp_clients() -> List[MCPClientInfo]: response_model=MCPClientInfo, summary="Get MCP client details", ) -async def get_mcp_client(client_key: str = Path(...)) -> MCPClientInfo: +async def get_mcp_client( + request: Request, + client_key: str = Path(...), +) -> MCPClientInfo: """Get details of a specific MCP client.""" - config = load_config() - client = config.mcp.clients.get(client_key) + from ..agent_context import get_agent_for_request + + agent = await get_agent_for_request(request) + mcp_config = agent.config.mcp + if mcp_config is None: + raise HTTPException(404, detail=f"MCP client '{client_key}' not found") + + client = mcp_config.clients.get(client_key) if client is None: raise HTTPException(404, detail=f"MCP client '{client_key}' not found") return _build_client_info(client_key, client) @@ -224,14 +238,22 @@ async def get_mcp_client(client_key: str = Path(...)) -> MCPClientInfo: status_code=201, ) async def create_mcp_client( + request: Request, client_key: str = Body(..., embed=True), client: MCPClientCreateRequest = Body(..., embed=True), ) -> MCPClientInfo: """Create a new MCP client configuration.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config, MCPConfig + + agent = await get_agent_for_request(request) + + # Initialize mcp config if not exists + if agent.config.mcp is None: + agent.config.mcp = MCPConfig(clients={}) # Check if client already exists - if client_key in config.mcp.clients: + if client_key in agent.config.mcp.clients: raise HTTPException( 400, detail=f"MCP client '{client_key}' already exists. Use PUT to " @@ -252,9 +274,24 @@ async def create_mcp_client( cwd=client.cwd, ) - # Add to config and save - config.mcp.clients[client_key] = new_client - save_config(config) + # Add to agent's config and save + agent.config.mcp.clients[client_key] = new_client + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + await agent.reload() + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) return _build_client_info(client_key, new_client) @@ -265,17 +302,21 @@ async def create_mcp_client( summary="Update an MCP client", ) async def update_mcp_client( + request: Request, client_key: str = Path(...), updates: MCPClientUpdateRequest = Body(...), ) -> MCPClientInfo: """Update an existing MCP client configuration.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config + + agent = await get_agent_for_request(request) - # Check if client exists - existing = config.mcp.clients.get(client_key) - if existing is None: + if agent.config.mcp is None or client_key not in agent.config.mcp.clients: raise HTTPException(404, detail=f"MCP client '{client_key}' not found") + existing = agent.config.mcp.clients[client_key] + # Update fields if provided update_data = updates.model_dump(exclude_unset=True) @@ -288,10 +329,25 @@ async def update_mcp_client( merged_data = existing.model_dump(mode="json") merged_data.update(update_data) updated_client = MCPClientConfig.model_validate(merged_data) - config.mcp.clients[client_key] = updated_client + agent.config.mcp.clients[client_key] = updated_client # Save updated config - save_config(config) + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + await agent.reload() + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) return _build_client_info(client_key, updated_client) @@ -302,18 +358,38 @@ async def update_mcp_client( summary="Toggle MCP client enabled status", ) async def toggle_mcp_client( + request: Request, client_key: str = Path(...), ) -> MCPClientInfo: """Toggle the enabled status of an MCP client.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config - client = config.mcp.clients.get(client_key) - if client is None: + agent = await get_agent_for_request(request) + + if agent.config.mcp is None or client_key not in agent.config.mcp.clients: raise HTTPException(404, detail=f"MCP client '{client_key}' not found") + client = agent.config.mcp.clients[client_key] + # Toggle enabled status client.enabled = not client.enabled - save_config(config) + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + await agent.reload() + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) return _build_client_info(client_key, client) @@ -324,16 +400,35 @@ async def toggle_mcp_client( summary="Delete an MCP client", ) async def delete_mcp_client( + request: Request, client_key: str = Path(...), ) -> Dict[str, str]: """Delete an MCP client configuration.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config - if client_key not in config.mcp.clients: + agent = await get_agent_for_request(request) + + if agent.config.mcp is None or client_key not in agent.config.mcp.clients: raise HTTPException(404, detail=f"MCP client '{client_key}' not found") # Remove client - del config.mcp.clients[client_key] - save_config(config) + del agent.config.mcp.clients[client_key] + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + await agent.reload() + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) return {"message": f"MCP client '{client_key}' deleted successfully"} diff --git a/src/copaw/app/routers/providers.py b/src/copaw/app/routers/providers.py index 2a912e98d..c6af4c1f8 100644 --- a/src/copaw/app/routers/providers.py +++ b/src/copaw/app/routers/providers.py @@ -3,14 +3,21 @@ from __future__ import annotations +import logging from typing import List, Literal, Optional from copy import deepcopy from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request from pydantic import BaseModel, Field +from ..agent_context import get_agent_for_request +from ...config.config import load_agent_config, save_agent_config from ...providers.provider import ProviderInfo, ModelInfo from ...providers.provider_manager import ActiveModelsInfo, ProviderManager +from ...providers.models import ModelSlotConfig + + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/models", tags=["models"]) @@ -357,9 +364,37 @@ async def remove_model_endpoint( summary="Get active LLM", ) async def get_active_models( + request: Request, manager: ProviderManager = Depends(get_provider_manager), ) -> ActiveModelsInfo: - return ActiveModelsInfo(active_llm=manager.get_active_model()) + """Get active model (agent-specific or global fallback).""" + # Try to get agent-specific active model + try: + workspace = await get_agent_for_request(request) + logger.debug( + f"get_active_models: got workspace.agent_id={workspace.agent_id}", + ) + agent_config = load_agent_config(workspace.agent_id) + logger.debug( + f"get_active_models: agent_config.active_model=" + f"{agent_config.active_model}", + ) + if agent_config.active_model: + logger.info( + f"Returning agent-specific model for {workspace.agent_id}: " + f"{agent_config.active_model}", + ) + return ActiveModelsInfo(active_llm=agent_config.active_model) + except Exception as e: + logger.warning( + f"Failed to get agent-specific model: {e}", + exc_info=True, + ) + + # Fallback to global active model + global_model = manager.get_active_model() + logger.info(f"Returning global model: {global_model}") + return ActiveModelsInfo(active_llm=global_model) @router.put( @@ -368,17 +403,34 @@ async def get_active_models( summary="Set active LLM", ) async def set_active_model( + request: Request, manager: ProviderManager = Depends(get_provider_manager), body: ModelSlotRequest = Body(...), ) -> ActiveModelsInfo: + """Set active model for current agent.""" + # Validate provider and model exist try: await manager.activate_model(body.provider_id, body.model) except ValueError as exc: message = str(exc) lower_msg = message.lower() if "provider" in lower_msg and "not found" in lower_msg: - # Missing provider raise HTTPException(status_code=404, detail=message) from exc - # Invalid model, unreachable provider, or other configuration error raise HTTPException(status_code=400, detail=message) from exc + + # Save to agent config + try: + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) + agent_config.active_model = ModelSlotConfig( + provider_id=body.provider_id, + model=body.model, + ) + save_agent_config(workspace.agent_id, agent_config) + except Exception as e: + # Log warning but don't fail the request + logger.warning( + f"Failed to save active model to agent config: {e}", + ) + return ActiveModelsInfo(active_llm=manager.get_active_model()) diff --git a/src/copaw/app/routers/skills.py b/src/copaw/app/routers/skills.py index 96f2d312f..a4f27d4b5 100644 --- a/src/copaw/app/routers/skills.py +++ b/src/copaw/app/routers/skills.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- import logging from typing import Any -from fastapi import APIRouter, HTTPException +from pathlib import Path +from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field from ...agents.skills_manager import ( SkillService, SkillInfo, - list_available_skills, ) from ...agents.skills_hub import ( search_hub_skills, @@ -60,32 +60,74 @@ class HubInstallRequest(BaseModel): @router.get("") -async def list_skills() -> list[SkillSpec]: - all_skills = SkillService.list_all_skills() - - available_skills = list_available_skills() - skills_spec = [] - for skill in all_skills: - skills_spec.append( - SkillSpec( - **skill.model_dump(), - enabled=skill.name in available_skills, - ), +async def list_skills( + request: Request, +) -> list[SkillSpec]: + """List all skills for active agent.""" + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + # Get all skills (builtin + customized) + all_skills = skill_service.list_all_skills() + + # Get active skills to determine enabled status + active_skills_dir = workspace_dir / "active_skills" + active_skill_names = set() + if active_skills_dir.exists(): + active_skill_names = { + d.name + for d in active_skills_dir.iterdir() + if d.is_dir() and (d / "SKILL.md").exists() + } + + # Convert to SkillSpec with enabled status + skills_spec = [ + SkillSpec( + name=skill.name, + content=skill.content, + source=skill.source, + path=skill.path, + references=skill.references, + scripts=skill.scripts, + enabled=skill.name in active_skill_names, ) + for skill in all_skills + ] + return skills_spec @router.get("/available") -async def get_available_skills() -> list[SkillSpec]: - available_skills = SkillService.list_available_skills() - skills_spec = [] - for skill in available_skills: - skills_spec.append( - SkillSpec( - **skill.model_dump(), - enabled=True, - ), +async def get_available_skills( + request: Request, +) -> list[SkillSpec]: + """List available (enabled) skills for active agent.""" + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + # Get available (active) skills + available_skills = skill_service.list_available_skills() + + # Convert to SkillSpec + skills_spec = [ + SkillSpec( + name=skill.name, + content=skill.content, + source=skill.source, + path=skill.path, + references=skill.references, + scripts=skill.scripts, + enabled=True, ) + for skill in available_skills + ] + return skills_spec @@ -118,25 +160,34 @@ def _github_token_hint(bundle_url: str) -> str: @router.post("/hub/install") -async def install_from_hub(request: HubInstallRequest): +async def install_from_hub( + request_body: HubInstallRequest, + request: Request, +): + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + try: result = install_skill_from_hub( - bundle_url=request.bundle_url, - version=request.version, - enable=request.enable, - overwrite=request.overwrite, + workspace_dir=workspace_dir, + bundle_url=request_body.bundle_url, + version=request_body.version, + enable=request_body.enable, + overwrite=request_body.overwrite, ) except ValueError as e: detail = str(e) logger.warning( "Skill hub install 400: bundle_url=%s detail=%s", - (request.bundle_url or "")[:80], + (request_body.bundle_url or "")[:80], detail, ) raise HTTPException(status_code=400, detail=detail) from e except RuntimeError as e: # Upstream hub is flaky/rate-limited sometimes; surface as bad gateway. - detail = str(e) + _github_token_hint(request.bundle_url) + detail = str(e) + _github_token_hint(request_body.bundle_url) logger.exception( "Skill hub install failed (upstream/rate limit): %s", e, @@ -144,7 +195,7 @@ async def install_from_hub(request: HubInstallRequest): raise HTTPException(status_code=502, detail=detail) from e except Exception as e: detail = f"Skill hub import failed: {e}" + _github_token_hint( - request.bundle_url, + request_body.bundle_url, ) logger.exception("Skill hub import failed: %s", e) raise HTTPException(status_code=502, detail=detail) from e @@ -157,48 +208,157 @@ async def install_from_hub(request: HubInstallRequest): @router.post("/batch-disable") -async def batch_disable_skills(skill_name: list[str]) -> None: +async def batch_disable_skills( + skill_name: list[str], + request: Request, +) -> None: + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + for skill in skill_name: - SkillService.disable_skill(skill) + skill_service.disable_skill(skill) @router.post("/batch-enable") -async def batch_enable_skills(skill_name: list[str]) -> None: +async def batch_enable_skills( + skill_name: list[str], + request: Request, +) -> None: + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + for skill in skill_name: - SkillService.enable_skill(skill) + skill_service.enable_skill(skill) @router.post("") -async def create_skill(request: CreateSkillRequest): - result = SkillService.create_skill( - name=request.name, - content=request.content, - references=request.references, - scripts=request.scripts, +async def create_skill( + request_body: CreateSkillRequest, + request: Request, +): + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + result = skill_service.create_skill( + name=request_body.name, + content=request_body.content, + references=request_body.references, + scripts=request_body.scripts, ) return {"created": result} @router.post("/{skill_name}/disable") -async def disable_skill(skill_name: str): - result = SkillService.disable_skill(skill_name) - return {"disabled": result} +async def disable_skill( + skill_name: str, + request: Request = None, +): + """Disable skill for active agent.""" + from ..agent_context import get_agent_for_request + import shutil + import asyncio + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + active_skill_dir = workspace_dir / "active_skills" / skill_name + + if active_skill_dir.exists(): + shutil.rmtree(active_skill_dir) + + # Hot reload config (async, non-blocking) + async def reload_in_background(): + try: + manager = request.app.state.multi_agent_manager + await manager.reload_agent(workspace.agent_id) + except Exception as e: + logger.warning(f"Background reload failed: {e}") + + asyncio.create_task(reload_in_background()) + + return {"disabled": True} + + return {"disabled": False} @router.post("/{skill_name}/enable") -async def enable_skill(skill_name: str): - result = SkillService.enable_skill(skill_name) - return {"enabled": result} +async def enable_skill( + skill_name: str, + request: Request = None, +): + """Enable skill for active agent.""" + from ..agent_context import get_agent_for_request + import shutil + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + active_skill_dir = workspace_dir / "active_skills" / skill_name + + # If already enabled, skip + if active_skill_dir.exists(): + return {"enabled": True} + + # Find skill from builtin or customized + builtin_skill_dir = ( + Path(__file__).parent.parent.parent / "agents" / "skills" / skill_name + ) + customized_skill_dir = workspace_dir / "customized_skills" / skill_name + + source_dir = None + if customized_skill_dir.exists(): + source_dir = customized_skill_dir + elif builtin_skill_dir.exists(): + source_dir = builtin_skill_dir + + if not source_dir or not (source_dir / "SKILL.md").exists(): + raise HTTPException( + status_code=404, + detail=f"Skill '{skill_name}' not found", + ) + + # Copy to active_skills + shutil.copytree(source_dir, active_skill_dir) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + manager = request.app.state.multi_agent_manager + await manager.reload_agent(workspace.agent_id) + except Exception as e: + logger.warning(f"Background reload failed: {e}") + + asyncio.create_task(reload_in_background()) + + return {"enabled": True} @router.delete("/{skill_name}") -async def delete_skill(skill_name: str): +async def delete_skill( + skill_name: str, + request: Request, +): """Delete a skill from customized_skills directory permanently. This only deletes skills from customized_skills directory. Built-in skills cannot be deleted. """ - result = SkillService.delete_skill(skill_name) + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + result = skill_service.delete_skill(skill_name) return {"deleted": result} @@ -207,6 +367,7 @@ async def load_skill_file( skill_name: str, source: str, file_path: str, + request: Request, ): """Load a specific file from a skill's references or scripts directory. @@ -226,7 +387,13 @@ async def load_skill_file( GET /skills/builtin_skill/files/builtin/scripts/utils/helper.py """ - content = SkillService.load_skill_file( + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + content = skill_service.load_skill_file( skill_name=skill_name, file_path=file_path, source=source, diff --git a/src/copaw/app/routers/tools.py b/src/copaw/app/routers/tools.py index 1f334b3f5..0f7b6d524 100644 --- a/src/copaw/app/routers/tools.py +++ b/src/copaw/app/routers/tools.py @@ -5,10 +5,15 @@ from typing import List -from fastapi import APIRouter, HTTPException, Path +from fastapi import ( + APIRouter, + HTTPException, + Path, + Request, +) from pydantic import BaseModel, Field -from ...config import load_config, save_config +from ...config import load_config router = APIRouter(prefix="/tools", tags=["tools"]) @@ -22,20 +27,33 @@ class ToolInfo(BaseModel): @router.get("", response_model=List[ToolInfo]) -async def list_tools() -> List[ToolInfo]: - """List all built-in tools and their enabled status. +async def list_tools( + request: Request, +) -> List[ToolInfo]: + """List all built-in tools and enabled status for active agent. Returns: List of tool information """ - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import load_agent_config + + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) # Ensure tools config exists with defaults - if not hasattr(config, "tools"): - config.tools = {} + if not agent_config.tools or not agent_config.tools.builtin_tools: + # Fallback to global config if agent config has no tools + config = load_config() + tools_config = config.tools if hasattr(config, "tools") else None + if not tools_config: + return [] + builtin_tools = tools_config.builtin_tools + else: + builtin_tools = agent_config.tools.builtin_tools tools_list = [] - for tool_config in config.tools.builtin_tools.values(): + for tool_config in builtin_tools.values(): tools_list.append( ToolInfo( name=tool_config.name, @@ -48,11 +66,15 @@ async def list_tools() -> List[ToolInfo]: @router.patch("/{tool_name}/toggle", response_model=ToolInfo) -async def toggle_tool(tool_name: str = Path(...)) -> ToolInfo: - """Toggle tool enabled status. +async def toggle_tool( + tool_name: str = Path(...), + request: Request = None, +) -> ToolInfo: + """Toggle tool enabled status for active agent. Args: tool_name: Tool function name + request: FastAPI request Returns: Updated tool information @@ -60,21 +82,45 @@ async def toggle_tool(tool_name: str = Path(...)) -> ToolInfo: Raises: HTTPException: If tool not found """ - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import load_agent_config, save_agent_config + + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) - if tool_name not in config.tools.builtin_tools: + if ( + not agent_config.tools + or tool_name not in agent_config.tools.builtin_tools + ): raise HTTPException( status_code=404, detail=f"Tool '{tool_name}' not found", ) # Toggle enabled status - tool_config = config.tools.builtin_tools[tool_name] + tool_config = agent_config.tools.builtin_tools[tool_name] tool_config.enabled = not tool_config.enabled - # Save config - save_config(config) + # Save agent config + save_agent_config(workspace.agent_id, agent_config) + + # Hot reload config (async, non-blocking) + import asyncio + + async def reload_in_background(): + try: + manager = request.app.state.multi_agent_manager + await manager.reload_agent(workspace.agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + # Return immediately (optimistic update) return ToolInfo( name=tool_config.name, enabled=tool_config.enabled, diff --git a/src/copaw/app/routers/workspace.py b/src/copaw/app/routers/workspace.py index e93f9e5b2..94bf7e0a9 100644 --- a/src/copaw/app/routers/workspace.py +++ b/src/copaw/app/routers/workspace.py @@ -11,10 +11,9 @@ from datetime import datetime, timezone from pathlib import Path -from fastapi import APIRouter, HTTPException, UploadFile, File +from fastapi import APIRouter, HTTPException, UploadFile, File, Request from fastapi.responses import StreamingResponse -from ...constant import WORKING_DIR router = APIRouter(prefix="/workspace", tags=["workspace"]) @@ -54,7 +53,7 @@ def _zip_directory(root: Path) -> io.BytesIO: # --------------------------------------------------------------------------- -def _validate_zip_data(data: bytes) -> None: +def _validate_zip_data(data: bytes, workspace_dir: Path) -> None: """Ensure *data* is a valid zip without path-traversal entries.""" if not zipfile.is_zipfile(io.BytesIO(data)): raise HTTPException( @@ -63,16 +62,16 @@ def _validate_zip_data(data: bytes) -> None: ) with zipfile.ZipFile(io.BytesIO(data)) as zf: for name in zf.namelist(): - resolved = (WORKING_DIR / name).resolve() - if not str(resolved).startswith(str(WORKING_DIR)): + resolved = (workspace_dir / name).resolve() + if not str(resolved).startswith(str(workspace_dir)): raise HTTPException( status_code=400, detail=f"Zip contains unsafe path: {name}", ) -def _extract_and_merge_zip(data: bytes) -> None: - """Extract zip data and merge into WORKING_DIR (blocking operation).""" +def _extract_and_merge_zip(data: bytes, workspace_dir: Path) -> None: + """Extract zip data and merge into workspace_dir (blocking operation).""" tmp_dir = None try: tmp_dir = Path(tempfile.mkdtemp(prefix="copaw_upload_")) @@ -84,10 +83,10 @@ def _extract_and_merge_zip(data: bytes) -> None: if len(top_entries) == 1 and top_entries[0].is_dir(): extract_root = top_entries[0] - WORKING_DIR.mkdir(parents=True, exist_ok=True) + workspace_dir.mkdir(parents=True, exist_ok=True) for item in extract_root.iterdir(): - dest = WORKING_DIR / item.name + dest = workspace_dir / item.name if item.is_file(): shutil.copy2(item, dest) else: @@ -99,10 +98,10 @@ def _extract_and_merge_zip(data: bytes) -> None: shutil.rmtree(tmp_dir, ignore_errors=True) -def _validate_and_extract_zip(data: bytes) -> None: +def _validate_and_extract_zip(data: bytes, workspace_dir: Path) -> None: """Validate and extract zip data (blocking operation).""" - _validate_zip_data(data) - _extract_and_merge_zip(data) + _validate_zip_data(data, workspace_dir) + _extract_and_merge_zip(data, workspace_dir) # --------------------------------------------------------------------------- @@ -114,28 +113,33 @@ def _validate_and_extract_zip(data: bytes) -> None: "/download", summary="Download workspace as zip", description=( - "Package the entire WORKING_DIR into a zip archive and stream it " - "back as a downloadable file." + "Package the entire agent workspace into a zip archive and stream " + "it back as a downloadable file." ), responses={ 200: { "content": {"application/zip": {}}, - "description": "Zip archive of WORKING_DIR", + "description": "Zip archive of agent workspace", }, }, ) -async def download_workspace(): - """Stream WORKING_DIR as a zip file.""" - if not WORKING_DIR.is_dir(): +async def download_workspace(request: Request): + """Stream agent workspace as a zip file.""" + from ..agent_context import get_agent_for_request + + agent = await get_agent_for_request(request) + workspace_dir = agent.workspace_dir + + if not workspace_dir.is_dir(): raise HTTPException( status_code=404, - detail=f"WORKING_DIR does not exist: {WORKING_DIR}", + detail=f"Workspace does not exist: {workspace_dir}", ) - buf = await asyncio.to_thread(_zip_directory, WORKING_DIR) + buf = await asyncio.to_thread(_zip_directory, workspace_dir) timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") - filename = f"copaw_workspace_{timestamp}.zip" + filename = f"copaw_workspace_{agent.agent_id}_{timestamp}.zip" return StreamingResponse( buf, @@ -152,20 +156,24 @@ async def download_workspace(): summary="Upload zip and merge into workspace", description=( "Upload a zip archive. Paths present in the zip are merged into " - "WORKING_DIR (files overwritten, dirs merged). Paths not in the zip " - "are left unchanged (e.g. copaw.db, runtime dirs). Download packs " - "the entire WORKING_DIR; upload only overwrites/merges zip contents." + "agent workspace (files overwritten, dirs merged). Paths not in " + "the zip are left unchanged (e.g. copaw.db, runtime dirs). " + "Download packs the entire workspace; upload only " + "overwrites/merges zip contents." ), ) async def upload_workspace( + request: Request, file: UploadFile = File( ..., - description="Zip archive to merge into WORKING_DIR", + description="Zip archive to merge into agent workspace", ), ) -> dict: """ - Merge uploaded zip contents into WORKING_DIR (overwrite, do not clear). + Merge uploaded zip contents into agent workspace (overwrite, not clear). """ + from ..agent_context import get_agent_for_request + if file.content_type and file.content_type not in ( "application/zip", "application/x-zip-compressed", @@ -178,10 +186,12 @@ async def upload_workspace( ), ) + agent = await get_agent_for_request(request) + workspace_dir = agent.workspace_dir data = await file.read() try: - await asyncio.to_thread(_validate_and_extract_zip, data) + await asyncio.to_thread(_validate_and_extract_zip, data, workspace_dir) return {"success": True} except HTTPException: raise diff --git a/src/copaw/app/runner/api.py b/src/copaw/app/runner/api.py index 6862b7849..753cd1750 100644 --- a/src/copaw/app/runner/api.py +++ b/src/copaw/app/runner/api.py @@ -18,46 +18,44 @@ router = APIRouter(prefix="/chats", tags=["chats"]) -def get_chat_manager(request: Request) -> ChatManager: - """Get the chat manager from app state. +async def get_chat_manager( + request: Request, +) -> ChatManager: + """Get the chat manager for the active agent. Args: request: FastAPI request object Returns: - ChatManager instance + ChatManager instance for the specified agent Raises: HTTPException: If manager is not initialized """ - mgr = getattr(request.app.state, "chat_manager", None) - if mgr is None: - raise HTTPException( - status_code=503, - detail="Chat manager not initialized", - ) - return mgr + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + return workspace.chat_manager -def get_session(request: Request) -> SafeJSONSession: - """Get the session from app state. +async def get_session( + request: Request, +) -> SafeJSONSession: + """Get the session for the active agent. Args: request: FastAPI request object Returns: - SafeJSONSession instance + SafeJSONSession instance for the specified agent Raises: HTTPException: If session is not initialized """ - runner = getattr(request.app.state, "runner", None) - if runner is None: - raise HTTPException( - status_code=503, - detail="Session not initialized", - ) - return runner.session + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + return workspace.runner.session @router.get("", response_model=list[ChatSpec]) diff --git a/src/copaw/app/runner/command_dispatch.py b/src/copaw/app/runner/command_dispatch.py index 585dde2ef..25f6384a5 100644 --- a/src/copaw/app/runner/command_dispatch.py +++ b/src/copaw/app/runner/command_dispatch.py @@ -9,7 +9,6 @@ from typing import AsyncIterator from agentscope.message import Msg, TextBlock -from reme.memory.file_based.reme_in_memory_memory import ReMeInMemoryMemory from .daemon_commands import ( DaemonContext, @@ -118,6 +117,11 @@ async def run_command_path( return # Conversation path: lightweight memory + CommandHandler + # Lazy import to avoid module-level dependency errors + from reme.memory.file_based.reme_in_memory_memory import ( + ReMeInMemoryMemory, + ) + memory = ReMeInMemoryMemory(token_counter=_get_token_counter()) session_state = await runner.session.get_session_state_dict( session_id=session_id, diff --git a/src/copaw/app/runner/manager.py b/src/copaw/app/runner/manager.py index 19b59e408..5415ab5dd 100644 --- a/src/copaw/app/runner/manager.py +++ b/src/copaw/app/runner/manager.py @@ -35,6 +35,9 @@ def __init__( """ self._repo = repo self._lock = asyncio.Lock() + logger.info( + f"ChatManager created with repo path: {repo.path}", + ) # ----- Read Operations ----- @@ -53,6 +56,10 @@ async def list_chats( List of chat specifications """ async with self._lock: + logger.debug( + f"list_chats: repo path={self._repo.path}, " + f"filters: user_id={user_id}, channel={channel}", + ) return await self._repo.filter_chats( user_id=user_id, channel=channel, @@ -92,15 +99,27 @@ async def get_or_create_chat( """ async with self._lock: # Try to find existing by session_id + logger.debug( + f"get_or_create_chat: Searching for existing chat: " + f"session_id={session_id}, user_id={user_id}, " + f"channel={channel}", + ) existing = await self._repo.get_chat_by_id( session_id, user_id, channel, ) if existing: + logger.debug( + f"get_or_create_chat: Found existing chat: {existing.id}", + ) return existing # Create new + logger.debug( + f"get_or_create_chat: Creating new chat for " + f"session_id={session_id}", + ) spec = ChatSpec( session_id=session_id, user_id=user_id, diff --git a/src/copaw/app/runner/repo/base.py b/src/copaw/app/runner/repo/base.py index a0daa0afd..2ade1ffc8 100644 --- a/src/copaw/app/runner/repo/base.py +++ b/src/copaw/app/runner/repo/base.py @@ -60,14 +60,28 @@ async def get_chat_by_id( Returns: ChatSpec or None if not found """ + import logging + + logger = logging.getLogger(__name__) + cf = await self.load() + + logger.debug( + f"get_chat_by_id: Searching in {len(cf.chats)} chats for " + f"session_id={session_id}, user_id={user_id}, " + f"channel={channel}", + ) + for chat in cf.chats: if ( chat.session_id == session_id and chat.user_id == user_id and chat.channel == channel ): + logger.debug(f"get_chat_by_id: Found match: {chat.id}") return chat + + logger.debug("get_chat_by_id: No match found") return None async def upsert_chat(self, spec: ChatSpec) -> None: diff --git a/src/copaw/app/runner/runner.py b/src/copaw/app/runner/runner.py index 7db622279..4b5640228 100644 --- a/src/copaw/app/runner/runner.py +++ b/src/copaw/app/runner/runner.py @@ -26,7 +26,7 @@ from ...agents.memory import MemoryManager from ...agents.react_agent import CoPawAgent from ...security.tool_guard.models import TOOL_GUARD_DENIED_MARK -from ...config import load_config +from ...config.config import load_agent_config, AgentsRunningConfig from ...constant import ( TOOL_GUARD_APPROVAL_TIMEOUT_SECONDS, WORKING_DIR, @@ -37,9 +37,17 @@ class AgentRunner(Runner): - def __init__(self) -> None: + def __init__( + self, + agent_id: str = "default", + workspace_dir: Path | None = None, + ) -> None: super().__init__() self.framework_type = "agentscope" + self.agent_id = agent_id # Store agent_id for config loading + self.workspace_dir = ( + workspace_dir # Store workspace_dir for prompt building + ) self._chat_manager = None # Store chat_manager reference self._mcp_manager = None # MCP client manager for hot-reload self.memory_manager: MemoryManager | None = None @@ -149,6 +157,10 @@ async def query_handler( """ Handle agent query. """ + logger.debug( + f"AgentRunner.query_handler called: agent_id={self.agent_id}, " + f"msgs={msgs}, request={request}", + ) query = _get_last_user_text(msgs) session_id = getattr(request, "session_id", "") or "" @@ -172,6 +184,16 @@ async def query_handler( yield msg, last return + logger.debug( + f"AgentRunner.stream_query: request={request}, " + f"agent_id={self.agent_id}", + ) + + # Set agent context for model creation + from ..agent_context import set_current_agent_id + + set_current_agent_id(self.agent_id) + agent = None chat = None session_state_loaded = False @@ -199,7 +221,11 @@ async def query_handler( session_id=session_id, user_id=user_id, channel=channel, - working_dir=str(WORKING_DIR), + working_dir=( + str(self.workspace_dir) + if self.workspace_dir + else str(WORKING_DIR) + ), ) # Get MCP clients from manager (hot-reloadable) @@ -207,9 +233,17 @@ async def query_handler( if self._mcp_manager is not None: mcp_clients = await self._mcp_manager.get_clients() - config = load_config() - max_iters = config.agents.running.max_iters - max_input_length = config.agents.running.max_input_length + # Load agent-specific configuration + agent_config = load_agent_config(self.agent_id) + + # Get running config with defaults + running_config = agent_config.running + if running_config is None: + running_config = AgentsRunningConfig() + + max_iters = running_config.max_iters + max_input_length = running_config.max_input_length + language = agent_config.language agent = CoPawAgent( env_context=env_context, @@ -219,9 +253,22 @@ async def query_handler( "session_id": session_id, "user_id": user_id, "channel": channel, + "agent_id": self.agent_id, }, max_iters=max_iters, max_input_length=max_input_length, + memory_compact_threshold=( + running_config.memory_compact_threshold + ), + memory_compact_reserve=running_config.memory_compact_reserve, + enable_tool_result_compact=( + running_config.enable_tool_result_compact + ), + tool_result_compact_keep_n=( + running_config.tool_result_compact_keep_n + ), + language=language, + workspace_dir=self.workspace_dir, ) await agent.register_mcp_clients() agent.set_console_output_enabled(enabled=False) @@ -238,13 +285,31 @@ async def query_handler( else: name = "Media Message" + logger.debug( + f"DEBUG chat_manager status: " + f"_chat_manager={self._chat_manager}, " + f"is_none={self._chat_manager is None}, " + f"agent_id={self.agent_id}", + ) + if self._chat_manager is not None: + logger.debug( + f"Runner: Calling get_or_create_chat for " + f"session_id={session_id}, user_id={user_id}, " + f"channel={channel}, name={name}", + ) chat = await self._chat_manager.get_or_create_chat( session_id, user_id, channel, name=name, ) + logger.debug(f"Runner: Got chat: {chat.id}") + else: + logger.warning( + f"ChatManager is None! Cannot auto-register chat for " + f"session_id={session_id}", + ) try: await self.session.load_session_state( @@ -433,15 +498,23 @@ async def init_handler(self, *args, **kwargs): "using existing environment variables", ) - session_dir = str(WORKING_DIR / "sessions") + session_dir = str( + (self.workspace_dir if self.workspace_dir else WORKING_DIR) + / "sessions", + ) self.session = SafeJSONSession(save_dir=session_dir) + # Only create and start MemoryManager if not already set by Workspace try: if self.memory_manager is None: self.memory_manager = MemoryManager( - working_dir=str(WORKING_DIR), + working_dir=( + str(self.workspace_dir) + if self.workspace_dir + else str(WORKING_DIR) + ), ) - await self.memory_manager.start() + await self.memory_manager.start() except Exception as e: logger.exception(f"MemoryManager start failed: {e}") diff --git a/src/copaw/app/workspace.py b/src/copaw/app/workspace.py new file mode 100644 index 000000000..1d2fc9e4d --- /dev/null +++ b/src/copaw/app/workspace.py @@ -0,0 +1,415 @@ +# -*- coding: utf-8 -*- +"""Workspace: Encapsulates a complete independent agent runtime. + +Each Workspace represents a standalone agent workspace with its own: +- Runner (request processing) +- ChannelManager (communication channels) +- MemoryManager (conversation memory) +- MCPClientManager (MCP tool clients) +- CronManager (scheduled tasks) + +All existing single-agent components are reused without modification. +""" +import asyncio +import logging +from pathlib import Path +from typing import Optional, TYPE_CHECKING + +from .runner import AgentRunner +from .channels.utils import make_process_from_runner +from .mcp import MCPClientManager +from .crons.manager import CronManager +from .crons.repo.json_repo import JsonJobRepository +from ..agents.memory import MemoryManager +from ..config.config import load_agent_config, AgentsRunningConfig + +if TYPE_CHECKING: + from .channels.base import BaseChannel + +logger = logging.getLogger(__name__) + + +class Workspace: + """Single agent workspace with complete runtime components. + + Each Workspace is an independent agent instance with its own: + - Runner: Processes agent requests + - ChannelManager: Manages communication channels + - MemoryManager: Manages conversation memory + - MCPClientManager: Manages MCP tool clients + - CronManager: Manages scheduled tasks + + All components use existing single-agent code without modification. + """ + + def __init__(self, agent_id: str, workspace_dir: str): + """Initialize agent instance. + + Args: + agent_id: Unique agent identifier + workspace_dir: Path to agent's workspace directory + """ + self.agent_id = agent_id + self.workspace_dir = Path(workspace_dir).expanduser() + self.workspace_dir.mkdir(parents=True, exist_ok=True) + + # All components are None until start() is called (lazy loading) + self._runner: Optional[AgentRunner] = None + self._channel_manager: Optional["BaseChannel"] = None + self._memory_manager: Optional[MemoryManager] = None + self._mcp_manager: Optional[MCPClientManager] = None + self._cron_manager: Optional["CronManager"] = None + self._chat_manager = None + self._config = None + self._config_watcher = None + self._mcp_config_watcher = None + self._started = False + + logger.debug( + f"Created Workspace: {agent_id} at {self.workspace_dir}", + ) + + @property + def runner(self) -> Optional[AgentRunner]: + """Get runner instance.""" + return self._runner + + @property + def channel_manager(self) -> Optional["BaseChannel"]: + """Get channel manager instance.""" + return self._channel_manager + + @property + def memory_manager(self) -> Optional[MemoryManager]: + """Get memory manager instance.""" + return self._memory_manager + + @property + def mcp_manager(self) -> Optional[MCPClientManager]: + """Get MCP client manager instance.""" + return self._mcp_manager + + @property + def cron_manager(self) -> Optional["CronManager"]: + """Get cron manager instance.""" + return self._cron_manager + + @property + def chat_manager(self): + """Get chat manager instance.""" + return self._chat_manager + + @property + def config(self): + """Get agent configuration.""" + if self._config is None: + self._config = load_agent_config(self.agent_id) + return self._config + + async def start(self): # pylint: disable=too-many-statements + """Start workspace and initialize all components concurrently.""" + if self._started: + logger.debug(f"Workspace already started: {self.agent_id}") + return + + logger.info(f"Starting workspace: {self.agent_id}") + + try: + # 1. Load agent configuration from workspace/agent.json + self._config = load_agent_config(self.agent_id) + agent_config = self._config + logger.debug(f"Loaded config for agent: {self.agent_id}") + + # 2. Create Runner + self._runner = AgentRunner( + agent_id=self.agent_id, + workspace_dir=self.workspace_dir, + ) + + # 3. Concurrently initialize MemoryManager and MCPManager + # IMPORTANT: Create MemoryManager BEFORE runner.start() to prevent + # init_handler from creating a duplicate MemoryManager + async def init_memory(): + # Get running config for memory manager + running_config = agent_config.running + + if running_config is None: + running_config = AgentsRunningConfig() + + self._memory_manager = MemoryManager( + working_dir=str(self.workspace_dir), + max_input_length=running_config.max_input_length, + memory_compact_ratio=running_config.memory_compact_ratio, + memory_reserve_ratio=running_config.memory_reserve_ratio, + language=agent_config.language, + ) + # Assign to runner BEFORE starting runner + self._runner.memory_manager = self._memory_manager + await self._memory_manager.start() + logger.debug( + f"MemoryManager started for agent: {self.agent_id}", + ) + + async def init_mcp(): + self._mcp_manager = MCPClientManager() + if agent_config.mcp: + try: + await self._mcp_manager.init_from_config( + agent_config.mcp, + ) + logger.debug( + f"MCP clients initialized for agent: " + f"{self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Failed to initialize MCP for agent " + f"{self.agent_id}: {e}", + ) + self._runner.set_mcp_manager(self._mcp_manager) + + async def init_chat(): + from .runner.manager import ChatManager + from .runner.repo.json_repo import JsonChatRepository + + chats_path = str(self.workspace_dir / "chats.json") + chat_repo = JsonChatRepository(chats_path) + self._chat_manager = ChatManager(repo=chat_repo) + self._runner.set_chat_manager(self._chat_manager) + logger.info( + f"ChatManager started for agent {self.agent_id}: " + f"chats.json={chats_path}", + ) + + # Run Memory, MCP, and Chat initialization concurrently + await asyncio.gather(init_memory(), init_mcp(), init_chat()) + + # Now start the runner (after MemoryManager is set) + await self._runner.start() + logger.debug(f"Runner started for agent: {self.agent_id}") + + # Set up restart callback for /daemon restart command + from .workspace_restart import create_restart_callback + + setattr( + self._runner, + "_restart_callback", + create_restart_callback(self), + ) + + # 4. Start ChannelManager (depends on Runner) + if agent_config.channels: + from ..config import Config, update_last_dispatch + from .channels.manager import ChannelManager + + temp_config = Config(channels=agent_config.channels) + + self._channel_manager = ChannelManager.from_config( + process=make_process_from_runner(self._runner), + config=temp_config, + on_last_dispatch=update_last_dispatch, + workspace_dir=self.workspace_dir, + ) + await self._channel_manager.start_all() + logger.debug( + f"ChannelManager started for agent: {self.agent_id}", + ) + + # 5. Start CronManager (always create for API access) + job_repo = JsonJobRepository( + str(self.workspace_dir / "jobs.json"), + ) + self._cron_manager = CronManager( + repo=job_repo, + runner=self._runner, + channel_manager=self._channel_manager, + timezone="UTC", + ) + # Only start background tasks if heartbeat is enabled + if agent_config.heartbeat and agent_config.heartbeat.enabled: + await self._cron_manager.start() + logger.debug( + f"CronManager started with heartbeat: {self.agent_id}", + ) + else: + logger.debug( + f"CronManager created (heartbeat disabled): " + f"{self.agent_id}", + ) + + # 6. Start config watchers for hot-reload (non-blocking) + await self._start_config_watchers() + + self._started = True + logger.info( + f"Workspace started successfully: {self.agent_id}", + ) + + except Exception as e: + logger.error( + f"Failed to start agent instance {self.agent_id}: {e}", + ) + # Clean up partially started components + await self.stop() + raise + + async def stop(self): + """Stop agent instance and clean up all resources.""" + if not self._started: + logger.debug(f"Workspace not started: {self.agent_id}") + return + + logger.info(f"Stopping agent instance: {self.agent_id}") + + # Stop components in reverse order + + # 0. Stop config watchers first + await self._stop_config_watchers() + + # 1. Stop CronManager + if self._cron_manager: + try: + await self._cron_manager.stop() + logger.debug( + f"CronManager stopped for agent: {self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Error stopping CronManager for agent " + f"{self.agent_id}: {e}", + ) + + # 2. Stop ChannelManager + if self._channel_manager: + try: + await self._channel_manager.stop_all() + logger.debug( + f"ChannelManager stopped for agent: {self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Error stopping ChannelManager for agent " + f"{self.agent_id}: {e}", + ) + + # 3. Stop MCPClientManager + if self._mcp_manager: + try: + await self._mcp_manager.close_all() + logger.debug( + f"MCPClientManager stopped for agent: " f"{self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Error stopping MCPClientManager for agent " + f"{self.agent_id}: {e}", + ) + + # 4. Stop MemoryManager + if self._memory_manager: + try: + await self._memory_manager.close() + logger.debug( + f"MemoryManager stopped for agent: " f"{self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Error stopping MemoryManager for agent " + f"{self.agent_id}: {e}", + ) + + # 5. Clear ChatManager reference (no stop method) + if self._chat_manager: + self._chat_manager = None + logger.debug( + f"ChatManager cleared for agent: {self.agent_id}", + ) + + # 6. Stop Runner + if self._runner: + try: + await self._runner.stop() + logger.debug(f"Runner stopped for agent: {self.agent_id}") + except Exception as e: + logger.warning( + f"Error stopping Runner for agent {self.agent_id}: {e}", + ) + + self._started = False + logger.info(f"Workspace stopped: {self.agent_id}") + + async def reload(self): + """Reload agent instance (stop and start with fresh configuration).""" + logger.info(f"Reloading agent instance: {self.agent_id}") + self._config = None # Clear cached config + await self.stop() + await self.start() + logger.info(f"Agent instance reloaded: {self.agent_id}") + + async def _start_config_watchers(self): + """Start config watchers for hot-reload of agent.json changes.""" + try: + # Start AgentConfigWatcher for channels and heartbeat + if self._channel_manager or self._cron_manager: + from .agent_config_watcher import AgentConfigWatcher + + self._config_watcher = AgentConfigWatcher( + agent_id=self.agent_id, + workspace_dir=self.workspace_dir, + channel_manager=self._channel_manager, + cron_manager=self._cron_manager, + ) + await self._config_watcher.start() + + # Start MCPConfigWatcher for MCP client hot-reload + if self._mcp_manager: + from .mcp.watcher import MCPConfigWatcher + + def mcp_config_loader(): + """Load MCP config from agent.json.""" + agent_config = load_agent_config(self.agent_id) + return agent_config.mcp + + self._mcp_config_watcher = MCPConfigWatcher( + mcp_manager=self._mcp_manager, + config_loader=mcp_config_loader, + config_path=self.workspace_dir / "agent.json", + ) + await self._mcp_config_watcher.start() + + except Exception as e: + logger.warning( + f"Failed to start config watchers for agent " + f"{self.agent_id}: {e}", + ) + + async def _stop_config_watchers(self): + """Stop config watchers.""" + if self._config_watcher: + try: + await self._config_watcher.stop() + except Exception as e: + logger.warning( + f"Error stopping AgentConfigWatcher for agent " + f"{self.agent_id}: {e}", + ) + self._config_watcher = None + + if self._mcp_config_watcher: + try: + await self._mcp_config_watcher.stop() + except Exception as e: + logger.warning( + f"Error stopping MCPConfigWatcher for agent " + f"{self.agent_id}: {e}", + ) + self._mcp_config_watcher = None + + def __repr__(self) -> str: + """String representation of workspace.""" + status = "started" if self._started else "stopped" + return ( + f"Workspace(id={self.agent_id}, " + f"workspace={self.workspace_dir}, " + f"status={status})" + ) diff --git a/src/copaw/app/workspace_restart.py b/src/copaw/app/workspace_restart.py new file mode 100644 index 000000000..5f2f50f2a --- /dev/null +++ b/src/copaw/app/workspace_restart.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +"""Workspace-level restart logic. + +This module provides workspace-scoped restart functionality for the +/daemon restart command. Each workspace can reload its own components +(channels, cron, MCP) without affecting other workspaces. +""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .workspace import Workspace + +logger = logging.getLogger(__name__) + + +async def restart_workspace(workspace: "Workspace") -> None: + """Restart a single workspace's components (channels, cron, MCP). + + This function performs an in-process reload of workspace components: + 1. Reloads agent configuration from agent.json + 2. Calls workspace.reload() to restart managers + + Args: + workspace: The workspace instance to restart + + Raises: + Exception: If restart fails + """ + logger.info(f"Restarting workspace: {workspace.agent_id}") + + try: + # Reload the workspace (hot reload all managers) + await workspace.reload() + + logger.info( + f"Workspace restart completed: {workspace.agent_id}", + ) + + except Exception as e: + logger.exception( + f"Failed to restart workspace {workspace.agent_id}: {e}", + ) + raise + + +def create_restart_callback(workspace: "Workspace"): + """Create a restart callback for a workspace's runner. + + This creates a closure that captures the workspace instance and + provides it as a callback for the /daemon restart command. + + Args: + workspace: The workspace instance + + Returns: + Async callable that restarts the workspace + """ + + async def _restart_callback() -> None: + """Restart callback for runner.""" + await restart_workspace(workspace) + + return _restart_callback diff --git a/src/copaw/cli/channels_cmd.py b/src/copaw/cli/channels_cmd.py index 7c8a040e1..d682a4ee5 100644 --- a/src/copaw/cli/channels_cmd.py +++ b/src/copaw/cli/channels_cmd.py @@ -3,6 +3,7 @@ from __future__ import annotations from types import SimpleNamespace +from pathlib import Path import click @@ -21,6 +22,8 @@ IMessageChannelConfig, QQConfig, VoiceChannelConfig, + load_agent_config, + save_agent_config, ) from .utils import prompt_confirm, prompt_path, prompt_select from ..config import get_available_channels @@ -30,6 +33,7 @@ get_channel_registry, ) + # Fields that contain secrets — display masked in ``list`` _SECRET_FIELDS = { "bot_token", @@ -816,41 +820,51 @@ def _channel_enabled(ch) -> bool: @channels_group.command("list") -def list_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def list_cmd(agent_id: str) -> None: """Show current channel configuration.""" - config_path = get_config_path() + try: + agent_config = load_agent_config(agent_id) + click.echo(f"Channels for agent: {agent_id}\n") - if not config_path.is_file(): - click.echo(f"Config not found: {config_path}") - click.echo("Will load default config.") - click.echo("Run `copaw channels config` to create one.") - cfg = load_config() - else: - cfg = load_config(config_path) - - extra = getattr(cfg.channels, "__pydantic_extra__", None) or {} - for key, name in _get_channel_names().items(): - ch = getattr(cfg.channels, key, None) - if ch is None: - ch = extra.get(key) - if ch is None: - continue - status = ( - click.style("enabled", fg="green") - if _channel_enabled(ch) - else click.style("disabled", fg="red") - ) - click.echo(f"\n{'─' * 40}") - click.echo(f" {name} [{status}]") - click.echo(f"{'─' * 40}") + if not agent_config.channels: + click.echo("No channels configured for this agent.") + return - for field_name, value in _channel_config_fields(ch): - display = ( - _mask(str(value)) if field_name in _SECRET_FIELDS else value + extra = ( + getattr(agent_config.channels, "__pydantic_extra__", None) or {} + ) + for key, name in _get_channel_names().items(): + ch = getattr(agent_config.channels, key, None) + if ch is None: + ch = extra.get(key) + if ch is None: + continue + status = ( + click.style("enabled", fg="green") + if _channel_enabled(ch) + else click.style("disabled", fg="red") ) - click.echo(f" {field_name:20s}: {display}") + click.echo(f"\n{'─' * 40}") + click.echo(f" {name} [{status}]") + click.echo(f"{'─' * 40}") + + for field_name, value in _channel_config_fields(ch): + display = ( + _mask(str(value)) + if field_name in _SECRET_FIELDS + else value + ) + click.echo(f" {field_name:20s}: {display}") - click.echo() + click.echo() + except ValueError as e: + click.echo(f"Error: {e}", err=True) + raise SystemExit(1) from e def _install_channel_to_dir( @@ -872,8 +886,6 @@ def _install_channel_to_dir( dest_dir = CUSTOM_CHANNELS_DIR / key if from_path: - from pathlib import Path - src = Path(from_path).resolve() if not src.exists(): click.echo(f"Path not found: {src}", err=True) @@ -1051,17 +1063,31 @@ def remove_cmd(key: str, keep_config: bool) -> None: @channels_group.command("config") -def configure_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def configure_cmd(agent_id: str) -> None: """Interactively configure channels.""" - config_path = get_config_path() - working_dir = config_path.parent - - click.echo(f"Working dir: {working_dir}") - working_dir.mkdir(parents=True, exist_ok=True) - - existing = load_config(config_path) if config_path.is_file() else Config() + try: + agent_config = load_agent_config(agent_id) + click.echo(f"Configuring channels for agent: {agent_id}\n") + + # Create a temporary Config object for the interactive configurator + temp_config = Config() + temp_config.channels = ( + agent_config.channels + if agent_config.channels + else temp_config.channels + ) - configure_channels_interactive(existing) + configure_channels_interactive(temp_config) - save_config(existing, config_path) - click.echo(f"\n✓ Configuration saved to {config_path}") + # Save back to agent config + agent_config.channels = temp_config.channels + save_agent_config(agent_id, agent_config) + click.echo(f"\n✓ Configuration saved for agent {agent_id}") + except ValueError as e: + click.echo(f"Error: {e}", err=True) + raise SystemExit(1) from e diff --git a/src/copaw/cli/chats_cmd.py b/src/copaw/cli/chats_cmd.py index 2b1eea7bd..aa6b835dd 100644 --- a/src/copaw/cli/chats_cmd.py +++ b/src/copaw/cli/chats_cmd.py @@ -55,12 +55,18 @@ def chats_group() -> None: default=None, help="Override API base URL, e.g. http://127.0.0.1:8088", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def list_chats( ctx: click.Context, user_id: Optional[str], channel: Optional[str], base_url: Optional[str], + agent_id: str, ) -> None: """List all chats, optionally filtered by user_id or channel. @@ -78,7 +84,8 @@ def list_chats( if channel: params["channel"] = channel with client(base_url) as c: - r = c.get("/chats", params=params) + headers = {"X-Agent-Id": agent_id} + r = c.get("/chats", params=params, headers=headers) r.raise_for_status() print_json(r.json()) @@ -86,11 +93,17 @@ def list_chats( @chats_group.command("get") @click.argument("chat_id") @click.option("--base-url", default=None, help="Override API base URL") +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def get_chat( ctx: click.Context, chat_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """View details of a specific chat (including message history). @@ -103,7 +116,8 @@ def get_chat( """ base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.get(f"/chats/{chat_id}") + headers = {"X-Agent-Id": agent_id} + r = c.get(f"/chats/{chat_id}", headers=headers) if r.status_code == 404: raise click.ClickException(f"chat not found: {chat_id}") r.raise_for_status() @@ -143,6 +157,11 @@ def get_chat( ), ) @click.option("--base-url", default=None, help="Override API base URL") +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def create_chat( ctx: click.Context, @@ -152,6 +171,7 @@ def create_chat( user_id: Optional[str], channel: str, base_url: Optional[str], + agent_id: str, ) -> None: """Create a new chat. @@ -189,7 +209,8 @@ def create_chat( "meta": {}, } with client(base_url) as c: - r = c.post("/chats", json=payload) + headers = {"X-Agent-Id": agent_id} + r = c.post("/chats", json=payload, headers=headers) r.raise_for_status() print_json(r.json()) @@ -198,12 +219,18 @@ def create_chat( @click.argument("chat_id") @click.option("--name", required=True, help="New chat name") @click.option("--base-url", default=None, help="Override API base URL") +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def update_chat( ctx: click.Context, chat_id: str, name: str, base_url: Optional[str], + agent_id: str, ) -> None: """Update chat name. @@ -215,10 +242,10 @@ def update_chat( copaw chats update --name "Renamed Chat" """ base_url = _base_url(ctx, base_url) - + headers = {"X-Agent-Id": agent_id} # Fetch existing spec, then patch name with client(base_url) as c: - r = c.get("/chats") + r = c.get("/chats", headers=headers) r.raise_for_status() specs = r.json() @@ -229,7 +256,7 @@ def update_chat( payload["name"] = name with client(base_url) as c: - r = c.put(f"/chats/{chat_id}", json=payload) + r = c.put(f"/chats/{chat_id}", json=payload, headers=headers) if r.status_code == 404: raise click.ClickException(f"chat not found: {chat_id}") r.raise_for_status() @@ -239,11 +266,17 @@ def update_chat( @chats_group.command("delete") @click.argument("chat_id") @click.option("--base-url", default=None, help="Override API base URL") +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def delete_chat( ctx: click.Context, chat_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Delete a specific chat. @@ -258,7 +291,8 @@ def delete_chat( """ base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.delete(f"/chats/{chat_id}") + headers = {"X-Agent-Id": agent_id} + r = c.delete(f"/chats/{chat_id}", headers=headers) if r.status_code == 404: raise click.ClickException(f"chat not found: {chat_id}") r.raise_for_status() diff --git a/src/copaw/cli/cron_cmd.py b/src/copaw/cli/cron_cmd.py index 8b6ea22af..6dd230ec5 100644 --- a/src/copaw/cli/cron_cmd.py +++ b/src/copaw/cli/cron_cmd.py @@ -42,12 +42,22 @@ def cron_group() -> None: "If omitted, uses global --host and --port from config." ), ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context -def list_jobs(ctx: click.Context, base_url: Optional[str]) -> None: +def list_jobs( + ctx: click.Context, + base_url: Optional[str], + agent_id: str, +) -> None: """List all cron jobs. Output is JSON from GET /cron/jobs.""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.get("/cron/jobs") + headers = {"X-Agent-Id": agent_id} + r = c.get("/cron/jobs", headers=headers) r.raise_for_status() print_json(r.json()) @@ -59,12 +69,23 @@ def list_jobs(ctx: click.Context, base_url: Optional[str]) -> None: default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context -def get_job(ctx: click.Context, job_id: str, base_url: Optional[str]) -> None: +def get_job( + ctx: click.Context, + job_id: str, + base_url: Optional[str], + agent_id: str, +) -> None: """Fetch a cron job by ID. Returns JSON from GET /cron/jobs/.""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.get(f"/cron/jobs/{job_id}") + headers = {"X-Agent-Id": agent_id} + r = c.get(f"/cron/jobs/{job_id}", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -78,16 +99,23 @@ def get_job(ctx: click.Context, job_id: str, base_url: Optional[str]) -> None: default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def job_state( ctx: click.Context, job_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Get the runtime state of a cron job (e.g. next run time, paused).""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.get(f"/cron/jobs/{job_id}/state") + headers = {"X-Agent-Id": agent_id} + r = c.get(f"/cron/jobs/{job_id}/state", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -262,6 +290,11 @@ def _build_spec_from_cli( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def create_job( ctx: click.Context, @@ -277,6 +310,7 @@ def create_job( enabled: bool, mode: str, base_url: Optional[str], + agent_id: str, ) -> None: """Create a cron job. @@ -317,7 +351,8 @@ def create_job( mode=mode, ) with client(base_url) as c: - r = c.post("/cron/jobs", json=payload) + headers = {"X-Agent-Id": agent_id} + r = c.post("/cron/jobs", json=payload, headers=headers) r.raise_for_status() print_json(r.json()) @@ -329,16 +364,23 @@ def create_job( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def delete_job( ctx: click.Context, job_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Permanently delete a cron job. The job is removed from the server.""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.delete(f"/cron/jobs/{job_id}") + headers = {"X-Agent-Id": agent_id} + r = c.delete(f"/cron/jobs/{job_id}", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -352,18 +394,25 @@ def delete_job( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def pause_job( ctx: click.Context, job_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Pause a cron job so it no longer runs on schedule. Use 'resume' to re-enable. """ base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.post(f"/cron/jobs/{job_id}/pause") + headers = {"X-Agent-Id": agent_id} + r = c.post(f"/cron/jobs/{job_id}/pause", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -377,16 +426,23 @@ def pause_job( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def resume_job( ctx: click.Context, job_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Resume a paused cron job so it runs again on its schedule.""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.post(f"/cron/jobs/{job_id}/resume") + headers = {"X-Agent-Id": agent_id} + r = c.post(f"/cron/jobs/{job_id}/resume", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -400,12 +456,23 @@ def resume_job( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context -def run_job(ctx: click.Context, job_id: str, base_url: Optional[str]) -> None: +def run_job( + ctx: click.Context, + job_id: str, + base_url: Optional[str], + agent_id: str, +) -> None: """Trigger a one-off run of a cron job immediately (ignores schedule).""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.post(f"/cron/jobs/{job_id}/run") + headers = {"X-Agent-Id": agent_id} + r = c.post(f"/cron/jobs/{job_id}/run", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() diff --git a/src/copaw/cli/daemon_cmd.py b/src/copaw/cli/daemon_cmd.py index 446b3e4cd..aa621f283 100644 --- a/src/copaw/cli/daemon_cmd.py +++ b/src/copaw/cli/daemon_cmd.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +from pathlib import Path import click @@ -18,11 +19,26 @@ run_daemon_version, ) from ..constant import WORKING_DIR +from ..config import load_config -def _context() -> DaemonContext: +def _get_agent_workspace(agent_id: str) -> Path: + """Get agent workspace directory.""" + try: + config = load_config() + if agent_id in config.agents.profiles: + ref = config.agents.profiles[agent_id] + workspace_dir = Path(ref.workspace_dir).expanduser() + return workspace_dir + except Exception: + pass + return WORKING_DIR + + +def _context(agent_id: str) -> DaemonContext: + working_dir = _get_agent_workspace(agent_id) return DaemonContext( - working_dir=WORKING_DIR, + working_dir=working_dir, memory_manager=None, restart_callback=None, ) @@ -34,30 +50,54 @@ def daemon_group() -> None: @daemon_group.command("status") -def status_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def status_cmd(agent_id: str) -> None: """Show daemon status (config, working dir, memory manager).""" - ctx = _context() + ctx = _context(agent_id) + click.echo(f"Agent: {agent_id}\n") click.echo(run_daemon_status(ctx)) @daemon_group.command("restart") -def restart_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def restart_cmd(agent_id: str) -> None: """Print restart instructions (CLI has no process to restart).""" - ctx = _context() + ctx = _context(agent_id) + click.echo(f"Agent: {agent_id}\n") click.echo(asyncio.run(run_daemon_restart(ctx))) @daemon_group.command("reload-config") -def reload_config_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def reload_config_cmd(agent_id: str) -> None: """Reload config (re-read from file).""" - ctx = _context() + ctx = _context(agent_id) + click.echo(f"Agent: {agent_id}\n") click.echo(run_daemon_reload_config(ctx)) @daemon_group.command("version") -def version_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def version_cmd(agent_id: str) -> None: """Show version and paths.""" - ctx = _context() + ctx = _context(agent_id) + click.echo(f"Agent: {agent_id}\n") click.echo(run_daemon_version(ctx)) @@ -69,8 +109,14 @@ def version_cmd() -> None: type=int, help="Number of last lines to show (default 100).", ) -def logs_cmd(lines: int) -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def logs_cmd(lines: int, agent_id: str) -> None: """Tail last N lines of WORKING_DIR/copaw.log.""" - ctx = _context() + ctx = _context(agent_id) lines = min(max(1, lines), 2000) + click.echo(f"Agent: {agent_id}\n") click.echo(run_daemon_logs(ctx, lines=lines)) diff --git a/src/copaw/cli/init_cmd.py b/src/copaw/cli/init_cmd.py index ace9dca3a..90fc512f3 100644 --- a/src/copaw/cli/init_cmd.py +++ b/src/copaw/cli/init_cmd.py @@ -140,6 +140,9 @@ def init_cmd( accept_security: bool, ) -> None: """Create working dir with config.json and HEARTBEAT.md (interactive).""" + from pathlib import Path + from ..app.migration import ensure_default_agent_exists + config_path = get_config_path() working_dir = config_path.parent heartbeat_path = get_heartbeat_query_path() @@ -184,6 +187,14 @@ def init_cmd( else: mark_telemetry_collected(working_dir) + # --- Ensure default agent workspace exists --- + click.echo("\n=== Default Workspace Initialization ===") + ensure_default_agent_exists() + click.echo("✓ Default workspace initialized") + + # Get default workspace path for subsequent operations + default_workspace = Path("~/.copaw/workspaces/default").expanduser() + # --- config.json --- write_config = True if config_path.is_file() and not force and not use_defaults: @@ -242,6 +253,11 @@ def init_cmd( existing = ( load_config(config_path) if config_path.is_file() else Config() ) + # Ensure agents.defaults exists + if existing.agents.defaults is None: + from ..config.config import AgentsDefaultsConfig + + existing.agents.defaults = AgentsDefaultsConfig() existing.agents.defaults.heartbeat = hb # --- show_tool_details --- @@ -306,6 +322,7 @@ def init_cmd( click.echo("Enabling all skills by default (skip existing)...") synced, skipped = sync_skills_to_working_dir( + workspace_dir=default_workspace, skill_names=None, force=False, ) @@ -328,6 +345,7 @@ def init_cmd( click.echo("Enabling all skills...") synced, skipped = sync_skills_to_working_dir( + workspace_dir=default_workspace, skill_names=None, force=False, ) @@ -351,14 +369,20 @@ def init_cmd( from ..agents.utils import copy_md_files config = load_config(config_path) if config_path.is_file() else Config() - current_language = config.agents.language + current_language = ( + config.agents.language or "zh" + ) # Default to "zh" if None installed_language = config.agents.installed_md_files_language if use_defaults: # --defaults: always attempt copy, skip files that already exist - # in WORKING_DIR (handles freshly mounted empty volumes). + # in default workspace (handles freshly mounted empty volumes). click.echo(f"\nChecking MD files [language: {current_language}]...") - copied = copy_md_files(current_language, skip_existing=True) + copied = copy_md_files( + current_language, + skip_existing=True, + workspace_dir=default_workspace, + ) if copied: config.agents.installed_md_files_language = current_language save_config(config, config_path) @@ -373,7 +397,10 @@ def init_cmd( click.echo( f"Language changed: {installed_language} → {current_language}", ) - copied = copy_md_files(current_language) + copied = copy_md_files( + current_language, + workspace_dir=default_workspace, + ) if copied: config.agents.installed_md_files_language = current_language save_config(config, config_path) diff --git a/src/copaw/cli/skills_cmd.py b/src/copaw/cli/skills_cmd.py index 4c4e15eb1..5a6221302 100644 --- a/src/copaw/cli/skills_cmd.py +++ b/src/copaw/cli/skills_cmd.py @@ -2,21 +2,47 @@ """CLI skill: list and interactively enable/disable skills.""" from __future__ import annotations +from pathlib import Path + import click from ..agents.skills_manager import SkillService, list_available_skills +from ..constant import WORKING_DIR +from ..config import load_config from .utils import prompt_checkbox, prompt_confirm +def _get_agent_workspace(agent_id: str) -> Path: + """Get agent workspace directory.""" + try: + config = load_config() + if agent_id in config.agents.profiles: + ref = config.agents.profiles[agent_id] + workspace_dir = Path(ref.workspace_dir).expanduser() + return workspace_dir + except Exception: + pass + return WORKING_DIR + + # pylint: disable=too-many-branches -def configure_skills_interactive() -> None: +def configure_skills_interactive( + agent_id: str = "default", + working_dir: Path | None = None, +) -> None: """Interactively select which skills to enable (multi-select).""" - all_skills = SkillService.list_all_skills() + if working_dir is None: + working_dir = _get_agent_workspace(agent_id) + + click.echo(f"Configuring skills for agent: {agent_id}\n") + + skill_service = SkillService(working_dir) + all_skills = skill_service.list_all_skills() if not all_skills: click.echo("No skills found. Nothing to configure.") return - available = set(list_available_skills()) + available = set(list_available_skills(working_dir)) all_names = {s.name for s in all_skills} # Default to all skills if nothing is currently active (first time) @@ -78,7 +104,7 @@ def configure_skills_interactive() -> None: # Apply changes for name in to_enable: - result = SkillService.enable_skill(name) + result = skill_service.enable_skill(name) if result: click.echo(f" ✓ Enabled: {name}") else: @@ -87,7 +113,7 @@ def configure_skills_interactive() -> None: ) for name in to_disable: - result = SkillService.disable_skill(name) + result = skill_service.disable_skill(name) if result: click.echo(f" ✓ Disabled: {name}") else: @@ -104,10 +130,20 @@ def skills_group() -> None: @skills_group.command("list") -def list_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def list_cmd(agent_id: str) -> None: """Show all skills and their enabled/disabled status.""" - all_skills = SkillService.list_all_skills() - available = set(list_available_skills()) + working_dir = _get_agent_workspace(agent_id) + + click.echo(f"Skills for agent: {agent_id}\n") + + skill_service = SkillService(working_dir) + all_skills = skill_service.list_all_skills() + available = set(list_available_skills(working_dir)) if not all_skills: click.echo("No skills found.") @@ -135,5 +171,11 @@ def list_cmd() -> None: @skills_group.command("config") -def configure_cmd() -> None: - configure_skills_interactive() +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def configure_cmd(agent_id: str) -> None: + """Interactively configure skills.""" + configure_skills_interactive(agent_id=agent_id) diff --git a/src/copaw/config/__init__.py b/src/copaw/config/__init__.py index d8a8add6b..0dad53827 100644 --- a/src/copaw/config/__init__.py +++ b/src/copaw/config/__init__.py @@ -22,15 +22,15 @@ update_last_dispatch, ) -# ConfigWatcher is provided by __getattr__ (lazy-loaded). -# pylint: disable=undefined-all-variable __all__ = [ "AgentsRunningConfig", "Config", "ChannelConfig", "ChannelConfigUnion", "HeartbeatConfig", - "ConfigWatcher", + "SecurityConfig", + "ToolGuardConfig", + "ToolGuardRuleConfig", "get_available_channels", "get_config_path", "get_heartbeat_config", @@ -42,12 +42,3 @@ "save_config", "update_last_dispatch", ] - - -def __getattr__(name: str): - """Lazy-load ConfigWatcher to avoid pulling app.channels/lark_oapi.""" - if name == "ConfigWatcher": - from .watcher import ConfigWatcher - - return ConfigWatcher - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/copaw/config/config.py b/src/copaw/config/config.py index 0d62c6e26..a98a1728e 100644 --- a/src/copaw/config/config.py +++ b/src/copaw/config/config.py @@ -1,7 +1,10 @@ # -*- coding: utf-8 -*- import os +import json +from pathlib import Path from typing import Optional, Union, Dict, List, Literal from pydantic import BaseModel, Field, ConfigDict, model_validator +import shortuuid from ..providers.models import ModelSlotConfig from ..constant import ( @@ -11,6 +14,15 @@ from .timezone import detect_system_timezone +def generate_short_agent_id() -> str: + """Generate a 6-character short UUID for agent identification. + + Returns: + 6-character short UUID string + """ + return shortuuid.ShortUUID().random(length=6) + + class BaseChannelConfig(BaseModel): """Base for channel config (read from config.json, no env).""" @@ -28,7 +40,7 @@ class BaseChannelConfig(BaseModel): class IMessageChannelConfig(BaseChannelConfig): db_path: str = "~/Library/Messages/chat.db" poll_sec: float = 1.0 - media_dir: str = "~/.copaw/media" + media_dir: Optional[str] = None max_decoded_size: int = ( 10 * 1024 * 1024 ) # 10MB default limit for Base64 data @@ -43,7 +55,7 @@ class DiscordConfig(BaseChannelConfig): class DingTalkConfig(BaseChannelConfig): client_id: str = "" client_secret: str = "" - media_dir: str = "~/.copaw/media" + media_dir: Optional[str] = None class FeishuConfig(BaseChannelConfig): @@ -55,7 +67,7 @@ class FeishuConfig(BaseChannelConfig): app_secret: str = "" encrypt_key: str = "" verification_token: str = "" - media_dir: str = "~/.copaw/media" + media_dir: Optional[str] = None class QQConfig(BaseChannelConfig): @@ -92,7 +104,7 @@ class MattermostConfig(BaseChannelConfig): url: str = "" bot_token: str = "" - media_dir: str = "~/.copaw/media/mattermost" + media_dir: Optional[str] = None show_typing: Optional[bool] = None thread_follow_without_mention: bool = False @@ -108,7 +120,7 @@ class WecomConfig(BaseChannelConfig): bot_id: str = "" secret: str = "" - media_dir: str = "~/.copaw/media" + media_dir: Optional[str] = None welcome_text: str = "" max_reconnect_attempts: int = -1 @@ -277,28 +289,107 @@ class AgentsLLMRoutingConfig(BaseModel): ) -class AgentsConfig(BaseModel): - defaults: AgentsDefaultsConfig = Field( - default_factory=AgentsDefaultsConfig, +class AgentProfileRef(BaseModel): + """Agent Profile reference (stored in root config.json). + + Only contains ID and workspace directory reference. + Full agent configuration is stored in workspace/agent.json. + """ + + id: str = Field(..., description="Unique agent ID") + workspace_dir: str = Field( + ..., + description="Path to agent's workspace directory", + ) + + +class AgentProfileConfig(BaseModel): + """Complete Agent Profile configuration (stored in workspace/agent.json). + + Each agent has its own configuration file with all settings. + """ + + id: str = Field(..., description="Unique agent ID") + name: str = Field(..., description="Human-readable agent name") + description: str = Field(default="", description="Agent description") + workspace_dir: str = Field( + default="", + description="Path to agent's workspace (optional, for reference)", + ) + + # Agent-specific configurations + channels: Optional["ChannelConfig"] = Field( + default=None, + description="Channel configurations for this agent", + ) + mcp: Optional["MCPConfig"] = Field( + default=None, + description="MCP clients for this agent", + ) + heartbeat: Optional[HeartbeatConfig] = Field( + default=None, + description="Heartbeat configuration for this agent", ) running: AgentsRunningConfig = Field( default_factory=AgentsRunningConfig, + description="Runtime configuration", ) llm_routing: AgentsLLMRoutingConfig = Field( default_factory=AgentsLLMRoutingConfig, - description="LLM routing settings (local/cloud).", + description="LLM routing settings", + ) + active_model: Optional["ModelSlotConfig"] = Field( + default=None, + description="Active model for this agent (provider_id + model)", ) language: str = Field( default="zh", - description="Language for agent MD files (zh/en/ru)", + description="Language setting for this agent", + ) + system_prompt_files: List[str] = Field( + default_factory=lambda: ["AGENTS.md", "SOUL.md", "PROFILE.md"], + description="System prompt markdown files", + ) + tools: Optional["ToolsConfig"] = Field( + default=None, + description="Tools configuration for this agent", ) - installed_md_files_language: Optional[str] = Field( + security: Optional["SecurityConfig"] = Field( default=None, - description="Language of currently installed md files", + description="Security configuration for this agent", + ) + + +class AgentsConfig(BaseModel): + """Agents configuration (root config.json only contains references).""" + + active_agent: str = Field( + default="default", + description="Currently active agent ID", + ) + profiles: Dict[str, AgentProfileRef] = Field( + default_factory=lambda: { + "default": AgentProfileRef( + id="default", + workspace_dir="~/.copaw/workspaces/default", + ), + }, + description="Agent profile references (ID and workspace path only)", + ) + + # Legacy fields for backward compatibility (deprecated) + # These fields MUST have default values (not None) to support downgrade + defaults: Optional[AgentsDefaultsConfig] = None + running: AgentsRunningConfig = Field( + default_factory=AgentsRunningConfig, + ) + llm_routing: AgentsLLMRoutingConfig = Field( + default_factory=AgentsLLMRoutingConfig, ) + language: str = Field(default="zh") + installed_md_files_language: Optional[str] = None system_prompt_files: List[str] = Field( default_factory=lambda: ["AGENTS.md", "SOUL.md", "PROFILE.md"], - description="List of markdown files to load into system prompt", ) @@ -561,3 +652,253 @@ class Config(BaseModel): WecomConfig, XiaoYiConfig, ] + + +# Agent configuration utility functions + + +def load_agent_config(agent_id: str) -> AgentProfileConfig: + """Load agent's complete configuration from workspace/agent.json. + + Args: + agent_id: Agent ID to load + + Returns: + AgentProfileConfig: Complete agent configuration + + Raises: + ValueError: If agent ID not found in root config + """ + from .utils import load_config + + config = load_config() + + if agent_id not in config.agents.profiles: + raise ValueError(f"Agent '{agent_id}' not found in config") + + agent_ref = config.agents.profiles[agent_id] + workspace_dir = Path(agent_ref.workspace_dir).expanduser() + agent_config_path = workspace_dir / "agent.json" + + if not agent_config_path.exists(): + # Fallback: Try to use root config fields for backward compatibility + # This allows downgrade scenarios where agent.json doesn't exist yet + fallback_config = AgentProfileConfig( + id=agent_id, + name=agent_id.title(), + description=f"{agent_id} agent", + workspace_dir=str(workspace_dir), + # Inherit from root config if available (for backward compat) + channels=( + config.channels + if hasattr(config, "channels") and config.channels + else None + ), + mcp=config.mcp if hasattr(config, "mcp") and config.mcp else None, + tools=( + config.tools + if hasattr(config, "tools") and config.tools + else None + ), + security=( + config.security + if hasattr(config, "security") and config.security + else None + ), + # Use agent-specific configs with proper defaults + running=( + config.agents.running + if hasattr(config.agents, "running") and config.agents.running + else AgentsRunningConfig() + ), + llm_routing=( + config.agents.llm_routing + if hasattr(config.agents, "llm_routing") + and config.agents.llm_routing + else AgentsLLMRoutingConfig() + ), + system_prompt_files=( + config.agents.system_prompt_files + if hasattr(config.agents, "system_prompt_files") + and config.agents.system_prompt_files + else ["AGENTS.md", "SOUL.md", "PROFILE.md"] + ), + ) + # Save for future use + save_agent_config(agent_id, fallback_config) + return fallback_config + + with open(agent_config_path, "r", encoding="utf-8") as f: + data = json.load(f) + + return AgentProfileConfig(**data) + + +def save_agent_config( + agent_id: str, + agent_config: AgentProfileConfig, +) -> None: + """Save agent configuration to workspace/agent.json. + + Args: + agent_id: Agent ID + agent_config: Complete agent configuration to save + + Raises: + ValueError: If agent ID not found in root config + """ + from .utils import load_config + + config = load_config() + + if agent_id not in config.agents.profiles: + raise ValueError(f"Agent '{agent_id}' not found in config") + + agent_ref = config.agents.profiles[agent_id] + workspace_dir = Path(agent_ref.workspace_dir).expanduser() + workspace_dir.mkdir(parents=True, exist_ok=True) + + agent_config_path = workspace_dir / "agent.json" + + with open(agent_config_path, "w", encoding="utf-8") as f: + json.dump( + agent_config.model_dump(exclude_none=True), + f, + ensure_ascii=False, + indent=2, + ) + + +def migrate_legacy_config_to_multi_agent() -> bool: + """Migrate legacy single-agent config to new multi-agent structure. + + Returns: + bool: True if migration was performed, False if already migrated + """ + from .utils import load_config, save_config + + config = load_config() + + # Check if already migrated (new structure has only AgentProfileRef) + if "default" in config.agents.profiles: + agent_ref = config.agents.profiles["default"] + # If it's already a AgentProfileRef, migration done + if isinstance(agent_ref, AgentProfileRef): + # Check if default agent config exists + workspace_dir = Path(agent_ref.workspace_dir).expanduser() + agent_config_path = workspace_dir / "agent.json" + if agent_config_path.exists(): + return False # Already migrated + + # Perform migration + print("Migrating legacy config to multi-agent structure...") + + # Extract legacy agent configuration + legacy_agents = config.agents + + # Create default agent workspace + default_workspace = Path("~/.copaw/workspaces/default").expanduser() + default_workspace.mkdir(parents=True, exist_ok=True) + + # Create default agent configuration from legacy settings + default_agent_config = AgentProfileConfig( + id="default", + name="Default Agent", + description="Default CoPaw agent", + workspace_dir=str(default_workspace), + channels=config.channels if config.channels else None, + mcp=config.mcp if config.mcp else None, + heartbeat=( + legacy_agents.defaults.heartbeat + if legacy_agents.defaults + else None + ), + running=( + legacy_agents.running + if legacy_agents.running + else AgentsRunningConfig() + ), + llm_routing=( + legacy_agents.llm_routing + if legacy_agents.llm_routing + else AgentsLLMRoutingConfig() + ), + system_prompt_files=( + legacy_agents.system_prompt_files + if legacy_agents.system_prompt_files + else ["AGENTS.md", "SOUL.md", "PROFILE.md"] + ), + tools=config.tools if config.tools else None, + security=config.security if config.security else None, + ) + + # Save default agent configuration to workspace + agent_config_path = default_workspace / "agent.json" + with open(agent_config_path, "w", encoding="utf-8") as f: + json.dump( + default_agent_config.model_dump(exclude_none=True), + f, + ensure_ascii=False, + indent=2, + ) + + # Migrate existing workspace files to default agent workspace + old_workspace = Path("~/.copaw").expanduser() + + # Move sessions, memory, and other workspace files + for item_name in ["sessions", "memory", "jobs.json"]: + old_path = old_workspace / item_name + if old_path.exists(): + new_path = default_workspace / item_name + if not new_path.exists(): + import shutil + + if old_path.is_dir(): + shutil.copytree(old_path, new_path) + else: + shutil.copy2(old_path, new_path) + print(f" Migrated {item_name} to default workspace") + + # Copy markdown files (AGENTS.md, SOUL.md, PROFILE.md) + for md_file in ["AGENTS.md", "SOUL.md", "PROFILE.md"]: + old_md = old_workspace / md_file + if old_md.exists(): + new_md = default_workspace / md_file + if not new_md.exists(): + import shutil + + shutil.copy2(old_md, new_md) + print(f" Migrated {md_file} to default workspace") + + # Update root config.json to new structure + # CRITICAL: Preserve legacy agent fields for downgrade compatibility + config.agents = AgentsConfig( + active_agent="default", + profiles={ + "default": AgentProfileRef( + id="default", + workspace_dir=str(default_workspace), + ), + }, + # Preserve legacy fields with values from migrated agent config + running=default_agent_config.running, + llm_routing=default_agent_config.llm_routing, + language=( + default_agent_config.language + if hasattr(default_agent_config, "language") + else "zh" + ), + system_prompt_files=default_agent_config.system_prompt_files, + ) + + # IMPORTANT: Keep channels, mcp, tools, security in root config for + # backward compatibility. Do NOT clear these fields. + # Old versions expect these fields to exist with valid values. + + save_config(config) + + print("Migration completed successfully!") + print(f" Default agent workspace: {default_workspace}") + print(f" Default agent config: {agent_config_path}") + + return True diff --git a/src/copaw/config/context.py b/src/copaw/config/context.py new file mode 100644 index 000000000..e10368127 --- /dev/null +++ b/src/copaw/config/context.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +"""Context variable for agent workspace directory. + +This module provides a context variable to pass the agent's workspace +directory to tool functions, allowing them to resolve relative paths +correctly in a multi-agent environment. +""" +from contextvars import ContextVar +from pathlib import Path + +# Context variable to store the current agent's workspace directory +current_workspace_dir: ContextVar[Path | None] = ContextVar( + "current_workspace_dir", + default=None, +) + + +def get_current_workspace_dir() -> Path | None: + """Get the current agent's workspace directory from context. + + Returns: + Path to the current agent's workspace directory, or None if not set. + """ + return current_workspace_dir.get() + + +def set_current_workspace_dir(workspace_dir: Path | None) -> None: + """Set the current agent's workspace directory in context. + + Args: + workspace_dir: Path to the agent's workspace directory. + """ + current_workspace_dir.set(workspace_dir) diff --git a/src/copaw/constant.py b/src/copaw/constant.py index f0c566cbe..61524862f 100644 --- a/src/copaw/constant.py +++ b/src/copaw/constant.py @@ -79,6 +79,9 @@ def get_str(env_var: str, default: str = "") -> str: .resolve() ) +# Default media directory for channels (cross-platform) +DEFAULT_MEDIA_DIR = WORKING_DIR / "media" + JOBS_FILE = EnvVarLoader.get_str("COPAW_JOBS_FILE", "jobs.json") CHATS_FILE = EnvVarLoader.get_str("COPAW_CHATS_FILE", "chats.json") diff --git a/src/copaw/providers/provider_manager.py b/src/copaw/providers/provider_manager.py index 34fcf81e9..3571aac47 100644 --- a/src/copaw/providers/provider_manager.py +++ b/src/copaw/providers/provider_manager.py @@ -9,7 +9,7 @@ import logging import json -from pydantic import BaseModel, Field +from pydantic import BaseModel from agentscope.model import ChatModelBase @@ -19,6 +19,7 @@ Provider, ProviderInfo, ) +from copaw.providers.models import ModelSlotConfig from copaw.providers.openai_provider import OpenAIProvider from copaw.providers.anthropic_provider import AnthropicProvider from copaw.providers.gemini_provider import GeminiProvider @@ -227,17 +228,6 @@ ) -class ModelSlotConfig(BaseModel): - provider_id: str = Field( - ..., - description="ID of the provider to use for this model slot", - ) - model: str = Field( - ..., - description="ID of the model to use for this model slot", - ) - - class ActiveModelsInfo(BaseModel): active_llm: ModelSlotConfig | None diff --git a/tests/unit/workspace/__init__.py b/tests/unit/workspace/__init__.py new file mode 100644 index 000000000..1b1d45590 --- /dev/null +++ b/tests/unit/workspace/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +"""Unit tests for workspace module.""" diff --git a/tests/unit/workspace/test_agent_creation.py b/tests/unit/workspace/test_agent_creation.py new file mode 100644 index 000000000..9909e5d13 --- /dev/null +++ b/tests/unit/workspace/test_agent_creation.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +"""Tests for agent creation with short UUID.""" +from unittest.mock import patch + +from copaw.config.config import ( + AgentProfileConfig, + generate_short_agent_id, +) + + +def test_agent_creation_auto_generates_short_id(): + """Test that agent creation validates short ID generation.""" + # Test that an empty ID triggers auto-generation logic + agent_config = AgentProfileConfig( + id="", # Empty ID should trigger auto-generation + name="Test Agent", + description="Test agent description", + ) + + # Verify the empty ID case + assert agent_config.id == "" + + # The actual auto-generation happens in the API endpoint + # This test verifies the precondition + + +def test_generate_short_id_collision_handling(): + """Test that agent creation can handle ID collisions.""" + # Generate some IDs + existing_ids = {generate_short_agent_id() for _ in range(5)} + + # Mock that first few attempts collide + collision_count = 0 + original_generate = generate_short_agent_id + + def mock_generate(): + nonlocal collision_count + if collision_count < 3: + collision_count += 1 + # Return an existing ID to simulate collision + return list(existing_ids)[0] + # Return a new unique ID + return original_generate() + + with patch( + "copaw.app.routers.agents.generate_short_agent_id", + side_effect=mock_generate, + ) as mock_fn: + # Generate IDs until we get a unique one + for _ in range(10): + new_id = mock_fn() + if new_id not in existing_ids: + break + + # Verify the mock was called + assert mock_fn.call_count > 0 + + +def test_default_agent_preserved(): + """Test that 'default' agent ID is preserved.""" + agent_config = AgentProfileConfig( + id="default", + name="Default Agent", + description="Default agent", + ) + + # 'default' should not be replaced with short UUID + assert agent_config.id == "default" + + +def test_short_uuid_properties(): + """Test properties of generated short UUIDs.""" + # Generate multiple IDs + ids = [generate_short_agent_id() for _ in range(20)] + + for agent_id in ids: + # Each should be 6 characters + assert len(agent_id) == 6 + # Should be alphanumeric + assert agent_id.isalnum() + # Should not contain ambiguous characters (shortuuid excludes them) + # This is a property of shortuuid library + assert "I" not in agent_id # Excluded by shortuuid + assert "l" not in agent_id # Excluded by shortuuid + assert "O" not in agent_id # Excluded by shortuuid + assert "0" not in agent_id # Excluded by shortuuid diff --git a/tests/unit/workspace/test_agent_id.py b/tests/unit/workspace/test_agent_id.py new file mode 100644 index 000000000..22ea7eb38 --- /dev/null +++ b/tests/unit/workspace/test_agent_id.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +"""Tests for agent ID generation and short UUID functionality.""" +from copaw.config.config import generate_short_agent_id + + +def test_generate_short_agent_id_length(): + """Test that generated agent ID has correct length.""" + agent_id = generate_short_agent_id() + assert len(agent_id) == 6 + assert isinstance(agent_id, str) + + +def test_generate_short_agent_id_unique(): + """Test that generated agent IDs are unique.""" + ids = {generate_short_agent_id() for _ in range(100)} + # With 100 generations, we should get at least 95 unique IDs + # (allowing for some collisions in the random space) + assert len(ids) >= 95 + + +def test_generate_short_agent_id_alphanumeric(): + """Test that generated agent ID contains only alphanumeric chars.""" + agent_id = generate_short_agent_id() + # shortuuid uses base57 alphabet by default + # (0-9, A-Z, a-z minus ambiguous chars like I, l, O, 0, etc.) + assert agent_id.isalnum() diff --git a/tests/unit/workspace/test_agent_model.py b/tests/unit/workspace/test_agent_model.py new file mode 100644 index 000000000..fd6c6af2a --- /dev/null +++ b/tests/unit/workspace/test_agent_model.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- +"""Tests for per-agent model configuration.""" +from pathlib import Path + +import pytest + +from copaw.config.config import ( + AgentProfileConfig, + load_agent_config, + save_agent_config, +) +from copaw.providers.models import ModelSlotConfig + + +@pytest.fixture +def mock_agent_workspace(tmp_path, monkeypatch): + """Create a temporary agent workspace for testing.""" + import json + from copaw.config.utils import get_config_path + from copaw.config.config import Config, AgentsConfig, AgentProfileRef + + # Setup workspace directory + workspace_dir = tmp_path / "workspaces" / "test_agent" + workspace_dir.mkdir(parents=True, exist_ok=True) + + # Patch config path FIRST before any config operations + monkeypatch.setenv( + "COPAW_CONFIG_PATH", + str(tmp_path / "config.json"), + ) + + # Create root config with this agent + root_config = Config( + agents=AgentsConfig( + active_agent="test_agent", + profiles={ + "test_agent": AgentProfileRef( + id="test_agent", + workspace_dir=str(workspace_dir), + ), + }, + ), + ) + + config_path = Path(get_config_path()) + config_path.parent.mkdir(parents=True, exist_ok=True) + with open(config_path, "w", encoding="utf-8") as f: + json.dump(root_config.model_dump(exclude_none=True), f) + + # Now create agent.json + agent_config = AgentProfileConfig( + id="test_agent", + name="Test Agent", + description="Test agent for model config", + ) + save_agent_config("test_agent", agent_config) + + return workspace_dir + + +def test_agent_model_config_defaults_to_none( + mock_agent_workspace, +): # pylint: disable=redefined-outer-name,unused-argument + """Test that agent model config defaults to None.""" + agent_config = load_agent_config("test_agent") + assert agent_config.active_model is None + + +def test_agent_model_config_can_be_set( + mock_agent_workspace, +): # pylint: disable=redefined-outer-name,unused-argument + """Test setting agent-specific model config.""" + agent_config = load_agent_config("test_agent") + + # Set active model + agent_config.active_model = ModelSlotConfig( + provider_id="openai", + model="gpt-4", + ) + save_agent_config("test_agent", agent_config) + + # Reload and verify + reloaded_config = load_agent_config("test_agent") + assert reloaded_config.active_model is not None + assert reloaded_config.active_model.provider_id == "openai" + assert reloaded_config.active_model.model == "gpt-4" + + +def test_agent_model_config_persists_across_reloads( + mock_agent_workspace, +): # pylint: disable=redefined-outer-name,unused-argument + """Test that model config persists across multiple save/load cycles.""" + agent_config = load_agent_config("test_agent") + + # Set model + agent_config.active_model = ModelSlotConfig( + provider_id="anthropic", + model="claude-3-5-sonnet-20241022", + ) + save_agent_config("test_agent", agent_config) + + # Reload multiple times + for _ in range(3): + reloaded = load_agent_config("test_agent") + assert reloaded.active_model is not None + assert reloaded.active_model.provider_id == "anthropic" + assert reloaded.active_model.model == "claude-3-5-sonnet-20241022" + + +def test_agent_model_config_can_be_cleared( + mock_agent_workspace, +): # pylint: disable=redefined-outer-name,unused-argument + """Test that model config can be set to None.""" + agent_config = load_agent_config("test_agent") + + # Set a model + agent_config.active_model = ModelSlotConfig( + provider_id="openai", + model="gpt-4", + ) + save_agent_config("test_agent", agent_config) + + # Clear it + agent_config.active_model = None + save_agent_config("test_agent", agent_config) + + # Verify it's cleared + reloaded = load_agent_config("test_agent") + assert reloaded.active_model is None + + +def test_different_agents_have_independent_models(tmp_path, monkeypatch): + """Test that different agents can have different model configs.""" + # Patch config path + monkeypatch.setenv( + "COPAW_CONFIG_PATH", + str(tmp_path / "config.json"), + ) + + # Create two agents + import json + from copaw.config.config import ( + Config, + AgentsConfig, + AgentProfileRef, + ) + from copaw.config.utils import get_config_path + + agent1_dir = tmp_path / "workspaces" / "agent1" + agent2_dir = tmp_path / "workspaces" / "agent2" + agent1_dir.mkdir(parents=True, exist_ok=True) + agent2_dir.mkdir(parents=True, exist_ok=True) + + # Create root config + root_config = Config( + agents=AgentsConfig( + active_agent="agent1", + profiles={ + "agent1": AgentProfileRef( + id="agent1", + workspace_dir=str(agent1_dir), + ), + "agent2": AgentProfileRef( + id="agent2", + workspace_dir=str(agent2_dir), + ), + }, + ), + ) + + config_path = Path(get_config_path()) + config_path.parent.mkdir(parents=True, exist_ok=True) + with open(config_path, "w", encoding="utf-8") as f: + json.dump(root_config.model_dump(exclude_none=True), f) + + # Create agent configs + config1 = AgentProfileConfig( + id="agent1", + name="Agent 1", + ) + config2 = AgentProfileConfig( + id="agent2", + name="Agent 2", + ) + + # Set different models + config1.active_model = ModelSlotConfig( + provider_id="openai", + model="gpt-4", + ) + config2.active_model = ModelSlotConfig( + provider_id="anthropic", + model="claude-3-5-sonnet-20241022", + ) + + save_agent_config("agent1", config1) + save_agent_config("agent2", config2) + + # Verify they're independent + reloaded1 = load_agent_config("agent1") + reloaded2 = load_agent_config("agent2") + + assert reloaded1.active_model.provider_id == "openai" + assert reloaded1.active_model.model == "gpt-4" + + assert reloaded2.active_model.provider_id == "anthropic" + assert reloaded2.active_model.model == "claude-3-5-sonnet-20241022" + + +def test_model_config_excluded_when_none( + mock_agent_workspace, +): # pylint: disable=redefined-outer-name + """Test that active_model is excluded from agent.json when None.""" + agent_config = load_agent_config("test_agent") + agent_config.active_model = None + save_agent_config("test_agent", agent_config) + + # Read the raw JSON file + import json + + agent_json_path = mock_agent_workspace / "agent.json" + with open(agent_json_path, "r", encoding="utf-8") as f: + raw_data = json.load(f) + + # active_model should not be in the JSON + assert "active_model" not in raw_data + + +def test_model_config_included_when_set( + mock_agent_workspace, +): # pylint: disable=redefined-outer-name + """Test that active_model is included in agent.json when set.""" + agent_config = load_agent_config("test_agent") + agent_config.active_model = ModelSlotConfig( + provider_id="openai", + model="gpt-4-turbo", + ) + save_agent_config("test_agent", agent_config) + + # Read the raw JSON file + import json + + agent_json_path = mock_agent_workspace / "agent.json" + with open(agent_json_path, "r", encoding="utf-8") as f: + raw_data = json.load(f) + + # active_model should be in the JSON + assert "active_model" in raw_data + assert raw_data["active_model"]["provider_id"] == "openai" + assert raw_data["active_model"]["model"] == "gpt-4-turbo" diff --git a/tests/unit/workspace/test_cli_agent_id.py b/tests/unit/workspace/test_cli_agent_id.py new file mode 100644 index 000000000..209f8ec2c --- /dev/null +++ b/tests/unit/workspace/test_cli_agent_id.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- +"""Tests for CLI --agent-id parameter support.""" +import json +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock +import pytest +from click.testing import CliRunner + +from copaw.cli.channels_cmd import channels_group +from copaw.cli.cron_cmd import cron_group +from copaw.cli.daemon_cmd import daemon_group +from copaw.cli.chats_cmd import chats_group +from copaw.cli.skills_cmd import skills_group +from copaw.config.config import AgentProfileConfig + + +@pytest.fixture +def temp_config_dir(): + """Create a temporary config directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + config_dir = Path(tmpdir) + workspaces_dir = config_dir / "workspaces" + workspaces_dir.mkdir() + + # Create default agent workspace + default_ws = workspaces_dir / "default" + default_ws.mkdir() + + # Create test agent workspace + test_ws = workspaces_dir / "abc123" + test_ws.mkdir() + + yield config_dir, default_ws, test_ws + + +def test_channels_list_default_agent( + temp_config_dir, +): # pylint: disable=W0621,W0613 + """Test copaw channels list uses default agent by default.""" + _, default_ws, _ = temp_config_dir + + # Create agent config for default + agent_config_path = default_ws / "agent.json" + agent_config = AgentProfileConfig( + id="default", + name="Default Agent", + workspace_dir=str(default_ws), + ) + agent_config_path.write_text( + json.dumps(agent_config.model_dump(exclude_none=True)), + encoding="utf-8", + ) + + runner = CliRunner() + + with patch("copaw.cli.channels_cmd.load_agent_config") as mock_load: + mock_load.return_value = agent_config + runner.invoke(channels_group, ["list"]) + + # Should call with 'default' agent + mock_load.assert_called_once_with("default") + + +def test_channels_list_custom_agent( + temp_config_dir, +): # pylint: disable=W0621,W0613 + """Test copaw channels list with custom agent_id.""" + _, _, test_ws = temp_config_dir + + # Create agent config for test agent + agent_config_path = test_ws / "agent.json" + agent_config = AgentProfileConfig( + id="abc123", + name="Test Agent", + workspace_dir=str(test_ws), + ) + agent_config_path.write_text( + json.dumps(agent_config.model_dump(exclude_none=True)), + encoding="utf-8", + ) + + runner = CliRunner() + + with patch("copaw.cli.channels_cmd.load_agent_config") as mock_load: + mock_load.return_value = agent_config + runner.invoke( + channels_group, + ["list", "--agent-id", "abc123"], + ) + + # Should call with custom agent + mock_load.assert_called_once_with("abc123") + + +def test_cron_list_with_agent_id(): + """Test copaw cron list with --agent-id.""" + runner = CliRunner() + + with patch("copaw.cli.cron_cmd.client") as mock_client: + mock_response = MagicMock() + mock_response.json.return_value = {"jobs": []} + mock_response.raise_for_status = MagicMock() + mock_client.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + runner.invoke( + cron_group, + ["list", "--agent-id", "test123"], + ) + + # Verify X-Agent-Id header was set + call_args = ( + mock_client.return_value.__enter__.return_value.get.call_args + ) + assert call_args is not None + if call_args[1].get("headers"): + assert call_args[1]["headers"]["X-Agent-Id"] == "test123" + + +def test_daemon_status_default_agent(): + """Test copaw daemon status defaults to 'default' agent.""" + runner = CliRunner() + + with patch("copaw.cli.daemon_cmd.run_daemon_status") as mock_status: + with patch("copaw.cli.daemon_cmd._get_agent_workspace") as mock_ws: + mock_ws.return_value = "/tmp/default" + mock_status.return_value = "Status: OK" + + runner.invoke(daemon_group, ["status"]) + + # Should use default agent + mock_ws.assert_called_once_with("default") + + +def test_daemon_status_custom_agent(): + """Test copaw daemon status with custom agent.""" + runner = CliRunner() + + with patch("copaw.cli.daemon_cmd.run_daemon_status") as mock_status: + with patch("copaw.cli.daemon_cmd._get_agent_workspace") as mock_ws: + mock_ws.return_value = "/tmp/xyz789" + mock_status.return_value = "Status: OK" + + runner.invoke( + daemon_group, + ["status", "--agent-id", "xyz789"], + ) + + # Should use custom agent + mock_ws.assert_called_once_with("xyz789") + + +def test_skills_list_default_agent(): + """Test copaw skills list defaults to 'default' agent.""" + runner = CliRunner() + + with patch( + "copaw.cli.skills_cmd._get_agent_workspace", + ) as mock_ws: + with patch("copaw.cli.skills_cmd.SkillService") as mock_service: + mock_ws.return_value = "/tmp/default" + mock_service_instance = MagicMock() + mock_service_instance.list_all_skills.return_value = [] + mock_service.return_value = mock_service_instance + + runner.invoke(skills_group, ["list"]) + + # Should use default agent + mock_ws.assert_called_once_with("default") + + +def test_skills_list_custom_agent(): + """Test copaw skills list with custom agent.""" + runner = CliRunner() + + with patch( + "copaw.cli.skills_cmd._get_agent_workspace", + ) as mock_ws: + with patch("copaw.cli.skills_cmd.SkillService") as mock_service: + mock_ws.return_value = "/tmp/abc123" + mock_service_instance = MagicMock() + mock_service_instance.list_all_skills.return_value = [] + mock_service.return_value = mock_service_instance + + runner.invoke( + skills_group, + ["list", "--agent-id", "abc123"], + ) + + # Should use custom agent + mock_ws.assert_called_once_with("abc123") + + +def test_chats_list_with_agent_id(): + """Test copaw chats list with --agent-id.""" + runner = CliRunner() + + with patch("copaw.cli.chats_cmd.client") as mock_client: + mock_response = MagicMock() + mock_response.json.return_value = [] + mock_response.raise_for_status = MagicMock() + mock_client.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + runner.invoke( + chats_group, + ["list", "--agent-id", "xyz789"], + ) + + # Verify X-Agent-Id header was set + call_args = ( + mock_client.return_value.__enter__.return_value.get.call_args + ) + assert call_args is not None + if call_args[1].get("headers"): + assert call_args[1]["headers"]["X-Agent-Id"] == "xyz789" diff --git a/tests/unit/workspace/test_prompt.py b/tests/unit/workspace/test_prompt.py new file mode 100644 index 000000000..f76a469a2 --- /dev/null +++ b/tests/unit/workspace/test_prompt.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +"""Tests for agent identity in system prompt.""" +import tempfile +from pathlib import Path +import pytest +from copaw.agents.prompt import build_system_prompt_from_working_dir + + +@pytest.fixture +def temp_workspace(): + """Create a temporary workspace directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + yield workspace + + +def test_prompt_without_agent_id(temp_workspace): # pylint: disable=W0621 + """Test system prompt without agent_id.""" + # Create a simple AGENTS.md + agents_md = temp_workspace / "AGENTS.md" + agents_md.write_text("You are a helpful assistant.", encoding="utf-8") + + prompt = build_system_prompt_from_working_dir( + working_dir=temp_workspace, + agent_id=None, + ) + + assert "You are a helpful assistant" in prompt + assert "Agent Identity" not in prompt + assert "You are agent" not in prompt + + +def test_prompt_with_default_agent_id( + temp_workspace, +): # pylint: disable=W0621 + """Test system prompt with 'default' agent_id.""" + agents_md = temp_workspace / "AGENTS.md" + agents_md.write_text("You are a helpful assistant.", encoding="utf-8") + + prompt = build_system_prompt_from_working_dir( + working_dir=temp_workspace, + agent_id="default", + ) + + # 'default' is special and should not add identity header + assert "You are a helpful assistant" in prompt + assert "Agent Identity" not in prompt + + +def test_prompt_with_custom_agent_id( + temp_workspace, +): # pylint: disable=W0621 + """Test system prompt with custom agent_id.""" + agents_md = temp_workspace / "AGENTS.md" + agents_md.write_text("You are a helpful assistant.", encoding="utf-8") + + prompt = build_system_prompt_from_working_dir( + working_dir=temp_workspace, + agent_id="abc123", + ) + + # Custom agent should have identity header + assert "Agent Identity" in prompt + assert "Your agent id is `abc123`" in prompt + assert "You are a helpful assistant" in prompt + # Identity should be at the beginning + assert prompt.index("Agent Identity") < prompt.index("helpful assistant") + + +def test_prompt_with_empty_workspace( + temp_workspace, +): # pylint: disable=W0621 + """Test system prompt with empty workspace.""" + prompt = build_system_prompt_from_working_dir( + working_dir=temp_workspace, + agent_id="xyz789", + ) + + # Should still add identity header even with no markdown files + assert "Agent Identity" in prompt + assert "Your agent id is `xyz789`" in prompt + + +def test_prompt_identity_format(temp_workspace): # pylint: disable=W0621 + """Test the exact format of identity header.""" + prompt = build_system_prompt_from_working_dir( + working_dir=temp_workspace, + agent_id="test99", + ) + + expected_header = ( + "# Agent Identity\n\n" + "Your agent id is `test99`. " + "This is your unique identifier in the multi-agent system.\n\n" + ) + assert expected_header in prompt diff --git a/tests/unit/workspace/test_workspace.py b/tests/unit/workspace/test_workspace.py new file mode 100644 index 000000000..5ed714454 --- /dev/null +++ b/tests/unit/workspace/test_workspace.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +"""Tests for Workspace class.""" +import tempfile +from pathlib import Path +import pytest + + +@pytest.mark.asyncio +async def test_workspace_creation(): + """Test workspace instance creation.""" + from copaw.app.workspace import Workspace + + with tempfile.TemporaryDirectory() as tmpdir: + workspace_dir = Path(tmpdir) / "test_agent" + workspace = Workspace( + agent_id="test123", + workspace_dir=str(workspace_dir), + ) + + assert workspace.agent_id == "test123" + assert workspace.workspace_dir == workspace_dir + assert workspace_dir.exists() + assert not workspace._started # pylint: disable=W0212 + + +@pytest.mark.asyncio +async def test_workspace_components_none_before_start(): + """Test that workspace components are None before start().""" + from copaw.app.workspace import Workspace + + with tempfile.TemporaryDirectory() as tmpdir: + workspace_dir = Path(tmpdir) / "test_agent" + workspace = Workspace( + agent_id="test123", + workspace_dir=str(workspace_dir), + ) + + assert workspace.runner is None + assert workspace.channel_manager is None + assert workspace.memory_manager is None + assert workspace.mcp_manager is None + assert workspace.cron_manager is None + assert workspace.chat_manager is None + + +@pytest.mark.asyncio +async def test_workspace_default_agent(): + """Test workspace with 'default' agent ID.""" + from copaw.app.workspace import Workspace + + with tempfile.TemporaryDirectory() as tmpdir: + workspace_dir = Path(tmpdir) / "default" + workspace = Workspace( + agent_id="default", + workspace_dir=str(workspace_dir), + ) + + assert workspace.agent_id == "default" + assert workspace.workspace_dir.name == "default" + + +@pytest.mark.asyncio +async def test_workspace_short_uuid_agent(): + """Test workspace with short UUID agent ID.""" + from copaw.app.workspace import Workspace + from copaw.config.config import generate_short_agent_id + + short_id = generate_short_agent_id() + + with tempfile.TemporaryDirectory() as tmpdir: + workspace_dir = Path(tmpdir) / short_id + workspace = Workspace( + agent_id=short_id, + workspace_dir=str(workspace_dir), + ) + + assert workspace.agent_id == short_id + assert len(workspace.agent_id) == 6 + assert workspace.workspace_dir.name == short_id + + +def test_workspace_repr(): + """Test workspace string representation.""" + from copaw.app.workspace import Workspace + + with tempfile.TemporaryDirectory() as tmpdir: + workspace_dir = Path(tmpdir) / "test_agent" + workspace = Workspace( + agent_id="test123", + workspace_dir=str(workspace_dir), + ) + + repr_str = repr(workspace) + assert "test123" in repr_str + assert "stopped" in repr_str + assert "Workspace" in repr_str diff --git a/website/public/docs/cli.en.md b/website/public/docs/cli.en.md index 082d5593c..7912de08f 100644 --- a/website/public/docs/cli.en.md +++ b/website/public/docs/cli.en.md @@ -84,8 +84,11 @@ the app is not running). | `copaw daemon version` | Version and paths | | `copaw daemon logs [-n N]` | Last N lines of log (default 100; from `copaw.log` in working dir) | +**Multi-Agent Support:** All commands support the `--agent-id` parameter (defaults to `default`). + ```bash -copaw daemon status +copaw daemon status # Default agent status +copaw daemon status --agent-id abc123 # Specific agent status copaw daemon version copaw daemon logs -n 50 ``` @@ -227,14 +230,18 @@ subcommand); use `remove` to uninstall custom channels (no `uninstall`). | `copaw channels remove ` | Remove a custom channel from `custom_channels/` (built-ins cannot be removed); `--keep-config` keeps config entry | | `copaw channels config` | Interactively enable/disable channels and fill in credentials | +**Multi-Agent Support:** All commands support the `--agent-id` parameter (defaults to `default`). + ```bash -copaw channels list # See current status +copaw channels list # See default agent's channels +copaw channels list --agent-id abc123 # See specific agent's channels copaw channels install my_channel # Create custom channel stub copaw channels install my_channel --path ./my_channel.py copaw channels add dingtalk # Add DingTalk to config copaw channels remove my_channel # Remove custom channel (and from config by default) copaw channels remove my_channel --keep-config # Remove module only, keep config entry -copaw channels config # Interactive configuration +copaw channels config # Configure default agent +copaw channels config --agent-id abc123 # Configure specific agent ``` The interactive `config` flow lets you pick a channel, enable/disable it, and enter credentials. It loops until you choose "Save and exit". @@ -270,6 +277,8 @@ ask CoPaw and send the reply". **Requires `copaw app` to be running.** | `copaw cron resume ` | Resume a paused job | | `copaw cron run ` | Run once immediately | +**Multi-Agent Support:** All commands support the `--agent-id` parameter (defaults to `default`). + ### Creating jobs **Option 1 — CLI arguments (simple jobs)** @@ -280,7 +289,7 @@ Two task types: - **agent** — ask CoPaw a question on schedule and deliver the reply. ```bash -# Text: send "Good morning!" to DingTalk every day at 9:00 +# Text: send "Good morning!" to DingTalk every day at 9:00 (default agent) copaw cron create \ --type text \ --name "Daily 9am" \ @@ -290,8 +299,9 @@ copaw cron create \ --target-session "session_id" \ --text "Good morning!" -# Agent: every 2 hours, ask CoPaw and forward the reply +# Agent: create task for specific agent copaw cron create \ + --agent-id abc123 \ --type agent \ --name "Check todos" \ --cron "0 */2 * * *" \ @@ -349,12 +359,15 @@ Manage chat sessions via the API. **Requires `copaw app` to be running.** | `copaw chats update --name "..."` | Rename a session | | `copaw chats delete ` | Delete a session | +**Multi-Agent Support:** All commands support the `--agent-id` parameter (defaults to `default`). + ```bash -copaw chats list +copaw chats list # Default agent's chats +copaw chats list --agent-id abc123 # Specific agent's chats copaw chats list --user-id alice --channel dingtalk copaw chats get 823845fe-dd13-43c2-ab8b-d05870602fd8 copaw chats create --session-id "discord:alice" --user-id alice --name "My Chat" -copaw chats create -f chat.json +copaw chats create --agent-id abc123 -f chat.json copaw chats update --name "Renamed" copaw chats delete ``` @@ -372,9 +385,13 @@ Extend CoPaw's capabilities with skills (PDF reading, web search, etc.). | `copaw skills list` | Show all skills and their enabled/disabled status | | `copaw skills config` | Interactively enable/disable skills (checkbox UI) | +**Multi-Agent Support:** All commands support the `--agent-id` parameter (defaults to `default`). + ```bash -copaw skills list # See what's available -copaw skills config # Toggle skills on/off interactively +copaw skills list # See default agent's skills +copaw skills list --agent-id abc123 # See specific agent's skills +copaw skills config # Configure default agent +copaw skills config --agent-id abc123 # Configure specific agent ``` In the interactive UI: ↑/↓ to navigate, Space to toggle, Enter to confirm. @@ -416,16 +433,31 @@ copaw --host 0.0.0.0 --port 9090 cron list ## Working directory -All config and data live in `~/.copaw` by default: `config.json`, -`HEARTBEAT.md`, `jobs.json`, `chats.json`, skills, memory, and agent persona -files. +All config and data live in `~/.copaw` by default: + +- **Global config**: `config.json` (providers, environment variables, agent list) +- **Agent workspaces**: `workspaces/{agent_id}/` (each agent's independent config and data) + +``` +~/.copaw/ +├── config.json # Global config +└── workspaces/ + ├── default/ # Default agent workspace + │ ├── agent.json # Agent config + │ ├── chats.json # Conversation history + │ ├── jobs.json # Cron jobs + │ ├── AGENTS.md # Persona files + │ └── memory/ # Memory files + └── abc123/ # Other agent workspace + └── ... +``` | Variable | Description | | ------------------- | ----------------------------------- | | `COPAW_WORKING_DIR` | Override the working directory path | | `COPAW_CONFIG_FILE` | Override the config file path | -See [Config & Working Directory](./config) for full details. +See [Config & Working Directory](./config) and [Multi-Agent Workspace](./multi-agent) for full details. --- @@ -453,3 +485,4 @@ See [Config & Working Directory](./config) for full details. - [Heartbeat](./heartbeat) — Scheduled check-in / digest - [Skills](./skills) — Built-in and custom skills - [Config & Working Directory](./config) — Working directory and config.json +- [Multi-Agent Workspace](./multi-agent) — Multi-agent setup and management diff --git a/website/public/docs/cli.zh.md b/website/public/docs/cli.zh.md index 954376d72..16f0257fe 100644 --- a/website/public/docs/cli.zh.md +++ b/website/public/docs/cli.zh.md @@ -76,8 +76,11 @@ Docker 镜像或 pip 安装包已内置控制台,无需单独构建。 | `copaw daemon version` | 版本与路径 | | `copaw daemon logs [-n N]` | 最近 N 行日志(默认 100,来自工作目录 `copaw.log`) | +**多智能体支持:** 所有命令都支持 `--agent-id` 参数(默认为 `default`)。 + ```bash -copaw daemon status +copaw daemon status # 默认智能体状态 +copaw daemon status --agent-id abc123 # 特定智能体状态 copaw daemon version copaw daemon logs -n 50 ``` @@ -215,14 +218,18 @@ copaw env delete TAVILY_API_KEY | `copaw channels remove ` | 从 `custom_channels/` 删除自定义频道(内置不可删);`--keep-config` 保留 config | | `copaw channels config` | 交互式启用/禁用频道并填写凭据 | +**多智能体支持:** 所有命令都支持 `--agent-id` 参数(默认为 `default`)。 + ```bash -copaw channels list # 看当前状态 +copaw channels list # 看默认智能体的频道状态 +copaw channels list --agent-id abc123 # 看特定智能体的频道状态 copaw channels install my_channel # 创建自定义频道模板 copaw channels install my_channel --path ./my_channel.py copaw channels add dingtalk # 把钉钉加入 config copaw channels remove my_channel # 删除自定义频道(并默认从 config 移除) copaw channels remove my_channel --keep-config # 只删模块,保留 config 条目 -copaw channels config # 交互式配置 +copaw channels config # 交互式配置默认智能体 +copaw channels config --agent-id abc123 # 交互式配置特定智能体 ``` 交互式 `config` 流程:依次选择频道、启用/禁用、填写凭据,循环直到选择「保存退出」。 @@ -258,6 +265,8 @@ copaw channels config # 交互式配置 | `copaw cron resume ` | 恢复暂停的任务 | | `copaw cron run ` | 立刻执行一次 | +**多智能体支持:** 所有命令都支持 `--agent-id` 参数(默认为 `default`)。 + ### 创建任务 **方式一——命令行参数(适合简单任务)** @@ -268,7 +277,7 @@ copaw channels config # 交互式配置 - **agent** —— 到点向 CoPaw 提问,把回复发到频道。 ```bash -# text:每天 9 点发「早上好!」到钉钉 +# text:每天 9 点发「早上好!」到钉钉(默认智能体) copaw cron create \ --type text \ --name "每日早安" \ @@ -278,8 +287,9 @@ copaw cron create \ --target-session "会话ID" \ --text "早上好!" -# agent:每 2 小时让 CoPaw 回答并转发 +# agent:为特定智能体创建任务 copaw cron create \ + --agent-id abc123 \ --type agent \ --name "检查待办" \ --cron "0 */2 * * *" \ @@ -337,12 +347,15 @@ JSON 结构见 `copaw cron get ` 的返回。 | `copaw chats update --name "..."` | 重命名会话 | | `copaw chats delete ` | 删除会话 | +**多智能体支持:** 所有命令都支持 `--agent-id` 参数(默认为 `default`)。 + ```bash -copaw chats list +copaw chats list # 默认智能体的会话 +copaw chats list --agent-id abc123 # 特定智能体的会话 copaw chats list --user-id alice --channel dingtalk copaw chats get 823845fe-dd13-43c2-ab8b-d05870602fd8 copaw chats create --session-id "discord:alice" --user-id alice --name "My Chat" -copaw chats create -f chat.json +copaw chats create --agent-id abc123 -f chat.json copaw chats update --name "新名称" copaw chats delete ``` @@ -360,9 +373,13 @@ copaw chats delete | `copaw skills list` | 列出所有技能及启用/禁用状态 | | `copaw skills config` | 交互式启用/禁用技能(复选框界面) | +**多智能体支持:** 所有命令都支持 `--agent-id` 参数(默认为 `default`)。 + ```bash -copaw skills list # 看有哪些技能 -copaw skills config # 交互式开关 +copaw skills list # 看默认智能体的技能 +copaw skills list --agent-id abc123 # 看特定智能体的技能 +copaw skills config # 交互式配置默认智能体 +copaw skills config --agent-id abc123 # 交互式配置特定智能体 ``` 交互界面中:↑/↓ 选择、空格 切换、回车 确认。确认前会预览变更。 @@ -403,15 +420,31 @@ copaw --host 0.0.0.0 --port 9090 cron list ## 工作目录 -配置和数据都在 `~/.copaw`(默认):`config.json`、`HEARTBEAT.md`、`jobs.json`、 -`chats.json`、技能文件、记忆文件和 Agent 人设 Markdown 文件。 +配置和数据都在 `~/.copaw`(默认): + +- **全局配置**: `config.json`(提供商、环境变量、智能体列表) +- **智能体工作区**: `workspaces/{agent_id}/`(每个智能体独立的配置和数据) + +``` +~/.copaw/ +├── config.json # 全局配置 +└── workspaces/ + ├── default/ # 默认智能体工作区 + │ ├── agent.json # 智能体配置 + │ ├── chats.json # 对话历史 + │ ├── jobs.json # 定时任务 + │ ├── AGENTS.md # 人设文件 + │ └── memory/ # 记忆文件 + └── abc123/ # 其他智能体工作区 + └── ... +``` | 变量 | 说明 | | ------------------- | ---------------- | | `COPAW_WORKING_DIR` | 覆盖工作目录路径 | | `COPAW_CONFIG_FILE` | 覆盖配置文件路径 | -详见 [配置与工作目录](./config)。 +详见 [配置与工作目录](./config) 和 [多智能体工作区](./multi-agent)。 --- @@ -439,3 +472,4 @@ copaw --host 0.0.0.0 --port 9090 cron list - [心跳](./heartbeat) —— 定时自检/摘要 - [技能](./skills) —— 内置技能与自定义技能 - [配置与工作目录](./config) —— 工作目录与 config.json +- [多智能体工作区](./multi-agent) —— 多智能体配置与管理 diff --git a/website/public/docs/config.en.md b/website/public/docs/config.en.md index 9593b562a..4f2f2ce61 100644 --- a/website/public/docs/config.en.md +++ b/website/public/docs/config.en.md @@ -16,21 +16,47 @@ By default, all config and data live in one folder — the **working directory** - **`~/.copaw`** (the `.copaw` folder under your home directory) -When you run `copaw init`, this directory is created automatically. Here's what -you'll find inside: - -| File / Directory | Purpose | -| -------------------- | ------------------------------------------------------------------ | -| `config.json` | Channel on/off and credentials, heartbeat settings, language, etc. | -| `HEARTBEAT.md` | Prompt content used each heartbeat run | -| `jobs.json` | Cron job list (managed via `copaw cron` or API) | -| `chats.json` | Chat/session list (file storage mode) | -| `token_usage.json` | LLM token usage records (by date and model) | -| `active_skills/` | Skills currently active and used by the agent | -| `customized_skills/` | User-created custom skills | -| `memory/` | Agent memory files (auto-managed) | -| `SOUL.md` | _(required)_ Core identity and behavioral principles | -| `AGENTS.md` | _(required)_ Detailed workflows, rules, and guidelines | +Starting from **v0.1.0**, CoPaw supports **multi-agent workspace**. When you run `copaw init`, the new structure looks like: + +``` +~/.copaw/ +├── config.json # Global config (providers, environment variables) +└── workspaces/ + ├── default/ # Default agent workspace + │ ├── agent.json # Agent config + │ ├── chats.json # Conversation history + │ ├── jobs.json # Cron jobs + │ ├── AGENTS.md # Detailed workflows, rules, and guidelines + │ ├── SOUL.md # Core identity and behavioral principles + │ ├── active_skills/ # Enabled skills + │ ├── customized_skills/ # Custom skills + │ └── memory/ # Memory files + └── abc123/ # Other agent workspace + └── ... +``` + +### Directory Explanation + +**Global Directory (`~/.copaw/`)** + +| File / Directory | Purpose | +| ---------------- | ----------------------------------------------------- | +| `config.json` | Global config (model providers, env vars, agent list) | +| `workspaces/` | All agent workspace directories | + +**Agent Workspace (`~/.copaw/workspaces/{agent_id}/`)** + +| File / Directory | Purpose | +| -------------------- | ------------------------------------------------------------ | +| `agent.json` | Agent config (channels, heartbeat, tools, skills, MCP, etc.) | +| `chats.json` | Conversation history | +| `jobs.json` | Cron job list | +| `token_usage.json` | Token usage records | +| `AGENTS.md` | _(required)_ Detailed workflows, rules, and guidelines | +| `SOUL.md` | _(required)_ Core identity and behavioral principles | +| `active_skills/` | Currently enabled skills | +| `customized_skills/` | User-created custom skills | +| `memory/` | Memory files (auto-managed) | > **Tip:** `SOUL.md` and `AGENTS.md` are the minimum required Markdown files > for the agent's system prompt. Without them, the agent falls back to a @@ -38,6 +64,8 @@ you'll find inside: > them based on your language choice (`zh` / `en` / `ru`). You can also > change the language later via the Console (Agent → Configuration). +> **Multi-Agent Workspace:** See the [Multi-Agent Workspace](./multi-agent) documentation for details. + --- ## Changing paths with environment variables (optional) @@ -214,14 +242,33 @@ Each channel has a common base and channel-specific fields. --- -#### `agents` — Agent behavior settings +#### `agents` — Multi-agent configuration + +From **v0.1.0**, the `agents` section now contains agent profiles: + +| Field | Type | Default | Description | +| --------------------- | ------ | ----------- | --------------------------------------------- | +| `agents.active_agent` | string | `"default"` | Currently active agent ID | +| `agents.profiles` | object | `{}` | Dictionary of agent profiles (key = agent ID) | + +**`agents.profiles[agent_id]`** — Agent profile reference + +| Field | Type | Required | Description | +| ------------- | ------ | -------- | ---------------------------- | +| `id` | string | Yes | Agent unique ID | +| `name` | string | Yes | Agent display name | +| `description` | string | No | Agent description | +| `enabled` | bool | Yes | Whether the agent is enabled | + +Each agent's detailed configuration is stored in `~/.copaw/workspaces/{agent_id}/agent.json`: -| Field | Type | Default | Description | -| ------------------------------------ | -------------- | --------- | ----------------------------------------------------------------------- | -| `agents.defaults.heartbeat` | object \| null | See below | Heartbeat configuration | -| `agents.running` | object | See below | Agent runtime behavior configuration | -| `agents.language` | string | `"zh"` | Language for agent MD files (`"zh"` / `"en"` / `"ru"`) | -| `agents.installed_md_files_language` | string \| null | `null` | Tracks which language's MD files are installed; managed by `copaw init` | +| Field | Type | Default | Description | +| ----------------------------- | -------------- | --------- | ----------------------------------------------------------------------- | +| `channels` | object | See below | Channel configurations | +| `heartbeat` | object \| null | See below | Heartbeat configuration | +| `running` | object | See below | Agent runtime behavior configuration | +| `language` | string | `"zh"` | Language for agent MD files (`"zh"` / `"en"` / `"ru"`) | +| `installed_md_files_language` | string \| null | `null` | Tracks which language's MD files are installed; managed by `copaw init` | **`agents.running`** — Agent runtime behavior @@ -418,15 +465,18 @@ Memory search relies on vector embeddings for semantic retrieval. Configure via - Everything lives under **`~/.copaw`** by default; override with `COPAW_WORKING_DIR` (and related env vars) if needed. -- Day-to-day you edit **config.json** (channels, heartbeat, language) and +- From **v0.1.0**, configuration is split into: + - **Global config** (`~/.copaw/config.json`) — providers, environment variables, agent list + - **Agent config** (`~/.copaw/workspaces/{agent_id}/agent.json`) — per-agent settings +- Day-to-day you edit agent-specific **agent.json** (channels, heartbeat, language) and **HEARTBEAT.md** (what to ask on each heartbeat tick); manage cron jobs - via CLI/API. -- Agent personality is defined by Markdown files in the working directory: + via CLI/API with `--agent-id` parameter. +- Each agent's personality is defined by Markdown files in its workspace directory: **SOUL.md** + **AGENTS.md** (required). -- LLM providers are configured via `copaw init` or the console UI. +- LLM providers are globally configured via `copaw init` or the console UI. - Config changes to channels are **auto-reloaded** without restart (polled every 2 seconds). -- Call the Agent API: **POST** `/agent/process`, JSON body, SSE streaming; +- Call the Agent API: **POST** `/agent/process` with `X-Agent-Id` header, JSON body, SSE streaming; see [Quick start — Verify install](./quickstart#verify-install-optional) for examples. @@ -437,6 +487,7 @@ Memory search relies on vector embeddings for semantic retrieval. Configure via - [Introduction](./intro) — What the project can do - [Channels](./channels) — How to fill in channels in config - [Heartbeat](./heartbeat) — How to fill in heartbeat in config +- [Multi-Agent Workspace](./multi-agent) — Multi-agent setup and management --- diff --git a/website/public/docs/config.zh.md b/website/public/docs/config.zh.md index 87389d357..497835452 100644 --- a/website/public/docs/config.zh.md +++ b/website/public/docs/config.zh.md @@ -16,26 +16,55 @@ CoPaw 所有配置和数据默认都在一个目录里,叫**工作目录**, - **`~/.copaw`**(即你当前用户下的 `.copaw` 文件夹) -运行 `copaw init` 后会自动创建这个目录,里面大致是这样的: - -| 文件/目录 | 作用 | -| -------------------- | --------------------------------------------- | -| `config.json` | 频道开关与鉴权、心跳设置、语言等 | -| `HEARTBEAT.md` | 心跳每次要问 CoPaw 的内容 | -| `jobs.json` | 定时任务列表(通过 `copaw cron` 或 API 管理) | -| `chats.json` | 会话列表(文件存储模式) | -| `token_usage.json` | LLM Token 消耗记录(按日期、模型统计) | -| `active_skills/` | 当前激活的技能(Agent 实际使用的) | -| `customized_skills/` | 用户自定义的技能 | -| `memory/` | Agent 记忆文件(自动管理) | -| `SOUL.md` | _(必需)_ 核心身份与行为原则 | -| `AGENTS.md` | _(必需)_ 详细的工作流程、规则和指南 | +从 **v0.1.0** 开始,CoPaw 支持**多智能体工作区**。运行 `copaw init` 后会自动创建这个目录,新的结构如下: + +``` +~/.copaw/ +├── config.json # 全局配置(提供商、环境变量) +└── workspaces/ + ├── default/ # 默认智能体工作区 + │ ├── agent.json # 智能体配置 + │ ├── chats.json # 对话历史 + │ ├── jobs.json # 定时任务 + │ ├── AGENTS.md # 详细工作流程、规则和指南 + │ ├── SOUL.md # 核心身份与行为原则 + │ ├── active_skills/ # 激活的技能 + │ ├── customized_skills/ # 自定义技能 + │ └── memory/ # 记忆文件 + └── abc123/ # 其他智能体工作区 + └── ... +``` + +### 目录说明 + +**全局目录(`~/.copaw/`)** + +| 文件/目录 | 作用 | +| ------------- | -------------------------------------------- | +| `config.json` | 全局配置(模型提供商、环境变量、智能体列表) | +| `workspaces/` | 所有智能体的工作区目录 | + +**智能体工作区(`~/.copaw/workspaces/{agent_id}/`)** + +| 文件/目录 | 作用 | +| -------------------- | -------------------------------------------- | +| `agent.json` | 智能体配置(频道、心跳、工具、技能、MCP 等) | +| `chats.json` | 对话历史 | +| `jobs.json` | 定时任务列表 | +| `token_usage.json` | Token 消耗记录 | +| `AGENTS.md` | _(必需)_ 详细的工作流程、规则和指南 | +| `SOUL.md` | _(必需)_ 核心身份与行为原则 | +| `active_skills/` | 当前激活的技能 | +| `customized_skills/` | 用户自定义的技能 | +| `memory/` | 记忆文件(自动管理) | > **提示:** `SOUL.md` 和 `AGENTS.md` 是 Agent 系统提示词的最低要求。如果它们不存在,Agent > 会退回到通用的 "You are a helpful assistant" 提示。运行 `copaw init` 时会根据你选择的 > 语言(`zh` / `en` / `ru`)自动复制这些文件。你也可以之后在控制台 > (Agent → 运行配置)中切换语言。 +> **多智能体工作区说明:** 详见 [多智能体工作区](./multi-agent) 文档。 + --- ## 用环境变量改路径(可选) @@ -71,12 +100,50 @@ copaw app ## config.json 完整结构 -下面是 **config.json 的完整字段说明**,包括类型、默认值和用途。你不需要填满所有字段——缺失的字段会自动用默认值。 +从 **v0.1.0** 开始,配置文件分为两层: + +1. **全局配置** - `~/.copaw/config.json`(提供商、环境变量、智能体列表) +2. **智能体配置** - `~/.copaw/workspaces/{agent_id}/agent.json`(每个智能体的独立配置) + +### 全局 config.json 示例 + +```json +{ + "agents": { + "active_agent": "default", + "profiles": { + "default": { + "id": "default", + "name": "默认智能体", + "description": "默认工作区智能体", + "enabled": true + }, + "abc123": { + "id": "abc123", + "name": "代码助手", + "description": "专注代码审查和开发", + "enabled": true + } + } + }, + "last_api": { + "host": "127.0.0.1", + "port": 7860 + }, + "show_tool_details": true +} +``` + +### 智能体配置 agent.json 示例 -### 完整示例 +每个智能体在其工作区目录下有独立的 `agent.json`: ```json { + "id": "default", + "name": "默认智能体", + "description": "默认工作区智能体", + "enabled": true, "channels": { "imessage": { "enabled": false, @@ -97,45 +164,22 @@ copaw app "client_id": "", "client_secret": "" }, - "feishu": { - "enabled": false, - "bot_prefix": "", - "app_id": "", - "app_secret": "", - "encrypt_key": "", - "verification_token": "", - "media_dir": "~/.copaw/media" - }, - "qq": { - "enabled": false, - "bot_prefix": "", - "app_id": "", - "client_secret": "" - }, "console": { "enabled": true, "bot_prefix": "" } }, - "agents": { - "defaults": { - "heartbeat": { - "every": "30m", - "target": "main", - "activeHours": null - } - }, - "running": { - "max_iters": 50, - "max_input_length": 131072 - }, - "language": "zh", - "installed_md_files_language": "zh" + "heartbeat": { + "every": "30m", + "target": "main", + "activeHours": null }, - "last_api": { - "host": "127.0.0.1", - "port": 8088 + "running": { + "max_iters": 50, + "max_input_length": 131072 }, + "language": "zh", + "installed_md_files_language": "zh", "user_timezone": "Asia/Shanghai", "last_dispatch": null, "show_tool_details": true diff --git a/website/public/docs/console.en.md b/website/public/docs/console.en.md index 8661ca00a..7c3b8f0c5 100644 --- a/website/public/docs/console.en.md +++ b/website/public/docs/console.en.md @@ -162,6 +162,13 @@ Click **Delete** → confirm. Edit files that define CoPaw's persona and behavior, such as `SOUL.md`, `AGENTS.md`, and `HEARTBEAT.md`, directly in the browser. +> **Multi-Agent Workspace:** Starting from **v0.1.0**, CoPaw supports +> **multi-agent workspace** functionality. You can run multiple independent +> agents in a single CoPaw instance, each with its own workspace, configuration, +> memory, and conversation history. Use the agent switcher at the top of the +> console to change the active agent. See [Multi-Agent Workspace](./multi-agent) +> for details. + ![Workspace](https://img.alicdn.com/imgextra/i3/O1CN01APrwdP1NqT9CKJMFt_!!6000000001621-2-tps-3822-2070.png) **Edit files:** @@ -457,3 +464,4 @@ Ask CoPaw directly, e.g. "How many tokens have I used recently?" or "Show me tok - [Skills](./skills) — Built-in skills and custom skills - [Heartbeat](./heartbeat) — Heartbeat configuration - [CLI](./cli) — Command-line reference +- [Multi-Agent Workspace](./multi-agent) — Multi-agent setup and management diff --git a/website/public/docs/console.zh.md b/website/public/docs/console.zh.md index d126e5e25..9ff39cd8a 100644 --- a/website/public/docs/console.zh.md +++ b/website/public/docs/console.zh.md @@ -155,6 +155,11 @@ 在这里编辑定义 CoPaw 人设和行为的文件——SOUL.md、AGENTS.md、 HEARTBEAT.md 等——全部在浏览器中完成。 +> **多智能体工作区:** 从 **v0.1.0** 开始,CoPaw 支持**多智能体工作区**功能。 +> 您可以在同一个 CoPaw 实例中运行多个独立的智能体,每个智能体拥有独立的 +> 工作区、配置、记忆和对话历史。在控制台顶部可以切换当前操作的智能体。 +> 详见 [多智能体工作区](./multi-agent)。 + ![工作区](https://img.alicdn.com/imgextra/i3/O1CN017nvwCe26Ucy89Kktq_!!6000000007665-2-tps-3822-2070.png) **编辑文件:** @@ -428,3 +433,4 @@ LM Studio 提供商连接 LM Studio 桌面应用内置的 OpenAI 兼容本地服 - [技能](./skills) —— 内置技能说明和自定义技能编写 - [心跳](./heartbeat) —— 心跳配置 - [CLI](./cli) —— 命令行参考 +- [多智能体工作区](./multi-agent) —— 多智能体配置与管理 diff --git a/website/public/docs/intro.en.md b/website/public/docs/intro.en.md index 5e4b2118b..f561fbcdb 100644 --- a/website/public/docs/intro.en.md +++ b/website/public/docs/intro.en.md @@ -54,6 +54,9 @@ what it actually does depends on which Skills you enable. [Heartbeat](./heartbeat). - **Cron jobs** — Scheduled tasks (send X at 9am, ask Y every 2h, etc.), managed via [CLI](./cli) or API. +- **Agent/Workspace** — Starting from **v0.1.0**, CoPaw supports multi-agent workspace, + allowing you to run multiple independent AI agents, each with its own configuration, + memory, skills, and conversation history. See [Multi-Agent Workspace](./multi-agent). Each term is explained in detail in its chapter. @@ -71,4 +74,5 @@ Each term is explained in detail in its chapter. - [Heartbeat](./heartbeat) — Set up scheduled check-in or digest (optional); - [CLI](./cli) — Init, cron jobs, clean working dir, etc.; - [Skills](./skills) — Understand and extend CoPaw’s capabilities; - - [Config & working dir](./config) — Working directory and config file. + - [Config & working dir](./config) — Working directory and config file; + - [Multi-Agent Workspace](./multi-agent) — Multi-agent setup and management (v0.1.0+ feature). diff --git a/website/public/docs/intro.zh.md b/website/public/docs/intro.zh.md index c8bef383e..0027ed4fd 100644 --- a/website/public/docs/intro.zh.md +++ b/website/public/docs/intro.zh.md @@ -48,6 +48,8 @@ CoPaw 由 [AgentScope 团队](https://github.com/agentscope-ai) 基于 频道。详见 [心跳](./heartbeat)。 - **定时任务** — 多条、各自独立配置时间的任务(每天几点发什么、每隔多久问 CoPaw 什么等), 通过 [CLI](./cli) 或 API 管理。 +- **智能体/工作区** — 从 **v0.1.0** 开始,CoPaw 支持多智能体工作区,允许运行多个独立的 AI + 智能体,每个智能体拥有独立的配置、记忆、技能和对话历史。详见 [多智能体工作区](./multi-agent)。 各概念的含义与配置方法,在对应章节中均有说明。 @@ -62,4 +64,5 @@ CoPaw 由 [AgentScope 团队](https://github.com/agentscope-ai) 基于 - [心跳](./heartbeat) — 配置定时自检或摘要(可选); - [CLI](./cli) — 初始化、定时任务、清空工作目录等命令; - [Skills](./skills) — 了解与扩展 CoPaw 能力; - - [配置与工作目录](./config) — 工作目录与配置文件说明。 + - [配置与工作目录](./config) — 工作目录与配置文件说明; + - [多智能体工作区](./multi-agent) — 多智能体配置与管理(v0.1.0+ 新功能)。 diff --git a/website/public/docs/multi-agent.en.md b/website/public/docs/multi-agent.en.md new file mode 100644 index 000000000..c868ca302 --- /dev/null +++ b/website/public/docs/multi-agent.en.md @@ -0,0 +1,360 @@ +# Multi-Agent Workspace + +CoPaw supports **multi-agent workspace**, allowing you to run multiple independent AI agents in a single CoPaw instance, each with its own configuration, memory, skills, and conversation history. + +> This feature was introduced in **v0.1.0**. + +--- + +## What is Multi-Agent? + +Simply put, **multi-agent** lets you run multiple "personas" in one CoPaw, where each persona: + +- Has its own **personality and specialization** (configured via different persona files) +- Remembers **its own conversations** (no cross-talk) +- Uses **different skills** (one good at code, another at writing) +- Connects to **different channels** (one for DingTalk, one for Discord) + +Think of it as having multiple assistants, each with their own specialty. + +--- + +## Why Use Multi-Agent? + +### Use Case 1: Functional Separation + +You might need: + +- A **daily assistant** - casual chat, lookup info, manage todos +- A **code assistant** - focused on code review and development +- A **writing assistant** - focused on document writing and editing + +Each agent focuses on its domain without interference. + +### Use Case 2: Platform Separation + +You might use CoPaw across multiple platforms: + +- **DingTalk** - work-related conversations +- **Discord** - community discussions +- **Console** - personal use + +Different platforms' conversations and configs stay completely isolated. + +### Use Case 3: Testing vs Production + +You might need: + +- **Production agent** - stable config for daily work +- **Test agent** - experiment with new features without affecting production + +--- + +## How to Use? (Recommended Method) + +### Managing Agents in Console + +> This is the simplest way - **no command-line required**. + +#### 1. View and Switch Agents + +After starting CoPaw, you'll see the **Agent Selector** in the **top-right corner** of the console: + +``` +┌───────────────────────────────────┐ +│ Current Agent [Default ▼] (1) │ +└───────────────────────────────────┘ +``` + +Click the dropdown to: + +- View all agents' names and descriptions +- Switch to another agent +- See the current agent's ID + +After switching, the page auto-refreshes to show the new agent's config and data. + +#### 2. Create a New Agent + +Go to **Settings → Agent Management** page: + +1. Click "Create Agent" button +2. Fill in the information: + - **Name**: Give the agent a name (e.g., "Code Assistant") + - **Description**: Explain the agent's purpose (optional) + - **ID**: Leave empty for auto-generation, or customize (e.g., "coder") +3. Click "OK" + +After creation, the new agent appears in the list and you can immediately switch to it. + +#### 3. Configure Agent-Specific Settings + +After switching to an agent, you can configure it individually: + +- **Channels** - Go to "Control → Channels" page to enable/configure channels +- **Skills** - Go to "Agent → Skills" page to enable/disable skills +- **Tools** - Go to "Agent → Tools" page to toggle built-in tools +- **Persona** - Go to "Agent → Workspace" page to edit AGENTS.md and SOUL.md + +These settings **only affect the current agent** and won't impact other agents. + +#### 4. Edit and Delete Agents + +In **Settings → Agent Management** page: + +- Click "Edit" button to modify agent's name and description +- Click "Delete" button to remove agent (default agent cannot be deleted) + +--- + +## Example Scenarios + +### Example 1: Work-Life Separation + +**Scenario**: You want to separate work and personal conversations. + +**Setup**: + +1. Create two agents in console: + + - `work` - work assistant + - `personal` - personal assistant + +2. For `work` agent: + + - Enable DingTalk channel + - Enable code and document-related skills + - Configure formal persona (AGENTS.md) + +3. For `personal` agent: + - Enable Discord or console + - Enable entertainment and news skills + - Configure casual persona + +**Usage**: Automatically use `work` agent on DingTalk, `personal` agent on Discord. + +### Example 2: Specialized Assistant Team + +**Scenario**: You want assistants for different professional domains. + +**Setup**: + +1. Create three agents: + + - `coder` - code assistant (enable code review, file operation skills) + - `writer` - writing assistant (enable document processing, news digest skills) + - `planner` - task assistant (enable cron, email skills) + +2. Switch to the appropriate agent as needed. + +**Benefits**: Each agent focuses on its domain with precise persona and uncluttered conversation history. + +### Example 3: Multi-Language Support + +**Scenario**: You need both Chinese and English assistants. + +**Setup**: + +1. Create two agents: + + - `zh-assistant` - Chinese assistant (language: "zh") + - `en-assistant` - English assistant (language: "en") + +2. Edit their AGENTS.md and SOUL.md in corresponding languages. + +**Usage**: Switch to `zh-assistant` for Chinese conversations, `en-assistant` for English. + +--- + +## FAQ + +### Q: Do I need to create multiple agents? + +Not necessarily. If your use case is simple, **using only the default agent is perfectly fine**. + +Consider creating multiple agents when: + +- You need clear functional separation (work/life, dev/writing, etc.) +- Connecting to multiple platforms and want isolated conversation histories +- Need to test new configs without affecting your daily-use agent + +### Q: Will switching agents lose my conversations? + +No. Each agent's conversation history is saved independently; switching only changes which agent you're currently viewing. + +### Q: Do multiple agents increase costs? + +No. Agents only call the LLM when in use; idle agents don't incur any fees. + +### Q: Can I use multiple agents simultaneously? + +Yes. If you configure different agents for DingTalk and Discord, they can respond to their respective channels simultaneously. + +### Q: How to delete an agent? + +Click the delete button in the "Settings → Agent Management" page in console. + +**Note**: After deletion, the workspace directory is retained (to prevent accidental data loss). To completely remove it, manually delete the `~/.copaw/workspaces/{agent_id}` directory. + +### Q: Can the default agent be deleted? + +Not recommended. The `default` agent is the system's default fallback; deleting it may cause compatibility issues. + +### Q: What can agents share? + +**Globally Shared**: + +- Model provider configuration (API keys, model selection) +- Environment variables (TAVILY_API_KEY, etc.) + +**Independent Configuration**: + +- Channel settings +- Skill enablement +- Conversation history +- Cron jobs +- Persona files + +--- + +## Upgrading from Single-Agent + +If you previously used CoPaw **v0.0.x**, upgrading to **v0.1.0** will **automatically migrate**: + +1. **Automatic Migration on First Start** + + - Old configs and data are automatically moved to the `default` agent workspace + - No manual file operations required + +2. **Verify Migration** + + - After starting CoPaw, check the agent list in console + - You should see an agent named "Default Agent" + - Your old conversations and configs should still be there + +3. **Backup Recommendation** + Back up your working directory before upgrading: + ```bash + cp -r ~/.copaw ~/.copaw.backup + ``` + +--- + +## Advanced: CLI and API + +> If you're not familiar with command-line or APIs, you can skip this section. All features are available in the console. + +### CLI Commands + +All multi-agent-aware CLI commands accept the `--agent-id` parameter (defaults to `default`): + +```bash +# View specific agent's configuration +copaw channels list --agent-id abc123 +copaw cron list --agent-id abc123 +copaw skills list --agent-id abc123 + +# Create cron job for specific agent +copaw cron create \ + --agent-id abc123 \ + --type agent \ + --name "Check Todos" \ + --cron "0 9 * * *" \ + --channel console \ + --target-user "user1" \ + --target-session "session1" \ + --text "What are my todos?" +``` + +**Commands Supporting `--agent-id`**: + +- `copaw channels` - channel management +- `copaw cron` - cron jobs +- `copaw daemon` - runtime status +- `copaw chats` - chat management +- `copaw skills` - skill management + +**Commands NOT Supporting `--agent-id`** (global operations): + +- `copaw init` - initialization +- `copaw providers` - model providers +- `copaw models` - model configuration +- `copaw env` - environment variables + +### REST API + +#### Agent Management API + +| Endpoint | Method | Description | +| ------------------------------- | ------ | --------------- | +| `/api/agents` | GET | List all agents | +| `/api/agents` | POST | Create agent | +| `/api/agents/{agent_id}` | GET | Get agent info | +| `/api/agents/{agent_id}` | PUT | Update agent | +| `/api/agents/{agent_id}` | DELETE | Delete agent | +| `/api/agents/{agent_id}/active` | POST | Activate agent | + +#### Agent-Scoped API + +All agent-specific APIs support the `X-Agent-Id` HTTP header: + +```bash +# Get specific agent's chat list +curl -H "X-Agent-Id: abc123" http://localhost:7860/api/chats + +# Create cron job for specific agent +curl -X POST http://localhost:7860/api/cron/jobs \ + -H "X-Agent-Id: abc123" \ + -H "Content-Type: application/json" \ + -d '{ ... }' +``` + +API endpoints supporting `X-Agent-Id`: + +- `/api/chats/*` - chat management +- `/api/cron/*` - cron jobs +- `/api/config/*` - channel and heartbeat config +- `/api/skills/*` - skill management +- `/api/tools/*` - tool management +- `/api/mcp/*` - MCP client management +- `/api/agent/*` - workspace files and memory + +### Configuration File Structure + +If you need to directly edit configuration files: + +#### Old Structure (v0.0.x) + +``` +~/.copaw/ +├── config.json # All config +├── chats.json +├── jobs.json +├── AGENTS.md +└── ... +``` + +#### New Structure (v0.1.0+) + +``` +~/.copaw/ +├── config.json # Global config (providers, agents.profiles) +└── workspaces/ + ├── default/ # Default agent workspace + │ ├── agent.json # Agent-specific config + │ ├── chats.json + │ ├── jobs.json + │ ├── AGENTS.md + │ └── ... + └── abc123/ # Other agent + └── ... +``` + +--- + +## Related Pages + +- [CLI Commands](./cli) - Detailed CLI reference +- [Configuration & Working Directory](./config) - Config file structure +- [Console](./console) - Web management interface +- [Skills](./skills) - Skill system diff --git a/website/public/docs/multi-agent.zh.md b/website/public/docs/multi-agent.zh.md new file mode 100644 index 000000000..9bdffd8a8 --- /dev/null +++ b/website/public/docs/multi-agent.zh.md @@ -0,0 +1,396 @@ +# 多智能体工作区 + +CoPaw 支持**多智能体工作区**,允许您在同一个 CoPaw 实例中运行多个独立的 AI 智能体,每个智能体拥有自己的配置、记忆、技能和对话历史。 + +> 本功能在 **v0.1.0** 中引入。 + +--- + +## 什么是多智能体? + +简单来说,**多智能体**就是让您可以在一个 CoPaw 中运行多个"分身",每个分身: + +- 有自己的**性格和专长**(通过不同的人设文件配置) +- 记住**各自的对话**(互不干扰) +- 使用**不同的技能**(一个擅长代码,一个擅长写作) +- 连接**不同的频道**(一个负责钉钉,一个负责 Discord) + +就像您有多个助手,每个助手各司其职。 + +--- + +## 为什么需要多智能体? + +### 场景一:按用途分工 + +您可能需要: + +- 一个**日常助手** - 闲聊、查资料、记待办 +- 一个**代码助手** - 专注代码审查和开发 +- 一个**写作助手** - 专注文档撰写和润色 + +每个智能体专注自己的领域,互不干扰。 + +### 场景二:按平台分离 + +您可能在多个平台使用 CoPaw: + +- **钉钉** - 工作相关对话 +- **Discord** - 社区讨论 +- **控制台** - 私人使用 + +不同平台的对话和配置完全隔离,不会混在一起。 + +### 场景三:测试与生产隔离 + +您可能需要: + +- **生产智能体** - 稳定配置,用于日常工作 +- **测试智能体** - 实验新功能,不影响生产环境 + +--- + +## 如何使用?(推荐方式) + +### 在控制台中管理智能体 + +> 这是最简单的方式,**无需任何命令行操作**。 + +#### 1. 查看和切换智能体 + +启动 CoPaw 后,在控制台**右上角**可以看到**智能体切换器**: + +``` +┌───────────────────────────────────┐ +│ 当前智能体 [默认智能体 ▼] (1) │ +└───────────────────────────────────┘ +``` + +点击下拉框可以: + +- 查看所有智能体的名称和描述 +- 切换到其他智能体 +- 看到当前智能体的 ID + +切换后,页面会自动刷新,显示新智能体的配置和数据。 + +#### 2. 创建新智能体 + +进入**设置 → 智能体管理**页面: + +1. 点击"创建智能体"按钮 +2. 填写信息: + - **名称**:给智能体起个名字(如"代码助手") + - **描述**:说明这个智能体的用途(可选) + - **ID**:留空自动生成,或自定义(如"coder") +3. 点击"确定" + +创建后,新智能体会出现在列表中,您可以立即切换过去使用。 + +#### 3. 为智能体配置专属设置 + +切换到某个智能体后,您可以为它单独配置: + +- **频道** - 去"控制 → 频道"页面,启用/配置频道 +- **技能** - 去"智能体 → 技能"页面,启用/禁用技能 +- **工具** - 去"智能体 → 工具"页面,开关内置工具 +- **人设** - 去"智能体 → 工作区"页面,编辑 AGENTS.md 和 SOUL.md + +这些配置**只影响当前智能体**,不会影响其他智能体。 + +#### 4. 编辑和删除智能体 + +在**设置 → 智能体管理**页面: + +- 点击"编辑"按钮修改智能体的名称和描述 +- 点击"删除"按钮移除智能体(默认智能体不能删除) + +--- + +## 使用场景示例 + +### 示例一:工作与生活分离 + +**场景**:您希望工作对话和私人对话分开。 + +**配置**: + +1. 在控制台创建两个智能体: + + - `work` - 工作助手 + - `personal` - 私人助手 + +2. 为 `work` 智能体: + + - 启用钉钉频道 + - 启用代码、文档相关技能 + - 配置正式的人设(AGENTS.md) + +3. 为 `personal` 智能体: + - 启用 Discord 或控制台 + - 启用娱乐、新闻相关技能 + - 配置轻松的人设 + +**使用**:在钉钉聊天时自动使用 `work` 智能体,在 Discord 聊天时使用 `personal` 智能体。 + +### 示例二:专业助手团队 + +**场景**:您希望有多个专业领域的助手。 + +**配置**: + +1. 创建三个智能体: + + - `coder` - 代码助手(启用代码审查、文件操作技能) + - `writer` - 写作助手(启用文档处理、新闻摘要技能) + - `planner` - 任务助手(启用定时任务、邮件技能) + +2. 根据需要切换到对应的智能体使用。 + +**优点**:每个智能体专注自己的领域,人设更精准,对话历史不会混淆。 + +### 示例三:多语言支持 + +**场景**:您需要中英文两个助手。 + +**配置**: + +1. 创建两个智能体: + + - `zh-assistant` - 中文助手(language: "zh") + - `en-assistant` - 英文助手(language: "en") + +2. 分别编辑它们的 AGENTS.md 和 SOUL.md 为对应语言。 + +**使用**:需要中文对话时切换到 `zh-assistant`,需要英文时切换到 `en-assistant`。 + +--- + +## 常见问题 + +### Q: 我需要创建多个智能体吗? + +不一定。如果您的使用场景简单,**只用默认智能体完全足够**。 + +建议创建多个智能体的情况: + +- 需要明确的功能分离(工作/生活、开发/写作等) +- 连接多个平台,希望每个平台有独立的对话历史 +- 需要测试新配置,不想影响日常使用的智能体 + +### Q: 智能体切换会丢失对话吗? + +不会。每个智能体的对话历史都是独立保存的,切换只是改变当前查看的智能体。 + +### Q: 多个智能体会增加成本吗? + +不会。智能体只在使用时才调用 LLM,闲置的智能体不会产生费用。 + +### Q: 可以同时使用多个智能体吗? + +可以。如果您在钉钉和 Discord 都配置了不同的智能体,它们可以同时响应各自频道的消息。 + +### Q: 如何删除智能体? + +在控制台的"设置 → 智能体管理"页面点击删除按钮。 + +**注意**:删除后工作区目录会保留(防止误删数据),如需彻底清理,请手动删除 `~/.copaw/workspaces/{agent_id}` 目录。 + +### Q: 默认智能体可以删除吗? + +不建议删除。`default` 智能体是系统的默认后备,删除可能导致兼容性问题。 + +### Q: 智能体之间可以共享什么? + +**全局共享**: + +- 模型提供商配置(API Key、模型选择) +- 环境变量(TAVILY_API_KEY 等) + +**独立配置**: + +- 频道配置 +- 技能启用状态 +- 对话历史 +- 定时任务 +- 人设文件 + +--- + +## 从单智能体升级 + +如果您之前使用 CoPaw **v0.0.x**,升级到 **v0.1.0** 时会**自动迁移**: + +1. **首次启动时自动迁移** + + - 旧的配置和数据会自动移动到 `default` 智能体工作区 + - 您无需手动操作任何文件 + +2. **验证迁移** + + - 启动 CoPaw 后,在控制台查看智能体列表 + - 应该能看到一个名为"默认智能体"的智能体 + - 您的旧对话和配置都应该还在 + +3. **备份建议** + 升级前备份工作目录: + ```bash + cp -r ~/.copaw ~/.copaw.backup + ``` + +--- + +## 进阶:CLI 和 API + +> 如果您不熟悉命令行或 API,可以跳过这部分。所有功能都可以在控制台中完成。 + +### CLI 命令 + +所有支持多智能体的 CLI 命令都接受 `--agent-id` 参数(默认为 `default`): + +```bash +# 查看特定智能体的配置 +copaw channels list --agent-id abc123 +copaw cron list --agent-id abc123 +copaw skills list --agent-id abc123 + +# 为特定智能体创建定时任务 +copaw cron create \ + --agent-id abc123 \ + --type agent \ + --name "检查待办" \ + --cron "0 9 * * *" \ + --channel console \ + --target-user "user1" \ + --target-session "session1" \ + --text "我有什么待办事项?" +``` + +**支持 `--agent-id` 的命令**: + +- `copaw channels` - 频道管理 +- `copaw cron` - 定时任务 +- `copaw daemon` - 运行状态 +- `copaw chats` - 对话管理 +- `copaw skills` - 技能管理 + +**不支持 `--agent-id` 的命令**(全局操作): + +- `copaw init` - 初始化 +- `copaw providers` - 模型提供商 +- `copaw models` - 模型配置 +- `copaw env` - 环境变量 + +### REST API + +#### 智能体管理 API + +| 端点 | 方法 | 说明 | +| ------------------------------- | ------ | -------------- | +| `/api/agents` | GET | 列出所有智能体 | +| `/api/agents` | POST | 创建新智能体 | +| `/api/agents/{agent_id}` | GET | 获取智能体详情 | +| `/api/agents/{agent_id}` | PUT | 更新智能体配置 | +| `/api/agents/{agent_id}` | DELETE | 删除智能体 | +| `/api/agents/{agent_id}/active` | POST | 激活智能体 | + +#### 智能体专属 API + +所有智能体专属的 API 都支持 `X-Agent-Id` HTTP 头: + +```bash +# 获取特定智能体的对话列表 +curl -H "X-Agent-Id: abc123" http://localhost:7860/api/chats + +# 为特定智能体创建定时任务 +curl -X POST http://localhost:7860/api/cron/jobs \ + -H "X-Agent-Id: abc123" \ + -H "Content-Type: application/json" \ + -d '{ ... }' +``` + +支持 `X-Agent-Id` 的 API 端点: + +- `/api/chats/*` - 对话管理 +- `/api/cron/*` - 定时任务 +- `/api/config/*` - 频道和心跳配置 +- `/api/skills/*` - 技能管理 +- `/api/tools/*` - 工具管理 +- `/api/mcp/*` - MCP 客户端管理 +- `/api/agent/*` - 工作区文件和记忆 + +### 配置文件结构 + +如果您需要直接编辑配置文件: + +#### 旧结构(v0.0.x) + +``` +~/.copaw/ +├── config.json # 包含所有配置 +├── chats.json +├── jobs.json +├── AGENTS.md +└── ... +``` + +#### 新结构(v0.1.0+) + +``` +~/.copaw/ +├── config.json # 全局配置(providers, agents.profiles) +└── workspaces/ + ├── default/ # 默认智能体工作区 + │ ├── agent.json # 智能体专属配置 + │ ├── chats.json + │ ├── jobs.json + │ ├── AGENTS.md + │ └── ... + └── abc123/ # 其他智能体 + └── ... +``` + +--- + +## 最佳实践 + +### 合理规划智能体数量 + +✅ **推荐**:3-5 个智能体,按主要功能或平台分类 +❌ **不推荐**:为每个小功能都创建智能体 + +过多智能体会增加管理复杂度,得不偿失。 + +### 使用清晰的名称 + +✅ **好的命名**: + +- `default` - 默认智能体 +- `work-assistant` - 工作助手 +- `code-reviewer` - 代码审查助手 + +❌ **不好的命名**: + +- `abc123` - 无意义的随机字符 +- `test1`, `test2` - 不清楚用途 + +### 定期备份 + +重要智能体的工作区建议定期备份: + +```bash +# 备份特定智能体 +cp -r ~/.copaw/workspaces/abc123 ~/backups/agent-abc123-$(date +%Y%m%d) + +# 备份所有智能体 +cp -r ~/.copaw/workspaces ~/backups/workspaces-$(date +%Y%m%d) +``` + +--- + +## 相关页面 + +- [CLI 命令](./cli) - 命令行工具详细说明 +- [配置与工作目录](./config) - 配置文件结构 +- [控制台](./console) - Web 管理界面 +- [技能](./skills) - 技能系统 diff --git a/website/src/pages/Docs.tsx b/website/src/pages/Docs.tsx index e97aefe28..c7b9dedc6 100644 --- a/website/src/pages/Docs.tsx +++ b/website/src/pages/Docs.tsx @@ -180,6 +180,7 @@ const DOC_SLUG_ICONS: Record = { intro: Rocket, quickstart: Zap, console: Terminal, + "multi-agent": Users, models: Cpu, channels: MessageSquare, skills: Wrench, @@ -203,7 +204,11 @@ const DOC_SLUGS: DocEntry[] = [ titleKey: "docs.quickstart", children: [{ slug: "desktop", titleKey: "docs.desktop" }], }, - { slug: "console", titleKey: "docs.console" }, + { + slug: "console", + titleKey: "docs.console", + children: [{ slug: "multi-agent", titleKey: "docs.multiAgent" }], + }, { slug: "models", titleKey: "docs.models" }, { slug: "channels", titleKey: "docs.channels" }, { slug: "skills", titleKey: "docs.skills" }, @@ -235,6 +240,7 @@ const DOC_TITLES: Record> = { "docs.quickstart": "快速开始", "docs.desktop": "桌面应用", "docs.console": "控制台", + "docs.multiAgent": "多智能体工作区", "docs.models": "模型", "docs.channels": "频道配置", "docs.heartbeat": "心跳", @@ -255,6 +261,7 @@ const DOC_TITLES: Record> = { "docs.quickstart": "Quick start", "docs.desktop": "Desktop App", "docs.console": "Console", + "docs.multiAgent": "Multi-Agent Workspace", "docs.models": "Models", "docs.channels": "Channels", "docs.heartbeat": "Heartbeat", From 2592cce2ffdfdeaac19f414b9dd4175c678f0f06 Mon Sep 17 00:00:00 2001 From: Runlin Lei Date: Mon, 16 Mar 2026 21:35:52 +0800 Subject: [PATCH 20/68] feat(channels): generalize discord debounce key fix to all channels (#1583) --- src/copaw/app/channels/base.py | 13 ++--- src/copaw/app/channels/dingtalk/channel.py | 12 ----- src/copaw/app/channels/discord_/channel.py | 56 +--------------------- 3 files changed, 8 insertions(+), 73 deletions(-) diff --git a/src/copaw/app/channels/base.py b/src/copaw/app/channels/base.py index d3d15b5e2..f07cd9d03 100644 --- a/src/copaw/app/channels/base.py +++ b/src/copaw/app/channels/base.py @@ -130,15 +130,15 @@ def _is_native_payload(self, payload: Any) -> bool: def get_debounce_key(self, payload: Any) -> str: """ Key for time debounce (same key = same conversation). - Override for channel-specific keys (e.g. short conversation_id). + Delegates to ``resolve_session_id`` so every channel gets + session-scoped isolation automatically. """ if isinstance(payload, dict): + sender_id = payload.get("sender_id") or "" meta = payload.get("meta") or {} - return ( - payload.get("session_id") - or meta.get("conversation_id") - or payload.get("sender_id") - or "" + return payload.get("session_id") or self.resolve_session_id( + sender_id, + meta, ) return getattr(payload, "session_id", "") or "" @@ -162,6 +162,7 @@ def merge_native_items(self, items: List[Any]) -> Any: "reply_loop", "incoming_message", "conversation_id", + "message_id", ): if k in m: merged_meta[k] = m[k] diff --git a/src/copaw/app/channels/dingtalk/channel.py b/src/copaw/app/channels/dingtalk/channel.py index 7964c9224..6e7f68fa1 100644 --- a/src/copaw/app/channels/dingtalk/channel.py +++ b/src/copaw/app/channels/dingtalk/channel.py @@ -1212,10 +1212,6 @@ async def send_content_parts( else: await self.send(to_handle, body.strip() or prefix, meta) - def get_debounce_key(self, payload: Any) -> str: - """Use short conversation_id or channel:sender for time debounce.""" - return self._debounce_key(payload) - def merge_native_items(self, items: List[Any]) -> Any: """Merge payloads (content_parts + meta) for DingTalk.""" return self._merge_native(items) @@ -1472,14 +1468,6 @@ async def _process_one_request( request.session_id or f"{self.channel}:{request.user_id}", ) - def _debounce_key(self, native: Any) -> str: - payload = native if isinstance(native, dict) else {} - meta = payload.get("meta") or {} - cid = meta.get("conversation_id") or "" - if cid: - return short_session_id_from_conversation_id(str(cid)) - return f"{self.channel}:{payload.get('sender_id', '')}" - def _merge_native(self, items: list) -> dict: """Merge multiple native payloads into one (content_parts + meta).""" if not items: diff --git a/src/copaw/app/channels/discord_/channel.py b/src/copaw/app/channels/discord_/channel.py index bafe41eec..6400ff2fd 100644 --- a/src/copaw/app/channels/discord_/channel.py +++ b/src/copaw/app/channels/discord_/channel.py @@ -9,7 +9,7 @@ import tempfile from pathlib import Path from urllib.parse import urlparse -from typing import Any, Dict, List, Optional +from typing import Any, Optional import aiohttp from agentscope_runtime.engine.schemas.agent_schemas import ( @@ -515,60 +515,6 @@ async def stop(self) -> None: if self._client: await self._client.close() - # ------------------------------------------------------------------ - # Debounce: use per-channel keys so concurrent messages from the same - # user in different channels/threads are NOT merged together. - # ------------------------------------------------------------------ - - def get_debounce_key(self, payload: Any) -> str: - """Return a debounce key scoped to the Discord channel or DM. - - The base class falls back to ``sender_id``, which causes - ``ChannelManager._drain_same_key()`` to incorrectly merge - messages when the same user sends to multiple channels at the - same time. This override uses ``resolve_session_id`` so each - channel/thread gets its own isolated debounce bucket. - """ - if isinstance(payload, dict): - meta = payload.get("meta") or {} - sender_id = payload.get("sender_id") or "" - return self.resolve_session_id(sender_id, meta) - return getattr(payload, "session_id", "") or "" - - def merge_native_items(self, items: List[Any]) -> Any: - """Merge native payloads while preserving Discord metadata. - - Extends the base implementation to also carry over - Discord-specific meta keys (``channel_id``, ``message_id``, - ``guild_id``, ``is_dm``, ``is_group``) from the first item. - """ - if not items: - return None - first = items[0] if isinstance(items[0], dict) else {} - merged_parts: List[Any] = [] - merged_meta: Dict[str, Any] = dict(first.get("meta") or {}) - for it in items: - p = it if isinstance(it, dict) else {} - merged_parts.extend(p.get("content_parts") or []) - m = p.get("meta") or {} - for k in ( - "reply_future", - "reply_loop", - "incoming_message", - "conversation_id", - "message_id", - ): - if k in m: - merged_meta[k] = m[k] - return { - "channel_id": first.get("channel_id") or self.channel, - "sender_id": first.get("sender_id") or "", - "content_parts": merged_parts, - "meta": merged_meta, - } - - # ------------------------------------------------------------------ - def resolve_session_id( self, sender_id: str, From 267dd4472d655515ca1566eb8be4506771a9cfa2 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Mon, 16 Mar 2026 21:36:52 +0800 Subject: [PATCH 21/68] feat(console): preserve custom cron expressions in UI (#1257) --- console/src/App.tsx | 7 +- .../Control/CronJobs/components/parseCron.ts | 151 +++++++++++++----- src/copaw/app/_app.py | 19 ++- src/copaw/app/crons/models.py | 3 + 4 files changed, 139 insertions(+), 41 deletions(-) diff --git a/console/src/App.tsx b/console/src/App.tsx index ed39545a0..e764c3647 100644 --- a/console/src/App.tsx +++ b/console/src/App.tsx @@ -37,7 +37,12 @@ const GlobalStyle = createGlobalStyle` } `; +function getRouterBasename(pathname: string): string | undefined { + return /^\/console(?:\/|$)/.test(pathname) ? "/console" : undefined; +} + function App() { + const basename = getRouterBasename(window.location.pathname); const { i18n } = useTranslation(); const lang = i18n.resolvedLanguage || i18n.language || "en"; const [antdLocale, setAntdLocale] = useState( @@ -61,7 +66,7 @@ function App() { }, [i18n]); return ( - + = { +const ORDERED_DAYS = ["mon", "tue", "wed", "thu", "fri", "sat", "sun"] as const; +type DayName = (typeof ORDERED_DAYS)[number]; + +const NUM_TO_NAME: Record = { "0": "sun", "1": "mon", "2": "tue", @@ -34,7 +38,11 @@ const NUM_TO_NAME: Record = { "7": "sun", }; -const VALID_NAMES = new Set(["mon", "tue", "wed", "thu", "fri", "sat", "sun"]); +const VALID_NAMES = new Set(ORDERED_DAYS); + +function isDayName(value: string): value is DayName { + return VALID_NAMES.has(value as DayName); +} /** * Parse cron expression to CronParts @@ -57,32 +65,31 @@ export function parseCron(cron: string): CronParts { const [, minute, hour, dayOfMonth, month, dayOfWeek] = match; - // Hourly: "0 * * * *" or "*/N * * * *" where N > 1 + // Hourly: "0 * * * *" if ( hour === "*" && dayOfMonth === "*" && month === "*" && - dayOfWeek === "*" + dayOfWeek === "*" && + minute === "0" ) { - if (minute === "0") { - return { type: "hourly", minute: 0 }; - } + return { type: "hourly", minute: 0 }; } // Daily: "M H * * *" if (dayOfMonth === "*" && month === "*" && dayOfWeek === "*") { - const h = parseInt(hour, 10); - const m = parseInt(minute, 10); - if (!isNaN(h) && !isNaN(m) && h >= 0 && h < 24 && m >= 0 && m < 60) { + const h = parsePlainCronNumber(hour, 0, 23); + const m = parsePlainCronNumber(minute, 0, 59); + if (h !== null && m !== null) { return { type: "daily", hour: h, minute: m }; } } // Weekly: "M H * * D" where D is days if (dayOfMonth === "*" && month === "*" && dayOfWeek !== "*") { - const h = parseInt(hour, 10); - const m = parseInt(minute, 10); - if (!isNaN(h) && !isNaN(m) && h >= 0 && h < 24 && m >= 0 && m < 60) { + const h = parsePlainCronNumber(hour, 0, 23); + const m = parsePlainCronNumber(minute, 0, 59); + if (h !== null && m !== null) { const days = parseDaysOfWeek(dayOfWeek); if (days.length > 0) { return { type: "weekly", hour: h, minute: m, daysOfWeek: days }; @@ -94,6 +101,23 @@ export function parseCron(cron: string): CronParts { return { type: "custom", rawCron: trimmed }; } +function parsePlainCronNumber( + value: string, + min: number, + max: number, +): number | null { + if (!INTEGER_RE.test(value)) { + return null; + } + + const parsed = Number(value); + if (parsed < min || parsed > max) { + return null; + } + + return parsed; +} + /** * Serialize CronParts back to cron expression */ @@ -111,10 +135,7 @@ export function serializeCron(parts: CronParts): string { case "weekly": { const h = parts.hour ?? 9; const m = parts.minute ?? 0; - const days = - parts.daysOfWeek && parts.daysOfWeek.length > 0 - ? parts.daysOfWeek.join(",") - : "mon"; // default Monday + const days = serializeDaysOfWeek(parts.daysOfWeek); return `${m} ${h} * * ${days}`; } @@ -130,50 +151,108 @@ export function serializeCron(parts: CronParts): string { * Parse day of week field to string abbreviations. * * Accepts both numeric (crontab convention: 0=Sun … 6=Sat) and - * named values (mon, tue, …). Always returns abbreviation strings. + * named values (mon, tue, …). Always returns abbreviation strings. + * Invalid or lossy tokens return an empty array so callers can + * fall back to `custom`. */ function parseDaysOfWeek(dayOfWeek: string): string[] { - const days: string[] = []; + const days: DayName[] = []; const parts = dayOfWeek.split(","); for (const part of parts) { const trimmed = part.trim().toLowerCase(); - // Try as a name first - if (VALID_NAMES.has(trimmed)) { + if (!trimmed) { + return []; + } + + if (isDayName(trimmed)) { if (!days.includes(trimmed)) { days.push(trimmed); } continue; } - // Handle ranges like "1-5" or "mon-fri" if (trimmed.includes("-")) { - const [startStr, endStr] = trimmed.split("-"); + const rangeParts = trimmed.split("-"); + if ( + rangeParts.length !== 2 || + rangeParts[0] === "" || + rangeParts[1] === "" + ) { + return []; + } + + const [startStr, endStr] = rangeParts; const startName = NUM_TO_NAME[startStr] || startStr; const endName = NUM_TO_NAME[endStr] || endStr; - if (VALID_NAMES.has(startName) && VALID_NAMES.has(endName)) { - // For ranges, expand to individual days - const ordered = ["mon", "tue", "wed", "thu", "fri", "sat", "sun"]; - const si = ordered.indexOf(startName); - const ei = ordered.indexOf(endName); - if (si !== -1 && ei !== -1) { - for (let i = si; i <= ei; i++) { - if (!days.includes(ordered[i])) { - days.push(ordered[i]); - } - } + + if (!isDayName(startName) || !isDayName(endName)) { + return []; + } + + const si = ORDERED_DAYS.indexOf(startName); + const ei = ORDERED_DAYS.indexOf(endName); + if (si === -1 || ei === -1 || si > ei) { + return []; + } + + for (let i = si; i <= ei; i++) { + if (!days.includes(ORDERED_DAYS[i])) { + days.push(ORDERED_DAYS[i]); } } continue; } - // Try as a number const name = NUM_TO_NAME[trimmed]; - if (name && !days.includes(name)) { + if (!name) { + return []; + } + + if (!days.includes(name)) { days.push(name); } } return days; } + +function serializeDaysOfWeek(daysOfWeek?: string[]): string { + if (!daysOfWeek || daysOfWeek.length === 0) { + return "mon"; + } + + const selectedDays = ORDERED_DAYS.filter((day) => daysOfWeek.includes(day)); + if (selectedDays.length === 0) { + return "mon"; + } + + const segments: string[] = []; + let rangeStart = selectedDays[0]; + let previousDay = selectedDays[0]; + + for (let i = 1; i <= selectedDays.length; i++) { + const currentDay = selectedDays[i]; + const isContiguous = + currentDay !== undefined && + ORDERED_DAYS.indexOf(currentDay) === + ORDERED_DAYS.indexOf(previousDay) + 1; + + if (isContiguous) { + previousDay = currentDay; + continue; + } + + if (rangeStart === previousDay) { + segments.push(rangeStart); + } else { + segments.push(`${rangeStart}-${previousDay}`); + } + + rangeStart = currentDay; + previousDay = currentDay ?? previousDay; + } + + return segments.join(","); +} diff --git a/src/copaw/app/_app.py b/src/copaw/app/_app.py index 25ad975cb..d160d8ff5 100644 --- a/src/copaw/app/_app.py +++ b/src/copaw/app/_app.py @@ -309,6 +309,12 @@ def get_version(): if os.path.isdir(_CONSOLE_STATIC_DIR): _console_path = Path(_CONSOLE_STATIC_DIR) + def _serve_console_index(): + if _CONSOLE_INDEX and _CONSOLE_INDEX.exists(): + return FileResponse(_CONSOLE_INDEX) + + raise HTTPException(status_code=404, detail="Not Found") + @app.get("/logo.png") def _console_logo(): f = _console_path / "logo.png" @@ -333,9 +339,14 @@ def _console_icon(): name="assets", ) + @app.get("/console") + @app.get("/console/") + @app.get("/console/{full_path:path}") + def _console_spa_alias(full_path: str = ""): + _ = full_path + return _serve_console_index() + @app.get("/{full_path:path}") def _console_spa(full_path: str): - if _CONSOLE_INDEX and _CONSOLE_INDEX.exists(): - return FileResponse(_CONSOLE_INDEX) - - raise HTTPException(status_code=404, detail="Not Found") + _ = full_path + return _serve_console_index() diff --git a/src/copaw/app/crons/models.py b/src/copaw/app/crons/models.py index f207e18dd..bf0169b11 100644 --- a/src/copaw/app/crons/models.py +++ b/src/copaw/app/crons/models.py @@ -44,6 +44,9 @@ def _crontab_dow_to_name(field: str) -> str: return field def _convert_token(tok: str) -> str: + if "/" in tok: + base, step = tok.rsplit("/", 1) + return f"{_convert_token(base)}/{step}" if "-" in tok: parts = tok.split("-", 1) return "-".join(_CRONTAB_NUM_TO_NAME.get(p, p) for p in parts) From 28f82947db3ded1b545d07bc19f81e48c0dd1be8 Mon Sep 17 00:00:00 2001 From: zhijianma Date: Mon, 16 Mar 2026 22:19:42 +0800 Subject: [PATCH 22/68] feat(console): implement console/chat endpoint instead of agent/process (#1571) --- console/src/pages/Chat/index.tsx | 2 +- src/copaw/app/channels/console/channel.py | 39 +++++++-- src/copaw/app/routers/agent_scoped.py | 2 + src/copaw/app/routers/console.py | 98 ++++++++++++++++++++++- 4 files changed, 132 insertions(+), 9 deletions(-) diff --git a/console/src/pages/Chat/index.tsx b/console/src/pages/Chat/index.tsx index d5ccf2b19..204a4614f 100644 --- a/console/src/pages/Chat/index.tsx +++ b/console/src/pages/Chat/index.tsx @@ -341,7 +341,7 @@ export default function ChatPage() { console.warn("Failed to get selected agent from storage:", error); } - return fetch(defaultConfig?.api?.baseURL || getApiUrl("/agent/process"), { + return fetch(defaultConfig?.api?.baseURL || getApiUrl("/console/chat"), { method: "POST", headers, body: JSON.stringify(requestBody), diff --git a/src/copaw/app/channels/console/channel.py b/src/copaw/app/channels/console/channel.py index 26d480c42..4498c1fef 100644 --- a/src/copaw/app/channels/console/channel.py +++ b/src/copaw/app/channels/console/channel.py @@ -5,17 +5,18 @@ A lightweight channel that prints all agent responses to stdout. Messages are sent to the agent via the standard AgentApp ``/agent/process`` -endpoint. This channel only handles the **output** side: whenever a -completed message event or a proactive send arrives, it is pretty-printed -to the terminal. +endpoint or via POST /console/chat. This channel handles the **output** side: +whenever a completed message event or a proactive send arrives, it is +pretty-printed to the terminal. """ from __future__ import annotations import logging import os import sys +import json from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from agentscope_runtime.engine.schemas.agent_schemas import RunStatus @@ -139,6 +140,19 @@ def from_config( filter_thinking=filter_thinking, ) + def resolve_session_id( + self, + sender_id: str, + channel_meta: Optional[dict] = None, + ) -> str: + """Resolve session_id: use explicit meta['session_id'] when provided + (e.g. from the HTTP /console/chat API), otherwise fall back to + 'console:'. + """ + if channel_meta and channel_meta.get("session_id"): + return channel_meta["session_id"] + return f"{self.channel}:{sender_id}" + def build_agent_request_from_native(self, native_payload: Any) -> Any: """ Build AgentRequest from console native payload (dict with @@ -161,8 +175,8 @@ def build_agent_request_from_native(self, native_payload: Any) -> Any: request.channel_meta = meta return request - async def consume_one(self, payload: Any) -> None: - """Process one payload (AgentRequest or native dict) from queue.""" + async def stream_one(self, payload: Any) -> AsyncGenerator[str, None]: + """Process one payload and yield SSE-formatted events""" if isinstance(payload, dict) and "content_parts" in payload: session_id = self.resolve_session_id( payload.get("sender_id") or "", @@ -212,6 +226,14 @@ async def consume_one(self, payload: Any) -> None: ev_type, ) + if hasattr(event, "model_dump_json"): + data = event.model_dump_json() + elif hasattr(event, "json"): + data = event.json() + else: + data = json.dumps({"text": str(event)}) + yield f"data: {data}\n\n" + if obj == "message" and status == RunStatus.Completed: parts = self._message_to_content_parts(event) self._print_parts(parts, ev_type) @@ -242,6 +264,11 @@ async def consume_one(self, payload: Any) -> None: err_msg = str(e).strip() or "An error occurred while processing." self._print_error(err_msg) + async def consume_one(self, payload: Any) -> None: + """Process one payload; drain stream_one (queue/terminal).""" + async for _ in self.stream_one(payload): + pass + # ── pretty-print helpers ──────────────────────────────────────── def _print_parts( diff --git a/src/copaw/app/routers/agent_scoped.py b/src/copaw/app/routers/agent_scoped.py index 4b127775d..f0437adb8 100644 --- a/src/copaw/app/routers/agent_scoped.py +++ b/src/copaw/app/routers/agent_scoped.py @@ -64,6 +64,7 @@ def create_agent_scoped_router() -> APIRouter: from .workspace import router as workspace_router from ..crons.api import router as cron_router from ..runner.api import router as chats_router + from .console import router as console_router # Create parent router with agentId parameter router = APIRouter(prefix="/agents/{agentId}", tags=["agent-scoped"]) @@ -85,5 +86,6 @@ def create_agent_scoped_router() -> APIRouter: router.include_router(skills_router) router.include_router(tools_router) router.include_router(workspace_router) + router.include_router(console_router) return router diff --git a/src/copaw/app/routers/console.py b/src/copaw/app/routers/console.py index c44883f97..0213a52c7 100644 --- a/src/copaw/app/routers/console.py +++ b/src/copaw/app/routers/console.py @@ -1,12 +1,106 @@ # -*- coding: utf-8 -*- -"""Console APIs for push messages.""" +"""Console APIs: push messages and chat.""" +from __future__ import annotations -from fastapi import APIRouter, Query +import json +import logging +from typing import AsyncGenerator, Union +from fastapi import APIRouter, HTTPException, Query, Request +from starlette.responses import StreamingResponse + +from agentscope_runtime.engine.schemas.agent_schemas import AgentRequest + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/console", tags=["console"]) +@router.post( + "/chat", + status_code=200, + summary="Chat with console (streaming response)", + description="Agent API Request Format. " + "See https://runtime.agentscope.io/en/protocol.html for " + "more details.", +) +async def post_console_chat( + request_data: Union[AgentRequest, dict], + request: Request, +) -> StreamingResponse: + """Accept a user message and stream the agent response. + + Accepts AgentRequest or dict, builds native payload, and streams events + via channel.stream_one(). + """ + + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + + # Extract channel info from request + if isinstance(request_data, AgentRequest): + channel_id = request_data.channel or "console" + sender_id = request_data.user_id or "default" + session_id = request_data.session_id or "default" + content_parts = ( + list(request_data.input[0].content) if request_data.input else [] + ) + else: + # Dict format - extract from request body + channel_id = request_data.get("channel", "console") + sender_id = request_data.get("user_id", "default") + session_id = request_data.get("session_id", "default") + input_data = request_data.get("input", []) + + # Extract content from input array + content_parts = [] + if input_data and len(input_data) > 0: + last_msg = input_data[-1] + if hasattr(last_msg, "content"): + content_parts = list(last_msg.content or []) + elif isinstance(last_msg, dict) and "content" in last_msg: + content_parts = last_msg["content"] or [] + + # + console_channel = await workspace.channel_manager.get_channel("console") + if console_channel is None: + raise HTTPException( + status_code=503, + detail="Channel Console not found", + ) + + # Build native payload + native_payload = { + "channel_id": channel_id, + "sender_id": sender_id, + "content_parts": content_parts, + "meta": { + "session_id": session_id, + "user_id": sender_id, + }, + } + + async def event_generator() -> AsyncGenerator[str, None]: + try: + async for event_data in console_channel.stream_one(native_payload): + yield event_data + except Exception as e: + logger.exception("Console chat stream error") + yield f"data: {json.dumps({'error': str(e)})}\n\n" + finally: + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + @router.get("/push-messages") async def get_push_messages( session_id: str | None = Query(None, description="Optional session id"), From a2bd2df0e7f393831137845dac470a8c3f909963 Mon Sep 17 00:00:00 2001 From: Ping <49363458+gnipping@users.noreply.github.com> Date: Tue, 17 Mar 2026 08:33:02 +0800 Subject: [PATCH 23/68] feat(security): add skill scanner (#564) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 平砚 Co-authored-by: xieyxclack --- console/src/api/modules/security.ts | 92 ++++ console/src/locales/en.json | 56 +- console/src/locales/ja.json | 56 +- console/src/locales/ru.json | 56 +- console/src/locales/zh.json | 56 +- console/src/pages/Agent/Skills/useSkills.ts | 189 ++++++- .../components/SkillScannerSection.tsx | 439 +++++++++++++++ .../Settings/Security/components/index.ts | 1 + console/src/pages/Settings/Security/index.tsx | 21 +- .../Settings/Security/useSkillScanner.ts | 127 +++++ pyproject.toml | 3 + src/copaw/agents/skills_manager.py | 56 ++ src/copaw/app/routers/config.py | 128 +++++ src/copaw/app/routers/skills.py | 85 ++- src/copaw/config/config.py | 43 ++ src/copaw/security/__init__.py | 2 + src/copaw/security/skill_scanner/__init__.py | 507 ++++++++++++++++++ .../skill_scanner/analyzers/__init__.py | 89 +++ .../analyzers/pattern_analyzer.py | 392 ++++++++++++++ .../skill_scanner/data/default_policy.yaml | 244 +++++++++ src/copaw/security/skill_scanner/models.py | 234 ++++++++ .../rules/signatures/command_injection.yaml | 194 +++++++ .../rules/signatures/data_exfiltration.yaml | 141 +++++ .../rules/signatures/hardcoded_secrets.yaml | 149 +++++ .../rules/signatures/obfuscation.yaml | 46 ++ .../rules/signatures/prompt_injection.yaml | 58 ++ .../rules/signatures/resource_abuse.yaml | 39 ++ .../rules/signatures/social_engineering.yaml | 27 + .../rules/signatures/supply_chain.yaml | 11 + .../signatures/unauthorized_tool_use.yaml | 59 ++ .../security/skill_scanner/scan_policy.py | 475 ++++++++++++++++ src/copaw/security/skill_scanner/scanner.py | 318 +++++++++++ website/public/docs/security.en.md | 114 ++++ website/public/docs/security.zh.md | 114 ++++ website/src/pages/Docs.tsx | 5 + 35 files changed, 4601 insertions(+), 25 deletions(-) create mode 100644 console/src/pages/Settings/Security/components/SkillScannerSection.tsx create mode 100644 console/src/pages/Settings/Security/useSkillScanner.ts create mode 100644 src/copaw/security/skill_scanner/__init__.py create mode 100644 src/copaw/security/skill_scanner/analyzers/__init__.py create mode 100644 src/copaw/security/skill_scanner/analyzers/pattern_analyzer.py create mode 100644 src/copaw/security/skill_scanner/data/default_policy.yaml create mode 100644 src/copaw/security/skill_scanner/models.py create mode 100644 src/copaw/security/skill_scanner/rules/signatures/command_injection.yaml create mode 100644 src/copaw/security/skill_scanner/rules/signatures/data_exfiltration.yaml create mode 100644 src/copaw/security/skill_scanner/rules/signatures/hardcoded_secrets.yaml create mode 100644 src/copaw/security/skill_scanner/rules/signatures/obfuscation.yaml create mode 100644 src/copaw/security/skill_scanner/rules/signatures/prompt_injection.yaml create mode 100644 src/copaw/security/skill_scanner/rules/signatures/resource_abuse.yaml create mode 100644 src/copaw/security/skill_scanner/rules/signatures/social_engineering.yaml create mode 100644 src/copaw/security/skill_scanner/rules/signatures/supply_chain.yaml create mode 100644 src/copaw/security/skill_scanner/rules/signatures/unauthorized_tool_use.yaml create mode 100644 src/copaw/security/skill_scanner/scan_policy.py create mode 100644 src/copaw/security/skill_scanner/scanner.py create mode 100644 website/public/docs/security.en.md create mode 100644 website/public/docs/security.zh.md diff --git a/console/src/api/modules/security.ts b/console/src/api/modules/security.ts index 4332fe37b..2a6580391 100644 --- a/console/src/api/modules/security.ts +++ b/console/src/api/modules/security.ts @@ -20,7 +20,51 @@ export interface ToolGuardConfig { disabled_rules: string[]; } +// ── Skill Scanner types ──────────────────────────────────────────── + +export interface SkillScannerWhitelistEntry { + skill_name: string; + content_hash: string; + added_at: string; +} + +export type SkillScannerMode = "block" | "warn" | "off"; + +export interface SkillScannerConfig { + mode: SkillScannerMode; + timeout: number; + whitelist: SkillScannerWhitelistEntry[]; +} + +export interface BlockedSkillFinding { + severity: string; + title: string; + description: string; + file_path: string; + line_number: number | null; + rule_id: string; +} + +export interface BlockedSkillRecord { + skill_name: string; + blocked_at: string; + max_severity: string; + findings: BlockedSkillFinding[]; + content_hash: string; + action: "blocked" | "warned"; +} + +export interface SecurityScanErrorResponse { + type: "security_scan_failed"; + detail: string; + skill_name: string; + max_severity: string; + findings: BlockedSkillFinding[]; +} + export const securityApi = { + // ── Tool Guard ────────────────────────────────────────────────── + getToolGuard: () => request("/config/security/tool-guard"), updateToolGuard: (body: ToolGuardConfig) => @@ -31,4 +75,52 @@ export const securityApi = { getBuiltinRules: () => request("/config/security/tool-guard/builtin-rules"), + + // ── Skill Scanner ─────────────────────────────────────────────── + + getSkillScanner: () => + request("/config/security/skill-scanner"), + + updateSkillScanner: (body: SkillScannerConfig) => + request("/config/security/skill-scanner", { + method: "PUT", + body: JSON.stringify(body), + }), + + getBlockedHistory: () => + request( + "/config/security/skill-scanner/blocked-history", + ), + + clearBlockedHistory: () => + request<{ cleared: boolean }>( + "/config/security/skill-scanner/blocked-history", + { method: "DELETE" }, + ), + + removeBlockedEntry: (index: number) => + request<{ removed: boolean }>( + `/config/security/skill-scanner/blocked-history/${index}`, + { method: "DELETE" }, + ), + + addToWhitelist: (skillName: string, contentHash: string = "") => + request<{ whitelisted: boolean; skill_name: string }>( + "/config/security/skill-scanner/whitelist", + { + method: "POST", + body: JSON.stringify({ + skill_name: skillName, + content_hash: contentHash, + }), + }, + ), + + removeFromWhitelist: (skillName: string) => + request<{ removed: boolean; skill_name: string }>( + `/config/security/skill-scanner/whitelist/${encodeURIComponent( + skillName, + )}`, + { method: "DELETE" }, + ), }; diff --git a/console/src/locales/en.json b/console/src/locales/en.json index 8ecb826bb..317b54d61 100644 --- a/console/src/locales/en.json +++ b/console/src/locales/en.json @@ -637,8 +637,10 @@ "prompt2": "Can you tell me what skills you have?" }, "security": { - "title": "Tool Guard", - "description": "Configure security scanning for tool calls. Dangerous operations will require your explicit approval before execution.", + "title": "Security", + "description": "Manage security features for tools and skills.", + "toolGuardTitle": "Tool Guard", + "toolGuardDescription": "Configure security scanning for tool calls. Dangerous operations will require your explicit approval before execution.", "enabled": "Enable Tool Guard", "enabledTooltip": "When enabled, tool calls are scanned for dangerous patterns before execution", "guardedTools": "Guarded Tools", @@ -695,6 +697,56 @@ "TOOL_CMD_DANGEROUS_RM": "Detects 'rm' command that may cause data loss", "TOOL_CMD_DANGEROUS_MV": "Detects 'mv' command that may move or overwrite files unexpectedly" } + }, + "skillScanner": { + "title": "Skill Scanner", + "description": "Automatically scan skills for security threats before enabling or installing. Unsafe skills can be blocked or whitelisted.", + "mode": "Scanner Mode", + "modeTooltip": "Controls how the scanner handles unsafe skills: block, warn only, or disabled", + "modeBlock": "Block", + "modeWarn": "Warn Only", + "modeOff": "Off", + "timeout": "Scan Timeout (seconds)", + "timeoutTooltip": "Maximum time to wait for a scan to complete (5-300 seconds)", + "saveSuccess": "Skill scanner settings saved", + "saveFailed": "Failed to save skill scanner settings", + "scanAlerts": { + "title": "Scan Alerts", + "empty": "No security alerts", + "clearAll": "Clear All", + "clearConfirm": "Clear all scan alerts?", + "skillName": "Skill", + "action": "Action", + "actionBlocked": "Blocked", + "actionWarned": "Warned", + "time": "Time", + "actions": "Actions", + "allowSkill": "Add to Whitelist", + "remove": "Remove", + "viewFindings": "View Details" + }, + "whitelist": { + "title": "Whitelist", + "empty": "No skills are whitelisted", + "skillName": "Skill", + "addedAt": "Added At", + "contentHash": "Content Hash", + "actions": "Actions", + "remove": "Remove", + "removeConfirm": "Remove this skill from the whitelist?", + "addSuccess": "Skill added to whitelist", + "removeSuccess": "Skill removed from whitelist", + "addFailed": "Failed to add skill to whitelist", + "removeFailed": "Failed to remove skill from whitelist", + "removeWillDisable": "The skill will also be disabled after removal.", + "removeAndDisabled": "Skill removed from whitelist and disabled" + }, + "scanError": { + "title": "Security Issues Detected", + "description": "The following security issues were found:", + "warnDescription": "Security issues were found but the skill was still enabled (warn mode):", + "goToWhitelist": "Go to Whitelist" + } } } } diff --git a/console/src/locales/ja.json b/console/src/locales/ja.json index 72b64bf1d..e412e43f7 100644 --- a/console/src/locales/ja.json +++ b/console/src/locales/ja.json @@ -586,8 +586,10 @@ "prompt2": "あなたのスキルを教えてください。" }, "security": { - "title": "ツールガード", - "description": "ツール呼び出しのセキュリティスキャンを設定します。危険な操作は実行前に明示的な承認が必要になります。", + "title": "セキュリティ", + "description": "ツールとスキルのセキュリティ機能を管理します。", + "toolGuardTitle": "ツールガード", + "toolGuardDescription": "ツール呼び出しのセキュリティスキャンを設定します。危険な操作は実行前に明示的な承認が必要になります。", "enabled": "ツールガードを有効にする", "enabledTooltip": "有効にすると、ツール呼び出しが実行前に危険なパターンをスキャンされます", "guardedTools": "保護対象ツール", @@ -644,6 +646,56 @@ "TOOL_CMD_DANGEROUS_RM": "データ損失を引き起こす可能性のある rm コマンドを検出", "TOOL_CMD_DANGEROUS_MV": "ファイルを意図せず移動・上書きする可能性のある mv コマンドを検出" } + }, + "skillScanner": { + "title": "スキルスキャナー", + "description": "スキルを有効化・インストールする前にセキュリティ脅威を自動スキャンします。", + "mode": "スキャンモード", + "modeTooltip": "スキャナーが安全でないスキルをどう処理するかを制御: ブロック、警告のみ、オフ", + "modeBlock": "ブロック", + "modeWarn": "警告のみ", + "modeOff": "オフ", + "timeout": "スキャンタイムアウト(秒)", + "timeoutTooltip": "スキャン完了を待つ最大時間(5-300秒)", + "saveSuccess": "スキルスキャナー設定を保存しました", + "saveFailed": "スキルスキャナー設定の保存に失敗しました", + "scanAlerts": { + "title": "スキャンアラート", + "empty": "セキュリティアラートはありません", + "clearAll": "すべてクリア", + "clearConfirm": "すべてのスキャンアラートをクリアしますか?", + "skillName": "スキル", + "action": "アクション", + "actionBlocked": "ブロック済み", + "actionWarned": "警告済み", + "time": "日時", + "actions": "操作", + "allowSkill": "ホワイトリストに追加", + "remove": "削除", + "viewFindings": "詳細を表示" + }, + "whitelist": { + "title": "ホワイトリスト", + "empty": "ホワイトリストにスキルがありません", + "skillName": "スキル", + "addedAt": "追加日時", + "contentHash": "コンテンツハッシュ", + "actions": "操作", + "remove": "削除", + "removeConfirm": "このスキルをホワイトリストから削除しますか?", + "addSuccess": "スキルをホワイトリストに追加しました", + "removeSuccess": "スキルをホワイトリストから削除しました", + "addFailed": "ホワイトリストへの追加に失敗しました", + "removeFailed": "ホワイトリストからの削除に失敗しました", + "removeWillDisable": "削除後、このスキルも無効化されます。", + "removeAndDisabled": "スキルをホワイトリストから削除し、無効化しました" + }, + "scanError": { + "title": "セキュリティ問題が検出されました", + "description": "以下のセキュリティ問題が見つかりました:", + "warnDescription": "セキュリティ問題が見つかりましたが、スキルは引き続き有効です(警告モード):", + "goToWhitelist": "ホワイトリストへ" + } } } } diff --git a/console/src/locales/ru.json b/console/src/locales/ru.json index 0c252d5fb..4c2b8779f 100644 --- a/console/src/locales/ru.json +++ b/console/src/locales/ru.json @@ -591,8 +591,10 @@ "prompt2": "Расскажи мне, какими навыками ты обладаешь?" }, "security": { - "title": "Защита инструментов", - "description": "Настройте сканирование безопасности для вызовов инструментов. Опасные операции потребуют вашего явного подтверждения перед выполнением.", + "title": "Безопасность", + "description": "Управление функциями безопасности инструментов и навыков.", + "toolGuardTitle": "Защита инструментов", + "toolGuardDescription": "Настройте сканирование безопасности для вызовов инструментов. Опасные операции потребуют вашего явного подтверждения перед выполнением.", "enabled": "Включить защиту инструментов", "enabledTooltip": "При включении вызовы инструментов сканируются на наличие опасных шаблонов перед выполнением", "guardedTools": "Защищённые инструменты", @@ -649,6 +651,56 @@ "TOOL_CMD_DANGEROUS_RM": "Обнаруживает команду rm, которая может привести к потере данных", "TOOL_CMD_DANGEROUS_MV": "Обнаруживает команду mv, которая может неожиданно переместить или перезаписать файлы" } + }, + "skillScanner": { + "title": "Сканер навыков", + "description": "Автоматическое сканирование навыков на наличие угроз безопасности перед активацией или установкой.", + "mode": "Режим сканирования", + "modeTooltip": "Управление поведением сканера: блокировка, только предупреждение или отключение", + "modeBlock": "Блокировать", + "modeWarn": "Только предупреждение", + "modeOff": "Выключен", + "timeout": "Таймаут сканирования (секунды)", + "timeoutTooltip": "Максимальное время ожидания завершения сканирования (5-300 секунд)", + "saveSuccess": "Настройки сканера навыков сохранены", + "saveFailed": "Не удалось сохранить настройки сканера навыков", + "scanAlerts": { + "title": "Оповещения сканирования", + "empty": "Нет оповещений безопасности", + "clearAll": "Очистить все", + "clearConfirm": "Очистить все оповещения сканирования?", + "skillName": "Навык", + "action": "Действие", + "actionBlocked": "Заблокировано", + "actionWarned": "Предупреждение", + "time": "Время", + "actions": "Действия", + "allowSkill": "Добавить в белый список", + "remove": "Удалить", + "viewFindings": "Подробности" + }, + "whitelist": { + "title": "Белый список", + "empty": "Белый список пуст", + "skillName": "Навык", + "addedAt": "Добавлено", + "contentHash": "Хеш содержимого", + "actions": "Действия", + "remove": "Удалить", + "removeConfirm": "Удалить этот навык из белого списка?", + "addSuccess": "Навык добавлен в белый список", + "removeSuccess": "Навык удалён из белого списка", + "addFailed": "Не удалось добавить навык в белый список", + "removeFailed": "Не удалось удалить навык из белого списка", + "removeWillDisable": "После удаления навык также будет отключён.", + "removeAndDisabled": "Навык удалён из белого списка и отключён" + }, + "scanError": { + "title": "Обнаружены проблемы безопасности", + "description": "Обнаружены следующие проблемы безопасности:", + "warnDescription": "Обнаружены проблемы безопасности, но навык был включён (режим предупреждения):", + "goToWhitelist": "Перейти к белому списку" + } } } } diff --git a/console/src/locales/zh.json b/console/src/locales/zh.json index f3d94b0f1..b894ab93b 100644 --- a/console/src/locales/zh.json +++ b/console/src/locales/zh.json @@ -637,8 +637,10 @@ "prompt2": "能告诉我你有哪些技能吗?" }, "security": { - "title": "工具防护", - "description": "配置工具调用的安全扫描。危险操作将在执行前需要你的明确批准。", + "title": "安全", + "description": "管理工具和技能的安全功能。", + "toolGuardTitle": "工具防护", + "toolGuardDescription": "配置工具调用的安全扫描。危险操作将在执行前需要你的明确批准。", "enabled": "启用工具防护", "enabledTooltip": "启用后,工具调用在执行前会被扫描是否包含危险模式", "guardedTools": "受保护的工具", @@ -695,6 +697,56 @@ "TOOL_CMD_DANGEROUS_RM": "检测可能导致数据丢失的 rm 命令", "TOOL_CMD_DANGEROUS_MV": "检测可能意外移动或覆盖文件的 mv 命令" } + }, + "skillScanner": { + "title": "技能扫描器", + "description": "在启用或安装技能前,自动扫描安全威胁。不安全的技能可以被拦截或加入白名单。", + "mode": "扫描模式", + "modeTooltip": "控制扫描器如何处理不安全的技能:拦截、仅提醒或关闭", + "modeBlock": "拦截", + "modeWarn": "仅提醒", + "modeOff": "关闭", + "timeout": "扫描超时(秒)", + "timeoutTooltip": "等待扫描完成的最长时间(5-300秒)", + "saveSuccess": "技能扫描器设置已保存", + "saveFailed": "保存技能扫描器设置失败", + "scanAlerts": { + "title": "扫描告警", + "empty": "暂无安全告警", + "clearAll": "清除全部", + "clearConfirm": "确定清除所有扫描告警吗?", + "skillName": "技能", + "action": "动作", + "actionBlocked": "已拦截", + "actionWarned": "已提醒", + "time": "时间", + "actions": "操作", + "allowSkill": "加入白名单", + "remove": "删除", + "viewFindings": "查看详情" + }, + "whitelist": { + "title": "白名单", + "empty": "暂无白名单技能", + "skillName": "技能", + "addedAt": "添加时间", + "contentHash": "内容哈希", + "actions": "操作", + "remove": "移除", + "removeConfirm": "确定将此技能从白名单中移除?", + "addSuccess": "技能已加入白名单", + "removeSuccess": "技能已从白名单移除", + "addFailed": "加入白名单失败", + "removeFailed": "从白名单移除失败", + "removeWillDisable": "移除后该技能将同时被禁用。", + "removeAndDisabled": "技能已从白名单移除并已禁用" + }, + "scanError": { + "title": "检测到安全问题", + "description": "发现以下安全问题:", + "warnDescription": "发现安全问题,但技能仍已启用(仅提醒模式):", + "goToWhitelist": "前往白名单" + } } } } diff --git a/console/src/pages/Agent/Skills/useSkills.ts b/console/src/pages/Agent/Skills/useSkills.ts index 17a3776bf..1f57749ba 100644 --- a/console/src/pages/Agent/Skills/useSkills.ts +++ b/console/src/pages/Agent/Skills/useSkills.ts @@ -1,15 +1,190 @@ -import { useState, useEffect } from "react"; +import { useState, useEffect, useCallback } from "react"; import { message, Modal } from "@agentscope-ai/design"; +import React from "react"; import api from "../../../api"; import type { SkillSpec } from "../../../api/types"; +import type { SecurityScanErrorResponse } from "../../../api/modules/security"; +import { useTranslation } from "react-i18next"; import { useAgentStore } from "../../../stores/agentStore"; +function tryParseScanError(error: unknown): SecurityScanErrorResponse | null { + if (!(error instanceof Error)) return null; + const msg = error.message || ""; + const jsonStart = msg.indexOf("{"); + if (jsonStart === -1) return null; + try { + const parsed = JSON.parse(msg.substring(jsonStart)); + if (parsed?.type === "security_scan_failed") { + return parsed as SecurityScanErrorResponse; + } + } catch { + // not JSON + } + return null; +} + export function useSkills() { + const { t } = useTranslation(); const { selectedAgent } = useAgentStore(); const [skills, setSkills] = useState([]); const [loading, setLoading] = useState(false); const [importing, setImporting] = useState(false); + const showScanErrorModal = useCallback( + (scanError: SecurityScanErrorResponse) => { + const findings = scanError.findings || []; + Modal.error({ + title: t("security.skillScanner.scanError.title"), + width: 640, + content: React.createElement( + "div", + null, + React.createElement( + "p", + null, + t("security.skillScanner.scanError.description"), + ), + React.createElement( + "div", + { + style: { + maxHeight: 300, + overflow: "auto", + marginTop: 8, + }, + }, + findings.map((f, i) => + React.createElement( + "div", + { + key: i, + style: { + padding: "8px 12px", + marginBottom: 4, + background: "#fafafa", + borderRadius: 6, + border: "1px solid #f0f0f0", + }, + }, + React.createElement( + "strong", + { style: { marginBottom: 4, display: "block" } }, + f.title, + ), + React.createElement( + "div", + { style: { fontSize: 12, color: "#666" } }, + f.file_path + (f.line_number ? `:${f.line_number}` : ""), + ), + f.description && + React.createElement( + "div", + { + style: { + fontSize: 12, + color: "#999", + marginTop: 2, + }, + }, + f.description, + ), + ), + ), + ), + ), + }); + }, + [t], + ); + + const handleError = useCallback( + (error: unknown, defaultMsg: string): boolean => { + const scanError = tryParseScanError(error); + if (scanError) { + showScanErrorModal(scanError); + return true; + } + console.error(defaultMsg, error); + message.error(defaultMsg); + return false; + }, + [showScanErrorModal], + ); + + const checkScanWarnings = useCallback( + async (skillName: string) => { + try { + const [alerts, scannerCfg] = await Promise.all([ + api.getBlockedHistory(), + api.getSkillScanner(), + ]); + if (!alerts.length) return; + if ( + scannerCfg?.whitelist?.some( + (w: { skill_name: string }) => w.skill_name === skillName, + ) + ) + return; + const latestForSkill = alerts + .filter((a) => a.skill_name === skillName && a.action === "warned") + .pop(); + if (!latestForSkill) return; + const findings = latestForSkill.findings || []; + Modal.warning({ + title: t("security.skillScanner.scanError.title"), + width: 640, + content: React.createElement( + "div", + null, + React.createElement( + "p", + null, + t("security.skillScanner.scanError.warnDescription"), + ), + React.createElement( + "div", + { style: { maxHeight: 300, overflow: "auto", marginTop: 8 } }, + findings.map((f, i) => + React.createElement( + "div", + { + key: i, + style: { + padding: "8px 12px", + marginBottom: 4, + background: "#fafafa", + borderRadius: 6, + border: "1px solid #f0f0f0", + }, + }, + React.createElement( + "strong", + { style: { marginBottom: 4, display: "block" } }, + f.title, + ), + React.createElement( + "div", + { style: { fontSize: 12, color: "#666" } }, + f.file_path + (f.line_number ? `:${f.line_number}` : ""), + ), + f.description && + React.createElement( + "div", + { style: { fontSize: 12, color: "#999", marginTop: 2 } }, + f.description, + ), + ), + ), + ), + ), + }); + } catch { + // non-critical + } + }, + [t], + ); + const fetchSkills = async () => { setLoading(true); try { @@ -46,10 +221,10 @@ export function useSkills() { await api.createSkill(name, content); message.success("Created successfully"); await fetchSkills(); + await checkScanWarnings(name); return true; } catch (error) { - console.error("Failed to save skill", error); - message.error("Failed to save"); + handleError(error, "Failed to save"); return false; } }; @@ -73,13 +248,13 @@ export function useSkills() { if (result?.installed) { message.success(`Imported skill: ${result.name}`); await fetchSkills(); + if (result.name) await checkScanWarnings(result.name); return true; } message.error("Import failed"); return false; } catch (error) { - console.error("Failed to import skill from hub", error); - message.error("Import failed"); + handleError(error, "Import failed"); return false; } finally { setImporting(false); @@ -104,11 +279,11 @@ export function useSkills() { ), ); message.success("Enabled successfully"); + await checkScanWarnings(skill.name); } return true; } catch (error) { - console.error("Failed to toggle skill", error); - message.error("Operation failed"); + handleError(error, "Operation failed"); return false; } }; diff --git a/console/src/pages/Settings/Security/components/SkillScannerSection.tsx b/console/src/pages/Settings/Security/components/SkillScannerSection.tsx new file mode 100644 index 000000000..ef78229bc --- /dev/null +++ b/console/src/pages/Settings/Security/components/SkillScannerSection.tsx @@ -0,0 +1,439 @@ +import { useState, useCallback } from "react"; +import { + Card, + InputNumber, + Table, + Tag, + Button, + Modal, + message, + Tooltip, + Empty, +} from "@agentscope-ai/design"; +import { Select, Space } from "antd"; +import { Trash2, ShieldCheck, Eye, ShieldOff } from "lucide-react"; +import { useTranslation } from "react-i18next"; +import { useSkillScanner } from "../useSkillScanner"; +import type { + BlockedSkillRecord, + BlockedSkillFinding, + SkillScannerWhitelistEntry, + SkillScannerMode, +} from "../../../../api/modules/security"; +import { skillApi } from "../../../../api/modules/skill"; +import styles from "../index.module.less"; + +function FindingsModal({ + findings, + skillName, + open, + onClose, +}: { + findings: BlockedSkillFinding[]; + skillName: string; + open: boolean; + onClose: () => void; +}) { + const { t } = useTranslation(); + + return ( + +
String(idx)} + pagination={false} + size="small" + columns={[ + { + title: "Title", + dataIndex: "title", + key: "title", + width: 200, + }, + { + title: "File", + key: "location", + width: 160, + render: (_: unknown, record: BlockedSkillFinding) => + record.line_number + ? `${record.file_path}:${record.line_number}` + : record.file_path, + }, + { + title: "Description", + dataIndex: "description", + key: "description", + ellipsis: true, + }, + ]} + /> + + ); +} + +export function SkillScannerSection() { + const { t } = useTranslation(); + const { + config, + blockedHistory, + whitelist, + loading, + updateConfig, + addToWhitelist, + removeFromWhitelist, + removeBlockedEntry, + clearBlockedHistory, + } = useSkillScanner(); + + const [saving, setSaving] = useState(false); + const [findingsModal, setFindingsModal] = useState<{ + open: boolean; + findings: BlockedSkillFinding[]; + skillName: string; + }>({ open: false, findings: [], skillName: "" }); + + const handleModeChange = useCallback( + async (mode: SkillScannerMode) => { + setSaving(true); + const ok = await updateConfig({ mode }); + if (ok) message.success(t("security.skillScanner.saveSuccess")); + else message.error(t("security.skillScanner.saveFailed")); + setSaving(false); + }, + [updateConfig, t], + ); + + const [pendingTimeout, setPendingTimeout] = useState(null); + + const handleTimeoutBlur = useCallback(async () => { + const value = pendingTimeout; + if (value === null || value < 5 || value > 300) { + setPendingTimeout(null); + return; + } + setSaving(true); + const ok = await updateConfig({ timeout: value }); + if (ok) message.success(t("security.skillScanner.saveSuccess")); + else message.error(t("security.skillScanner.saveFailed")); + setPendingTimeout(null); + setSaving(false); + }, [pendingTimeout, updateConfig, t]); + + const handleAllowSkill = useCallback( + async (record: BlockedSkillRecord, index: number) => { + const ok = await addToWhitelist(record.skill_name, record.content_hash); + if (ok) { + message.success(t("security.skillScanner.whitelist.addSuccess")); + await removeBlockedEntry(index); + } else { + message.error(t("security.skillScanner.whitelist.addFailed")); + } + }, + [addToWhitelist, removeBlockedEntry, t], + ); + + const handleRemoveWhitelist = useCallback( + async (skillName: string) => { + Modal.confirm({ + title: t("security.skillScanner.whitelist.removeConfirm"), + content: t("security.skillScanner.whitelist.removeWillDisable"), + onOk: async () => { + const ok = await removeFromWhitelist(skillName); + if (!ok) { + message.error(t("security.skillScanner.whitelist.removeFailed")); + return; + } + try { + await skillApi.disableSkill(skillName); + message.success( + t("security.skillScanner.whitelist.removeAndDisabled"), + ); + } catch { + message.success(t("security.skillScanner.whitelist.removeSuccess")); + } + }, + }); + }, + [removeFromWhitelist, t], + ); + + const handleClearHistory = useCallback(() => { + Modal.confirm({ + title: t("security.skillScanner.scanAlerts.clearConfirm"), + onOk: async () => { + await clearBlockedHistory(); + }, + }); + }, [clearBlockedHistory, t]); + + if (loading || !config) return null; + + const enabled = config.mode !== "off"; + + const blockedColumns = [ + { + title: t("security.skillScanner.scanAlerts.skillName"), + dataIndex: "skill_name", + key: "skill_name", + width: 180, + }, + { + title: t("security.skillScanner.scanAlerts.action"), + dataIndex: "action", + key: "action", + width: 100, + render: (action: string) => ( + + {action === "blocked" + ? t("security.skillScanner.scanAlerts.actionBlocked") + : t("security.skillScanner.scanAlerts.actionWarned")} + + ), + }, + { + title: t("security.skillScanner.scanAlerts.time"), + dataIndex: "blocked_at", + key: "blocked_at", + width: 180, + render: (val: string) => { + try { + return new Date(val).toLocaleString(); + } catch { + return val; + } + }, + }, + { + title: t("security.skillScanner.scanAlerts.actions"), + key: "actions", + width: 200, + render: (_: unknown, record: BlockedSkillRecord, index: number) => ( + + +
String(idx)} + pagination={false} + size="small" + /> + )} + + + {/* Whitelist */} +
+

+ {t("security.skillScanner.whitelist.title")} +

+
+ + + {whitelist.length === 0 ? ( +
+ +
+ ) : ( +
+ )} + + + + setFindingsModal({ open: false, findings: [], skillName: "" }) + } + /> + + ); +} diff --git a/console/src/pages/Settings/Security/components/index.ts b/console/src/pages/Settings/Security/components/index.ts index ee4af737f..3262bffc0 100644 --- a/console/src/pages/Settings/Security/components/index.ts +++ b/console/src/pages/Settings/Security/components/index.ts @@ -2,3 +2,4 @@ export * from "./PageHeader"; export * from "./RuleTable"; export * from "./RuleModal"; export * from "./PreviewModal"; +export * from "./SkillScannerSection"; diff --git a/console/src/pages/Settings/Security/index.tsx b/console/src/pages/Settings/Security/index.tsx index 7a8e37703..058e6112b 100644 --- a/console/src/pages/Settings/Security/index.tsx +++ b/console/src/pages/Settings/Security/index.tsx @@ -11,7 +11,13 @@ import { PlusCircleOutlined } from "@ant-design/icons"; import { useTranslation } from "react-i18next"; import api from "../../../api"; import { useToolGuard, type MergedRule } from "./useToolGuard"; -import { PageHeader, RuleTable, RuleModal, PreviewModal } from "./components"; +import { + PageHeader, + RuleTable, + RuleModal, + PreviewModal, + SkillScannerSection, +} from "./components"; import styles from "./index.module.less"; const BUILTIN_TOOLS = [ @@ -204,6 +210,17 @@ function SecurityPage() {
+
+
+

+ {t("security.toolGuardTitle")} +

+

+ {t("security.toolGuardDescription")} +

+
+
+
+ + (null); + const [blockedHistory, setBlockedHistory] = useState( + [], + ); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + const fetchAll = useCallback(async () => { + setLoading(true); + setError(null); + try { + const [cfg, history] = await Promise.all([ + api.getSkillScanner(), + api.getBlockedHistory(), + ]); + setConfig(cfg); + setBlockedHistory(history); + } catch (err) { + const msg = + err instanceof Error + ? err.message + : "Failed to load skill scanner config"; + console.error("Failed to load skill scanner:", err); + setError(msg); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + fetchAll(); + }, [fetchAll]); + + const updateConfig = useCallback( + async (updates: Partial) => { + if (!config) return; + const newConfig = { ...config, ...updates }; + try { + const saved = await api.updateSkillScanner(newConfig); + setConfig(saved); + return true; + } catch (err) { + console.error("Failed to update skill scanner config:", err); + return false; + } + }, + [config], + ); + + const addToWhitelist = useCallback( + async (skillName: string, contentHash: string = "") => { + try { + await api.addToWhitelist(skillName, contentHash); + await fetchAll(); + return true; + } catch (err) { + console.error("Failed to add to whitelist:", err); + return false; + } + }, + [fetchAll], + ); + + const removeFromWhitelist = useCallback( + async (skillName: string) => { + try { + await api.removeFromWhitelist(skillName); + await fetchAll(); + return true; + } catch (err) { + console.error("Failed to remove from whitelist:", err); + return false; + } + }, + [fetchAll], + ); + + const removeBlockedEntry = useCallback( + async (index: number) => { + try { + await api.removeBlockedEntry(index); + await fetchAll(); + return true; + } catch (err) { + console.error("Failed to remove blocked entry:", err); + return false; + } + }, + [fetchAll], + ); + + const clearBlockedHistory = useCallback(async () => { + try { + await api.clearBlockedHistory(); + setBlockedHistory([]); + return true; + } catch (err) { + console.error("Failed to clear blocked history:", err); + return false; + } + }, []); + + const whitelist: SkillScannerWhitelistEntry[] = config?.whitelist ?? []; + + return { + config, + blockedHistory, + whitelist, + loading, + error, + fetchAll, + updateConfig, + addToWhitelist, + removeFromWhitelist, + removeBlockedEntry, + clearBlockedHistory, + }; +} diff --git a/pyproject.toml b/pyproject.toml index 8e124d0ac..d087a2487 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "shortuuid>=1.0.0", "google-genai>=1.67.0", "tzdata>=2024.1", + "pyyaml>=6.0", ] [tool.setuptools.dynamic] @@ -49,6 +50,8 @@ include-package-data = true "agents/skills/**", "tokenizer/**", "security/tool_guard/rules/**", + "security/skill_scanner/rules/**", + "security/skill_scanner/data/**", ] [build-system] diff --git a/src/copaw/agents/skills_manager.py b/src/copaw/agents/skills_manager.py index db8b9b67b..7fd4877ad 100644 --- a/src/copaw/agents/skills_manager.py +++ b/src/copaw/agents/skills_manager.py @@ -554,6 +554,11 @@ def __init__(self, workspace_dir: Path): """ self.workspace_dir = workspace_dir + def get_customized_skill_dir(self, name: str) -> Path | None: + """Return the Path to a skill inside customized_skills, or None.""" + skill_dir = get_customized_skills_dir(self.workspace_dir) / name + return skill_dir if skill_dir.exists() else None + def list_all_skills(self) -> list[SkillInfo]: """ List all skills from builtin and customized directories. @@ -734,9 +739,31 @@ def create_skill( name, ) + # --- Security scan (post-write) ---------------------------------- + try: + from ..security.skill_scanner import ( + SkillScanError, + scan_skill_directory, + ) + + scan_skill_directory(skill_dir, skill_name=name) + except SkillScanError: + raise + except Exception as scan_exc: + logger.warning( + "Security scan error for skill '%s' (non-fatal): %s", + name, + scan_exc, + ) + # --------------------------------------------------------------- + logger.debug("Created skill '%s' in customized_skills.", name) return True except Exception as e: + from ..security.skill_scanner import SkillScanError + + if isinstance(e, SkillScanError): + raise logger.error( "Failed to create skill '%s': %s", name, @@ -780,6 +807,10 @@ def enable_skill(self, name: str, force: bool = False) -> bool: """ Enable a skill by syncing it to active_skills directory. + Before syncing the skill runs through a security scan. + Blocking behaviour is controlled by the scanner mode in + config (``security.skill_scanner.mode``). + Args: name: Skill name to enable. force: If True, overwrite existing skill in active_skills. @@ -787,6 +818,31 @@ def enable_skill(self, name: str, force: bool = False) -> bool: Returns: True if skill was enabled successfully, False otherwise. """ + # --- Security scan (pre-activation) -------------------------------- + try: + from ..security.skill_scanner import ( + SkillScanError, + scan_skill_directory, + ) + + source_dir = self.get_customized_skill_dir(name) + if source_dir is None: + builtin = get_builtin_skills_dir() / name + if builtin.is_dir(): + source_dir = builtin + + if source_dir is not None: + scan_skill_directory(source_dir, skill_name=name) + except SkillScanError: + raise + except Exception as scan_exc: + logger.warning( + "Security scan error for skill '%s' (non-fatal): %s", + name, + scan_exc, + ) + # ------------------------------------------------------------------- + sync_skills_to_working_dir( self.workspace_dir, skill_names=[name], diff --git a/src/copaw/app/routers/config.py b/src/copaw/app/routers/config.py index cf86f460f..04208d89c 100644 --- a/src/copaw/app/routers/config.py +++ b/src/copaw/app/routers/config.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- +from datetime import datetime, timezone from typing import Any, List from fastapi import APIRouter, Body, HTTPException, Path, Request +from pydantic import BaseModel from ...config import ( load_config, @@ -26,6 +28,8 @@ MattermostConfig, MQTTConfig, QQConfig, + SkillScannerConfig, + SkillScannerWhitelistEntry, TelegramConfig, VoiceChannelConfig, ) @@ -422,3 +426,127 @@ async def get_builtin_rules() -> List[ToolGuardRuleConfig]: ) for r in rules ] + + +# ── Security / Skill Scanner ──────────────────────────────────────── + + +@router.get( + "/security/skill-scanner", + response_model=SkillScannerConfig, + summary="Get skill scanner settings", +) +async def get_skill_scanner() -> SkillScannerConfig: + config = load_config() + return config.security.skill_scanner + + +@router.put( + "/security/skill-scanner", + response_model=SkillScannerConfig, + summary="Update skill scanner settings", +) +async def put_skill_scanner( + body: SkillScannerConfig = Body(...), +) -> SkillScannerConfig: + config = load_config() + config.security.skill_scanner = body + save_config(config) + return body + + +@router.get( + "/security/skill-scanner/blocked-history", + summary="Get blocked skills history", +) +async def get_blocked_history() -> list: + from ...security.skill_scanner import get_blocked_history as _get_history + + records = _get_history() + return [r.to_dict() for r in records] + + +@router.delete( + "/security/skill-scanner/blocked-history", + summary="Clear all blocked skills history", +) +async def delete_blocked_history() -> dict: + from ...security.skill_scanner import clear_blocked_history + + clear_blocked_history() + return {"cleared": True} + + +@router.delete( + "/security/skill-scanner/blocked-history/{index}", + summary="Remove a single blocked history entry", +) +async def delete_blocked_entry( + index: int = Path(..., ge=0), +) -> dict: + from ...security.skill_scanner import remove_blocked_entry + + ok = remove_blocked_entry(index) + if not ok: + raise HTTPException(status_code=404, detail="Entry not found") + return {"removed": True} + + +class WhitelistAddRequest(BaseModel): + skill_name: str + content_hash: str = "" + + +@router.post( + "/security/skill-scanner/whitelist", + summary="Add a skill to the whitelist", +) +async def add_to_whitelist( + body: WhitelistAddRequest = Body(...), +) -> dict: + skill_name = body.skill_name.strip() + content_hash = body.content_hash + if not skill_name: + raise HTTPException(status_code=400, detail="skill_name is required") + + config = load_config() + scanner_cfg = config.security.skill_scanner + + for entry in scanner_cfg.whitelist: + if entry.skill_name == skill_name: + raise HTTPException( + status_code=409, + detail=f"Skill '{skill_name}' is already whitelisted", + ) + + scanner_cfg.whitelist.append( + SkillScannerWhitelistEntry( + skill_name=skill_name, + content_hash=content_hash, + added_at=datetime.now(timezone.utc).isoformat(), + ), + ) + save_config(config) + return {"whitelisted": True, "skill_name": skill_name} + + +@router.delete( + "/security/skill-scanner/whitelist/{skill_name}", + summary="Remove a skill from the whitelist", +) +async def remove_from_whitelist( + skill_name: str = Path(..., min_length=1), +) -> dict: + config = load_config() + scanner_cfg = config.security.skill_scanner + original_len = len(scanner_cfg.whitelist) + scanner_cfg.whitelist = [ + e for e in scanner_cfg.whitelist if e.skill_name != skill_name + ] + if len(scanner_cfg.whitelist) == original_len: + raise HTTPException( + status_code=404, + detail=f"Skill '{skill_name}' not found in whitelist", + ) + save_config(config) + return {"removed": True, "skill_name": skill_name} diff --git a/src/copaw/app/routers/skills.py b/src/copaw/app/routers/skills.py index a4f27d4b5..1c6b7b141 100644 --- a/src/copaw/app/routers/skills.py +++ b/src/copaw/app/routers/skills.py @@ -3,6 +3,7 @@ from typing import Any from pathlib import Path from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from ...agents.skills_manager import ( SkillService, @@ -12,11 +13,37 @@ search_hub_skills, install_skill_from_hub, ) +from ...security.skill_scanner import SkillScanError logger = logging.getLogger(__name__) +def _scan_error_response(exc: SkillScanError) -> JSONResponse: + """Build a 422 response with structured scan findings.""" + result = exc.result + return JSONResponse( + status_code=422, + content={ + "type": "security_scan_failed", + "detail": str(exc), + "skill_name": result.skill_name, + "max_severity": result.max_severity.value, + "findings": [ + { + "severity": f.severity.value, + "title": f.title, + "description": f.description, + "file_path": f.file_path, + "line_number": f.line_number, + "rule_id": f.rule_id, + } + for f in result.findings + ], + }, + ) + + class SkillSpec(SkillInfo): enabled: bool = False @@ -177,6 +204,8 @@ async def install_from_hub( enable=request_body.enable, overwrite=request_body.overwrite, ) + except SkillScanError as e: + return _scan_error_response(e) except ValueError as e: detail = str(e) logger.warning( @@ -186,7 +215,6 @@ async def install_from_hub( ) raise HTTPException(status_code=400, detail=detail) from e except RuntimeError as e: - # Upstream hub is flaky/rate-limited sometimes; surface as bad gateway. detail = str(e) + _github_token_hint(request_body.bundle_url) logger.exception( "Skill hub install failed (upstream/rate limit): %s", @@ -226,15 +254,36 @@ async def batch_disable_skills( async def batch_enable_skills( skill_name: list[str], request: Request, -) -> None: +): from ..agent_context import get_agent_for_request workspace = await get_agent_for_request(request) workspace_dir = Path(workspace.workspace_dir) skill_service = SkillService(workspace_dir) + blocked: list[dict] = [] for skill in skill_name: - skill_service.enable_skill(skill) + try: + skill_service.enable_skill(skill) + except SkillScanError as e: + blocked.append( + { + "skill_name": skill, + "max_severity": e.result.max_severity.value, + "detail": str(e), + }, + ) + if blocked: + return JSONResponse( + status_code=422, + content={ + "type": "security_scan_failed", + "detail": ( + f"{len(blocked)} skill(s) blocked by security scan" + ), + "blocked_skills": blocked, + }, + ) @router.post("") @@ -248,12 +297,15 @@ async def create_skill( workspace_dir = Path(workspace.workspace_dir) skill_service = SkillService(workspace_dir) - result = skill_service.create_skill( - name=request_body.name, - content=request_body.content, - references=request_body.references, - scripts=request_body.scripts, - ) + try: + result = skill_service.create_skill( + name=request_body.name, + content=request_body.content, + references=request_body.references, + scripts=request_body.scripts, + ) + except SkillScanError as e: + return _scan_error_response(e) return {"created": result} @@ -324,6 +376,21 @@ async def enable_skill( detail=f"Skill '{skill_name}' not found", ) + # --- Security scan (pre-activation) -------------------------------- + try: + from ...security.skill_scanner import scan_skill_directory + + scan_skill_directory(source_dir, skill_name=skill_name) + except SkillScanError as e: + return _scan_error_response(e) + except Exception as scan_exc: + logger.warning( + "Security scan error for skill '%s' (non-fatal): %s", + skill_name, + scan_exc, + ) + # ------------------------------------------------------------------- + # Copy to active_skills shutil.copytree(source_dir, active_skill_dir) diff --git a/src/copaw/config/config.py b/src/copaw/config/config.py index a98a1728e..77a4d2c82 100644 --- a/src/copaw/config/config.py +++ b/src/copaw/config/config.py @@ -613,10 +613,53 @@ class ToolGuardConfig(BaseModel): disabled_rules: List[str] = Field(default_factory=list) +class SkillScannerWhitelistEntry(BaseModel): + """A whitelisted skill (identified by name + content hash).""" + + skill_name: str + content_hash: str = Field( + default="", + description="SHA-256 of concatenated file contents at whitelist time. " + "Empty string means any content is allowed.", + ) + added_at: str = Field( + default="", + description="ISO 8601 timestamp when the entry was added.", + ) + + +class SkillScannerConfig(BaseModel): + """Skill scanner settings under ``security.skill_scanner``. + + ``mode`` controls the scanner behavior: + * ``"block"`` – scan and block unsafe skills. + * ``"warn"`` – scan but only log warnings, do not block (default). + * ``"off"`` – disable scanning entirely. + """ + + mode: Literal["block", "warn", "off"] = Field( + default="warn", + description="Scanner mode: block, warn, or off.", + ) + timeout: int = Field( + default=30, + ge=5, + le=300, + description="Max seconds to wait for a scan to complete.", + ) + whitelist: List[SkillScannerWhitelistEntry] = Field( + default_factory=list, + description="Skills that bypass security scanning.", + ) + + class SecurityConfig(BaseModel): """Top-level ``security`` section in config.json.""" tool_guard: ToolGuardConfig = Field(default_factory=ToolGuardConfig) + skill_scanner: SkillScannerConfig = Field( + default_factory=SkillScannerConfig, + ) class Config(BaseModel): diff --git a/src/copaw/security/__init__.py b/src/copaw/security/__init__.py index d2367e231..61e1d0674 100644 --- a/src/copaw/security/__init__.py +++ b/src/copaw/security/__init__.py @@ -7,6 +7,8 @@ * **Tool-call guarding** (``copaw.security.tool_guard``) Pre-execution parameter scanning to detect dangerous tool usage patterns (command injection, data exfiltration, etc.). +* **Skill scanning** (``copaw.security.skill_scanner``) + Static analysis of skill directories before install / activation. Sub-modules are kept independent so each concern can evolve (or be disabled) without affecting the others. Import-time cost is near-zero diff --git a/src/copaw/security/skill_scanner/__init__.py b/src/copaw/security/skill_scanner/__init__.py new file mode 100644 index 000000000..f27752beb --- /dev/null +++ b/src/copaw/security/skill_scanner/__init__.py @@ -0,0 +1,507 @@ +# -*- coding: utf-8 -*- +""" +Skill security scanner for CoPaw. + +Scans skills for security threats before they are activated or installed. + +Architecture +~~~~~~~~~~~~ + +The scanner follows a lightweight, extensible design: + +* **BaseAnalyzer** - abstract interface every analyzer must implement. +* **PatternAnalyzer** - YAML regex-signature matching (fast, line-based). +* **SkillScanner** - orchestrator that runs registered analyzers and + aggregates findings into a :class:`ScanResult`. + +This branch intentionally ships the baseline pattern analyzer only. +Additional analyzers can be plugged in later without changing the +orchestrator. + +Quick start:: + + from copaw.security.skill_scanner import SkillScanner + + scanner = SkillScanner() + result = scanner.scan_skill("/path/to/skill_directory") + if not result.is_safe: + print(f"Blocked: {result.max_severity.value} findings detected") +""" +from __future__ import annotations + +import hashlib +import json +import logging +import os +import threading +from concurrent.futures import ( + ThreadPoolExecutor, + TimeoutError as FuturesTimeout, +) +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from .models import ( + Finding, + ScanResult, + Severity, + SkillFile, + ThreatCategory, +) +from .scan_policy import ScanPolicy +from .analyzers import BaseAnalyzer +from .analyzers.pattern_analyzer import PatternAnalyzer +from .scanner import SkillScanner + +logger = logging.getLogger(__name__) + +__all__ = [ + "BaseAnalyzer", + "BlockedSkillRecord", + "Finding", + "PatternAnalyzer", + "ScanPolicy", + "ScanResult", + "Severity", + "SkillFile", + "SkillScanner", + "SkillScanError", + "ThreatCategory", + "compute_skill_content_hash", + "get_blocked_history", + "clear_blocked_history", + "remove_blocked_entry", + "is_skill_whitelisted", + "scan_skill_directory", +] + +# --------------------------------------------------------------------------- +# Config helpers +# --------------------------------------------------------------------------- + + +_VALID_MODES = {"block", "warn", "off"} + + +def _load_scanner_config() -> Any: + """Load SkillScannerConfig from the app config (lazy import).""" + try: + from ...config import load_config + + return load_config().security.skill_scanner + except Exception: + return None + + +def _get_scan_mode(cfg: Any = None) -> str: + """Return the effective scan mode: ``block``, ``warn``, or ``off``. + + Priority: env ``COPAW_SKILL_SCAN_MODE`` > config > default ``warn``. + """ + env = os.environ.get("COPAW_SKILL_SCAN_MODE") + if env is not None: + val = env.lower().strip() + if val in _VALID_MODES: + return val + if cfg is None: + cfg = _load_scanner_config() + return cfg.mode if cfg is not None else "block" + + +def _scan_timeout(cfg: Any = None) -> float: + if cfg is None: + cfg = _load_scanner_config() + return float(cfg.timeout) if cfg is not None else 30.0 + + +# --------------------------------------------------------------------------- +# Content hash +# --------------------------------------------------------------------------- + + +def compute_skill_content_hash(skill_dir: Path) -> str: + """SHA-256 hash of all regular file contents in *skill_dir* (sorted).""" + h = hashlib.sha256() + try: + for p in sorted(skill_dir.rglob("*")): + if p.is_file() and not p.is_symlink(): + try: + h.update(p.read_bytes()) + except OSError: + pass + except OSError: + pass + return h.hexdigest() + + +# --------------------------------------------------------------------------- +# Whitelist helpers +# --------------------------------------------------------------------------- + + +def is_skill_whitelisted( + skill_name: str, + skill_dir: Path | None = None, + *, + cfg: Any = None, +) -> bool: + """Return True if *skill_name* is on the whitelist. + + When a whitelist entry has a non-empty ``content_hash``, the hash must + match the current directory contents for the entry to apply. + """ + if cfg is None: + cfg = _load_scanner_config() + if cfg is None: + return False + for entry in cfg.whitelist: + if entry.skill_name != skill_name: + continue + if not entry.content_hash: + return True + if skill_dir is not None: + current_hash = compute_skill_content_hash(skill_dir) + if current_hash == entry.content_hash: + return True + else: + return True + return False + + +# --------------------------------------------------------------------------- +# Blocked history persistence +# --------------------------------------------------------------------------- + +_BLOCKED_HISTORY_FILE = "skill_scanner_blocked.json" +_history_lock = threading.Lock() + + +def _get_blocked_history_path() -> Path: + try: + from ...constant import WORKING_DIR + + return WORKING_DIR / _BLOCKED_HISTORY_FILE + except Exception: + return Path.home() / ".copaw" / _BLOCKED_HISTORY_FILE + + +@dataclass +class BlockedSkillRecord: + """A record of a scan alert (blocked or warned).""" + + skill_name: str + blocked_at: str + max_severity: str + findings: list[dict[str, Any]] = field(default_factory=list) + content_hash: str = "" + action: str = "blocked" + + def to_dict(self) -> dict[str, Any]: + return { + "skill_name": self.skill_name, + "blocked_at": self.blocked_at, + "max_severity": self.max_severity, + "findings": self.findings, + "content_hash": self.content_hash, + "action": self.action, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> BlockedSkillRecord: + return cls( + skill_name=data.get("skill_name", ""), + blocked_at=data.get("blocked_at", ""), + max_severity=data.get("max_severity", ""), + findings=data.get("findings", []), + content_hash=data.get("content_hash", ""), + action=data.get("action", "blocked"), + ) + + +def _finding_to_dict(f: Finding) -> dict[str, Any]: + return { + "severity": f.severity.value, + "title": f.title, + "description": f.description, + "file_path": f.file_path, + "line_number": f.line_number, + "rule_id": f.rule_id, + } + + +def _record_blocked_skill( + result: ScanResult, + skill_dir: Path, + *, + action: str = "blocked", +) -> None: + """Append a scan alert to the history file.""" + record = BlockedSkillRecord( + skill_name=result.skill_name, + blocked_at=datetime.now(timezone.utc).isoformat(), + max_severity=result.max_severity.value, + findings=[_finding_to_dict(f) for f in result.findings], + content_hash=compute_skill_content_hash(skill_dir), + action=action, + ) + path = _get_blocked_history_path() + with _history_lock: + try: + existing: list[dict[str, Any]] = [] + if path.is_file(): + existing = json.loads(path.read_text(encoding="utf-8")) + existing.append(record.to_dict()) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(existing, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + except Exception as exc: + logger.warning("Failed to record blocked skill: %s", exc) + + +def get_blocked_history() -> list[BlockedSkillRecord]: + """Load all blocked skill records from disk.""" + path = _get_blocked_history_path() + if not path.is_file(): + return [] + try: + data = json.loads(path.read_text(encoding="utf-8")) + return [BlockedSkillRecord.from_dict(d) for d in data] + except Exception as exc: + logger.warning("Failed to load blocked history: %s", exc) + return [] + + +def clear_blocked_history() -> None: + """Delete all blocked skill records.""" + path = _get_blocked_history_path() + try: + if path.is_file(): + path.unlink() + except OSError as exc: + logger.warning("Failed to clear blocked history: %s", exc) + + +def remove_blocked_entry(index: int) -> bool: + """Remove a single blocked record by index. Returns True on success.""" + path = _get_blocked_history_path() + if not path.is_file(): + return False + try: + data = json.loads(path.read_text(encoding="utf-8")) + if 0 <= index < len(data): + data.pop(index) + path.write_text( + json.dumps(data, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + return True + return False + except Exception as exc: + logger.warning("Failed to remove blocked entry: %s", exc) + return False + + +# --------------------------------------------------------------------------- +# Lazy singleton (thread-safe) +# --------------------------------------------------------------------------- + +_scanner_instance: SkillScanner | None = None +_scanner_lock = threading.Lock() + + +def _get_scanner() -> SkillScanner: + """Return a lazily-initialised :class:`SkillScanner` singleton.""" + global _scanner_instance + if _scanner_instance is None: + with _scanner_lock: + if _scanner_instance is None: + _scanner_instance = SkillScanner() + return _scanner_instance + + +# --------------------------------------------------------------------------- +# Scan result cache (mtime-based) +# --------------------------------------------------------------------------- + +_MAX_CACHE_ENTRIES = 64 +_scan_cache: dict[str, tuple[float, ScanResult]] = {} +_cache_lock = threading.Lock() + + +def _get_dir_mtime(skill_dir: Path) -> float: + """Return the latest mtime among the directory and its immediate files.""" + try: + latest = skill_dir.stat().st_mtime + except OSError: + return 0.0 + try: + for p in skill_dir.iterdir(): + if p.is_file() and not p.is_symlink(): + latest = max(latest, p.stat().st_mtime) + except OSError: + pass + return latest + + +def _get_cached_result( + skill_dir: Path, +) -> ScanResult | None: + """Return a cached ScanResult if the directory hasn't changed.""" + key = str(skill_dir) + with _cache_lock: + entry = _scan_cache.get(key) + if entry is None: + return None + cached_mtime, cached_result = entry + current_mtime = _get_dir_mtime(skill_dir) + if current_mtime == cached_mtime: + logger.debug( + "Returning cached scan result for '%s'", + cached_result.skill_name, + ) + return cached_result + return None + + +def _store_cached_result( + skill_dir: Path, + result: ScanResult, +) -> None: + """Store a scan result in the cache (LRU eviction).""" + key = str(skill_dir) + mtime = _get_dir_mtime(skill_dir) + with _cache_lock: + _scan_cache.pop(key, None) + _scan_cache[key] = (mtime, result) + while len(_scan_cache) > _MAX_CACHE_ENTRIES: + oldest = next(iter(_scan_cache)) + del _scan_cache[oldest] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def _format_finding_location(f: Finding) -> str: + if f.line_number is not None: + return f"({f.file_path}:{f.line_number})" + return f"({f.file_path})" + + +class SkillScanError(Exception): + """Raised when a skill fails a security scan and blocking is enabled.""" + + def __init__(self, result: ScanResult) -> None: + self.result = result + findings_summary = "; ".join( + f"[{f.severity.value}] {f.title} " f"{_format_finding_location(f)}" + for f in result.findings[:5] + ) + truncated = ( + f" (and {len(result.findings) - 5} more)" + if len(result.findings) > 5 + else "" + ) + super().__init__( + f"Security scan of skill '{result.skill_name}' found " + f"{len(result.findings)} issue(s) " + f"(max severity: {result.max_severity.value}): " + f"{findings_summary}{truncated}", + ) + + +def scan_skill_directory( + skill_dir: str | Path, + *, + skill_name: str | None = None, + block: bool | None = None, + timeout: float | None = None, +) -> ScanResult | None: + """Scan a skill directory and optionally block on unsafe results. + + Parameters + ---------- + skill_dir: + Path to the skill directory to scan. + skill_name: + Human-readable name (falls back to directory name). + block: + Whether to raise :class:`SkillScanError` when the scan finds + CRITICAL/HIGH issues. *None* means use the configured mode + (``block`` mode → True, ``warn`` mode → False). + timeout: + Maximum seconds to wait for the scan to complete before + giving up and returning ``None``. *None* reads from config. + + Returns + ------- + ScanResult or None + ``None`` when scanning is disabled, whitelisted, or timed out. + + Raises + ------ + SkillScanError + When blocking is enabled and the skill is deemed unsafe. + """ + cfg = _load_scanner_config() + mode = _get_scan_mode(cfg) + if mode == "off": + return None + + resolved = Path(skill_dir).resolve() + effective_name = skill_name or resolved.name + + if is_skill_whitelisted(effective_name, resolved, cfg=cfg): + logger.debug( + "Skill '%s' is whitelisted, skipping scan", + effective_name, + ) + return None + + effective_timeout = timeout if timeout is not None else _scan_timeout(cfg) + + cached = _get_cached_result(resolved) + if cached is not None: + result = cached + else: + scanner = _get_scanner() + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + scanner.scan_skill, + resolved, + skill_name=skill_name, + ) + try: + result = future.result(timeout=effective_timeout) + except FuturesTimeout: + logger.warning( + "Security scan of skill '%s' timed out after %.0fs", + effective_name, + effective_timeout, + ) + future.cancel() + return None + + _store_cached_result(resolved, result) + + if not result.is_safe: + should_block = block if block is not None else (mode == "block") + if should_block: + _record_blocked_skill(result, resolved, action="blocked") + raise SkillScanError(result) + _record_blocked_skill(result, resolved, action="warned") + logger.warning( + "Skill '%s' has %d security finding(s) (max severity: %s) " + "but blocking is disabled – proceeding anyway.", + result.skill_name, + len(result.findings), + result.max_severity.value, + ) + + return result diff --git a/src/copaw/security/skill_scanner/analyzers/__init__.py b/src/copaw/security/skill_scanner/analyzers/__init__.py new file mode 100644 index 000000000..8096c3712 --- /dev/null +++ b/src/copaw/security/skill_scanner/analyzers/__init__.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +"""Abstract base class for all security analyzers. + +Every analyzer must subclass :class:`BaseAnalyzer` and implement +:meth:`analyze`. The interface is intentionally minimal so that +new detection engines (e.g. LLM-based, behavioural dataflow) can be +added as drop-in plugins without touching the scanner orchestrator. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING + +from ..models import Finding, SkillFile + +if TYPE_CHECKING: + from ..scan_policy import ScanPolicy + + +class BaseAnalyzer(ABC): + """Abstract base class for all security analyzers. + + Parameters + ---------- + name: + Human-readable analyzer name (used in :attr:`Finding.analyzer`). + policy: + Optional :class:`ScanPolicy` for org-specific rule scoping, + severity overrides, and allowlists. When *None*, analysers + use their own built-in defaults. + """ + + def __init__( + self, + name: str, + *, + policy: "ScanPolicy | None" = None, + ) -> None: + self.name = name + # Lazily import to avoid circular dependencies + if policy is None: + from ..scan_policy import ScanPolicy + + policy = ScanPolicy.default() + self._policy = policy + + @property + def policy(self) -> "ScanPolicy": + """The active scan policy.""" + return self._policy + + # ------------------------------------------------------------------ + # Abstract interface + # ------------------------------------------------------------------ + + @abstractmethod + def analyze( + self, + skill_dir: Path, + files: list[SkillFile], + *, + skill_name: str | None = None, + ) -> list[Finding]: + """Analyze a skill package for security issues. + + Parameters + ---------- + skill_dir: + Root directory of the skill. + files: + Pre-discovered list of :class:`SkillFile` objects belonging + to the skill. + skill_name: + Optional skill name for richer finding messages. + + Returns + ------- + list[Finding] + Findings discovered by this analyzer. + """ + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def get_name(self) -> str: # noqa: D401 + """Analyzer name.""" + return self.name diff --git a/src/copaw/security/skill_scanner/analyzers/pattern_analyzer.py b/src/copaw/security/skill_scanner/analyzers/pattern_analyzer.py new file mode 100644 index 000000000..8df88af54 --- /dev/null +++ b/src/copaw/security/skill_scanner/analyzers/pattern_analyzer.py @@ -0,0 +1,392 @@ +# -*- coding: utf-8 -*- +"""YAML-signature pattern matching analyzer. + +Loads security rules from YAML files (see ``rules/signatures/``) and +performs fast, line-based regex matching with a multiline fallback for +patterns that intentionally span newlines. +""" +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import Any + +import yaml + +from ..models import Finding, Severity, SkillFile, ThreatCategory +from ..scan_policy import ScanPolicy +from . import BaseAnalyzer + +logger = logging.getLogger(__name__) + +# Matches character-class contents so we can tell whether ``\n`` in a +# pattern is genuinely multiline vs. ``[^\n]``. +_CHAR_CLASS_RE = re.compile(r"\[[^\]]*\]") + +# Default signatures directory (shipped with the package). +_DEFAULT_SIGNATURES_DIR = ( + Path(__file__).resolve().parent.parent / "rules" / "signatures" +) + + +# --------------------------------------------------------------------------- +# SecurityRule – one YAML rule entry +# --------------------------------------------------------------------------- + + +class SecurityRule: + """A single regex-based security detection rule.""" + + __slots__ = ( + "id", + "category", + "severity", + "patterns", + "exclude_patterns", + "file_types", + "description", + "remediation", + "compiled_patterns", + "compiled_exclude_patterns", + ) + + def __init__(self, rule_data: dict[str, Any]) -> None: + self.id: str = rule_data["id"] + self.category = ThreatCategory(rule_data["category"]) + self.severity = Severity(rule_data["severity"]) + self.patterns: list[str] = rule_data["patterns"] + self.exclude_patterns: list[str] = rule_data.get( + "exclude_patterns", + [], + ) + self.file_types: list[str] = rule_data.get("file_types", []) + self.description: str = rule_data["description"] + self.remediation: str = rule_data.get("remediation", "") + + self.compiled_patterns: list[re.Pattern[str]] = [] + for pat in self.patterns: + try: + self.compiled_patterns.append(re.compile(pat)) + except re.error as exc: + logger.warning("Bad regex in rule %s: %s", self.id, exc) + + self.compiled_exclude_patterns: list[re.Pattern[str]] = [] + for pat in self.exclude_patterns: + try: + self.compiled_exclude_patterns.append(re.compile(pat)) + except re.error as exc: + logger.warning( + "Bad exclude regex in rule %s: %s", + self.id, + exc, + ) + + # ------------------------------------------------------------------ + + def matches_file_type(self, file_type: str) -> bool: + """Return *True* if this rule applies to *file_type*.""" + if not self.file_types: + return True + return file_type in self.file_types + + def scan_content( + self, + content: str, + file_path: str | None = None, + ) -> list[dict[str, Any]]: + """Scan *content* for rule violations. + + Returns a list of match dicts with ``line_number``, ``line_content``, + ``matched_pattern``, ``matched_text``, and ``file_path``. + """ + matches: list[dict[str, Any]] = [] + lines = content.split("\n") + + # --- Pass 1: line-based matching (fast) -------------------------- + for line_num, line in enumerate(lines, start=1): + excluded = any( + ep.search(line) for ep in self.compiled_exclude_patterns + ) + if excluded: + continue + for pattern in self.compiled_patterns: + m = pattern.search(line) + if m: + matches.append( + { + "line_number": line_num, + "line_content": line.strip(), + "matched_pattern": pattern.pattern, + "matched_text": m.group(0), + "file_path": file_path, + }, + ) + + # --- Pass 2: multiline patterns ---------------------------------- + for pattern in self.compiled_patterns: + stripped = _CHAR_CLASS_RE.sub("", pattern.pattern) + if "\\n" not in stripped: + continue + for m in pattern.finditer(content): + matched_text = m.group(0) + excluded = any( + ep.search(matched_text) + for ep in self.compiled_exclude_patterns + ) + if excluded: + continue + start_line = content.count("\n", 0, m.start()) + 1 + snippet = ( + lines[start_line - 1].strip() + if 0 <= start_line - 1 < len(lines) + else "" + ) + matches.append( + { + "line_number": start_line, + "line_content": snippet, + "matched_pattern": pattern.pattern, + "matched_text": matched_text[:200], + "file_path": file_path, + }, + ) + + return matches + + +# --------------------------------------------------------------------------- +# RuleLoader +# --------------------------------------------------------------------------- + + +class RuleLoader: + """Loads :class:`SecurityRule` objects from YAML files.""" + + def __init__(self, rules_path: Path | None = None) -> None: + self.rules_path = rules_path or _DEFAULT_SIGNATURES_DIR + self.rules: list[SecurityRule] = [] + self.rules_by_id: dict[str, SecurityRule] = {} + self.rules_by_category: dict[ThreatCategory, list[SecurityRule]] = {} + + def load_rules(self) -> list[SecurityRule]: + """Load and index all rules from the configured path.""" + path = Path(self.rules_path) + if path.is_dir(): + raw: list[dict[str, Any]] = [] + for yaml_file in sorted(path.glob("*.yaml")): + try: + with open(yaml_file, encoding="utf-8") as fh: + data = yaml.safe_load(fh) + except Exception as exc: + raise RuntimeError( + f"Failed to load {yaml_file}: {exc}", + ) from exc + if not isinstance(data, list): + raise RuntimeError(f"Expected list in {yaml_file}") + raw.extend(data) + else: + try: + with open(path, encoding="utf-8") as fh: + raw = yaml.safe_load(fh) + except Exception as exc: + raise RuntimeError(f"Failed to load {path}: {exc}") from exc + if not isinstance(raw, list): + raise RuntimeError(f"Expected list in {path}") + + self.rules = [] + self.rules_by_id = {} + self.rules_by_category = {} + + for entry in raw: + try: + rule = SecurityRule(entry) + self.rules.append(rule) + self.rules_by_id[rule.id] = rule + self.rules_by_category.setdefault(rule.category, []).append( + rule, + ) + except Exception as exc: + logger.warning( + "Skipping rule %s: %s", + entry.get("id", "?"), + exc, + ) + + return self.rules + + def get_rule(self, rule_id: str) -> SecurityRule | None: + return self.rules_by_id.get(rule_id) + + def get_rules_for_file_type(self, file_type: str) -> list[SecurityRule]: + return [r for r in self.rules if r.matches_file_type(file_type)] + + def get_rules_for_category( + self, + category: ThreatCategory, + ) -> list[SecurityRule]: + return self.rules_by_category.get(category, []) + + +# --------------------------------------------------------------------------- +# PatternAnalyzer +# --------------------------------------------------------------------------- + + +class PatternAnalyzer(BaseAnalyzer): + """Analyzer that matches YAML regex signatures against skill files. + + Parameters + ---------- + rules_path: + Path to a YAML file or a directory of YAML files. Defaults to + the ``rules/signatures/`` directory shipped with the package. + policy: + Optional :class:`ScanPolicy` for rule disabling, severity + overrides, and doc-path skipping. + """ + + def __init__( + self, + rules_path: Path | None = None, + *, + policy: ScanPolicy | None = None, + ) -> None: + super().__init__(name="pattern", policy=policy) + loader = RuleLoader(rules_path) + self._rules = loader.load_rules() + self._rules_by_file_type: dict[str, list[SecurityRule]] = {} + logger.debug("PatternAnalyzer loaded %d rules", len(self._rules)) + + # ------------------------------------------------------------------ + # BaseAnalyzer interface + # ------------------------------------------------------------------ + + def analyze( + self, + skill_dir: Path, + files: list[SkillFile], + *, + skill_name: str | None = None, + ) -> list[Finding]: + findings: list[Finding] = [] + skip_in_docs = self.policy.rule_scoping.skip_in_docs + + for sf in files: + content = sf.read_content() + if not content: + continue + + is_doc = self.policy.is_doc_path(sf.relative_path) + + applicable = self._get_rules(sf.file_type) + for rule in applicable: + # --- Policy-based rule filtering --- + # Skip disabled rules early + if self.policy.is_rule_disabled(rule.id): + continue + # Skip doc-only exclusions + if is_doc and rule.id in skip_in_docs: + continue + # Code-only rules should not fire on non-code files + if rule.id in self.policy.rule_scoping.code_only: + if sf.file_type not in ( + "python", + "bash", + "javascript", + "typescript", + ): + continue + + matches = rule.scan_content( + content, + file_path=sf.relative_path, + ) + for match in matches: + # Apply severity override if configured + severity = rule.severity + override = self.policy.get_severity_override(rule.id) + if override: + try: + severity = Severity(override) + except ValueError: + pass + + findings.append( + Finding( + id=( + f"{rule.id}:{sf.relative_path}" + f":{match['line_number']}" + ), + rule_id=rule.id, + category=rule.category, + severity=severity, + title=rule.description, + description=rule.description, + file_path=sf.relative_path, + line_number=match["line_number"], + snippet=match["line_content"], + remediation=rule.remediation, + analyzer=self.name, + metadata={ + "matched_pattern": match["matched_pattern"], + "matched_text": match["matched_text"], + }, + ), + ) + + # Filter well-known test credentials + findings = [ + f for f in findings if not self._is_known_test_credential(f) + ] + + # De-duplicate if enabled in policy + if self.policy.rule_scoping.dedupe_duplicate_findings: + findings = self._dedupe_findings(findings) + + return findings + + # ------------------------------------------------------------------ + # Credential filtering + # ------------------------------------------------------------------ + + def _is_known_test_credential(self, finding: Finding) -> bool: + """Suppress findings that match known test/placeholder credentials.""" + if finding.category != ThreatCategory.HARDCODED_SECRETS: + return False + snippet = (finding.snippet or "").lower() + for cred in self.policy.credentials.known_test_values: + if cred.lower() in snippet: + return True + for marker in self.policy.credentials.placeholder_markers: + if marker.lower() in snippet: + return True + return False + + # ------------------------------------------------------------------ + # De-duplication + # ------------------------------------------------------------------ + + @staticmethod + def _dedupe_findings(findings: list[Finding]) -> list[Finding]: + """Remove exact duplicate findings (same rule + file + line).""" + seen: set[str] = set() + unique: list[Finding] = [] + for f in findings: + key = f"{f.rule_id}:{f.file_path}:{f.line_number}" + if key not in seen: + seen.add(key) + unique.append(f) + return unique + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _get_rules(self, file_type: str) -> list[SecurityRule]: + """Return rules applicable to *file_type* (cached).""" + if file_type not in self._rules_by_file_type: + self._rules_by_file_type[file_type] = [ + r for r in self._rules if r.matches_file_type(file_type) + ] + return self._rules_by_file_type[file_type] diff --git a/src/copaw/security/skill_scanner/data/default_policy.yaml b/src/copaw/security/skill_scanner/data/default_policy.yaml new file mode 100644 index 000000000..c2b571d10 --- /dev/null +++ b/src/copaw/security/skill_scanner/data/default_policy.yaml @@ -0,0 +1,244 @@ +# CoPaw Skill Scanner – Default Scan Policy +# ============================================ +# This file defines the built-in security policy. Every setting here can be +# overridden in an organisation-specific policy file passed via --policy. +# +# To use a custom policy: +# scanner = SkillScanner(policy=ScanPolicy.from_yaml("my_org_policy.yaml")) + +policy_name: default +policy_version: "1.0" +preset_base: balanced + +# --------------------------------------------------------------------------- +# Hidden files – which dotfiles and dotdirs are considered benign +# --------------------------------------------------------------------------- +hidden_files: + benign_dotfiles: + - ".gitignore" + - ".gitattributes" + - ".gitmodules" + - ".gitkeep" + - ".editorconfig" + - ".prettierrc" + - ".prettierignore" + - ".eslintrc" + - ".eslintignore" + - ".eslintrc.json" + - ".eslintrc.js" + - ".eslintrc.yml" + - ".npmrc" + - ".npmignore" + - ".nvmrc" + - ".node-version" + - ".python-version" + - ".ruby-version" + - ".tool-versions" + - ".flake8" + - ".pylintrc" + - ".isort.cfg" + - ".mypy.ini" + - ".babelrc" + - ".browserslistrc" + - ".dockerignore" + - ".env.example" + - ".env.sample" + - ".env.template" + - ".markdownlint.json" + - ".markdownlintignore" + - ".yamllint" + - ".yamllint.yml" + - ".cursorrules" + - ".cursorignore" + - ".clang-format" + - ".clang-tidy" + - ".mcp.json" + - ".envrc" + - ".version" + + benign_dotdirs: + - ".github" + - ".vscode" + - ".idea" + - ".cursor" + - ".husky" + - ".circleci" + - ".gitlab" + - ".cache" + - ".tmp" + - ".data" + - ".next" + - ".nuxt" + - ".claude" + - ".devcontainer" + - ".vitepress" + - ".docusaurus" + - ".storybook" + +# --------------------------------------------------------------------------- +# Rule scoping – which rules fire where +# --------------------------------------------------------------------------- +rule_scoping: + skip_in_docs: + - "COMMAND_INJECTION_EVAL" + - "COMMAND_INJECTION_OS_SYSTEM" + - "COMMAND_INJECTION_SHELL_TRUE" + - "RESOURCE_ABUSE_INFINITE_LOOP" + + code_only: + - "COMMAND_INJECTION_EVAL" + - "COMMAND_INJECTION_OS_SYSTEM" + - "COMMAND_INJECTION_SHELL_TRUE" + - "RESOURCE_ABUSE_INFINITE_LOOP" + + doc_path_indicators: + - "docs" + - "doc" + - "examples" + - "example" + - "tutorials" + - "tutorial" + - "guides" + - "guide" + - "samples" + - "sample" + - "demo" + - "demos" + - "tests" + - "test" + + doc_filename_patterns: + - "readme" + - "example" + - "tutorial" + - "sample" + - "demo" + - "howto" + - "guide" + + dedupe_duplicate_findings: true + +# --------------------------------------------------------------------------- +# Credentials – known test/placeholder values to auto-suppress +# --------------------------------------------------------------------------- +credentials: + known_test_values: + - "test" + - "test123" + - "password" + - "changeme" + - "example" + - "dummy" + - "placeholder" + - "your-api-key" + - "your_api_key" + - "sk-test" + - "pk-test" + - "xxx" + - "abc123" + - "secret" + - "foobar" + + placeholder_markers: + - "your-" + - "your_" + - "your " + - "example" + - "sample" + - "dummy" + - "placeholder" + - "replace" + - "changeme" + - "change_me" + - " str: + """Read file content if not already loaded.""" + if self.content is None and self.path.exists(): + try: + with open(self.path, encoding="utf-8") as f: + self.content = f.read() + except (OSError, UnicodeDecodeError): + self.content = "" + return self.content or "" + + @property + def is_hidden(self) -> bool: + """Check if file is a dotfile or inside a hidden dir.""" + parts = Path(self.relative_path).parts + return any(part.startswith(".") and part != "." for part in parts) + + # ------------------------------------------------------------------ + # Factory helpers + # ------------------------------------------------------------------ + + @classmethod + def from_path(cls, path: Path, base_dir: Path) -> "SkillFile": + """Create a SkillFile from an on-disk path relative to *base_dir*.""" + rel = str(path.relative_to(base_dir)) + suffix = path.suffix.lower() + file_type = _FILE_TYPE_MAP.get(suffix, "other") + try: + size = path.stat().st_size + except OSError: + size = 0 + return cls( + path=path, + relative_path=rel, + file_type=file_type, + size_bytes=size, + ) + + +# --------------------------------------------------------------------------- +# Finding +# --------------------------------------------------------------------------- + + +@dataclass +class Finding: + """A security issue discovered during a skill scan.""" + + id: str + rule_id: str + category: ThreatCategory + severity: Severity + title: str + description: str + file_path: str | None = None + line_number: int | None = None + snippet: str | None = None + remediation: str | None = None + analyzer: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "rule_id": self.rule_id, + "category": self.category.value, + "severity": self.severity.value, + "title": self.title, + "description": self.description, + "file_path": self.file_path, + "line_number": self.line_number, + "snippet": self.snippet, + "remediation": self.remediation, + "analyzer": self.analyzer, + "metadata": self.metadata, + } + + +# --------------------------------------------------------------------------- +# Scan result +# --------------------------------------------------------------------------- + + +@dataclass +class ScanResult: + """Aggregated results from scanning a single skill.""" + + skill_name: str + skill_directory: str + findings: list[Finding] = field(default_factory=list) + scan_duration_seconds: float = 0.0 + analyzers_used: list[str] = field(default_factory=list) + analyzers_failed: list[dict[str, str]] = field(default_factory=list) + timestamp: datetime = field( + default_factory=lambda: datetime.now(timezone.utc), + ) + + # ------------------------------------------------------------------ + # Convenience properties + # ------------------------------------------------------------------ + + @property + def is_safe(self) -> bool: + """``True`` when there are no CRITICAL or HIGH findings.""" + return not any( + f.severity in (Severity.CRITICAL, Severity.HIGH) + for f in self.findings + ) + + @property + def max_severity(self) -> Severity: + """Return the highest severity found, or ``SAFE``.""" + if not self.findings: + return Severity.SAFE + order = [ + Severity.CRITICAL, + Severity.HIGH, + Severity.MEDIUM, + Severity.LOW, + Severity.INFO, + ] + for sev in order: + if any(f.severity == sev for f in self.findings): + return sev + return Severity.SAFE + + def get_findings_by_severity(self, severity: Severity) -> list[Finding]: + return [f for f in self.findings if f.severity == severity] + + def get_findings_by_category( + self, + category: ThreatCategory, + ) -> list[Finding]: + return [f for f in self.findings if f.category == category] + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "skill_name": self.skill_name, + "skill_path": self.skill_directory, + "is_safe": self.is_safe, + "max_severity": self.max_severity.value, + "findings_count": len(self.findings), + "findings": [f.to_dict() for f in self.findings], + "scan_duration_seconds": self.scan_duration_seconds, + "analyzers_used": self.analyzers_used, + "timestamp": self.timestamp.isoformat(), + } + if self.analyzers_failed: + result["analyzers_failed"] = self.analyzers_failed + return result diff --git a/src/copaw/security/skill_scanner/rules/signatures/command_injection.yaml b/src/copaw/security/skill_scanner/rules/signatures/command_injection.yaml new file mode 100644 index 000000000..1d41736b0 --- /dev/null +++ b/src/copaw/security/skill_scanner/rules/signatures/command_injection.yaml @@ -0,0 +1,194 @@ +# Command & Code Injection Signatures +# Detects dangerous code execution, shell injection, path traversal, and SQL injection + +- id: COMMAND_INJECTION_EVAL + category: command_injection + severity: CRITICAL + patterns: + - "(?]*>[^<]*" + - "\\bon\\w+\\s*=\\s*[\"'][^\"']*[\"']" + - "javascript\\s*:" + file_types: [other] + description: "SVG file contains embedded script tags or event handlers that can execute JavaScript" + remediation: "Remove all ", " ", content, flags=re.S | re.I) + content = re.sub(r"", " ", content, flags=re.S | re.I) + content = re.sub(r"<[^>]+>", " ", content) + return { + "path": url, + "title": title, + "text": KnowledgeManager._normalize_text(unescape(content)), + } + + @staticmethod + def _normalize_text(text: str) -> str: + compact = text.replace("\r", "\n") + # Remove all blank lines (lines containing only whitespace) + compact = re.sub(r"\n[ \t]*\n+", "\n", compact) + compact = re.sub(r"[ \t]+", " ", compact) + return compact.strip() + + def _extract_visible_chat_text(self, state: dict[str, Any]) -> str: + messages = self._messages_from_session_state(state) + snippets = [ + self._extract_text_from_runtime_message(message) + for message in messages + ] + return self._normalize_text("\n\n".join(item for item in snippets if item)) + + @staticmethod + def _messages_from_session_state(state: dict[str, Any]) -> list[dict[str, Any]]: + memory_state = state.get("agent", {}).get("memory", {}) + if isinstance(memory_state, dict): + raw_entries = memory_state.get("content", []) + elif isinstance(memory_state, list): + raw_entries = memory_state + else: + raw_entries = [] + + messages: list[dict[str, Any]] = [] + for entry in raw_entries: + raw_message = None + if isinstance(entry, list) and entry: + raw_message = entry[0] + elif isinstance(entry, dict): + raw_message = entry + if not isinstance(raw_message, dict): + continue + if raw_message.get("type") in {"plugin_call", "plugin_call_output"}: + continue + messages.append(raw_message) + return messages + + @staticmethod + def _extract_text_from_state(value: Any) -> str: + snippets: list[str] = [] + + def walk(node: Any) -> None: + if isinstance(node, str): + cleaned = node.strip() + if cleaned: + snippets.append(cleaned) + return + if isinstance(node, list): + for item in node: + walk(item) + return + if isinstance(node, dict): + for key in ("text", "thinking", "output", "name", "role"): + if key in node and isinstance(node[key], str): + walk(node[key]) + content = node.get("content") + if content is not None: + walk(content) + data = node.get("data") + if data is not None: + walk(data) + for nested_key, nested_value in node.items(): + if nested_key in {"text", "thinking", "output", "name", "role", "content", "data"}: + continue + if isinstance(nested_value, (dict, list)): + walk(nested_value) + + walk(value) + return "\n".join(snippets) + + def _build_file_sources_from_messages( + self, + messages: list[Any], + config: KnowledgeConfig, + session_id: str, + ) -> list[KnowledgeSourceSpec]: + sources: list[KnowledgeSourceSpec] = [] + seen_ids: set[str] = set() + for block in self._iter_message_blocks(messages): + if block.get("type") != "file": + continue + file_ref = self._file_reference_from_block(block) + if not file_ref: + continue + parsed_ref = urlparse(file_ref) + remote_hash = None + if parsed_ref.scheme in {"http", "https"}: + remote_hash = hashlib.sha1(file_ref.encode("utf-8")).hexdigest() + stored_path = self._materialize_file_reference( + file_ref, + block.get("name") or Path(file_ref).name or "chat-file", + config, + ) + if stored_path is None: + continue + digest = hashlib.sha1(str(stored_path).encode("utf-8")).hexdigest()[:12] + source_id = f"auto-file-{digest}" + if source_id in seen_ids: + continue + seen_ids.add(source_id) + tags = ["auto", "origin:auto", "source:chat", "auto:file"] + if remote_hash: + tags.extend(["remote:http", f"remote:url_hash:{remote_hash}"]) + sources.append( + KnowledgeSourceSpec( + id=source_id, + name=f"Auto File: {stored_path.name}", + type="file", + location=str(stored_path), + enabled=True, + recursive=False, + tags=tags, + description=f"Auto-collected from chat session {session_id}", + ), + ) + return sources + + def _build_text_sources_from_messages( + self, + messages: list[Any], + config: KnowledgeConfig, + session_id: str, + user_id: str, + long_text_min_chars: int, + ) -> list[KnowledgeSourceSpec]: + sources: list[KnowledgeSourceSpec] = [] + seen_ids: set[str] = set() + for role, text in self._iter_message_texts(messages): + normalized = self._normalize_text(text) + if len(normalized) < long_text_min_chars: + continue + digest = hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:12] + source_id = f"auto-text-{digest}" + if source_id in seen_ids: + continue + seen_ids.add(source_id) + title = normalized.splitlines()[0][:48] or "Long chat text" + sources.append( + KnowledgeSourceSpec( + id=source_id, + name=f"Auto Text: {title}", + type="text", + content=normalized, + enabled=True, + recursive=False, + tags=[ + "auto", + "origin:auto", + "source:chat", + "auto:text", + f"role:{role}", + ], + description=( + f"Auto-saved from {role} message in {session_id}" + + (f" for {user_id}" if user_id else "") + ), + ), + ) + return sources + + def _build_text_sources_from_turn_pair( + self, + request_messages: list[Any], + response_messages: list[Any], + session_id: str, + user_id: str, + long_text_min_chars: int, + ) -> list[KnowledgeSourceSpec]: + user_text = self._normalize_text( + "\n".join( + text + for role, text in self._iter_message_texts(request_messages) + if str(role).lower() == "user" + ) + ) + assistant_text = self._normalize_text( + "\n".join( + text + for role, text in self._iter_message_texts(response_messages) + if str(role).lower() == "assistant" + ) + ) + if not user_text or not assistant_text: + return [] + + merged = self._normalize_text(f"用户: {user_text}\n\n智能体: {assistant_text}") + if len(merged) < long_text_min_chars: + return [] + + digest = hashlib.sha1(merged.encode("utf-8")).hexdigest()[:12] + source_id = f"auto-text-{digest}" + title = merged.splitlines()[0][:48] or "Long chat text" + return [ + KnowledgeSourceSpec( + id=source_id, + name=f"Auto Text: {title}", + type="text", + content=merged, + enabled=True, + recursive=False, + tags=[ + "auto", + "origin:auto", + "source:chat", + "auto:text", + "role:turn_pair", + ], + description=( + f"Auto-saved from user-assistant turn in {session_id}" + + (f" for {user_id}" if user_id else "") + ), + ) + ] + + def _build_url_sources_from_messages( + self, + messages: list[Any], + session_id: str, + user_id: str, + automation_config: Any | None = None, + min_content_chars: int | None = None, + ) -> list[KnowledgeSourceSpec]: + sources: list[KnowledgeSourceSpec] = [] + seen_ids: set[str] = set() + for role, text in self._iter_message_texts(messages): + for url in self._extract_urls_from_text(text): + if self._should_exclude_url(url, automation_config): + logger.debug("Skipping excluded URL: %s", url) + continue + digest = hashlib.sha1(url.encode("utf-8")).hexdigest()[:12] + source_id = f"auto-url-{digest}" + if source_id in seen_ids: + continue + seen_ids.add(source_id) + label = url if len(url) <= 80 else f"{url[:77]}..." + fetched_text = "" + if min_content_chars is not None: + try: + doc = self._read_url_document(url) + except Exception: + continue + fetched_text = self._normalize_text(doc.get("text", "")) + if len(fetched_text) < max(0, min_content_chars): + continue + # Capture surrounding text context from the conversation message + # so title generation can use it without fetching the URL. + context_snippet = self._extract_url_context(text, url, max_chars=400) + description = ( + f"Auto-collected URL from {role} message in {session_id}" + + (f" for {user_id}" if user_id else "") + ) + if context_snippet: + description = f"{description}\n来源上下文: {context_snippet}" + sources.append( + KnowledgeSourceSpec( + id=source_id, + name=f"Auto URL: {label}", + type="url", + location=url, + content=fetched_text, + enabled=True, + recursive=False, + tags=[ + "auto", + "origin:auto", + "source:chat", + "auto:url", + f"role:{role}", + ], + description=description, + ), + ) + return sources + + @staticmethod + def _extract_urls_from_text(text: str) -> list[str]: + found: list[str] = [] + seen: set[str] = set() + for match in _CHAT_URL_RE.findall(text or ""): + cleaned = match.rstrip(_URL_TRAILING_STRIP_CHARS) + if not cleaned: + continue + + # Defensive normalization: a previously merged token may contain + # additional URLs separated by CJK words or punctuation. + normalized_urls: list[str] = [] + cjk_chunks = re.split(r"[\u4e00-\u9fff]+", cleaned) + for chunk in cjk_chunks: + chunk = chunk.strip() + if not chunk: + continue + nested = _CHAT_URL_RE.findall(chunk) + if nested: + normalized_urls.extend(nested) + else: + normalized_urls.append(chunk) + + for candidate in normalized_urls: + candidate = candidate.rstrip(_URL_TRAILING_STRIP_CHARS) + if not candidate or candidate in seen: + continue + seen.add(candidate) + found.append(candidate) + return found + + @staticmethod + def _should_exclude_url(url: str, automation_config: Any | None = None) -> bool: + """Return True if the URL should be excluded from auto-collection. + + Exclusion criteria (all can be toggled via automation_config): + - Private/intranet addresses (localhost, 127.x, 192.168.x, etc.) + - URLs containing credential/token query parameters + - User-defined exclusion prefix patterns + """ + try: + parsed = urlparse(url) + except Exception: + return False + + # Private-address exclusion + exclude_private = True + if automation_config is not None: + exclude_private = bool( + getattr(automation_config, "url_exclude_private_addresses", True) + ) + if exclude_private: + host = parsed.hostname or "" + if _PRIVATE_HOST_RE.match(host): + return True + + # Token/credential query-param exclusion + exclude_tokens = True + if automation_config is not None: + exclude_tokens = bool( + getattr(automation_config, "url_exclude_token_params", True) + ) + if exclude_tokens and parsed.query: + try: + params = parse_qs(parsed.query, keep_blank_values=True) + if any(k.lower() in _URL_SENSITIVE_PARAMS for k in params): + return True + except Exception: + pass + + # User-defined pattern exclusion (prefix match) + exclude_patterns: list[str] = [] + if automation_config is not None: + raw = getattr(automation_config, "url_exclude_patterns", None) + if isinstance(raw, list): + exclude_patterns = raw + for pattern in exclude_patterns: + if isinstance(pattern, str) and url.startswith(pattern): + return True + + return False + + @staticmethod + def _extract_url_context(text: str, url: str, max_chars: int = 400) -> str: + """Extract a snippet of surrounding text around the given URL. + + Returns up to max_chars/2 chars before and max_chars/2 after + the URL occurrence, stripped and compacted. + """ + idx = text.find(url) + if idx == -1: + return "" + half = max_chars // 2 + before = text[max(0, idx - half): idx].strip() + after = text[idx + len(url): idx + len(url) + half].strip() + parts = [p for p in (before, after) if p] + snippet = " ... ".join(parts) + # Compact whitespace + snippet = re.sub(r"\s+", " ", snippet).strip() + return snippet[:max_chars] + + @staticmethod + def _block_to_dict(block: Any) -> dict[str, Any] | None: + if isinstance(block, dict): + return block + if hasattr(block, "model_dump"): + return block.model_dump() + if hasattr(block, "dict"): + return block.dict() + return None + + def _iter_message_blocks(self, messages: list[Any]) -> list[dict[str, Any]]: + blocks: list[dict[str, Any]] = [] + for message in messages: + content = getattr(message, "content", None) + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + blocks.append({"type": "text", "text": content}) + continue + if not isinstance(content, list): + continue + for block in content: + payload = self._block_to_dict(block) + if payload is not None: + blocks.append(payload) + return blocks + + def _iter_message_texts(self, messages: list[Any]) -> list[tuple[str, str]]: + texts: list[tuple[str, str]] = [] + for message in messages: + role = getattr(message, "role", None) + if isinstance(message, dict): + role = message.get("role") + role = role or "assistant" + if isinstance(message, dict) and message.get("type") in { + "plugin_call", + "plugin_call_output", + }: + continue + content = getattr(message, "content", None) + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + texts.append((role, content)) + continue + if not isinstance(content, list): + continue + joined: list[str] = [] + for block in content: + payload = self._block_to_dict(block) + if not payload or payload.get("type") != "text": + continue + text = payload.get("text") + if isinstance(text, str) and text.strip(): + joined.append(text) + if joined: + texts.append((role, "\n".join(joined))) + return texts + + def _extract_text_from_runtime_message(self, message: dict[str, Any]) -> str: + content = message.get("content") + if isinstance(content, str): + return self._normalize_text(content) + if not isinstance(content, list): + return "" + snippets: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") == "text": + text = block.get("text") + if isinstance(text, str) and text.strip(): + snippets.append(text) + return self._normalize_text("\n".join(snippets)) + + @staticmethod + def _file_reference_from_block(block: dict[str, Any]) -> str: + for key in ("file_url", "path", "url"): + value = block.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + source = block.get("source") + if isinstance(source, dict): + value = source.get("url") or source.get("path") + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + def _materialize_file_reference( + self, + file_ref: str, + filename: str, + config: KnowledgeConfig, + ) -> Path | None: + parsed = urlparse(file_ref) + if parsed.scheme in {"", "file"}: + local_value = parsed.path if parsed.scheme == "file" else file_ref + path = Path(local_value).expanduser() + if path.exists() and path.is_file(): + return path.resolve() + return None + + if parsed.scheme in {"http", "https"}: + downloaded_name = Path(parsed.path).name or filename + return self._download_remote_file_with_cache( + file_ref, + downloaded_name, + config, + ) + return None + + def _download_remote_file_with_cache( + self, + url: str, + filename: str, + config: KnowledgeConfig, + ) -> Path | None: + url_key = hashlib.sha1(url.encode("utf-8")).hexdigest() + meta_path = self.remote_meta_dir / f"{url_key}.json" + now = datetime.now(UTC) + metadata: dict[str, Any] = {} + + if meta_path.exists(): + metadata = self._load_json(meta_path) + cached_path = metadata.get("file_path") + if isinstance(cached_path, str) and cached_path: + cached_file = Path(cached_path) + if cached_file.exists() and cached_file.is_file(): + return cached_file + next_retry_at = self._parse_iso_utc(metadata.get("next_retry_at")) + if next_retry_at is not None and next_retry_at > now: + return None + + try: + response = httpx.get(url, timeout=15.0, follow_redirects=True) + response.raise_for_status() + content = response.content + if len(content) > config.index.max_file_size: + self._save_remote_meta( + meta_path, + { + "url": url, + "status": "failed", + "last_error": "file too large", + "fail_count": 1, + "next_retry_at": ( + now + timedelta(seconds=30) + ).isoformat(), + "updated_at": now.isoformat(), + }, + ) + return None + + content_hash = hashlib.sha1(content).hexdigest() + blob_dir = self.remote_blob_dir / content_hash + blob_dir.mkdir(parents=True, exist_ok=True) + safe_name = self._safe_name(Path(filename).name) + blob_path = blob_dir / safe_name + if not blob_path.exists(): + blob_path.write_bytes(content) + + self._save_remote_meta( + meta_path, + { + "url": url, + "status": "ok", + "content_hash": content_hash, + "file_path": str(blob_path), + "file_name": safe_name, + "fail_count": 0, + "next_retry_at": None, + "updated_at": now.isoformat(), + }, + ) + return blob_path + except Exception as exc: + fail_count = int(metadata.get("fail_count", 0)) + 1 + backoff_seconds = min(300, 5 * (2 ** (fail_count - 1))) + self._save_remote_meta( + meta_path, + { + "url": url, + "status": "failed", + "last_error": str(exc), + "fail_count": fail_count, + "next_retry_at": ( + now + timedelta(seconds=backoff_seconds) + ).isoformat(), + "updated_at": now.isoformat(), + }, + ) + return None + + @staticmethod + def _save_remote_meta(path: Path, payload: dict[str, Any]) -> None: + path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + @staticmethod + def _parse_iso_utc(value: Any) -> datetime | None: + if not isinstance(value, str) or not value: + return None + text = value.strip() + if text.endswith("Z"): + text = text[:-1] + "+00:00" + try: + parsed = datetime.fromisoformat(text) + except ValueError: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=UTC) + return parsed.astimezone(UTC) + + def _history_backfill_signature(self, running_config: Any | None) -> str: + payload = { + "auto_collect_chat_files": bool( + getattr(running_config, "auto_collect_chat_files", False), + ), + "auto_collect_chat_urls": bool( + getattr(running_config, "auto_collect_chat_urls", True), + ), + "auto_collect_long_text": bool( + getattr(running_config, "auto_collect_long_text", False), + ), + "long_text_min_chars": int( + getattr(running_config, "long_text_min_chars", 2000), + ), + "knowledge_chunk_size": int( + getattr(running_config, "knowledge_chunk_size", 1200), + ), + "version": 2, + } + return hashlib.sha1( + json.dumps(payload, sort_keys=True, ensure_ascii=False).encode("utf-8"), + ).hexdigest() + + @staticmethod + def _resolve_chunk_size( + config: KnowledgeConfig, + running_config: Any | None, + ) -> int: + chunk_size = getattr(running_config, "knowledge_chunk_size", None) + if isinstance(chunk_size, int): + return chunk_size + return config.index.chunk_size + + def history_backfill_status(self) -> dict[str, Any]: + """Return whether historical chat data still needs knowledge backfill.""" + state = self._load_backfill_state() + has_backfill_record = bool(state) + backfill_completed = bool(state.get("completed")) + + chats_path = self.working_dir / CHATS_FILE + history_chat_count = 0 + if chats_path.exists(): + try: + payload = self._load_json(chats_path) + chats = payload.get("chats", []) + history_chat_count = sum( + 1 + for chat in chats + if str(chat.get("session_id", "") or "").strip() + ) + except Exception: + history_chat_count = 0 + + marked_unbackfilled = not backfill_completed + has_pending_history = marked_unbackfilled and history_chat_count > 0 + return { + "has_backfill_record": has_backfill_record, + "backfill_completed": backfill_completed, + "marked_unbackfilled": marked_unbackfilled, + "history_chat_count": history_chat_count, + "has_pending_history": has_pending_history, + "progress": self.get_history_backfill_progress(), + } + + def get_history_backfill_progress(self) -> dict[str, Any]: + payload = self._load_backfill_progress_state() + return { + "running": bool(payload.get("running")), + "completed": bool(payload.get("completed")), + "failed": bool(payload.get("failed")), + "total_sessions": int(payload.get("total_sessions", 0) or 0), + "traversed_sessions": int(payload.get("traversed_sessions", 0) or 0), + "processed_sessions": int(payload.get("processed_sessions", 0) or 0), + "current_session_id": payload.get("current_session_id"), + "error": payload.get("error"), + "updated_at": payload.get("updated_at"), + "reason": payload.get("reason"), + } + + def _load_backfill_state(self) -> dict[str, Any]: + if not self.backfill_state_path.exists(): + return {} + try: + return self._load_json(self.backfill_state_path) + except Exception: + return {} + + def _save_backfill_state(self, payload: dict[str, Any]) -> None: + self.backfill_state_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def _load_backfill_progress_state(self) -> dict[str, Any]: + if not self.backfill_progress_path.exists(): + return {} + try: + return self._load_json(self.backfill_progress_path) + except Exception: + return {} + + def _save_backfill_progress(self, payload: dict[str, Any]) -> None: + self.backfill_progress_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def _remote_source_status(self, source: KnowledgeSourceSpec) -> dict[str, Any]: + remote_hash = "" + for tag in source.tags or []: + if tag.startswith("remote:url_hash:"): + remote_hash = tag.split(":", 2)[-1] + break + if not remote_hash: + return {} + + meta_path = self.remote_meta_dir / f"{remote_hash}.json" + if not meta_path.exists(): + return { + "remote_status": "unknown", + "remote_cache_state": "missing", + "remote_fail_count": 0, + "remote_next_retry_at": None, + "remote_last_error": None, + "remote_updated_at": None, + } + + payload = self._load_json(meta_path) + remote_status = payload.get("status", "unknown") + fail_count = int(payload.get("fail_count", 0) or 0) + next_retry_at = payload.get("next_retry_at") + next_retry_dt = self._parse_iso_utc(next_retry_at) + now = datetime.now(UTC) + + if remote_status == "ok": + if source.location and Path(source.location).exists(): + cache_state = "cached" + else: + cache_state = "missing" + elif remote_status == "failed": + if next_retry_dt is not None and next_retry_dt > now: + cache_state = "waiting_retry" + else: + cache_state = "ready_retry" + else: + cache_state = "unknown" + + return { + "remote_status": remote_status, + "remote_cache_state": cache_state, + "remote_fail_count": fail_count, + "remote_next_retry_at": next_retry_at, + "remote_last_error": payload.get("last_error"), + "remote_updated_at": payload.get("updated_at"), + } + + def _upsert_source( + self, + config: KnowledgeConfig, + source: KnowledgeSourceSpec, + ) -> bool: + normalized = self._source_with_auto_name(source, config) + for index, existing in enumerate(config.sources): + if existing.id != normalized.id: + continue + existing_normalized = self._source_with_auto_name(existing, config) + if existing_normalized.model_dump(mode="json") == normalized.model_dump(mode="json"): + return False + config.sources[index] = normalized + return True + config.sources.append(normalized) + return True + + def _source_with_auto_name( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> KnowledgeSourceSpec: + updates: dict[str, Any] = {} + source_for_title = source + + if not (source.description or "").strip(): + generated_description = self._generate_source_description(source, config) + if generated_description: + updates["description"] = generated_description + source_for_title = source.model_copy( + update={"description": generated_description} + ) + + generated = self._generate_source_name(source_for_title, config) + if source.name != generated: + updates["name"] = generated + + if not updates: + return source + return source.model_copy(update=updates) + + def _generate_source_description( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + semantic = self._semantic_description_for_source(source, config) + if semantic: + return self._truncate_description(semantic) + + if source.type == "url": + url = (source.location or "").strip() + if url: + parsed = urlparse(url) + host = parsed.netloc or url + path = parsed.path.strip("/") + tail = path.split("/")[-1] if path else "" + if tail: + return self._truncate_description(f"{host}/{tail}") + return self._truncate_description(host) + + if source.type in {"file", "directory"} and source.location: + location = (source.location or "").strip() + if location: + return self._truncate_description(Path(location).name or location) + + if source.name: + return self._truncate_description(source.name) + return "" + + def _semantic_description_for_source( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + candidates: list[str] = [] + + full_text = self._collect_full_text_for_local_title(source, config) + if full_text: + candidates.append(full_text) + + indexed_payload = self._load_index_payload_safe(source.id) + if indexed_payload: + chunk_texts = [ + chunk.get("text", "") + for chunk in indexed_payload.get("chunks", []) + if isinstance(chunk, dict) and isinstance(chunk.get("text"), str) + ] + if chunk_texts: + candidates.append("\n".join(chunk_texts)) + + for candidate in candidates: + sentence = self._semantic_title_from_text(candidate) + if sentence: + return sentence + return "" + + def _generate_source_name( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + semantic = self._semantic_title_for_source(source, config) + if semantic: + return self._truncate_title(semantic) + + if source.type == "url": + url = (source.location or "").strip() + if url: + parsed = urlparse(url) + host = parsed.netloc or url + path = parsed.path.strip("/") + tail = path.split("/")[-1] if path else "" + if tail: + return self._truncate_title(f"{host}/{tail}") + return self._truncate_title(host) + + if source.type in {"file", "directory"} and source.location: + location = (source.location or "").strip() + if location: + return self._truncate_title(Path(location).name or location) + + return self._truncate_title(source.id) + + def _semantic_title_for_source( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + candidates: list[str] = [] + + if source.description: + candidates.append(source.description) + + indexed_payload = self._load_index_payload_safe(source.id) + if indexed_payload: + chunk_titles: list[str] = [] + chunk_texts: list[str] = [] + for chunk in indexed_payload.get("chunks", []): + if not isinstance(chunk, dict): + continue + chunk_title = chunk.get("document_title") + if isinstance(chunk_title, str) and chunk_title.strip(): + chunk_titles.append(chunk_title) + chunk_text = chunk.get("text") + if isinstance(chunk_text, str) and chunk_text.strip(): + chunk_texts.append(chunk_text) + if chunk_titles: + candidates.append("\n".join(chunk_titles)) + if chunk_texts: + candidates.append("\n".join(chunk_texts)) + + for candidate in candidates: + title = self._semantic_title_from_text(candidate) + if title: + return title + return "" + + def _collect_full_text_for_local_title( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + parts: list[str] = [] + + if source.content and source.content.strip(): + parts.append(source.content) + + indexed_payload = self._load_index_payload_safe(source.id) + if indexed_payload: + chunk_texts = [ + chunk.get("text", "") + for chunk in indexed_payload.get("chunks", []) + if isinstance(chunk, dict) and isinstance(chunk.get("text"), str) + ] + if chunk_texts: + parts.append("\n".join(chunk_texts)) + + location = (source.location or "").strip() + if source.type == "file" and location: + full_text = self._read_local_text(Path(location)) + if full_text: + parts.append(full_text) + elif source.type == "directory" and location: + full_text = self._read_directory_text(Path(location), config) + if full_text: + parts.append(full_text) + + merged = self._normalize_text("\n".join(part for part in parts if part)) + return merged + + def _read_local_text(self, path: Path) -> str: + try: + resolved = path.expanduser().resolve() + if not resolved.exists() or not resolved.is_file(): + return "" + raw = resolved.read_text(encoding="utf-8", errors="ignore") + return self._normalize_text(raw) + except Exception: + return "" + + def _read_directory_text( + self, + directory: Path, + config: KnowledgeConfig | None = None, + ) -> str: + try: + root = directory.expanduser().resolve() + if not root.exists() or not root.is_dir(): + return "" + parts: list[str] = [] + for path in root.rglob("*"): + if not path.is_file(): + continue + if config is not None: + relative = path.relative_to(root).as_posix() + if not self._is_allowed_path(relative, config): + continue + text = self._read_local_text(path) + if text: + parts.append(text) + return self._normalize_text("\n".join(parts)) + except Exception: + return "" + + def _load_index_payload_safe(self, source_id: str) -> dict[str, Any] | None: + index_path = self._index_path(source_id) + if not index_path.exists(): + return None + try: + return self._load_json(index_path) + except Exception: + return None + + def _read_local_text_snippet(self, path: Path, max_chars: int = 2400) -> str: + try: + resolved = path.expanduser().resolve() + if not resolved.exists() or not resolved.is_file(): + return "" + raw = resolved.read_text(encoding="utf-8", errors="ignore") + return self._normalize_text(raw[:max_chars]) + except Exception: + return "" + + def _read_directory_text_snippet(self, directory: Path, max_chars: int = 2400) -> str: + try: + root = directory.expanduser().resolve() + if not root.exists() or not root.is_dir(): + return "" + for path in root.rglob("*"): + if not path.is_file(): + continue + snippet = self._read_local_text_snippet(path, max_chars=max_chars) + if snippet: + return snippet + except Exception: + return "" + return "" + + def _semantic_title_from_text(self, text: str) -> str: + normalized = self._normalize_text(text or "") + if not normalized: + return "" + + sentences = [ + s.strip(" \t-::;;,.。!?!?") + for s in _TITLE_SENTENCE_SPLIT_RE.split(normalized) + if s.strip() + ] + if not sentences: + return "" + + token_freq: dict[str, int] = {} + sentence_tokens: list[list[str]] = [] + for sentence in sentences: + tokens = [ + tok.lower() + for tok in _TITLE_WORD_RE.findall(sentence) + if tok.lower() not in _TITLE_STOP_WORDS + ] + sentence_tokens.append(tokens) + for token in tokens: + token_freq[token] = token_freq.get(token, 0) + 1 + + best_sentence = "" + best_score = -1.0 + for sentence, tokens in zip(sentences, sentence_tokens): + if not tokens: + score = 0.0 + else: + unique_score = sum(token_freq.get(token, 0) for token in set(tokens)) + score = unique_score / (len(tokens) ** 0.5) + if score > best_score: + best_score = score + best_sentence = sentence + + if not best_sentence: + best_sentence = sentences[0] + return self._normalize_text(best_sentence) + + @staticmethod + def _truncate_title(value: str, max_len: int = 120) -> str: + compact = re.sub(r"\s+", " ", (value or "").strip()) + if not compact: + compact = "knowledge" + if len(compact) <= max_len: + return compact + return compact[: max_len - 3].rstrip() + "..." + + @staticmethod + def _truncate_description(value: str, max_len: int = 180) -> str: + compact = re.sub(r"\s+", " ", (value or "").strip()) + if not compact: + return "" + if len(compact) <= max_len: + return compact + return compact[: max_len - 3].rstrip() + "..." + + @staticmethod + def _chunk_documents( + documents: list[dict[str, str]], + chunk_size: int, + ) -> list[dict[str, Any]]: + chunks: list[dict[str, Any]] = [] + for document in documents: + text = document["text"] + if not text: + continue + for index, start in enumerate(range(0, len(text), chunk_size)): + chunk_text = text[start : start + chunk_size] + if not chunk_text.strip(): + continue + chunks.append( + { + "chunk_id": f"{document['path']}::{index}", + "document_path": document["path"], + "document_title": document["title"], + "text": chunk_text, + }, + ) + return chunks + + @staticmethod + def _score_chunk(text: str, terms: list[str]) -> int: + lowered = text.lower() + score = 0 + phrase = " ".join(terms) + if phrase and phrase in lowered: + score += len(terms) + 2 + for term in terms: + score += lowered.count(term) + return score + + @staticmethod + def _build_snippet(text: str, terms: list[str], length: int = 240) -> str: + lowered = text.lower() + position = 0 + for term in terms: + found = lowered.find(term) + if found >= 0: + position = found + break + start = max(position - 60, 0) + end = min(start + length, len(text)) + return text[start:end].strip() + + @staticmethod + def _is_allowed_path(relative_path: str, config: KnowledgeConfig) -> bool: + normalized = relative_path.strip("/") + if any( + fnmatch.fnmatch(normalized, pattern) + for pattern in config.index.exclude_globs + ): + return False + if not config.index.include_globs: + return True + return any( + fnmatch.fnmatch(normalized, pattern) + or fnmatch.fnmatch(f"./{normalized}", pattern) + for pattern in config.index.include_globs + ) + + @staticmethod + def _safe_name(value: str) -> str: + safe = re.sub(r"[^A-Za-z0-9._-]+", "-", value).strip("-.") + return safe or "knowledge" + + @staticmethod + def _session_filename(session_id: str, user_id: str) -> str: + safe_sid = _sanitize_filename(session_id) + safe_uid = _sanitize_filename(user_id) if user_id else "" + if safe_uid: + return f"{safe_uid}_{safe_sid}.json" + return f"{safe_sid}.json" \ No newline at end of file diff --git a/tests/unit/app/crons/test_manager.py b/tests/unit/app/crons/test_manager.py new file mode 100644 index 000000000..88f00167b --- /dev/null +++ b/tests/unit/app/crons/test_manager.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from copaw.app.crons.manager import CronManager +from copaw.app.crons.models import CronJobSpec, JobsFile +from copaw.app.crons.repo.base import BaseJobRepository + + +class _MemoryRepo(BaseJobRepository): + def __init__(self, jobs_file: JobsFile): + self._jobs_file = jobs_file + + async def load(self) -> JobsFile: + return self._jobs_file + + async def save(self, jobs_file: JobsFile) -> None: + self._jobs_file = jobs_file + + +async def test_start_skips_invalid_persisted_cron_job() -> None: + invalid_job = CronJobSpec.model_validate( + { + "id": "bad-job", + "name": "Bad Job", + "enabled": True, + "schedule": {"type": "cron", "cron": "0 */30 * * *", "timezone": "UTC"}, + "task_type": "text", + "text": "hello", + "dispatch": { + "type": "channel", + "channel": "console", + "target": {"user_id": "u1", "session_id": "s1"}, + }, + } + ) + valid_job = CronJobSpec.model_validate( + { + "id": "good-job", + "name": "Good Job", + "enabled": True, + "schedule": {"type": "cron", "cron": "0 * * * *", "timezone": "UTC"}, + "task_type": "text", + "text": "hello", + "dispatch": { + "type": "channel", + "channel": "console", + "target": {"user_id": "u1", "session_id": "s1"}, + }, + } + ) + repo = _MemoryRepo(JobsFile(version=1, jobs=[invalid_job, valid_job])) + manager = CronManager(repo=repo, runner=object(), channel_manager=object()) + + await manager.start() + + bad_state = manager.get_state("bad-job") + good_state = manager.get_state("good-job") + assert bad_state.last_status == "error" + assert bad_state.last_error is not None + assert good_state.next_run_at is not None + + await manager.stop() diff --git a/tests/unit/app/routers/test_knowledge.py b/tests/unit/app/routers/test_knowledge.py new file mode 100644 index 000000000..985b76df7 --- /dev/null +++ b/tests/unit/app/routers/test_knowledge.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- + +from pathlib import Path + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from copaw.app.routers import knowledge as knowledge_router_module +from copaw.config.config import Config +from copaw.knowledge import GraphOpsManager, KnowledgeManager + + +@pytest.fixture +def knowledge_api_client(tmp_path: Path, monkeypatch) -> TestClient: + config = Config() + state = {"config": config} + + def fake_load_config(): + return state["config"] + + def fake_save_config(new_config): + state["config"] = new_config + + monkeypatch.setattr(knowledge_router_module, "load_config", fake_load_config) + monkeypatch.setattr(knowledge_router_module, "save_config", fake_save_config) + monkeypatch.setattr(knowledge_router_module, "WORKING_DIR", tmp_path) + + app = FastAPI() + app.include_router(knowledge_router_module.router) + return TestClient(app) + + +def test_upsert_source_auto_generates_description_when_empty( + knowledge_api_client: TestClient, +): + response = knowledge_api_client.put( + "/knowledge/sources", + json={ + "id": "text-auto-description", + "name": "Manual Name", + "type": "text", + "content": "Quarterly planning checklist and milestone review for the release train.", + "enabled": True, + "recursive": False, + "tags": [], + "description": "", + }, + ) + + assert response.status_code == 200 + generated = response.json()["description"] + assert generated + + +def test_upsert_source_title_prefers_description_over_content( + knowledge_api_client: TestClient, +): + response = knowledge_api_client.put( + "/knowledge/sources", + json={ + "id": "text-title-from-description", + "name": "Manual Name", + "type": "text", + "content": "Very long internal content that should not be the direct title source.", + "enabled": True, + "recursive": False, + "tags": [], + "description": "Release checklist summary for sprint handoff", + }, + ) + + assert response.status_code == 200 + assert response.json()["name"] == "Release checklist summary for sprint handoff" + + +def test_list_sources_uses_auto_generated_titles_for_existing_config( + knowledge_api_client: TestClient, +): + put_response = knowledge_api_client.put( + "/knowledge/sources", + json={ + "id": "url-auto-title", + "name": "Old URL Name", + "type": "url", + "location": "https://example.com/path/to/guide.html", + "enabled": True, + "recursive": False, + "tags": ["remote"], + "description": "", + }, + ) + assert put_response.status_code == 200 + + listing = knowledge_api_client.get("/knowledge/sources") + assert listing.status_code == 200 + source = listing.json()["sources"][0] + assert source["name"] == "example.com/guide.html" + + +def test_history_backfill_status_includes_progress( + knowledge_api_client: TestClient, +): + response = knowledge_api_client.get("/knowledge/history-backfill/status") + assert response.status_code == 200 + payload = response.json() + assert "progress" in payload + assert isinstance(payload["progress"], dict) + assert "running" in payload["progress"] + + +def test_history_backfill_progress_ws_snapshot( + knowledge_api_client: TestClient, +): + with knowledge_api_client.websocket_connect( + "/knowledge/history-backfill/progress/ws?interval_ms=300", + ) as ws: + payload = ws.receive_json() + + assert payload["type"] == "snapshot" + assert isinstance(payload["progress"], dict) + assert "running" in payload["progress"] + + +def test_clear_knowledge_requires_confirmation(knowledge_api_client: TestClient): + response = knowledge_api_client.delete("/knowledge/clear") + assert response.status_code == 400 + assert response.json()["detail"] == "KNOWLEDGE_CLEAR_CONFIRM_REQUIRED" + + +def test_clear_knowledge_removes_sources_and_indexes( + knowledge_api_client: TestClient, + tmp_path: Path, +): + config_payload = Config().knowledge.model_dump(mode="json") + config_payload["sources"] = [ + { + "id": "clear-1", + "name": "to-clear", + "type": "text", + "location": "", + "content": "clear me", + "enabled": True, + "recursive": False, + "tags": [], + "description": "", + } + ] + saved = knowledge_api_client.put("/knowledge/config", json=config_payload) + assert saved.status_code == 200 + + index_root = tmp_path / "knowledge" / "indexes" + index_root.mkdir(parents=True, exist_ok=True) + (index_root / "clear-1.json").write_text("{}", encoding="utf-8") + + response = knowledge_api_client.delete( + "/knowledge/clear?confirm=true&remove_sources=true" + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["cleared"] is True + assert payload["removed_source_configs"] is True + assert payload["cleared_sources"] == 1 + assert payload["cleared_indexes"] >= 1 + + +def test_read_url_document_skips_binary_content(monkeypatch): + class _Resp: + headers = {"content-type": "image/png"} + text = "binary" + + @staticmethod + def raise_for_status(): + return None + + monkeypatch.setattr("copaw.knowledge.manager.httpx.get", lambda *args, **kwargs: _Resp()) + + doc = KnowledgeManager._read_url_document("https://example.com/a.png") + + assert doc["path"] == "https://example.com/a.png" + assert doc["text"] == "" + + +def test_get_memify_job_status_requires_memify_enabled( + knowledge_api_client: TestClient, +): + config_payload = Config().knowledge.model_dump(mode="json") + config_payload["enabled"] = True + config_payload["memify_enabled"] = False + saved = knowledge_api_client.put("/knowledge/config", json=config_payload) + assert saved.status_code == 200 + + response = knowledge_api_client.get("/knowledge/memify/jobs/job-1") + assert response.status_code == 400 + assert response.json()["detail"] == "MEMIFY_DISABLED" + + +def test_get_memify_job_status_success( + knowledge_api_client: TestClient, + tmp_path: Path, +): + knowledge_config = Config().knowledge + config_payload = knowledge_config.model_dump(mode="json") + config_payload["enabled"] = True + config_payload["memify_enabled"] = True + saved = knowledge_api_client.put("/knowledge/config", json=config_payload) + assert saved.status_code == 200 + + knowledge_config.enabled = True + knowledge_config.memify_enabled = True + + graph_ops = GraphOpsManager(tmp_path) + job = graph_ops.run_memify( + config=knowledge_config, + pipeline_type="default", + dataset_scope=[], + idempotency_key="route-status-job", + dry_run=False, + ) + job_id = job["job_id"] + + response = knowledge_api_client.get(f"/knowledge/memify/jobs/{job_id}") + assert response.status_code == 200 + payload = response.json() + assert payload["job_id"] == job_id + assert payload["status"] in {"succeeded", "failed"} diff --git a/tests/unit/app/runner/test_knowledge_context_injection.py b/tests/unit/app/runner/test_knowledge_context_injection.py new file mode 100644 index 000000000..6353576f9 --- /dev/null +++ b/tests/unit/app/runner/test_knowledge_context_injection.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, AsyncIterator, cast + +from agentscope.message import Msg, TextBlock +from agentscope_runtime.engine.schemas.agent_schemas import AgentRequest + +from copaw.app.runner.runner import AgentRunner +from copaw.app.runner.session import SafeJSONSession + + +class _DummyAgent: + captured_input_msgs = None + + def __init__(self, *args, **kwargs) -> None: + _ = args, kwargs + + async def register_mcp_clients(self) -> None: + return None + + def set_console_output_enabled(self, enabled: bool) -> None: + _ = enabled + + def rebuild_sys_prompt(self) -> None: + return None + + def __call__(self, msgs): + type(self).captured_input_msgs = msgs + return object() + + +class _DummySession(SafeJSONSession): + def __init__(self) -> None: + super().__init__(save_dir=".") + + async def load_session_state( + self, + session_id: str, + user_id: str = "", + allow_not_exist: bool = True, + **state_modules_mapping, + ) -> None: + _ = session_id, user_id, allow_not_exist, state_modules_mapping + + async def save_session_state( + self, + session_id: str, + user_id: str = "", + **state_modules_mapping, + ) -> None: + _ = session_id, user_id, state_modules_mapping + + +async def test_query_handler_does_not_inject_knowledge_context(monkeypatch) -> None: + from copaw.app.runner import runner as runner_module + + async def _no_approval(session_id: str, query: str | None): + _ = session_id, query + return None, False + + async def _stream_printing_messages(*args, **kwargs): + _ = args, kwargs + yield ( + Msg( + name="assistant", + role="assistant", + content=[TextBlock(type="text", text="ok")], + ), + True, + ) + + runner = AgentRunner() + runner.session = _DummySession() + cast(Any, runner)._resolve_pending_approval = _no_approval + + monkeypatch.setattr(runner_module, "CoPawAgent", _DummyAgent) + monkeypatch.setattr(runner_module, "build_env_context", lambda **kwargs: kwargs) + monkeypatch.setattr( + runner_module, + "load_config", + lambda: SimpleNamespace( + agents=SimpleNamespace( + running=SimpleNamespace( + max_iters=8, + max_input_length=8192, + auto_collect_chat_files=False, + auto_collect_chat_urls=False, + auto_collect_long_text=False, + ), + ), + knowledge=SimpleNamespace(enabled=True), + ), + ) + monkeypatch.setattr( + runner_module, + "stream_printing_messages", + _stream_printing_messages, + ) + + msgs = [ + Msg( + name="user", + role="user", + content=[TextBlock(type="text", text="如何接入知识库")], + ), + ] + request = cast( + AgentRequest, + SimpleNamespace( + session_id="session-1", + user_id="user-1", + channel="console", + ), + ) + + stream = cast( + AsyncIterator[tuple[Msg, bool]], + cast(Any, runner).query_handler(msgs, request=request), + ) + async for _msg, _last in stream: + pass + + captured = cast(list[Msg], _DummyAgent.captured_input_msgs) + assert captured is not None + assert len(captured) == 1 + assert captured[0].role == "user" + + +async def test_query_handler_skips_knowledge_context_when_disabled(monkeypatch) -> None: + from copaw.app.runner import runner as runner_module + + async def _no_approval(session_id: str, query: str | None): + _ = session_id, query + return None, False + + async def _stream_printing_messages(*args, **kwargs): + _ = args, kwargs + yield ( + Msg( + name="assistant", + role="assistant", + content=[TextBlock(type="text", text="ok")], + ), + True, + ) + + runner = AgentRunner() + runner.session = _DummySession() + cast(Any, runner)._resolve_pending_approval = _no_approval + + monkeypatch.setattr(runner_module, "CoPawAgent", _DummyAgent) + monkeypatch.setattr(runner_module, "build_env_context", lambda **kwargs: kwargs) + monkeypatch.setattr( + runner_module, + "load_config", + lambda: SimpleNamespace( + agents=SimpleNamespace( + running=SimpleNamespace( + max_iters=8, + max_input_length=8192, + auto_collect_chat_files=False, + auto_collect_long_text=False, + knowledge_retrieval_enabled=False, + ), + ), + knowledge=SimpleNamespace(enabled=True), + ), + ) + monkeypatch.setattr( + runner_module, + "stream_printing_messages", + _stream_printing_messages, + ) + + msgs = [ + Msg( + name="user", + role="user", + content=[TextBlock(type="text", text="如何接入知识库")], + ), + ] + request = cast( + AgentRequest, + SimpleNamespace( + session_id="session-2", + user_id="user-1", + channel="console", + ), + ) + + stream = cast( + AsyncIterator[tuple[Msg, bool]], + cast(Any, runner).query_handler(msgs, request=request), + ) + async for _msg, _last in stream: + pass + + captured = cast(list[Msg], _DummyAgent.captured_input_msgs) + assert captured is not None + assert len(captured) == 1 + assert captured[0].role == "user" diff --git a/tests/unit/app/test_graph_tools.py b/tests/unit/app/test_graph_tools.py new file mode 100644 index 000000000..27f7b2a8f --- /dev/null +++ b/tests/unit/app/test_graph_tools.py @@ -0,0 +1,257 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import importlib +import json +from types import SimpleNamespace + +from copaw.config.config import Config, KnowledgeSourceSpec +from copaw.knowledge.manager import KnowledgeManager + + +async def test_graph_query_requires_graph_enabled(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.graph_query") + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=True, graph_query_enabled=False), + ), + ) + + result = await module.graph_query("find relation") + text = result.content[0]["text"] + assert "graph query is disabled" in text + + +async def test_graph_query_formats_payload(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.graph_query") + + class _FakeGraphOpsManager: + def __init__(self, working_dir) -> None: + _ = working_dir + + def graph_query(self, **kwargs): + _ = kwargs + return SimpleNamespace( + records=[{"subject": "A", "predicate": "rel", "object": "B"}], + summary="ok", + provenance={"engine": "local_lexical"}, + warnings=[], + ) + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace( + enabled=True, + graph_query_enabled=True, + allow_cypher_query=False, + ), + ), + ) + monkeypatch.setattr(module, "GraphOpsManager", _FakeGraphOpsManager) + + result = await module.graph_query("find relation", query_mode="template") + payload = json.loads(result.content[0]["text"]) + assert payload["summary"] == "ok" + assert payload["records"][0]["subject"] == "A" + + +async def test_memify_run_requires_memify_enabled(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.memify_run") + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=True, memify_enabled=False), + ), + ) + result = await module.memify_run() + text = result.content[0]["text"] + assert "memify is disabled" in text + + +async def test_memify_run_returns_job_payload(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.memify_run") + + class _FakeGraphOpsManager: + def __init__(self, working_dir) -> None: + _ = working_dir + + def run_memify(self, **kwargs): + _ = kwargs + return { + "accepted": True, + "job_id": "job123", + "estimated_steps": 1, + "status_url": "/knowledge/memify/jobs/job123", + } + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=True, memify_enabled=True), + ), + ) + monkeypatch.setattr(module, "GraphOpsManager", _FakeGraphOpsManager) + + result = await module.memify_run(pipeline_type="default") + payload = json.loads(result.content[0]["text"]) + assert payload["accepted"] is True + assert payload["job_id"] == "job123" + + +async def test_memify_status_handles_not_found(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.memify_status") + + class _FakeGraphOpsManager: + def __init__(self, working_dir) -> None: + _ = working_dir + + def get_memify_status(self, job_id: str): + _ = job_id + return None + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=True, memify_enabled=True), + ), + ) + monkeypatch.setattr(module, "GraphOpsManager", _FakeGraphOpsManager) + + result = await module.memify_status("missing-job") + text = result.content[0]["text"] + assert "memify job not found" in text + + +async def test_triplet_focus_search_requires_enabled(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.triplet_focus_search") + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=True, triplet_search_enabled=False), + ), + ) + + result = await module.triplet_focus_search(query_text="entity relation") + text = result.content[0]["text"] + assert "triplet-focused search is disabled" in text + + +async def test_triplet_focus_search_formats_payload(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.triplet_focus_search") + + class _FakeGraphOpsManager: + def __init__(self, working_dir) -> None: + _ = working_dir + + def graph_query(self, **kwargs): + _ = kwargs + return SimpleNamespace( + records=[ + { + "subject": "Agent", + "predicate": "uses", + "object": "Tool", + "score": 2.0, + "source_id": "s1", + "source_type": "text", + "document_path": "docs/a.md", + "document_title": "A", + } + ], + warnings=[], + ) + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=True, triplet_search_enabled=True), + ), + ) + monkeypatch.setattr(module, "GraphOpsManager", _FakeGraphOpsManager) + + result = await module.triplet_focus_search(query_text="Agent uses Tool") + payload = json.loads(result.content[0]["text"]) + assert payload["triplets"][0]["subject"] == "Agent" + assert payload["triplets"][0]["predicate"] == "uses" + assert payload["triplets"][0]["object"] == "Tool" + + +async def test_graph_tool_chain_smoke_local_engine( + monkeypatch, + tmp_path, +) -> None: + graph_query_module = importlib.import_module("copaw.agents.tools.graph_query") + memify_run_module = importlib.import_module("copaw.agents.tools.memify_run") + memify_status_module = importlib.import_module("copaw.agents.tools.memify_status") + triplet_module = importlib.import_module("copaw.agents.tools.triplet_focus_search") + + knowledge_config = Config().knowledge + knowledge_config.enabled = True + knowledge_config.graph_query_enabled = True + knowledge_config.triplet_search_enabled = True + knowledge_config.memify_enabled = True + + manager = KnowledgeManager(tmp_path) + source = KnowledgeSourceSpec( + id="smoke-text-source", + name="Smoke Source", + type="text", + content="Agent uses tool for graph data processing.", + enabled=True, + recursive=False, + tags=["smoke"], + description="", + ) + knowledge_config.sources.append(source) + manager.index_source( + source, + knowledge_config, + SimpleNamespace(knowledge_chunk_size=knowledge_config.index.chunk_size), + ) + + load_config_stub = lambda: SimpleNamespace(knowledge=knowledge_config) + for module in [ + graph_query_module, + memify_run_module, + memify_status_module, + triplet_module, + ]: + monkeypatch.setattr(module, "load_config", load_config_stub) + monkeypatch.setattr(module, "WORKING_DIR", tmp_path) + + memify_result = await memify_run_module.memify_run( + pipeline_type="default", + idempotency_key="graph-tool-chain-smoke", + ) + memify_payload = json.loads(memify_result.content[0]["text"]) + assert memify_payload["accepted"] is True + + status_result = await memify_status_module.memify_status(memify_payload["job_id"]) + status_payload = json.loads(status_result.content[0]["text"]) + assert status_payload["job_id"] == memify_payload["job_id"] + assert status_payload["status"] == "succeeded" + + graph_result = await graph_query_module.graph_query( + query_text="Agent uses tool", + query_mode="template", + ) + graph_payload = json.loads(graph_result.content[0]["text"]) + assert len(graph_payload["records"]) >= 1 + + triplet_result = await triplet_module.triplet_focus_search( + query_text="Agent uses tool", + ) + triplet_payload = json.loads(triplet_result.content[0]["text"]) + assert isinstance(triplet_payload["triplets"], list) \ No newline at end of file diff --git a/tests/unit/app/test_knowledge_search_tool.py b/tests/unit/app/test_knowledge_search_tool.py new file mode 100644 index 000000000..c66a19dce --- /dev/null +++ b/tests/unit/app/test_knowledge_search_tool.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from types import SimpleNamespace +import importlib + +from copaw.agents.tools.knowledge_search import knowledge_search + + +async def test_knowledge_search_rejects_empty_query() -> None: + result = await knowledge_search(" ") + text = result.content[0]["text"] + assert "query cannot be empty" in text + + +async def test_knowledge_search_returns_disabled_message(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.knowledge_search") + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=False), + agents=SimpleNamespace( + running=SimpleNamespace(knowledge_retrieval_enabled=True), + ), + ), + ) + + result = await module.knowledge_search("how to index docs") + text = result.content[0]["text"] + assert "Knowledge is disabled" in text + + +async def test_knowledge_search_returns_runtime_disabled_message( + monkeypatch, +) -> None: + module = importlib.import_module("copaw.agents.tools.knowledge_search") + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=True), + agents=SimpleNamespace( + running=SimpleNamespace(knowledge_retrieval_enabled=False), + ), + ), + ) + + result = await module.knowledge_search("how to index docs") + text = result.content[0]["text"] + assert "Knowledge retrieval is disabled" in text + + +async def test_knowledge_search_formats_hits(monkeypatch) -> None: + module = importlib.import_module("copaw.agents.tools.knowledge_search") + + class _FakeManager: + def __init__(self, working_dir) -> None: + _ = working_dir + + def search( + self, + query: str, + config, + limit: int = 10, + source_ids=None, + source_types=None, + ): + _ = config, source_ids + assert query == "knowledge index" + assert limit == 2 + assert source_types == ["file"] + return { + "query": query, + "hits": [ + { + "source_name": "Project Docs", + "source_type": "file", + "document_title": "Index Guide", + "document_path": "docs/index.md", + "score": 2.5, + "snippet": "Run index after adding sources.", + }, + { + "source_name": "Low Score", + "source_type": "file", + "document_title": "Ignore", + "document_path": "docs/low.md", + "score": 0.2, + "snippet": "too low", + }, + ], + } + + monkeypatch.setattr( + module, + "load_config", + lambda: SimpleNamespace( + knowledge=SimpleNamespace(enabled=True), + agents=SimpleNamespace( + running=SimpleNamespace(knowledge_retrieval_enabled=True), + ), + ), + ) + monkeypatch.setattr(module, "KnowledgeManager", _FakeManager) + + result = await module.knowledge_search( + query="knowledge index", + max_results=2, + min_score=1.0, + source_types=["file"], + ) + text = result.content[0]["text"] + assert "Knowledge search results for: knowledge index" in text + assert "[1] Project Docs (file) score=2.50" in text + assert "title: Index Guide" in text + assert "path: docs/index.md" in text + assert "snippet: Run index after adding sources." in text + assert "Low Score" not in text diff --git a/tests/unit/config/test_config_utils.py b/tests/unit/config/test_config_utils.py new file mode 100644 index 000000000..15fcd9471 --- /dev/null +++ b/tests/unit/config/test_config_utils.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +import json + +from copaw.config.utils import load_config + + +def test_load_config_migrates_legacy_knowledge_engine_object(tmp_path): + config_path = tmp_path / "config.json" + payload = { + "knowledge": { + "enabled": True, + "engine": { + "provider": "default", + "fallback_to_default": True, + }, + } + } + config_path.write_text( + json.dumps(payload, ensure_ascii=False), + encoding="utf-8", + ) + + config = load_config(config_path) + + assert config.knowledge.engine == "local_lexical" From eb154577f1f53d4ebaeb38ccdf284ea177dc80fb Mon Sep 17 00:00:00 2001 From: Future Meng Date: Tue, 17 Mar 2026 15:50:12 +0800 Subject: [PATCH 57/68] feat(console): refine knowledge note cards --- console/src/api/modules/knowledge.ts | 55 ++ console/src/api/types/knowledge.ts | 6 + console/src/locales/en.json | 17 +- console/src/locales/zh.json | 17 +- .../pages/Agent/Knowledge/index.module.less | 632 +++++++++------ console/src/pages/Agent/Knowledge/index.tsx | 748 ++++++++++++------ src/copaw/app/routers/knowledge.py | 151 +++- src/copaw/knowledge/manager.py | 245 +++++- tests/unit/app/routers/test_knowledge.py | 143 +++- 9 files changed, 1495 insertions(+), 519 deletions(-) diff --git a/console/src/api/modules/knowledge.ts b/console/src/api/modules/knowledge.ts index 42c511932..65bc64836 100644 --- a/console/src/api/modules/knowledge.ts +++ b/console/src/api/modules/knowledge.ts @@ -5,6 +5,7 @@ import type { KnowledgeConfig, KnowledgeHistoryBackfillRunResponse, KnowledgeHistoryBackfillStatus, + KnowledgeRestoreResponse, KnowledgeIndexResult, KnowledgeClearResponse, KnowledgeSearchResponse, @@ -141,4 +142,58 @@ export const knowledgeApi = { `/knowledge/search?${searchParams.toString()}`, ); }, + + downloadKnowledgeBackup: async (): Promise => { + const response = await fetch(getApiUrl("/knowledge/backup"), { + method: "GET", + }); + if (!response.ok) { + throw new Error( + `Knowledge backup failed: ${response.status} ${response.statusText}`, + ); + } + return await response.blob(); + }, + + downloadKnowledgeSourceBackup: async (sourceId: string): Promise => { + const response = await fetch( + getApiUrl(`/knowledge/backup/${encodeURIComponent(sourceId)}`), + { + method: "GET", + }, + ); + if (!response.ok) { + throw new Error( + `Knowledge source backup failed: ${response.status} ${response.statusText}`, + ); + } + return await response.blob(); + }, + + restoreKnowledgeBackup: async ( + file: File, + replaceExisting = true, + ): Promise => { + const formData = new FormData(); + formData.append("file", file); + + const response = await fetch( + getApiUrl( + `/knowledge/restore?replace_existing=${replaceExisting ? "true" : "false"}`, + ), + { + method: "POST", + body: formData, + }, + ); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error( + `Knowledge restore failed: ${response.status} ${response.statusText}${ + text ? ` - ${text}` : "" + }`, + ); + } + return (await response.json()) as KnowledgeRestoreResponse; + }, }; \ No newline at end of file diff --git a/console/src/api/types/knowledge.ts b/console/src/api/types/knowledge.ts index c7497d807..4f3c71071 100644 --- a/console/src/api/types/knowledge.ts +++ b/console/src/api/types/knowledge.ts @@ -144,4 +144,10 @@ export interface KnowledgeClearResponse { cleared_indexes: number; cleared_sources: number; removed_source_configs: boolean; +} + +export interface KnowledgeRestoreResponse { + success: boolean; + replace_existing: boolean; + restored_sources: number; } \ No newline at end of file diff --git a/console/src/locales/en.json b/console/src/locales/en.json index ae6913aed..f96f4cdbb 100644 --- a/console/src/locales/en.json +++ b/console/src/locales/en.json @@ -94,13 +94,19 @@ "backfillFailed": "Failed to run history backfill", "goToRuntimeConfig": "Go to Runtime Config", "automationTitle": "Automation", + "noteStyle": "Note style", + "noteStyleNotion": "Notion", + "noteStyleObsidian": "Obsidian", "enabled": "Enabled", "addSource": "Add Source", - "indexAll": "Index All", + "indexAll": "Rebuild Index", + "backupAll": "Backup", + "restore": "Restore", + "backupSource": "Backup", "indexNow": "Index Now", "reindex": "Reindex", "indexSuccess": "Knowledge source indexed", - "indexAllSuccess": "Indexed all enabled knowledge sources", + "indexAllSuccess": "Rebuilt index for all enabled knowledge sources", "indexFailed": "Failed to index knowledge sources", "clearKnowledge": "Clear Knowledge", "clearConfirmTitle": "Confirm Clear Knowledge", @@ -108,7 +114,12 @@ "clearConfirmOk": "Clear", "clearSuccess": "Knowledge cleared: removed {{sources}} source configs and {{indexes}} index files", "clearFailed": "Failed to clear knowledge", - "unifiedProgressIndexAll": "Batch indexing in progress", + "backupAllSuccess": "Knowledge backup completed", + "backupSourceSuccess": "Knowledge source backup completed", + "backupFailed": "Failed to backup knowledge", + "restoreSuccess": "Restore completed, restored {{count}} sources", + "restoreFailed": "Failed to restore backup", + "unifiedProgressIndexAll": "Rebuilding index in progress", "unifiedProgressBackfill": "History backfill: traversed {{traversed}} / {{total}}", "unifiedProgressBackfillStarting": "Preparing history backfill", "unifiedProgressClearing": "Knowledge clearing in progress", diff --git a/console/src/locales/zh.json b/console/src/locales/zh.json index c24c16d7a..953318801 100644 --- a/console/src/locales/zh.json +++ b/console/src/locales/zh.json @@ -94,13 +94,19 @@ "backfillFailed": "历史追溯执行失败", "goToRuntimeConfig": "前往运行配置", "automationTitle": "自动沉淀", + "noteStyle": "笔记风格", + "noteStyleNotion": "Notion", + "noteStyleObsidian": "Obsidian", "enabled": "启用", "addSource": "新增来源", - "indexAll": "全部索引", + "indexAll": "重建索引", + "backupAll": "备份", + "restore": "恢复", + "backupSource": "备份", "indexNow": "立即索引", "reindex": "重新索引", "indexSuccess": "知识源索引完成", - "indexAllSuccess": "已完成全部启用知识源的索引", + "indexAllSuccess": "已完成全部启用知识源的重建索引", "indexFailed": "知识源索引失败", "clearKnowledge": "清空知识库", "clearConfirmTitle": "确认清空知识库", @@ -108,7 +114,12 @@ "clearConfirmOk": "确认清空", "clearSuccess": "知识库已清空:删除 {{sources}} 个来源配置,清理 {{indexes}} 个索引文件", "clearFailed": "清空知识库失败", - "unifiedProgressIndexAll": "批量索引执行中", + "backupAllSuccess": "知识库备份完成", + "backupSourceSuccess": "知识源备份完成", + "backupFailed": "知识库备份失败", + "restoreSuccess": "恢复完成,共恢复 {{count}} 个知识源", + "restoreFailed": "恢复失败", + "unifiedProgressIndexAll": "重建索引执行中", "unifiedProgressBackfill": "历史对话追溯中:已遍历 {{traversed}} / {{total}}", "unifiedProgressBackfillStarting": "历史对话追溯准备中", "unifiedProgressClearing": "知识库清空执行中", diff --git a/console/src/pages/Agent/Knowledge/index.module.less b/console/src/pages/Agent/Knowledge/index.module.less index 5c7455fb2..391787548 100644 --- a/console/src/pages/Agent/Knowledge/index.module.less +++ b/console/src/pages/Agent/Knowledge/index.module.less @@ -1,5 +1,53 @@ .knowledgePage { + --note-page-bg: radial-gradient(circle at top left, #f9fbff 0%, #fefefe 40%, #f6f7fb 100%); + --note-card-bg: #ffffff; + --note-card-border: rgba(26, 31, 44, 0.08); + --note-card-shadow: 0 10px 28px rgba(17, 24, 39, 0.06); + --note-hover-border: #4c6fff; + --note-hover-shadow: 0 16px 36px rgba(76, 111, 255, 0.2); + --note-block-bg: #f8f9fd; + --note-block-border: #e8ecf5; + --note-title-color: #1f2430; + --note-body-color: #4f566b; + --note-label-color: #7f8699; + --note-font: "Avenir Next", "Segoe UI", sans-serif; + --note-title-font: "Iowan Old Style", "Georgia", serif; + --note-radius: 18px; padding: 24px; + background: var(--note-page-bg); + border-radius: 18px; + font-family: var(--note-font); +} + +.noteStyleNotion { + --note-page-bg: linear-gradient(160deg, #f8fafc 0%, #fdfefe 55%, #f3f7ff 100%); + --note-card-bg: #ffffff; + --note-card-border: rgba(17, 24, 39, 0.09); + --note-card-shadow: 0 10px 26px rgba(15, 23, 42, 0.07); + --note-hover-border: #4c6fff; + --note-hover-shadow: 0 16px 34px rgba(76, 111, 255, 0.2); + --note-block-bg: #f8f9fd; + --note-block-border: #e8ecf5; +} + +.noteStyleObsidian { + --note-page-bg: radial-gradient(circle at top right, #202734 0%, #1a202b 60%, #151b25 100%); + --note-card-bg: #1f2633; + --note-card-border: #2d3748; + --note-card-shadow: 0 12px 28px rgba(0, 0, 0, 0.3); + --note-hover-border: #63b3ed; + --note-hover-shadow: 0 18px 36px rgba(49, 130, 206, 0.26); + --note-block-bg: #273042; + --note-block-border: #344259; + --note-title-color: #edf2f7; + --note-body-color: #cbd5e1; + --note-label-color: #93a4ba; + --note-font: "IBM Plex Sans", "Segoe UI", sans-serif; + --note-title-font: "IBM Plex Serif", "Georgia", serif; +} + +.knowledgeListThemeScope { + width: 100%; } .header { @@ -16,14 +64,14 @@ } .title { - margin-bottom: 4px !important; - font-size: 24px !important; - font-weight: 600 !important; + margin-bottom: 4px; + font-size: 24px; + font-weight: 600; } .description { - margin: 0 !important; - color: #999; + margin: 0; + color: var(--note-label-color); font-size: 14px; } @@ -32,6 +80,37 @@ align-items: center; gap: 8px; flex-wrap: wrap; + justify-content: flex-end; +} + +.headerControlGroup, +.headerButtonGroup { + display: flex; + align-items: center; + gap: 8px; + flex-wrap: wrap; +} + +.headerButtonGroup { + justify-content: flex-end; +} + +.noteStyleSegment { + min-width: 220px; +} + +.noteStyleLabel { + color: var(--note-label-color); +} + +.noteStyleOptionLabel { + display: inline-flex; + align-items: center; + gap: 6px; +} + +.noteStyleOptionText { + white-space: nowrap; } .unifiedProgressRow { @@ -49,169 +128,149 @@ width: 100%; } -.rowBetween { +.fullWidth { width: 100%; - justify-content: space-between; } -.queueBanner { +.contentLoadingWrap { display: flex; - align-items: center; - justify-content: space-between; - gap: 12px; - padding: 12px 16px; - border-radius: 10px; - background: #e6f4ff; - border: 1px solid #91caff; + justify-content: center; + padding: 16px 0; } -.queueBannerWaiting { - background: #fffbe6; - border-color: #ffe58f; +.contentEmpty { + font-size: 12px; + color: var(--note-label-color); } -.queueBannerDone { - background: #f6ffed; - border-color: #b7eb8f; +.documentList { + display: flex; + flex-direction: column; + gap: 12px; } -.queueBannerLeft { +.documentItem { display: flex; - align-items: center; - gap: 10px; + flex-direction: column; + gap: 8px; } -.queueBannerIcon { - font-size: 18px; - color: #1677ff; +.documentTitle { + font-size: 12px; + color: var(--note-body-color); +} - .queueBannerWaiting & { - color: #faad14; - } +.documentMarkdown { + border: 1px solid var(--note-block-border); + border-radius: 10px; + background: var(--note-block-bg); + padding: 8px; +} - .queueBannerDone & { - color: #52c41a; - } +.documentMarkdownViewer { + border: none; + background: transparent; } -.queueBannerText { - display: flex; - flex-direction: column; - gap: 2px; +.documentMarkdownViewer :global(table) { + width: 100%; + border-collapse: collapse; + margin: 12px 0; } -.queueBannerSub { - color: #667085; - font-size: 12px; +.documentMarkdownViewer :global(th), +.documentMarkdownViewer :global(td) { + border: 1px solid var(--note-block-border); + padding: 10px 12px; + text-align: left; + vertical-align: top; } -.queueBannerMeta { - display: flex; - align-items: center; - flex-wrap: wrap; - gap: 4px; - margin-top: 4px; +.documentMarkdownViewer :global(th) { + background: rgba(76, 111, 255, 0.08); + color: var(--note-title-color); + font-weight: 600; } -.queueBannerMetaAction { - display: inline-flex; - align-items: center; - flex-wrap: wrap; - gap: 6px; +.documentMarkdownViewer :global(code) { + background: rgba(76, 111, 255, 0.08); + color: #274690; + border-radius: 6px; + padding: 0.15em 0.4em; } -.queueBannerActionButton { - padding: 0; - height: auto; - line-height: 1.2; +.documentMarkdownViewer :global(pre) { + margin: 12px 0; + border-radius: 12px; + border: 1px solid var(--note-block-border); + background: rgba(248, 249, 253, 0.92); + overflow: auto; } -.queueLogsBox { - font-family: "SF Mono", "Menlo", monospace; - font-size: 12px; +.documentMarkdownViewer :global(pre code) { + display: block; + background: transparent; + color: var(--note-body-color); + padding: 14px 16px; } -.policyCard { - display: flex; - flex-direction: column; - gap: 12px; +.documentMarkdownViewer :global(blockquote) { + margin: 12px 0; + padding: 10px 14px; + border-left: 3px solid rgba(76, 111, 255, 0.35); + background: rgba(76, 111, 255, 0.06); + color: var(--note-body-color); } -.policyHeader { - display: flex; - align-items: flex-start; - justify-content: space-between; - gap: 12px; +.documentMarkdownViewer :global(a) { + color: #2f67d8; } -.policyDescription { - margin: 4px 0 0 !important; - color: #667085; +.noteStyleObsidian .documentMarkdownViewer :global(table) { + background: #1b2432; } -.policyGrid { - display: grid; - grid-template-columns: repeat(2, minmax(0, 1fr)); - gap: 10px 12px; +.noteStyleObsidian .documentMarkdownViewer :global(th), +.noteStyleObsidian .documentMarkdownViewer :global(td) { + border-color: #42526b; } -.policyItem { - display: flex; - align-items: center; - justify-content: space-between; - gap: 8px; - padding: 10px 12px; - border: 1px solid #eceff6; - border-radius: 10px; - background: #fafbff; +.noteStyleObsidian .documentMarkdownViewer :global(th) { + background: #243247; + color: #f8fafc; } -.fullWidth { - width: 100%; +.noteStyleObsidian .documentMarkdownViewer :global(tr:nth-child(even) td) { + background: rgba(59, 130, 246, 0.08); } -.contentLoadingWrap { - display: flex; - justify-content: center; - padding: 16px 0; +.noteStyleObsidian .documentMarkdownViewer :global(tr:nth-child(odd) td) { + background: rgba(15, 23, 42, 0.32); } -.contentEmpty { - font-size: 12px; - color: #999; +.noteStyleObsidian .documentMarkdownViewer :global(code) { + background: #283548; + color: #ffd580; } -.documentList { - display: flex; - flex-direction: column; - gap: 12px; +.noteStyleObsidian .documentMarkdownViewer :global(pre) { + border-color: #42526b; + background: linear-gradient(180deg, #101924 0%, #182334 100%); + box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.04); } -.documentItem { - display: flex; - flex-direction: column; - gap: 4px; +.noteStyleObsidian .documentMarkdownViewer :global(pre code) { + color: #e5edf8; } -.documentTitle { - font-size: 12px; - color: #525866; +.noteStyleObsidian .documentMarkdownViewer :global(blockquote) { + border-left-color: #63b3ed; + background: rgba(99, 179, 237, 0.08); + color: #d7e4f5; } -.documentText { - margin: 0; - font-size: 12px; - line-height: 1.6; - color: #525866; - background: #f5f6fa; - border: 1px solid #eceff6; - border-radius: 8px; - padding: 10px 12px; - white-space: pre-wrap; - word-break: break-word; - max-height: 400px; - overflow-y: auto; - font-family: inherit; +.noteStyleObsidian .documentMarkdownViewer :global(a) { + color: #8ec5ff; } .searchHeader { @@ -225,10 +284,42 @@ width: 180px; } +.searchControls { + display: flex; + width: 100%; + gap: 10px; + align-items: center; +} + +.searchCompact { + flex: 1; +} + +.searchButtons { + display: flex; + gap: 8px; + flex-shrink: 0; +} + .searchResultsWrap { margin-top: 16px; } +.searchStatusText { + color: var(--note-label-color); + font-size: 13px; + min-height: 24px; + display: flex; + align-items: center; + gap: 8px; +} + +.searchResultList { + max-height: 320px; + overflow: auto; + padding-right: 4px; +} + .searchHitCard { border-radius: 12px; border: 1px solid #eceff6; @@ -249,111 +340,103 @@ } .searchHitSnippet { - margin: 8px 0 0 !important; + margin: 8px 0 0; color: #475467; } -.sourceToolbar { +.filterBar { width: 100%; margin-bottom: 12px; + display: flex; + align-items: flex-end; + gap: 12px; + flex-wrap: wrap; } -.sourceCards { - display: grid; - grid-template-columns: repeat(auto-fill, minmax(360px, 1fr)); - gap: 14px; - align-items: stretch; +.filterGroup { + display: flex; + flex-direction: column; + gap: 6px; + min-width: 0; + flex: 1 1 240px; } -.copawCard { - height: 100%; +.filterLabel { + color: var(--note-label-color); + font-size: 12px; } -.copawCardBody { - height: 100%; +.filterSelect { + width: 100%; } -.copawSparkCardWrapper { - height: 100%; +.cardsGrid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); + gap: 24px; + align-items: stretch; } -.copawSparkContent { - background: #fff; - border: 1px solid rgba(0, 0, 0, 0.04); - border-radius: 16px; +.cardContainer { + background: var(--note-card-bg); + border: 1px solid var(--note-card-border); + border-radius: var(--note-radius); transition: all 0.2s ease-in-out; overflow: hidden; height: 100%; display: flex; flex-direction: column; + box-shadow: var(--note-card-shadow); + animation: cardFadeIn 0.35s ease; + padding: 14px 14px 0; &:hover { - border-color: #615ced; - box-shadow: 0 12px 32px rgba(0, 0, 0, 0.08); + border-color: var(--note-hover-border); + box-shadow: var(--note-hover-shadow); } } -.copawSparkContentActive { - border-color: #615ced; - box-shadow: - 0 0 0 2px rgba(97, 92, 237, 0.18), - 0 12px 32px rgba(97, 92, 237, 0.18); +.cardContainerActive { + border-color: var(--note-hover-border); + box-shadow: var(--note-hover-shadow); } -.sourceCardBody { +.cardHeader { display: flex; flex-direction: column; - gap: 10px; - min-height: 0; - height: 100%; - flex: 1 1 auto; - padding: 14px 14px 10px; -} - -.sourceCardHeader { - display: flex; - flex-direction: column; - align-items: stretch; gap: 6px; - margin-bottom: 0; } -.sourceOriginTag, -.sourceTypeTag { +.originTag, +.typeTag { font-size: 12px; border-radius: 4px; padding: 1px 6px; white-space: nowrap; line-height: 18px; + flex: 0 0 auto; } -.sourceOriginTag { +.originTag { color: #615ced; background: rgba(97, 92, 237, 0.1); } -.sourceTypeTag { +.typeTag { color: #1677ff; background: rgba(22, 119, 255, 0.1); } -.sourceMainInfo { - display: flex; - flex-direction: column; - gap: 2px; - width: 100%; - min-width: 0; -} - -.sourceHeaderTopRow { +.cardHeaderRow { display: flex; width: 100%; align-items: center; flex-wrap: nowrap; gap: 5px; + min-width: 0; } -.sourceHeaderId { +.cardHeaderId { flex: 1 1 auto; min-width: 0; line-height: 20px; @@ -362,84 +445,68 @@ text-overflow: ellipsis; } -.sourceHeaderTypeTag { - flex: 0 0 auto; -} - -.sourceRunningTag { +.indexedTag, +.notIndexedTag { font-size: 12px; - color: #f59e0b; - background: rgba(245, 158, 11, 0.12); border-radius: 4px; padding: 1px 6px; white-space: nowrap; line-height: 18px; - flex: 0 0 auto; } -.sourceIndexedTag { - font-size: 12px; +.indexedTag { color: #389e0d; background: rgba(82, 196, 26, 0.12); - border-radius: 4px; - padding: 1px 6px; - white-space: nowrap; - line-height: 18px; } -.sourceNotIndexedTag { - font-size: 12px; +.notIndexedTag { color: #667085; background: rgba(102, 112, 133, 0.12); - border-radius: 4px; - padding: 1px 6px; - white-space: nowrap; - line-height: 18px; } -.sourceTitle { +.cardTitle { display: block; - width: 100%; font-size: 14px; line-height: 1.4; - color: #1a1a1a; + color: var(--note-title-color); white-space: normal; word-break: break-word; + font-family: var(--note-title-font); } -.sourceMeta { +.cardMeta { display: flex; min-height: 0; flex-direction: column; gap: 10px; flex: 1 1 auto; + margin-top: 10px; } -.sourceInfoSection { +.infoSection { display: flex; flex-direction: column; gap: 5px; min-width: 0; } -.sourceInfoLabel { +.infoLabel { font-size: 12px; - color: #999; - margin-bottom: 0; + color: var(--note-label-color); line-height: 1.2; } -.sourceInfoBlock { +.infoBlock { font-size: 12px; - color: #525866; - background-color: #f5f6fa; - border: 1px solid #eceff6; + color: var(--note-body-color); + background-color: var(--note-block-bg); + border: 1px solid var(--note-block-border); border-radius: 8px; padding: 7px 9px; line-height: 1.4; } -.sourceLocationButton { +.clickableBlock { cursor: pointer; transition: border-color 0.2s ease, box-shadow 0.2s ease; @@ -454,19 +521,13 @@ } } -.sourceSingleLineValue { +.singleLineValue { white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } -.sourceInfoTagWrap { - display: flex; - align-items: center; - min-height: 20px; -} - -.sourceStatusStatsRow { +.statusRow { display: flex; align-items: center; justify-content: space-between; @@ -475,57 +536,56 @@ min-height: 20px; } -.sourceStatusStatsSubRow { +.statusRow > :first-child { display: flex; - justify-content: space-between; align-items: center; - gap: 6px; - margin-top: 6px; min-height: 20px; } -.sourceStatsText { +.statusRow > :last-child { flex: 1 1 auto; min-width: 0; text-align: right; line-height: 20px; } -.detailTagRow { +.statusSubRow { display: flex; + justify-content: space-between; align-items: center; - gap: 8px; + gap: 6px; + margin-top: 6px; + min-height: 20px; } -.sourceMetaRow { +.detailTagRow { display: flex; - justify-content: space-between; align-items: center; gap: 8px; } -.sourceCardFooter { +.cardFooter { display: flex; justify-content: flex-end; align-items: center; gap: 8px; border-top: 1px solid #f1f2f6; padding: 8px 14px 10px; - margin-top: 0; + margin: 0 -14px; } -.sourceActions { +.actionRow { display: flex; justify-content: flex-end; align-items: center; gap: 8px; } -.sourceActionButton { +.actionButton { padding: 0; } -.sourceDeleteButton { +.deleteButton { padding: 4px; display: flex; align-items: center; @@ -538,12 +598,29 @@ } } -.remoteError { - color: #d4380d; +.noteStyleObsidian .actionButton { + color: #9ecbff; +} + +.noteStyleObsidian .actionButton:hover { + color: #d4e8ff; } -.originSelect { - width: 220px; +.noteStyleObsidian .deleteButton { + color: #ffb3b8; + border: 1px solid rgba(255, 99, 110, 0.38); + border-radius: 8px; + background: rgba(255, 99, 110, 0.08); +} + +.noteStyleObsidian .deleteButton:hover { + background-color: #d64550; + color: #fff; + border-color: #d64550; +} + +.remoteError { + color: #d4380d; } .dragZone { @@ -569,44 +646,123 @@ } .enableBackfillHint { - margin-bottom: 0 !important; + margin-bottom: 0; color: #667085; } +@keyframes cardFadeIn { + from { + opacity: 0; + transform: translateY(6px); + } + + to { + opacity: 1; + transform: translateY(0); + } +} + @media (max-width: 900px) { .header { - flex-direction: column; + flex-wrap: wrap; align-items: stretch; } + .headerInfo { + flex: 0 0 100%; + width: 100%; + } + + .headerActions { + width: 100%; + justify-content: space-between; + align-items: flex-start; + row-gap: 10px; + } + + .headerControlGroup { + flex: 1 1 260px; + } + + .headerButtonGroup { + flex: 1 1 320px; + justify-content: flex-start; + } + + .knowledgePage { + padding: 18px; + } + .searchTypeSelect { width: 140px; } - .originSelect { - width: 180px; + .searchControls { + flex-direction: column; + align-items: stretch; } - .sourceCards { - grid-template-columns: 1fr; + .searchButtons { + width: 100%; + justify-content: flex-end; } - .sourceCardHeader { + .filterBar { flex-direction: column; align-items: stretch; + gap: 10px; + } + + .filterGroup { + flex-basis: auto; + width: 100%; + } + + .noteStyleSegment { + min-width: auto; + } + + .noteStyleLabel, + .noteStyleOptionText { + display: none; + } + + .noteStyleOptionLabel { + gap: 0; } - .sourceActions { + .actionRow { justify-content: flex-start; flex-wrap: wrap; } +} - .policyHeader { +@media (max-width: 640px) { + .headerActions { flex-direction: column; align-items: stretch; } - .policyGrid { - grid-template-columns: 1fr; + .headerControlGroup, + .headerButtonGroup { + width: 100%; + } + + .headerButtonGroup { + justify-content: flex-start; + } + + .knowledgePage { + padding: 12px; + border-radius: 0; + } + + .cardContainer { + padding: 12px 12px 0; + } + + .cardFooter { + margin: 0 -12px; + padding: 8px 12px 10px; } } diff --git a/console/src/pages/Agent/Knowledge/index.tsx b/console/src/pages/Agent/Knowledge/index.tsx index 14e3ff649..023ed1efd 100644 --- a/console/src/pages/Agent/Knowledge/index.tsx +++ b/console/src/pages/Agent/Knowledge/index.tsx @@ -13,14 +13,18 @@ import { Tag, message, } from "@agentscope-ai/design"; -import { Divider, Progress, Space, Spin, Typography } from "antd"; +import { Divider, Progress, Segmented, Space, Spin, Typography } from "antd"; import { useNavigate } from "react-router-dom"; import { + BookOutlined, DatabaseOutlined, + DownloadOutlined, DeleteOutlined, + MoonOutlined, PlusOutlined, ReloadOutlined, SearchOutlined, + UploadOutlined, } from "@ant-design/icons"; import { useTranslation } from "react-i18next"; import api from "../../../api"; @@ -36,8 +40,13 @@ import type { KnowledgeSourceSpec, KnowledgeSourceType, } from "../../../api/types"; +import { MarkdownCopy } from "../../../components/MarkdownCopy/MarkdownCopy"; import styles from "./index.module.less"; +const KNOWLEDGE_NOTE_STYLE_STORAGE_KEY = "copaw_knowledge_note_style"; + +type KnowledgeNoteStyle = "notion" | "obsidian"; + const SOURCE_TYPE_OPTIONS: Array<{ label: string; value: KnowledgeSourceType; @@ -132,6 +141,9 @@ function KnowledgePage() { const [saving, setSaving] = useState(false); const [indexingAll, setIndexingAll] = useState(false); const [clearingKnowledge, setClearingKnowledge] = useState(false); + const [exportingAll, setExportingAll] = useState(false); + const [exportingSourceId, setExportingSourceId] = useState(null); + const [importingBackup, setImportingBackup] = useState(false); const [indexingId, setIndexingId] = useState(null); const [detailDrawerOpen, setDetailDrawerOpen] = useState(false); const [selectedSource, setSelectedSource] = @@ -148,8 +160,19 @@ function KnowledgePage() { Array<{ file: File; relativePath: string }> >([]); const [selectedDirectorySummary, setSelectedDirectorySummary] = useState(""); + const [noteStyle, setNoteStyle] = useState(() => { + if (typeof window === "undefined") { + return "notion"; + } + const saved = window.localStorage.getItem(KNOWLEDGE_NOTE_STYLE_STORAGE_KEY); + if (saved === "obsidian" || saved === "notion") { + return saved; + } + return "notion"; + }); const singleFileInputRef = useRef(null); const directoryInputRef = useRef(null); + const backupImportInputRef = useRef(null); const remoteStateRef = useRef>({}); const hasLoadedOnceRef = useRef(false); const backfillProgressWsRef = useRef(null); @@ -222,6 +245,13 @@ function KnowledgePage() { loadData(); }, [loadData]); + useEffect(() => { + if (typeof window === "undefined") { + return; + } + window.localStorage.setItem(KNOWLEDGE_NOTE_STYLE_STORAGE_KEY, noteStyle); + }, [noteStyle]); + // While history backfill is running, poll existing sources API to refresh cards. useEffect(() => { if (!backfillProgress?.running) { @@ -549,6 +579,90 @@ function KnowledgePage() { } }; + const triggerBlobDownload = useCallback((blob: Blob, filename: string) => { + const objectUrl = window.URL.createObjectURL(blob); + const link = document.createElement("a"); + link.href = objectUrl; + link.download = filename; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + window.URL.revokeObjectURL(objectUrl); + }, []); + + const handleBackupAll = useCallback(async () => { + try { + setExportingAll(true); + const blob = await api.downloadKnowledgeBackup(); + const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); + triggerBlobDownload(blob, `copaw_knowledge_${timestamp}.zip`); + message.success(t("knowledge.backupAllSuccess")); + } catch (error) { + console.error("Failed to backup knowledge", error); + message.error(t("knowledge.backupFailed")); + } finally { + setExportingAll(false); + } + }, [t, triggerBlobDownload]); + + const handleBackupSource = useCallback( + async (sourceId: string, sourceName: string) => { + try { + setExportingSourceId(sourceId); + const blob = await api.downloadKnowledgeSourceBackup(sourceId); + const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); + const normalizedName = (sourceName || sourceId) + .replace(/[^A-Za-z0-9._-]+/g, "-") + .replace(/^-+|-+$/g, "") || sourceId; + triggerBlobDownload( + blob, + `copaw_knowledge_${normalizedName}_${timestamp}.zip`, + ); + message.success(t("knowledge.backupSourceSuccess")); + } catch (error) { + console.error("Failed to backup knowledge source", error); + message.error(t("knowledge.backupFailed")); + } finally { + setExportingSourceId(null); + } + }, + [t, triggerBlobDownload], + ); + + const handleRestoreBackupPicked = useCallback( + async (event: React.ChangeEvent) => { + const file = event.target.files?.[0] ?? null; + event.target.value = ""; + if (!file) { + return; + } + + try { + setImportingBackup(true); + const result = await api.restoreKnowledgeBackup(file, true); + message.success( + t("knowledge.restoreSuccess", { + count: result.restored_sources, + }), + ); + await loadData(); + } catch (error) { + console.error("Failed to restore knowledge backup", error); + message.error(t("knowledge.restoreFailed")); + } finally { + setImportingBackup(false); + } + }, + [loadData, t], + ); + + const handleRestoreBackup = useCallback(() => { + if (importingBackup) { + return; + } + backupImportInputRef.current?.click(); + }, [importingBackup]); + const handleClearKnowledge = useCallback(() => { Modal.confirm({ title: t("knowledge.clearConfirmTitle"), @@ -602,6 +716,15 @@ function KnowledgePage() { } }; + const handleResetSearch = useCallback(() => { + setSearchQuery(""); + setSearchTypeFilter("all"); + setHits([]); + }, []); + + const hasSearchQuery = searchQuery.trim().length > 0; + const showSearchPanel = searching || hasSearchQuery || hits.length > 0; + const handleDeleteSource = useCallback(async (sourceId: string) => { try { await api.deleteKnowledgeSource(sourceId); @@ -822,6 +945,41 @@ function KnowledgePage() { t, ]); + const noteStyleOptions = useMemo( + () => [ + { + label: ( + + + + {t("knowledge.noteStyleNotion")} + + + ), + value: "notion", + }, + { + label: ( + + + + {t("knowledge.noteStyleObsidian")} + + + ), + value: "obsidian", + }, + ], + [t], + ); + + const noteStyleClassName = useMemo(() => { + if (noteStyle === "obsidian") { + return styles.noteStyleObsidian; + } + return styles.noteStyleNotion; + }, [noteStyle]); + return (
@@ -837,41 +995,63 @@ function KnowledgePage() {
- {t("knowledge.enabled")} - - - - - - {showBackfillNowButton ? ( +
+ {t("knowledge.enabled")} + +
+
+ + + + - ) : null} + {showBackfillNowButton ? ( + + ) : null} +
@@ -900,103 +1080,159 @@ function KnowledgePage() { ) : null} - +
+ + { + const value = event.target.value; + setSearchQuery(value); + if (!value.trim() && hits.length > 0) { + setHits([]); + } + }} + placeholder={t("knowledge.searchPlaceholder")} + onPressEnter={handleSearch} + /> + +
+ + +
+
+ + {showSearchPanel ? ( +
+ {searching ? ( +
+ +
+ ) : hits.length === 0 ? ( +
{t("knowledge.searchEmpty")}
+ ) : ( +
+ + {hits.map((hit) => ( + +
+ + {hit.source_name} + {hit.source_type} + + + {t("knowledge.scoreLabel", { + score: Number(hit.score).toFixed(2), + })} + +
+ {hit.document_title} + + {hit.document_path} + + + {hit.snippet} + +
+ ))} +
+
+ )} +
+ ) : null} + + +
+ +
+ + {t("knowledge.noteStyle")} + + setNoteStyle(value as KnowledgeNoteStyle)} + className={styles.noteStyleSegment} + /> +
+
+ +
+ +
+
+ + {t("knowledge.sourceOriginFilter")} + + - setSearchTypeFilter(value as KnowledgeSourceType | "all") - } + value={sourceTypeFilter} + onChange={(value) => setSourceTypeFilter(value as KnowledgeSourceType | "all")} options={[ { label: t("knowledge.allTypes"), value: "all" }, ...SOURCE_TYPE_OPTIONS, ]} - className={styles.searchTypeSelect} + className={styles.filterSelect} /> - setSearchQuery(event.target.value)} - placeholder={t("knowledge.searchPlaceholder")} - onPressEnter={handleSearch} - /> - - - -
- {hits.length === 0 ? ( - - ) : ( - - {hits.map((hit) => ( - -
- - {hit.source_name} - {hit.source_type} - - - {t("knowledge.scoreLabel", { - score: Number(hit.score).toFixed(2), - })} - -
- {hit.document_title} - - {hit.document_path} - - - {hit.snippet} - -
- ))} -
- )} +
- - - - - {t("knowledge.sourceOriginFilter")} - setSourceTypeFilter(value as KnowledgeSourceType | "all")} - options={[ - { label: t("knowledge.allTypes"), value: "all" }, - ...SOURCE_TYPE_OPTIONS, - ]} - className={styles.originSelect} - /> - {filteredSources.length === 0 ? ( ) : ( -
+
{filteredSources.map((record) => { const originText = getSourceOriginText(record, t); const remoteLine = formatRemoteStatus(record, t); const isActiveCard = indexingId === record.id; const cardTitle = record.name?.trim() || ""; const descriptionText = record.description?.trim() || ""; + const hideDescriptionBlock = + cardTitle.length > 0 && + descriptionText.length > 0 && + cardTitle === descriptionText; const indexedCountText = record.status.indexed ? t("knowledge.indexedCount", { documents: record.status.document_count, @@ -1004,56 +1240,50 @@ function KnowledgePage() { }) : "-"; return ( -
-
-
-
-
-
-
-
- - {record.id} - - - {record.type} - - {originText} -
-
+
+
+
+
+ + {record.id} + + {record.type} + {originText}
+
-
- {cardTitle ? ( -
-
- {t("knowledge.table.title")} -
-
openDetailDrawer(record)} - onKeyDown={(event) => - handleDetailDrawerValueKeyDown(event, record) - } - className={`${styles.sourceInfoBlock} ${styles.sourceLocationButton}`} +
+ {cardTitle ? ( +
+
+ {t("knowledge.table.title")} +
+
openDetailDrawer(record)} + onKeyDown={(event) => + handleDetailDrawerValueKeyDown(event, record) + } + className={`${styles.infoBlock} ${styles.clickableBlock}`} + title={cardTitle} + > + - - {cardTitle} - -
+ {cardTitle} +
- ) : null} +
+ ) : null} -
-
+ {!hideDescriptionBlock ? ( +
+
{t("knowledge.table.source")}
handleDetailDrawerValueKeyDown(event, record) } - className={`${styles.sourceInfoBlock} ${styles.sourceLocationButton}`} + className={`${styles.infoBlock} ${styles.clickableBlock}`} title={descriptionText || t("knowledge.inlineText")} > {descriptionText || t("knowledge.inlineText")}
+ ) : null} - {record.location ? ( -
-
- {t("knowledge.table.location")} -
-
openDetailDrawer(record)} - onKeyDown={(event) => - handleDetailDrawerValueKeyDown(event, record) - } - className={`${styles.sourceInfoBlock} ${styles.sourceSingleLineValue} ${styles.sourceLocationButton}`} - title={record.location} - > - {record.location} -
-
- ) : null} - -
-
- {t("knowledge.statusAndStats")} + {record.location ? ( +
+
+ {t("knowledge.table.location")}
handleDetailDrawerValueKeyDown(event, record) } - className={`${styles.sourceInfoBlock} ${styles.sourceLocationButton}`} + className={`${styles.infoBlock} ${styles.singleLineValue} ${styles.clickableBlock}`} + title={record.location} > -
-
- - {record.status.indexed - ? t("knowledge.indexed") - : t("knowledge.notIndexed")} - -
- - {indexedCountText} - + {record.location} +
+
+ ) : null} + +
+
+ {t("knowledge.statusAndStats")} +
+
openDetailDrawer(record)} + onKeyDown={(event) => + handleDetailDrawerValueKeyDown(event, record) + } + className={`${styles.infoBlock} ${styles.clickableBlock}`} + > +
+
+ + {record.status.indexed + ? t("knowledge.indexed") + : t("knowledge.notIndexed")} +
- {remoteLine ? ( -
- Remote - {remoteLine} -
- ) : null} - {record.status.remote_last_error ? ( - - {t("knowledge.remoteLastError", { - error: record.status.remote_last_error, - })} - - ) : null} + + {indexedCountText} +
+ {remoteLine ? ( +
+ Remote + {remoteLine} +
+ ) : null} + {record.status.remote_last_error ? ( + + {t("knowledge.remoteLastError", { + error: record.status.remote_last_error, + })} + + ) : null}
-
-
+
+
+
-
-
); @@ -1180,6 +1418,7 @@ function KnowledgePage() {
)} +
@@ -1466,55 +1705,55 @@ function KnowledgePage() { {selectedSource ? ( {selectedSource.name?.trim() ? ( -
-
{t("knowledge.table.title")}
-
+
+
{t("knowledge.table.title")}
+
{selectedSource.name}
) : null} -
-
{t("knowledge.table.source")}
-
+
+
{t("knowledge.table.source")}
+
{selectedSource.description || t("knowledge.inlineText")}
- {selectedSourceOriginText} - {selectedSource.type} + {selectedSourceOriginText} + {selectedSource.type}
-
-
{t("knowledge.form.id")}
-
+
+
{t("knowledge.form.id")}
+
{selectedSource.id}
{selectedSource.location ? ( -
-
{t("knowledge.table.location")}
-
{selectedSource.location}
+
+
{t("knowledge.table.location")}
+
{selectedSource.location}
) : null} -
-
{t("knowledge.table.chunkStats")}
-
+
+
{t("knowledge.table.chunkStats")}
+
{selectedSourceIndexedCountText}
-
-
{t("knowledge.table.status")}
-
+
+
{t("knowledge.table.status")}
+
{selectedSource.status.indexed @@ -1525,9 +1764,9 @@ function KnowledgePage() {
{selectedSourceRemoteLine ? ( -
-
Remote
-
{selectedSourceRemoteLine}
+
+
Remote
+
{selectedSourceRemoteLine}
) : null} @@ -1541,8 +1780,8 @@ function KnowledgePage() { -
-
{t("knowledge.documentContent")}
+
+
{t("knowledge.documentContent")}
{sourceContentLoading ? (
@@ -1562,7 +1801,22 @@ function KnowledgePage() { {doc.title || doc.path} )} -
{doc.text}
+
+ +
))}
diff --git a/src/copaw/app/routers/knowledge.py b/src/copaw/app/routers/knowledge.py index ea100e218..c19d8a4b4 100644 --- a/src/copaw/app/routers/knowledge.py +++ b/src/copaw/app/routers/knowledge.py @@ -3,11 +3,18 @@ from __future__ import annotations import asyncio +import io import json +import shutil +import tempfile +import zipfile +from datetime import datetime, timezone +from pathlib import Path from typing import Optional from types import SimpleNamespace from fastapi import APIRouter, Body, File, Form, HTTPException, Query, UploadFile, WebSocket, WebSocketDisconnect +from fastapi.responses import StreamingResponse from ...config import load_config, save_config from ...config.config import KnowledgeConfig, KnowledgeSourceSpec @@ -17,6 +24,49 @@ router = APIRouter(prefix="/knowledge", tags=["knowledge"]) +def _zip_path(path) -> io.BytesIO: + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: + for entry in sorted(path.rglob("*")): + arcname = entry.relative_to(path).as_posix() + if entry.is_file(): + zf.write(entry, arcname) + elif entry.is_dir(): + zf.write(entry, arcname + "/") + buf.seek(0) + return buf + + +def _validate_zip_data(data: bytes) -> None: + if not zipfile.is_zipfile(io.BytesIO(data)): + raise HTTPException( + status_code=400, + detail="Uploaded file is not a valid zip archive", + ) + with zipfile.ZipFile(io.BytesIO(data)) as zf: + for name in zf.namelist(): + p = Path(name) + if p.is_absolute() or ".." in p.parts: + raise HTTPException( + status_code=400, + detail=f"Zip contains unsafe path: {name}", + ) + + +def _extract_zip_to_temp(data: bytes) -> Path: + tmp_dir = Path(tempfile.mkdtemp(prefix="copaw_knowledge_import_")) + with zipfile.ZipFile(io.BytesIO(data)) as zf: + zf.extractall(tmp_dir) + return tmp_dir + + +def _detect_extract_root(tmp_dir: Path) -> Path: + entries = [entry for entry in tmp_dir.iterdir() if not entry.name.startswith(".__")] + if len(entries) == 1 and entries[0].is_dir() and (entries[0] / "sources").exists(): + return entries[0] + return tmp_dir + + def _clamp_int(value: str | None, default: int, minimum: int, maximum: int) -> int: try: parsed = int((value or "").strip()) @@ -290,4 +340,103 @@ async def stream_history_backfill_progress(websocket: WebSocket): last_fingerprint = fingerprint await asyncio.sleep(interval_ms / 1000) except WebSocketDisconnect: - return \ No newline at end of file + return + + +@router.get("/backup") +async def backup_knowledge(): + manager = _manager() + if not manager.root_dir.exists(): + raise HTTPException(status_code=404, detail="KNOWLEDGE_NOT_FOUND") + + buf = await asyncio.to_thread(_zip_path, manager.root_dir) + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + filename = f"copaw_knowledge_{timestamp}.zip" + return StreamingResponse( + buf, + media_type="application/zip", + headers={ + "Content-Disposition": f'attachment; filename="{filename}"', + }, + ) + + +@router.get("/backup/{source_id}") +async def backup_knowledge_source(source_id: str): + manager = _manager() + source_dir = manager.get_source_storage_dir(source_id) + if not source_dir.exists() or not source_dir.is_dir(): + raise HTTPException(status_code=404, detail="KNOWLEDGE_SOURCE_NOT_FOUND") + + buf = await asyncio.to_thread(_zip_path, source_dir) + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + safe_name = manager._safe_name(source_id) + filename = f"copaw_knowledge_{safe_name}_{timestamp}.zip" + return StreamingResponse( + buf, + media_type="application/zip", + headers={ + "Content-Disposition": f'attachment; filename="{filename}"', + }, + ) + + +@router.post("/restore") +async def restore_knowledge_backup( + file: UploadFile = File(...), + replace_existing: bool = Query(default=True), +): + if file.content_type and file.content_type not in { + "application/zip", + "application/x-zip-compressed", + "application/octet-stream", + }: + raise HTTPException( + status_code=400, + detail=f"Expected a zip file, got content-type: {file.content_type}", + ) + + data = await file.read() + _validate_zip_data(data) + + manager = _manager() + tmp_dir: Path | None = None + try: + tmp_dir = await asyncio.to_thread(_extract_zip_to_temp, data) + extract_root = _detect_extract_root(tmp_dir) + if not (extract_root / "sources").is_dir(): + raise HTTPException( + status_code=400, + detail="Invalid knowledge backup: missing sources directory", + ) + + if replace_existing and manager.root_dir.exists(): + shutil.rmtree(manager.root_dir, ignore_errors=True) + + manager.root_dir.mkdir(parents=True, exist_ok=True) + for item in extract_root.iterdir(): + dest = manager.root_dir / item.name + if item.is_file(): + shutil.copy2(item, dest) + else: + if dest.exists() and dest.is_file(): + dest.unlink() + shutil.copytree(item, dest, dirs_exist_ok=True) + + manager.sources_dir.mkdir(parents=True, exist_ok=True) + manager.uploads_dir.mkdir(parents=True, exist_ok=True) + manager.remote_blob_dir.mkdir(parents=True, exist_ok=True) + manager.remote_meta_dir.mkdir(parents=True, exist_ok=True) + + config = load_config() + config.knowledge.sources = manager.list_sources_from_storage() + save_config(config) + + return { + "success": True, + "replace_existing": replace_existing, + "restored_sources": len(config.knowledge.sources), + } + finally: + if tmp_dir and tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) \ No newline at end of file diff --git a/src/copaw/knowledge/manager.py b/src/copaw/knowledge/manager.py index fdbe56226..422089fc3 100644 --- a/src/copaw/knowledge/manager.py +++ b/src/copaw/knowledge/manager.py @@ -99,14 +99,18 @@ class KnowledgeManager: def __init__(self, working_dir: str | Path): self.working_dir = Path(working_dir).expanduser().resolve() self.root_dir = self.working_dir / "knowledge" - self.index_dir = self.root_dir / "indexes" + self.sources_dir = self.root_dir / "sources" + self.catalog_path = self.root_dir / "catalog.json" self.uploads_dir = self.root_dir / "uploads" self.backfill_state_path = self.root_dir / "history-backfill-state.json" self.backfill_progress_path = self.root_dir / "history-backfill-progress.json" self.remote_dir = self.uploads_dir / "remote" self.remote_blob_dir = self.remote_dir / "blobs" self.remote_meta_dir = self.remote_dir / "url-meta" - self.index_dir.mkdir(parents=True, exist_ok=True) + legacy_index_dir = self.root_dir / "indexes" + if legacy_index_dir.exists(): + shutil.rmtree(legacy_index_dir, ignore_errors=True) + self.sources_dir.mkdir(parents=True, exist_ok=True) self.uploads_dir.mkdir(parents=True, exist_ok=True) self.remote_blob_dir.mkdir(parents=True, exist_ok=True) self.remote_meta_dir.mkdir(parents=True, exist_ok=True) @@ -134,8 +138,8 @@ def get_source_status( source: KnowledgeSourceSpec | None = None, ) -> dict[str, Any]: """Return persisted index metadata for a source.""" - index_path = self._index_path(source_id) - if not index_path.exists(): + source_index_path = self._source_index_path(source_id) + if not source_index_path.exists(): status = { "indexed": False, "indexed_at": None, @@ -147,7 +151,7 @@ def get_source_status( status.update(self._remote_source_status(source)) return status - payload = self._load_json(index_path) + payload = self._load_json(source_index_path) status = { "indexed": True, "indexed_at": payload.get("indexed_at"), @@ -179,10 +183,7 @@ def index_source( "error": None, "chunks": chunks, } - self._index_path(source.id).write_text( - json.dumps(payload, ensure_ascii=False, indent=2), - encoding="utf-8", - ) + self._write_source_storage(source, payload, documents) return { "source_id": source.id, "document_count": len(documents), @@ -208,22 +209,22 @@ def index_all( def delete_index(self, source_id: str) -> None: """Delete persisted index for a source.""" - index_path = self._index_path(source_id) - if index_path.exists(): - index_path.unlink() + source_dir = self._source_dir(source_id) + if source_dir.exists(): + shutil.rmtree(source_dir, ignore_errors=True) def clear_knowledge(self, config: KnowledgeConfig, *, remove_sources: bool = True) -> dict[str, Any]: """Clear persisted knowledge data and optionally reset configured sources.""" source_count = len(config.sources) cleared_indexes = 0 - if self.index_dir.exists(): - cleared_indexes = len(list(self.index_dir.glob("*.json"))) + if self.sources_dir.exists(): + cleared_indexes = len(list(self.sources_dir.glob("*/index.json"))) if self.root_dir.exists(): shutil.rmtree(self.root_dir, ignore_errors=True) # Recreate expected directory structure after cleanup. - self.index_dir.mkdir(parents=True, exist_ok=True) + self.sources_dir.mkdir(parents=True, exist_ok=True) self.uploads_dir.mkdir(parents=True, exist_ok=True) self.remote_blob_dir.mkdir(parents=True, exist_ok=True) self.remote_meta_dir.mkdir(parents=True, exist_ok=True) @@ -258,10 +259,9 @@ def search( continue if source_types and source.type not in source_types: continue - index_path = self._index_path(source.id) - if not index_path.exists(): + payload = self._load_index_payload(source.id) + if payload is None: continue - payload = self._load_json(index_path) for chunk in payload.get("chunks", []): score = self._score_chunk(chunk.get("text", ""), terms) if score <= 0: @@ -286,10 +286,9 @@ def search( def get_source_documents(self, source_id: str) -> dict[str, Any]: """Return the indexed documents for a source, merged by document path.""" - index_path = self._index_path(source_id) - if not index_path.exists(): + payload = self._load_index_payload(source_id) + if payload is None: return {"indexed": False, "documents": []} - payload = self._load_json(index_path) chunks = payload.get("chunks", []) # Merge chunks back into per-document text blocks docs: dict[str, dict[str, Any]] = {} @@ -318,8 +317,206 @@ def get_source_documents(self, source_id: str) -> dict[str, Any]: "documents": documents, } - def _index_path(self, source_id: str) -> Path: - return self.index_dir / f"{source_id}.json" + def _source_dir(self, source_id: str) -> Path: + return self.sources_dir / self._safe_name(source_id) + + def _source_index_path(self, source_id: str) -> Path: + return self._source_dir(source_id) / "index.json" + + def _source_content_md_path(self, source_id: str) -> Path: + return self._source_dir(source_id) / "content.md" + + def get_source_storage_dir(self, source_id: str) -> Path: + return self._source_dir(source_id) + + def list_sources_from_storage(self) -> list[KnowledgeSourceSpec]: + """Rebuild source specs from persisted v2 storage layout.""" + sources: list[KnowledgeSourceSpec] = [] + for index_path in sorted(self.sources_dir.glob("*/index.json")): + try: + payload = self._load_json(index_path) + source_payload = payload.get("source") + if not isinstance(source_payload, dict): + continue + source = KnowledgeSourceSpec.model_validate(source_payload) + sources.append(source) + except Exception: + logger.warning( + "Failed to read source spec from storage index: %s", + index_path, + ) + return sources + + def _load_index_payload(self, source_id: str) -> dict[str, Any] | None: + source_index_path = self._source_index_path(source_id) + if source_index_path.exists(): + return self._load_json(source_index_path) + return None + + def _write_source_storage( + self, + source: KnowledgeSourceSpec, + payload: dict[str, Any], + documents: list[dict[str, str]], + ) -> None: + source_dir = self._source_dir(source.id) + source_dir.mkdir(parents=True, exist_ok=True) + (source_dir / "raw").mkdir(parents=True, exist_ok=True) + (source_dir / "media").mkdir(parents=True, exist_ok=True) + + self._source_index_path(source.id).write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + self._source_content_md_path(source.id).write_text( + self._build_source_markdown(source, documents), + encoding="utf-8", + ) + self._sync_raw_source_assets(source) + self._update_catalog_entry(source, payload) + + def _update_catalog_entry( + self, + source: KnowledgeSourceSpec, + payload: dict[str, Any], + ) -> None: + catalog: dict[str, Any] = { + "version": 2, + "updated_at": datetime.now(UTC).isoformat(), + "sources": {}, + } + if self.catalog_path.exists(): + try: + current = self._load_json(self.catalog_path) + if isinstance(current, dict): + catalog.update(current) + if not isinstance(catalog.get("sources"), dict): + catalog["sources"] = {} + except Exception: + logger.warning("Failed to read knowledge catalog, recreating") + + catalog["updated_at"] = datetime.now(UTC).isoformat() + catalog["sources"][source.id] = { + "id": source.id, + "name": source.name, + "type": source.type, + "indexed_at": payload.get("indexed_at"), + "document_count": payload.get("document_count", 0), + "chunk_count": payload.get("chunk_count", 0), + "path": str(self._source_dir(source.id)), + } + + self.catalog_path.write_text( + json.dumps(catalog, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def _build_source_markdown( + self, + source: KnowledgeSourceSpec, + documents: list[dict[str, str]], + ) -> str: + lines = [ + f"# {source.name}", + "", + "## Metadata", + "", + f"- id: {source.id}", + f"- type: {source.type}", + f"- location: {source.location or '-'}", + f"- updated_at: {datetime.now(UTC).isoformat()}", + "", + "## Documents", + "", + ] + if not documents: + lines.append("(no documents)") + lines.append("") + return "\n".join(lines) + + for doc in documents: + title = self._truncate_title(doc.get("title", "document"), max_len=200) + path = doc.get("path", "") + text = doc.get("text", "").strip() + lines.extend( + [ + f"### {title}", + "", + f"- path: {path}", + "", + text if text else "(empty)", + "", + ], + ) + return "\n".join(lines) + + def _sync_raw_source_assets(self, source: KnowledgeSourceSpec) -> None: + raw_root = self._source_dir(source.id) / "raw" + media_root = self._source_dir(source.id) / "media" + + if source.type not in {"file", "directory"}: + return + if not source.location: + return + + source_path = Path(source.location).expanduser() + if not source_path.exists(): + return + + if source.type == "file" and source_path.is_file(): + target_file = raw_root / source_path.name + try: + shutil.copy2(source_path, target_file) + self._write_media_semantic_if_needed(target_file, media_root) + except Exception: + logger.warning("Failed to sync raw file for source %s", source.id) + return + + if source.type == "directory" and source_path.is_dir(): + target_dir = raw_root / source_path.name + if target_dir.exists(): + shutil.rmtree(target_dir, ignore_errors=True) + try: + shutil.copytree(source_path, target_dir, dirs_exist_ok=True) + for file_path in target_dir.rglob("*"): + if file_path.is_file(): + self._write_media_semantic_if_needed(file_path, media_root) + except Exception: + logger.warning("Failed to sync raw directory for source %s", source.id) + + def _write_media_semantic_if_needed(self, file_path: Path, media_root: Path) -> None: + suffix = file_path.suffix.lower() + media_kind = None + if suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg"}: + media_kind = "image" + elif suffix in {".mp3", ".wav", ".m4a", ".flac", ".ogg"}: + media_kind = "audio" + elif suffix in {".mp4", ".mov", ".avi", ".mkv", ".webm"}: + media_kind = "video" + if media_kind is None: + return + + semantic_name = f"{self._safe_name(file_path.stem)}.semantic.md" + semantic_path = media_root / semantic_name + size = file_path.stat().st_size if file_path.exists() else 0 + semantic_path.write_text( + "\n".join( + [ + f"# {file_path.name}", + "", + "## Semantic Summary", + "", + "(placeholder) Semantic extraction is not generated yet.", + "", + "## Metadata", + "", + f"- kind: {media_kind}", + f"- original_file: {file_path.as_posix()}", + f"- size_bytes: {size}", + ], + ), + encoding="utf-8", + ) @staticmethod def _load_json(path: Path) -> dict[str, Any]: @@ -1964,7 +2161,7 @@ def _read_directory_text( return "" def _load_index_payload_safe(self, source_id: str) -> dict[str, Any] | None: - index_path = self._index_path(source_id) + index_path = self._source_index_path(source_id) if not index_path.exists(): return None try: diff --git a/tests/unit/app/routers/test_knowledge.py b/tests/unit/app/routers/test_knowledge.py index 985b76df7..ad8dd331e 100644 --- a/tests/unit/app/routers/test_knowledge.py +++ b/tests/unit/app/routers/test_knowledge.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- +import io +import json from pathlib import Path +import zipfile import pytest from fastapi import FastAPI @@ -149,9 +152,9 @@ def test_clear_knowledge_removes_sources_and_indexes( saved = knowledge_api_client.put("/knowledge/config", json=config_payload) assert saved.status_code == 200 - index_root = tmp_path / "knowledge" / "indexes" - index_root.mkdir(parents=True, exist_ok=True) - (index_root / "clear-1.json").write_text("{}", encoding="utf-8") + source_root = tmp_path / "knowledge" / "sources" / "clear-1" + source_root.mkdir(parents=True, exist_ok=True) + (source_root / "index.json").write_text("{}", encoding="utf-8") response = knowledge_api_client.delete( "/knowledge/clear?confirm=true&remove_sources=true" @@ -225,3 +228,137 @@ def test_get_memify_job_status_success( payload = response.json() assert payload["job_id"] == job_id assert payload["status"] in {"succeeded", "failed"} + + +def _build_knowledge_zip(entries: dict[str, str]) -> bytes: + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: + for path, content in entries.items(): + zf.writestr(path, content) + return buf.getvalue() + + +def _source_index_payload(source_id: str, content: str = "knowledge text") -> str: + payload = { + "source": { + "id": source_id, + "name": source_id, + "type": "text", + "location": "", + "content": content, + "enabled": True, + "recursive": False, + "tags": [], + "description": "", + }, + "documents": [ + { + "path": f"{source_id}.md", + "title": source_id, + "text": content, + } + ], + "chunks": [], + } + return json.dumps(payload, ensure_ascii=False) + + +def test_restore_knowledge_backup_replace_existing( + knowledge_api_client: TestClient, + tmp_path: Path, +): + old_source_dir = tmp_path / "knowledge" / "sources" / "old-source" + old_source_dir.mkdir(parents=True, exist_ok=True) + (old_source_dir / "index.json").write_text( + _source_index_payload("old-source", "old content"), + encoding="utf-8", + ) + + zip_data = _build_knowledge_zip( + { + "sources/new-source/index.json": _source_index_payload( + "new-source", + "new content", + ), + "sources/new-source/content.md": "# new-source\n\nnew content\n", + "catalog.json": json.dumps({"version": 2}), + } + ) + + response = knowledge_api_client.post( + "/knowledge/restore", + files={"file": ("knowledge.zip", zip_data, "application/zip")}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["success"] is True + assert payload["replace_existing"] is True + assert payload["restored_sources"] == 1 + assert not old_source_dir.exists() + assert (tmp_path / "knowledge" / "sources" / "new-source" / "index.json").exists() + + listing = knowledge_api_client.get("/knowledge/sources") + assert listing.status_code == 200 + ids = {item["id"] for item in listing.json()["sources"]} + assert ids == {"new-source"} + + +def test_restore_knowledge_backup_merge_existing( + knowledge_api_client: TestClient, + tmp_path: Path, +): + existing_source_dir = tmp_path / "knowledge" / "sources" / "local-source" + existing_source_dir.mkdir(parents=True, exist_ok=True) + (existing_source_dir / "index.json").write_text( + _source_index_payload("local-source", "local content"), + encoding="utf-8", + ) + + zip_data = _build_knowledge_zip( + { + "sources/imported-source/index.json": _source_index_payload( + "imported-source", + "imported content", + ), + "sources/imported-source/content.md": "# imported-source\n\nimported content\n", + } + ) + + response = knowledge_api_client.post( + "/knowledge/restore?replace_existing=false", + files={"file": ("knowledge.zip", zip_data, "application/zip")}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["success"] is True + assert payload["replace_existing"] is False + assert payload["restored_sources"] == 2 + assert (tmp_path / "knowledge" / "sources" / "local-source" / "index.json").exists() + assert ( + tmp_path / "knowledge" / "sources" / "imported-source" / "index.json" + ).exists() + + listing = knowledge_api_client.get("/knowledge/sources") + assert listing.status_code == 200 + ids = {item["id"] for item in listing.json()["sources"]} + assert ids == {"local-source", "imported-source"} + + +def test_restore_knowledge_backup_rejects_unsafe_zip_path( + knowledge_api_client: TestClient, +): + zip_data = _build_knowledge_zip( + { + "../escape/index.json": _source_index_payload("escape"), + } + ) + + response = knowledge_api_client.post( + "/knowledge/restore", + files={"file": ("knowledge.zip", zip_data, "application/zip")}, + ) + + assert response.status_code == 400 + assert "unsafe path" in response.json()["detail"] From a4c129d7665c1d04a0c46b4fce880fe5f49fd89a Mon Sep 17 00:00:00 2001 From: Future Meng Date: Tue, 17 Mar 2026 19:52:46 +0800 Subject: [PATCH 58/68] feat(knowledge): unify summary/subject model and nlp dependency layering --- console/src/api/types/agent.ts | 9 +- console/src/api/types/knowledge.ts | 12 +- console/src/locales/en.json | 13 +- console/src/locales/zh.json | 15 +- .../components/KnowledgeMaintenanceCard.tsx | 29 +- .../pages/Agent/Knowledge/index.module.less | 22 ++ console/src/pages/Agent/Knowledge/index.tsx | 187 +++++++-- pyproject.toml | 6 +- src/copaw/agents/react_agent.py | 6 +- src/copaw/agents/skills_manager.py | 53 ++- src/copaw/agents/tools/graph_query.py | 6 + src/copaw/agents/tools/knowledge_search.py | 9 + src/copaw/agents/tools/memify_run.py | 4 + src/copaw/agents/tools/memify_status.py | 4 + .../agents/tools/triplet_focus_search.py | 4 + src/copaw/app/routers/agent.py | 44 ++- src/copaw/app/routers/knowledge.py | 37 +- src/copaw/app/runner/runner.py | 10 +- src/copaw/config/config.py | 23 +- src/copaw/knowledge/manager.py | 368 ++++++++++++------ src/copaw/knowledge/module_skills.py | 23 ++ .../knowledge_search_assistant/SKILL.md | 42 ++ tests/unit/app/routers/test_knowledge.py | 126 +++++- .../test_knowledge_context_injection.py | 10 +- .../unit/app/test_agent_skill_registration.py | 64 +++ tests/unit/app/test_graph_tools.py | 2 +- .../unit/app/test_knowledge_module_skills.py | 31 ++ 27 files changed, 901 insertions(+), 258 deletions(-) create mode 100644 src/copaw/knowledge/module_skills.py create mode 100644 src/copaw/knowledge/skills/knowledge_search_assistant/SKILL.md create mode 100644 tests/unit/app/test_agent_skill_registration.py create mode 100644 tests/unit/app/test_knowledge_module_skills.py diff --git a/console/src/api/types/agent.ts b/console/src/api/types/agent.ts index 81eedb8c2..8406fba86 100644 --- a/console/src/api/types/agent.ts +++ b/console/src/api/types/agent.ts @@ -13,9 +13,10 @@ export interface AgentsRunningConfig { memory_reserve_ratio: number; enable_tool_result_compact: boolean; tool_result_compact_keep_n: number; - auto_collect_chat_files: boolean; - auto_collect_chat_urls: boolean; - auto_collect_long_text: boolean; - long_text_min_chars: number; + knowledge_enabled: boolean; + knowledge_auto_collect_chat_files: boolean; + knowledge_auto_collect_chat_urls: boolean; + knowledge_auto_collect_long_text: boolean; + knowledge_long_text_min_chars: number; knowledge_chunk_size: number; } diff --git a/console/src/api/types/knowledge.ts b/console/src/api/types/knowledge.ts index 4f3c71071..0c9a9ea06 100644 --- a/console/src/api/types/knowledge.ts +++ b/console/src/api/types/knowledge.ts @@ -14,7 +14,7 @@ export interface KnowledgeSourceSpec { enabled: boolean; recursive: boolean; tags: string[]; - description: string; + summary: string; } export interface KnowledgeIndexConfig { @@ -26,10 +26,10 @@ export interface KnowledgeIndexConfig { } export interface KnowledgeAutomationConfig { - auto_collect_chat_files: boolean; - auto_collect_chat_urls: boolean; - auto_collect_long_text: boolean; - long_text_min_chars: number; + knowledge_auto_collect_chat_files: boolean; + knowledge_auto_collect_chat_urls: boolean; + knowledge_auto_collect_long_text: boolean; + knowledge_long_text_min_chars: number; } export interface KnowledgeConfig { @@ -55,6 +55,8 @@ export interface KnowledgeSourceStatus { } export interface KnowledgeSourceItem extends KnowledgeSourceSpec { + subject?: string; + keywords?: string[]; status: KnowledgeSourceStatus; } diff --git a/console/src/locales/en.json b/console/src/locales/en.json index f96f4cdbb..c69400603 100644 --- a/console/src/locales/en.json +++ b/console/src/locales/en.json @@ -110,7 +110,7 @@ "indexFailed": "Failed to index knowledge sources", "clearKnowledge": "Clear Knowledge", "clearConfirmTitle": "Confirm Clear Knowledge", - "clearConfirmContent": "This will delete all knowledge source configs and indexed data. This action cannot be undone. Continue?", + "clearConfirmContent": "This will delete all knowledge source configs and indexed data. This action cannot be undone. Please back up first before clearing. Continue?", "clearConfirmOk": "Clear", "clearSuccess": "Knowledge cleared: removed {{sources}} source configs and {{indexes}} index files", "clearFailed": "Failed to clear knowledge", @@ -178,8 +178,9 @@ "documentContentEmpty": "No content retrieved", "documentContentNotIndexed": "Not indexed yet — run Index first", "table": { - "title": "Title", - "source": "Description", + "subject": "Subject", + "source": "Summary", + "keywords": "Keywords", "origin": "Origin", "type": "Type", "location": "Location", @@ -200,8 +201,8 @@ "directoryUpload": "Folder Upload", "content": "Content", "contentPlaceholder": "Paste source text here", - "description": "Description", - "descriptionPlaceholder": "What this source is for", + "summary": "Summary", + "summaryPlaceholder": "Add a summary for this source (optional)", "enabled": "Enabled", "disabled": "Disabled" } @@ -671,6 +672,8 @@ "toolResultCompactKeepNMin": "Tool result keep count must be at least 1", "autoCollectChatFiles": "Auto-collect chat files", "autoCollectChatFilesTooltip": "After each chat turn, automatically register file references from the current turn as file knowledge sources.", + "knowledgeEnabled": "Enable knowledge", + "knowledgeEnabledTooltip": "Master switch for knowledge features. When off, retrieval, auto-collection, and maintenance are all disabled.", "autoCollectChatUrls": "Auto-collect chat URLs", "autoCollectChatUrlsTooltip": "After each chat turn, automatically register URLs found in chat text as URL knowledge sources.", "autoCollectLongText": "Auto-save long text", diff --git a/console/src/locales/zh.json b/console/src/locales/zh.json index 953318801..92182a418 100644 --- a/console/src/locales/zh.json +++ b/console/src/locales/zh.json @@ -110,7 +110,7 @@ "indexFailed": "知识源索引失败", "clearKnowledge": "清空知识库", "clearConfirmTitle": "确认清空知识库", - "clearConfirmContent": "将删除所有知识来源配置与已索引数据,此操作不可恢复。是否继续?", + "clearConfirmContent": "将删除所有知识来源配置与已索引数据,此操作不可恢复。请先执行备份,再继续清空。是否继续?", "clearConfirmOk": "确认清空", "clearSuccess": "知识库已清空:删除 {{sources}} 个来源配置,清理 {{indexes}} 个索引文件", "clearFailed": "清空知识库失败", @@ -178,8 +178,9 @@ "documentContentEmpty": "未读取到内容", "documentContentNotIndexed": "尚未索引,建诮先执行“索引”操作", "table": { - "title": "标题", - "source": "描述", + "subject": "主题", + "source": "摘要", + "keywords": "关键词", "origin": "来源归属", "type": "类型", "location": "位置", @@ -200,8 +201,8 @@ "directoryUpload": "文件夹上传", "content": "内容", "contentPlaceholder": "在这里粘贴来源文本", - "description": "描述", - "descriptionPlaceholder": "说明该来源的用途", + "summary": "摘要", + "summaryPlaceholder": "补充该来源的摘要(可选)", "enabled": "启用", "disabled": "停用" } @@ -639,7 +640,7 @@ "description": "配置智能体运行参数", "reactAgentTitle": "ReAct 智能体", "contextManagementTitle": "上下文管理", - "knowledgeMaintenanceTitle": "知识库自动维护", + "knowledgeMaintenanceTitle": "知识库", "maxIters": "最大迭代次数", "maxItersTooltip": "ReAct 智能体的最大推理-行动迭代次数", "maxItersPlaceholder": "请输入最大迭代次数", @@ -671,6 +672,8 @@ "toolResultCompactKeepNMin": "工具结果保留数量必须大于等于 1", "autoCollectChatFiles": "自动收集对话文件", "autoCollectChatFilesTooltip": "在每轮对话结束后,将本轮文件引用自动登记为知识库 file 来源。", + "knowledgeEnabled": "启用知识库", + "knowledgeEnabledTooltip": "知识库总开关。关闭后,知识检索、自动沉淀与相关维护能力将全部停用。", "autoCollectChatUrls": "自动收集对话 URL", "autoCollectChatUrlsTooltip": "在每轮对话结束后,将对话文本中出现的 URL 自动登记为知识库 url 来源。", "autoCollectLongText": "自动沉淀长文本", diff --git a/console/src/pages/Agent/Config/components/KnowledgeMaintenanceCard.tsx b/console/src/pages/Agent/Config/components/KnowledgeMaintenanceCard.tsx index c475fd6ab..b2cc801c3 100644 --- a/console/src/pages/Agent/Config/components/KnowledgeMaintenanceCard.tsx +++ b/console/src/pages/Agent/Config/components/KnowledgeMaintenanceCard.tsx @@ -4,7 +4,8 @@ import styles from "../index.module.less"; export function KnowledgeMaintenanceCard() { const { t } = useTranslation(); - const autoCollectLongText = Form.useWatch("auto_collect_long_text"); + const knowledgeEnabled = Form.useWatch("knowledge_enabled"); + const autoCollectLongText = Form.useWatch("knowledge_auto_collect_long_text"); return ( + + + + - + - + - + @@ -86,6 +96,7 @@ export function KnowledgeMaintenanceCard() { min={200} max={8000} step={100} + disabled={!knowledgeEnabled} placeholder={t("agentConfig.knowledgeChunkSizePlaceholder")} /> diff --git a/console/src/pages/Agent/Knowledge/index.module.less b/console/src/pages/Agent/Knowledge/index.module.less index 391787548..7546ca17e 100644 --- a/console/src/pages/Agent/Knowledge/index.module.less +++ b/console/src/pages/Agent/Knowledge/index.module.less @@ -128,6 +128,12 @@ width: 100%; } +.disabledPanel { + opacity: 0.55; + pointer-events: none; + filter: saturate(0.8); +} + .fullWidth { width: 100%; } @@ -506,6 +512,16 @@ line-height: 1.4; } +.keywordList { + display: flex; + flex-wrap: wrap; + gap: 6px; +} + +.keywordTag { + margin: 0; +} + .clickableBlock { cursor: pointer; transition: border-color 0.2s ease, box-shadow 0.2s ease; @@ -619,6 +635,12 @@ border-color: #d64550; } +.noteStyleObsidian .keywordTag { + background: rgba(99, 179, 237, 0.16); + border-color: rgba(99, 179, 237, 0.28); + color: #d8ecff; +} + .remoteError { color: #d4380d; } diff --git a/console/src/pages/Agent/Knowledge/index.tsx b/console/src/pages/Agent/Knowledge/index.tsx index 023ed1efd..d3eff4ee9 100644 --- a/console/src/pages/Agent/Knowledge/index.tsx +++ b/console/src/pages/Agent/Knowledge/index.tsx @@ -116,6 +116,8 @@ function KnowledgePage() { const [runningConfig, setRunningConfig] = useState( null, ); + const knowledgeRuntimeEnabled = runningConfig?.knowledge_enabled ?? true; + const knowledgePageDisabled = !knowledgeRuntimeEnabled; const [backfillStatus, setBackfillStatus] = useState(null); const [backfillProgress, setBackfillProgress] = @@ -412,14 +414,18 @@ function KnowledgePage() { const handleConfirmEnable = useCallback( async (runBackfillNow: boolean) => { - if (!config) { + if (!config || !runningConfig) { return; } try { - const nextRunningConfig = await enableConfigForm.validateFields(); + const nextRunningConfig = { + ...runningConfig, + ...(await enableConfigForm.validateFields()), + knowledge_enabled: true, + } as AgentsRunningConfig; setEnableModalSubmitting(true); const updatedRunningConfig = await api.updateAgentRunningConfig( - nextRunningConfig as AgentsRunningConfig, + nextRunningConfig, ); setRunningConfig(updatedRunningConfig); const updatedKnowledge = await persistKnowledgeConfig({ @@ -454,20 +460,43 @@ function KnowledgePage() { handleRunHistoryBackfillNow, loadData, persistKnowledgeConfig, + runningConfig, t, ], ); const handleToggleEnabled = async (checked: boolean) => { - if (!config) { + if (!config || !runningConfig) { return; } if (!checked) { - await persistKnowledgeConfig({ - ...config, - enabled: false, - }); + try { + const updatedRunningConfig = await api.updateAgentRunningConfig({ + ...runningConfig, + knowledge_enabled: false, + }); + setRunningConfig(updatedRunningConfig); + await loadData(); + } catch (error) { + console.error("Failed to disable knowledge runtime", error); + message.error(t("knowledge.configSaveFailed")); + } + return; + } + + if (config.enabled) { + try { + const updatedRunningConfig = await api.updateAgentRunningConfig({ + ...runningConfig, + knowledge_enabled: true, + }); + setRunningConfig(updatedRunningConfig); + await loadData(); + } catch (error) { + console.error("Failed to enable knowledge runtime", error); + message.error(t("knowledge.configSaveFailed")); + } return; } @@ -485,6 +514,9 @@ function KnowledgePage() { }; const handleAddSource = async () => { + if (knowledgePageDisabled) { + return; + } try { const values = await form.validateFields(); setSaving(true); @@ -523,7 +555,7 @@ function KnowledgePage() { content, recursive: values.recursive ?? true, tags: values.tags ?? [], - description: values.description ?? "", + summary: values.summary ?? "", }; await api.upsertKnowledgeSource(payload); @@ -552,6 +584,9 @@ function KnowledgePage() { }; const handleIndexSource = useCallback(async (sourceId: string) => { + if (knowledgePageDisabled) { + return; + } try { setIndexingId(sourceId); await api.indexKnowledgeSource(sourceId); @@ -563,9 +598,12 @@ function KnowledgePage() { } finally { setIndexingId(null); } - }, [loadData, t]); + }, [knowledgePageDisabled, loadData, t]); const handleIndexAll = async () => { + if (knowledgePageDisabled) { + return; + } try { setIndexingAll(true); await api.indexAllKnowledgeSources(); @@ -591,6 +629,9 @@ function KnowledgePage() { }, []); const handleBackupAll = useCallback(async () => { + if (knowledgePageDisabled) { + return; + } try { setExportingAll(true); const blob = await api.downloadKnowledgeBackup(); @@ -603,7 +644,7 @@ function KnowledgePage() { } finally { setExportingAll(false); } - }, [t, triggerBlobDownload]); + }, [knowledgePageDisabled, t, triggerBlobDownload]); const handleBackupSource = useCallback( async (sourceId: string, sourceName: string) => { @@ -657,13 +698,19 @@ function KnowledgePage() { ); const handleRestoreBackup = useCallback(() => { + if (knowledgePageDisabled) { + return; + } if (importingBackup) { return; } backupImportInputRef.current?.click(); - }, [importingBackup]); + }, [importingBackup, knowledgePageDisabled]); const handleClearKnowledge = useCallback(() => { + if (knowledgePageDisabled) { + return; + } Modal.confirm({ title: t("knowledge.clearConfirmTitle"), content: t("knowledge.clearConfirmContent"), @@ -691,9 +738,12 @@ function KnowledgePage() { } }, }); - }, [loadData, t]); + }, [knowledgePageDisabled, loadData, t]); const handleSearch = async () => { + if (knowledgePageDisabled) { + return; + } const query = searchQuery.trim(); if (!query) { setHits([]); @@ -850,6 +900,19 @@ function KnowledgePage() { : "-"; }, [selectedSource, t]); + const selectedSourceSummaryData = useMemo(() => { + if (!selectedSource) { + return { + summary: "", + keywords: [] as string[], + }; + } + return { + summary: (selectedSource.summary || "").trim(), + keywords: selectedSource.keywords || [], + }; + }, [selectedSource]); + const openDetailDrawer = useCallback((record: KnowledgeSourceItem) => { setSelectedSource(record); setSourceContent(null); @@ -878,7 +941,7 @@ function KnowledgePage() { isFileDragActive ? styles.dragZoneActive : "" }`; const enableModalAutoCollectLongText = Form.useWatch( - "auto_collect_long_text", + "knowledge_auto_collect_long_text", enableConfigForm, ); const showBackfillNowButton = Boolean( @@ -1005,7 +1068,7 @@ function KnowledgePage() {
{t("knowledge.enabled")}
@@ -1014,6 +1077,7 @@ function KnowledgePage() { icon={} onClick={handleIndexAll} loading={indexingAll} + disabled={knowledgePageDisabled} > {t("knowledge.indexAll")} @@ -1021,6 +1085,7 @@ function KnowledgePage() { icon={} onClick={handleBackupAll} loading={exportingAll} + disabled={knowledgePageDisabled} > {t("knowledge.backupAll")} @@ -1028,6 +1093,7 @@ function KnowledgePage() { icon={} onClick={handleRestoreBackup} loading={importingBackup} + disabled={knowledgePageDisabled} > {t("knowledge.restore")} @@ -1039,6 +1105,7 @@ function KnowledgePage() { icon={} onClick={handleClearKnowledge} loading={clearingKnowledge} + disabled={knowledgePageDisabled} > {t("knowledge.clearKnowledge")} @@ -1047,6 +1114,7 @@ function KnowledgePage() { icon={} onClick={handleRunHistoryBackfillNow} loading={backfillingHistory} + disabled={knowledgePageDisabled} > {t("knowledge.backfillNowButton")} @@ -1068,7 +1136,11 @@ function KnowledgePage() {
) : null} - + @@ -1227,12 +1299,13 @@ function KnowledgePage() { const originText = getSourceOriginText(record, t); const remoteLine = formatRemoteStatus(record, t); const isActiveCard = indexingId === record.id; - const cardTitle = record.name?.trim() || ""; - const descriptionText = record.description?.trim() || ""; - const hideDescriptionBlock = - cardTitle.length > 0 && - descriptionText.length > 0 && - cardTitle === descriptionText; + const cardSubject = (record.subject || record.name || "").trim(); + const summaryText = (record.summary || "").trim(); + const summaryKeywords = record.keywords || []; + const hideSummaryBlock = + cardSubject.length > 0 && + summaryText.length > 0 && + cardSubject === summaryText; const indexedCountText = record.status.indexed ? t("knowledge.indexedCount", { documents: record.status.document_count, @@ -1256,10 +1329,10 @@ function KnowledgePage() {
- {cardTitle ? ( + {cardSubject ? (
- {t("knowledge.table.title")} + {t("knowledge.table.subject")}
- {cardTitle} + {cardSubject}
) : null} - {!hideDescriptionBlock ? ( + {!hideSummaryBlock ? (
{t("knowledge.table.source")} @@ -1294,18 +1367,35 @@ function KnowledgePage() { handleDetailDrawerValueKeyDown(event, record) } className={`${styles.infoBlock} ${styles.clickableBlock}`} - title={descriptionText || t("knowledge.inlineText")} + title={summaryText || t("knowledge.inlineText")} > - {descriptionText || t("knowledge.inlineText")} + {summaryText || t("knowledge.inlineText")}
) : null} + {summaryKeywords.length > 0 ? ( +
+
+ {t("knowledge.table.keywords")} +
+
+
+ {summaryKeywords.map((keyword) => ( + + {keyword} + + ))} +
+
+
+ ) : null} + {record.location ? (
@@ -1444,7 +1534,7 @@ function KnowledgePage() { location: "", content: "", tags: [], - description: "", + summary: "", }} > )} - - + + @@ -1609,7 +1699,7 @@ function KnowledgePage() { > @@ -1617,7 +1707,7 @@ function KnowledgePage() { @@ -1625,7 +1715,7 @@ function KnowledgePage() { @@ -1633,7 +1723,7 @@ function KnowledgePage() { {selectedSource.name?.trim() ? (
-
{t("knowledge.table.title")}
+
{t("knowledge.table.subject")}
- {selectedSource.name} + {selectedSource.subject || selectedSource.name}
) : null} @@ -1716,10 +1806,25 @@ function KnowledgePage() {
{t("knowledge.table.source")}
- {selectedSource.description || t("knowledge.inlineText")} + {selectedSourceSummaryData.summary || t("knowledge.inlineText")}
+ {selectedSourceSummaryData.keywords.length > 0 ? ( +
+
{t("knowledge.table.keywords")}
+
+
+ {selectedSourceSummaryData.keywords.map((keyword) => ( + + {keyword} + + ))} +
+
+
+ ) : null} +
{selectedSourceOriginText} {selectedSource.type} diff --git a/pyproject.toml b/pyproject.toml index 2ab5bf624..25ed5cbcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "aiofiles>=24.1.0", "paho-mqtt>=2.0.0", "matrix-nio>=0.24.0", + "jieba>=0.42.1", ] [tool.setuptools.dynamic] @@ -74,8 +75,11 @@ mlx = [ ollama = [ "ollama>=0.6.1", ] +nlp = [ + "hanlp>=2.1.0b50", +] full = [ - "copaw[local,ollama,llamacpp]", + "copaw[local,ollama,llamacpp,nlp]", "mlx-lm>=0.10.0; sys_platform == 'darwin'", ] diff --git a/src/copaw/agents/react_agent.py b/src/copaw/agents/react_agent.py index f302ac143..2fbcd4645 100644 --- a/src/copaw/agents/react_agent.py +++ b/src/copaw/agents/react_agent.py @@ -46,7 +46,6 @@ ) from .utils import process_file_and_media_blocks_in_message from ..agents.memory import MemoryManager -from ..config import load_config from ..constant import ( MEMORY_COMPACT_RATIO, WORKING_DIR, @@ -208,6 +207,7 @@ def _create_toolkit( tool_enabled = ( tool_enabled and bool(getattr(config.knowledge, "enabled", False)) + and bool(getattr(config.agents.running, "knowledge_enabled", True)) and bool( getattr( config.agents.running, @@ -220,18 +220,21 @@ def _create_toolkit( tool_enabled = ( tool_enabled and bool(getattr(config.knowledge, "enabled", False)) + and bool(getattr(config.agents.running, "knowledge_enabled", True)) and bool(getattr(config.knowledge, "graph_query_enabled", False)) ) elif tool_name in {"memify_run", "memify_status"}: tool_enabled = ( tool_enabled and bool(getattr(config.knowledge, "enabled", False)) + and bool(getattr(config.agents.running, "knowledge_enabled", True)) and bool(getattr(config.knowledge, "memify_enabled", False)) ) elif tool_name == "triplet_focus_search": tool_enabled = ( tool_enabled and bool(getattr(config.knowledge, "enabled", False)) + and bool(getattr(config.agents.running, "knowledge_enabled", True)) and bool(getattr(config.knowledge, "triplet_search_enabled", False)) ) @@ -253,7 +256,6 @@ def _register_skills(self, toolkit: Toolkit) -> None: Args: toolkit: Toolkit to register skills to """ - # Check skills initialization ensure_skills_initialized() working_skills_dir = get_working_skills_dir() diff --git a/src/copaw/agents/skills_manager.py b/src/copaw/agents/skills_manager.py index 76ba47d81..f3293ac55 100644 --- a/src/copaw/agents/skills_manager.py +++ b/src/copaw/agents/skills_manager.py @@ -217,6 +217,44 @@ def _collect_skills_from_dir(directory: Path) -> dict[str, Path]: return skills +def sync_skill_dir_to_active( + skill_dir: Path, + force: bool = False, +) -> bool: + """Sync a single skill directory into active_skills.""" + skill_md = skill_dir / "SKILL.md" + if not skill_dir.is_dir() or not skill_md.exists(): + logger.warning("Skill directory is invalid or missing SKILL.md: %s", skill_dir) + return False + + active_skills = get_active_skills_dir() + active_skills.mkdir(parents=True, exist_ok=True) + + target_dir = active_skills / skill_dir.name + if target_dir.exists(): + if _directories_match_ignoring_runtime_artifacts(skill_dir, target_dir): + return True + if not force: + logger.debug( + "Skill '%s' already exists in active_skills with different content, skipping.", + skill_dir.name, + ) + return False + shutil.rmtree(target_dir) + + try: + shutil.copytree(skill_dir, target_dir) + logger.debug("Synced skill '%s' to active_skills.", skill_dir.name) + return True + except Exception as e: + logger.error( + "Failed to sync skill directory '%s' to active_skills: %s", + skill_dir, + e, + ) + return False + + def sync_skills_to_working_dir( skill_names: list[str] | None = None, force: bool = False, @@ -280,17 +318,12 @@ def sync_skills_to_working_dir( # Copy skill directory try: - if target_dir.exists(): - shutil.rmtree(target_dir) - shutil.copytree(skill_dir, target_dir) - logger.debug("Synced skill '%s' to active_skills.", skill_name) - synced_count += 1 + if sync_skill_dir_to_active(skill_dir, force=True): + synced_count += 1 + else: + skipped_count += 1 except Exception as e: - logger.error( - "Failed to sync skill '%s': %s", - skill_name, - e, - ) + logger.error("Failed to sync skill '%s': %s", skill_name, e) return synced_count, skipped_count diff --git a/src/copaw/agents/tools/graph_query.py b/src/copaw/agents/tools/graph_query.py index 84c07c81c..0b8ac3d1e 100644 --- a/src/copaw/agents/tools/graph_query.py +++ b/src/copaw/agents/tools/graph_query.py @@ -53,6 +53,12 @@ async def graph_query( TextBlock(type="text", text="Error: knowledge is disabled in configuration."), ], ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[ + TextBlock(type="text", text="Error: knowledge is disabled in agent runtime configuration."), + ], + ) if not bool(getattr(config.knowledge, "graph_query_enabled", False)): return ToolResponse( content=[ diff --git a/src/copaw/agents/tools/knowledge_search.py b/src/copaw/agents/tools/knowledge_search.py index a73e81c16..4288eb86d 100644 --- a/src/copaw/agents/tools/knowledge_search.py +++ b/src/copaw/agents/tools/knowledge_search.py @@ -51,6 +51,15 @@ async def knowledge_search( ), ], ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Knowledge is disabled in agent runtime configuration.", + ), + ], + ) if not bool( getattr( getattr(config, "agents", None), diff --git a/src/copaw/agents/tools/memify_run.py b/src/copaw/agents/tools/memify_run.py index 3c14ef157..7f4a4d86b 100644 --- a/src/copaw/agents/tools/memify_run.py +++ b/src/copaw/agents/tools/memify_run.py @@ -40,6 +40,10 @@ async def memify_run( return ToolResponse( content=[TextBlock(type="text", text="Error: knowledge is disabled in configuration.")], ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in agent runtime configuration.")], + ) if not bool(getattr(config.knowledge, "memify_enabled", False)): return ToolResponse( content=[TextBlock(type="text", text="Error: memify is disabled in configuration.")], diff --git a/src/copaw/agents/tools/memify_status.py b/src/copaw/agents/tools/memify_status.py index e6abb5c50..77e0b9ec6 100644 --- a/src/copaw/agents/tools/memify_status.py +++ b/src/copaw/agents/tools/memify_status.py @@ -26,6 +26,10 @@ async def memify_status(job_id: str) -> ToolResponse: return ToolResponse( content=[TextBlock(type="text", text="Error: knowledge is disabled in configuration.")], ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in agent runtime configuration.")], + ) if not bool(getattr(config.knowledge, "memify_enabled", False)): return ToolResponse( content=[TextBlock(type="text", text="Error: memify is disabled in configuration.")], diff --git a/src/copaw/agents/tools/triplet_focus_search.py b/src/copaw/agents/tools/triplet_focus_search.py index e1cd95691..4f17a1cc3 100644 --- a/src/copaw/agents/tools/triplet_focus_search.py +++ b/src/copaw/agents/tools/triplet_focus_search.py @@ -43,6 +43,10 @@ async def triplet_focus_search( return ToolResponse( content=[TextBlock(type="text", text="Error: knowledge is disabled in configuration.")], ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in agent runtime configuration.")], + ) if not bool(getattr(config.knowledge, "triplet_search_enabled", False)): return ToolResponse( content=[ diff --git a/src/copaw/app/routers/agent.py b/src/copaw/app/routers/agent.py index 79a116ba0..5dbdb16f9 100644 --- a/src/copaw/app/routers/agent.py +++ b/src/copaw/app/routers/agent.py @@ -9,6 +9,7 @@ save_config, AgentsRunningConfig, ) +from ...knowledge.module_skills import sync_knowledge_module_skills from ...agents.memory.agent_md_manager import AGENT_MD_MANAGER @@ -25,31 +26,38 @@ def _migrate_knowledge_automation_to_running(config) -> bool: return False if ( - running.auto_collect_chat_files == defaults.auto_collect_chat_files - and legacy.auto_collect_chat_files != defaults.auto_collect_chat_files + running.knowledge_enabled == defaults.knowledge_enabled + and config.knowledge.enabled != defaults.knowledge_enabled ): - running.auto_collect_chat_files = legacy.auto_collect_chat_files + running.knowledge_enabled = config.knowledge.enabled changed = True if ( - running.auto_collect_chat_urls == defaults.auto_collect_chat_urls - and legacy.auto_collect_chat_urls != defaults.auto_collect_chat_urls + running.knowledge_auto_collect_chat_files == defaults.knowledge_auto_collect_chat_files + and legacy.knowledge_auto_collect_chat_files != defaults.knowledge_auto_collect_chat_files ): - running.auto_collect_chat_urls = legacy.auto_collect_chat_urls + running.knowledge_auto_collect_chat_files = legacy.knowledge_auto_collect_chat_files changed = True if ( - running.auto_collect_long_text == defaults.auto_collect_long_text - and legacy.auto_collect_long_text != defaults.auto_collect_long_text + running.knowledge_auto_collect_chat_urls == defaults.knowledge_auto_collect_chat_urls + and legacy.knowledge_auto_collect_chat_urls != defaults.knowledge_auto_collect_chat_urls ): - running.auto_collect_long_text = legacy.auto_collect_long_text + running.knowledge_auto_collect_chat_urls = legacy.knowledge_auto_collect_chat_urls changed = True if ( - running.long_text_min_chars == defaults.long_text_min_chars - and legacy.long_text_min_chars != defaults.long_text_min_chars + running.knowledge_auto_collect_long_text == defaults.knowledge_auto_collect_long_text + and legacy.knowledge_auto_collect_long_text != defaults.knowledge_auto_collect_long_text ): - running.long_text_min_chars = legacy.long_text_min_chars + running.knowledge_auto_collect_long_text = legacy.knowledge_auto_collect_long_text + changed = True + + if ( + running.knowledge_long_text_min_chars == defaults.knowledge_long_text_min_chars + and legacy.knowledge_long_text_min_chars != defaults.knowledge_long_text_min_chars + ): + running.knowledge_long_text_min_chars = legacy.knowledge_long_text_min_chars changed = True knowledge_index = getattr(config.knowledge, "index", None) @@ -70,10 +78,11 @@ def _sync_running_to_knowledge_automation(config) -> None: if legacy is None: return running = config.agents.running - legacy.auto_collect_chat_files = running.auto_collect_chat_files - legacy.auto_collect_chat_urls = running.auto_collect_chat_urls - legacy.auto_collect_long_text = running.auto_collect_long_text - legacy.long_text_min_chars = running.long_text_min_chars + config.knowledge.enabled = running.knowledge_enabled + legacy.knowledge_auto_collect_chat_files = running.knowledge_auto_collect_chat_files + legacy.knowledge_auto_collect_chat_urls = running.knowledge_auto_collect_chat_urls + legacy.knowledge_auto_collect_long_text = running.knowledge_auto_collect_long_text + legacy.knowledge_long_text_min_chars = running.knowledge_long_text_min_chars config.knowledge.index.chunk_size = running.knowledge_chunk_size @@ -289,8 +298,11 @@ async def put_agents_running_config( ) -> AgentsRunningConfig: """Update agent running configuration.""" config = load_config() + previous_enabled = bool(getattr(config.agents.running, "knowledge_enabled", True)) config.agents.running = running_config _sync_running_to_knowledge_automation(config) + if previous_enabled != running_config.knowledge_enabled: + sync_knowledge_module_skills(running_config.knowledge_enabled) save_config(config) return running_config diff --git a/src/copaw/app/routers/knowledge.py b/src/copaw/app/routers/knowledge.py index c19d8a4b4..246073caf 100644 --- a/src/copaw/app/routers/knowledge.py +++ b/src/copaw/app/routers/knowledge.py @@ -20,10 +20,25 @@ from ...config.config import KnowledgeConfig, KnowledgeSourceSpec from ...constant import WORKING_DIR from ...knowledge import GraphOpsManager, KnowledgeManager +from ...knowledge.module_skills import sync_knowledge_module_skills router = APIRouter(prefix="/knowledge", tags=["knowledge"]) +def _knowledge_runtime_enabled(config) -> bool: + running = getattr(getattr(config, "agents", None), "running", None) + return bool(getattr(running, "knowledge_enabled", True)) + + +def _knowledge_effective_enabled(config) -> bool: + return _knowledge_runtime_enabled(config) + + +def _ensure_knowledge_enabled(config) -> None: + if not _knowledge_effective_enabled(config): + raise HTTPException(status_code=400, detail="KNOWLEDGE_DISABLED") + + def _zip_path(path) -> io.BytesIO: buf = io.BytesIO() with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: @@ -96,7 +111,11 @@ async def put_knowledge_config( knowledge_config: KnowledgeConfig = Body(...), ) -> KnowledgeConfig: config = load_config() + previous_enabled = bool(getattr(config.agents.running, "knowledge_enabled", True)) config.knowledge = knowledge_config + config.agents.running.knowledge_enabled = knowledge_config.enabled + if previous_enabled != knowledge_config.enabled: + sync_knowledge_module_skills(knowledge_config.enabled) save_config(config) return config.knowledge @@ -105,7 +124,7 @@ async def put_knowledge_config( async def list_sources(): config = load_config() return { - "enabled": config.knowledge.enabled, + "enabled": _knowledge_effective_enabled(config), "sources": _manager().list_sources(config.knowledge), } @@ -115,6 +134,7 @@ async def upsert_source( source: KnowledgeSourceSpec = Body(...), ) -> KnowledgeSourceSpec: config = load_config() + _ensure_knowledge_enabled(config) manager = _manager() source = manager.normalize_source_name(source, config.knowledge) existing = _find_source(config.knowledge, source.id) @@ -170,6 +190,7 @@ async def upload_knowledge_directory( @router.delete("/sources/{source_id}") async def delete_source(source_id: str): config = load_config() + _ensure_knowledge_enabled(config) source = _find_source(config.knowledge, source_id) if source is None: raise HTTPException(status_code=404, detail="KNOWLEDGE_SOURCE_NOT_FOUND") @@ -191,6 +212,7 @@ async def clear_knowledge( raise HTTPException(status_code=400, detail="KNOWLEDGE_CLEAR_CONFIRM_REQUIRED") config = load_config() + _ensure_knowledge_enabled(config) result = _manager().clear_knowledge( config.knowledge, remove_sources=remove_sources, @@ -202,6 +224,7 @@ async def clear_knowledge( @router.post("/sources/{source_id}/index") async def index_source(source_id: str): config = load_config() + _ensure_knowledge_enabled(config) source = _find_source(config.knowledge, source_id) if source is None: raise HTTPException(status_code=404, detail="KNOWLEDGE_SOURCE_NOT_FOUND") @@ -230,6 +253,7 @@ async def get_source_content(source_id: str): @router.post("/index") async def index_all_sources(): config = load_config() + _ensure_knowledge_enabled(config) try: return _manager().index_all(config.knowledge, config.agents.running) except (FileNotFoundError, ValueError, OSError) as exc: @@ -246,6 +270,7 @@ async def search_knowledge( source_types: Optional[str] = Query(default=None), ): config = load_config() + _ensure_knowledge_enabled(config) ids = [item for item in (source_ids or "").split(",") if item] types = [item for item in (source_types or "").split(",") if item] return _manager().search( @@ -267,13 +292,14 @@ async def get_history_backfill_status(): async def run_history_backfill_now(): """Run history backfill immediately regardless of runtime auto-backfill toggle.""" config = load_config() + _ensure_knowledge_enabled(config) manager = _manager() running = config.agents.running force_running = SimpleNamespace( - auto_collect_chat_files=running.auto_collect_chat_files, - auto_collect_chat_urls=running.auto_collect_chat_urls, - auto_collect_long_text=running.auto_collect_long_text, - long_text_min_chars=running.long_text_min_chars, + knowledge_auto_collect_chat_files=running.knowledge_auto_collect_chat_files, + knowledge_auto_collect_chat_urls=running.knowledge_auto_collect_chat_urls, + knowledge_auto_collect_long_text=running.knowledge_auto_collect_long_text, + knowledge_long_text_min_chars=running.knowledge_long_text_min_chars, knowledge_chunk_size=running.knowledge_chunk_size, ) result = await asyncio.to_thread( @@ -297,6 +323,7 @@ async def get_memify_job_status(job_id: str): raise HTTPException(status_code=400, detail="MEMIFY_JOB_ID_REQUIRED") config = load_config() + _ensure_knowledge_enabled(config) if not config.knowledge.enabled: raise HTTPException(status_code=400, detail="KNOWLEDGE_DISABLED") if not bool(getattr(config.knowledge, "memify_enabled", False)): diff --git a/src/copaw/app/runner/runner.py b/src/copaw/app/runner/runner.py index 89ad8072b..0928b5a03 100644 --- a/src/copaw/app/runner/runner.py +++ b/src/copaw/app/runner/runner.py @@ -215,8 +215,8 @@ async def query_handler( try: should_collect_user_assets = bool( - getattr(running, "auto_collect_chat_files", False) - or getattr(running, "auto_collect_chat_urls", True) + getattr(running, "knowledge_auto_collect_chat_files", False) + or getattr(running, "knowledge_auto_collect_chat_urls", True) ) if should_collect_user_assets: from ...knowledge import KnowledgeManager @@ -340,8 +340,8 @@ async def query_handler( try: running = config.agents.running should_auto_collect = bool( - getattr(running, "auto_collect_chat_files", False) - or getattr(running, "auto_collect_long_text", False) + getattr(running, "knowledge_auto_collect_chat_files", False) + or getattr(running, "knowledge_auto_collect_long_text", False) ) if should_auto_collect: @@ -350,7 +350,7 @@ async def query_handler( manager = knowledge_manager or KnowledgeManager(WORKING_DIR) should_auto_collect_text = bool( - getattr(running, "auto_collect_long_text", False), + getattr(running, "knowledge_auto_collect_long_text", False), ) if should_auto_collect_text: diff --git a/src/copaw/config/config.py b/src/copaw/config/config.py index d14bb3c5a..474e9f81c 100644 --- a/src/copaw/config/config.py +++ b/src/copaw/config/config.py @@ -218,28 +218,33 @@ class AgentsRunningConfig(BaseModel): ), ) - auto_collect_chat_files: bool = Field( + knowledge_enabled: bool = Field( + default=True, + description="Master switch for knowledge features and operations", + ) + + knowledge_auto_collect_chat_files: bool = Field( default=True, description=( "Automatically collect file references in chat turns into knowledge sources" ), ) - auto_collect_chat_urls: bool = Field( + knowledge_auto_collect_chat_urls: bool = Field( default=True, description=( "Automatically collect URLs mentioned in chat turns into knowledge sources" ), ) - auto_collect_long_text: bool = Field( + knowledge_auto_collect_long_text: bool = Field( default=True, description=( "Automatically save long chat passages into text knowledge sources" ), ) - long_text_min_chars: int = Field( + knowledge_long_text_min_chars: int = Field( default=2000, ge=200, le=20000, @@ -623,7 +628,7 @@ class KnowledgeSourceSpec(BaseModel): enabled: bool = Field(default=True) recursive: bool = Field(default=True) tags: List[str] = Field(default_factory=list) - description: str = Field(default="") + summary: str = Field(default="") @model_validator(mode="after") def validate_source(self): @@ -673,10 +678,10 @@ class KnowledgeIndexConfig(BaseModel): class KnowledgeAutomationConfig(BaseModel): """Passive knowledge collection during chat turns.""" - auto_collect_chat_files: bool = Field(default=True) - auto_collect_chat_urls: bool = Field(default=True) - auto_collect_long_text: bool = Field(default=True) - long_text_min_chars: int = Field(default=2000, ge=200, le=20000) + knowledge_auto_collect_chat_files: bool = Field(default=True) + knowledge_auto_collect_chat_urls: bool = Field(default=True) + knowledge_auto_collect_long_text: bool = Field(default=True) + knowledge_long_text_min_chars: int = Field(default=2000, ge=200, le=20000) url_exclude_private_addresses: bool = Field( default=True, diff --git a/src/copaw/knowledge/manager.py b/src/copaw/knowledge/manager.py index 422089fc3..b3e7c6586 100644 --- a/src/copaw/knowledge/manager.py +++ b/src/copaw/knowledge/manager.py @@ -9,6 +9,7 @@ import logging import re import shutil +from collections import Counter from datetime import UTC, datetime, timedelta from html import unescape from pathlib import Path @@ -17,6 +18,16 @@ import httpx +try: + import jieba # type: ignore[import-not-found] +except Exception: # pragma: no cover - optional dependency + jieba = None + +try: + import hanlp # type: ignore[import-not-found] +except Exception: # pragma: no cover - optional dependency + hanlp = None + from ..constant import CHATS_FILE from ..config.config import KnowledgeConfig, KnowledgeSourceSpec @@ -75,6 +86,30 @@ "知识", "数据", } +_SEMANTIC_TOKEN_RE = re.compile(r"[A-Za-z][A-Za-z0-9_-]{2,}|[\u4e00-\u9fff]{2,}") +_SEMANTIC_STOP_WORDS = { + *_TITLE_STOP_WORDS, + "is", + "are", + "was", + "were", + "be", + "to", + "of", + "in", + "on", + "at", + "by", + "or", + "as", + "it", + "an", + "a", + "关键词", + "关键", + "词", +} +_KEYWORD_DEFAULT_TOP_N = 3 _TEXTUAL_CONTENT_TYPE_MARKERS = ( "text/", "application/json", @@ -120,6 +155,10 @@ def list_sources(self, config: KnowledgeConfig) -> list[dict[str, Any]]: results: list[dict[str, Any]] = [] for source in config.sources: payload = source.model_dump(mode="json") + processed = self._process_source_knowledge(source, config) + payload["subject"] = processed.get("subject") or source.name + payload["summary"] = processed.get("summary") or source.summary + payload["keywords"] = processed.get("keywords") or [] payload["status"] = self.get_source_status(source.id, source) results.append(payload) return results @@ -756,7 +795,9 @@ def auto_collect_user_message_assets( running_config: Any | None = None, ) -> dict[str, Any]: """Collect file/url knowledge immediately from user-sent content.""" - if not config.enabled: + if not config.enabled or not bool( + getattr(running_config, "knowledge_enabled", True) + ): return { "changed": False, "file_sources": 0, @@ -770,13 +811,13 @@ def auto_collect_user_message_assets( errors: list[dict[str, str]] = [] user_messages = list(request_messages or []) - auto_collect_chat_files = getattr(running_config, "auto_collect_chat_files", None) - if auto_collect_chat_files is None: - auto_collect_chat_files = config.automation.auto_collect_chat_files + knowledge_auto_collect_chat_files = getattr(running_config, "knowledge_auto_collect_chat_files", None) + if knowledge_auto_collect_chat_files is None: + knowledge_auto_collect_chat_files = config.automation.knowledge_auto_collect_chat_files - auto_collect_chat_urls = getattr(running_config, "auto_collect_chat_urls", None) - if auto_collect_chat_urls is None: - auto_collect_chat_urls = config.automation.auto_collect_chat_urls + knowledge_auto_collect_chat_urls = getattr(running_config, "knowledge_auto_collect_chat_urls", None) + if knowledge_auto_collect_chat_urls is None: + knowledge_auto_collect_chat_urls = config.automation.knowledge_auto_collect_chat_urls auto_collect_url_min_chars = int( getattr( running_config, @@ -786,7 +827,7 @@ def auto_collect_user_message_assets( or _AUTO_COLLECT_URL_MIN_CONTENT_CHARS ) - if auto_collect_chat_files: + if knowledge_auto_collect_chat_files: for source in self._build_file_sources_from_messages( user_messages, config, @@ -802,7 +843,7 @@ def auto_collect_user_message_assets( ) file_sources += 1 - if auto_collect_chat_urls: + if knowledge_auto_collect_chat_urls: for source in self._build_url_sources_from_messages( user_messages, session_id, @@ -841,7 +882,9 @@ def auto_collect_turn_text_pair( running_config: Any | None = None, ) -> dict[str, Any]: """Collect text knowledge after response, based on one user-assistant turn pair.""" - if not config.enabled: + if not config.enabled or not bool( + getattr(running_config, "knowledge_enabled", True) + ): return { "changed": False, "file_sources": 0, @@ -849,10 +892,10 @@ def auto_collect_turn_text_pair( "text_sources": 0, } - auto_collect_long_text = getattr(running_config, "auto_collect_long_text", None) - if auto_collect_long_text is None: - auto_collect_long_text = config.automation.auto_collect_long_text - if not auto_collect_long_text: + knowledge_auto_collect_long_text = getattr(running_config, "knowledge_auto_collect_long_text", None) + if knowledge_auto_collect_long_text is None: + knowledge_auto_collect_long_text = config.automation.knowledge_auto_collect_long_text + if not knowledge_auto_collect_long_text: return { "changed": False, "file_sources": 0, @@ -860,9 +903,9 @@ def auto_collect_turn_text_pair( "text_sources": 0, } - long_text_min_chars = getattr(running_config, "long_text_min_chars", None) - if not isinstance(long_text_min_chars, int): - long_text_min_chars = config.automation.long_text_min_chars + knowledge_long_text_min_chars = getattr(running_config, "knowledge_long_text_min_chars", None) + if not isinstance(knowledge_long_text_min_chars, int): + knowledge_long_text_min_chars = config.automation.knowledge_long_text_min_chars errors: list[dict[str, str]] = [] changed = False @@ -872,7 +915,7 @@ def auto_collect_turn_text_pair( response_messages=list(response_messages or []), session_id=session_id, user_id=user_id, - long_text_min_chars=long_text_min_chars, + knowledge_long_text_min_chars=knowledge_long_text_min_chars, ): if self._upsert_source(config, source): changed = True @@ -901,7 +944,9 @@ def auto_backfill_history_data( running_config: Any | None = None, ) -> dict[str, Any]: """Backfill historical chat-session data into knowledge sources once.""" - if not config.enabled: + if not config.enabled or not bool( + getattr(running_config, "knowledge_enabled", True) + ): self._save_backfill_progress( { "running": False, @@ -962,12 +1007,12 @@ def auto_backfill_history_data( } ) - long_text_min_chars = getattr(running_config, "long_text_min_chars", None) - if not isinstance(long_text_min_chars, int): - long_text_min_chars = config.automation.long_text_min_chars - auto_collect_chat_files = getattr(running_config, "auto_collect_chat_files", False) - auto_collect_chat_urls = getattr(running_config, "auto_collect_chat_urls", True) - auto_collect_long_text = getattr(running_config, "auto_collect_long_text", False) + knowledge_long_text_min_chars = getattr(running_config, "knowledge_long_text_min_chars", None) + if not isinstance(knowledge_long_text_min_chars, int): + knowledge_long_text_min_chars = config.automation.knowledge_long_text_min_chars + knowledge_auto_collect_chat_files = getattr(running_config, "knowledge_auto_collect_chat_files", False) + knowledge_auto_collect_chat_urls = getattr(running_config, "knowledge_auto_collect_chat_urls", True) + knowledge_auto_collect_long_text = getattr(running_config, "knowledge_auto_collect_long_text", False) auto_collect_url_min_chars = int( getattr( running_config, @@ -1007,7 +1052,7 @@ def auto_backfill_history_data( continue processed_sessions += 1 - if auto_collect_chat_files: + if knowledge_auto_collect_chat_files: for source in self._build_file_sources_from_messages( messages, config, @@ -1024,13 +1069,13 @@ def auto_backfill_history_data( ) file_sources += 1 - if auto_collect_long_text: + if knowledge_auto_collect_long_text: for source in self._build_text_sources_from_messages( messages, config, session_id, user_id, - long_text_min_chars, + knowledge_long_text_min_chars, ): upserted = self._upsert_source(config, source) if upserted: @@ -1043,7 +1088,7 @@ def auto_backfill_history_data( ) text_sources += 1 - if auto_collect_chat_urls: + if knowledge_auto_collect_chat_urls: for source in self._build_url_sources_from_messages( messages, session_id, @@ -1301,7 +1346,7 @@ def _build_file_sources_from_messages( enabled=True, recursive=False, tags=tags, - description=f"Auto-collected from chat session {session_id}", + summary=f"Auto-collected from chat session {session_id}", ), ) return sources @@ -1312,13 +1357,13 @@ def _build_text_sources_from_messages( config: KnowledgeConfig, session_id: str, user_id: str, - long_text_min_chars: int, + knowledge_long_text_min_chars: int, ) -> list[KnowledgeSourceSpec]: sources: list[KnowledgeSourceSpec] = [] seen_ids: set[str] = set() for role, text in self._iter_message_texts(messages): normalized = self._normalize_text(text) - if len(normalized) < long_text_min_chars: + if len(normalized) < knowledge_long_text_min_chars: continue digest = hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:12] source_id = f"auto-text-{digest}" @@ -1341,7 +1386,7 @@ def _build_text_sources_from_messages( "auto:text", f"role:{role}", ], - description=( + summary=( f"Auto-saved from {role} message in {session_id}" + (f" for {user_id}" if user_id else "") ), @@ -1355,7 +1400,7 @@ def _build_text_sources_from_turn_pair( response_messages: list[Any], session_id: str, user_id: str, - long_text_min_chars: int, + knowledge_long_text_min_chars: int, ) -> list[KnowledgeSourceSpec]: user_text = self._normalize_text( "\n".join( @@ -1375,7 +1420,7 @@ def _build_text_sources_from_turn_pair( return [] merged = self._normalize_text(f"用户: {user_text}\n\n智能体: {assistant_text}") - if len(merged) < long_text_min_chars: + if len(merged) < knowledge_long_text_min_chars: return [] digest = hashlib.sha1(merged.encode("utf-8")).hexdigest()[:12] @@ -1396,7 +1441,7 @@ def _build_text_sources_from_turn_pair( "auto:text", "role:turn_pair", ], - description=( + summary=( f"Auto-saved from user-assistant turn in {session_id}" + (f" for {user_id}" if user_id else "") ), @@ -1436,12 +1481,12 @@ def _build_url_sources_from_messages( # Capture surrounding text context from the conversation message # so title generation can use it without fetching the URL. context_snippet = self._extract_url_context(text, url, max_chars=400) - description = ( + summary = ( f"Auto-collected URL from {role} message in {session_id}" + (f" for {user_id}" if user_id else "") ) if context_snippet: - description = f"{description}\n来源上下文: {context_snippet}" + summary = f"{summary}\n来源上下文: {context_snippet}" sources.append( KnowledgeSourceSpec( id=source_id, @@ -1458,7 +1503,7 @@ def _build_url_sources_from_messages( "auto:url", f"role:{role}", ], - description=description, + summary=summary, ), ) return sources @@ -1781,17 +1826,17 @@ def _parse_iso_utc(value: Any) -> datetime | None: def _history_backfill_signature(self, running_config: Any | None) -> str: payload = { - "auto_collect_chat_files": bool( - getattr(running_config, "auto_collect_chat_files", False), + "knowledge_auto_collect_chat_files": bool( + getattr(running_config, "knowledge_auto_collect_chat_files", False), ), - "auto_collect_chat_urls": bool( - getattr(running_config, "auto_collect_chat_urls", True), + "knowledge_auto_collect_chat_urls": bool( + getattr(running_config, "knowledge_auto_collect_chat_urls", True), ), - "auto_collect_long_text": bool( - getattr(running_config, "auto_collect_long_text", False), + "knowledge_auto_collect_long_text": bool( + getattr(running_config, "knowledge_auto_collect_long_text", False), ), - "long_text_min_chars": int( - getattr(running_config, "long_text_min_chars", 2000), + "knowledge_long_text_min_chars": int( + getattr(running_config, "knowledge_long_text_min_chars", 2000), ), "knowledge_chunk_size": int( getattr(running_config, "knowledge_chunk_size", 1200), @@ -1960,12 +2005,12 @@ def _source_with_auto_name( updates: dict[str, Any] = {} source_for_title = source - if not (source.description or "").strip(): - generated_description = self._generate_source_description(source, config) - if generated_description: - updates["description"] = generated_description + if not (source.summary or "").strip(): + generated_summary = self._generate_source_summary(source, config) + if generated_summary: + updates["summary"] = generated_summary source_for_title = source.model_copy( - update={"description": generated_description} + update={"summary": generated_summary} ) generated = self._generate_source_name(source_for_title, config) @@ -1976,14 +2021,20 @@ def _source_with_auto_name( return source return source.model_copy(update=updates) - def _generate_source_description( + def _generate_source_summary( self, source: KnowledgeSourceSpec, config: KnowledgeConfig | None = None, ) -> str: - semantic = self._semantic_description_for_source(source, config) + semantic = self._semantic_summary_for_source(source, config) if semantic: - return self._truncate_description(semantic) + keywords = self._semantic_keywords_for_source(source, config) + if keywords: + summary_with_keywords = ( + f"{semantic} 关键词: {', '.join(keywords)}" + ) + return self._truncate_summary(summary_with_keywords) + return self._truncate_summary(semantic) if source.type == "url": url = (source.location or "").strip() @@ -1993,51 +2044,32 @@ def _generate_source_description( path = parsed.path.strip("/") tail = path.split("/")[-1] if path else "" if tail: - return self._truncate_description(f"{host}/{tail}") - return self._truncate_description(host) + return self._truncate_summary(f"{host}/{tail}") + return self._truncate_summary(host) if source.type in {"file", "directory"} and source.location: location = (source.location or "").strip() if location: - return self._truncate_description(Path(location).name or location) + return self._truncate_summary(Path(location).name or location) if source.name: - return self._truncate_description(source.name) + return self._truncate_summary(source.name) return "" - def _semantic_description_for_source( + def _semantic_summary_for_source( self, source: KnowledgeSourceSpec, config: KnowledgeConfig | None = None, ) -> str: - candidates: list[str] = [] - - full_text = self._collect_full_text_for_local_title(source, config) - if full_text: - candidates.append(full_text) - - indexed_payload = self._load_index_payload_safe(source.id) - if indexed_payload: - chunk_texts = [ - chunk.get("text", "") - for chunk in indexed_payload.get("chunks", []) - if isinstance(chunk, dict) and isinstance(chunk.get("text"), str) - ] - if chunk_texts: - candidates.append("\n".join(chunk_texts)) - - for candidate in candidates: - sentence = self._semantic_title_from_text(candidate) - if sentence: - return sentence - return "" + processed = self._process_source_knowledge(source, config) + return processed.get("summary", "") def _generate_source_name( self, source: KnowledgeSourceSpec, config: KnowledgeConfig | None = None, ) -> str: - semantic = self._semantic_title_for_source(source, config) + semantic = self._semantic_subject_for_source(source, config) if semantic: return self._truncate_title(semantic) @@ -2059,15 +2091,79 @@ def _generate_source_name( return self._truncate_title(source.id) - def _semantic_title_for_source( + def _semantic_subject_for_source( self, source: KnowledgeSourceSpec, config: KnowledgeConfig | None = None, ) -> str: + processed = self._process_source_knowledge(source, config) + return processed.get("subject", "") + + def _semantic_keywords_for_source( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + top_n: int = _KEYWORD_DEFAULT_TOP_N, + ) -> list[str]: + processed = self._process_source_knowledge(source, config, top_n=top_n) + return processed.get("keywords", []) + + def _process_source_knowledge( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + top_n: int = _KEYWORD_DEFAULT_TOP_N, + ) -> dict[str, Any]: + candidates = self._collect_source_processing_candidates(source, config) + merged = self._normalize_text("\n".join(part for part in candidates if part)) + processed = self._process_knowledge_text(merged, top_n=top_n) + + # Keep deterministic priority for subjects: summary > content > index/title. + for candidate in candidates: + subject = self._extract_subject_from_text(candidate) + if subject: + processed["subject"] = subject + break + return processed + + def _process_knowledge_text( + self, + text: str, + top_n: int = _KEYWORD_DEFAULT_TOP_N, + ) -> dict[str, Any]: + normalized = self._normalize_text(text or "") + if not normalized: + return { + "subject": "", + "summary": "", + "keywords": [], + } + + return { + "subject": self._extract_subject_from_text(normalized), + "summary": self._extract_summary_from_text(normalized), + "keywords": self._extract_keywords_from_text(normalized, top_n=top_n), + } + + def _collect_source_processing_text( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + candidates = self._collect_source_processing_candidates(source, config) + return self._normalize_text("\n".join(part for part in candidates if part)) + + def _collect_source_processing_candidates( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> list[str]: candidates: list[str] = [] - if source.description: - candidates.append(source.description) + if source.summary and source.summary.strip(): + candidates.append(source.summary) + if source.content and source.content.strip(): + candidates.append(source.content) indexed_payload = self._load_index_payload_safe(source.id) if indexed_payload: @@ -2082,49 +2178,23 @@ def _semantic_title_for_source( chunk_text = chunk.get("text") if isinstance(chunk_text, str) and chunk_text.strip(): chunk_texts.append(chunk_text) + if chunk_titles: candidates.append("\n".join(chunk_titles)) if chunk_texts: candidates.append("\n".join(chunk_texts)) - for candidate in candidates: - title = self._semantic_title_from_text(candidate) - if title: - return title - return "" - - def _collect_full_text_for_local_title( - self, - source: KnowledgeSourceSpec, - config: KnowledgeConfig | None = None, - ) -> str: - parts: list[str] = [] - - if source.content and source.content.strip(): - parts.append(source.content) - - indexed_payload = self._load_index_payload_safe(source.id) - if indexed_payload: - chunk_texts = [ - chunk.get("text", "") - for chunk in indexed_payload.get("chunks", []) - if isinstance(chunk, dict) and isinstance(chunk.get("text"), str) - ] - if chunk_texts: - parts.append("\n".join(chunk_texts)) - location = (source.location or "").strip() if source.type == "file" and location: full_text = self._read_local_text(Path(location)) if full_text: - parts.append(full_text) + candidates.append(full_text) elif source.type == "directory" and location: full_text = self._read_directory_text(Path(location), config) if full_text: - parts.append(full_text) + candidates.append(full_text) - merged = self._normalize_text("\n".join(part for part in parts if part)) - return merged + return candidates def _read_local_text(self, path: Path) -> str: try: @@ -2210,11 +2280,7 @@ def _semantic_title_from_text(self, text: str) -> str: token_freq: dict[str, int] = {} sentence_tokens: list[list[str]] = [] for sentence in sentences: - tokens = [ - tok.lower() - for tok in _TITLE_WORD_RE.findall(sentence) - if tok.lower() not in _TITLE_STOP_WORDS - ] + tokens = self._tokenize_text(sentence) sentence_tokens.append(tokens) for token in tokens: token_freq[token] = token_freq.get(token, 0) + 1 @@ -2235,6 +2301,66 @@ def _semantic_title_from_text(self, text: str) -> str: best_sentence = sentences[0] return self._normalize_text(best_sentence) + def _extract_subject_from_text(self, text: str) -> str: + return self._semantic_title_from_text(text) + + def _extract_summary_from_text(self, text: str) -> str: + # Keep the summary extractor independent for future tuning. + return self._semantic_title_from_text(text) + + @staticmethod + def _tokenize_text(text: str) -> list[str]: + normalized = re.sub(r"\s+", " ", (text or "").strip()) + if not normalized: + return [] + + raw_tokens: list[str] = [] + + if jieba is not None: + try: + raw_tokens = [str(tok) for tok in jieba.lcut(normalized)] + except Exception: + raw_tokens = [] + elif hanlp is not None: + for attr in ("tokenize", "tok"): + fn = getattr(hanlp, attr, None) + if not callable(fn): + continue + try: + result = fn(normalized) + if isinstance(result, list): + raw_tokens = [str(tok) for tok in result] + elif isinstance(result, tuple): + raw_tokens = [str(tok) for tok in result] + if raw_tokens: + break + except Exception: + continue + + if not raw_tokens: + raw_tokens = _SEMANTIC_TOKEN_RE.findall(normalized) + + tokens: list[str] = [] + for raw in raw_tokens: + token = str(raw).strip().lower() + if not token: + continue + if not _SEMANTIC_TOKEN_RE.fullmatch(token): + continue + if token in _SEMANTIC_STOP_WORDS: + continue + tokens.append(token) + return tokens + + def _extract_keywords_from_text(self, text: str, top_n: int = 3) -> list[str]: + tokens = self._tokenize_text(text) + if not tokens or top_n <= 0: + return [] + + freq = Counter(tokens) + ranked = sorted(freq.items(), key=lambda item: (-item[1], item[0])) + return [token for token, _ in ranked[:top_n]] + @staticmethod def _truncate_title(value: str, max_len: int = 120) -> str: compact = re.sub(r"\s+", " ", (value or "").strip()) @@ -2245,7 +2371,7 @@ def _truncate_title(value: str, max_len: int = 120) -> str: return compact[: max_len - 3].rstrip() + "..." @staticmethod - def _truncate_description(value: str, max_len: int = 180) -> str: + def _truncate_summary(value: str, max_len: int = 180) -> str: compact = re.sub(r"\s+", " ", (value or "").strip()) if not compact: return "" diff --git a/src/copaw/knowledge/module_skills.py b/src/copaw/knowledge/module_skills.py new file mode 100644 index 000000000..3d8a487f9 --- /dev/null +++ b/src/copaw/knowledge/module_skills.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +from pathlib import Path + +from ..agents.skills_manager import SkillService, sync_skill_dir_to_active + + +KNOWLEDGE_MODULE_SKILLS_DIR = Path(__file__).parent / "skills" +KNOWLEDGE_MODULE_SKILL_NAMES = ("knowledge_search_assistant",) + + +def sync_knowledge_module_skills(enabled: bool) -> None: + """Keep knowledge module skills aligned with the runtime enabled state.""" + for skill_name in KNOWLEDGE_MODULE_SKILL_NAMES: + if enabled: + skill_dir = KNOWLEDGE_MODULE_SKILLS_DIR / skill_name + if not sync_skill_dir_to_active(skill_dir, force=True): + raise RuntimeError( + f"Failed to enable knowledge module skill: {skill_name}" + ) + continue + + SkillService.disable_skill(skill_name) \ No newline at end of file diff --git a/src/copaw/knowledge/skills/knowledge_search_assistant/SKILL.md b/src/copaw/knowledge/skills/knowledge_search_assistant/SKILL.md new file mode 100644 index 000000000..448d481ab --- /dev/null +++ b/src/copaw/knowledge/skills/knowledge_search_assistant/SKILL.md @@ -0,0 +1,42 @@ +--- +name: knowledge_search_assistant +description: "Use knowledge_search proactively when the user is asking about existing project facts, process notes, prior decisions, archived materials, or whether the knowledge base already contains something relevant." +metadata: + { + "copaw": + { + "emoji": ":books:", + "requires": {} + } + } +--- + +# Knowledge Search Assistant + +Use this skill when the user's question is likely answered by existing knowledge base content. Prefer checking knowledge_search before answering from memory. + +## Trigger Signals + +- The user asks whether the knowledge base, docs, or prior notes already contain something. +- The user is asking for established project facts, conventions, workflows, or historical decisions. +- The user is requesting grounded recall rather than fresh synthesis. + +## Suggested Flow + +1. Extract a short search phrase from the user's request. +2. Call knowledge_search with the original question or a shorter query. +3. If the first search is weak, retry once with fewer keywords. +4. Answer from the retrieved evidence when available. +5. If nothing useful is found, say that no relevant knowledge was found. + +## Response Rules + +- Treat search hits as evidence and summarize them accurately. +- Do not present guesses as stored facts. +- If the user asks "do we already have this", answer the retrieval result first. + +## Do Not Use + +- Pure code editing, debugging, testing, or build tasks that need direct workspace inspection instead. +- General conversation that clearly does not depend on stored knowledge. +- Cases where the user explicitly says not to use the knowledge base. diff --git a/tests/unit/app/routers/test_knowledge.py b/tests/unit/app/routers/test_knowledge.py index ad8dd331e..294851090 100644 --- a/tests/unit/app/routers/test_knowledge.py +++ b/tests/unit/app/routers/test_knowledge.py @@ -34,42 +34,69 @@ def fake_save_config(new_config): return TestClient(app) -def test_upsert_source_auto_generates_description_when_empty( +def test_upsert_source_auto_generates_summary_when_empty( knowledge_api_client: TestClient, ): response = knowledge_api_client.put( "/knowledge/sources", json={ - "id": "text-auto-description", + "id": "text-auto-summary", "name": "Manual Name", "type": "text", "content": "Quarterly planning checklist and milestone review for the release train.", "enabled": True, "recursive": False, "tags": [], - "description": "", + "summary": "", }, ) assert response.status_code == 200 - generated = response.json()["description"] + generated = response.json()["summary"] assert generated -def test_upsert_source_title_prefers_description_over_content( +def test_upsert_source_auto_summary_includes_keywords( knowledge_api_client: TestClient, ): response = knowledge_api_client.put( "/knowledge/sources", json={ - "id": "text-title-from-description", + "id": "text-auto-keywords", "name": "Manual Name", "type": "text", - "content": "Very long internal content that should not be the direct title source.", + "content": ( + "支付系统 风控规则 更新。" + "支付系统 对账流程 优化。" + "支付系统 异常告警 升级。" + "风控规则 每日巡检。" + ), + "enabled": True, + "recursive": False, + "tags": [], + "summary": "", + }, + ) + + assert response.status_code == 200 + generated = response.json()["summary"] + assert "关键词:" in generated + + +def test_upsert_source_subject_prefers_summary_over_content( + knowledge_api_client: TestClient, +): + response = knowledge_api_client.put( + "/knowledge/sources", + json={ + "id": "text-subject-from-summary", + "name": "Manual Name", + "type": "text", + "content": "Very long internal content that should not be the direct subject source.", "enabled": True, "recursive": False, "tags": [], - "description": "Release checklist summary for sprint handoff", + "summary": "Release checklist summary for sprint handoff", }, ) @@ -77,20 +104,20 @@ def test_upsert_source_title_prefers_description_over_content( assert response.json()["name"] == "Release checklist summary for sprint handoff" -def test_list_sources_uses_auto_generated_titles_for_existing_config( +def test_list_sources_uses_auto_generated_subjects_for_existing_config( knowledge_api_client: TestClient, ): put_response = knowledge_api_client.put( "/knowledge/sources", json={ - "id": "url-auto-title", + "id": "url-auto-subject", "name": "Old URL Name", "type": "url", "location": "https://example.com/path/to/guide.html", "enabled": True, "recursive": False, "tags": ["remote"], - "description": "", + "summary": "", }, ) assert put_response.status_code == 200 @@ -101,6 +128,36 @@ def test_list_sources_uses_auto_generated_titles_for_existing_config( assert source["name"] == "example.com/guide.html" +def test_list_sources_returns_structured_summary_keywords( + knowledge_api_client: TestClient, +): + put_response = knowledge_api_client.put( + "/knowledge/sources", + json={ + "id": "text-structured-fields", + "name": "Manual Name", + "type": "text", + "content": ( + "支付系统 风控规则 更新。" + "支付系统 对账流程 优化。" + "支付系统 异常告警 升级。" + ), + "enabled": True, + "recursive": False, + "tags": [], + "summary": "", + }, + ) + assert put_response.status_code == 200 + + listing = knowledge_api_client.get("/knowledge/sources") + assert listing.status_code == 200 + source = listing.json()["sources"][0] + assert isinstance(source.get("subject"), str) + assert isinstance(source.get("summary"), str) + assert isinstance(source.get("keywords"), list) + + def test_history_backfill_status_includes_progress( knowledge_api_client: TestClient, ): @@ -136,6 +193,7 @@ def test_clear_knowledge_removes_sources_and_indexes( tmp_path: Path, ): config_payload = Config().knowledge.model_dump(mode="json") + config_payload["enabled"] = True config_payload["sources"] = [ { "id": "clear-1", @@ -146,7 +204,7 @@ def test_clear_knowledge_removes_sources_and_indexes( "enabled": True, "recursive": False, "tags": [], - "description": "", + "summary": "", } ] saved = knowledge_api_client.put("/knowledge/config", json=config_payload) @@ -230,6 +288,48 @@ def test_get_memify_job_status_success( assert payload["status"] in {"succeeded", "failed"} +def test_put_knowledge_config_syncs_running_toggle_and_module_skill( + knowledge_api_client: TestClient, + monkeypatch, +): + sync_calls: list[bool] = [] + monkeypatch.setattr( + knowledge_router_module, + "sync_knowledge_module_skills", + lambda enabled: sync_calls.append(enabled), + ) + + config_payload = Config().knowledge.model_dump(mode="json") + config_payload["enabled"] = False + + response = knowledge_api_client.put("/knowledge/config", json=config_payload) + + assert response.status_code == 200 + assert response.json()["enabled"] is False + assert sync_calls == [False] + + +def test_put_knowledge_config_does_not_resync_module_skill_when_toggle_unchanged( + knowledge_api_client: TestClient, + monkeypatch, +): + sync_calls: list[bool] = [] + monkeypatch.setattr( + knowledge_router_module, + "sync_knowledge_module_skills", + lambda enabled: sync_calls.append(enabled), + ) + + config_payload = Config().knowledge.model_dump(mode="json") + config_payload["enabled"] = True + + response = knowledge_api_client.put("/knowledge/config", json=config_payload) + + assert response.status_code == 200 + assert response.json()["enabled"] is True + assert sync_calls == [] + + def _build_knowledge_zip(entries: dict[str, str]) -> bytes: buf = io.BytesIO() with zipfile.ZipFile(buf, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: @@ -249,7 +349,7 @@ def _source_index_payload(source_id: str, content: str = "knowledge text") -> st "enabled": True, "recursive": False, "tags": [], - "description": "", + "summary": "", }, "documents": [ { diff --git a/tests/unit/app/runner/test_knowledge_context_injection.py b/tests/unit/app/runner/test_knowledge_context_injection.py index 6353576f9..7c8759c64 100644 --- a/tests/unit/app/runner/test_knowledge_context_injection.py +++ b/tests/unit/app/runner/test_knowledge_context_injection.py @@ -85,9 +85,9 @@ async def _stream_printing_messages(*args, **kwargs): running=SimpleNamespace( max_iters=8, max_input_length=8192, - auto_collect_chat_files=False, - auto_collect_chat_urls=False, - auto_collect_long_text=False, + knowledge_auto_collect_chat_files=False, + knowledge_auto_collect_chat_urls=False, + knowledge_auto_collect_long_text=False, ), ), knowledge=SimpleNamespace(enabled=True), @@ -160,8 +160,8 @@ async def _stream_printing_messages(*args, **kwargs): running=SimpleNamespace( max_iters=8, max_input_length=8192, - auto_collect_chat_files=False, - auto_collect_long_text=False, + knowledge_auto_collect_chat_files=False, + knowledge_auto_collect_long_text=False, knowledge_retrieval_enabled=False, ), ), diff --git a/tests/unit/app/test_agent_skill_registration.py b/tests/unit/app/test_agent_skill_registration.py new file mode 100644 index 000000000..3c1cb9966 --- /dev/null +++ b/tests/unit/app/test_agent_skill_registration.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import importlib +from pathlib import Path + + +class _FakeToolkit: + def __init__(self) -> None: + self.registered_skills: list[str] = [] + + def register_agent_skill(self, skill_dir: str) -> None: + self.registered_skills.append(skill_dir) +def test_register_skills_includes_builtin_knowledge_skill_when_enabled( + monkeypatch, + tmp_path: Path, +) -> None: + module = importlib.import_module("copaw.agents.react_agent") + + working_skills_dir = tmp_path / "working" + manual_skill_dir = working_skills_dir / "manual_skill" + knowledge_skill_dir = working_skills_dir / "knowledge_search_assistant" + manual_skill_dir.mkdir(parents=True) + knowledge_skill_dir.mkdir(parents=True) + + monkeypatch.setattr(module, "ensure_skills_initialized", lambda: None) + monkeypatch.setattr(module, "get_working_skills_dir", lambda: working_skills_dir) + monkeypatch.setattr( + module, + "list_available_skills", + lambda: ["knowledge_search_assistant", "manual_skill"], + ) + + toolkit = _FakeToolkit() + agent = object.__new__(module.CoPawAgent) + + agent._register_skills(toolkit) + + assert toolkit.registered_skills == [ + str(knowledge_skill_dir), + str(manual_skill_dir), + ] + + +def test_register_skills_skips_builtin_knowledge_skill_when_disabled( + monkeypatch, + tmp_path: Path, +) -> None: + module = importlib.import_module("copaw.agents.react_agent") + + working_skills_dir = tmp_path / "working" + manual_skill_dir = working_skills_dir / "manual_skill" + manual_skill_dir.mkdir(parents=True) + + monkeypatch.setattr(module, "ensure_skills_initialized", lambda: None) + monkeypatch.setattr(module, "get_working_skills_dir", lambda: working_skills_dir) + monkeypatch.setattr(module, "list_available_skills", lambda: ["manual_skill"]) + + toolkit = _FakeToolkit() + agent = object.__new__(module.CoPawAgent) + + agent._register_skills(toolkit) + + assert toolkit.registered_skills == [str(manual_skill_dir)] \ No newline at end of file diff --git a/tests/unit/app/test_graph_tools.py b/tests/unit/app/test_graph_tools.py index 27f7b2a8f..fa1089622 100644 --- a/tests/unit/app/test_graph_tools.py +++ b/tests/unit/app/test_graph_tools.py @@ -212,7 +212,7 @@ async def test_graph_tool_chain_smoke_local_engine( enabled=True, recursive=False, tags=["smoke"], - description="", + summary="", ) knowledge_config.sources.append(source) manager.index_source( diff --git a/tests/unit/app/test_knowledge_module_skills.py b/tests/unit/app/test_knowledge_module_skills.py new file mode 100644 index 000000000..620b2a6e0 --- /dev/null +++ b/tests/unit/app/test_knowledge_module_skills.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import importlib +from pathlib import Path + + +def test_sync_knowledge_module_skills_toggles_active_skill( + monkeypatch, + tmp_path: Path, +) -> None: + module = importlib.import_module("copaw.knowledge.module_skills") + skills_manager = importlib.import_module("copaw.agents.skills_manager") + + module_skills_dir = tmp_path / "module_skills" + knowledge_skill_dir = module_skills_dir / "knowledge_search_assistant" + knowledge_skill_dir.mkdir(parents=True) + (knowledge_skill_dir / "SKILL.md").write_text( + "---\nname: knowledge_search_assistant\ndescription: test\n---\n", + encoding="utf-8", + ) + + active_skills_dir = tmp_path / "active_skills" + monkeypatch.setattr(module, "KNOWLEDGE_MODULE_SKILLS_DIR", module_skills_dir) + monkeypatch.setattr(skills_manager, "ACTIVE_SKILLS_DIR", active_skills_dir) + + module.sync_knowledge_module_skills(True) + assert (active_skills_dir / "knowledge_search_assistant" / "SKILL.md").exists() + + module.sync_knowledge_module_skills(False) + assert not (active_skills_dir / "knowledge_search_assistant").exists() \ No newline at end of file From 84090bf40acfbc18862d3867eef74d800a66a3f7 Mon Sep 17 00:00:00 2001 From: Future Meng Date: Tue, 17 Mar 2026 21:48:05 +0800 Subject: [PATCH 59/68] refactor(knowledge): remove cognee compatibility from mvp branch --- src/copaw/config/config.py | 2 +- src/copaw/config/utils.py | 4 +--- src/copaw/knowledge/graph_ops.py | 21 +++++------------ tests/unit/app/test_graph_tools.py | 37 +++++++++++++++++++----------- 4 files changed, 31 insertions(+), 33 deletions(-) diff --git a/src/copaw/config/config.py b/src/copaw/config/config.py index 474e9f81c..4ef227e6f 100644 --- a/src/copaw/config/config.py +++ b/src/copaw/config/config.py @@ -702,7 +702,7 @@ class KnowledgeConfig(BaseModel): version: int = Field(default=1, ge=1) enabled: bool = Field(default=False) - engine: Literal["local_lexical", "cognee"] = Field(default="local_lexical") + engine: Literal["local_lexical"] = Field(default="local_lexical") graph_query_enabled: bool = Field(default=False) triplet_search_enabled: bool = Field(default=False) memify_enabled: bool = Field(default=False) diff --git a/src/copaw/config/utils.py b/src/copaw/config/utils.py index d615ae835..4dc638863 100644 --- a/src/copaw/config/utils.py +++ b/src/copaw/config/utils.py @@ -347,9 +347,7 @@ def load_config(config_path: Optional[Path] = None) -> Config: # Backward compat: knowledge.engine object -> literal enum string knowledge = data.get("knowledge") if isinstance(knowledge, dict) and isinstance(knowledge.get("engine"), dict): - legacy_engine = knowledge.get("engine") or {} - provider = str(legacy_engine.get("provider", "")).strip().lower() - knowledge["engine"] = "cognee" if provider == "cognee" else "local_lexical" + knowledge["engine"] = "local_lexical" return Config.model_validate(data) diff --git a/src/copaw/knowledge/graph_ops.py b/src/copaw/knowledge/graph_ops.py index 51ec60020..2d11517d2 100644 --- a/src/copaw/knowledge/graph_ops.py +++ b/src/copaw/knowledge/graph_ops.py @@ -2,8 +2,7 @@ """Bridge layer for graph-oriented knowledge operations. This module provides a lightweight manager used by graph tools. It keeps -current MVP behavior compatible while reserving integration points for -Cognee-backed implementations. +current MVP behavior compatible with the local lexical engine. """ from __future__ import annotations @@ -55,7 +54,7 @@ def graph_query( engine = getattr(config, "engine", "local_lexical") warnings: list[str] = [] - if query_mode == "cypher" and engine != "cognee": + if query_mode == "cypher": return GraphOpsResult( records=[], summary="Cypher mode is not available on local_lexical engine.", @@ -63,9 +62,6 @@ def graph_query( warnings=["CYPHER_UNAVAILABLE_ON_LOCAL_ENGINE"], ) - if engine == "cognee": - raise RuntimeError("Cognee graph provider is not wired yet.") - manager = KnowledgeManager(self.working_dir) search_result = manager.search( query=query_text, @@ -116,7 +112,7 @@ def run_memify( """Create a memify job record. The local lexical engine stores a no-op success job so tool contracts - and job observability can be validated before real Cognee wiring. + and job observability can be validated for MVP flows. """ jobs = self._load_memify_jobs() @@ -142,14 +138,9 @@ def run_memify( now = datetime.now(UTC).isoformat() engine = getattr(config, "engine", "local_lexical") - if engine == "cognee": - status = "failed" - error = "Cognee memify provider is not wired yet." - warnings = ["COGNEE_PROVIDER_NOT_READY"] - else: - status = "succeeded" - error = None - warnings = ["LOCAL_ENGINE_MEMIFY_NOOP"] + status = "succeeded" + error = None + warnings = ["LOCAL_ENGINE_MEMIFY_NOOP"] job_payload = { "job_id": job_id, diff --git a/tests/unit/app/test_graph_tools.py b/tests/unit/app/test_graph_tools.py index fa1089622..c64195f2d 100644 --- a/tests/unit/app/test_graph_tools.py +++ b/tests/unit/app/test_graph_tools.py @@ -9,14 +9,23 @@ from copaw.knowledge.manager import KnowledgeManager +def _mock_config(knowledge: SimpleNamespace) -> SimpleNamespace: + return SimpleNamespace( + knowledge=knowledge, + agents=SimpleNamespace( + running=SimpleNamespace(knowledge_enabled=True), + ), + ) + + async def test_graph_query_requires_graph_enabled(monkeypatch) -> None: module = importlib.import_module("copaw.agents.tools.graph_query") monkeypatch.setattr( module, "load_config", - lambda: SimpleNamespace( - knowledge=SimpleNamespace(enabled=True, graph_query_enabled=False), + lambda: _mock_config( + SimpleNamespace(enabled=True, graph_query_enabled=False), ), ) @@ -44,7 +53,7 @@ def graph_query(self, **kwargs): monkeypatch.setattr( module, "load_config", - lambda: SimpleNamespace( + lambda: _mock_config( knowledge=SimpleNamespace( enabled=True, graph_query_enabled=True, @@ -66,8 +75,8 @@ async def test_memify_run_requires_memify_enabled(monkeypatch) -> None: monkeypatch.setattr( module, "load_config", - lambda: SimpleNamespace( - knowledge=SimpleNamespace(enabled=True, memify_enabled=False), + lambda: _mock_config( + SimpleNamespace(enabled=True, memify_enabled=False), ), ) result = await module.memify_run() @@ -94,8 +103,8 @@ def run_memify(self, **kwargs): monkeypatch.setattr( module, "load_config", - lambda: SimpleNamespace( - knowledge=SimpleNamespace(enabled=True, memify_enabled=True), + lambda: _mock_config( + SimpleNamespace(enabled=True, memify_enabled=True), ), ) monkeypatch.setattr(module, "GraphOpsManager", _FakeGraphOpsManager) @@ -120,8 +129,8 @@ def get_memify_status(self, job_id: str): monkeypatch.setattr( module, "load_config", - lambda: SimpleNamespace( - knowledge=SimpleNamespace(enabled=True, memify_enabled=True), + lambda: _mock_config( + SimpleNamespace(enabled=True, memify_enabled=True), ), ) monkeypatch.setattr(module, "GraphOpsManager", _FakeGraphOpsManager) @@ -137,8 +146,8 @@ async def test_triplet_focus_search_requires_enabled(monkeypatch) -> None: monkeypatch.setattr( module, "load_config", - lambda: SimpleNamespace( - knowledge=SimpleNamespace(enabled=True, triplet_search_enabled=False), + lambda: _mock_config( + SimpleNamespace(enabled=True, triplet_search_enabled=False), ), ) @@ -175,8 +184,8 @@ def graph_query(self, **kwargs): monkeypatch.setattr( module, "load_config", - lambda: SimpleNamespace( - knowledge=SimpleNamespace(enabled=True, triplet_search_enabled=True), + lambda: _mock_config( + SimpleNamespace(enabled=True, triplet_search_enabled=True), ), ) monkeypatch.setattr(module, "GraphOpsManager", _FakeGraphOpsManager) @@ -221,7 +230,7 @@ async def test_graph_tool_chain_smoke_local_engine( SimpleNamespace(knowledge_chunk_size=knowledge_config.index.chunk_size), ) - load_config_stub = lambda: SimpleNamespace(knowledge=knowledge_config) + load_config_stub = lambda: _mock_config(knowledge_config) for module in [ graph_query_module, memify_run_module, From c1434f80b61b92e9ca069cdc47cb626e8120efb6 Mon Sep 17 00:00:00 2001 From: Future Meng Date: Tue, 17 Mar 2026 21:59:21 +0800 Subject: [PATCH 60/68] fix(knowledge): require runtime and config enable switches --- src/copaw/app/routers/knowledge.py | 4 ++- tests/unit/app/routers/test_knowledge.py | 31 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/copaw/app/routers/knowledge.py b/src/copaw/app/routers/knowledge.py index 246073caf..0fdb7f728 100644 --- a/src/copaw/app/routers/knowledge.py +++ b/src/copaw/app/routers/knowledge.py @@ -31,7 +31,9 @@ def _knowledge_runtime_enabled(config) -> bool: def _knowledge_effective_enabled(config) -> bool: - return _knowledge_runtime_enabled(config) + return _knowledge_runtime_enabled(config) and bool( + getattr(getattr(config, "knowledge", None), "enabled", False) + ) def _ensure_knowledge_enabled(config) -> None: diff --git a/tests/unit/app/routers/test_knowledge.py b/tests/unit/app/routers/test_knowledge.py index 294851090..6c199ba1d 100644 --- a/tests/unit/app/routers/test_knowledge.py +++ b/tests/unit/app/routers/test_knowledge.py @@ -17,6 +17,8 @@ @pytest.fixture def knowledge_api_client(tmp_path: Path, monkeypatch) -> TestClient: config = Config() + config.knowledge.enabled = True + config.agents.running.knowledge_enabled = True state = {"config": config} def fake_load_config(): @@ -188,6 +190,35 @@ def test_clear_knowledge_requires_confirmation(knowledge_api_client: TestClient) assert response.json()["detail"] == "KNOWLEDGE_CLEAR_CONFIRM_REQUIRED" +def test_list_sources_requires_effective_knowledge_enabled( + knowledge_api_client: TestClient, +): + config_payload = Config().knowledge.model_dump(mode="json") + config_payload["enabled"] = False + saved = knowledge_api_client.put("/knowledge/config", json=config_payload) + assert saved.status_code == 200 + + response = knowledge_api_client.get("/knowledge/sources") + assert response.status_code == 200 + assert response.json()["enabled"] is False + + blocked = knowledge_api_client.put( + "/knowledge/sources", + json={ + "id": "blocked-source", + "name": "blocked-source", + "type": "text", + "content": "should be blocked", + "enabled": True, + "recursive": False, + "tags": [], + "summary": "", + }, + ) + assert blocked.status_code == 400 + assert blocked.json()["detail"] == "KNOWLEDGE_DISABLED" + + def test_clear_knowledge_removes_sources_and_indexes( knowledge_api_client: TestClient, tmp_path: Path, From 088af4522cbbba9f2d75ae1008eb049043481177 Mon Sep 17 00:00:00 2001 From: zhijianma Date: Tue, 17 Mar 2026 22:21:52 +0800 Subject: [PATCH 61/68] feat(chat): add chat status management and reconnect functionality (#1672) --- console/src/api/modules/chat.ts | 7 + console/src/api/types/chat.ts | 4 + console/src/layouts/MainLayout/index.tsx | 63 +++----- console/src/pages/Chat/index.tsx | 167 +++++++++++++++---- console/src/pages/Chat/sessionApi/index.ts | 44 +++++- src/copaw/app/routers/console.py | 129 ++++++++++----- src/copaw/app/runner/api.py | 32 ++-- src/copaw/app/runner/manager.py | 1 + src/copaw/app/runner/models.py | 9 ++ src/copaw/app/runner/task_tracker.py | 176 +++++++++++++++++++++ src/copaw/app/workspace.py | 7 + 11 files changed, 520 insertions(+), 119 deletions(-) create mode 100644 src/copaw/app/runner/task_tracker.py diff --git a/console/src/api/modules/chat.ts b/console/src/api/modules/chat.ts index 422f12fca..7d354a29a 100644 --- a/console/src/api/modules/chat.ts +++ b/console/src/api/modules/chat.ts @@ -43,6 +43,13 @@ export const chatApi = { body: JSON.stringify(chatIds), }, ), + + /** Stop a running console chat (only stop when user clicks stop). chat_id = ChatSpec.id */ + stopConsoleChat: (chatId: string) => + request<{ stopped: boolean }>( + `/console/chat/stop?chat_id=${encodeURIComponent(chatId)}`, + { method: "POST" }, + ), }; export const sessionApi = { diff --git a/console/src/api/types/chat.ts b/console/src/api/types/chat.ts index 5846b0686..9b7610269 100644 --- a/console/src/api/types/chat.ts +++ b/console/src/api/types/chat.ts @@ -1,3 +1,5 @@ +export type ChatStatus = "idle" | "running"; + export interface ChatSpec { id: string; // Chat UUID identifier session_id: string; // Session identifier (channel:user_id format) @@ -6,6 +8,7 @@ export interface ChatSpec { created_at: string | null; // Chat creation timestamp (ISO 8601) updated_at: string | null; // Chat last update timestamp (ISO 8601) meta?: Record; // Additional metadata + status?: ChatStatus; // Conversation status: idle or running } export interface Message { @@ -16,6 +19,7 @@ export interface Message { export interface ChatHistory { messages: Message[]; + status?: ChatStatus; // Conversation status: idle or running } export interface ChatDeleteResponse { diff --git a/console/src/layouts/MainLayout/index.tsx b/console/src/layouts/MainLayout/index.tsx index 4198cf1ed..01e8188ae 100644 --- a/console/src/layouts/MainLayout/index.tsx +++ b/console/src/layouts/MainLayout/index.tsx @@ -1,6 +1,5 @@ import { Layout } from "antd"; -import { useEffect } from "react"; -import { Routes, Route, useLocation, useNavigate } from "react-router-dom"; +import { Routes, Route, useLocation, Navigate } from "react-router-dom"; import Sidebar from "../Sidebar"; import Header from "../Header"; import ConsoleCronBubble from "../../components/ConsoleCronBubble"; @@ -45,16 +44,8 @@ const pathToKey: Record = { export default function MainLayout() { const location = useLocation(); - const navigate = useNavigate(); const currentPath = location.pathname; const selectedKey = pathToKey[currentPath] || "chat"; - const isChatPage = currentPath === "/" || currentPath.startsWith("/chat"); - - useEffect(() => { - if (currentPath === "/") { - navigate("/chat", { replace: true }); - } - }, [currentPath, navigate]); return ( @@ -64,36 +55,28 @@ export default function MainLayout() {
-
- -
- {!isChatPage && ( - - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } - /> - - )} + + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } + /> +
diff --git a/console/src/pages/Chat/index.tsx b/console/src/pages/Chat/index.tsx index 8ecdd5c3d..b0d274fa7 100644 --- a/console/src/pages/Chat/index.tsx +++ b/console/src/pages/Chat/index.tsx @@ -1,6 +1,7 @@ import { AgentScopeRuntimeWebUI, IAgentScopeRuntimeWebUIOptions, + IAgentScopeRuntimeWebUIRef, } from "@agentscope-ai/chat"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { Button, Modal, Result, message } from "antd"; @@ -13,6 +14,8 @@ import defaultConfig, { getDefaultConfig } from "./OptionsPanel/defaultConfig"; import Weather from "./Weather"; import { getApiToken, getApiUrl } from "../../api/config"; import { providerApi } from "../../api/modules/provider"; +import { chatApi } from "../../api/modules/chat"; +import api from "../../api"; import ModelSelector from "./ModelSelector"; import { useTheme } from "../../contexts/ThemeContext"; import { useAgentStore } from "../../stores/agentStore"; @@ -122,6 +125,10 @@ export default function ChatPage() { const [showModelPrompt, setShowModelPrompt] = useState(false); const { selectedAgent } = useAgentStore(); const [refreshKey, setRefreshKey] = useState(0); + const [chatStatus, setChatStatus] = useState<"idle" | "running">("idle"); + const [, setReconnectStreaming] = useState(false); + const reconnectTriggeredForRef = useRef(null); + const prevChatIdRef = useRef(undefined); const isComposingRef = useRef(false); const isChatActiveRef = useRef(false); @@ -131,9 +138,15 @@ export default function ChatPage() { const lastSessionIdRef = useRef(null); const chatIdRef = useRef(chatId); const navigateRef = useRef(navigate); + const chatRef = useRef(null); chatIdRef.current = chatId; navigateRef.current = navigate; + useEffect(() => { + sessionApi.setChatRef(chatRef); + return () => sessionApi.setChatRef(null); + }, []); + useEffect(() => { const handleCompositionStart = () => { if (!isChatActiveRef.current) return; @@ -201,6 +214,34 @@ export default function ChatPage() { }; }, []); + // Fetch chat status when viewing a chat (for running indicator and reconnect) + useEffect(() => { + if (!chatId || chatId === "undefined" || chatId === "null") { + setChatStatus("idle"); + return; + } + const realId = sessionApi.getRealIdForSession(chatId) ?? chatId; + api.getChat(realId).then( + (res) => setChatStatus((res.status as "idle" | "running") ?? "idle"), + () => setChatStatus("idle"), + ); + }, [chatId]); + + // Trigger reconnect when session status becomes "running" so the library + // consumes the SSE stream. Done here (not in sessionApi.getSession) so we + // run after React has updated and the chat input ref is ready, avoiding + // a fixed timeout and race conditions. + useEffect(() => { + if (prevChatIdRef.current !== chatId) { + prevChatIdRef.current = chatId; + reconnectTriggeredForRef.current = null; + } + if (!chatId || chatStatus !== "running") return; + if (reconnectTriggeredForRef.current === chatId) return; + reconnectTriggeredForRef.current = chatId; + sessionApi.triggerReconnectSubmit(); + }, [chatId, chatStatus]); + // Refresh chat when selectedAgent changes const prevSelectedAgentRef = useRef(selectedAgent); useEffect(() => { @@ -285,10 +326,76 @@ export default function ChatPage() { const customFetch = useCallback( async (data: { - input: any[]; + input?: any[]; biz_params?: any; signal?: AbortSignal; + reconnect?: boolean; + session_id?: string; + user_id?: string; + channel?: string; }): Promise => { + const headers: Record = { + "Content-Type": "application/json", + }; + const token = getApiToken(); + if (token) headers.Authorization = `Bearer ${token}`; + try { + const agentStorage = localStorage.getItem("copaw-agent-storage"); + if (agentStorage) { + const parsed = JSON.parse(agentStorage); + const selectedAgent = parsed?.state?.selectedAgent; + if (selectedAgent) { + headers["X-Agent-Id"] = selectedAgent; + } + } + } catch (error) { + console.warn("Failed to get selected agent from storage:", error); + } + + const shouldReconnect = + data.reconnect || data.biz_params?.reconnect === true; + const reconnectSessionId = + data.session_id ?? window.currentSessionId ?? ""; + if (shouldReconnect && reconnectSessionId) { + const res = await fetch(getApiUrl("/console/chat"), { + method: "POST", + headers, + body: JSON.stringify({ + reconnect: true, + session_id: reconnectSessionId, + user_id: data.user_id ?? window.currentUserId ?? "default", + channel: data.channel ?? window.currentChannel ?? "console", + }), + }); + if (!res.ok || !res.body) return res; + const onStreamEnd = () => { + setChatStatus("idle"); + setReconnectStreaming(false); + }; + const stream = res.body; + const transformed = new ReadableStream({ + start(controller) { + const reader = stream.getReader(); + function pump() { + reader.read().then(({ done, value }) => { + if (done) { + controller.close(); + onStreamEnd(); + return; + } + controller.enqueue(value); + return pump(); + }); + } + pump(); + }, + }); + return new Response(transformed, { + headers: res.headers, + status: res.status, + }); + } + try { const activeModels = await providerApi.getActiveModels(); if ( @@ -303,46 +410,27 @@ export default function ChatPage() { return buildModelError(); } - const { input, biz_params } = data; + const { input = [], biz_params } = data; const session = input[input.length - 1]?.session || {}; + const sessionId = window.currentSessionId || session?.session_id || ""; const requestBody = { input: input.slice(-1), - session_id: window.currentSessionId || session?.session_id || "", + session_id: sessionId, user_id: window.currentUserId || session?.user_id || "default", channel: window.currentChannel || session?.channel || "console", stream: true, ...biz_params, }; - const headers: Record = { - "Content-Type": "application/json", - }; - const token = getApiToken(); - if (token) headers.Authorization = `Bearer ${token}`; - - // Add selected agent ID for multi-agent support - try { - const agentStorage = localStorage.getItem("copaw-agent-storage"); - if (agentStorage) { - const parsed = JSON.parse(agentStorage); - const selectedAgent = parsed?.state?.selectedAgent; - if (selectedAgent) { - headers["X-Agent-Id"] = selectedAgent; - } - } - } catch (error) { - console.warn("Failed to get selected agent from storage:", error); - } - - return fetch(defaultConfig?.api?.baseURL || getApiUrl("/console/chat"), { + return fetch(getApiUrl("/console/chat"), { method: "POST", headers, body: JSON.stringify(requestBody), signal: data.signal, }); }, - [], + [setChatStatus, setReconnectStreaming], ); const options = useMemo(() => { @@ -378,7 +466,17 @@ export default function ChatPage() { ...defaultConfig.api, fetch: customFetch, cancel(data: { session_id: string }) { - console.log(data); + const chatIdForStop = data?.session_id + ? sessionApi.getRealIdForSession(data.session_id) ?? data.session_id + : ""; + if (chatIdForStop) { + chatApi.stopConsoleChat(chatIdForStop).then( + () => setChatStatus("idle"), + (err) => { + console.error("stopConsoleChat failed:", err); + }, + ); + } }, }, actions: { @@ -403,8 +501,21 @@ export default function ChatPage() { }, [wrappedSessionApi, customFetch, copyResponse, t, isDark]); return ( -
- +
+
+ +
; /** Real backend UUID, used when id is overridden with a local timestamp. */ realId?: string; + /** Conversation status: idle or running (for reconnect). */ + status?: "idle" | "running"; } // --------------------------------------------------------------------------- @@ -190,6 +195,7 @@ const chatSpecToSession = (chat: ChatSpec): ExtendedSession => channel: chat.channel, messages: [], meta: chat.meta || {}, + status: chat.status ?? "idle", }) as ExtendedSession; /** Returns true when id is a pure numeric local timestamp (not a backend UUID). */ @@ -257,6 +263,35 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { */ onSessionRemoved: ((removedId: string) => void) | null = null; + /** + * Ref to the chat component so we can trigger submit with reconnect flag + * (library will call customFetch with biz_params.reconnect and consume the SSE stream). + */ + private chatRef: RefObject | null = null; + + setChatRef(ref: RefObject | null): void { + this.chatRef = ref; + } + + /** + * Programmatically trigger the library's submit with biz_params.reconnect so + * customFetch does POST /console/chat with reconnect:true and the library + * consumes the SSE stream (replay + live tail). + */ + triggerReconnectSubmit(): void { + const ref = this.chatRef?.current; + if (!ref?.input?.submit) { + console.warn("triggerReconnectSubmit: chatRef not available"); + return; + } + ref.input.submit({ + query: "", + biz_params: { + reconnect: true, + } as IAgentScopeRuntimeWebUIInputData["biz_params"], + }); + } + private createEmptySession(sessionId: string): ExtendedSession { window.currentSessionId = sessionId; window.currentUserId = DEFAULT_USER_ID; @@ -342,7 +377,11 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { this.sessionRequests.set(sessionId, requestPromise); try { - return await requestPromise; + const result = await requestPromise; + // Reconnect for running sessions is triggered by ChatPage when session + // status becomes "running" (useEffect on chatStatus), avoiding a fixed + // timeout and race conditions with the chat input ref. + return result; } finally { this.sessionRequests.delete(sessionId); } @@ -369,6 +408,7 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { messages: convertMessages(chatHistory.messages || []), meta: fromList.meta || {}, realId: fromList.realId, + status: chatHistory.status ?? "idle", }; this.updateWindowVariables(session); return session; @@ -404,6 +444,7 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { messages: convertMessages(chatHistory.messages || []), meta: refreshed.meta || {}, realId: refreshed.realId, + status: chatHistory.status ?? "idle", }; this.updateWindowVariables(session); return session; @@ -434,6 +475,7 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { channel: fromList?.channel || DEFAULT_CHANNEL, messages: convertMessages(chatHistory.messages || []), meta: fromList?.meta || {}, + status: chatHistory.status ?? "idle", }; this.updateWindowVariables(session); diff --git a/src/copaw/app/routers/console.py b/src/copaw/app/routers/console.py index 8adc04d2e..fc3c02748 100644 --- a/src/copaw/app/routers/console.py +++ b/src/copaw/app/routers/console.py @@ -10,35 +10,18 @@ from starlette.responses import StreamingResponse from agentscope_runtime.engine.schemas.agent_schemas import AgentRequest +from ..agent_context import get_agent_for_request logger = logging.getLogger(__name__) router = APIRouter(prefix="/console", tags=["console"]) -@router.post( - "/chat", - status_code=200, - summary="Chat with console (streaming response)", - description="Agent API Request Format. " - "See https://runtime.agentscope.io/en/protocol.html for " - "more details.", -) -async def post_console_chat( - request_data: Union[AgentRequest, dict], - request: Request, -) -> StreamingResponse: - """Accept a user message and stream the agent response. +def _extract_session_and_payload(request_data: Union[AgentRequest, dict]): + """Extract run_key (ChatSpec.id), session_id, and native payload. - Accepts AgentRequest or dict, builds native payload, and streams events - via channel.stream_one(). + run_key must be ChatSpec.id (chat_id) so it matches list_chats/get_chat. """ - - from ..agent_context import get_agent_for_request - - workspace = await get_agent_for_request(request) - - # Extract channel info from request if isinstance(request_data, AgentRequest): channel_id = request_data.channel or "console" sender_id = request_data.user_id or "default" @@ -47,30 +30,17 @@ async def post_console_chat( list(request_data.input[0].content) if request_data.input else [] ) else: - # Dict format - extract from request body channel_id = request_data.get("channel", "console") sender_id = request_data.get("user_id", "default") session_id = request_data.get("session_id", "default") input_data = request_data.get("input", []) - - # Extract content from input array content_parts = [] - if input_data and len(input_data) > 0: - last_msg = input_data[-1] - if hasattr(last_msg, "content"): - content_parts = list(last_msg.content or []) - elif isinstance(last_msg, dict) and "content" in last_msg: - content_parts = last_msg["content"] or [] - - # - console_channel = await workspace.channel_manager.get_channel("console") - if console_channel is None: - raise HTTPException( - status_code=503, - detail="Channel Console not found", - ) + for content_part in input_data: + if hasattr(content_part, "content"): + content_parts.extend(list(content_part.content or [])) + elif isinstance(content_part, dict) and "content" in content_part: + content_parts.extend(content_part["content"] or []) - # Build native payload native_payload = { "channel_id": channel_id, "sender_id": sender_id, @@ -80,10 +50,74 @@ async def post_console_chat( "user_id": sender_id, }, } + return native_payload + + +@router.post( + "/chat", + status_code=200, + summary="Chat with console (streaming response)", + description="Agent API Request Format. See runtime.agentscope.io. " + "Use body.reconnect=true to attach to a running stream.", +) +async def post_console_chat( + request_data: Union[AgentRequest, dict], + request: Request, +) -> StreamingResponse: + """Stream agent response. Run continues in background after disconnect. + Stop via POST /console/chat/stop. Reconnect with body.reconnect=true. + """ + workspace = await get_agent_for_request(request) + console_channel = await workspace.channel_manager.get_channel("console") + if console_channel is None: + raise HTTPException( + status_code=503, + detail="Channel Console not found", + ) + try: + native_payload = _extract_session_and_payload(request_data) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + session_id = console_channel.resolve_session_id( + sender_id=native_payload["sender_id"], + channel_meta=native_payload["meta"], + ) + name = "New Chat" + if len(native_payload["content_parts"]) > 0: + content = native_payload["content_parts"][0] + if content: + name = content.text[:10] + else: + name = "Media Message" + chat = await workspace.chat_manager.get_or_create_chat( + session_id, + native_payload["sender_id"], + native_payload["channel_id"], + name=name, + ) + tracker = workspace.task_tracker + + is_reconnect = False + if isinstance(request_data, dict): + is_reconnect = request_data.get("reconnect") is True + + if is_reconnect: + queue = await tracker.attach(chat.id) + if queue is None: + raise HTTPException( + status_code=404, + detail="No running chat for this session", + ) + else: + queue, _ = await tracker.attach_or_start( + chat.id, + native_payload, + console_channel.stream_one, + ) async def event_generator() -> AsyncGenerator[str, None]: try: - async for event_data in console_channel.stream_one(native_payload): + async for event_data in tracker.stream_from_queue(queue): yield event_data except Exception as e: logger.exception("Console chat stream error") @@ -99,6 +133,21 @@ async def event_generator() -> AsyncGenerator[str, None]: ) +@router.post( + "/chat/stop", + status_code=200, + summary="Stop running console chat", +) +async def post_console_chat_stop( + request: Request, + chat_id: str = Query(..., description="Chat id (ChatSpec.id) to stop"), +) -> dict: + """Stop the running chat. Only stops when called.""" + workspace = await get_agent_for_request(request) + stopped = await workspace.task_tracker.request_stop(chat_id) + return {"stopped": stopped} + + @router.get("/push-messages") async def get_push_messages( session_id: str | None = Query(None, description="Optional session id"), diff --git a/src/copaw/app/runner/api.py b/src/copaw/app/runner/api.py index 753cd1750..25761a9d5 100644 --- a/src/copaw/app/runner/api.py +++ b/src/copaw/app/runner/api.py @@ -18,6 +18,13 @@ router = APIRouter(prefix="/chats", tags=["chats"]) +async def get_workspace(request: Request): + """Get the workspace for the active agent.""" + from ..agent_context import get_agent_for_request + + return await get_agent_for_request(request) + + async def get_chat_manager( request: Request, ) -> ChatManager: @@ -32,9 +39,7 @@ async def get_chat_manager( Raises: HTTPException: If manager is not initialized """ - from ..agent_context import get_agent_for_request - - workspace = await get_agent_for_request(request) + workspace = await get_workspace(request) return workspace.chat_manager @@ -52,9 +57,7 @@ async def get_session( Raises: HTTPException: If session is not initialized """ - from ..agent_context import get_agent_for_request - - workspace = await get_agent_for_request(request) + workspace = await get_workspace(request) return workspace.runner.session @@ -63,6 +66,7 @@ async def list_chats( user_id: Optional[str] = Query(None, description="Filter by user ID"), channel: Optional[str] = Query(None, description="Filter by channel"), mgr: ChatManager = Depends(get_chat_manager), + workspace=Depends(get_workspace), ): """List all chats with optional filters. @@ -71,7 +75,13 @@ async def list_chats( channel: Optional channel name to filter chats mgr: Chat manager dependency """ - return await mgr.list_chats(user_id=user_id, channel=channel) + chats = await mgr.list_chats(user_id=user_id, channel=channel) + tracker = workspace.task_tracker + result = [] + for spec in chats: + status = await tracker.get_status(spec.id) + result.append(spec.model_copy(update={"status": status})) + return result @router.post("", response_model=ChatSpec) @@ -125,6 +135,7 @@ async def get_chat( chat_id: str, mgr: ChatManager = Depends(get_chat_manager), session: SafeJSONSession = Depends(get_session), + workspace=Depends(get_workspace), ): """Get detailed information about a specific chat by UUID. @@ -134,7 +145,7 @@ async def get_chat( session: SafeJSONSession dependency Returns: - ChatHistory with messages + ChatHistory with messages and status (idle/running) Raises: HTTPException: If chat not found (404) @@ -150,15 +161,16 @@ async def get_chat( chat_spec.session_id, chat_spec.user_id, ) + status = await workspace.task_tracker.get_status(chat_id) if not state: - return ChatHistory(messages=[]) + return ChatHistory(messages=[], status=status) memories = state.get("agent", {}).get("memory", []) memory = InMemoryMemory() memory.load_state_dict(memories) memories = await memory.get_memory() messages = agentscope_msg_to_message(memories) - return ChatHistory(messages=messages) + return ChatHistory(messages=messages, status=status) @router.put("/{chat_id}", response_model=ChatSpec) diff --git a/src/copaw/app/runner/manager.py b/src/copaw/app/runner/manager.py index 5415ab5dd..f7a51c9ca 100644 --- a/src/copaw/app/runner/manager.py +++ b/src/copaw/app/runner/manager.py @@ -126,6 +126,7 @@ async def get_or_create_chat( channel=channel, name=name, ) + logger.debug(f"get_or_create_chat: created spec={spec.id}") # Call internal create without lock (already locked) await self._repo.upsert_chat(spec) logger.debug( diff --git a/src/copaw/app/runner/models.py b/src/copaw/app/runner/models.py index ac39e3a0a..d00105fd7 100644 --- a/src/copaw/app/runner/models.py +++ b/src/copaw/app/runner/models.py @@ -41,12 +41,21 @@ class ChatSpec(BaseModel): default_factory=dict, description="Additional metadata", ) + status: str = Field( + default="idle", + description="Conversation status: idle or running", + exclude=True, + ) class ChatHistory(BaseModel): """Complete chat view with spec and state.""" messages: list[Message] = Field(default_factory=list) + status: str = Field( + default="idle", + description="Conversation status: idle or running", + ) class ChatsFile(BaseModel): diff --git a/src/copaw/app/runner/task_tracker.py b/src/copaw/app/runner/task_tracker.py new file mode 100644 index 000000000..140bcd2f5 --- /dev/null +++ b/src/copaw/app/runner/task_tracker.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +"""Task tracker for background runs: streaming, reconnect, multi-subscriber. + +run_key is ChatSpec.id (chat_id). Per run: task, queues, event buffer. +Reconnects get buffer replay + new events. Cleanup when task completes. +""" +from __future__ import annotations + +import asyncio +import json +import logging +import weakref +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator, Callable, Coroutine + +logger = logging.getLogger(__name__) + +_QUEUE_MAX_SIZE = 4096 +_SENTINEL = None + + +@dataclass +class _RunState: + """Per-run state (task, queues, buffer), guarded by tracker lock.""" + + task: asyncio.Future + queues: list[asyncio.Queue] = field(default_factory=list) + buffer: list[str] = field(default_factory=list) + + +class TaskTracker: + """Per-workspace tracker: run_key -> RunState. + + All mutations to _runs under _lock. Producer broadcasts under lock. + Dead queues pruned when full. + """ + + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._runs: dict[str, _RunState] = {} + + @property + def lock(self) -> asyncio.Lock: + return self._lock + + async def get_status(self, run_key: str) -> str: + """Return ``'idle'`` or ``'running'``.""" + async with self._lock: + state = self._runs.get(run_key) + if state is None or state.task.done(): + return "idle" + return "running" + + async def attach(self, run_key: str) -> asyncio.Queue | None: + """Attach to an existing run. + + Returns a new queue pre-filled with the event buffer, or ``None`` + if no run is active for *run_key*. + """ + async with self._lock: + state = self._runs.get(run_key) + if state is None or state.task.done(): + return None + q: asyncio.Queue = asyncio.Queue(maxsize=_QUEUE_MAX_SIZE) + for sse in state.buffer: + q.put_nowait(sse) + state.queues.append(q) + return q + + async def request_stop(self, run_key: str) -> bool: + """Cancel the run. Returns ``True`` if it was running.""" + async with self._lock: + state = self._runs.get(run_key) + if state is None or state.task.done(): + return False + state.task.cancel() + return True + + async def attach_or_start( + self, + run_key: str, + payload: Any, + stream_fn: Callable[..., Coroutine], + ) -> tuple[asyncio.Queue, bool]: + """Attach to an existing run or start a new one. + + Returns ``(queue, is_new_run)``. + """ + async with self._lock: + state = self._runs.get(run_key) + if state is not None and not state.task.done(): + q: asyncio.Queue = asyncio.Queue(maxsize=_QUEUE_MAX_SIZE) + for sse in state.buffer: + q.put_nowait(sse) + state.queues.append(q) + return q, False + + my_queue: asyncio.Queue = asyncio.Queue(maxsize=_QUEUE_MAX_SIZE) + run = _RunState( + task=asyncio.Future(), # placeholder, replaced below + queues=[my_queue], + buffer=[], + ) + self._runs[run_key] = run + + tracker_ref = weakref.ref(self) + + async def _producer() -> None: + try: + async for sse in stream_fn(payload): + tracker = tracker_ref() + if tracker is None: + return + async with tracker.lock: + run.buffer.append(sse) + alive: list[asyncio.Queue] = [] + for q in run.queues: + try: + q.put_nowait(sse) + alive.append(q) + except asyncio.QueueFull: + logger.warning( + "dropping subscriber queue (full) " + "run_key=%s", + run_key, + ) + # Prune dead queues (full = client not reading) + run.queues = alive + except asyncio.CancelledError: + logger.debug("run cancelled run_key=%s", run_key) + except Exception: + logger.exception("run error run_key=%s", run_key) + err_sse = ( + "data: " + f"{json.dumps({'error': 'internal server error'})}\n\n" + ) + tracker = tracker_ref() + if tracker is not None: + async with tracker.lock: + run.buffer.append(err_sse) + for q in run.queues: + try: + q.put_nowait(err_sse) + except asyncio.QueueFull: + pass + finally: + tracker = tracker_ref() + if tracker is not None: + async with tracker.lock: + for q in run.queues: + try: + q.put_nowait(_SENTINEL) + except asyncio.QueueFull: + pass + # pylint: disable=protected-access + tracker._runs.pop( + run_key, + None, + ) + + run.task = asyncio.create_task(_producer()) + return my_queue, True + + @staticmethod + async def stream_from_queue( + queue: asyncio.Queue, + ) -> AsyncGenerator[str, None]: + """Yield SSE strings from *queue* until the sentinel ``None``.""" + while True: + try: + event = await queue.get() + if event is _SENTINEL: + break + yield event + except asyncio.CancelledError: + break diff --git a/src/copaw/app/workspace.py b/src/copaw/app/workspace.py index 09e6d6aaf..832763f24 100644 --- a/src/copaw/app/workspace.py +++ b/src/copaw/app/workspace.py @@ -16,6 +16,7 @@ from typing import Optional, TYPE_CHECKING from .runner import AgentRunner +from .runner.task_tracker import TaskTracker from .channels.utils import make_process_from_runner from .mcp import MCPClientManager from .crons.manager import CronManager @@ -63,6 +64,7 @@ def __init__(self, agent_id: str, workspace_dir: str): self._config = None self._config_watcher = None self._mcp_config_watcher = None + self._task_tracker = TaskTracker() self._started = False logger.debug( @@ -99,6 +101,11 @@ def chat_manager(self): """Get chat manager instance.""" return self._chat_manager + @property + def task_tracker(self) -> TaskTracker: + """Get task tracker for background chat and reconnect.""" + return self._task_tracker + @property def config(self): """Get agent configuration.""" From 397743bb061f56a750c03bffe786d4cbf6129350 Mon Sep 17 00:00:00 2001 From: Weirui Kuang <39145382+rayrayraykk@users.noreply.github.com> Date: Tue, 17 Mar 2026 22:45:49 +0800 Subject: [PATCH 62/68] feat: fix wecom version (#1681) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 385a8ea33..fd016cd52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "pywebview>=4.0", "aiofiles>=24.1.0", "paho-mqtt>=2.0.0", - "wecom-aibot-sdk @ https://agentscope.oss-cn-zhangjiakou.aliyuncs.com/pre_whl/wecom_aibot_sdk-1.0.0-py3-none-any.whl", + "wecom-aibot-sdk>=1.0.0", "matrix-nio>=0.24.0", "shortuuid>=1.0.0", "google-genai>=1.67.0", From 94d8d03ecbcb06c3352cd053f753d4e0f789e4d7 Mon Sep 17 00:00:00 2001 From: Weirui Kuang <39145382+rayrayraykk@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:23:41 +0800 Subject: [PATCH 63/68] fix: upgrade agentscope-runtime to 1.1.1 and bump version to 0.1.0b2 (#1684) --- pyproject.toml | 2 +- src/copaw/__version__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fd016cd52..44418928e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" requires-python = ">=3.10,<3.14" dependencies = [ "agentscope==1.0.17", - "agentscope-runtime==1.1.0", + "agentscope-runtime==1.1.1", "httpx>=0.27.0", "packaging>=24.0", "discord-py>=2.3", diff --git a/src/copaw/__version__.py b/src/copaw/__version__.py index c4ecba7a3..ef3ebc569 100644 --- a/src/copaw/__version__.py +++ b/src/copaw/__version__.py @@ -1,2 +1,2 @@ # -*- coding: utf-8 -*- -__version__ = "0.1.0b1" +__version__ = "0.1.0b2" From 1635acdda4d1174b1fa7486ce52f53ea1b8f1820 Mon Sep 17 00:00:00 2001 From: Yuexiang XIE Date: Wed, 18 Mar 2026 10:28:59 +0800 Subject: [PATCH 64/68] chore: bumping version to 0.1.0b3 (#1688) --- src/copaw/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/copaw/__version__.py b/src/copaw/__version__.py index ef3ebc569..30316b628 100644 --- a/src/copaw/__version__.py +++ b/src/copaw/__version__.py @@ -1,2 +1,2 @@ # -*- coding: utf-8 -*- -__version__ = "0.1.0b2" +__version__ = "0.1.0b3" From 3dc706d46c347672d2f23ec8876916ec330f0dd3 Mon Sep 17 00:00:00 2001 From: Yuexiang XIE Date: Wed, 18 Mar 2026 10:34:42 +0800 Subject: [PATCH 65/68] feat(console): update multi language in console (#1686) --- console/src/locales/en.json | 9 ++++++++- console/src/locales/ja.json | 9 ++++++++- console/src/locales/ru.json | 9 ++++++++- console/src/locales/zh.json | 9 ++++++++- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/console/src/locales/en.json b/console/src/locales/en.json index b00bdb3fb..5448b3b83 100644 --- a/console/src/locales/en.json +++ b/console/src/locales/en.json @@ -739,7 +739,14 @@ "allParams": "All parameters", "descriptions": { "TOOL_CMD_DANGEROUS_RM": "Detects 'rm' command that may cause data loss", - "TOOL_CMD_DANGEROUS_MV": "Detects 'mv' command that may move or overwrite files unexpectedly" + "TOOL_CMD_DANGEROUS_MV": "Detects 'mv' command that may move or overwrite files unexpectedly", + "TOOL_CMD_FS_DESTRUCTION": "Detects low-level disk formatting or wiping commands", + "TOOL_CMD_DOS_FORK_BOMB": "Detects classic Bash fork bombs and mass process termination", + "TOOL_CMD_PIPE_TO_SHELL": "Detects 'curl | bash' patterns used to download and immediately execute remote payloads", + "TOOL_CMD_REVERSE_SHELL": "Detects attempts to establish reverse shells or unauthorized network tunnels", + "TOOL_CMD_SYSTEM_TAMPERING": "Detects access to cron jobs, SSH keys, or sudo permissions (including reads and modifications)", + "TOOL_CMD_UNSAFE_PERMISSIONS": "Detects global permission downgrades (chmod 777) or setting immutable flags", + "TOOL_CMD_OBFUSCATED_EXEC": "Detects execution of base64 encoded strings passed directly to a shell interpreter" } }, "skillScanner": { diff --git a/console/src/locales/ja.json b/console/src/locales/ja.json index aae806e9b..9913c8044 100644 --- a/console/src/locales/ja.json +++ b/console/src/locales/ja.json @@ -691,7 +691,14 @@ "allParams": "すべてのパラメータ", "descriptions": { "TOOL_CMD_DANGEROUS_RM": "データ損失を引き起こす可能性のある rm コマンドを検出", - "TOOL_CMD_DANGEROUS_MV": "ファイルを意図せず移動・上書きする可能性のある mv コマンドを検出" + "TOOL_CMD_DANGEROUS_MV": "ファイルを意図せず移動・上書きする可能性のある mv コマンドを検出", + "TOOL_CMD_FS_DESTRUCTION": "低レベルのディスクフォーマットまたはワイプコマンドを検出", + "TOOL_CMD_DOS_FORK_BOMB": "Bash フォーク爆弾および大量プロセス終了を検出", + "TOOL_CMD_PIPE_TO_SHELL": "リモートペイロードをダウンロードして即座に実行する 'curl | bash' パターンを検出", + "TOOL_CMD_REVERSE_SHELL": "リバースシェルまたは不正なネットワークトンネルの確立を検出", + "TOOL_CMD_SYSTEM_TAMPERING": "cron ジョブ、SSH キー、sudo 権限へのアクセス(読み取り・変更を含む)を検出", + "TOOL_CMD_UNSAFE_PERMISSIONS": "グローバル権限の引き下げ(chmod 777)や不変フラグの設定を検出", + "TOOL_CMD_OBFUSCATED_EXEC": "base64 エンコード文字列をシェルインタプリタに直接渡して実行する操作を検出" } }, "skillScanner": { diff --git a/console/src/locales/ru.json b/console/src/locales/ru.json index ae5a89a01..e7750c1a2 100644 --- a/console/src/locales/ru.json +++ b/console/src/locales/ru.json @@ -696,7 +696,14 @@ "allParams": "Все параметры", "descriptions": { "TOOL_CMD_DANGEROUS_RM": "Обнаруживает команду rm, которая может привести к потере данных", - "TOOL_CMD_DANGEROUS_MV": "Обнаруживает команду mv, которая может неожиданно переместить или перезаписать файлы" + "TOOL_CMD_DANGEROUS_MV": "Обнаруживает команду mv, которая может неожиданно переместить или перезаписать файлы", + "TOOL_CMD_FS_DESTRUCTION": "Обнаруживает команды низкоуровневого форматирования или очистки дисков", + "TOOL_CMD_DOS_FORK_BOMB": "Обнаруживает классические Bash fork-бомбы и массовое завершение процессов", + "TOOL_CMD_PIPE_TO_SHELL": "Обнаруживает паттерны 'curl | bash' для загрузки и немедленного выполнения удалённых скриптов", + "TOOL_CMD_REVERSE_SHELL": "Обнаруживает попытки создания обратных оболочек или несанкционированных сетевых туннелей", + "TOOL_CMD_SYSTEM_TAMPERING": "Обнаруживает доступ к заданиям cron, SSH-ключам или правам sudo (включая чтение и изменение)", + "TOOL_CMD_UNSAFE_PERMISSIONS": "Обнаруживает глобальное понижение прав доступа (chmod 777) или установку неизменяемых флагов", + "TOOL_CMD_OBFUSCATED_EXEC": "Обнаруживает выполнение строк в кодировке base64, передаваемых напрямую в интерпретатор оболочки" } }, "skillScanner": { diff --git a/console/src/locales/zh.json b/console/src/locales/zh.json index 74a5ada6d..33eafe2b1 100644 --- a/console/src/locales/zh.json +++ b/console/src/locales/zh.json @@ -705,7 +705,14 @@ "allParams": "所有参数", "descriptions": { "TOOL_CMD_DANGEROUS_RM": "检测可能导致数据丢失的 rm 命令", - "TOOL_CMD_DANGEROUS_MV": "检测可能意外移动或覆盖文件的 mv 命令" + "TOOL_CMD_DANGEROUS_MV": "检测可能意外移动或覆盖文件的 mv 命令", + "TOOL_CMD_FS_DESTRUCTION": "检测低级别磁盘格式化或擦除命令", + "TOOL_CMD_DOS_FORK_BOMB": "检测经典 Bash Fork 炸弹和批量进程终止", + "TOOL_CMD_PIPE_TO_SHELL": "检测通过 'curl | bash' 模式下载并立即执行远程载荷的行为", + "TOOL_CMD_REVERSE_SHELL": "检测建立反向 Shell 或未授权网络隧道的行为", + "TOOL_CMD_SYSTEM_TAMPERING": "检测对定时任务、SSH 密钥或 sudo 权限的访问(包括读取和修改)", + "TOOL_CMD_UNSAFE_PERMISSIONS": "检测全局权限降级(chmod 777)或设置不可变标志的操作", + "TOOL_CMD_OBFUSCATED_EXEC": "检测将 base64 编码字符串直接传递给 Shell 解释器执行的行为" } }, "skillScanner": { From 053307c302445ebe0eacebf4b8d73c2cf2730ee5 Mon Sep 17 00:00:00 2001 From: zhaozhuang521 <71918264+zhaozhuang521@users.noreply.github.com> Date: Wed, 18 Mar 2026 10:53:22 +0800 Subject: [PATCH 66/68] feat(console): Optimize document navigation anchors and their language. (#1707) Co-authored-by: zhaozhuang --- console/src/layouts/Sidebar.tsx | 19 +++-- .../Channels/components/ChannelDrawer.tsx | 83 +++++++++++++------ 2 files changed, 72 insertions(+), 30 deletions(-) diff --git a/console/src/layouts/Sidebar.tsx b/console/src/layouts/Sidebar.tsx index 4c8cd1f6e..fae4ea04c 100644 --- a/console/src/layouts/Sidebar.tsx +++ b/console/src/layouts/Sidebar.tsx @@ -235,9 +235,19 @@ export default function Sidebar({ selectedKey }: SidebarProps) { .then((res) => res.json()) .then((data) => { const releases = data?.releases ?? {}; - // Sort versions by upload_time (newest first) - const versionsWithTime = Object.entries(releases).map( - ([version, files]) => { + + // Filter out pre-release versions (alpha, beta, rc, dev, etc.) + const isStableVersion = (version: string) => { + // Pre-release indicators: a, alpha, b, beta, rc, c, candidate, dev, post + const preReleasePattern = /(a|alpha|b|beta|rc|c|candidate|dev)\d*/i; + // Also check for prerelease field in package info + return !preReleasePattern.test(version); + }; + + // Sort versions by upload_time (newest first), only include stable versions + const versionsWithTime = Object.entries(releases) + .filter(([version]) => isStableVersion(version)) + .map(([version, files]) => { const fileList = files as Array<{ upload_time_iso_8601?: string }>; // Get the latest upload time among all files for this version const latestUpload = fileList @@ -246,8 +256,7 @@ export default function Sidebar({ selectedKey }: SidebarProps) { .sort() .pop(); return { version, uploadTime: latestUpload || "" }; - }, - ); + }); versionsWithTime.sort( (a, b) => new Date(b.uploadTime).getTime() - new Date(a.uploadTime).getTime(), diff --git a/console/src/pages/Control/Channels/components/ChannelDrawer.tsx b/console/src/pages/Control/Channels/components/ChannelDrawer.tsx index a64fcc998..dab930490 100644 --- a/console/src/pages/Control/Channels/components/ChannelDrawer.tsx +++ b/console/src/pages/Control/Channels/components/ChannelDrawer.tsx @@ -37,24 +37,41 @@ interface ChannelDrawerProps { onSubmit: (values: Record) => void; } -// Doc URLs per channel (anchors on https://copaw.agentscope.io/docs/channels) -const CHANNEL_DOC_URLS: Partial> = { +// Doc EN URLs per channel (anchors on https://copaw.agentscope.io/docs/channels) +const CHANNEL_DOC_EN_URLS: Partial> = { dingtalk: - "https://copaw.agentscope.io/docs/channels/#%E9%92%89%E9%92%89%E6%8E%A8%E8%8D%90", - feishu: "https://copaw.agentscope.io/docs/channels/#%E9%A3%9E%E4%B9%A6", + "https://copaw.agentscope.io/docs/channels/?lang=en#DingTalk-recommended", + feishu: "https://copaw.agentscope.io/docs/channels/?lang=en#Feishu-Lark", imessage: - "https://copaw.agentscope.io/docs/channels/#iMessage%E4%BB%85-macOS", - discord: "https://copaw.agentscope.io/docs/channels/#Discord", - qq: "https://copaw.agentscope.io/docs/channels/#QQ", - telegram: "https://copaw.agentscope.io/docs/channels/#Telegram", - mqtt: "https://copaw.agentscope.io/docs/channels/#MQTT", - mattermost: "https://copaw.agentscope.io/docs/channels/#Mattermost", - matrix: "https://copaw.agentscope.io/docs/channels/#Matrix", - wecom: - "https://copaw.agentscope.io/docs/channels/#%E4%BC%81%E4%B8%9A%E5%BE%AE%E4%BF%A1", + "https://copaw.agentscope.io/docs/channels/?lang=en#iMessage-macOS-only", + discord: "https://copaw.agentscope.io/docs/channels/?lang=en#Discord", + qq: "https://copaw.agentscope.io/docs/channels/?lang=en#QQ", + telegram: "https://copaw.agentscope.io/docs/channels/?lang=en#Telegram", + mqtt: "https://copaw.agentscope.io/docs/channels/?lang=en#MQTT", + mattermost: "https://copaw.agentscope.io/docs/channels/?lang=en#Mattermost", + matrix: "https://copaw.agentscope.io/docs/channels/?lang=en#Matrix", + wecom: "https://copaw.agentscope.io/docs/channels/?lang=en#WeCom-WeChat-Work", xiaoyi: "https://developer.huawei.com/consumer/cn/doc/service/openclaw-0000002518410344", }; + +// Doc ZH URLs per channel (anchors on https://copaw.agentscope.io/docs/channels) +const CHANNEL_DOC_ZH_URLS: Partial> = { + dingtalk: "https://copaw.agentscope.io/docs/channels/?lang=zh#钉钉推荐", + feishu: "https://copaw.agentscope.io/docs/channels/?lang=zh#飞书", + imessage: + "https://copaw.agentscope.io/docs/channels/?lang=zh#iMessage仅-macOS", + discord: "https://copaw.agentscope.io/docs/channels/?lang=zh#Discord", + qq: "https://copaw.agentscope.io/docs/channels/?lang=zh#QQ", + telegram: "https://copaw.agentscope.io/docs/channels/?lang=zh#Telegram", + mqtt: "https://copaw.agentscope.io/docs/channels/?lang=zh#MQTT", + mattermost: "https://copaw.agentscope.io/docs/channels/?lang=zh#Mattermost", + matrix: "https://copaw.agentscope.io/docs/channels/?lang=zh#Matrix", + wecom: "https://copaw.agentscope.io/docs/channels/?lang=zh#企业微信", + xiaoyi: + "https://developer.huawei.com/consumer/cn/doc/service/openclaw-0000002518410344", +}; + const twilioConsoleUrl = "https://console.twilio.com"; export function ChannelDrawer({ @@ -68,7 +85,8 @@ export function ChannelDrawer({ onClose, onSubmit, }: ChannelDrawerProps) { - const { t } = useTranslation(); + const { t, i18n } = useTranslation(); + const currentLang = i18n.language?.startsWith("zh") ? "zh" : "en"; const label = activeKey ? getChannelLabel(activeKey) : activeLabel; const renderAccessControlFields = () => ( @@ -618,17 +636,32 @@ export function ChannelDrawer({ ? `${label} ${t("channels.settings")}` : t("channels.channelSettings")} - {activeKey && CHANNEL_DOC_URLS[activeKey] && ( - - )} + {activeKey && + CHANNEL_DOC_EN_URLS[activeKey] && + CHANNEL_DOC_ZH_URLS[activeKey] && ( + + )} {activeKey === "voice" && (