From 4400d91aea113758b285510eb29871d129c37a7d Mon Sep 17 00:00:00 2001 From: Meng Junxing Date: Thu, 22 Jan 2026 16:49:06 +0800 Subject: [PATCH 1/8] Feature/literature mcp (#192) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: literature-MCP 完整功能 * refactor: improve boolean parsing and logging in literature search functions * feat: enhance literature search functionality with improved query validation and detailed results formatting * refactor: rename oa_url to access_url in LiteratureWork model and related tests --- service/app/mcp/literature.py | 463 +++++++++++++ service/app/utils/literature/__init__.py | 17 + service/app/utils/literature/base_client.py | 32 + service/app/utils/literature/doi_cleaner.py | 116 ++++ service/app/utils/literature/models.py | 82 +++ .../app/utils/literature/openalex_client.py | 611 ++++++++++++++++++ .../app/utils/literature/work_distributor.py | 164 +++++ .../__init__.py | 0 .../unit/test_literature/test_base_client.py | 300 +++++++++ .../unit/test_literature/test_doi_cleaner.py | 403 ++++++++++++ .../test_literature/test_openalex_client.py | 461 +++++++++++++ .../test_literature/test_work_distributor.py | 428 ++++++++++++ .../unit/test_utils/test_built_in_tools.py | 227 ------- 13 files changed, 3077 insertions(+), 227 deletions(-) create mode 100644 service/app/mcp/literature.py create mode 100644 service/app/utils/literature/__init__.py create mode 100644 service/app/utils/literature/base_client.py create mode 100644 service/app/utils/literature/doi_cleaner.py create mode 100644 service/app/utils/literature/models.py create mode 100644 service/app/utils/literature/openalex_client.py create mode 100644 service/app/utils/literature/work_distributor.py rename service/tests/unit/{test_utils => test_literature}/__init__.py (100%) create mode 100644 service/tests/unit/test_literature/test_base_client.py create mode 100644 service/tests/unit/test_literature/test_doi_cleaner.py create mode 100644 service/tests/unit/test_literature/test_openalex_client.py create mode 100644 service/tests/unit/test_literature/test_work_distributor.py delete mode 100644 service/tests/unit/test_utils/test_built_in_tools.py diff --git a/service/app/mcp/literature.py b/service/app/mcp/literature.py new file mode 100644 index 00000000..5c5e55ca --- /dev/null +++ b/service/app/mcp/literature.py @@ -0,0 +1,463 @@ +""" +Literature MCP Server - Multi-source academic literature search + +Provides tools for searching academic literature from multiple data sources +(OpenAlex, Semantic Scholar, PubMed, etc.) with unified interface. +""" + +import json +import logging +from datetime import datetime +from typing import Any + +import httpx +from fastmcp import FastMCP + +from app.utils.literature import SearchRequest, WorkDistributor + +logger = logging.getLogger(__name__) + +TRUE_VALUES = frozenset({"true", "1", "yes"}) +FALSE_VALUES = frozenset({"false", "0", "no"}) + +# Create FastMCP instance +mcp = FastMCP("literature") + +# Metadata for MCP server +__mcp_metadata__ = { + "name": "Literature Search", + "description": "Search academic literature from multiple sources with advanced filtering", + "version": "1.0.0", +} + + +@mcp.tool() +async def search_literature( + query: str, + mailto: str | None = None, + author: str | None = None, + institution: str | None = None, + source: str | None = None, + year_from: str | None = None, + year_to: str | None = None, + is_oa: str | None = None, + work_type: str | None = None, + language: str | None = None, + is_retracted: str | None = None, + has_abstract: str | None = None, + has_fulltext: str | None = None, + sort_by: str = "relevance", + max_results: str | int = 50, + data_sources: list[str] | None = None, + include_abstract: str | bool = False, +) -> str: + """ + Search academic literature from multiple data sources (OpenAlex, Semantic Scholar, PubMed, etc.) + + 🔑 STRONGLY RECOMMENDED: Always provide a valid email address (mailto parameter) + ═════════════════════════════════════════════════════════════════════════════════ + + 📊 Performance Difference: + - WITH email (mailto): 10 requests/second (fast, ideal for large searches) + - WITHOUT email (mailto): 1 request/second (slow, sequential processing) + + ⚠️ Impact: Omitting email can cause 10x slowdown or timeouts for large result sets. + Production research should ALWAYS include email. Example: "researcher@university.edu" + + Response Format Overview + ════════════════════════ + The tool returns TWO sections automatically: + + 1️⃣ EXECUTIVE SUMMARY + - Key statistics (total found, unique count, sources) + - Average citations and open access rate + - Publication year range + - Warning/issue resolution status + + 2️⃣ DETAILED RESULTS (Complete JSON with URLs) + - Each paper includes: + • ✅ Valid URLs (access_url; doi is a raw identifier) + • Title, Authors (first 5), Publication Year + • Citation Count, Journal, Open Access Status + • Abstract (only if include_abstract=True) + - Format: JSON array for easy parsing/import + - All URLs are validated and functional + + Args: + query: Search keywords (e.g., "machine learning", "CRISPR", "cancer immunotherapy") + [REQUIRED] Most important parameter for accurate results + + mailto: Email address to enable fast API pool at OpenAlex + [⭐ STRONGLY RECOMMENDED - includes your email] + Examples: "researcher@mit.edu", "student@university.edu", "name@company.com" + Impact: 10x faster searches. Production users MUST provide this. + Note: Email is private, only used for API identification. + + author: OPTIONAL - Filter by author name (e.g., "Albert Einstein", "Jennifer Doudna") + Will auto-correct common misspellings if not found exactly + + institution: OPTIONAL - Filter by affiliation (e.g., "MIT", "Harvard", "Stanford University") + Partial name matching supported + + source: OPTIONAL - Filter by journal/venue (e.g., "Nature", "Science", "JAMA") + Matches both journal names and abbreviated titles + + year_from: OPTIONAL - Start year (e.g., "2020" or 2020) + Accepts string or integer, will auto-clamp to valid range (1700-2026) + + year_to: OPTIONAL - End year (e.g., "2024" or 2024) + Accepts string or integer, will auto-clamp to valid range (1700-2026) + If year_from > year_to, they will be automatically swapped + + is_oa: OPTIONAL - Open access filter ("true"/"false"/"yes"/"no") + "true" returns ONLY open access papers with direct links + + work_type: OPTIONAL - Filter by publication type + Options: "article", "review", "preprint", "book", "dissertation", "dataset", etc. + + language: OPTIONAL - Filter by publication language (e.g., "en", "zh", "ja", "fr", "de") + "en" = English only, "zh" = Chinese only, etc. + + is_retracted: OPTIONAL - Retracted paper filter ("true"/"false") + "false" excludes retracted works (recommended for research) + "true" shows ONLY retracted papers (for auditing) + + has_abstract: OPTIONAL - Require abstract ("true"/"false") + "true" returns only papers with abstracts + + has_fulltext: OPTIONAL - Require full text access ("true"/"false") + "true" returns only papers with available full text + + sort_by: Sort results - "relevance" (default), "cited_by_count", "publication_date" + "cited_by_count" useful for influential papers + "publication_date" shows most recent first + + max_results: Result limit (default: 50, range: 1-1000, accepts string or int) + More results = slower query. Recommended: 50-200 for research + + data_sources: Advanced - Sources to query (default: ["openalex"]) + Can include: ["openalex", "semantic_scholar", "pubmed"] + + include_abstract: Include full abstracts in JSON output? (default: False) + True = include full abstracts for detailed review + False = save token budget by excluding abstracts + + Returns: + Markdown report with two sections: + + 📋 Section 1: EXECUTIVE SUMMARY + └─ Search conditions recap + └─ Total results found & unique count + └─ Statistics: avg citations, OA rate, year range + └─ ⚠️ Any warnings/filter issues & resolutions + + 📊 Section 2: COMPLETE RESULTS (JSON Array) + └─ Each paper object contains: + • "doi": Raw DOI string (not a URL) + • "title": Paper title + • "authors": Author names [first 5 only to save tokens] + • "publication_year": Publication date + • "cited_by_count": Citation impact metric + • "journal": Journal/venue name + • "description": Short description about the paper + └─ access_url is validated and immediately accessible + └─ Copy JSON directly into spreadsheet, database, or reference manager + + Usage Tips (READ THIS!) + ══════════════════════ + ✅ DO: + - Always provide mailto (10x faster searches) + - Start simple: query + mailto first + - Review results before refining search + - Use filters incrementally to narrow down + - Set include_abstract=True only for final review (saves API calls) + + ❌ DON'T: + - Make multiple searches without reviewing first results + - Use vague keywords like "research" or "analysis" + - Search without mailto unless doing quick test + - Ignore the "Next Steps Guide" section + - Omit email for production/important research + """ + try: + # Validate query early to avoid accidental broad searches + if not query or not str(query).strip(): + return "❌ Invalid input: query cannot be empty." + if len(str(query).strip()) < 3: + return "❌ Invalid input: query is too short (minimum 3 characters)." + + # Convert string parameters to proper types + year_from_int = int(year_from) if year_from and str(year_from).strip() else None + year_to_int = int(year_to) if year_to and str(year_to).strip() else None + + # Clamp year ranges (warn but don't block search) + max_year = datetime.now().year + 1 + year_warning = "" + if year_from_int is not None and year_from_int > max_year: + year_warning += f"year_from {year_from_int}→{max_year}. " + year_from_int = max_year + if year_to_int is not None and year_to_int < 1700: + year_warning += f"year_to {year_to_int}→1700. " + year_to_int = 1700 + + # Ensure year_from <= year_to when both are set + if year_from_int is not None and year_to_int is not None and year_from_int > year_to_int: + year_warning += f"year_from {year_from_int} and year_to {year_to_int} swapped to maintain a valid range. " + year_from_int, year_to_int = year_to_int, year_from_int + + # Convert is_oa to boolean + bool_warning_parts: list[str] = [] + + def _parse_bool_field(raw: str | bool | None, field_name: str) -> bool | None: + if raw is None: + return None + if isinstance(raw, bool): + return raw + val = str(raw).strip().lower() + if val in TRUE_VALUES: + return True + if val in FALSE_VALUES: + return False + bool_warning_parts.append(f"{field_name}={raw!r} not recognized; ignoring this filter.") + return None + + # Convert bool-like fields + is_oa_bool = _parse_bool_field(is_oa, "is_oa") + is_retracted_bool = _parse_bool_field(is_retracted, "is_retracted") + has_abstract_bool = _parse_bool_field(has_abstract, "has_abstract") + has_fulltext_bool = _parse_bool_field(has_fulltext, "has_fulltext") + + # Convert max_results to int with early clamping + max_results_warning = "" + try: + max_results_int = int(max_results) if max_results else 50 + except (TypeError, ValueError): + max_results_warning = "⚠️ max_results is not a valid integer; using default 50. " + max_results_int = 50 + + if max_results_int < 1: + max_results_warning += f"max_results {max_results_int}→50 (minimum is 1). " + max_results_int = 50 + elif max_results_int > 1000: + max_results_warning += f"max_results {max_results_int}→1000 (maximum is 1000). " + max_results_int = 1000 + + # Convert include_abstract to bool + include_abstract_bool = str(include_abstract).lower() in {"true", "1", "yes"} if include_abstract else False + + openalex_email = mailto.strip() if mailto and str(mailto).strip() else None + + logger.info( + "Literature search requested: query=%r, mailto=%s, max_results=%d", + query, + "" if openalex_email else None, + max_results_int, + ) + + # Create search request with converted types + request = SearchRequest( + query=query, + author=author, + institution=institution, + source=source, + year_from=year_from_int, + year_to=year_to_int, + is_oa=is_oa_bool, + work_type=work_type, + language=language, + is_retracted=is_retracted_bool, + has_abstract=has_abstract_bool, + has_fulltext=has_fulltext_bool, + sort_by=sort_by, + max_results=max_results_int, + data_sources=data_sources, + ) + + # Execute search + async with WorkDistributor(openalex_email=openalex_email) as distributor: + result = await distributor.search(request) + + if year_warning: + result.setdefault("warnings", []).append(f"⚠️ Year adjusted: {year_warning.strip()}") + if bool_warning_parts: + result.setdefault("warnings", []).append("⚠️ Boolean filter issues: " + " ".join(bool_warning_parts)) + if max_results_warning: + result.setdefault("warnings", []).append(max_results_warning.strip()) + + # Format output + return _format_search_result(request, result, include_abstract_bool) + + except ValueError as e: + logger.warning(f"Literature search validation error: {e}") + return f"❌ Invalid input: {str(e)}" + except httpx.HTTPError as e: + logger.error(f"Literature search network error: {e}", exc_info=True) + return "❌ Network error while contacting literature sources. Please try again later." + except Exception as e: + logger.error(f"Literature search failed: {e}", exc_info=True) + return "❌ Unexpected error during search. Please retry or contact support." + + +def _format_search_result(request: SearchRequest, result: dict[str, Any], include_abstract: bool = False) -> str: + """ + Format search results into human-readable report + JSON data + + Args: + request: Original search request + result: Search result from WorkDistributor + include_abstract: Whether to include abstracts in JSON (default: False to save tokens) + + Returns: + Formatted markdown report with embedded JSON + """ + works = result["works"] + + # Build report sections + sections: list[str] = ["# Literature Search Report\n"] + + # Warnings and resolution status (if any) + if warnings := result.get("warnings", []): + sections.extend(["## ⚠️ Warnings and Resolution Status\n", *warnings, ""]) + + # Search conditions + conditions: list[str] = [ + f"- **Query**: {request.query}", + *([f"- **Author**: {request.author}"] if request.author else []), + *([f"- **Institution**: {request.institution}"] if request.institution else []), + *([f"- **Source**: {request.source}"] if request.source else []), + *( + [f"- **Year Range**: {request.year_from or '...'} - {request.year_to or '...'}"] + if request.year_from or request.year_to + else [] + ), + *([f"- **Open Access Only**: {'Yes' if request.is_oa else 'No'}"] if request.is_oa is not None else []), + *([f"- **Work Type**: {request.work_type}"] if request.work_type else []), + *([f"- **Language**: {request.language}"] if request.language else []), + *( + [f"- **Exclude Retracted**: {'No' if request.is_retracted else 'Yes'}"] + if request.is_retracted is not None + else [] + ), + *( + [f"- **Require Abstract**: {'Yes' if request.has_abstract else 'No'}"] + if request.has_abstract is not None + else [] + ), + *( + [f"- **Require Full Text**: {'Yes' if request.has_fulltext else 'No'}"] + if request.has_fulltext is not None + else [] + ), + f"- **Sort By**: {request.sort_by}", + f"- **Max Results**: {request.max_results}", + ] + sections.extend(["## Search Conditions\n", "\n".join(conditions), ""]) + + # Check if no results + if not works: + sections.extend(["## ❌ No Results Found\n", "**Suggestions to improve your search:**\n"]) + suggestions: list[str] = [ + "1. **Simplify keywords**: Try broader or different terms", + *(["2. **Remove author filter**: Author name may not be recognized"] if request.author else []), + *(["3. **Remove institution filter**: Try without institution constraint"] if request.institution else []), + *(["4. **Remove source filter**: Try without journal constraint"] if request.source else []), + *( + ["5. **Expand year range**: Current range may be too narrow"] + if request.year_from or request.year_to + else [] + ), + *(["6. **Remove open access filter**: Include non-OA papers"] if request.is_oa else []), + "7. **Check spelling**: Verify all terms are spelled correctly", + ] + sections.extend(["\n".join(suggestions), ""]) + return "\n".join(sections) + + # Statistics and overall insights + total_count = result["total_count"] + unique_count = result["unique_count"] + sources = result["sources"] + + stats: list[str] = [ + f"- **Total Found**: {total_count} works", + f"- **After Deduplication**: {unique_count} works", + ] + source_info = ", ".join(f"{name}: {count}" for name, count in sources.items()) + stats.append(f"- **Data Sources**: {source_info}") + + # Add insights + avg_citations = sum(w.cited_by_count for w in works) / len(works) + stats.append(f"- **Average Citations**: {avg_citations:.1f}") + + oa_count = sum(w.is_oa for w in works) + oa_ratio = (oa_count / len(works)) * 100 + stats.append(f"- **Open Access Rate**: {oa_ratio:.1f}% ({oa_count}/{len(works)})") + + if years := [w.publication_year for w in works if w.publication_year]: + stats.append(f"- **Year Range**: {min(years)} - {max(years)}") + + sections.extend(["## Search Statistics\n", "\n".join(stats), ""]) + + # Complete JSON list + sections.extend( + [ + "## Complete Works List (JSON)\n", + "The following JSON contains all works with full abstracts:\n" + if include_abstract + else "The following JSON contains all works (abstracts excluded to save tokens):\n", + "```json", + ] + ) + + # Convert works to dict for JSON serialization + works_dict = [] + for work in works: + work_data = { + "id": work.id, + "doi": work.doi, + "title": work.title, + "authors": work.authors[:5], # Limit to first 5 authors + "publication_year": work.publication_year, + "cited_by_count": work.cited_by_count, + "journal": work.journal, + "primary_institution": work.primary_institution, + "is_oa": work.is_oa, + "access_url": work.access_url, + "source": work.source, + } + # Only include abstract if requested + if include_abstract and work.abstract: + work_data["abstract"] = work.abstract + works_dict.append(work_data) + + sections.extend([json.dumps(works_dict, indent=2, ensure_ascii=False), "```", ""]) + + # Next steps guidance - prevent infinite loops + sections.extend(["---", "## 🎯 Next Steps Guide\n", "**Before making another search, consider:**\n"]) + next_steps: list[str] = [ + *(["✓ **Results found** - Review the JSON data above for your analysis"] if unique_count > 0 else []), + *( + [ + f"⚠️ **Result limit reached** ({request.max_results}) - " + "Consider narrowing filters (author, year, journal) for more targeted results" + ] + if unique_count >= request.max_results + else [] + ), + *( + ["💡 **Few results** - Consider broadening your search by removing some filters"] + if 0 < unique_count < 10 + else [] + ), + "", + "**To refine your search:**", + "- If too many results → Add more specific filters (author, institution, journal, year)", + "- If too few results → Remove filters or use broader keywords", + "- If wrong results → Check filter spelling and try variations", + "", + "⚠️ **Important**: Avoid making multiple similar searches without reviewing results first!", + "Each search consumes API quota and context window. Make targeted, deliberate queries.", + ] + + sections.append("\n".join(next_steps)) + + return "\n".join(sections) diff --git a/service/app/utils/literature/__init__.py b/service/app/utils/literature/__init__.py new file mode 100644 index 00000000..c4dd14ba --- /dev/null +++ b/service/app/utils/literature/__init__.py @@ -0,0 +1,17 @@ +""" +Literature search utilities for multi-source academic literature retrieval +""" + +from .base_client import BaseLiteratureClient +from .doi_cleaner import deduplicate_by_doi, normalize_doi +from .models import LiteratureWork, SearchRequest +from .work_distributor import WorkDistributor + +__all__ = [ + "BaseLiteratureClient", + "normalize_doi", + "deduplicate_by_doi", + "SearchRequest", + "LiteratureWork", + "WorkDistributor", +] diff --git a/service/app/utils/literature/base_client.py b/service/app/utils/literature/base_client.py new file mode 100644 index 00000000..ba8a3db6 --- /dev/null +++ b/service/app/utils/literature/base_client.py @@ -0,0 +1,32 @@ +""" +Abstract base class for literature data source clients +""" + +from abc import ABC, abstractmethod + +from .models import LiteratureWork, SearchRequest + + +class BaseLiteratureClient(ABC): + """ + Base class for literature data source clients + + All data source implementations (OpenAlex, Semantic Scholar, PubMed, etc.) + should inherit from this class and implement the required methods. + """ + + @abstractmethod + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """ + Execute search and return results in standard format + + Args: + request: Standardized search request + + Returns: + Tuple of (works, warnings) where warnings is a list of messages for LLM feedback + + Raises: + Exception: If search fails after retries + """ + pass diff --git a/service/app/utils/literature/doi_cleaner.py b/service/app/utils/literature/doi_cleaner.py new file mode 100644 index 00000000..0af35d6d --- /dev/null +++ b/service/app/utils/literature/doi_cleaner.py @@ -0,0 +1,116 @@ +""" +DOI normalization and deduplication utilities +""" + +import re +from typing import Protocol, TypeVar + + +class WorkWithDOI(Protocol): + """Protocol for objects with DOI and citation information""" + + doi: str | None + cited_by_count: int + publication_year: int | None + + +T = TypeVar("T", bound=WorkWithDOI) + + +def normalize_doi(doi: str | None) -> str | None: + """ + Normalize DOI format to standard form + + Removes common prefixes, validates format, and converts to lowercase. + DOI specification (ISO 26324) defines DOI matching as case-insensitive, + so lowercase conversion is safe and improves consistency. + + Args: + doi: DOI string in any common format + + Returns: + Normalized DOI (e.g., "10.1038/nature12345") or None if invalid + + Examples: + >>> normalize_doi("https://doi.org/10.1038/nature12345") + "10.1038/nature12345" + >>> normalize_doi("DOI: 10.1038/nature12345") + "10.1038/nature12345" + >>> normalize_doi("doi:10.1038/nature12345") + "10.1038/nature12345" + """ + if not doi: + return None + + doi = doi.strip().lower() + + # Remove common prefixes + doi = re.sub(r"^(https?://)?(dx\.)?doi\.org/", "", doi) + doi = re.sub(r"^doi:\s*", "", doi) + + # Validate format (10.xxxx/yyyy) + return doi if re.match(r"^10\.\d+/.+", doi) else None + + +def deduplicate_by_doi(works: list[T]) -> list[T]: + """ + Deduplicate works by DOI, keeping the highest priority version + + Priority rules: + 1. Works with DOI take priority over those without + 2. For same DOI, keep the one with higher citation count + 3. If citation count is equal, keep the most recently published + + Args: + works: List of LiteratureWork objects + + Returns: + Deduplicated list of works + + Examples: + >>> works = [ + ... LiteratureWork(doi="10.1038/1", cited_by_count=100, ...), + ... LiteratureWork(doi="10.1038/1", cited_by_count=50, ...), + ... LiteratureWork(doi=None, ...), + ... ] + >>> unique = deduplicate_by_doi(works) + >>> len(unique) + 2 + >>> unique[0].cited_by_count + 100 + """ + # Group by: with DOI vs without DOI + with_doi: dict[str, T] = {} + without_doi: list[T] = [] + + for work in works: + # Check if work has doi attribute + if not work.doi: + without_doi.append(work) + continue + + doi = normalize_doi(work.doi) + if not doi: + without_doi.append(work) + continue + + # If DOI already exists, compare priority + if doi in with_doi: + existing = with_doi[doi] + + # Higher citation count? + if work.cited_by_count > existing.cited_by_count: + with_doi[doi] = work + # Same citation count, more recent publication? + elif ( + work.cited_by_count == existing.cited_by_count + and work.publication_year + and existing.publication_year + and work.publication_year > existing.publication_year + ): + with_doi[doi] = work + else: + with_doi[doi] = work + + # Combine results: DOI works first, then non-DOI works + return list(with_doi.values()) + without_doi diff --git a/service/app/utils/literature/models.py b/service/app/utils/literature/models.py new file mode 100644 index 00000000..48aca79c --- /dev/null +++ b/service/app/utils/literature/models.py @@ -0,0 +1,82 @@ +""" +Shared data models for literature utilities +""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class SearchRequest: + """ + Standardized search request format for all data sources + + Attributes: + query: Search keywords (searches title, abstract, full text) + author: Author name (will be converted to author ID) + institution: Institution name (will be converted to institution ID) + source: Journal or conference name + year_from: Start year (inclusive) + year_to: End year (inclusive) + is_oa: Filter for open access only + work_type: Work type filter ("article", "review", "preprint", etc.) + language: Language code filter (e.g., "en", "zh", "fr") + is_retracted: Filter for retracted works (True to include only retracted, False to exclude) + has_abstract: Filter for works with abstracts + has_fulltext: Filter for works with full text available + sort_by: Sort method - "relevance", "cited_by_count", "publication_date" + max_results: Maximum number of results to return + data_sources: List of data sources to query (default: ["openalex"]) + """ + + query: str + author: str | None = None + institution: str | None = None + source: str | None = None + year_from: int | None = None + year_to: int | None = None + is_oa: bool | None = None + work_type: str | None = None + language: str | None = None + is_retracted: bool | None = None + has_abstract: bool | None = None + has_fulltext: bool | None = None + sort_by: str = "relevance" + max_results: int = 50 + data_sources: list[str] | None = None + + +@dataclass +class LiteratureWork: + """ + Standardized literature work format across all data sources + + Attributes: + id: Internal ID from the data source + doi: Digital Object Identifier (normalized format) + title: Work title + authors: List of author information [{"name": "...", "id": "..."}] + publication_year: Year of publication + cited_by_count: Number of citations + abstract: Abstract text + journal: Journal or venue name + is_oa: Whether open access + access_url: Best available access link (OA, landing page, or DOI) + primary_institution: First affiliated institution (if available) + source: Data source name ("openalex", "semantic_scholar", etc.) + raw_data: Original data from the source (for debugging) + """ + + id: str + doi: str | None + title: str + authors: list[dict[str, str | None]] + publication_year: int | None + cited_by_count: int + abstract: str | None + journal: str | None + is_oa: bool + source: str + access_url: str | None = None + primary_institution: str | None = None + raw_data: dict[str, Any] = field(default_factory=dict) diff --git a/service/app/utils/literature/openalex_client.py b/service/app/utils/literature/openalex_client.py new file mode 100644 index 00000000..0b08d01a --- /dev/null +++ b/service/app/utils/literature/openalex_client.py @@ -0,0 +1,611 @@ +""" +OpenAlex API client for literature search + +Implements the best practices from OpenAlex API guide: +- Two-step lookup for names (author/institution/source -> ID -> filter) +- Rate limiting with mailto parameter (10 req/s) +- Exponential backoff retry for errors +- Batch queries with pipe separator (up to 50 IDs) +- Maximum page size (200 per page) +- Abstract reconstruction from inverted index +""" + +import asyncio +import logging +import random +from typing import Any + +import httpx + +from .base_client import BaseLiteratureClient +from .doi_cleaner import normalize_doi +from .models import LiteratureWork, SearchRequest + +logger = logging.getLogger(__name__) + + +class _RateLimiter: + """ + Simple global rate limiter with optional concurrency guard. + + Enforces a minimum interval between request starts across all callers. + """ + + def __init__(self, rate_per_second: float, max_concurrency: int) -> None: + self._min_interval = 1.0 / rate_per_second if rate_per_second > 0 else 0.0 + self._lock = asyncio.Lock() + self._last_request = 0.0 + self._semaphore = asyncio.Semaphore(max_concurrency) + + async def __aenter__(self) -> None: + await self._semaphore.acquire() + await self._throttle() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + self._semaphore.release() + + async def _throttle(self) -> None: + if self._min_interval <= 0: + return + + async with self._lock: + now = asyncio.get_running_loop().time() + wait_time = self._last_request + self._min_interval - now + if wait_time > 0: + await asyncio.sleep(wait_time) + self._last_request = asyncio.get_running_loop().time() + + +class OpenAlexClient(BaseLiteratureClient): + """ + OpenAlex API client + + Implements best practices from official API guide for LLMs: + https://docs.openalex.org/api-guide-for-llms + """ + + BASE_URL = "https://api.openalex.org" + MAX_PER_PAGE = 200 + MAX_RETRIES = 5 + TIMEOUT = 30.0 + + def __init__(self, email: str | None, rate_limit: int | None = None, timeout: float = 30.0) -> None: + """ + Initialize OpenAlex client + + Args: + email: Email for polite pool (10x rate limit increase). If None, use default pool. + rate_limit: Requests per second (default: 10 with email, 1 without email) + timeout: Request timeout in seconds (default: 30.0) + """ + self.email = email + self.rate_limit = rate_limit or (10 if self.email else 1) + max_concurrency = 10 if self.email else 1 + self.rate_limiter = _RateLimiter(rate_per_second=self.rate_limit, max_concurrency=max_concurrency) + self.client = httpx.AsyncClient(timeout=timeout) + pool_type = "polite" if self.email else "default" + logger.info( + "OpenAlex client initialized with pool=%s, email=%s, rate_limit=%s/s", + pool_type, + "" if self.email else None, + self.rate_limit, + ) + + @property + def pool_type(self) -> str: + """Return pool type string.""" + return "polite" if self.email else "default" + + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """ + Execute search and return results in standard format + + Implementation steps: + 1. Convert author name -> author ID (if specified) + 2. Convert institution name -> institution ID (if specified) + 3. Convert journal name -> source ID (if specified) + 4. Build filter query + 5. Paginate through results + 6. Transform to standard format + + Args: + request: Standardized search request + + Returns: + Tuple of (works, warnings) + - works: List of literature works in standard format + - warnings: List of warning/info messages for LLM feedback + """ + logger.info( + "OpenAlex search [%s @ %s/s]: query=%r, max_results=%d", + self.pool_type, + self.rate_limit, + request.query, + request.max_results, + ) + + warnings: list[str] = [] + + # Step 1-3: Resolve IDs for names (two-step lookup pattern) + author_id = None + if request.author: + author_id, _success, msg = await self._resolve_author_id(request.author) + warnings.append(msg) + + institution_id = None + if request.institution: + institution_id, _success, msg = await self._resolve_institution_id(request.institution) + warnings.append(msg) + + source_id = None + if request.source: + source_id, _success, msg = await self._resolve_source_id(request.source) + warnings.append(msg) + + # Step 4: Build query parameters + params = self._build_query_params(request, author_id, institution_id, source_id) + + # Step 5: Fetch all pages + works = await self._fetch_all_pages(params, request.max_results) + + # Step 6: Transform to standard format + return [self._transform_work(w) for w in works], warnings + + def _build_query_params( + self, + request: SearchRequest, + author_id: str | None, + institution_id: str | None, + source_id: str | None, + ) -> dict[str, str]: + """ + Build OpenAlex query parameters + + Args: + request: Search request + author_id: Resolved author ID (if any) + institution_id: Resolved institution ID (if any) + source_id: Resolved source ID (if any) + + Returns: + Dictionary of query parameters + """ + params: dict[str, str] = { + "per-page": str(self.MAX_PER_PAGE), + } + + if self.email: + params["mailto"] = self.email + + # Search keywords + if request.query: + params["search"] = request.query + + # Build filters + filters: list[str] = [] + + if author_id: + filters.append(f"authorships.author.id:{author_id}") + + if institution_id: + filters.append(f"authorships.institutions.id:{institution_id}") + + if source_id: + filters.append(f"primary_location.source.id:{source_id}") + + # Year range + if request.year_from and request.year_to: + filters.append(f"publication_year:{request.year_from}-{request.year_to}") + elif request.year_from: + filters.append(f"publication_year:>{request.year_from - 1}") + elif request.year_to: + filters.append(f"publication_year:<{request.year_to + 1}") + + # Open access filter + if request.is_oa is not None: + filters.append(f"is_oa:{str(request.is_oa).lower()}") + + # Work type filter + if request.work_type: + filters.append(f"type:{request.work_type}") + + # Language filter + if request.language: + filters.append(f"language:{request.language}") + + # Retracted filter + if request.is_retracted is not None: + filters.append(f"is_retracted:{str(request.is_retracted).lower()}") + + # Abstract filter + if request.has_abstract is not None: + filters.append(f"has_abstract:{str(request.has_abstract).lower()}") + + # Fulltext filter + if request.has_fulltext is not None: + filters.append(f"has_fulltext:{str(request.has_fulltext).lower()}") + + if filters: + params["filter"] = ",".join(filters) + + # Sorting + sort_map = { + "relevance": None, # Default sorting by relevance + "cited_by_count": "cited_by_count:desc", + "publication_date": "publication_date:desc", + } + if sort := sort_map.get(request.sort_by): + params["sort"] = sort + + return params + + async def _resolve_author_id(self, author_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: author name -> author ID + + Args: + author_name: Author name to search + + Returns: + Tuple of (author_id, success, message) + - author_id: Author ID (e.g., "A5023888391") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/authors" + params: dict[str, str] = {"search": author_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + # Return first result's ID in short format + author_id = results[0]["id"].split("/")[-1] + author_display = results[0].get("display_name", author_name) + logger.info("Resolved author %r -> %s", author_name, author_id) + return author_id, True, f"✓ Author resolved: '{author_name}' -> '{author_display}'" + else: + msg = ( + f"⚠️ Author '{author_name}' not found. " + f"Suggestions: (1) Try full name format like 'Smith, John' or 'John Smith', " + f"(2) Check spelling, (3) Try removing middle name/initial." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve author '{author_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _resolve_institution_id(self, institution_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: institution name -> institution ID + + Args: + institution_name: Institution name to search + + Returns: + Tuple of (institution_id, success, message) + - institution_id: Institution ID (e.g., "I136199984") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/institutions" + params: dict[str, str] = {"search": institution_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + institution_id = results[0]["id"].split("/")[-1] + inst_display = results[0].get("display_name", institution_name) + logger.info("Resolved institution %r -> %s", institution_name, institution_id) + return institution_id, True, f"✓ Institution resolved: '{institution_name}' -> '{inst_display}'" + else: + msg = ( + f"⚠️ Institution '{institution_name}' not found. " + f"Suggestions: (1) Use full official name (e.g., 'Harvard University' not 'Harvard'), " + f"(2) Try variations (e.g., 'MIT' vs 'Massachusetts Institute of Technology'), " + f"(3) Check spelling." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve institution '{institution_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _resolve_source_id(self, source_name: str) -> tuple[str | None, bool, str]: + """ + Two-step lookup: source name -> source ID + + Args: + source_name: Journal/conference name to search + + Returns: + Tuple of (source_id, success, message) + - source_id: Source ID (e.g., "S137773608") or None if not found + - success: Whether resolution was successful + - message: Status message for LLM feedback + """ + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/sources" + params: dict[str, str] = {"search": source_name} + if self.email: + params["mailto"] = self.email + response = await self._request_with_retry(url, params) + + if results := response.get("results", []): + source_id = results[0]["id"].split("/")[-1] + source_display = results[0].get("display_name", source_name) + logger.info("Resolved source %r -> %s", source_name, source_id) + return source_id, True, f"✓ Source resolved: '{source_name}' -> '{source_display}'" + else: + msg = ( + f"⚠️ Source/Journal '{source_name}' not found. " + f"Suggestions: (1) Use full journal name (e.g., 'Nature' or 'Science'), " + f"(2) Try alternative names (e.g., 'JAMA' vs 'Journal of the American Medical Association'), " + f"(3) Check spelling." + ) + logger.warning(msg) + return None, False, msg + except Exception as e: + msg = f"⚠️ Failed to resolve source '{source_name}': {e}" + logger.warning(msg) + return None, False, msg + + async def _fetch_all_pages(self, params: dict[str, str], max_results: int) -> list[dict[str, Any]]: + """ + Paginate through all results up to max_results + + Args: + params: Base query parameters + max_results: Maximum number of results to fetch + + Returns: + List of work objects from API + """ + all_works: list[dict[str, Any]] = [] + page = 1 + + while len(all_works) < max_results: + async with self.rate_limiter: + try: + url = f"{self.BASE_URL}/works" + page_params = {**params, "page": str(page)} + response = await self._request_with_retry(url, page_params) + + works = response.get("results", []) + if not works: + break + + all_works.extend(works) + logger.info("Fetched page %d: %d works", page, len(works)) + + # Check if there are more pages + meta = response.get("meta", {}) + total_count = meta.get("count", 0) + if len(all_works) >= total_count: + break + + page += 1 + + except Exception as e: + logger.error(f"Error fetching page {page}: {e}") + break + + return all_works[:max_results] + + async def _request_with_retry(self, url: str, params: dict[str, str]) -> dict[str, Any]: + """ + HTTP request with exponential backoff retry + + Implements best practices: + - Retry on 403 (rate limit) with exponential backoff + - Retry on 5xx (server error) with exponential backoff + - Don't retry on 4xx (except 403) + - Retry on timeout + + Args: + url: Request URL + params: Query parameters + + Returns: + JSON response + + Raises: + Exception: If all retries fail + """ + for attempt in range(self.MAX_RETRIES): + try: + response = await self.client.get(url, params=params) + + if response.status_code == 200: + return response.json() + elif response.status_code == 429: + retry_after = self._parse_retry_after(response.headers.get("Retry-After")) + wait_time = retry_after if retry_after is not None else 2**attempt + wait_time = self._apply_jitter(wait_time) + logger.warning( + "Rate limited (429), waiting %.2fs... (attempt %d)", + wait_time, + attempt + 1, + ) + await asyncio.sleep(wait_time) + elif response.status_code == 403: + # Rate limited + wait_time = self._apply_jitter(2**attempt) + logger.warning( + "Rate limited (403), waiting %.2fs... (attempt %d)", + wait_time, + attempt + 1, + ) + await asyncio.sleep(wait_time) + elif response.status_code >= 500: + # Server error + wait_time = self._apply_jitter(2**attempt) + logger.warning( + "Server error (%d), waiting %.2fs... (attempt %d)", + response.status_code, + wait_time, + attempt + 1, + ) + await asyncio.sleep(wait_time) + else: + # Other error, don't retry + response.raise_for_status() + + except httpx.TimeoutException: + if attempt >= self.MAX_RETRIES - 1: + raise + wait_time = self._apply_jitter(2**attempt) + logger.warning("Timeout, retrying in %.2fs... (attempt %d)", wait_time, attempt + 1) + await asyncio.sleep(wait_time) + except Exception as e: + logger.error(f"Request failed: {e}") + if attempt >= self.MAX_RETRIES - 1: + raise + wait_time = self._apply_jitter(2**attempt) + await asyncio.sleep(wait_time) + + raise Exception(f"Failed after {self.MAX_RETRIES} retries") + + @staticmethod + def _apply_jitter(wait_time: float) -> float: + return wait_time + random.uniform(0.1, 0.9) + + @staticmethod + def _parse_retry_after(retry_after: str | None) -> float | None: + if not retry_after: + return None + try: + return float(retry_after) + except ValueError: + return None + + def _transform_work(self, work: dict[str, Any]) -> LiteratureWork: + """ + Transform OpenAlex work data to standard format + + Args: + work: Raw work object from OpenAlex API + + Returns: + Standardized LiteratureWork object + """ + # Extract authors + authors: list[dict[str, str | None]] = [] + for authorship in work.get("authorships", []): + author = authorship.get("author", {}) + authors.append( + { + "name": author.get("display_name", "Unknown"), + "id": author.get("id", "").split("/")[-1] if author.get("id") else None, + } + ) + + # Extract journal/source + journal = None + primary_location = work.get("primary_location") or {} + if source := primary_location.get("source"): + journal = source.get("display_name") + + # Extract open access info + oa_info = work.get("open_access", {}) + is_oa = oa_info.get("is_oa", False) + oa_url = oa_info.get("oa_url") + + # Extract abstract (reconstruct from inverted index) + abstract = self._reconstruct_abstract(work.get("abstract_inverted_index")) + + # Extract DOI (remove prefix) + doi = None + if doi_raw := work.get("doi"): + doi = normalize_doi(doi_raw) + + # Extract primary institution (first available) + primary_institution = None + for authorship in work.get("authorships", []): + institutions = authorship.get("institutions", []) + if institutions: + primary_institution = institutions[0].get("display_name") + if primary_institution: + break + + # Build best access URL (OA first, then landing page, then DOI) + access_url = oa_url + if not access_url: + access_url = primary_location.get("landing_page_url") or primary_location.get("pdf_url") + if not access_url and doi: + access_url = f"https://doi.org/{doi}" + + return LiteratureWork( + id=work["id"].split("/")[-1], + doi=doi, + title=work.get("title", "Untitled"), + authors=authors, + publication_year=work.get("publication_year"), + cited_by_count=work.get("cited_by_count", 0), + abstract=abstract, + journal=journal, + is_oa=is_oa, + source="openalex", + access_url=access_url, + primary_institution=primary_institution, + raw_data=work, + ) + + def _reconstruct_abstract(self, inverted_index: dict[str, list[int]] | None) -> str | None: + """ + Reconstruct abstract from inverted index + + OpenAlex stores abstracts as inverted index for efficiency. + Format: {"word": [position1, position2, ...], ...} + + Args: + inverted_index: Inverted index from OpenAlex + + Returns: + Reconstructed abstract text or None + + Examples: + >>> index = {"Hello": [0], "world": [1], "!": [2]} + >>> _reconstruct_abstract(index) + "Hello world !" + """ + if not inverted_index: + return None + + # Expand inverted index to (position, word) pairs + word_positions: list[tuple[int, str]] = [ + (pos, word) for word, positions in inverted_index.items() for pos in positions + ] + + # Sort by position and join + word_positions.sort() + return " ".join(word for _, word in word_positions) + + async def close(self) -> None: + """Close the HTTP client""" + await self.client.aclose() + + async def __aenter__(self) -> "OpenAlexClient": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + await self.close() diff --git a/service/app/utils/literature/work_distributor.py b/service/app/utils/literature/work_distributor.py new file mode 100644 index 00000000..8cdaac59 --- /dev/null +++ b/service/app/utils/literature/work_distributor.py @@ -0,0 +1,164 @@ +""" +Work distributor for coordinating multiple literature data sources +""" + +import inspect +import logging +from typing import Any + +from .doi_cleaner import deduplicate_by_doi +from .models import LiteratureWork, SearchRequest + +logger = logging.getLogger(__name__) + + +class WorkDistributor: + """ + Distribute search requests to multiple literature data sources + and aggregate results + """ + + def __init__(self, openalex_email: str | None = None) -> None: + """ + Initialize distributor with available clients + + Args: + openalex_email: Email for OpenAlex polite pool (required for OpenAlex) + """ + self.clients: dict[str, Any] = {} + self.openalex_email = openalex_email + self._register_clients() + + def _register_clients(self) -> None: + """Register available data source clients""" + # Import here to avoid circular dependencies + try: + from .openalex_client import OpenAlexClient + + self.clients["openalex"] = OpenAlexClient(email=self.openalex_email) + logger.info("Registered OpenAlex client") + except ImportError as e: + logger.warning(f"Failed to register OpenAlex client: {e}") + + # Future: Add more clients + # from .semantic_scholar_client import SemanticScholarClient + # self.clients["semantic_scholar"] = SemanticScholarClient() + + async def close(self) -> None: + """Close any underlying HTTP clients""" + for client in self.clients.values(): + close_method = getattr(client, "close", None) + if callable(close_method): + result = close_method() + if inspect.isawaitable(result): + await result + + async def __aenter__(self) -> "WorkDistributor": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> None: + await self.close() + + async def search(self, request: SearchRequest) -> dict[str, Any]: + """ + Execute search across multiple data sources and aggregate results + + Args: + request: Standardized search request + + Returns: + Dictionary containing: + - total_count: Total number of works fetched (before dedup) + - unique_count: Number of unique works (after dedup) + - sources: Dict of source name -> count + - works: List of deduplicated LiteratureWork objects + - warnings: List of warning/info messages for LLM feedback + + Examples: + >>> distributor = WorkDistributor() + >>> request = SearchRequest(query="machine learning", max_results=50) + >>> result = await distributor.search(request) + >>> print(f"Found {result['unique_count']} unique works") + """ + # Clamp max_results to 50/1000 with warnings + all_warnings: list[str] = [] + if request.max_results < 1: + all_warnings.append("⚠️ max_results < 1; using default 50") + request.max_results = 50 + elif request.max_results > 1000: + all_warnings.append("⚠️ max_results > 1000; using 1000") + request.max_results = 1000 + + # Determine which data sources to use + sources = request.data_sources or ["openalex"] + unknown_sources = [source_name for source_name in sources if source_name not in self.clients] + if unknown_sources: + all_warnings.append("⚠️ Unknown data_sources ignored: " + ", ".join(sorted(set(unknown_sources)))) + + # Collect works and warnings from all sources + all_works: list[LiteratureWork] = [] + source_counts: dict[str, int] = {} + + for source_name in sources: + if client := self.clients.get(source_name): + try: + logger.info("Fetching from %s...", source_name) + works, warnings_data = await client.search(request) + all_warnings.extend(warnings_data) + + all_works.extend(works) + source_counts[source_name] = len(works) + logger.info("Fetched %d works from %s", len(works), source_name) + except Exception as e: + logger.error(f"Error fetching from {source_name}: {e}", exc_info=True) + source_counts[source_name] = 0 + all_warnings.append(f"⚠️ Error fetching from {source_name}: {str(e)}") + else: + logger.warning(f"Data source '{source_name}' not available") + + # Deduplicate by DOI + logger.info("Deduplicating %d works...", len(all_works)) + unique_works = deduplicate_by_doi(all_works) + logger.info("After deduplication: %d unique works", len(unique_works)) + + # Sort results + unique_works = self._sort_works(unique_works, request.sort_by) + + # Limit to max_results + unique_works = unique_works[: request.max_results] + + return { + "total_count": len(all_works), + "unique_count": len(unique_works), + "sources": source_counts, + "works": unique_works, + "warnings": all_warnings, + } + + def _sort_works(self, works: list[LiteratureWork], sort_by: str) -> list[LiteratureWork]: + """ + Sort works by specified criteria + + Args: + works: List of works to sort + sort_by: Sort method - "relevance", "cited_by_count", "publication_date" + + Returns: + Sorted list of works + """ + if sort_by == "cited_by_count": + return sorted(works, key=lambda w: w.cited_by_count, reverse=True) + elif sort_by == "publication_date": + return sorted( + works, + key=lambda w: w.publication_year or float("-inf"), + reverse=True, + ) + else: # relevance or default + # For relevance, keep original order (API returns by relevance) + return works diff --git a/service/tests/unit/test_utils/__init__.py b/service/tests/unit/test_literature/__init__.py similarity index 100% rename from service/tests/unit/test_utils/__init__.py rename to service/tests/unit/test_literature/__init__.py diff --git a/service/tests/unit/test_literature/test_base_client.py b/service/tests/unit/test_literature/test_base_client.py new file mode 100644 index 00000000..a692737a --- /dev/null +++ b/service/tests/unit/test_literature/test_base_client.py @@ -0,0 +1,300 @@ +"""Tests for base literature client.""" + +import pytest + +from app.utils.literature.base_client import BaseLiteratureClient +from app.utils.literature.models import LiteratureWork, SearchRequest + + +class ConcreteClient(BaseLiteratureClient): + """Concrete implementation of BaseLiteratureClient for testing.""" + + async def search(self, request: SearchRequest) -> tuple[list[LiteratureWork], list[str]]: + """Dummy search implementation.""" + return [], [] + + +class TestBaseLiteratureClientProtocol: + """Test BaseLiteratureClient protocol and abstract methods.""" + + def test_cannot_instantiate_abstract_class(self) -> None: + """Test that BaseLiteratureClient cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseLiteratureClient() # type: ignore + + def test_concrete_implementation(self) -> None: + """Test that concrete implementation can be instantiated.""" + client = ConcreteClient() + assert client is not None + assert isinstance(client, BaseLiteratureClient) + + @pytest.mark.asyncio + async def test_search_method_required(self) -> None: + """Test that search method is required.""" + request = SearchRequest(query="test") + result = await ConcreteClient().search(request) + assert result == ([], []) + + +class TestSearchRequestDataclass: + """Test SearchRequest data model.""" + + def test_search_request_required_field(self) -> None: + """Test SearchRequest with required query field.""" + request = SearchRequest(query="machine learning") + assert request.query == "machine learning" + + def test_search_request_default_values(self) -> None: + """Test SearchRequest default values.""" + request = SearchRequest(query="test") + assert request.query == "test" + assert request.author is None + assert request.institution is None + assert request.source is None + assert request.year_from is None + assert request.year_to is None + assert request.is_oa is None + assert request.work_type is None + assert request.language is None + assert request.is_retracted is None + assert request.has_abstract is None + assert request.has_fulltext is None + assert request.sort_by == "relevance" + assert request.max_results == 50 + assert request.data_sources is None + + def test_search_request_all_fields(self) -> None: + """Test SearchRequest with all fields specified.""" + request = SearchRequest( + query="machine learning", + author="John Doe", + institution="MIT", + source="Nature", + year_from=2015, + year_to=2021, + is_oa=True, + work_type="journal-article", + language="en", + is_retracted=False, + has_abstract=True, + has_fulltext=True, + sort_by="cited_by_count", + max_results=100, + data_sources=["openalex", "semantic_scholar"], + ) + + assert request.query == "machine learning" + assert request.author == "John Doe" + assert request.institution == "MIT" + assert request.source == "Nature" + assert request.year_from == 2015 + assert request.year_to == 2021 + assert request.is_oa is True + assert request.work_type == "journal-article" + assert request.language == "en" + assert request.is_retracted is False + assert request.has_abstract is True + assert request.has_fulltext is True + assert request.sort_by == "cited_by_count" + assert request.max_results == 100 + assert request.data_sources == ["openalex", "semantic_scholar"] + + def test_search_request_partial_year_range(self) -> None: + """Test SearchRequest with only year_from.""" + request = SearchRequest(query="test", year_from=2015) + assert request.year_from == 2015 + assert request.year_to is None + + def test_search_request_partial_year_range_to_only(self) -> None: + """Test SearchRequest with only year_to.""" + request = SearchRequest(query="test", year_to=2021) + assert request.year_from is None + assert request.year_to == 2021 + + +class TestLiteratureWorkDataclass: + """Test LiteratureWork data model.""" + + def test_literature_work_minimal(self) -> None: + """Test LiteratureWork with minimal required fields.""" + work = LiteratureWork( + id="W123", + doi=None, + title="Test Paper", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + assert work.id == "W123" + assert work.title == "Test Paper" + assert work.cited_by_count == 0 + assert work.source == "openalex" + + def test_literature_work_complete(self) -> None: + """Test LiteratureWork with all fields.""" + authors = [ + {"name": "John Doe", "id": "A1"}, + {"name": "Jane Smith", "id": "A2"}, + ] + + work = LiteratureWork( + id="W2741809807", + doi="10.1038/nature12345", + title="Machine Learning Fundamentals", + authors=authors, + publication_year=2020, + cited_by_count=150, + abstract="This is a comprehensive review of machine learning concepts.", + journal="Nature", + is_oa=True, + access_url="https://example.com/paper.pdf", + source="openalex", + ) + + assert work.id == "W2741809807" + assert work.doi == "10.1038/nature12345" + assert work.title == "Machine Learning Fundamentals" + assert len(work.authors) == 2 + assert work.authors[0]["name"] == "John Doe" + assert work.publication_year == 2020 + assert work.cited_by_count == 150 + assert work.abstract is not None + assert work.journal == "Nature" + assert work.is_oa is True + assert work.access_url is not None + + def test_literature_work_raw_data_default(self) -> None: + """Test LiteratureWork raw_data defaults to empty dict.""" + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + assert work.raw_data == {} + + def test_literature_work_raw_data_custom(self) -> None: + """Test LiteratureWork with custom raw_data.""" + raw_data = {"custom_field": "value", "api_response": {"status": "ok"}} + + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=[], + publication_year=None, + cited_by_count=0, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + raw_data=raw_data, + ) + + assert work.raw_data == raw_data + assert work.raw_data["custom_field"] == "value" + + def test_literature_work_multiple_authors(self) -> None: + """Test LiteratureWork with multiple authors.""" + authors = [ + {"name": "Author 1", "id": "A1"}, + {"name": "Author 2", "id": None}, # Author without ID + {"name": "Author 3", "id": "A3"}, + ] + + work = LiteratureWork( + id="W123", + doi=None, + title="Test", + authors=authors, + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + assert len(work.authors) == 3 + assert work.authors[1]["id"] is None + + def test_literature_work_comparison(self) -> None: + """Test LiteratureWork equality comparison.""" + work1 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal="Nature", + is_oa=True, + access_url=None, + source="openalex", + ) + + work2 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal="Nature", + is_oa=True, + access_url=None, + source="openalex", + ) + + # DataclassesObjects with same values should be equal + assert work1 == work2 + + def test_literature_work_inequality(self) -> None: + """Test LiteratureWork inequality.""" + work1 = LiteratureWork( + id="W123", + doi="10.1038/nature12345", + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + work2 = LiteratureWork( + id="W456", + doi="10.1038/nature67890", + title="Paper 2", + authors=[], + publication_year=2021, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + assert work1 != work2 diff --git a/service/tests/unit/test_literature/test_doi_cleaner.py b/service/tests/unit/test_literature/test_doi_cleaner.py new file mode 100644 index 00000000..2cfb6082 --- /dev/null +++ b/service/tests/unit/test_literature/test_doi_cleaner.py @@ -0,0 +1,403 @@ +"""Tests for DOI normalization and deduplication utilities.""" + +import pytest + +from app.utils.literature.doi_cleaner import deduplicate_by_doi, normalize_doi +from app.utils.literature.models import LiteratureWork + + +class TestNormalizeDOI: + """Test DOI normalization functionality.""" + + def test_normalize_doi_with_https_prefix(self) -> None: + """Test normalizing DOI with https:// prefix.""" + result = normalize_doi("https://doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_http_prefix(self) -> None: + """Test normalizing DOI with http:// prefix.""" + result = normalize_doi("http://doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_dx_prefix(self) -> None: + """Test normalizing DOI with dx.doi.org prefix.""" + result = normalize_doi("https://dx.doi.org/10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_doi_colon_prefix(self) -> None: + """Test normalizing DOI with 'doi:' prefix.""" + result = normalize_doi("doi:10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_doi_prefix_uppercase(self) -> None: + """Test normalizing DOI with 'DOI:' prefix (uppercase).""" + result = normalize_doi("DOI: 10.1038/nature12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_with_whitespace(self) -> None: + """Test normalizing DOI with leading/trailing whitespace.""" + result = normalize_doi(" 10.1038/nature12345 ") + assert result == "10.1038/nature12345" + + def test_normalize_doi_case_insensitive(self) -> None: + """Test that DOI normalization converts to lowercase.""" + result = normalize_doi("10.1038/NATURE12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_mixed_case_with_prefix(self) -> None: + """Test normalizing DOI with mixed case and prefix.""" + result = normalize_doi("https://DOI.ORG/10.1038/NATURE12345") + assert result == "10.1038/nature12345" + + def test_normalize_doi_none_input(self) -> None: + """Test normalizing None DOI returns None.""" + result = normalize_doi(None) + assert result is None + + def test_normalize_doi_empty_string(self) -> None: + """Test normalizing empty string returns None.""" + result = normalize_doi("") + assert result is None + + def test_normalize_doi_whitespace_only(self) -> None: + """Test normalizing whitespace-only string returns None.""" + result = normalize_doi(" ") + assert result is None + + def test_normalize_doi_invalid_format(self) -> None: + """Test normalizing invalid DOI format returns None.""" + result = normalize_doi("not-a-valid-doi") + assert result is None + + def test_normalize_doi_missing_prefix(self) -> None: + """Test normalizing DOI missing the '10.' prefix returns None.""" + result = normalize_doi("1038/nature12345") + assert result is None + + def test_normalize_doi_missing_suffix(self) -> None: + """Test normalizing DOI missing the suffix returns None.""" + result = normalize_doi("10.1038/") + assert result is None + + def test_normalize_doi_complex_suffix(self) -> None: + """Test normalizing DOI with complex suffix.""" + result = normalize_doi("10.1145/3580305.3599315") + assert result == "10.1145/3580305.3599315" + + def test_normalize_doi_with_version(self) -> None: + """Test normalizing DOI with version suffix.""" + result = normalize_doi("https://doi.org/10.1038/nature.2020.27710") + assert result == "10.1038/nature.2020.27710" + + +class TestDeduplicateByDOI: + """Test DOI-based deduplication functionality.""" + + @pytest.fixture + def sample_work(self) -> LiteratureWork: + """Create a sample literature work.""" + return LiteratureWork( + id="W2741809807", + doi="10.1038/nature12345", + title="Test Paper", + authors=[{"name": "John Doe", "id": "A1"}], + publication_year=2020, + cited_by_count=100, + abstract="Test abstract", + journal="Nature", + is_oa=True, + access_url="https://example.com/paper.pdf", + source="openalex", + ) + + def test_deduplicate_empty_list(self) -> None: + """Test deduplicating empty list returns empty list.""" + result = deduplicate_by_doi([]) + assert result == [] + + def test_deduplicate_single_work(self, sample_work: LiteratureWork) -> None: + """Test deduplicating single work returns same work.""" + result = deduplicate_by_doi([sample_work]) + assert len(result) == 1 + assert result[0].id == sample_work.id + + def test_deduplicate_duplicate_doi_keeps_higher_citations(self, sample_work: LiteratureWork) -> None: + """Test deduplication keeps work with higher citation count.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W1" # Higher citation count + + def test_deduplicate_duplicate_doi_equal_citations_keeps_newer(self) -> None: + """Test deduplication keeps more recently published work when citation count is equal.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2019, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Test Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W2" # More recent publication + + def test_deduplicate_without_doi(self) -> None: + """Test deduplicating works without DOI.""" + work1 = LiteratureWork( + id="W1", + doi=None, + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi=None, + title="Paper 2", + authors=[], + publication_year=2020, + cited_by_count=20, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 2 # Both kept since no DOI + + def test_deduplicate_invalid_doi_treated_as_no_doi(self) -> None: + """Test deduplicating works with invalid DOI treats them as without DOI.""" + work1 = LiteratureWork( + id="W1", + doi="invalid-doi-format", + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Paper 2", + authors=[], + publication_year=2020, + cited_by_count=20, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 2 + # Invalid DOI work should be in the results + assert any(w.id == "W1" for w in result) + assert any(w.id == "W2" for w in result) + + def test_deduplicate_doi_with_versions_deduplicated(self) -> None: + """Test deduplicating DOIs with version info.""" + work1 = LiteratureWork( + id="W1", + doi="https://doi.org/10.1038/nature.2020.27710", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature.2020.27710", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work1, work2]) + assert len(result) == 1 + assert result[0].id == "W2" # Higher citation count + + def test_deduplicate_preserves_order_with_doi(self) -> None: + """Test that deduplication preserves order: DOI works first, then non-DOI.""" + work_no_doi = LiteratureWork( + id="W_no_doi", + doi=None, + title="No DOI", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work_with_doi = LiteratureWork( + id="W_with_doi", + doi="10.1038/nature12345", + title="With DOI", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + result = deduplicate_by_doi([work_no_doi, work_with_doi]) + assert len(result) == 2 + assert result[0].id == "W_with_doi" # DOI works come first + assert result[1].id == "W_no_doi" + + def test_deduplicate_complex_scenario(self) -> None: + """Test deduplication with complex mix of works.""" + works = [ + # Duplicate pair with same DOI + LiteratureWork( + id="W1", + doi="10.1038/A", + title="A", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + LiteratureWork( + id="W2", + doi="10.1038/A", + title="A", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + # Another unique DOI + LiteratureWork( + id="W3", + doi="10.1038/B", + title="B", + authors=[], + publication_year=2021, + cited_by_count=75, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + # No DOI works + LiteratureWork( + id="W4", + doi=None, + title="C", + authors=[], + publication_year=2022, + cited_by_count=30, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + LiteratureWork( + id="W5", + doi=None, + title="D", + authors=[], + publication_year=2022, + cited_by_count=40, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + ] + + result = deduplicate_by_doi(works) + assert len(result) == 4 # W1 removed (duplicate), others kept + result_ids = {w.id for w in result} + assert result_ids == {"W2", "W3", "W4", "W5"} + # Verify W2 (higher citations) was kept over W1 + assert "W2" in result_ids + assert "W1" not in result_ids diff --git a/service/tests/unit/test_literature/test_openalex_client.py b/service/tests/unit/test_literature/test_openalex_client.py new file mode 100644 index 00000000..fadbc62e --- /dev/null +++ b/service/tests/unit/test_literature/test_openalex_client.py @@ -0,0 +1,461 @@ +"""Tests for OpenAlex API client.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.utils.literature.models import SearchRequest +from app.utils.literature.openalex_client import OpenAlexClient + + +class TestOpenAlexClientInit: + """Test OpenAlex client initialization.""" + + def test_client_initialization(self) -> None: + """Test client initializes with correct parameters.""" + email = "test@example.com" + rate_limit = 5 + timeout = 15.0 + + client = OpenAlexClient(email=email, rate_limit=rate_limit, timeout=timeout) + + assert client.email == email + assert client.rate_limit == rate_limit + assert client.pool_type == "polite" + assert pytest.approx(client.rate_limiter._min_interval, rel=0.01) == 1 / rate_limit + + def test_client_initialization_defaults(self) -> None: + """Test client initializes with default parameters.""" + email = "test@example.com" + client = OpenAlexClient(email=email) + + assert client.email == email + assert client.rate_limit == 10 + assert client.pool_type == "polite" + # Verify timeout was set (httpx Timeout object) + assert client.client.timeout is not None + + def test_client_initialization_default_pool(self) -> None: + """Test client initializes default pool when email is missing.""" + client = OpenAlexClient(email=None) + + assert client.email is None + assert client.rate_limit == 1 + assert client.pool_type == "default" + assert pytest.approx(client.rate_limiter._min_interval, rel=0.01) == 1.0 + + +class TestOpenAlexClientSearch: + """Test OpenAlex search functionality.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + @pytest.fixture + def mock_response(self) -> dict: + """Create a mock OpenAlex API response.""" + return { + "meta": {"count": 1, "page": 1}, + "results": [ + { + "id": "https://openalex.org/W2741809807", + "title": "Machine Learning Fundamentals", + "doi": "https://doi.org/10.1038/nature12345", + "publication_year": 2020, + "cited_by_count": 150, + "abstract_inverted_index": { + "Machine": [0], + "learning": [1], + "is": [2], + "fundamental": [3], + }, + "authorships": [ + { + "author": { + "id": "https://openalex.org/A5023888391", + "display_name": "Jane Smith", + } + } + ], + "primary_location": { + "source": { + "id": "https://openalex.org/S137773608", + "display_name": "Nature", + } + }, + "open_access": { + "is_oa": True, + "oa_url": "https://example.com/paper.pdf", + }, + } + ], + } + + @pytest.mark.asyncio + async def test_search_basic_query(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test basic search with simple query.""" + request = SearchRequest(query="machine learning", max_results=10) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + assert works[0].title == "Machine Learning Fundamentals" + assert works[0].doi == "10.1038/nature12345" + assert isinstance(warnings, list) + + @pytest.mark.asyncio + async def test_search_with_author_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with author filter.""" + request = SearchRequest(query="machine learning", author="Jane Smith", max_results=10) + + with patch.object(client, "_resolve_author_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("A5023888391", True, "✓ Author resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Jane Smith") + assert any("Author resolved" in msg for msg in warnings) + + @pytest.mark.asyncio + async def test_search_with_institution_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with institution filter.""" + request = SearchRequest(query="machine learning", institution="Harvard University", max_results=10) + + with patch.object(client, "_resolve_institution_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("I136199984", True, "✓ Institution resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Harvard University") + assert any("Institution resolved" in msg for msg in warnings) + + @pytest.mark.asyncio + async def test_search_with_source_filter(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with source (journal) filter.""" + request = SearchRequest(query="machine learning", source="Nature", max_results=10) + + with patch.object(client, "_resolve_source_id", new_callable=AsyncMock) as mock_resolve: + mock_resolve.return_value = ("S137773608", True, "✓ Source resolved") + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + mock_resolve.assert_called_once_with("Nature") + assert any("Source resolved" in msg for msg in warnings) + + @pytest.mark.asyncio + async def test_search_with_year_range(self, client: OpenAlexClient, mock_response: dict) -> None: + """Test search with year range filter.""" + request = SearchRequest(query="machine learning", year_from=2015, year_to=2021, max_results=10) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_response + + works, warnings = await client.search(request) + + assert len(works) == 1 + # Verify year filter was applied + call_args = mock_request.call_args + params = call_args[0][1] if call_args else {} + assert "2015-2021" in params.get("filter", "") + + @pytest.mark.asyncio + async def test_search_max_results_clamping_low(self, client: OpenAlexClient) -> None: + """Test that search handles low max_results correctly.""" + request = SearchRequest(query="test", max_results=0) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = {"meta": {"count": 0}, "results": []} + + # Should not raise an error even with 0 max_results + works, warnings = await client.search(request) + assert isinstance(works, list) + assert isinstance(warnings, list) + + @pytest.mark.asyncio + async def test_search_max_results_clamping_high(self, client: OpenAlexClient) -> None: + """Test that search handles high max_results correctly.""" + request = SearchRequest(query="test", max_results=5000) + + with patch.object(client, "_request_with_retry", new_callable=AsyncMock) as mock_request: + mock_request.return_value = {"meta": {"count": 0}, "results": []} + + # Should not raise an error even with high max_results + works, warnings = await client.search(request) + assert isinstance(works, list) + assert isinstance(warnings, list) + + +class TestOpenAlexClientPrivateMethods: + """Test OpenAlex client private methods.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + def test_build_query_params_basic(self, client: OpenAlexClient) -> None: + """Test building basic query parameters.""" + request = SearchRequest(query="machine learning", max_results=50) + params = client._build_query_params(request, None, None, None) + + assert params["search"] == "machine learning" + assert params["per-page"] == "200" + assert params["mailto"] == "test@example.com" + + def test_build_query_params_with_filters(self, client: OpenAlexClient) -> None: + """Test building query parameters with filters.""" + request = SearchRequest( + query="machine learning", + year_from=2015, + year_to=2021, + is_oa=True, + work_type="journal-article", + ) + params = client._build_query_params(request, None, None, None) + + assert "filter" in params + assert "publication_year:2015-2021" in params["filter"] + assert "is_oa:true" in params["filter"] + assert "type:journal-article" in params["filter"] + + def test_build_query_params_with_resolved_ids(self, client: OpenAlexClient) -> None: + """Test building query parameters with resolved author/institution/source IDs.""" + request = SearchRequest(query="test") + params = client._build_query_params(request, "A123", "I456", "S789") + + assert "filter" in params + assert "authorships.author.id:A123" in params["filter"] + assert "authorships.institutions.id:I456" in params["filter"] + assert "primary_location.source.id:S789" in params["filter"] + + def test_build_query_params_sorting_by_citations(self, client: OpenAlexClient) -> None: + """Test building query parameters with citation sorting.""" + request = SearchRequest(query="test", sort_by="cited_by_count") + params = client._build_query_params(request, None, None, None) + + assert params.get("sort") == "cited_by_count:desc" + + def test_build_query_params_sorting_by_date(self, client: OpenAlexClient) -> None: + """Test building query parameters with date sorting.""" + request = SearchRequest(query="test", sort_by="publication_date") + params = client._build_query_params(request, None, None, None) + + assert params.get("sort") == "publication_date:desc" + + def test_reconstruct_abstract_normal(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction from inverted index.""" + inverted_index = { + "Machine": [0], + "learning": [1], + "is": [2], + "fundamental": [3], + } + + result = client._reconstruct_abstract(inverted_index) + + assert result == "Machine learning is fundamental" + + def test_reconstruct_abstract_with_duplicates(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction with duplicate words.""" + inverted_index = { + "The": [0, 5], + "quick": [1], + "brown": [2], + "fox": [3], + "jumps": [4], + } + + result = client._reconstruct_abstract(inverted_index) + + assert result == "The quick brown fox jumps The" + + def test_reconstruct_abstract_none(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction returns None for empty input.""" + result = client._reconstruct_abstract(None) + + assert result is None + + def test_reconstruct_abstract_empty(self, client: OpenAlexClient) -> None: + """Test abstract reconstruction returns None for empty dict.""" + result = client._reconstruct_abstract({}) + + assert result is None + + def test_transform_work_complete(self, client: OpenAlexClient) -> None: + """Test transforming complete OpenAlex work object.""" + work_data = { + "id": "https://openalex.org/W2741809807", + "title": "Machine Learning Fundamentals", + "doi": "https://doi.org/10.1038/nature12345", + "publication_year": 2020, + "cited_by_count": 150, + "abstract_inverted_index": {"Machine": [0], "learning": [1]}, + "authorships": [ + { + "author": { + "id": "https://openalex.org/A5023888391", + "display_name": "Jane Smith", + } + }, + { + "author": { + "id": "https://openalex.org/A5023888392", + "display_name": "John Doe", + } + }, + ], + "primary_location": { + "source": { + "id": "https://openalex.org/S137773608", + "display_name": "Nature", + } + }, + "open_access": { + "is_oa": True, + "oa_url": "https://example.com/paper.pdf", + }, + } + + result = client._transform_work(work_data) + + assert result.id == "W2741809807" + assert result.title == "Machine Learning Fundamentals" + assert result.doi == "10.1038/nature12345" + assert result.publication_year == 2020 + assert result.cited_by_count == 150 + assert len(result.authors) == 2 + assert result.authors[0]["name"] == "Jane Smith" + assert result.journal == "Nature" + assert result.is_oa is True + assert result.access_url == "https://example.com/paper.pdf" + assert result.source == "openalex" + + def test_transform_work_minimal(self, client: OpenAlexClient) -> None: + """Test transforming minimal OpenAlex work object.""" + work_data = { + "id": "https://openalex.org/W123", + "title": "Minimal Paper", + "authorships": [], + } + + result = client._transform_work(work_data) + + assert result.id == "W123" + assert result.title == "Minimal Paper" + assert result.doi is None + assert result.authors == [] + assert result.journal is None + assert result.is_oa is False + + +class TestOpenAlexClientRequestWithRetry: + """Test OpenAlex client request retry logic.""" + + @pytest.fixture + def client(self) -> OpenAlexClient: + """Create an OpenAlex client for testing.""" + return OpenAlexClient(email="test@example.com") + + @pytest.mark.asyncio + async def test_request_with_retry_success(self, client: OpenAlexClient) -> None: + """Test successful request without retry.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.return_value = mock_response + + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + mock_get.assert_called_once() + + @pytest.mark.asyncio + async def test_request_with_retry_timeout(self, client: OpenAlexClient) -> None: + """Test request retry on timeout.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + # First call timeout, second call success + mock_get.side_effect = [httpx.TimeoutException("timeout"), mock_response] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + async def test_request_with_retry_rate_limit(self, client: OpenAlexClient) -> None: + """Test request retry on rate limit (403).""" + mock_response_403 = MagicMock() + mock_response_403.status_code = 403 + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + mock_response_200.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.side_effect = [mock_response_403, mock_response_200] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + async def test_request_with_retry_server_error(self, client: OpenAlexClient) -> None: + """Test request retry on server error (5xx).""" + mock_response_500 = MagicMock() + mock_response_500.status_code = 500 + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + mock_response_200.json.return_value = {"success": True} + + with patch.object(client.client, "get", new_callable=AsyncMock) as mock_get: + mock_get.side_effect = [mock_response_500, mock_response_200] + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await client._request_with_retry("http://test.com", {}) + + assert result == {"success": True} + assert mock_get.call_count == 2 + + +class TestOpenAlexClientContextManager: + """Test OpenAlex client context manager.""" + + @pytest.mark.asyncio + async def test_context_manager_enter_exit(self) -> None: + """Test client works as async context manager.""" + async with OpenAlexClient(email="test@example.com") as client: + assert client is not None + assert client.email == "test@example.com" + + @pytest.mark.asyncio + async def test_close_method(self) -> None: + """Test client close method.""" + client = OpenAlexClient(email="test@example.com") + with patch.object(client.client, "aclose", new_callable=AsyncMock) as mock_close: + await client.close() + mock_close.assert_called_once() diff --git a/service/tests/unit/test_literature/test_work_distributor.py b/service/tests/unit/test_literature/test_work_distributor.py new file mode 100644 index 00000000..41e54dd5 --- /dev/null +++ b/service/tests/unit/test_literature/test_work_distributor.py @@ -0,0 +1,428 @@ +"""Tests for work distributor.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.utils.literature.models import LiteratureWork, SearchRequest +from app.utils.literature.work_distributor import WorkDistributor + + +class TestWorkDistributorInit: + """Test WorkDistributor initialization.""" + + def test_init_with_openalex_email(self) -> None: + """Test initialization with OpenAlex email.""" + distributor = WorkDistributor(openalex_email="test@example.com") + + assert distributor.openalex_email == "test@example.com" + # OpenAlex client should be registered (polite pool) + assert "openalex" in distributor.clients + + def test_init_without_openalex_email(self) -> None: + """Test initialization without OpenAlex email.""" + distributor = WorkDistributor() + + assert distributor.openalex_email is None + # OpenAlex client should still be registered (default pool) + assert "openalex" in distributor.clients + + def test_init_with_import_error(self) -> None: + """Test initialization when OpenAlex client import fails.""" + # This test would require mocking the import, which is complex + # Instead, just verify initialization works without email + distributor = WorkDistributor() + + assert distributor.openalex_email is None + assert "openalex" in distributor.clients + + +class TestWorkDistributorSearch: + """Test WorkDistributor search functionality.""" + + @pytest.fixture + def sample_work(self) -> LiteratureWork: + """Create a sample literature work.""" + return LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Test Paper", + authors=[{"name": "John Doe", "id": "A1"}], + publication_year=2020, + cited_by_count=100, + abstract="Test abstract", + journal="Nature", + is_oa=True, + access_url="https://example.com/paper.pdf", + source="openalex", + ) + + @pytest.fixture + def mock_openalex_client(self, sample_work: LiteratureWork) -> MagicMock: + """Create a mock OpenAlex client.""" + client = AsyncMock() + client.search = AsyncMock(return_value=([sample_work], ["✓ Search completed"])) + return client + + @pytest.mark.asyncio + async def test_search_basic(self, sample_work: LiteratureWork, mock_openalex_client: MagicMock) -> None: + """Test basic search with default source.""" + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_openalex_client} + distributor.openalex_email = "test@example.com" + + result = await distributor.search(request) + + assert result["total_count"] == 1 + assert result["unique_count"] == 1 + assert "openalex" in result["sources"] + assert len(result["works"]) == 1 + assert result["works"][0].id == "W1" + + @pytest.mark.asyncio + async def test_search_multiple_sources(self, sample_work: LiteratureWork) -> None: + """Test search with multiple data sources.""" + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature67890", + title="Another Paper", + authors=[], + publication_year=2021, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="semantic_scholar", + ) + + mock_client1 = AsyncMock() + mock_client1.search = AsyncMock(return_value=([sample_work], [])) + + mock_client2 = AsyncMock() + mock_client2.search = AsyncMock(return_value=([work2], [])) + + request = SearchRequest(query="test", max_results=50, data_sources=["openalex", "semantic_scholar"]) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client1, "semantic_scholar": mock_client2} + + result = await distributor.search(request) + + assert result["total_count"] == 2 + assert result["unique_count"] == 2 + assert "openalex" in result["sources"] + assert "semantic_scholar" in result["sources"] + + @pytest.mark.asyncio + async def test_search_deduplication(self) -> None: + """Test search deduplicates results by DOI.""" + work1 = LiteratureWork( + id="W1", + doi="10.1038/nature12345", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + work2 = LiteratureWork( + id="W2", + doi="10.1038/nature12345", + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="other", + ) + + mock_client = AsyncMock() + mock_client.search = AsyncMock(return_value=([work1, work2], [])) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert result["total_count"] == 2 + assert result["unique_count"] == 1 # Deduplicated + assert result["works"][0].id == "W1" # Higher citation count + + @pytest.mark.asyncio + async def test_search_with_client_error(self, sample_work: LiteratureWork) -> None: + """Test search handles client errors gracefully.""" + mock_client = AsyncMock() + mock_client.search = AsyncMock(side_effect=Exception("API Error")) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert result["total_count"] == 0 + assert result["unique_count"] == 0 + assert result["sources"]["openalex"] == 0 + assert any("Error" in w for w in result["warnings"]) + + @pytest.mark.asyncio + async def test_search_unavailable_source(self) -> None: + """Test search with unavailable data source.""" + request = SearchRequest(query="test", max_results=50, data_sources=["unavailable_source"]) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert result["total_count"] == 0 + assert result["unique_count"] == 0 + assert result["works"] == [] + + @pytest.mark.asyncio + async def test_search_max_results_clamping_low(self) -> None: + """Test search clamps max_results to minimum.""" + request = SearchRequest(query="test", max_results=0) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert any("max_results < 1" in w for w in result["warnings"]) + assert request.max_results == 50 + + @pytest.mark.asyncio + async def test_search_max_results_clamping_high(self) -> None: + """Test search clamps max_results to maximum.""" + request = SearchRequest(query="test", max_results=5000) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {} + + result = await distributor.search(request) + + assert any("max_results > 1000" in w for w in result["warnings"]) + assert request.max_results == 1000 + + @pytest.mark.asyncio + async def test_search_result_limiting(self) -> None: + """Test search limits results to max_results.""" + works = [ + LiteratureWork( + id=f"W{i}", + doi=f"10.1038/paper{i}", + title=f"Paper {i}", + authors=[], + publication_year=2020, + cited_by_count=100 - i, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + for i in range(20) + ] + + mock_client = AsyncMock() + mock_client.search = AsyncMock(return_value=(works, [])) + + request = SearchRequest(query="test", max_results=10) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert len(result["works"]) == 10 + + @pytest.mark.asyncio + async def test_search_with_warnings(self) -> None: + """Test search collects warnings from clients.""" + work = LiteratureWork( + id="W1", + doi=None, + title="Paper", + authors=[], + publication_year=2020, + cited_by_count=10, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ) + + mock_client = AsyncMock() + mock_client.search = AsyncMock( + return_value=( + [work], + ["⚠️ Author not found", "✓ Search completed"], + ) + ) + + request = SearchRequest(query="test", max_results=50) + + distributor = WorkDistributor.__new__(WorkDistributor) + distributor.clients = {"openalex": mock_client} + + result = await distributor.search(request) + + assert "⚠️ Author not found" in result["warnings"] + assert "✓ Search completed" in result["warnings"] + + +class TestWorkDistributorSorting: + """Test WorkDistributor sorting functionality.""" + + @pytest.fixture + def sample_works(self) -> list[LiteratureWork]: + """Create sample works for sorting tests.""" + return [ + LiteratureWork( + id="W1", + doi=None, + title="Paper 1", + authors=[], + publication_year=2020, + cited_by_count=50, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + LiteratureWork( + id="W2", + doi=None, + title="Paper 2", + authors=[], + publication_year=2021, + cited_by_count=100, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + LiteratureWork( + id="W3", + doi=None, + title="Paper 3", + authors=[], + publication_year=2019, + cited_by_count=75, + abstract=None, + journal=None, + is_oa=False, + access_url=None, + source="openalex", + ), + ] + + def test_sort_by_relevance(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by relevance (default, maintains order).""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "relevance") + + # Should maintain original order for relevance + assert result[0].id == "W1" + assert result[1].id == "W2" + assert result[2].id == "W3" + + def test_sort_by_cited_by_count(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by citation count.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "cited_by_count") + + assert result[0].id == "W2" # 100 citations + assert result[1].id == "W3" # 75 citations + assert result[2].id == "W1" # 50 citations + + def test_sort_by_publication_date(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by publication date.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "publication_date") + + assert result[0].id == "W2" # 2021 + assert result[1].id == "W1" # 2020 + assert result[2].id == "W3" # 2019 + + def test_sort_with_missing_year(self, sample_works: list[LiteratureWork]) -> None: + """Test sorting by publication date with missing years.""" + sample_works[1].publication_year = None + + distributor = WorkDistributor.__new__(WorkDistributor) + + result = distributor._sort_works(sample_works, "publication_date") + + # Works with missing year should go to the end + assert result[0].id == "W1" # 2020 + assert result[1].id == "W3" # 2019 + assert result[2].publication_year is None + + +class TestWorkDistributorContextManager: + """Test WorkDistributor context manager.""" + + @pytest.mark.asyncio + async def test_context_manager_enter_exit(self) -> None: + """Test context manager functionality.""" + async with WorkDistributor(openalex_email="test@example.com") as distributor: + assert distributor is not None + + @pytest.mark.asyncio + async def test_close_method(self) -> None: + """Test close method.""" + distributor = WorkDistributor(openalex_email="test@example.com") + + # Replace the actual client with a mock + mock_client = MagicMock() + mock_client.close = AsyncMock() + distributor.clients["openalex"] = mock_client + + await distributor.close() + + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_sync_close(self) -> None: + """Test close method with synchronous close.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + mock_client = MagicMock() + # Synchronous close (returns None, not awaitable) + mock_client.close = MagicMock(return_value=None) + distributor.clients = {"openalex": mock_client} + + await distributor.close() + + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_with_no_close_method(self) -> None: + """Test close method with client that has no close method.""" + distributor = WorkDistributor.__new__(WorkDistributor) + + mock_client = MagicMock(spec=[]) # No close method + distributor.clients = {"openalex": mock_client} + + # Should not raise an error + await distributor.close() diff --git a/service/tests/unit/test_utils/test_built_in_tools.py b/service/tests/unit/test_utils/test_built_in_tools.py deleted file mode 100644 index 61ca5ebb..00000000 --- a/service/tests/unit/test_utils/test_built_in_tools.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Tests for built-in tools utilities.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastmcp import FastMCP - -from app.mcp.builtin_tools import register_built_in_tools - - -class TestBuiltInTools: - """Test built-in tools registration and functionality.""" - - @pytest.fixture - def mock_mcp(self): - """Create a mock FastMCP instance.""" - mcp = MagicMock(spec=FastMCP) - mcp.tool = MagicMock() - mcp.resource = MagicMock() - return mcp - - def test_register_built_in_tools(self, mock_mcp: MagicMock) -> None: - """Test that built-in tools are registered properly.""" - register_built_in_tools(mock_mcp) - - # Verify that the decorators were called (tools were registered) - assert mock_mcp.tool.call_count >= 4 # We have at least 4 tools - assert mock_mcp.resource.call_count >= 1 # We have at least 1 resource - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_success(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search tool with successful response.""" - # Mock response data - mock_response_data = { - "items": [ - { - "full_name": "test/repo1", - "html_url": "https://github.com/test/repo1", - "description": "Test repository 1", - "stargazers_count": 100, - "forks_count": 20, - "language": "Python", - "updated_at": "2024-01-01T00:00:00Z", - "topics": ["test", "demo"], - }, - { - "full_name": "test/repo2", - "html_url": "https://github.com/test/repo2", - "description": "Test repository 2", - "stargazers_count": 50, - "forks_count": 10, - "language": "Python", - "updated_at": "2024-01-02T00:00:00Z", - "topics": [], - }, - ] - } - - # Mock the context manager and JSON loading - mock_response = MagicMock() - mock_response.__enter__ = MagicMock(return_value=mock_response) - mock_response.__exit__ = MagicMock(return_value=None) - mock_urlopen.return_value = mock_response - - with patch("app.mcp.builtin_tools.json.load") as mock_json_load: - mock_json_load.return_value = mock_response_data - - register_built_in_tools(mock_mcp) - - # Get the search_github function from the registered tools - # Since we can't easily extract it, we'll test the logic directly - # by calling the function that would be registered - - # For this test, we'll verify the mock was set up correctly - assert mock_mcp.tool.called - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_empty_query(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search with empty query.""" - register_built_in_tools(mock_mcp) - - # The actual test would need access to the registered function - # For now, we verify the registration happened - assert mock_mcp.tool.called - - @patch("app.mcp.builtin_tools.request.urlopen") - def test_search_github_api_error(self, mock_urlopen: MagicMock, mock_mcp: MagicMock) -> None: - """Test GitHub search with API error.""" - # Mock URL open to raise an exception - mock_urlopen.side_effect = Exception("API Error") - - register_built_in_tools(mock_mcp) - - # Verify registration still happened despite the error not occurring yet - assert mock_mcp.tool.called - - def test_search_github_parameters(self, mock_mcp: MagicMock) -> None: - """Test GitHub search with different parameters.""" - register_built_in_tools(mock_mcp) - - # Verify the tool was registered with proper signature - assert mock_mcp.tool.called - - # The actual function would accept parameters like query, max_results, sort_by - # Since we can't easily test the registered function directly, - # we verify the registration process - - async def test_llm_web_search_no_auth(self, mock_mcp: MagicMock) -> None: - """Test LLM web search without authentication.""" - with patch("app.mcp.builtin_tools.get_access_token") as mock_get_token: - mock_get_token.return_value = None - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_llm_web_search_with_auth(self, mock_mcp: MagicMock) -> None: - """Test LLM web search with authentication.""" - with ( - patch("fastmcp.server.dependencies.get_access_token") as mock_get_token, - patch("app.middleware.auth.AuthProvider") as mock_auth_provider, - patch("app.core.providers.get_user_provider_manager") as mock_get_manager, - patch("app.infra.database.connection.AsyncSessionLocal") as mock_session, - ): - # Mock authentication - mock_token = MagicMock() - mock_token.claims = {"user_id": "test-user"} - mock_get_token.return_value = mock_token - - mock_user_info = MagicMock() - mock_user_info.id = "test-user" - mock_auth_provider.parse_user_info.return_value = mock_user_info - - # Mock database session - mock_db = AsyncMock() - mock_session.return_value.__aenter__.return_value = mock_db - - # Mock provider manager - mock_provider_manager = AsyncMock() - mock_get_manager.return_value = mock_provider_manager - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_refresh_tools_success(self, mock_mcp: MagicMock) -> None: - """Test refresh tools functionality.""" - with ( - patch("app.mcp.builtin_tools.get_access_token") as mock_get_token, - patch("app.mcp.builtin_tools.AuthProvider") as mock_auth_provider, - patch("app.mcp.builtin_tools.tool_loader") as mock_tool_loader, - ): - # Mock authentication - mock_token = MagicMock() - mock_token.claims = {"user_id": "test-user"} - mock_get_token.return_value = mock_token - - mock_user_info = MagicMock() - mock_user_info.id = "test-user" - mock_auth_provider.parse_user_info.return_value = mock_user_info - - # Mock tool loader - mock_tool_loader.refresh_tools.return_value = { - "added": ["tool1", "tool2"], - "removed": ["old_tool"], - "updated": ["updated_tool"], - } - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - async def test_refresh_tools_no_auth(self, mock_mcp: MagicMock) -> None: - """Test refresh tools without authentication.""" - with patch("app.mcp.builtin_tools.get_access_token") as mock_get_token: - mock_get_token.return_value = None - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - def test_get_server_status(self, mock_mcp: MagicMock) -> None: - """Test get server status tool.""" - with patch("app.mcp.builtin_tools.tool_loader") as mock_tool_loader: - mock_proxy_manager = MagicMock() - mock_proxy_manager.list_proxies.return_value = ["proxy1", "proxy2"] - mock_tool_loader.proxy_manager = mock_proxy_manager - - register_built_in_tools(mock_mcp) - - # Verify the tool was registered - assert mock_mcp.tool.called - - @pytest.mark.parametrize("sort_by", ["stars", "forks", "updated"]) - def test_search_github_sort_options(self, mock_mcp: MagicMock, sort_by: str) -> None: - """Test GitHub search with different sort options.""" - register_built_in_tools(mock_mcp) - - # Verify the tool registration happened - assert mock_mcp.tool.called - - def test_tools_registration_count(self, mock_mcp: MagicMock) -> None: - """Test that the expected number of tools are registered.""" - register_built_in_tools(mock_mcp) - - # We expect at least these tools: - # - search_github - # - llm_web_search - # - refresh_tools - # - get_server_status - expected_min_tools = 4 - - assert mock_mcp.tool.call_count >= expected_min_tools - - def test_resource_registration_count(self, mock_mcp: MagicMock) -> None: - """Test that the expected number of resources are registered.""" - register_built_in_tools(mock_mcp) - - # We expect at least these resources: - # - config://server - expected_min_resources = 1 - - assert mock_mcp.resource.call_count >= expected_min_resources From f5a0f5d9fb320b7581bbe78ef1531f54d736929b Mon Sep 17 00:00:00 2001 From: Harvey Date: Thu, 22 Jan 2026 17:16:38 +0800 Subject: [PATCH 2/8] feat: remove test-build workflow and update README for development setup --- .../workflows/{test-build.yaml => beta.yaml} | 124 +++++++++++------- README.md | 85 ++++-------- README_zh.md | 71 ++++++---- 3 files changed, 152 insertions(+), 128 deletions(-) rename .github/workflows/{test-build.yaml => beta.yaml} (60%) diff --git a/.github/workflows/test-build.yaml b/.github/workflows/beta.yaml similarity index 60% rename from .github/workflows/test-build.yaml rename to .github/workflows/beta.yaml index f4cac61a..b78aed1b 100644 --- a/.github/workflows/test-build.yaml +++ b/.github/workflows/beta.yaml @@ -1,4 +1,4 @@ -name: Test Deploy +name: Beta on: push: @@ -6,27 +6,28 @@ on: - test jobs: - # 准备阶段:设置构建变量 + # Setup: Build variables setup: + name: Setup runs-on: ubuntu-latest outputs: - build_start: ${{ steps.build_setup.outputs.build_start }} - beta_tag: ${{ steps.build_setup.outputs.beta_tag }} - commit_author: ${{ steps.build_setup.outputs.commit_author }} - commit_email: ${{ steps.build_setup.outputs.commit_email }} - commit_message: ${{ steps.build_setup.outputs.commit_message }} - commit_sha: ${{ steps.build_setup.outputs.commit_sha }} - commit_sha_short: ${{ steps.build_setup.outputs.commit_sha_short }} - commit_date: ${{ steps.build_setup.outputs.commit_date }} + build_start: ${{ steps.setup.outputs.build_start }} + beta_version: ${{ steps.setup.outputs.beta_version }} + commit_author: ${{ steps.setup.outputs.commit_author }} + commit_email: ${{ steps.setup.outputs.commit_email }} + commit_message: ${{ steps.setup.outputs.commit_message }} + commit_sha: ${{ steps.setup.outputs.commit_sha }} + commit_sha_short: ${{ steps.setup.outputs.commit_sha_short }} + commit_date: ${{ steps.setup.outputs.commit_date }} steps: - - name: Checkout code + - name: Checkout uses: actions/checkout@v4 - name: Setup build variables - id: build_setup + id: setup run: | echo "build_start=$(date '+%Y-%m-%d %H:%M:%S')" >> $GITHUB_OUTPUT - echo "beta_tag=BETA.$(date -u '+%Y-%m-%dT%H-%M-%SZ')" >> $GITHUB_OUTPUT + echo "beta_version=beta-$(date -u '+%Y%m%d%H%M%S')-$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT echo "commit_author=$(git log -1 --pretty=format:'%an')" >> $GITHUB_OUTPUT echo "commit_email=$(git log -1 --pretty=format:'%ae')" >> $GITHUB_OUTPUT echo "commit_message=$(git log -1 --pretty=format:'%s')" >> $GITHUB_OUTPUT @@ -34,12 +35,15 @@ jobs: echo "commit_sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT echo "commit_date=$(git log -1 --pretty=format:'%cd' --date=format:'%Y-%m-%d %H:%M:%S')" >> $GITHUB_OUTPUT - # 并行构建:Service 镜像 + # Parallel build: Service image build-service: + name: Build Service runs-on: ubuntu-latest needs: setup + env: + BETA_VERSION: ${{ needs.setup.outputs.beta_version }} steps: - - name: Checkout code + - name: Checkout uses: actions/checkout@v4 - name: Set up Docker Buildx @@ -52,19 +56,28 @@ jobs: username: ${{ secrets.SCIENCEOL_REGISTRY_USERNAME }} password: ${{ secrets.SCIENCEOL_REGISTRY_PASSWORD }} - - name: Build and push Service Docker image - run: | - docker buildx build service \ - -t registry.sciol.ac.cn/sciol/xyzen-service:test \ - -t registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_tag }} \ - --push - - # 并行构建:Web 镜像 + - name: Build and push Service image + uses: docker/build-push-action@v6 + with: + context: ./service + push: true + build-args: | + XYZEN_VERSION=${{ env.BETA_VERSION }} + XYZEN_COMMIT_SHA=${{ github.sha }} + XYZEN_BUILD_TIME=${{ needs.setup.outputs.commit_date }} + tags: | + registry.sciol.ac.cn/sciol/xyzen-service:beta + registry.sciol.ac.cn/sciol/xyzen-service:${{ env.BETA_VERSION }} + + # Parallel build: Web image build-web: + name: Build Web runs-on: ubuntu-latest needs: setup + env: + BETA_VERSION: ${{ needs.setup.outputs.beta_version }} steps: - - name: Checkout code + - name: Checkout uses: actions/checkout@v4 - name: Set up Docker Buildx @@ -77,57 +90,70 @@ jobs: username: ${{ secrets.SCIENCEOL_REGISTRY_USERNAME }} password: ${{ secrets.SCIENCEOL_REGISTRY_PASSWORD }} - - name: Build and push Web Docker image - run: | - docker buildx build web \ - -t registry.sciol.ac.cn/sciol/xyzen-web:test \ - -t registry.sciol.ac.cn/sciol/xyzen-web:${{ needs.setup.outputs.beta_tag }} \ - --push + - name: Build and push Web image + uses: docker/build-push-action@v6 + with: + context: ./web + push: true + tags: | + registry.sciol.ac.cn/sciol/xyzen-web:beta + registry.sciol.ac.cn/sciol/xyzen-web:${{ env.BETA_VERSION }} - # 部署阶段:等待所有构建完成后统一部署 + # Deploy: Wait for all builds to complete deploy: + name: Deploy runs-on: ubuntu-latest needs: [setup, build-service, build-web] + env: + BETA_VERSION: ${{ needs.setup.outputs.beta_version }} steps: - name: Download Let's Encrypt CA run: curl -o ca.crt https://letsencrypt.org/certs/isrgrootx1.pem - - name: Rolling update deployments + - name: Deploy to Kubernetes run: | kubectl \ --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - set image deployment/xyzen -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_tag }} + set image deployment/xyzen -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ env.BETA_VERSION }} kubectl \ --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - set image deployment/xyzen-web -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-web:${{ needs.setup.outputs.beta_tag }} + set image deployment/xyzen-web -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-web:${{ env.BETA_VERSION }} kubectl \ --server=${{ secrets.SCIENCEOL_K8S_SERVER_URL }} \ --token=${{ secrets.SCIENCEOL_K8S_ADMIN_TOKEN }} \ --certificate-authority=ca.crt \ - set image deployment/xyzen-celery -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_tag }} + set image deployment/xyzen-celery -n bohrium *=registry.sciol.ac.cn/sciol/xyzen-service:${{ env.BETA_VERSION }} - # 通知阶段:发送构建结果通知 + - name: Deployment Summary + run: | + echo "## 🧪 Beta Deployment Complete" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Component | Image |" >> $GITHUB_STEP_SUMMARY + echo "|-----------|-------|" >> $GITHUB_STEP_SUMMARY + echo "| Service | \`registry.sciol.ac.cn/sciol/xyzen-service:${{ env.BETA_VERSION }}\` |" >> $GITHUB_STEP_SUMMARY + echo "| Web | \`registry.sciol.ac.cn/sciol/xyzen-web:${{ env.BETA_VERSION }}\` |" >> $GITHUB_STEP_SUMMARY + + # Notify: Send build result notification notify: + name: Notify runs-on: ubuntu-latest needs: [setup, build-service, build-web, deploy] if: always() steps: - - name: Checkout code + - name: Checkout uses: actions/checkout@v4 - name: Calculate build duration - id: build_duration - shell: bash + id: duration run: | BUILD_START="${{ needs.setup.outputs.build_start }}" if [ -z "$BUILD_START" ]; then - echo "Warning: build_start is empty, using current time as fallback" BUILD_START=$(date '+%Y-%m-%d %H:%M:%S') fi @@ -140,10 +166,8 @@ jobs: MINUTES=$(((DURATION_SEC % 3600) / 60)) SECONDS=$((DURATION_SEC % 60)) - DURATION="${HOURS}h ${MINUTES}m ${SECONDS}s" - - echo "build_end=$BUILD_END" >> $GITHUB_ENV - echo "build_duration=$DURATION" >> $GITHUB_ENV + echo "build_end=$BUILD_END" >> $GITHUB_OUTPUT + echo "build_duration=${HOURS}h ${MINUTES}m ${SECONDS}s" >> $GITHUB_OUTPUT - name: Determine overall status id: status @@ -154,7 +178,7 @@ jobs: echo "status=failure" >> $GITHUB_OUTPUT fi - - name: Send build notification + - name: Send notification uses: ./.github/actions/email-notification with: status: ${{ steps.status.outputs.status }} @@ -165,20 +189,20 @@ jobs: recipient: ${{ secrets.SMTP_RECEIVER }} architecture: 'amd64' pr_number: 'N/A' - pr_title: 'Push to test' + pr_title: 'Beta Deploy ${{ needs.setup.outputs.beta_version }}' pr_url: '${{ github.server_url }}/${{ github.repository }}/commit/${{ github.sha }}' head_ref: ${{ github.ref_name }} base_ref: 'test' repo: ${{ github.repository }} run_id: ${{ github.run_id }} build_start: ${{ needs.setup.outputs.build_start }} - build_end: ${{ env.build_end }} - build_duration: ${{ env.build_duration }} + build_end: ${{ steps.duration.outputs.build_end }} + build_duration: ${{ steps.duration.outputs.build_duration }} commit_author: ${{ needs.setup.outputs.commit_author }} commit_email: ${{ needs.setup.outputs.commit_email }} commit_message: ${{ needs.setup.outputs.commit_message }} commit_sha: ${{ needs.setup.outputs.commit_sha }} commit_sha_short: ${{ needs.setup.outputs.commit_sha_short }} commit_date: ${{ needs.setup.outputs.commit_date }} - service_image: 'registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_tag }}' - web_image: 'registry.sciol.ac.cn/sciol/xyzen-web:${{ needs.setup.outputs.beta_tag }}' + service_image: 'registry.sciol.ac.cn/sciol/xyzen-service:${{ needs.setup.outputs.beta_version }}' + web_image: 'registry.sciol.ac.cn/sciol/xyzen-web:${{ needs.setup.outputs.beta_version }}' diff --git a/README.md b/README.md index 3416ab51..cdd0ad25 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Your next agent platform for multi-agent orchestration, real-time chat, and docu [![React](https://img.shields.io/badge/react-%2320232a.svg?style=flat&logo=react&logoColor=%2361DAFB)](https://reactjs.org/) [![npm version](https://img.shields.io/npm/v/@sciol/xyzen.svg)](https://www.npmjs.com/package/@sciol/xyzen) [![Pre-commit CI](https://github.com/ScienceOL/Xyzen/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/pre-commit.yaml) -[![Prod Build](https://github.com/ScienceOL/Xyzen/actions/workflows/prod-build.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/prod-build.yaml) +[![Release](https://github.com/ScienceOL/Xyzen/actions/workflows/release.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/release.yaml) [![Test Suite](https://github.com/ScienceOL/Xyzen/actions/workflows/test.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/test.yaml) [![codecov](https://codecov.io/github/ScienceOL/Xyzen/graph/badge.svg?token=91W3GO7CRI)](https://codecov.io/github/ScienceOL/Xyzen) @@ -31,17 +31,9 @@ Xyzen is an AI lab server built with FastAPI + LangGraph on the backend and Reac ## Getting Started -Xyzen uses Docker for all development to ensure consistency across environments and to manage required infrastructure services (PostgreSQL, Redis, Mosquitto, Casdoor). - ### Prerequisites - Docker and Docker Compose -- [uv](https://docs.astral.sh/uv/) for pre-commit hooks (Python tools) -- Node.js with Yarn (via [Corepack](https://nodejs.org/api/corepack.html)) for pre-commit hooks (Frontend tools) - -## Development Setup - -The easiest way to get started with Xyzen is using the containerized development environment. This automatically sets up all services (PostgreSQL, Mosquitto, Casdoor) and development tools. ### Quick Start @@ -52,69 +44,50 @@ The easiest way to get started with Xyzen is using the containerized development cd Xyzen ``` -2. Start the development environment: - - **On Unix/Linux/macOS:** +2. Create environment configuration: ```bash - ./launch/dev.sh - ``` - - **On Windows (PowerShell):** - - ```powershell - .\launch\dev.ps1 + cp docker/.env.example docker/.env.dev ``` - Or use the Makefile: +3. Configure your LLM provider in `docker/.env.dev`: ```bash - make dev # Start in foreground (shows logs) - make dev ARGS="-d" # Start in background (daemon mode) - make dev ARGS="-s" # Stop containers (without removal) - make dev ARGS="-e" # Stop and remove containers - ``` - -The script will automatically: - -- Check Docker and validate `.env.dev` file -- Set up global Sciol virtual environment at `~/.sciol/venv` -- Install and configure pre-commit hooks -- Create VS Code workspace configuration -- Start infrastructure services (PostgreSQL, Mosquitto, Casdoor) -- Launch development containers with hot reloading + # Enable providers (comma-separated): azure_openai,openai,google,qwen + XYZEN_LLM_providers=openai -### Container Development Options - -**Start in foreground (see logs):** + # OpenAI example + XYZEN_LLM_OpenAI_key=sk-your-api-key + XYZEN_LLM_OpenAI_endpoint=https://api.openai.com/v1 + XYZEN_LLM_OpenAI_deployment=gpt-4o + ``` -```bash -./launch/dev.sh -``` + See `docker/.env.example` for all available configuration options. -**Start in background:** +4. Start the development environment: -```bash -./launch/dev.sh -d -``` + ```bash + ./launch/dev.sh # Start in foreground (shows logs) + ./launch/dev.sh -d # Start in background (daemon mode) + ./launch/dev.sh -s # Stop containers + ./launch/dev.sh -e # Stop and remove containers + ``` -**Stop containers:** + Or use the Makefile: -```bash -./launch/dev.sh -s -``` + ```bash + make dev # Start in foreground + make dev ARGS="-d" # Start in background + ``` -**Stop and remove containers:** +The script will automatically set up all infrastructure services (PostgreSQL, Redis, Mosquitto, Casdoor) and launch development containers with hot reloading. -```bash -./launch/dev.sh -e -``` +## Development -**Show help:** +### Prerequisites for Contributing -```bash -./launch/dev.sh -h -``` +- [uv](https://docs.astral.sh/uv/) for Python tools and pre-commit hooks +- Node.js with Yarn (via [Corepack](https://nodejs.org/api/corepack.html)) for frontend tools ## AI Assistant Rules diff --git a/README_zh.md b/README_zh.md index 7b7e182b..db86f237 100644 --- a/README_zh.md +++ b/README_zh.md @@ -12,7 +12,7 @@ [![React](https://img.shields.io/badge/react-%2320232a.svg?style=flat&logo=react&logoColor=%2361DAFB)](https://reactjs.org/) [![npm version](https://img.shields.io/npm/v/@sciol/xyzen.svg)](https://www.npmjs.com/package/@sciol/xyzen) [![Pre-commit CI](https://github.com/ScienceOL/Xyzen/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/pre-commit.yaml) -[![Prod Build](https://github.com/ScienceOL/Xyzen/actions/workflows/prod-build.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/prod-build.yaml) +[![Release](https://github.com/ScienceOL/Xyzen/actions/workflows/release.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/release.yaml) [![Test Suite](https://github.com/ScienceOL/Xyzen/actions/workflows/test.yaml/badge.svg)](https://github.com/ScienceOL/Xyzen/actions/workflows/test.yaml) [![codecov](https://codecov.io/github/ScienceOL/Xyzen/graph/badge.svg?token=91W3GO7CRI)](https://codecov.io/github/ScienceOL/Xyzen) @@ -31,36 +31,63 @@ Xyzen 由 FastAPI + LangGraph 后端与 React + Zustand 前端构建,支持多 ## 快速开始 -开发依赖 Docker(PostgreSQL、Redis、Mosquitto、Casdoor)。 - ### 前置条件 - Docker 和 Docker Compose -- `uv`(用于 Python 工具链) -- Node.js + Yarn(Corepack) -### 启动开发环境 +### 启动步骤 -```bash -git clone https://github.com/ScienceOL/Xyzen.git -cd Xyzen -./launch/dev.sh -``` +1. 克隆仓库: -Windows(PowerShell): + ```bash + git clone https://github.com/ScienceOL/Xyzen.git + cd Xyzen + ``` -```powershell -.\launch\dev.ps1 -``` +2. 创建环境配置文件: -常用命令: + ```bash + cp docker/.env.example docker/.env.dev + ``` -```bash -make dev # 前台启动 -make dev ARGS="-d" # 后台启动 -make dev ARGS="-s" # 停止容器 -make dev ARGS="-e" # 停止并移除容器 -``` +3. 在 `docker/.env.dev` 中配置 LLM 模型: + + ```bash + # 启用的模型供应商(逗号分隔):azure_openai,openai,google,qwen + XYZEN_LLM_providers=openai + + # OpenAI 示例 + XYZEN_LLM_OpenAI_key=sk-your-api-key + XYZEN_LLM_OpenAI_endpoint=https://api.openai.com/v1 + XYZEN_LLM_OpenAI_deployment=gpt-4o + ``` + + 完整配置项请参考 `docker/.env.example`。 + +4. 启动开发环境: + + ```bash + ./launch/dev.sh # 前台启动(显示日志) + ./launch/dev.sh -d # 后台启动 + ./launch/dev.sh -s # 停止容器 + ./launch/dev.sh -e # 停止并移除容器 + ``` + + 或使用 Makefile: + + ```bash + make dev # 前台启动 + make dev ARGS="-d" # 后台启动 + ``` + +脚本会自动配置所有基础服务(PostgreSQL、Redis、Mosquitto、Casdoor)并启动带热重载的开发容器。 + +## 开发 + +### 贡献代码的前置条件 + +- [uv](https://docs.astral.sh/uv/)(Python 工具链和 pre-commit hooks) +- Node.js + Yarn(通过 [Corepack](https://nodejs.org/api/corepack.html),用于前端工具) ## AI 助手规则 From 1c4d901f62664b551ff7902786ebdf84d9707b96 Mon Sep 17 00:00:00 2001 From: "xinquiry(SII)" <100398322+xinquiry@users.noreply.github.com> Date: Thu, 22 Jan 2026 18:29:32 +0800 Subject: [PATCH 3/8] feat: tool cost system and PPTX image handling fixes (#193) * fix: prompt, factory * feat: enhanced ppt generation with image slides mode - Add image_slides mode for PPTX with full-bleed AI-generated images - Add ImageBlock.image_id field for referencing generated images - Add ImageSlideSpec for image-only slides - Add ImageFetcher service for fetching images from various sources - Reorganize knowledge module from single file to module structure - Move document utilities from app/mcp/ to app/tools/utils/documents/ - Resolve image_ids to storage URLs in async layer (operations.py) - Fix type errors and move tests to proper location Co-Authored-By: Claude * feat: implement the tool cost --------- Co-authored-by: Claude --- service/app/agents/factory.py | 17 +- service/app/agents/graph_builder.py | 28 +- service/app/core/chat/langchain.py | 6 +- service/app/core/chat/stream_handlers.py | 10 +- service/app/core/consume_strategy.py | 14 +- service/app/schemas/chat_event_payloads.py | 3 +- service/app/tasks/chat.py | 61 +- service/app/tools/builtin/image.py | 6 +- service/app/tools/builtin/knowledge.py | 491 ---------- .../app/tools/builtin/knowledge/__init__.py | 30 + .../tools/builtin/knowledge/help_content.py | 526 ++++++++++ .../app/tools/builtin/knowledge/operations.py | 344 +++++++ .../app/tools/builtin/knowledge/schemas.py | 95 ++ service/app/tools/builtin/knowledge/tools.py | 232 +++++ service/app/tools/cost.py | 57 ++ service/app/tools/registry.py | 35 +- service/app/tools/utils/documents/__init__.py | 83 ++ .../utils/documents/handlers.py} | 405 +++++++- .../tools/utils/documents/image_fetcher.py | 271 ++++++ .../utils/documents/spec.py} | 52 +- service/tests/unit/agents/test_factory.py | 132 +++ .../tests/unit/agents/test_graph_builder.py | 174 +++- .../unit/handler/mcp/test_file_handlers.py | 374 -------- .../unit/test_core/test_consume_strategy.py | 26 +- service/tests/unit/tools/__init__.py | 0 service/tests/unit/tools/test_cost.py | 173 ++++ service/tests/unit/tools/utils/__init__.py | 0 .../unit/tools/utils/documents/__init__.py | 0 .../tools/utils/documents/test_handlers.py | 903 ++++++++++++++++++ .../utils/documents/test_image_fetcher.py | 378 ++++++++ 30 files changed, 3987 insertions(+), 939 deletions(-) delete mode 100644 service/app/tools/builtin/knowledge.py create mode 100644 service/app/tools/builtin/knowledge/__init__.py create mode 100644 service/app/tools/builtin/knowledge/help_content.py create mode 100644 service/app/tools/builtin/knowledge/operations.py create mode 100644 service/app/tools/builtin/knowledge/schemas.py create mode 100644 service/app/tools/builtin/knowledge/tools.py create mode 100644 service/app/tools/cost.py create mode 100644 service/app/tools/utils/documents/__init__.py rename service/app/{mcp/file_handlers.py => tools/utils/documents/handlers.py} (71%) create mode 100644 service/app/tools/utils/documents/image_fetcher.py rename service/app/{mcp/document_spec.py => tools/utils/documents/spec.py} (79%) create mode 100644 service/tests/unit/agents/test_factory.py delete mode 100644 service/tests/unit/handler/mcp/test_file_handlers.py create mode 100644 service/tests/unit/tools/__init__.py create mode 100644 service/tests/unit/tools/test_cost.py create mode 100644 service/tests/unit/tools/utils/__init__.py create mode 100644 service/tests/unit/tools/utils/documents/__init__.py create mode 100644 service/tests/unit/tools/utils/documents/test_handlers.py create mode 100644 service/tests/unit/tools/utils/documents/test_image_fetcher.py diff --git a/service/app/agents/factory.py b/service/app/agents/factory.py index 6f38d1cf..53dcc82c 100644 --- a/service/app/agents/factory.py +++ b/service/app/agents/factory.py @@ -207,10 +207,11 @@ def _resolve_agent_config( def _inject_system_prompt(config_dict: dict[str, Any], system_prompt: str) -> dict[str, Any]: """ - Inject system_prompt into a react-style config. + Inject system_prompt into a graph config. - For configs using stdlib:react component, updates the config_overrides - to include the system_prompt. + Handles both: + 1. Component nodes with stdlib:react - updates config_overrides + 2. LLM nodes - updates prompt_template Args: config_dict: GraphConfig as dict @@ -224,8 +225,9 @@ def _inject_system_prompt(config_dict: dict[str, Any], system_prompt: str) -> di config = copy.deepcopy(config_dict) - # Find component nodes and inject system_prompt + # Find nodes and inject system_prompt (first matching node only) for node in config.get("nodes", []): + # Handle component nodes (existing behavior) if node.get("type") == "component": comp_config = node.get("component_config", {}) comp_ref = comp_config.get("component_ref", {}) @@ -234,6 +236,13 @@ def _inject_system_prompt(config_dict: dict[str, Any], system_prompt: str) -> di if comp_ref.get("key") == "react": overrides = comp_config.setdefault("config_overrides", {}) overrides["system_prompt"] = system_prompt + break + + # Handle LLM nodes + elif node.get("type") == "llm": + llm_config = node.get("llm_config", {}) + llm_config["prompt_template"] = system_prompt + break return config diff --git a/service/app/agents/graph_builder.py b/service/app/agents/graph_builder.py index c0726c15..cc70906c 100644 --- a/service/app/agents/graph_builder.py +++ b/service/app/agents/graph_builder.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Annotated, Any from jinja2 import Template -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, SystemMessage from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition @@ -344,11 +344,15 @@ async def llm_node(state: StateDict | BaseModel) -> StateDict: # Convert state to dict for template rendering (but we already have messages) state_dict = self._state_to_dict(state) - # Render prompt template - prompt = self._render_template(llm_config.prompt_template, state_dict) + # Build messages for LLM - start with conversation messages + llm_messages = list(messages) - # Build messages for LLM - llm_messages = list(messages) + [HumanMessage(content=prompt)] + # Prepend system prompt if configured (uses Jinja2 template rendering) + if llm_config.prompt_template: + rendered_prompt = self._render_template(llm_config.prompt_template, state_dict) + # Filter any existing SystemMessage and prepend ours + llm_messages = [m for m in llm_messages if not isinstance(m, SystemMessage)] + llm_messages = [SystemMessage(content=rendered_prompt)] + llm_messages # Invoke LLM (using pre-created configured_llm) response = await configured_llm.ainvoke(llm_messages) @@ -382,11 +386,15 @@ async def llm_node(state: StateDict | BaseModel) -> StateDict: logger.info(f"[LLM Node: {config.id}] Text output completed, tool_calls: {len(tool_calls)}") - # Build AIMessage preserving tool_calls - ai_message = AIMessage( - content=content_str, - tool_calls=tool_calls, - ) + # Preserve the original response message to retain provider-specific metadata + # (e.g., Gemini thought signatures needed for tool calling). + if isinstance(response, BaseMessage): + ai_message = response + else: + ai_message = AIMessage( + content=content_str, + tool_calls=tool_calls, + ) return { llm_config.output_key: content_str, diff --git a/service/app/core/chat/langchain.py b/service/app/core/chat/langchain.py index 31395d5d..5826cfc4 100644 --- a/service/app/core/chat/langchain.py +++ b/service/app/core/chat/langchain.py @@ -492,9 +492,11 @@ async def _handle_updates_mode( logger.info(f"[ToolEvent] Skipping historical tool response: {tool_call_id}") continue ctx.emitted_tool_result_ids.add(tool_call_id) - result = format_tool_result(msg.content, tool_name) + # Get raw content before formatting (for cost calculation) + raw_content = msg.content + result = format_tool_result(raw_content, tool_name) logger.info(f"[ToolEvent] >>> Emitting tool_call_response for {tool_call_id}") - yield ToolEventHandler.create_tool_response_event(tool_call_id, result) + yield ToolEventHandler.create_tool_response_event(tool_call_id, result, raw_result=raw_content) last_message = messages[-1] diff --git a/service/app/core/chat/stream_handlers.py b/service/app/core/chat/stream_handlers.py index af23f2a7..e4ab358e 100644 --- a/service/app/core/chat/stream_handlers.py +++ b/service/app/core/chat/stream_handlers.py @@ -107,15 +107,19 @@ def create_tool_request_event(tool_call: dict[str, Any]) -> StreamingEvent: @staticmethod def create_tool_response_event( - tool_call_id: str, result: str, status: str = ToolCallStatus.COMPLETED + tool_call_id: str, + result: str, + status: str = ToolCallStatus.COMPLETED, + raw_result: str | dict | list | None = None, ) -> StreamingEvent: """ Create a tool call response event. Args: tool_call_id: ID of the tool call - result: Formatted result string + result: Formatted result string for display status: Tool call status + raw_result: Raw result for cost calculation (optional, unformatted) Returns: StreamingEvent for tool call response @@ -125,6 +129,8 @@ def create_tool_response_event( "status": status, "result": result, } + if raw_result is not None: + data["raw_result"] = raw_result return {"type": ChatEventType.TOOL_CALL_RESPONSE, "data": data} diff --git a/service/app/core/consume_strategy.py b/service/app/core/consume_strategy.py index 74369400..a7cd275b 100644 --- a/service/app/core/consume_strategy.py +++ b/service/app/core/consume_strategy.py @@ -24,7 +24,7 @@ class ConsumptionContext: output_tokens: int = 0 total_tokens: int = 0 content_length: int = 0 - generated_files_count: int = 0 + tool_costs: int = 0 @dataclass @@ -64,13 +64,12 @@ class TierBasedConsumptionStrategy(ConsumptionStrategy): Design decisions: - LITE tier (rate 0.0) = completely free - - Tier rate multiplies ALL costs (base + tokens + files) + - Tier rate multiplies ALL costs (base + tokens + tool costs) """ BASE_COST = 1 INPUT_TOKEN_RATE = 0.2 / 1000 # per token OUTPUT_TOKEN_RATE = 1 / 1000 # per token - FILE_GENERATION_COST = 10 def calculate(self, context: ConsumptionContext) -> ConsumptionResult: """Calculate consumption with tier-based multiplier. @@ -90,7 +89,7 @@ def calculate(self, context: ConsumptionContext) -> ConsumptionResult: breakdown={ "base_cost": 0, "token_cost": 0, - "file_cost": 0, + "tool_costs": 0, "tier_rate": 0.0, "tier": context.model_tier.value if context.model_tier else "lite", "note": "LITE tier - free usage", @@ -99,10 +98,9 @@ def calculate(self, context: ConsumptionContext) -> ConsumptionResult: # Calculate base token cost token_cost = context.input_tokens * self.INPUT_TOKEN_RATE + context.output_tokens * self.OUTPUT_TOKEN_RATE - file_cost = context.generated_files_count * self.FILE_GENERATION_COST - # Tier rate multiplies ALL costs - base_amount = self.BASE_COST + token_cost + file_cost + # Tier rate multiplies ALL costs (including tool costs) + base_amount = self.BASE_COST + token_cost + context.tool_costs final_amount = int(base_amount * tier_rate) return ConsumptionResult( @@ -110,7 +108,7 @@ def calculate(self, context: ConsumptionContext) -> ConsumptionResult: breakdown={ "base_cost": self.BASE_COST, "token_cost": token_cost, - "file_cost": file_cost, + "tool_costs": context.tool_costs, "pre_multiplier_total": base_amount, "tier_rate": tier_rate, "tier": context.model_tier.value if context.model_tier else "default", diff --git a/service/app/schemas/chat_event_payloads.py b/service/app/schemas/chat_event_payloads.py index 7ad6b913..4f13d652 100644 --- a/service/app/schemas/chat_event_payloads.py +++ b/service/app/schemas/chat_event_payloads.py @@ -92,7 +92,8 @@ class ToolCallResponseData(TypedDict): toolCallId: str status: str - result: str + result: str # Formatted result for display + raw_result: NotRequired[str | dict | list] # Raw result for cost calculation error: NotRequired[str] diff --git a/service/app/tasks/chat.py b/service/app/tasks/chat.py index d6c2993a..57143ce2 100644 --- a/service/app/tasks/chat.py +++ b/service/app/tasks/chat.py @@ -23,6 +23,7 @@ from app.repos.session import SessionRepository from app.schemas.chat_event_payloads import CitationData from app.schemas.chat_event_types import ChatEventType +from app.tools.cost import calculate_tool_cost logger = logging.getLogger(__name__) @@ -177,6 +178,10 @@ async def _process_chat_message_async( output_tokens: int = 0 total_tokens: int = 0 + # Tool cost tracking + tool_costs_total = 0 + tool_call_data: dict[str, dict[str, Any]] = {} # tool_call_id -> {name, args} + # Agent run tracking (for new timeline-based persistence) agent_run_id: UUID | None = None agent_run_start_time: float | None = None @@ -305,9 +310,24 @@ async def _process_chat_message_async( await publisher.publish(json.dumps(stream_event)) elif stream_event["type"] == ChatEventType.TOOL_CALL_REQUEST: + # Store tool call data for cost calculation + req = stream_event["data"] + tool_call_id = req.get("id") + tool_name = req.get("name", "") + if tool_call_id: + # Parse arguments (may be JSON string) + raw_args = req.get("arguments", {}) + if isinstance(raw_args, str): + try: + parsed_args = json.loads(raw_args) + except json.JSONDecodeError: + parsed_args = {} + else: + parsed_args = raw_args or {} + tool_call_data[tool_call_id] = {"name": tool_name, "args": parsed_args} + # Persist tool call request try: - req = stream_event["data"] tool_message = MessageCreate( role="tool", content=json.dumps( @@ -331,11 +351,42 @@ async def _process_chat_message_async( await publisher.publish(json.dumps(stream_event)) elif stream_event["type"] == ChatEventType.TOOL_CALL_RESPONSE: + resp = stream_event["data"] + tool_call_id = resp.get("toolCallId") + + # Calculate tool cost using stored data from TOOL_CALL_REQUEST + if tool_call_id and tool_call_id in tool_call_data: + stored = tool_call_data[tool_call_id] + tool_name = stored.get("name", "") + args = stored.get("args", {}) + # Use raw_result for cost calculation (unformatted) + result = resp.get("raw_result") + # Parse result if it's a JSON string + if isinstance(result, str): + try: + result = json.loads(result) + except json.JSONDecodeError: + result = None + # Only dict results are supported for cost calculation + if not isinstance(result, dict): + result = None + + # Only charge for successful tool executions + tool_failed = ( + resp.get("status") == "error" + or resp.get("error") is not None + or (isinstance(result, dict) and result.get("success") is False) + ) + if tool_failed: + logger.info(f"Tool {tool_name} failed, not charging") + else: + cost = calculate_tool_cost(tool_name, args, result) + if cost > 0: + tool_costs_total += cost + logger.info(f"Tool {tool_name} cost: {cost} (total: {tool_costs_total})") + # Persist tool call response try: - resp = stream_event["data"] - tool_call_id = resp.get("toolCallId") - # Only persist if toolCallId is valid - skip otherwise if not tool_call_id or not isinstance(tool_call_id, str): logger.warning( @@ -601,7 +652,7 @@ async def _process_chat_message_async( output_tokens=output_tokens, total_tokens=total_tokens, content_length=len(full_content), - generated_files_count=generated_files_count, + tool_costs=tool_costs_total, ) result = ConsumptionCalculator.calculate(consume_context) total_cost = result.amount diff --git a/service/app/tools/builtin/image.py b/service/app/tools/builtin/image.py index 94d23c81..c5985808 100644 --- a/service/app/tools/builtin/image.py +++ b/service/app/tools/builtin/image.py @@ -521,7 +521,8 @@ async def generate_image_placeholder( "Generate an image based on a text description. " "Provide a detailed prompt describing the desired image. " "To modify or generate based on a previous image, pass the 'image_id' from a previous generate_image result. " - "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image." + "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image. " + "TIP: You can use 'image_id' values when creating PPTX presentations with knowledge_write - see knowledge_help(topic='image_slides') for details." ), args_schema=GenerateImageInput, coroutine=generate_image_placeholder, @@ -574,7 +575,8 @@ async def generate_image_bound( "Generate an image based on a text description. " "Provide a detailed prompt describing the desired image including style, colors, composition, and subject. " "To modify or generate based on a previous image, pass the 'image_id' from a previous generate_image result. " - "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image to the user." + "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image to the user. " + "TIP: You can use 'image_id' values when creating beautiful PPTX presentations with knowledge_write in image_slides mode - call knowledge_help(topic='image_slides') for the full workflow." ), args_schema=GenerateImageInput, coroutine=generate_image_bound, diff --git a/service/app/tools/builtin/knowledge.py b/service/app/tools/builtin/knowledge.py deleted file mode 100644 index 4e656dc9..00000000 --- a/service/app/tools/builtin/knowledge.py +++ /dev/null @@ -1,491 +0,0 @@ -""" -Knowledge Base Tools - -LangChain tools for knowledge base file operations. -These tools require runtime context (user_id, knowledge_set_id) to function. - -Unlike web search which works context-free, knowledge tools are created per-agent -with the agent's knowledge_set_id bound at creation time. -""" - -from __future__ import annotations - -import io -import logging -import mimetypes -from datetime import datetime, timezone -from typing import Any -from uuid import UUID - -from langchain_core.tools import BaseTool, StructuredTool -from pydantic import BaseModel, Field -from sqlmodel.ext.asyncio.session import AsyncSession - -from app.core.storage import FileCategory, FileScope, generate_storage_key, get_storage_service -from app.infra.database import AsyncSessionLocal -from app.models.file import FileCreate -from app.repos.file import FileRepository -from app.repos.knowledge_set import KnowledgeSetRepository - -logger = logging.getLogger(__name__) - - -# --- Input Schemas --- - - -class KnowledgeListFilesInput(BaseModel): - """Input schema for list_files tool - no parameters needed.""" - - pass - - -class KnowledgeReadFileInput(BaseModel): - """Input schema for read_file tool.""" - - filename: str = Field( - description=( - "The name of the file to read from the knowledge base. " - "Supported formats: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, " - "images (PNG/JPG/GIF/WEBP with OCR), and plain text files." - ) - ) - - -class KnowledgeWriteFileInput(BaseModel): - """Input schema for write_file tool.""" - - filename: str = Field( - description=( - "The name of the file to create or update. Use appropriate extensions: " - ".txt, .md (plain text), .pdf (PDF document), .docx (Word), " - ".xlsx (Excel), .pptx (PowerPoint), .json, .yaml, .xml, .html." - ) - ) - content: str = Field( - description=( - "The content to write. Can be plain text (creates simple documents) or " - "a JSON specification for production-quality documents:\n\n" - "**For PDF/DOCX (DocumentSpec JSON):**\n" - '{"title": "My Report", "author": "Name", "content": [\n' - ' {"type": "heading", "content": "Section 1", "level": 1},\n' - ' {"type": "text", "content": "Paragraph text here"},\n' - ' {"type": "list", "items": ["Item 1", "Item 2"], "ordered": false},\n' - ' {"type": "table", "headers": ["Col1", "Col2"], "rows": [["A", "B"]]},\n' - ' {"type": "page_break"}\n' - "]}\n\n" - "**For XLSX (SpreadsheetSpec JSON):**\n" - '{"sheets": [{"name": "Data", "headers": ["Name", "Value"], ' - '"data": [["A", 1], ["B", 2]], "freeze_header": true}]}\n\n' - "**For PPTX (PresentationSpec JSON):**\n" - '{"title": "My Presentation", "slides": [\n' - ' {"layout": "title", "title": "Welcome", "subtitle": "Intro"},\n' - ' {"layout": "title_content", "title": "Slide 2", ' - '"content": [{"type": "list", "items": ["Point 1", "Point 2"]}], ' - '"notes": "Speaker notes here"}\n' - "]}" - ) - ) - - -class KnowledgeSearchFilesInput(BaseModel): - """Input schema for search_files tool.""" - - query: str = Field(description="Search term to find files by name.") - - -# --- Helper Functions --- - - -async def _get_files_in_knowledge_set(db: AsyncSession, user_id: str, knowledge_set_id: UUID) -> list[UUID]: - """Get all file IDs in a knowledge set.""" - knowledge_set_repo = KnowledgeSetRepository(db) - - # Validate access - try: - await knowledge_set_repo.validate_access(user_id, knowledge_set_id) - except ValueError as e: - raise ValueError(f"Access denied: {e}") - - # Get file IDs - file_ids = await knowledge_set_repo.get_files_in_knowledge_set(knowledge_set_id) - return file_ids - - -# --- Tool Implementation Functions --- - - -async def _list_files(user_id: str, knowledge_set_id: UUID) -> dict[str, Any]: - """List all files in the knowledge set.""" - try: - async with AsyncSessionLocal() as db: - file_repo = FileRepository(db) - - try: - file_ids = await _get_files_in_knowledge_set(db, user_id, knowledge_set_id) - except ValueError as e: - return {"error": str(e), "success": False} - - # Fetch file objects - files = [] - for file_id in file_ids: - file = await file_repo.get_file_by_id(file_id) - if file and not file.is_deleted: - files.append(file) - - # Format output - entries: list[str] = [] - for f in files: - entries.append(f"[FILE] {f.original_filename} (ID: {f.id})") - - return { - "success": True, - "knowledge_set_id": str(knowledge_set_id), - "entries": entries, - "count": len(entries), - } - - except Exception as e: - logger.error(f"Error listing files: {e}") - return {"error": f"Internal error: {e!s}", "success": False} - - -async def _read_file(user_id: str, knowledge_set_id: UUID, filename: str) -> dict[str, Any]: - """Read content of a file from the knowledge set.""" - from app.mcp.file_handlers import FileHandlerFactory - - try: - # Normalize filename - filename = filename.strip("/").split("/")[-1] - - async with AsyncSessionLocal() as db: - file_repo = FileRepository(db) - target_file = None - - try: - file_ids = await _get_files_in_knowledge_set(db, user_id, knowledge_set_id) - except ValueError as e: - return {"error": str(e), "success": False} - - # Find file by name - for file_id in file_ids: - file = await file_repo.get_file_by_id(file_id) - if file and file.original_filename == filename and not file.is_deleted: - target_file = file - break - - if not target_file: - return {"error": f"File '{filename}' not found in knowledge set.", "success": False} - - # Download content - storage = get_storage_service() - buffer = io.BytesIO() - await storage.download_file(target_file.storage_key, buffer) - file_bytes = buffer.getvalue() - - # Use handler to process content (text mode only for LangChain tools) - handler = FileHandlerFactory.get_handler(target_file.original_filename) - - try: - result = handler.read_content(file_bytes, mode="text") - return { - "success": True, - "filename": target_file.original_filename, - "content": result, - "size_bytes": target_file.file_size, - } - except Exception as e: - return {"error": f"Error parsing file: {e!s}", "success": False} - - except Exception as e: - logger.error(f"Error reading file: {e}") - return {"error": f"Internal error: {e!s}", "success": False} - - -async def _write_file(user_id: str, knowledge_set_id: UUID, filename: str, content: str) -> dict[str, Any]: - """Create or update a file in the knowledge set.""" - from app.mcp.file_handlers import FileHandlerFactory - - try: - filename = filename.strip("/").split("/")[-1] - - async with AsyncSessionLocal() as db: - file_repo = FileRepository(db) - knowledge_set_repo = KnowledgeSetRepository(db) - storage = get_storage_service() - - try: - file_ids = await _get_files_in_knowledge_set(db, user_id, knowledge_set_id) - except ValueError as e: - return {"error": str(e), "success": False} - - # Check if file exists - existing_file = None - for file_id in file_ids: - file = await file_repo.get_file_by_id(file_id) - if file and file.original_filename == filename and not file.is_deleted: - existing_file = file - break - - # Determine content type - content_type, _ = mimetypes.guess_type(filename) - if not content_type: - if filename.endswith(".docx"): - content_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - elif filename.endswith(".xlsx"): - content_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" - elif filename.endswith(".pptx"): - content_type = "application/vnd.openxmlformats-officedocument.presentationml.presentation" - elif filename.endswith(".pdf"): - content_type = "application/pdf" - else: - content_type = "text/plain" - - # Use handler to create content bytes - handler = FileHandlerFactory.get_handler(filename) - encoded_content = handler.create_content(content) - - new_key = generate_storage_key(user_id, filename, FileScope.PRIVATE) - data = io.BytesIO(encoded_content) - file_size_bytes = len(encoded_content) - - await storage.upload_file(data, new_key, content_type=content_type) - - if existing_file: - # Update existing - existing_file.storage_key = new_key - existing_file.file_size = file_size_bytes - existing_file.content_type = content_type - existing_file.updated_at = datetime.now(timezone.utc) - db.add(existing_file) - await db.commit() - return {"success": True, "message": f"Updated file: {filename}"} - else: - # Create new and link - new_file = FileCreate( - user_id=user_id, - folder_id=None, - original_filename=filename, - storage_key=new_key, - file_size=file_size_bytes, - content_type=content_type, - scope=FileScope.PRIVATE, - category=FileCategory.DOCUMENT, - ) - created_file = await file_repo.create_file(new_file) - await knowledge_set_repo.link_file_to_knowledge_set(created_file.id, knowledge_set_id) - await db.commit() - return {"success": True, "message": f"Created file: {filename}"} - - except Exception as e: - logger.error(f"Error writing file: {e}") - return {"error": f"Internal error: {e!s}", "success": False} - - -async def _search_files(user_id: str, knowledge_set_id: UUID, query: str) -> dict[str, Any]: - """Search for files by name in the knowledge set.""" - try: - async with AsyncSessionLocal() as db: - file_repo = FileRepository(db) - matches: list[str] = [] - - try: - file_ids = await _get_files_in_knowledge_set(db, user_id, knowledge_set_id) - except ValueError as e: - return {"error": str(e), "success": False} - - for file_id in file_ids: - file = await file_repo.get_file_by_id(file_id) - if file and not file.is_deleted and query.lower() in file.original_filename.lower(): - matches.append(f"{file.original_filename} (ID: {file.id})") - - return { - "success": True, - "query": query, - "matches": matches, - "count": len(matches), - } - - except Exception as e: - logger.error(f"Error searching files: {e}") - return {"error": f"Internal error: {e!s}", "success": False} - - -# --- Tool Factory --- - - -def create_knowledge_tools() -> dict[str, BaseTool]: - """ - Create knowledge tools with placeholder implementations. - - Note: Knowledge tools require runtime context (user_id, knowledge_set_id). - The actual tool instances are created per-agent with context bound. - This function returns template tools for the registry. - - Returns: - Dict mapping tool_id to BaseTool placeholder instances. - """ - # These are placeholder tools - actual execution requires context binding - # See create_knowledge_tools_for_agent() for runtime creation - - tools: dict[str, BaseTool] = {} - - # List files tool - async def list_files_placeholder() -> dict[str, Any]: - return {"error": "Knowledge tools require agent context binding", "success": False} - - tools["knowledge_list"] = StructuredTool( - name="knowledge_list", - description=( - "List all files in the agent's knowledge base. Returns a list of filenames " - "that can be read or searched. Use this first to discover available files." - ), - args_schema=KnowledgeListFilesInput, - coroutine=list_files_placeholder, - ) - - # Read file tool - async def read_file_placeholder(filename: str) -> dict[str, Any]: - return {"error": "Knowledge tools require agent context binding", "success": False} - - tools["knowledge_read"] = StructuredTool( - name="knowledge_read", - description=( - "Read the content of a file from the agent's knowledge base. " - "Supports: PDF (text + tables), DOCX (text + tables), XLSX (all sheets), " - "PPTX (text + speaker notes), HTML (text extraction), JSON/YAML/XML (formatted), " - "images (OCR text extraction from PNG/JPG/GIF/WEBP), and plain text files. " - "Use knowledge_list first to see available files." - ), - args_schema=KnowledgeReadFileInput, - coroutine=read_file_placeholder, - ) - - # Write file tool - async def write_file_placeholder(filename: str, content: str) -> dict[str, Any]: - return {"error": "Knowledge tools require agent context binding", "success": False} - - tools["knowledge_write"] = StructuredTool( - name="knowledge_write", - description=( - "Create or update a file in the agent's knowledge base. " - "Supports: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, and plain text. " - "For production-quality documents (PDF/DOCX/XLSX/PPTX), provide a JSON " - "specification with structured content (headings, lists, tables, etc.) " - "instead of plain text. See content field description for JSON schema examples." - ), - args_schema=KnowledgeWriteFileInput, - coroutine=write_file_placeholder, - ) - - # Search files tool - async def search_files_placeholder(query: str) -> dict[str, Any]: - return {"error": "Knowledge tools require agent context binding", "success": False} - - tools["knowledge_search"] = StructuredTool( - name="knowledge_search", - description=( - "Search for files by name in the agent's knowledge base. Returns matching filenames that can then be read." - ), - args_schema=KnowledgeSearchFilesInput, - coroutine=search_files_placeholder, - ) - - return tools - - -def create_knowledge_tools_for_agent(user_id: str, knowledge_set_id: UUID) -> list[BaseTool]: - """ - Create knowledge tools bound to a specific agent's context. - - This creates actual working tools with user_id and knowledge_set_id - captured in closures. - - Args: - user_id: The user ID for access control - knowledge_set_id: The knowledge set ID to operate on - - Returns: - List of BaseTool instances with context bound - """ - tools: list[BaseTool] = [] - - # List files tool - async def list_files_bound() -> dict[str, Any]: - return await _list_files(user_id, knowledge_set_id) - - tools.append( - StructuredTool( - name="knowledge_list", - description=( - "List all files in your knowledge base. Returns filenames that can be read or searched. " - "Use this first to discover available files." - ), - args_schema=KnowledgeListFilesInput, - coroutine=list_files_bound, - ) - ) - - # Read file tool - async def read_file_bound(filename: str) -> dict[str, Any]: - return await _read_file(user_id, knowledge_set_id, filename) - - tools.append( - StructuredTool( - name="knowledge_read", - description=( - "Read the content of a file from your knowledge base. " - "Supports: PDF (text + tables), DOCX (text + tables), XLSX (all sheets), " - "PPTX (text + speaker notes), HTML (text extraction), JSON/YAML/XML (formatted), " - "images (OCR text extraction from PNG/JPG/GIF/WEBP), and plain text files. " - "Use knowledge_list first to see available files." - ), - args_schema=KnowledgeReadFileInput, - coroutine=read_file_bound, - ) - ) - - # Write file tool - async def write_file_bound(filename: str, content: str) -> dict[str, Any]: - return await _write_file(user_id, knowledge_set_id, filename, content) - - tools.append( - StructuredTool( - name="knowledge_write", - description=( - "Create or update a file in your knowledge base. " - "Supports: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, and plain text. " - "For production-quality documents (PDF/DOCX/XLSX/PPTX), provide a JSON " - "specification with structured content (headings, lists, tables, etc.) " - "instead of plain text. See content field description for JSON schema examples." - ), - args_schema=KnowledgeWriteFileInput, - coroutine=write_file_bound, - ) - ) - - # Search files tool - async def search_files_bound(query: str) -> dict[str, Any]: - return await _search_files(user_id, knowledge_set_id, query) - - tools.append( - StructuredTool( - name="knowledge_search", - description=( - "Search for files by name in your knowledge base. Returns matching filenames that can then be read." - ), - args_schema=KnowledgeSearchFilesInput, - coroutine=search_files_bound, - ) - ) - - return tools - - -__all__ = [ - "create_knowledge_tools", - "create_knowledge_tools_for_agent", - "KnowledgeListFilesInput", - "KnowledgeReadFileInput", - "KnowledgeWriteFileInput", - "KnowledgeSearchFilesInput", -] diff --git a/service/app/tools/builtin/knowledge/__init__.py b/service/app/tools/builtin/knowledge/__init__.py new file mode 100644 index 00000000..b2187d03 --- /dev/null +++ b/service/app/tools/builtin/knowledge/__init__.py @@ -0,0 +1,30 @@ +""" +Knowledge Base Tools for LangChain Agents. + +This module provides tools for knowledge base file operations. +These tools require runtime context (user_id, knowledge_set_id) to function. + +Unlike web search which works context-free, knowledge tools are created per-agent +with the agent's knowledge_set_id bound at creation time. +""" + +from __future__ import annotations + +from .schemas import ( + KnowledgeHelpInput, + KnowledgeListFilesInput, + KnowledgeReadFileInput, + KnowledgeSearchFilesInput, + KnowledgeWriteFileInput, +) +from .tools import create_knowledge_tools, create_knowledge_tools_for_agent + +__all__ = [ + "create_knowledge_tools", + "create_knowledge_tools_for_agent", + "KnowledgeListFilesInput", + "KnowledgeReadFileInput", + "KnowledgeWriteFileInput", + "KnowledgeSearchFilesInput", + "KnowledgeHelpInput", +] diff --git a/service/app/tools/builtin/knowledge/help_content.py b/service/app/tools/builtin/knowledge/help_content.py new file mode 100644 index 00000000..bb026be4 --- /dev/null +++ b/service/app/tools/builtin/knowledge/help_content.py @@ -0,0 +1,526 @@ +""" +Help content constants for knowledge tools. + +Contains all help text and documentation for knowledge base operations. +""" + +from __future__ import annotations + +from typing import Any + +KNOWLEDGE_HELP_OVERVIEW = """ +# Knowledge Base Tools - Quick Reference + +## Available Tools +- **knowledge_list**: List all files in your knowledge base +- **knowledge_read**: Read content from a file +- **knowledge_write**: Create or update files (supports rich documents) +- **knowledge_search**: Search files by name +- **knowledge_help**: Get detailed usage guides (this tool) + +## Supported File Types +- **Documents**: PDF, DOCX, PPTX, XLSX +- **Data**: JSON, YAML, XML +- **Web**: HTML +- **Text**: TXT, MD, CSV + +## Quick Start +1. Use `knowledge_list` to see available files +2. Use `knowledge_write` with plain text for simple files +3. Use `knowledge_write` with JSON spec for rich documents (call `knowledge_help` with topic='pptx' for examples) + +## Creating Beautiful Presentations with AI Images +For stunning presentations with AI-generated slides: +1. Use `generate_image` to create each slide as an image (use 16:9 aspect ratio) +2. Collect the `image_id` values from each generation +3. Use `knowledge_write` with `mode: "image_slides"` to assemble into PPTX + +Call `knowledge_help(topic='image_slides')` for detailed workflow and examples. + +For detailed help on a specific topic, call knowledge_help with topic='pptx', 'pdf', 'xlsx', 'images', 'tables', 'image_slides', or 'all'. +""" + +KNOWLEDGE_HELP_PPTX = """ +# PPTX (PowerPoint) Generation Guide + +## Basic Structure +```json +{ + "title": "Presentation Title", + "author": "Author Name", + "slides": [ + { "layout": "...", "title": "...", "content": [...] } + ] +} +``` + +## Slide Layouts +- `title` - Title slide with title and subtitle +- `title_content` - Title with content area (most common) +- `section` - Section header +- `two_column` - Two column layout +- `comparison` - Side-by-side comparison +- `title_only` - Title without content placeholder +- `blank` - Empty slide + +## Content Block Types + +### Text +```json +{"type": "text", "content": "Your paragraph text here", "style": "normal"} +``` +Styles: `normal`, `bold`, `italic`, `code` + +### Heading +```json +{"type": "heading", "content": "Section Title", "level": 2} +``` +Levels 1-6 (1 is largest) + +### List +```json +{"type": "list", "items": ["Point 1", "Point 2", "Point 3"], "ordered": false} +``` +Set `ordered: true` for numbered lists + +### Table +```json +{ + "type": "table", + "headers": ["Column 1", "Column 2", "Column 3"], + "rows": [ + ["Row 1 A", "Row 1 B", "Row 1 C"], + ["Row 2 A", "Row 2 B", "Row 2 C"] + ] +} +``` + +### Image +```json +{ + "type": "image", + "url": "https://example.com/chart.png", + "caption": "Figure 1: Sales Chart", + "width": 400 +} +``` +- `url`: HTTP URL, base64 data URL, or storage:// path +- `image_id`: UUID from generate_image tool (alternative to url) +- `caption`: Optional text below image +- `width`: Optional width in points (72 points = 1 inch) + +## Complete Example +```json +{ + "title": "Q4 Business Review", + "slides": [ + { + "layout": "title", + "title": "Q4 2024 Review", + "subtitle": "Sales Department" + }, + { + "layout": "title_content", + "title": "Revenue Summary", + "content": [ + {"type": "heading", "content": "Key Metrics", "level": 2}, + {"type": "table", "headers": ["Region", "Q3", "Q4", "Growth"], "rows": [ + ["North America", "$1.2M", "$1.5M", "+25%"], + ["Europe", "$800K", "$1.1M", "+37%"] + ]}, + {"type": "text", "content": "All regions exceeded targets.", "style": "bold"} + ], + "notes": "Emphasize the European growth story" + }, + { + "layout": "title_content", + "title": "Visual Analysis", + "content": [ + {"type": "image", "url": "https://example.com/chart.png", "caption": "Revenue by Region"} + ] + } + ] +} +``` +""" + +KNOWLEDGE_HELP_PDF_DOCX = """ +# PDF/DOCX Generation Guide + +## Basic Structure +```json +{ + "title": "Document Title", + "author": "Author Name", + "subject": "Document Subject", + "page_size": "letter", + "content": [...] +} +``` + +Page sizes: `letter`, `A4`, `legal` + +## Content Block Types + +### Heading +```json +{"type": "heading", "content": "Chapter Title", "level": 1} +``` + +### Text +```json +{"type": "text", "content": "Paragraph text here", "style": "normal"} +``` +Styles: `normal`, `bold`, `italic`, `code` + +### List +```json +{"type": "list", "items": ["Item 1", "Item 2"], "ordered": false} +``` + +### Table +```json +{ + "type": "table", + "headers": ["Name", "Value"], + "rows": [["Item A", "100"], ["Item B", "200"]] +} +``` + +### Page Break +```json +{"type": "page_break"} +``` + +## Example +```json +{ + "title": "Monthly Report", + "author": "Analytics Team", + "content": [ + {"type": "heading", "content": "Executive Summary", "level": 1}, + {"type": "text", "content": "This report covers..."}, + {"type": "heading", "content": "Key Findings", "level": 2}, + {"type": "list", "items": ["Revenue up 15%", "Costs down 8%"], "ordered": false}, + {"type": "page_break"}, + {"type": "heading", "content": "Detailed Analysis", "level": 1}, + {"type": "table", "headers": ["Metric", "Value"], "rows": [["Sales", "$1.5M"]]} + ] +} +``` +""" + +KNOWLEDGE_HELP_XLSX = """ +# XLSX (Excel) Generation Guide + +## Basic Structure +```json +{ + "sheets": [ + { + "name": "Sheet Name", + "headers": ["Col1", "Col2"], + "data": [[...], [...]], + "freeze_header": true + } + ] +} +``` + +## Sheet Properties +- `name`: Sheet tab name +- `headers`: Optional column headers (styled with blue background) +- `data`: 2D array of cell values (strings, numbers, null) +- `freeze_header`: Freeze the header row for scrolling (default: true) + +## Example: Multi-Sheet Workbook +```json +{ + "sheets": [ + { + "name": "Sales Data", + "headers": ["Product", "Q1", "Q2", "Q3", "Q4", "Total"], + "data": [ + ["Widget A", 100, 150, 200, 250, 700], + ["Widget B", 80, 90, 110, 130, 410], + ["Widget C", 50, 60, 70, 80, 260] + ], + "freeze_header": true + }, + { + "name": "Summary", + "headers": ["Metric", "Value"], + "data": [ + ["Total Revenue", 1370], + ["Average per Product", 456.67], + ["Best Performer", "Widget A"] + ] + } + ] +} +``` +""" + +KNOWLEDGE_HELP_IMAGES = """ +# Image Embedding Guide + +## Supported in PPTX Content Blocks + +### Image Block Structure +```json +{ + "type": "image", + "url": "...", + "caption": "Optional caption text", + "width": 400 +} +``` + +## URL Formats + +### HTTP/HTTPS URLs +```json +{"type": "image", "url": "https://example.com/chart.png"} +``` + +### Base64 Data URLs +```json +{"type": "image", "url": "..."} +``` + +### Storage URLs (internal files) +```json +{"type": "image", "url": "storage://path/to/uploaded/image.png"} +``` + +### Generated Images (from generate_image tool) +```json +{"type": "image", "image_id": "abc-123-456-def"} +``` + +## Size Handling +- **Max file size**: 10MB +- **Max dimension**: 4096px (larger images auto-resized) +- **Width parameter**: Specify in points (72pt = 1 inch) +- **Aspect ratio**: Always preserved + +## Example with Caption +```json +{ + "type": "image", + "url": "https://example.com/quarterly-chart.png", + "caption": "Figure 1: Quarterly Revenue Comparison", + "width": 500 +} +``` + +## Error Handling +If an image fails to load, a placeholder text will appear: +`[Image failed to load: ]` +""" + +KNOWLEDGE_HELP_TABLES = """ +# Table Generation Guide + +## Table Block Structure +```json +{ + "type": "table", + "headers": ["Column 1", "Column 2", "Column 3"], + "rows": [ + ["Row 1 Col 1", "Row 1 Col 2", "Row 1 Col 3"], + ["Row 2 Col 1", "Row 2 Col 2", "Row 2 Col 3"] + ] +} +``` + +## Supported In +- **PPTX**: Styled tables with blue headers +- **PDF**: Formatted tables with borders +- **DOCX**: Word tables with grid style + +## Styling (Automatic) +- Header row: Blue background (#4472C4), white bold text, centered +- Data rows: Standard formatting, left-aligned +- Borders: Thin black borders on all cells + +## Example: Data Table +```json +{ + "type": "table", + "headers": ["Product", "Price", "Stock", "Status"], + "rows": [ + ["Laptop Pro", "$1,299", "45", "In Stock"], + ["Tablet Air", "$799", "120", "In Stock"], + ["Phone Max", "$999", "0", "Out of Stock"], + ["Watch SE", "$249", "200", "In Stock"] + ] +} +``` + +## Tips +- Keep tables simple (avoid merged cells - not supported) +- Use consistent data types per column +- Header count must match row column count +- Empty cells: use empty string "" +""" + +KNOWLEDGE_HELP_IMAGE_SLIDES = """ +# Creating Beautiful Presentations with AI-Generated Slides + +## Overview +Instead of using structured content blocks, you can create stunning presentations +by generating each slide as an AI image. This gives full creative control over +typography, layout, colors, and visual effects. + +## Step-by-Step Workflow + +### Step 1: Generate Slide Images +Use the `generate_image` tool for each slide: + +``` +generate_image( + prompt="Professional presentation slide with title 'Q4 Revenue Summary' showing a blue gradient background, large white bold text, and a subtle upward trending graph icon. Clean corporate style, 16:9 aspect ratio.", + aspect_ratio="16:9" +) +``` + +### Step 2: Collect Image IDs +Each `generate_image` call returns an `image_id`. Save these: +- Slide 1: "abc-123-..." +- Slide 2: "def-456-..." +- etc. + +### Step 3: Create PPTX +Use `knowledge_write` with image_slides mode: + +```json +{ + "mode": "image_slides", + "title": "Q4 Business Review", + "author": "Sales Team", + "image_slides": [ + {"image_id": "abc-123-...", "notes": "Opening remarks"}, + {"image_id": "def-456-...", "notes": "Highlight 25% growth"}, + {"image_id": "ghi-789-...", "notes": "Thank the team"} + ] +} +``` + +## Prompting Tips for Consistent Style + +1. **Define a style template** and reference it in each prompt: + - "Corporate blue theme (#1a73e8), white text, clean minimal layout" + +2. **Specify slide type** in prompts: + - "Title slide" / "Content slide" / "Section divider" / "Closing slide" + +3. **Include aspect ratio**: + - Always use "16:9 aspect ratio presentation slide" + +4. **Maintain visual consistency**: + - "Matching the style of previous slides in this presentation" + +## Complete Example + +```python +# Agent generates beautiful slides +slide1 = await generate_image( + prompt="Title slide: 'Q4 Business Review 2024' with dark blue gradient, + large white text centered, subtle geometric patterns, + professional corporate style, 16:9 presentation slide" +) + +slide2 = await generate_image( + prompt="Content slide: 'Revenue Growth +25%' with bar chart visualization, + blue color scheme matching previous slide, clean data presentation, + 16:9 presentation slide" +) + +slide3 = await generate_image( + prompt="Closing slide: 'Thank You' with contact information, + matching corporate blue theme, 16:9 presentation slide" +) + +# Agent assembles into PPTX +await knowledge_write( + filename="Q4-Review.pptx", + content=json.dumps({ + "mode": "image_slides", + "title": "Q4 Business Review", + "image_slides": [ + {"image_id": slide1["image_id"], "notes": "Welcome everyone"}, + {"image_id": slide2["image_id"], "notes": "Emphasize growth"}, + {"image_id": slide3["image_id"], "notes": "Q&A time"} + ] + }) +) +``` + +## Limitations +- Text in images is NOT editable in PowerPoint +- Best for final presentations, not drafts requiring edits +- Larger file sizes than structured content +""" + + +def get_help_content(topic: str | None) -> dict[str, Any]: + """Get help content for the specified topic.""" + topic_map = { + "pptx": KNOWLEDGE_HELP_PPTX, + "powerpoint": KNOWLEDGE_HELP_PPTX, + "pdf": KNOWLEDGE_HELP_PDF_DOCX, + "docx": KNOWLEDGE_HELP_PDF_DOCX, + "word": KNOWLEDGE_HELP_PDF_DOCX, + "xlsx": KNOWLEDGE_HELP_XLSX, + "excel": KNOWLEDGE_HELP_XLSX, + "images": KNOWLEDGE_HELP_IMAGES, + "image": KNOWLEDGE_HELP_IMAGES, + "tables": KNOWLEDGE_HELP_TABLES, + "table": KNOWLEDGE_HELP_TABLES, + "image_slides": KNOWLEDGE_HELP_IMAGE_SLIDES, + "imageslides": KNOWLEDGE_HELP_IMAGE_SLIDES, + } + + if topic is None: + return {"success": True, "content": KNOWLEDGE_HELP_OVERVIEW} + + topic_lower = topic.lower().strip() + + if topic_lower == "all": + all_content = ( + KNOWLEDGE_HELP_OVERVIEW + + "\n\n---\n\n" + + KNOWLEDGE_HELP_PPTX + + "\n\n---\n\n" + + KNOWLEDGE_HELP_PDF_DOCX + + "\n\n---\n\n" + + KNOWLEDGE_HELP_XLSX + + "\n\n---\n\n" + + KNOWLEDGE_HELP_IMAGES + + "\n\n---\n\n" + + KNOWLEDGE_HELP_TABLES + + "\n\n---\n\n" + + KNOWLEDGE_HELP_IMAGE_SLIDES + ) + return {"success": True, "content": all_content} + + if topic_lower in topic_map: + return {"success": True, "content": topic_map[topic_lower]} + + return { + "success": False, + "error": f"Unknown topic: {topic}. Available topics: pptx, pdf, docx, xlsx, images, tables, image_slides, all", + } + + +__all__ = [ + "KNOWLEDGE_HELP_OVERVIEW", + "KNOWLEDGE_HELP_PPTX", + "KNOWLEDGE_HELP_PDF_DOCX", + "KNOWLEDGE_HELP_XLSX", + "KNOWLEDGE_HELP_IMAGES", + "KNOWLEDGE_HELP_TABLES", + "KNOWLEDGE_HELP_IMAGE_SLIDES", + "get_help_content", +] diff --git a/service/app/tools/builtin/knowledge/operations.py b/service/app/tools/builtin/knowledge/operations.py new file mode 100644 index 00000000..eb4f386d --- /dev/null +++ b/service/app/tools/builtin/knowledge/operations.py @@ -0,0 +1,344 @@ +""" +Knowledge tool implementation functions. + +Core operations for knowledge base file management. +""" + +from __future__ import annotations + +import io +import json +import logging +import mimetypes +from datetime import datetime, timezone +from typing import Any +from uuid import UUID + +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.storage import FileCategory, FileScope, generate_storage_key, get_storage_service +from app.infra.database import AsyncSessionLocal +from app.models.file import FileCreate +from app.repos.file import FileRepository +from app.repos.knowledge_set import KnowledgeSetRepository + +logger = logging.getLogger(__name__) + + +async def _resolve_image_ids_to_storage_urls( + content: str, + file_repo: FileRepository, + user_id: str, +) -> str: + """ + Resolve image_ids in document specs to storage:// URLs. + + This function handles the async database lookup in the async layer, + so sync document handlers don't need to do async operations. + + Supports: + - PresentationSpec with image_slides mode (image_slides[].image_id) + - PresentationSpec with ImageBlocks in slides (slides[].content[].image_id) + + Args: + content: JSON content to process + file_repo: File repository for database lookups + user_id: User ID for ownership verification (security check) + """ + try: + data = json.loads(content) + except json.JSONDecodeError: + # Not JSON, return as-is + return content + + if not isinstance(data, dict): + return content + + modified = False + + # Collect all image_ids that need resolution + image_ids_to_resolve: set[str] = set() + + # Check for image_slides mode + if data.get("mode") == "image_slides" and "image_slides" in data: + for slide in data.get("image_slides", []): + if isinstance(slide, dict) and slide.get("image_id"): + image_ids_to_resolve.add(slide["image_id"]) + + # Check for ImageBlocks in structured slides + for slide in data.get("slides", []): + if isinstance(slide, dict): + for block in slide.get("content", []): + if isinstance(block, dict) and block.get("type") == "image" and block.get("image_id"): + image_ids_to_resolve.add(block["image_id"]) + + if not image_ids_to_resolve: + return content + + # Resolve image_ids to storage URLs + id_to_storage_url: dict[str, str] = {} + for image_id in image_ids_to_resolve: + try: + file_uuid = UUID(image_id) + file_record = await file_repo.get_file_by_id(file_uuid) + if file_record and not file_record.is_deleted: + # Security check: verify the file belongs to the current user + if file_record.user_id != user_id: + logger.warning( + f"Image ownership mismatch: {image_id} belongs to {file_record.user_id}, not {user_id}" + ) + continue + id_to_storage_url[image_id] = f"storage://{file_record.storage_key}" + else: + logger.warning(f"Image not found or deleted: {image_id}") + except ValueError: + logger.warning(f"Invalid image_id format: {image_id}") + + # Replace image_ids with storage URLs in image_slides + if data.get("mode") == "image_slides" and "image_slides" in data: + for slide in data.get("image_slides", []): + if isinstance(slide, dict) and slide.get("image_id"): + image_id = slide["image_id"] + if image_id in id_to_storage_url: + # Add storage_url field, keep image_id for reference + slide["storage_url"] = id_to_storage_url[image_id] + modified = True + + # Replace image_ids with storage URLs in structured slides + for slide in data.get("slides", []): + if isinstance(slide, dict): + for block in slide.get("content", []): + if isinstance(block, dict) and block.get("type") == "image" and block.get("image_id"): + image_id = block["image_id"] + if image_id in id_to_storage_url: + # Set url to storage URL, keep image_id for reference + block["url"] = id_to_storage_url[image_id] + modified = True + + if modified: + return json.dumps(data) + return content + + +async def get_files_in_knowledge_set(db: AsyncSession, user_id: str, knowledge_set_id: UUID) -> list[UUID]: + """Get all file IDs in a knowledge set.""" + knowledge_set_repo = KnowledgeSetRepository(db) + + # Validate access + try: + await knowledge_set_repo.validate_access(user_id, knowledge_set_id) + except ValueError as e: + raise ValueError(f"Access denied: {e}") + + # Get file IDs + file_ids = await knowledge_set_repo.get_files_in_knowledge_set(knowledge_set_id) + return file_ids + + +async def list_files(user_id: str, knowledge_set_id: UUID) -> dict[str, Any]: + """List all files in the knowledge set.""" + try: + async with AsyncSessionLocal() as db: + file_repo = FileRepository(db) + + try: + file_ids = await get_files_in_knowledge_set(db, user_id, knowledge_set_id) + except ValueError as e: + return {"error": str(e), "success": False} + + # Fetch file objects + files = [] + for file_id in file_ids: + file = await file_repo.get_file_by_id(file_id) + if file and not file.is_deleted: + files.append(file) + + # Format output + entries: list[str] = [] + for f in files: + entries.append(f"[FILE] {f.original_filename} (ID: {f.id})") + + return { + "success": True, + "knowledge_set_id": str(knowledge_set_id), + "entries": entries, + "count": len(entries), + } + + except Exception as e: + logger.error(f"Error listing files: {e}") + return {"error": f"Internal error: {e!s}", "success": False} + + +async def read_file(user_id: str, knowledge_set_id: UUID, filename: str) -> dict[str, Any]: + """Read content of a file from the knowledge set.""" + from app.tools.utils.documents.handlers import FileHandlerFactory + + try: + # Normalize filename + filename = filename.strip("/").split("/")[-1] + + async with AsyncSessionLocal() as db: + file_repo = FileRepository(db) + target_file = None + + try: + file_ids = await get_files_in_knowledge_set(db, user_id, knowledge_set_id) + except ValueError as e: + return {"error": str(e), "success": False} + + # Find file by name + for file_id in file_ids: + file = await file_repo.get_file_by_id(file_id) + if file and file.original_filename == filename and not file.is_deleted: + target_file = file + break + + if not target_file: + return {"error": f"File '{filename}' not found in knowledge set.", "success": False} + + # Download content + storage = get_storage_service() + buffer = io.BytesIO() + await storage.download_file(target_file.storage_key, buffer) + file_bytes = buffer.getvalue() + + # Use handler to process content (text mode only for LangChain tools) + handler = FileHandlerFactory.get_handler(target_file.original_filename) + + try: + result = handler.read_content(file_bytes, mode="text") + return { + "success": True, + "filename": target_file.original_filename, + "content": result, + "size_bytes": target_file.file_size, + } + except Exception as e: + return {"error": f"Error parsing file: {e!s}", "success": False} + + except Exception as e: + logger.error(f"Error reading file: {e}") + return {"error": f"Internal error: {e!s}", "success": False} + + +async def write_file(user_id: str, knowledge_set_id: UUID, filename: str, content: str) -> dict[str, Any]: + """Create or update a file in the knowledge set.""" + from app.tools.utils.documents.handlers import FileHandlerFactory + + try: + filename = filename.strip("/").split("/")[-1] + + async with AsyncSessionLocal() as db: + file_repo = FileRepository(db) + knowledge_set_repo = KnowledgeSetRepository(db) + storage = get_storage_service() + + try: + file_ids = await get_files_in_knowledge_set(db, user_id, knowledge_set_id) + except ValueError as e: + return {"error": str(e), "success": False} + + # Check if file exists + existing_file = None + for file_id in file_ids: + file = await file_repo.get_file_by_id(file_id) + if file and file.original_filename == filename and not file.is_deleted: + existing_file = file + break + + # Determine content type + content_type, _ = mimetypes.guess_type(filename) + if not content_type: + if filename.endswith(".docx"): + content_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + elif filename.endswith(".xlsx"): + content_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + elif filename.endswith(".pptx"): + content_type = "application/vnd.openxmlformats-officedocument.presentationml.presentation" + elif filename.endswith(".pdf"): + content_type = "application/pdf" + else: + content_type = "text/plain" + + # Resolve image_ids to storage URLs for PPTX files (async DB lookup here) + if filename.endswith(".pptx"): + content = await _resolve_image_ids_to_storage_urls(content, file_repo, user_id) + + # Use handler to create content bytes + handler = FileHandlerFactory.get_handler(filename) + encoded_content = handler.create_content(content) + + new_key = generate_storage_key(user_id, filename, FileScope.PRIVATE) + data = io.BytesIO(encoded_content) + file_size_bytes = len(encoded_content) + + await storage.upload_file(data, new_key, content_type=content_type) + + if existing_file: + # Update existing + existing_file.storage_key = new_key + existing_file.file_size = file_size_bytes + existing_file.content_type = content_type + existing_file.updated_at = datetime.now(timezone.utc) + db.add(existing_file) + await db.commit() + return {"success": True, "message": f"Updated file: {filename}"} + else: + # Create new and link + new_file = FileCreate( + user_id=user_id, + folder_id=None, + original_filename=filename, + storage_key=new_key, + file_size=file_size_bytes, + content_type=content_type, + scope=FileScope.PRIVATE, + category=FileCategory.DOCUMENT, + ) + created_file = await file_repo.create_file(new_file) + await knowledge_set_repo.link_file_to_knowledge_set(created_file.id, knowledge_set_id) + await db.commit() + return {"success": True, "message": f"Created file: {filename}"} + + except Exception as e: + logger.error(f"Error writing file: {e}") + return {"error": f"Internal error: {e!s}", "success": False} + + +async def search_files(user_id: str, knowledge_set_id: UUID, query: str) -> dict[str, Any]: + """Search for files by name in the knowledge set.""" + try: + async with AsyncSessionLocal() as db: + file_repo = FileRepository(db) + matches: list[str] = [] + + try: + file_ids = await get_files_in_knowledge_set(db, user_id, knowledge_set_id) + except ValueError as e: + return {"error": str(e), "success": False} + + for file_id in file_ids: + file = await file_repo.get_file_by_id(file_id) + if file and not file.is_deleted and query.lower() in file.original_filename.lower(): + matches.append(f"{file.original_filename} (ID: {file.id})") + + return { + "success": True, + "query": query, + "matches": matches, + "count": len(matches), + } + + except Exception as e: + logger.error(f"Error searching files: {e}") + return {"error": f"Internal error: {e!s}", "success": False} + + +__all__ = [ + "get_files_in_knowledge_set", + "list_files", + "read_file", + "write_file", + "search_files", +] diff --git a/service/app/tools/builtin/knowledge/schemas.py b/service/app/tools/builtin/knowledge/schemas.py new file mode 100644 index 00000000..dcc0dac7 --- /dev/null +++ b/service/app/tools/builtin/knowledge/schemas.py @@ -0,0 +1,95 @@ +""" +Input schemas for knowledge tools. + +Pydantic models defining the input parameters for each knowledge tool. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class KnowledgeListFilesInput(BaseModel): + """Input schema for list_files tool - no parameters needed.""" + + pass + + +class KnowledgeReadFileInput(BaseModel): + """Input schema for read_file tool.""" + + filename: str = Field( + description=( + "The name of the file to read from the knowledge base. " + "Supported formats: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, " + "images (PNG/JPG/GIF/WEBP with OCR), and plain text files." + ) + ) + + +class KnowledgeWriteFileInput(BaseModel): + """Input schema for write_file tool.""" + + filename: str = Field( + description=( + "The name of the file to create or update. Use appropriate extensions: " + ".txt, .md (plain text), .pdf (PDF document), .docx (Word), " + ".xlsx (Excel), .pptx (PowerPoint), .json, .yaml, .xml, .html." + ) + ) + content: str = Field( + description=( + "The content to write. Can be plain text (creates simple documents) or " + "a JSON specification for production-quality documents:\n\n" + "**For PDF/DOCX (DocumentSpec JSON):**\n" + '{"title": "My Report", "author": "Name", "content": [\n' + ' {"type": "heading", "content": "Section 1", "level": 1},\n' + ' {"type": "text", "content": "Paragraph text here"},\n' + ' {"type": "list", "items": ["Item 1", "Item 2"], "ordered": false},\n' + ' {"type": "table", "headers": ["Col1", "Col2"], "rows": [["A", "B"]]},\n' + ' {"type": "page_break"}\n' + "]}\n\n" + "**For XLSX (SpreadsheetSpec JSON):**\n" + '{"sheets": [{"name": "Data", "headers": ["Name", "Value"], ' + '"data": [["A", 1], ["B", 2]], "freeze_header": true}]}\n\n' + "**For PPTX (PresentationSpec JSON) - Structured mode:**\n" + '{"title": "My Presentation", "slides": [\n' + ' {"layout": "title", "title": "Welcome", "subtitle": "Intro"},\n' + ' {"layout": "title_content", "title": "Slide 2", ' + '"content": [{"type": "list", "items": ["Point 1", "Point 2"]}], ' + '"notes": "Speaker notes here"}\n' + "]}\n\n" + "**For PPTX - AI-generated image slides mode:**\n" + '{"mode": "image_slides", "title": "My Presentation", ' + '"image_slides": [\n' + ' {"image_id": "", "notes": "Speaker notes"}\n' + "]}" + ) + ) + + +class KnowledgeSearchFilesInput(BaseModel): + """Input schema for search_files tool.""" + + query: str = Field(description="Search term to find files by name.") + + +class KnowledgeHelpInput(BaseModel): + """Input schema for knowledge_help tool.""" + + topic: str | None = Field( + default=None, + description=( + "Optional topic to get help for. Options: 'pptx', 'pdf', 'docx', 'xlsx', " + "'images', 'tables', 'image_slides', 'all'. If not specified, returns overview." + ), + ) + + +__all__ = [ + "KnowledgeListFilesInput", + "KnowledgeReadFileInput", + "KnowledgeWriteFileInput", + "KnowledgeSearchFilesInput", + "KnowledgeHelpInput", +] diff --git a/service/app/tools/builtin/knowledge/tools.py b/service/app/tools/builtin/knowledge/tools.py new file mode 100644 index 00000000..0cfc2862 --- /dev/null +++ b/service/app/tools/builtin/knowledge/tools.py @@ -0,0 +1,232 @@ +""" +Knowledge tool factory functions. + +Creates LangChain tools for knowledge base operations. +""" + +from __future__ import annotations + +from typing import Any +from uuid import UUID + +from langchain_core.tools import BaseTool, StructuredTool + +from .help_content import get_help_content +from .operations import list_files, read_file, search_files, write_file +from .schemas import ( + KnowledgeHelpInput, + KnowledgeListFilesInput, + KnowledgeReadFileInput, + KnowledgeSearchFilesInput, + KnowledgeWriteFileInput, +) + + +def create_knowledge_tools() -> dict[str, BaseTool]: + """ + Create knowledge tools with placeholder implementations. + + Note: Knowledge tools require runtime context (user_id, knowledge_set_id). + The actual tool instances are created per-agent with context bound. + This function returns template tools for the registry. + + Returns: + Dict mapping tool_id to BaseTool placeholder instances. + """ + # These are placeholder tools - actual execution requires context binding + # See create_knowledge_tools_for_agent() for runtime creation + + tools: dict[str, BaseTool] = {} + + # List files tool + async def list_files_placeholder() -> dict[str, Any]: + return {"error": "Knowledge tools require agent context binding", "success": False} + + tools["knowledge_list"] = StructuredTool( + name="knowledge_list", + description=( + "List all files in the agent's knowledge base. Returns a list of filenames " + "that can be read or searched. Use this first to discover available files." + ), + args_schema=KnowledgeListFilesInput, + coroutine=list_files_placeholder, + ) + + # Read file tool + async def read_file_placeholder(filename: str) -> dict[str, Any]: + return {"error": "Knowledge tools require agent context binding", "success": False} + + tools["knowledge_read"] = StructuredTool( + name="knowledge_read", + description=( + "Read the content of a file from the agent's knowledge base. " + "Supports: PDF (text + tables), DOCX (text + tables), XLSX (all sheets), " + "PPTX (text + speaker notes), HTML (text extraction), JSON/YAML/XML (formatted), " + "images (OCR text extraction from PNG/JPG/GIF/WEBP), and plain text files. " + "Use knowledge_list first to see available files." + ), + args_schema=KnowledgeReadFileInput, + coroutine=read_file_placeholder, + ) + + # Write file tool + async def write_file_placeholder(filename: str, content: str) -> dict[str, Any]: + return {"error": "Knowledge tools require agent context binding", "success": False} + + tools["knowledge_write"] = StructuredTool( + name="knowledge_write", + description=( + "Create or update a file in the agent's knowledge base. " + "Supports: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, and plain text. " + "For rich documents with images, tables, and formatting, provide a JSON specification. " + "Use knowledge_help with topic='pptx' or 'pdf' for detailed examples. " + "For beautiful AI-generated presentations: use generate_image to create slide images, " + "then use knowledge_write with mode='image_slides' and the image_id values. " + "Call knowledge_help(topic='image_slides') for the complete workflow." + ), + args_schema=KnowledgeWriteFileInput, + coroutine=write_file_placeholder, + ) + + # Search files tool + async def search_files_placeholder(query: str) -> dict[str, Any]: + return {"error": "Knowledge tools require agent context binding", "success": False} + + tools["knowledge_search"] = StructuredTool( + name="knowledge_search", + description=( + "Search for files by name in the agent's knowledge base. Returns matching filenames that can then be read." + ), + args_schema=KnowledgeSearchFilesInput, + coroutine=search_files_placeholder, + ) + + # Help tool + async def help_placeholder(topic: str | None = None) -> dict[str, Any]: + return get_help_content(topic) + + tools["knowledge_help"] = StructuredTool( + name="knowledge_help", + description=( + "Get detailed help and examples for using knowledge tools. " + "Call without arguments for overview, or with topic='pptx', 'pdf', 'xlsx', " + "'images', 'tables', 'image_slides', or 'all' for specific guides with JSON examples." + ), + args_schema=KnowledgeHelpInput, + coroutine=help_placeholder, + ) + + return tools + + +def create_knowledge_tools_for_agent(user_id: str, knowledge_set_id: UUID) -> list[BaseTool]: + """ + Create knowledge tools bound to a specific agent's context. + + This creates actual working tools with user_id and knowledge_set_id + captured in closures. + + Args: + user_id: The user ID for access control + knowledge_set_id: The knowledge set ID to operate on + + Returns: + List of BaseTool instances with context bound + """ + tools: list[BaseTool] = [] + + # List files tool + async def list_files_bound() -> dict[str, Any]: + return await list_files(user_id, knowledge_set_id) + + tools.append( + StructuredTool( + name="knowledge_list", + description=( + "List all files in your knowledge base. Returns filenames that can be read or searched. " + "Use this first to discover available files." + ), + args_schema=KnowledgeListFilesInput, + coroutine=list_files_bound, + ) + ) + + # Read file tool + async def read_file_bound(filename: str) -> dict[str, Any]: + return await read_file(user_id, knowledge_set_id, filename) + + tools.append( + StructuredTool( + name="knowledge_read", + description=( + "Read the content of a file from your knowledge base. " + "Supports: PDF (text + tables), DOCX (text + tables), XLSX (all sheets), " + "PPTX (text + speaker notes), HTML (text extraction), JSON/YAML/XML (formatted), " + "images (OCR text extraction from PNG/JPG/GIF/WEBP), and plain text files. " + "Use knowledge_list first to see available files." + ), + args_schema=KnowledgeReadFileInput, + coroutine=read_file_bound, + ) + ) + + # Write file tool + async def write_file_bound(filename: str, content: str) -> dict[str, Any]: + return await write_file(user_id, knowledge_set_id, filename, content) + + tools.append( + StructuredTool( + name="knowledge_write", + description=( + "Create or update a file in your knowledge base. " + "Supports: PDF, DOCX, XLSX, PPTX, HTML, JSON, YAML, XML, and plain text. " + "For rich documents with images, tables, and formatting, provide a JSON specification. " + "Use knowledge_help with topic='pptx' or 'pdf' for detailed examples. " + "For beautiful AI-generated presentations: use generate_image to create slide images, " + "then use knowledge_write with mode='image_slides' and the image_id values. " + "Call knowledge_help(topic='image_slides') for the complete workflow." + ), + args_schema=KnowledgeWriteFileInput, + coroutine=write_file_bound, + ) + ) + + # Search files tool + async def search_files_bound(query: str) -> dict[str, Any]: + return await search_files(user_id, knowledge_set_id, query) + + tools.append( + StructuredTool( + name="knowledge_search", + description=( + "Search for files by name in your knowledge base. Returns matching filenames that can then be read." + ), + args_schema=KnowledgeSearchFilesInput, + coroutine=search_files_bound, + ) + ) + + # Help tool (no context needed - static content) + async def help_bound(topic: str | None = None) -> dict[str, Any]: + return get_help_content(topic) + + tools.append( + StructuredTool( + name="knowledge_help", + description=( + "Get detailed help and examples for using knowledge tools. " + "Call without arguments for overview, or with topic='pptx', 'pdf', 'xlsx', " + "'images', 'tables', 'image_slides', or 'all' for specific guides with JSON examples." + ), + args_schema=KnowledgeHelpInput, + coroutine=help_bound, + ) + ) + + return tools + + +__all__ = [ + "create_knowledge_tools", + "create_knowledge_tools_for_agent", +] diff --git a/service/app/tools/cost.py b/service/app/tools/cost.py new file mode 100644 index 00000000..16921a94 --- /dev/null +++ b/service/app/tools/cost.py @@ -0,0 +1,57 @@ +"""Tool cost calculation utilities.""" + +from __future__ import annotations + +import logging +from typing import Any + +from app.tools.registry import BuiltinToolRegistry + +logger = logging.getLogger(__name__) + + +def calculate_tool_cost( + tool_name: str, + tool_args: dict[str, Any] | None = None, + tool_result: dict[str, Any] | None = None, +) -> int: + """ + Calculate cost for a tool execution. + + Args: + tool_name: Name of the tool + tool_args: Tool input arguments + tool_result: Tool execution result + + Returns: + Cost in points + """ + # Get tool cost config from registry + tool_info = BuiltinToolRegistry.get_info(tool_name) + if not tool_info or not tool_info.cost: + return 0 + + config = tool_info.cost + cost = config.base_cost + + # Add input image cost (for generate_image with reference) + if config.input_image_cost and tool_args: + if tool_args.get("image_id"): # Has reference image + cost += config.input_image_cost + + # Add output file cost (for knowledge_write creating new files) + if config.output_file_cost and tool_result: + if isinstance(tool_result, dict): + # Check if tool created a new file (not updated) + # knowledge_write returns message like "Created file: filename" + message = tool_result.get("message", "") + if tool_result.get("success") and "Created" in message: + cost += config.output_file_cost + + if cost > 0: + logger.debug(f"Tool {tool_name} cost: {cost} (base={config.base_cost})") + + return cost + + +__all__ = ["calculate_tool_cost"] diff --git a/service/app/tools/registry.py b/service/app/tools/registry.py index ad58d534..607bae1f 100644 --- a/service/app/tools/registry.py +++ b/service/app/tools/registry.py @@ -22,6 +22,14 @@ logger = logging.getLogger(__name__) +class ToolCostConfig(BaseModel): + """Cost configuration for a tool.""" + + base_cost: int = Field(default=0, description="Base cost per execution") + input_image_cost: int = Field(default=0, description="Additional cost per input image") + output_file_cost: int = Field(default=0, description="Additional cost per output file") + + class ToolInfo(BaseModel): """Metadata about a builtin tool for API responses.""" @@ -42,6 +50,10 @@ class ToolInfo(BaseModel): default_factory=list, description="Runtime context requirements (e.g., ['user_id', 'knowledge_set_id'])", ) + cost: ToolCostConfig = Field( + default_factory=ToolCostConfig, + description="Cost configuration for this tool", + ) class BuiltinToolRegistry: @@ -65,6 +77,7 @@ def register( ui_toggleable: bool = True, default_enabled: bool = False, requires_context: list[str] | None = None, + cost: ToolCostConfig | None = None, ) -> None: """ Register a builtin tool. @@ -77,6 +90,7 @@ def register( ui_toggleable: Whether to show as toggle in UI (default: True) default_enabled: Whether enabled by default for new agents (default: False) requires_context: List of required context keys (e.g., ["user_id"]) + cost: Cost configuration for the tool (default: no cost) """ cls._tools[tool_id] = tool cls._metadata[tool_id] = ToolInfo( @@ -87,6 +101,7 @@ def register( ui_toggleable=ui_toggleable, default_enabled=default_enabled, requires_context=requires_context or [], + cost=cost or ToolCostConfig(), ) logger.debug(f"Registered builtin tool: {tool_id} ({category})") @@ -172,8 +187,18 @@ def register_builtin_tools() -> None: ui_toggleable=True, default_enabled=True, # Web search enabled by default requires_context=[], + cost=ToolCostConfig(base_cost=1), ) + # Tool cost configs for knowledge tools + knowledge_tool_costs = { + "knowledge_list": ToolCostConfig(), # Free + "knowledge_read": ToolCostConfig(), # Free + "knowledge_write": ToolCostConfig(output_file_cost=5), # Charge for new files + "knowledge_search": ToolCostConfig(), # Free + "knowledge_help": ToolCostConfig(), # Free + } + # Register knowledge tools (auto-enabled when knowledge_set exists, not UI toggleable) knowledge_tools = create_knowledge_tools() for tool_id, tool in knowledge_tools.items(): @@ -184,6 +209,7 @@ def register_builtin_tools() -> None: ui_toggleable=False, # Auto-enabled based on context default_enabled=False, requires_context=["user_id", "knowledge_set_id"], + cost=knowledge_tool_costs.get(tool_id, ToolCostConfig()), ) # Register memory tools (disabled due to performance issues) @@ -200,6 +226,12 @@ def register_builtin_tools() -> None: # requires_context=["user_id", "agent_id"], # ) + # Tool cost configs for image tools + image_tool_costs = { + "generate_image": ToolCostConfig(base_cost=10, input_image_cost=5), # 10 base, +5 if using reference + "read_image": ToolCostConfig(base_cost=2), # Vision model inference + } + # Register image tools from app.tools.builtin.image import create_image_tools @@ -213,9 +245,10 @@ def register_builtin_tools() -> None: ui_toggleable=True, default_enabled=False, requires_context=["user_id"], + cost=image_tool_costs.get(tool_id, ToolCostConfig()), ) logger.info(f"Registered {BuiltinToolRegistry.count()} builtin tools") -__all__ = ["BuiltinToolRegistry", "ToolInfo", "register_builtin_tools"] +__all__ = ["BuiltinToolRegistry", "ToolCostConfig", "ToolInfo", "register_builtin_tools"] diff --git a/service/app/tools/utils/documents/__init__.py b/service/app/tools/utils/documents/__init__.py new file mode 100644 index 00000000..686e0dfa --- /dev/null +++ b/service/app/tools/utils/documents/__init__.py @@ -0,0 +1,83 @@ +""" +Document generation utilities for PDF, DOCX, XLSX, and PPTX files. + +This module provides: +- Document specification schemas (spec.py) +- Image fetching from various sources (image_fetcher.py) +- File handlers for reading and creating documents (handlers.py) +""" + +from app.tools.utils.documents.handlers import ( + BaseFileHandler, + DocxFileHandler, + ExcelFileHandler, + FileHandlerFactory, + HtmlFileHandler, + ImageFileHandler, + JsonFileHandler, + PdfFileHandler, + PptxFileHandler, + ReadMode, + TextFileHandler, + XmlFileHandler, + YamlFileHandler, +) +from app.tools.utils.documents.image_fetcher import ( + DEFAULT_TIMEOUT, + MAX_IMAGE_DIMENSION, + MAX_IMAGE_SIZE_BYTES, + FetchedImage, + ImageFetcher, +) +from app.tools.utils.documents.spec import ( + ContentBlock, + DocumentSpec, + HeadingBlock, + ImageBlock, + ImageSlideSpec, + ListBlock, + PageBreakBlock, + PresentationSpec, + SheetSpec, + SlideSpec, + SpreadsheetSpec, + TableBlock, + TextBlock, +) + +__all__ = [ + # Spec classes + "TextBlock", + "HeadingBlock", + "ListBlock", + "TableBlock", + "ImageBlock", + "PageBreakBlock", + "ContentBlock", + "DocumentSpec", + "SheetSpec", + "SpreadsheetSpec", + "SlideSpec", + "ImageSlideSpec", + "PresentationSpec", + # Image fetcher + "ImageFetcher", + "FetchedImage", + "MAX_IMAGE_SIZE_BYTES", + "MAX_IMAGE_DIMENSION", + "DEFAULT_TIMEOUT", + # File handlers + "BaseFileHandler", + "TextFileHandler", + "HtmlFileHandler", + "JsonFileHandler", + "YamlFileHandler", + "XmlFileHandler", + "ImageFileHandler", + "PdfFileHandler", + "DocxFileHandler", + "ExcelFileHandler", + "PptxFileHandler", + "FileHandlerFactory", + "ReadMode", +] diff --git a/service/app/mcp/file_handlers.py b/service/app/tools/utils/documents/handlers.py similarity index 71% rename from service/app/mcp/file_handlers.py rename to service/app/tools/utils/documents/handlers.py index 9e27733c..52deeed6 100644 --- a/service/app/mcp/file_handlers.py +++ b/service/app/tools/utils/documents/handlers.py @@ -20,7 +20,7 @@ from pydantic import BaseModel, ValidationError if TYPE_CHECKING: - from app.mcp.document_spec import DocumentSpec, PresentationSpec, SpreadsheetSpec + from app.tools.utils.documents.spec import DocumentSpec, PresentationSpec, SpreadsheetSpec logger = logging.getLogger(__name__) @@ -384,7 +384,7 @@ def _format_table(self, table: Any) -> str: def create_content(self, text_content: str) -> bytes: """Create PDF from text or DocumentSpec JSON.""" - from app.mcp.document_spec import DocumentSpec + from app.tools.utils.documents.spec import DocumentSpec spec = self._try_parse_spec(text_content, DocumentSpec) if spec: @@ -554,7 +554,7 @@ def _extract_table(self, tbl_element: Any, doc: Any) -> str: def create_content(self, text_content: str) -> bytes: """Create DOCX from text or DocumentSpec JSON.""" - from app.mcp.document_spec import DocumentSpec + from app.tools.utils.documents.spec import DocumentSpec spec = self._try_parse_spec(text_content, DocumentSpec) if spec: @@ -660,7 +660,7 @@ def read_content(self, file_bytes: bytes, mode: ReadMode = "text") -> Union[str, def create_content(self, text_content: str) -> bytes: """Create XLSX from text or SpreadsheetSpec JSON.""" - from app.mcp.document_spec import SpreadsheetSpec + from app.tools.utils.documents.spec import SpreadsheetSpec spec = self._try_parse_spec(text_content, SpreadsheetSpec) if spec: @@ -785,7 +785,7 @@ def read_content(self, file_bytes: bytes, mode: ReadMode = "text") -> Union[str, def create_content(self, text_content: str) -> bytes: """Create PPTX from text or PresentationSpec JSON.""" - from app.mcp.document_spec import PresentationSpec + from app.tools.utils.documents.spec import PresentationSpec spec = self._try_parse_spec(text_content, PresentationSpec) if spec: @@ -816,12 +816,23 @@ def _create_pptx_from_text(self, text_content: str) -> bytes: def _create_pptx_from_spec(self, spec: PresentationSpec) -> bytes: """Create production PPTX from PresentationSpec.""" + # Route based on mode + if spec.mode == "image_slides": + return self._create_pptx_image_slides(spec) + else: + return self._create_pptx_structured(spec) + + def _create_pptx_structured(self, spec: PresentationSpec) -> bytes: + """Create PPTX with structured DSL slides (traditional mode).""" try: from pptx import Presentation except ImportError: raise ImportError("python-pptx is required for PPTX handling. Please install 'python-pptx'.") + from app.tools.utils.documents.image_fetcher import ImageFetcher + prs = Presentation() + image_fetcher = ImageFetcher() # Layout mapping LAYOUTS = { @@ -849,20 +860,9 @@ def _create_pptx_from_spec(self, spec: PresentationSpec) -> bytes: if slide_spec.subtitle and len(slide.placeholders) > 1: slide.placeholders[1].text = slide_spec.subtitle # type: ignore[union-attr] - # Add content to body placeholder - if slide_spec.content and len(slide.placeholders) > 1: - body = slide.placeholders[1] - if hasattr(body, "text_frame"): - tf = body.text_frame # type: ignore[union-attr] - for i, block in enumerate(slide_spec.content): - if block.type == "text": - p = tf.paragraphs[0] if i == 0 else tf.add_paragraph() - p.text = block.content # type: ignore[union-attr] - elif block.type == "list": - for item in block.items: # type: ignore[union-attr] - p = tf.add_paragraph() - p.text = item - p.level = 0 + # Render content blocks + if slide_spec.content: + self._render_content_blocks(slide, slide_spec.content, image_fetcher) # Add speaker notes if slide_spec.notes: @@ -878,6 +878,373 @@ def _create_pptx_from_spec(self, spec: PresentationSpec) -> bytes: prs.save(buffer) return buffer.getvalue() + def _create_pptx_image_slides(self, spec: PresentationSpec) -> bytes: + """Create PPTX with full-bleed images as slides.""" + try: + from pptx import Presentation + from pptx.util import Inches, Pt + except ImportError: + raise ImportError("python-pptx is required for PPTX handling. Please install 'python-pptx'.") + + from app.tools.utils.documents.image_fetcher import ImageFetcher + + prs = Presentation() + # Set slide dimensions (16:9 widescreen) + prs.slide_width = Inches(13.333) + prs.slide_height = Inches(7.5) + + image_fetcher = ImageFetcher() + blank_layout = prs.slide_layouts[6] # Blank layout + + for slide_spec in spec.image_slides: + slide = prs.slides.add_slide(blank_layout) + + # Use storage_url if available (resolved by async layer), otherwise fall back to image_id + if slide_spec.storage_url: + result = image_fetcher.fetch(url=slide_spec.storage_url) + else: + result = image_fetcher.fetch(image_id=slide_spec.image_id) + + if result.success and result.data: + # Add full-bleed image (0,0 to full slide dimensions) + image_stream = io.BytesIO(result.data) + slide.shapes.add_picture( + image_stream, + Inches(0), + Inches(0), + prs.slide_width, + prs.slide_height, + ) + else: + # Add error text for failed images + text_box = slide.shapes.add_textbox(Inches(1), Inches(3), Inches(11), Inches(1)) + tf = text_box.text_frame + tf.paragraphs[0].text = f"[Slide image failed: {result.error}]" + tf.paragraphs[0].font.size = Pt(24) + tf.paragraphs[0].font.italic = True + + # Add speaker notes + if slide_spec.notes: + notes_slide = slide.notes_slide + if notes_slide.notes_text_frame: + notes_slide.notes_text_frame.text = slide_spec.notes + + # Ensure at least one slide exists + if not prs.slides: + slide = prs.slides.add_slide(blank_layout) + + buffer = io.BytesIO() + prs.save(buffer) + return buffer.getvalue() + + def _render_content_blocks( + self, + slide: Any, + content_blocks: list[Any], + image_fetcher: Any, + ) -> None: + """Render all content blocks on a slide with vertical stacking.""" + + # Content area dimensions (below title) + CONTENT_LEFT = 0.5 # inches + CONTENT_TOP = 1.8 # inches + CONTENT_WIDTH = 9.0 # inches + CONTENT_BOTTOM = 7.0 # inches + + current_y = CONTENT_TOP + + for block in content_blocks: + if current_y >= CONTENT_BOTTOM: + logger.warning("Slide content area full, skipping remaining blocks") + break + + remaining_height = CONTENT_BOTTOM - current_y + + if block.type == "text": + height = self._render_text_block(slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH) + elif block.type == "list": + height = self._render_list_block(slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH) + elif block.type == "image": + height = self._render_image_block( + slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH, remaining_height, image_fetcher + ) + elif block.type == "table": + height = self._render_table_block( + slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH, remaining_height + ) + elif block.type == "heading": + height = self._render_heading_block(slide, block, CONTENT_LEFT, current_y, CONTENT_WIDTH) + else: + # Unknown block type, skip + height = 0.0 + + current_y += height + + def _render_text_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + ) -> float: + """Render a text block. Returns height in inches.""" + from pptx.util import Inches, Pt + + # Estimate height based on text length + chars_per_line = int(max_width * 12) # ~12 chars per inch at 12pt + num_lines = max(1, len(block.content) // chars_per_line + 1) + box_height = num_lines * 0.25 # ~0.25 inches per line + + text_box = slide.shapes.add_textbox( + Inches(left), + Inches(top), + Inches(max_width), + Inches(box_height), + ) + + tf = text_box.text_frame + tf.word_wrap = True + p = tf.paragraphs[0] + p.text = block.content + p.font.size = Pt(12) + + # Apply style + if hasattr(block, "style"): + if block.style == "bold": + p.font.bold = True + elif block.style == "italic": + p.font.italic = True + elif block.style == "code": + p.font.name = "Courier New" + p.font.size = Pt(10) + + return box_height + 0.1 # Add margin + + def _render_list_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + ) -> float: + """Render a list block. Returns height in inches.""" + from pptx.util import Inches, Pt + + num_items = len(block.items) + item_height = 0.3 # inches per item + box_height = num_items * item_height + + text_box = slide.shapes.add_textbox( + Inches(left), + Inches(top), + Inches(max_width), + Inches(box_height), + ) + + tf = text_box.text_frame + tf.word_wrap = True + + for i, item in enumerate(block.items): + p = tf.paragraphs[0] if i == 0 else tf.add_paragraph() + prefix = f"{i + 1}. " if block.ordered else "• " + p.text = prefix + item + p.font.size = Pt(12) + p.level = 0 + + return box_height + 0.1 + + def _render_heading_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + ) -> float: + """Render a heading block. Returns height in inches.""" + from pptx.util import Inches, Pt + + # Font sizes by heading level + HEADING_SIZES = { + 1: 24, + 2: 20, + 3: 18, + 4: 16, + 5: 14, + 6: 12, + } + + level = getattr(block, "level", 1) + font_size = HEADING_SIZES.get(level, 14) + box_height = font_size / 72.0 * 1.5 # Convert to inches with padding + + text_box = slide.shapes.add_textbox( + Inches(left), + Inches(top), + Inches(max_width), + Inches(box_height), + ) + + tf = text_box.text_frame + p = tf.paragraphs[0] + p.text = block.content + p.font.size = Pt(font_size) + p.font.bold = True + + return box_height + 0.1 + + def _render_table_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + max_height: float, + ) -> float: + """Render a table block. Returns height in inches.""" + from pptx.dml.color import RGBColor + from pptx.enum.text import PP_ALIGN + from pptx.util import Inches, Pt + + num_cols = len(block.headers) + if num_cols == 0: + return 0.0 + + num_rows = 1 + len(block.rows) # Header + data rows + + # Calculate row height + row_height = 0.4 # inches + table_height = num_rows * row_height + + # Cap table height to available space + max_table_height = min(4.0, max_height - 0.2) + if table_height > max_table_height: + table_height = max_table_height + row_height = table_height / num_rows + + # Create table + table_shape = slide.shapes.add_table( + num_rows, + num_cols, + Inches(left), + Inches(top), + Inches(max_width), + Inches(table_height), + ) + table = table_shape.table + + # Style header row + for i, header in enumerate(block.headers): + cell = table.cell(0, i) + cell.text = str(header) + cell.fill.solid() + cell.fill.fore_color.rgb = RGBColor(0x44, 0x72, 0xC4) + + # Set text properties + if cell.text_frame.paragraphs: + para = cell.text_frame.paragraphs[0] + para.font.bold = True + para.font.size = Pt(11) + para.font.color.rgb = RGBColor(0xFF, 0xFF, 0xFF) + para.alignment = PP_ALIGN.CENTER + + # Fill data rows + for row_idx, row_data in enumerate(block.rows): + for col_idx, cell_val in enumerate(row_data): + if col_idx < num_cols: + cell = table.cell(row_idx + 1, col_idx) + cell.text = str(cell_val) + if cell.text_frame.paragraphs: + para = cell.text_frame.paragraphs[0] + para.font.size = Pt(10) + + return table_height + 0.2 + + def _render_image_block( + self, + slide: Any, + block: Any, + left: float, + top: float, + max_width: float, + max_height: float, + image_fetcher: Any, + ) -> float: + """Render an image block. Returns height in inches.""" + from pptx.enum.text import PP_ALIGN + from pptx.util import Inches, Pt + + # Fetch image by url or image_id + url = getattr(block, "url", None) + image_id = getattr(block, "image_id", None) + result = image_fetcher.fetch(url=url, image_id=image_id) + + if not result.success: + # Add error placeholder + text_box = slide.shapes.add_textbox(Inches(left), Inches(top), Inches(max_width), Inches(0.5)) + tf = text_box.text_frame + tf.paragraphs[0].text = f"[Image failed to load: {result.error}]" + tf.paragraphs[0].font.italic = True + tf.paragraphs[0].font.size = Pt(10) + return 0.6 + + # Calculate image dimensions + if block.width: + # Use specified width (in points, convert to inches) + img_width = block.width / 72.0 + elif result.width and result.height: + # Scale to fit max_width while maintaining aspect ratio + img_width = min(max_width * 0.8, result.width / 96.0) # 96 DPI assumption, 80% max width + else: + img_width = min(max_width * 0.6, 4.0) # Default 4 inches or 60% width + + # Calculate height maintaining aspect ratio + if result.width and result.height: + aspect = result.height / result.width + img_height = img_width * aspect + else: + img_height = img_width * 0.75 # Default 4:3 aspect + + # Cap height to available space (leave room for caption) + caption_space = 0.5 if block.caption else 0.1 + available_height = max_height - caption_space + if img_height > available_height: + scale = available_height / img_height + img_height = available_height + img_width = img_width * scale + + # Center image horizontally + img_left = left + (max_width - img_width) / 2 + + # Add image to slide + image_stream = io.BytesIO(result.data) + slide.shapes.add_picture( + image_stream, + Inches(img_left), + Inches(top), + Inches(img_width), + Inches(img_height), + ) + + total_height = img_height + 0.1 + + # Add caption if present + if block.caption: + caption_top = top + img_height + 0.1 + caption_box = slide.shapes.add_textbox(Inches(left), Inches(caption_top), Inches(max_width), Inches(0.3)) + tf = caption_box.text_frame + p = tf.paragraphs[0] + p.text = block.caption + p.alignment = PP_ALIGN.CENTER + p.font.size = Pt(10) + p.font.italic = True + total_height += 0.4 + + return total_height + class FileHandlerFactory: """Factory to get the appropriate file handler based on filename.""" diff --git a/service/app/tools/utils/documents/image_fetcher.py b/service/app/tools/utils/documents/image_fetcher.py new file mode 100644 index 00000000..116f9734 --- /dev/null +++ b/service/app/tools/utils/documents/image_fetcher.py @@ -0,0 +1,271 @@ +""" +Image fetching service for document generation. + +Handles HTTP URLs, base64 data URLs, and storage:// protocol. +Designed for synchronous use in document handlers. +""" + +from __future__ import annotations + +import base64 +import io +import logging +import re +from dataclasses import dataclass +from typing import Any, Coroutine, TypeVar + +import httpx +from PIL import Image as PILImage + +logger = logging.getLogger(__name__) + +# Constants +MAX_IMAGE_SIZE_BYTES = 10 * 1024 * 1024 # 10MB +MAX_IMAGE_DIMENSION = 4096 # pixels +DEFAULT_TIMEOUT = 30.0 + +T = TypeVar("T") + + +def _run_async(coro: Coroutine[Any, Any, T]) -> T: + """ + Run an async coroutine from sync code, handling existing event loops. + + When called from within an already-running event loop (e.g., Celery worker), + asyncio.run() fails. This helper uses a thread pool to safely execute + async code in such cases. + """ + import asyncio + import concurrent.futures + + try: + # Check if there's already a running event loop + asyncio.get_running_loop() + # We're in an async context - run in a thread pool + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(asyncio.run, coro) + return future.result() + except RuntimeError: + # No running loop - safe to use asyncio.run + return asyncio.run(coro) + + +@dataclass +class FetchedImage: + """Result of an image fetch operation.""" + + success: bool + data: bytes | None = None + format: str | None = None # "png", "jpeg", etc. + width: int | None = None + height: int | None = None + error: str | None = None + + +class ImageFetcher: + """ + Fetches images from various sources for document embedding. + + Supports: + - HTTP/HTTPS URLs + - Base64 data URLs (data:image/png;base64,...) + - storage:// protocol (internal file storage) + """ + + def __init__( + self, + timeout: float = DEFAULT_TIMEOUT, + max_size_bytes: int = MAX_IMAGE_SIZE_BYTES, + max_dimension: int = MAX_IMAGE_DIMENSION, + ): + self.timeout = timeout + self.max_size_bytes = max_size_bytes + self.max_dimension = max_dimension + + def fetch(self, url: str | None = None, image_id: str | None = None) -> FetchedImage: + """ + Fetch an image from the given URL or resolve image_id to storage. + + Args: + url: HTTP URL, base64 data URL, or storage:// URL (takes precedence if provided) + image_id: UUID of a generated image from generate_image tool (fallback if no url) + + Returns: + FetchedImage with data or error information + """ + try: + # Prefer URL over image_id when both are present + # This allows the async layer to resolve image_ids to URLs beforehand + if url: + if url.startswith("data:"): + return self._fetch_base64(url) + elif url.startswith("storage://"): + return self._fetch_from_storage(url) + elif url.startswith(("http://", "https://")): + return self._fetch_http(url) + else: + return FetchedImage(success=False, error=f"Unsupported URL scheme: {url[:50]}") + elif image_id: + # Fallback to image_id if no URL provided + return self._fetch_by_image_id(image_id) + else: + return FetchedImage(success=False, error="Either url or image_id must be provided") + except Exception as e: + logger.error(f"Image fetch failed: {e}") + return FetchedImage(success=False, error=str(e)) + + def _fetch_http(self, url: str) -> FetchedImage: + """Fetch image from HTTP/HTTPS URL.""" + try: + with httpx.Client(timeout=self.timeout, follow_redirects=True) as client: + response = client.get(url) + response.raise_for_status() + + # Check size from header + content_length = response.headers.get("content-length") + if content_length and int(content_length) > self.max_size_bytes: + return FetchedImage( + success=False, + error=f"Image too large: {int(content_length)} bytes (max {self.max_size_bytes})", + ) + + data = response.content + if len(data) > self.max_size_bytes: + return FetchedImage( + success=False, + error=f"Image too large: {len(data)} bytes (max {self.max_size_bytes})", + ) + + return self._process_image_data(data) + + except httpx.TimeoutException: + return FetchedImage(success=False, error=f"Timeout fetching image: {url[:100]}") + except httpx.HTTPStatusError as e: + return FetchedImage(success=False, error=f"HTTP error {e.response.status_code}: {url[:100]}") + except httpx.RequestError as e: + return FetchedImage(success=False, error=f"Request error: {e}") + + def _fetch_base64(self, data_url: str) -> FetchedImage: + """Decode base64 data URL.""" + # Format: data:image/png;base64, + match = re.match(r"data:image/(\w+);base64,(.+)", data_url, re.DOTALL) + if not match: + return FetchedImage(success=False, error="Invalid base64 data URL format") + + format_hint = match.group(1).lower() + b64_data = match.group(2) + + try: + data = base64.b64decode(b64_data) + except Exception as e: + return FetchedImage(success=False, error=f"Base64 decode failed: {e}") + + if len(data) > self.max_size_bytes: + return FetchedImage( + success=False, + error=f"Image too large: {len(data)} bytes (max {self.max_size_bytes})", + ) + + return self._process_image_data(data, format_hint) + + def _fetch_from_storage(self, storage_url: str) -> FetchedImage: + """ + Fetch image from internal storage. + + Uses _run_async to execute async storage download in sync context. + """ + from app.core.storage import get_storage_service + + # Extract storage key: storage://path/to/file.png -> path/to/file.png + storage_key = storage_url.replace("storage://", "") + + try: + storage = get_storage_service() + buffer = io.BytesIO() + + # Run async download in sync context + _run_async(storage.download_file(storage_key, buffer)) + + data = buffer.getvalue() + if len(data) > self.max_size_bytes: + return FetchedImage( + success=False, + error=f"Image too large: {len(data)} bytes (max {self.max_size_bytes})", + ) + + return self._process_image_data(data) + + except Exception as e: + return FetchedImage(success=False, error=f"Storage fetch failed: {e}") + + def _fetch_by_image_id(self, image_id: str) -> FetchedImage: + """ + Handle image_id parameter. + + Image IDs should be resolved to storage URLs in the async layer (operations.py) + before reaching this sync code. If this method is called, it means the proper + flow wasn't followed. + + For backward compatibility, we return a clear error message. + """ + from uuid import UUID + + # Validate UUID format first + try: + UUID(image_id) + except ValueError: + return FetchedImage(success=False, error=f"Invalid image_id format: {image_id}") + + # Return error explaining the proper flow + return FetchedImage( + success=False, + error=( + f"image_id '{image_id}' was not resolved to a storage URL. " + "Image IDs must be resolved in the async layer before document generation. " + "Use knowledge_write tool which handles this automatically." + ), + ) + + def _process_image_data(self, data: bytes, format_hint: str | None = None) -> FetchedImage: + """ + Process raw image data: validate, get dimensions, optionally resize. + """ + try: + img = PILImage.open(io.BytesIO(data)) + + # Get actual format + img_format = (img.format or format_hint or "PNG").lower() + if img_format == "jpeg": + img_format = "jpg" + + width, height = img.size + + # Resize if too large + if width > self.max_dimension or height > self.max_dimension: + ratio = min(self.max_dimension / width, self.max_dimension / height) + new_width = int(width * ratio) + new_height = int(height * ratio) + img = img.resize((new_width, new_height), PILImage.Resampling.LANCZOS) + width, height = new_width, new_height + + # Re-encode + output = io.BytesIO() + save_format = "PNG" if img_format == "png" else "JPEG" + if img.mode in ("RGBA", "P") and save_format == "JPEG": + img = img.convert("RGB") + img.save(output, format=save_format) + data = output.getvalue() + + return FetchedImage( + success=True, + data=data, + format=img_format, + width=width, + height=height, + ) + + except Exception as e: + return FetchedImage(success=False, error=f"Image processing failed: {e}") + + +__all__ = ["ImageFetcher", "FetchedImage", "MAX_IMAGE_SIZE_BYTES", "MAX_IMAGE_DIMENSION", "DEFAULT_TIMEOUT"] diff --git a/service/app/mcp/document_spec.py b/service/app/tools/utils/documents/spec.py similarity index 79% rename from service/app/mcp/document_spec.py rename to service/app/tools/utils/documents/spec.py index 62fb5279..c87d84ef 100644 --- a/service/app/mcp/document_spec.py +++ b/service/app/tools/utils/documents/spec.py @@ -62,7 +62,11 @@ class ImageBlock(BaseModel): """An image block.""" type: Literal["image"] = "image" - url: str = Field(description="Image URL or base64 data URL") + url: str | None = Field(default=None, description="Image URL or base64 data URL") + image_id: str | None = Field( + default=None, + description="UUID of generated image from generate_image tool", + ) caption: str | None = Field(default=None, description="Optional image caption") width: int | None = Field(default=None, description="Optional width in points/pixels") @@ -186,11 +190,26 @@ class SlideSpec(BaseModel): notes: str | None = Field(default=None, description="Speaker notes") +class ImageSlideSpec(BaseModel): + """Specification for an image-only slide (full-bleed generated image).""" + + image_id: str = Field(description="UUID of the generated slide image from generate_image tool") + storage_url: str | None = Field( + default=None, + description="Resolved storage URL (set by async layer, not by user)", + ) + notes: str | None = Field(default=None, description="Speaker notes for this slide") + + class PresentationSpec(BaseModel): """ Production-ready presentation specification for PPTX generation. - Example: + Supports two modes: + - structured: Traditional slides with DSL content blocks (default) + - image_slides: Full-bleed AI-generated image slides + + Example (structured mode): ```json { "title": "Q4 Review", @@ -212,13 +231,39 @@ class PresentationSpec(BaseModel): ] } ``` + + Example (image_slides mode): + ```json + { + "mode": "image_slides", + "title": "Q4 Review", + "image_slides": [ + {"image_id": "abc-123-...", "notes": "Welcome everyone"}, + {"image_id": "def-456-...", "notes": "Revenue summary"} + ] + } + ``` """ title: str | None = Field(default=None, description="Presentation title") author: str | None = Field(default=None, description="Presentation author") + + # Mode selection + mode: Literal["structured", "image_slides"] = Field( + default="structured", + description="'structured' for DSL slides, 'image_slides' for full-bleed generated images", + ) + + # For structured mode slides: list[SlideSpec] = Field( default_factory=list, - description="List of slide specifications", + description="List of slide specifications (for structured mode)", + ) + + # For image_slides mode + image_slides: list[ImageSlideSpec] = Field( + default_factory=list, + description="List of image slide specifications (for image_slides mode)", ) @@ -234,5 +279,6 @@ class PresentationSpec(BaseModel): "SheetSpec", "SpreadsheetSpec", "SlideSpec", + "ImageSlideSpec", "PresentationSpec", ] diff --git a/service/tests/unit/agents/test_factory.py b/service/tests/unit/agents/test_factory.py new file mode 100644 index 00000000..85bfc6da --- /dev/null +++ b/service/tests/unit/agents/test_factory.py @@ -0,0 +1,132 @@ +"""Tests for agent factory module.""" + +from app.agents.factory import _inject_system_prompt + + +class TestInjectSystemPrompt: + """Test _inject_system_prompt function.""" + + def test_inject_into_llm_node(self) -> None: + """Test system prompt injection into LLM node.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "agent", + "type": "llm", + "llm_config": { + "prompt_template": "Default prompt", + "tools_enabled": True, + }, + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # Original should not be mutated + assert config_dict["nodes"][0]["llm_config"]["prompt_template"] == "Default prompt" + + # Result should have the new prompt + assert result["nodes"][0]["llm_config"]["prompt_template"] == "Custom system prompt" + + def test_inject_into_component_node(self) -> None: + """Test system prompt injection into react component node.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "agent", + "type": "component", + "component_config": { + "component_ref": {"key": "react"}, + }, + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # Original should not be mutated + assert "config_overrides" not in config_dict["nodes"][0]["component_config"] + + # Result should have config_overrides with system_prompt + assert result["nodes"][0]["component_config"]["config_overrides"]["system_prompt"] == "Custom system prompt" + + def test_inject_only_into_first_matching_node(self) -> None: + """Test that system prompt is only injected into the first matching node.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "agent1", + "type": "llm", + "llm_config": { + "prompt_template": "Prompt 1", + }, + }, + { + "id": "agent2", + "type": "llm", + "llm_config": { + "prompt_template": "Prompt 2", + }, + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # First node should be updated + assert result["nodes"][0]["llm_config"]["prompt_template"] == "Custom system prompt" + # Second node should remain unchanged + assert result["nodes"][1]["llm_config"]["prompt_template"] == "Prompt 2" + + def test_llm_node_takes_precedence_over_non_react_component(self) -> None: + """Test that LLM nodes are preferred over non-react components.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "other", + "type": "component", + "component_config": { + "component_ref": {"key": "other_component"}, + }, + }, + { + "id": "agent", + "type": "llm", + "llm_config": { + "prompt_template": "Default", + }, + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # LLM node should be updated (other component is not react) + assert result["nodes"][1]["llm_config"]["prompt_template"] == "Custom system prompt" + + def test_no_matching_nodes(self) -> None: + """Test graceful handling when no matching nodes exist.""" + config_dict = { + "version": "2.0", + "nodes": [ + { + "id": "transform", + "type": "transform", + }, + ], + "edges": [], + } + + result = _inject_system_prompt(config_dict, "Custom system prompt") + + # Should return config unchanged + assert result == config_dict diff --git a/service/tests/unit/agents/test_graph_builder.py b/service/tests/unit/agents/test_graph_builder.py index cec6f1c3..9af67348 100644 --- a/service/tests/unit/agents/test_graph_builder.py +++ b/service/tests/unit/agents/test_graph_builder.py @@ -1,8 +1,10 @@ """Tests for graph_builder module.""" -from unittest.mock import AsyncMock +from typing import Any +from unittest.mock import AsyncMock, MagicMock import pytest +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from pydantic import BaseModel from app.agents.graph_builder import GraphBuilder, build_state_class @@ -112,3 +114,173 @@ def test_build_simple_graph(self) -> None: ) graph = builder.build() assert graph is not None + + +class TestPromptTemplateAsSystemMessage: + """Test that prompt_template is prepended as SystemMessage.""" + + @pytest.mark.asyncio + async def test_prompt_template_prepends_system_message(self) -> None: + """Test prompt_template is prepended as SystemMessage, not appended as HumanMessage.""" + config = GraphConfig( + nodes=[ + GraphNodeConfig( + id="agent", + name="Agent", + type=NodeType.LLM, + llm_config=LLMNodeConfig( + prompt_template="You are a helpful assistant.", + tools_enabled=False, + ), + ), + ], + edges=[ + GraphEdgeConfig(from_node="START", to_node="agent"), + GraphEdgeConfig(from_node="agent", to_node="END"), + ], + entry_point="agent", + ) + + # Create mock LLM that captures the messages it receives + captured_messages: list[BaseMessage] = [] + + async def mock_llm_factory(model: str | None = None, temperature: float | None = None) -> Any: + mock_llm = MagicMock() + + async def capture_invoke(messages: list[BaseMessage]) -> AIMessage: + captured_messages.clear() + captured_messages.extend(messages) + return AIMessage(content="Response") + + mock_llm.ainvoke = capture_invoke + return mock_llm + + builder = GraphBuilder( + config=config, + llm_factory=mock_llm_factory, + tool_registry={}, + ) + graph = await builder.build() + + # Invoke with a user message + initial_state: dict[str, list[BaseMessage]] = { + "messages": [HumanMessage(content="Hello")], + } + await graph.ainvoke(initial_state) # type: ignore[arg-type] # type: ignore[arg-type] + + # Verify SystemMessage is first, not HumanMessage at the end + assert len(captured_messages) >= 2 + assert isinstance(captured_messages[0], SystemMessage) + assert captured_messages[0].content == "You are a helpful assistant." + assert isinstance(captured_messages[1], HumanMessage) + assert captured_messages[1].content == "Hello" + + @pytest.mark.asyncio + async def test_no_prompt_template_passes_messages_unchanged(self) -> None: + """Test that empty prompt_template doesn't add any system message.""" + config = GraphConfig( + nodes=[ + GraphNodeConfig( + id="agent", + name="Agent", + type=NodeType.LLM, + llm_config=LLMNodeConfig( + prompt_template="", # Empty + tools_enabled=False, + ), + ), + ], + edges=[ + GraphEdgeConfig(from_node="START", to_node="agent"), + GraphEdgeConfig(from_node="agent", to_node="END"), + ], + entry_point="agent", + ) + + captured_messages: list[BaseMessage] = [] + + async def mock_llm_factory(model: str | None = None, temperature: float | None = None) -> Any: + mock_llm = MagicMock() + + async def capture_invoke(messages: list[BaseMessage]) -> AIMessage: + captured_messages.clear() + captured_messages.extend(messages) + return AIMessage(content="Response") + + mock_llm.ainvoke = capture_invoke + return mock_llm + + builder = GraphBuilder( + config=config, + llm_factory=mock_llm_factory, + tool_registry={}, + ) + graph = await builder.build() + + initial_state: dict[str, list[BaseMessage]] = { + "messages": [HumanMessage(content="Hello")], + } + await graph.ainvoke(initial_state) # type: ignore[arg-type] + + # Should only have the original HumanMessage + assert len(captured_messages) == 1 + assert isinstance(captured_messages[0], HumanMessage) + assert captured_messages[0].content == "Hello" + + @pytest.mark.asyncio + async def test_existing_system_message_replaced(self) -> None: + """Test that existing SystemMessage in messages is replaced.""" + config = GraphConfig( + nodes=[ + GraphNodeConfig( + id="agent", + name="Agent", + type=NodeType.LLM, + llm_config=LLMNodeConfig( + prompt_template="New system prompt", + tools_enabled=False, + ), + ), + ], + edges=[ + GraphEdgeConfig(from_node="START", to_node="agent"), + GraphEdgeConfig(from_node="agent", to_node="END"), + ], + entry_point="agent", + ) + + captured_messages: list[BaseMessage] = [] + + async def mock_llm_factory(model: str | None = None, temperature: float | None = None) -> Any: + mock_llm = MagicMock() + + async def capture_invoke(messages: list[BaseMessage]) -> AIMessage: + captured_messages.clear() + captured_messages.extend(messages) + return AIMessage(content="Response") + + mock_llm.ainvoke = capture_invoke + return mock_llm + + builder = GraphBuilder( + config=config, + llm_factory=mock_llm_factory, + tool_registry={}, + ) + graph = await builder.build() + + # Input with existing SystemMessage + initial_state: dict[str, list[BaseMessage]] = { + "messages": [ + SystemMessage(content="Old system prompt"), + HumanMessage(content="Hello"), + ], + } + await graph.ainvoke(initial_state) # type: ignore[arg-type] + + # Should have new SystemMessage first, no duplicates + system_messages = [m for m in captured_messages if isinstance(m, SystemMessage)] + assert len(system_messages) == 1 + assert system_messages[0].content == "New system prompt" + assert isinstance(captured_messages[0], SystemMessage) + assert isinstance(captured_messages[1], HumanMessage) diff --git a/service/tests/unit/handler/mcp/test_file_handlers.py b/service/tests/unit/handler/mcp/test_file_handlers.py deleted file mode 100644 index 940382f1..00000000 --- a/service/tests/unit/handler/mcp/test_file_handlers.py +++ /dev/null @@ -1,374 +0,0 @@ -""" -Tests for file handlers. -""" - -import json -from unittest.mock import MagicMock, patch - -import pytest - -from app.mcp.document_spec import ( - DocumentSpec, - HeadingBlock, - ListBlock, - PresentationSpec, - SheetSpec, - SlideSpec, - SpreadsheetSpec, - TableBlock, - TextBlock, -) -from app.mcp.file_handlers import ( - DocxFileHandler, - ExcelFileHandler, - FileHandlerFactory, - HtmlFileHandler, - ImageFileHandler, - JsonFileHandler, - PdfFileHandler, - PptxFileHandler, - TextFileHandler, - XmlFileHandler, - YamlFileHandler, -) - - -class TestFileHandlerFactory: - def test_get_handler(self) -> None: - # Existing handlers - assert isinstance(FileHandlerFactory.get_handler("test.pdf"), PdfFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.docx"), DocxFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.doc"), DocxFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.xlsx"), ExcelFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.xls"), ExcelFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.pptx"), PptxFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.ppt"), PptxFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.txt"), TextFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.csv"), TextFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.py"), TextFileHandler) - - def test_get_handler_new_types(self) -> None: - # New handlers - assert isinstance(FileHandlerFactory.get_handler("test.html"), HtmlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.htm"), HtmlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.json"), JsonFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.yaml"), YamlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.yml"), YamlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.xml"), XmlFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.png"), ImageFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.jpg"), ImageFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.jpeg"), ImageFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.gif"), ImageFileHandler) - assert isinstance(FileHandlerFactory.get_handler("test.webp"), ImageFileHandler) - - -class TestTextFileHandler: - def test_read_write(self) -> None: - handler = TextFileHandler() - content = "Hello, World!" - - # Write - bytes_content = handler.create_content(content) - assert isinstance(bytes_content, bytes) - assert bytes_content == b"Hello, World!" - - # Read - read_content = handler.read_content(bytes_content) - assert read_content == content - - def test_read_image_fail(self) -> None: - handler = TextFileHandler() - with pytest.raises(ValueError): - handler.read_content(b"test", mode="image") - - -class TestHtmlFileHandler: - def test_read_html(self) -> None: - handler = HtmlFileHandler() - html = b"

Title

Content

" - content = handler.read_content(html) - assert isinstance(content, str) - # Should extract text from HTML - assert "Title" in content or "Content" in content - - def test_read_html_strips_scripts(self) -> None: - handler = HtmlFileHandler() - html = b"Safe content" - content = handler.read_content(html) - assert "alert" not in content - assert "Safe content" in content - - def test_create_html(self) -> None: - handler = HtmlFileHandler() - content = handler.create_content("Hello\n\nWorld") - assert b"" in content - assert b"" in content - assert b"Hello" in content - - def test_read_image_fail(self) -> None: - handler = HtmlFileHandler() - with pytest.raises(ValueError): - handler.read_content(b"", mode="image") - - -class TestJsonFileHandler: - def test_read_json(self) -> None: - handler = JsonFileHandler() - data = {"key": "value", "nested": {"a": 1}} - json_bytes = json.dumps(data).encode() - - content = handler.read_content(json_bytes) - assert "key" in content - assert "value" in content - - def test_read_invalid_json(self) -> None: - handler = JsonFileHandler() - content = handler.read_content(b"not valid json") - assert content == "not valid json" - - def test_create_json_valid(self) -> None: - handler = JsonFileHandler() - data = '{"key": "value"}' - result = handler.create_content(data) - parsed = json.loads(result) - assert parsed["key"] == "value" - - def test_create_json_invalid_wraps(self) -> None: - handler = JsonFileHandler() - result = handler.create_content("plain text") - parsed = json.loads(result) - assert "content" in parsed - assert parsed["content"] == "plain text" - - -class TestYamlFileHandler: - def test_read_yaml(self) -> None: - handler = YamlFileHandler() - yaml_content = b"key: value\nnested:\n a: 1" - content = handler.read_content(yaml_content) - assert "key" in content - assert "value" in content - - def test_create_yaml_from_json(self) -> None: - handler = YamlFileHandler() - json_input = '{"key": "value"}' - result = handler.create_content(json_input) - assert b"key: value" in result - - -class TestXmlFileHandler: - def test_read_xml(self) -> None: - handler = XmlFileHandler() - xml = b"Hello" - content = handler.read_content(xml) - assert "Hello" in content - assert "item" in content - - def test_create_xml(self) -> None: - handler = XmlFileHandler() - result = handler.create_content("Test content") - assert b"" in result - assert b"Test content" in result - - def test_read_image_fail(self) -> None: - handler = XmlFileHandler() - with pytest.raises(ValueError): - handler.read_content(b"", mode="image") - - -class TestImageFileHandler: - def test_detect_format_png(self) -> None: - handler = ImageFileHandler() - png_magic = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 - assert handler._detect_format(png_magic) == "png" - - def test_detect_format_jpeg(self) -> None: - handler = ImageFileHandler() - jpeg_magic = b"\xff\xd8" + b"\x00" * 100 - assert handler._detect_format(jpeg_magic) == "jpeg" - - def test_detect_format_gif(self) -> None: - handler = ImageFileHandler() - gif_magic = b"GIF89a" + b"\x00" * 100 - assert handler._detect_format(gif_magic) == "gif" - - def test_create_raises_error(self) -> None: - handler = ImageFileHandler() - with pytest.raises(ValueError, match="Cannot create image"): - handler.create_content("text") - - -# Only mock external deps for complex handlers if they are not installed in test env -# But for now we assume we might need to mock them to run this in strict CI envs -# where deps might be missing during dev. - - -@patch("fitz.open") -@patch("fitz.Matrix") -class TestPdfFileHandler: - def test_read_text(self, mock_matrix: MagicMock, mock_open: MagicMock) -> None: - handler = PdfFileHandler() - mock_doc = MagicMock() - mock_page = MagicMock() - mock_page.get_text.return_value = "Page text" - mock_page.find_tables.return_value = MagicMock(tables=[]) - mock_doc.__iter__.return_value = [mock_page] - mock_open.return_value = mock_doc - - content = handler.read_content(b"pdf_bytes", mode="text") - assert content == "Page text" - mock_open.assert_called_with(stream=b"pdf_bytes", filetype="pdf") - - def test_write_plain_text(self, mock_matrix: MagicMock, mock_open: MagicMock) -> None: - handler = PdfFileHandler() - # For plain text, it uses reportlab, not fitz - result = handler.create_content("Some text") - assert isinstance(result, bytes) - # PDF magic bytes - assert result[:4] == b"%PDF" - - -@patch("docx.Document") -class TestDocxFileHandler: - def test_read(self, mock_document_cls: MagicMock) -> None: - handler = DocxFileHandler() - mock_doc = MagicMock() - mock_element = MagicMock() - mock_element.tag = "p" - mock_element.iter.return_value = [MagicMock(text="Para 1")] - mock_doc.element.body = [mock_element] - mock_document_cls.return_value = mock_doc - - content = handler.read_content(b"docx_bytes") - assert "Para 1" in content - - def test_write_plain_text(self, mock_document_cls: MagicMock) -> None: - handler = DocxFileHandler() - mock_doc = MagicMock() - mock_document_cls.return_value = mock_doc - - handler.create_content("Line 1\nLine 2") - - assert mock_doc.add_paragraph.call_count == 2 - mock_doc.save.assert_called() - - -@patch("openpyxl.Workbook") -@patch("openpyxl.load_workbook") -class TestExcelFileHandler: - def test_read(self, mock_load_workbook: MagicMock, mock_workbook: MagicMock) -> None: - handler = ExcelFileHandler() - mock_wb = MagicMock() - mock_ws = MagicMock() - mock_wb.sheetnames = ["Sheet1"] - mock_wb.__getitem__.return_value = mock_ws - mock_ws.iter_rows.return_value = [("A", "B")] - mock_load_workbook.return_value = mock_wb - - content = handler.read_content(b"xlsx_bytes") - assert "Sheet1" in content - assert "A\tB" in content - - def test_write_csv(self, mock_load_workbook: MagicMock, mock_workbook: MagicMock) -> None: - handler = ExcelFileHandler() - mock_wb = MagicMock() - mock_ws = MagicMock() - mock_wb.active = mock_ws - mock_workbook.return_value = mock_wb - - handler.create_content("A,B\nC,D") - - assert mock_ws.append.call_count == 2 - mock_wb.save.assert_called() - - -@patch("pptx.Presentation") -class TestPptxFileHandler: - def test_read(self, mock_presentation: MagicMock) -> None: - handler = PptxFileHandler() - mock_prs = MagicMock() - mock_slide = MagicMock() - mock_shape = MagicMock() - mock_shape.text = "Slide Text" - mock_slide.shapes = [mock_shape] - mock_slide.has_notes_slide = False - mock_prs.slides = [mock_slide] - mock_presentation.return_value = mock_prs - - content = handler.read_content(b"pptx_bytes") - assert "Slide Text" in content - - def test_write_plain_text(self, mock_presentation: MagicMock) -> None: - handler = PptxFileHandler() - mock_prs = MagicMock() - mock_slide = MagicMock() - mock_prs.slides.add_slide.return_value = mock_slide - mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] - mock_presentation.return_value = mock_prs - - handler.create_content("Title\nBody") - - mock_prs.slides.add_slide.assert_called() - mock_prs.save.assert_called() - - -# Document Spec Tests - - -class TestDocumentSpec: - def test_create_document_spec(self) -> None: - spec = DocumentSpec( - title="Test Doc", - author="Test Author", - content=[ - HeadingBlock(content="Chapter 1", level=1), - TextBlock(content="Some text here"), - ListBlock(items=["Item 1", "Item 2"], ordered=False), - TableBlock(headers=["A", "B"], rows=[["1", "2"]]), - ], - ) - assert spec.title == "Test Doc" - assert len(spec.content) == 4 - assert spec.content[0].type == "heading" - - def test_document_spec_json_roundtrip(self) -> None: - spec = DocumentSpec( - title="Test", - content=[TextBlock(content="Hello")], - ) - json_str = spec.model_dump_json() - parsed = DocumentSpec.model_validate_json(json_str) - assert parsed.title == spec.title - - -class TestSpreadsheetSpec: - def test_create_spreadsheet_spec(self) -> None: - spec = SpreadsheetSpec( - sheets=[ - SheetSpec( - name="Data", - headers=["Name", "Value"], - data=[["A", 1], ["B", 2]], - ) - ] - ) - assert len(spec.sheets) == 1 - assert spec.sheets[0].name == "Data" - - -class TestPresentationSpec: - def test_create_presentation_spec(self) -> None: - spec = PresentationSpec( - title="My Presentation", - slides=[ - SlideSpec(layout="title", title="Welcome", subtitle="Intro"), - SlideSpec( - layout="title_content", - title="Slide 2", - content=[ListBlock(items=["Point 1", "Point 2"])], - ), - ], - ) - assert len(spec.slides) == 2 - assert spec.slides[0].layout == "title" diff --git a/service/tests/unit/test_core/test_consume_strategy.py b/service/tests/unit/test_core/test_consume_strategy.py index 41a43ddd..b2c1c321 100644 --- a/service/tests/unit/test_core/test_consume_strategy.py +++ b/service/tests/unit/test_core/test_consume_strategy.py @@ -21,7 +21,7 @@ def test_default_values(self) -> None: assert context.output_tokens == 0 assert context.total_tokens == 0 assert context.content_length == 0 - assert context.generated_files_count == 0 + assert context.tool_costs == 0 def test_with_values(self) -> None: """Test ConsumptionContext with custom values.""" @@ -31,14 +31,14 @@ def test_with_values(self) -> None: output_tokens=500, total_tokens=1500, content_length=5000, - generated_files_count=2, + tool_costs=15, ) assert context.model_tier == ModelTier.PRO assert context.input_tokens == 1000 assert context.output_tokens == 500 assert context.total_tokens == 1500 assert context.content_length == 5000 - assert context.generated_files_count == 2 + assert context.tool_costs == 15 class TestTierBasedConsumptionStrategy: @@ -53,7 +53,6 @@ def test_lite_tier_is_free(self) -> None: output_tokens=5000, total_tokens=15000, content_length=50000, - generated_files_count=5, ) result = strategy.calculate(context) @@ -71,7 +70,6 @@ def test_standard_tier_base_multiplier(self) -> None: output_tokens=1000, total_tokens=2000, content_length=1000, - generated_files_count=0, ) result = strategy.calculate(context) @@ -92,7 +90,6 @@ def test_pro_tier_multiplier(self) -> None: output_tokens=1000, total_tokens=2000, content_length=1000, - generated_files_count=0, ) result = strategy.calculate(context) @@ -113,7 +110,6 @@ def test_ultra_tier_multiplier(self) -> None: output_tokens=1000, total_tokens=2000, content_length=1000, - generated_files_count=0, ) result = strategy.calculate(context) @@ -125,8 +121,8 @@ def test_ultra_tier_multiplier(self) -> None: assert result.amount == expected assert result.breakdown["tier_rate"] == 6.8 - def test_file_generation_cost(self) -> None: - """Test that file generation cost is included.""" + def test_tool_costs(self) -> None: + """Test that tool costs are included in calculation.""" strategy = TierBasedConsumptionStrategy() context = ConsumptionContext( model_tier=ModelTier.STANDARD, @@ -134,13 +130,14 @@ def test_file_generation_cost(self) -> None: output_tokens=0, total_tokens=0, content_length=0, - generated_files_count=2, + tool_costs=20, ) result = strategy.calculate(context) + # base_cost(1) + tool_costs(20) = 21 expected = int((1 + 20) * 1.0) assert result.amount == expected - assert result.breakdown["file_cost"] == 20 + assert result.breakdown["tool_costs"] == 20 def test_no_tier_defaults_to_1(self) -> None: """Test that None tier defaults to rate 1.0.""" @@ -151,7 +148,6 @@ def test_no_tier_defaults_to_1(self) -> None: output_tokens=1000, total_tokens=2000, content_length=1000, - generated_files_count=0, ) result = strategy.calculate(context) @@ -171,13 +167,13 @@ def test_breakdown_contains_all_fields(self) -> None: output_tokens=500, total_tokens=1500, content_length=1000, - generated_files_count=1, + tool_costs=10, ) result = strategy.calculate(context) assert "base_cost" in result.breakdown assert "token_cost" in result.breakdown - assert "file_cost" in result.breakdown + assert "tool_costs" in result.breakdown assert "tier_rate" in result.breakdown assert "tier" in result.breakdown assert "pre_multiplier_total" in result.breakdown @@ -205,7 +201,6 @@ def test_calculate_pro_tier(self) -> None: input_tokens=1000, output_tokens=1000, total_tokens=2000, - generated_files_count=0, ) result = ConsumptionCalculator.calculate(context) @@ -222,7 +217,6 @@ def test_breakdown_is_json_serializable(self) -> None: input_tokens=1000, output_tokens=500, total_tokens=1500, - generated_files_count=1, ) result = ConsumptionCalculator.calculate(context) diff --git a/service/tests/unit/tools/__init__.py b/service/tests/unit/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/service/tests/unit/tools/test_cost.py b/service/tests/unit/tools/test_cost.py new file mode 100644 index 00000000..32b1a662 --- /dev/null +++ b/service/tests/unit/tools/test_cost.py @@ -0,0 +1,173 @@ +"""Unit tests for tool cost calculation.""" + +import pytest + +from app.tools.cost import calculate_tool_cost +from app.tools.registry import BuiltinToolRegistry, ToolCostConfig, ToolInfo + + +class TestCalculateToolCost: + """Tests for calculate_tool_cost function.""" + + @pytest.fixture(autouse=True) + def setup_registry(self) -> None: + """Set up test registry before each test.""" + BuiltinToolRegistry.clear() + + # Register mock tools for testing + from unittest.mock import MagicMock + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test tool" + + # Tool with base cost only + BuiltinToolRegistry._metadata["generate_image"] = ToolInfo( + id="generate_image", + name="Generate Image", + description="Generate images", + category="image", + cost=ToolCostConfig(base_cost=10, input_image_cost=5), + ) + + # Tool with input_image_cost + BuiltinToolRegistry._metadata["read_image"] = ToolInfo( + id="read_image", + name="Read Image", + description="Read images", + category="image", + cost=ToolCostConfig(base_cost=2), + ) + + # Tool with output_file_cost + BuiltinToolRegistry._metadata["knowledge_write"] = ToolInfo( + id="knowledge_write", + name="Knowledge Write", + description="Write files", + category="knowledge", + cost=ToolCostConfig(output_file_cost=5), + ) + + # Tool with no cost + BuiltinToolRegistry._metadata["knowledge_read"] = ToolInfo( + id="knowledge_read", + name="Knowledge Read", + description="Read files", + category="knowledge", + cost=ToolCostConfig(), + ) + + # Web search tool + BuiltinToolRegistry._metadata["web_search"] = ToolInfo( + id="web_search", + name="Web Search", + description="Search the web", + category="search", + cost=ToolCostConfig(base_cost=1), + ) + + def test_generate_image_without_reference(self) -> None: + """Test generate_image cost without reference image.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args={"prompt": "a beautiful sunset"}, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Base cost only (10), no input_image_cost + assert cost == 10 + + def test_generate_image_with_reference(self) -> None: + """Test generate_image cost with reference image.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args={"prompt": "a beautiful sunset", "image_id": "ref123"}, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Base cost (10) + input_image_cost (5) = 15 + assert cost == 15 + + def test_read_image_cost(self) -> None: + """Test read_image cost.""" + cost = calculate_tool_cost( + tool_name="read_image", + tool_args={"image_id": "abc123", "question": "What is in this image?"}, + tool_result={"success": True, "analysis": "A beautiful sunset"}, + ) + assert cost == 2 + + def test_knowledge_write_creating_file(self) -> None: + """Test knowledge_write cost when creating a new file.""" + cost = calculate_tool_cost( + tool_name="knowledge_write", + tool_args={"filename": "report.txt", "content": "Hello world"}, + tool_result={"success": True, "message": "Created file: report.txt"}, + ) + # output_file_cost (5) for creating new file + assert cost == 5 + + def test_knowledge_write_updating_file(self) -> None: + """Test knowledge_write cost when updating an existing file.""" + cost = calculate_tool_cost( + tool_name="knowledge_write", + tool_args={"filename": "report.txt", "content": "Updated content"}, + tool_result={"success": True, "message": "Updated file: report.txt"}, + ) + # No cost for updating (message doesn't contain "Created") + assert cost == 0 + + def test_knowledge_read_is_free(self) -> None: + """Test that knowledge_read is free.""" + cost = calculate_tool_cost( + tool_name="knowledge_read", + tool_args={"filename": "report.txt"}, + tool_result={"success": True, "content": "Hello world"}, + ) + assert cost == 0 + + def test_web_search_cost(self) -> None: + """Test web_search cost.""" + cost = calculate_tool_cost( + tool_name="web_search", + tool_args={"query": "Python programming"}, + tool_result={"results": [{"title": "Result 1"}]}, + ) + assert cost == 1 + + def test_unknown_tool_is_free(self) -> None: + """Test that unknown tools have zero cost.""" + cost = calculate_tool_cost( + tool_name="unknown_tool", + tool_args={"some": "args"}, + tool_result={"some": "result"}, + ) + assert cost == 0 + + def test_no_args_provided(self) -> None: + """Test cost calculation when no args are provided.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args=None, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Should only return base cost + assert cost == 10 + + def test_no_result_provided(self) -> None: + """Test cost calculation when no result is provided.""" + cost = calculate_tool_cost( + tool_name="knowledge_write", + tool_args={"filename": "test.txt", "content": "Hello"}, + tool_result=None, + ) + # Should only return base cost (0 for knowledge_write) + assert cost == 0 + + def test_failed_tool_execution(self) -> None: + """Test that failed tool executions don't charge output_file_cost.""" + cost = calculate_tool_cost( + tool_name="knowledge_write", + tool_args={"filename": "test.txt", "content": "Hello"}, + tool_result={"success": False, "error": "Permission denied"}, + ) + # No charge for failed execution + assert cost == 0 diff --git a/service/tests/unit/tools/utils/__init__.py b/service/tests/unit/tools/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/service/tests/unit/tools/utils/documents/__init__.py b/service/tests/unit/tools/utils/documents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/service/tests/unit/tools/utils/documents/test_handlers.py b/service/tests/unit/tools/utils/documents/test_handlers.py new file mode 100644 index 00000000..c7a0da34 --- /dev/null +++ b/service/tests/unit/tools/utils/documents/test_handlers.py @@ -0,0 +1,903 @@ +""" +Tests for file handlers. +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from app.tools.utils.documents.handlers import ( + DocxFileHandler, + ExcelFileHandler, + FileHandlerFactory, + HtmlFileHandler, + ImageFileHandler, + JsonFileHandler, + PdfFileHandler, + PptxFileHandler, + TextFileHandler, + XmlFileHandler, + YamlFileHandler, +) +from app.tools.utils.documents.image_fetcher import FetchedImage +from app.tools.utils.documents.spec import ( + DocumentSpec, + HeadingBlock, + ImageBlock, + ImageSlideSpec, + ListBlock, + PresentationSpec, + SheetSpec, + SlideSpec, + SpreadsheetSpec, + TableBlock, + TextBlock, +) + + +class TestFileHandlerFactory: + def test_get_handler(self) -> None: + # Existing handlers + assert isinstance(FileHandlerFactory.get_handler("test.pdf"), PdfFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.docx"), DocxFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.doc"), DocxFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.xlsx"), ExcelFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.xls"), ExcelFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.pptx"), PptxFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.ppt"), PptxFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.txt"), TextFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.csv"), TextFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.py"), TextFileHandler) + + def test_get_handler_new_types(self) -> None: + # New handlers + assert isinstance(FileHandlerFactory.get_handler("test.html"), HtmlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.htm"), HtmlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.json"), JsonFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.yaml"), YamlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.yml"), YamlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.xml"), XmlFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.png"), ImageFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.jpg"), ImageFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.jpeg"), ImageFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.gif"), ImageFileHandler) + assert isinstance(FileHandlerFactory.get_handler("test.webp"), ImageFileHandler) + + +class TestTextFileHandler: + def test_read_write(self) -> None: + handler = TextFileHandler() + content = "Hello, World!" + + # Write + bytes_content = handler.create_content(content) + assert isinstance(bytes_content, bytes) + assert bytes_content == b"Hello, World!" + + # Read + read_content = handler.read_content(bytes_content) + assert read_content == content + + def test_read_image_fail(self) -> None: + handler = TextFileHandler() + with pytest.raises(ValueError): + handler.read_content(b"test", mode="image") + + +class TestHtmlFileHandler: + def test_read_html(self) -> None: + handler = HtmlFileHandler() + html = b"

Title

Content

" + content = handler.read_content(html) + assert isinstance(content, str) + # Should extract text from HTML + assert "Title" in content or "Content" in content + + def test_read_html_strips_scripts(self) -> None: + handler = HtmlFileHandler() + html = b"Safe content" + content = handler.read_content(html) + assert "alert" not in content + assert "Safe content" in content + + def test_create_html(self) -> None: + handler = HtmlFileHandler() + content = handler.create_content("Hello\n\nWorld") + assert b"" in content + assert b"" in content + assert b"Hello" in content + + def test_read_image_fail(self) -> None: + handler = HtmlFileHandler() + with pytest.raises(ValueError): + handler.read_content(b"", mode="image") + + +class TestJsonFileHandler: + def test_read_json(self) -> None: + handler = JsonFileHandler() + data = {"key": "value", "nested": {"a": 1}} + json_bytes = json.dumps(data).encode() + + content = handler.read_content(json_bytes) + assert "key" in content + assert "value" in content + + def test_read_invalid_json(self) -> None: + handler = JsonFileHandler() + content = handler.read_content(b"not valid json") + assert content == "not valid json" + + def test_create_json_valid(self) -> None: + handler = JsonFileHandler() + data = '{"key": "value"}' + result = handler.create_content(data) + parsed = json.loads(result) + assert parsed["key"] == "value" + + def test_create_json_invalid_wraps(self) -> None: + handler = JsonFileHandler() + result = handler.create_content("plain text") + parsed = json.loads(result) + assert "content" in parsed + assert parsed["content"] == "plain text" + + +class TestYamlFileHandler: + def test_read_yaml(self) -> None: + handler = YamlFileHandler() + yaml_content = b"key: value\nnested:\n a: 1" + content = handler.read_content(yaml_content) + assert "key" in content + assert "value" in content + + def test_create_yaml_from_json(self) -> None: + handler = YamlFileHandler() + json_input = '{"key": "value"}' + result = handler.create_content(json_input) + assert b"key: value" in result + + +class TestXmlFileHandler: + def test_read_xml(self) -> None: + handler = XmlFileHandler() + xml = b"Hello" + content = handler.read_content(xml) + assert "Hello" in content + assert "item" in content + + def test_create_xml(self) -> None: + handler = XmlFileHandler() + result = handler.create_content("Test content") + assert b"" in result + assert b"Test content" in result + + def test_read_image_fail(self) -> None: + handler = XmlFileHandler() + with pytest.raises(ValueError): + handler.read_content(b"", mode="image") + + +class TestImageFileHandler: + def test_detect_format_png(self) -> None: + handler = ImageFileHandler() + png_magic = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + assert handler._detect_format(png_magic) == "png" + + def test_detect_format_jpeg(self) -> None: + handler = ImageFileHandler() + jpeg_magic = b"\xff\xd8" + b"\x00" * 100 + assert handler._detect_format(jpeg_magic) == "jpeg" + + def test_detect_format_gif(self) -> None: + handler = ImageFileHandler() + gif_magic = b"GIF89a" + b"\x00" * 100 + assert handler._detect_format(gif_magic) == "gif" + + def test_create_raises_error(self) -> None: + handler = ImageFileHandler() + with pytest.raises(ValueError, match="Cannot create image"): + handler.create_content("text") + + +# Only mock external deps for complex handlers if they are not installed in test env +# But for now we assume we might need to mock them to run this in strict CI envs +# where deps might be missing during dev. + + +@patch("fitz.open") +@patch("fitz.Matrix") +class TestPdfFileHandler: + def test_read_text(self, mock_matrix: MagicMock, mock_open: MagicMock) -> None: + handler = PdfFileHandler() + mock_doc = MagicMock() + mock_page = MagicMock() + mock_page.get_text.return_value = "Page text" + mock_page.find_tables.return_value = MagicMock(tables=[]) + mock_doc.__iter__.return_value = [mock_page] + mock_open.return_value = mock_doc + + content = handler.read_content(b"pdf_bytes", mode="text") + assert content == "Page text" + mock_open.assert_called_with(stream=b"pdf_bytes", filetype="pdf") + + def test_write_plain_text(self, mock_matrix: MagicMock, mock_open: MagicMock) -> None: + handler = PdfFileHandler() + # For plain text, it uses reportlab, not fitz + result = handler.create_content("Some text") + assert isinstance(result, bytes) + # PDF magic bytes + assert result[:4] == b"%PDF" + + +@patch("docx.Document") +class TestDocxFileHandler: + def test_read(self, mock_document_cls: MagicMock) -> None: + handler = DocxFileHandler() + mock_doc = MagicMock() + mock_element = MagicMock() + mock_element.tag = "p" + mock_element.iter.return_value = [MagicMock(text="Para 1")] + mock_doc.element.body = [mock_element] + mock_document_cls.return_value = mock_doc + + content = handler.read_content(b"docx_bytes") + assert "Para 1" in content + + def test_write_plain_text(self, mock_document_cls: MagicMock) -> None: + handler = DocxFileHandler() + mock_doc = MagicMock() + mock_document_cls.return_value = mock_doc + + handler.create_content("Line 1\nLine 2") + + assert mock_doc.add_paragraph.call_count == 2 + mock_doc.save.assert_called() + + +@patch("openpyxl.Workbook") +@patch("openpyxl.load_workbook") +class TestExcelFileHandler: + def test_read(self, mock_load_workbook: MagicMock, mock_workbook: MagicMock) -> None: + handler = ExcelFileHandler() + mock_wb = MagicMock() + mock_ws = MagicMock() + mock_wb.sheetnames = ["Sheet1"] + mock_wb.__getitem__.return_value = mock_ws + mock_ws.iter_rows.return_value = [("A", "B")] + mock_load_workbook.return_value = mock_wb + + content = handler.read_content(b"xlsx_bytes") + assert "Sheet1" in content + assert "A\tB" in content + + def test_write_csv(self, mock_load_workbook: MagicMock, mock_workbook: MagicMock) -> None: + handler = ExcelFileHandler() + mock_wb = MagicMock() + mock_ws = MagicMock() + mock_wb.active = mock_ws + mock_workbook.return_value = mock_wb + + handler.create_content("A,B\nC,D") + + assert mock_ws.append.call_count == 2 + mock_wb.save.assert_called() + + +@patch("pptx.Presentation") +class TestPptxFileHandler: + def test_read(self, mock_presentation: MagicMock) -> None: + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_shape = MagicMock() + mock_shape.text = "Slide Text" + mock_slide.shapes = [mock_shape] + mock_slide.has_notes_slide = False + mock_prs.slides = [mock_slide] + mock_presentation.return_value = mock_prs + + content = handler.read_content(b"pptx_bytes") + assert "Slide Text" in content + + def test_write_plain_text(self, mock_presentation: MagicMock) -> None: + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + handler.create_content("Title\nBody") + + mock_prs.slides.add_slide.assert_called() + mock_prs.save.assert_called() + + +# Document Spec Tests + + +class TestDocumentSpec: + def test_create_document_spec(self) -> None: + spec = DocumentSpec( + title="Test Doc", + author="Test Author", + content=[ + HeadingBlock(content="Chapter 1", level=1), + TextBlock(content="Some text here"), + ListBlock(items=["Item 1", "Item 2"], ordered=False), + TableBlock(headers=["A", "B"], rows=[["1", "2"]]), + ], + ) + assert spec.title == "Test Doc" + assert len(spec.content) == 4 + assert spec.content[0].type == "heading" + + def test_document_spec_json_roundtrip(self) -> None: + spec = DocumentSpec( + title="Test", + content=[TextBlock(content="Hello")], + ) + json_str = spec.model_dump_json() + parsed = DocumentSpec.model_validate_json(json_str) + assert parsed.title == spec.title + + +class TestSpreadsheetSpec: + def test_create_spreadsheet_spec(self) -> None: + spec = SpreadsheetSpec( + sheets=[ + SheetSpec( + name="Data", + headers=["Name", "Value"], + data=[["A", 1], ["B", 2]], + ) + ] + ) + assert len(spec.sheets) == 1 + assert spec.sheets[0].name == "Data" + + +class TestPresentationSpec: + def test_create_presentation_spec(self) -> None: + spec = PresentationSpec( + title="My Presentation", + slides=[ + SlideSpec(layout="title", title="Welcome", subtitle="Intro"), + SlideSpec( + layout="title_content", + title="Slide 2", + content=[ListBlock(items=["Point 1", "Point 2"])], + ), + ], + ) + assert len(spec.slides) == 2 + assert spec.slides[0].layout == "title" + + +class TestPptxFileHandlerEnhanced: + """Tests for enhanced PPTX generation with images, tables, and headings.""" + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_table_block(self, mock_presentation: MagicMock, mock_image_fetcher: MagicMock) -> None: + """Test table rendering in PPTX.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_table_shape = MagicMock() + mock_table = MagicMock() + mock_table_shape.table = mock_table + + # Mock table cells + mock_cells = {} + for row in range(3): # 1 header + 2 data rows + for col in range(2): + cell = MagicMock() + cell.text_frame.paragraphs = [MagicMock()] + mock_cells[(row, col)] = cell + mock_table.cell = lambda r, c: mock_cells[(r, c)] # type: ignore[misc] + + mock_slide.shapes.add_table.return_value = mock_table_shape + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_prs.slides.__iter__ = lambda self: iter([mock_slide]) # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Data Slide", + content=[ + TableBlock( + headers=["Name", "Value"], + rows=[["Item A", "100"], ["Item B", "200"]], + ) + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify table was created + mock_slide.shapes.add_table.assert_called_once() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_heading_block(self, mock_presentation: MagicMock, mock_image_fetcher: MagicMock) -> None: + """Test heading rendering in PPTX.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_textbox.text_frame.paragraphs = [MagicMock()] + + mock_slide.shapes.add_textbox.return_value = mock_textbox + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Content Slide", + content=[ + HeadingBlock(content="Section Header", level=2), + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify textbox was created for heading + mock_slide.shapes.add_textbox.assert_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_image_block_success(self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock) -> None: + """Test successful image rendering in PPTX.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image_data", + format="png", + width=200, + height=150, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Image Slide", + content=[ + ImageBlock( + url="https://example.com/image.png", + caption="Test Image", + ) + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify image was fetched with keyword arguments (new signature) + mock_fetcher.fetch.assert_called_once_with(url="https://example.com/image.png", image_id=None) + # Verify add_picture was called + mock_slide.shapes.add_picture.assert_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_image_block_failure_placeholder( + self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock + ) -> None: + """Test image failure shows placeholder text.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_textbox.text_frame.paragraphs = [MagicMock()] + + mock_slide.shapes.add_textbox.return_value = mock_textbox + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher failure + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=False, + error="Connection timeout", + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Image Slide", + content=[ImageBlock(url="https://example.com/fail.png")], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify placeholder textbox was created (not add_picture) + mock_slide.shapes.add_textbox.assert_called() + mock_slide.shapes.add_picture.assert_not_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_mixed_content(self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock) -> None: + """Test slide with multiple content block types.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_textbox.text_frame.paragraphs = [MagicMock()] + mock_table_shape = MagicMock() + mock_table = MagicMock() + mock_table_shape.table = mock_table + + # Mock table cells + mock_cells = {} + for row in range(2): + for col in range(2): + cell = MagicMock() + cell.text_frame.paragraphs = [MagicMock()] + mock_cells[(row, col)] = cell + mock_table.cell = lambda r, c: mock_cells[(r, c)] # type: ignore[misc] + + mock_slide.shapes.add_textbox.return_value = mock_textbox + mock_slide.shapes.add_table.return_value = mock_table_shape + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image", + format="png", + width=100, + height=100, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Mixed Content", + content=[ + HeadingBlock(content="Introduction", level=2), + TextBlock(content="Some intro text here."), + ListBlock(items=["Point 1", "Point 2"], ordered=True), + TableBlock(headers=["A", "B"], rows=[["1", "2"]]), + ImageBlock(url="https://example.com/chart.png"), + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify all content types were rendered + # Multiple textbox calls for heading, text, list + assert mock_slide.shapes.add_textbox.call_count >= 3 + # One table call + mock_slide.shapes.add_table.assert_called_once() + # One picture call + mock_slide.shapes.add_picture.assert_called_once() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_text_block_with_style(self, mock_presentation: MagicMock, mock_image_fetcher: MagicMock) -> None: + """Test text block with style attribute.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_paragraph = MagicMock() + mock_textbox.text_frame.paragraphs = [mock_paragraph] + mock_textbox.text_frame.word_wrap = True + + mock_slide.shapes.add_textbox.return_value = mock_textbox + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Styled Text", + content=[ + TextBlock(content="Bold text", style="bold"), + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + mock_slide.shapes.add_textbox.assert_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_image_with_specified_width(self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock) -> None: + """Test image block with width parameter.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image_data", + format="png", + width=800, + height=600, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Image with Width", + content=[ + ImageBlock( + url="https://example.com/image.png", + width=288, # 4 inches at 72 DPI + ) + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + mock_slide.shapes.add_picture.assert_called() + + +class TestPresentationSpecImageSlides: + """Tests for PresentationSpec with image_slides mode.""" + + def test_create_presentation_spec_image_slides_mode(self) -> None: + """Test creating PresentationSpec with image_slides mode.""" + spec = PresentationSpec( + mode="image_slides", + title="AI Generated Presentation", + author="AI Agent", + image_slides=[ + ImageSlideSpec(image_id="abc-123-456-def", notes="Welcome slide"), + ImageSlideSpec(image_id="ghi-789-012-jkl", notes="Content slide"), + ImageSlideSpec(image_id="mno-345-678-pqr"), # No notes + ], + ) + assert spec.mode == "image_slides" + assert len(spec.image_slides) == 3 + assert spec.image_slides[0].image_id == "abc-123-456-def" + assert spec.image_slides[0].notes == "Welcome slide" + assert spec.image_slides[2].notes is None + + def test_create_presentation_spec_structured_mode_default(self) -> None: + """Test that structured mode is the default.""" + spec = PresentationSpec( + title="Traditional Presentation", + slides=[ + SlideSpec(layout="title", title="Welcome"), + ], + ) + assert spec.mode == "structured" + assert len(spec.slides) == 1 + + def test_image_slide_spec_json_roundtrip(self) -> None: + """Test ImageSlideSpec JSON serialization.""" + spec = PresentationSpec( + mode="image_slides", + image_slides=[ + ImageSlideSpec(image_id="test-uuid", notes="Test notes"), + ], + ) + json_str = spec.model_dump_json() + parsed = PresentationSpec.model_validate_json(json_str) + assert parsed.mode == "image_slides" + assert parsed.image_slides[0].image_id == "test-uuid" + + +class TestImageBlockWithImageId: + """Tests for ImageBlock with image_id field.""" + + def test_image_block_with_url(self) -> None: + """Test ImageBlock with URL (traditional).""" + block = ImageBlock(url="https://example.com/image.png", caption="Test") + assert block.url == "https://example.com/image.png" + assert block.image_id is None + assert block.caption == "Test" + + def test_image_block_with_image_id(self) -> None: + """Test ImageBlock with image_id (new feature).""" + block = ImageBlock(image_id="abc-123-uuid", caption="Generated image") + assert block.url is None + assert block.image_id == "abc-123-uuid" + assert block.caption == "Generated image" + + def test_image_block_both_url_and_image_id(self) -> None: + """Test ImageBlock can have both (though one is preferred).""" + block = ImageBlock( + url="https://example.com/fallback.png", + image_id="abc-123-uuid", + ) + assert block.url is not None + assert block.image_id is not None + + +class TestPptxImageSlidesMode: + """Tests for PPTX generation with image_slides mode.""" + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_create_pptx_image_slides_success( + self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock + ) -> None: + """Test PPTX generation with image_slides mode.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + + # Mock slide dimensions + mock_prs.slide_width = MagicMock() + mock_prs.slide_height = MagicMock() + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher returning success + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image_data", + format="png", + width=1920, + height=1080, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + mode="image_slides", + title="Generated Presentation", + image_slides=[ + ImageSlideSpec(image_id="slide-1-uuid", notes="Speaker notes 1"), + ImageSlideSpec(image_id="slide-2-uuid", notes="Speaker notes 2"), + ], + ) + + handler.create_content(spec.model_dump_json()) + + # Verify image fetcher was called with image_id + assert mock_fetcher.fetch.call_count == 2 + mock_fetcher.fetch.assert_any_call(image_id="slide-1-uuid") + mock_fetcher.fetch.assert_any_call(image_id="slide-2-uuid") + + # Verify slides were added with full-bleed images + assert mock_prs.slides.add_slide.call_count == 2 + assert mock_slide.shapes.add_picture.call_count == 2 + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_create_pptx_image_slides_failure_shows_placeholder( + self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock + ) -> None: + """Test PPTX generation shows placeholder when image fails.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + mock_textbox = MagicMock() + mock_textbox.text_frame.paragraphs = [MagicMock()] + + mock_prs.slide_width = MagicMock() + mock_prs.slide_height = MagicMock() + mock_slide.shapes.add_textbox.return_value = mock_textbox + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher returning failure + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=False, + error="Image not found", + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + mode="image_slides", + image_slides=[ + ImageSlideSpec(image_id="missing-uuid"), + ], + ) + + handler.create_content(spec.model_dump_json()) + + # Verify placeholder textbox was added instead of picture + mock_slide.shapes.add_textbox.assert_called() + mock_slide.shapes.add_picture.assert_not_called() + + @patch("app.tools.utils.documents.image_fetcher.ImageFetcher") + @patch("pptx.Presentation") + def test_render_image_block_with_image_id( + self, mock_presentation: MagicMock, mock_image_fetcher_cls: MagicMock + ) -> None: + """Test rendering ImageBlock with image_id in structured slides.""" + handler = PptxFileHandler() + mock_prs = MagicMock() + mock_slide = MagicMock() + + mock_prs.slides.add_slide.return_value = mock_slide + mock_prs.slides.__bool__ = lambda self: True # type: ignore[method-assign] + mock_presentation.return_value = mock_prs + + # Mock image fetcher + mock_fetcher = MagicMock() + mock_fetcher.fetch.return_value = FetchedImage( + success=True, + data=b"fake_image_data", + format="png", + width=400, + height=300, + ) + mock_image_fetcher_cls.return_value = mock_fetcher + + spec = PresentationSpec( + slides=[ + SlideSpec( + layout="title_content", + title="Image from generate_image", + content=[ + ImageBlock( + image_id="generated-image-uuid", + caption="AI Generated Chart", + ) + ], + ) + ] + ) + + handler.create_content(spec.model_dump_json()) + + # Verify image fetcher was called with image_id (not url) + mock_fetcher.fetch.assert_called_once() + call_kwargs = mock_fetcher.fetch.call_args + # The call should have image_id set + assert call_kwargs[1].get("image_id") == "generated-image-uuid" or ( + call_kwargs[0] == () and call_kwargs[1].get("url") is None + ) diff --git a/service/tests/unit/tools/utils/documents/test_image_fetcher.py b/service/tests/unit/tools/utils/documents/test_image_fetcher.py new file mode 100644 index 00000000..bc974672 --- /dev/null +++ b/service/tests/unit/tools/utils/documents/test_image_fetcher.py @@ -0,0 +1,378 @@ +""" +Tests for ImageFetcher service. +""" + +import base64 +import io +from unittest.mock import MagicMock, patch + + +from app.tools.utils.documents.image_fetcher import ( + DEFAULT_TIMEOUT, + MAX_IMAGE_DIMENSION, + MAX_IMAGE_SIZE_BYTES, + FetchedImage, + ImageFetcher, +) + + +class TestFetchedImage: + def test_success_result(self) -> None: + result = FetchedImage( + success=True, + data=b"image_data", + format="png", + width=100, + height=100, + ) + assert result.success + assert result.data == b"image_data" + assert result.error is None + + def test_failure_result(self) -> None: + result = FetchedImage(success=False, error="Connection failed") + assert not result.success + assert result.error == "Connection failed" + assert result.data is None + + +class TestImageFetcher: + def test_init_defaults(self) -> None: + fetcher = ImageFetcher() + assert fetcher.timeout == DEFAULT_TIMEOUT + assert fetcher.max_size_bytes == MAX_IMAGE_SIZE_BYTES + assert fetcher.max_dimension == MAX_IMAGE_DIMENSION + + def test_init_custom_values(self) -> None: + fetcher = ImageFetcher(timeout=10.0, max_size_bytes=1000, max_dimension=500) + assert fetcher.timeout == 10.0 + assert fetcher.max_size_bytes == 1000 + assert fetcher.max_dimension == 500 + + def test_unsupported_scheme(self) -> None: + fetcher = ImageFetcher() + result = fetcher.fetch("ftp://example.com/image.png") + assert not result.success + assert "Unsupported URL scheme" in (result.error or "") + + +class TestImageFetcherHTTP: + @patch("httpx.Client") + def test_fetch_http_success(self, mock_client_cls: MagicMock) -> None: + # Create a minimal valid PNG (1x1 pixel) + png_data = self._create_minimal_png() + + mock_response = MagicMock() + mock_response.content = png_data + mock_response.headers = {"content-length": str(len(png_data))} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/image.png") + + assert result.success + assert result.data is not None + assert result.format == "png" + mock_client.get.assert_called_once_with("https://example.com/image.png") + + @patch("httpx.Client") + def test_fetch_http_timeout(self, mock_client_cls: MagicMock) -> None: + import httpx + + mock_client = MagicMock() + mock_client.get.side_effect = httpx.TimeoutException("timeout") + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/image.png") + + assert not result.success + assert "Timeout" in (result.error or "") + + @patch("httpx.Client") + def test_fetch_http_status_error(self, mock_client_cls: MagicMock) -> None: + import httpx + + mock_response = MagicMock() + mock_response.status_code = 404 + + mock_client = MagicMock() + mock_client.get.side_effect = httpx.HTTPStatusError("Not found", request=MagicMock(), response=mock_response) + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/image.png") + + assert not result.success + assert "404" in (result.error or "") + + @patch("httpx.Client") + def test_fetch_http_too_large_header(self, mock_client_cls: MagicMock) -> None: + mock_response = MagicMock() + mock_response.headers = {"content-length": str(MAX_IMAGE_SIZE_BYTES + 1)} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/large.png") + + assert not result.success + assert "too large" in (result.error or "").lower() + + @patch("httpx.Client") + def test_fetch_http_too_large_content(self, mock_client_cls: MagicMock) -> None: + mock_response = MagicMock() + mock_response.headers = {} # No content-length header + mock_response.content = b"x" * (MAX_IMAGE_SIZE_BYTES + 1) + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value = mock_client + + fetcher = ImageFetcher() + result = fetcher.fetch("https://example.com/large.png") + + assert not result.success + assert "too large" in (result.error or "").lower() + + def _create_minimal_png(self) -> bytes: + """Create a minimal valid 1x1 PNG image.""" + from PIL import Image + + img = Image.new("RGB", (1, 1), color="red") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + return buffer.getvalue() + + +class TestImageFetcherBase64: + def test_fetch_base64_valid_png(self) -> None: + # Create a minimal valid PNG + from PIL import Image + + img = Image.new("RGB", (10, 10), color="blue") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + fetcher = ImageFetcher() + result = fetcher.fetch(data_url) + + assert result.success + assert result.data is not None + assert result.format == "png" + assert result.width == 10 + assert result.height == 10 + + def test_fetch_base64_valid_jpeg(self) -> None: + from PIL import Image + + img = Image.new("RGB", (20, 15), color="green") + buffer = io.BytesIO() + img.save(buffer, format="JPEG") + jpeg_bytes = buffer.getvalue() + + b64_data = base64.b64encode(jpeg_bytes).decode("utf-8") + data_url = f"data:image/jpeg;base64,{b64_data}" + + fetcher = ImageFetcher() + result = fetcher.fetch(data_url) + + assert result.success + assert result.format in ("jpg", "jpeg") + assert result.width == 20 + assert result.height == 15 + + def test_fetch_base64_invalid_format(self) -> None: + fetcher = ImageFetcher() + result = fetcher.fetch("data:text/plain;base64,SGVsbG8=") + assert not result.success + assert "Invalid" in (result.error or "") + + def test_fetch_base64_invalid_data(self) -> None: + fetcher = ImageFetcher() + result = fetcher.fetch("!!data") + assert not result.success + # Either decode error or image processing error + assert result.error is not None + + def test_fetch_base64_too_large(self) -> None: + from PIL import Image + + # Create an image that exceeds the limit when decoded + img = Image.new("RGB", (100, 100), color="red") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + # Use a small limit to trigger the check + fetcher = ImageFetcher(max_size_bytes=100) + result = fetcher.fetch(data_url) + + assert not result.success + assert "too large" in (result.error or "").lower() + + +class TestImageFetcherResize: + def test_resize_large_image(self) -> None: + from PIL import Image + + # Create an image larger than max dimension + large_size = MAX_IMAGE_DIMENSION + 500 + img = Image.new("RGB", (large_size, large_size // 2), color="purple") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + fetcher = ImageFetcher() + result = fetcher.fetch(data_url) + + assert result.success + # Image should be resized + assert result.width is not None + assert result.height is not None + assert result.width <= MAX_IMAGE_DIMENSION + assert result.height <= MAX_IMAGE_DIMENSION + + def test_no_resize_small_image(self) -> None: + from PIL import Image + + small_size = 100 + img = Image.new("RGB", (small_size, small_size), color="yellow") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + fetcher = ImageFetcher() + result = fetcher.fetch(data_url) + + assert result.success + assert result.width == small_size + assert result.height == small_size + + +class TestImageFetcherStorage: + @patch("app.core.storage.get_storage_service") + def test_fetch_from_storage_success(self, mock_get_storage: MagicMock) -> None: + from PIL import Image + + # Create test image + img = Image.new("RGB", (50, 50), color="cyan") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + # Mock storage service + mock_storage = MagicMock() + + async def mock_download(key: str, output_buffer: io.BytesIO) -> None: + output_buffer.write(png_bytes) + + mock_storage.download_file = mock_download + mock_get_storage.return_value = mock_storage + + fetcher = ImageFetcher() + result = fetcher.fetch("storage://path/to/image.png") + + assert result.success + assert result.width == 50 + assert result.height == 50 + + @patch("app.core.storage.get_storage_service") + def test_fetch_from_storage_failure(self, mock_get_storage: MagicMock) -> None: + mock_storage = MagicMock() + + async def mock_download(key: str, output_buffer: io.BytesIO) -> None: + raise FileNotFoundError("File not found") + + mock_storage.download_file = mock_download + mock_get_storage.return_value = mock_storage + + fetcher = ImageFetcher() + result = fetcher.fetch("storage://path/to/missing.png") + + assert not result.success + assert "Storage fetch failed" in (result.error or "") + + +class TestImageFetcherByImageId: + """Tests for fetching images by image_id (UUID).""" + + def test_fetch_invalid_uuid_format(self) -> None: + """Test that invalid UUID format returns error.""" + fetcher = ImageFetcher() + result = fetcher.fetch(image_id="not-a-valid-uuid") + + assert not result.success + assert "Invalid image_id format" in (result.error or "") + + def test_fetch_requires_url_or_image_id(self) -> None: + """Test that fetch fails when neither url nor image_id is provided.""" + fetcher = ImageFetcher() + result = fetcher.fetch() + + assert not result.success + assert "Either url or image_id must be provided" in (result.error or "") + + def test_fetch_with_url_still_works(self) -> None: + """Test that fetch with url parameter still works (backward compat).""" + from PIL import Image + + # Create a minimal valid PNG + img = Image.new("RGB", (10, 10), color="blue") + buffer = io.BytesIO() + img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + b64_data = base64.b64encode(png_bytes).decode("utf-8") + data_url = f"data:image/png;base64,{b64_data}" + + fetcher = ImageFetcher() + # Test with keyword argument url= + result = fetcher.fetch(url=data_url) + + assert result.success + assert result.format == "png" + assert result.width == 10 + + def test_fetch_by_image_id_returns_resolution_error(self) -> None: + """Test that image_id returns error about needing resolution in async layer.""" + from uuid import uuid4 + + fetcher = ImageFetcher() + test_uuid = str(uuid4()) + result = fetcher.fetch(image_id=test_uuid) + + assert not result.success + # Should explain that image_id needs to be resolved in async layer + assert "not resolved" in (result.error or "").lower() or "async layer" in (result.error or "").lower() From fc4f6f491f3a62de11b01a439c975a9ad9da9bdc Mon Sep 17 00:00:00 2001 From: "xinquiry(SII)" <100398322+xinquiry@users.noreply.github.com> Date: Thu, 22 Jan 2026 20:18:59 +0800 Subject: [PATCH 4/8] fix: fix the first time calling knowledge tool error (#194) --- service/app/core/checkin.py | 25 +++++------ service/app/infra/database/__init__.py | 2 + service/app/infra/database/connection.py | 44 +++++++++++++++++++ .../app/tools/builtin/knowledge/operations.py | 10 ++--- .../unit/test_literature/test_base_client.py | 2 +- 5 files changed, 62 insertions(+), 21 deletions(-) diff --git a/service/app/core/checkin.py b/service/app/core/checkin.py index 4837fc98..7ab99e6a 100644 --- a/service/app/core/checkin.py +++ b/service/app/core/checkin.py @@ -2,6 +2,7 @@ import logging from datetime import datetime, timedelta, timezone +from typing import TypedDict from sqlmodel.ext.asyncio.session import AsyncSession @@ -16,6 +17,13 @@ CHECKIN_TZ = timezone(timedelta(hours=8)) +class CheckInStatus(TypedDict): + checked_in_today: bool + consecutive_days: int + next_points: int + total_check_ins: int + + class CheckInService: """Service layer for check-in operations.""" @@ -58,19 +66,7 @@ def calculate_points(consecutive_days: int) -> int: Returns: Points to award. """ - if consecutive_days <= 0: - return 10 - elif consecutive_days == 1: - return 10 - elif consecutive_days == 2: - return 20 - elif consecutive_days == 3: - return 30 - elif consecutive_days == 4: - return 40 - else: - # Day 5 and beyond: 50 points - return 50 + return 10 * max(1, min(consecutive_days, 5)) async def check_in(self, user_id: str) -> tuple[CheckIn, int]: """ @@ -101,7 +97,6 @@ async def check_in(self, user_id: str) -> tuple[CheckIn, int]: existing_check_in = await self.check_in_repo.get_check_in_by_user_and_date(user_id, today) if existing_check_in: logger.warning(f"User {user_id} has already checked in today") - # raise ErrCodeError(ErrCode.ALREADY_CHECKED_IN_TODAY, "您今天已经签到过了哦~") raise ErrCodeError(ErrCode.ALREADY_CHECKED_IN_TODAY) # Get latest check-in to calculate consecutive days @@ -147,7 +142,7 @@ async def check_in(self, user_id: str) -> tuple[CheckIn, int]: return check_in, wallet.virtual_balance - async def get_check_in_status(self, user_id: str) -> dict: + async def get_check_in_status(self, user_id: str) -> CheckInStatus: """ Get check-in status for a user. diff --git a/service/app/infra/database/__init__.py b/service/app/infra/database/__init__.py index 4eac26fe..b50a9dc6 100644 --- a/service/app/infra/database/__init__.py +++ b/service/app/infra/database/__init__.py @@ -5,6 +5,7 @@ create_task_session_factory, engine, get_session, + get_task_db_session, ) __all__ = [ @@ -14,4 +15,5 @@ "AsyncSessionLocal", "ASYNC_DATABASE_URL", "create_task_session_factory", + "get_task_db_session", ] diff --git a/service/app/infra/database/connection.py b/service/app/infra/database/connection.py index a8afa96a..56314693 100644 --- a/service/app/infra/database/connection.py +++ b/service/app/infra/database/connection.py @@ -1,4 +1,6 @@ +import os from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine @@ -138,3 +140,45 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]: """ async with AsyncSessionLocal() as session: yield session + + +_worker_engines: dict[int, async_sessionmaker[AsyncSession]] = {} + + +@asynccontextmanager +async def get_task_db_session(): + """ + Get a database session suitable for Celery Worker / tool execution contexts. + + Each call creates a new engine bound to the current event loop to avoid cross-process issues. + + Usage: + async with get_task_db_session() as db: + result = await db.execute(...) + """ + # task_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False, future=True) + # TaskSessionLocal = async_sessionmaker( + # bind=task_engine, + # class_=AsyncSession, + # expire_on_commit=False, + # ) + + # async with TaskSessionLocal() as session: + # try: + # yield session + # finally: + # await task_engine.dispose() + + # Better version that reuses engines per process + pid = os.getpid() + + if pid not in _worker_engines: + task_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False, future=True) + _worker_engines[pid] = async_sessionmaker( + bind=task_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + async with _worker_engines[pid]() as session: + yield session diff --git a/service/app/tools/builtin/knowledge/operations.py b/service/app/tools/builtin/knowledge/operations.py index eb4f386d..b1a8d59d 100644 --- a/service/app/tools/builtin/knowledge/operations.py +++ b/service/app/tools/builtin/knowledge/operations.py @@ -17,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from app.core.storage import FileCategory, FileScope, generate_storage_key, get_storage_service -from app.infra.database import AsyncSessionLocal +from app.infra.database import get_task_db_session from app.models.file import FileCreate from app.repos.file import FileRepository from app.repos.knowledge_set import KnowledgeSetRepository @@ -138,7 +138,7 @@ async def get_files_in_knowledge_set(db: AsyncSession, user_id: str, knowledge_s async def list_files(user_id: str, knowledge_set_id: UUID) -> dict[str, Any]: """List all files in the knowledge set.""" try: - async with AsyncSessionLocal() as db: + async with get_task_db_session() as db: file_repo = FileRepository(db) try: @@ -178,7 +178,7 @@ async def read_file(user_id: str, knowledge_set_id: UUID, filename: str) -> dict # Normalize filename filename = filename.strip("/").split("/")[-1] - async with AsyncSessionLocal() as db: + async with get_task_db_session() as db: file_repo = FileRepository(db) target_file = None @@ -229,7 +229,7 @@ async def write_file(user_id: str, knowledge_set_id: UUID, filename: str, conten try: filename = filename.strip("/").split("/")[-1] - async with AsyncSessionLocal() as db: + async with get_task_db_session() as db: file_repo = FileRepository(db) knowledge_set_repo = KnowledgeSetRepository(db) storage = get_storage_service() @@ -309,7 +309,7 @@ async def write_file(user_id: str, knowledge_set_id: UUID, filename: str, conten async def search_files(user_id: str, knowledge_set_id: UUID, query: str) -> dict[str, Any]: """Search for files by name in the knowledge set.""" try: - async with AsyncSessionLocal() as db: + async with get_task_db_session() as db: file_repo = FileRepository(db) matches: list[str] = [] diff --git a/service/tests/unit/test_literature/test_base_client.py b/service/tests/unit/test_literature/test_base_client.py index a692737a..95d03b2a 100644 --- a/service/tests/unit/test_literature/test_base_client.py +++ b/service/tests/unit/test_literature/test_base_client.py @@ -138,7 +138,7 @@ def test_literature_work_minimal(self) -> None: def test_literature_work_complete(self) -> None: """Test LiteratureWork with all fields.""" - authors = [ + authors: list[dict[str, str | None]] = [ {"name": "John Doe", "id": "A1"}, {"name": "Jane Smith", "id": "A2"}, ] From 6997cd8f88beb0218415fe9b77472277e24dd7e5 Mon Sep 17 00:00:00 2001 From: "xinquiry(SII)" <100398322+xinquiry@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:15:22 +0800 Subject: [PATCH 5/8] fix: fix the wrong cache for second call of agent tools (#195) --- service/app/infra/database/connection.py | 37 +++++++++--------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/service/app/infra/database/connection.py b/service/app/infra/database/connection.py index 56314693..a59a7ff6 100644 --- a/service/app/infra/database/connection.py +++ b/service/app/infra/database/connection.py @@ -1,3 +1,4 @@ +import asyncio import os from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -142,7 +143,7 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]: yield session -_worker_engines: dict[int, async_sessionmaker[AsyncSession]] = {} +_worker_engines: dict[tuple[int, int], async_sessionmaker[AsyncSession]] = {} @asynccontextmanager @@ -150,35 +151,25 @@ async def get_task_db_session(): """ Get a database session suitable for Celery Worker / tool execution contexts. - Each call creates a new engine bound to the current event loop to avoid cross-process issues. - - Usage: - async with get_task_db_session() as db: - result = await db.execute(...) + Each call creates a new engine bound to the current event loop to avoid cross-loop issues. """ - # task_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False, future=True) - # TaskSessionLocal = async_sessionmaker( - # bind=task_engine, - # class_=AsyncSession, - # expire_on_commit=False, - # ) - - # async with TaskSessionLocal() as session: - # try: - # yield session - # finally: - # await task_engine.dispose() - - # Better version that reuses engines per process pid = os.getpid() + loop_id = id(asyncio.get_running_loop()) + cache_key = (pid, loop_id) + + if cache_key not in _worker_engines: + # Clean up old engines from this process but different loops + old_keys = [k for k in _worker_engines if k[0] == pid and k[1] != loop_id] + for old_key in old_keys: + # Optionally dispose old engines (fire and forget) + del _worker_engines[old_key] - if pid not in _worker_engines: task_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False, future=True) - _worker_engines[pid] = async_sessionmaker( + _worker_engines[cache_key] = async_sessionmaker( bind=task_engine, class_=AsyncSession, expire_on_commit=False, ) - async with _worker_engines[pid]() as session: + async with _worker_engines[cache_key]() as session: yield session From af94399f58ddf618ab1ba3b87a74b52f893a38ab Mon Sep 17 00:00:00 2001 From: "xinquiry(SII)" <100398322+xinquiry@users.noreply.github.com> Date: Thu, 22 Jan 2026 22:40:44 +0800 Subject: [PATCH 6/8] feat: several improvements (#196) * fix: jump to latest topic when click agent * feat: allow more than one image for generate image * feat: allow user directly edit mcp in the chat-toolbar * feat: improve the frontend perf --- service/app/api/v1/sessions.py | 11 +- service/app/core/session/service.py | 14 + service/app/tools/builtin/image.py | 167 ++++++---- service/app/tools/cost.py | 7 +- service/tests/unit/tools/test_cost.py | 24 +- .../components/features/FileUploadPreview.tsx | 10 +- .../features/FileUploadThumbnail.tsx | 58 +++- .../layouts/components/ChatToolbar.tsx | 13 +- .../components/ChatToolbar/McpToolsButton.tsx | 284 ++++++++++++------ .../components/ChatToolbar/MobileMoreMenu.tsx | 247 +++++++++++++-- web/src/i18n/locales/en/app.json | 6 + web/src/i18n/locales/zh/app.json | 6 + web/src/lib/Markdown.tsx | 19 +- web/src/service/fileService.ts | 54 +++- web/src/store/slices/chatSlice.ts | 26 +- 15 files changed, 712 insertions(+), 234 deletions(-) diff --git a/service/app/api/v1/sessions.py b/service/app/api/v1/sessions.py index f4ef323a..94db7476 100644 --- a/service/app/api/v1/sessions.py +++ b/service/app/api/v1/sessions.py @@ -72,16 +72,17 @@ async def create_session_with_default_topic( raise handle_auth_error(e) -@router.get("/by-agent/{agent_id}", response_model=SessionRead) +@router.get("/by-agent/{agent_id}", response_model=SessionReadWithTopics) async def get_session_by_agent( agent_id: str, user: str = Depends(get_current_user), db: AsyncSession = Depends(get_session) -) -> SessionRead: +) -> SessionReadWithTopics: """ - Retrieve a session for the current user with a specific agent. + Retrieve a session for the current user with a specific agent, including topics. Finds a session associated with the given agent ID for the authenticated user. The agent_id can be "default" for sessions without an agent, a UUID string for sessions with a specific agent, or a builtin agent string ID. + Topics are ordered by updated_at descending (most recent first). Args: agent_id: Agent identifier ("default", UUID string, or builtin agent ID) @@ -89,13 +90,13 @@ async def get_session_by_agent( db: Database session (injected by dependency) Returns: - SessionRead: The session associated with the user and agent + SessionReadWithTopics: The session with topics associated with the user and agent Raises: HTTPException: 404 if no session found for this user-agent combination """ try: - return await SessionService(db).get_session_by_agent(user, agent_id) + return await SessionService(db).get_session_by_agent_with_topics(user, agent_id) except ErrCodeError as e: raise handle_auth_error(e) diff --git a/service/app/core/session/service.py b/service/app/core/session/service.py index 53d0e729..fd9bfc14 100644 --- a/service/app/core/session/service.py +++ b/service/app/core/session/service.py @@ -49,6 +49,20 @@ async def get_session_by_agent(self, user_id: str, agent_id: str) -> SessionRead raise ErrCode.SESSION_NOT_FOUND.with_messages("No session found for this user-agent combination") return SessionRead(**session.model_dump()) + async def get_session_by_agent_with_topics(self, user_id: str, agent_id: str) -> SessionReadWithTopics: + agent_uuid = await self._resolve_agent_uuid_for_lookup(agent_id) + session = await self.session_repo.get_session_by_user_and_agent(user_id, agent_uuid) + if not session: + raise ErrCode.SESSION_NOT_FOUND.with_messages("No session found for this user-agent combination") + + # Fetch topics ordered by updated_at descending (most recent first) + topics = await self.topic_repo.get_topics_by_session(session.id, order_by_updated=True) + topic_reads = [TopicRead(**topic.model_dump()) for topic in topics] + + session_dict = session.model_dump() + session_dict["topics"] = topic_reads + return SessionReadWithTopics(**session_dict) + async def get_sessions_with_topics(self, user_id: str) -> list[SessionReadWithTopics]: sessions = await self.session_repo.get_sessions_by_user_ordered_by_activity(user_id) diff --git a/service/app/tools/builtin/image.py b/service/app/tools/builtin/image.py index c5985808..f9c10061 100644 --- a/service/app/tools/builtin/image.py +++ b/service/app/tools/builtin/image.py @@ -14,7 +14,7 @@ from uuid import UUID, uuid4 from langchain_core.tools import BaseTool, StructuredTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from app.configs import configs from app.core.storage import FileScope, generate_storage_key, get_storage_service @@ -24,6 +24,9 @@ # --- Input Schemas --- +# Maximum number of reference images allowed for generation +MAX_INPUT_IMAGES = 4 + class GenerateImageInput(BaseModel): """Input schema for generate_image tool.""" @@ -35,14 +38,24 @@ class GenerateImageInput(BaseModel): default="1:1", description="Aspect ratio of the generated image.", ) - image_id: str | None = Field( + image_ids: list[str] | None = Field( default=None, description=( - "Optional image UUID to use as a reference input. " - "Use the 'image_id' value returned from generate_image or upload tools." + f"Optional list of image UUIDs (max {MAX_INPUT_IMAGES}) to use as reference inputs. " + "Use the 'image_id' values returned from generate_image or upload tools." ), ) + @model_validator(mode="after") + def validate_image_inputs(self) -> "GenerateImageInput": + """Validate image_ids field.""" + if self.image_ids: + if len(self.image_ids) > MAX_INPUT_IMAGES: + raise ValueError(f"Maximum {MAX_INPUT_IMAGES} input images allowed, got {len(self.image_ids)}") + if len(self.image_ids) == 0: + self.image_ids = None # Normalize empty list to None + return self + class ReadImageInput(BaseModel): """Input schema for read_image tool.""" @@ -62,8 +75,7 @@ class ReadImageInput(BaseModel): async def _generate_image_with_langchain( prompt: str, aspect_ratio: str = "1:1", - image_bytes: bytes | None = None, - image_mime_type: str | None = None, + images: list[tuple[bytes, str]] | None = None, ) -> tuple[bytes, str]: """ Generate an image using LangChain ChatGoogleGenerativeAI via ProviderManager. @@ -76,6 +88,7 @@ async def _generate_image_with_langchain( Args: prompt: Text description of the image to generate aspect_ratio: Aspect ratio for the generated image + images: Optional list of (image_bytes, mime_type) tuples to use as references Returns: Tuple of (image_bytes, mime_type) @@ -102,25 +115,34 @@ async def _generate_image_with_langchain( ) # Request image generation via LangChain - if image_bytes and image_mime_type: - b64_data = base64.b64encode(image_bytes).decode("utf-8") - message = HumanMessage( - content=[ + if images: + # Build content array with multiple image_url blocks + content: list[dict[str, Any]] = [] + for image_bytes, image_mime_type in images: + b64_data = base64.b64encode(image_bytes).decode("utf-8") + content.append( { "type": "image_url", "image_url": { "url": f"data:{image_mime_type};base64,{b64_data}", }, - }, - { - "type": "text", - "text": ( - "Use the provided image as a reference. " - f"Generate a new image with aspect ratio {aspect_ratio}: {prompt}" - ), - }, - ] + } + ) + + # Add text prompt with appropriate phrasing for single vs multiple images + image_count = len(images) + if image_count == 1: + reference_text = "Use the provided image as a reference." + else: + reference_text = f"Use these {image_count} provided images as references." + + content.append( + { + "type": "text", + "text": f"{reference_text} Generate a new image with aspect ratio {aspect_ratio}: {prompt}", + } ) + message = HumanMessage(content=content) # type: ignore[arg-type] else: message = HumanMessage(content=f"Generate an image with aspect ratio {aspect_ratio}: {prompt}") response = await llm.ainvoke([message]) @@ -172,46 +194,67 @@ async def _generate_image_with_langchain( raise ValueError("No image data in response. Model may not support image generation.") -async def _load_image_for_generation(user_id: str, image_id: str) -> tuple[bytes, str, str]: +async def _load_images_for_generation(user_id: str, image_ids: list[str]) -> list[tuple[bytes, str, str]]: + """ + Load multiple images for generation from the database. + + Args: + user_id: User ID for permission check + image_ids: List of image UUIDs to load + + Returns: + List of tuples: (image_bytes, mime_type, storage_key) + + Raises: + ValueError: If any image_id is invalid, not found, deleted, or inaccessible + """ from app.infra.database import create_task_session_factory from app.repos.file import FileRepository - try: - file_uuid = UUID(image_id) - except ValueError as exc: - raise ValueError(f"Invalid image_id format: {image_id}") from exc + results: list[tuple[bytes, str, str]] = [] # Create a fresh session factory for the current event loop (Celery worker) TaskSessionLocal = create_task_session_factory() async with TaskSessionLocal() as db: file_repo = FileRepository(db) - file_record = await file_repo.get_file_by_id(file_uuid) + storage = get_storage_service() - if file_record is None: - raise ValueError(f"Image not found: {image_id}") + for image_id in image_ids: + try: + file_uuid = UUID(image_id) + except ValueError as exc: + raise ValueError(f"Invalid image_id format: {image_id}") from exc - if file_record.is_deleted: - raise ValueError(f"Image has been deleted: {image_id}") + file_record = await file_repo.get_file_by_id(file_uuid) - if file_record.user_id != user_id and file_record.scope != "public": - raise ValueError("Permission denied: you don't have access to this image") + if file_record is None: + raise ValueError(f"Image not found: {image_id}") - storage_key = file_record.storage_key - content_type = file_record.content_type or "image/png" + if file_record.is_deleted: + raise ValueError(f"Image has been deleted: {image_id}") - storage = get_storage_service() - buffer = io.BytesIO() - await storage.download_file(storage_key, buffer) - image_bytes = buffer.getvalue() - return image_bytes, content_type, storage_key + if file_record.user_id != user_id and file_record.scope != "public": + raise ValueError(f"Permission denied: you don't have access to image {image_id}") + + storage_key = file_record.storage_key + content_type = file_record.content_type or "image/png" + + # Download from storage + buffer = io.BytesIO() + await storage.download_file(storage_key, buffer) + image_bytes = buffer.getvalue() + + results.append((image_bytes, content_type, storage_key)) + + return results async def _generate_image( user_id: str, prompt: str, aspect_ratio: str = "1:1", - image_id: str | None = None, + image_ids: list[str] | None = None, ) -> dict[str, Any]: """ Generate an image and store it to OSS, then register in database. @@ -220,28 +263,27 @@ async def _generate_image( user_id: User ID for storage organization prompt: Image description aspect_ratio: Aspect ratio for the image + image_ids: Optional list of image UUIDs to use as reference inputs Returns: Dictionary with success status, path, URL, and metadata """ try: - # Load optional reference image - source_image_bytes = None - source_mime_type = None - source_storage_key = None - source_image_id = image_id - if source_image_id: - source_image_bytes, source_mime_type, source_storage_key = await _load_image_for_generation( - user_id, - source_image_id, - ) + # Load optional reference images + images_for_generation: list[tuple[bytes, str]] | None = None + source_storage_keys: list[str] = [] + source_image_ids: list[str] = image_ids or [] + + if source_image_ids: + loaded_images = await _load_images_for_generation(user_id, source_image_ids) + images_for_generation = [(img[0], img[1]) for img in loaded_images] + source_storage_keys = [img[2] for img in loaded_images] # Generate image using LangChain via ProviderManager image_bytes, mime_type = await _generate_image_with_langchain( prompt, aspect_ratio, - image_bytes=source_image_bytes, - image_mime_type=source_mime_type, + images=images_for_generation, ) # Determine file extension from mime type @@ -290,27 +332,27 @@ async def _generate_image( metainfo={ "prompt": prompt, "aspect_ratio": aspect_ratio, - "source_image_id": source_image_id, - "source_storage_key": source_storage_key, + "source_image_ids": source_image_ids, + "source_storage_keys": source_storage_keys, }, ) file_record = await file_repo.create_file(file_data) await db.commit() # Refresh to get the generated UUID await db.refresh(file_record) - image_id = str(file_record.id) + generated_image_id = str(file_record.id) - logger.info(f"Generated image for user {user_id}: {storage_key} (id={image_id})") + logger.info(f"Generated image for user {user_id}: {storage_key} (id={generated_image_id})") return { "success": True, - "image_id": image_id, + "image_id": generated_image_id, "path": storage_key, "url": url, "markdown": f"![Generated Image]({url})", "prompt": prompt, "aspect_ratio": aspect_ratio, - "source_image_id": source_image_id, + "source_image_ids": source_image_ids, "mime_type": mime_type, "size_bytes": len(image_bytes), } @@ -511,7 +553,7 @@ def create_image_tools() -> dict[str, BaseTool]: async def generate_image_placeholder( prompt: str, aspect_ratio: str = "1:1", - image_id: str | None = None, + image_ids: list[str] | None = None, ) -> dict[str, Any]: return {"error": "Image tools require agent context binding", "success": False} @@ -520,7 +562,7 @@ async def generate_image_placeholder( description=( "Generate an image based on a text description. " "Provide a detailed prompt describing the desired image. " - "To modify or generate based on a previous image, pass the 'image_id' from a previous generate_image result. " + f"To generate based on previous images, pass 'image_ids' with up to {MAX_INPUT_IMAGES} reference image UUIDs. " "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image. " "TIP: You can use 'image_id' values when creating PPTX presentations with knowledge_write - see knowledge_help(topic='image_slides') for details." ), @@ -564,9 +606,9 @@ def create_image_tools_for_agent(user_id: str) -> list[BaseTool]: async def generate_image_bound( prompt: str, aspect_ratio: str = "1:1", - image_id: str | None = None, + image_ids: list[str] | None = None, ) -> dict[str, Any]: - return await _generate_image(user_id, prompt, aspect_ratio, image_id) + return await _generate_image(user_id, prompt, aspect_ratio, image_ids) tools.append( StructuredTool( @@ -574,7 +616,7 @@ async def generate_image_bound( description=( "Generate an image based on a text description. " "Provide a detailed prompt describing the desired image including style, colors, composition, and subject. " - "To modify or generate based on a previous image, pass the 'image_id' from a previous generate_image result. " + f"To generate based on previous images, pass 'image_ids' with up to {MAX_INPUT_IMAGES} reference image UUIDs. " "Returns a JSON result containing 'image_id' (for future reference), 'url', and 'markdown' - use the 'markdown' field directly in your response to display the image to the user. " "TIP: You can use 'image_id' values when creating beautiful PPTX presentations with knowledge_write in image_slides mode - call knowledge_help(topic='image_slides') for the full workflow." ), @@ -611,4 +653,5 @@ async def read_image_bound( "create_image_tools_for_agent", "GenerateImageInput", "ReadImageInput", + "MAX_INPUT_IMAGES", ] diff --git a/service/app/tools/cost.py b/service/app/tools/cost.py index 16921a94..f3dc0c95 100644 --- a/service/app/tools/cost.py +++ b/service/app/tools/cost.py @@ -34,10 +34,11 @@ def calculate_tool_cost( config = tool_info.cost cost = config.base_cost - # Add input image cost (for generate_image with reference) + # Add input image cost (for generate_image with reference images) if config.input_image_cost and tool_args: - if tool_args.get("image_id"): # Has reference image - cost += config.input_image_cost + image_ids = tool_args.get("image_ids") + if image_ids: + cost += config.input_image_cost * len(image_ids) # Add output file cost (for knowledge_write creating new files) if config.output_file_cost and tool_result: diff --git a/service/tests/unit/tools/test_cost.py b/service/tests/unit/tools/test_cost.py index 32b1a662..69ad41c7 100644 --- a/service/tests/unit/tools/test_cost.py +++ b/service/tests/unit/tools/test_cost.py @@ -77,15 +77,35 @@ def test_generate_image_without_reference(self) -> None: assert cost == 10 def test_generate_image_with_reference(self) -> None: - """Test generate_image cost with reference image.""" + """Test generate_image cost with single reference image.""" cost = calculate_tool_cost( tool_name="generate_image", - tool_args={"prompt": "a beautiful sunset", "image_id": "ref123"}, + tool_args={"prompt": "a beautiful sunset", "image_ids": ["ref123"]}, tool_result={"success": True, "image_id": "abc123"}, ) # Base cost (10) + input_image_cost (5) = 15 assert cost == 15 + def test_generate_image_with_multiple_references(self) -> None: + """Test generate_image cost with multiple reference images.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args={"prompt": "combine these images", "image_ids": ["ref1", "ref2", "ref3"]}, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Base cost (10) + input_image_cost (5) * 3 = 25 + assert cost == 25 + + def test_generate_image_with_empty_image_ids(self) -> None: + """Test generate_image cost with empty image_ids list.""" + cost = calculate_tool_cost( + tool_name="generate_image", + tool_args={"prompt": "a sunset", "image_ids": []}, + tool_result={"success": True, "image_id": "abc123"}, + ) + # Base cost only (10), empty list means no input images + assert cost == 10 + def test_read_image_cost(self) -> None: """Test read_image cost.""" cost = calculate_tool_cost( diff --git a/web/src/components/features/FileUploadPreview.tsx b/web/src/components/features/FileUploadPreview.tsx index 3d0d4641..d76a0cd2 100644 --- a/web/src/components/features/FileUploadPreview.tsx +++ b/web/src/components/features/FileUploadPreview.tsx @@ -1,3 +1,4 @@ +import React from "react"; import { useXyzen } from "@/store"; import { FileUploadThumbnail } from "./FileUploadThumbnail"; import clsx from "clsx"; @@ -6,8 +7,11 @@ export interface FileUploadPreviewProps { className?: string; } -export function FileUploadPreview({ className }: FileUploadPreviewProps) { - const { uploadedFiles, isUploading, uploadError } = useXyzen(); +function FileUploadPreviewComponent({ className }: FileUploadPreviewProps) { + // Use selective subscriptions to avoid re-renders from unrelated store changes + const uploadedFiles = useXyzen((state) => state.uploadedFiles); + const isUploading = useXyzen((state) => state.isUploading); + const uploadError = useXyzen((state) => state.uploadError); if (uploadedFiles.length === 0) { return null; @@ -57,3 +61,5 @@ export function FileUploadPreview({ className }: FileUploadPreviewProps) { ); } + +export const FileUploadPreview = React.memo(FileUploadPreviewComponent); diff --git a/web/src/components/features/FileUploadThumbnail.tsx b/web/src/components/features/FileUploadThumbnail.tsx index d31f2e82..44e7a675 100644 --- a/web/src/components/features/FileUploadThumbnail.tsx +++ b/web/src/components/features/FileUploadThumbnail.tsx @@ -1,3 +1,4 @@ +import React, { useCallback } from "react"; import { XMarkIcon, DocumentIcon, @@ -12,20 +13,28 @@ export interface FileUploadThumbnailProps { file: UploadedFile; } -export function FileUploadThumbnail({ file }: FileUploadThumbnailProps) { - const { removeFile, retryUpload } = useXyzen(); +function FileUploadThumbnailComponent({ file }: FileUploadThumbnailProps) { + // Use selective subscriptions to avoid re-renders from unrelated store changes + const removeFile = useXyzen((state) => state.removeFile); + const retryUpload = useXyzen((state) => state.retryUpload); - const handleRemove = (e: React.MouseEvent) => { - e.preventDefault(); - e.stopPropagation(); - removeFile(file.id); - }; + const handleRemove = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + removeFile(file.id); + }, + [removeFile, file.id], + ); - const handleRetry = (e: React.MouseEvent) => { - e.preventDefault(); - e.stopPropagation(); - retryUpload(file.id); - }; + const handleRetry = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + retryUpload(file.id); + }, + [retryUpload, file.id], + ); const getFileIcon = () => { if (file.category === "images") { @@ -74,6 +83,7 @@ export function FileUploadThumbnail({ file }: FileUploadThumbnailProps) { {file.name} ) : ( @@ -165,3 +175,27 @@ export function FileUploadThumbnail({ file }: FileUploadThumbnailProps) { ); } + +// Custom comparison function for React.memo +// Only re-render when relevant file properties change +function arePropsEqual( + prevProps: FileUploadThumbnailProps, + nextProps: FileUploadThumbnailProps, +): boolean { + const prevFile = prevProps.file; + const nextFile = nextProps.file; + + return ( + prevFile.id === nextFile.id && + prevFile.status === nextFile.status && + prevFile.progress === nextFile.progress && + prevFile.thumbnailUrl === nextFile.thumbnailUrl && + prevFile.name === nextFile.name && + prevFile.category === nextFile.category + ); +} + +export const FileUploadThumbnail = React.memo( + FileUploadThumbnailComponent, + arePropsEqual, +); diff --git a/web/src/components/layouts/components/ChatToolbar.tsx b/web/src/components/layouts/components/ChatToolbar.tsx index cb47bcc6..03a15624 100644 --- a/web/src/components/layouts/components/ChatToolbar.tsx +++ b/web/src/components/layouts/components/ChatToolbar.tsx @@ -86,6 +86,7 @@ export default function ChatToolbar({ uploadedFiles, isUploading, updateAgent, + openSettingsModal, } = useXyzen(); // All user agents for lookup @@ -288,6 +289,8 @@ export default function ChatToolbar({ agent={currentAgent} onUpdateAgent={updateAgent} mcpInfo={currentMcpInfo} + allMcpServers={mcpServers} + onOpenSettings={() => openSettingsModal("mcp")} sessionKnowledgeSetId={currentChannel?.knowledge_set_id} onUpdateSessionKnowledge={handleKnowledgeSetChange} /> @@ -328,9 +331,15 @@ export default function ChatToolbar({ )} {/* MCP Tool Button */} - {currentMcpInfo && ( + {currentAgent && ( openSettingsModal("mcp")} buttonClassName={cn( toolbarButtonClass, "w-auto px-2 gap-1.5", diff --git a/web/src/components/layouts/components/ChatToolbar/McpToolsButton.tsx b/web/src/components/layouts/components/ChatToolbar/McpToolsButton.tsx index 0ba7fdac..2dd2d549 100644 --- a/web/src/components/layouts/components/ChatToolbar/McpToolsButton.tsx +++ b/web/src/components/layouts/components/ChatToolbar/McpToolsButton.tsx @@ -1,21 +1,22 @@ /** - * MCP Tools Button with hover tooltip + * MCP Tools Button with interactive Popover * - * Displays connected MCP servers and their available tools. + * Allows users to toggle MCP servers on/off for the current agent. */ import McpIcon from "@/assets/McpIcon"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; import { cn } from "@/lib/utils"; import type { Agent } from "@/types/agents"; +import type { McpServer } from "@/types/mcp"; +import { CheckIcon, Cog6ToothIcon } from "@heroicons/react/24/outline"; +import { useState } from "react"; import { useTranslation } from "react-i18next"; -interface McpServer { - id: string; - name: string; - status?: string; // "online" | "offline" or other statuses - tools?: Array<{ name: string }>; -} - interface McpInfo { agent: Agent; servers: McpServer[]; @@ -23,114 +24,217 @@ interface McpInfo { interface McpToolsButtonProps { mcpInfo: McpInfo; + allMcpServers: McpServer[]; + agent: Agent; + onUpdateAgent: (agent: Agent) => Promise; + onOpenSettings?: () => void; buttonClassName?: string; } export function McpToolsButton({ mcpInfo, + allMcpServers, + agent, + onUpdateAgent, + onOpenSettings, buttonClassName, }: McpToolsButtonProps) { const { t } = useTranslation(); + const [isOpen, setIsOpen] = useState(false); + const [isUpdating, setIsUpdating] = useState(null); + const totalTools = mcpInfo.servers.reduce( (total, server) => total + (server.tools?.length || 0), 0, ); - return ( -
- - - {/* MCP Tooltip */} -
- - {/* Arrow */} -
-
-
+ // Get connected server IDs from agent + const connectedServerIds = new Set( + agent.mcp_server_ids || agent.mcp_servers?.map((s) => s.id) || [], ); -} -/** - * MCP Tooltip content component - */ -function McpTooltipContent({ mcpInfo }: { mcpInfo: McpInfo }) { - const { t } = useTranslation(); + // Separate servers into connected and available + const connectedServers = allMcpServers.filter((server) => + connectedServerIds.has(server.id), + ); + const availableServers = allMcpServers.filter( + (server) => !connectedServerIds.has(server.id), + ); + + const handleMcpServerToggle = async (serverId: string, connect: boolean) => { + if (!agent || isUpdating) return; + + setIsUpdating(serverId); + try { + const currentIds = + agent.mcp_server_ids || agent.mcp_servers?.map((s) => s.id) || []; + const newIds = connect + ? [...currentIds, serverId] + : currentIds.filter((id) => id !== serverId); + + await onUpdateAgent({ + ...agent, + mcp_server_ids: newIds, + }); + } catch (error) { + console.error("Failed to update MCP server:", error); + } finally { + setIsUpdating(null); + } + }; return ( - <> -
-
- - - {t("app.toolbar.mcpTools")} - -
-
- {t("app.chat.assistantsTitle")}: {mcpInfo.agent.name} -
-
+ + + + + +
+ {/* Header */} +
+
+ + + {t("app.toolbar.mcpTools")} + +
+
+ {t("app.chat.assistantsTitle")}: {agent.name} +
+
-
- {mcpInfo.servers.map((server) => ( - - ))} -
- + {/* Connected Servers Section */} + {connectedServers.length > 0 && ( +
+

+ {t("app.toolbar.mcpConnected", "Connected")} +

+
+ {connectedServers.map((server) => ( + handleMcpServerToggle(server.id, false)} + /> + ))} +
+
+ )} + + {/* Available Servers Section */} + {availableServers.length > 0 && ( +
+

+ {t("app.toolbar.mcpAvailable", "Available")} +

+
+ {availableServers.map((server) => ( + handleMcpServerToggle(server.id, true)} + /> + ))} +
+
+ )} + + {/* Empty State */} + {allMcpServers.length === 0 && ( +
+
+ {t("app.toolbar.mcpNoServers", "No MCP servers configured")} +
+ +
+ )} +
+
+
); } /** - * Individual MCP server card + * Individual MCP server toggle item */ -function McpServerCard({ server }: { server: McpServer }) { +interface McpServerToggleItemProps { + server: McpServer; + isConnected: boolean; + isUpdating: boolean; + onToggle: () => void; +} + +function McpServerToggleItem({ + server, + isConnected, + isUpdating, + onToggle, +}: McpServerToggleItemProps) { + const { t } = useTranslation(); + const isOnline = server.status === "online"; + const isDisabled = !isOnline || isUpdating; + return ( -
-
-
-
- + ); } diff --git a/web/src/components/layouts/components/ChatToolbar/MobileMoreMenu.tsx b/web/src/components/layouts/components/ChatToolbar/MobileMoreMenu.tsx index 86793050..6b7b3afa 100644 --- a/web/src/components/layouts/components/ChatToolbar/MobileMoreMenu.tsx +++ b/web/src/components/layouts/components/ChatToolbar/MobileMoreMenu.tsx @@ -1,19 +1,26 @@ /** * Mobile More Menu * - * A popup menu shown on mobile with tool selector and MCP info. + * A popup menu shown on mobile with tool selector and MCP management. */ import McpIcon from "@/assets/McpIcon"; +import { cn } from "@/lib/utils"; import type { Agent } from "@/types/agents"; +import type { McpServer } from "@/types/mcp"; +import { + CheckIcon, + ChevronDownIcon, + Cog6ToothIcon, +} from "@heroicons/react/24/outline"; import { AnimatePresence, motion } from "motion/react"; +import { useState } from "react"; +import { useTranslation } from "react-i18next"; import { ToolSelector } from "./ToolSelector"; interface McpInfo { - servers: Array<{ - id: string; - tools?: Array<{ name: string }>; - }>; + agent: Agent; + servers: McpServer[]; } interface MobileMoreMenuProps { @@ -21,6 +28,8 @@ interface MobileMoreMenuProps { agent: Agent | null; onUpdateAgent: (agent: Agent) => Promise; mcpInfo: McpInfo | null; + allMcpServers?: McpServer[]; + onOpenSettings?: () => void; sessionKnowledgeSetId?: string | null; onUpdateSessionKnowledge?: (knowledgeSetId: string | null) => Promise; } @@ -30,14 +39,61 @@ export function MobileMoreMenu({ agent, onUpdateAgent, mcpInfo, + allMcpServers = [], + onOpenSettings, sessionKnowledgeSetId, onUpdateSessionKnowledge, }: MobileMoreMenuProps) { + const { t } = useTranslation(); + const [showMcpList, setShowMcpList] = useState(false); + const [isUpdating, setIsUpdating] = useState(null); + const handleUpdateAgent = async (updatedAgent: Agent) => { await onUpdateAgent(updatedAgent); // Don't close on toggle - let user configure multiple tools }; + // Get connected server IDs from agent + const connectedServerIds = new Set( + agent?.mcp_server_ids || agent?.mcp_servers?.map((s) => s.id) || [], + ); + + // Separate servers into connected and available + const connectedServers = allMcpServers.filter((server) => + connectedServerIds.has(server.id), + ); + const availableServers = allMcpServers.filter( + (server) => !connectedServerIds.has(server.id), + ); + + const totalTools = + mcpInfo?.servers.reduce( + (total, server) => total + (server.tools?.length || 0), + 0, + ) || 0; + + const handleMcpServerToggle = async (serverId: string, connect: boolean) => { + if (!agent || isUpdating) return; + + setIsUpdating(serverId); + try { + const currentIds = + agent.mcp_server_ids || agent.mcp_servers?.map((s) => s.id) || []; + const newIds = connect + ? [...currentIds, serverId] + : currentIds.filter((id) => id !== serverId); + + await onUpdateAgent({ + ...agent, + mcp_server_ids: newIds, + }); + } catch (error) { + console.error("Failed to update MCP server:", error); + } finally { + setIsUpdating(null); + } + }; + return ( {isOpen && ( @@ -64,23 +120,114 @@ export function MobileMoreMenu({
)} - {/* MCP Tool Info */} - {mcpInfo && ( -
-
-
- - MCP Tools -
- {mcpInfo.servers.length > 0 && ( - - {mcpInfo.servers.reduce( - (total, server) => total + (server.tools?.length || 0), - 0, + {/* MCP Tool Section - Expandable */} + {agent && ( +
+ + + {/* Expandable MCP Server List */} + + {showMcpList && ( + +
+ {/* Empty State */} + {allMcpServers.length === 0 && ( +
+
+ {t( + "app.toolbar.mcpNoServers", + "No MCP servers configured", + )} +
+ +
+ )} + + {/* Connected Servers */} + {connectedServers.length > 0 && ( +
+
+ {t("app.toolbar.mcpConnected", "Connected")} +
+
+ {connectedServers.map((server) => ( + + handleMcpServerToggle(server.id, false) + } + /> + ))} +
+
+ )} + + {/* Available Servers */} + {availableServers.length > 0 && ( +
+
+ {t("app.toolbar.mcpAvailable", "Available")} +
+
+ {availableServers.map((server) => ( + + handleMcpServerToggle(server.id, true) + } + /> + ))} +
+
+ )} +
+
)} -
+
)}
@@ -90,4 +237,64 @@ export function MobileMoreMenu({ ); } +/** + * Mobile MCP Server toggle item + */ +interface MobileMcpServerItemProps { + server: McpServer; + isConnected: boolean; + isUpdating: boolean; + onToggle: () => void; +} + +function MobileMcpServerItem({ + server, + isConnected, + isUpdating, + onToggle, +}: MobileMcpServerItemProps) { + const { t } = useTranslation(); + const isOnline = server.status === "online"; + const isDisabled = !isOnline || isUpdating; + + return ( + + ); +} + export default MobileMoreMenu; diff --git a/web/src/i18n/locales/en/app.json b/web/src/i18n/locales/en/app.json index 7435f459..beea8f01 100644 --- a/web/src/i18n/locales/en/app.json +++ b/web/src/i18n/locales/en/app.json @@ -33,6 +33,12 @@ "knowledgeConnect": "Connect Knowledge Base", "knowledgeDisconnect": "Disconnect", "mcpTools": "MCP Tools Connected", + "mcpConnected": "Connected", + "mcpAvailable": "Available", + "mcpToolsCount": "tools", + "mcpOffline": "offline", + "mcpNoServers": "No MCP servers configured", + "mcpOpenSettings": "Open Settings", "searchOff": "Off", "searchOffDesc": "Do not use search", "searchBuiltinDesc": "Use model's native search capability", diff --git a/web/src/i18n/locales/zh/app.json b/web/src/i18n/locales/zh/app.json index 2ec06249..9e689381 100644 --- a/web/src/i18n/locales/zh/app.json +++ b/web/src/i18n/locales/zh/app.json @@ -33,6 +33,12 @@ "knowledgeConnect": "连接知识库", "knowledgeDisconnect": "断开连接", "mcpTools": "MCP 工具已连接", + "mcpConnected": "已连接", + "mcpAvailable": "可用", + "mcpToolsCount": "个工具", + "mcpOffline": "离线", + "mcpNoServers": "未配置 MCP 服务器", + "mcpOpenSettings": "打开设置", "searchOff": "关闭", "searchOffDesc": "不使用搜索功能", "searchBuiltinDesc": "使用模型原生搜索能力", diff --git a/web/src/lib/Markdown.tsx b/web/src/lib/Markdown.tsx index 9f1c8f02..f3555bbe 100644 --- a/web/src/lib/Markdown.tsx +++ b/web/src/lib/Markdown.tsx @@ -443,9 +443,9 @@ const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms)); const isXyzenDownloadUrl = (src: string) => src.includes("/xyzen/api/v1/files/") && src.includes("/download"); -const MarkdownImage: React.FC> = ( - props, -) => { +const MarkdownImageComponent: React.FC< + React.ImgHTMLAttributes +> = (props) => { const { src, alt, ...rest } = props; const backendUrl = useXyzen((state) => state.backendUrl); const token = useXyzen((state) => state.token); @@ -581,6 +581,7 @@ const MarkdownImage: React.FC> = ( {alt} @@ -638,6 +639,14 @@ const MarkdownImage: React.FC> = ( ); }; +// Memoize MarkdownImage to prevent re-renders during streaming +// Only re-render when src or alt changes +const MarkdownImage = React.memo( + MarkdownImageComponent, + (prevProps, nextProps) => + prevProps.src === nextProps.src && prevProps.alt === nextProps.alt, +); + // Helper component to catch Escape key for image lightbox function ImageLightboxEscapeCatcher({ onEscape }: { onEscape: () => void }) { useEffect(() => { @@ -761,9 +770,7 @@ const Markdown: React.FC = function Markdown(props) { ); }, - img(props: React.ComponentPropsWithoutRef<"img">) { - return ; - }, + img: MarkdownImage, }), [isDark], ); diff --git a/web/src/service/fileService.ts b/web/src/service/fileService.ts index 8fcb0333..296d9264 100644 --- a/web/src/service/fileService.ts +++ b/web/src/service/fileService.ts @@ -388,7 +388,8 @@ class FileService { } /** - * Generate thumbnail URL for preview + * Generate thumbnail URL for preview using Canvas API + * Resizes image to max 160px dimension and outputs as JPEG for small file size */ generateThumbnail(file: File): Promise { return new Promise((resolve, reject) => { @@ -397,16 +398,53 @@ class FileService { return; } - const reader = new FileReader(); - reader.onload = (e) => { - if (e.target?.result) { - resolve(e.target.result as string); + const MAX_SIZE = 160; + const objectUrl = URL.createObjectURL(file); + const img = new Image(); + + img.onload = () => { + URL.revokeObjectURL(objectUrl); + + // Calculate scaled dimensions maintaining aspect ratio + let width = img.width; + let height = img.height; + + if (width > height) { + if (width > MAX_SIZE) { + height = Math.round((height * MAX_SIZE) / width); + width = MAX_SIZE; + } } else { - reject(new Error("Failed to read file")); + if (height > MAX_SIZE) { + width = Math.round((width * MAX_SIZE) / height); + height = MAX_SIZE; + } + } + + // Create canvas and draw scaled image + const canvas = document.createElement("canvas"); + canvas.width = width; + canvas.height = height; + + const ctx = canvas.getContext("2d"); + if (!ctx) { + reject(new Error("Failed to get canvas context")); + return; } + + ctx.drawImage(img, 0, 0, width, height); + + // Export as JPEG with 0.8 quality for small file size + const thumbnailUrl = canvas.toDataURL("image/jpeg", 0.8); + resolve(thumbnailUrl); }; - reader.onerror = () => reject(new Error("Failed to read file")); - reader.readAsDataURL(file); + + img.onerror = () => { + URL.revokeObjectURL(objectUrl); + reject(new Error("Failed to load image")); + }; + + img.src = objectUrl; }); } } diff --git a/web/src/store/slices/chatSlice.ts b/web/src/store/slices/chatSlice.ts index e59b7df4..058949c0 100644 --- a/web/src/store/slices/chatSlice.ts +++ b/web/src/store/slices/chatSlice.ts @@ -402,27 +402,9 @@ export const createChatSlice: StateCreator< * - If no session exists, creates one with a default topic */ activateChannelForAgent: async (agentId: string) => { - const { channels, chatHistory, backendUrl } = get(); + const { backendUrl } = get(); - // First, check if we already have a channel for this agent - const existingChannel = Object.values(channels).find( - (ch) => ch.agentId === agentId, - ); - - if (existingChannel) { - // Already have a channel, activate it - await get().activateChannel(existingChannel.id); - return; - } - - // Check chat history for existing topics with this agent - const existingHistory = chatHistory.find((h) => h.sessionId === agentId); - if (existingHistory) { - await get().activateChannel(existingHistory.id); - return; - } - - // No existing channel, try to find or create a session for this agent + // Always fetch from backend to get the most recent topic const token = authService.getToken(); if (!token) { console.error("No authentication token available"); @@ -446,8 +428,8 @@ export const createChatSlice: StateCreator< // Get the most recent topic for this session, or create one if (session.topics && session.topics.length > 0) { - // Activate the most recent topic - const latestTopic = session.topics[session.topics.length - 1]; + // Activate the most recent topic (backend returns topics ordered by updated_at descending) + const latestTopic = session.topics[0]; // Create channel if doesn't exist const channel: ChatChannel = { From 1c1fb7fbfeb7cffffab6e26dbc90210dc77e0678 Mon Sep 17 00:00:00 2001 From: "xinquiry(SII)" <100398322+xinquiry@users.noreply.github.com> Date: Fri, 23 Jan 2026 00:31:56 +0800 Subject: [PATCH 7/8] feat: multiple UI improvements and fixes (#198) * fix: jump to latest topic when click agent * feat: allow more than one image for generate image * feat: allow user directly edit mcp in the chat-toolbar * feat: improve the frontend perf * fix: restore previous active topic when clicking agent Instead of always jumping to the latest topic, now tracks and restores the previously active topic for each agent when switching between them. Co-Authored-By: Claude * feat: add context menu to FocusedView agents and download button to lightbox - Add right-click context menu (edit/delete) to compact AgentListItem variant - Render context menu via portal to escape overflow:hidden containers - Add edit/delete handlers to FocusedView with AgentSettingsModal and ConfirmationModal - Add download button to image lightbox with smart filename detection Co-Authored-By: Claude * feat: add web_fetch tool bundled with web_search - Add web_fetch tool using Trafilatura for content extraction - Bundle web_fetch with web_search in frontend toolConfig - Group WEB_SEARCH_TOOLS for unified toggle behavior - Only load web_fetch when web_search is available (SearXNG enabled) - Update tool capabilities mapping for web_fetch Co-Authored-By: Claude --------- Co-authored-by: Claude --- service/app/tools/__init__.py | 2 +- service/app/tools/builtin/__init__.py | 3 + service/app/tools/builtin/fetch.py | 165 +++++++ service/app/tools/capabilities.py | 1 + service/app/tools/prepare.py | 8 +- service/app/tools/registry.py | 14 + service/pyproject.toml | 1 + service/uv.lock | 121 +++++ web/src/app/chat/SpatialWorkspace.tsx | 86 ++++ web/src/app/chat/spatial/FocusedView.tsx | 122 +++-- web/src/components/agents/AgentList.tsx | 115 +++++ web/src/components/agents/AgentListItem.tsx | 474 ++++++++++++++++++++ web/src/components/agents/index.ts | 2 + web/src/components/layouts/XyzenAgent.tsx | 364 +-------------- web/src/core/agent/toolConfig.ts | 70 ++- web/src/lib/Markdown.tsx | 78 +++- web/src/store/slices/chatSlice.ts | 25 +- 17 files changed, 1238 insertions(+), 413 deletions(-) create mode 100644 service/app/tools/builtin/fetch.py create mode 100644 web/src/components/agents/AgentList.tsx create mode 100644 web/src/components/agents/AgentListItem.tsx create mode 100644 web/src/components/agents/index.ts diff --git a/service/app/tools/__init__.py b/service/app/tools/__init__.py index b5e85ebe..108793ec 100644 --- a/service/app/tools/__init__.py +++ b/service/app/tools/__init__.py @@ -17,7 +17,7 @@ Tool Categories: | Category | Tools | UI Toggle | Auto-enabled | |------------|---------------------------|-----------|--------------| -| search | web_search | Yes | - | +| search | web_search, web_fetch | Yes | - | | knowledge | knowledge_* | No | Yes (with knowledge_set) | | image | generate_image, read_image| Yes | - | | research | think, ConductResearch | No | Component-internal | diff --git a/service/app/tools/builtin/__init__.py b/service/app/tools/builtin/__init__.py index e83066c1..0b2c48e0 100644 --- a/service/app/tools/builtin/__init__.py +++ b/service/app/tools/builtin/__init__.py @@ -13,6 +13,7 @@ - research: Deep research workflow tools (component-internal, not exported here) """ +from app.tools.builtin.fetch import create_web_fetch_tool from app.tools.builtin.image import create_image_tools, create_image_tools_for_agent from app.tools.builtin.knowledge import create_knowledge_tools, create_knowledge_tools_for_agent from app.tools.builtin.memory import create_memory_tools, create_memory_tools_for_agent @@ -21,6 +22,8 @@ __all__ = [ # Search "create_web_search_tool", + # Fetch + "create_web_fetch_tool", # Knowledge "create_knowledge_tools", "create_knowledge_tools_for_agent", diff --git a/service/app/tools/builtin/fetch.py b/service/app/tools/builtin/fetch.py new file mode 100644 index 00000000..ff5c69c2 --- /dev/null +++ b/service/app/tools/builtin/fetch.py @@ -0,0 +1,165 @@ +""" +Web Fetch Tool + +LangChain tool for fetching and extracting content from web pages using Trafilatura. +Extracts clean text/markdown content from HTML pages with metadata extraction. +""" + +from __future__ import annotations + +import logging +from typing import Any, Literal + +import trafilatura +from langchain_core.tools import BaseTool, StructuredTool +from pydantic import BaseModel, Field +from trafilatura.settings import use_config + +logger = logging.getLogger(__name__) + + +class WebFetchInput(BaseModel): + """Input schema for web fetch tool.""" + + url: str = Field(description="The URL of the web page to fetch and extract content from.") + output_format: Literal["markdown", "text"] = Field( + default="markdown", + description="Output format: 'markdown' for structured content, 'text' for plain text.", + ) + include_links: bool = Field( + default=True, + description="Whether to include hyperlinks in the extracted content.", + ) + include_images: bool = Field( + default=False, + description="Whether to include image references in the output.", + ) + timeout: int = Field( + default=30, + ge=5, + le=120, + description="Request timeout in seconds.", + ) + + +async def _web_fetch( + url: str, + output_format: Literal["markdown", "text"] = "markdown", + include_links: bool = True, + include_images: bool = False, + timeout: int = 30, +) -> dict[str, Any]: + """ + Fetch and extract content from a web page. + + Uses Trafilatura for robust HTML content extraction and conversion + to clean markdown or plain text. + + Returns: + A dictionary containing: + - success: Boolean indicating success + - url: The original URL + - title: Page title if available + - author: Author if available + - date: Publication date if available + - content: Extracted markdown/text content + - error: Error message if failed + """ + if not url.strip(): + return { + "success": False, + "error": "URL cannot be empty", + "url": url, + "title": None, + "author": None, + "date": None, + "content": None, + } + + # Configure trafilatura + config = use_config() + config.set("DEFAULT", "EXTRACTION_TIMEOUT", str(timeout)) + + try: + # Fetch the page + downloaded = trafilatura.fetch_url(url) + if downloaded is None: + return { + "success": False, + "error": "Failed to fetch URL - the page may be unavailable or blocked", + "url": url, + "title": None, + "author": None, + "date": None, + "content": None, + } + + # Extract content + content = trafilatura.extract( + downloaded, + output_format="markdown" if output_format == "markdown" else "txt", + include_links=include_links, + include_images=include_images, + include_comments=False, + ) + + if content is None: + return { + "success": False, + "error": "Failed to extract content from page - the page may have no readable content", + "url": url, + "title": None, + "author": None, + "date": None, + "content": None, + } + + # Extract metadata + metadata = trafilatura.extract_metadata(downloaded) + + logger.info(f"Web fetch completed: '{url}' extracted {len(content)} characters") + + return { + "success": True, + "url": url, + "title": metadata.title if metadata else None, + "author": metadata.author if metadata else None, + "date": metadata.date if metadata else None, + "content": content, + } + + except Exception as e: + error_msg = f"Fetch failed: {e!s}" + logger.error(f"Web fetch error for '{url}': {error_msg}") + return { + "success": False, + "error": error_msg, + "url": url, + "title": None, + "author": None, + "date": None, + "content": None, + } + + +def create_web_fetch_tool() -> BaseTool: + """ + Create the web fetch tool. + + Returns: + StructuredTool for web page content extraction. + """ + return StructuredTool( + name="web_fetch", + description=( + "Fetch and extract content from a web page. " + "Converts HTML to clean markdown or plain text, removing ads, navigation, and boilerplate. " + "Also extracts metadata like title, author, and publication date when available. " + "Use this when you need to read the full content of a specific web page." + ), + args_schema=WebFetchInput, + coroutine=_web_fetch, + ) + + +__all__ = ["create_web_fetch_tool", "WebFetchInput"] diff --git a/service/app/tools/capabilities.py b/service/app/tools/capabilities.py index fc88b836..1b4deecf 100644 --- a/service/app/tools/capabilities.py +++ b/service/app/tools/capabilities.py @@ -53,6 +53,7 @@ class ToolCapability(StrEnum): "google_search": [ToolCapability.WEB_SEARCH], "bing_search": [ToolCapability.WEB_SEARCH], "tavily_search": [ToolCapability.WEB_SEARCH], + "web_fetch": [ToolCapability.WEB_SEARCH], # Knowledge tools "knowledge_list": [ToolCapability.KNOWLEDGE_RETRIEVAL], "knowledge_read": [ToolCapability.KNOWLEDGE_RETRIEVAL, ToolCapability.FILE_OPERATIONS], diff --git a/service/app/tools/prepare.py b/service/app/tools/prepare.py index bb1106ba..81b4a6df 100644 --- a/service/app/tools/prepare.py +++ b/service/app/tools/prepare.py @@ -81,7 +81,7 @@ def _load_all_builtin_tools( """ Load all available builtin tools. - - Web search: loaded if SearXNG is enabled + - Web search + fetch: loaded if SearXNG is enabled - Knowledge tools: loaded if effective knowledge_set_id exists and user_id is available - Image tools: loaded if image generation is enabled and user_id is available - Memory tools: loaded if agent and user_id are available (currently disabled) @@ -101,10 +101,14 @@ def _load_all_builtin_tools( tools: list[BaseTool] = [] - # Load web_search if available in registry (registered at startup if SearXNG enabled) + # Load web search tools if available in registry (registered at startup if SearXNG enabled) web_search = BuiltinToolRegistry.get("web_search") if web_search: tools.append(web_search) + # Load web fetch tool (bundled with web_search) + web_fetch = BuiltinToolRegistry.get("web_fetch") + if web_fetch: + tools.append(web_fetch) # Determine effective knowledge_set_id # Priority: session override > agent config diff --git a/service/app/tools/registry.py b/service/app/tools/registry.py index 607bae1f..66d27420 100644 --- a/service/app/tools/registry.py +++ b/service/app/tools/registry.py @@ -173,6 +173,7 @@ def register_builtin_tools() -> None: Called at app startup to populate the registry. """ + from app.tools.builtin.fetch import create_web_fetch_tool from app.tools.builtin.knowledge import create_knowledge_tools from app.tools.builtin.search import create_web_search_tool @@ -190,6 +191,19 @@ def register_builtin_tools() -> None: cost=ToolCostConfig(base_cost=1), ) + # Register web fetch tool (bundled with web_search, not separate toggle) + fetch_tool = create_web_fetch_tool() + BuiltinToolRegistry.register( + tool_id="web_fetch", + tool=fetch_tool, + category="search", + display_name="Web Fetch", + ui_toggleable=False, # Bundled with web_search + default_enabled=True, + requires_context=[], + cost=ToolCostConfig(base_cost=1), + ) + # Tool cost configs for knowledge tools knowledge_tool_costs = { "knowledge_list": ToolCostConfig(), # Free diff --git a/service/pyproject.toml b/service/pyproject.toml index a36f8b25..61d80424 100644 --- a/service/pyproject.toml +++ b/service/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "pytesseract>=0.3.13", "pillow>=12.0.0", "celery-types>=0.24.0", + "trafilatura>=1.12.0", ] [dependency-groups] diff --git a/service/uv.lock b/service/uv.lock index 3b04ea30..3ae07986 100644 --- a/service/uv.lock +++ b/service/uv.lock @@ -234,6 +234,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005, upload-time = "2025-12-12T08:01:40.209Z" }, ] +[[package]] +name = "babel" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852, upload-time = "2025-02-01T15:17:41.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, +] + [[package]] name = "beautifulsoup4" version = "4.14.3" @@ -485,6 +494,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "courlan" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "babel" }, + { name = "tld" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/54/6d6ceeff4bed42e7a10d6064d35ee43a810e7b3e8beb4abeae8cff4713ae/courlan-1.3.2.tar.gz", hash = "sha256:0b66f4db3a9c39a6e22dd247c72cfaa57d68ea660e94bb2c84ec7db8712af190", size = 206382, upload-time = "2024-10-29T16:40:20.994Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/ca/6a667ccbe649856dcd3458bab80b016681b274399d6211187c6ab969fc50/courlan-1.3.2-py3-none-any.whl", hash = "sha256:d0dab52cf5b5b1000ee2839fbc2837e93b2514d3cb5bb61ae158a55b7a04c6be", size = 33848, upload-time = "2024-10-29T16:40:18.325Z" }, +] + [[package]] name = "coverage" version = "7.13.0" @@ -583,6 +606,21 @@ sqlite = [ { name = "aiosqlite" }, ] +[[package]] +name = "dateparser" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "regex" }, + { name = "tzlocal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/30/064144f0df1749e7bb5faaa7f52b007d7c2d08ec08fed8411aba87207f68/dateparser-1.2.2.tar.gz", hash = "sha256:986316f17cb8cdc23ea8ce563027c5ef12fc725b6fb1d137c14ca08777c5ecf7", size = 329840, upload-time = "2025-06-26T09:29:23.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/22/f020c047ae1346613db9322638186468238bcfa8849b4668a22b97faad65/dateparser-1.2.2-py3-none-any.whl", hash = "sha256:5a5d7211a09013499867547023a2a0c91d5a27d15dd4dbcea676ea9fe66f2482", size = 315453, upload-time = "2025-06-26T09:29:21.412Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -1121,6 +1159,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "htmldate" +version = "1.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "charset-normalizer" }, + { name = "dateparser" }, + { name = "lxml" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/10/ead9dabc999f353c3aa5d0dc0835b1e355215a5ecb489a7f4ef2ddad5e33/htmldate-1.9.4.tar.gz", hash = "sha256:1129063e02dd0354b74264de71e950c0c3fcee191178321418ccad2074cc8ed0", size = 44690, upload-time = "2025-11-04T17:46:44.983Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/bd/adfcdaaad5805c0c5156aeefd64c1e868c05e9c1cd6fd21751f168cd88c7/htmldate-1.9.4-py3-none-any.whl", hash = "sha256:1b94bcc4e08232a5b692159903acf95548b6a7492dddca5bb123d89d6325921c", size = 31558, upload-time = "2025-11-04T17:46:43.258Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -1327,6 +1381,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, ] +[[package]] +name = "justext" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lxml", extra = ["html-clean"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/f3/45890c1b314f0d04e19c1c83d534e611513150939a7cf039664d9ab1e649/justext-3.0.2.tar.gz", hash = "sha256:13496a450c44c4cd5b5a75a5efcd9996066d2a189794ea99a49949685a0beb05", size = 828521, upload-time = "2025-02-25T20:21:49.934Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/ac/52f4e86d1924a7fc05af3aeb34488570eccc39b4af90530dd6acecdf16b5/justext-3.0.2-py2.py3-none-any.whl", hash = "sha256:62b1c562b15c3c6265e121cc070874243a443bfd53060e869393f09d6b6cc9a7", size = 837940, upload-time = "2025-02-25T20:21:44.179Z" }, +] + [[package]] name = "kombu" version = "5.6.1" @@ -1652,6 +1718,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ea/7b/93c73c67db235931527301ed3785f849c78991e2e34f3fd9a6663ffda4c5/lxml-6.0.2-cp312-cp312-win_arm64.whl", hash = "sha256:61cb10eeb95570153e0c0e554f58df92ecf5109f75eacad4a95baa709e26c3d6", size = 3672836, upload-time = "2025-09-22T04:01:52.145Z" }, ] +[package.optional-dependencies] +html-clean = [ + { name = "lxml-html-clean" }, +] + +[[package]] +name = "lxml-html-clean" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lxml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/cb/c9c5bb2a9c47292e236a808dd233a03531f53b626f36259dcd32b49c76da/lxml_html_clean-0.4.3.tar.gz", hash = "sha256:c9df91925b00f836c807beab127aac82575110eacff54d0a75187914f1bd9d8c", size = 21498, upload-time = "2025-10-02T20:49:24.895Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/4a/63a9540e3ca73709f4200564a737d63a4c8c9c4dd032bab8535f507c190a/lxml_html_clean-0.4.3-py3-none-any.whl", hash = "sha256:63fd7b0b9c3a2e4176611c2ca5d61c4c07ffca2de76c14059a81a2825833731e", size = 14177, upload-time = "2025-10-02T20:49:23.749Z" }, +] + [[package]] name = "mako" version = "1.3.10" @@ -2478,6 +2561,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + [[package]] name = "pywin32" version = "311" @@ -2811,6 +2903,7 @@ dependencies = [ { name = "reportlab" }, { name = "requests" }, { name = "sqlmodel" }, + { name = "trafilatura" }, { name = "websockets" }, ] @@ -2877,6 +2970,7 @@ requires-dist = [ { name = "reportlab", specifier = ">=4.4.7" }, { name = "requests", specifier = ">=2.32.4" }, { name = "sqlmodel", specifier = ">=0.0.24" }, + { name = "trafilatura", specifier = ">=1.12.0" }, { name = "websockets", specifier = ">=13.0,<14.0" }, ] @@ -3041,6 +3135,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd", size = 878694, upload-time = "2025-10-06T20:21:59.876Z" }, ] +[[package]] +name = "tld" +version = "0.13.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/a1/5723b07a70c1841a80afc9ac572fdf53488306848d844cd70519391b0d26/tld-0.13.1.tar.gz", hash = "sha256:75ec00936cbcf564f67361c41713363440b6c4ef0f0c1592b5b0fbe72c17a350", size = 462000, upload-time = "2025-05-21T22:18:29.341Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/70/b2f38360c3fc4bc9b5e8ef429e1fde63749144ac583c2dbdf7e21e27a9ad/tld-0.13.1-py2.py3-none-any.whl", hash = "sha256:a2d35109433ac83486ddf87e3c4539ab2c5c2478230e5d9c060a18af4b03aa7c", size = 274718, upload-time = "2025-05-21T22:18:25.811Z" }, +] + [[package]] name = "tqdm" version = "4.67.1" @@ -3053,6 +3156,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, ] +[[package]] +name = "trafilatura" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "courlan" }, + { name = "htmldate" }, + { name = "justext" }, + { name = "lxml" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/25/e3ebeefdebfdfae8c4a4396f5a6ea51fc6fa0831d63ce338e5090a8003dc/trafilatura-2.0.0.tar.gz", hash = "sha256:ceb7094a6ecc97e72fea73c7dba36714c5c5b577b6470e4520dca893706d6247", size = 253404, upload-time = "2024-12-03T15:23:24.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/b6/097367f180b6383a3581ca1b86fcae284e52075fa941d1232df35293363c/trafilatura-2.0.0-py3-none-any.whl", hash = "sha256:77eb5d1e993747f6f20938e1de2d840020719735690c840b9a1024803a4cd51d", size = 132557, upload-time = "2024-12-03T15:23:21.41Z" }, +] + [[package]] name = "typer" version = "0.20.0" diff --git a/web/src/app/chat/SpatialWorkspace.tsx b/web/src/app/chat/SpatialWorkspace.tsx index 7f3ae57a..8d9bd1d5 100644 --- a/web/src/app/chat/SpatialWorkspace.tsx +++ b/web/src/app/chat/SpatialWorkspace.tsx @@ -11,9 +11,13 @@ import "@xyflow/react/dist/style.css"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import AddAgentModal from "@/components/modals/AddAgentModal"; +import AgentSettingsModal from "@/components/modals/AgentSettingsModal"; +import ConfirmationModal from "@/components/modals/ConfirmationModal"; import { useMyMarketplaceListings } from "@/hooks/useMarketplace"; import { useXyzen } from "@/store"; +import type { Agent } from "@/types/agents"; import { AnimatePresence } from "framer-motion"; +import { useTranslation } from "react-i18next"; import { AddAgentButton, @@ -34,6 +38,7 @@ import { } from "./spatial"; function InnerWorkspace() { + const { t } = useTranslation(); const { agents, updateAgentLayout, @@ -101,6 +106,10 @@ function InnerWorkspace() { }); const [saveStatus, setSaveStatus] = useState("idle"); const [isAddModalOpen, setAddModalOpen] = useState(false); + const [editingAgent, setEditingAgent] = useState(null); + const [isEditModalOpen, setEditModalOpen] = useState(false); + const [agentToDelete, setAgentToDelete] = useState(null); + const [isConfirmModalOpen, setConfirmModalOpen] = useState(false); const [prevViewport, setPrevViewport] = useState(null); const [newlyCreatedAgentId, setNewlyCreatedAgentId] = useState( null, @@ -516,6 +525,29 @@ function InnerWorkspace() { }, 1000); }, [setViewport, getViewport, fitView]); + // Agent edit/delete handlers for FocusedView (with confirmation modal) + const handleEditAgentFromFocus = useCallback( + (agentId: string) => { + const agent = agents.find((a) => a.id === agentId); + if (agent) { + setEditingAgent(agent); + setEditModalOpen(true); + } + }, + [agents], + ); + + const handleDeleteAgentFromFocus = useCallback( + (agentId: string) => { + const agent = agents.find((a) => a.id === agentId); + if (agent) { + setAgentToDelete(agent); + setConfirmModalOpen(true); + } + }, + [agents], + ); + // Viewport change handler const handleViewportChange = useCallback((_: unknown, viewport: Viewport) => { if (focusedAgentIdRef.current) return; @@ -622,6 +654,8 @@ function InnerWorkspace() { onClose={handleCloseFocus} onSwitchAgent={(id) => handleFocus(id)} onCanvasClick={handleCloseFocus} + onEditAgent={handleEditAgentFromFocus} + onDeleteAgent={handleDeleteAgentFromFocus} /> )} @@ -630,6 +664,58 @@ function InnerWorkspace() { isOpen={isAddModalOpen} onClose={() => setAddModalOpen(false)} /> + + {/* Edit Agent Modal */} + {editingAgent && ( + { + setEditModalOpen(false); + setEditingAgent(null); + }} + sessionId="" + agentId={editingAgent.id} + agentName={editingAgent.name} + agent={editingAgent} + currentAvatar={editingAgent.avatar ?? undefined} + onAvatarChange={(avatarUrl) => { + setEditingAgent({ ...editingAgent, avatar: avatarUrl }); + updateAgentAvatar(editingAgent.id, avatarUrl); + }} + onGridSizeChange={() => {}} + onDelete={ + publishedAgentIds.has(editingAgent.id) + ? undefined + : () => { + deleteAgent(editingAgent.id); + setEditModalOpen(false); + setEditingAgent(null); + } + } + /> + )} + + {/* Delete Confirmation Modal */} + {agentToDelete && ( + { + setConfirmModalOpen(false); + setAgentToDelete(null); + }} + onConfirm={() => { + if (publishedAgentIds.has(agentToDelete.id)) return; + deleteAgent(agentToDelete.id); + setConfirmModalOpen(false); + setAgentToDelete(null); + }} + title={t("agents.deleteAgent")} + message={t("agents.deleteConfirmation", { name: agentToDelete.name })} + confirmLabel={t("common.delete")} + cancelLabel={t("common.cancel")} + /> + )}
); } diff --git a/web/src/app/chat/spatial/FocusedView.tsx b/web/src/app/chat/spatial/FocusedView.tsx index 4885b120..23cd23e2 100644 --- a/web/src/app/chat/spatial/FocusedView.tsx +++ b/web/src/app/chat/spatial/FocusedView.tsx @@ -1,7 +1,9 @@ +import { AgentList } from "@/components/agents"; import XyzenChat from "@/components/layouts/XyzenChat"; import { useXyzen } from "@/store"; +import type { Agent } from "@/types/agents"; import { motion } from "framer-motion"; -import { useEffect, useRef } from "react"; +import { useCallback, useEffect, useMemo, useRef } from "react"; import { AgentData } from "./types"; interface FocusedViewProps { @@ -10,6 +12,9 @@ interface FocusedViewProps { onClose: () => void; onSwitchAgent: (id: string) => void; onCanvasClick?: () => void; // Callback specifically for canvas clicks + // Agent edit/delete handlers + onEditAgent?: (agentId: string) => void; + onDeleteAgent?: (agentId: string) => void; } export function FocusedView({ @@ -18,12 +23,82 @@ export function FocusedView({ onClose, onSwitchAgent, onCanvasClick, + onEditAgent, + onDeleteAgent, }: FocusedViewProps) { const switcherRef = useRef(null); const chatRef = useRef(null); const { activateChannelForAgent } = useXyzen(); + // Convert AgentData to Agent type for AgentList component + const agentsForList: Agent[] = useMemo( + () => + agents.map((a) => ({ + id: a.id, // Use node ID for switching + name: a.name, + description: a.desc, + avatar: a.avatar, + user_id: "", + created_at: "", + updated_at: "", + })), + [agents], + ); + + // Create a map for quick lookup of original AgentData + const agentDataMap = useMemo( + () => new Map(agents.map((a) => [a.id, a])), + [agents], + ); + + // Get selected agent's node ID + const selectedAgentId = useMemo( + () => agents.find((a) => a.name === agent.name)?.id, + [agents, agent.name], + ); + + // Callbacks to get status and role from original AgentData + const getAgentStatus = useCallback( + (a: Agent) => { + const status = agentDataMap.get(a.id)?.status; + // Map "offline" to "idle" since compact variant only supports "idle" | "busy" + return status === "busy" ? "busy" : "idle"; + }, + [agentDataMap], + ); + + const getAgentRole = useCallback( + (a: Agent) => agentDataMap.get(a.id)?.role, + [agentDataMap], + ); + + const handleAgentClick = useCallback( + (a: Agent) => onSwitchAgent(a.id), + [onSwitchAgent], + ); + + // Map node id back to real agentId for edit/delete + const handleEditClick = useCallback( + (a: Agent) => { + const agentData = agentDataMap.get(a.id); + if (agentData?.agentId && onEditAgent) { + onEditAgent(agentData.agentId); + } + }, + [agentDataMap, onEditAgent], + ); + + const handleDeleteClick = useCallback( + (a: Agent) => { + const agentData = agentDataMap.get(a.id); + if (agentData?.agentId && onDeleteAgent) { + onDeleteAgent(agentData.agentId); + } + }, + [agentDataMap, onDeleteAgent], + ); + // Activate the channel for the selected agent useEffect(() => { if (agent.agentId) { @@ -126,40 +201,17 @@ export function FocusedView({ Active Agents
-
- {agents.map((a) => ( - - ))} +
+
diff --git a/web/src/components/agents/AgentList.tsx b/web/src/components/agents/AgentList.tsx new file mode 100644 index 00000000..3fff480c --- /dev/null +++ b/web/src/components/agents/AgentList.tsx @@ -0,0 +1,115 @@ +"use client"; + +import type { Agent } from "@/types/agents"; +import { motion, type Variants } from "framer-motion"; +import React from "react"; +import { AgentListItem } from "./AgentListItem"; + +// Container animation variants for detailed variant +const containerVariants: Variants = { + hidden: { opacity: 0 }, + visible: { + opacity: 1, + transition: { + staggerChildren: 0.08, + delayChildren: 0.1, + }, + }, +}; + +// Base props for both variants +interface AgentListBaseProps { + agents: Agent[]; + onAgentClick?: (agent: Agent) => void; +} + +// Props for detailed variant +interface DetailedAgentListProps extends AgentListBaseProps { + variant: "detailed"; + publishedAgentIds?: Set; + lastConversationTimeByAgent?: Record; + onEdit?: (agent: Agent) => void; + onDelete?: (agent: Agent) => void; + // Compact variant props not used + selectedAgentId?: never; + getAgentStatus?: never; + getAgentRole?: never; +} + +// Props for compact variant +interface CompactAgentListProps extends AgentListBaseProps { + variant: "compact"; + selectedAgentId?: string; + getAgentStatus?: (agent: Agent) => "idle" | "busy"; + getAgentRole?: (agent: Agent) => string | undefined; + // Right-click menu support (shared with detailed) + publishedAgentIds?: Set; + onEdit?: (agent: Agent) => void; + onDelete?: (agent: Agent) => void; + // Detailed variant props not used + lastConversationTimeByAgent?: never; +} + +export type AgentListProps = DetailedAgentListProps | CompactAgentListProps; + +export const AgentList: React.FC = (props) => { + const { agents, variant, onAgentClick } = props; + + if (variant === "detailed") { + const { publishedAgentIds, lastConversationTimeByAgent, onEdit, onDelete } = + props as DetailedAgentListProps; + + return ( + + {agents.map((agent) => ( + + ))} + + ); + } + + // Compact variant + const { + selectedAgentId, + getAgentStatus, + getAgentRole, + publishedAgentIds, + onEdit, + onDelete, + } = props as CompactAgentListProps; + + return ( +
+ {agents.map((agent) => ( + + ))} +
+ ); +}; + +export default AgentList; diff --git a/web/src/components/agents/AgentListItem.tsx b/web/src/components/agents/AgentListItem.tsx new file mode 100644 index 00000000..e7a78e0b --- /dev/null +++ b/web/src/components/agents/AgentListItem.tsx @@ -0,0 +1,474 @@ +"use client"; + +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/animate-ui/components/animate/tooltip"; +import { Badge } from "@/components/base/Badge"; +import { formatTime } from "@/lib/formatDate"; +import type { Agent } from "@/types/agents"; +import { + PencilIcon, + ShoppingBagIcon, + TrashIcon, +} from "@heroicons/react/24/outline"; +import { motion, type Variants } from "framer-motion"; +import React, { useEffect, useRef, useState } from "react"; +import { createPortal } from "react-dom"; +import { useTranslation } from "react-i18next"; + +// Animation variants for detailed variant +const itemVariants: Variants = { + hidden: { y: 20, opacity: 0 }, + visible: { + y: 0, + opacity: 1, + transition: { + type: "spring", + stiffness: 100, + damping: 12, + }, + }, +}; + +// Context menu component +interface ContextMenuProps { + x: number; + y: number; + onEdit: () => void; + onDelete: () => void; + onClose: () => void; + isDefaultAgent?: boolean; + isMarketplacePublished?: boolean; +} + +const ContextMenu: React.FC = ({ + x, + y, + onEdit, + onDelete, + onClose, + isDefaultAgent = false, + isMarketplacePublished = false, +}) => { + const { t } = useTranslation(); + const menuRef = useRef(null); + + useEffect(() => { + const handleClickOutside = (event: MouseEvent) => { + if (menuRef.current && !menuRef.current.contains(event.target as Node)) { + onClose(); + } + }; + + const handleEscape = (event: KeyboardEvent) => { + if (event.key === "Escape") { + onClose(); + } + }; + + document.addEventListener("mousedown", handleClickOutside); + document.addEventListener("keydown", handleEscape); + + return () => { + document.removeEventListener("mousedown", handleClickOutside); + document.removeEventListener("keydown", handleEscape); + }; + }, [onClose]); + + return ( + + + {isMarketplacePublished ? ( + + + + + + + + {t("agents.deleteBlockedMessage", { + defaultValue: + "This agent is published to Agent Market. Please unpublish it first, then delete it.", + })} + + + ) : ( + + )} + + ); +}; + +// Shared props for both variants +interface AgentListItemBaseProps { + agent: Agent; + onClick?: (agent: Agent) => void; +} + +// Props specific to detailed variant +interface DetailedVariantProps extends AgentListItemBaseProps { + variant: "detailed"; + isMarketplacePublished?: boolean; + lastConversationTime?: string; + onEdit?: (agent: Agent) => void; + onDelete?: (agent: Agent) => void; + // Compact variant props not used + isSelected?: never; + status?: never; + role?: never; +} + +// Props specific to compact variant +interface CompactVariantProps extends AgentListItemBaseProps { + variant: "compact"; + isSelected?: boolean; + status?: "idle" | "busy"; + role?: string; + // Right-click menu support (shared with detailed) + isMarketplacePublished?: boolean; + onEdit?: (agent: Agent) => void; + onDelete?: (agent: Agent) => void; + // Detailed variant props not used + lastConversationTime?: never; +} + +export type AgentListItemProps = DetailedVariantProps | CompactVariantProps; + +// Detailed variant component (for sidebar) +const DetailedAgentListItem: React.FC = ({ + agent, + isMarketplacePublished = false, + lastConversationTime, + onClick, + onEdit, + onDelete, +}) => { + const { t } = useTranslation(); + const [contextMenu, setContextMenu] = useState<{ + x: number; + y: number; + } | null>(null); + + const longPressTimer = useRef | null>(null); + const isLongPress = useRef(false); + + const handleTouchStart = (e: React.TouchEvent) => { + isLongPress.current = false; + const touch = e.touches[0]; + const { clientX, clientY } = touch; + + longPressTimer.current = setTimeout(() => { + setContextMenu({ x: clientX, y: clientY }); + // Haptic feedback (best-effort) + try { + if ("vibrate" in navigator) { + navigator.vibrate(10); + } + } catch { + // ignore + } + }, 500); + }; + + const handleTouchEnd = () => { + if (longPressTimer.current) { + clearTimeout(longPressTimer.current); + } + }; + + const handleTouchMove = () => { + if (longPressTimer.current) { + clearTimeout(longPressTimer.current); + longPressTimer.current = null; + } + }; + + // Check if it's a default agent based on tags + const isDefaultAgent = agent.tags?.some((tag) => tag.startsWith("default_")); + + const handleContextMenu = (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + + setContextMenu({ + x: e.clientX, + y: e.clientY, + }); + }; + + return ( + <> + { + if (isLongPress.current) return; + onClick?.(agent); + }} + onContextMenu={handleContextMenu} + onTouchStart={handleTouchStart} + onTouchEnd={handleTouchEnd} + onTouchMove={handleTouchMove} + className={` + group relative flex cursor-pointer items-start gap-4 rounded-sm border p-3 + border-neutral-200 bg-white hover:bg-neutral-50 dark:border-neutral-800 dark:bg-neutral-900 dark:hover:bg-neutral-800/60 + ${agent.id === "default-chat" ? "select-none" : ""} + `} + > + {/* Avatar */} +
+ {agent.name} +
+ + {/* Content */} +
+
+

+ {agent.name} +

+ + {/* Marketplace published badge */} + {isMarketplacePublished && ( + + + + + + + + + + + + {t("agents.badges.marketplace", { + defaultValue: "Published to Marketplace", + })} + + + )} +
+ +

+ {agent.description} +

+ + {/* Last conversation time */} + {lastConversationTime && ( +

+ {formatTime(lastConversationTime)} +

+ )} +
+
+ + {/* Context menu - rendered via portal to escape overflow:hidden containers */} + {contextMenu && + createPortal( + onEdit?.(agent)} + onDelete={() => onDelete?.(agent)} + onClose={() => setContextMenu(null)} + isDefaultAgent={isDefaultAgent} + isMarketplacePublished={isMarketplacePublished} + />, + document.body, + )} + + ); +}; + +// Compact variant component (for spatial workspace switcher) +const CompactAgentListItem: React.FC = ({ + agent, + isSelected = false, + status = "idle", + role, + isMarketplacePublished = false, + onClick, + onEdit, + onDelete, +}) => { + const [contextMenu, setContextMenu] = useState<{ + x: number; + y: number; + } | null>(null); + + const longPressTimer = useRef | null>(null); + const isLongPress = useRef(false); + + // Check if it's a default agent based on tags + const isDefaultAgent = agent.tags?.some((tag) => tag.startsWith("default_")); + + const handleContextMenu = (e: React.MouseEvent) => { + if (!onEdit && !onDelete) return; // No context menu if no handlers + e.preventDefault(); + e.stopPropagation(); + setContextMenu({ x: e.clientX, y: e.clientY }); + }; + + const handleTouchStart = (e: React.TouchEvent) => { + if (!onEdit && !onDelete) return; + isLongPress.current = false; + const touch = e.touches[0]; + const { clientX, clientY } = touch; + + longPressTimer.current = setTimeout(() => { + isLongPress.current = true; + setContextMenu({ x: clientX, y: clientY }); + try { + if ("vibrate" in navigator) { + navigator.vibrate(10); + } + } catch { + // ignore + } + }, 500); + }; + + const handleTouchEnd = () => { + if (longPressTimer.current) { + clearTimeout(longPressTimer.current); + } + }; + + const handleTouchMove = () => { + if (longPressTimer.current) { + clearTimeout(longPressTimer.current); + longPressTimer.current = null; + } + }; + + return ( + <> + + + {/* Context menu - rendered via portal to escape overflow:hidden containers */} + {contextMenu && + (onEdit || onDelete) && + createPortal( + onEdit?.(agent)} + onDelete={() => onDelete?.(agent)} + onClose={() => setContextMenu(null)} + isDefaultAgent={isDefaultAgent} + isMarketplacePublished={isMarketplacePublished} + />, + document.body, + )} + + ); +}; + +// Main component that switches between variants +export const AgentListItem: React.FC = (props) => { + if (props.variant === "detailed") { + return ; + } + return ; +}; + +export default AgentListItem; diff --git a/web/src/components/agents/index.ts b/web/src/components/agents/index.ts new file mode 100644 index 00000000..e71b8a02 --- /dev/null +++ b/web/src/components/agents/index.ts @@ -0,0 +1,2 @@ +export { AgentList, type AgentListProps } from "./AgentList"; +export { AgentListItem, type AgentListItemProps } from "./AgentListItem"; diff --git a/web/src/components/layouts/XyzenAgent.tsx b/web/src/components/layouts/XyzenAgent.tsx index 4acfe7d6..94a536b8 100644 --- a/web/src/components/layouts/XyzenAgent.tsx +++ b/web/src/components/layouts/XyzenAgent.tsx @@ -1,20 +1,10 @@ "use client"; -import { - Tooltip, - TooltipContent, - TooltipProvider, - TooltipTrigger, -} from "@/components/animate-ui/components/animate/tooltip"; -import { Badge } from "@/components/base/Badge"; + +import { TooltipProvider } from "@/components/animate-ui/components/animate/tooltip"; +import { AgentList } from "@/components/agents"; import { useAuth } from "@/hooks/useAuth"; -import { formatTime } from "@/lib/formatDate"; -import { - PencilIcon, - ShoppingBagIcon, - TrashIcon, -} from "@heroicons/react/24/outline"; -import { motion, type Variants } from "framer-motion"; -import React, { useEffect, useMemo, useRef, useState } from "react"; +import { motion } from "framer-motion"; +import { useEffect, useMemo, useState } from "react"; import { useTranslation } from "react-i18next"; import AddAgentModal from "@/components/modals/AddAgentModal"; @@ -26,325 +16,6 @@ import { useXyzen } from "@/store"; // Import types from separate file import type { Agent } from "@/types/agents"; -interface AgentCardProps { - agent: Agent; - isMarketplacePublished?: boolean; - lastConversationTime?: string; - onClick?: (agent: Agent) => void; - onEdit?: (agent: Agent) => void; - onDelete?: (agent: Agent) => void; -} - -// 定义动画变体 -const itemVariants: Variants = { - hidden: { y: 20, opacity: 0 }, - visible: { - y: 0, - opacity: 1, - transition: { - type: "spring", - stiffness: 100, - damping: 12, - }, - }, -}; - -// 右键菜单组件 -interface ContextMenuProps { - x: number; - y: number; - onEdit: () => void; - onDelete: () => void; - onClose: () => void; - isDefaultAgent?: boolean; - isMarketplacePublished?: boolean; - agent?: Agent; -} - -const ContextMenu: React.FC = ({ - x, - y, - onEdit, - onDelete, - onClose, - isDefaultAgent = false, - isMarketplacePublished = false, -}) => { - const { t } = useTranslation(); - const menuRef = useRef(null); - - useEffect(() => { - const handleClickOutside = (event: MouseEvent) => { - if (menuRef.current && !menuRef.current.contains(event.target as Node)) { - onClose(); - } - }; - - const handleEscape = (event: KeyboardEvent) => { - if (event.key === "Escape") { - onClose(); - } - }; - - document.addEventListener("mousedown", handleClickOutside); - document.addEventListener("keydown", handleEscape); - - return () => { - document.removeEventListener("mousedown", handleClickOutside); - document.removeEventListener("keydown", handleEscape); - }; - }, [onClose]); - - return ( - - - {isMarketplacePublished ? ( - - - - - - - - {t("agents.deleteBlockedMessage", { - defaultValue: - "This agent is published to Agent Market. Please unpublish it first, then delete it.", - })} - - - ) : ( - - )} - - ); -}; - -// 详细版本-包括名字,描述,头像,标签以及GPT模型 -const AgentCard: React.FC = ({ - agent, - isMarketplacePublished = false, - lastConversationTime, - onClick, - onEdit, - onDelete, -}) => { - const { t } = useTranslation(); - const [contextMenu, setContextMenu] = useState<{ - x: number; - y: number; - } | null>(null); - - const longPressTimer = useRef | null>(null); - const isLongPress = useRef(false); - - const handleTouchStart = (e: React.TouchEvent) => { - isLongPress.current = false; - const touch = e.touches[0]; - const { clientX, clientY } = touch; - - longPressTimer.current = setTimeout(() => { - setContextMenu({ x: clientX, y: clientY }); - // Haptic feedback (best-effort) - try { - if ("vibrate" in navigator) { - navigator.vibrate(10); - } - } catch { - // ignore - } - }, 500); - }; - - const handleTouchEnd = () => { - if (longPressTimer.current) { - clearTimeout(longPressTimer.current); - } - }; - - const handleTouchMove = () => { - if (longPressTimer.current) { - clearTimeout(longPressTimer.current); - longPressTimer.current = null; - } - }; - - // Check if it's a default agent based on tags - const isDefaultAgent = agent.tags?.some((tag) => tag.startsWith("default_")); - - const handleContextMenu = (e: React.MouseEvent) => { - e.preventDefault(); - e.stopPropagation(); - - setContextMenu({ - x: e.clientX, - y: e.clientY, - }); - }; - - return ( - <> - { - if (isLongPress.current) return; - onClick?.(agent); - }} - onContextMenu={handleContextMenu} - onTouchStart={handleTouchStart} - onTouchEnd={handleTouchEnd} - onTouchMove={handleTouchMove} - className={` - group relative flex cursor-pointer items-start gap-4 rounded-sm border p-3 - border-neutral-200 bg-white hover:bg-neutral-50 dark:border-neutral-800 dark:bg-neutral-900 dark:hover:bg-neutral-800/60 - ${agent.id === "default-chat" ? "select-none" : ""} - `} - > - {/* 头像 */} -
- {agent.name} -
- - {/* 内容 */} -
-
-

- {agent.name} -

- - {/* Marketplace published badge */} - {isMarketplacePublished && ( - - - - - - - - - - - - {t("agents.badges.marketplace", { - defaultValue: "Published to Marketplace", - })} - - - )} - - {/* Knowledge set badge */} - {/* {knowledgeSetName && ( -
- - 📚 {knowledgeSetName} - -
- )} */} -
- -

- {agent.description} -

- - {/* Last conversation time */} - {lastConversationTime && ( -

- {formatTime(lastConversationTime)} -

- )} -
-
- - {/* 右键菜单 */} - {contextMenu && ( - onEdit?.(agent)} - onDelete={() => onDelete?.(agent)} - onClose={() => setContextMenu(null)} - isDefaultAgent={isDefaultAgent} - isMarketplacePublished={isMarketplacePublished} - agent={agent} - /> - )} - - ); -}; - -const containerVariants: Variants = { - hidden: { opacity: 0 }, - visible: { - opacity: 1, - transition: { - staggerChildren: 0.08, - delayChildren: 0.1, - }, - }, -}; - interface XyzenAgentProps { systemAgentType?: "chat" | "all"; } @@ -502,21 +173,18 @@ export default function XyzenAgent({ - {allAgents.map((agent) => ( - - ))} +
)}
-

- {listing.name} -

+
+

+ {listing.name} +

+ {listing.fork_mode === "locked" && ( + + + {t("marketplace.forkMode.locked")} + + )} +

{t("marketplace.detail.publishedBy")}{" "} @@ -330,80 +339,119 @@ export default function AgentMarketplaceDetail({ {/* Configuration Tab */} {activeTab === "config" && (

- {listing.snapshot ? ( - <> -
- - v{listing.snapshot.version} - - - {getAgentType( - listing.snapshot.configuration.graph_config, - )} - - - {listing.snapshot.commit_message} - + {/* Locked agent - hide config for non-owners */} + {listing.fork_mode === "locked" && !isOwner ? ( +
+
+
- - {/* Model */} - {listing.snapshot.configuration.model && ( -
-

- {t("marketplace.detail.config.model")} -

-

- {listing.snapshot.configuration.model} -

+

+ {t("marketplace.fork.lockedAgent")} +

+

+ {t("marketplace.detail.config.hidden")} +

+
+ ) : ( + <> + {/* Locked agent warning for owner */} + {listing.fork_mode === "locked" && isOwner && ( +
+
+ +
+

+ {t("marketplace.fork.lockedAgent")} +

+

+ {t( + "marketplace.detail.config.lockedOwnerNote", + )} +

+
+
)} - - {/* System Prompt */} - {getDisplayPrompt(listing.snapshot.configuration) && ( -
-

- {t("marketplace.detail.config.systemPrompt")} -

-
-
-                                {getDisplayPrompt(
-                                  listing.snapshot.configuration,
+                        {listing.snapshot ? (
+                          <>
+                            
+ + v{listing.snapshot.version} + + + {getAgentType( + listing.snapshot.configuration.graph_config, )} -
+ + + {listing.snapshot.commit_message} +
-
- )} - {/* MCP Servers in Configuration */} - {listing.snapshot.mcp_server_configs && - listing.snapshot.mcp_server_configs.length > 0 && ( -
-

- {t("marketplace.detail.config.mcpServers", { - count: - listing.snapshot.mcp_server_configs.length, - })} -

-
- {listing.snapshot.mcp_server_configs.map( - (mcp, index) => ( - - {mcp.name} - - ), - )} + {/* Model */} + {listing.snapshot.configuration.model && ( +
+

+ {t("marketplace.detail.config.model")} +

+

+ {listing.snapshot.configuration.model} +

-
- )} + )} + + {/* System Prompt */} + {getDisplayPrompt( + listing.snapshot.configuration, + ) && ( +
+

+ {t("marketplace.detail.config.systemPrompt")} +

+
+
+                                    {getDisplayPrompt(
+                                      listing.snapshot.configuration,
+                                    )}
+                                  
+
+
+ )} + + {/* MCP Servers in Configuration */} + {listing.snapshot.mcp_server_configs && + listing.snapshot.mcp_server_configs.length > + 0 && ( +
+

+ {t("marketplace.detail.config.mcpServers", { + count: + listing.snapshot.mcp_server_configs + .length, + })} +

+
+ {listing.snapshot.mcp_server_configs.map( + (mcp, index) => ( + + {mcp.name} + + ), + )} +
+
+ )} + + ) : ( +
+ +

{t("marketplace.detail.config.empty")}

+
+ )} - ) : ( -
- -

{t("marketplace.detail.config.empty")}

-
)}
)} @@ -653,6 +701,7 @@ export default function AgentMarketplaceDetail({ agentName={listing.name} agentDescription={listing.description || undefined} requirements={requirements} + forkMode={listing.fork_mode} onForkSuccess={handleForkSuccess} /> )} diff --git a/web/src/app/marketplace/AgentMarketplaceManage.tsx b/web/src/app/marketplace/AgentMarketplaceManage.tsx index d8bf0ada..b6a58bcd 100644 --- a/web/src/app/marketplace/AgentMarketplaceManage.tsx +++ b/web/src/app/marketplace/AgentMarketplaceManage.tsx @@ -5,6 +5,7 @@ import { PlateReadmeViewer } from "@/components/editor/PlateReadmeViewer"; import { AgentGraphEditor } from "@/components/editors/AgentGraphEditor"; import { JsonEditor } from "@/components/editors/JsonEditor"; import ConfirmationModal from "@/components/modals/ConfirmationModal"; +import { toast } from "sonner"; import { useListingHistory, useMarketplaceListing, @@ -12,7 +13,10 @@ import { useUnpublishAgent, } from "@/hooks/useMarketplace"; import type { AgentSnapshot } from "@/service/marketplaceService"; -import { marketplaceService } from "@/service/marketplaceService"; +import { + marketplaceService, + type ForkMode, +} from "@/service/marketplaceService"; import type { GraphConfig } from "@/types/graphConfig"; import { ArrowLeftIcon, @@ -26,6 +30,8 @@ import { EyeIcon, GlobeAltIcon, HeartIcon, + LockClosedIcon, + LockOpenIcon, PencilIcon, TrashIcon, } from "@heroicons/react/24/outline"; @@ -33,7 +39,6 @@ import { Tab, TabGroup, TabList, TabPanel, TabPanels } from "@headlessui/react"; import { useQueryClient } from "@tanstack/react-query"; import { useCallback, useState } from "react"; import { useTranslation } from "react-i18next"; -import { toast } from "sonner"; interface AgentMarketplaceManageProps { marketplaceId: string; @@ -67,6 +72,7 @@ export default function AgentMarketplaceManage({ const [graphConfigError, setGraphConfigError] = useState(null); const [activeEditorTab, setActiveEditorTab] = useState(0); const [isSavingConfig, setIsSavingConfig] = useState(false); + const [isSavingForkMode, setIsSavingForkMode] = useState(false); const queryClient = useQueryClient(); @@ -177,6 +183,26 @@ export default function AgentMarketplaceManage({ ); }; + const handleForkModeChange = async (newForkMode: ForkMode) => { + if (!listing) return; + try { + setIsSavingForkMode(true); + await marketplaceService.updateListing(listing.id, { + fork_mode: newForkMode, + }); + // Invalidate queries to refresh data + queryClient.invalidateQueries({ + queryKey: ["marketplace", "listing", listing.id], + }); + toast.success(t("marketplace.manage.forkMode.success")); + } catch (error) { + console.error("Failed to update fork mode:", error); + toast.error(t("marketplace.manage.forkMode.error")); + } finally { + setIsSavingForkMode(false); + } + }; + // Configuration editing handlers const handleGraphConfigChange = useCallback((config: GraphConfig) => { setGraphConfig(config); @@ -246,13 +272,17 @@ export default function AgentMarketplaceManage({ pattern: "react", display_name: "ReAct Agent", }, + // Prompt stored in prompt_config.custom_instructions (not llm_config.prompt_template) + prompt_config: { + custom_instructions: "You are a helpful assistant.", + }, nodes: [ { id: "agent", name: "ReAct Agent", type: "llm", llm_config: { - prompt_template: "You are a helpful assistant.", + prompt_template: "", // Backend will inject from prompt_config tools_enabled: true, output_key: "response", }, @@ -814,6 +844,86 @@ export default function AgentMarketplaceManage({
+ {/* Fork Mode Settings Card */} +
+

+ {t("marketplace.manage.forkMode.title")} +

+
+ + + {isSavingForkMode && ( +
+ + {t("marketplace.manage.forkMode.saving")} +
+ )} +

+ {t("marketplace.manage.forkMode.help")} +

+
+
+ {/* Metadata Card */}

diff --git a/web/src/components/features/ForkAgentModal.tsx b/web/src/components/features/ForkAgentModal.tsx index 6c781dd9..b7b7d06b 100644 --- a/web/src/components/features/ForkAgentModal.tsx +++ b/web/src/components/features/ForkAgentModal.tsx @@ -9,8 +9,11 @@ import { CheckCircleIcon, ExclamationTriangleIcon, InformationCircleIcon, + LockClosedIcon, } from "@heroicons/react/24/outline"; import { useState } from "react"; +import { useTranslation } from "react-i18next"; +import type { ForkMode } from "@/service/marketplaceService"; interface ForkAgentModalProps { open: boolean; @@ -23,6 +26,7 @@ interface ForkAgentModalProps { knowledge_base: { name: string; file_count: number } | null; provider_needed: boolean; }; + forkMode: ForkMode; onForkSuccess?: (agentId: string) => void; } @@ -39,8 +43,10 @@ export default function ForkAgentModal({ agentName, agentDescription, requirements, + forkMode, onForkSuccess, }: ForkAgentModalProps) { + const { t } = useTranslation(); const [customName, setCustomName] = useState(`${agentName} (Fork)`); const [currentStep, setCurrentStep] = useState< "name" | "requirements" | "confirm" @@ -160,15 +166,30 @@ export default function ForkAgentModal({ {/* Step 1: Name */} {currentStep === "name" && (
-
-
- -
- Your forked agent will be completely independent. Changes - won't affect the original. + {forkMode === "locked" ? ( +
+
+ +
+

+ {t("marketplace.fork.lockedAgent")} +

+

+ {t("marketplace.fork.lockedAgentDescription")} +

+
-
+ ) : ( +
+
+ +
+ {t("marketplace.fork.editableDescription")} +
+
+
+ )} + {/* Fork Mode Selector */} + + +
+ + +
+

+ {t("marketplace.publish.forkMode.help")} +

+
+ {/* Preview Toggle */}