diff --git a/src/copaw/agents/tools/__init__.py b/src/copaw/agents/tools/__init__.py index 12a83c9f4..fefba39f7 100644 --- a/src/copaw/agents/tools/__init__.py +++ b/src/copaw/agents/tools/__init__.py @@ -21,6 +21,7 @@ from .desktop_screenshot import desktop_screenshot from .memory_search import create_memory_search_tool from .get_current_time import get_current_time +from .read_media import read_media __all__ = [ "execute_python_code", @@ -38,4 +39,5 @@ "browser_use", "create_memory_search_tool", "get_current_time", + "read_media", ] diff --git a/src/copaw/agents/tools/read_media.py b/src/copaw/agents/tools/read_media.py new file mode 100644 index 000000000..3328daefc --- /dev/null +++ b/src/copaw/agents/tools/read_media.py @@ -0,0 +1,798 @@ +# -*- coding: utf-8 -*- +"""Read media file (image, video, audio) and return appropriate Block. + +Supports: +- Local file paths (any location accessible by the system) +- file:// URLs +- http(s):// URLs + +Media Types: +- Images: PNG, JPG, GIF, WEBP, BMP +- Videos: MP4, AVI, MOV, MKV, WEBM, FLV, WMV +- Audio: MP3, WAV, AAC, OGG, M4A, FLAC, WMA + +Features: +- Image compression (using Pillow) +- Video compression with frame extraction (using FFmpeg) +- Automatic media type detection and appropriate Block return + +Security: +- Maximum file size: 20MB (before compression) +- File content validation via magic numbers +""" +# flake8: noqa: E501 +# pylint: disable=line-too-long,too-many-return-statements,too-many-branches +import base64 +import logging +import os +import tempfile +import asyncio +from pathlib import Path +from typing import Optional +from urllib.parse import unquote + +import httpx + +from agentscope.message import TextBlock, ImageBlock, AudioBlock, VideoBlock +from agentscope.tool import ToolResponse + +logger = logging.getLogger(__name__) + + +# Supported media formats and their MIME types +SUPPORTED_FORMATS = { + # Images + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + ".bmp": "image/bmp", + # Videos + ".mp4": "video/mp4", + ".avi": "video/x-msvideo", + ".mov": "video/quicktime", + ".mkv": "video/x-matroska", + ".webm": "video/webm", + ".flv": "video/x-flv", + ".wmv": "video/x-ms-wmv", + # Audio + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".aac": "audio/aac", + ".ogg": "audio/ogg", + ".m4a": "audio/mp4", + ".flac": "audio/flac", + ".wma": "audio/x-ms-wma", +} + +# File extension categories +IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"} +VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv"} +AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac", ".ogg", ".m4a", ".flac", ".wma"} + +# Image format magic numbers (file signatures) for validation +# Each entry: (offset, signature bytes) +IMAGE_MAGIC_SIGNATURES = { + ".png": (0, b"\x89PNG\r\n\x1a\n"), + ".jpg": (0, b"\xff\xd8\xff"), + ".jpeg": (0, b"\xff\xd8\xff"), + ".gif": (0, b"GIF87a"), # Also matches GIF89a (first 6 bytes same) + ".webp": (8, b"WEBP"), # RIFF header at 0, WEBP at offset 8 + ".bmp": (0, b"BM"), +} + +# Video format magic numbers +VIDEO_MAGIC_SIGNATURES = { + ".mp4": (4, b"ftyp"), # ftyp box at offset 4 + ".avi": (0, b"RIFF"), # RIFF header + ".mov": (4, b"ftyp"), # QuickTime uses ftyp + ".mkv": (0, b"\x1a\x45\xdf\xa3"), # EBML header + ".webm": (0, b"\x1a\x45\xdf\xa3"), # Same as MKV + ".flv": (0, b"FLV"), + ".wmv": (0, b"\x30\x26\xb2\x75"), # ASF header +} + +# Audio format magic numbers +AUDIO_MAGIC_SIGNATURES = { + ".mp3": (0, b"\xff\xfb"), # MPEG-1 Layer 3 + ".wav": (0, b"RIFF"), # RIFF/WAVE + ".aac": (0, b"\xff\xf1"), # ADTS + ".ogg": (0, b"OggS"), + ".m4a": (4, b"ftyp"), # Same as MP4 + ".flac": (0, b"fLaC"), + ".wma": (0, b"\x30\x26\xb2\x75"), # Same as WMV (ASF) +} + +# Maximum file size: 20MB +MAX_FILE_SIZE = 20 * 1024 * 1024 + + +def _get_media_type(file_path: str) -> Optional[str]: + """Get MIME type from file extension. + + Args: + file_path: Path to the file. + + Returns: + MIME type string or None if unsupported. + """ + ext = Path(file_path).suffix.lower() + return SUPPORTED_FORMATS.get(ext) + + +def _check_special_format(ext: str, header: bytes) -> bool: + """Check special format signatures for files with multiple variants. + + Args: + ext: File extension. + header: File header bytes. + + Returns: + True if format is valid, False otherwise. + """ + if ext == ".gif": + return header[0:6] in (b"GIF87a", b"GIF89a") + if ext == ".webp": + return header[0:4] == b"RIFF" and header[8:12] == b"WEBP" + if ext in (".mp4", ".mov", ".m4a"): + return b"ftyp" in header[4:12] + if ext == ".wav": + return header[0:4] == b"RIFF" and header[8:12] == b"WAVE" + if ext == ".avi": + return header[0:4] == b"RIFF" and header[8:12] == b"AVI " + return False + + +def _validate_media_magic(file_path: str) -> tuple[bool, str]: + """Validate that file content matches expected format. + + Args: + file_path: Path to the file. + + Returns: + Tuple of (is_valid, error_message). + """ + ext = Path(file_path).suffix.lower() + + # Get appropriate magic signatures based on extension + if ext in IMAGE_EXTENSIONS: + signatures = IMAGE_MAGIC_SIGNATURES + elif ext in VIDEO_EXTENSIONS: + signatures = VIDEO_MAGIC_SIGNATURES + elif ext in AUDIO_EXTENSIONS: + signatures = AUDIO_MAGIC_SIGNATURES + else: + return (False, f"Unsupported media format: {ext}") + + if ext not in signatures: + # No magic signature validation for this format + return (True, "") + + offset, signature = signatures[ext] + + try: + with open(file_path, "rb") as f: + # Read enough bytes to check signature + header = f.read(offset + len(signature) + 16) + + if len(header) < offset + len(signature): + return (False, "File too small to validate format") + + # Check signature at expected offset + actual_signature = header[offset : offset + len(signature)] + + # Special handling for formats with multiple variants + if _check_special_format(ext, header): + return (True, "") + + if actual_signature == signature: + return (True, "") + + return ( + False, + f"File format mismatch: file extension is {ext}, but content is not valid {ext} format", + ) + + except Exception as e: + return (False, f"File validation failed: {e}") + + +def _parse_source(source: str) -> tuple[str, Optional[str], str]: + """Parse media source into type and path/URL. + + Args: + source: Media source (local path, file:// URL, or http(s):// URL). + + Returns: + Tuple of (source_type, parsed_path_or_url, error_message). + source_type is "local", "file_url", "http_url", or "unknown". + """ + source = source.strip() + + # HTTP(S) URL + if source.startswith(("http://", "https://")): + return ("http_url", source, "") + + # file:// URL + if source.startswith("file://"): + # Remove file:// prefix and decode URL encoding + path = source[7:] + # Handle URL-encoded characters + path = unquote(path) + return ("file_url", path, "") + + # Local path (check if it looks like a path) + return ("local", source, "") + + +async def _fetch_http_media(url: str) -> tuple[bytes, str, str]: + """Fetch media from HTTP URL. + + Args: + url: HTTP(S) URL to fetch. + + Returns: + Tuple of (media_data, media_type, error_message). + """ + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url, follow_redirects=True) + response.raise_for_status() + + # Check content type + content_type = response.headers.get("content-type", "") + + # Check size + content_length = len(response.content) + if content_length > MAX_FILE_SIZE: + size_mb = content_length / (1024 * 1024) + return ( + b"", + "", + f"File too large: {size_mb:.2f}MB, maximum allowed is 20MB", + ) + + # Determine media type from content-type header + media_type = content_type.split(";")[0].strip() + + return (response.content, media_type, "") + + except httpx.TimeoutException: + return (b"", "", f"Request timeout: {url}") + except httpx.HTTPStatusError as e: + return (b"", "", f"HTTP error: {e.response.status_code}") + except Exception as e: + return (b"", "", f"Request failed: {e}") + + +def _compress_image( + input_path: str, + output_path: str, + target_size_mb: float, +) -> bool: + """Compress image to target size using Pillow. + + Args: + input_path: Path to input image. + output_path: Path to save compressed image. + target_size_mb: Target file size in MB. + + Returns: + True if compression succeeded and file is within target size. + """ + try: + from PIL import Image + except ImportError: + return False + + target_bytes = target_size_mb * 1024 * 1024 + quality = 95 + + try: + with Image.open(input_path) as img: + # Convert to RGB if necessary (handle RGBA, P, LA modes) + if img.mode in ("RGBA", "LA", "P"): + background = Image.new("RGB", img.size, (255, 255, 255)) + if img.mode == "P": + img = img.convert("RGBA") + if img.mode in ("RGBA", "LA"): + background.paste( + img, + mask=img.split()[-1] + if img.mode in ("RGBA", "LA") + else None, + ) + img = background + else: + img = img.convert("RGB") + elif img.mode != "RGB": + img = img.convert("RGB") + + # Try reducing quality first + while quality >= 20: + img.save(output_path, "JPEG", optimize=True, quality=quality) + if os.path.getsize(output_path) <= target_bytes: + return True + quality -= 5 + + # If quality reduction isn't enough, also resize + if os.path.getsize(output_path) > target_bytes: + ratio = 0.8 + while ratio > 0.3: + new_size = ( + int(img.width * ratio), + int(img.height * ratio), + ) + resized = img.resize(new_size, Image.Resampling.LANCZOS) + resized.save( + output_path, + "JPEG", + optimize=True, + quality=75, + ) + if os.path.getsize(output_path) <= target_bytes: + return True + ratio -= 0.1 + + return os.path.getsize(output_path) <= target_bytes + except Exception: + return False + + +async def _compress_video( + input_path: str, + output_path: str, + target_size_mb: float, + fps: int = 1, +) -> bool: + """Compress video using FFmpeg with optional frame extraction. + + Args: + input_path: Path to input video. + output_path: Path to save compressed video. + target_size_mb: Target file size in MB. + fps: Frames per second to extract (1 = 1 frame per second). + Use 0 to keep original frame rate. + + Returns: + True if compression succeeded. + """ + # Calculate CRF based on target size (higher = more compression) + # 5MB -> CRF 28, 10MB -> CRF 26, etc. + crf = max(18, min(28, int(28 - (target_size_mb - 5) / 5 * 2))) + + # Build FFmpeg command + cmd = [ + "ffmpeg", + "-y", # Overwrite output + "-i", + input_path, + "-c:v", + "libx264", + "-crf", + str(crf), + "-preset", + "slow", # Better compression ratio + "-c:a", + "aac", # Audio codec + "-b:a", + "64k", # Low audio bitrate + "-movflags", + "+faststart", + ] + + # Add frame rate filter if specified + if fps > 0: + cmd.extend(["-vf", f"fps={fps}"]) + cmd.extend(["-r", str(fps)]) # Output frame rate + + cmd.append(output_path) + + try: + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, _ = await process.communicate() + + if process.returncode != 0: + return False + + # Check if output file exists and is smaller + if not os.path.exists(output_path): + return False + + return True + except Exception: + return False + + +def _get_file_category(file_path: str) -> str: + """Get the category of file (image, video, audio). + + Args: + file_path: Path to the file. + + Returns: + Category string: "image", "video", "audio", or "unknown". + """ + ext = Path(file_path).suffix.lower() + if ext in IMAGE_EXTENSIONS: + return "image" + elif ext in VIDEO_EXTENSIONS: + return "video" + elif ext in AUDIO_EXTENSIONS: + return "audio" + return "unknown" + + +def _create_media_block( + category: str, + media_type: str, + base64_data: str, +) -> ImageBlock | VideoBlock | AudioBlock: + """Create appropriate media block based on category. + + Args: + category: File category ("image", "video", "audio"). + media_type: MIME type of the media. + base64_data: Base64-encoded media data. + + Returns: + ImageBlock, VideoBlock, or AudioBlock. + """ + source = { + "type": "base64", + "media_type": media_type, + "data": base64_data, + } + if category == "image": + return ImageBlock(type="image", source=source) + elif category == "video": + return VideoBlock(type="video", source=source) + else: # audio + return AudioBlock(type="audio", source=source) + + +async def _handle_http_media(url: str) -> ToolResponse: + """Handle media fetching from HTTP URL. + + Args: + url: HTTP(S) URL to fetch media from. + + Returns: + ToolResponse with media block or error. + """ + media_data, media_type, error = await _fetch_http_media(url) + if error: + return ToolResponse( + content=[TextBlock(type="text", text=f"Error: {error}")], + ) + + base64_data = base64.b64encode(media_data).decode("utf-8") + + if media_type.startswith("image/"): + block = _create_media_block("image", media_type, base64_data) + elif media_type.startswith("video/"): + block = _create_media_block("video", media_type, base64_data) + elif media_type.startswith("audio/"): + block = _create_media_block("audio", media_type, base64_data) + else: + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error: Unsupported media type: {media_type}", + ), + ], + ) + + return ToolResponse(content=[block]) + + +def _validate_local_file(file_path: str) -> tuple[bool, str]: + """Validate local file path and return resolved path or error. + + Args: + file_path: Path to validate. + + Returns: + Tuple of (is_valid, resolved_path_or_error). + """ + if not os.path.isabs(file_path): + file_path = os.path.abspath(file_path) + + if not os.path.lexists(file_path): + return False, f"Error: File does not exist: {file_path}" + + if os.path.islink(file_path) and not os.path.exists(file_path): + return ( + False, + f"Error: Symbolic link points to non-existent location: {file_path}", + ) + + file_path = os.path.realpath(file_path) + + if not os.path.isfile(file_path): + return False, f"Error: Path is not a file: {file_path}" + + return True, file_path + + +def _check_file_validity(file_path: str) -> tuple[bool, str]: + """Check file format and size validity. + + Args: + file_path: Path to the file. + + Returns: + Tuple of (is_valid, media_type_or_error). + """ + media_type = _get_media_type(file_path) + if not media_type: + supported = ", ".join(SUPPORTED_FORMATS.keys()) + return ( + False, + f"Error: Unsupported media format. Supported formats: {supported}", + ) + + is_valid, error = _validate_media_magic(file_path) + if not is_valid: + return False, f"Error: {error}" + + file_size = os.path.getsize(file_path) + if file_size > MAX_FILE_SIZE: + size_mb = file_size / (1024 * 1024) + return ( + False, + f"Error: File too large ({size_mb:.2f}MB), maximum allowed is 20MB.", + ) + + return True, media_type + + +def _compress_image_file( + file_path: str, + max_size_mb: float, +) -> tuple[str | None, bool]: + """Compress image file and return temp path. + + Args: + file_path: Path to image file. + max_size_mb: Target max size in MB. + + Returns: + Tuple of (temp_file_path, was_compressed). + """ + fd, temp_file = tempfile.mkstemp(suffix=".jpg") + os.close(fd) + + if _compress_image(file_path, temp_file, max_size_mb): + return temp_file, True + os.unlink(temp_file) + return None, False + + +async def _compress_video_file( + file_path: str, + max_size_mb: float, + video_fps: int, +) -> tuple[str | None, bool]: + """Compress video file and return temp path. + + Args: + file_path: Path to video file. + max_size_mb: Target max size in MB. + video_fps: Frame rate for compression. + + Returns: + Tuple of (temp_file_path, was_compressed). + """ + fd, temp_file = tempfile.mkstemp(suffix=".mp4") + os.close(fd) + + if await _compress_video(file_path, temp_file, max_size_mb, video_fps): + return temp_file, True + os.unlink(temp_file) + return None, False + + +async def _compress_file_if_needed( + file_path: str, + category: str, + compress: bool, + max_size_mb: float, + video_fps: int, +) -> tuple[str, bool, str | None]: + """Compress file if needed and return path to use. + + Args: + file_path: Original file path. + category: File category. + compress: Whether compression is enabled. + max_size_mb: Target max size in MB. + video_fps: Video frame rate for compression. + + Returns: + Tuple of (file_to_read, was_compressed, temp_file_path). + """ + file_size = os.path.getsize(file_path) + if not compress or file_size <= max_size_mb * 1024 * 1024: + return file_path, False, None + + temp_file: str | None = None + was_compressed = False + + if category == "image": + temp_file, was_compressed = _compress_image_file( + file_path, + max_size_mb, + ) + elif category == "video": + temp_file, was_compressed = await _compress_video_file( + file_path, + max_size_mb, + video_fps, + ) + + if was_compressed and temp_file: + return temp_file, True, temp_file + return file_path, False, None + + +async def read_media( + source: str, + compress: bool = True, + max_size_mb: float = 5.0, + video_fps: int = 1, +) -> ToolResponse: + """Read media file (image, video, audio) and return appropriate Block. + + Supports image, video and audio formats with automatic compression + to fit model input limits. + + Args: + source (`str`): + Media file source, can be: + - Local file path (e.g., /Users/xxx/video.mp4) + - file:// URL (e.g., file:///Users/xxx/audio.mp3) + - http(s):// URL (e.g., https://example.com/image.png) + + compress (`bool`): + Whether to enable compression (default True). For large files, + compression can reduce size to fit model input limits. + + max_size_mb (`float`): + Target file size limit after compression (MB), default 5MB. + Returns error if original file exceeds 20MB. + + video_fps (`int`): + Video frame extraction parameter, frames per second to keep (default 1). + - 1 = 1 frame per second (suitable for video content analysis) + - 5 = 5 frames per second (smoother) + - 0 = No frame extraction, keep original frame rate + Frame extraction can significantly reduce video file size. + + Returns: + `ToolResponse`: Contains appropriate Block (ImageBlock, VideoBlock, AudioBlock) + or error message. + + Examples: + >>> # Read local image + >>> await read_media("/path/to/photo.png") + + >>> # Read video with frame extraction (2 fps) + >>> await read_media("/path/to/video.mp4", video_fps=2) + + >>> # Read audio from URL + >>> await read_media("https://example.com/audio.mp3") + + >>> # Disable compression + >>> await read_media("/path/to/small.gif", compress=False) + """ + if not source: + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Error: No media file source provided.", + ), + ], + ) + + source_type, parsed_source, error = _parse_source(source) + if error: + return ToolResponse( + content=[TextBlock(type="text", text=f"Error: {error}")], + ) + + # Handle HTTP URLs + if source_type == "http_url": + if parsed_source is None: + return ToolResponse( + content=[TextBlock(type="text", text="Error: Invalid URL")], + ) + return await _handle_http_media(parsed_source) + + # Handle local files and file:// URLs + file_path = parsed_source + if file_path is None: + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Error: Cannot parse media file source", + ), + ], + ) + + # Validate local file + is_valid, result = _validate_local_file(file_path) + if not is_valid: + return ToolResponse(content=[TextBlock(type="text", text=result)]) + file_path = result + + # Check file format and size + is_valid, result = _check_file_validity(file_path) + if not is_valid: + return ToolResponse(content=[TextBlock(type="text", text=result)]) + media_type = result + + # Determine file category + category = _get_file_category(file_path) + + # Compress if needed + file_to_read, was_compressed, temp_file = await _compress_file_if_needed( + file_path, + category, + compress, + max_size_mb, + video_fps, + ) + if was_compressed and category == "video": + media_type = "video/mp4" + + try: + # Read and encode file + with open(file_to_read, "rb") as f: + media_data = f.read() + + base64_data = base64.b64encode(media_data).decode("utf-8") + + # Build info text + final_size_mb = len(media_data) / (1024 * 1024) + info_text = f"Media file loaded: {os.path.basename(source)} ({final_size_mb:.2f}MB)" + if was_compressed: + info_text += " [compressed]" + if category == "video" and video_fps != 1 and video_fps > 0: + info_text += f" [frame extraction: {video_fps}fps]" + + # Create block + block = _create_media_block(category, media_type, base64_data) + + return ToolResponse( + content=[TextBlock(type="text", text=info_text), block], + ) + + except Exception as e: + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Error: Failed to read file: {e}", + ), + ], + ) + finally: + # Clean up temp file if created + if temp_file and os.path.exists(temp_file): + try: + os.unlink(temp_file) + except Exception: + logger.warning( + "Failed to remove temporary file: %s", + temp_file, + exc_info=True, + ) diff --git a/src/copaw/app/auth_middleware.py b/src/copaw/app/auth_middleware.py new file mode 100644 index 000000000..a3019f8cf --- /dev/null +++ b/src/copaw/app/auth_middleware.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +"""HTTP Basic Auth middleware for FastAPI.""" + +import secrets +from typing import Optional + +from fastapi import Request, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response + + +class BasicAuthMiddleware(BaseHTTPMiddleware): + """Middleware to enforce HTTP Basic Auth on all routes + except excluded paths.""" + + def __init__( + self, + app, + username: str, + password: str, + excluded_paths: Optional[list] = None, + ): + super().__init__(app) + self.username = username + self.password = password + self.excluded_paths = excluded_paths or [] + self.security = HTTPBasic(auto_error=False) + + def _is_excluded(self, path: str) -> bool: + """Check if path is excluded from auth.""" + for excluded in self.excluded_paths: + if path == excluded or path.startswith(excluded + "/"): + return True + return False + + def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool: + """Verify username and password using constant-time comparison.""" + if not credentials: + return False + is_username_ok = secrets.compare_digest( + credentials.username, + self.username, + ) + is_password_ok = secrets.compare_digest( + credentials.password, + self.password, + ) + return is_username_ok and is_password_ok + + async def dispatch(self, request: Request, call_next) -> Response: + # Skip auth if password is not set + if not self.password: + return await call_next(request) + + path = request.url.path + + # Skip excluded paths + if self._is_excluded(path): + return await call_next(request) + + # Check for basic auth credentials + credentials = await self.security(request) + + if not credentials or not self._verify_credentials(credentials): + return Response( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Basic"}, + content="Unauthorized", + ) + + return await call_next(request) diff --git a/src/copaw/app/channels/feishu/channel.py b/src/copaw/app/channels/feishu/channel.py index 106287757..e561c4ccd 100644 --- a/src/copaw/app/channels/feishu/channel.py +++ b/src/copaw/app/channels/feishu/channel.py @@ -17,12 +17,12 @@ import json import logging import mimetypes -import re import sys import threading import time import types from collections import OrderedDict +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -53,7 +53,6 @@ FEISHU_USER_NAME_FETCH_TIMEOUT, ) from .utils import ( - build_interactive_content, extract_json_key, extract_post_image_keys, extract_post_media_file_keys, @@ -170,6 +169,7 @@ def __init__( group_policy: str = "open", allow_from: Optional[List[str]] = None, deny_message: str = "", + domain: str = "https://open.feishu.cn", require_mention: bool = False, ): super().__init__( @@ -190,6 +190,7 @@ def __init__( self.bot_prefix = bot_prefix self.encrypt_key = encrypt_key or "" self.verification_token = verification_token or "" + self.domain = domain or "https://open.feishu.cn" self._media_dir = Path(media_dir).expanduser() self._client: Any = None @@ -241,6 +242,7 @@ def from_env( group_policy=os.getenv("FEISHU_GROUP_POLICY", "open"), allow_from=allow_from, deny_message=os.getenv("FEISHU_DENY_MESSAGE", ""), + domain=os.getenv("FEISHU_DOMAIN", "https://open.feishu.cn"), require_mention=os.getenv("FEISHU_REQUIRE_MENTION", "0") == "1", ) @@ -271,6 +273,7 @@ def from_config( group_policy=config.group_policy or "open", allow_from=config.allow_from or [], deny_message=config.deny_message or "", + domain=config.domain or "https://open.feishu.cn", require_mention=config.require_mention, ) @@ -401,7 +404,7 @@ async def _get_tenant_access_token(self) -> str: return self._tenant_access_token url = ( - "https://open.feishu.cn/open-apis/auth/v3/" + f"{self.domain}/open-apis/auth/v3/" "tenant_access_token/internal" ) payload = { @@ -460,7 +463,7 @@ async def _get_user_name_by_open_id(self, open_id: str) -> Optional[str]: if open_id in self._nickname_cache: return self._nickname_cache[open_id] url = ( - "https://open.feishu.cn/open-apis/contact/v3/users/" + f"{self.domain}/open-apis/contact/v3/users/" f"{open_id}?user_id_type=open_id" ) try: @@ -497,7 +500,8 @@ async def _get_user_name_by_open_id(self, open_id: str) -> Optional[str]: # Response per Feishu doc: GET contact/v3/users/{user_id} # https://open.feishu.cn/document/server-docs/contact-v3/user/get # Body: { "code": 0, "data": { "user": { "name": ... } } } - # "name" can be string or i18n object { "zh_cn": "中文", "en": "en" } + # "name" can be string or i18n object + # { "zh_cn": "中文", "en": "en" } user = data.get("data") or {} inner = user.get("user") or {} name = None @@ -566,6 +570,96 @@ async def _get_user_name_by_open_id(self, open_id: str) -> Optional[str]: ) return None + async def _parse_post_content( + self, + message_id: str, + content_raw: str, + ) -> Dict[str, Any]: + """Parse Feishu post (rich text) content. + + Post content format: + { + "title": "...", + "content": [ + [ + {"tag": "text", "text": "..."}, + {"tag": "img", "image_key": "..."} + ], + [{"tag": "md", "text": "..."}] + ] + } + + Returns dict with "text" and "image_urls" keys. + """ + result: Dict[str, Any] = {"text": "", "image_urls": []} + try: + data = json.loads(content_raw) if content_raw else {} + except json.JSONDecodeError: + return result + + # Extract title if present + title = data.get("title", "") + text_parts: List[str] = [] + if title: + text_parts.append(title) + + # Parse content rows + content_rows = data.get("content", []) + if isinstance(content_rows, list): + for row in content_rows: + row_text = await self._parse_content_row( + row, + message_id, + result["image_urls"], + ) + if row_text: + text_parts.append(row_text) + + result["text"] = "\n".join(text_parts) + return result + + async def _parse_content_row( + self, + row: Any, + message_id: str, + image_urls: List[str], + ) -> str: + """Parse a single content row, extracting text and image URLs. + + Args: + row: Content row (list of items). + message_id: Message ID for image download. + image_urls: List to append found image URLs to. + + Returns: + Concatenated text from the row, or empty string. + """ + if not isinstance(row, list): + return "" + row_texts: List[str] = [] + for item in row: + if not isinstance(item, dict): + continue + tag = item.get("tag", "") + if tag == "text": + text = item.get("text", "") + if text: + row_texts.append(text) + elif tag == "md": + text = item.get("text", "") + if text: + row_texts.append(text) + elif tag == "img": + image_key = item.get("image_key", "") + if image_key: + url_or_path = await self._download_image_resource( + message_id, + image_key, + ) + if url_or_path: + image_urls.append(url_or_path) + return "".join(row_texts) + def _emit_request_threadsafe(self, request: Any) -> None: """Enqueue request via manager (thread-safe).""" if self._enqueue is not None: @@ -765,6 +859,21 @@ async def _on_message(self, data: "P2ImMessageReceiveV1") -> None: text_parts.append("[audio: download failed]") else: text_parts.append("[audio: missing key]") + elif msg_type == "post": + # Parse rich post content (text + images mixed) + post_content = await self._parse_post_content( + message_id, + content_raw, + ) + if post_content.get("text"): + text_parts.append(post_content["text"]) + for img_url in post_content.get("image_urls", []): + content_parts.append( + ImageContent( + type=ContentType.IMAGE, + image_url=img_url, + ), + ) else: text_parts.append(f"[{msg_type}]") @@ -836,6 +945,63 @@ async def _on_message(self, data: "P2ImMessageReceiveV1") -> None: except Exception: logger.exception("feishu _on_message failed") + async def handle_webhook_event(self, payload: Dict[str, Any]) -> None: + """Handle webhook event from Feishu HTTP callback. + + Converts webhook event format (2.0 schema) to internal format + and processes it like WebSocket events. + + Args: + payload: The webhook event payload from Feishu + """ + try: + # Extract event data from the 2.0 schema + event = payload.get("event", payload) + if not event: + logger.warning("Feishu webhook: no event data in payload") + return + + # Get message and sender from event + message = event.get("message", {}) + sender = event.get("sender", {}) + + if not message or not sender: + logger.debug("Feishu webhook: missing message or sender") + return + + # Build a compatible data structure for _on_message + # Convert webhook format to internal format + event_obj = types.SimpleNamespace( + message=types.SimpleNamespace( + message_id=message.get("message_id", ""), + chat_id=message.get("chat_id", ""), + chat_type=message.get("chat_type", "p2p"), + message_type=message.get("message_type", "text"), + content=message.get("content", ""), + ), + sender=types.SimpleNamespace( + sender_type=sender.get("sender_type", ""), + sender_id=types.SimpleNamespace( + open_id=sender.get("sender_id", {}).get("open_id", ""), + ), + name=sender.get("name", ""), + nickname=sender.get("nickname", ""), + ), + ) + + # Create wrapper like WebSocket event structure + class WebhookData: + def __init__(self, event): + self.event = event + + data = WebhookData(event_obj) + + # Process using existing _on_message logic + await self._on_message(data) + + except Exception: + logger.exception("Feishu webhook: handle_webhook_event failed") + async def _add_reaction( self, message_id: str, @@ -892,7 +1058,7 @@ async def _download_image_resource( """Download image to media_dir; return local path or None.""" token = await self._get_tenant_access_token() url = ( - f"https://open.feishu.cn/open-apis/im/v1/messages/{message_id}" + f"{self.domain}/open-apis/im/v1/messages/{message_id}" f"/resources/{image_key}" ) headers = {"Authorization": f"Bearer {token}"} @@ -939,7 +1105,7 @@ async def _download_file_resource( """ token = await self._get_tenant_access_token() url = ( - f"https://open.feishu.cn/open-apis/im/v1/messages/" + f"{self.domain}/open-apis/im/v1/messages/" f"{message_id}/resources/{file_key}?type=file" ) headers = {"Authorization": f"Bearer {token}"} @@ -1110,6 +1276,73 @@ def _build_post_content( }, } + def _build_card_v2_content( + self, + text: str, + image_keys: List[str], + header_title: Optional[str] = None, + template: str = "blue", + ) -> Dict[str, Any]: + """Build Feishu Card V2 message content (Schema 2.0). + + Args: + text: Markdown text content + image_keys: List of image keys + (obtained via _upload_image_sync) + header_title: Card header title (optional) + template: Header color template, options: + green, red, blue, orange, indigo, grey + """ + # Build body elements + elements: List[Dict[str, Any]] = [] + + # Add text content + if text: + elements.append( + { + "tag": "div", + "text": { + "tag": "lark_md", + "content": normalize_feishu_md(text), + }, + }, + ) + + # Add image elements + for image_key in image_keys: + elements.append( + { + "tag": "img", + "img_key": image_key, + "alt": {"tag": "plain_text", "content": "Image"}, + }, + ) + + # If no content, show placeholder + if not elements: + elements.append( + { + "tag": "div", + "text": {"tag": "lark_md", "content": "[empty]"}, + }, + ) + + # Build base card structure + card: Dict[str, Any] = { + "schema": "2.0", + "config": {"update_multi": True}, + "body": {"elements": elements}, + } + + # Add header (if title provided) + if header_title: + card["header"] = { + "template": template, + "title": {"tag": "plain_text", "content": header_title}, + } + + return card + def _upload_image_sync(self, data: bytes, filename: str) -> Optional[str]: """Upload image via lark client; return image_key.""" if not self._client: @@ -1184,7 +1417,7 @@ async def _upload_file(self, path_or_url: str) -> Optional[str]: file_type = "xls" if ext == "xlsx" else file_type file_type = "ppt" if ext == "pptx" else file_type mime = mimetypes.guess_type(str(path))[0] or "application/octet-stream" - url = "https://open.feishu.cn/open-apis/im/v1/files" + url = f"{self.domain}/open-apis/im/v1/files" form = aiohttp.FormData() form.add_field("file_type", file_type) form.add_field("file_name", path.name) @@ -1297,38 +1530,57 @@ def _send_message_sync( logger.exception("feishu _send_message_sync failed") return None - async def _send_text( + async def send_text( self, receive_id_type: str, receive_id: str, body: str, ) -> Optional[str]: - """Send text as post (md) or interactive card (when body has tables). + """Send text message (using Card V2 format). Returns the message_id on success, None on failure. Body already has bot_prefix if needed. """ - has_table = bool(re.search(r"^\s*\|", body, re.MULTILINE)) + card = self._build_card_v2_content(body, [], header_title=None) + content = json.dumps(card, ensure_ascii=False) + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + lambda: self._send_message_sync( + receive_id_type, + receive_id, + "interactive", + content, + ), + ) + + async def _send_card_v2( + self, + receive_id_type: str, + receive_id: str, + text: str, + image_keys: List[str], + header_title: Optional[str] = None, + template: str = "blue", + ) -> Optional[str]: + """Send Card V2 message. + + Returns the message_id on success, None on failure. + """ + card = self._build_card_v2_content( + text, + image_keys, + header_title, + template, + ) + content = json.dumps(card, ensure_ascii=False) loop = asyncio.get_running_loop() - if has_table: - content = build_interactive_content(body) - return await loop.run_in_executor( - None, - lambda: self._send_message_sync( - receive_id_type, - receive_id, - "interactive", - content, - ), - ) - post = self._build_post_content(body, []) - content = json.dumps(post, ensure_ascii=False) return await loop.run_in_executor( None, lambda: self._send_message_sync( receive_id_type, receive_id, - "post", + "interactive", content, ), ) @@ -1606,7 +1858,7 @@ async def send_content_parts( # type: ignore[override] parts: List[OutgoingContentPart], meta: Optional[Dict[str, Any]] = None, ) -> Optional[str]: - """Send text as post (md), then images, then files. + """Send content parts (using Card V2 format, text and images combined). Returns the message_id of the last successfully sent message, or None if nothing was sent. @@ -1630,8 +1882,12 @@ async def send_content_parts( # type: ignore[override] (receive_id or "")[:20], ) prefix = (meta or {}).get("bot_prefix", "") or self.bot_prefix or "" + + # Collect text and images text_parts: List[str] = [] - media_parts: List[OutgoingContentPart] = [] + image_parts: List[OutgoingContentPart] = [] + file_parts: List[OutgoingContentPart] = [] + for p in parts: t = getattr(p, "type", None) or ( p.get("type") if isinstance(p, dict) else None @@ -1642,66 +1898,70 @@ async def send_content_parts( # type: ignore[override] refusal_val = getattr(p, "refusal", None) or ( p.get("refusal") if isinstance(p, dict) else None ) + if t == ContentType.TEXT and text_val: text_parts.append(text_val or "") elif t == ContentType.REFUSAL and refusal_val: text_parts.append(refusal_val or "") + elif t == ContentType.IMAGE: + image_parts.append(p) elif t in ( - ContentType.IMAGE, ContentType.FILE, ContentType.VIDEO, ContentType.AUDIO, ): - media_parts.append(p) - body = "\n".join(text_parts).strip() + file_parts.append(p) + logger.info( "feishu send_content_parts: to_handle=%s text_parts=%s " - "media_count=%s media_types=%s", + "image_count=%s file_count=%s", to_handle[:40] if to_handle else "", len(text_parts), - len(media_parts), - [getattr(m, "type", None) for m in media_parts], + len(image_parts), + len(file_parts), ) + + # Upload all images to get image_keys + image_keys: List[str] = [] + for part in image_parts: + data, filename = await self._part_to_image_bytes(part) + if data: + loop = asyncio.get_running_loop() + image_key = await loop.run_in_executor( + None, + partial(self._upload_image_sync, data, filename), + ) + if image_key: + image_keys.append(image_key) + + # Send Card V2 message (including text and images) + body = "\n".join(text_parts).strip() if prefix and body: body = prefix + body last_message_id: Optional[str] = None - if body: - last_message_id = await self._send_text( + if body or image_keys: + last_message_id = await self._send_card_v2( receive_id_type, receive_id, body, + image_keys, ) - for part in media_parts: - pt = getattr(part, "type", None) - if pt == ContentType.IMAGE: - msg_id = await self._send_image( - receive_id_type, - receive_id, - part, - ) - logger.info( - "feishu send_content_parts: image sent ok=%s", - bool(msg_id), - ) - if msg_id: - last_message_id = msg_id - elif pt in ( - ContentType.FILE, - ContentType.VIDEO, - ContentType.AUDIO, - ): - msg_id = await self._send_file( - receive_id_type, - receive_id, - part, - ) - logger.info( - "feishu send_content_parts: file sent ok=%s type=%s", - bool(msg_id), - pt, - ) - if msg_id: - last_message_id = msg_id + + # Files are still sent separately + for part in file_parts: + msg_id = await self._send_file( + receive_id_type, + receive_id, + part, + ) + logger.info( + "feishu send_content_parts: file sent ok=%s type=%s", + bool(msg_id), + getattr(part, "type", None), + ) + if msg_id: + last_message_id = msg_id + return last_message_id async def _run_process_loop( @@ -1758,7 +2018,7 @@ async def send( text: str, meta: Optional[Dict[str, Any]] = None, ) -> None: - """Proactive send: resolve receive_id and send text as post.""" + """Proactively send text message (using Card V2 format).""" if not self.enabled: return recv = await self._get_receive_for_send(to_handle, meta) @@ -1772,7 +2032,7 @@ async def send( prefix = (meta or {}).get("bot_prefix", "") or self.bot_prefix or "" body = (prefix + text) if text else prefix if body: - await self._send_text(receive_id_type, receive_id, body) + await self._send_card_v2(receive_id_type, receive_id, body, []) def get_to_handle_from_request(self, request: Any) -> str: """Feishu sends by session_id; return feishu:sw: or feishu:open_id: @@ -1867,6 +2127,7 @@ async def start(self) -> None: self.app_secret, event_handler=event_handler, log_level=lark.LogLevel.INFO, + domain=self.domain, ) self._stop_event.clear() self._ws_thread = threading.Thread( diff --git a/src/copaw/app/routers/feishu_notify.py b/src/copaw/app/routers/feishu_notify.py new file mode 100644 index 000000000..29bcb207e --- /dev/null +++ b/src/copaw/app/routers/feishu_notify.py @@ -0,0 +1,336 @@ +# -*- coding: utf-8 -*- +"""Feishu (Lark) Simple Notification Router. + +Provides a simple HTTP endpoint to send messages to Feishu. +Uses environment variables for target configuration (chat_id or open_id). + +Example usage: + curl -X POST \ + "http://localhost:8000/api/v1/notify/feishu?message=Server+alert" + +Environment variables: + FEISHU_NOTIFY_CHAT_ID: Target chat ID (group chat) + FEISHU_NOTIFY_OPEN_ID: Target user open ID (private message) +""" + +import json +import logging +import os +import time +import uuid +from typing import Optional, Tuple + +from fastapi import APIRouter, Request, status +from fastapi.responses import JSONResponse + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +def _get_feishu_channel(request: Request): + """Get FeishuChannel instance from channel manager.""" + cm = getattr(request.app.state, "channel_manager", None) + if cm is None: + return None + + if hasattr(cm, "channels"): + channels = cm.channels + if isinstance(channels, dict): + channel_iter = channels.values() + else: + channel_iter = channels + for ch in channel_iter: + if getattr(ch, "channel", None) == "feishu": + return ch + return None + + +def _get_target_id() -> Tuple[Optional[str], Optional[str]]: + """Get target ID from environment variables. + + Returns: + Tuple of (receive_id_type, receive_id) or (None, None) if not + configured. + """ + chat_id = os.environ.get("FEISHU_NOTIFY_CHAT_ID") + open_id = os.environ.get("FEISHU_NOTIFY_OPEN_ID") + + if not chat_id and not open_id: + logger.warning( + "Feishu notify: FEISHU_NOTIFY_CHAT_ID or " + "FEISHU_NOTIFY_OPEN_ID not set", + ) + return None, None + + if chat_id: + return "chat_id", chat_id + return "open_id", open_id + + +async def _parse_message_from_request( + request: Request, + message: Optional[str], + source: Optional[str], +) -> Tuple[Optional[str], Optional[str]]: + """Parse message and source from request. + + Tries to read from query params first, then from JSON body or raw body. + + Returns: + Tuple of (message, source) with updated values. + """ + if message is not None and source is not None: + return message, source + + try: + body = await request.body() + body_str = body.decode("utf-8").strip() + + if not body_str: + return message, source + + # Try JSON parsing + try: + json_data = json.loads(body_str) + if isinstance(json_data, dict): + if message is None and "message" in json_data: + message = json_data["message"] + if source is None and "source" in json_data: + source = json_data["source"] + if message is None: + message = body_str + except json.JSONDecodeError: + # Not JSON, use raw body as message + if message is None: + message = body_str + except Exception as e: + logger.warning(f"Failed to read request body: {e}") + + return message, source + + +def _validate_request( + receive_id_type: Optional[str], + message: Optional[str], +) -> Optional[JSONResponse]: + """Validate notification request parameters. + + Returns: + JSONResponse with error if validation fails, None if valid. + """ + if not receive_id_type: + return JSONResponse( + content={ + "code": 400, + "message": "FEISHU_NOTIFY_CHAT_ID or " + "FEISHU_NOTIFY_OPEN_ID not set", + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + if not message or not message.strip(): + return JSONResponse( + content={"code": 400, "message": "Message is required"}, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + return None + + +def _build_simulated_event( + formatted_message: str, + source_name: str, + receive_id_type: Optional[str], + chat_id: Optional[str], + open_id: Optional[str], +) -> dict: + """Build simulated webhook event for agent processing.""" + if receive_id_type is None: + receive_id_type = "open_id" + chat_type = "group" if receive_id_type == "chat_id" else "p2p" + simulated_sender_id = open_id or f"virtual_notify_{uuid.uuid4().hex[:8]}" + + return { + "event": { + "message": { + "message_id": f"simulated_{uuid.uuid4().hex}_" + f"{int(time.time())}", + "chat_id": chat_id or open_id, + "chat_type": chat_type, + "message_type": "text", + "content": json.dumps({"text": formatted_message}), + }, + "sender": { + "sender_type": "user", + "sender_id": {"open_id": simulated_sender_id}, + "name": source_name, + "nickname": source_name, + }, + }, + } + + +async def _send_direct_message( + feishu_channel, + receive_id_type: Optional[str], + receive_id: Optional[str], + formatted_message: str, +) -> Tuple[bool, Optional[JSONResponse]]: + """Send direct message via Feishu channel. + + Returns: + Tuple of (success, error_response). + """ + try: + direct_result = await feishu_channel.send_text( + receive_id_type=receive_id_type, + receive_id=receive_id, + body=formatted_message, + ) + + if not direct_result: + logger.error("Feishu notify: send_text returned False") + return False, JSONResponse( + content={ + "code": 500, + "message": "Failed to send direct message", + }, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + return True, None + except Exception as e: + logger.exception(f"Feishu notify: failed to send message: {e}") + return False, JSONResponse( + content={"code": 500, "message": f"Internal error: {str(e)}"}, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +async def _queue_for_agent_processing( + feishu_channel, + formatted_message: str, + source_name: str, + receive_id_type: Optional[str], + chat_id: Optional[str], + open_id: Optional[str], +): + """Queue message for agent processing via simulated webhook event.""" + simulated_event = _build_simulated_event( + formatted_message, + source_name, + receive_id_type, + chat_id, + open_id, + ) + + if hasattr(feishu_channel, "handle_webhook_event"): + await feishu_channel.handle_webhook_event(simulated_event) + logger.info( + "Feishu notify: queued for agent processing via webhook event", + ) + else: + logger.warning( + "Feishu notify: handle_webhook_event not available, " + "skipping agent processing", + ) + + +@router.post("/v1/notify/feishu") +async def notify_feishu( + request: Request, + message: Optional[str] = None, + source: Optional[str] = None, +) -> JSONResponse: + """Send a simple text message to Feishu. + + Args: + message: The message content to send (from query param or body) + source: Source identifier for the message (default: "System") + + Environment: + FEISHU_NOTIFY_CHAT_ID: Target chat ID for group messages + FEISHU_NOTIFY_OPEN_ID: Target user ID for private messages + + Returns: + JSONResponse with code and message + + Examples: + # Query parameter with source + curl -X POST "http://localhost:8000/api/v1/notify/feishu\ +?message=Test message&source=Zabbix" + + # JSON body with source + curl -X POST http://localhost:8000/api/v1/notify/feishu \ + -H "Content-Type: application/json" \ + -d '{"message": "Test message", "source": "Zabbix"}' + + # Pipe input + echo "Server alert" | curl -X POST -d @- \ + http://localhost:8000/api/v1/notify/feishu + """ + # 1. Get target configuration + receive_id_type, receive_id = _get_target_id() + + # 2. Parse message from request + message, source = await _parse_message_from_request( + request, + message, + source, + ) + + # 3. Validate request + error_response = _validate_request(receive_id_type, message) + if error_response: + return error_response + + message = message.strip() + source_name = source or "System" + formatted_message = f"[{source_name}] {message}" + + # 4. Get FeishuChannel instance + feishu_channel = _get_feishu_channel(request) + if feishu_channel is None: + logger.error("Feishu notify: Feishu channel not found") + return JSONResponse( + content={"code": 503, "message": "Feishu channel not available"}, + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + ) + + # 5. Send direct message + display_id = receive_id[:20] if receive_id else "unknown" + logger.info( + "Feishu notify: sending message to " + f"{receive_id_type}={display_id}... " + f"message_len={len(formatted_message)}", + ) + + success, error_response = await _send_direct_message( + feishu_channel, + receive_id_type, + receive_id, + formatted_message, + ) + if not success: + return error_response + + # 6. Queue for agent processing + chat_id = os.environ.get("FEISHU_NOTIFY_CHAT_ID") + open_id = os.environ.get("FEISHU_NOTIFY_OPEN_ID") + await _queue_for_agent_processing( + feishu_channel, + formatted_message, + source_name, + receive_id_type, + chat_id, + open_id, + ) + + return JSONResponse( + content={ + "code": 0, + "message": "Direct message sent and queued " + "for agent processing", + }, + status_code=status.HTTP_200_OK, + ) diff --git a/src/copaw/app/routers/feishu_webhook.py b/src/copaw/app/routers/feishu_webhook.py new file mode 100644 index 000000000..0fbda9534 --- /dev/null +++ b/src/copaw/app/routers/feishu_webhook.py @@ -0,0 +1,384 @@ +# -*- coding: utf-8 -*- +"""Feishu (Lark) Webhook Router. + +Handles Feishu event subscriptions via HTTP webhook. +Supports challenge verification, signature verification, and event dispatching. +Reference: https://open.feishu.cn/document/ukTMukTMukTM/ +uYDNxYjL2QTM24iN0EjN/event-subscription-guide +""" + +import base64 +import hashlib +import hmac +import json +import logging +from typing import Any, Dict, Tuple + +from fastapi import APIRouter, HTTPException, Request, status +from fastapi.responses import JSONResponse + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +def verify_signature( + encrypt_key: str, + timestamp: str, + nonce: str, + body: str, + expected_signature: str, +) -> bool: + """Verify Feishu webhook request signature. + + Args: + encrypt_key: The encryption key configured in Feishu app + timestamp: Request timestamp from header + nonce: Request nonce from header + body: Raw request body + expected_signature: Expected signature from header + + Returns: + True if signature is valid, False otherwise + + Reference: https://open.larksuite.com/document/server-docs/ + event-subscription/event-subscription-configure-/ + encrypt-key-encryption-configuration-case + Algorithm: SHA256(timestamp + nonce + encrypt_key + body), output as hex + """ + if not encrypt_key: + logger.warning( + "No encrypt_key configured, skipping signature verification", + ) + return True + + # Lark signature algorithm: + # SHA256(timestamp + nonce + encrypt_key + body) + # Note: This is NOT HMAC, just a simple SHA256 hash + content = f"{timestamp}{nonce}{encrypt_key}{body}" + computed = hashlib.sha256(content.encode("utf-8")).hexdigest() + + # Debug logging - use info level for troubleshooting + is_valid = hmac.compare_digest(computed, expected_signature) + if is_valid: + logger.info(f"Signature verification PASSED for timestamp={timestamp}") + else: + logger.warning( + "Signature verification FAILED: " + f"timestamp={timestamp}, nonce={nonce}, " + f"key_prefix={encrypt_key[:8]}..., " + f"body_len={len(body)}, " + f"computed={computed[:20]}..., " + f"expected={expected_signature[:20]}...", + ) + + return is_valid + + +def decrypt_body(encrypt_key: str, encrypted_body: str) -> str: + """Decrypt Feishu/Lark webhook payload using AES-256-CBC. + + Args: + encrypt_key: The encryption key from Lark developer console + encrypted_body: Base64-encoded encrypted payload + + Returns: + Decrypted JSON string + + Reference: https://open.larksuite.com/document/uAjLw4CM/ukTMukTMukTM/ + event-subscription-guide/event-subscriptions/encrypt-keys + """ + if not encrypted_body: + return "" + + try: + # Try to import cryptography for AES decryption + from cryptography.hazmat.primitives.ciphers import ( + Cipher, + algorithms, + modes, + ) + from cryptography.hazmat.backends import default_backend + except ImportError: + logger.error( + "cryptography package is required for webhook decryption. " + "Install with: pip install cryptography", + ) + raise RuntimeError( + "cryptography package required for Lark webhook decryption", + ) from None + + # Decode the base64 encrypted body + encrypted_bytes = base64.b64decode(encrypted_body) + + # Derive AES key from encrypt_key using SHA-256 + # Lark uses the first 32 bytes of SHA256(encrypt_key) as the AES key + key = hashlib.sha256(encrypt_key.encode("utf-8")).digest() + + # Extract IV (first 16 bytes) and ciphertext + # Lark format: IV (16 bytes) + ciphertext + padding + iv = encrypted_bytes[:16] + ciphertext = encrypted_bytes[16:] + + # Create AES-256-CBC cipher + cipher = Cipher( + algorithms.AES(key), + modes.CBC(iv), + backend=default_backend(), + ) + decryptor = cipher.decryptor() + + # Decrypt + padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize() + + # Remove PKCS7 padding + padding_len = padded_plaintext[-1] + plaintext = padded_plaintext[:-padding_len] + + return plaintext.decode("utf-8") + + +def _get_feishu_config(): + """Load and return Feishu configuration.""" + from ...config.utils import load_config + + config = load_config() + return config.channels.feishu + + +def _get_signature_key(feishu_config) -> str: + """Get signature key from config.""" + return ( + feishu_config.webhook_encrypt_key + or feishu_config.encrypt_key + or feishu_config.webhook_verification_token + or feishu_config.verification_token + ) + + +def _decrypt_payload_if_needed( + payload: Dict[str, Any], + feishu_config, +) -> Tuple[Dict[str, Any], bool]: + """Decrypt payload if encrypted. + + Returns: + Tuple of (decrypted_payload, is_url_verification). + """ + is_url_verification = payload.get("type") == "url_verification" + + if "encrypt" not in payload: + return payload, is_url_verification + + encrypt_key = ( + getattr(feishu_config, "webhook_encrypt_key", None) + or getattr(feishu_config, "encrypt_key", None) + or getattr(feishu_config, "verification_token", None) + ) + + if not encrypt_key: + return payload, is_url_verification + + try: + decrypted = decrypt_body(encrypt_key, payload["encrypt"]) + payload = json.loads(decrypted) + logger.info("Successfully decrypted webhook payload") + is_url_verification = payload.get("type") == "url_verification" + return payload, is_url_verification + except Exception as e: + logger.error(f"Failed to decrypt webhook payload: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Decryption failed", + ) from e + + +def _verify_webhook_signature( + feishu_config, + timestamp: str, + nonce: str, + body_str: str, + signature: str, +) -> None: + """Verify webhook signature, raise HTTPException if invalid.""" + skip_sig = getattr(feishu_config, "webhook_skip_signature_verify", False) + if skip_sig: + logger.warning( + "Skipping signature verification " + "(webhook_skip_signature_verify is enabled)", + ) + return + + signature_key = _get_signature_key(feishu_config) + if not signature_key or not signature: + return + + is_valid = verify_signature( + signature_key, + timestamp, + nonce, + body_str, + signature, + ) + + if is_valid: + return + + # Try using verification_token as fallback + verification_key = ( + feishu_config.webhook_verification_token + or feishu_config.verification_token + ) + if verification_key and verification_key != signature_key: + is_valid = verify_signature( + verification_key, + timestamp, + nonce, + body_str, + signature, + ) + if is_valid: + logger.info("Signature verified using verification_token") + return + + logger.error( + f"Webhook signature verification failed. " + f"Timestamp: {timestamp}, Nonce: {nonce}, " + f"Signature key prefix: {signature_key[:8]}..., " + f"Body length: {len(body_str)}", + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid signature", + ) + + +def _find_feishu_channel(request: Request): + """Find FeishuChannel instance from channel manager.""" + cm = getattr(request.app.state, "channel_manager", None) + if cm is None: + logger.error("Channel manager not initialized") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Channel manager not ready", + ) + + if hasattr(cm, "channels"): + channels = cm.channels + if isinstance(channels, dict): + channel_iter = channels.values() + else: + channel_iter = channels + for ch in channel_iter: + if getattr(ch, "channel", None) == "feishu": + return ch + + logger.error("Feishu channel not found") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Feishu channel not available", + ) + + +@router.post("/webhook/feishu") +async def handle_feishu_webhook(request: Request) -> JSONResponse: + """Handle Feishu webhook events. + + Handles: + 1. URL verification (challenge response) + 2. Event callbacks with signature verification + 3. Message dispatching to FeishuChannel + """ + # Get request headers for verification + timestamp = request.headers.get("X-Lark-Request-Timestamp", "") + nonce = request.headers.get("X-Lark-Request-Nonce", "") + signature = request.headers.get("X-Lark-Signature", "") + + # Read raw body + body = await request.body() + body_str = body.decode("utf-8") + + # Debug logging for troubleshooting + logger.info( + f"Feishu webhook request: timestamp={timestamp}, nonce={nonce}, " + f"signature={signature[:30] if signature else 'None'}..., " + f"body_len={len(body_str)}", + ) + logger.info("Feishu webhook full body for debug: %s", body_str) + + # Parse JSON payload + try: + payload: Dict[str, Any] = json.loads(body_str) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse webhook payload: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid JSON payload", + ) from e + + # Get config for decryption and verification + feishu_config = _get_feishu_config() + + # Handle decryption if needed + payload, is_url_verification = _decrypt_payload_if_needed( + payload, + feishu_config, + ) + + # Handle URL verification + if is_url_verification: + challenge = payload.get("challenge") + logger.info(f"Feishu webhook URL verification, challenge: {challenge}") + return JSONResponse( + content={"challenge": challenge}, + status_code=status.HTTP_200_OK, + ) + + # Check if webhook is enabled + if not feishu_config.webhook_enabled: + logger.warning("Feishu webhook is disabled in config") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Webhook not enabled", + ) + + # Verify signature + _verify_webhook_signature( + feishu_config, + timestamp, + nonce, + body_str, + signature, + ) + + # Log event info + header = payload.get("header", {}) + event_id = header.get("event_id", "") + logger.debug(f"Received Feishu webhook event: {event_id}") + + # Find FeishuChannel and dispatch event + feishu_channel = _find_feishu_channel(request) + + try: + await feishu_channel.handle_webhook_event(payload) + except Exception as e: + logger.exception(f"Error handling webhook event: {e}") + # Return 200 to prevent Feishu from retrying + + return JSONResponse( + content={"code": 0, "msg": "success"}, + status_code=status.HTTP_200_OK, + ) + + +@router.get("/webhook/feishu/health") +async def feishu_webhook_health(request: Request) -> JSONResponse: + """Health check endpoint for Feishu webhook.""" + cm = getattr(request.app.state, "channel_manager", None) + return JSONResponse( + content={ + "status": "ok", + "webhook_enabled": cm is not None, + }, + status_code=status.HTTP_200_OK, + ) diff --git a/tests/channels/test_feishu.py b/tests/channels/test_feishu.py new file mode 100644 index 000000000..fba128c3e --- /dev/null +++ b/tests/channels/test_feishu.py @@ -0,0 +1,446 @@ +# -*- coding: utf-8 -*- +"""Tests for Feishu (Lark) channel. + +This module tests: +1. Event handling logic (_on_message) +2. Token retrieval and caching (_get_tenant_access_token) +3. Message routing (_route_from_handle) +4. Session ID generation logic +5. Receive ID storage and loading +""" +# pylint: disable=protected-access +# protected-access: tests need to access internal methods + +import sys +from types import ModuleType +from unittest.mock import Mock, patch + +import pytest + +# Create mock modules to avoid import dependencies +_mock_agentscope = ModuleType("agentscope_runtime") +_mock_agentscope.engine = ModuleType("agentscope_runtime.engine") +_mock_agentscope.engine.schemas = ModuleType( + "agentscope_runtime.engine.schemas", +) +_mock_agentscope.engine.schemas.agent_schemas = Mock() +_mock_agentscope.engine.schemas.agent_schemas.FileContent = Mock +_mock_agentscope.engine.schemas.agent_schemas.ImageContent = Mock +_mock_agentscope.engine.schemas.agent_schemas.TextContent = Mock +sys.modules["agentscope_runtime"] = _mock_agentscope +sys.modules["agentscope_runtime.engine"] = _mock_agentscope.engine +sys.modules[ + "agentscope_runtime.engine.schemas" +] = _mock_agentscope.engine.schemas +sys.modules[ + "agentscope_runtime.engine.schemas.agent_schemas" +] = _mock_agentscope.engine.schemas.agent_schemas + +# Mock other problematic imports +sys.modules["lark_oapi"] = Mock() +sys.modules["lark_oapi.ws"] = Mock() +sys.modules["lark_oapi.ws.client"] = Mock() + +# pylint: disable=wrong-import-position +# Import the modules we're testing +from copaw.app.channels.feishu.constants import ( # noqa: E402 + FEISHU_PROCESSED_IDS_MAX, + FEISHU_SESSION_ID_SUFFIX_LEN, + FEISHU_TOKEN_REFRESH_BEFORE_SECONDS, +) +from copaw.app.channels.feishu.utils import ( # noqa: E402 + extract_json_key, + normalize_feishu_md, + sender_display_string, + short_session_id_from_full_id, +) + +# pylint: enable=wrong-import-position + + +class TestFeishuUtils: + """Test utility functions.""" + + def test_short_session_id_from_full_id(self) -> None: + """Test session ID shortening.""" + full_id = "oc_1234567890abcdef" + result = short_session_id_from_full_id(full_id) + assert len(result) == FEISHU_SESSION_ID_SUFFIX_LEN + assert result == full_id[-FEISHU_SESSION_ID_SUFFIX_LEN:] + + def test_short_session_id_from_full_id_short_input(self) -> None: + """Test session ID with short input.""" + short_id = "abc" + result = short_session_id_from_full_id(short_id) + assert result == short_id + + def test_sender_display_string_with_nickname(self) -> None: + """Test sender display with nickname.""" + result = sender_display_string("张三", "ou_1234567890") + assert result == "张三#7890" + + def test_sender_display_string_without_nickname(self) -> None: + """Test sender display without nickname.""" + result = sender_display_string(None, "ou_1234567890") + assert result == "unknown#7890" + + def test_sender_display_string_short_sender_id(self) -> None: + """Test sender display with short sender ID.""" + result = sender_display_string("李四", "ab") + assert result == "李四#ab" + + def test_extract_json_key_found(self) -> None: + """Test extracting key from JSON.""" + content = '{"text": "hello world"}' + result = extract_json_key(content, "text") + assert result == "hello world" + + def test_extract_json_key_not_found(self) -> None: + """Test extracting missing key.""" + content = '{"other": "value"}' + result = extract_json_key(content, "text") + assert result is None + + def test_extract_json_key_invalid_json(self) -> None: + """Test extracting from invalid JSON.""" + content = "not json" + result = extract_json_key(content, "text") + assert result is None + + def test_extract_json_key_multiple_keys(self) -> None: + """Test extracting first present key from multiple options.""" + content = '{"imageKey": "img_123"}' + result = extract_json_key(content, "image_key", "imageKey") + assert result == "img_123" + + def test_normalize_feishu_md_code_fence(self) -> None: + """Test markdown normalization adds newline before code fence.""" + text = "代码如下:```python\nprint(1)\n```" + result = normalize_feishu_md(text) + assert "\n```" in result + + def test_normalize_feishu_md_empty(self) -> None: + """Test markdown normalization with empty input.""" + assert normalize_feishu_md("") == "" + assert normalize_feishu_md(None) is None + + +class TestFeishuConstants: + """Test constants are properly defined.""" + + def test_session_id_suffix_length(self) -> None: + """Test session ID suffix length constant.""" + assert FEISHU_SESSION_ID_SUFFIX_LEN == 8 + + def test_processed_ids_max(self) -> None: + """Test processed IDs max constant.""" + assert FEISHU_PROCESSED_IDS_MAX == 1000 + + def test_token_refresh_before_seconds(self) -> None: + """Test token refresh buffer constant.""" + assert FEISHU_TOKEN_REFRESH_BEFORE_SECONDS == 60 + + +# Import channel class with mocked dependencies +with patch.dict( + "sys.modules", + { + "lark_oapi": Mock(), + "lark_oapi.ws": Mock(), + "lark_oapi.ws.client": Mock(), + }, +): + from copaw.app.channels.feishu.channel import FeishuChannel + + +class TestFeishuChannelRoute: + """Test message routing logic.""" + + @pytest.fixture + def channel(self): + """Create a mock Feishu channel.""" + process_mock = Mock() + return FeishuChannel( + process=process_mock, + enabled=True, + app_id="test_app_id", + app_secret="test_app_secret", + bot_prefix="[BOT] ", + ) + + def test_route_from_handle_session_key(self, channel) -> None: + """Test routing from session key handle.""" + result = channel._route_from_handle("feishu:sw:abc123") + assert result == {"session_key": "abc123"} + + def test_route_from_handle_chat_id(self, channel) -> None: + """Test routing from chat_id handle.""" + result = channel._route_from_handle("feishu:chat_id:oc_123") + assert result == { + "receive_id_type": "chat_id", + "receive_id": "oc_123", + } + + def test_route_from_handle_open_id(self, channel) -> None: + """Test routing from open_id handle.""" + result = channel._route_from_handle("feishu:open_id:ou_123") + assert result == { + "receive_id_type": "open_id", + "receive_id": "ou_123", + } + + def test_route_from_handle_raw_chat_id(self, channel) -> None: + """Test routing from raw chat_id (oc_ prefix).""" + result = channel._route_from_handle("oc_123456") + assert result == { + "receive_id_type": "chat_id", + "receive_id": "oc_123456", + } + + def test_route_from_handle_raw_open_id(self, channel) -> None: + """Test routing from raw open_id (ou_ prefix).""" + result = channel._route_from_handle("ou_123456") + assert result == { + "receive_id_type": "open_id", + "receive_id": "ou_123456", + } + + def test_route_from_handle_unknown(self, channel) -> None: + """Test routing from unknown handle defaults to open_id.""" + result = channel._route_from_handle("random_id") + assert result == { + "receive_id_type": "open_id", + "receive_id": "random_id", + } + + def test_resolve_session_id_group_chat(self, channel) -> None: + """Test session ID resolution for group chat.""" + meta = { + "feishu_chat_id": "oc_1234567890abcdef", + "feishu_chat_type": "group", + } + result = channel.resolve_session_id("ou_sender", meta) + # Should use chat_id suffix for group chats (8 chars) + assert result == "90abcdef" + + def test_resolve_session_id_p2p(self, channel) -> None: + """Test session ID resolution for p2p chat.""" + meta = { + "feishu_chat_id": "oc_123", + "feishu_chat_type": "p2p", + } + result = channel.resolve_session_id("ou_abcdef123456", meta) + # Should use sender_id suffix for p2p (8 chars) + assert result == "ef123456" + + def test_to_handle_from_target_session(self, channel) -> None: + """Test to_handle generation with session_id.""" + result = channel.to_handle_from_target( + user_id="ou_123", + session_id="abc123", + ) + assert result == "feishu:sw:abc123" + + def test_to_handle_from_target_no_session(self, channel) -> None: + """Test to_handle generation without session_id.""" + result = channel.to_handle_from_target( + user_id="ou_123", + session_id="", + ) + assert result == "feishu:open_id:ou_123" + + +class TestFeishuChannelDeduplication: + """Test message deduplication logic.""" + + @pytest.fixture + def channel(self): + """Create a mock Feishu channel.""" + process_mock = Mock() + channel = FeishuChannel( + process=process_mock, + enabled=True, + app_id="test_app_id", + app_secret="test_app_secret", + bot_prefix="[BOT] ", + ) + return channel + + def test_processed_message_ids_deduplication(self, channel) -> None: + """Test message ID deduplication with LRU behavior. + + Simulates the trimming logic in _on_message: + while len(self._processed_message_ids) > FEISHU_PROCESSED_IDS_MAX: + self._processed_message_ids.popitem(last=False) + """ + # Add messages up to limit + for i in range(FEISHU_PROCESSED_IDS_MAX + 10): + msg_id = f"msg_{i:04d}" + channel._processed_message_ids[msg_id] = None + # Simulate the trimming logic from _on_message + while ( + len(channel._processed_message_ids) > FEISHU_PROCESSED_IDS_MAX + ): + channel._processed_message_ids.popitem(last=False) + + # Should have trimmed to max size + assert len(channel._processed_message_ids) == FEISHU_PROCESSED_IDS_MAX + # Oldest items should be removed (msg_0010 to msg_0000 are trimmed) + assert "msg_0000" not in channel._processed_message_ids + assert "msg_0009" not in channel._processed_message_ids + # Newest items should remain + assert "msg_1009" in channel._processed_message_ids + assert "msg_1000" in channel._processed_message_ids + + def test_is_message_processed(self, channel) -> None: + """Test checking if message was already processed.""" + msg_id = "om_1234567890" + channel._processed_message_ids[msg_id] = None + + # Should be able to check existence + assert msg_id in channel._processed_message_ids + + +class TestFeishuChannelBuildPostContent: + """Test post content building.""" + + @pytest.fixture + def channel(self): + """Create a mock Feishu channel.""" + process_mock = Mock() + return FeishuChannel( + process=process_mock, + enabled=True, + app_id="test_app_id", + app_secret="test_app_secret", + bot_prefix="[BOT] ", + ) + + def test_build_post_content_text_only(self, channel) -> None: + """Test building post content with text only.""" + result = channel._build_post_content("Hello world", []) + + assert "zh_cn" in result + assert result["zh_cn"]["content"][0][0]["tag"] == "md" + assert result["zh_cn"]["content"][0][0]["text"] == "Hello world" + + def test_build_post_content_with_images(self, channel) -> None: + """Test building post content with images.""" + result = channel._build_post_content("See image:", ["img_key_123"]) + + assert len(result["zh_cn"]["content"]) == 2 + assert result["zh_cn"]["content"][1][0]["tag"] == "img" + assert result["zh_cn"]["content"][1][0]["image_key"] == "img_key_123" + + def test_build_post_content_empty(self, channel) -> None: + """Test building post content with empty input.""" + result = channel._build_post_content("", []) + + # Should have at least one row with [empty] + assert result["zh_cn"]["content"][0][0]["text"] == "[empty]" + + +class TestFeishuChannelParsePostContent: + """Test parsing incoming post (rich text) content.""" + + @pytest.fixture + def channel(self): + """Create a mock Feishu channel.""" + process_mock = Mock() + return FeishuChannel( + process=process_mock, + enabled=True, + app_id="test_app_id", + app_secret="test_app_secret", + bot_prefix="[BOT] ", + ) + + @pytest.mark.asyncio + async def test_parse_post_content_text_only(self, channel) -> None: + """Test parsing post content with text only.""" + content_raw = '{"content": [[{"tag": "text", "text": "Hello world"}]]}' + result = await channel._parse_post_content("msg_123", content_raw) + + assert result["text"] == "Hello world" + assert result["image_urls"] == [] + + @pytest.mark.asyncio + async def test_parse_post_content_with_title(self, channel) -> None: + """Test parsing post content with title.""" + content_raw = ( + '{"title": "My Title", "content": ' + '[[{"tag": "text", "text": "Body text"}]]}' + ) + result = await channel._parse_post_content("msg_123", content_raw) + + assert result["text"] == "My Title\nBody text" + assert result["image_urls"] == [] + + @pytest.mark.asyncio + async def test_parse_post_content_markdown(self, channel) -> None: + """Test parsing post content with markdown tag.""" + content_raw = '{"content": [[{"tag": "md", "text": "**Bold** text"}]]}' + result = await channel._parse_post_content("msg_123", content_raw) + + assert result["text"] == "**Bold** text" + assert result["image_urls"] == [] + + @pytest.mark.asyncio + async def test_parse_post_content_multiple_rows(self, channel) -> None: + """Test parsing post content with multiple rows.""" + content_raw = ( + '{"content": [[{"tag": "text", "text": "Line 1"}],' + ' [{"tag": "text", "text": "Line 2"}]]}' + ) + result = await channel._parse_post_content("msg_123", content_raw) + + assert result["text"] == "Line 1\nLine 2" + assert result["image_urls"] == [] + + @pytest.mark.asyncio + async def test_parse_post_content_mixed_items_in_row( + self, + channel, + ) -> None: + """Test parsing post content with mixed items in a row.""" + content_raw = ( + '{"content": [[{"tag": "text", "text": "Hello "}, ' + '{"tag": "text", "text": "world"}]]}' + ) + result = await channel._parse_post_content("msg_123", content_raw) + + assert result["text"] == "Hello world" + assert result["image_urls"] == [] + + @pytest.mark.asyncio + async def test_parse_post_content_invalid_json(self, channel) -> None: + """Test parsing invalid JSON content.""" + result = await channel._parse_post_content("msg_123", "not valid json") + + assert result["text"] == "" + assert result["image_urls"] == [] + + @pytest.mark.asyncio + async def test_parse_post_content_empty_content(self, channel) -> None: + """Test parsing empty content.""" + result = await channel._parse_post_content("msg_123", "{}") + + assert result["text"] == "" + assert result["image_urls"] == [] + + +class TestFeishuChannelConfiguration: + """Test channel configuration.""" + + def test_channel_disabled_by_default(self) -> None: + """Test channel is disabled by default from env.""" + import os + + # Ensure env var is not set + if "FEISHU_CHANNEL_ENABLED" in os.environ: + del os.environ["FEISHU_CHANNEL_ENABLED"] + + process_mock = Mock() + channel = FeishuChannel.from_env(process_mock) + + assert channel.enabled is False diff --git a/tests/test_auth_middleware.py b/tests/test_auth_middleware.py new file mode 100644 index 000000000..b0f025deb --- /dev/null +++ b/tests/test_auth_middleware.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +"""Tests for BasicAuthMiddleware.""" + +import base64 +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from copaw.app.auth_middleware import BasicAuthMiddleware + + +def create_test_app( + username="admin", + password="secret", + excluded_paths=None, + enabled=True, +): + """Create a test app with auth middleware.""" + app = FastAPI() + + @app.get("/test") + def test_endpoint(): + return {"message": "ok"} + + @app.get("/webhook/feishu") + def feishu_webhook(): + return {"message": "webhook ok"} + + @app.get("/webhook/feishu/health") + def feishu_health(): + return {"status": "healthy"} + + if enabled and password: + app.add_middleware( + BasicAuthMiddleware, + username=username, + password=password, + excluded_paths=excluded_paths or ["/webhook/feishu"], + ) + + return app + + +class TestBasicAuthMiddleware: + """Test cases for BasicAuthMiddleware.""" + + def test_no_auth_returns_401(self): + """Test that requests without auth return 401.""" + app = create_test_app() + client = TestClient(app) + + response = client.get("/test") + + assert response.status_code == 401 + assert "WWW-Authenticate" in response.headers + assert response.headers["WWW-Authenticate"] == "Basic" + + def test_valid_auth_returns_200(self): + """Test that requests with valid auth return 200.""" + app = create_test_app() + client = TestClient(app) + + credentials = base64.b64encode(b"admin:secret").decode("utf-8") + response = client.get( + "/test", + headers={"Authorization": f"Basic {credentials}"}, + ) + + assert response.status_code == 200 + assert response.json() == {"message": "ok"} + + def test_invalid_auth_returns_401(self): + """Test that requests with invalid auth return 401.""" + app = create_test_app() + client = TestClient(app) + + credentials = base64.b64encode(b"admin:wrongpassword").decode("utf-8") + response = client.get( + "/test", + headers={"Authorization": f"Basic {credentials}"}, + ) + + assert response.status_code == 401 + + def test_excluded_path_no_auth_required(self): + """Test that excluded paths don't require auth.""" + app = create_test_app() + client = TestClient(app) + + response = client.get("/webhook/feishu") + + assert response.status_code == 200 + assert response.json() == {"message": "webhook ok"} + + def test_excluded_path_subpath_no_auth_required(self): + """Test that subpaths of excluded paths don't require auth.""" + app = create_test_app() + client = TestClient(app) + + response = client.get("/webhook/feishu/health") + + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + def test_disabled_auth_allows_all(self): + """Test that when password is empty, auth is disabled.""" + app = create_test_app(password="", enabled=True) + client = TestClient(app) + + response = client.get("/test") + + assert response.status_code == 200 + assert response.json() == {"message": "ok"} + + def test_wrong_username_returns_401(self): + """Test that wrong username returns 401.""" + app = create_test_app() + client = TestClient(app) + + credentials = base64.b64encode(b"wronguser:secret").decode("utf-8") + response = client.get( + "/test", + headers={"Authorization": f"Basic {credentials}"}, + ) + + assert response.status_code == 401 + + def test_custom_credentials_work(self): + """Test that custom username/password work.""" + app = create_test_app(username="customuser", password="custompass") + client = TestClient(app) + + credentials = base64.b64encode(b"customuser:custompass").decode( + "utf-8", + ) + response = client.get( + "/test", + headers={"Authorization": f"Basic {credentials}"}, + ) + + assert response.status_code == 200 diff --git a/tests/test_feishu_webhook.py b/tests/test_feishu_webhook.py new file mode 100644 index 000000000..b91ea16dc --- /dev/null +++ b/tests/test_feishu_webhook.py @@ -0,0 +1,435 @@ +# -*- coding: utf-8 -*- +"""Tests for Feishu webhook functionality.""" +# pylint: disable=redefined-outer-name,protected-access +# redefined-outer-name: pytest fixtures are reused across tests +# protected-access: tests need to access internal methods + +import base64 +import hashlib +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI, status +from fastapi.testclient import TestClient + +from copaw.app.routers.feishu_webhook import ( + decrypt_body, + router, + verify_signature, +) + + +@pytest.fixture +def app(): + """Create test FastAPI app with webhook router.""" + test_app = FastAPI() + test_app.include_router(router) + + # Mock app.state + test_app.state.channel_manager = MagicMock() + + return test_app + + +@pytest.fixture +def client(app): + """Create test client.""" + return TestClient(app) + + +@pytest.fixture +def mock_config(): + """Mock FeishuConfig for testing.""" + config = MagicMock() + feishu_config = config.channels.feishu + feishu_config.webhook_enabled = True + feishu_config.webhook_verification_token = "test_token" + feishu_config.verification_token = "" + feishu_config.webhook_encrypt_key = "" + feishu_config.encrypt_key = "" + feishu_config.webhook_skip_signature_verify = False + return config + + +class TestVerifySignature: + """Test signature verification.""" + + def test_verify_signature_valid(self): + """Test signature verification with valid signature.""" + encrypt_key = "test_key" + timestamp = "1234567890" + nonce = "test_nonce" + body = '{"test": "data"}' + + # Generate expected signature using Lark algorithm: + # SHA256(timestamp + nonce + encrypt_key + body) as hex + content = f"{timestamp}{nonce}{encrypt_key}{body}" + expected_signature = hashlib.sha256( + content.encode("utf-8"), + ).hexdigest() + + # Verify + result = verify_signature( + encrypt_key, + timestamp, + nonce, + body, + expected_signature, + ) + assert result is True + + def test_verify_signature_invalid(self): + """Test signature verification with invalid signature.""" + encrypt_key = "test_key" + timestamp = "1234567890" + nonce = "test_nonce" + body = '{"test": "data"}' + wrong_signature = "wrong_signature" + + result = verify_signature( + encrypt_key, + timestamp, + nonce, + body, + wrong_signature, + ) + assert result is False + + def test_verify_signature_no_key(self): + """Test signature verification with no key (should pass).""" + result = verify_signature( + "", + "timestamp", + "nonce", + "body", + "signature", + ) + assert result is True + + +class TestChallengeVerification: + """Test challenge verification endpoint.""" + + def test_challenge_response(self, client): + """Test URL verification challenge response.""" + challenge = "test_challenge_123" + payload = { + "type": "url_verification", + "challenge": challenge, + } + + response = client.post( + "/webhook/feishu", + json=payload, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.json() == {"challenge": challenge} + + +class TestWebhookEventHandling: + """Test webhook event handling.""" + + def test_webhook_disabled(self, client, mock_config): + """Test webhook returns 503 when disabled.""" + mock_config.channels.feishu.webhook_enabled = False + + with patch( + "copaw.app.routers.feishu_webhook._get_feishu_config", + return_value=mock_config.channels.feishu, + ): + payload = { + "schema": "2.0", + "header": {"event_id": "test_event"}, + "event": {"message": {}}, + } + + response = client.post("/webhook/feishu", json=payload) + + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + assert "not enabled" in response.json()["detail"].lower() + + def test_invalid_signature(self, client, mock_config): + """Test webhook returns 403 for invalid signature.""" + mock_config.channels.feishu.webhook_verification_token = ( + "correct_token" + ) + + with patch( + "copaw.app.routers.feishu_webhook._get_feishu_config", + return_value=mock_config.channels.feishu, + ): + payload = { + "schema": "2.0", + "header": {"event_id": "test_event"}, + "event": {"message": {}}, + } + + response = client.post( + "/webhook/feishu", + json=payload, + headers={ + "X-Lark-Request-Timestamp": "1234567890", + "X-Lark-Request-Nonce": "nonce", + "X-Lark-Signature": "invalid_signature", + }, + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_invalid_json(self, client): + """Test webhook returns 400 for invalid JSON.""" + response = client.post( + "/webhook/feishu", + data="invalid json", + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_channel_not_found(self, client, mock_config): + """Test webhook returns 503 when Feishu channel not found.""" + # Mock empty channels + app = client.app + app.state.channel_manager = MagicMock() + app.state.channel_manager.channels = {} + + with patch( + "copaw.app.routers.feishu_webhook._get_feishu_config", + return_value=mock_config.channels.feishu, + ): + payload = { + "schema": "2.0", + "header": {"event_id": "test_event"}, + "event": { + "message": { + "message_id": "test_msg", + "chat_type": "p2p", + "message_type": "text", + "content": '{"text": "hello"}', + }, + "sender": { + "sender_id": {"open_id": "test_user"}, + "sender_type": "user", + }, + }, + } + + response = client.post("/webhook/feishu", json=payload) + + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + assert "not available" in response.json()["detail"].lower() + + @pytest.mark.asyncio + async def test_successful_event_dispatch(self, client, mock_config): + """Test successful event dispatch to channel.""" + # Setup mock channel + mock_channel = MagicMock() + mock_channel.channel = "feishu" + mock_channel.handle_webhook_event = AsyncMock() + + app = client.app + app.state.channel_manager = MagicMock() + app.state.channel_manager.channels = {"feishu": mock_channel} + + with patch( + "copaw.app.routers.feishu_webhook._get_feishu_config", + return_value=mock_config.channels.feishu, + ): + payload = { + "schema": "2.0", + "header": {"event_id": "test_event_123"}, + "event": { + "message": { + "message_id": "om_test_123", + "chat_id": "oc_test_chat", + "chat_type": "p2p", + "message_type": "text", + "content": '{"text": "hello from webhook"}', + }, + "sender": { + "sender_id": {"open_id": "ou_test_user"}, + "sender_type": "user", + "name": "Test User", + }, + }, + } + + response = client.post("/webhook/feishu", json=payload) + + assert response.status_code == status.HTTP_200_OK + assert response.json() == {"code": 0, "msg": "success"} + mock_channel.handle_webhook_event.assert_called_once() + + +class TestHealthEndpoint: + """Test health check endpoint.""" + + def test_health_check(self, client): + """Test health check returns correct status.""" + response = client.get("/webhook/feishu/health") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["status"] == "ok" + assert "webhook_enabled" in data + + +class TestWebhookEventFormat: + """Test webhook event format conversion.""" + + @pytest.mark.asyncio + async def test_text_message_format(self): + """Test text message format conversion.""" + from copaw.app.channels.feishu.channel import FeishuChannel + + channel = FeishuChannel( + process=MagicMock(), + enabled=True, + app_id="test_app_id", + app_secret="test_app_secret", + bot_prefix="test", + ) + channel._enqueue = MagicMock() + + # Mock _add_reaction to avoid API calls + channel._add_reaction = AsyncMock() + + payload = { + "event": { + "message": { + "message_id": "om_test_msg", + "chat_id": "oc_test_chat", + "chat_type": "p2p", + "message_type": "text", + "content": '{"text": "Test message"}', + }, + "sender": { + "sender_id": {"open_id": "ou_test_user"}, + "sender_type": "user", + "name": "Test User", + }, + }, + } + + await channel.handle_webhook_event(payload) + + # Verify the event was processed + channel._enqueue.assert_called_once() + native = channel._enqueue.call_args[0][0] + assert native["channel_id"] == "feishu" + assert "Test User" in native["sender_id"] + assert native["meta"]["feishu_message_id"] == "om_test_msg" + assert native["meta"]["feishu_chat_type"] == "p2p" + + @pytest.mark.asyncio + async def test_group_message_format(self): + """Test group message format conversion.""" + from copaw.app.channels.feishu.channel import FeishuChannel + + channel = FeishuChannel( + process=MagicMock(), + enabled=True, + app_id="test_app_id", + app_secret="test_app_secret", + bot_prefix="test", + ) + channel._enqueue = MagicMock() + channel._add_reaction = AsyncMock() + + payload = { + "event": { + "message": { + "message_id": "om_test_msg", + "chat_id": "oc_group_chat", + "chat_type": "group", + "message_type": "text", + "content": '{"text": "Group message"}', + }, + "sender": { + "sender_id": {"open_id": "ou_test_user"}, + "sender_type": "user", + "name": "Test User", + }, + }, + } + + await channel.handle_webhook_event(payload) + + native = channel._enqueue.call_args[0][0] + assert native["meta"]["feishu_chat_type"] == "group" + assert native["meta"]["feishu_receive_id_type"] == "chat_id" + + @pytest.mark.asyncio + async def test_empty_event(self): + """Test handling of empty event.""" + from copaw.app.channels.feishu.channel import FeishuChannel + + channel = FeishuChannel( + process=MagicMock(), + enabled=True, + app_id="test_app_id", + app_secret="test_app_secret", + bot_prefix="test", + ) + channel._enqueue = MagicMock() + + # Empty event + payload = {"event": {}} + + await channel.handle_webhook_event(payload) + + # Should not call enqueue + channel._enqueue.assert_not_called() + + @pytest.mark.asyncio + async def test_bot_message_filtered(self): + """Test that bot messages are filtered out.""" + from copaw.app.channels.feishu.channel import FeishuChannel + + channel = FeishuChannel( + process=MagicMock(), + enabled=True, + app_id="test_app_id", + app_secret="test_app_secret", + bot_prefix="test", + ) + channel._enqueue = MagicMock() + channel._add_reaction = AsyncMock() + + payload = { + "event": { + "message": { + "message_id": "om_test_msg", + "chat_id": "oc_test_chat", + "chat_type": "p2p", + "message_type": "text", + "content": '{"text": "Bot message"}', + }, + "sender": { + "sender_id": {"open_id": "ou_bot_user"}, + "sender_type": "bot", # Bot sender + "name": "Bot", + }, + }, + } + + await channel.handle_webhook_event(payload) + + # Should not call enqueue for bot messages + channel._enqueue.assert_not_called() + + +class TestDecryptBody: + """Test body decryption.""" + + def test_decrypt_empty(self): + """Test decrypting empty body.""" + result = decrypt_body("key", "") + assert result == "" + + def test_decrypt_invalid_data(self): + """Test decrypt with invalid data raises error.""" + # Invalid encrypted data (too short for IV) should raise ValueError + with pytest.raises(ValueError): + decrypt_body("key", base64.b64encode(b"test").decode()) diff --git a/tests/tools/test_read_media.py b/tests/tools/test_read_media.py new file mode 100644 index 000000000..404545bbd --- /dev/null +++ b/tests/tools/test_read_media.py @@ -0,0 +1,470 @@ +# -*- coding: utf-8 -*- +# pylint: disable=redefined-outer-name,unused-import,line-too-long +"""Unit tests for read_media tool.""" +import base64 +import os +from pathlib import Path +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest + +from copaw.agents.tools.read_media import ( + read_media, + _parse_source, + _get_media_type, + _get_file_category, + MAX_FILE_SIZE, + SUPPORTED_FORMATS, + IMAGE_EXTENSIONS, + VIDEO_EXTENSIONS, + AUDIO_EXTENSIONS, +) + + +# Test fixtures +@pytest.fixture +def temp_dir(tmp_path: Path): + """Create a temporary directory for testing.""" + return tmp_path + + +@pytest.fixture +def sample_png(temp_dir: Path): + """Create a sample PNG file for testing.""" + # Minimal valid PNG (1x1 transparent pixel) + png_data = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", # noqa: E501 + ) + png_path = temp_dir / "test.png" + png_path.write_bytes(png_data) + return png_path + + +@pytest.fixture +def sample_jpg(temp_dir: Path): + """Create a sample JPG file for testing.""" + # Minimal valid JPEG + jpg_data = base64.b64decode( + "/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAP///////////////////////////////////" + "////////////////////////////////////////////////////////////////" + "////////////////////////////////////////////////wAALCA" + "ACAgBARE" + "A/8QAFQAAAQUBAQEAAAAAAAAAAAAAAAIDAQQFBgcICf/aAAgBAQABBQKb" + "pqD/2gAIAQAAAQUFpmmo/9oACAEBAAEFAlD/2gAIAQIBAQFwUf/aAAgBAwEB" + "AXBR/9oADAMBEQCEAaEAAX//2Q==", + ) + jpg_path = temp_dir / "test.jpg" + jpg_path.write_bytes(jpg_data) + return jpg_path + + +@pytest.fixture +def sample_mp4(temp_dir: Path): + """Create a minimal MP4 file for testing.""" + # Minimal MP4-like file with ftyp box + mp4_data = b"\x00\x00\x00\x20ftypisomisommp41" + mp4_path = temp_dir / "test.mp4" + mp4_path.write_bytes(mp4_data) + return mp4_path + + +@pytest.fixture +def sample_mp3(temp_dir: Path): + """Create a minimal MP3 file for testing.""" + # Minimal MP3-like file with MPEG sync word + mp3_data = b"\xff\xfb\x90\x00" + b"\x00" * 100 + mp3_path = temp_dir / "test.mp3" + mp3_path.write_bytes(mp3_data) + return mp3_path + + +class TestGetMediaType: + """Tests for _get_media_type function.""" + + def test_image_extensions(self): + """Test image format detection.""" + assert _get_media_type("test.png") == "image/png" + assert _get_media_type("test.jpg") == "image/jpeg" + assert _get_media_type("test.jpeg") == "image/jpeg" + assert _get_media_type("test.gif") == "image/gif" + assert _get_media_type("test.webp") == "image/webp" + assert _get_media_type("test.bmp") == "image/bmp" + + def test_video_extensions(self): + """Test video format detection.""" + assert _get_media_type("test.mp4") == "video/mp4" + assert _get_media_type("test.avi") == "video/x-msvideo" + assert _get_media_type("test.mov") == "video/quicktime" + assert _get_media_type("test.mkv") == "video/x-matroska" + assert _get_media_type("test.webm") == "video/webm" + + def test_audio_extensions(self): + """Test audio format detection.""" + assert _get_media_type("test.mp3") == "audio/mpeg" + assert _get_media_type("test.wav") == "audio/wav" + assert _get_media_type("test.aac") == "audio/aac" + assert _get_media_type("test.ogg") == "audio/ogg" + assert _get_media_type("test.m4a") == "audio/mp4" + assert _get_media_type("test.flac") == "audio/flac" + + def test_unsupported_format(self): + """Test unsupported format returns None.""" + assert _get_media_type("test.txt") is None + assert _get_media_type("test.pdf") is None + assert _get_media_type("test.svg") is None + + def test_no_extension(self): + """Test file without extension returns None.""" + assert _get_media_type("testfile") is None + + +class TestGetFileCategory: + """Tests for _get_file_category function.""" + + def test_image_category(self): + """Test image category detection.""" + assert _get_file_category("test.png") == "image" + assert _get_file_category("/path/to/photo.jpg") == "image" + assert _get_file_category("image.GIF") == "image" + + def test_video_category(self): + """Test video category detection.""" + assert _get_file_category("test.mp4") == "video" + assert _get_file_category("/path/to/movie.mov") == "video" + assert _get_file_category("video.AVI") == "video" + + def test_audio_category(self): + """Test audio category detection.""" + assert _get_file_category("test.mp3") == "audio" + assert _get_file_category("/path/to/song.wav") == "audio" + assert _get_file_category("audio.MP3") == "audio" + + def test_unknown_category(self): + """Test unknown category.""" + assert _get_file_category("test.txt") == "unknown" + assert _get_file_category("test.pdf") == "unknown" + + +class TestParseSource: + """Tests for _parse_source function.""" + + def test_http_url(self): + """Test HTTP URL parsing.""" + source_type, parsed, error = _parse_source( + "http://example.com/image.png", + ) + assert source_type == "http_url" + assert parsed == "http://example.com/image.png" + assert error == "" + + def test_https_url(self): + """Test HTTPS URL parsing.""" + source_type, parsed, error = _parse_source( + "https://example.com/video.mp4", + ) + assert source_type == "http_url" + assert parsed == "https://example.com/video.mp4" + assert error == "" + + def test_file_url(self): + """Test file:// URL parsing.""" + source_type, parsed, error = _parse_source( + "file:///Users/test/media.mp3", + ) + assert source_type == "file_url" + assert parsed == "/Users/test/media.mp3" + assert error == "" + + def test_file_url_encoded(self): + """Test file:// URL with encoded characters.""" + source_type, parsed, error = _parse_source( + "file:///Users/test%20folder/video.mp4", + ) + assert source_type == "file_url" + assert parsed == "/Users/test folder/video.mp4" + assert error == "" + + def test_local_path(self): + """Test local path parsing.""" + source_type, parsed, error = _parse_source("/Users/test/audio.mp3") + assert source_type == "local" + assert parsed == "/Users/test/audio.mp3" + assert error == "" + + def test_relative_path(self): + """Test relative path parsing.""" + source_type, parsed, error = _parse_source("video.mp4") + assert source_type == "local" + assert parsed == "video.mp4" + assert error == "" + + +class TestReadMedia: + """Tests for read_media async function.""" + + @pytest.mark.asyncio + async def test_empty_source(self): + """Test empty source returns error.""" + response = await read_media("") + assert len(response.content) == 1 + assert response.content[0]["type"] == "text" + assert "No media file source provided" in response.content[0]["text"] + + @pytest.mark.asyncio + async def test_read_png_file(self, sample_png: Path): + """Test reading a valid PNG file.""" + response = await read_media(str(sample_png)) + assert len(response.content) == 2 # Text + ImageBlock + assert response.content[0]["type"] == "text" + assert response.content[1]["type"] == "image" + assert response.content[1]["source"]["type"] == "base64" + assert response.content[1]["source"]["media_type"] == "image/png" + + @pytest.mark.asyncio + async def test_read_jpg_file(self, sample_jpg: Path): + """Test reading a valid JPG file.""" + response = await read_media(str(sample_jpg)) + assert len(response.content) == 2 + assert response.content[1]["type"] == "image" + assert response.content[1]["source"]["media_type"] == "image/jpeg" + + @pytest.mark.asyncio + async def test_read_mp4_file(self, sample_mp4: Path): + """Test reading a valid MP4 file.""" + response = await read_media(str(sample_mp4)) + assert len(response.content) == 2 + assert response.content[0]["type"] == "text" + assert response.content[1]["type"] == "video" + assert response.content[1]["source"]["media_type"] == "video/mp4" + + @pytest.mark.asyncio + async def test_read_mp3_file(self, sample_mp3: Path): + """Test reading a valid MP3 file.""" + response = await read_media(str(sample_mp3)) + assert len(response.content) == 2 + assert response.content[0]["type"] == "text" + assert response.content[1]["type"] == "audio" + assert response.content[1]["source"]["media_type"] == "audio/mpeg" + + @pytest.mark.asyncio + async def test_file_url(self, sample_png: Path): + """Test reading media via file:// URL.""" + file_url = f"file://{sample_png}" + response = await read_media(file_url) + assert len(response.content) == 2 + assert response.content[1]["type"] == "image" + + @pytest.mark.asyncio + async def test_nonexistent_file(self): + """Test reading nonexistent file returns error.""" + response = await read_media("/nonexistent/path/file.png") + assert response.content[0]["type"] == "text" + assert "File does not exist" in response.content[0]["text"] + + @pytest.mark.asyncio + async def test_unsupported_format(self, temp_dir: Path): + """Test reading unsupported format returns error.""" + txt_file = temp_dir / "test.txt" + txt_file.write_text("not a media file") + response = await read_media(str(txt_file)) + assert response.content[0]["type"] == "text" + assert "Unsupported media format" in response.content[0]["text"] + + @pytest.mark.asyncio + async def test_file_too_large(self, temp_dir: Path): + """Test reading file larger than 20MB returns error.""" + # Create a file slightly larger than 20MB with valid PNG header + png_header = b"\x89PNG\r\n\x1a\n" + large_file = temp_dir / "large.png" + large_file.write_bytes(png_header + b"\x00" * (MAX_FILE_SIZE + 1)) + + response = await read_media(str(large_file)) + assert response.content[0]["type"] == "text" + assert "File too large" in response.content[0]["text"] + + @pytest.mark.asyncio + async def test_directory_instead_of_file(self, temp_dir: Path): + """Test reading directory returns error.""" + subdir = temp_dir / "subdir" + subdir.mkdir() + response = await read_media(str(subdir)) + assert response.content[0]["type"] == "text" + assert "Path is not a file" in response.content[0]["text"] + + @pytest.mark.asyncio + async def test_relative_path(self, temp_dir: Path): + """Test relative path resolution.""" + # Create a file + png_data = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", # noqa: E501 + ) + (temp_dir / "relative.png").write_bytes(png_data) + + # Change to temp dir and test relative path + original_dir = os.getcwd() + try: + os.chdir(temp_dir) + response = await read_media("relative.png") + assert response.content[1]["type"] == "image" + finally: + os.chdir(original_dir) + + @pytest.mark.asyncio + async def test_http_url_success(self, sample_png: Path): + """Test fetching media from HTTP URL.""" + png_data = sample_png.read_bytes() + mock_response = AsyncMock() + mock_response.content = png_data + mock_response.headers = {"content-type": "image/png"} + mock_response.raise_for_status = AsyncMock() + + with patch("httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.get = AsyncMock( + return_value=mock_response, + ) + + response = await read_media("https://example.com/test.png") + assert response.content[0]["type"] == "image" + assert response.content[0]["source"]["media_type"] == "image/png" + + @pytest.mark.asyncio + async def test_http_url_video(self, sample_mp4: Path): + """Test fetching video from HTTP URL.""" + mp4_data = sample_mp4.read_bytes() + mock_response = AsyncMock() + mock_response.content = mp4_data + mock_response.headers = {"content-type": "video/mp4"} + mock_response.raise_for_status = AsyncMock() + + with patch("httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.get = AsyncMock( + return_value=mock_response, + ) + + response = await read_media("https://example.com/test.mp4") + assert response.content[0]["type"] == "video" + assert response.content[0]["source"]["media_type"] == "video/mp4" + + @pytest.mark.asyncio + async def test_http_url_too_large(self): + """Test HTTP URL with file too large returns error.""" + mock_response = AsyncMock() + mock_response.content = b"\x00" * (MAX_FILE_SIZE + 1) + mock_response.headers = {"content-type": "image/png"} + mock_response.raise_for_status = AsyncMock() + + with patch("httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.get = AsyncMock( + return_value=mock_response, + ) + + response = await read_media("https://example.com/large.png") + assert response.content[0]["type"] == "text" + assert "File too large" in response.content[0]["text"] + + @pytest.mark.asyncio + async def test_broken_symlink(self, temp_dir: Path): + """Test broken symlink returns error.""" + # Create symlink pointing to non-existent file + symlink_path = temp_dir / "broken_link.png" + if symlink_path.exists() or symlink_path.is_symlink(): + symlink_path.unlink() + symlink_path.symlink_to(temp_dir / "nonexistent.png") + + response = await read_media(str(symlink_path)) + assert response.content[0]["type"] == "text" + assert ( + "Symbolic link" in response.content[0]["text"] + or "does not exist" in response.content[0]["text"] + ) + + @pytest.mark.asyncio + async def test_magic_number_mismatch(self, temp_dir: Path): + """Test file with wrong magic number is rejected.""" + # Create a file with .png extension but invalid PNG content + fake_png = temp_dir / "fake.png" + fake_png.write_bytes(b"This is not a PNG file at all!") + + response = await read_media(str(fake_png)) + assert response.content[0]["type"] == "text" + assert ( + "format" in response.content[0]["text"].lower() + or "file" in response.content[0]["text"].lower() + ) + + @pytest.mark.asyncio + async def test_image_compression(self, temp_dir: Path): + """Test image compression for large images.""" + # Create a mock for PIL + mock_img = MagicMock() + mock_img.mode = "RGB" + mock_img.width = 100 + mock_img.height = 100 + + # Create a large file with valid PNG header + png_header = b"\x89PNG\r\n\x1a\n" + large_png = temp_dir / "large.png" + # Create file larger than 5MB (default max_size_mb) + large_png.write_bytes(png_header + b"\x00" * (6 * 1024 * 1024)) + + with patch("PIL.Image.open") as mock_open: + mock_open.return_value.__enter__ = MagicMock(return_value=mock_img) + mock_open.return_value.__exit__ = MagicMock(return_value=False) + + response = await read_media( + str(large_png), + compress=True, + max_size_mb=5.0, + ) + + # Should return blocks (either compressed or original) + assert len(response.content) >= 1 + + @pytest.mark.asyncio + async def test_compress_false(self, temp_dir: Path): + """Test disabling compression.""" + png_data = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", # noqa: E501 + ) + png_path = temp_dir / "test.png" + png_path.write_bytes(png_data) + + response = await read_media(str(png_path), compress=False) + assert response.content[1]["type"] == "image" + + @pytest.mark.asyncio + async def test_video_with_fps_parameter(self, sample_mp4: Path): + """Test reading video with custom FPS parameter.""" + response = await read_media(str(sample_mp4), video_fps=5) + assert response.content[0]["type"] == "text" + assert response.content[1]["type"] == "video" + + +class TestConstants: + """Tests for module constants.""" + + def test_max_file_size(self): + """Test max file size is 20MB.""" + assert MAX_FILE_SIZE == 20 * 1024 * 1024 + + def test_supported_formats(self): + """Test all required formats are supported.""" + required_images = [".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"] + required_videos = [".mp4", ".avi", ".mov", ".mkv", ".webm"] + required_audio = [".mp3", ".wav", ".aac", ".ogg", ".m4a", ".flac"] + + for fmt in required_images + required_videos + required_audio: + assert fmt in SUPPORTED_FORMATS, f"Missing format: {fmt}" + + def test_extension_categories(self): + """Test extension categories are correctly defined.""" + assert IMAGE_EXTENSIONS == { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + } + assert ".mp4" in VIDEO_EXTENSIONS + assert ".mp3" in AUDIO_EXTENSIONS