diff --git a/.github/workflows/test-build.yaml b/.github/workflows/beta.yaml similarity index 60% rename from .github/workflows/test-build.yaml rename to .github/workflows/beta.yaml index f4cac61a..b78aed1b 100644 --- a/.github/workflows/test-build.yaml +++ b/.github/workflows/beta.yaml @@ -1,4 +1,4 @@ -name: Test Deploy +name: Beta on: push: @@ -6,27 +6,28 @@ on: - test jobs: - # 准备阶段:设置构建变量 + # Setup: Build variables setup: + name: Setup runs-on: ubuntu-latest outputs: - build_start: ${{ steps.build_setup.outputs.build_start }} - beta_tag: ${{ steps.build_setup.outputs.beta_tag }} - commit_author: ${{ steps.build_setup.outputs.commit_author }} - commit_email: ${{ steps.build_setup.outputs.commit_email }} - commit_message: ${{ steps.build_setup.outputs.commit_message }} - commit_sha: ${{ steps.build_setup.outputs.commit_sha }} - commit_sha_short: ${{ steps.build_setup.outputs.commit_sha_short }} - commit_date: ${{ steps.build_setup.outputs.commit_date }} + build_start: ${{ steps.setup.outputs.build_start }} + beta_version: ${{ steps.setup.outputs.beta_version }} + commit_author: ${{ steps.setup.outputs.commit_author }} + commit_email: ${{ steps.setup.outputs.commit_email }} + commit_message: ${{ steps.setup.outputs.commit_message }} + commit_sha: ${{ steps.setup.outputs.commit_sha }} + commit_sha_short: ${{ steps.setup.outputs.commit_sha_short }} + commit_date: ${{ steps.setup.outputs.commit_date }} steps: - - name: Checkout code + - name: Checkout uses: actions/checkout@v4 - name: Setup build variables - id: build_setup + id: setup run: | echo "build_start=$(date '+%Y-%m-%d %H:%M:%S')" >> $GITHUB_OUTPUT - echo "beta_tag=BETA.$(date -u '+%Y-%m-%dT%H-%M-%SZ')" >> $GITHUB_OUTPUT + echo "beta_version=beta-$(date -u '+%Y%m%d%H%M%S')-$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT echo "commit_author=$(git log -1 --pretty=format:'%an')" >> $GITHUB_OUTPUT echo "commit_email=$(git log -1 --pretty=format:'%ae')" >> $GITHUB_OUTPUT echo "commit_message=$(git log -1 --pretty=format:'%s')" >> $GITHUB_OUTPUT @@ -34,12 +35,15 @@ jobs: echo "commit_sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT echo "commit_date=$(git log -1 --pretty=format:'%cd' --date=format:'%Y-%m-%d %H:%M:%S')" >> $GITHUB_OUTPUT - # 并行构建:Service 镜像 + # Parallel build: Service image build-service: + name: Build Service runs-on: ubuntu-latest needs: setup + env: + BETA_VERSION: ${{ needs.setup.outputs.beta_version }} steps: - - name: Checkout code + - name: Checkout uses: actions/checkout@v4 - name: Set up Docker Buildx @@ -52,19 +56,28 @@ jobs: username: ${{ secrets.SCIENCEOL_REGISTRY_USERNAME }} password: ${{ secrets.SCIENCEOL_REGISTRY_PASSWORD }} - - name: Build and push Service Docker image - run: | - docker buildx build service \ - -t registry.sciol.ac.cn/sciol/xyzen-service:test \ - -t registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_tag }} \ - --push - - # 并行构建:Web 镜像 + - name: Build and push Service image + uses: docker/build-push-action@v6 + with: + context: ./service + push: true + build-args: | + XYZEN_VERSION=${{ env.BETA_VERSION }} + XYZEN_COMMIT_SHA=${{ github.sha }} + XYZEN_BUILD_TIME=${{ needs.setup.outputs.commit_date }} + tags: | + registry.sciol.ac.cn/sciol/xyzen-service:beta + registry.sciol.ac.cn/sciol/xyzen-service:${{ env.BETA_VERSION }} + + # Parallel build: Web image build-web: + name: Build Web runs-on: ubuntu-latest needs: setup + env: + BETA_VERSION: ${{ needs.setup.outputs.beta_version }} steps: - - name: Checkout code + - name: Checkout uses: actions/checkout@v4 - name: Set up Docker Buildx @@ -77,57 +90,70 @@ jobs: username: ${{ secrets.SCIENCEOL_REGISTRY_USERNAME }} password: ${{ secrets.SCIENCEOL_REGISTRY_PASSWORD }} - - name: Build and push Web Docker image - run: | - docker buildx build web \ - -t registry.sciol.ac.cn/sciol/xyzen-web:test \ - -t registry.sciol.ac.cn/sciol/xyzen-web:${{ needs.setup.outputs.beta_tag }} \ - --push + - name: Build and push Web image + uses: docker/build-push-action@v6 + with: + context: ./web + push: true + tags: | + registry.sciol.ac.cn/sciol/xyzen-web:beta + registry.sciol.ac.cn/sciol/xyzen-web:${{ env.BETA_VERSION }} - # 部署阶段:等待所有构建完成后统一部署 + # Deploy: Wait for all builds to complete deploy: + name: Deploy runs-on: ubuntu-latest needs: [setup, build-service, build-web] + env: + BETA_VERSION: ${{ needs.setup.outputs.beta_version }} steps: - name: Download Let's Encrypt CA run: curl -o ca.crt https://letsencrypt.org/certs/isrgrootx1.pem - - name: Rolling update deployments + - name: Deploy to Kubernetes run: | kubectl \ --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - set image deployment/xyzen -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_tag }} + set image deployment/xyzen -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ env.BETA_VERSION }} kubectl \ --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - set image deployment/xyzen-web -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-web:${{ needs.setup.outputs.beta_tag }} + set image deployment/xyzen-web -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-web:${{ env.BETA_VERSION }} kubectl \ --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - set image deployment/xyzen-celery -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_tag }} + set image deployment/xyzen-celery -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ env.BETA_VERSION }} - # 通知阶段:发送构建结果通知 + - name: Deployment Summary + run: | + echo "## 🧪 Beta Deployment Complete" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Component | Image |" >> $GITHUB_STEP_SUMMARY + echo "|-----------|-------|" >> $GITHUB_STEP_SUMMARY + echo "| Service | \`registry.sciol.ac.cn/sciol/xyzen-service:${{ env.BETA_VERSION }}\` |" >> $GITHUB_STEP_SUMMARY + echo "| Web | \`registry.sciol.ac.cn/sciol/xyzen-web:${{ env.BETA_VERSION }}\` |" >> $GITHUB_STEP_SUMMARY + + # Notify: Send build result notification notify: + name: Notify runs-on: ubuntu-latest needs: [setup, build-service, build-web, deploy] if: always() steps: - - name: Checkout code + - name: Checkout uses: actions/checkout@v4 - name: Calculate build duration - id: build_duration - shell: bash + id: duration run: | BUILD_START="${{ needs.setup.outputs.build_start }}" if [ -z "$BUILD_START" ]; then - echo "Warning: build_start is empty, using current time as fallback" BUILD_START=$(date '+%Y-%m-%d %H:%M:%S') fi @@ -140,10 +166,8 @@ jobs: MINUTES=$(((DURATION_SEC % 3600) / 60)) SECONDS=$((DURATION_SEC % 60)) - DURATION="${HOURS}h ${MINUTES}m ${SECONDS}s" - - echo "build_end=$BUILD_END" >> $GITHUB_ENV - echo "build_duration=$DURATION" >> $GITHUB_ENV + echo "build_end=$BUILD_END" >> $GITHUB_OUTPUT + echo "build_duration=${HOURS}h ${MINUTES}m ${SECONDS}s" >> $GITHUB_OUTPUT - name: Determine overall status id: status @@ -154,7 +178,7 @@ jobs: echo "status=failure" >> $GITHUB_OUTPUT fi - - name: Send build notification + - name: Send notification uses: ./.github/actions/email-notification with: status: ${{ steps.status.outputs.status }} @@ -165,20 +189,20 @@ jobs: recipient: ${{ secrets.SMTP_RECEIVER }} architecture: 'amd64' pr_number: 'N/A' - pr_title: 'Push to test' + pr_title: 'Beta Deploy ${{ needs.setup.outputs.beta_version }}' pr_url: '${{ github.server_url }}/${{ github.repository }}/commit/${{ github.sha }}' head_ref: ${{ github.ref_name }} base_ref: 'test' repo: ${{ github.repository }} run_id: ${{ github.run_id }} build_start: ${{ needs.setup.outputs.build_start }} - build_end: ${{ env.build_end }} - build_duration: ${{ env.build_duration }} + build_end: ${{ steps.duration.outputs.build_end }} + build_duration: ${{ steps.duration.outputs.build_duration }} commit_author: ${{ needs.setup.outputs.commit_author }} commit_email: ${{ needs.setup.outputs.commit_email }} commit_message: ${{ needs.setup.outputs.commit_message }} commit_sha: ${{ needs.setup.outputs.commit_sha }} commit_sha_short: ${{ needs.setup.outputs.commit_sha_short }} commit_date: ${{ needs.setup.outputs.commit_date }} - service_image: 'registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_tag }}' - web_image: 'registry.sciol.ac.cn/sciol/xyzen-web:${{ needs.setup.outputs.beta_tag }}' + service_image: 'registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_version }}' + web_image: 'registry.sciol.ac.cn/sciol/xyzen-web:${{ needs.setup.outputs.beta_version }}' diff --git a/README.md b/README.md index 3416ab51..cdd0ad25 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Your next agent platform for multi-agent orchestration, real-time chat, and docu [![React](https://img.shields.io/badge/react-%2320232a.svg?style=flat&logo=react&logoColor=%2361DAFB)](https://reactjs.org/) [![npm version](https://img.shields.io/npm/v/@sciol/xyzen.svg)](https://www.npmjs.com/package/@sciol/xyzen) [![Pre-commit CI](https://github.com/ScienceOL/Xyzen/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/pre-commit.yaml) -[![Prod Build](https://github.com/ScienceOL/Xyzen/actions/workflows/prod-build.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/prod-build.yaml) +[![Release](https://github.com/ScienceOL/Xyzen/actions/workflows/release.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/release.yaml) [![Test Suite](https://github.com/ScienceOL/Xyzen/actions/workflows/test.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/test.yaml) [![codecov](https://codecov.io/github/ScienceOL/Xyzen/graph/badge.svg?token=91W3GO7CRI)](https://codecov.io/github/ScienceOL/Xyzen) @@ -31,17 +31,9 @@ Xyzen is an AI lab server built with FastAPI + LangGraph on the backend and Reac ## Getting Started -Xyzen uses Docker for all development to ensure consistency across environments and to manage required infrastructure services (PostgreSQL, Redis, Mosquitto, Casdoor). - ### Prerequisites - Docker and Docker Compose -- [uv](https://docs.astral.sh/uv/) for pre-commit hooks (Python tools) -- Node.js with Yarn (via [Corepack](https://nodejs.org/api/corepack.html)) for pre-commit hooks (Frontend tools) - -## Development Setup - -The easiest way to get started with Xyzen is using the containerized development environment. This automatically sets up all services (PostgreSQL, Mosquitto, Casdoor) and development tools. ### Quick Start @@ -52,69 +44,50 @@ The easiest way to get started with Xyzen is using the containerized development cd Xyzen ``` -2. Start the development environment: - - **On Unix/Linux/macOS:** +2. Create environment configuration: ```bash - ./launch/dev.sh - ``` - - **On Windows (PowerShell):** - - ```powershell - .\launch\dev.ps1 + cp docker/.env.example docker/.env.dev ``` - Or use the Makefile: +3. Configure your LLM provider in `docker/.env.dev`: ```bash - make dev # Start in foreground (shows logs) - make dev ARGS="-d" # Start in background (daemon mode) - make dev ARGS="-s" # Stop containers (without removal) - make dev ARGS="-e" # Stop and remove containers - ``` - -The script will automatically: - -- Check Docker and validate `.env.dev` file -- Set up global Sciol virtual environment at `~/.sciol/venv` -- Install and configure pre-commit hooks -- Create VS Code workspace configuration -- Start infrastructure services (PostgreSQL, Mosquitto, Casdoor) -- Launch development containers with hot reloading + # Enable providers (comma-separated): azure_openai,openai,google,qwen + XYZEN_LLM_providers=openai -### Container Development Options - -**Start in foreground (see logs):** + # OpenAI example + XYZEN_LLM_OpenAI_key=sk-your-api-key + XYZEN_LLM_OpenAI_endpoint=https://api.openai.com/v1 + XYZEN_LLM_OpenAI_deployment=gpt-4o + ``` -```bash -./launch/dev.sh -``` + See `docker/.env.example` for all available configuration options. -**Start in background:** +4. Start the development environment: -```bash -./launch/dev.sh -d -``` + ```bash + ./launch/dev.sh # Start in foreground (shows logs) + ./launch/dev.sh -d # Start in background (daemon mode) + ./launch/dev.sh -s # Stop containers + ./launch/dev.sh -e # Stop and remove containers + ``` -**Stop containers:** + Or use the Makefile: -```bash -./launch/dev.sh -s -``` + ```bash + make dev # Start in foreground + make dev ARGS="-d" # Start in background + ``` -**Stop and remove containers:** +The script will automatically set up all infrastructure services (PostgreSQL, Redis, Mosquitto, Casdoor) and launch development containers with hot reloading. -```bash -./launch/dev.sh -e -``` +## Development -**Show help:** +### Prerequisites for Contributing -```bash -./launch/dev.sh -h -``` +- [uv](https://docs.astral.sh/uv/) for Python tools and pre-commit hooks +- Node.js with Yarn (via [Corepack](https://nodejs.org/api/corepack.html)) for frontend tools ## AI Assistant Rules diff --git a/README_zh.md b/README_zh.md index 7b7e182b..db86f237 100644 --- a/README_zh.md +++ b/README_zh.md @@ -12,7 +12,7 @@ [![React](https://img.shields.io/badge/react-%2320232a.svg?style=flat&logo=react&logoColor=%2361DAFB)](https://reactjs.org/) [![npm version](https://img.shields.io/npm/v/@sciol/xyzen.svg)](https://www.npmjs.com/package/@sciol/xyzen) [![Pre-commit CI](https://github.com/ScienceOL/Xyzen/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/pre-commit.yaml) -[![Prod Build](https://github.com/ScienceOL/Xyzen/actions/workflows/prod-build.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/prod-build.yaml) +[![Release](https://github.com/ScienceOL/Xyzen/actions/workflows/release.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/release.yaml) [![Test Suite](https://github.com/ScienceOL/Xyzen/actions/workflows/test.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/test.yaml) [![codecov](https://codecov.io/github/ScienceOL/Xyzen/graph/badge.svg?token=91W3GO7CRI)](https://codecov.io/github/ScienceOL/Xyzen) @@ -31,36 +31,63 @@ Xyzen 由 FastAPI + LangGraph 后端与 React + Zustand 前端构建,支持多 ## 快速开始 -开发依赖 Docker(PostgreSQL、Redis、Mosquitto、Casdoor)。 - ### 前置条件 - Docker 和 Docker Compose -- `uv`(用于 Python 工具链) -- Node.js + Yarn(Corepack) -### 启动开发环境 +### 启动步骤 -```bash -git clone https://github.com/ScienceOL/Xyzen.git -cd Xyzen -./launch/dev.sh -``` +1. 克隆仓库: -Windows(PowerShell): + ```bash + git clone https://github.com/ScienceOL/Xyzen.git + cd Xyzen + ``` -```powershell -.\launch\dev.ps1 -``` +2. 创建环境配置文件: -常用命令: + ```bash + cp docker/.env.example docker/.env.dev + ``` -```bash -make dev # 前台启动 -make dev ARGS="-d" # 后台启动 -make dev ARGS="-s" # 停止容器 -make dev ARGS="-e" # 停止并移除容器 -``` +3. 在 `docker/.env.dev` 中配置 LLM 模型: + + ```bash + # 启用的模型供应商(逗号分隔):azure_openai,openai,google,qwen + XYZEN_LLM_providers=openai + + # OpenAI 示例 + XYZEN_LLM_OpenAI_key=sk-your-api-key + XYZEN_LLM_OpenAI_endpoint=https://api.openai.com/v1 + XYZEN_LLM_OpenAI_deployment=gpt-4o + ``` + + 完整配置项请参考 `docker/.env.example`。 + +4. 启动开发环境: + + ```bash + ./launch/dev.sh # 前台启动(显示日志) + ./launch/dev.sh -d # 后台启动 + ./launch/dev.sh -s # 停止容器 + ./launch/dev.sh -e # 停止并移除容器 + ``` + + 或使用 Makefile: + + ```bash + make dev # 前台启动 + make dev ARGS="-d" # 后台启动 + ``` + +脚本会自动配置所有基础服务(PostgreSQL、Redis、Mosquitto、Casdoor)并启动带热重载的开发容器。 + +## 开发 + +### 贡献代码的前置条件 + +- [uv](https://docs.astral.sh/uv/)(Python 工具链和 pre-commit hooks) +- Node.js + Yarn(通过 [Corepack](https://nodejs.org/api/corepack.html),用于前端工具) ## AI 助手规则 diff --git a/service/app/agents/factory.py b/service/app/agents/factory.py index 6f38d1cf..53dcc82c 100644 --- a/service/app/agents/factory.py +++ b/service/app/agents/factory.py @@ -207,10 +207,11 @@ def _resolve_agent_config( def _inject_system_prompt(config_dict: dict[str, Any], system_prompt: str) -> dict[str, Any]: """ - Inject system_prompt into a react-style config. + Inject system_prompt into a graph config. - For configs using stdlib:react component, updates the config_overrides - to include the system_prompt. + Handles both: + 1. Component nodes with stdlib:react - updates config_overrides + 2. LLM nodes - updates prompt_template Args: config_dict: GraphConfig as dict @@ -224,8 +225,9 @@ def _inject_system_prompt(config_dict: dict[str, Any], system_prompt: str) -> di config = copy.deepcopy(config_dict) - # Find component nodes and inject system_prompt + # Find nodes and inject system_prompt (first matching node only) for node in config.get("nodes", []): + # Handle component nodes (existing behavior) if node.get("type") == "component": comp_config = node.get("component_config", {}) comp_ref = comp_config.get("component_ref", {}) @@ -234,6 +236,13 @@ def _inject_system_prompt(config_dict: dict[str, Any], system_prompt: str) -> di if comp_ref.get("key") == "react": overrides = comp_config.setdefault("config_overrides", {}) overrides["system_prompt"] = system_prompt + break + + # Handle LLM nodes + elif node.get("type") == "llm": + llm_config = node.get("llm_config", {}) + llm_config["prompt_template"] = system_prompt + break return config diff --git a/service/app/agents/graph_builder.py b/service/app/agents/graph_builder.py index c0726c15..cc70906c 100644 --- a/service/app/agents/graph_builder.py +++ b/service/app/agents/graph_builder.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Annotated, Any from jinja2 import Template -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, SystemMessage from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition @@ -344,11 +344,15 @@ async def llm_node(state: StateDict | BaseModel) -> StateDict: # Convert state to dict for template rendering (but we already have messages) state_dict = self._state_to_dict(state) - # Render prompt template - prompt = self._render_template(llm_config.prompt_template, state_dict) + # Build messages for LLM - start with conversation messages + llm_messages = list(messages) - # Build messages for LLM - llm_messages = list(messages) + [HumanMessage(content=prompt)] + # Prepend system prompt if configured (uses Jinja2 template rendering) + if llm_config.prompt_template: + rendered_prompt = self._render_template(llm_config.prompt_template, state_dict) + # Filter any existing SystemMessage and prepend ours + llm_messages = [m for m in llm_messages if not isinstance(m, SystemMessage)] + llm_messages = [SystemMessage(content=rendered_prompt)] + llm_messages # Invoke LLM (using pre-created configured_llm) response = await configured_llm.ainvoke(llm_messages) @@ -382,11 +386,15 @@ async def llm_node(state: StateDict | BaseModel) -> StateDict: logger.info(f"[LLM Node: {config.id}] Text output completed, tool_calls: {len(tool_calls)}") - # Build AIMessage preserving tool_calls - ai_message = AIMessage( - content=content_str, - tool_calls=tool_calls, - ) + # Preserve the original response message to retain provider-specific metadata + # (e.g., Gemini thought signatures needed for tool calling). + if isinstance(response, BaseMessage): + ai_message = response + else: + ai_message = AIMessage( + content=content_str, + tool_calls=tool_calls, + ) return { llm_config.output_key: content_str, diff --git a/service/app/api/v1/sessions.py b/service/app/api/v1/sessions.py index f4ef323a..94db7476 100644 --- a/service/app/api/v1/sessions.py +++ b/service/app/api/v1/sessions.py @@ -72,16 +72,17 @@ async def create_session_with_default_topic( raise handle_auth_error(e) -@router.get("/by-agent/{agent_id}", response_model=SessionRead) +@router.get("/by-agent/{agent_id}", response_model=SessionReadWithTopics) async def get_session_by_agent( agent_id: str, user: str = Depends(get_current_user), db: AsyncSession = Depends(get_session) -) -> SessionRead: +) -> SessionReadWithTopics: """ - Retrieve a session for the current user with a specific agent. + Retrieve a session for the current user with a specific agent, including topics. Finds a session associated with the given agent ID for the authenticated user. The agent_id can be "default" for sessions without an agent, a UUID string for sessions with a specific agent, or a builtin agent string ID. + Topics are ordered by updated_at descending (most recent first). Args: agent_id: Agent identifier ("default", UUID string, or builtin agent ID) @@ -89,13 +90,13 @@ async def get_session_by_agent( db: Database session (injected by dependency) Returns: - SessionRead: The session associated with the user and agent + SessionReadWithTopics: The session with topics associated with the user and agent Raises: HTTPException: 404 if no session found for this user-agent combination """ try: - return await SessionService(db).get_session_by_agent(user, agent_id) + return await SessionService(db).get_session_by_agent_with_topics(user, agent_id) except ErrCodeError as e: raise handle_auth_error(e) diff --git a/service/app/core/chat/langchain.py b/service/app/core/chat/langchain.py index 31395d5d..5826cfc4 100644 --- a/service/app/core/chat/langchain.py +++ b/service/app/core/chat/langchain.py @@ -492,9 +492,11 @@ async def _handle_updates_mode( logger.info(f"[ToolEvent] Skipping historical tool response: {tool_call_id}") continue ctx.emitted_tool_result_ids.add(tool_call_id) - result = format_tool_result(msg.content, tool_name) + # Get raw content before formatting (for cost calculation) + raw_content = msg.content + result = format_tool_result(raw_content, tool_name) logger.info(f"[ToolEvent] >>> Emitting tool_call_response for {tool_call_id}") - yield ToolEventHandler.create_tool_response_event(tool_call_id, result) + yield ToolEventHandler.create_tool_response_event(tool_call_id, result, raw_result=raw_content) last_message = messages[-1] diff --git a/service/app/core/chat/stream_handlers.py b/service/app/core/chat/stream_handlers.py index af23f2a7..e4ab358e 100644 --- a/service/app/core/chat/stream_handlers.py +++ b/service/app/core/chat/stream_handlers.py @@ -107,15 +107,19 @@ def create_tool_request_event(tool_call: dict[str, Any]) -> StreamingEvent: @staticmethod def create_tool_response_event( - tool_call_id: str, result: str, status: str = ToolCallStatus.COMPLETED + tool_call_id: str, + result: str, + status: str = ToolCallStatus.COMPLETED, + raw_result: str | dict | list | None = None, ) -> StreamingEvent: """ Create a tool call response event. Args: tool_call_id: ID of the tool call - result: Formatted result string + result: Formatted result string for display status: Tool call status + raw_result: Raw result for cost calculation (optional, unformatted) Returns: StreamingEvent for tool call response @@ -125,6 +129,8 @@ def create_tool_response_event( "status": status, "result": result, } + if raw_result is not None: + data["raw_result"] = raw_result return {"type": ChatEventType.TOOL_CALL_RESPONSE, "data": data} diff --git a/service/app/core/checkin.py b/service/app/core/checkin.py index 4837fc98..7ab99e6a 100644 --- a/service/app/core/checkin.py +++ b/service/app/core/checkin.py @@ -2,6 +2,7 @@ import logging from datetime import datetime, timedelta, timezone +from typing import TypedDict from sqlmodel.ext.asyncio.session import AsyncSession @@ -16,6 +17,13 @@ CHECKIN_TZ = timezone(timedelta(hours=8)) +class CheckInStatus(TypedDict): + checked_in_today: bool + consecutive_days: int + next_points: int + total_check_ins: int + + class CheckInService: """Service layer for check-in operations.""" @@ -58,19 +66,7 @@ def calculate_points(consecutive_days: int) -> int: Returns: Points to award. """ - if consecutive_days <= 0: - return 10 - elif consecutive_days == 1: - return 10 - elif consecutive_days == 2: - return 20 - elif consecutive_days == 3: - return 30 - elif consecutive_days == 4: - return 40 - else: - # Day 5 and beyond: 50 points - return 50 + return 10 * max(1, min(consecutive_days, 5)) async def check_in(self, user_id: str) -> tuple[CheckIn, int]: """ @@ -101,7 +97,6 @@ async def check_in(self, user_id: str) -> tuple[CheckIn, int]: existing_check_in = await self.check_in_repo.get_check_in_by_user_and_date(user_id, today) if existing_check_in: logger.warning(f"User {user_id} has already checked in today") - # raise ErrCodeError(ErrCode.ALREADY_CHECKED_IN_TODAY, "您今天已经签到过了哦~") raise ErrCodeError(ErrCode.ALREADY_CHECKED_IN_TODAY) # Get latest check-in to calculate consecutive days @@ -147,7 +142,7 @@ async def check_in(self, user_id: str) -> tuple[CheckIn, int]: return check_in, wallet.virtual_balance - async def get_check_in_status(self, user_id: str) -> dict: + async def get_check_in_status(self, user_id: str) -> CheckInStatus: """ Get check-in status for a user. diff --git a/service/app/core/consume_strategy.py b/service/app/core/consume_strategy.py index 74369400..a7cd275b 100644 --- a/service/app/core/consume_strategy.py +++ b/service/app/core/consume_strategy.py @@ -24,7 +24,7 @@ class ConsumptionContext: output_tokens: int = 0 total_tokens: int = 0 content_length: int = 0 - generated_files_count: int = 0 + tool_costs: int = 0 @dataclass @@ -64,13 +64,12 @@ class TierBasedConsumptionStrategy(ConsumptionStrategy): Design decisions: - LITE tier (rate 0.0) = completely free - - Tier rate multiplies ALL costs (base + tokens + files) + - Tier rate multiplies ALL costs (base + tokens + tool costs) """ BASE_COST = 1 INPUT_TOKEN_RATE = 0.2 / 1000 # per token OUTPUT_TOKEN_RATE = 1 / 1000 # per token - FILE_GENERATION_COST = 10 def calculate(self, context: ConsumptionContext) -> ConsumptionResult: """Calculate consumption with tier-based multiplier. @@ -90,7 +89,7 @@ def calculate(self, context: ConsumptionContext) -> ConsumptionResult: breakdown={ "base_cost": 0, "token_cost": 0, - "file_cost": 0, + "tool_costs": 0, "tier_rate": 0.0, "tier": context.model_tier.value if context.model_tier else "lite", "note": "LITE tier - free usage", @@ -99,10 +98,9 @@ def calculate(self, context: ConsumptionContext) -> ConsumptionResult: # Calculate base token cost token_cost = context.input_tokens * self.INPUT_TOKEN_RATE + context.output_tokens * self.OUTPUT_TOKEN_RATE - file_cost = context.generated_files_count * self.FILE_GENERATION_COST - # Tier rate multiplies ALL costs - base_amount = self.BASE_COST + token_cost + file_cost + # Tier rate multiplies ALL costs (including tool costs) + base_amount = self.BASE_COST + token_cost + context.tool_costs final_amount = int(base_amount * tier_rate) return ConsumptionResult( @@ -110,7 +108,7 @@ def calculate(self, context: ConsumptionContext) -> ConsumptionResult: breakdown={ "base_cost": self.BASE_COST, "token_cost": token_cost, - "file_cost": file_cost, + "tool_costs": context.tool_costs, "pre_multiplier_total": base_amount, "tier_rate": tier_rate, "tier": context.model_tier.value if context.model_tier else "default", diff --git a/service/app/core/session/service.py b/service/app/core/session/service.py index 53d0e729..fd9bfc14 100644 --- a/service/app/core/session/service.py +++ b/service/app/core/session/service.py @@ -49,6 +49,20 @@ async def get_session_by_agent(self, user_id: str, agent_id: str) -> SessionRead raise ErrCode.SESSION_NOT_FOUND.with_messages("No session found for this user-agent combination") return SessionRead(**session.model_dump()) + async def get_session_by_agent_with_topics(self, user_id: str, agent_id: str) -> SessionReadWithTopics: + agent_uuid = await self._resolve_agent_uuid_for_lookup(agent_id) + session = await self.session_repo.get_session_by_user_and_agent(user_id, agent_uuid) + if not session: + raise ErrCode.SESSION_NOT_FOUND.with_messages("No session found for this user-agent combination") + + # Fetch topics ordered by updated_at descending (most recent first) + topics = await self.topic_repo.get_topics_by_session(session.id, order_by_updated=True) + topic_reads = [TopicRead(**topic.model_dump()) for topic in topics] + + session_dict = session.model_dump() + session_dict["topics"] = topic_reads + return SessionReadWithTopics(**session_dict) + async def get_sessions_with_topics(self, user_id: str) -> list[SessionReadWithTopics]: sessions = await self.session_repo.get_sessions_by_user_ordered_by_activity(user_id) diff --git a/service/app/infra/database/__init__.py b/service/app/infra/database/__init__.py index 4eac26fe..b50a9dc6 100644 --- a/service/app/infra/database/__init__.py +++ b/service/app/infra/database/__init__.py @@ -5,6 +5,7 @@ create_task_session_factory, engine, get_session, + get_task_db_session, ) __all__ = [ @@ -14,4 +15,5 @@ "AsyncSessionLocal", "ASYNC_DATABASE_URL", "create_task_session_factory", + "get_task_db_session", ] diff --git a/service/app/infra/database/connection.py b/service/app/infra/database/connection.py index a8afa96a..a59a7ff6 100644 --- a/service/app/infra/database/connection.py +++ b/service/app/infra/database/connection.py @@ -1,4 +1,7 @@ +import asyncio +import os from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine @@ -138,3 +141,35 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]: """ async with AsyncSessionLocal() as session: yield session + + +_worker_engines: dict[tuple[int, int], async_sessionmaker[AsyncSession]] = {} + + +@asynccontextmanager +async def get_task_db_session(): + """ + Get a database session suitable for Celery Worker / tool execution contexts. + + Each call creates a new engine bound to the current event loop to avoid cross-loop issues. + """ + pid = os.getpid() + loop_id = id(asyncio.get_running_loop()) + cache_key = (pid, loop_id) + + if cache_key not in _worker_engines: + # Clean up old engines from this process but different loops + old_keys = [k for k in _worker_engines if k[0] == pid and k[1] != loop_id] + for old_key in old_keys: + # Optionally dispose old engines (fire and forget) + del _worker_engines[old_key] + + task_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False, future=True) + _worker_engines[cache_key] = async_sessionmaker( + bind=task_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + async with _worker_engines[cache_key]() as session: + yield session diff --git a/service/app/mcp/literature.py b/service/app/mcp/literature.py new file mode 100644 index 00000000..5c5e55ca --- /dev/null +++ b/service/app/mcp/literature.py @@ -0,0 +1,463 @@ +""" +Literature MCP Server - Multi-source academic literature search + +Provides tools for searching academic literature from multiple data sources +(OpenAlex, Semantic Scholar, PubMed, etc.) with unified interface. +""" + +import json +import logging +from datetime import datetime +from typing import Any + +import httpx +from fastmcp import FastMCP + +from app.utils.literature import SearchRequest, WorkDistributor + +logger = logging.getLogger(__name__) + +TRUE_VALUES = frozenset({"true", "1", "yes"}) +FALSE_VALUES = frozenset({"false", "0", "no"}) + +# Create FastMCP instance +mcp = FastMCP("literature") + +# Metadata for MCP server +__mcp_metadata__ = { + "name": "Literature Search", + "description": "Search academic literature from multiple sources with advanced filtering", + "version": "1.0.0", +} + + +@mcp.tool() +async def search_literature( + query: str, + mailto: str | None = None, + author: str | None = None, + institution: str | None = None, + source: str | None = None, + year_from: str | None = None, + year_to: str | None = None, + is_oa: str | None = None, + work_type: str | None = None, + language: str | None = None, + is_retracted: str | None = None, + has_abstract: str | None = None, + has_fulltext: str | None = None, + sort_by: str = "relevance", + max_results: str | int = 50, + data_sources: list[str] | None = None, + include_abstract: str | bool = False, +) -> str: + """ + Search academic literature from multiple data sources (OpenAlex, Semantic Scholar, PubMed, etc.) + + 🔑 STRONGLY RECOMMENDED: Always provide a valid email address (mailto parameter) + ═════════════════════════════════════════════════════════════════════════════════ + + 📊 Performance Difference: + - WITH email (mailto): 10 requests/second (fast, ideal for large searches) + - WITHOUT email (mailto): 1 request/second (slow, sequential processing) + + ⚠️ Impact: Omitting email can cause 10x slowdown or timeouts for large result sets. + Production research should ALWAYS include email. Example: "researcher@university.edu" + + Response Format Overview + ════════════════════════ + The tool returns TWO sections automatically: + + 1️⃣ EXECUTIVE SUMMARY + - Key statistics (total found, unique count, sources) + - Average citations and open access rate + - Publication year range + - Warning/issue resolution status + + 2️⃣ DETAILED RESULTS (Complete JSON with URLs) + - Each paper includes: + • ✅ Valid URLs (access_url; doi is a raw identifier) + • Title, Authors (first 5), Publication Year + • Citation Count, Journal, Open Access Status + • Abstract (only if include_abstract=True) + - Format: JSON array for easy parsing/import + - All URLs are validated and functional + + Args: + query: Search keywords (e.g., "machine learning", "CRISPR", "cancer immunotherapy") + [REQUIRED] Most important parameter for accurate results + + mailto: Email address to enable fast API pool at OpenAlex + [⭐ STRONGLY RECOMMENDED - includes your email] + Examples: "researcher@mit.edu", "student@university.edu", "name@company.com" + Impact: 10x faster searches. Production users MUST provide this. + Note: Email is private, only used for API identification. + + author: OPTIONAL - Filter by author name (e.g., "Albert Einstein", "Jennifer Doudna") + Will auto-correct common misspellings if not found exactly + + institution: OPTIONAL - Filter by affiliation (e.g., "MIT", "Harvard", "Stanford University") + Partial name matching supported + + source: OPTIONAL - Filter by journal/venue (e.g., "Nature", "Science", "JAMA") + Matches both journal names and abbreviated titles + + year_from: OPTIONAL - Start year (e.g., "2020" or 2020) + Accepts string or integer, will auto-clamp to valid range (1700-2026) + + year_to: OPTIONAL - End year (e.g., "2024" or 2024) + Accepts string or integer, will auto-clamp to valid range (1700-2026) + If year_from > year_to, they will be automatically swapped + + is_oa: OPTIONAL - Open access filter ("true"/"false"/"yes"/"no") + "true" returns ONLY open access papers with direct links + + work_type: OPTIONAL - Filter by publication type + Options: "article", "review", "preprint", "book", "dissertation", "dataset", etc. + + language: OPTIONAL - Filter by publication language (e.g., "en", "zh", "ja", "fr", "de") + "en" = English only, "zh" = Chinese only, etc. + + is_retracted: OPTIONAL - Retracted paper filter ("true"/"false") + "false" excludes retracted works (recommended for research) + "true" shows ONLY retracted papers (for auditing) + + has_abstract: OPTIONAL - Require abstract ("true"/"false") + "true" returns only papers with abstracts + + has_fulltext: OPTIONAL - Require full text access ("true"/"false") + "true" returns only papers with available full text + + sort_by: Sort results - "relevance" (default), "cited_by_count", "publication_date" + "cited_by_count" useful for influential papers + "publication_date" shows most recent first + + max_results: Result limit (default: 50, range: 1-1000, accepts string or int) + More results = slower query. Recommended: 50-200 for research + + data_sources: Advanced - Sources to query (default: ["openalex"]) + Can include: ["openalex", "semantic_scholar", "pubmed"] + + include_abstract: Include full abstracts in JSON output? (default: False) + True = include full abstracts for detailed review + False = save token budget by excluding abstracts + + Returns: + Markdown report with two sections: + + 📋 Section 1: EXECUTIVE SUMMARY + └─ Search conditions recap + └─ Total results found & unique count + └─ Statistics: avg citations, OA rate, year range + └─ ⚠️ Any warnings/filter issues & resolutions + + 📊 Section 2: COMPLETE RESULTS (JSON Array) + └─ Each paper object contains: + • "doi": Raw DOI string (not a URL) + • "title": Paper title + • "authors": Author names [first 5 only to save tokens] + • "publication_year": Publication date + • "cited_by_count": Citation impact metric + • "journal": Journal/venue name + • "description": Short description about the paper + └─ access_url is validated and immediately accessible + └─ Copy JSON directly into spreadsheet, database, or reference manager + + Usage Tips (READ THIS!) + ══════════════════════ + ✅ DO: + - Always provide mailto (10x faster searches) + - Start simple: query + mailto first + - Review results before refining search + - Use filters incrementally to narrow down + - Set include_abstract=True only for final review (saves API calls) + + ❌ DON'T: + - Make multiple searches without reviewing first results + - Use vague keywords like "research" or "analysis" + - Search without mailto unless doing quick test + - Ignore the "Next Steps Guide" section + - Omit email for production/important research + """ + try: + # Validate query early to avoid accidental broad searches + if not query or not str(query).strip(): + return "❌ Invalid input: query cannot be empty." + if len(str(query).strip()) < 3: + return "❌ Invalid input: query is too short (minimum 3 characters)." + + # Convert string parameters to proper types + year_from_int = int(year_from) if year_from and str(year_from).strip() else None + year_to_int = int(year_to) if year_to and str(year_to).strip() else None + + # Clamp year ranges (warn but don't block search) + max_year = datetime.now().year + 1 + year_warning = "" + if year_from_int is not None and year_from_int > max_year: + year_warning += f"year_from {year_from_int}→{max_year}. " + year_from_int = max_year + if year_to_int is not None and year_to_int < 1700: + year_warning += f"year_to {year_to_int}→1700. " + year_to_int = 1700 + + # Ensure year_from <= year_to when both are set + if year_from_int is not None and year_to_int is not None and year_from_int > year_to_int: + year_warning += f"year_from {year_from_int} and year_to {year_to_int} swapped to maintain a valid range. " + year_from_int, year_to_int = year_to_int, year_from_int + + # Convert is_oa to boolean + bool_warning_parts: list[str] = [] + + def _parse_bool_field(raw: str | bool | None, field_name: str) -> bool | None: + if raw is None: + return None + if isinstance(raw, bool): + return raw + val = str(raw).strip().lower() + if val in TRUE_VALUES: + return True + if val in FALSE_VALUES: + return False + bool_warning_parts.append(f"{field_name}={raw!r} not recognized; ignoring this filter.") + return None + + # Convert bool-like fields + is_oa_bool = _parse_bool_field(is_oa, "is_oa") + is_retracted_bool = _parse_bool_field(is_retracted, "is_retracted") + has_abstract_bool = _parse_bool_field(has_abstract, "has_abstract") + has_fulltext_bool = _parse_bool_field(has_fulltext, "has_fulltext") + + # Convert max_results to int with early clamping + max_results_warning = "" + try: + max_results_int = int(max_results) if max_results else 50 + except (TypeError, ValueError): + max_results_warning = "⚠️ max_results is not a valid integer; using default 50. " + max_results_int = 50 + + if max_results_int < 1: + max_results_warning += f"max_results {max_results_int}→50 (minimum is 1). " + max_results_int = 50 + elif max_results_int > 1000: + max_results_warning += f"max_results {max_results_int}→1000 (maximum is 1000). " + max_results_int = 1000 + + # Convert include_abstract to bool + include_abstract_bool = str(include_abstract).lower() in {"true", "1", "yes"} if include_abstract else False + + openalex_email = mailto.strip() if mailto and str(mailto).strip() else None + + logger.info( + "Literature search requested: query=%r, mailto=%s, max_results=%d", + query, + "" if openalex_email else None, + max_results_int, + ) + + # Create search request with converted types + request = SearchRequest( + query=query, + author=author, + institution=institution, + source=source, + year_from=year_from_int, + year_to=year_to_int, + is_oa=is_oa_bool, + work_type=work_type, + language=language, + is_retracted=is_retracted_bool, + has_abstract=has_abstract_bool, + has_fulltext=has_fulltext_bool, + sort_by=sort_by, + max_results=max_results_int, + data_sources=data_sources, + ) + + # Execute search + async with WorkDistributor(openalex_email=openalex_email) as distributor: + result = await distributor.search(request) + + if year_warning: + result.setdefault("warnings", []).append(f"⚠️ Year adjusted: {year_warning.strip()}") + if bool_warning_parts: + result.setdefault("warnings", []).append("⚠️ Boolean filter issues: " + " ".join(bool_warning_parts)) + if max_results_warning: + result.setdefault("warnings", []).append(max_results_warning.strip()) + + # Format output + return _format_search_result(request, result, include_abstract_bool) + + except ValueError as e: + logger.warning(f"Literature search validation error: {e}") + return f"❌ Invalid input: {str(e)}" + except httpx.HTTPError as e: + logger.error(f"Literature search network error: {e}", exc_info=True) + return "❌ Network error while contacting literature sources. Please try again later." + except Exception as e: + logger.error(f"Literature search failed: {e}", exc_info=True) + return "❌ Unexpected error during search. Please retry or contact support." + + +def _format_search_result(request: SearchRequest, result: dict[str, Any], include_abstract: bool = False) -> str: + """ + Format search results into human-readable report + JSON data + + Args: + request: Original search request + result: Search result from WorkDistributor + include_abstract: Whether to include abstracts in JSON (default: False to save tokens) + + Returns: + Formatted markdown report with embedded JSON + """ + works = result["works"] + + # Build report sections + sections: list[str] = ["# Literature Search Report\n"] + + # Warnings and resolution status (if any) + if warnings := result.get("warnings", []): + sections.extend(["## ⚠️ Warnings and Resolution Status\n", *warnings, ""]) + + # Search conditions + conditions: list[str] = [ + f"- **Query**: {request.query}", + *([f"- **Author**: {request.author}"] if request.author else []), + *([f"- **Institution**: {request.institution}"] if request.institution else []), + *([f"- **Source**: {request.source}"] if request.source else []), + *( + [f"- **Year Range**: {request.year_from or '...'} - {request.year_to or '...'}"] + if request.year_from or request.year_to + else [] + ), + *([f"- **Open Access Only**: {'Yes' if request.is_oa else 'No'}"] if request.is_oa is not None else []), + *([f"- **Work Type**: {request.work_type}"] if request.work_type else []), + *([f"- **Language**: {request.language}"] if request.language else []), + *( + [f"- **Exclude Retracted**: {'No' if request.is_retracted else 'Yes'}"] + if request.is_retracted is not None + else [] + ), + *( + [f"- **Require Abstract**: {'Yes' if request.has_abstract else 'No'}"] + if request.has_abstract is not None + else [] + ), + *( + [f"- **Require Full Text**: {'Yes' if request.has_fulltext else 'No'}"] + if request.has_fulltext is not None + else [] + ), + f"- **Sort By**: {request.sort_by}", + f"- **Max Results**: {request.max_results}", + ] + sections.extend(["## Search Conditions\n", "\n".join(conditions), ""]) + + # Check if no results + if not works: + sections.extend(["## ❌ No Results Found\n", "**Suggestions to improve your search:**\n"]) + suggestions: list[str] = [ + "1. **Simplify keywords**: Try broader or different terms", + *(["2. **Remove author filter**: Author name may not be recognized"] if request.author else []), + *(["3. **Remove institution filter**: Try without institution constraint"] if request.institution else []), + *(["4. **Remove source filter**: Try without journal constraint"] if request.source else []), + *( + ["5. **Expand year range**: Current range may be too narrow"] + if request.year_from or request.year_to + else [] + ), + *(["6. **Remove open access filter**: Include non-OA papers"] if request.is_oa else []), + "7. **Check spelling**: Verify all terms are spelled correctly", + ] + sections.extend(["\n".join(suggestions), ""]) + return "\n".join(sections) + + # Statistics and overall insights + total_count = result["total_count"] + unique_count = result["unique_count"] + sources = result["sources"] + + stats: list[str] = [ + f"- **Total Found**: {total_count} works", + f"- **After Deduplication**: {unique_count} works", + ] + source_info = ", ".join(f"{name}: {count}" for name, count in sources.items()) + stats.append(f"- **Data Sources**: {source_info}") + + # Add insights + avg_citations = sum(w.cited_by_count for w in works) / len(works) + stats.append(f"- **Average Citations**: {avg_citations:.1f}") + + oa_count = sum(w.is_oa for w in works) + oa_ratio = (oa_count / len(works)) * 100 + stats.append(f"- **Open Access Rate**: {oa_ratio:.1f}% ({oa_count}/{len(works)})") + + if years := [w.publication_year for w in works if w.publication_year]: + stats.append(f"- **Year Range**: {min(years)} - {max(years)}") + + sections.extend(["## Search Statistics\n", "\n".join(stats), ""]) + + # Complete JSON list + sections.extend( + [ + "## Complete Works List (JSON)\n", + "The following JSON contains all works with full abstracts:\n" + if include_abstract + else "The following JSON contains all works (abstracts excluded to save tokens):\n", + "```json", + ] + ) + + # Convert works to dict for JSON serialization + works_dict = [] + for work in works: + work_data = { + "id": work.id, + "doi": work.doi, + "title": work.title, + "authors": work.authors[:5], # Limit to first 5 authors + "publication_year": work.publication_year, + "cited_by_count": work.cited_by_count, + "journal": work.journal, + "primary_institution": work.primary_institution, + "is_oa": work.is_oa, + "access_url": work.access_url, + "source": work.source, + } + # Only include abstract if requested + if include_abstract and work.abstract: + work_data["abstract"] = work.abstract + works_dict.append(work_data) + + sections.extend([json.dumps(works_dict, indent=2, ensure_ascii=False), "```", ""]) + + # Next steps guidance - prevent infinite loops + sections.extend(["---", "## 🎯 Next Steps Guide\n", "**Before making another search, consider:**\n"]) + next_steps: list[str] = [ + *(["✓ **Results found** - Review the JSON data above for your analysis"] if unique_count > 0 else []), + *( + [ + f"⚠️ **Result limit reached** ({request.max_results}) - " + "Consider narrowing filters (author, year, journal) for more targeted results" + ] + if unique_count >= request.max_results + else [] + ), + *( + ["💡 **Few results** - Consider broadening your search by removing some filters"] + if 0 < unique_count < 10 + else [] + ), + "", + "**To refine your search:**", + "- If too many results → Add more specific filters (author, institution, journal, year)", + "- If too few results → Remove filters or use broader keywords", + "- If wrong results → Check filter spelling and try variations", + "", + "⚠️ **Important**: Avoid making multiple similar searches without reviewing results first!", + "Each search consumes API quota and context window. Make targeted, deliberate queries.", + ] + + sections.append("\n".join(next_steps)) + + return "\n".join(sections) diff --git a/service/app/schemas/chat_event_payloads.py b/service/app/schemas/chat_event_payloads.py index 7ad6b913..4f13d652 100644 --- a/service/app/schemas/chat_event_payloads.py +++ b/service/app/schemas/chat_event_payloads.py @@ -92,7 +92,8 @@ class ToolCallResponseData(TypedDict): toolCallId: str status: str - result: str + result: str # Formatted result for display + raw_result: NotRequired[str | dict | list] # Raw result for cost calculation error: NotRequired[str] diff --git a/service/app/tasks/chat.py b/service/app/tasks/chat.py index d6c2993a..57143ce2 100644 --- a/service/app/tasks/chat.py +++ b/service/app/tasks/chat.py @@ -23,6 +23,7 @@ from app.repos.session import SessionRepository from app.schemas.chat_event_payloads import CitationData from app.schemas.chat_event_types import ChatEventType +from app.tools.cost import calculate_tool_cost logger = logging.getLogger(__name__) @@ -177,6 +178,10 @@ async def _process_chat_message_async( output_tokens: int = 0 total_tokens: int = 0 + # Tool cost tracking + tool_costs_total = 0 + tool_call_data: dict[str, dict[str, Any]] = {} # tool_call_id -> {name, args} + # Agent run tracking (for new timeline-based persistence) agent_run_id: UUID | None = None agent_run_start_time: float | None = None @@ -305,9 +310,24 @@ async def _process_chat_message_async( await publisher.publish(json.dumps(stream_event)) elif stream_event["type"] == ChatEventType.TOOL_CALL_REQUEST: + # Store tool call data for cost calculation + req = stream_event["data"] + tool_call_id = req.get("id") + tool_name = req.get("name", "") + if tool_call_id: + # Parse arguments (may be JSON string) + raw_args = req.get("arguments", {}) + if isinstance(raw_args, str): + try: + parsed_args = json.loads(raw_args) + except json.JSONDecodeError: + parsed_args = {} + else: + parsed_args = raw_args or {} + tool_call_data[tool_call_id] = {"name": tool_name, "args": parsed_args} + # Persist tool call request try: - req = stream_event["data"] tool_message = MessageCreate( role="tool", content=json.dumps( @@ -331,11 +351,42 @@ async def _process_chat_message_async( await publisher.publish(json.dumps(stream_event)) elif stream_event["type"] == ChatEventType.TOOL_CALL_RESPONSE: + resp = stream_event["data"] + tool_call_id = resp.get("toolCallId") + + # Calculate tool cost using stored data from TOOL_CALL_REQUEST + if tool_call_id and tool_call_id in tool_call_data: + stored = tool_call_data[tool_call_id] + tool_name = stored.get("name", "") + args = stored.get("args", {}) + # Use raw_result for cost calculation (unformatted) + result = resp.get("raw_result") + # Parse result if it's a JSON string + if isinstance(result, str): + try: + result = json.loads(result) + except json.JSONDecodeError: + result = None + # Only dict results are supported for cost calculation + if not isinstance(result, dict): + result = None + + # Only charge for successful tool executions + tool_failed = ( + resp.get("status") == "error" + or resp.get("error") is not None + or (isinstance(result, dict) and result.get("success") is False) + ) + if tool_failed: + logger.info(f"Tool {tool_name} failed, not charging") + else: + cost = calculate_tool_cost(tool_name, args, result) + if cost > 0: + tool_costs_total += cost + logger.info(f"Tool {tool_name} cost: {cost} (total: {tool_costs_total})") + # Persist tool call response try: - resp = stream_event["data"] - tool_call_id = resp.get("toolCallId") - # Only persist if toolCallId is valid - skip otherwise if not tool_call_id or not isinstance(tool_call_id, str): logger.warning( @@ -601,7 +652,7 @@ async def _process_chat_message_async( output_tokens=output_tokens, total_tokens=total_tokens, content_length=len(full_content), - generated_files_count=generated_files_count, + tool_costs=tool_costs_total, ) result = ConsumptionCalculator.calculate(consume_context) total_cost = result.amount diff --git a/service/app/tools/__init__.py b/service/app/tools/__init__.py index b5e85ebe..108793ec 100644 --- a/service/app/tools/__init__.py +++ b/service/app/tools/__init__.py @@ -17,7 +17,7 @@ Tool Categories: | Category | Tools | UI Toggle | Auto-enabled | |------------|---------------------------|-----------|--------------| -| search | web_search | Yes | - | +| search | web_search, web_fetch | Yes | - | | knowledge | knowledge_* | No | Yes (with knowledge_set) | | image | generate_image, read_image| Yes | - | | research | think, ConductResearch | No | Component-internal | diff --git a/service/app/tools/builtin/__init__.py b/service/app/tools/builtin/__init__.py index e83066c1..0b2c48e0 100644 --- a/service/app/tools/builtin/__init__.py +++ b/service/app/tools/builtin/__init__.py @@ -13,6 +13,7 @@ - research: Deep research workflow tools (component-internal, not exported here) """ +from app.tools.builtin.fetch import create_web_fetch_tool from app.tools.builtin.image import create_image_tools, create_image_tools_for_agent from app.tools.builtin.knowledge import create_knowledge_tools, create_knowledge_tools_for_agent from app.tools.builtin.memory import create_memory_tools, create_memory_tools_for_agent @@ -21,6 +22,8 @@ __all__ = [ # Search "create_web_search_tool", + # Fetch + "create_web_fetch_tool", # Knowledge "create_knowledge_tools", "create_knowledge_tools_for_agent", diff --git a/service/app/tools/builtin/fetch.py b/service/app/tools/builtin/fetch.py new file mode 100644 index 00000000..ff5c69c2 --- /dev/null +++ b/service/app/tools/builtin/fetch.py @@ -0,0 +1,165 @@ +""" +Web Fetch Tool + +LangChain tool for fetching and extracting content from web pages using Trafilatura. +Extracts clean text/markdown content from HTML pages with metadata extraction. +""" + +from __future__ import annotations + +import logging +from typing import Any, Literal + +import trafilatura +from langchain_core.tools import BaseTool, StructuredTool +from pydantic import BaseModel, Field +from trafilatura.settings import use_config + +logger = logging.getLogger(__name__) + + +class WebFetchInput(BaseModel): + """Input schema for web fetch tool.""" + + url: str = Field(description="The URL of the web page to fetch and extract content from.") + output_format: Literal["markdown", "text"] = Field( + default="markdown", + description="Output format: 'markdown' for structured content, 'text' for plain text.", + ) + include_links: bool = Field( + default=True, + description="Whether to include hyperlinks in the extracted content.", + ) + include_images: bool = Field( + default=False, + description="Whether to include image references in the output.", + ) + timeout: int = Field( + default=30, + ge=5, + le=120, + description="Request timeout in seconds.", + ) + + +async def _web_fetch( + url: str, + output_format: Literal["markdown", "text"] = "markdown", + include_links: bool = True, + include_images: bool = False, + timeout: int = 30, +) -> dict[str, Any]: + """ + Fetch and extract content from a web page. + + Uses Trafilatura for robust HTML content extraction and conversion + to clean markdown or plain text. + + Returns: + A dictionary containing: + - success: Boolean indicating success + - url: The original URL + - title: Page title if available + - author: Author if available + - date: Publication date if available + - content: Extracted markdown/text content + - error: Error message if failed + """ + if not url.strip(): + return { + "success": False, + "error": "URL cannot be empty", + "url": url, + "title": None, + "author": None, + "date": None, + "content": None, + } + + # Configure trafilatura + config = use_config() + config.set("DEFAULT", "EXTRACTION_TIMEOUT", str(timeout)) + + try: + # Fetch the page + downloaded = trafilatura.fetch_url(url) + if downloaded is None: + return { + "success": False, + "error": "Failed to fetch URL - the page may be unavailable or blocked", + "url": url, + "title": None, + "author": None, + "date": None, + "content": None, + } + + # Extract content + content = trafilatura.extract( + downloaded, + output_format="markdown" if output_format == "markdown" else "txt", + include_links=include_links, + include_images=include_images, + include_comments=False, + ) + + if content is None: + return { + "success": False, + "error": "Failed to extract content from page - the page may have no readable content", + "url": url, + "title": None, + "author": None, + "date": None, + "content": None, + } + + # Extract metadata + metadata = trafilatura.extract_metadata(downloaded) + + logger.info(f"Web fetch completed: '{url}' extracted {len(content)} characters") + + return { + "success": True, + "url": url, + "title": metadata.title if metadata else None, + "author": metadata.author if metadata else None, + "date": metadata.date if metadata else None, + "content": content, + } + + except Exception as e: + error_msg = f"Fetch failed: {e!s}" + logger.error(f"Web fetch error for '{url}': {error_msg}") + return { + "success": False, + "error": error_msg, + "url": url, + "title": None, + "author": None, + "date": None, + "content": None, + } + + +def create_web_fetch_tool() -> BaseTool: + """ + Create the web fetch tool. + + Returns: + StructuredTool for web page content extraction. + """ + return StructuredTool( + name="web_fetch", + description=( + "Fetch and extract content from a web page. " + "Converts HTML to clean markdown or plain text, removing ads, navigation, and boilerplate. " + "Also extracts metadata like title, author, and publication date when available. " + "Use this when you need to read the full content of a specific web page." + ), + args_schema=WebFetchInput, + coroutine=_web_fetch, + ) + + +__all__ = ["create_web_fetch_tool", "WebFetchInput"] diff --git a/service/app/tools/builtin/image.py b/service/app/tools/builtin/image.py index 94d23c81..f9c10061 100644 --- a/service/app/tools/builtin/image.py +++ b/service/app/tools/builtin/image.py @@ -14,7 +14,7 @@ from uuid import UUID, uuid4 from langchain_core.tools import BaseTool, StructuredTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from app.configs import configs from app.core.storage import FileScope, generate_storage_key, get_storage_service @@ -24,6 +24,9 @@ # --- Input Schemas --- +# Maximum number of reference images allowed for generation +MAX_INPUT_IMAGES = 4 + class GenerateImageInput(BaseModel): """Input schema for generate_image tool.""" @@ -35,14 +38,24 @@ class GenerateImageInput(BaseModel): default="1:1", description="Aspect ratio of the generated image.", ) - image_id: str | None = Field( + image_ids: list[str] | None = Field( default=None, description=( - "Optional image UUID to use as a reference input. " - "Use the 'image_id' value returned from generate_image or upload tools." + f"Optional list of image UUIDs (max {MAX_INPUT_IMAGES}) to use as reference inputs. " + "Use the 'image_id' values returned from generate_image or upload tools." ), ) + @model_validator(mode="after") + def validate_image_inputs(self) -> "GenerateImageInput": + """Validate image_ids field.""" + if self.image_ids: + if len(self.image_ids) > MAX_INPUT_IMAGES: + raise ValueError(f"Maximum {MAX_INPUT_IMAGES} input images allowed, got {len(self.image_ids)}") + if len(self.image_ids) == 0: + self.image_ids = None # Normalize empty list to None + return self + class ReadImageInput(BaseModel): """Input schema for read_image tool.""" @@ -62,8 +75,7 @@ class ReadImageInput(BaseModel): async def _generate_image_with_langchain( prompt: str, aspect_ratio: str = "1:1", - image_bytes: bytes | None = None, - image_mime_type: str | None = None, + images: list[tuple[bytes, str]] | None = None, ) -> tuple[bytes, str]: """ Generate an image using LangChain ChatGoogleGenerativeAI via ProviderManager. @@ -76,6 +88,7 @@ async def _generate_image_with_langchain( Args: prompt: Text description of the image to generate aspect_ratio: Aspect ratio for the generated image + images: Optional list of (image_bytes, mime_type) tuples to use as references Returns: Tuple of (image_bytes, mime_type) @@ -102,25 +115,34 @@ async def _generate_image_with_langchain( ) # Request image generation via LangChain - if image_bytes and image_mime_type: - b64_data = base64.b64encode(image_bytes).decode("utf-8") - message = HumanMessage( - content=[ + if images: + # Build content array with multiple image_url blocks + content: list[dict[str, Any]] = [] + for image_bytes, image_mime_type in images: + b64_data = base64.b64encode(image_bytes).decode("utf-8") + content.append( { "type": "image_url", "image_url": { "url": f"data:{image_mime_type};base64,{b64_data}", }, - }, - { - "type": "text", - "text": ( - "Use the provided image as a reference. " - f"Generate a new image with aspect ratio {aspect_ratio}: {prompt}" - ), - }, - ] + } + ) + + # Add text prompt with appropriate phrasing for single vs multiple images + image_count = len(images) + if image_count == 1: + reference_text = "Use the provided image as a reference." + else: + reference_text = f"Use these {image_count} provided images as references." + + content.append( + { + "type": "text", + "text": f"{reference_text} Generate a new image with aspect ratio {aspect_ratio}: {prompt}", + } ) + message = HumanMessage(content=content) # type: ignore[arg-type] else: message = HumanMessage(content=f"Generate an image with aspect ratio {aspect_ratio}: {prompt}") response = await llm.ainvoke([message]) @@ -172,46 +194,67 @@ async def _generate_image_with_langchain( raise ValueError("No image data in response. Model may not support image generation.") -async def _load_image_for_generation(user_id: str, image_id: str) -> tuple[bytes, str, str]: +async def _load_images_for_generation(user_id: str, image_ids: list[str]) -> list[tuple[bytes, str, str]]: + """ + Load multiple images for generation from the database. + + Args: + user_id: User ID for permission check + image_ids: List of image UUIDs to load + + Returns: + List of tuples: (image_bytes, mime_type, storage_key) + + Raises: + ValueError: If any image_id is invalid, not found, deleted, or inaccessible + """ from app.infra.database import create_task_session_factory from app.repos.file import FileRepository - try: - file_uuid = UUID(image_id) - except ValueError as exc: - raise ValueError(f"Invalid image_id format: {image_id}") from exc + results: list[tuple[bytes, str, str]] = [] # Create a fresh session factory for the current event loop (Celery worker) TaskSessionLocal = create_task_session_factory() async with TaskSessionLocal() as db: file_repo = FileRepository(db) - file_record = await file_repo.get_file_by_id(file_uuid) + storage = get_storage_service() - if file_record is None: - raise ValueError(f"Image not found: {image_id}") + for image_id in image_ids: + try: + file_uuid = UUID(image_id) + except ValueError as exc: + raise ValueError(f"Invalid image_id format: {image_id}") from exc - if file_record.is_deleted: - raise ValueError(f"Image has been deleted: {image_id}") + file_record = await file_repo.get_file_by_id(file_uuid) - if file_record.user_id != user_id and file_record.scope != "public": - raise ValueError("Permission denied: you don't have access to this image") + if file_record is None: + raise ValueError(f"Image not found: {image_id}") - storage_key = file_record.storage_key - content_type = file_record.content_type or "image/png" + if file_record.is_deleted: + raise ValueError(f"Image has been deleted: {image_id}") - storage = get_storage_service() - buffer = io.BytesIO() - await storage.download_file(storage_key, buffer) - image_bytes = buffer.getvalue() - return image_bytes, content_type, storage_key + if file_record.user_id != user_id and file_record.scope != "public": + raise ValueError(f"Permission denied: you don't have access to image {image_id}") + + storage_key = file_record.storage_key + content_type = file_record.content_type or "image/png" + + # Download from storage + buffer = io.BytesIO() + await storage.download_file(storage_key, buffer) + image_bytes = buffer.getvalue() + + results.append((image_bytes, content_type, storage_key)) + + return results async def _generate_image( user_id: str, prompt: str, aspect_ratio: str = "1:1", - image_id: str | None = None, + image_ids: list[str] | None = None, ) -> dict[str, Any]: """ Generate an image and store it to OSS, then register in database. @@ -220,28 +263,27 @@ async def _generate_image( user_id: User ID for storage organization prompt: Image description aspect_ratio: Aspect ratio for the image + image_ids: Optional list of image UUIDs to use as reference inputs Returns: Dictionary with success status, path, URL, and metadata """ try: - # Load optional reference image - source_image_bytes = None - source_mime_type = None - source_storage_key = None - source_image_id = image_id - if source_image_id: - source_image_bytes, source_mime_type, source_storage_key = await _load_image_for_generation( - user_id, - source_image_id, - ) + # Load optional reference images + images_for_generation: list[tuple[bytes, str]] | None = None + source_storage_keys: list[str] = [] + source_image_ids: list[str] = image_ids or [] + + if source_image_ids: + loaded_images = await _load_images_for_generation(user_id, source_image_ids) + images_for_generation = [(img[0], img[1]) for img in loaded_images] + source_storage_keys = [img[2] for img in loaded_images] # Generate image using LangChain via ProviderManager image_bytes, mime_type = await _generate_image_with_langchain( prompt, aspect_ratio, - image_bytes=source_image_bytes, - image_mime_type=source_mime_type, + images=images_for_generation, ) # Determine file extension from mime type @@ -290,27 +332,27 @@ async def _generate_image( metainfo={ "prompt": prompt, "aspect_ratio": aspect_ratio, - "source_image_id": source_image_id, - "source_storage_key": source_storage_key, + "source_image_ids": source_image_ids, + "source_storage_keys": source_storage_keys, }, ) file_record = await file_repo.create_file(file_data) await db.commit() # Refresh to get the generated UUID await db.refresh(file_record) - image_id = str(file_record.id) + generated_image_id = str(file_record.id) - logger.info(f"Generated image for user {user_id}: {storage_key} (id={image_id})") + logger.info(f"Generated image for user {user_id}: {storage_key} (id={generated_image_id})") return { "success": True, - "image_id": image_id, + "image_id": generated_image_id, "path": storage_key, "url": url, "markdown": f"![Generated Image]({url})", "prompt": prompt, "aspect_ratio": aspect_ratio, - "source_image_id": source_image_id, + "source_image_ids": source_image_ids, "mime_type": mime_type, "size_bytes": len(image_bytes), } @@ -511,7 +553,7 @@ def create_image_tools() -> dict[str, BaseTool]: async def generate_image_placeholder( prompt: str, aspect_ratio: str = "1:1", - image_id: str | None = None, + image_ids: list[str] | None = None, ) -> dict[str, Any]: return {"error": "Image tools require agent context binding", "success": False} @@ -520,8 +562,9 @@ async def generate_image_placeholder( description=( "Generate an image based on a text description. " "Provide a detailed prompt describing the desired image. " - "To modify or generate based on a previous image, pass the 'image_id' from a previous generate_image result. " - "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image." + f"To generate based on previous images, pass 'image_ids' with up to {MAX_INPUT_IMAGES} reference image UUIDs. " + "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image. " + "TIP: You can use 'image_id' values when creating PPTX presentations with knowledge_write - see knowledge_help(topic='image_slides') for details." ), args_schema=GenerateImageInput, coroutine=generate_image_placeholder, @@ -563,9 +606,9 @@ def create_image_tools_for_agent(user_id: str) -> list[BaseTool]: async def generate_image_bound( prompt: str, aspect_ratio: str = "1:1", - image_id: str | None = None, + image_ids: list[str] | None = None, ) -> dict[str, Any]: - return await _generate_image(user_id, prompt, aspect_ratio, image_id) + return await _generate_image(user_id, prompt, aspect_ratio, image_ids) tools.append( StructuredTool( @@ -573,8 +616,9 @@ async def generate_image_bound( description=( "Generate an image based on a text description. " "Provide a detailed prompt describing the desired image including style, colors, composition, and subject. " - "To modify or generate based on a previous image, pass the 'image_id' from a previous generate_image result. " - "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image to the user." + f"To generate based on previous images, pass 'image_ids' with up to {MAX_INPUT_IMAGES} reference image UUIDs. " + "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image to the user. " + "TIP: You can use 'image_id' values when creating beautiful PPTX presentations with knowledge_write in image_slides mode - call knowledge_help(topic='image_slides') for the full workflow." ), args_schema=GenerateImageInput, coroutine=generate_image_bound, @@ -609,4 +653,5 @@ async def read_image_bound( "create_image_tools_for_agent", "GenerateImageInput", "ReadImageInput", + "MAX_INPUT_IMAGES", ] diff --git a/service/app/tools/builtin/knowledge.py b/service/app/tools/builtin/knowledge.py deleted file mode 100644 index 4e656dc9..00000000 --- a/service/app/tools/builtin/knowledge.py +++ /dev/null @@ -1,491 +0,0 @@ -""" -Knowledge Base Tools - -LangChain tools for knowledge base file operations. -These tools require runtime context (user_id, knowledge_set_id) to function. - -Unlike web search which works context-free, knowledge tools are created per-agent -with the agent's knowledge_set_id bound at creation time. -""" - -from __future__ import annotations - -import io -import logging -import mimetypes -from datetime import datetime, timezone -from typing import Any -from uuid import UUID - -from langchain_core.tools import BaseTool, StructuredTool -from pydantic import BaseModel, Field -from sqlmodel.ext.asyncio.session import AsyncSession - -from app.core.storage import FileCategory, FileScope, generate_storage_key, get_storage_service -from app.infra.database import AsyncSessionLocal -from app.models.file import FileCreate -from app.repos.file import FileRepository -from app.repos.knowledge_set import KnowledgeSetRepository - -logger = logging.getLogger(__name__) - - -# --- Input Schemas --- - - -class KnowledgeListFilesInput(BaseModel): - """Input schema for list_files tool - no parameters needed.""" - - pass - - -class KnowledgeReadFileInput(BaseModel): - """Input schema for read_file tool.""" - - filename: str = Field( - description=( - "The name of the file to read from the knowledge base. " - "Supported formats: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, " - "images (PNG/JPG/GIF/WEBP with OCR), and plain text files." - ) - ) - - -class KnowledgeWriteFileInput(BaseModel): - """Input schema for write_file tool.""" - - filename: str = Field( - description=( - "The name of the file to create or update. Use appropriate extensions: " - ".txt, .md (plain text), .pdf (PDF document), .docx (Word), " - ".xlsx (Excel), .pptx (PowerPoint), .json, .yaml, .xml, .html." - ) - ) - content: str = Field( - description=( - "The content to write. Can be plain text (creates simple documents) or " - "a JSON specification for production-quality documents:\n\n" - "**For PDF/DOCX (DocumentSpec JSON):**\n" - '{"title": "My Report", "author": "Name", "content": [\n' - ' {"type": "heading", "content": "Section 1", "level": 1},\n' - ' {"type": "text", "content": "Paragraph text here"},\n' - ' {"type": "list", "items": ["Item 1", "Item 2"], "ordered": false},\n' - ' {"type": "table", "headers": ["Col1", "Col2"], "rows": [["A", "B"]]},\n' - ' {"type": "page_break"}\n' - "]}\n\n" - "**For XLSX (SpreadsheetSpec JSON):**\n" - '{"sheets": [{"name": "Data", "headers": ["Name", "Value"], ' - '"data": [["A", 1], ["B", 2]], "freeze_header": true}]}\n\n' - "**For PPTX (PresentationSpec JSON):**\n" - '{"title": "My Presentation", "slides": [\n' - ' {"layout": "title", "title": "Welcome", "subtitle": "Intro"},\n' - ' {"layout": "title_content", "title": "Slide 2", ' - '"content": [{"type": "list", "items": ["Point 1", "Point 2"]}], ' - '"notes": "Speaker notes here"}\n' - "]}" - ) - ) - - -class KnowledgeSearchFilesInput(BaseModel): - """Input schema for search_files tool.""" - - query: str = Field(description="Search term to find files by name.") - - -# --- Helper Functions --- - - -async def _get_files_in_knowledge_set(db: AsyncSession, user_id: str, knowledge_set_id: UUID) -> list[UUID]: - """Get all file IDs in a knowledge set.""" - knowledge_set_repo = KnowledgeSetRepository(db) - - # Validate access - try: - await knowledge_set_repo.validate_access(user_id, knowledge_set_id) - except ValueError as e: - raise ValueError(f"Access denied: {e}") - - # Get file IDs - file_ids = await knowledge_set_repo.get_files_in_knowledge_set(knowledge_set_id) - return file_ids - - -# --- Tool Implementation Functions --- - - -async def _list_files(user_id: str, knowledge_set_id: UUID) -> dict[str, Any]: - """List all files in the knowledge set.""" - try: - async with AsyncSessionLocal() as db: - file_repo = FileRepository(db) - - try: - file_ids = await _get_files_in_knowledge_set(db, user_id, knowledge_set_id) - except ValueError as e: - return {"error": str(e), "success": False} - - # Fetch file objects - files = [] - for file_id in file_ids: - file = await file_repo.get_file_by_id(file_id) - if file and not file.is_deleted: - files.append(file) - - # Format output - entries: list[str] = [] - for f in files: - entries.append(f"[FILE] {f.original_filename} (ID: {f.id})") - - return { - "success": True, - "knowledge_set_id": str(knowledge_set_id), - "entries": entries, - "count": len(entries), - } - - except Exception as e: - logger.error(f"Error listing files: {e}") - return {"error": f"Internal error: {e!s}", "success": False} - - -async def _read_file(user_id: str, knowledge_set_id: UUID, filename: str) -> dict[str, Any]: - """Read content of a file from the knowledge set.""" - from app.mcp.file_handlers import FileHandlerFactory - - try: - # Normalize filename - filename = filename.strip("/").split("/")[-1] - - async with AsyncSessionLocal() as db: - file_repo = FileRepository(db) - target_file = None - - try: - file_ids = await _get_files_in_knowledge_set(db, user_id, knowledge_set_id) - except ValueError as e: - return {"error": str(e), "success": False} - - # Find file by name - for file_id in file_ids: - file = await file_repo.get_file_by_id(file_id) - if file and file.original_filename == filename and not file.is_deleted: - target_file = file - break - - if not target_file: - return {"error": f"File '{filename}' not found in knowledge set.", "success": False} - - # Download content - storage = get_storage_service() - buffer = io.BytesIO() - await storage.download_file(target_file.storage_key, buffer) - file_bytes = buffer.getvalue() - - # Use handler to process content (text mode only for LangChain tools) - handler = FileHandlerFactory.get_handler(target_file.original_filename) - - try: - result = handler.read_content(file_bytes, mode="text") - return { - "success": True, - "filename": target_file.original_filename, - "content": result, - "size_bytes": target_file.file_size, - } - except Exception as e: - return {"error": f"Error parsing file: {e!s}", "success": False} - - except Exception as e: - logger.error(f"Error reading file: {e}") - return {"error": f"Internal error: {e!s}", "success": False} - - -async def _write_file(user_id: str, knowledge_set_id: UUID, filename: str, content: str) -> dict[str, Any]: - """Create or update a file in the knowledge set.""" - from app.mcp.file_handlers import FileHandlerFactory - - try: - filename = filename.strip("/").split("/")[-1] - - async with AsyncSessionLocal() as db: - file_repo = FileRepository(db) - knowledge_set_repo = KnowledgeSetRepository(db) - storage = get_storage_service() - - try: - file_ids = await _get_files_in_knowledge_set(db, user_id, knowledge_set_id) - except ValueError as e: - return {"error": str(e), "success": False} - - # Check if file exists - existing_file = None - for file_id in file_ids: - file = await file_repo.get_file_by_id(file_id) - if file and file.original_filename == filename and not file.is_deleted: - existing_file = file - break - - # Determine content type - content_type, _ = mimetypes.guess_type(filename) - if not content_type: - if filename.endswith(".docx"): - content_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - elif filename.endswith(".xlsx"): - content_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" - elif filename.endswith(".pptx"): - content_type = "application/vnd.openxmlformats-officedocument.presentationml.presentation" - elif filename.endswith(".pdf"): - content_type = "application/pdf" - else: - content_type = "text/plain" - - # Use handler to create content bytes - handler = FileHandlerFactory.get_handler(filename) - encoded_content = handler.create_content(content) - - new_key = generate_storage_key(user_id, filename, FileScope.PRIVATE) - data = io.BytesIO(encoded_content) - file_size_bytes = len(encoded_content) - - await storage.upload_file(data, new_key, content_type=content_type) - - if existing_file: - # Update existing - existing_file.storage_key = new_key - existing_file.file_size = file_size_bytes - existing_file.content_type = content_type - existing_file.updated_at = datetime.now(timezone.utc) - db.add(existing_file) - await db.commit() - return {"success": True, "message": f"Updated file: {filename}"} - else: - # Create new and link - new_file = FileCreate( - user_id=user_id, - folder_id=None, - original_filename=filename, - storage_key=new_key, - file_size=file_size_bytes, - content_type=content_type, - scope=FileScope.PRIVATE, - category=FileCategory.DOCUMENT, - ) - created_file = await file_repo.create_file(new_file) - await knowledge_set_repo.link_file_to_knowledge_set(created_file.id, knowledge_set_id) - await db.commit() - return {"success": True, "message": f"Created file: {filename}"} - - except Exception as e: - logger.error(f"Error writing file: {e}") - return {"error": f"Internal error: {e!s}", "success": False} - - -async def _search_files(user_id: str, knowledge_set_id: UUID, query: str) -> dict[str, Any]: - """Search for files by name in the knowledge set.""" - try: - async with AsyncSessionLocal() as db: - file_repo = FileRepository(db) - matches: list[str] = [] - - try: - file_ids = await _get_files_in_knowledge_set(db, user_id, knowledge_set_id) - except ValueError as e: - return {"error": str(e), "success": False} - - for file_id in file_ids: - file = await file_repo.get_file_by_id(file_id) - if file and not file.is_deleted and query.lower() in file.original_filename.lower(): - matches.append(f"{file.original_filename} (ID: {file.id})") - - return { - "success": True, - "query": query, - "matches": matches, - "count": len(matches), - } - - except Exception as e: - logger.error(f"Error searching files: {e}") - return {"error": f"Internal error: {e!s}", "success": False} - - -# --- Tool Factory --- - - -def create_knowledge_tools() -> dict[str, BaseTool]: - """ - Create knowledge tools with placeholder implementations. - - Note: Knowledge tools require runtime context (user_id, knowledge_set_id). - The actual tool instances are created per-agent with context bound. - This function returns template tools for the registry. - - Returns: - Dict mapping tool_id to BaseTool placeholder instances. - """ - # These are placeholder tools - actual execution requires context binding - # See create_knowledge_tools_for_agent() for runtime creation - - tools: dict[str, BaseTool] = {} - - # List files tool - async def list_files_placeholder() -> dict[str, Any]: - return {"error": "Knowledge tools require agent context binding", "success": False} - - tools["knowledge_list"] = StructuredTool( - name="knowledge_list", - description=( - "List all files in the agent's knowledge base. Returns a list of filenames " - "that can be read or searched. Use this first to discover available files." - ), - args_schema=KnowledgeListFilesInput, - coroutine=list_files_placeholder, - ) - - # Read file tool - async def read_file_placeholder(filename: str) -> dict[str, Any]: - return {"error": "Knowledge tools require agent context binding", "success": False} - - tools["knowledge_read"] = StructuredTool( - name="knowledge_read", - description=( - "Read the content of a file from the agent's knowledge base. " - "Supports: PDF (text + tables), DOCX (text + tables), XLSX (all sheets), " - "PPTX (text + speaker notes), HTML (text extraction), JSON/YAML/XML (formatted), " - "images (OCR text extraction from PNG/JPG/GIF/WEBP), and plain text files. " - "Use knowledge_list first to see available files." - ), - args_schema=KnowledgeReadFileInput, - coroutine=read_file_placeholder, - ) - - # Write file tool - async def write_file_placeholder(filename: str, content: str) -> dict[str, Any]: - return {"error": "Knowledge tools require agent context binding", "success": False} - - tools["knowledge_write"] = StructuredTool( - name="knowledge_write", - description=( - "Create or update a file in the agent's knowledge base. " - "Supports: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, and plain text. " - "For production-quality documents (PDF/DOCX/XLSX/PPTX), provide a JSON " - "specification with structured content (headings, lists, tables, etc.) " - "instead of plain text. See content field description for JSON schema examples." - ), - args_schema=KnowledgeWriteFileInput, - coroutine=write_file_placeholder, - ) - - # Search files tool - async def search_files_placeholder(query: str) -> dict[str, Any]: - return {"error": "Knowledge tools require agent context binding", "success": False} - - tools["knowledge_search"] = StructuredTool( - name="knowledge_search", - description=( - "Search for files by name in the agent's knowledge base. Returns matching filenames that can then be read." - ), - args_schema=KnowledgeSearchFilesInput, - coroutine=search_files_placeholder, - ) - - return tools - - -def create_knowledge_tools_for_agent(user_id: str, knowledge_set_id: UUID) -> list[BaseTool]: - """ - Create knowledge tools bound to a specific agent's context. - - This creates actual working tools with user_id and knowledge_set_id - captured in closures. - - Args: - user_id: The user ID for access control - knowledge_set_id: The knowledge set ID to operate on - - Returns: - List of BaseTool instances with context bound - """ - tools: list[BaseTool] = [] - - # List files tool - async def list_files_bound() -> dict[str, Any]: - return await _list_files(user_id, knowledge_set_id) - - tools.append( - StructuredTool( - name="knowledge_list", - description=( - "List all files in your knowledge base. Returns filenames that can be read or searched. " - "Use this first to discover available files." - ), - args_schema=KnowledgeListFilesInput, - coroutine=list_files_bound, - ) - ) - - # Read file tool - async def read_file_bound(filename: str) -> dict[str, Any]: - return await _read_file(user_id, knowledge_set_id, filename) - - tools.append( - StructuredTool( - name="knowledge_read", - description=( - "Read the content of a file from your knowledge base. " - "Supports: PDF (text + tables), DOCX (text + tables), XLSX (all sheets), " - "PPTX (text + speaker notes), HTML (text extraction), JSON/YAML/XML (formatted), " - "images (OCR text extraction from PNG/JPG/GIF/WEBP), and plain text files. " - "Use knowledge_list first to see available files." - ), - args_schema=KnowledgeReadFileInput, - coroutine=read_file_bound, - ) - ) - - # Write file tool - async def write_file_bound(filename: str, content: str) -> dict[str, Any]: - return await _write_file(user_id, knowledge_set_id, filename, content) - - tools.append( - StructuredTool( - name="knowledge_write", - description=( - "Create or update a file in your knowledge base. " - "Supports: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, and plain text. " - "For production-quality documents (PDF/DOCX/XLSX/PPTX), provide a JSON " - "specification with structured content (headings, lists, tables, etc.) " - "instead of plain text. See content field description for JSON schema examples." - ), - args_schema=KnowledgeWriteFileInput, - coroutine=write_file_bound, - ) - ) - - # Search files tool - async def search_files_bound(query: str) -> dict[str, Any]: - return await _search_files(user_id, knowledge_set_id, query) - - tools.append( - StructuredTool( - name="knowledge_search", - description=( - "Search for files by name in your knowledge base. Returns matching filenames that can then be read." - ), - args_schema=KnowledgeSearchFilesInput, - coroutine=search_files_bound, - ) - ) - - return tools - - -__all__ = [ - "create_knowledge_tools", - "create_knowledge_tools_for_agent", - "KnowledgeListFilesInput", - "KnowledgeReadFileInput", - "KnowledgeWriteFileInput", - "KnowledgeSearchFilesInput", -] diff --git a/service/app/tools/builtin/knowledge/__init__.py b/service/app/tools/builtin/knowledge/__init__.py new file mode 100644 index 00000000..b2187d03 --- /dev/null +++ b/service/app/tools/builtin/knowledge/__init__.py @@ -0,0 +1,30 @@ +""" +Knowledge Base Tools for LangChain Agents. + +This module provides tools for knowledge base file operations. +These tools require runtime context (user_id, knowledge_set_id) to function. + +Unlike web search which works context-free, knowledge tools are created per-agent +with the agent's knowledge_set_id bound at creation time. +""" + +from __future__ import annotations + +from .schemas import ( + KnowledgeHelpInput, + KnowledgeListFilesInput, + KnowledgeReadFileInput, + KnowledgeSearchFilesInput, + KnowledgeWriteFileInput, +) +from .tools import create_knowledge_tools, create_knowledge_tools_for_agent + +__all__ = [ + "create_knowledge_tools", + "create_knowledge_tools_for_agent", + "KnowledgeListFilesInput", + "KnowledgeReadFileInput", + "KnowledgeWriteFileInput", + "KnowledgeSearchFilesInput", + "KnowledgeHelpInput", +] diff --git a/service/app/tools/builtin/knowledge/help_content.py b/service/app/tools/builtin/knowledge/help_content.py new file mode 100644 index 00000000..bb026be4 --- /dev/null +++ b/service/app/tools/builtin/knowledge/help_content.py @@ -0,0 +1,526 @@ +""" +Help content constants for knowledge tools. + +Contains all help text and documentation for knowledge base operations. +""" + +from __future__ import annotations + +from typing import Any + +KNOWLEDGE_HELP_OVERVIEW = """ +# Knowledge Base Tools - Quick Reference + +## Available Tools +- **knowledge_list**: List all files in your knowledge base +- **knowledge_read**: Read content from a file +- **knowledge_write**: Create or update files (supports rich documents) +- **knowledge_search**: Search files by name +- **knowledge_help**: Get detailed usage guides (this tool) + +## Supported File Types +- **Documents**: PDF, DOCX, PPTX, XLSX +- **Data**: JSON, YAML, XML +- **Web**: HTML +- **Text**: TXT, MD, CSV + +## Quick Start +1. Use `knowledge_list` to see available files +2. Use `knowledge_write` with plain text for simple files +3. Use `knowledge_write` with JSON spec for rich documents (call `knowledge_help` with topic='pptx' for examples) + +## Creating Beautiful Presentations with AI Images +For stunning presentations with AI-generated slides: +1. Use `generate_image` to create each slide as an image (use 16:9 aspect ratio) +2. Collect the `image_id` values from each generation +3. Use `knowledge_write` with `mode: "image_slides"` to assemble into PPTX + +Call `knowledge_help(topic='image_slides')` for detailed workflow and examples. + +For detailed help on a specific topic, call knowledge_help with topic='pptx', 'pdf', 'xlsx', 'images', 'tables', 'image_slides', or 'all'. +""" + +KNOWLEDGE_HELP_PPTX = """ +# PPTX (PowerPoint) Generation Guide + +## Basic Structure +```json +{ + "title": "Presentation Title", + "author": "Author Name", + "slides": [ + { "layout": "...", "title": "...", "content": [...] } + ] +} +``` + +## Slide Layouts +- `title` - Title slide with title and subtitle +- `title_content` - Title with content area (most common) +- `section` - Section header +- `two_column` - Two column layout +- `comparison` - Side-by-side comparison +- `title_only` - Title without content placeholder +- `blank` - Empty slide + +## Content Block Types + +### Text +```json +{"type": "text", "content": "Your paragraph text here", "style": "normal"} +``` +Styles: `normal`, `bold`, `italic`, `code` + +### Heading +```json +{"type": "heading", "content": "Section Title", "level": 2} +``` +Levels 1-6 (1 is largest) + +### List +```json +{"type": "list", "items": ["Point 1", "Point 2", "Point 3"], "ordered": false} +``` +Set `ordered: true` for numbered lists + +### Table +```json +{ + "type": "table", + "headers": ["Column 1", "Column 2", "Column 3"], + "rows": [ + ["Row 1 A", "Row 1 B", "Row 1 C"], + ["Row 2 A", "Row 2 B", "Row 2 C"] + ] +} +``` + +### Image +```json +{ + "type": "image", + "url": "https://example.com/chart.png", + "caption": "Figure 1: Sales Chart", + "width": 400 +} +``` +- `url`: HTTP URL, base64 data URL, or storage:// path +- `image_id`: UUID from generate_image tool (alternative to url) +- `caption`: Optional text below image +- `width`: Optional width in points (72 points = 1 inch) + +## Complete Example +```json +{ + "title": "Q4 Business Review", + "slides": [ + { + "layout": "title", + "title": "Q4 2024 Review", + "subtitle": "Sales Department" + }, + { + "layout": "title_content", + "title": "Revenue Summary", + "content": [ + {"type": "heading", "content": "Key Metrics", "level": 2}, + {"type": "table", "headers": ["Region", "Q3", "Q4", "Growth"], "rows": [ + ["North America", "$1.2M", "$1.5M", "+25%"], + ["Europe", "$800K", "$1.1M", "+37%"] + ]}, + {"type": "text", "content": "All regions exceeded targets.", "style": "bold"} + ], + "notes": "Emphasize the European growth story" + }, + { + "layout": "title_content", + "title": "Visual Analysis", + "content": [ + {"type": "image", "url": "https://example.com/chart.png", "caption": "Revenue by Region"} + ] + } + ] +} +``` +""" + +KNOWLEDGE_HELP_PDF_DOCX = """ +# PDF/DOCX Generation Guide + +## Basic Structure +```json +{ + "title": "Document Title", + "author": "Author Name", + "subject": "Document Subject", + "page_size": "letter", + "content": [...] +} +``` + +Page sizes: `letter`, `A4`, `legal` + +## Content Block Types + +### Heading +```json +{"type": "heading", "content": "Chapter Title", "level": 1} +``` + +### Text +```json +{"type": "text", "content": "Paragraph text here", "style": "normal"} +``` +Styles: `normal`, `bold`, `italic`, `code` + +### List +```json +{"type": "list", "items": ["Item 1", "Item 2"], "ordered": false} +``` + +### Table +```json +{ + "type": "table", + "headers": ["Name", "Value"], + "rows": [["Item A", "100"], ["Item B", "200"]] +} +``` + +### Page Break +```json +{"type": "page_break"} +``` + +## Example +```json +{ + "title": "Monthly Report", + "author": "Analytics Team", + "content": [ + {"type": "heading", "content": "Executive Summary", "level": 1}, + {"type": "text", "content": "This report covers..."}, + {"type": "heading", "content": "Key Findings", "level": 2}, + {"type": "list", "items": ["Revenue up 15%", "Costs down 8%"], "ordered": false}, + {"type": "page_break"}, + {"type": "heading", "content": "Detailed Analysis", "level": 1}, + {"type": "table", "headers": ["Metric", "Value"], "rows": [["Sales", "$1.5M"]]} + ] +} +``` +""" + +KNOWLEDGE_HELP_XLSX = """ +# XLSX (Excel) Generation Guide + +## Basic Structure +```json +{ + "sheets": [ + { + "name": "Sheet Name", + "headers": ["Col1", "Col2"], + "data": [[...], [...]], + "freeze_header": true + } + ] +} +``` + +## Sheet Properties +- `name`: Sheet tab name +- `headers`: Optional column headers (styled with blue background) +- `data`: 2D array of cell values (strings, numbers, null) +- `freeze_header`: Freeze the header row for scrolling (default: true) + +## Example: Multi-Sheet Workbook +```json +{ + "sheets": [ + { + "name": "Sales Data", + "headers": ["Product", "Q1", "Q2", "Q3", "Q4", "Total"], + "data": [ + ["Widget A", 100, 150, 200, 250, 700], + ["Widget B", 80, 90, 110, 130, 410], + ["Widget C", 50, 60, 70, 80, 260] + ], + "freeze_header": true + }, + { + "name": "Summary", + "headers": ["Metric", "Value"], + "data": [ + ["Total Revenue", 1370], + ["Average per Product", 456.67], + ["Best Performer", "Widget A"] + ] + } + ] +} +``` +""" + +KNOWLEDGE_HELP_IMAGES = """ +# Image Embedding Guide + +## Supported in PPTX Content Blocks + +### Image Block Structure +```json +{ + "type": "image", + "url": "...", + "caption": "Optional caption text", + "width": 400 +} +``` + +## URL Formats + +### HTTP/HTTPS URLs +```json +{"type": "image", "url": "https://example.com/chart.png"} +``` + +### Base64 Data URLs +```json +{"type": "image", "url": "..."} +``` + +### Storage URLs (internal files) +```json +{"type": "image", "url": "storage://path/to/uploaded/image.png"} +``` + +### Generated Images (from generate_image tool) +```json +{"type": "image", "image_id": "abc-123-456-def"} +``` + +## Size Handling +- **Max file size**: 10MB +- **Max dimension**: 4096px (larger images auto-resized) +- **Width parameter**: Specify in points (72pt = 1 inch) +- **Aspect ratio**: Always preserved + +## Example with Caption +```json +{ + "type": "image", + "url": "https://example.com/quarterly-chart.png", + "caption": "Figure 1: Quarterly Revenue Comparison", + "width": 500 +} +``` + +## Error Handling +If an image fails to load, a placeholder text will appear: +`[Image failed to load: ]` +""" + +KNOWLEDGE_HELP_TABLES = """ +# Table Generation Guide + +## Table Block Structure +```json +{ + "type": "table", + "headers": ["Column 1", "Column 2", "Column 3"], + "rows": [ + ["Row 1 Col 1", "Row 1 Col 2", "Row 1 Col 3"], + ["Row 2 Col 1", "Row 2 Col 2", "Row 2 Col 3"] + ] +} +``` + +## Supported In +- **PPTX**: Styled tables with blue headers +- **PDF**: Formatted tables with borders +- **DOCX**: Word tables with grid style + +## Styling (Automatic) +- Header row: Blue background (#4472C4), white bold text, centered +- Data rows: Standard formatting, left-aligned +- Borders: Thin black borders on all cells + +## Example: Data Table +```json +{ + "type": "table", + "headers": ["Product", "Price", "Stock", "Status"], + "rows": [ + ["Laptop Pro", "$1,299", "45", "In Stock"], + ["Tablet Air", "$799", "120", "In Stock"], + ["Phone Max", "$999", "0", "Out of Stock"], + ["Watch SE", "$249", "200", "In Stock"] + ] +} +``` + +## Tips +- Keep tables simple (avoid merged cells - not supported) +- Use consistent data types per column +- Header count must match row column count +- Empty cells: use empty string "" +""" + +KNOWLEDGE_HELP_IMAGE_SLIDES = """ +# Creating Beautiful Presentations with AI-Generated Slides + +## Overview +Instead of using structured content blocks, you can create stunning presentations +by generating each slide as an AI image. This gives full creative control over +typography, layout, colors, and visual effects. + +## Step-by-Step Workflow + +### Step 1: Generate Slide Images +Use the `generate_image` tool for each slide: + +``` +generate_image( + prompt="Professional presentation slide with title 'Q4 Revenue Summary' showing a blue gradient background, large white bold text, and a subtle upward trending graph icon. Clean corporate style, 16:9 aspect ratio.", + aspect_ratio="16:9" +) +``` + +### Step 2: Collect Image IDs +Each `generate_image` call returns an `image_id`. Save these: +- Slide 1: "abc-123-..." +- Slide 2: "def-456-..." +- etc. + +### Step 3: Create PPTX +Use `knowledge_write` with image_slides mode: + +```json +{ + "mode": "image_slides", + "title": "Q4 Business Review", + "author": "Sales Team", + "image_slides": [ + {"image_id": "abc-123-...", "notes": "Opening remarks"}, + {"image_id": "def-456-...", "notes": "Highlight 25% growth"}, + {"image_id": "ghi-789-...", "notes": "Thank the team"} + ] +} +``` + +## Prompting Tips for Consistent Style + +1. **Define a style template** and reference it in each prompt: + - "Corporate blue theme (#1a73e8), white text, clean minimal layout" + +2. **Specify slide type** in prompts: + - "Title slide" / "Content slide" / "Section divider" / "Closing slide" + +3. **Include aspect ratio**: + - Always use "16:9 aspect ratio presentation slide" + +4. **Maintain visual consistency**: + - "Matching the style of previous slides in this presentation" + +## Complete Example + +```python +# Agent generates beautiful slides +slide1 = await generate_image( + prompt="Title slide: 'Q4 Business Review 2024' with dark blue gradient, + large white text centered, subtle geometric patterns, + professional corporate style, 16:9 presentation slide" +) + +slide2 = await generate_image( + prompt="Content slide: 'Revenue Growth +25%' with bar chart visualization, + blue color scheme matching previous slide, clean data presentation, + 16:9 presentation slide" +) + +slide3 = await generate_image( + prompt="Closing slide: 'Thank You' with contact information, + matching corporate blue theme, 16:9 presentation slide" +) + +# Agent assembles into PPTX +await knowledge_write( + filename="Q4-Review.pptx", + content=json.dumps({ + "mode": "image_slides", + "title": "Q4 Business Review", + "image_slides": [ + {"image_id": slide1["image_id"], "notes": "Welcome everyone"}, + {"image_id": slide2["image_id"], "notes": "Emphasize growth"}, + {"image_id": slide3["image_id"], "notes": "Q&A time"} + ] + }) +) +``` + +## Limitations +- Text in images is NOT editable in PowerPoint +- Best for final presentations, not drafts requiring edits +- Larger file sizes than structured content +""" + + +def get_help_content(topic: str | None) -> dict[str, Any]: + """Get help content for the specified topic.""" + topic_map = { + "pptx": KNOWLEDGE_HELP_PPTX, + "powerpoint": KNOWLEDGE_HELP_PPTX, + "pdf": KNOWLEDGE_HELP_PDF_DOCX, + "docx": KNOWLEDGE_HELP_PDF_DOCX, + "word": KNOWLEDGE_HELP_PDF_DOCX, + "xlsx": KNOWLEDGE_HELP_XLSX, + "excel": KNOWLEDGE_HELP_XLSX, + "images": KNOWLEDGE_HELP_IMAGES, + "image": KNOWLEDGE_HELP_IMAGES, + "tables": KNOWLEDGE_HELP_TABLES, + "table": KNOWLEDGE_HELP_TABLES, + "image_slides": KNOWLEDGE_HELP_IMAGE_SLIDES, + "imageslides": KNOWLEDGE_HELP_IMAGE_SLIDES, + } + + if topic is None: + return {"success": True, "content": KNOWLEDGE_HELP_OVERVIEW} + + topic_lower = topic.lower().strip() + + if topic_lower == "all": + all_content = ( + KNOWLEDGE_HELP_OVERVIEW + + "\n\n---\n\n" + + KNOWLEDGE_HELP_PPTX + + "\n\n---\n\n" + + KNOWLEDGE_HELP_PDF_DOCX + + "\n\n---\n\n" + + KNOWLEDGE_HELP_XLSX + + "\n\n---\n\n" + + KNOWLEDGE_HELP_IMAGES + + "\n\n---\n\n" + + KNOWLEDGE_HELP_TABLES + + "\n\n---\n\n" + + KNOWLEDGE_HELP_IMAGE_SLIDES + ) + return {"success": True, "content": all_content} + + if topic_lower in topic_map: + return {"success": True, "content": topic_map[topic_lower]} + + return { + "success": False, + "error": f"Unknown topic: {topic}. Available topics: pptx, pdf, docx, xlsx, images, tables, image_slides, all", + } + + +__all__ = [ + "KNOWLEDGE_HELP_OVERVIEW", + "KNOWLEDGE_HELP_PPTX", + "KNOWLEDGE_HELP_PDF_DOCX", + "KNOWLEDGE_HELP_XLSX", + "KNOWLEDGE_HELP_IMAGES", + "KNOWLEDGE_HELP_TABLES", + "KNOWLEDGE_HELP_IMAGE_SLIDES", + "get_help_content", +] diff --git a/service/app/tools/builtin/knowledge/operations.py b/service/app/tools/builtin/knowledge/operations.py new file mode 100644 index 00000000..b1a8d59d --- /dev/null +++ b/service/app/tools/builtin/knowledge/operations.py @@ -0,0 +1,344 @@ +""" +Knowledge tool implementation functions. + +Core operations for knowledge base file management. +""" + +from __future__ import annotations + +import io +import json +import logging +import mimetypes +from datetime import datetime, timezone +from typing import Any +from uuid import UUID + +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.storage import FileCategory, FileScope, generate_storage_key, get_storage_service +from app.infra.database import get_task_db_session +from app.models.file import FileCreate +from app.repos.file import FileRepository +from app.repos.knowledge_set import KnowledgeSetRepository + +logger = logging.getLogger(__name__) + + +async def _resolve_image_ids_to_storage_urls( + content: str, + file_repo: FileRepository, + user_id: str, +) -> str: + """ + Resolve image_ids in document specs to storage:// URLs. + + This function handles the async database lookup in the async layer, + so sync document handlers don't need to do async operations. + + Supports: + - PresentationSpec with image_slides mode (image_slides[].image_id) + - PresentationSpec with ImageBlocks in slides (slides[].content[].image_id) + + Args: + content: JSON content to process + file_repo: File repository for database lookups + user_id: User ID for ownership verification (security check) + """ + try: + data = json.loads(content) + except json.JSONDecodeError: + # Not JSON, return as-is + return content + + if not isinstance(data, dict): + return content + + modified = False + + # Collect all image_ids that need resolution + image_ids_to_resolve: set[str] = set() + + # Check for image_slides mode + if data.get("mode") == "image_slides" and "image_slides" in data: + for slide in data.get("image_slides", []): + if isinstance(slide, dict) and slide.get("image_id"): + image_ids_to_resolve.add(slide["image_id"]) + + # Check for ImageBlocks in structured slides + for slide in data.get("slides", []): + if isinstance(slide, dict): + for block in slide.get("content", []): + if isinstance(block, dict) and block.get("type") == "image" and block.get("image_id"): + image_ids_to_resolve.add(block["image_id"]) + + if not image_ids_to_resolve: + return content + + # Resolve image_ids to storage URLs + id_to_storage_url: dict[str, str] = {} + for image_id in image_ids_to_resolve: + try: + file_uuid = UUID(image_id) + file_record = await file_repo.get_file_by_id(file_uuid) + if file_record and not file_record.is_deleted: + # Security check: verify the file belongs to the current user + if file_record.user_id != user_id: + logger.warning( + f"Image ownership mismatch: {image_id} belongs to {file_record.user_id}, not {user_id}" + ) + continue + id_to_storage_url[image_id] = f"storage://{file_record.storage_key}" + else: + logger.warning(f"Image not found or deleted: {image_id}") + except ValueError: + logger.warning(f"Invalid image_id format: {image_id}") + + # Replace image_ids with storage URLs in image_slides + if data.get("mode") == "image_slides" and "image_slides" in data: + for slide in data.get("image_slides", []): + if isinstance(slide, dict) and slide.get("image_id"): + image_id = slide["image_id"] + if image_id in id_to_storage_url: + # Add storage_url field, keep image_id for reference + slide["storage_url"] = id_to_storage_url[image_id] + modified = True + + # Replace image_ids with storage URLs in structured slides + for slide in data.get("slides", []): + if isinstance(slide, dict): + for block in slide.get("content", []): + if isinstance(block, dict) and block.get("type") == "image" and block.get("image_id"): + image_id = block["image_id"] + if image_id in id_to_storage_url: + # Set url to storage URL, keep image_id for reference + block["url"] = id_to_storage_url[image_id] + modified = True + + if modified: + return json.dumps(data) + return content + + +async def get_files_in_knowledge_set(db: AsyncSession, user_id: str, knowledge_set_id: UUID) -> list[UUID]: + """Get all file IDs in a knowledge set.""" + knowledge_set_repo = KnowledgeSetRepository(db) + + # Validate access + try: + await knowledge_set_repo.validate_access(user_id, knowledge_set_id) + except ValueError as e: + raise ValueError(f"Access denied: {e}") + + # Get file IDs + file_ids = await knowledge_set_repo.get_files_in_knowledge_set(knowledge_set_id) + return file_ids + + +async def list_files(user_id: str, knowledge_set_id: UUID) -> dict[str, Any]: + """List all files in the knowledge set.""" + try: + async with get_task_db_session() as db: + file_repo = FileRepository(db) + + try: + file_ids = await get_files_in_knowledge_set(db, user_id, knowledge_set_id) + except ValueError as e: + return {"error": str(e), "success": False} + + # Fetch file objects + files = [] + for file_id in file_ids: + file = await file_repo.get_file_by_id(file_id) + if file and not file.is_deleted: + files.append(file) + + # Format output + entries: list[str] = [] + for f in files: + entries.append(f"[FILE] {f.original_filename} (ID: {f.id})") + + return { + "success": True, + "knowledge_set_id": str(knowledge_set_id), + "entries": entries, + "count": len(entries), + } + + except Exception as e: + logger.error(f"Error listing files: {e}") + return {"error": f"Internal error: {e!s}", "success": False} + + +async def read_file(user_id: str, knowledge_set_id: UUID, filename: str) -> dict[str, Any]: + """Read content of a file from the knowledge set.""" + from app.tools.utils.documents.handlers import FileHandlerFactory + + try: + # Normalize filename + filename = filename.strip("/").split("/")[-1] + + async with get_task_db_session() as db: + file_repo = FileRepository(db) + target_file = None + + try: + file_ids = await get_files_in_knowledge_set(db, user_id, knowledge_set_id) + except ValueError as e: + return {"error": str(e), "success": False} + + # Find file by name + for file_id in file_ids: + file = await file_repo.get_file_by_id(file_id) + if file and file.original_filename == filename and not file.is_deleted: + target_file = file + break + + if not target_file: + return {"error": f"File '{filename}' not found in knowledge set.", "success": False} + + # Download content + storage = get_storage_service() + buffer = io.BytesIO() + await storage.download_file(target_file.storage_key, buffer) + file_bytes = buffer.getvalue() + + # Use handler to process content (text mode only for LangChain tools) + handler = FileHandlerFactory.get_handler(target_file.original_filename) + + try: + result = handler.read_content(file_bytes, mode="text") + return { + "success": True, + "filename": target_file.original_filename, + "content": result, + "size_bytes": target_file.file_size, + } + except Exception as e: + return {"error": f"Error parsing file: {e!s}", "success": False} + + except Exception as e: + logger.error(f"Error reading file: {e}") + return {"error": f"Internal error: {e!s}", "success": False} + + +async def write_file(user_id: str, knowledge_set_id: UUID, filename: str, content: str) -> dict[str, Any]: + """Create or update a file in the knowledge set.""" + from app.tools.utils.documents.handlers import FileHandlerFactory + + try: + filename = filename.strip("/").split("/")[-1] + + async with get_task_db_session() as db: + file_repo = FileRepository(db) + knowledge_set_repo = KnowledgeSetRepository(db) + storage = get_storage_service() + + try: + file_ids = await get_files_in_knowledge_set(db, user_id, knowledge_set_id) + except ValueError as e: + return {"error": str(e), "success": False} + + # Check if file exists + existing_file = None + for file_id in file_ids: + file = await file_repo.get_file_by_id(file_id) + if file and file.original_filename == filename and not file.is_deleted: + existing_file = file + break + + # Determine content type + content_type, _ = mimetypes.guess_type(filename) + if not content_type: + if filename.endswith(".docx"): + content_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + elif filename.endswith(".xlsx"): + content_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + elif filename.endswith(".pptx"): + content_type = "application/vnd.openxmlformats-officedocument.presentationml.presentation" + elif filename.endswith(".pdf"): + content_type = "application/pdf" + else: + content_type = "text/plain" + + # Resolve image_ids to storage URLs for PPTX files (async DB lookup here) + if filename.endswith(".pptx"): + content = await _resolve_image_ids_to_storage_urls(content, file_repo, user_id) + + # Use handler to create content bytes + handler = FileHandlerFactory.get_handler(filename) + encoded_content = handler.create_content(content) + + new_key = generate_storage_key(user_id, filename, FileScope.PRIVATE) + data = io.BytesIO(encoded_content) + file_size_bytes = len(encoded_content) + + await storage.upload_file(data, new_key, content_type=content_type) + + if existing_file: + # Update existing + existing_file.storage_key = new_key + existing_file.file_size = file_size_bytes + existing_file.content_type = content_type + existing_file.updated_at = datetime.now(timezone.utc) + db.add(existing_file) + await db.commit() + return {"success": True, "message": f"Updated file: {filename}"} + else: + # Create new and link + new_file = FileCreate( + user_id=user_id, + folder_id=None, + original_filename=filename, + storage_key=new_key, + file_size=file_size_bytes, + content_type=content_type, + scope=FileScope.PRIVATE, + category=FileCategory.DOCUMENT, + ) + created_file = await file_repo.create_file(new_file) + await knowledge_set_repo.link_file_to_knowledge_set(created_file.id, knowledge_set_id) + await db.commit() + return {"success": True, "message": f"Created file: {filename}"} + + except Exception as e: + logger.error(f"Error writing file: {e}") + return {"error": f"Internal error: {e!s}", "success": False} + + +async def search_files(user_id: str, knowledge_set_id: UUID, query: str) -> dict[str, Any]: + """Search for files by name in the knowledge set.""" + try: + async with get_task_db_session() as db: + file_repo = FileRepository(db) + matches: list[str] = [] + + try: + file_ids = await get_files_in_knowledge_set(db, user_id, knowledge_set_id) + except ValueError as e: + return {"error": str(e), "success": False} + + for file_id in file_ids: + file = await file_repo.get_file_by_id(file_id) + if file and not file.is_deleted and query.lower() in file.original_filename.lower(): + matches.append(f"{file.original_filename} (ID: {file.id})") + + return { + "success": True, + "query": query, + "matches": matches, + "count": len(matches), + } + + except Exception as e: + logger.error(f"Error searching files: {e}") + return {"error": f"Internal error: {e!s}", "success": False} + + +__all__ = [ + "get_files_in_knowledge_set", + "list_files", + "read_file", + "write_file", + "search_files", +] diff --git a/service/app/tools/builtin/knowledge/schemas.py b/service/app/tools/builtin/knowledge/schemas.py new file mode 100644 index 00000000..dcc0dac7 --- /dev/null +++ b/service/app/tools/builtin/knowledge/schemas.py @@ -0,0 +1,95 @@ +""" +Input schemas for knowledge tools. + +Pydantic models defining the input parameters for each knowledge tool. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class KnowledgeListFilesInput(BaseModel): + """Input schema for list_files tool - no parameters needed.""" + + pass + + +class KnowledgeReadFileInput(BaseModel): + """Input schema for read_file tool.""" + + filename: str = Field( + description=( + "The name of the file to read from the knowledge base. " + "Supported formats: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, " + "images (PNG/JPG/GIF/WEBP with OCR), and plain text files." + ) + ) + + +class KnowledgeWriteFileInput(BaseModel): + """Input schema for write_file tool.""" + + filename: str = Field( + description=( + "The name of the file to create or update. Use appropriate extensions: " + ".txt, .md (plain text), .pdf (PDF document), .docx (Word), " + ".xlsx (Excel), .pptx (PowerPoint), .json, .yaml, .xml, .html." + ) + ) + content: str = Field( + description=( + "The content to write. Can be plain text (creates simple documents) or " + "a JSON specification for production-quality documents:\n\n" + "**For PDF/DOCX (DocumentSpec JSON):**\n" + '{"title": "My Report", "author": "Name", "content": [\n' + ' {"type": "heading", "content": "Section 1", "level": 1},\n' + ' {"type": "text", "content": "Paragraph text here"},\n' + ' {"type": "list", "items": ["Item 1", "Item 2"], "ordered": false},\n' + ' {"type": "table", "headers": ["Col1", "Col2"], "rows": [["A", "B"]]},\n' + ' {"type": "page_break"}\n' + "]}\n\n" + "**For XLSX (SpreadsheetSpec JSON):**\n" + '{"sheets": [{"name": "Data", "headers": ["Name", "Value"], ' + '"data": [["A", 1], ["B", 2]], "freeze_header": true}]}\n\n' + "**For PPTX (PresentationSpec JSON) - Structured mode:**\n" + '{"title": "My Presentation", "slides": [\n' + ' {"layout": "title", "title": "Welcome", "subtitle": "Intro"},\n' + ' {"layout": "title_content", "title": "Slide 2", ' + '"content": [{"type": "list", "items": ["Point 1", "Point 2"]}], ' + '"notes": "Speaker notes here"}\n' + "]}\n\n" + "**For PPTX - AI-generated image slides mode:**\n" + '{"mode": "image_slides", "title": "My Presentation", ' + '"image_slides": [\n' + ' {"image_id": "", "notes": "Speaker notes"}\n' + "]}" + ) + ) + + +class KnowledgeSearchFilesInput(BaseModel): + """Input schema for search_files tool.""" + + query: str = Field(description="Search term to find files by name.") + + +class KnowledgeHelpInput(BaseModel): + """Input schema for knowledge_help tool.""" + + topic: str | None = Field( + default=None, + description=( + "Optional topic to get help for. Options: 'pptx', 'pdf', 'docx', 'xlsx', " + "'images', 'tables', 'image_slides', 'all'. If not specified, returns overview." + ), + ) + + +__all__ = [ + "KnowledgeListFilesInput", + "KnowledgeReadFileInput", + "KnowledgeWriteFileInput", + "KnowledgeSearchFilesInput", + "KnowledgeHelpInput", +] diff --git a/service/app/tools/builtin/knowledge/tools.py b/service/app/tools/builtin/knowledge/tools.py new file mode 100644 index 00000000..0cfc2862 --- /dev/null +++ b/service/app/tools/builtin/knowledge/tools.py @@ -0,0 +1,232 @@ +""" +Knowledge tool factory functions. + +Creates LangChain tools for knowledge base operations. +""" + +from __future__ import annotations + +from typing import Any +from uuid import UUID + +from langchain_core.tools import BaseTool, StructuredTool + +from .help_content import get_help_content +from .operations import list_files, read_file, search_files, write_file +from .schemas import ( + KnowledgeHelpInput, + KnowledgeListFilesInput, + KnowledgeReadFileInput, + KnowledgeSearchFilesInput, + KnowledgeWriteFileInput, +) + + +def create_knowledge_tools() -> dict[str, BaseTool]: + """ + Create knowledge tools with placeholder implementations. + + Note: Knowledge tools require runtime context (user_id, knowledge_set_id). + The actual tool instances are created per-agent with context bound. + This function returns template tools for the registry. + + Returns: + Dict mapping tool_id to BaseTool placeholder instances. + """ + # These are placeholder tools - actual execution requires context binding + # See create_knowledge_tools_for_agent() for runtime creation + + tools: dict[str, BaseTool] = {} + + # List files tool + async def list_files_placeholder() -> dict[str, Any]: + return {"error": "Knowledge tools require agent context binding", "success": False} + + tools["knowledge_list"] = StructuredTool( + name="knowledge_list", + description=( + "List all files in the agent's knowledge base. Returns a list of filenames " + "that can be read or searched. Use this first to discover available files." + ), + args_schema=KnowledgeListFilesInput, + coroutine=list_files_placeholder, + ) + + # Read file tool + async def read_file_placeholder(filename: str) -> dict[str, Any]: + return {"error": "Knowledge tools require agent context binding", "success": False} + + tools["knowledge_read"] = StructuredTool( + name="knowledge_read", + description=( + "Read the content of a file from the agent's knowledge base. " + "Supports: PDF (text + tables), DOCX (text + tables), XLSX (all sheets), " + "PPTX (text + speaker notes), HTML (text extraction), JSON/YAML/XML (formatted), " + "images (OCR text extraction from PNG/JPG/GIF/WEBP), and plain text files. " + "Use knowledge_list first to see available files." + ), + args_schema=KnowledgeReadFileInput, + coroutine=read_file_placeholder, + ) + + # Write file tool + async def write_file_placeholder(filename: str, content: str) -> dict[str, Any]: + return {"error": "Knowledge tools require agent context binding", "success": False} + + tools["knowledge_write"] = StructuredTool( + name="knowledge_write", + description=( + "Create or update a file in the agent's knowledge base. " + "Supports: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, and plain text. " + "For rich documents with images, tables, and formatting, provide a JSON specification. " + "Use knowledge_help with topic='pptx' or 'pdf' for detailed examples. " + "For beautiful AI-generated presentations: use generate_image to create slide images, " + "then use knowledge_write with mode='image_slides' and the image_id values. " + "Call knowledge_help(topic='image_slides') for the complete workflow." + ), + args_schema=KnowledgeWriteFileInput, + coroutine=write_file_placeholder, + ) + + # Search files tool + async def search_files_placeholder(query: str) -> dict[str, Any]: + return {"error": "Knowledge tools require agent context binding", "success": False} + + tools["knowledge_search"] = StructuredTool( + name="knowledge_search", + description=( + "Search for files by name in the agent's knowledge base. Returns matching filenames that can then be read." + ), + args_schema=KnowledgeSearchFilesInput, + coroutine=search_files_placeholder, + ) + + # Help tool + async def help_placeholder(topic: str | None = None) -> dict[str, Any]: + return get_help_content(topic) + + tools["knowledge_help"] = StructuredTool( + name="knowledge_help", + description=( + "Get detailed help and examples for using knowledge tools. " + "Call without arguments for overview, or with topic='pptx', 'pdf', 'xlsx', " + "'images', 'tables', 'image_slides', or 'all' for specific guides with JSON examples." + ), + args_schema=KnowledgeHelpInput, + coroutine=help_placeholder, + ) + + return tools + + +def create_knowledge_tools_for_agent(user_id: str, knowledge_set_id: UUID) -> list[BaseTool]: + """ + Create knowledge tools bound to a specific agent's context. + + This creates actual working tools with user_id and knowledge_set_id + captured in closures. + + Args: + user_id: The user ID for access control + knowledge_set_id: The knowledge set ID to operate on + + Returns: + List of BaseTool instances with context bound + """ + tools: list[BaseTool] = [] + + # List files tool + async def list_files_bound() -> dict[str, Any]: + return await list_files(user_id, knowledge_set_id) + + tools.append( + StructuredTool( + name="knowledge_list", + description=( + "List all files in your knowledge base. Returns filenames that can be read or searched. " + "Use this first to discover available files." + ), + args_schema=KnowledgeListFilesInput, + coroutine=list_files_bound, + ) + ) + + # Read file tool + async def read_file_bound(filename: str) -> dict[str, Any]: + return await read_file(user_id, knowledge_set_id, filename) + + tools.append( + StructuredTool( + name="knowledge_read", + description=( + "Read the content of a file from your knowledge base. " + "Supports: PDF (text + tables), DOCX (text + tables), XLSX (all sheets), " + "PPTX (text + speaker notes), HTML (text extraction), JSON/YAML/XML (formatted), " + "images (OCR text extraction from PNG/JPG/GIF/WEBP), and plain text files. " + "Use knowledge_list first to see available files." + ), + args_schema=KnowledgeReadFileInput, + coroutine=read_file_bound, + ) + ) + + # Write file tool + async def write_file_bound(filename: str, content: str) -> dict[str, Any]: + return await write_file(user_id, knowledge_set_id, filename, content) + + tools.append( + StructuredTool( + name="knowledge_write", + description=( + "Create or update a file in your knowledge base. " + "Supports: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, and plain text. " + "For rich documents with images, tables, and formatting, provide a JSON specification. " + "Use knowledge_help with topic='pptx' or 'pdf' for detailed examples. " + "For beautiful AI-generated presentations: use generate_image to create slide images, " + "then use knowledge_write with mode='image_slides' and the image_id values. " + "Call knowledge_help(topic='image_slides') for the complete workflow." + ), + args_schema=KnowledgeWriteFileInput, + coroutine=write_file_bound, + ) + ) + + # Search files tool + async def search_files_bound(query: str) -> dict[str, Any]: + return await search_files(user_id, knowledge_set_id, query) + + tools.append( + StructuredTool( + name="knowledge_search", + description=( + "Search for files by name in your knowledge base. Returns matching filenames that can then be read." + ), + args_schema=KnowledgeSearchFilesInput, + coroutine=search_files_bound, + ) + ) + + # Help tool (no context needed - static content) + async def help_bound(topic: str | None = None) -> dict[str, Any]: + return get_help_content(topic) + + tools.append( + StructuredTool( + name="knowledge_help", + description=( + "Get detailed help and examples for using knowledge tools. " + "Call without arguments for overview, or with topic='pptx', 'pdf', 'xlsx', " + "'images', 'tables', 'image_slides', or 'all' for specific guides with JSON examples." + ), + args_schema=KnowledgeHelpInput, + coroutine=help_bound, + ) + ) + + return tools + + +__all__ = [ + "create_knowledge_tools", + "create_knowledge_tools_for_agent", +] diff --git a/service/app/tools/capabilities.py b/service/app/tools/capabilities.py index fc88b836..1b4deecf 100644 --- a/service/app/tools/capabilities.py +++ b/service/app/tools/capabilities.py @@ -53,6 +53,7 @@ class ToolCapability(StrEnum): "google_search": [ToolCapability.WEB_SEARCH], "bing_search": [ToolCapability.WEB_SEARCH], "tavily_search": [ToolCapability.WEB_SEARCH], + "web_fetch": [ToolCapability.WEB_SEARCH], # Knowledge tools "knowledge_list": [ToolCapability.KNOWLEDGE_RETRIEVAL], "knowledge_read": [ToolCapability.KNOWLEDGE_RETRIEVAL, ToolCapability.FILE_OPERATIONS], diff --git a/service/app/tools/cost.py b/service/app/tools/cost.py new file mode 100644 index 00000000..f3dc0c95 --- /dev/null +++ b/service/app/tools/cost.py @@ -0,0 +1,58 @@ +"""Tool cost calculation utilities.""" + +from __future__ import annotations + +import logging +from typing import Any + +from app.tools.registry import BuiltinToolRegistry + +logger = logging.getLogger(__name__) + + +def calculate_tool_cost( + tool_name: str, + tool_args: dict[str, Any] | None = None, + tool_result: dict[str, Any] | None = None, +) -> int: + """ + Calculate cost for a tool execution. + + Args: + tool_name: Name of the tool + tool_args: Tool input arguments + tool_result: Tool execution result + + Returns: + Cost in points + """ + # Get tool cost config from registry + tool_info = BuiltinToolRegistry.get_info(tool_name) + if not tool_info or not tool_info.cost: + return 0 + + config = tool_info.cost + cost = config.base_cost + + # Add input image cost (for generate_image with reference images) + if config.input_image_cost and tool_args: + image_ids = tool_args.get("image_ids") + if image_ids: + cost += config.input_image_cost * len(image_ids) + + # Add output file cost (for knowledge_write creating new files) + if config.output_file_cost and tool_result: + if isinstance(tool_result, dict): + # Check if tool created a new file (not updated) + # knowledge_write returns message like "Created file: filename" + message = tool_result.get("message", "") + if tool_result.get("success") and "Created" in message: + cost += config.output_file_cost + + if cost > 0: + logger.debug(f"Tool {tool_name} cost: {cost} (base={config.base_cost})") + + return cost + + +__all__ = ["calculate_tool_cost"] diff --git a/service/app/tools/prepare.py b/service/app/tools/prepare.py index bb1106ba..81b4a6df 100644 --- a/service/app/tools/prepare.py +++ b/service/app/tools/prepare.py @@ -81,7 +81,7 @@ def _load_all_builtin_tools( """ Load all available builtin tools. - - Web search: loaded if SearXNG is enabled + - Web search + fetch: loaded if SearXNG is enabled - Knowledge tools: loaded if effective knowledge_set_id exists and user_id is available - Image tools: loaded if image generation is enabled and user_id is available - Memory tools: loaded if agent and user_id are available (currently disabled) @@ -101,10 +101,14 @@ def _load_all_builtin_tools( tools: list[BaseTool] = [] - # Load web_search if available in registry (registered at startup if SearXNG enabled) + # Load web search tools if available in registry (registered at startup if SearXNG enabled) web_search = BuiltinToolRegistry.get("web_search") if web_search: tools.append(web_search) + # Load web fetch tool (bundled with web_search) + web_fetch = BuiltinToolRegistry.get("web_fetch") + if web_fetch: + tools.append(web_fetch) # Determine effective knowledge_set_id # Priority: session override > agent config diff --git a/service/app/tools/registry.py b/service/app/tools/registry.py index ad58d534..66d27420 100644 --- a/service/app/tools/registry.py +++ b/service/app/tools/registry.py @@ -22,6 +22,14 @@ logger = logging.getLogger(__name__) +class ToolCostConfig(BaseModel): + """Cost configuration for a tool.""" + + base_cost: int = Field(default=0, description="Base cost per execution") + input_image_cost: int = Field(default=0, description="Additional cost per input image") + output_file_cost: int = Field(default=0, description="Additional cost per output file") + + class ToolInfo(BaseModel): """Metadata about a builtin tool for API responses.""" @@ -42,6 +50,10 @@ class ToolInfo(BaseModel): default_factory=list, description="Runtime context requirements (e.g., ['user_id', 'knowledge_set_id'])", ) + cost: ToolCostConfig = Field( + default_factory=ToolCostConfig, + description="Cost configuration for this tool", + ) class BuiltinToolRegistry: @@ -65,6 +77,7 @@ def register( ui_toggleable: bool = True, default_enabled: bool = False, requires_context: list[str] | None = None, + cost: ToolCostConfig | None = None, ) -> None: """ Register a builtin tool. @@ -77,6 +90,7 @@ def register( ui_toggleable: Whether to show as toggle in UI (default: True) default_enabled: Whether enabled by default for new agents (default: False) requires_context: List of required context keys (e.g., ["user_id"]) + cost: Cost configuration for the tool (default: no cost) """ cls._tools[tool_id] = tool cls._metadata[tool_id] = ToolInfo( @@ -87,6 +101,7 @@ def register( ui_toggleable=ui_toggleable, default_enabled=default_enabled, requires_context=requires_context or [], + cost=cost or ToolCostConfig(), ) logger.debug(f"Registered builtin tool: {tool_id} ({category})") @@ -158,6 +173,7 @@ def register_builtin_tools() -> None: Called at app startup to populate the registry. """ + from app.tools.builtin.fetch import create_web_fetch_tool from app.tools.builtin.knowledge import create_knowledge_tools from app.tools.builtin.search import create_web_search_tool @@ -172,8 +188,31 @@ def register_builtin_tools() -> None: ui_toggleable=True, default_enabled=True, # Web search enabled by default requires_context=[], + cost=ToolCostConfig(base_cost=1), ) + # Register web fetch tool (bundled with web_search, not separate toggle) + fetch_tool = create_web_fetch_tool() + BuiltinToolRegistry.register( + tool_id="web_fetch", + tool=fetch_tool, + category="search", + display_name="Web Fetch", + ui_toggleable=False, # Bundled with web_search + default_enabled=True, + requires_context=[], + cost=ToolCostConfig(base_cost=1), + ) + + # Tool cost configs for knowledge tools + knowledge_tool_costs = { + "knowledge_list": ToolCostConfig(), # Free + "knowledge_read": ToolCostConfig(), # Free + "knowledge_write": ToolCostConfig(output_file_cost=5), # Charge for new files + "knowledge_search": ToolCostConfig(), # Free + "knowledge_help": ToolCostConfig(), # Free + } + # Register knowledge tools (auto-enabled when knowledge_set exists, not UI toggleable) knowledge_tools = create_knowledge_tools() for tool_id, tool in knowledge_tools.items(): @@ -184,6 +223,7 @@ def register_builtin_tools() -> None: ui_toggleable=False, # Auto-enabled based on context default_enabled=False, requires_context=["user_id", "knowledge_set_id"], + cost=knowledge_tool_costs.get(tool_id, ToolCostConfig()), ) # Register memory tools (disabled due to performance issues) @@ -200,6 +240,12 @@ def register_builtin_tools() -> None: # requires_context=["user_id", "agent_id"], # ) + # Tool cost configs for image tools + image_tool_costs = { + "generate_image": ToolCostConfig(base_cost=10, input_image_cost=5), # 10 base, +5 if using reference + "read_image": ToolCostConfig(base_cost=2), # Vision model inference + } + # Register image tools from app.tools.builtin.image import create_image_tools @@ -213,9 +259,10 @@ def register_builtin_tools() -> None: ui_toggleable=True, default_enabled=False, requires_context=["user_id"], + cost=image_tool_costs.get(tool_id, ToolCostConfig()), ) logger.info(f"Registered {BuiltinToolRegistry.count()} builtin tools") -__all__ = ["BuiltinToolRegistry", "ToolInfo", "register_builtin_tools"] +__all__ = ["BuiltinToolRegistry", "ToolCostConfig", "ToolInfo", "register_builtin_tools"] diff --git a/service/app/tools/utils/documents/__init__.py b/service/app/tools/utils/documents/__init__.py new file mode 100644 index 00000000..686e0dfa --- /dev/null +++ b/service/app/tools/utils/documents/__init__.py @@ -0,0 +1,83 @@ +""" +Document generation utilities for PDF, DOCX, XLSX, and PPTX files. + +This module provides: +- Document specification schemas (spec.py) +- Image fetching from various sources (image_fetcher.py) +- File handlers for reading and creating documents (handlers.py) +""" + +from app.tools.utils.documents.handlers import ( + BaseFileHandler, + DocxFileHandler, + ExcelFileHandler, + FileHandlerFactory, + HtmlFileHandler, + ImageFileHandler, + JsonFileHandler, + PdfFileHandler, + PptxFileHandler, + ReadMode, + TextFileHandler, + XmlFileHandler, + YamlFileHandler, +) +from app.tools.utils.documents.image_fetcher import ( + DEFAULT_TIMEOUT, + MAX_IMAGE_DIMENSION, + MAX_IMAGE_SIZE_BYTES, + FetchedImage, + ImageFetcher, +) +from app.tools.utils.documents.spec import ( + ContentBlock, + DocumentSpec, + HeadingBlock, + ImageBlock, + ImageSlideSpec, + ListBlock, + PageBreakBlock, + PresentationSpec, + SheetSpec, + SlideSpec, + SpreadsheetSpec, + TableBlock, + TextBlock, +) + +__all__ = [ + # Spec classes + "TextBlock", + "HeadingBlock", + "ListBlock", + "TableBlock", + "ImageBlock", + "PageBreakBlock", + "ContentBlock", + "DocumentSpec", + "SheetSpec", + "SpreadsheetSpec", + "SlideSpec", + "ImageSlideSpec", + "PresentationSpec", + # Image fetcher + "ImageFetcher", + "FetchedImage", + "MAX_IMAGE_SIZE_BYTES", + "MAX_IMAGE_DIMENSION", + "DEFAULT_TIMEOUT", + # File handlers + "BaseFileHandler", + "TextFileHandler", + "HtmlFileHandler", + "JsonFileHandler", + "YamlFileHandler", + "XmlFileHandler", + "ImageFileHandler", + "PdfFileHandler", + "DocxFileHandler", + "ExcelFileHandler", + "PptxFileHandler", + "FileHandlerFactory", + "ReadMode", +] diff --git a/service/app/mcp/file_handlers.py b/service/app/tools/utils/documents/handlers.py similarity index 71% rename from service/app/mcp/file_handlers.py rename to service/app/tools/utils/documents/handlers.py index 9e27733c..52deeed6 100644 --- a/service/app/mcp/file_handlers.py +++ b/service/app/tools/utils/documents/handlers.py @@ -20,7 +20,7 @@ from pydantic import BaseModel, ValidationError if TYPE_CHECKING: - from app.mcp.document_spec import DocumentSpec, PresentationSpec, SpreadsheetSpec + from app.tools.utils.documents.spec import DocumentSpec, PresentationSpec, SpreadsheetSpec logger = logging.getLogger(__name__) @@ -384,7 +384,7 @@ def _format_table(self, table: Any) -> str: def create_content(self, text_content: str) -> bytes: """Create PDF from text or DocumentSpec JSON.""" - from app.mcp.document_spec import DocumentSpec + from app.tools.utils.documents.spec import DocumentSpec spec = self._try_parse_spec(text_content, DocumentSpec) if spec: @@ -554,7 +554,7 @@ def _extract_table(self, tbl_element: Any, doc: Any) -> str: def create_content(self, text_content: str) -> bytes: """Create DOCX from text or DocumentSpec JSON.""" - from app.mcp.document_spec import DocumentSpec + from app.tools.utils.documents.spec import DocumentSpec spec = self._try_parse_spec(text_content, DocumentSpec) if spec: @@ -660,7 +660,7 @@ def read_content(self, file_bytes: bytes, mode: ReadMode = "text") -> Union[str, def create_content(self, text_content: str) -> bytes: """Create XLSX from text or SpreadsheetSpec JSON.""" - from app.mcp.document_spec import SpreadsheetSpec + from app.tools.utils.documents.spec import SpreadsheetSpec spec = self._try_parse_spec(text_content, SpreadsheetSpec) if spec: @@ -785,7 +785,7 @@ def read_content(self, file_bytes: bytes, mode: ReadMode = "text") -> Union[str, def create_content(self, text_content: str) -> bytes: """Create PPTX from text or PresentationSpec JSON.""" - from app.mcp.document_spec import PresentationSpec + from app.tools.utils.documents.spec import PresentationSpec spec = self._try_parse_spec(text_content, PresentationSpec) if spec: @@ -816,12 +816,23 @@ def _create_pptx_from_text(self, text_content: str) -> bytes: def _create_pptx_from_spec(self, spec: PresentationSpec) -> bytes: """Create production PPTX from PresentationSpec.""" + # Route based on mode + if spec.mode == "image_slides": + return self._create_pptx_image_slides(spec) + else: + return self._create_pptx_structured(spec) + + def _create_pptx_structured(self, spec: PresentationSpec) -> bytes: + """Create PPTX with structured DSL slides (traditional mode).""" try: from pptx import Presentation except ImportError: raise ImportError("python-pptx is required for PPTX handling. Please install 'python-pptx'.") + from app.tools.utils.documents.image_fetcher import ImageFetcher + prs = Presentation() + image_fetcher = ImageFetcher() # Layout mapping LAYOUTS = { @@ -849,20 +860,9 @@ def _create_pptx_from_spec(self, spec: PresentationSpec) -> bytes: if slide_spec.subtitle and len(slide.placeholders) > 1: slide.placeholders[1].text = slide_spec.subtitle # type: ignore[union-attr] - # Add content to body placeholder - if slide_spec.content and len(slide.placeholders) > 1: - body = slide.placeholders[1] - if hasattr(body, "text_frame"): - tf = body.text_frame # type: ignore[union-attr] - for i, block in enumerate(slide_spec.content): - if block.type == "text": - p = tf.paragraphs[0] if i == 0 else tf.add_paragraph() - p.text = block.content # type: ignore[union-attr] - elif block.type == "list": - for item in block.items: # type: ignore[union-attr] - p = tf.add_paragraph() - p.text = item - p.level = 0 + # Render content blocks + if slide_spec.content: + self._render_content_blocks(slide, slide_spec.content, image_fetcher) # Add speaker notes if slide_spec.notes: @@ -878,6 +878,373 @@ def _create_pptx_from_spec(self, spec: PresentationSpec) -> bytes: prs.save(buffer) return buffer.getvalue() + def _create_pptx_image_slides(self, spec: PresentationSpec) -> bytes: + """Create PPTX with full-bleed images as slides.""" + try: + from pptx import Presentation + from pptx.util import Inches, Pt + except ImportError: + raise ImportError("python-pptx is required for PPTX handling. Please install 'python-pptx'.") + + from app.tools.utils.documents.image_fetcher import ImageFetcher + + prs = Presentation() + # Set slide dimensions (16:9 widescreen) + prs.slide_width = Inches(13.333) + prs.slide_height = Inches(7.5) + + image_fetcher = ImageFetcher() + blank_layout = prs.slide_layouts[6] # Blank layout + + for slide_spec in spec.image_slides: + slide = prs.slides.add_slide(blank_layout) + + # Use storage_url if available (resolved by async layer), otherwise fall back to image_id + if slide_spec.storage_url: + result = image_fetcher.fetch(url=slide_spec.storage_url) + else: + result = image_fetcher.fetch(image_id=slide_spec.image_id) + + if result.success and result.data: + # Add full-bleed image (0,0 to full slide dimensions) + image_stream = io.BytesIO(result.data) + slide.shapes.add_picture( + image_stream, + Inches(0), + Inches(0), + prs.slide_width, + prs.slide_height, + ) + else: + # Add error text for failed images + text_box = slide.shapes.add_textbox(Inches(1), Inches(3), Inches(11), Inches(1)) + tf = text_box.text_frame + tf.paragraphs[0].text = f"[Slide image failed: {result.error}]" + tf.paragraphs[0].font.size = Pt(24) + tf.paragraphs[0].font.italic = True + + # Add speaker notes + if slide_spec.notes: + notes_slide = slide.notes_slide + if notes_slide.notes_text_frame: + notes_slide.notes_text_frame.text = slide_spec.notes + + # Ensure at least one slide exists + if not prs.slides: + slide = prs.slides.add_slide(blank_layout) + + buffer = io.BytesIO() + prs.save(buffer) + return buffer.getvalue() + + def _render_content_blocks( + self, + slide: Any, + content_blocks: list[Any], + image_fetcher: Any, + ) -> None: + """Render all content blocks on a slide with vertical stacking.""" + + # Content area dimensions (below title) + CONTENT_LEFT = 0.5 # inches + CONTENT_TOP = 1.8 # inches + CONTENT_WIDTH = 9.0 # inches + CONTENT_BOTTOM = 7.0 # inches + + current_y = CONTENT_TOP + + for block in content_blocks: + if current_y >= CONTENT_BOTTOM: + logger.warning("Slide content area full, skipping remaining blocks") + break + + remaining_height = CONTENT_BOTTOM - current_y + + if block.type == "text": + height = self._render_text_block(slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH) + elif block.type == "list": + height = self._render_list_block(slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH) + elif block.type == "image": + height = self._render_image_block( + slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH, remaining_height, image_fetcher + ) + elif block.type == "table": + height = self._render_table_block( + slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH, remaining_height + ) + elif block.type == "heading": + height = self._render_heading_block(slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH) + else: + # Unknown block type, skip + height = 0.0 + + current_y += height + + def _render_text_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + ) -> float: + """Render a text block. Returns height in inches.""" + from pptx.util import Inches, Pt + + # Estimate height based on text length + chars_per_line = int(max_width * 12) # ~12 chars per inch at 12pt + num_lines = max(1, len(block.content) // chars_per_line + 1) + box_height = num_lines * 0.25 # ~0.25 inches per line + + text_box = slide.shapes.add_textbox( + Inches(left), + Inches(top), + Inches(max_width), + Inches(box_height), + ) + + tf = text_box.text_frame + tf.word_wrap = True + p = tf.paragraphs[0] + p.text = block.content + p.font.size = Pt(12) + + # Apply style + if hasattr(block, "style"): + if block.style == "bold": + p.font.bold = True + elif block.style == "italic": + p.font.italic = True + elif block.style == "code": + p.font.name = "Courier New" + p.font.size = Pt(10) + + return box_height + 0.1 # Add margin + + def _render_list_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + ) -> float: + """Render a list block. Returns height in inches.""" + from pptx.util import Inches, Pt + + num_items = len(block.items) + item_height = 0.3 # inches per item + box_height = num_items * item_height + + text_box = slide.shapes.add_textbox( + Inches(left), + Inches(top), + Inches(max_width), + Inches(box_height), + ) + + tf = text_box.text_frame + tf.word_wrap = True + + for i, item in enumerate(block.items): + p = tf.paragraphs[0] if i == 0 else tf.add_paragraph() + prefix = f"{i + 1}. " if block.ordered else "• " + p.text = prefix + item + p.font.size = Pt(12) + p.level = 0 + + return box_height + 0.1 + + def _render_heading_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + ) -> float: + """Render a heading block. Returns height in inches.""" + from pptx.util import Inches, Pt + + # Font sizes by heading level + HEADING_SIZES = { + 1: 24, + 2: 20, + 3: 18, + 4: 16, + 5: 14, + 6: 12, + } + + level = getattr(block, "level", 1) + font_size = HEADING_SIZES.get(level, 14) + box_height = font_size / 72.0 * 1.5 # Convert to inches with padding + + text_box = slide.shapes.add_textbox( + Inches(left), + Inches(top), + Inches(max_width), + Inches(box_height), + ) + + tf = text_box.text_frame + p = tf.paragraphs[0] + p.text = block.content + p.font.size = Pt(font_size) + p.font.bold = True + + return box_height + 0.1 + + def _render_table_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + max_height: float, + ) -> float: + """Render a table block. Returns height in inches.""" + from pptx.dml.color import RGBColor + from pptx.enum.text import PP_ALIGN + from pptx.util import Inches, Pt + + num_cols = len(block.headers) + if num_cols == 0: + return 0.0 + + num_rows = 1 + len(block.rows) # Header + data rows + + # Calculate row height + row_height = 0.4 # inches + table_height = num_rows * row_height + + # Cap table height to available space + max_table_height = min(4.0, max_height - 0.2) + if table_height > max_table_height: + table_height = max_table_height + row_height = table_height / num_rows + + # Create table + table_shape = slide.shapes.add_table( + num_rows, + num_cols, + Inches(left), + Inches(top), + Inches(max_width), + Inches(table_height), + ) + table = table_shape.table + + # Style header row + for i, header in enumerate(block.headers): + cell = table.cell(0, i) + cell.text = str(header) + cell.fill.solid() + cell.fill.fore_color.rgb = RGBColor(0x44, 0x72, 0xC4) + + # Set text properties + if cell.text_frame.paragraphs: + para = cell.text_frame.paragraphs[0] + para.font.bold = True + para.font.size = Pt(11) + para.font.color.rgb = RGBColor(0xFF, 0xFF, 0xFF) + para.alignment = PP_ALIGN.CENTER + + # Fill data rows + for row_idx, row_data in enumerate(block.rows): + for col_idx, cell_val in enumerate(row_data): + if col_idx < num_cols: + cell = table.cell(row_idx + 1, col_idx) + cell.text = str(cell_val) + if cell.text_frame.paragraphs: + para = cell.text_frame.paragraphs[0] + para.font.size = Pt(10) + + return table_height + 0.2 + + def _render_image_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + max_height: float, + image_fetcher: Any, + ) -> float: + """Render an image block. Returns height in inches.""" + from pptx.enum.text import PP_ALIGN + from pptx.util import Inches, Pt + + # Fetch image by url or image_id + url = getattr(block, "url", None) + image_id = getattr(block, "image_id", None) + result = image_fetcher.fetch(url=url, image_id=image_id) + + if not result.success: + # Add error placeholder + text_box = slide.shapes.add_textbox(Inches(left), Inches(top), Inches(max_width), Inches(0.5)) + tf = text_box.text_frame + tf.paragraphs[0].text = f"[Image failed to load: {result.error}]" + tf.paragraphs[0].font.italic = True + tf.paragraphs[0].font.size = Pt(10) + return 0.6 + + # Calculate image dimensions + if block.width: + # Use specified width (in points, convert to inches) + img_width = block.width / 72.0 + elif result.width and result.height: + # Scale to fit max_width while maintaining aspect ratio + img_width = min(max_width * 0.8, result.width / 96.0) # 96 DPI assumption, 80% max width + else: + img_width = min(max_width * 0.6, 4.0) # Default 4 inches or 60% width + + # Calculate height maintaining aspect ratio + if result.width and result.height: + aspect = result.height / result.width + img_height = img_width * aspect + else: + img_height = img_width * 0.75 # Default 4:3 aspect + + # Cap height to available space (leave room for caption) + caption_space = 0.5 if block.caption else 0.1 + available_height = max_height - caption_space + if img_height > available_height: + scale = available_height / img_height + img_height = available_height + img_width = img_width * scale + + # Center image horizontally + img_left = left + (max_width - img_width) / 2 + + # Add image to slide + image_stream = io.BytesIO(result.data) + slide.shapes.add_picture( + image_stream, + Inches(img_left), + Inches(top), + Inches(img_width), + Inches(img_height), + ) + + total_height = img_height + 0.1 + + # Add caption if present + if block.caption: + caption_top = top + img_height + 0.1 + caption_box = slide.shapes.add_textbox(Inches(left), Inches(caption_top), Inches(max_width), Inches(0.3)) + tf = caption_box.text_frame + p = tf.paragraphs[0] + p.text = block.caption + p.alignment = PP_ALIGN.CENTER + p.font.size = Pt(10) + p.font.italic = True + total_height += 0.4 + + return total_height + class FileHandlerFactory: """Factory to get the appropriate file handler based on filename.""" diff --git a/service/app/tools/utils/documents/image_fetcher.py b/service/app/tools/utils/documents/image_fetcher.py new file mode 100644 index 00000000..116f9734 --- /dev/null +++ b/service/app/tools/utils/documents/image_fetcher.py @@ -0,0 +1,271 @@ +""" +Image fetching service for document generation. + +Handles HTTP URLs, base64 data URLs, and storage:// protocol. +Designed for synchronous use in document handlers. +""" + +from __future__ import annotations + +import base64 +import io +import logging +import re +from dataclasses import dataclass +from typing import Any, Coroutine, TypeVar + +import httpx +from PIL import Image as PILImage + +logger = logging.getLogger(__name__) + +# Constants +MAX_IMAGE_SIZE_BYTES = 10 * 1024 * 1024 # 10MB +MAX_IMAGE_DIMENSION = 4096 # pixels +DEFAULT_TIMEOUT = 30.0 + +T = TypeVar("T") + + +def _run_async(coro: Coroutine[Any, Any, T]) -> T: + """ + Run an async coroutine from sync code, handling existing event loops. + + When called from within an already-running event loop (e.g., Celery worker), + asyncio.run() fails. This helper uses a thread pool to safely execute + async code in such cases. + """ + import asyncio + import concurrent.futures + + try: + # Check if there's already a running event loop + asyncio.get_running_loop() + # We're in an async context - run in a thread pool + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(asyncio.run, coro) + return future.result() + except RuntimeError: + # No running loop - safe to use asyncio.run + return asyncio.run(coro) + + +@dataclass +class FetchedImage: + """Result of an image fetch operation.""" + + success: bool + data: bytes | None = None + format: str | None = None # "png", "jpeg", etc. + width: int | None = None + height: int | None = None + error: str | None = None + + +class ImageFetcher: + """ + Fetches images from various sources for document embedding. + + Supports: + - HTTP/HTTPS URLs + - Base64 data URLs (data:image/png;base64,...) + - storage:// protocol (internal file storage) + """ + + def __init__( + self, + timeout: float = DEFAULT_TIMEOUT, + max_size_bytes: int = MAX_IMAGE_SIZE_BYTES, + max_dimension: int = MAX_IMAGE_DIMENSION, + ): + self.timeout = timeout + self.max_size_bytes = max_size_bytes + self.max_dimension = max_dimension + + def fetch(self, url: str | None = None, image_id: str | None = None) -> FetchedImage: + """ + Fetch an image from the given URL or resolve image_id to storage. + + Args: + url: HTTP URL, base64 data URL, or storage:// URL (takes precedence if provided) + image_id: UUID of a generated image from generate_image tool (fallback if no url) + + Returns: + FetchedImage with data or error information + """ + try: + # Prefer URL over image_id when both are present + # This allows the async layer to resolve image_ids to URLs beforehand + if url: + if url.startswith("data:"): + return self._fetch_base64(url) + elif url.startswith("storage://"): + return self._fetch_from_storage(url) + elif url.startswith(("http://", "https://")): + return self._fetch_http(url) + else: + return FetchedImage(success=False, error=f"Unsupported URL scheme: {url[:50]}") + elif image_id: + # Fallback to image_id if no URL provided + return self._fetch_by_image_id(image_id) + else: + return FetchedImage(success=False, error="Either url or image_id must be provided") + except Exception as e: + logger.error(f"Image fetch failed: {e}") + return FetchedImage(success=False, error=str(e)) + + def _fetch_http(self, url: str) -> FetchedImage: + """Fetch image from HTTP/HTTPS URL.""" + try: + with httpx.Client(timeout=self.timeout, follow_redirects=True) as client: + response = client.get(url) + response.raise_for_status() + + # Check size from header + content_length = response.headers.get("content-length") + if content_length and int(content_length) > self.max_size_bytes: + return FetchedImage( + success=False, + error=f"Image too large: {int(content_length)} bytes (max {self.max_size_bytes})", + ) + + data = response.content + if len(data) > self.max_size_bytes: + return FetchedImage( + success=False, + error=f"Image too large: {len(data)} bytes (max {self.max_size_bytes})", + ) + + return self._process_image_data(data) + + except httpx.TimeoutException: + return FetchedImage(success=False, error=f"Timeout fetching image: {url[:100]}") + except httpx.HTTPStatusError as e: + return FetchedImage(success=False, error=f"HTTP error {e.response.status_code}: {url[:100]}") + except httpx.RequestError as e: + return FetchedImage(success=False, error=f"Request error: {e}") + + def _fetch_base64(self, data_url: str) -> FetchedImage: + """Decode base64 data URL.""" + # Format: data:image/png;base64, + match = re.match(r"data:image/(\w+);base64,(.+)", data_url, re.DOTALL) + if not match: + return FetchedImage(success=False, error="Invalid base64 data URL format") + + format_hint = match.group(1).lower() + b64_data = match.group(2) + + try: + data = base64.b64decode(b64_data) + except Exception as e: + return FetchedImage(success=False, error=f"Base64 decode failed: {e}") + + if len(data) > self.max_size_bytes: + return FetchedImage( + success=False, + error=f"Image too large: {len(data)} bytes (max {self.max_size_bytes})", + ) + + return self._process_image_data(data, format_hint) + + def _fetch_from_storage(self, storage_url: str) -> FetchedImage: + """ + Fetch image from internal storage. + + Uses _run_async to execute async storage download in sync context. + """ + from app.core.storage import get_storage_service + + # Extract storage key: storage://path/to/file.png -> path/to/file.png + storage_key = storage_url.replace("storage://", "") + + try: + storage = get_storage_service() + buffer = io.BytesIO() + + # Run async download in sync context + _run_async(storage.download_file(storage_key, buffer)) + + data = buffer.getvalue() + if len(data) > self.max_size_bytes: + return FetchedImage( + success=False, + error=f"Image too large: {len(data)} bytes (max {self.max_size_bytes})", + ) + + return self._process_image_data(data) + + except Exception as e: + return FetchedImage(success=False, error=f"Storage fetch failed: {e}") + + def _fetch_by_image_id(self, image_id: str) -> FetchedImage: + """ + Handle image_id parameter. + + Image IDs should be resolved to storage URLs in the async layer (operations.py) + before reaching this sync code. If this method is called, it means the proper + flow wasn't followed. + + For backward compatibility, we return a clear error message. + """ + from uuid import UUID + + # Validate UUID format first + try: + UUID(image_id) + except ValueError: + return FetchedImage(success=False, error=f"Invalid image_id format: {image_id}") + + # Return error explaining the proper flow + return FetchedImage( + success=False, + error=( + f"image_id '{image_id}' was not resolved to a storage URL. " + "Image IDs must be resolved in the async layer before document generation. " + "Use knowledge_write tool which handles this automatically." + ), + ) + + def _process_image_data(self, data: bytes, format_hint: str | None = None) -> FetchedImage: + """ + Process raw image data: validate, get dimensions, optionally resize. + """ + try: + img = PILImage.open(io.BytesIO(data)) + + # Get actual format + img_format = (img.format or format_hint or "PNG").lower() + if img_format == "jpeg": + img_format = "jpg" + + width, height = img.size + + # Resize if too large + if width > self.max_dimension or height > self.max_dimension: + ratio = min(self.max_dimension / width, self.max_dimension / height) + new_width = int(width * ratio) + new_height = int(height * ratio) + img = img.resize((new_width, new_height), PILImage.Resampling.LANCZOS) + width, height = new_width, new_height + + # Re-encode + output = io.BytesIO() + save_format = "PNG" if img_format == "png" else "JPEG" + if img.mode in ("RGBA", "P") and save_format == "JPEG": + img = img.convert("RGB") + img.save(output, format=save_format) + data = output.getvalue() + + return FetchedImage( + success=True, + data=data, + format=img_format, + width=width, + height=height, + ) + + except Exception as e: + return FetchedImage(success=False, error=f"Image processing failed: {e}") + + +__all__ = ["ImageFetcher", "FetchedImage", "MAX_IMAGE_SIZE_BYTES", "MAX_IMAGE_DIMENSION", "DEFAULT_TIMEOUT"] diff --git a/service/app/mcp/document_spec.py b/service/app/tools/utils/documents/spec.py similarity index 79% rename from service/app/mcp/document_spec.py rename to service/app/tools/utils/documents/spec.py index 62fb5279..c87d84ef 100644 --- a/service/app/mcp/document_spec.py +++ b/service/app/tools/utils/documents/spec.py @@ -62,7 +62,11 @@ class ImageBlock(BaseModel): """An image block.""" type: Literal["image"] = "image" - url: str = Field(description="Image URL or base64 data URL") + url: str | None = Field(default=None, description="Image URL or base64 data URL") + image_id: str | None = Field( + default=None, + description="UUID of generated image from generate_image tool", + ) caption: str | None = Field(default=None, description="Optional image caption") width: int | None = Field(default=None, description="Optional width in points/pixels") @@ -186,11 +190,26 @@ class SlideSpec(BaseModel): notes: str | None = Field(default=None, description="Speaker notes") +class ImageSlideSpec(BaseModel): + """Specification for an image-only slide (full-bleed generated image).""" + + image_id: str = Field(description="UUID of the generated slide image from generate_image tool") + storage_url: str | None = Field( + default=None, + description="Resolved storage URL (set by async layer, not by user)", + ) + notes: str | None = Field(default=None, description="Speaker notes for this slide") + + class PresentationSpec(BaseModel): """ Production-ready presentation specification for PPTX generation. - Example: + Supports two modes: + - structured: Traditional slides with DSL content blocks (default) + - image_slides: Full-bleed AI-generated image slides + + Example (structured mode): ```json { "title": "Q4 Review", @@ -212,13 +231,39 @@ class PresentationSpec(BaseModel): ] } ``` + + Example (image_slides mode): + ```json + { + "mode": "image_slides", + "title": "Q4 Review", + "image_slides": [ + {"image_id": "abc-123-...", "notes": "Welcome everyone"}, + {"image_id": "def-456-...", "notes": "Revenue summary"} + ] + } + ``` """ title: str | None = Field(default=None, description="Presentation title") author: str | None = Field(default=None, description="Presentation author") + + # Mode selection + mode: Literal["structured", "image_slides"] = Field( + default="structured", + description="'structured' for DSL slides, 'image_slides' for full-bleed generated images", + ) + + # For structured mode slides: list[SlideSpec] = Field( default_factory=list, - description="List of slide specifications", + description="List of slide specifications (for structured mode)", + ) + + # For image_slides mode + image_slides: list[ImageSlideSpec] = Field( + default_factory=list, + description="List of image slide specifications (for image_slides mode)", ) @@ -234,5 +279,6 @@ class PresentationSpec(BaseModel): "SheetSpec", "SpreadsheetSpec", "SlideSpec", + "ImageSlideSpec", "PresentationSpec", ] diff --git a/service/app/utils/literature/__init__.py b/service/app/utils/literature/__init__.py new file mode 100644 index 00000000..c4dd14ba --- /dev/null +++ b/service/app/utils/literature/__init__.py @@ -0,0 +1,17 @@ +""" +Literature search utilities for multi-source academic literature retrieval +""" + +from .base_client import BaseLiteratureClient +from .doi_cleaner import deduplicate_by_doi, normalize_doi +from .models import LiteratureWork, SearchRequest +from .work_distributor import WorkDistributor + +__all__ = [ + "BaseLiteratureClient", + "normalize_doi", + "deduplicate_by_doi", + "SearchRequest", + "LiteratureWork", + "WorkDistributor", +] diff --git a/service/app/utils/literature/base_client.py b/service/app/utils/literature/base_client.py new file mode 100644 index 00000000..ba8a3db6 --- /dev/null +++ b/service/app/utils/literature/base_client.py @@ -0,0 +1,32 @@ +""" +Abstract base class for literature data source clients +""" + +from abc import ABC, abstractmethod + +from .models import LiteratureWork, SearchRequest + + +class BaseLiteratureClient(ABC): + """ + Base class for literature data source clients + + All data source implementations (OpenAlex, Semantic Scholar, PubMed, etc.) + should inherit from this class and implement the required methods. + """ + + @abstractmethod + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """ + Execute search and return results in standard format + + Args: + request: Standardized search request + + Returns: + Tuple of (works, warnings) where warnings is a list of messages for LLM feedback + + Raises: + Exception: If search fails after retries + """ + pass diff --git a/service/app/utils/literature/doi_cleaner.py b/service/app/utils/literature/doi_cleaner.py new file mode 100644 index 00000000..0af35d6d --- /dev/null +++ b/service/app/utils/literature/doi_cleaner.py @@ -0,0 +1,116 @@ +""" +DOI normalization and deduplication utilities +""" + +import re +from typing import Protocol, TypeVar + + +class WorkWithDOI(Protocol): + """Protocol for objects with DOI and citation information""" + + doi: str | None + cited_by_count: int + publication_year: int | None + + +T = TypeVar("T", bound=WorkWithDOI) + + +def normalize_doi(doi: str | None) -> str | None: + """ + Normalize DOI format to standard form + + Removes common prefixes, validates format, and converts to lowercase. + DOI specification (ISO 26324) defines DOI matching as case-insensitive, + so lowercase conversion is safe and improves consistency. + + Args: + doi: DOI string in any common format + + Returns: + Normalized DOI (e.g., "10.1038/nature12345") or None if invalid + + Examples: + >>> normalize_doi("https://doi.org/10.1038/nature12345") + "10.1038/nature12345" + >>> normalize_doi("DOI: 10.1038/nature12345") + "10.1038/nature12345" + >>> normalize_doi("doi:10.1038/nature12345") + "10.1038/nature12345" + """ + if not doi: + return None + + doi = doi.strip().lower() + + # Remove common prefixes + doi = re.sub(r"^(https?://)?(dx\.)?doi\.org/", "", doi) + doi = re.sub(r"^doi:\s*", "", doi) + + # Validate format (10.xxxx/yyyy) + return doi if re.match(r"^10\.\d+/.+", doi) else None + + +def deduplicate_by_doi(works: list[T]) -> list[T]: + """ + Deduplicate works by DOI, keeping the highest priority version + + Priority rules: + 1. Works with DOI take priority over those without + 2. For same DOI, keep the one with higher citation count + 3. If citation count is equal, keep the most recently published + + Args: + works: List of LiteratureWork objects + + Returns: + Deduplicated list of works + + Examples: + >>> works = [ + ... LiteratureWork(doi="10.1038/1", cited_by_count=100, ...), + ... LiteratureWork(doi="10.1038/1", cited_by_count=50, ...), + ... LiteratureWork(doi=None, ...), + ... ] + >>> unique = deduplicate_by_doi(works) + >>> len(unique) + 2 + >>> unique[0].cited_by_count + 100 + """ + # Group by: with DOI vs without DOI + with_doi: dict[str, T] = {} + without_doi: list[T] = [] + + for work in works: + # Check if work has doi attribute + if not work.doi: + without_doi.append(work) + continue + + doi = normalize_doi(work.doi) + if not doi: + without_doi.append(work) + continue + + # If DOI already exists, compare priority + if doi in with_doi: + existing = with_doi[doi] + + # Higher citation count? + if work.cited_by_count > existing.cited_by_count: + with_doi[doi] = work + # Same citation count, more recent publication? + elif ( + work.cited_by_count == existing.cited_by_count + and work.publication_year + and existing.publication_year + and work.publication_year > existing.publication_year + ): + with_doi[doi] = work + else: + with_doi[doi] = work + + # Combine results: DOI works first, then non-DOI works + return list(with_doi.values()) + without_doi diff --git a/service/app/utils/literature/models.py b/service/app/utils/literature/models.py new file mode 100644 index 00000000..48aca79c --- /dev/null +++ b/service/app/utils/literature/models.py @@ -0,0 +1,82 @@ +""" +Shared data models for literature utilities +""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class SearchRequest: + """ + Standardized search request format for all data sources + + Attributes: + query: Search keywords (searches title, abstract, full text) + author: Author name (will be converted to author ID) + institution: Institution name (will be converted to institution ID) + source: Journal or conference name + year_from: Start year (inclusive) + year_to: End year (inclusive) + is_oa: Filter for open access only + work_type: Work type filter ("article", "review", "preprint", etc.) + language: Language code filter (e.g., "en", "zh", "fr") + is_retracted: Filter for retracted works (True to include only retracted, False to exclude) + has_abstract: Filter for works with abstracts + has_fulltext: Filter for works with full text available + sort_by: Sort method - "relevance", "cited_by_count", "publication_date" + max_results: Maximum number of results to return + data_sources: List of data sources to query (default: ["openalex"]) + """ + + query: str + author: str | None = None + institution: str | None = None + source: str | None = None + year_from: int | None = None + year_to: int | None = None + is_oa: bool | None = None + work_type: str | None = None + language: str | None = None + is_retracted: bool | None = None + has_abstract: bool | None = None + has_fulltext: bool | None = None + sort_by: str = "relevance" + max_results: int = 50 + data_sources: list[str] | None = None + + +@dataclass +class LiteratureWork: + """ + Standardized literature work format across all data sources + + Attributes: + id: Internal ID from the data source + doi: Digital Object Identifier (normalized format) + title: Work title + authors: List of author information [{"name": "...", "id": "..."}] + publication_year: Year of publication + cited_by_count: Number of citations + abstract: Abstract text + journal: Journal or venue name + is_oa: Whether open access + access_url: Best available access link (OA, landing page, or DOI) + primary_institution: First affiliated institution (if available) + source: Data source name ("openalex", "semantic_scholar", etc.) + raw_data: Original data from the source (for debugging) + """ + + id: str + doi: str | None + title: str + authors: list[dict[str, str | None]] + publication_year: int | None + cited_by_count: int + abstract: str | None + journal: str | None + is_oa: bool + source: str + access_url: str | None = None + primary_institution: str | None = None + raw_data: dict[str, Any] = field(default_factory=dict) diff --git a/service/app/utils/literature/openalex_client.py b/service/app/utils/literature/openalex_client.py new file mode 100644 index 00000000..0b08d01a --- /dev/null +++ b/service/app/utils/literature/openalex_client.py @@ -0,0 +1,611 @@ +""" +OpenAlex API client for literature search + +Implements the best practices from OpenAlex API guide: +- Two-step lookup for names (author/institution/source -> ID -> filter) +- Rate limiting with mailto parameter (10 req/s) +- Exponential backoff retry for errors +- Batch queries with pipe separator (up to 50 IDs) +- Maximum page size (200 per page) +- Abstract reconstruction from inverted index +""" + +import asyncio +import logging +import random +from typing import Any + +import httpx + +from .base_client import BaseLiteratureClient +from .doi_cleaner import normalize_doi +from .models import LiteratureWork, SearchRequest + +logger = logging.getLogger(__name__) + + +class _RateLimiter: + """ + Simple global rate limiter with optional concurrency guard. + + Enforces a minimum interval between request starts across all callers. + """ + + def __init__(self, rate_per_second: float, max_concurrency: int) -> None: + self._min_interval = 1.0 / rate_per_second if rate_per_second > 0 else 0.0 + self._lock = asyncio.Lock() + self._last_request = 0.0 + self._semaphore = asyncio.Semaphore(max_concurrency) + + async def __aenter__(self) -> None: + await self._semaphore.acquire() + await self._throttle() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + self._semaphore.release() + + async def _throttle(self) -> None: + if self._min_interval <= 0: + return + + async with self._lock: + now = asyncio.get_running_loop().time() + wait_time = self._last_request + self._min_interval - now + if wait_time > 0: + await asyncio.sleep(wait_time) + self._last_request = asyncio.get_running_loop().time() + + +class OpenAlexClient(BaseLiteratureClient): + """ + OpenAlex API client + + Implements best practices from official API guide for LLMs: + https://docs.openalex.org/api-guide-for-llms + """ + + BASE_URL = "https://api.openalex.org" + MAX_PER_PAGE = 200 + MAX_RETRIES = 5 + TIMEOUT = 30.0 + + def __init__(self, email: str | None, rate_limit: int | None = None, timeout: float = 30.0) -> None: + """ + Initialize OpenAlex client + + Args: + email: Email for polite pool (10x rate limit increase). If None, use default pool. + rate_limit: Requests per second (default: 10 with email, 1 without email) + timeout: Request timeout in seconds (default: 30.0) + """ + self.email = email + self.rate_limit = rate_limit or (10 if self.email else 1) + max_concurrency = 10 if self.email else 1 + self.rate_limiter = _RateLimiter(rate_per_second=self.rate_limit, max_concurrency=max_concurrency) + self.client = httpx.AsyncClient(timeout=timeout) + pool_type = "polite" if self.email else "default" + logger.info( + "OpenAlex client initialized with pool=%s, email=%s, rate_limit=%s/s", + pool_type, + "" if self.email else None, + self.rate_limit, + ) + + @property + def pool_type(self) -> str: + """Return pool type string.""" + return "polite" if self.email else "default" + + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """ + Execute search and return results in standard format + + Implementation steps: + 1. Convert author name -> author ID (if specified) + 2. Convert institution name -> institution ID (if specified) + 3. Convert journal name -> source ID (if specified) + 4. Build filter query + 5. Paginate through results + 6. Transform to standard format + + Args: + request: Standardized search request + + Returns: + Tuple of (works, warnings) + - works: List of literature works in standard format + - warnings: List of warning/info messages for LLM feedback + """ + logger.info( + "OpenAlex search [%s @ %s/s]: query=%r, max_results=%d", + self.pool_type, + self.rate_limit, + request.query, + request.max_results, + ) + + warnings: list[str] = [] + + # Step 1-3: Resolve IDs for names (two-step lookup pattern) + author_id = None + if request.author: + author_id, _success, msg = await self._resolve_author_id(request.author) + warnings.append(msg) + + institution_id = None + if request.institution: + institution_id, _success, msg = await self._resolve_institution_id(request.institution) + warnings.append(msg) + + source_id = None + if request.source: + source_id, _success, msg = await self._resolve_source_id(request.source) + warnings.append(msg) + + # Step 4: Build query parameters + params = self._build_query_params(request, author_id, institution_id, source_id) + + # Step 5: Fetch all pages + works = await self._fetch_all_pages(params, request.max_results) + + # Step 6: Transform to standard format + return [self._transform_work(w) for w in works], warnings + + def _build_query_params( + self, + request: SearchRequest, + author_id: str | None, + institution_id: str | None, + source_id: str | None, + ) -> dict[str, str]: + """ + Build OpenAlex query parameters + + Args: + request: Search request + author_id: Resolved author ID (if any) + institution_id: Resolved institution ID (if any) + source_id: Resolved source ID (if any) + + Returns: + Dictionary of query parameters + """ + params: dict[str, str] = { + "per-page": str(self.MAX_PER_PAGE), + } + + if self.email: + params["mailto"] = self.email + + # Search keywords + if request.query: + params["search"] = request.query + + # Build filters + filters: list[str] = [] + + if author_id: + filters.append(f"authorships.author.id:{author_id}") + + if institution_id: + filters.append(f"authorships.institutions.id:{institution_id}") + + if source_id: + filters.append(f"primary_location.source.id:{source_id}") + + # Year range + if request.year_from and request.year_to: + filters.append(f"publication_year:{request.year_from}-{request.year_to}") + elif request.year_from: + filters.append(f"publication_year:>{request.year_from - 1}") + elif request.year_to: + filters.append(f"publication_year:<{request.year_to + 1}") + + # Open access filter + if request.is_oa is not None: + filters.append(f"is_oa:{str(request.is_oa).lower()}") + + # Work type filter + if request.work_type: + filters.append(f"type:{request.work_type}") + + # Language filter + if request.language: + filters.append(f"language:{request.language}") + + # Retracted filter + if request.is_retracted is not None: + filters.append(f"is_retracted:{str(request.is_retracted).lower()}") + + # Abstract filter + if request.has_abstract is not None: + filters.append(f"has_abstract:{str(request.has_abstract).lower()}") + + # Fulltext filter + if request.has_fulltext is not None: + filters.append(f"has_fulltext:{str(request.has_fulltext).lower()}") + + if filters: + params["filter"] = ",".join(filters) + + # Sorting + sort_map = { + "relevance": None, # Default sorting by relevance + "cited_by_count": "cited_by_count:desc", + "publication_date": "publication_date:desc", + } + if sort := sort_map.get(request.sort_by): + params["sort"] = sort + + return params + + async def _resolve_author_id(self, author_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: author name -> author ID + + Args: + author_name: Author name to search + + Returns: + Tuple of (author_id, success, message) + - author_id: Author ID (e.g., "A5023888391") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/authors" + params: dict[str, str] = {"search": author_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + # Return first result's ID in short format + author_id = results[0]["id"].split("/")[-1] + author_display = results[0].get("display_name", author_name) + logger.info("Resolved author %r -> %s", author_name, author_id) + return author_id, True, f"✓ Author resolved: '{author_name}' -> '{author_display}'" + else: + msg = ( + f"⚠️ Author '{author_name}' not found. " + f"Suggestions: (1) Try full name format like 'Smith, John' or 'John Smith', " + f"(2) Check spelling, (3) Try removing middle name/initial." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve author '{author_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _resolve_institution_id(self, institution_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: institution name -> institution ID + + Args: + institution_name: Institution name to search + + Returns: + Tuple of (institution_id, success, message) + - institution_id: Institution ID (e.g., "I136199984") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/institutions" + params: dict[str, str] = {"search": institution_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + institution_id = results[0]["id"].split("/")[-1] + inst_display = results[0].get("display_name", institution_name) + logger.info("Resolved institution %r -> %s", institution_name, institution_id) + return institution_id, True, f"✓ Institution resolved: '{institution_name}' -> '{inst_display}'" + else: + msg = ( + f"⚠️ Institution '{institution_name}' not found. " + f"Suggestions: (1) Use full official name (e.g., 'Harvard University' not 'Harvard'), " + f"(2) Try variations (e.g., 'MIT' vs 'Massachusetts Institute of Technology'), " + f"(3) Check spelling." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve institution '{institution_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _resolve_source_id(self, source_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: source name -> source ID + + Args: + source_name: Journal/conference name to search + + Returns: + Tuple of (source_id, success, message) + - source_id: Source ID (e.g., "S137773608") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/sources" + params: dict[str, str] = {"search": source_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + source_id = results[0]["id"].split("/")[-1] + source_display = results[0].get("display_name", source_name) + logger.info("Resolved source %r -> %s", source_name, source_id) + return source_id, True, f"✓ Source resolved: '{source_name}' -> '{source_display}'" + else: + msg = ( + f"⚠️ Source/Journal '{source_name}' not found. " + f"Suggestions: (1) Use full journal name (e.g., 'Nature' or 'Science'), " + f"(2) Try alternative names (e.g., 'JAMA' vs 'Journal of the American Medical Association'), " + f"(3) Check spelling." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve source '{source_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _fetch_all_pages(self, params: dict[str, str], max_results: int) -> list[dict[str, Any]]: + """ + Paginate through all results up to max_results + + Args: + params: Base query parameters + max_results: Maximum number of results to fetch + + Returns: + List of work objects from API + """ + all_works: list[dict[str, Any]] = [] + page = 1 + + while len(all_works) < max_results: + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/works" + page_params = {**params, "page": str(page)} + response = await self._request_with_retry(url, page_params) + + works = response.get("results", []) + if not works: + break + + all_works.extend(works) + logger.info("Fetched page %d: %d works", page, len(works)) + + # Check if there are more pages + meta = response.get("meta", {}) + total_count = meta.get("count", 0) + if len(all_works) >= total_count: + break + + page += 1 + + except Exception as e: + logger.error(f"Error fetching page {page}: {e}") + break + + return all_works[:max_results] + + async def _request_with_retry(self, url: str, params: dict[str, str]) -> dict[str, Any]: + """ + HTTP request with exponential backoff retry + + Implements best practices: + - Retry on 403 (rate limit) with exponential backoff + - Retry on 5xx (server error) with exponential backoff + - Don't retry on 4xx (except 403) + - Retry on timeout + + Args: + url: Request URL + params: Query parameters + + Returns: + JSON response + + Raises: + Exception: If all retries fail + """ + for attempt in range(self.MAX_RETRIES): + try: + response = await self.client.get(url, params=params) + + if response.status_code == 200: + return response.json() + elif response.status_code == 429: + retry_after = self._parse_retry_after(response.headers.get("Retry-After")) + wait_time = retry_after if retry_after is not None else 2**attempt + wait_time = self._apply_jitter(wait_time) + logger.warning( + "Rate limited (429), waiting %.2fs... (attempt %d)", + wait_time, + attempt + 1, + ) + await asyncio.sleep(wait_time) + elif response.status_code == 403: + # Rate limited + wait_time = self._apply_jitter(2**attempt) + logger.warning( + "Rate limited (403), waiting %.2fs... (attempt %d)", + wait_time, + attempt + 1, + ) + await asyncio.sleep(wait_time) + elif response.status_code >= 500: + # Server error + wait_time = self._apply_jitter(2**attempt) + logger.warning( + "Server error (%d), waiting %.2fs... (attempt %d)", + response.status_code, + wait_time, + attempt + 1, + ) + await asyncio.sleep(wait_time) + else: + # Other error, don't retry + response.raise_for_status() + + except httpx.TimeoutException: + if attempt >= self.MAX_RETRIES - 1: + raise + wait_time = self._apply_jitter(2**attempt) + logger.warning("Timeout, retrying in %.2fs... (attempt %d)", wait_time, attempt + 1) + await asyncio.sleep(wait_time) + except Exception as e: + logger.error(f"Request failed: {e}") + if attempt >= self.MAX_RETRIES - 1: + raise + wait_time = self._apply_jitter(2**attempt) + await asyncio.sleep(wait_time) + + raise Exception(f"Failed after {self.MAX_RETRIES} retries") + + @staticmethod + def _apply_jitter(wait_time: float) -> float: + return wait_time + random.uniform(0.1, 0.9) + + @staticmethod + def _parse_retry_after(retry_after: str | None) -> float | None: + if not retry_after: + return None + try: + return float(retry_after) + except ValueError: + return None + + def _transform_work(self, work: dict[str, Any]) -> LiteratureWork: + """ + Transform OpenAlex work data to standard format + + Args: + work: Raw work object from OpenAlex API + + Returns: + Standardized LiteratureWork object + """ + # Extract authors + authors: list[dict[str, str | None]] = [] + for authorship in work.get("authorships", []): + author = authorship.get("author", {}) + authors.append( + { + "name": author.get("display_name", "Unknown"), + "id": author.get("id", "").split("/")[-1] if author.get("id") else None, + } + ) + + # Extract journal/source + journal = None + primary_location = work.get("primary_location") or {} + if source := primary_location.get("source"): + journal = source.get("display_name") + + # Extract open access info + oa_info = work.get("open_access", {}) + is_oa = oa_info.get("is_oa", False) + oa_url = oa_info.get("oa_url") + + # Extract abstract (reconstruct from inverted index) + abstract = self._reconstruct_abstract(work.get("abstract_inverted_index")) + + # Extract DOI (remove prefix) + doi = None + if doi_raw := work.get("doi"): + doi = normalize_doi(doi_raw) + + # Extract primary institution (first available) + primary_institution = None + for authorship in work.get("authorships", []): + institutions = authorship.get("institutions", []) + if institutions: + primary_institution = institutions[0].get("display_name") + if primary_institution: + break + + # Build best access URL (OA first, then landing page, then DOI) + access_url = oa_url + if not access_url: + access_url = primary_location.get("landing_page_url") or primary_location.get("pdf_url") + if not access_url and doi: + access_url = f"https://doi.org/{doi}" + + return LiteratureWork( + id=work["id"].split("/")[-1], + doi=doi, + title=work.get("title", "Untitled"), + authors=authors, + publication_year=work.get("publication_year"), + cited_by_count=work.get("cited_by_count", 0), + abstract=abstract, + journal=journal, + is_oa=is_oa, + source="openalex", + access_url=access_url, + primary_institution=primary_institution, + raw_data=work, + ) + + def _reconstruct_abstract(self, inverted_index: dict[str, list[int]] | None) -> str | None: + """ + Reconstruct abstract from inverted index + + OpenAlex stores abstracts as inverted index for efficiency. + Format: {"word": [position1, position2, ...], ...} + + Args: + inverted_index: Inverted index from OpenAlex + + Returns: + Reconstructed abstract text or None + + Examples: + >>> index = {"Hello": [0], "world": [1], "!": [2]} + >>> _reconstruct_abstract(index) + "Hello world !" + """ + if not inverted_index: + return None + + # Expand inverted index to (position, word) pairs + word_positions: list[tuple[int, str]] = [ + (pos, word) for word, positions in inverted_index.items() for pos in positions + ] + + # Sort by position and join + word_positions.sort() + return " ".join(word for _, word in word_positions) + + async def close(self) -> None: + """Close the HTTP client""" + await self.client.aclose() + + async def __aenter__(self) -> "OpenAlexClient": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + await self.close() diff --git a/service/app/utils/literature/work_distributor.py b/service/app/utils/literature/work_distributor.py new file mode 100644 index 00000000..8cdaac59 --- /dev/null +++ b/service/app/utils/literature/work_distributor.py @@ -0,0 +1,164 @@ +""" +Work distributor for coordinating multiple literature data sources +""" + +import inspect +import logging +from typing import Any + +from .doi_cleaner import deduplicate_by_doi +from .models import LiteratureWork, SearchRequest + +logger = logging.getLogger(__name__) + + +class WorkDistributor: + """ + Distribute search requests to multiple literature data sources + and aggregate results + """ + + def __init__(self, openalex_email: str | None = None) -> None: + """ + Initialize distributor with available clients + + Args: + openalex_email: Email for OpenAlex polite pool (required for OpenAlex) + """ + self.clients: dict[str, Any] = {} + self.openalex_email = openalex_email + self._register_clients() + + def _register_clients(self) -> None: + """Register available data source clients""" + # Import here to avoid circular dependencies + try: + from .openalex_client import OpenAlexClient + + self.clients["openalex"] = OpenAlexClient(email=self.openalex_email) + logger.info("Registered OpenAlex client") + except ImportError as e: + logger.warning(f"Failed to register OpenAlex client: {e}") + + # Future: Add more clients + # from .semantic_scholar_client import SemanticScholarClient + # self.clients["semantic_scholar"] = SemanticScholarClient() + + async def close(self) -> None: + """Close any underlying HTTP clients""" + for client in self.clients.values(): + close_method = getattr(client, "close", None) + if callable(close_method): + result = close_method() + if inspect.isawaitable(result): + await result + + async def __aenter__(self) -> "WorkDistributor": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + await self.close() + + async def search(self, request: SearchRequest) -> dict[str, Any]: + """ + Execute search across multiple data sources and aggregate results + + Args: + request: Standardized search request + + Returns: + Dictionary containing: + - total_count: Total number of works fetched (before dedup) + - unique_count: Number of unique works (after dedup) + - sources: Dict of source name -> count + - works: List of deduplicated LiteratureWork objects + - warnings: List of warning/info messages for LLM feedback + + Examples: + >>> distributor = WorkDistributor() + >>> request = SearchRequest(query="machine learning", max_results=50) + >>> result = await distributor.search(request) + >>> print(f"Found {result['unique_count']} unique works") + """ + # Clamp max_results to 50/1000 with warnings + all_warnings: list[str] = [] + if request.max_results < 1: + all_warnings.append("⚠️ max_results < 1; using default 50") + request.max_results = 50 + elif request.max_results > 1000: + all_warnings.append("⚠️ max_results > 1000; using 1000") + request.max_results = 1000 + + # Determine which data sources to use + sources = request.data_sources or ["openalex"] + unknown_sources = [source_name for source_name in sources if source_name not in self.clients] + if unknown_sources: + all_warnings.append("⚠️ Unknown data_sources ignored: " + ", ".join(sorted(set(unknown_sources)))) + + # Collect works and warnings from all sources + all_works: list[LiteratureWork] = [] + source_counts: dict[str, int] = {} + + for source_name in sources: + if client := self.clients.get(source_name): + try: + logger.info("Fetching from %s...", source_name) + works, warnings_data = await client.search(request) + all_warnings.extend(warnings_data) + + all_works.extend(works) + source_counts[source_name] = len(works) + logger.info("Fetched %d works from %s", len(works), source_name) + except Exception as e: + logger.error(f"Error fetching from {source_name}: {e}", exc_info=True) + source_counts[source_name] = 0 + all_warnings.append(f"⚠️ Error fetching from {source_name}: {str(e)}") + else: + logger.warning(f"Data source '{source_name}' not available") + + # Deduplicate by DOI + logger.info("Deduplicating %d works...", len(all_works)) + unique_works = deduplicate_by_doi(all_works) + logger.info("After deduplication: %d unique works", len(unique_works)) + + # Sort results + unique_works = self._sort_works(unique_works, request.sort_by) + + # Limit to max_results + unique_works = unique_works[: request.max_results] + + return { + "total_count": len(all_works), + "unique_count": len(unique_works), + "sources": source_counts, + "works": unique_works, + "warnings": all_warnings, + } + + def _sort_works(self, works: list[LiteratureWork], sort_by: str) -> list[LiteratureWork]: + """ + Sort works by specified criteria + + Args: + works: List of works to sort + sort_by: Sort method - "relevance", "cited_by_count", "publication_date" + + Returns: + Sorted list of works + """ + if sort_by == "cited_by_count": + return sorted(works, key=lambda w: w.cited_by_count, reverse=True) + elif sort_by == "publication_date": + return sorted( + works, + key=lambda w: w.publication_year or float("-inf"), + reverse=True, + ) + else: # relevance or default + # For relevance, keep original order (API returns by relevance) + return works diff --git a/service/pyproject.toml b/service/pyproject.toml index a36f8b25..61d80424 100644 --- a/service/pyproject.toml +++ b/service/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "pytesseract>=0.3.13", "pillow>=12.0.0", "celery-types>=0.24.0", + "trafilatura>=1.12.0", ] [dependency-groups] diff --git a/service/tests/unit/agents/test_factory.py b/service/tests/unit/agents/test_factory.py new file mode 100644 index 00000000..85bfc6da --- /dev/null +++ b/service/tests/unit/agents/test_factory.py @@ -0,0 +1,132 @@ +"""Tests for agent factory module.""" + +from app.agents.factory import _inject_system_prompt + + +class TestInjectSystemPrompt: + """Test _inject_system_prompt function.""" + + def test_inject_into_llm_node(self) -> None: + """Test system prompt injection into LLM node.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "agent", + "type": "llm", + "llm_config": { + "prompt_template": "Default prompt", + "tools_enabled": True, + }, + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # Original should not be mutated + assert config_dict["nodes"][0]["llm_config"]["prompt_template"] == "Default prompt" + + # Result should have the new prompt + assert result["nodes"][0]["llm_config"]["prompt_template"] == "Custom system prompt" + + def test_inject_into_component_node(self) -> None: + """Test system prompt injection into react component node.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "agent", + "type": "component", + "component_config": { + "component_ref": {"key": "react"}, + }, + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # Original should not be mutated + assert "config_overrides" not in config_dict["nodes"][0]["component_config"] + + # Result should have config_overrides with system_prompt + assert result["nodes"][0]["component_config"]["config_overrides"]["system_prompt"] == "Custom system prompt" + + def test_inject_only_into_first_matching_node(self) -> None: + """Test that system prompt is only injected into the first matching node.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "agent1", + "type": "llm", + "llm_config": { + "prompt_template": "Prompt 1", + }, + }, + { + "id": "agent2", + "type": "llm", + "llm_config": { + "prompt_template": "Prompt 2", + }, + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # First node should be updated + assert result["nodes"][0]["llm_config"]["prompt_template"] == "Custom system prompt" + # Second node should remain unchanged + assert result["nodes"][1]["llm_config"]["prompt_template"] == "Prompt 2" + + def test_llm_node_takes_precedence_over_non_react_component(self) -> None: + """Test that LLM nodes are preferred over non-react components.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "other", + "type": "component", + "component_config": { + "component_ref": {"key": "other_component"}, + }, + }, + { + "id": "agent", + "type": "llm", + "llm_config": { + "prompt_template": "Default", + }, + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # LLM node should be updated (other component is not react) + assert result["nodes"][1]["llm_config"]["prompt_template"] == "Custom system prompt" + + def test_no_matching_nodes(self) -> None: + """Test graceful handling when no matching nodes exist.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "transform", + "type": "transform", + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # Should return config unchanged + assert result == config_dict diff --git a/service/tests/unit/agents/test_graph_builder.py b/service/tests/unit/agents/test_graph_builder.py index cec6f1c3..9af67348 100644 --- a/service/tests/unit/agents/test_graph_builder.py +++ b/service/tests/unit/agents/test_graph_builder.py @@ -1,8 +1,10 @@ """Tests for graph_builder module.""" -from unittest.mock import AsyncMock +from typing import Any +from unittest.mock import AsyncMock, MagicMock import pytest +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from pydantic import BaseModel from app.agents.graph_builder import GraphBuilder, build_state_class @@ -112,3 +114,173 @@ def test_build_simple_graph(self) -> None: ) graph = builder.build() assert graph is not None + + +class TestPromptTemplateAsSystemMessage: + """Test that prompt_template is prepended as SystemMessage.""" + + @pytest.mark.asyncio + async def test_prompt_template_prepends_system_message(self) -> None: + """Test prompt_template is prepended as SystemMessage, not appended as HumanMessage.""" + config = GraphConfig( + nodes=[ + GraphNodeConfig( + id="agent", + name="Agent", + type=NodeType.LLM, + llm_config=LLMNodeConfig( + prompt_template="You are a helpful assistant.", + tools_enabled=False, + ), + ), + ], + edges=[ + GraphEdgeConfig(from_node="START", to_node="agent"), + GraphEdgeConfig(from_node="agent", to_node="END"), + ], + entry_point="agent", + ) + + # Create mock LLM that captures the messages it receives + captured_messages: list[BaseMessage] = [] + + async def mock_llm_factory(model: str | None = None, temperature: float | None = None) -> Any: + mock_llm = MagicMock() + + async def capture_invoke(messages: list[BaseMessage]) -> AIMessage: + captured_messages.clear() + captured_messages.extend(messages) + return AIMessage(content="Response") + + mock_llm.ainvoke = capture_invoke + return mock_llm + + builder = GraphBuilder( + config=config, + llm_factory=mock_llm_factory, + tool_registry={}, + ) + graph = await builder.build() + + # Invoke with a user message + initial_state: dict[str, list[BaseMessage]] = { + "messages": [HumanMessage(content="Hello")], + } + await graph.ainvoke(initial_state) # type: ignore[arg-type] # type: ignore[arg-type] + + # Verify SystemMessage is first, not HumanMessage at the end + assert len(captured_messages) >= 2 + assert isinstance(captured_messages[0], SystemMessage) + assert captured_messages[0].content == "You are a helpful assistant." + assert isinstance(captured_messages[1], HumanMessage) + assert captured_messages[1].content == "Hello" + + @pytest.mark.asyncio + async def test_no_prompt_template_passes_messages_unchanged(self) -> None: + """Test that empty prompt_template doesn't add any system message.""" + config = GraphConfig( + nodes=[ + GraphNodeConfig( + id="agent", + name="Agent", + type=NodeType.LLM, + llm_config=LLMNodeConfig( + prompt_template="", # Empty + tools_enabled=False, + ), + ), + ], + edges=[ + GraphEdgeConfig(from_node="START", to_node="agent"), + GraphEdgeConfig(from_node="agent", to_node="END"), + ], + entry_point="agent", + ) + + captured_messages: list[BaseMessage] = [] + + async def mock_llm_factory(model: str | None = None, temperature: float | None = None) -> Any: + mock_llm = MagicMock() + + async def capture_invoke(messages: list[BaseMessage]) -> AIMessage: + captured_messages.clear() + captured_messages.extend(messages) + return AIMessage(content="Response") + + mock_llm.ainvoke = capture_invoke + return mock_llm + + builder = GraphBuilder( + config=config, + llm_factory=mock_llm_factory, + tool_registry={}, + ) + graph = await builder.build() + + initial_state: dict[str, list[BaseMessage]] = { + "messages": [HumanMessage(content="Hello")], + } + await graph.ainvoke(initial_state) # type: ignore[arg-type] + + # Should only have the original HumanMessage + assert len(captured_messages) == 1 + assert isinstance(captured_messages[0], HumanMessage) + assert captured_messages[0].content == "Hello" + + @pytest.mark.asyncio + async def test_existing_system_message_replaced(self) -> None: + """Test that existing SystemMessage in messages is replaced.""" + config = GraphConfig( + nodes=[ + GraphNodeConfig( + id="agent", + name="Agent", + type=NodeType.LLM, + llm_config=LLMNodeConfig( + prompt_template="New system prompt", + tools_enabled=False, + ), + ), + ], + edges=[ + GraphEdgeConfig(from_node="START", to_node="agent"), + GraphEdgeConfig(from_node="agent", to_node="END"), + ], + entry_point="agent", + ) + + captured_messages: list[BaseMessage] = [] + + async def mock_llm_factory(model: str | None = None, temperature: float | None = None) -> Any: + mock_llm = MagicMock() + + async def capture_invoke(messages: list[BaseMessage]) -> AIMessage: + captured_messages.clear() + captured_messages.extend(messages) + return AIMessage(content="Response") + + mock_llm.ainvoke = capture_invoke + return mock_llm + + builder = GraphBuilder( + config=config, + llm_factory=mock_llm_factory, + tool_registry={}, + ) + graph = await builder.build() + + # Input with existing SystemMessage + initial_state: dict[str, list[BaseMessage]] = { + "messages": [ + SystemMessage(content="Old system prompt"), + HumanMessage(content="Hello"), + ], + } + await graph.ainvoke(initial_state) # type: ignore[arg-type] + + # Should have new SystemMessage first, no duplicates + system_messages = [m for m in captured_messages if isinstance(m, SystemMessage)] + assert len(system_messages) == 1 + assert system_messages[0].content == "New system prompt" + assert isinstance(captured_messages[0], SystemMessage) + assert isinstance(captured_messages[1], HumanMessage) diff --git a/service/tests/unit/handler/mcp/test_file_handlers.py b/service/tests/unit/handler/mcp/test_file_handlers.py deleted file mode 100644 index 940382f1..00000000 --- a/service/tests/unit/handler/mcp/test_file_handlers.py +++ /dev/null @@ -1,374 +0,0 @@ -""" -Tests for file handlers. -""" - -import json -from unittest.mock import MagicMock, patch - -import pytest - -from app.mcp.document_spec import ( - DocumentSpec, - HeadingBlock, - ListBlock, - PresentationSpec, - SheetSpec, - SlideSpec, - SpreadsheetSpec, - TableBlock, - TextBlock, -) -from app.mcp.file_handlers import ( - DocxFileHandler, - ExcelFileHandler, - FileHandlerFactory, - HtmlFileHandler, - ImageFileHandler, - JsonFileHandler, - PdfFileHandler, - PptxFileHandler, - TextFileHandler, - XmlFileHandler, - YamlFileHandler, -) - - -class TestFileHandlerFactory: - def test_get_handler(self) -> None: - # Existing handlers - assert isinstance(FileHandlerFactory.get_handler("test.pdf"), PdfFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.docx"), DocxFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.doc"), DocxFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.xlsx"), ExcelFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.xls"), ExcelFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.pptx"), PptxFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.ppt"), PptxFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.txt"), TextFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.csv"), TextFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.py"), TextFileHandler) - - def test_get_handler_new_types(self) -> None: - # New handlers - assert isinstance(FileHandlerFactory.get_handler("test.html"), HtmlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.htm"), HtmlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.json"), JsonFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.yaml"), YamlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.yml"), YamlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.xml"), XmlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.png"), ImageFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.jpg"), ImageFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.jpeg"), ImageFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.gif"), ImageFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.webp"), ImageFileHandler) - - -class TestTextFileHandler: - def test_read_write(self) -> None: - handler = TextFileHandler() - content = "Hello, World!" - - # Write - bytes_content = handler.create_content(content) - assert isinstance(bytes_content, bytes) - assert bytes_content == b"Hello, World!" - - # Read - read_content = handler.read_content(bytes_content) - assert read_content == content - - def test_read_image_fail(self) -> None: - handler = TextFileHandler() - with pytest.raises(ValueError): - handler.read_content(b"test", mode="image") - - -class TestHtmlFileHandler: - def test_read_html(self) -> None: - handler = HtmlFileHandler() - html = b"

Title

Content

" - content = handler.read_content(html) - assert isinstance(content, str) - # Should extract text from HTML - assert "Title" in content or "Content" in content - - def test_read_html_strips_scripts(self) -> None: - handler = HtmlFileHandler() - html = b"Safe content" - content = handler.read_content(html) - assert "alert" not in content - assert "Safe content" in content - - def test_create_html(self) -> None: - handler = HtmlFileHandler() - content = handler.create_content("Hello\n\nWorld") - assert b"" in content - assert b"" in content - assert b"Hello" in content - - def test_read_image_fail(self) -> None: - handler = HtmlFileHandler() - with pytest.raises(ValueError): - handler.read_content(b"", mode="image") - - -class TestJsonFileHandler: - def test_read_json(self) -> None: - handler = JsonFileHandler() - data = {"key": "value", "nested": {"a": 1}} - json_bytes = json.dumps(data).encode() - - content = handler.read_content(json_bytes) - assert "key" in content - assert "value" in content - - def test_read_invalid_json(self) -> None: - handler = JsonFileHandler() - content = handler.read_content(b"not valid json") - assert content == "not valid json" - - def test_create_json_valid(self) -> None: - handler = JsonFileHandler() - data = '{"key": "value"}' - result = handler.create_content(data) - parsed = json.loads(result) - assert parsed["key"] == "value" - - def test_create_json_invalid_wraps(self) -> None: - handler = JsonFileHandler() - result = handler.create_content("plain text") - parsed = json.loads(result) - assert "content" in parsed - assert parsed["content"] == "plain text" - - -class TestYamlFileHandler: - def test_read_yaml(self) -> None: - handler = YamlFileHandler() - yaml_content = b"key: value\nnested:\n a: 1" - content = handler.read_content(yaml_content) - assert "key" in content - assert "value" in content - - def test_create_yaml_from_json(self) -> None: - handler = YamlFileHandler() - json_input = '{"key": "value"}' - result = handler.create_content(json_input) - assert b"key: value" in result - - -class TestXmlFileHandler: - def test_read_xml(self) -> None: - handler = XmlFileHandler() - xml = b"Hello" - content = handler.read_content(xml) - assert "Hello" in content - assert "item" in content - - def test_create_xml(self) -> None: - handler = XmlFileHandler() - result = handler.create_content("Test content") - assert b"" in result - assert b"Test content" in result - - def test_read_image_fail(self) -> None: - handler = XmlFileHandler() - with pytest.raises(ValueError): - handler.read_content(b"", mode="image") - - -class TestImageFileHandler: - def test_detect_format_png(self) -> None: - handler = ImageFileHandler() - png_magic = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 - assert handler._detect_format(png_magic) == "png" - - def test_detect_format_jpeg(self) -> None: - handler = ImageFileHandler() - jpeg_magic = b"\xff\xd8" + b"\x00" * 100 - assert handler._detect_format(jpeg_magic) == "jpeg" - - def test_detect_format_gif(self) -> None: - handler = ImageFileHandler() - gif_magic = b"GIF89a" + b"\x00" * 100 - assert handler._detect_format(gif_magic) == "gif" - - def test_create_raises_error(self) -> None: - handler = ImageFileHandler() - with pytest.raises(ValueError, match="Cannot create image"): - handler.create_content("text") - - -# Only mock external deps for complex handlers if they are not installed in test env -# But for now we assume we might need to mock them to run this in strict CI envs -# where deps might be missing during dev. - - -@patch("fitz.open") -@patch("fitz.Matrix") -class TestPdfFileHandler: - def test_read_text(self, mock_matrix: MagicMock, mock_open: MagicMock) -> None: - handler = PdfFileHandler() - mock_doc = MagicMock() - mock_page = MagicMock() - mock_page.get_text.return_value = "Page text" - mock_page.find_tables.return_value = MagicMock(tables=[]) - mock_doc.__iter__.return_value = [mock_page] - mock_open.return_value = mock_doc - - content = handler.read_content(b"pdf_bytes", mode="text") - assert content == "Page text" - mock_open.assert_called_with(stream=b"pdf_bytes", filetype="pdf") - - def test_write_plain_text(self, mock_matrix: MagicMock, mock_open: MagicMock) -> None: - handler = PdfFileHandler() - # For plain text, it uses reportlab, not fitz - result = handler.create_content("Some text") - assert isinstance(result, bytes) - # PDF magic bytes - assert result[:4] == b"%PDF" - - -@patch("docx.Document") -class TestDocxFileHandler: - def test_read(self, mock_document_cls: MagicMock) -> None: - handler = DocxFileHandler() - mock_doc = MagicMock() - mock_element = MagicMock() - mock_element.tag = "p" - mock_element.iter.return_value = [MagicMock(text="Para 1")] - mock_doc.element.body = [mock_element] - mock_document_cls.return_value = mock_doc - - content = handler.read_content(b"docx_bytes") - assert "Para 1" in content - - def test_write_plain_text(self, mock_document_cls: MagicMock) -> None: - handler = DocxFileHandler() - mock_doc = MagicMock() - mock_document_cls.return_value = mock_doc - - handler.create_content("Line 1\nLine 2") - - assert mock_doc.add_paragraph.call_count == 2 - mock_doc.save.assert_called() - - -@patch("openpyxl.Workbook") -@patch("openpyxl.load_workbook") -class TestExcelFileHandler: - def test_read(self, mock_load_workbook: MagicMock, mock_workbook: MagicMock) -> None: - handler = ExcelFileHandler() - mock_wb = MagicMock() - mock_ws = MagicMock() - mock_wb.sheetnames = ["Sheet1"] - mock_wb.__getitem__.return_value = mock_ws - mock_ws.iter_rows.return_value = [("A", "B")] - mock_load_workbook.return_value = mock_wb - - content = handler.read_content(b"xlsx_bytes") - assert "Sheet1" in content - assert "A\tB" in content - - def test_write_csv(self, mock_load_workbook: MagicMock, mock_workbook: MagicMock) -> None: - handler = ExcelFileHandler() - mock_wb = MagicMock() - mock_ws = MagicMock() - mock_wb.active = mock_ws - mock_workbook.return_value = mock_wb - - handler.create_content("A,B\nC,D") - - assert mock_ws.append.call_count == 2 - mock_wb.save.assert_called() - - -@patch("pptx.Presentation") -class TestPptxFileHandler: - def test_read(self, mock_presentation: MagicMock) -> None: - handler = PptxFileHandler() - mock_prs = MagicMock() - mock_slide = MagicMock() - mock_shape = MagicMock() - mock_shape.text = "Slide Text" - mock_slide.shapes = [mock_shape] - mock_slide.has_notes_slide = False - mock_prs.slides = [mock_slide] - mock_presentation.return_value = mock_prs - - content = handler.read_content(b"pptx_bytes") - assert "Slide Text" in content - - def test_write_plain_text(self, mock_presentation: MagicMock) -> None: - handler = PptxFileHandler() - mock_prs = MagicMock() - mock_slide = MagicMock() - mock_prs.slides.add_slide.return_value = mock_slide - mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] - mock_presentation.return_value = mock_prs - - handler.create_content("Title\nBody") - - mock_prs.slides.add_slide.assert_called() - mock_prs.save.assert_called() - - -# Document Spec Tests - - -class TestDocumentSpec: - def test_create_document_spec(self) -> None: - spec = DocumentSpec( - title="Test Doc", - author="Test Author", - content=[ - HeadingBlock(content="Chapter 1", level=1), - TextBlock(content="Some text here"), - ListBlock(items=["Item 1", "Item 2"], ordered=False), - TableBlock(headers=["A", "B"], rows=[["1", "2"]]), - ], - ) - assert spec.title == "Test Doc" - assert len(spec.content) == 4 - assert spec.content[0].type == "heading" - - def test_document_spec_json_roundtrip(self) -> None: - spec = DocumentSpec( - title="Test", - content=[TextBlock(content="Hello")], - ) - json_str = spec.model_dump_json() - parsed = DocumentSpec.model_validate_json(json_str) - assert parsed.title == spec.title - - -class TestSpreadsheetSpec: - def test_create_spreadsheet_spec(self) -> None: - spec = SpreadsheetSpec( - sheets=[ - SheetSpec( - name="Data", - headers=["Name", "Value"], - data=[["A", 1], ["B", 2]], - ) - ] - ) - assert len(spec.sheets) == 1 - assert spec.sheets[0].name == "Data" - - -class TestPresentationSpec: - def test_create_presentation_spec(self) -> None: - spec = PresentationSpec( - title="My Presentation", - slides=[ - SlideSpec(layout="title", title="Welcome", subtitle="Intro"), - SlideSpec( - layout="title_content", - title="Slide 2", - content=[ListBlock(items=["Point 1", "Point 2"])], - ), - ], - ) - assert len(spec.slides) == 2 - assert spec.slides[0].layout == "title" diff --git a/service/tests/unit/test_core/test_consume_strategy.py b/service/tests/unit/test_core/test_consume_strategy.py index 41a43ddd..b2c1c321 100644 --- a/service/tests/unit/test_core/test_consume_strategy.py +++ b/service/tests/unit/test_core/test_consume_strategy.py @@ -21,7 +21,7 @@ def test_default_values(self) -> None: assert context.output_tokens == 0 assert context.total_tokens == 0 assert context.content_length == 0 - assert context.generated_files_count == 0 + assert context.tool_costs == 0 def test_with_values(self) -> None: """Test ConsumptionContext with custom values.""" @@ -31,14 +31,14 @@ def test_with_values(self) -> None: output_tokens=500, total_tokens=1500, content_length=5000, - generated_files_count=2, + tool_costs=15, ) assert context.model_tier == ModelTier.PRO assert context.input_tokens == 1000 assert context.output_tokens == 500 assert context.total_tokens == 1500 assert context.content_length == 5000 - assert context.generated_files_count == 2 + assert context.tool_costs == 15 class TestTierBasedConsumptionStrategy: @@ -53,7 +53,6 @@ def test_lite_tier_is_free(self) -> None: output_tokens=5000, total_tokens=15000, content_length=50000, - generated_files_count=5, ) result = strategy.calculate(context) @@ -71,7 +70,6 @@ def test_standard_tier_base_multiplier(self) -> None: output_tokens=1000, total_tokens=2000, content_length=1000, - generated_files_count=0, ) result = strategy.calculate(context) @@ -92,7 +90,6 @@ def test_pro_tier_multiplier(self) -> None: output_tokens=1000, total_tokens=2000, content_length=1000, - generated_files_count=0, ) result = strategy.calculate(context) @@ -113,7 +110,6 @@ def test_ultra_tier_multiplier(self) -> None: output_tokens=1000, total_tokens=2000, content_length=1000, - generated_files_count=0, ) result = strategy.calculate(context) @@ -125,8 +121,8 @@ def test_ultra_tier_multiplier(self) -> None: assert result.amount == expected assert result.breakdown["tier_rate"] == 6.8 - def test_file_generation_cost(self) -> None: - """Test that file generation cost is included.""" + def test_tool_costs(self) -> None: + """Test that tool costs are included in calculation.""" strategy = TierBasedConsumptionStrategy() context = ConsumptionContext( model_tier=ModelTier.STANDARD, @@ -134,13 +130,14 @@ def test_file_generation_cost(self) -> None: output_tokens=0, total_tokens=0, content_length=0, - generated_files_count=2, + tool_costs=20, ) result = strategy.calculate(context) + # base_cost(1) + tool_costs(20) = 21 expected = int((1 + 20) * 1.0) assert result.amount == expected - assert result.breakdown["file_cost"] == 20 + assert result.breakdown["tool_costs"] == 20 def test_no_tier_defaults_to_1(self) -> None: """Test that None tier defaults to rate 1.0.""" @@ -151,7 +148,6 @@ def test_no_tier_defaults_to_1(self) -> None: output_tokens=1000, total_tokens=2000, content_length=1000, - generated_files_count=0, ) result = strategy.calculate(context) @@ -171,13 +167,13 @@ def test_breakdown_contains_all_fields(self) -> None: output_tokens=500, total_tokens=1500, content_length=1000, - generated_files_count=1, + tool_costs=10, ) result = strategy.calculate(context) assert "base_cost" in result.breakdown assert "token_cost" in result.breakdown - assert "file_cost" in result.breakdown + assert "tool_costs" in result.breakdown assert "tier_rate" in result.breakdown assert "tier" in result.breakdown assert "pre_multiplier_total" in result.breakdown @@ -205,7 +201,6 @@ def test_calculate_pro_tier(self) -> None: input_tokens=1000, output_tokens=1000, total_tokens=2000, - generated_files_count=0, ) result = ConsumptionCalculator.calculate(context) @@ -222,7 +217,6 @@ def test_breakdown_is_json_serializable(self) -> None: input_tokens=1000, output_tokens=500, total_tokens=1500, - generated_files_count=1, ) result = ConsumptionCalculator.calculate(context) diff --git a/service/tests/unit/test_utils/__init__.py b/service/tests/unit/test_literature/__init__.py similarity index 100% rename from service/tests/unit/test_utils/__init__.py rename to service/tests/unit/test_literature/__init__.py diff --git a/service/tests/unit/test_literature/test_base_client.py b/service/tests/unit/test_literature/test_base_client.py new file mode 100644 index 00000000..95d03b2a --- /dev/null +++ b/service/tests/unit/test_literature/test_base_client.py @@ -0,0 +1,300 @@ +"""Tests for base literature client.""" + +import pytest + +from app.utils.literature.base_client import BaseLiteratureClient +from app.utils.literature.models import LiteratureWork, SearchRequest + + +class ConcreteClient(BaseLiteratureClient): + """Concrete implementation of BaseLiteratureClient for testing.""" + + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """Dummy search implementation.""" + return [], [] + + +class TestBaseLiteratureClientProtocol: + """Test BaseLiteratureClient protocol and abstract methods.""" + + def test_cannot_instantiate_abstract_class(self) -> None: + """Test that BaseLiteratureClient cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseLiteratureClient() # type: ignore + + def test_concrete_implementation(self) -> None: + """Test that concrete implementation can be instantiated.""" + client = ConcreteClient() + assert client is not None + assert isinstance(client, BaseLiteratureClient) + + @pytest.mark.asyncio + async def test_search_method_required(self) -> None: + """Test that search method is required.""" + request = SearchRequest(query="test") + result = await ConcreteClient().search(request) + assert result == ([], []) + + +class TestSearchRequestDataclass: + """Test SearchRequest data model.""" + + def test_search_request_required_field(self) -> None: + """Test SearchRequest with required query field.""" + request = SearchRequest(query="machine learning") + assert request.query == "machine learning" + + def test_search_request_default_values(self) -> None: + """Test SearchRequest default values.""" + request = SearchRequest(query="test") + assert request.query == "test" + assert request.author is None + assert request.institution is None + assert request.source is None + assert request.year_from is None + assert request.year_to is None + assert request.is_oa is None + assert request.work_type is None + assert request.language is None + assert request.is_retracted is None + assert request.has_abstract is None + assert request.has_fulltext is None + assert request.sort_by == "relevance" + assert request.max_results == 50 + assert request.data_sources is None + + def test_search_request_all_fields(self) -> None: + """Test SearchRequest with all fields specified.""" + request = SearchRequest( + query="machine learning", + author="John Doe", + institution="MIT", + source="Nature", + year_from=2015, + year_to=2021, + is_oa=True, + work_type="journal-article", + language="en", + is_retracted=False, + has_abstract=True, + has_fulltext=True, + sort_by="cited_by_count", + max_results=100, + data_sources=["openalex", "semantic_scholar"], + ) + + assert request.query == "machine learning" + assert request.author == "John Doe" + assert request.institution == "MIT" + assert request.source == "Nature" + assert request.year_from == 2015 + assert request.year_to == 2021 + assert request.is_oa is True + assert request.work_type == "journal-article" + assert request.language == "en" + assert request.is_retracted is False + assert request.has_abstract is True + assert request.has_fulltext is True + assert request.sort_by == "cited_by_count" + assert request.max_results == 100 + assert request.data_sources == ["openalex", "semantic_scholar"] + + def test_search_request_partial_year_range(self) -> None: + """Test SearchRequest with only year_from.""" + request = SearchRequest(query="test", year_from=2015) + assert request.year_from == 2015 + assert request.year_to is None + + def test_search_request_partial_year_range_to_only(self) -> None: + """Test SearchRequest with only year_to.""" + request = SearchRequest(query="test", year_to=2021) + assert request.year_from is None + assert request.year_to == 2021 + + +class TestLiteratureWorkDataclass: + """Test LiteratureWork data model.""" + + def test_literature_work_minimal(self) -> None: + """Test LiteratureWork with minimal required fields.""" + work = LiteratureWork( + id="W123", + doi=None, + title="Test Paper", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + assert work.id == "W123" + assert work.title == "Test Paper" + assert work.cited_by_count == 0 + assert work.source == "openalex" + + def test_literature_work_complete(self) -> None: + """Test LiteratureWork with all fields.""" + authors: list[dict[str, str | None]] = [ + {"name": "John Doe", "id": "A1"}, + {"name": "Jane Smith", "id": "A2"}, + ] + + work = LiteratureWork( + id="W2741809807", + doi="10.1038/nature12345", + title="Machine Learning Fundamentals", + authors=authors, + publication_year=2020, + cited_by_count=150, + abstract="This is a comprehensive review of machine learning concepts.", + journal="Nature", + is_oa=True, + access_url="https://example.com/paper.pdf", + source="openalex", + ) + + assert work.id == "W2741809807" + assert work.doi == "10.1038/nature12345" + assert work.title == "Machine Learning Fundamentals" + assert len(work.authors) == 2 + assert work.authors[0]["name"] == "John Doe" + assert work.publication_year == 2020 + assert work.cited_by_count == 150 + assert work.abstract is not None + assert work.journal == "Nature" + assert work.is_oa is True + assert work.access_url is not None + + def test_literature_work_raw_data_default(self) -> None: + """Test LiteratureWork raw_data defaults to empty dict.""" + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + assert work.raw_data == {} + + def test_literature_work_raw_data_custom(self) -> None: + """Test LiteratureWork with custom raw_data.""" + raw_data = {"custom_field": "value", "api_response": {"status": "ok"}} + + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + raw_data=raw_data, + ) + + assert work.raw_data == raw_data + assert work.raw_data["custom_field"] == "value" + + def test_literature_work_multiple_authors(self) -> None: + """Test LiteratureWork with multiple authors.""" + authors = [ + {"name": "Author 1", "id": "A1"}, + {"name": "Author 2", "id": None}, # Author without ID + {"name": "Author 3", "id": "A3"}, + ] + + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=authors, + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + assert len(work.authors) == 3 + assert work.authors[1]["id"] is None + + def test_literature_work_comparison(self) -> None: + """Test LiteratureWork equality comparison.""" + work1 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal="Nature", + is_oa=True, + access_url=None, + source="openalex", + ) + + work2 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal="Nature", + is_oa=True, + access_url=None, + source="openalex", + ) + + # DataclassesObjects with same values should be equal + assert work1 == work2 + + def test_literature_work_inequality(self) -> None: + """Test LiteratureWork inequality.""" + work1 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + work2 = LiteratureWork( + id="W456", + doi="10.1038/nature67890", + title="Paper 2", + authors=[], + publication_year=2021, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + assert work1 != work2 diff --git a/service/tests/unit/test_literature/test_doi_cleaner.py b/service/tests/unit/test_literature/test_doi_cleaner.py new file mode 100644 index 00000000..2cfb6082 --- /dev/null +++ b/service/tests/unit/test_literature/test_doi_cleaner.py @@ -0,0 +1,403 @@ +"""Tests for DOI normalization and deduplication utilities.""" + +import pytest + +from app.utils.literature.doi_cleaner import deduplicate_by_doi, normalize_doi +from app.utils.literature.models import LiteratureWork + + +class TestNormalizeDOI: + """Test DOI normalization functionality.""" + + def test_normalize_doi_with_https_prefix(self) -> None: + """Test normalizing DOI with https:// prefix.""" + result = normalize_doi("https://doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_http_prefix(self) -> None: + """Test normalizing DOI with http:// prefix.""" + result = normalize_doi("http://doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_dx_prefix(self) -> None: + """Test normalizing DOI with dx.doi.org prefix.""" + result = normalize_doi("https://dx.doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_doi_colon_prefix(self) -> None: + """Test normalizing DOI with 'doi:' prefix.""" + result = normalize_doi("doi:10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_doi_prefix_uppercase(self) -> None: + """Test normalizing DOI with 'DOI:' prefix (uppercase).""" + result = normalize_doi("DOI: 10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_whitespace(self) -> None: + """Test normalizing DOI with leading/trailing whitespace.""" + result = normalize_doi(" 10.1038/nature12345 ") + assert result == "10.1038/nature12345" + + def test_normalize_doi_case_insensitive(self) -> None: + """Test that DOI normalization converts to lowercase.""" + result = normalize_doi("10.1038/NATURE12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_mixed_case_with_prefix(self) -> None: + """Test normalizing DOI with mixed case and prefix.""" + result = normalize_doi("https://DOI.ORG/10.1038/NATURE12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_none_input(self) -> None: + """Test normalizing None DOI returns None.""" + result = normalize_doi(None) + assert result is None + + def test_normalize_doi_empty_string(self) -> None: + """Test normalizing empty string returns None.""" + result = normalize_doi("") + assert result is None + + def test_normalize_doi_whitespace_only(self) -> None: + """Test normalizing whitespace-only string returns None.""" + result = normalize_doi(" ") + assert result is None + + def test_normalize_doi_invalid_format(self) -> None: + """Test normalizing invalid DOI format returns None.""" + result = normalize_doi("not-a-valid-doi") + assert result is None + + def test_normalize_doi_missing_prefix(self) -> None: + """Test normalizing DOI missing the '10.' prefix returns None.""" + result = normalize_doi("1038/nature12345") + assert result is None + + def test_normalize_doi_missing_suffix(self) -> None: + """Test normalizing DOI missing the suffix returns None.""" + result = normalize_doi("10.1038/") + assert result is None + + def test_normalize_doi_complex_suffix(self) -> None: + """Test normalizing DOI with complex suffix.""" + result = normalize_doi("10.1145/3580305.3599315") + assert result == "10.1145/3580305.3599315" + + def test_normalize_doi_with_version(self) -> None: + """Test normalizing DOI with version suffix.""" + result = normalize_doi("https://doi.org/10.1038/nature.2020.27710") + assert result == "10.1038/nature.2020.27710" + + +class TestDeduplicateByDOI: + """Test DOI-based deduplication functionality.""" + + @pytest.fixture + def sample_work(self) -> LiteratureWork: + """Create a sample literature work.""" + return LiteratureWork( + id="W2741809807", + doi="10.1038/nature12345", + title="Test Paper", + authors=[{"name": "John Doe", "id": "A1"}], + publication_year=2020, + cited_by_count=100, + abstract="Test abstract", + journal="Nature", + is_oa=True, + access_url="https://example.com/paper.pdf", + source="openalex", + ) + + def test_deduplicate_empty_list(self) -> None: + """Test deduplicating empty list returns empty list.""" + result = deduplicate_by_doi([]) + assert result == [] + + def test_deduplicate_single_work(self, sample_work: LiteratureWork) -> None: + """Test deduplicating single work returns same work.""" + result = deduplicate_by_doi([sample_work]) + assert len(result) == 1 + assert result[0].id == sample_work.id + + def test_deduplicate_duplicate_doi_keeps_higher_citations(self, sample_work: LiteratureWork) -> None: + """Test deduplication keeps work with higher citation count.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W1" # Higher citation count + + def test_deduplicate_duplicate_doi_equal_citations_keeps_newer(self) -> None: + """Test deduplication keeps more recently published work when citation count is equal.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2019, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W2" # More recent publication + + def test_deduplicate_without_doi(self) -> None: + """Test deduplicating works without DOI.""" + work1 = LiteratureWork( + id="W1", + doi=None, + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi=None, + title="Paper 2", + authors=[], + publication_year=2020, + cited_by_count=20, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 2 # Both kept since no DOI + + def test_deduplicate_invalid_doi_treated_as_no_doi(self) -> None: + """Test deduplicating works with invalid DOI treats them as without DOI.""" + work1 = LiteratureWork( + id="W1", + doi="invalid-doi-format", + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Paper 2", + authors=[], + publication_year=2020, + cited_by_count=20, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 2 + # Invalid DOI work should be in the results + assert any(w.id == "W1" for w in result) + assert any(w.id == "W2" for w in result) + + def test_deduplicate_doi_with_versions_deduplicated(self) -> None: + """Test deduplicating DOIs with version info.""" + work1 = LiteratureWork( + id="W1", + doi="https://doi.org/10.1038/nature.2020.27710", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature.2020.27710", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W2" # Higher citation count + + def test_deduplicate_preserves_order_with_doi(self) -> None: + """Test that deduplication preserves order: DOI works first, then non-DOI.""" + work_no_doi = LiteratureWork( + id="W_no_doi", + doi=None, + title="No DOI", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work_with_doi = LiteratureWork( + id="W_with_doi", + doi="10.1038/nature12345", + title="With DOI", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work_no_doi, work_with_doi]) + assert len(result) == 2 + assert result[0].id == "W_with_doi" # DOI works come first + assert result[1].id == "W_no_doi" + + def test_deduplicate_complex_scenario(self) -> None: + """Test deduplication with complex mix of works.""" + works = [ + # Duplicate pair with same DOI + LiteratureWork( + id="W1", + doi="10.1038/A", + title="A", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + LiteratureWork( + id="W2", + doi="10.1038/A", + title="A", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + # Another unique DOI + LiteratureWork( + id="W3", + doi="10.1038/B", + title="B", + authors=[], + publication_year=2021, + cited_by_count=75, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + # No DOI works + LiteratureWork( + id="W4", + doi=None, + title="C", + authors=[], + publication_year=2022, + cited_by_count=30, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + LiteratureWork( + id="W5", + doi=None, + title="D", + authors=[], + publication_year=2022, + cited_by_count=40, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + ] + + result = deduplicate_by_doi(works) + assert len(result) == 4 # W1 removed (duplicate), others kept + result_ids = {w.id for w in result} + assert result_ids == {"W2", "W3", "W4", "W5"} + # Verify W2 (higher citations) was kept over W1 + assert "W2" in result_ids + assert "W1" not in result_ids diff --git a/service/tests/unit/test_literature/test_openalex_client.py b/service/tests/unit/test_literature/test_openalex_client.py new file mode 100644 index 00000000..fadbc62e --- /dev/null +++ b/service/tests/unit/test_literature/test_openalex_client.py @@ -0,0 +1,461 @@ +"""Tests for OpenAlex API client.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.utils.literature.models import SearchRequest +from app.utils.literature.openalex_client import OpenAlexClient + + +class TestOpenAlexClientInit: + """Test OpenAlex client initialization.""" + + def test_client_initialization(self) -> None: + """Test client initializes with correct parameters.""" + email = "test@example.com" + rate_limit = 5 + timeout = 15.0 + + client = OpenAlexClient(email=email, rate_limit=rate_limit, timeout=timeout) + + assert client.email == email + assert client.rate_limit == rate_limit + assert client.pool_type == "polite" + assert pytest.approx(client.rate_limiter._min_interval, rel=0.01) == 1 / rate_limit + + def test_client_initialization_defaults(self) -> None: + """Test client initializes with default parameters.""" + email = "test@example.com" + client = OpenAlexClient(email=email) + + assert client.email == email + assert client.rate_limit == 10 + assert client.pool_type == "polite" + # Verify timeout was set (httpx Timeout object) + assert client.client.timeout is not None + + def test_client_initialization_default_pool(self) -> None: + """Test client initializes default pool when email is missing.""" + client = OpenAlexClient(email=None) + + assert client.email is None + assert client.rate_limit == 1 + assert client.pool_type == "default" + assert pytest.approx(client.rate_limiter._min_interval, rel=0.01) == 1.0 + + +class TestOpenAlexClientSearch: + """Test OpenAlex search functionality.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + @pytest.fixture + def mock_response(self) -> dict: + """Create a mock OpenAlex API response.""" + return { + "meta": {"count": 1, "page": 1}, + "results": [ + { + "id": "https://openalex.org/W2741809807", + "title": "Machine Learning Fundamentals", + "doi": "https://doi.org/10.1038/nature12345", + "publication_year": 2020, + "cited_by_count": 150, + "abstract_inverted_index": { + "Machine": [0], + "learning": [1], + "is": [2], + "fundamental": [3], + }, + "authorships": [ + { + "author": { + "id": "https://openalex.org/A5023888391", + "display_name": "Jane Smith", + } + } + ], + "primary_location": { + "source": { + "id": "https://openalex.org/S137773608", + "display_name": "Nature", + } + }, + "open_access": { + "is_oa": True, + "oa_url": "https://example.com/paper.pdf", + }, + } + ], + } + + @pytest.mark.asyncio + async def test_search_basic_query(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test basic search with simple query.""" + request = SearchRequest(query="machine learning", max_results=10) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + assert works[0].title == "Machine Learning Fundamentals" + assert works[0].doi == "10.1038/nature12345" + assert isinstance(warnings, list) + + @pytest.mark.asyncio + async def test_search_with_author_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with author filter.""" + request = SearchRequest(query="machine learning", author="Jane Smith", max_results=10) + + with patch.object(client, "_resolve_author_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("A5023888391", True, "✓ Author resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Jane Smith") + assert any("Author resolved" in msg for msg in warnings) + + @pytest.mark.asyncio + async def test_search_with_institution_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with institution filter.""" + request = SearchRequest(query="machine learning", institution="Harvard University", max_results=10) + + with patch.object(client, "_resolve_institution_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("I136199984", True, "✓ Institution resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Harvard University") + assert any("Institution resolved" in msg for msg in warnings) + + @pytest.mark.asyncio + async def test_search_with_source_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with source (journal) filter.""" + request = SearchRequest(query="machine learning", source="Nature", max_results=10) + + with patch.object(client, "_resolve_source_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("S137773608", True, "✓ Source resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Nature") + assert any("Source resolved" in msg for msg in warnings) + + @pytest.mark.asyncio + async def test_search_with_year_range(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with year range filter.""" + request = SearchRequest(query="machine learning", year_from=2015, year_to=2021, max_results=10) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + # Verify year filter was applied + call_args = mock_request.call_args + params = call_args[0][1] if call_args else {} + assert "2015-2021" in params.get("filter", "") + + @pytest.mark.asyncio + async def test_search_max_results_clamping_low(self, client: OpenAlexClient) -> None: + """Test that search handles low max_results correctly.""" + request = SearchRequest(query="test", max_results=0) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = {"meta": {"count": 0}, "results": []} + + # Should not raise an error even with 0 max_results + works, warnings = await client.search(request) + assert isinstance(works, list) + assert isinstance(warnings, list) + + @pytest.mark.asyncio + async def test_search_max_results_clamping_high(self, client: OpenAlexClient) -> None: + """Test that search handles high max_results correctly.""" + request = SearchRequest(query="test", max_results=5000) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = {"meta": {"count": 0}, "results": []} + + # Should not raise an error even with high max_results + works, warnings = await client.search(request) + assert isinstance(works, list) + assert isinstance(warnings, list) + + +class TestOpenAlexClientPrivateMethods: + """Test OpenAlex client private methods.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + def test_build_query_params_basic(self, client: OpenAlexClient) -> None: + """Test building basic query parameters.""" + request = SearchRequest(query="machine learning", max_results=50) + params = client._build_query_params(request, None, None, None) + + assert params["search"] == "machine learning" + assert params["per-page"] == "200" + assert params["mailto"] == "test@example.com" + + def test_build_query_params_with_filters(self, client: OpenAlexClient) -> None: + """Test building query parameters with filters.""" + request = SearchRequest( + query="machine learning", + year_from=2015, + year_to=2021, + is_oa=True, + work_type="journal-article", + ) + params = client._build_query_params(request, None, None, None) + + assert "filter" in params + assert "publication_year:2015-2021" in params["filter"] + assert "is_oa:true" in params["filter"] + assert "type:journal-article" in params["filter"] + + def test_build_query_params_with_resolved_ids(self, client: OpenAlexClient) -> None: + """Test building query parameters with resolved author/institution/source IDs.""" + request = SearchRequest(query="test") + params = client._build_query_params(request, "A123", "I456", "S789") + + assert "filter" in params + assert "authorships.author.id:A123" in params["filter"] + assert "authorships.institutions.id:I456" in params["filter"] + assert "primary_location.source.id:S789" in params["filter"] + + def test_build_query_params_sorting_by_citations(self, client: OpenAlexClient) -> None: + """Test building query parameters with citation sorting.""" + request = SearchRequest(query="test", sort_by="cited_by_count") + params = client._build_query_params(request, None, None, None) + + assert params.get("sort") == "cited_by_count:desc" + + def test_build_query_params_sorting_by_date(self, client: OpenAlexClient) -> None: + """Test building query parameters with date sorting.""" + request = SearchRequest(query="test", sort_by="publication_date") + params = client._build_query_params(request, None, None, None) + + assert params.get("sort") == "publication_date:desc" + + def test_reconstruct_abstract_normal(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction from inverted index.""" + inverted_index = { + "Machine": [0], + "learning": [1], + "is": [2], + "fundamental": [3], + } + + result = client._reconstruct_abstract(inverted_index) + + assert result == "Machine learning is fundamental" + + def test_reconstruct_abstract_with_duplicates(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction with duplicate words.""" + inverted_index = { + "The": [0, 5], + "quick": [1], + "brown": [2], + "fox": [3], + "jumps": [4], + } + + result = client._reconstruct_abstract(inverted_index) + + assert result == "The quick brown fox jumps The" + + def test_reconstruct_abstract_none(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction returns None for empty input.""" + result = client._reconstruct_abstract(None) + + assert result is None + + def test_reconstruct_abstract_empty(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction returns None for empty dict.""" + result = client._reconstruct_abstract({}) + + assert result is None + + def test_transform_work_complete(self, client: OpenAlexClient) -> None: + """Test transforming complete OpenAlex work object.""" + work_data = { + "id": "https://openalex.org/W2741809807", + "title": "Machine Learning Fundamentals", + "doi": "https://doi.org/10.1038/nature12345", + "publication_year": 2020, + "cited_by_count": 150, + "abstract_inverted_index": {"Machine": [0], "learning": [1]}, + "authorships": [ + { + "author": { + "id": "https://openalex.org/A5023888391", + "display_name": "Jane Smith", + } + }, + { + "author": { + "id": "https://openalex.org/A5023888392", + "display_name": "John Doe", + } + }, + ], + "primary_location": { + "source": { + "id": "https://openalex.org/S137773608", + "display_name": "Nature", + } + }, + "open_access": { + "is_oa": True, + "oa_url": "https://example.com/paper.pdf", + }, + } + + result = client._transform_work(work_data) + + assert result.id == "W2741809807" + assert result.title == "Machine Learning Fundamentals" + assert result.doi == "10.1038/nature12345" + assert result.publication_year == 2020 + assert result.cited_by_count == 150 + assert len(result.authors) == 2 + assert result.authors[0]["name"] == "Jane Smith" + assert result.journal == "Nature" + assert result.is_oa is True + assert result.access_url == "https://example.com/paper.pdf" + assert result.source == "openalex" + + def test_transform_work_minimal(self, client: OpenAlexClient) -> None: + """Test transforming minimal OpenAlex work object.""" + work_data = { + "id": "https://openalex.org/W123", + "title": "Minimal Paper", + "authorships": [], + } + + result = client._transform_work(work_data) + + assert result.id == "W123" + assert result.title == "Minimal Paper" + assert result.doi is None + assert result.authors == [] + assert result.journal is None + assert result.is_oa is False + + +class TestOpenAlexClientRequestWithRetry: + """Test OpenAlex client request retry logic.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + @pytest.mark.asyncio + async def test_request_with_retry_success(self, client: OpenAlexClient) -> None: + """Test successful request without retry.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.return_value = mock_response + + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + mock_get.assert_called_once() + + @pytest.mark.asyncio + async def test_request_with_retry_timeout(self, client: OpenAlexClient) -> None: + """Test request retry on timeout.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + # First call timeout, second call success + mock_get.side_effect = [httpx.TimeoutException("timeout"), mock_response] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + async def test_request_with_retry_rate_limit(self, client: OpenAlexClient) -> None: + """Test request retry on rate limit (403).""" + mock_response_403 = MagicMock() + mock_response_403.status_code = 403 + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + mock_response_200.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.side_effect = [mock_response_403, mock_response_200] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + async def test_request_with_retry_server_error(self, client: OpenAlexClient) -> None: + """Test request retry on server error (5xx).""" + mock_response_500 = MagicMock() + mock_response_500.status_code = 500 + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + mock_response_200.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.side_effect = [mock_response_500, mock_response_200] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + +class TestOpenAlexClientContextManager: + """Test OpenAlex client context manager.""" + + @pytest.mark.asyncio + async def test_context_manager_enter_exit(self) -> None: + """Test client works as async context manager.""" + async with OpenAlexClient(email="test@example.com") as client: + assert client is not None + assert client.email == "test@example.com" + + @pytest.mark.asyncio + async def test_close_method(self) -> None: + """Test client close method.""" + client = OpenAlexClient(email="test@example.com") + with patch.object(client.client, "aclose", new_callable=AsyncMock) as mock_close: + await client.close() + mock_close.assert_called_once() diff --git a/service/tests/unit/test_literature/test_work_distributor.py b/service/tests/unit/test_literature/test_work_distributor.py new file mode 100644 index 00000000..41e54dd5 --- /dev/null +++ b/service/tests/unit/test_literature/test_work_distributor.py @@ -0,0 +1,428 @@ +"""Tests for work distributor.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.utils.literature.models import LiteratureWork, SearchRequest +from app.utils.literature.work_distributor import WorkDistributor + + +class TestWorkDistributorInit: + """Test WorkDistributor initialization.""" + + def test_init_with_openalex_email(self) -> None: + """Test initialization with OpenAlex email.""" + distributor = WorkDistributor(openalex_email="test@example.com") + + assert distributor.openalex_email == "test@example.com" + # OpenAlex client should be registered (polite pool) + assert "openalex" in distributor.clients + + def test_init_without_openalex_email(self) -> None: + """Test initialization without OpenAlex email.""" + distributor = WorkDistributor() + + assert distributor.openalex_email is None + # OpenAlex client should still be registered (default pool) + assert "openalex" in distributor.clients + + def test_init_with_import_error(self) -> None: + """Test initialization when OpenAlex client import fails.""" + # This test would require mocking the import, which is complex + # Instead, just verify initialization works without email + distributor = WorkDistributor() + + assert distributor.openalex_email is None + assert "openalex" in distributor.clients + + +class TestWorkDistributorSearch: + """Test WorkDistributor search functionality.""" + + @pytest.fixture + def sample_work(self) -> LiteratureWork: + """Create a sample literature work.""" + return LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[{"name": "John Doe", "id": "A1"}], + publication_year=2020, + cited_by_count=100, + abstract="Test abstract", + journal="Nature", + is_oa=True, + access_url="https://example.com/paper.pdf", + source="openalex", + ) + + @pytest.fixture + def mock_openalex_client(self, sample_work: LiteratureWork) -> MagicMock: + """Create a mock OpenAlex client.""" + client = AsyncMock() + client.search = AsyncMock(return_value=([sample_work], ["✓ Search completed"])) + return client + + @pytest.mark.asyncio + async def test_search_basic(self, sample_work: LiteratureWork, mock_openalex_client: MagicMock) -> None: + """Test basic search with default source.""" + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_openalex_client} + distributor.openalex_email = "test@example.com" + + result = await distributor.search(request) + + assert result["total_count"] == 1 + assert result["unique_count"] == 1 + assert "openalex" in result["sources"] + assert len(result["works"]) == 1 + assert result["works"][0].id == "W1" + + @pytest.mark.asyncio + async def test_search_multiple_sources(self, sample_work: LiteratureWork) -> None: + """Test search with multiple data sources.""" + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature67890", + title="Another Paper", + authors=[], + publication_year=2021, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="semantic_scholar", + ) + + mock_client1 = AsyncMock() + mock_client1.search = AsyncMock(return_value=([sample_work], [])) + + mock_client2 = AsyncMock() + mock_client2.search = AsyncMock(return_value=([work2], [])) + + request = SearchRequest(query="test", max_results=50, data_sources=["openalex", "semantic_scholar"]) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client1, "semantic_scholar": mock_client2} + + result = await distributor.search(request) + + assert result["total_count"] == 2 + assert result["unique_count"] == 2 + assert "openalex" in result["sources"] + assert "semantic_scholar" in result["sources"] + + @pytest.mark.asyncio + async def test_search_deduplication(self) -> None: + """Test search deduplicates results by DOI.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="other", + ) + + mock_client = AsyncMock() + mock_client.search = AsyncMock(return_value=([work1, work2], [])) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert result["total_count"] == 2 + assert result["unique_count"] == 1 # Deduplicated + assert result["works"][0].id == "W1" # Higher citation count + + @pytest.mark.asyncio + async def test_search_with_client_error(self, sample_work: LiteratureWork) -> None: + """Test search handles client errors gracefully.""" + mock_client = AsyncMock() + mock_client.search = AsyncMock(side_effect=Exception("API Error")) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert result["total_count"] == 0 + assert result["unique_count"] == 0 + assert result["sources"]["openalex"] == 0 + assert any("Error" in w for w in result["warnings"]) + + @pytest.mark.asyncio + async def test_search_unavailable_source(self) -> None: + """Test search with unavailable data source.""" + request = SearchRequest(query="test", max_results=50, data_sources=["unavailable_source"]) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert result["total_count"] == 0 + assert result["unique_count"] == 0 + assert result["works"] == [] + + @pytest.mark.asyncio + async def test_search_max_results_clamping_low(self) -> None: + """Test search clamps max_results to minimum.""" + request = SearchRequest(query="test", max_results=0) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert any("max_results < 1" in w for w in result["warnings"]) + assert request.max_results == 50 + + @pytest.mark.asyncio + async def test_search_max_results_clamping_high(self) -> None: + """Test search clamps max_results to maximum.""" + request = SearchRequest(query="test", max_results=5000) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert any("max_results > 1000" in w for w in result["warnings"]) + assert request.max_results == 1000 + + @pytest.mark.asyncio + async def test_search_result_limiting(self) -> None: + """Test search limits results to max_results.""" + works = [ + LiteratureWork( + id=f"W{i}", + doi=f"10.1038/paper{i}", + title=f"Paper {i}", + authors=[], + publication_year=2020, + cited_by_count=100 - i, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + for i in range(20) + ] + + mock_client = AsyncMock() + mock_client.search = AsyncMock(return_value=(works, [])) + + request = SearchRequest(query="test", max_results=10) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert len(result["works"]) == 10 + + @pytest.mark.asyncio + async def test_search_with_warnings(self) -> None: + """Test search collects warnings from clients.""" + work = LiteratureWork( + id="W1", + doi=None, + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + mock_client = AsyncMock() + mock_client.search = AsyncMock( + return_value=( + [work], + ["⚠️ Author not found", "✓ Search completed"], + ) + ) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert "⚠️ Author not found" in result["warnings"] + assert "✓ Search completed" in result["warnings"] + + +class TestWorkDistributorSorting: + """Test WorkDistributor sorting functionality.""" + + @pytest.fixture + def sample_works(self) -> list[LiteratureWork]: + """Create sample works for sorting tests.""" + return [ + LiteratureWork( + id="W1", + doi=None, + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + LiteratureWork( + id="W2", + doi=None, + title="Paper 2", + authors=[], + publication_year=2021, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + LiteratureWork( + id="W3", + doi=None, + title="Paper 3", + authors=[], + publication_year=2019, + cited_by_count=75, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + ] + + def test_sort_by_relevance(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by relevance (default, maintains order).""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "relevance") + + # Should maintain original order for relevance + assert result[0].id == "W1" + assert result[1].id == "W2" + assert result[2].id == "W3" + + def test_sort_by_cited_by_count(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by citation count.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "cited_by_count") + + assert result[0].id == "W2" # 100 citations + assert result[1].id == "W3" # 75 citations + assert result[2].id == "W1" # 50 citations + + def test_sort_by_publication_date(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by publication date.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "publication_date") + + assert result[0].id == "W2" # 2021 + assert result[1].id == "W1" # 2020 + assert result[2].id == "W3" # 2019 + + def test_sort_with_missing_year(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by publication date with missing years.""" + sample_works[1].publication_year = None + + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "publication_date") + + # Works with missing year should go to the end + assert result[0].id == "W1" # 2020 + assert result[1].id == "W3" # 2019 + assert result[2].publication_year is None + + +class TestWorkDistributorContextManager: + """Test WorkDistributor context manager.""" + + @pytest.mark.asyncio + async def test_context_manager_enter_exit(self) -> None: + """Test context manager functionality.""" + async with WorkDistributor(openalex_email="test@example.com") as distributor: + assert distributor is not None + + @pytest.mark.asyncio + async def test_close_method(self) -> None: + """Test close method.""" + distributor = WorkDistributor(openalex_email="test@example.com") + + # Replace the actual client with a mock + mock_client = MagicMock() + mock_client.close = AsyncMock() + distributor.clients["openalex"] = mock_client + + await distributor.close() + + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_sync_close(self) -> None: + """Test close method with synchronous close.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + mock_client = MagicMock() + # Synchronous close (returns None, not awaitable) + mock_client.close = MagicMock(return_value=None) + distributor.clients = {"openalex": mock_client} + + await distributor.close() + + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_no_close_method(self) -> None: + """Test close method with client that has no close method.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + mock_client = MagicMock(spec=[]) # No close method + distributor.clients = {"openalex": mock_client} + + # Should not raise an error + await distributor.close() diff --git a/service/tests/unit/test_utils/test_built_in_tools.py b/service/tests/unit/test_utils/test_built_in_tools.py deleted file mode 100644 index 61ca5ebb..00000000 --- a/service/tests/unit/test_utils/test_built_in_tools.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Tests for built-in tools utilities.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastmcp import FastMCP - -from app.mcp.builtin_tools import register_built_in_tools - - -class TestBuiltInTools: - """Test built-in tools registration and functionality.""" - - @pytest.fixture - def mock_mcp(self): - """Create a mock FastMCP instance.""" - mcp = MagicMock(spec=FastMCP) - mcp.tool = MagicMock() - mcp.resource = MagicMock() - return mcp - - def test_register_built_in_tools(self, mock_mcp: MagicMock) -> None: - """Test that built-in tools are registered properly.""" - register_built_in_tools(mock_mcp) - - # Verify that the decorators were called (tools were registered) - assert mock_mcp.tool.call_count >= 4 # We have at least 4 tools - assert mock_mcp.resource.call_count >= 1 # We have at least 1 resource - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_success(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search tool with successful response.""" - # Mock response data - mock_response_data = { - "items": [ - { - "full_name": "test/repo1", - "html_url": "https://github.com/test/repo1", - "description": "Test repository 1", - "stargazers_count": 100, - "forks_count": 20, - "language": "Python", - "updated_at": "2024-01-01T00:00:00Z", - "topics": ["test", "demo"], - }, - { - "full_name": "test/repo2", - "html_url": "https://github.com/test/repo2", - "description": "Test repository 2", - "stargazers_count": 50, - "forks_count": 10, - "language": "Python", - "updated_at": "2024-01-02T00:00:00Z", - "topics": [], - }, - ] - } - - # Mock the context manager and JSON loading - mock_response = MagicMock() - mock_response.__enter__ = MagicMock(return_value=mock_response) - mock_response.__exit__ = MagicMock(return_value=None) - mock_urlopen.return_value = mock_response - - with patch("app.mcp.builtin_tools.json.load") as mock_json_load: - mock_json_load.return_value = mock_response_data - - register_built_in_tools(mock_mcp) - - # Get the search_github function from the registered tools - # Since we can't easily extract it, we'll test the logic directly - # by calling the function that would be registered - - # For this test, we'll verify the mock was set up correctly - assert mock_mcp.tool.called - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_empty_query(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search with empty query.""" - register_built_in_tools(mock_mcp) - - # The actual test would need access to the registered function - # For now, we verify the registration happened - assert mock_mcp.tool.called - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_api_error(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search with API error.""" - # Mock URL open to raise an exception - mock_urlopen.side_effect = Exception("API Error") - - register_built_in_tools(mock_mcp) - - # Verify registration still happened despite the error not occurring yet - assert mock_mcp.tool.called - - def test_search_github_parameters(self, mock_mcp: MagicMock) -> None: - """Test GitHub search with different parameters.""" - register_built_in_tools(mock_mcp) - - # Verify the tool was registered with proper signature - assert mock_mcp.tool.called - - # The actual function would accept parameters like query, max_results, sort_by - # Since we can't easily test the registered function directly, - # we verify the registration process - - async def test_llm_web_search_no_auth(self, mock_mcp: MagicMock) -> None: - """Test LLM web search without authentication.""" - with patch("app.mcp.builtin_tools.get_access_token") as mock_get_token: - mock_get_token.return_value = None - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_llm_web_search_with_auth(self, mock_mcp: MagicMock) -> None: - """Test LLM web search with authentication.""" - with ( - patch("fastmcp.server.dependencies.get_access_token") as mock_get_token, - patch("app.middleware.auth.AuthProvider") as mock_auth_provider, - patch("app.core.providers.get_user_provider_manager") as mock_get_manager, - patch("app.infra.database.connection.AsyncSessionLocal") as mock_session, - ): - # Mock authentication - mock_token = MagicMock() - mock_token.claims = {"user_id": "test-user"} - mock_get_token.return_value = mock_token - - mock_user_info = MagicMock() - mock_user_info.id = "test-user" - mock_auth_provider.parse_user_info.return_value = mock_user_info - - # Mock database session - mock_db = AsyncMock() - mock_session.return_value.__aenter__.return_value = mock_db - - # Mock provider manager - mock_provider_manager = AsyncMock() - mock_get_manager.return_value = mock_provider_manager - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_refresh_tools_success(self, mock_mcp: MagicMock) -> None: - """Test refresh tools functionality.""" - with ( - patch("app.mcp.builtin_tools.get_access_token") as mock_get_token, - patch("app.mcp.builtin_tools.AuthProvider") as mock_auth_provider, - patch("app.mcp.builtin_tools.tool_loader") as mock_tool_loader, - ): - # Mock authentication - mock_token = MagicMock() - mock_token.claims = {"user_id": "test-user"} - mock_get_token.return_value = mock_token - - mock_user_info = MagicMock() - mock_user_info.id = "test-user" - mock_auth_provider.parse_user_info.return_value = mock_user_info - - # Mock tool loader - mock_tool_loader.refresh_tools.return_value = { - "added": ["tool1", "tool2"], - "removed": ["old_tool"], - "updated": ["updated_tool"], - } - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_refresh_tools_no_auth(self, mock_mcp: MagicMock) -> None: - """Test refresh tools without authentication.""" - with patch("app.mcp.builtin_tools.get_access_token") as mock_get_token: - mock_get_token.return_value = None - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - def test_get_server_status(self, mock_mcp: MagicMock) -> None: - """Test get server status tool.""" - with patch("app.mcp.builtin_tools.tool_loader") as mock_tool_loader: - mock_proxy_manager = MagicMock() - mock_proxy_manager.list_proxies.return_value = ["proxy1", "proxy2"] - mock_tool_loader.proxy_manager = mock_proxy_manager - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - @pytest.mark.parametrize("sort_by", ["stars", "forks", "updated"]) - def test_search_github_sort_options(self, mock_mcp: MagicMock, sort_by: str) -> None: - """Test GitHub search with different sort options.""" - register_built_in_tools(mock_mcp) - - # Verify the tool registration happened - assert mock_mcp.tool.called - - def test_tools_registration_count(self, mock_mcp: MagicMock) -> None: - """Test that the expected number of tools are registered.""" - register_built_in_tools(mock_mcp) - - # We expect at least these tools: - # - search_github - # - llm_web_search - # - refresh_tools - # - get_server_status - expected_min_tools = 4 - - assert mock_mcp.tool.call_count >= expected_min_tools - - def test_resource_registration_count(self, mock_mcp: MagicMock) -> None: - """Test that the expected number of resources are registered.""" - register_built_in_tools(mock_mcp) - - # We expect at least these resources: - # - config://server - expected_min_resources = 1 - - assert mock_mcp.resource.call_count >= expected_min_resources diff --git a/service/tests/unit/tools/__init__.py b/service/tests/unit/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/service/tests/unit/tools/test_cost.py b/service/tests/unit/tools/test_cost.py new file mode 100644 index 00000000..69ad41c7 --- /dev/null +++ b/service/tests/unit/tools/test_cost.py @@ -0,0 +1,193 @@ +"""Unit tests for tool cost calculation.""" + +import pytest + +from app.tools.cost import calculate_tool_cost +from app.tools.registry import BuiltinToolRegistry, ToolCostConfig, ToolInfo + + +class TestCalculateToolCost: + """Tests for calculate_tool_cost function.""" + + @pytest.fixture(autouse=True) + def setup_registry(self) -> None: + """Set up test registry before each test.""" + BuiltinToolRegistry.clear() + + # Register mock tools for testing + from unittest.mock import MagicMock + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + + # Tool with base cost only + BuiltinToolRegistry._metadata["generate_image"] = ToolInfo( + id="generate_image", + name="Generate Image", + description="Generate images", + category="image", + cost=ToolCostConfig(base_cost=10, input_image_cost=5), + ) + + # Tool with input_image_cost + BuiltinToolRegistry._metadata["read_image"] = ToolInfo( + id="read_image", + name="Read Image", + description="Read images", + category="image", + cost=ToolCostConfig(base_cost=2), + ) + + # Tool with output_file_cost + BuiltinToolRegistry._metadata["knowledge_write"] = ToolInfo( + id="knowledge_write", + name="Knowledge Write", + description="Write files", + category="knowledge", + cost=ToolCostConfig(output_file_cost=5), + ) + + # Tool with no cost + BuiltinToolRegistry._metadata["knowledge_read"] = ToolInfo( + id="knowledge_read", + name="Knowledge Read", + description="Read files", + category="knowledge", + cost=ToolCostConfig(), + ) + + # Web search tool + BuiltinToolRegistry._metadata["web_search"] = ToolInfo( + id="web_search", + name="Web Search", + description="Search the web", + category="search", + cost=ToolCostConfig(base_cost=1), + ) + + def test_generate_image_without_reference(self) -> None: + """Test generate_image cost without reference image.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args={"prompt": "a beautiful sunset"}, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Base cost only (10), no input_image_cost + assert cost == 10 + + def test_generate_image_with_reference(self) -> None: + """Test generate_image cost with single reference image.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args={"prompt": "a beautiful sunset", "image_ids": ["ref123"]}, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Base cost (10) + input_image_cost (5) = 15 + assert cost == 15 + + def test_generate_image_with_multiple_references(self) -> None: + """Test generate_image cost with multiple reference images.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args={"prompt": "combine these images", "image_ids": ["ref1", "ref2", "ref3"]}, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Base cost (10) + input_image_cost (5) * 3 = 25 + assert cost == 25 + + def test_generate_image_with_empty_image_ids(self) -> None: + """Test generate_image cost with empty image_ids list.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args={"prompt": "a sunset", "image_ids": []}, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Base cost only (10), empty list means no input images + assert cost == 10 + + def test_read_image_cost(self) -> None: + """Test read_image cost.""" + cost = calculate_tool_cost( + tool_name="read_image", + tool_args={"image_id": "abc123", "question": "What is in this image?"}, + tool_result={"success": True, "analysis": "A beautiful sunset"}, + ) + assert cost == 2 + + def test_knowledge_write_creating_file(self) -> None: + """Test knowledge_write cost when creating a new file.""" + cost = calculate_tool_cost( + tool_name="knowledge_write", + tool_args={"filename": "report.txt", "content": "Hello world"}, + tool_result={"success": True, "message": "Created file: report.txt"}, + ) + # output_file_cost (5) for creating new file + assert cost == 5 + + def test_knowledge_write_updating_file(self) -> None: + """Test knowledge_write cost when updating an existing file.""" + cost = calculate_tool_cost( + tool_name="knowledge_write", + tool_args={"filename": "report.txt", "content": "Updated content"}, + tool_result={"success": True, "message": "Updated file: report.txt"}, + ) + # No cost for updating (message doesn't contain "Created") + assert cost == 0 + + def test_knowledge_read_is_free(self) -> None: + """Test that knowledge_read is free.""" + cost = calculate_tool_cost( + tool_name="knowledge_read", + tool_args={"filename": "report.txt"}, + tool_result={"success": True, "content": "Hello world"}, + ) + assert cost == 0 + + def test_web_search_cost(self) -> None: + """Test web_search cost.""" + cost = calculate_tool_cost( + tool_name="web_search", + tool_args={"query": "Python programming"}, + tool_result={"results": [{"title": "Result 1"}]}, + ) + assert cost == 1 + + def test_unknown_tool_is_free(self) -> None: + """Test that unknown tools have zero cost.""" + cost = calculate_tool_cost( + tool_name="unknown_tool", + tool_args={"some": "args"}, + tool_result={"some": "result"}, + ) + assert cost == 0 + + def test_no_args_provided(self) -> None: + """Test cost calculation when no args are provided.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args=None, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Should only return base cost + assert cost == 10 + + def test_no_result_provided(self) -> None: + """Test cost calculation when no result is provided.""" + cost = calculate_tool_cost( + tool_name="knowledge_write", + tool_args={"filename": "test.txt", "content": "Hello"}, + tool_result=None, + ) + # Should only return base cost (0 for knowledge_write) + assert cost == 0 + + def test_failed_tool_execution(self) -> None: + """Test that failed tool executions don't charge output_file_cost.""" + cost = calculate_tool_cost( + tool_name="knowledge_write", + tool_args={"filename": "test.txt", "content": "Hello"}, + tool_result={"success": False, "error": "Permission denied"}, + ) + # No charge for failed execution + assert cost == 0 diff --git a/service/tests/unit/tools/utils/__init__.py b/service/tests/unit/tools/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/service/tests/unit/tools/utils/documents/__init__.py b/service/tests/unit/tools/utils/documents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/service/tests/unit/tools/utils/documents/test_handlers.py b/service/tests/unit/tools/utils/documents/test_handlers.py new file mode 100644 index 00000000..c7a0da34 --- /dev/null +++ b/service/tests/unit/tools/utils/documents/test_handlers.py @@ -0,0 +1,903 @@ +""" +Tests for file handlers. +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from app.tools.utils.documents.handlers import ( + DocxFileHandler, + ExcelFileHandler, + FileHandlerFactory, + HtmlFileHandler, + ImageFileHandler, + JsonFileHandler, + PdfFileHandler, + PptxFileHandler, + TextFileHandler, + XmlFileHandler, + YamlFileHandler, +) +from app.tools.utils.documents.image_fetcher import FetchedImage +from app.tools.utils.documents.spec import ( + DocumentSpec, + HeadingBlock, + ImageBlock, + ImageSlideSpec, + ListBlock, + PresentationSpec, + SheetSpec, + SlideSpec, + SpreadsheetSpec, + TableBlock, + TextBlock, +) + + +class TestFileHandlerFactory: + def test_get_handler(self) -> None: + # Existing handlers + assert isinstance(FileHandlerFactory.get_handler("test.pdf"), PdfFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.docx"), DocxFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.doc"), DocxFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.xlsx"), ExcelFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.xls"), ExcelFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.pptx"), PptxFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.ppt"), PptxFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.txt"), TextFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.csv"), TextFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.py"), TextFileHandler) + + def test_get_handler_new_types(self) -> None: + # New handlers + assert isinstance(FileHandlerFactory.get_handler("test.html"), HtmlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.htm"), HtmlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.json"), JsonFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.yaml"), YamlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.yml"), YamlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.xml"), XmlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.png"), ImageFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.jpg"), ImageFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.jpeg"), ImageFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.gif"), ImageFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.webp"), ImageFileHandler) + + +class TestTextFileHandler: + def test_read_write(self) -> None: + handler = TextFileHandler() + content = "Hello, World!" + + # Write + bytes_content = handler.create_content(content) + assert isinstance(bytes_content, bytes) + assert bytes_content == b"Hello, World!" + + # Read + read_content = handler.read_content(bytes_content) + assert read_content == content + + def test_read_image_fail(self) -> None: + handler = TextFileHandler() + with pytest.raises(ValueError): + handler.read_content(b"test", mode="image") + + +class TestHtmlFileHandler: + def test_read_html(self) -> None: + handler = HtmlFileHandler() + html = b"

Title

Content

" + content = handler.read_content(html) + assert isinstance(content, str) + # Should extract text from HTML + assert "Title" in content or "Content" in content + + def test_read_html_strips_scripts(self) -> None: + handler = HtmlFileHandler() + html = b"Safe content" + content = handler.read_content(html) + assert "alert" not in content + assert "Safe content" in content + + def test_create_html(self) -> None: + handler = HtmlFileHandler() + content = handler.create_content("Hello\n\nWorld") + assert b"" in content + assert b"" in content + assert b"Hello" in content + + def test_read_image_fail(self) -> None: + handler = HtmlFileHandler() + with pytest.raises(ValueError): + handler.read_content(b"", mode="image") + + +class TestJsonFileHandler: + def test_read_json(self) -> None: + handler = JsonFileHandler() + data = {"key": "value", "nested": {"a": 1}} + json_bytes = json.dumps(data).encode() + + content = handler.read_content(json_bytes) + assert "key" in content + assert "value" in content + + def test_read_invalid_json(self) -> None: + handler = JsonFileHandler() + content = handler.read_content(b"not valid json") + assert content == "not valid json" + + def test_create_json_valid(self) -> None: + handler = JsonFileHandler() + data = '{"key": "value"}' + result = handler.create_content(data) + parsed = json.loads(result) + assert parsed["key"] == "value" + + def test_create_json_invalid_wraps(self) -> None: + handler = JsonFileHandler() + result = handler.create_content("plain text") + parsed = json.loads(result) + assert "content" in parsed + assert parsed["content"] == "plain text" + + +class TestYamlFileHandler: + def test_read_yaml(self) -> None: + handler = YamlFileHandler() + yaml_content = b"key: value\nnested:\n a: 1" + content = handler.read_content(yaml_content) + assert "key" in content + assert "value" in content + + def test_create_yaml_from_json(self) -> None: + handler = YamlFileHandler() + json_input = '{"key": "value"}' + result = handler.create_content(json_input) + assert b"key: value" in result + + +class TestXmlFileHandler: + def test_read_xml(self) -> None: + handler = XmlFileHandler() + xml = b"Hello" + content = handler.read_content(xml) + assert "Hello" in content + assert "item" in content + + def test_create_xml(self) -> None: + handler = XmlFileHandler() + result = handler.create_content("Test content") + assert b"" in result + assert b"Test content" in result + + def test_read_image_fail(self) -> None: + handler = XmlFileHandler() + with pytest.raises(ValueError): + handler.read_content(b"", mode="image") + + +class TestImageFileHandler: + def test_detect_format_png(self) -> None: + handler = ImageFileHandler() + png_magic = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + assert handler._detect_format(png_magic) == "png" + + def test_detect_format_jpeg(self) -> None: + handler = ImageFileHandler() + jpeg_magic = b"\xff\xd8" + b"\x00" * 100 + assert handler._detect_format(jpeg_magic) == "jpeg" + + def test_detect_format_gif(self) -> None: + handler = ImageFileHandler() + gif_magic = b"GIF89a" + b"\x00" * 100 + assert handler._detect_format(gif_magic) == "gif" + + def test_create_raises_error(self) -> None: + handler = ImageFileHandler() + with pytest.raises(ValueError, match="Cannot create image"): + handler.create_content("text") + + +# Only mock external deps for complex handlers if they are not installed in test env +# But for now we assume we might need to mock them to run this in strict CI envs +# where deps might be missing during dev. + + +@patch("fitz.open") +@patch("fitz.Matrix") +class TestPdfFileHandler: + def test_read_text(self, mock_matrix: MagicMock, mock_open: MagicMock) -> None: + handler = PdfFileHandler() + mock_doc = MagicMock() + mock_page = MagicMock() + mock_page.get_text.return_value = "Page text" + mock_page.find_tables.return_value = MagicMock(tables=[]) + mock_doc.__iter__.return_value = [mock_page] + mock_open.return_value = mock_doc + + content = handler.read_content(b"pdf_bytes", mode="text") + assert content == "Page text" + mock_open.assert_called_with(stream=b"pdf_bytes", filetype="pdf") + + def test_write_plain_text(self, mock_matrix: MagicMock, mock_open: MagicMock) -> None: + handler = PdfFileHandler() + # For plain text, it uses reportlab, not fitz + result = handler.create_content("Some text") + assert isinstance(result, bytes) + # PDF magic bytes + assert result[:4] == b"%PDF" + + +@patch("docx.Document") +class TestDocxFileHandler: + def test_read(self, mock_document_cls: MagicMock) -> None: + handler = DocxFileHandler() + mock_doc = MagicMock() + mock_element = MagicMock() + mock_element.tag = "p" + mock_element.iter.return_value = [MagicMock(text="Para 1")] + mock_doc.element.body = [mock_element] + mock_document_cls.return_value = mock_doc + + content = handler.read_content(b"docx_bytes") + assert "Para 1" in content + + def test_write_plain_text(self, mock_document_cls: MagicMock) -> None: + handler = DocxFileHandler() + mock_doc = MagicMock() + mock_document_cls.return_value = mock_doc + + handler.create_content("Line 1\nLine 2") + + assert mock_doc.add_paragraph.call_count == 2 + mock_doc.save.assert_called() + + +@patch("openpyxl.Workbook") +@patch("openpyxl.load_workbook") +class TestExcelFileHandler: + def test_read(self, mock_load_workbook: MagicMock, mock_workbook: MagicMock) -> None: + handler = ExcelFileHandler() + mock_wb = MagicMock() + mock_ws = MagicMock() + mock_wb.sheetnames = ["Sheet1"] + mock_wb.__getitem__.return_value = mock_ws + mock_ws.iter_rows.return_value = [("A", "B")] + mock_load_workbook.return_value = mock_wb + + content = handler.read_content(b"xlsx_bytes") + assert "Sheet1" in content + assert "A\tB" in content + + def test_write_csv(self, mock_load_workbook: MagicMock, mock_workbook: MagicMock) -> None: + handler = ExcelFileHandler() + mock_wb = MagicMock() + mock_ws = MagicMock() + mock_wb.active = mock_ws + mock_workbook.return_value = mock_wb + + handler.create_content("A,B\nC,D") + + assert mock_ws.append.call_count == 2 + mock_wb.save.assert_called() + + +@patch("pptx.Presentation") +class TestPptxFileHandler: + def test_read(self, mock_presentation: MagicMock) -> None: + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_shape = MagicMock() + mock_shape.text = "Slide Text" + mock_slide.shapes = [mock_shape] + mock_slide.has_notes_slide = False + mock_prs.slides = [mock_slide] + mock_presentation.return_value = mock_prs + + content = handler.read_content(b"pptx_bytes") + assert "Slide Text" in content + + def test_write_plain_text(self, mock_presentation: MagicMock) -> None: + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + handler.create_content("Title\nBody") + + mock_prs.slides.add_slide.assert_called() + mock_prs.save.assert_called() + + +# Document Spec Tests + + +class TestDocumentSpec: + def test_create_document_spec(self) -> None: + spec = DocumentSpec( + title="Test Doc", + author="Test Author", + content=[ + HeadingBlock(content="Chapter 1", level=1), + TextBlock(content="Some text here"), + ListBlock(items=["Item 1", "Item 2"], ordered=False), + TableBlock(headers=["A", "B"], rows=[["1", "2"]]), + ], + ) + assert spec.title == "Test Doc" + assert len(spec.content) == 4 + assert spec.content[0].type == "heading" + + def test_document_spec_json_roundtrip(self) -> None: + spec = DocumentSpec( + title="Test", + content=[TextBlock(content="Hello")], + ) + json_str = spec.model_dump_json() + parsed = DocumentSpec.model_validate_json(json_str) + assert parsed.title == spec.title + + +class TestSpreadsheetSpec: + def test_create_spreadsheet_spec(self) -> None: + spec = SpreadsheetSpec( + sheets=[ + SheetSpec( + name="Data", + headers=["Name", "Value"], + data=[["A", 1], ["B", 2]], + ) + ] + ) + assert len(spec.sheets) == 1 + assert spec.sheets[0].name == "Data" + + +class TestPresentationSpec: + def test_create_presentation_spec(self) -> None: + spec = PresentationSpec( + title="My Presentation", + slides=[ + SlideSpec(layout="title", title="Welcome", subtitle="Intro"), + SlideSpec( + layout="title_content", + title="Slide 2", + content=[ListBlock(items=["Point 1", "Point 2"])], + ), + ], + ) + assert len(spec.slides) == 2 + assert spec.slides[0].layout == "title" + + +class TestPptxFileHandlerEnhanced: + """Tests for enhanced PPTX generation with images, tables, and headings.""" + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_table_block(self, mock_presentation: MagicMock, mock_image_fetcher: MagicMock) -> None: + """Test table rendering in PPTX.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_table_shape = MagicMock() + mock_table = MagicMock() + mock_table_shape.table = mock_table + + # Mock table cells + mock_cells = {} + for row in range(3): # 1 header + 2 data rows + for col in range(2): + cell = MagicMock() + cell.text_frame.paragraphs = [MagicMock()] + mock_cells[(row, col)] = cell + mock_table.cell = lambda r, c: mock_cells[(r, c)] # type: ignore[misc] + + mock_slide.shapes.add_table.return_value = mock_table_shape + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_prs.slides.__iter__ = lambda self: iter([mock_slide]) # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Data Slide", + content=[ + TableBlock( + headers=["Name", "Value"], + rows=[["Item A", "100"], ["Item B", "200"]], + ) + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify table was created + mock_slide.shapes.add_table.assert_called_once() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_heading_block(self, mock_presentation: MagicMock, mock_image_fetcher: MagicMock) -> None: + """Test heading rendering in PPTX.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_textbox.text_frame.paragraphs = [MagicMock()] + + mock_slide.shapes.add_textbox.return_value = mock_textbox + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Content Slide", + content=[ + HeadingBlock(content="Section Header", level=2), + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify textbox was created for heading + mock_slide.shapes.add_textbox.assert_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_image_block_success(self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock) -> None: + """Test successful image rendering in PPTX.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image_data", + format="png", + width=200, + height=150, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Image Slide", + content=[ + ImageBlock( + url="https://example.com/image.png", + caption="Test Image", + ) + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify image was fetched with keyword arguments (new signature) + mock_fetcher.fetch.assert_called_once_with(url="https://example.com/image.png", image_id=None) + # Verify add_picture was called + mock_slide.shapes.add_picture.assert_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_image_block_failure_placeholder( + self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock + ) -> None: + """Test image failure shows placeholder text.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_textbox.text_frame.paragraphs = [MagicMock()] + + mock_slide.shapes.add_textbox.return_value = mock_textbox + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher failure + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=False, + error="Connection timeout", + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Image Slide", + content=[ImageBlock(url="https://example.com/fail.png")], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify placeholder textbox was created (not add_picture) + mock_slide.shapes.add_textbox.assert_called() + mock_slide.shapes.add_picture.assert_not_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_mixed_content(self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock) -> None: + """Test slide with multiple content block types.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_textbox.text_frame.paragraphs = [MagicMock()] + mock_table_shape = MagicMock() + mock_table = MagicMock() + mock_table_shape.table = mock_table + + # Mock table cells + mock_cells = {} + for row in range(2): + for col in range(2): + cell = MagicMock() + cell.text_frame.paragraphs = [MagicMock()] + mock_cells[(row, col)] = cell + mock_table.cell = lambda r, c: mock_cells[(r, c)] # type: ignore[misc] + + mock_slide.shapes.add_textbox.return_value = mock_textbox + mock_slide.shapes.add_table.return_value = mock_table_shape + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image", + format="png", + width=100, + height=100, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Mixed Content", + content=[ + HeadingBlock(content="Introduction", level=2), + TextBlock(content="Some intro text here."), + ListBlock(items=["Point 1", "Point 2"], ordered=True), + TableBlock(headers=["A", "B"], rows=[["1", "2"]]), + ImageBlock(url="https://example.com/chart.png"), + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify all content types were rendered + # Multiple textbox calls for heading, text, list + assert mock_slide.shapes.add_textbox.call_count >= 3 + # One table call + mock_slide.shapes.add_table.assert_called_once() + # One picture call + mock_slide.shapes.add_picture.assert_called_once() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_text_block_with_style(self, mock_presentation: MagicMock, mock_image_fetcher: MagicMock) -> None: + """Test text block with style attribute.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_paragraph = MagicMock() + mock_textbox.text_frame.paragraphs = [mock_paragraph] + mock_textbox.text_frame.word_wrap = True + + mock_slide.shapes.add_textbox.return_value = mock_textbox + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Styled Text", + content=[ + TextBlock(content="Bold text", style="bold"), + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + mock_slide.shapes.add_textbox.assert_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_image_with_specified_width(self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock) -> None: + """Test image block with width parameter.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image_data", + format="png", + width=800, + height=600, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Image with Width", + content=[ + ImageBlock( + url="https://example.com/image.png", + width=288, # 4 inches at 72 DPI + ) + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + mock_slide.shapes.add_picture.assert_called() + + +class TestPresentationSpecImageSlides: + """Tests for PresentationSpec with image_slides mode.""" + + def test_create_presentation_spec_image_slides_mode(self) -> None: + """Test creating PresentationSpec with image_slides mode.""" + spec = PresentationSpec( + mode="image_slides", + title="AI Generated Presentation", + author="AI Agent", + image_slides=[ + ImageSlideSpec(image_id="abc-123-456-def", notes="Welcome slide"), + ImageSlideSpec(image_id="ghi-789-012-jkl", notes="Content slide"), + ImageSlideSpec(image_id="mno-345-678-pqr"), # No notes + ], + ) + assert spec.mode == "image_slides" + assert len(spec.image_slides) == 3 + assert spec.image_slides[0].image_id == "abc-123-456-def" + assert spec.image_slides[0].notes == "Welcome slide" + assert spec.image_slides[2].notes is None + + def test_create_presentation_spec_structured_mode_default(self) -> None: + """Test that structured mode is the default.""" + spec = PresentationSpec( + title="Traditional Presentation", + slides=[ + SlideSpec(layout="title", title="Welcome"), + ], + ) + assert spec.mode == "structured" + assert len(spec.slides) == 1 + + def test_image_slide_spec_json_roundtrip(self) -> None: + """Test ImageSlideSpec JSON serialization.""" + spec = PresentationSpec( + mode="image_slides", + image_slides=[ + ImageSlideSpec(image_id="test-uuid", notes="Test notes"), + ], + ) + json_str = spec.model_dump_json() + parsed = PresentationSpec.model_validate_json(json_str) + assert parsed.mode == "image_slides" + assert parsed.image_slides[0].image_id == "test-uuid" + + +class TestImageBlockWithImageId: + """Tests for ImageBlock with image_id field.""" + + def test_image_block_with_url(self) -> None: + """Test ImageBlock with URL (traditional).""" + block = ImageBlock(url="https://example.com/image.png", caption="Test") + assert block.url == "https://example.com/image.png" + assert block.image_id is None + assert block.caption == "Test" + + def test_image_block_with_image_id(self) -> None: + """Test ImageBlock with image_id (new feature).""" + block = ImageBlock(image_id="abc-123-uuid", caption="Generated image") + assert block.url is None + assert block.image_id == "abc-123-uuid" + assert block.caption == "Generated image" + + def test_image_block_both_url_and_image_id(self) -> None: + """Test ImageBlock can have both (though one is preferred).""" + block = ImageBlock( + url="https://example.com/fallback.png", + image_id="abc-123-uuid", + ) + assert block.url is not None + assert block.image_id is not None + + +class TestPptxImageSlidesMode: + """Tests for PPTX generation with image_slides mode.""" + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_create_pptx_image_slides_success( + self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock + ) -> None: + """Test PPTX generation with image_slides mode.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + + # Mock slide dimensions + mock_prs.slide_width = MagicMock() + mock_prs.slide_height = MagicMock() + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher returning success + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image_data", + format="png", + width=1920, + height=1080, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + mode="image_slides", + title="Generated Presentation", + image_slides=[ + ImageSlideSpec(image_id="slide-1-uuid", notes="Speaker notes 1"), + ImageSlideSpec(image_id="slide-2-uuid", notes="Speaker notes 2"), + ], + ) + + handler.create_content(spec.model_dump_json()) + + # Verify image fetcher was called with image_id + assert mock_fetcher.fetch.call_count == 2 + mock_fetcher.fetch.assert_any_call(image_id="slide-1-uuid") + mock_fetcher.fetch.assert_any_call(image_id="slide-2-uuid") + + # Verify slides were added with full-bleed images + assert mock_prs.slides.add_slide.call_count == 2 + assert mock_slide.shapes.add_picture.call_count == 2 + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_create_pptx_image_slides_failure_shows_placeholder( + self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock + ) -> None: + """Test PPTX generation shows placeholder when image fails.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_textbox.text_frame.paragraphs = [MagicMock()] + + mock_prs.slide_width = MagicMock() + mock_prs.slide_height = MagicMock() + mock_slide.shapes.add_textbox.return_value = mock_textbox + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher returning failure + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=False, + error="Image not found", + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + mode="image_slides", + image_slides=[ + ImageSlideSpec(image_id="missing-uuid"), + ], + ) + + handler.create_content(spec.model_dump_json()) + + # Verify placeholder textbox was added instead of picture + mock_slide.shapes.add_textbox.assert_called() + mock_slide.shapes.add_picture.assert_not_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_image_block_with_image_id( + self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock + ) -> None: + """Test rendering ImageBlock with image_id in structured slides.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image_data", + format="png", + width=400, + height=300, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Image from generate_image", + content=[ + ImageBlock( + image_id="generated-image-uuid", + caption="AI Generated Chart", + ) + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify image fetcher was called with image_id (not url) + mock_fetcher.fetch.assert_called_once() + call_kwargs = mock_fetcher.fetch.call_args + # The call should have image_id set + assert call_kwargs[1].get("image_id") == "generated-image-uuid" or ( + call_kwargs[0] == () and call_kwargs[1].get("url") is None + ) diff --git a/service/tests/unit/tools/utils/documents/test_image_fetcher.py b/service/tests/unit/tools/utils/documents/test_image_fetcher.py new file mode 100644 index 00000000..bc974672 --- /dev/null +++ b/service/tests/unit/tools/utils/documents/test_image_fetcher.py @@ -0,0 +1,378 @@ +""" +Tests for ImageFetcher service. +""" + +import base64 +import io +from unittest.mock import MagicMock, patch + + +from app.tools.utils.documents.image_fetcher import ( + DEFAULT_TIMEOUT, + MAX_IMAGE_DIMENSION, + MAX_IMAGE_SIZE_BYTES, + FetchedImage, + ImageFetcher, +) + + +class TestFetchedImage: + def test_success_result(self) -> None: + result = FetchedImage( + success=True, + data=b"image_data", + format="png", + width=100, + height=100, + ) + assert result.success + assert result.data == b"image_data" + assert result.error is None + + def test_failure_result(self) -> None: + result = FetchedImage(success=False, error="Connection failed") + assert not result.success + assert result.error == "Connection failed" + assert result.data is None + + +class TestImageFetcher: + def test_init_defaults(self) -> None: + fetcher = ImageFetcher() + assert fetcher.timeout == DEFAULT_TIMEOUT + assert fetcher.max_size_bytes == MAX_IMAGE_SIZE_BYTES + assert fetcher.max_dimension == MAX_IMAGE_DIMENSION + + def test_init_custom_values(self) -> None: + fetcher = ImageFetcher(timeout=10.0, max_size_bytes=1000, max_dimension=500) + assert fetcher.timeout == 10.0 + assert fetcher.max_size_bytes == 1000 + assert fetcher.max_dimension == 500 + + def test_unsupported_scheme(self) -> None: + fetcher = ImageFetcher() + result = fetcher.fetch("ftp://example.com/image.png") + assert not result.success + assert "Unsupported URL scheme" in (result.error or "") + + +class TestImageFetcherHTTP: + @patch("httpx.Client") + def test_fetch_http_success(self, mock_client_cls: MagicMock) -> None: + # Create a minimal valid PNG (1x1 pixel) + png_data = self._create_minimal_png() + + mock_response = MagicMock() + mock_response.content = png_data + mock_response.headers = {"content-length": str(len(png_data))} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/image.png") + + assert result.success + assert result.data is not None + assert result.format == "png" + mock_client.get.assert_called_once_with("https://example.com/image.png") + + @patch("httpx.Client") + def test_fetch_http_timeout(self, mock_client_cls: MagicMock) -> None: + import httpx + + mock_client = MagicMock() + mock_client.get.side_effect = httpx.TimeoutException("timeout") + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/image.png") + + assert not result.success + assert "Timeout" in (result.error or "") + + @patch("httpx.Client") + def test_fetch_http_status_error(self, mock_client_cls: MagicMock) -> None: + import httpx + + mock_response = MagicMock() + mock_response.status_code = 404 + + mock_client = MagicMock() + mock_client.get.side_effect = httpx.HTTPStatusError("Not found", request=MagicMock(), response=mock_response) + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/image.png") + + assert not result.success + assert "404" in (result.error or "") + + @patch("httpx.Client") + def test_fetch_http_too_large_header(self, mock_client_cls: MagicMock) -> None: + mock_response = MagicMock() + mock_response.headers = {"content-length": str(MAX_IMAGE_SIZE_BYTES + 1)} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/large.png") + + assert not result.success + assert "too large" in (result.error or "").lower() + + @patch("httpx.Client") + def test_fetch_http_too_large_content(self, mock_client_cls: MagicMock) -> None: + mock_response = MagicMock() + mock_response.headers = {} # No content-length header + mock_response.content = b"x" * (MAX_IMAGE_SIZE_BYTES + 1) + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/large.png") + + assert not result.success + assert "too large" in (result.error or "").lower() + + def _create_minimal_png(self) -> bytes: + """Create a minimal valid 1x1 PNG image.""" + from PIL import Image + + img = Image.new("RGB", (1, 1), color="red") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + return buffer.getvalue() + + +class TestImageFetcherBase64: + def test_fetch_base64_valid_png(self) -> None: + # Create a minimal valid PNG + from PIL import Image + + img = Image.new("RGB", (10, 10), color="blue") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + fetcher = ImageFetcher() + result = fetcher.fetch(data_url) + + assert result.success + assert result.data is not None + assert result.format == "png" + assert result.width == 10 + assert result.height == 10 + + def test_fetch_base64_valid_jpeg(self) -> None: + from PIL import Image + + img = Image.new("RGB", (20, 15), color="green") + buffer = io.BytesIO() + img.save(buffer, format="JPEG") + jpeg_bytes = buffer.getvalue() + + b64_data = base64.b64encode(jpeg_bytes).decode("utf-8") + data_url = f"data:image/jpeg;base64,{b64_data}" + + fetcher = ImageFetcher() + result = fetcher.fetch(data_url) + + assert result.success + assert result.format in ("jpg", "jpeg") + assert result.width == 20 + assert result.height == 15 + + def test_fetch_base64_invalid_format(self) -> None: + fetcher = ImageFetcher() + result = fetcher.fetch("data:text/plain;base64,SGVsbG8=") + assert not result.success + assert "Invalid" in (result.error or "") + + def test_fetch_base64_invalid_data(self) -> None: + fetcher = ImageFetcher() + result = fetcher.fetch("!!data") + assert not result.success + # Either decode error or image processing error + assert result.error is not None + + def test_fetch_base64_too_large(self) -> None: + from PIL import Image + + # Create an image that exceeds the limit when decoded + img = Image.new("RGB", (100, 100), color="red") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + # Use a small limit to trigger the check + fetcher = ImageFetcher(max_size_bytes=100) + result = fetcher.fetch(data_url) + + assert not result.success + assert "too large" in (result.error or "").lower() + + +class TestImageFetcherResize: + def test_resize_large_image(self) -> None: + from PIL import Image + + # Create an image larger than max dimension + large_size = MAX_IMAGE_DIMENSION + 500 + img = Image.new("RGB", (large_size, large_size // 2), color="purple") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + fetcher = ImageFetcher() + result = fetcher.fetch(data_url) + + assert result.success + # Image should be resized + assert result.width is not None + assert result.height is not None + assert result.width <= MAX_IMAGE_DIMENSION + assert result.height <= MAX_IMAGE_DIMENSION + + def test_no_resize_small_image(self) -> None: + from PIL import Image + + small_size = 100 + img = Image.new("RGB", (small_size, small_size), color="yellow") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + fetcher = ImageFetcher() + result = fetcher.fetch(data_url) + + assert result.success + assert result.width == small_size + assert result.height == small_size + + +class TestImageFetcherStorage: + @patch("app.core.storage.get_storage_service") + def test_fetch_from_storage_success(self, mock_get_storage: MagicMock) -> None: + from PIL import Image + + # Create test image + img = Image.new("RGB", (50, 50), color="cyan") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + # Mock storage service + mock_storage = MagicMock() + + async def mock_download(key: str, output_buffer: io.BytesIO) -> None: + output_buffer.write(png_bytes) + + mock_storage.download_file = mock_download + mock_get_storage.return_value = mock_storage + + fetcher = ImageFetcher() + result = fetcher.fetch("storage://path/to/image.png") + + assert result.success + assert result.width == 50 + assert result.height == 50 + + @patch("app.core.storage.get_storage_service") + def test_fetch_from_storage_failure(self, mock_get_storage: MagicMock) -> None: + mock_storage = MagicMock() + + async def mock_download(key: str, output_buffer: io.BytesIO) -> None: + raise FileNotFoundError("File not found") + + mock_storage.download_file = mock_download + mock_get_storage.return_value = mock_storage + + fetcher = ImageFetcher() + result = fetcher.fetch("storage://path/to/missing.png") + + assert not result.success + assert "Storage fetch failed" in (result.error or "") + + +class TestImageFetcherByImageId: + """Tests for fetching images by image_id (UUID).""" + + def test_fetch_invalid_uuid_format(self) -> None: + """Test that invalid UUID format returns error.""" + fetcher = ImageFetcher() + result = fetcher.fetch(image_id="not-a-valid-uuid") + + assert not result.success + assert "Invalid image_id format" in (result.error or "") + + def test_fetch_requires_url_or_image_id(self) -> None: + """Test that fetch fails when neither url nor image_id is provided.""" + fetcher = ImageFetcher() + result = fetcher.fetch() + + assert not result.success + assert "Either url or image_id must be provided" in (result.error or "") + + def test_fetch_with_url_still_works(self) -> None: + """Test that fetch with url parameter still works (backward compat).""" + from PIL import Image + + # Create a minimal valid PNG + img = Image.new("RGB", (10, 10), color="blue") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + fetcher = ImageFetcher() + # Test with keyword argument url= + result = fetcher.fetch(url=data_url) + + assert result.success + assert result.format == "png" + assert result.width == 10 + + def test_fetch_by_image_id_returns_resolution_error(self) -> None: + """Test that image_id returns error about needing resolution in async layer.""" + from uuid import uuid4 + + fetcher = ImageFetcher() + test_uuid = str(uuid4()) + result = fetcher.fetch(image_id=test_uuid) + + assert not result.success + # Should explain that image_id needs to be resolved in async layer + assert "not resolved" in (result.error or "").lower() or "async layer" in (result.error or "").lower() diff --git a/service/uv.lock b/service/uv.lock index 3b04ea30..3ae07986 100644 --- a/service/uv.lock +++ b/service/uv.lock @@ -234,6 +234,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005, upload-time = "2025-12-12T08:01:40.209Z" }, ] +[[package]] +name = "babel" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852, upload-time = "2025-02-01T15:17:41.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, +] + [[package]] name = "beautifulsoup4" version = "4.14.3" @@ -485,6 +494,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "courlan" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "babel" }, + { name = "tld" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/54/6d6ceeff4bed42e7a10d6064d35ee43a810e7b3e8beb4abeae8cff4713ae/courlan-1.3.2.tar.gz", hash = "sha256:0b66f4db3a9c39a6e22dd247c72cfaa57d68ea660e94bb2c84ec7db8712af190", size = 206382, upload-time = "2024-10-29T16:40:20.994Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/ca/6a667ccbe649856dcd3458bab80b016681b274399d6211187c6ab969fc50/courlan-1.3.2-py3-none-any.whl", hash = "sha256:d0dab52cf5b5b1000ee2839fbc2837e93b2514d3cb5bb61ae158a55b7a04c6be", size = 33848, upload-time = "2024-10-29T16:40:18.325Z" }, +] + [[package]] name = "coverage" version = "7.13.0" @@ -583,6 +606,21 @@ sqlite = [ { name = "aiosqlite" }, ] +[[package]] +name = "dateparser" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "regex" }, + { name = "tzlocal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/30/064144f0df1749e7bb5faaa7f52b007d7c2d08ec08fed8411aba87207f68/dateparser-1.2.2.tar.gz", hash = "sha256:986316f17cb8cdc23ea8ce563027c5ef12fc725b6fb1d137c14ca08777c5ecf7", size = 329840, upload-time = "2025-06-26T09:29:23.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/22/f020c047ae1346613db9322638186468238bcfa8849b4668a22b97faad65/dateparser-1.2.2-py3-none-any.whl", hash = "sha256:5a5d7211a09013499867547023a2a0c91d5a27d15dd4dbcea676ea9fe66f2482", size = 315453, upload-time = "2025-06-26T09:29:21.412Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -1121,6 +1159,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "htmldate" +version = "1.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "charset-normalizer" }, + { name = "dateparser" }, + { name = "lxml" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/10/ead9dabc999f353c3aa5d0dc0835b1e355215a5ecb489a7f4ef2ddad5e33/htmldate-1.9.4.tar.gz", hash = "sha256:1129063e02dd0354b74264de71e950c0c3fcee191178321418ccad2074cc8ed0", size = 44690, upload-time = "2025-11-04T17:46:44.983Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/bd/adfcdaaad5805c0c5156aeefd64c1e868c05e9c1cd6fd21751f168cd88c7/htmldate-1.9.4-py3-none-any.whl", hash = "sha256:1b94bcc4e08232a5b692159903acf95548b6a7492dddca5bb123d89d6325921c", size = 31558, upload-time = "2025-11-04T17:46:43.258Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -1327,6 +1381,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, ] +[[package]] +name = "justext" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lxml", extra = ["html-clean"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/f3/45890c1b314f0d04e19c1c83d534e611513150939a7cf039664d9ab1e649/justext-3.0.2.tar.gz", hash = "sha256:13496a450c44c4cd5b5a75a5efcd9996066d2a189794ea99a49949685a0beb05", size = 828521, upload-time = "2025-02-25T20:21:49.934Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/ac/52f4e86d1924a7fc05af3aeb34488570eccc39b4af90530dd6acecdf16b5/justext-3.0.2-py2.py3-none-any.whl", hash = "sha256:62b1c562b15c3c6265e121cc070874243a443bfd53060e869393f09d6b6cc9a7", size = 837940, upload-time = "2025-02-25T20:21:44.179Z" }, +] + [[package]] name = "kombu" version = "5.6.1" @@ -1652,6 +1718,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ea/7b/93c73c67db235931527301ed3785f849c78991e2e34f3fd9a6663ffda4c5/lxml-6.0.2-cp312-cp312-win_arm64.whl", hash = "sha256:61cb10eeb95570153e0c0e554f58df92ecf5109f75eacad4a95baa709e26c3d6", size = 3672836, upload-time = "2025-09-22T04:01:52.145Z" }, ] +[package.optional-dependencies] +html-clean = [ + { name = "lxml-html-clean" }, +] + +[[package]] +name = "lxml-html-clean" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lxml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/cb/c9c5bb2a9c47292e236a808dd233a03531f53b626f36259dcd32b49c76da/lxml_html_clean-0.4.3.tar.gz", hash = "sha256:c9df91925b00f836c807beab127aac82575110eacff54d0a75187914f1bd9d8c", size = 21498, upload-time = "2025-10-02T20:49:24.895Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/4a/63a9540e3ca73709f4200564a737d63a4c8c9c4dd032bab8535f507c190a/lxml_html_clean-0.4.3-py3-none-any.whl", hash = "sha256:63fd7b0b9c3a2e4176611c2ca5d61c4c07ffca2de76c14059a81a2825833731e", size = 14177, upload-time = "2025-10-02T20:49:23.749Z" }, +] + [[package]] name = "mako" version = "1.3.10" @@ -2478,6 +2561,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + [[package]] name = "pywin32" version = "311" @@ -2811,6 +2903,7 @@ dependencies = [ { name = "reportlab" }, { name = "requests" }, { name = "sqlmodel" }, + { name = "trafilatura" }, { name = "websockets" }, ] @@ -2877,6 +2970,7 @@ requires-dist = [ { name = "reportlab", specifier = ">=4.4.7" }, { name = "requests", specifier = ">=2.32.4" }, { name = "sqlmodel", specifier = ">=0.0.24" }, + { name = "trafilatura", specifier = ">=1.12.0" }, { name = "websockets", specifier = ">=13.0,<14.0" }, ] @@ -3041,6 +3135,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd", size = 878694, upload-time = "2025-10-06T20:21:59.876Z" }, ] +[[package]] +name = "tld" +version = "0.13.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/a1/5723b07a70c1841a80afc9ac572fdf53488306848d844cd70519391b0d26/tld-0.13.1.tar.gz", hash = "sha256:75ec00936cbcf564f67361c41713363440b6c4ef0f0c1592b5b0fbe72c17a350", size = 462000, upload-time = "2025-05-21T22:18:29.341Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/70/b2f38360c3fc4bc9b5e8ef429e1fde63749144ac583c2dbdf7e21e27a9ad/tld-0.13.1-py2.py3-none-any.whl", hash = "sha256:a2d35109433ac83486ddf87e3c4539ab2c5c2478230e5d9c060a18af4b03aa7c", size = 274718, upload-time = "2025-05-21T22:18:25.811Z" }, +] + [[package]] name = "tqdm" version = "4.67.1" @@ -3053,6 +3156,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, ] +[[package]] +name = "trafilatura" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "courlan" }, + { name = "htmldate" }, + { name = "justext" }, + { name = "lxml" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/25/e3ebeefdebfdfae8c4a4396f5a6ea51fc6fa0831d63ce338e5090a8003dc/trafilatura-2.0.0.tar.gz", hash = "sha256:ceb7094a6ecc97e72fea73c7dba36714c5c5b577b6470e4520dca893706d6247", size = 253404, upload-time = "2024-12-03T15:23:24.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/b6/097367f180b6383a3581ca1b86fcae284e52075fa941d1232df35293363c/trafilatura-2.0.0-py3-none-any.whl", hash = "sha256:77eb5d1e993747f6f20938e1de2d840020719735690c840b9a1024803a4cd51d", size = 132557, upload-time = "2024-12-03T15:23:21.41Z" }, +] + [[package]] name = "typer" version = "0.20.0" diff --git a/web/src/app/chat/SpatialWorkspace.tsx b/web/src/app/chat/SpatialWorkspace.tsx index 7f3ae57a..8d9bd1d5 100644 --- a/web/src/app/chat/SpatialWorkspace.tsx +++ b/web/src/app/chat/SpatialWorkspace.tsx @@ -11,9 +11,13 @@ import "@xyflow/react/dist/style.css"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import AddAgentModal from "@/components/modals/AddAgentModal"; +import AgentSettingsModal from "@/components/modals/AgentSettingsModal"; +import ConfirmationModal from "@/components/modals/ConfirmationModal"; import { useMyMarketplaceListings } from "@/hooks/useMarketplace"; import { useXyzen } from "@/store"; +import type { Agent } from "@/types/agents"; import { AnimatePresence } from "framer-motion"; +import { useTranslation } from "react-i18next"; import { AddAgentButton, @@ -34,6 +38,7 @@ import { } from "./spatial"; function InnerWorkspace() { + const { t } = useTranslation(); const { agents, updateAgentLayout, @@ -101,6 +106,10 @@ function InnerWorkspace() { }); const [saveStatus, setSaveStatus] = useState("idle"); const [isAddModalOpen, setAddModalOpen] = useState(false); + const [editingAgent, setEditingAgent] = useState(null); + const [isEditModalOpen, setEditModalOpen] = useState(false); + const [agentToDelete, setAgentToDelete] = useState(null); + const [isConfirmModalOpen, setConfirmModalOpen] = useState(false); const [prevViewport, setPrevViewport] = useState(null); const [newlyCreatedAgentId, setNewlyCreatedAgentId] = useState( null, @@ -516,6 +525,29 @@ function InnerWorkspace() { }, 1000); }, [setViewport, getViewport, fitView]); + // Agent edit/delete handlers for FocusedView (with confirmation modal) + const handleEditAgentFromFocus = useCallback( + (agentId: string) => { + const agent = agents.find((a) => a.id === agentId); + if (agent) { + setEditingAgent(agent); + setEditModalOpen(true); + } + }, + [agents], + ); + + const handleDeleteAgentFromFocus = useCallback( + (agentId: string) => { + const agent = agents.find((a) => a.id === agentId); + if (agent) { + setAgentToDelete(agent); + setConfirmModalOpen(true); + } + }, + [agents], + ); + // Viewport change handler const handleViewportChange = useCallback((_: unknown, viewport: Viewport) => { if (focusedAgentIdRef.current) return; @@ -622,6 +654,8 @@ function InnerWorkspace() { onClose={handleCloseFocus} onSwitchAgent={(id) => handleFocus(id)} onCanvasClick={handleCloseFocus} + onEditAgent={handleEditAgentFromFocus} + onDeleteAgent={handleDeleteAgentFromFocus} /> )} @@ -630,6 +664,58 @@ function InnerWorkspace() { isOpen={isAddModalOpen} onClose={() => setAddModalOpen(false)} /> + + {/* Edit Agent Modal */} + {editingAgent && ( + { + setEditModalOpen(false); + setEditingAgent(null); + }} + sessionId="" + agentId={editingAgent.id} + agentName={editingAgent.name} + agent={editingAgent} + currentAvatar={editingAgent.avatar ?? undefined} + onAvatarChange={(avatarUrl) => { + setEditingAgent({ ...editingAgent, avatar: avatarUrl }); + updateAgentAvatar(editingAgent.id, avatarUrl); + }} + onGridSizeChange={() => {}} + onDelete={ + publishedAgentIds.has(editingAgent.id) + ? undefined + : () => { + deleteAgent(editingAgent.id); + setEditModalOpen(false); + setEditingAgent(null); + } + } + /> + )} + + {/* Delete Confirmation Modal */} + {agentToDelete && ( + { + setConfirmModalOpen(false); + setAgentToDelete(null); + }} + onConfirm={() => { + if (publishedAgentIds.has(agentToDelete.id)) return; + deleteAgent(agentToDelete.id); + setConfirmModalOpen(false); + setAgentToDelete(null); + }} + title={t("agents.deleteAgent")} + message={t("agents.deleteConfirmation", { name: agentToDelete.name })} + confirmLabel={t("common.delete")} + cancelLabel={t("common.cancel")} + /> + )} ); } diff --git a/web/src/app/chat/spatial/FocusedView.tsx b/web/src/app/chat/spatial/FocusedView.tsx index 4885b120..23cd23e2 100644 --- a/web/src/app/chat/spatial/FocusedView.tsx +++ b/web/src/app/chat/spatial/FocusedView.tsx @@ -1,7 +1,9 @@ +import { AgentList } from "@/components/agents"; import XyzenChat from "@/components/layouts/XyzenChat"; import { useXyzen } from "@/store"; +import type { Agent } from "@/types/agents"; import { motion } from "framer-motion"; -import { useEffect, useRef } from "react"; +import { useCallback, useEffect, useMemo, useRef } from "react"; import { AgentData } from "./types"; interface FocusedViewProps { @@ -10,6 +12,9 @@ interface FocusedViewProps { onClose: () => void; onSwitchAgent: (id: string) => void; onCanvasClick?: () => void; // Callback specifically for canvas clicks + // Agent edit/delete handlers + onEditAgent?: (agentId: string) => void; + onDeleteAgent?: (agentId: string) => void; } export function FocusedView({ @@ -18,12 +23,82 @@ export function FocusedView({ onClose, onSwitchAgent, onCanvasClick, + onEditAgent, + onDeleteAgent, }: FocusedViewProps) { const switcherRef = useRef(null); const chatRef = useRef(null); const { activateChannelForAgent } = useXyzen(); + // Convert AgentData to Agent type for AgentList component + const agentsForList: Agent[] = useMemo( + () => + agents.map((a) => ({ + id: a.id, // Use node ID for switching + name: a.name, + description: a.desc, + avatar: a.avatar, + user_id: "", + created_at: "", + updated_at: "", + })), + [agents], + ); + + // Create a map for quick lookup of original AgentData + const agentDataMap = useMemo( + () => new Map(agents.map((a) => [a.id, a])), + [agents], + ); + + // Get selected agent's node ID + const selectedAgentId = useMemo( + () => agents.find((a) => a.name === agent.name)?.id, + [agents, agent.name], + ); + + // Callbacks to get status and role from original AgentData + const getAgentStatus = useCallback( + (a: Agent) => { + const status = agentDataMap.get(a.id)?.status; + // Map "offline" to "idle" since compact variant only supports "idle" | "busy" + return status === "busy" ? "busy" : "idle"; + }, + [agentDataMap], + ); + + const getAgentRole = useCallback( + (a: Agent) => agentDataMap.get(a.id)?.role, + [agentDataMap], + ); + + const handleAgentClick = useCallback( + (a: Agent) => onSwitchAgent(a.id), + [onSwitchAgent], + ); + + // Map node id back to real agentId for edit/delete + const handleEditClick = useCallback( + (a: Agent) => { + const agentData = agentDataMap.get(a.id); + if (agentData?.agentId && onEditAgent) { + onEditAgent(agentData.agentId); + } + }, + [agentDataMap, onEditAgent], + ); + + const handleDeleteClick = useCallback( + (a: Agent) => { + const agentData = agentDataMap.get(a.id); + if (agentData?.agentId && onDeleteAgent) { + onDeleteAgent(agentData.agentId); + } + }, + [agentDataMap, onDeleteAgent], + ); + // Activate the channel for the selected agent useEffect(() => { if (agent.agentId) { @@ -126,40 +201,17 @@ export function FocusedView({ Active Agents -
- {agents.map((a) => ( - - ))} +
+
diff --git a/web/src/components/agents/AgentList.tsx b/web/src/components/agents/AgentList.tsx new file mode 100644 index 00000000..3fff480c --- /dev/null +++ b/web/src/components/agents/AgentList.tsx @@ -0,0 +1,115 @@ +"use client"; + +import type { Agent } from "@/types/agents"; +import { motion, type Variants } from "framer-motion"; +import React from "react"; +import { AgentListItem } from "./AgentListItem"; + +// Container animation variants for detailed variant +const containerVariants: Variants = { + hidden: { opacity: 0 }, + visible: { + opacity: 1, + transition: { + staggerChildren: 0.08, + delayChildren: 0.1, + }, + }, +}; + +// Base props for both variants +interface AgentListBaseProps { + agents: Agent[]; + onAgentClick?: (agent: Agent) => void; +} + +// Props for detailed variant +interface DetailedAgentListProps extends AgentListBaseProps { + variant: "detailed"; + publishedAgentIds?: Set; + lastConversationTimeByAgent?: Record; + onEdit?: (agent: Agent) => void; + onDelete?: (agent: Agent) => void; + // Compact variant props not used + selectedAgentId?: never; + getAgentStatus?: never; + getAgentRole?: never; +} + +// Props for compact variant +interface CompactAgentListProps extends AgentListBaseProps { + variant: "compact"; + selectedAgentId?: string; + getAgentStatus?: (agent: Agent) => "idle" | "busy"; + getAgentRole?: (agent: Agent) => string | undefined; + // Right-click menu support (shared with detailed) + publishedAgentIds?: Set; + onEdit?: (agent: Agent) => void; + onDelete?: (agent: Agent) => void; + // Detailed variant props not used + lastConversationTimeByAgent?: never; +} + +export type AgentListProps = DetailedAgentListProps | CompactAgentListProps; + +export const AgentList: React.FC = (props) => { + const { agents, variant, onAgentClick } = props; + + if (variant === "detailed") { + const { publishedAgentIds, lastConversationTimeByAgent, onEdit, onDelete } = + props as DetailedAgentListProps; + + return ( + + {agents.map((agent) => ( + + ))} + + ); + } + + // Compact variant + const { + selectedAgentId, + getAgentStatus, + getAgentRole, + publishedAgentIds, + onEdit, + onDelete, + } = props as CompactAgentListProps; + + return ( +
+ {agents.map((agent) => ( + + ))} +
+ ); +}; + +export default AgentList; diff --git a/web/src/components/agents/AgentListItem.tsx b/web/src/components/agents/AgentListItem.tsx new file mode 100644 index 00000000..e7a78e0b --- /dev/null +++ b/web/src/components/agents/AgentListItem.tsx @@ -0,0 +1,474 @@ +"use client"; + +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/animate-ui/components/animate/tooltip"; +import { Badge } from "@/components/base/Badge"; +import { formatTime } from "@/lib/formatDate"; +import type { Agent } from "@/types/agents"; +import { + PencilIcon, + ShoppingBagIcon, + TrashIcon, +} from "@heroicons/react/24/outline"; +import { motion, type Variants } from "framer-motion"; +import React, { useEffect, useRef, useState } from "react"; +import { createPortal } from "react-dom"; +import { useTranslation } from "react-i18next"; + +// Animation variants for detailed variant +const itemVariants: Variants = { + hidden: { y: 20, opacity: 0 }, + visible: { + y: 0, + opacity: 1, + transition: { + type: "spring", + stiffness: 100, + damping: 12, + }, + }, +}; + +// Context menu component +interface ContextMenuProps { + x: number; + y: number; + onEdit: () => void; + onDelete: () => void; + onClose: () => void; + isDefaultAgent?: boolean; + isMarketplacePublished?: boolean; +} + +const ContextMenu: React.FC = ({ + x, + y, + onEdit, + onDelete, + onClose, + isDefaultAgent = false, + isMarketplacePublished = false, +}) => { + const { t } = useTranslation(); + const menuRef = useRef(null); + + useEffect(() => { + const handleClickOutside = (event: MouseEvent) => { + if (menuRef.current && !menuRef.current.contains(event.target as Node)) { + onClose(); + } + }; + + const handleEscape = (event: KeyboardEvent) => { + if (event.key === "Escape") { + onClose(); + } + }; + + document.addEventListener("mousedown", handleClickOutside); + document.addEventListener("keydown", handleEscape); + + return () => { + document.removeEventListener("mousedown", handleClickOutside); + document.removeEventListener("keydown", handleEscape); + }; + }, [onClose]); + + return ( + + + {isMarketplacePublished ? ( + + + + + + + + {t("agents.deleteBlockedMessage", { + defaultValue: + "This agent is published to Agent Market. Please unpublish it first, then delete it.", + })} + + + ) : ( + + )} + + ); +}; + +// Shared props for both variants +interface AgentListItemBaseProps { + agent: Agent; + onClick?: (agent: Agent) => void; +} + +// Props specific to detailed variant +interface DetailedVariantProps extends AgentListItemBaseProps { + variant: "detailed"; + isMarketplacePublished?: boolean; + lastConversationTime?: string; + onEdit?: (agent: Agent) => void; + onDelete?: (agent: Agent) => void; + // Compact variant props not used + isSelected?: never; + status?: never; + role?: never; +} + +// Props specific to compact variant +interface CompactVariantProps extends AgentListItemBaseProps { + variant: "compact"; + isSelected?: boolean; + status?: "idle" | "busy"; + role?: string; + // Right-click menu support (shared with detailed) + isMarketplacePublished?: boolean; + onEdit?: (agent: Agent) => void; + onDelete?: (agent: Agent) => void; + // Detailed variant props not used + lastConversationTime?: never; +} + +export type AgentListItemProps = DetailedVariantProps | CompactVariantProps; + +// Detailed variant component (for sidebar) +const DetailedAgentListItem: React.FC = ({ + agent, + isMarketplacePublished = false, + lastConversationTime, + onClick, + onEdit, + onDelete, +}) => { + const { t } = useTranslation(); + const [contextMenu, setContextMenu] = useState<{ + x: number; + y: number; + } | null>(null); + + const longPressTimer = useRef | null>(null); + const isLongPress = useRef(false); + + const handleTouchStart = (e: React.TouchEvent) => { + isLongPress.current = false; + const touch = e.touches[0]; + const { clientX, clientY } = touch; + + longPressTimer.current = setTimeout(() => { + setContextMenu({ x: clientX, y: clientY }); + // Haptic feedback (best-effort) + try { + if ("vibrate" in navigator) { + navigator.vibrate(10); + } + } catch { + // ignore + } + }, 500); + }; + + const handleTouchEnd = () => { + if (longPressTimer.current) { + clearTimeout(longPressTimer.current); + } + }; + + const handleTouchMove = () => { + if (longPressTimer.current) { + clearTimeout(longPressTimer.current); + longPressTimer.current = null; + } + }; + + // Check if it's a default agent based on tags + const isDefaultAgent = agent.tags?.some((tag) => tag.startsWith("default_")); + + const handleContextMenu = (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + + setContextMenu({ + x: e.clientX, + y: e.clientY, + }); + }; + + return ( + <> + { + if (isLongPress.current) return; + onClick?.(agent); + }} + onContextMenu={handleContextMenu} + onTouchStart={handleTouchStart} + onTouchEnd={handleTouchEnd} + onTouchMove={handleTouchMove} + className={` + group relative flex cursor-pointer items-start gap-4 rounded-sm border p-3 + border-neutral-200 bg-white hover:bg-neutral-50 dark:border-neutral-800 dark:bg-neutral-900 dark:hover:bg-neutral-800/60 + ${agent.id === "default-chat" ? "select-none" : ""} + `} + > + {/* Avatar */} +
+ {agent.name} +
+ + {/* Content */} +
+
+

+ {agent.name} +

+ + {/* Marketplace published badge */} + {isMarketplacePublished && ( + + + + + + + + + + + + {t("agents.badges.marketplace", { + defaultValue: "Published to Marketplace", + })} + + + )} +
+ +

+ {agent.description} +

+ + {/* Last conversation time */} + {lastConversationTime && ( +

+ {formatTime(lastConversationTime)} +

+ )} +
+
+ + {/* Context menu - rendered via portal to escape overflow:hidden containers */} + {contextMenu && + createPortal( + onEdit?.(agent)} + onDelete={() => onDelete?.(agent)} + onClose={() => setContextMenu(null)} + isDefaultAgent={isDefaultAgent} + isMarketplacePublished={isMarketplacePublished} + />, + document.body, + )} + + ); +}; + +// Compact variant component (for spatial workspace switcher) +const CompactAgentListItem: React.FC = ({ + agent, + isSelected = false, + status = "idle", + role, + isMarketplacePublished = false, + onClick, + onEdit, + onDelete, +}) => { + const [contextMenu, setContextMenu] = useState<{ + x: number; + y: number; + } | null>(null); + + const longPressTimer = useRef | null>(null); + const isLongPress = useRef(false); + + // Check if it's a default agent based on tags + const isDefaultAgent = agent.tags?.some((tag) => tag.startsWith("default_")); + + const handleContextMenu = (e: React.MouseEvent) => { + if (!onEdit && !onDelete) return; // No context menu if no handlers + e.preventDefault(); + e.stopPropagation(); + setContextMenu({ x: e.clientX, y: e.clientY }); + }; + + const handleTouchStart = (e: React.TouchEvent) => { + if (!onEdit && !onDelete) return; + isLongPress.current = false; + const touch = e.touches[0]; + const { clientX, clientY } = touch; + + longPressTimer.current = setTimeout(() => { + isLongPress.current = true; + setContextMenu({ x: clientX, y: clientY }); + try { + if ("vibrate" in navigator) { + navigator.vibrate(10); + } + } catch { + // ignore + } + }, 500); + }; + + const handleTouchEnd = () => { + if (longPressTimer.current) { + clearTimeout(longPressTimer.current); + } + }; + + const handleTouchMove = () => { + if (longPressTimer.current) { + clearTimeout(longPressTimer.current); + longPressTimer.current = null; + } + }; + + return ( + <> + + + {/* Context menu - rendered via portal to escape overflow:hidden containers */} + {contextMenu && + (onEdit || onDelete) && + createPortal( + onEdit?.(agent)} + onDelete={() => onDelete?.(agent)} + onClose={() => setContextMenu(null)} + isDefaultAgent={isDefaultAgent} + isMarketplacePublished={isMarketplacePublished} + />, + document.body, + )} + + ); +}; + +// Main component that switches between variants +export const AgentListItem: React.FC = (props) => { + if (props.variant === "detailed") { + return ; + } + return ; +}; + +export default AgentListItem; diff --git a/web/src/components/agents/index.ts b/web/src/components/agents/index.ts new file mode 100644 index 00000000..e71b8a02 --- /dev/null +++ b/web/src/components/agents/index.ts @@ -0,0 +1,2 @@ +export { AgentList, type AgentListProps } from "./AgentList"; +export { AgentListItem, type AgentListItemProps } from "./AgentListItem"; diff --git a/web/src/components/features/FileUploadPreview.tsx b/web/src/components/features/FileUploadPreview.tsx index 3d0d4641..d76a0cd2 100644 --- a/web/src/components/features/FileUploadPreview.tsx +++ b/web/src/components/features/FileUploadPreview.tsx @@ -1,3 +1,4 @@ +import React from "react"; import { useXyzen } from "@/store"; import { FileUploadThumbnail } from "./FileUploadThumbnail"; import clsx from "clsx"; @@ -6,8 +7,11 @@ export interface FileUploadPreviewProps { className?: string; } -export function FileUploadPreview({ className }: FileUploadPreviewProps) { - const { uploadedFiles, isUploading, uploadError } = useXyzen(); +function FileUploadPreviewComponent({ className }: FileUploadPreviewProps) { + // Use selective subscriptions to avoid re-renders from unrelated store changes + const uploadedFiles = useXyzen((state) => state.uploadedFiles); + const isUploading = useXyzen((state) => state.isUploading); + const uploadError = useXyzen((state) => state.uploadError); if (uploadedFiles.length === 0) { return null; @@ -57,3 +61,5 @@ export function FileUploadPreview({ className }: FileUploadPreviewProps) { ); } + +export const FileUploadPreview = React.memo(FileUploadPreviewComponent); diff --git a/web/src/components/features/FileUploadThumbnail.tsx b/web/src/components/features/FileUploadThumbnail.tsx index d31f2e82..44e7a675 100644 --- a/web/src/components/features/FileUploadThumbnail.tsx +++ b/web/src/components/features/FileUploadThumbnail.tsx @@ -1,3 +1,4 @@ +import React, { useCallback } from "react"; import { XMarkIcon, DocumentIcon, @@ -12,20 +13,28 @@ export interface FileUploadThumbnailProps { file: UploadedFile; } -export function FileUploadThumbnail({ file }: FileUploadThumbnailProps) { - const { removeFile, retryUpload } = useXyzen(); +function FileUploadThumbnailComponent({ file }: FileUploadThumbnailProps) { + // Use selective subscriptions to avoid re-renders from unrelated store changes + const removeFile = useXyzen((state) => state.removeFile); + const retryUpload = useXyzen((state) => state.retryUpload); - const handleRemove = (e: React.MouseEvent) => { - e.preventDefault(); - e.stopPropagation(); - removeFile(file.id); - }; + const handleRemove = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + removeFile(file.id); + }, + [removeFile, file.id], + ); - const handleRetry = (e: React.MouseEvent) => { - e.preventDefault(); - e.stopPropagation(); - retryUpload(file.id); - }; + const handleRetry = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + retryUpload(file.id); + }, + [retryUpload, file.id], + ); const getFileIcon = () => { if (file.category === "images") { @@ -74,6 +83,7 @@ export function FileUploadThumbnail({ file }: FileUploadThumbnailProps) { {file.name} ) : ( @@ -165,3 +175,27 @@ export function FileUploadThumbnail({ file }: FileUploadThumbnailProps) { ); } + +// Custom comparison function for React.memo +// Only re-render when relevant file properties change +function arePropsEqual( + prevProps: FileUploadThumbnailProps, + nextProps: FileUploadThumbnailProps, +): boolean { + const prevFile = prevProps.file; + const nextFile = nextProps.file; + + return ( + prevFile.id === nextFile.id && + prevFile.status === nextFile.status && + prevFile.progress === nextFile.progress && + prevFile.thumbnailUrl === nextFile.thumbnailUrl && + prevFile.name === nextFile.name && + prevFile.category === nextFile.category + ); +} + +export const FileUploadThumbnail = React.memo( + FileUploadThumbnailComponent, + arePropsEqual, +); diff --git a/web/src/components/layouts/XyzenAgent.tsx b/web/src/components/layouts/XyzenAgent.tsx index 4acfe7d6..94a536b8 100644 --- a/web/src/components/layouts/XyzenAgent.tsx +++ b/web/src/components/layouts/XyzenAgent.tsx @@ -1,20 +1,10 @@ "use client"; -import { - Tooltip, - TooltipContent, - TooltipProvider, - TooltipTrigger, -} from "@/components/animate-ui/components/animate/tooltip"; -import { Badge } from "@/components/base/Badge"; + +import { TooltipProvider } from "@/components/animate-ui/components/animate/tooltip"; +import { AgentList } from "@/components/agents"; import { useAuth } from "@/hooks/useAuth"; -import { formatTime } from "@/lib/formatDate"; -import { - PencilIcon, - ShoppingBagIcon, - TrashIcon, -} from "@heroicons/react/24/outline"; -import { motion, type Variants } from "framer-motion"; -import React, { useEffect, useMemo, useRef, useState } from "react"; +import { motion } from "framer-motion"; +import { useEffect, useMemo, useState } from "react"; import { useTranslation } from "react-i18next"; import AddAgentModal from "@/components/modals/AddAgentModal"; @@ -26,325 +16,6 @@ import { useXyzen } from "@/store"; // Import types from separate file import type { Agent } from "@/types/agents"; -interface AgentCardProps { - agent: Agent; - isMarketplacePublished?: boolean; - lastConversationTime?: string; - onClick?: (agent: Agent) => void; - onEdit?: (agent: Agent) => void; - onDelete?: (agent: Agent) => void; -} - -// 定义动画变体 -const itemVariants: Variants = { - hidden: { y: 20, opacity: 0 }, - visible: { - y: 0, - opacity: 1, - transition: { - type: "spring", - stiffness: 100, - damping: 12, - }, - }, -}; - -// 右键菜单组件 -interface ContextMenuProps { - x: number; - y: number; - onEdit: () => void; - onDelete: () => void; - onClose: () => void; - isDefaultAgent?: boolean; - isMarketplacePublished?: boolean; - agent?: Agent; -} - -const ContextMenu: React.FC = ({ - x, - y, - onEdit, - onDelete, - onClose, - isDefaultAgent = false, - isMarketplacePublished = false, -}) => { - const { t } = useTranslation(); - const menuRef = useRef(null); - - useEffect(() => { - const handleClickOutside = (event: MouseEvent) => { - if (menuRef.current && !menuRef.current.contains(event.target as Node)) { - onClose(); - } - }; - - const handleEscape = (event: KeyboardEvent) => { - if (event.key === "Escape") { - onClose(); - } - }; - - document.addEventListener("mousedown", handleClickOutside); - document.addEventListener("keydown", handleEscape); - - return () => { - document.removeEventListener("mousedown", handleClickOutside); - document.removeEventListener("keydown", handleEscape); - }; - }, [onClose]); - - return ( - - - {isMarketplacePublished ? ( - - - - - - - - {t("agents.deleteBlockedMessage", { - defaultValue: - "This agent is published to Agent Market. Please unpublish it first, then delete it.", - })} - - - ) : ( - - )} - - ); -}; - -// 详细版本-包括名字,描述,头像,标签以及GPT模型 -const AgentCard: React.FC = ({ - agent, - isMarketplacePublished = false, - lastConversationTime, - onClick, - onEdit, - onDelete, -}) => { - const { t } = useTranslation(); - const [contextMenu, setContextMenu] = useState<{ - x: number; - y: number; - } | null>(null); - - const longPressTimer = useRef | null>(null); - const isLongPress = useRef(false); - - const handleTouchStart = (e: React.TouchEvent) => { - isLongPress.current = false; - const touch = e.touches[0]; - const { clientX, clientY } = touch; - - longPressTimer.current = setTimeout(() => { - setContextMenu({ x: clientX, y: clientY }); - // Haptic feedback (best-effort) - try { - if ("vibrate" in navigator) { - navigator.vibrate(10); - } - } catch { - // ignore - } - }, 500); - }; - - const handleTouchEnd = () => { - if (longPressTimer.current) { - clearTimeout(longPressTimer.current); - } - }; - - const handleTouchMove = () => { - if (longPressTimer.current) { - clearTimeout(longPressTimer.current); - longPressTimer.current = null; - } - }; - - // Check if it's a default agent based on tags - const isDefaultAgent = agent.tags?.some((tag) => tag.startsWith("default_")); - - const handleContextMenu = (e: React.MouseEvent) => { - e.preventDefault(); - e.stopPropagation(); - - setContextMenu({ - x: e.clientX, - y: e.clientY, - }); - }; - - return ( - <> - { - if (isLongPress.current) return; - onClick?.(agent); - }} - onContextMenu={handleContextMenu} - onTouchStart={handleTouchStart} - onTouchEnd={handleTouchEnd} - onTouchMove={handleTouchMove} - className={` - group relative flex cursor-pointer items-start gap-4 rounded-sm border p-3 - border-neutral-200 bg-white hover:bg-neutral-50 dark:border-neutral-800 dark:bg-neutral-900 dark:hover:bg-neutral-800/60 - ${agent.id === "default-chat" ? "select-none" : ""} - `} - > - {/* 头像 */} -
- {agent.name} -
- - {/* 内容 */} -
-
-

- {agent.name} -

- - {/* Marketplace published badge */} - {isMarketplacePublished && ( - - - - - - - - - - - - {t("agents.badges.marketplace", { - defaultValue: "Published to Marketplace", - })} - - - )} - - {/* Knowledge set badge */} - {/* {knowledgeSetName && ( -
- - 📚 {knowledgeSetName} - -
- )} */} -
- -

- {agent.description} -

- - {/* Last conversation time */} - {lastConversationTime && ( -

- {formatTime(lastConversationTime)} -

- )} -
-
- - {/* 右键菜单 */} - {contextMenu && ( - onEdit?.(agent)} - onDelete={() => onDelete?.(agent)} - onClose={() => setContextMenu(null)} - isDefaultAgent={isDefaultAgent} - isMarketplacePublished={isMarketplacePublished} - agent={agent} - /> - )} - - ); -}; - -const containerVariants: Variants = { - hidden: { opacity: 0 }, - visible: { - opacity: 1, - transition: { - staggerChildren: 0.08, - delayChildren: 0.1, - }, - }, -}; - interface XyzenAgentProps { systemAgentType?: "chat" | "all"; } @@ -502,21 +173,18 @@ export default function XyzenAgent({ - {allAgents.map((agent) => ( - - ))} + - - {/* MCP Tooltip */} -
- - {/* Arrow */} -
-
-
+ // Get connected server IDs from agent + const connectedServerIds = new Set( + agent.mcp_server_ids || agent.mcp_servers?.map((s) => s.id) || [], ); -} -/** - * MCP Tooltip content component - */ -function McpTooltipContent({ mcpInfo }: { mcpInfo: McpInfo }) { - const { t } = useTranslation(); + // Separate servers into connected and available + const connectedServers = allMcpServers.filter((server) => + connectedServerIds.has(server.id), + ); + const availableServers = allMcpServers.filter( + (server) => !connectedServerIds.has(server.id), + ); + + const handleMcpServerToggle = async (serverId: string, connect: boolean) => { + if (!agent || isUpdating) return; + + setIsUpdating(serverId); + try { + const currentIds = + agent.mcp_server_ids || agent.mcp_servers?.map((s) => s.id) || []; + const newIds = connect + ? [...currentIds, serverId] + : currentIds.filter((id) => id !== serverId); + + await onUpdateAgent({ + ...agent, + mcp_server_ids: newIds, + }); + } catch (error) { + console.error("Failed to update MCP server:", error); + } finally { + setIsUpdating(null); + } + }; return ( - <> -
-
- - - {t("app.toolbar.mcpTools")} - -
-
- {t("app.chat.assistantsTitle")}: {mcpInfo.agent.name} -
-
+ + + + + +
+ {/* Header */} +
+
+ + + {t("app.toolbar.mcpTools")} + +
+
+ {t("app.chat.assistantsTitle")}: {agent.name} +
+
-
- {mcpInfo.servers.map((server) => ( - - ))} -
- + {/* Connected Servers Section */} + {connectedServers.length > 0 && ( +
+

+ {t("app.toolbar.mcpConnected", "Connected")} +

+
+ {connectedServers.map((server) => ( + handleMcpServerToggle(server.id, false)} + /> + ))} +
+
+ )} + + {/* Available Servers Section */} + {availableServers.length > 0 && ( +
+

+ {t("app.toolbar.mcpAvailable", "Available")} +

+
+ {availableServers.map((server) => ( + handleMcpServerToggle(server.id, true)} + /> + ))} +
+
+ )} + + {/* Empty State */} + {allMcpServers.length === 0 && ( +
+
+ {t("app.toolbar.mcpNoServers", "No MCP servers configured")} +
+ +
+ )} +
+
+
); } /** - * Individual MCP server card + * Individual MCP server toggle item */ -function McpServerCard({ server }: { server: McpServer }) { +interface McpServerToggleItemProps { + server: McpServer; + isConnected: boolean; + isUpdating: boolean; + onToggle: () => void; +} + +function McpServerToggleItem({ + server, + isConnected, + isUpdating, + onToggle, +}: McpServerToggleItemProps) { + const { t } = useTranslation(); + const isOnline = server.status === "online"; + const isDisabled = !isOnline || isUpdating; + return ( -
-
-
-
- + ); } diff --git a/web/src/components/layouts/components/ChatToolbar/MobileMoreMenu.tsx b/web/src/components/layouts/components/ChatToolbar/MobileMoreMenu.tsx index 86793050..6b7b3afa 100644 --- a/web/src/components/layouts/components/ChatToolbar/MobileMoreMenu.tsx +++ b/web/src/components/layouts/components/ChatToolbar/MobileMoreMenu.tsx @@ -1,19 +1,26 @@ /** * Mobile More Menu * - * A popup menu shown on mobile with tool selector and MCP info. + * A popup menu shown on mobile with tool selector and MCP management. */ import McpIcon from "@/assets/McpIcon"; +import { cn } from "@/lib/utils"; import type { Agent } from "@/types/agents"; +import type { McpServer } from "@/types/mcp"; +import { + CheckIcon, + ChevronDownIcon, + Cog6ToothIcon, +} from "@heroicons/react/24/outline"; import { AnimatePresence, motion } from "motion/react"; +import { useState } from "react"; +import { useTranslation } from "react-i18next"; import { ToolSelector } from "./ToolSelector"; interface McpInfo { - servers: Array<{ - id: string; - tools?: Array<{ name: string }>; - }>; + agent: Agent; + servers: McpServer[]; } interface MobileMoreMenuProps { @@ -21,6 +28,8 @@ interface MobileMoreMenuProps { agent: Agent | null; onUpdateAgent: (agent: Agent) => Promise; mcpInfo: McpInfo | null; + allMcpServers?: McpServer[]; + onOpenSettings?: () => void; sessionKnowledgeSetId?: string | null; onUpdateSessionKnowledge?: (knowledgeSetId: string | null) => Promise; } @@ -30,14 +39,61 @@ export function MobileMoreMenu({ agent, onUpdateAgent, mcpInfo, + allMcpServers = [], + onOpenSettings, sessionKnowledgeSetId, onUpdateSessionKnowledge, }: MobileMoreMenuProps) { + const { t } = useTranslation(); + const [showMcpList, setShowMcpList] = useState(false); + const [isUpdating, setIsUpdating] = useState(null); + const handleUpdateAgent = async (updatedAgent: Agent) => { await onUpdateAgent(updatedAgent); // Don't close on toggle - let user configure multiple tools }; + // Get connected server IDs from agent + const connectedServerIds = new Set( + agent?.mcp_server_ids || agent?.mcp_servers?.map((s) => s.id) || [], + ); + + // Separate servers into connected and available + const connectedServers = allMcpServers.filter((server) => + connectedServerIds.has(server.id), + ); + const availableServers = allMcpServers.filter( + (server) => !connectedServerIds.has(server.id), + ); + + const totalTools = + mcpInfo?.servers.reduce( + (total, server) => total + (server.tools?.length || 0), + 0, + ) || 0; + + const handleMcpServerToggle = async (serverId: string, connect: boolean) => { + if (!agent || isUpdating) return; + + setIsUpdating(serverId); + try { + const currentIds = + agent.mcp_server_ids || agent.mcp_servers?.map((s) => s.id) || []; + const newIds = connect + ? [...currentIds, serverId] + : currentIds.filter((id) => id !== serverId); + + await onUpdateAgent({ + ...agent, + mcp_server_ids: newIds, + }); + } catch (error) { + console.error("Failed to update MCP server:", error); + } finally { + setIsUpdating(null); + } + }; + return ( {isOpen && ( @@ -64,23 +120,114 @@ export function MobileMoreMenu({
)} - {/* MCP Tool Info */} - {mcpInfo && ( -
-
-
- - MCP Tools -
- {mcpInfo.servers.length > 0 && ( - - {mcpInfo.servers.reduce( - (total, server) => total + (server.tools?.length || 0), - 0, + {/* MCP Tool Section - Expandable */} + {agent && ( +
+ + + {/* Expandable MCP Server List */} + + {showMcpList && ( + +
+ {/* Empty State */} + {allMcpServers.length === 0 && ( +
+
+ {t( + "app.toolbar.mcpNoServers", + "No MCP servers configured", + )} +
+ +
+ )} + + {/* Connected Servers */} + {connectedServers.length > 0 && ( +
+
+ {t("app.toolbar.mcpConnected", "Connected")} +
+
+ {connectedServers.map((server) => ( + + handleMcpServerToggle(server.id, false) + } + /> + ))} +
+
+ )} + + {/* Available Servers */} + {availableServers.length > 0 && ( +
+
+ {t("app.toolbar.mcpAvailable", "Available")} +
+
+ {availableServers.map((server) => ( + + handleMcpServerToggle(server.id, true) + } + /> + ))} +
+
+ )} +
+
)} -
+
)}
@@ -90,4 +237,64 @@ export function MobileMoreMenu({ ); } +/** + * Mobile MCP Server toggle item + */ +interface MobileMcpServerItemProps { + server: McpServer; + isConnected: boolean; + isUpdating: boolean; + onToggle: () => void; +} + +function MobileMcpServerItem({ + server, + isConnected, + isUpdating, + onToggle, +}: MobileMcpServerItemProps) { + const { t } = useTranslation(); + const isOnline = server.status === "online"; + const isDisabled = !isOnline || isUpdating; + + return ( + + ); +} + export default MobileMoreMenu; diff --git a/web/src/core/agent/toolConfig.ts b/web/src/core/agent/toolConfig.ts index 7df6e3a0..d36e1da9 100644 --- a/web/src/core/agent/toolConfig.ts +++ b/web/src/core/agent/toolConfig.ts @@ -4,7 +4,7 @@ * Tool filter semantics: * - null: All available tools are enabled * - []: No tools enabled - * - ["web_search"]: Only web_search enabled + * - ["web_search", "web_fetch"]: Only web search tools enabled (always bundled) */ import type { Agent } from "@/types/agents"; @@ -12,6 +12,7 @@ import type { Agent } from "@/types/agents"; // Available builtin tool IDs export const BUILTIN_TOOLS = { WEB_SEARCH: "web_search", + WEB_FETCH: "web_fetch", KNOWLEDGE_LIST: "knowledge_list", KNOWLEDGE_READ: "knowledge_read", KNOWLEDGE_WRITE: "knowledge_write", @@ -21,14 +22,11 @@ export const BUILTIN_TOOLS = { MEMORY_SEARCH: "memory_search", } as const; -// All builtin tool IDs as array -export const ALL_BUILTIN_TOOL_IDS = [ +// Web search tools as a group (search + fetch always together) +export const WEB_SEARCH_TOOLS = [ BUILTIN_TOOLS.WEB_SEARCH, - ...Object.values(BUILTIN_TOOLS).filter((id) => id.startsWith("knowledge_")), - BUILTIN_TOOLS.GENERATE_IMAGE, - BUILTIN_TOOLS.READ_IMAGE, - BUILTIN_TOOLS.MEMORY_SEARCH, -]; + BUILTIN_TOOLS.WEB_FETCH, +] as const; // Knowledge tools as a group export const KNOWLEDGE_TOOLS = [ @@ -38,6 +36,15 @@ export const KNOWLEDGE_TOOLS = [ BUILTIN_TOOLS.KNOWLEDGE_SEARCH, ] as const; +// All builtin tool IDs as array +export const ALL_BUILTIN_TOOL_IDS = [ + ...WEB_SEARCH_TOOLS, + ...KNOWLEDGE_TOOLS, + BUILTIN_TOOLS.GENERATE_IMAGE, + BUILTIN_TOOLS.READ_IMAGE, + BUILTIN_TOOLS.MEMORY_SEARCH, +]; + // Image tools as a group export const IMAGE_TOOLS = [ BUILTIN_TOOLS.GENERATE_IMAGE, @@ -74,10 +81,12 @@ export function isToolEnabled(agent: Agent | null, toolId: string): boolean { } /** - * Check if web search is enabled + * Check if web search tools are enabled (web_search + web_fetch) */ export function isWebSearchEnabled(agent: Agent | null): boolean { - return isToolEnabled(agent, BUILTIN_TOOLS.WEB_SEARCH); + const filter = getToolFilter(agent); + if (filter === null) return true; + return WEB_SEARCH_TOOLS.some((toolId) => filter.includes(toolId)); } /** @@ -185,13 +194,50 @@ export function updateKnowledgeEnabled( } /** - * Enable/disable web search + * Enable/disable all web search tools at once (web_search + web_fetch) */ export function updateWebSearchEnabled( agent: Agent, enabled: boolean, ): Record { - return updateToolFilter(agent, BUILTIN_TOOLS.WEB_SEARCH, enabled); + const currentConfig = (agent.graph_config ?? {}) as GraphConfig; + const currentFilter = currentConfig.tool_config?.tool_filter; + + let newFilter: string[] | null; + + if (currentFilter === null || currentFilter === undefined) { + // Currently all enabled + if (enabled) { + // Already enabled, keep null + newFilter = null; + } else { + // Disable web search: list all tools EXCEPT web search ones + newFilter = ALL_BUILTIN_TOOL_IDS.filter( + (id) => + !WEB_SEARCH_TOOLS.includes(id as (typeof WEB_SEARCH_TOOLS)[number]), + ); + } + } else { + // Working with explicit filter + if (enabled) { + const existing = new Set(currentFilter); + WEB_SEARCH_TOOLS.forEach((toolId) => existing.add(toolId)); + newFilter = Array.from(existing); + } else { + newFilter = currentFilter.filter( + (id) => + !WEB_SEARCH_TOOLS.includes(id as (typeof WEB_SEARCH_TOOLS)[number]), + ); + } + } + + return { + ...currentConfig, + tool_config: { + ...currentConfig.tool_config, + tool_filter: newFilter, + }, + }; } /** diff --git a/web/src/i18n/locales/en/app.json b/web/src/i18n/locales/en/app.json index 7435f459..beea8f01 100644 --- a/web/src/i18n/locales/en/app.json +++ b/web/src/i18n/locales/en/app.json @@ -33,6 +33,12 @@ "knowledgeConnect": "Connect Knowledge Base", "knowledgeDisconnect": "Disconnect", "mcpTools": "MCP Tools Connected", + "mcpConnected": "Connected", + "mcpAvailable": "Available", + "mcpToolsCount": "tools", + "mcpOffline": "offline", + "mcpNoServers": "No MCP servers configured", + "mcpOpenSettings": "Open Settings", "searchOff": "Off", "searchOffDesc": "Do not use search", "searchBuiltinDesc": "Use model's native search capability", diff --git a/web/src/i18n/locales/zh/app.json b/web/src/i18n/locales/zh/app.json index 2ec06249..9e689381 100644 --- a/web/src/i18n/locales/zh/app.json +++ b/web/src/i18n/locales/zh/app.json @@ -33,6 +33,12 @@ "knowledgeConnect": "连接知识库", "knowledgeDisconnect": "断开连接", "mcpTools": "MCP 工具已连接", + "mcpConnected": "已连接", + "mcpAvailable": "可用", + "mcpToolsCount": "个工具", + "mcpOffline": "离线", + "mcpNoServers": "未配置 MCP 服务器", + "mcpOpenSettings": "打开设置", "searchOff": "关闭", "searchOffDesc": "不使用搜索功能", "searchBuiltinDesc": "使用模型原生搜索能力", diff --git a/web/src/lib/Markdown.tsx b/web/src/lib/Markdown.tsx index 9f1c8f02..2aeea918 100644 --- a/web/src/lib/Markdown.tsx +++ b/web/src/lib/Markdown.tsx @@ -1,6 +1,7 @@ import { zIndexClasses } from "@/constants/zIndex"; import { Dialog, DialogPanel } from "@headlessui/react"; import { + ArrowDownTrayIcon, ArrowsPointingOutIcon, CheckIcon, ClipboardIcon, @@ -443,9 +444,30 @@ const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms)); const isXyzenDownloadUrl = (src: string) => src.includes("/xyzen/api/v1/files/") && src.includes("/download"); -const MarkdownImage: React.FC> = ( - props, -) => { +const getDownloadFilename = (alt?: string, src?: string): string => { + // Use alt text if available and sensible + if (alt && alt.length > 0 && alt.length < 100 && !alt.includes("/")) { + const sanitized = alt.replace(/[^a-zA-Z0-9-_. ]/g, "").trim(); + if (sanitized) { + return /\.(png|jpg|jpeg|gif|webp|svg)$/i.test(sanitized) + ? sanitized + : `${sanitized}.png`; + } + } + // Try to extract from URL + if (src) { + const urlFilename = src.split("/").pop()?.split("?")[0]; + if (urlFilename && /\.(png|jpg|jpeg|gif|webp|svg)$/i.test(urlFilename)) { + return urlFilename; + } + } + // Fallback + return `image-${Date.now()}.png`; +}; + +const MarkdownImageComponent: React.FC< + React.ImgHTMLAttributes +> = (props) => { const { src, alt, ...rest } = props; const backendUrl = useXyzen((state) => state.backendUrl); const token = useXyzen((state) => state.token); @@ -581,6 +603,7 @@ const MarkdownImage: React.FC> = ( {alt} @@ -607,18 +630,50 @@ const MarkdownImage: React.FC> = ( setIsLightboxOpen(false)} /> - setIsLightboxOpen(false)} - className="absolute top-6 right-6 z-10 rounded-full bg-white/10 p-3 text-white hover:bg-white/20 transition-colors backdrop-blur-sm border border-white/20" - aria-label="Close" - > - - + {/* Download and Close buttons */} +
+ { + e.stopPropagation(); + try { + // Fetch image as blob to handle cross-origin downloads + const response = await fetch(imageSrc); + const blob = await response.blob(); + const blobUrl = URL.createObjectURL(blob); + const link = document.createElement("a"); + link.href = blobUrl; + link.download = getDownloadFilename(alt, fullSrc); + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + URL.revokeObjectURL(blobUrl); + } catch (err) { + console.error("Failed to download image:", err); + } + }} + className="rounded-full bg-white/10 p-3 text-white hover:bg-white/20 transition-colors backdrop-blur-sm border border-white/20" + aria-label="Download" + > + + + setIsLightboxOpen(false)} + className="rounded-full bg-white/10 p-3 text-white hover:bg-white/20 transition-colors backdrop-blur-sm border border-white/20" + aria-label="Close" + > + + +
> = ( ); }; +// Memoize MarkdownImage to prevent re-renders during streaming +// Only re-render when src or alt changes +const MarkdownImage = React.memo( + MarkdownImageComponent, + (prevProps, nextProps) => + prevProps.src === nextProps.src && prevProps.alt === nextProps.alt, +); + // Helper component to catch Escape key for image lightbox function ImageLightboxEscapeCatcher({ onEscape }: { onEscape: () => void }) { useEffect(() => { @@ -761,9 +824,7 @@ const Markdown: React.FC = function Markdown(props) { ); }, - img(props: React.ComponentPropsWithoutRef<"img">) { - return ; - }, + img: MarkdownImage, }), [isDark], ); diff --git a/web/src/service/fileService.ts b/web/src/service/fileService.ts index 8fcb0333..296d9264 100644 --- a/web/src/service/fileService.ts +++ b/web/src/service/fileService.ts @@ -388,7 +388,8 @@ class FileService { } /** - * Generate thumbnail URL for preview + * Generate thumbnail URL for preview using Canvas API + * Resizes image to max 160px dimension and outputs as JPEG for small file size */ generateThumbnail(file: File): Promise { return new Promise((resolve, reject) => { @@ -397,16 +398,53 @@ class FileService { return; } - const reader = new FileReader(); - reader.onload = (e) => { - if (e.target?.result) { - resolve(e.target.result as string); + const MAX_SIZE = 160; + const objectUrl = URL.createObjectURL(file); + const img = new Image(); + + img.onload = () => { + URL.revokeObjectURL(objectUrl); + + // Calculate scaled dimensions maintaining aspect ratio + let width = img.width; + let height = img.height; + + if (width > height) { + if (width > MAX_SIZE) { + height = Math.round((height * MAX_SIZE) / width); + width = MAX_SIZE; + } } else { - reject(new Error("Failed to read file")); + if (height > MAX_SIZE) { + width = Math.round((width * MAX_SIZE) / height); + height = MAX_SIZE; + } + } + + // Create canvas and draw scaled image + const canvas = document.createElement("canvas"); + canvas.width = width; + canvas.height = height; + + const ctx = canvas.getContext("2d"); + if (!ctx) { + reject(new Error("Failed to get canvas context")); + return; } + + ctx.drawImage(img, 0, 0, width, height); + + // Export as JPEG with 0.8 quality for small file size + const thumbnailUrl = canvas.toDataURL("image/jpeg", 0.8); + resolve(thumbnailUrl); }; - reader.onerror = () => reject(new Error("Failed to read file")); - reader.readAsDataURL(file); + + img.onerror = () => { + URL.revokeObjectURL(objectUrl); + reject(new Error("Failed to load image")); + }; + + img.src = objectUrl; }); } } diff --git a/web/src/store/slices/chatSlice.ts b/web/src/store/slices/chatSlice.ts index e59b7df4..16acb8b8 100644 --- a/web/src/store/slices/chatSlice.ts +++ b/web/src/store/slices/chatSlice.ts @@ -29,6 +29,7 @@ import type { export interface ChatSlice { // Chat panel state activeChatChannel: string | null; + activeTopicByAgent: Record; // agentId -> topicId mapping chatHistory: ChatHistoryItem[]; chatHistoryLoading: boolean; channels: Record; @@ -117,6 +118,7 @@ export const createChatSlice: StateCreator< return { // Chat panel state activeChatChannel: null, + activeTopicByAgent: {}, chatHistory: [], chatHistoryLoading: true, channels: {}, @@ -240,6 +242,15 @@ export const createChatSlice: StateCreator< } set({ activeChatChannel: topicId }); + + // Track active topic per agent + const existingChannel = channels[topicId]; + if (existingChannel?.agentId) { + set((state: ChatSlice) => { + state.activeTopicByAgent[existingChannel.agentId!] = topicId; + }); + } + let channel = channels[topicId]; if (!channel) { @@ -398,31 +409,21 @@ export const createChatSlice: StateCreator< /** * Activate or create a chat channel for a specific agent. * This is used by the spatial workspace to open chat with an agent. - * - If a session exists for the agent, activates the most recent topic + * - If user previously had an active topic for this agent, restore it + * - If no previous topic, activates the most recent topic * - If no session exists, creates one with a default topic */ activateChannelForAgent: async (agentId: string) => { - const { channels, chatHistory, backendUrl } = get(); - - // First, check if we already have a channel for this agent - const existingChannel = Object.values(channels).find( - (ch) => ch.agentId === agentId, - ); - - if (existingChannel) { - // Already have a channel, activate it - await get().activateChannel(existingChannel.id); - return; - } + const { backendUrl, activeTopicByAgent, channels } = get(); - // Check chat history for existing topics with this agent - const existingHistory = chatHistory.find((h) => h.sessionId === agentId); - if (existingHistory) { - await get().activateChannel(existingHistory.id); + // First check if there's a previously active topic for this agent + const previousTopicId = activeTopicByAgent[agentId]; + if (previousTopicId && channels[previousTopicId]) { + await get().activateChannel(previousTopicId); return; } - // No existing channel, try to find or create a session for this agent + // No previous topic, fetch from backend to get the most recent topic const token = authService.getToken(); if (!token) { console.error("No authentication token available"); @@ -446,8 +447,8 @@ export const createChatSlice: StateCreator< // Get the most recent topic for this session, or create one if (session.topics && session.topics.length > 0) { - // Activate the most recent topic - const latestTopic = session.topics[session.topics.length - 1]; + // Activate the most recent topic (backend returns topics ordered by updated_at descending) + const latestTopic = session.topics[0]; // Create channel if doesn't exist const channel: ChatChannel = {