diff --git a/src/bot/handlers/message.py b/src/bot/handlers/message.py index e5fa9f78..a12240eb 100644 --- a/src/bot/handlers/message.py +++ b/src/bot/handlers/message.py @@ -360,13 +360,13 @@ async def handle_text_message( # Enhanced stream updates handler with progress tracking async def stream_handler(update_obj): - # Intercept send_image_to_user MCP tool calls. + # Intercept send_file_to_user / send_image_to_user MCP tool calls. # The SDK namespaces MCP tools as "mcp____". if update_obj.tool_calls: for tc in update_obj.tool_calls: tc_name = tc.get("name", "") - if tc_name == "send_image_to_user" or tc_name.endswith( - "__send_image_to_user" + if tc_name in ("send_file_to_user", "send_image_to_user") or tc_name.endswith( + ("__send_file_to_user", "__send_image_to_user") ): tc_input = tc.get("input", {}) file_path = tc_input.get("file_path", "") @@ -439,7 +439,7 @@ async def stream_handler(update_obj): # Delete progress message await progress_msg.delete() - # Use MCP-collected images (from send_image_to_user tool calls) + # Use MCP-collected files (from send_file_to_user tool calls) images: list[ImageAttachment] = mcp_images # Try to combine text + images when response fits in a caption diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index ac1d5304..d16aa2be 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -34,8 +34,10 @@ from .utils.draft_streamer import DraftStreamer, generate_draft_id from .utils.html_format import escape_html from .utils.image_extractor import ( + FileAttachment, ImageAttachment, should_send_as_photo, + validate_file_path, validate_image_path, ) @@ -304,8 +306,10 @@ def _register_agentic_handlers(self, app: Application) -> None: handlers = [ ("start", self.agentic_start), ("new", self.agentic_new), + ("stop", self.agentic_stop), ("status", self.agentic_status), ("verbose", self.agentic_verbose), + ("cleanup", self.agentic_cleanup), ("repo", self.agentic_repo), ("restart", command.restart_command), ] @@ -324,10 +328,11 @@ def _register_agentic_handlers(self, app: Application) -> None: group=10, ) - # File uploads -> Claude + # File uploads -> Claude (documents, audio files, video files) app.add_handler( MessageHandler( - filters.Document.ALL, self._inject_deps(self.agentic_document) + filters.Document.ALL | filters.AUDIO | filters.VIDEO, + self._inject_deps(self.agentic_document), ), group=10, ) @@ -413,8 +418,10 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] commands = [ BotCommand("start", "Start the bot"), BotCommand("new", "Start a fresh session"), + BotCommand("stop", "Stop current Claude task"), BotCommand("status", "Show session status"), - BotCommand("verbose", "Set output verbosity (0/1/2)"), + BotCommand("verbose", "Set output verbosity (0/1/2/3)"), + BotCommand("cleanup", "Delete tool/thinking messages"), BotCommand("repo", "List repos / switch workspace"), BotCommand("restart", "Restart the bot"), ] @@ -550,34 +557,76 @@ async def agentic_verbose( args = update.message.text.split()[1:] if update.message.text else [] if not args: current = self._get_verbose_level(context) - labels = {0: "quiet", 1: "normal", 2: "detailed"} + labels = {0: "quiet", 1: "normal", 2: "detailed", 3: "full"} await update.message.reply_text( f"Verbosity: {current} ({labels.get(current, '?')})\n\n" - "Usage: /verbose 0|1|2\n" + "Usage: /verbose 0|1|2|3\n" " 0 = quiet (final response only)\n" " 1 = normal (tools + reasoning)\n" - " 2 = detailed (tools with inputs + reasoning)", + " 2 = detailed (tools with inputs + reasoning)\n" + " 3 = full (commands + output, like vanilla Claude Code)", parse_mode="HTML", ) return try: level = int(args[0]) - if level not in (0, 1, 2): + if level not in (0, 1, 2, 3): raise ValueError except ValueError: await update.message.reply_text( - "Please use: /verbose 0, /verbose 1, or /verbose 2" + "Please use: /verbose 0, /verbose 1, /verbose 2, or /verbose 3" ) return context.user_data["verbose_level"] = level - labels = {0: "quiet", 1: "normal", 2: "detailed"} + labels = {0: "quiet", 1: "normal", 2: "detailed", 3: "full (commands + output)"} await update.message.reply_text( f"Verbosity set to {level} ({labels[level]})", parse_mode="HTML", ) + async def agentic_cleanup( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Delete tool/thinking messages from the last response: /cleanup.""" + msg_ids = context.user_data.get("last_tool_message_ids", []) + chat_id = context.user_data.get("last_tool_chat_id") + + if not msg_ids or not chat_id: + await update.message.reply_text("No tool messages to clean up.") + return + + deleted = 0 + for msg_id in msg_ids: + try: + await context.bot.delete_message(chat_id=chat_id, message_id=msg_id) + deleted += 1 + except Exception: + pass # message may already be deleted or too old + + context.user_data["last_tool_message_ids"] = [] + await update.message.reply_text(f"Cleaned up {deleted} messages.") + + async def agentic_stop( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Stop the currently running Claude task: /stop.""" + task = context.user_data.get("running_claude_task") + if task and not task.done(): + # Kill the Claude CLI subprocess first + claude_integration = context.bot_data.get("claude_integration") + if claude_integration: + sdk_manager = getattr(claude_integration, "sdk_manager", None) + if sdk_manager: + await sdk_manager.abort() + # Then cancel the asyncio task + task.cancel() + context.user_data["running_claude_task"] = None + await update.message.reply_text("⛔ Stopped.") + else: + await update.message.reply_text("Nothing running.") + def _format_verbose_progress( self, activity_log: List[Dict[str, Any]], @@ -591,18 +640,24 @@ def _format_verbose_progress( elapsed = time.time() - start_time lines: List[str] = [f"Working... ({elapsed:.0f}s)\n"] - for entry in activity_log[-15:]: # Show last 15 entries max + max_entries = 30 if verbose_level >= 3 else 15 + for entry in activity_log[-max_entries:]: kind = entry.get("kind", "tool") if kind == "text": - # Claude's intermediate reasoning/commentary snippet = entry.get("detail", "") - if verbose_level >= 2: + if verbose_level >= 3: + lines.append(f"\U0001f4ac {snippet}") + elif verbose_level >= 2: lines.append(f"\U0001f4ac {snippet}") else: - # Level 1: one short line lines.append(f"\U0001f4ac {snippet[:80]}") + elif kind == "result": + # Tool result (level 3 only) + result = entry.get("detail", "") + if "base64" in result and len(result) > 100: + result = "[binary/image data]" + lines.append(f" \u2514\u2500 {result[:300]}") else: - # Tool call icon = _tool_icon(entry["name"]) if verbose_level >= 2 and entry.get("detail"): lines.append(f"{icon} {entry['name']}: {entry['detail']}") @@ -678,92 +733,183 @@ def _make_stream_callback( mcp_images: Optional[List[ImageAttachment]] = None, approved_directory: Optional[Path] = None, draft_streamer: Optional[DraftStreamer] = None, + chat: Any = None, + reply_to_message_id: Optional[int] = None, ) -> Optional[Callable[[StreamUpdate], Any]]: - """Create a stream callback for verbose progress updates. + """Create a stream callback that sends per-event messages. + + At verbose >= 1, each tool call and thinking block gets its own + Telegram message (like linuz90's bot). Tool messages are tracked + in tool_log for optional cleanup. When *mcp_images* is provided, the callback also intercepts - ``send_image_to_user`` tool calls and collects validated + ``send_file_to_user`` tool calls and collects validated :class:`ImageAttachment` objects for later Telegram delivery. - - When *draft_streamer* is provided, tool activity and assistant - text are streamed to the user in real time via - ``sendMessageDraft``. - - Returns None when verbose_level is 0 **and** no MCP image - collection or draft streaming is requested. - Typing indicators are handled by a separate heartbeat task. """ need_mcp_intercept = mcp_images is not None and approved_directory is not None if verbose_level == 0 and not need_mcp_intercept and draft_streamer is None: return None - last_edit_time = [0.0] # mutable container for closure + # Track sent tool message IDs for optional cleanup + tool_message_ids: List[int] = [] + tool_log.append({"_tool_message_ids": tool_message_ids}) + last_tool_msg_id: List[Optional[int]] = [None] # track last tool msg for result appending + last_edit_time = [0.0] # mutable container for throttled progress edits + + async def _send_tool_msg(text: str) -> Optional[int]: + """Send a tool status message and track its ID.""" + if not chat: + return None + try: + msg = await chat.send_message(text) + tool_message_ids.append(msg.message_id) + return msg.message_id + except Exception: + return None + + async def _edit_tool_msg(msg_id: int, text: str) -> None: + """Edit a previously sent tool message.""" + if not chat: + return + try: + await chat._bot.edit_message_text( + text, chat_id=chat.id, message_id=msg_id + ) + except Exception: + pass + + def _format_tool_detail(name: str, tool_input: dict) -> str: + """Format tool input for display — clean, human-readable.""" + if name == "Bash": + cmd = tool_input.get("command", "") + # Show first 200 chars of command + return cmd[:200] + ("..." if len(cmd) > 200 else "") + elif name in ("Read", "Write", "Edit", "MultiEdit"): + path = tool_input.get("file_path", "") + return path.rsplit("/", 1)[-1] if "/" in path else path + elif name in ("Grep", "Glob"): + pattern = tool_input.get("pattern", "") + path = tool_input.get("path", "") + short_path = path.rsplit("/", 1)[-1] if "/" in path else path + return f'"{pattern}" in {short_path}' if short_path else f'"{pattern}"' + elif name == "Skill": + return tool_input.get("skill", "") + elif name == "ToolSearch": + return tool_input.get("query", "") + elif name in ("WebFetch", "WebSearch"): + return (tool_input.get("url", "") or tool_input.get("query", ""))[:80] + elif "send_file" in name or "send_image" in name: + path = tool_input.get("file_path", "") + return path.rsplit("/", 1)[-1] if "/" in path else path + else: + # Generic: show first meaningful value + for v in tool_input.values(): + if isinstance(v, str) and v: + return v[:80] + return "" + + def _clean_result(text: str) -> str: + """Clean tool result for display.""" + if not text: + return "" + # Skip binary/base64 data + if "base64" in text or len(text) > 1000: + if "base64" in text: + return "[imagen/datos binarios]" + return text[:200] + "..." + # Clean up common noise + text = text.strip() + if text.startswith("[{'type':") or text.startswith("{'type':"): + return "[datos estructurados]" + return text[:300] async def _on_stream(update_obj: StreamUpdate) -> None: - # Intercept send_image_to_user MCP tool calls. - # The SDK namespaces MCP tools as "mcp____", - # so match both the bare name and the namespaced variant. + # Intercept send_file_to_user / send_image_to_user MCP tool calls if update_obj.tool_calls and need_mcp_intercept: for tc in update_obj.tool_calls: tc_name = tc.get("name", "") - if tc_name == "send_image_to_user" or tc_name.endswith( - "__send_image_to_user" + if tc_name in ("send_file_to_user", "send_image_to_user") or tc_name.endswith( + ("__send_file_to_user", "__send_image_to_user") ): tc_input = tc.get("input", {}) file_path = tc_input.get("file_path", "") caption = tc_input.get("caption", "") - img = validate_image_path( + attachment = validate_file_path( file_path, approved_directory, caption ) - if img: - mcp_images.append(img) + if attachment: + mcp_images.append(attachment) - # Capture tool calls - if update_obj.tool_calls: + # Send per-event messages for tool calls + if update_obj.tool_calls and verbose_level >= 1: for tc in update_obj.tool_calls: name = tc.get("name", "unknown") - detail = self._summarize_tool_input(name, tc.get("input", {})) - if verbose_level >= 1: - tool_log.append( - {"kind": "tool", "name": name, "detail": detail} - ) - if draft_streamer: - icon = _tool_icon(name) - line = ( - f"{icon} {name}: {detail}" if detail else f"{icon} {name}" - ) - await draft_streamer.append_tool(line) + tool_input = tc.get("input", {}) + icon = _tool_icon(name) + detail = _format_tool_detail(name, tool_input) + + if verbose_level >= 3 and name == "Bash": + # Level 3: show full command + cmd = tool_input.get("command", "")[:400] + msg_text = f"{icon} {cmd}" + elif detail: + msg_text = f"{icon} {name}: {detail}" + else: + msg_text = f"{icon} {name}" + + msg_id = await _send_tool_msg(msg_text) + last_tool_msg_id[0] = msg_id + + # Also log for reference + tool_log.append({"kind": "tool", "name": name, "detail": detail}) + + # Don't duplicate tool lines in the draft — they're already + # sent as individual messages above. + + # Tool results — edit the last tool message to append result + if verbose_level >= 3 and update_obj.type == "tool_result": + result_text = _clean_result(str(getattr(update_obj, "content", "") or "")) + if result_text and last_tool_msg_id[0]: + # Read current message and append result + tool_log.append({"kind": "result", "detail": result_text}) + + # Extended thinking (ThinkingBlocks — Claude's internal reasoning) + if update_obj.type == "thinking" and update_obj.content: + thinking = update_obj.content.strip() + if thinking and verbose_level >= 1: + # Show first line of thinking as a 🧠 message + first_line = thinking.split("\n", 1)[0].strip()[:200] + if first_line: + await _send_tool_msg(f"🧠 {first_line}") + tool_log.append({"kind": "text", "detail": f"🧠 {first_line}"}) - # Capture assistant text (reasoning / commentary) + # Assistant text (visible reasoning / commentary) if update_obj.type == "assistant" and update_obj.content: text = update_obj.content.strip() - if text: - first_line = text.split("\n", 1)[0].strip() + if text and "[ThinkingBlock(" in text: + text = "" + if text and verbose_level >= 1: + first_line = text.split("\n", 1)[0].strip()[:200] if first_line: - if verbose_level >= 1: - tool_log.append( - {"kind": "text", "detail": first_line[:120]} - ) - if draft_streamer: - await draft_streamer.append_tool( - f"\U0001f4ac {first_line[:120]}" - ) + await _send_tool_msg(f"💬 {first_line}") + tool_log.append({"kind": "text", "detail": first_line}) - # Stream text to user via draft (prefer token deltas; - # skip full assistant messages to avoid double-appending) + # Stream response text to user via draft (live typing preview). + # The draft is temporary (vanishes when next real message arrives) + # but the persistent 💬 and final messages capture everything. if draft_streamer and update_obj.content: if update_obj.type == "stream_delta": await draft_streamer.append_text(update_obj.content) - # Throttle progress message edits to avoid Telegram rate limits - if not draft_streamer and verbose_level >= 1: + # Throttle progress message edits to avoid Telegram rate limits. + # Update "Working..." with elapsed time counter. + if verbose_level >= 1: now = time.time() - if (now - last_edit_time[0]) >= 2.0 and tool_log: + if (now - last_edit_time[0]) >= 3.0: last_edit_time[0] = now - new_text = self._format_verbose_progress( - tool_log, verbose_level, start_time - ) + elapsed = int(now - start_time) + new_text = f"⏳ Working... ({elapsed}s)" try: await progress_msg.edit_text(new_text) except Exception: @@ -771,6 +917,47 @@ async def _on_stream(update_obj: StreamUpdate) -> None: return _on_stream + async def _send_formatted_message( + self, + update: Update, + text: str, + parse_mode: str = "HTML", + reply_to_message_id: Optional[int] = None, + ) -> None: + """Send a formatted message with HTML fallback to plain text. + + If Telegram rejects the HTML, strips tags and retries as plain text. + """ + try: + await update.message.reply_text( + text, + parse_mode=parse_mode, + reply_markup=None, + reply_to_message_id=reply_to_message_id, + ) + except Exception as e: + logger.warning( + "HTML send failed, falling back to plain text", + error=str(e), + html_preview=text[:200], + ) + # Strip HTML tags for plain text fallback + plain = re.sub(r"<[^>]+>", "", text) + # Also unescape HTML entities + plain = plain.replace("&", "&").replace("<", "<").replace(">", ">") + try: + await update.message.reply_text( + plain, + reply_markup=None, + reply_to_message_id=reply_to_message_id, + ) + except Exception as plain_err: + await update.message.reply_text( + f"Failed to deliver response " + f"(error: {str(plain_err)[:150]}). Please try again.", + reply_to_message_id=reply_to_message_id, + ) + async def _send_images( self, update: Update, @@ -873,6 +1060,39 @@ async def agentic_text( message_length=len(message_text), ) + # If Claude is currently processing, interrupt and send follow-up + running_task = context.user_data.get("running_claude_task") + if running_task and not running_task.done(): + claude_integration = context.bot_data.get("claude_integration") + if claude_integration: + sdk_manager = getattr(claude_integration, "sdk_manager", None) + if sdk_manager and sdk_manager.is_processing: + logger.info( + "Follow-up message during processing, interrupting", + user_id=user_id, + ) + await update.message.reply_text( + f"📨 Interrupting... {message_text[:80]}" + ) + # 1. Send interrupt signal (like Ctrl+C) + await sdk_manager.interrupt() + # 2. Give it 3 seconds to stop gracefully + try: + await asyncio.wait_for( + asyncio.shield(running_task), timeout=3.0 + ) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception): + # 3. Didn't stop — forcefully kill the subprocess + logger.info("Interrupt didn't stop in time, aborting") + await sdk_manager.abort() + running_task.cancel() + try: + await running_task + except (asyncio.CancelledError, Exception): + pass + context.user_data["running_claude_task"] = None + # Fall through to process this message as a continuation + # Rate limit check rate_limiter = context.bot_data.get("rate_limiter") if rate_limiter: @@ -908,15 +1128,16 @@ async def agentic_text( start_time = time.time() mcp_images: List[ImageAttachment] = [] - # Stream drafts (private chats only) + # Stream drafts (private chats use sendMessageDraft, groups fall back to editMessageText) draft_streamer: Optional[DraftStreamer] = None - if self.settings.enable_stream_drafts and chat.type == "private": + if self.settings.enable_stream_drafts: draft_streamer = DraftStreamer( bot=context.bot, chat_id=chat.id, draft_id=generate_draft_id(), message_thread_id=update.message.message_thread_id, throttle_interval=self.settings.stream_draft_interval, + is_private_chat=(chat.type == "private"), ) on_stream = self._make_stream_callback( @@ -927,14 +1148,16 @@ async def agentic_text( mcp_images=mcp_images, approved_directory=self.settings.approved_directory, draft_streamer=draft_streamer, + chat=chat, + reply_to_message_id=update.message.message_id, ) # Independent typing heartbeat — stays alive even with no stream events heartbeat = self._start_typing_heartbeat(chat) - success = True - try: - claude_response = await claude_integration.run_command( + # Track the running task so /stop can cancel it + run_task = asyncio.ensure_future( + claude_integration.run_command( prompt=message_text, working_directory=current_dir, user_id=user_id, @@ -942,6 +1165,12 @@ async def agentic_text( on_stream=on_stream, force_new=force_new, ) + ) + context.user_data["running_claude_task"] = run_task + + success = True + try: + claude_response = await run_task # New session created successfully — clear the one-shot flag if force_new: @@ -978,6 +1207,15 @@ async def agentic_text( claude_response.content ) + except asyncio.CancelledError: + success = False + logger.info("Claude task cancelled by user", user_id=user_id) + from .utils.formatting import FormattedMessage + + formatted_messages = [ + FormattedMessage("⛔ Task stopped.", parse_mode=None) + ] + except Exception as e: success = False logger.error("Claude integration failed", error=str(e), user_id=user_id) @@ -989,18 +1227,12 @@ async def agentic_text( ] finally: heartbeat.cancel() - if draft_streamer: - try: - await draft_streamer.flush() - except Exception: - logger.debug("Draft flush failed in finally block", user_id=user_id) + context.user_data["running_claude_task"] = None - try: - await progress_msg.delete() - except Exception: - logger.debug("Failed to delete progress message, ignoring") + # Keep progress messages visible (don't delete) + pass - # Use MCP-collected images (from send_image_to_user tool calls) + # Use MCP-collected files (from send_file_to_user tool calls) images: List[ImageAttachment] = mcp_images # Try to combine text + images in one message when possible @@ -1019,45 +1251,18 @@ async def agentic_text( except Exception as img_err: logger.warning("Image+caption send failed", error=str(img_err)) - # Send text messages (skip if caption was already embedded in photos) + # Send response FIRST (so it appears before draft disappears) if not caption_sent: for i, message in enumerate(formatted_messages): if not message.text or not message.text.strip(): continue - try: - await update.message.reply_text( - message.text, - parse_mode=message.parse_mode, - reply_markup=None, # No keyboards in agentic mode - reply_to_message_id=( - update.message.message_id if i == 0 else None - ), - ) - if i < len(formatted_messages) - 1: - await asyncio.sleep(0.5) - except Exception as send_err: - logger.warning( - "Failed to send HTML response, retrying as plain text", - error=str(send_err), - message_index=i, - ) - try: - await update.message.reply_text( - message.text, - reply_markup=None, - reply_to_message_id=( - update.message.message_id if i == 0 else None - ), - ) - except Exception as plain_err: - await update.message.reply_text( - f"Failed to deliver response " - f"(Telegram error: {str(plain_err)[:150]}). " - f"Please try again.", - reply_to_message_id=( - update.message.message_id if i == 0 else None - ), - ) + await self._send_formatted_message( + update, + message.text, + parse_mode=message.parse_mode, + ) + if i < len(formatted_messages) - 1: + await asyncio.sleep(0.5) # Send images separately if caption wasn't used if images: @@ -1070,6 +1275,22 @@ async def agentic_text( except Exception as img_err: logger.warning("Image send failed", error=str(img_err)) + # Finalize progress message AFTER response is sent (no gap) + elapsed = int(time.time() - start_time) + try: + await progress_msg.edit_text(f"✅ Done ({elapsed}s)") + except Exception: + pass + + # Save tool message IDs for /cleanup command + all_tool_msg_ids = [progress_msg.message_id] + for entry in tool_log: + ids = entry.get("_tool_message_ids") + if ids: + all_tool_msg_ids.extend(ids) + context.user_data["last_tool_message_ids"] = all_tool_msg_ids + context.user_data["last_tool_chat_id"] = chat.id + # Audit log audit_logger = context.bot_data.get("audit_logger") if audit_logger: @@ -1083,69 +1304,70 @@ async def agentic_text( async def agentic_document( self, update: Update, context: ContextTypes.DEFAULT_TYPE ) -> None: - """Process file upload -> Claude, minimal chrome.""" + """Process file upload -> Claude, minimal chrome. + + Downloads the file to /tmp/telegram-uploads/ and passes the local + path to Claude so it can read the file during the session. The file + is intentionally kept on disk after the response. + """ user_id = update.effective_user.id - document = update.message.document + msg = update.message + + # Unified handling: document, audio, or video attachments + document = msg.document or msg.audio or msg.video + if not document: + await msg.reply_text("No file detected in this message.") + return + + file_name_raw = getattr(document, "file_name", None) logger.info( "Agentic document upload", user_id=user_id, - filename=document.file_name, + filename=file_name_raw, ) # Security validation security_validator = context.bot_data.get("security_validator") - if security_validator: - valid, error = security_validator.validate_filename(document.file_name) + if security_validator and file_name_raw: + valid, error = security_validator.validate_filename(file_name_raw) if not valid: - await update.message.reply_text(f"File rejected: {error}") + await msg.reply_text(f"File rejected: {error}") return # Size check max_size = 10 * 1024 * 1024 - if document.file_size > max_size: - await update.message.reply_text( + if document.file_size and document.file_size > max_size: + await msg.reply_text( f"File too large ({document.file_size / 1024 / 1024:.1f}MB). Max: 10MB." ) return - chat = update.message.chat + chat = msg.chat await chat.send_action("typing") - progress_msg = await update.message.reply_text("Working...") + progress_msg = await msg.reply_text("Working...") - # Try enhanced file handler, fall back to basic - features = context.bot_data.get("features") - file_handler = features.get_file_handler() if features else None - prompt: Optional[str] = None + # Download file to /tmp/telegram-uploads/ so Claude can read it + upload_dir = Path("/tmp/telegram-uploads") + upload_dir.mkdir(parents=True, exist_ok=True) - if file_handler: - try: - processed_file = await file_handler.handle_document_upload( - document, - user_id, - update.message.caption or "Please review this file:", - ) - prompt = processed_file.prompt - except Exception: - file_handler = None + file_name = file_name_raw or f"file_{document.file_unique_id}" + dest_path = upload_dir / file_name - if not file_handler: - file = await document.get_file() - file_bytes = await file.download_as_bytearray() - try: - content = file_bytes.decode("utf-8") - if len(content) > 50000: - content = content[:50000] + "\n... (truncated)" - caption = update.message.caption or "Please review this file:" - prompt = ( - f"{caption}\n\n**File:** `{document.file_name}`\n\n" - f"```\n{content}\n```" - ) - except UnicodeDecodeError: - await progress_msg.edit_text( - "Unsupported file format. Must be text-based (UTF-8)." - ) - return + # Avoid collisions by appending a suffix + if dest_path.exists(): + stem = dest_path.stem + suffix = dest_path.suffix + counter = 1 + while dest_path.exists(): + dest_path = upload_dir / f"{stem}_{counter}{suffix}" + counter += 1 + + tg_file = await document.get_file() + await tg_file.download_to_drive(str(dest_path)) + + caption = update.message.caption or "El usuario envio este archivo" + prompt = f"[Archivo adjunto: {dest_path}]\n\n{caption}" # Process with Claude claude_integration = context.bot_data.get("claude_integration") @@ -1167,6 +1389,19 @@ async def agentic_document( verbose_level = self._get_verbose_level(context) tool_log: List[Dict[str, Any]] = [] mcp_images_doc: List[ImageAttachment] = [] + + # Stream drafts for document handler too + draft_streamer_doc: Optional[DraftStreamer] = None + if self.settings.enable_stream_drafts: + draft_streamer_doc = DraftStreamer( + bot=context.bot, + chat_id=chat.id, + draft_id=generate_draft_id(), + message_thread_id=msg.message_thread_id, + throttle_interval=self.settings.stream_draft_interval, + is_private_chat=(chat.type == "private"), + ) + on_stream = self._make_stream_callback( verbose_level, progress_msg, @@ -1174,11 +1409,16 @@ async def agentic_document( time.time(), mcp_images=mcp_images_doc, approved_directory=self.settings.approved_directory, + draft_streamer=draft_streamer_doc, + chat=chat, + reply_to_message_id=update.message.message_id, ) heartbeat = self._start_typing_heartbeat(chat) - try: - claude_response = await claude_integration.run_command( + + # Track the running task so follow-up messages can interrupt it + run_task = asyncio.ensure_future( + claude_integration.run_command( prompt=prompt, working_directory=current_dir, user_id=user_id, @@ -1186,6 +1426,11 @@ async def agentic_document( on_stream=on_stream, force_new=force_new, ) + ) + context.user_data["running_claude_task"] = run_task + + try: + claude_response = await run_task if force_new: context.user_data["force_new_session"] = False @@ -1210,7 +1455,7 @@ async def agentic_document( except Exception: logger.debug("Failed to delete progress message, ignoring") - # Use MCP-collected images (from send_image_to_user tool calls) + # Use MCP-collected files (from send_file_to_user tool calls) images: List[ImageAttachment] = mcp_images_doc caption_sent = False @@ -1230,10 +1475,10 @@ async def agentic_document( if not caption_sent: for i, message in enumerate(formatted_messages): - await update.message.reply_text( + await self._send_formatted_message( + update, message.text, parse_mode=message.parse_mode, - reply_markup=None, reply_to_message_id=( update.message.message_id if i == 0 else None ), @@ -1258,6 +1503,7 @@ async def agentic_document( logger.error("Claude file processing failed", error=str(e), user_id=user_id) finally: heartbeat.cancel() + context.user_data["running_claude_task"] = None async def agentic_photo( self, update: Update, context: ContextTypes.DEFAULT_TYPE @@ -1277,14 +1523,22 @@ async def agentic_photo( progress_msg = await update.message.reply_text("Working...") try: + import os photo = update.message.photo[-1] - processed_image = await image_handler.process_image( - photo, update.message.caption - ) + # Download photo to disk so Claude can read it + file = await photo.get_file() + os.makedirs("/tmp/telegram-uploads", exist_ok=True) + timestamp = int(update.message.date.timestamp() * 1000) if update.message.date else 0 + photo_path = f"/tmp/telegram-uploads/photo_{timestamp}.jpg" + await file.download_to_drive(photo_path) + + caption = update.message.caption or "" + prompt = f"[Foto: {photo_path}]\n\n{caption}" if caption else f"[Foto: {photo_path}]" + await self._handle_agentic_media_message( update=update, context=context, - prompt=processed_image.prompt, + prompt=prompt, progress_msg=progress_msg, user_id=user_id, chat=chat, @@ -1321,7 +1575,14 @@ async def agentic_voice( voice, update.message.caption ) - await progress_msg.edit_text("Working...") + # Show transcription to user + transcript_display = processed_voice.transcription + if len(transcript_display) > 4000: + transcript_display = transcript_display[:4000] + "…" + await progress_msg.edit_text(f'🎤 "{transcript_display}"') + + # Send a new progress message for Claude's response + progress_msg = await update.message.reply_text("Working...") await self._handle_agentic_media_message( update=update, context=context, @@ -1366,6 +1627,19 @@ async def _handle_agentic_media_message( verbose_level = self._get_verbose_level(context) tool_log: List[Dict[str, Any]] = [] mcp_images_media: List[ImageAttachment] = [] + + # Stream drafts for media handler + draft_streamer_media: Optional[DraftStreamer] = None + if self.settings.enable_stream_drafts: + draft_streamer_media = DraftStreamer( + bot=context.bot, + chat_id=chat.id, + draft_id=generate_draft_id(), + message_thread_id=getattr(update.message, "message_thread_id", None), + throttle_interval=self.settings.stream_draft_interval, + is_private_chat=(chat.type == "private"), + ) + on_stream = self._make_stream_callback( verbose_level, progress_msg, @@ -1373,11 +1647,16 @@ async def _handle_agentic_media_message( time.time(), mcp_images=mcp_images_media, approved_directory=self.settings.approved_directory, + draft_streamer=draft_streamer_media, + chat=chat, + reply_to_message_id=update.message.message_id, ) heartbeat = self._start_typing_heartbeat(chat) - try: - claude_response = await claude_integration.run_command( + + # Track the running task so follow-up messages can interrupt it + run_task = asyncio.ensure_future( + claude_integration.run_command( prompt=prompt, working_directory=current_dir, user_id=user_id, @@ -1385,8 +1664,14 @@ async def _handle_agentic_media_message( on_stream=on_stream, force_new=force_new, ) + ) + context.user_data["running_claude_task"] = run_task + + try: + claude_response = await run_task finally: heartbeat.cancel() + context.user_data["running_claude_task"] = None if force_new: context.user_data["force_new_session"] = False @@ -1404,12 +1689,10 @@ async def _handle_agentic_media_message( formatter = ResponseFormatter(self.settings) formatted_messages = formatter.format_claude_response(claude_response.content) - try: - await progress_msg.delete() - except Exception: - logger.debug("Failed to delete progress message, ignoring") + # Keep progress messages visible (don't delete) + pass - # Use MCP-collected images (from send_image_to_user tool calls). + # Use MCP-collected files (from send_file_to_user tool calls). images: List[ImageAttachment] = mcp_images_media caption_sent = False @@ -1431,10 +1714,10 @@ async def _handle_agentic_media_message( for i, message in enumerate(formatted_messages): if not message.text or not message.text.strip(): continue - await update.message.reply_text( + await self._send_formatted_message( + update, message.text, parse_mode=message.parse_mode, - reply_markup=None, reply_to_message_id=(update.message.message_id if i == 0 else None), ) if i < len(formatted_messages) - 1: diff --git a/src/bot/utils/image_extractor.py b/src/bot/utils/image_extractor.py index 403097c5..fbd052d4 100644 --- a/src/bot/utils/image_extractor.py +++ b/src/bot/utils/image_extractor.py @@ -1,10 +1,14 @@ -"""Validate image file paths and prepare them for Telegram delivery. +"""Validate file paths and prepare them for Telegram delivery. -Used by the MCP ``send_image_to_user`` tool intercept — the stream callback -validates each path via :func:`validate_image_path` and collects -:class:`ImageAttachment` objects for later Telegram delivery. +Used by the MCP ``send_file_to_user`` tool intercept — the stream callback +validates each path via :func:`validate_file_path` and collects +:class:`FileAttachment` objects for later Telegram delivery. + +Backwards-compatible aliases (:class:`ImageAttachment`, +:func:`validate_image_path`) are kept so existing code continues to work. """ +import mimetypes from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -28,29 +32,37 @@ TELEGRAM_PHOTO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"} # Safety caps -MAX_IMAGES_PER_RESPONSE = 10 +MAX_FILES_PER_RESPONSE = 10 MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB PHOTO_SIZE_LIMIT = 10 * 1024 * 1024 # 10 MB — Telegram photo API limit +# Backwards-compat alias +MAX_IMAGES_PER_RESPONSE = MAX_FILES_PER_RESPONSE + @dataclass -class ImageAttachment: - """An image file to attach to a Telegram response.""" +class FileAttachment: + """A file to attach to a Telegram response.""" path: Path mime_type: str original_reference: str -def validate_image_path( +# Backwards-compat alias +ImageAttachment = FileAttachment + + +def validate_file_path( file_path: str, approved_directory: Path, caption: str = "", -) -> Optional[ImageAttachment]: - """Validate a single image path from an MCP ``send_image_to_user`` call. +) -> Optional[FileAttachment]: + """Validate a file path from an MCP ``send_file_to_user`` call. - Returns an :class:`ImageAttachment` if the path is a valid, existing image + Returns a :class:`FileAttachment` if the path is a valid, existing file inside *approved_directory*, or ``None`` otherwise. + Accepts **any** file type (images, PDFs, audio, etc.). """ try: path = Path(file_path) @@ -64,7 +76,7 @@ def validate_image_path( resolved.relative_to(approved_directory.resolve()) except ValueError: logger.debug( - "MCP image path outside approved directory", + "MCP file path outside approved directory", path=str(resolved), approved=str(approved_directory), ) @@ -75,28 +87,30 @@ def validate_image_path( file_size = resolved.stat().st_size if file_size > MAX_FILE_SIZE_BYTES: - logger.debug("MCP image file too large", path=str(resolved), size=file_size) + logger.debug("MCP file too large", path=str(resolved), size=file_size) return None ext = resolved.suffix.lower() - mime_type = IMAGE_EXTENSIONS.get(ext) - if not mime_type: - return None + mime_type = IMAGE_EXTENSIONS.get(ext) or mimetypes.guess_type(str(resolved))[0] or "application/octet-stream" - return ImageAttachment( + return FileAttachment( path=resolved, mime_type=mime_type, original_reference=caption or file_path, ) except (OSError, ValueError) as e: - logger.debug("MCP image path validation failed", path=file_path, error=str(e)) + logger.debug("MCP file path validation failed", path=file_path, error=str(e)) return None +# Backwards-compat alias +validate_image_path = validate_file_path + + def should_send_as_photo(path: Path) -> bool: """Return True if the image should be sent via reply_photo(). - Raster images ≤ 10 MB are sent as photos (inline preview). + Raster images <= 10 MB are sent as photos (inline preview). SVGs and large files are sent as documents. """ ext = path.suffix.lower() diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..c10ac0ba 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -2,6 +2,7 @@ import asyncio import os +import re from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -53,81 +54,26 @@ class ClaudeResponse: is_error: bool = False error_type: Optional[str] = None tools_used: List[Dict[str, Any]] = field(default_factory=list) + interrupted: bool = False @dataclass class StreamUpdate: """Streaming update from Claude SDK.""" - type: str # 'assistant', 'user', 'system', 'result', 'stream_delta' + type: str # 'assistant', 'user', 'system', 'result', 'stream_delta', 'thinking' content: Optional[str] = None tool_calls: Optional[List[Dict]] = None metadata: Optional[Dict] = None -def _make_can_use_tool_callback( - security_validator: SecurityValidator, - working_directory: Path, - approved_directory: Path, -) -> Any: - """Create a can_use_tool callback for SDK-level tool permission validation. +class ClaudeSDKManager: + """Manage Claude Code SDK integration. - The callback validates file path boundaries and bash directory boundaries - *before* the SDK executes the tool, providing preventive security enforcement. + Keeps a persistent ClaudeSDKClient alive per session so that follow-up + messages can be injected via interrupt() + query() without creating a + new subprocess. """ - _FILE_TOOLS = {"Write", "Edit", "Read", "create_file", "edit_file", "read_file"} - _BASH_TOOLS = {"Bash", "bash", "shell"} - - async def can_use_tool( - tool_name: str, - tool_input: Dict[str, Any], - context: ToolPermissionContext, - ) -> Any: - # File path validation - if tool_name in _FILE_TOOLS: - file_path = tool_input.get("file_path") or tool_input.get("path") - if file_path: - # Allow Claude Code internal paths (~/.claude/plans/, etc.) - if _is_claude_internal_path(file_path): - return PermissionResultAllow() - - valid, _resolved, error = security_validator.validate_path( - file_path, working_directory - ) - if not valid: - logger.warning( - "can_use_tool denied file operation", - tool_name=tool_name, - file_path=file_path, - error=error, - ) - return PermissionResultDeny(message=error or "Invalid file path") - - # Bash directory boundary validation - if tool_name in _BASH_TOOLS: - command = tool_input.get("command", "") - if command: - valid, error = check_bash_directory_boundary( - command, working_directory, approved_directory - ) - if not valid: - logger.warning( - "can_use_tool denied bash command", - tool_name=tool_name, - command=command, - error=error, - ) - return PermissionResultDeny( - message=error or "Bash directory boundary violation" - ) - - return PermissionResultAllow() - - return can_use_tool - - -class ClaudeSDKManager: - """Manage Claude Code SDK integration.""" def __init__( self, @@ -138,14 +84,53 @@ def __init__( self.config = config self.security_validator = security_validator - # Set up environment for Claude Code SDK if API key is provided - # If no API key is provided, the SDK will use existing CLI authentication if config.anthropic_api_key_str: os.environ["ANTHROPIC_API_KEY"] = config.anthropic_api_key_str logger.info("Using provided API key for Claude SDK authentication") else: logger.info("No API key provided, using existing Claude CLI authentication") + self._active_client: Optional[ClaudeSDKClient] = None + self._is_processing = False + + @property + def is_processing(self) -> bool: + """Whether a command is currently being processed.""" + return self._is_processing + + async def interrupt(self) -> bool: + """Send interrupt signal to the running Claude command (like Ctrl+C). + + Returns True if there was an active client to interrupt. + """ + client = self._active_client + if client is None: + return False + logger.info("Sending interrupt to active Claude client") + try: + await client.interrupt() + return True + except Exception as e: + logger.debug("Interrupt signal failed", error=str(e)) + return False + + async def abort(self) -> bool: + """Forcefully abort the running command (interrupt + kill subprocess).""" + client = self._active_client + if client is None: + return False + logger.info("Aborting active Claude client") + try: + await client.interrupt() + except Exception as e: + logger.debug("Interrupt signal failed (may already be done)", error=str(e)) + try: + await client.disconnect() + except Exception as e: + logger.warning("Error disconnecting client during abort", error=str(e)) + self._active_client = None + return True + async def execute_command( self, prompt: str, @@ -156,6 +141,7 @@ async def execute_command( ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() + self._is_processing = True logger.info( "Starting Claude SDK command", @@ -165,56 +151,23 @@ async def execute_command( ) try: - # Capture stderr from Claude CLI for better error diagnostics stderr_lines: List[str] = [] def _stderr_callback(line: str) -> None: stderr_lines.append(line) logger.debug("Claude CLI stderr", line=line) - # Build system prompt, loading CLAUDE.md from working directory if present - base_prompt = ( - f"All file operations must stay within {working_directory}. " - "Use relative paths." - ) - claude_md_path = Path(working_directory) / "CLAUDE.md" - if claude_md_path.exists(): - base_prompt += "\n\n" + claude_md_path.read_text(encoding="utf-8") - logger.info( - "Loaded CLAUDE.md into system prompt", - path=str(claude_md_path), - ) - - # When DISABLE_TOOL_VALIDATION=true, pass None for allowed/disallowed - # tools so the SDK does not restrict tool usage (e.g. MCP tools). - if self.config.disable_tool_validation: - sdk_allowed_tools = None - sdk_disallowed_tools = None - else: - sdk_allowed_tools = self.config.claude_allowed_tools - sdk_disallowed_tools = self.config.claude_disallowed_tools - - # Build Claude Agent options options = ClaudeAgentOptions( max_turns=self.config.claude_max_turns, model=self.config.claude_model or None, - max_budget_usd=self.config.claude_max_cost_per_request, cwd=str(working_directory), - allowed_tools=sdk_allowed_tools, - disallowed_tools=sdk_disallowed_tools, cli_path=self.config.claude_cli_path or None, include_partial_messages=stream_callback is not None, - sandbox={ - "enabled": self.config.sandbox_enabled, - "autoAllowBashIfSandboxed": True, - "excludedCommands": self.config.sandbox_excluded_commands or [], - }, - system_prompt=base_prompt, - setting_sources=["project"], + permission_mode="bypassPermissions", + setting_sources=["user", "project"], stderr=_stderr_callback, ) - # Pass MCP server configuration if enabled if self.config.enable_mcp and self.config.mcp_config_path: options.mcp_servers = self._load_mcp_config(self.config.mcp_config_path) logger.info( @@ -222,49 +175,27 @@ def _stderr_callback(line: str) -> None: mcp_config_path=str(self.config.mcp_config_path), ) - # Wire can_use_tool callback for preventive tool validation - if self.security_validator: - options.can_use_tool = _make_can_use_tool_callback( - security_validator=self.security_validator, - working_directory=working_directory, - approved_directory=self.config.approved_directory, - ) - - # Resume previous session if we have a session_id if session_id and continue_session: options.resume = session_id - logger.info( - "Resuming previous session", - session_id=session_id, - ) + logger.info("Resuming previous session", session_id=session_id) - # Collect messages via ClaudeSDKClient messages: List[Message] = [] + interrupted = False async def _run_client() -> None: - # Use connect(None) + query(prompt) pattern because - # can_use_tool requires the prompt as AsyncIterable, not - # a plain string. connect(None) uses an empty async - # iterable internally, satisfying the requirement. + nonlocal interrupted client = ClaudeSDKClient(options) + self._active_client = client try: await client.connect() await client.query(prompt) - # Iterate over raw messages and parse them ourselves - # so that MessageParseError (e.g. from rate_limit_event) - # doesn't kill the underlying async generator. When - # parse_message raises inside the SDK's receive_messages() - # generator, Python terminates that generator permanently, - # causing us to lose all subsequent messages including - # the ResultMessage. async for raw_data in client._query.receive_messages(): try: message = parse_message(raw_data) except MessageParseError as e: logger.debug( - "Skipping unparseable message", - error=str(e), + "Skipping unparseable message", error=str(e) ) continue @@ -273,7 +204,6 @@ async def _run_client() -> None: if isinstance(message, ResultMessage): break - # Handle streaming callback if stream_callback: try: await self._handle_stream_message( @@ -285,16 +215,22 @@ async def _run_client() -> None: error=str(callback_error), error_type=type(callback_error).__name__, ) + except asyncio.CancelledError: + interrupted = True + logger.info("Claude command was interrupted/cancelled") finally: - await client.disconnect() + self._active_client = None + try: + await client.disconnect() + except Exception: + pass - # Execute with timeout await asyncio.wait_for( _run_client(), timeout=self.config.claude_timeout_seconds, ) - # Extract cost, tools, and session_id from result message + # Extract results from messages cost = 0.0 tools_used: List[Dict[str, Any]] = [] claude_session_id = None @@ -322,8 +258,6 @@ async def _run_client() -> None: ) break - # Fallback: extract session_id from StreamEvent messages if - # ResultMessage didn't provide one (can happen with some CLI versions) if not claude_session_id: for message in messages: msg_session_id = getattr(message, "session_id", None) @@ -335,10 +269,7 @@ async def _run_client() -> None: ) break - # Calculate duration duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000) - - # Use Claude's session_id if available, otherwise fall back final_session_id = claude_session_id or session_id or "" if claude_session_id and claude_session_id != session_id: @@ -348,9 +279,12 @@ async def _run_client() -> None: previous_session_id=session_id, ) - # Use ResultMessage.result if available, fall back to message extraction if result_content is not None: - content = result_content + content = re.sub( + r'\[ThinkingBlock\(thinking=\'.*?\',\s*signature=\'.*?\'\)\]\s*', + '', result_content, flags=re.DOTALL + ) + content = content.strip() else: content_parts = [] for msg in messages: @@ -377,6 +311,7 @@ async def _run_client() -> None: ] ), tools_used=tools_used, + interrupted=interrupted, ) except asyncio.TimeoutError: @@ -402,7 +337,6 @@ async def _run_client() -> None: except ProcessError as e: error_str = str(e) - # Include captured stderr for better diagnostics captured_stderr = "\n".join(stderr_lines[-20:]) if stderr_lines else "" if captured_stderr: error_str = f"{error_str}\nStderr: {captured_stderr}" @@ -412,7 +346,6 @@ async def _run_client() -> None: exit_code=getattr(e, "exit_code", None), stderr=captured_stderr or None, ) - # Check if the process error is MCP-related if "mcp" in error_str.lower(): raise ClaudeMCPError(f"MCP server error: {error_str}") raise ClaudeProcessError(f"Claude process error: {error_str}") @@ -420,7 +353,6 @@ async def _run_client() -> None: except CLIConnectionError as e: error_str = str(e) logger.error("Claude connection error", error=error_str) - # Check if the connection error is MCP-related if "mcp" in error_str.lower() or "server" in error_str.lower(): raise ClaudeMCPError(f"MCP server connection failed: {error_str}") raise ClaudeProcessError(f"Failed to connect to Claude: {error_str}") @@ -436,7 +368,6 @@ async def _run_client() -> None: except Exception as e: exceptions = getattr(e, "exceptions", None) if exceptions is not None: - # ExceptionGroup from TaskGroup operations (Python 3.11+) logger.error( "Task group error in Claude SDK", error=str(e), @@ -455,15 +386,18 @@ async def _run_client() -> None: ) raise ClaudeProcessError(f"Unexpected error: {str(e)}") + finally: + self._is_processing = False + async def _handle_stream_message( self, message: Message, stream_callback: Callable[[StreamUpdate], None] ) -> None: """Handle streaming message from claude-agent-sdk.""" try: if isinstance(message, AssistantMessage): - # Extract content from assistant message content = getattr(message, "content", []) text_parts = [] + thinking_parts = [] tool_calls = [] if content and isinstance(content, list): @@ -478,6 +412,17 @@ async def _handle_stream_message( ) elif hasattr(block, "text"): text_parts.append(block.text) + elif hasattr(block, "thinking"): + thinking = getattr(block, "thinking", "") + if thinking: + thinking_parts.append(thinking) + + if thinking_parts: + thinking_update = StreamUpdate( + type="thinking", + content="\n".join(thinking_parts), + ) + await stream_callback(thinking_update) if text_parts or tool_calls: update = StreamUpdate( @@ -486,8 +431,7 @@ async def _handle_stream_message( tool_calls=tool_calls if tool_calls else None, ) await stream_callback(update) - elif content: - # Fallback for non-list content + elif content and not thinking_parts: update = StreamUpdate( type="assistant", content=str(content), @@ -509,7 +453,17 @@ async def _handle_stream_message( elif isinstance(message, UserMessage): content = getattr(message, "content", "") - if content: + raw_content = getattr(message, "content", None) + if isinstance(raw_content, list): + for block in raw_content: + if hasattr(block, "content") and hasattr(block, "tool_use_id"): + result_text = str(getattr(block, "content", "")) + update = StreamUpdate( + type="tool_result", + content=result_text, + ) + await stream_callback(update) + elif content: update = StreamUpdate( type="user", content=content, @@ -520,10 +474,7 @@ async def _handle_stream_message( logger.warning("Stream callback failed", error=str(e)) def _load_mcp_config(self, config_path: Path) -> Dict[str, Any]: - """Load MCP server configuration from a JSON file. - - The new claude-agent-sdk expects mcp_servers as a dict, not a file path. - """ + """Load MCP server configuration from a JSON file.""" import json try: diff --git a/src/mcp/telegram_server.py b/src/mcp/telegram_server.py index cc320386..4851ba72 100644 --- a/src/mcp/telegram_server.py +++ b/src/mcp/telegram_server.py @@ -1,46 +1,39 @@ """MCP server exposing Telegram-specific tools to Claude. -Runs as a stdio transport server. The ``send_image_to_user`` tool validates -file existence and extension, then returns a success string. Actual Telegram -delivery is handled by the bot's stream callback which intercepts the tool -call. +Runs as a stdio transport server. The ``send_file_to_user`` tool validates +file existence, then returns a success string. Actual Telegram delivery is +handled by the bot's stream callback which intercepts the tool call. """ from pathlib import Path from mcp.server.fastmcp import FastMCP -IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg"} - mcp = FastMCP("telegram") @mcp.tool() -async def send_image_to_user(file_path: str, caption: str = "") -> str: - """Send an image file to the Telegram user. +async def send_file_to_user(file_path: str, caption: str = "") -> str: + """Send a file to the Telegram user. + + Supports any file type: images, PDFs, audio, documents, etc. Args: - file_path: Absolute path to the image file. - caption: Optional caption to display with the image. + file_path: Absolute path to the file. + caption: Optional caption to display with the file. Returns: - Confirmation string when the image is queued for delivery. + Confirmation string when the file is queued for delivery. """ path = Path(file_path) if not path.is_absolute(): return f"Error: path must be absolute, got '{file_path}'" - if path.suffix.lower() not in IMAGE_EXTENSIONS: - return ( - f"Error: unsupported image extension '{path.suffix}'. " - f"Supported: {', '.join(sorted(IMAGE_EXTENSIONS))}" - ) - if not path.is_file(): return f"Error: file not found: {file_path}" - return f"Image queued for delivery: {path.name}" + return f"File queued for delivery: {path.name}" if __name__ == "__main__": diff --git a/src/security/validators.py b/src/security/validators.py index 381ba321..5326b3e1 100644 --- a/src/security/validators.py +++ b/src/security/validators.py @@ -240,6 +240,10 @@ def validate_filename(self, filename: str) -> Tuple[bool, Optional[str]]: ) return False, "Invalid filename: contains forbidden pattern" + # Skip remaining checks if security patterns disabled + if self.disable_security_patterns: + return True, None + # Check for forbidden filenames if filename.lower() in {name.lower() for name in self.FORBIDDEN_FILENAMES}: logger.warning("Forbidden filename", filename=filename) @@ -253,15 +257,16 @@ def validate_filename(self, filename: str) -> Tuple[bool, Optional[str]]: ) return False, f"File type not allowed: {filename}" - # Check extension - path_obj = Path(filename) - ext = path_obj.suffix.lower() + # Check extension (skip if security patterns disabled) + if not self.disable_security_patterns: + path_obj = Path(filename) + ext = path_obj.suffix.lower() - if ext and ext not in self.ALLOWED_EXTENSIONS: - logger.warning( - "File extension not allowed", filename=filename, extension=ext - ) - return False, f"File type not allowed: {ext}" + if ext and ext not in self.ALLOWED_EXTENSIONS: + logger.warning( + "File extension not allowed", filename=filename, extension=ext + ) + return False, f"File type not allowed: {ext}" # Check for hidden files (starting with .) if filename.startswith(".") and filename not in {".gitignore", ".gitkeep"}: diff --git a/tests/unit/test_bot/test_image_extractor.py b/tests/unit/test_bot/test_image_extractor.py index 19fb690d..ff3c0cd9 100644 --- a/tests/unit/test_bot/test_image_extractor.py +++ b/tests/unit/test_bot/test_image_extractor.py @@ -98,11 +98,13 @@ def test_nonexistent_file_rejected(self, work_dir: Path, approved_dir: Path): result = validate_image_path(str(work_dir / "missing.png"), approved_dir) assert result is None - def test_non_image_extension_rejected(self, work_dir: Path, approved_dir: Path): + def test_non_image_extension_accepted(self, work_dir: Path, approved_dir: Path): + """validate_file_path (aliased as validate_image_path) accepts any file type.""" txt = work_dir / "notes.txt" txt.write_text("hello") result = validate_image_path(str(txt), approved_dir) - assert result is None + assert result is not None + assert result.mime_type == "text/plain" def test_outside_approved_dir_rejected(self, tmp_path: Path): outside = tmp_path / "outside" diff --git a/tests/unit/test_mcp/test_telegram_server.py b/tests/unit/test_mcp/test_telegram_server.py index c40f8fed..9e723ec1 100644 --- a/tests/unit/test_mcp/test_telegram_server.py +++ b/tests/unit/test_mcp/test_telegram_server.py @@ -4,7 +4,7 @@ import pytest -from src.mcp.telegram_server import send_image_to_user +from src.mcp.telegram_server import send_file_to_user @pytest.fixture @@ -15,43 +15,37 @@ def image_file(tmp_path: Path) -> Path: return img -class TestSendImageToUser: +class TestSendFileToUser: async def test_valid_image(self, image_file: Path) -> None: - result = await send_image_to_user(str(image_file)) - assert "Image queued for delivery" in result + result = await send_file_to_user(str(image_file)) + assert "File queued for delivery" in result assert "chart.png" in result async def test_valid_image_with_caption(self, image_file: Path) -> None: - result = await send_image_to_user(str(image_file), caption="My chart") - assert "Image queued for delivery" in result + result = await send_file_to_user(str(image_file), caption="My chart") + assert "File queued for delivery" in result async def test_relative_path_rejected(self, image_file: Path) -> None: - result = await send_image_to_user("relative/path/chart.png") + result = await send_file_to_user("relative/path/chart.png") assert "Error" in result assert "absolute" in result async def test_missing_file_rejected(self, tmp_path: Path) -> None: missing = tmp_path / "nonexistent.png" - result = await send_image_to_user(str(missing)) + result = await send_file_to_user(str(missing)) assert "Error" in result assert "not found" in result - async def test_non_image_extension_rejected(self, tmp_path: Path) -> None: - txt_file = tmp_path / "notes.txt" - txt_file.write_text("hello") - result = await send_image_to_user(str(txt_file)) - assert "Error" in result - assert "unsupported" in result - - async def test_all_supported_extensions(self, tmp_path: Path) -> None: - for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg"]: - img = tmp_path / f"test{ext}" - img.write_bytes(b"\x00" * 10) - result = await send_image_to_user(str(img)) - assert "Image queued for delivery" in result, f"Failed for {ext}" + async def test_any_extension_accepted(self, tmp_path: Path) -> None: + """send_file_to_user accepts any file type, not just images.""" + for ext in [".png", ".jpg", ".pdf", ".docx", ".mp3", ".zip", ".txt"]: + f = tmp_path / f"test{ext}" + f.write_bytes(b"\x00" * 10) + result = await send_file_to_user(str(f)) + assert "File queued for delivery" in result, f"Failed for {ext}" async def test_case_insensitive_extension(self, tmp_path: Path) -> None: img = tmp_path / "photo.JPG" img.write_bytes(b"\x00" * 10) - result = await send_image_to_user(str(img)) - assert "Image queued for delivery" in result + result = await send_file_to_user(str(img)) + assert "File queued for delivery" in result