diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index ac1d5304..9a99c7b2 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -115,6 +115,7 @@ class MessageOrchestrator: def __init__(self, settings: Settings, deps: Dict[str, Any]): self.settings = settings self.deps = deps + self._known_commands: frozenset[str] = frozenset() def _inject_deps(self, handler: Callable) -> Callable: # type: ignore[type-arg] """Wrap handler to inject dependencies into context.bot_data.""" @@ -306,12 +307,16 @@ def _register_agentic_handlers(self, app: Application) -> None: ("new", self.agentic_new), ("status", self.agentic_status), ("verbose", self.agentic_verbose), + ("model", self.agentic_model), ("repo", self.agentic_repo), ("restart", command.restart_command), ] if self.settings.enable_project_threads: handlers.append(("sync_threads", command.sync_threads)) + # Derive known commands dynamically — avoids drift when new commands are added + self._known_commands: frozenset[str] = frozenset(cmd for cmd, _ in handlers) + for cmd, handler in handlers: app.add_handler(CommandHandler(cmd, self._inject_deps(handler))) @@ -324,6 +329,19 @@ def _register_agentic_handlers(self, app: Application) -> None: group=10, ) + # Unknown slash commands -> Claude (passthrough in agentic mode). + # Registered commands are handled by CommandHandlers in group 0 + # (higher priority). This catches any /command not matched there + # and forwards it to Claude, while skipping known commands to + # avoid double-firing. + app.add_handler( + MessageHandler( + filters.COMMAND, + self._inject_deps(self._handle_unknown_command), + ), + group=10, + ) + # File uploads -> Claude app.add_handler( MessageHandler( @@ -415,6 +433,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("new", "Start a fresh session"), BotCommand("status", "Show session status"), BotCommand("verbose", "Set output verbosity (0/1/2)"), + BotCommand("model", "Switch Claude model"), BotCommand("repo", "List repos / switch workspace"), BotCommand("restart", "Restart the bot"), ] @@ -578,6 +597,81 @@ async def agentic_verbose( parse_mode="HTML", ) + def _get_model_override(self, context: ContextTypes.DEFAULT_TYPE) -> Optional[str]: + """Return per-user model override, or None to use the default.""" + return context.user_data.get("model_override") + + @staticmethod + def _resolve_model_display( + user_override: Optional[str], + config_model: Optional[str], + last_model: Optional[str] = None, + ) -> str: + """Return a human-readable model string showing what will actually be used.""" + if user_override: + return user_override + if config_model: + return config_model + if last_model: + return last_model + return "unknown (send a message first to detect)" + + async def agentic_model( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Set Claude model: /model [model_name].""" + args = update.message.text.split()[1:] if update.message.text else [] + user_override = self._get_model_override(context) + last_model = context.user_data.get("last_model") + current = self._resolve_model_display( + user_override, self.settings.claude_model, last_model + ) + + if not args: + source = "user override" if user_override else ( + "server config" if self.settings.claude_model else "Claude Code default" + ) + await update.message.reply_text( + f"Model: {escape_html(current)} ({source})\n\n" + "Usage: /model model_name\n" + "Reset: /model default", + parse_mode="HTML", + ) + return + + model_name = args[0].strip() + if not model_name or len(model_name) > 100: + await update.message.reply_text("Invalid model name.") + return + audit_logger = context.bot_data.get("audit_logger") + if model_name == "default": + context.user_data.pop("model_override", None) + default = self._resolve_model_display(None, self.settings.claude_model) + await update.message.reply_text( + f"Model reset to {escape_html(default)}", + parse_mode="HTML", + ) + if audit_logger: + await audit_logger.log_command( + user_id=update.effective_user.id, + command="model_reset", + args=[], + success=True, + ) + else: + context.user_data["model_override"] = model_name + await update.message.reply_text( + f"Model set to {escape_html(model_name)}", + parse_mode="HTML", + ) + if audit_logger: + await audit_logger.log_command( + user_id=update.effective_user.id, + command="model", + args=[model_name], + success=True, + ) + def _format_verbose_progress( self, activity_log: List[Dict[str, Any]], @@ -941,6 +1035,7 @@ async def agentic_text( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=self._get_model_override(context), ) # New session created successfully — clear the one-shot flag @@ -948,6 +1043,8 @@ async def agentic_text( context.user_data["force_new_session"] = False context.user_data["claude_session_id"] = claude_response.session_id + if claude_response.model: + context.user_data["last_model"] = claude_response.model # Track directory changes from .handlers.message import _update_working_directory_from_claude_response @@ -1185,12 +1282,15 @@ async def agentic_document( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=self._get_model_override(context), ) if force_new: context.user_data["force_new_session"] = False context.user_data["claude_session_id"] = claude_response.session_id + if claude_response.model: + context.user_data["last_model"] = claude_response.model from .handlers.message import _update_working_directory_from_claude_response @@ -1384,6 +1484,7 @@ async def _handle_agentic_media_message( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=self._get_model_override(context), ) finally: heartbeat.cancel() @@ -1392,6 +1493,7 @@ async def _handle_agentic_media_message( context.user_data["force_new_session"] = False context.user_data["claude_session_id"] = claude_response.session_id + context.user_data["last_model"] = claude_response.model from .handlers.message import _update_working_directory_from_claude_response @@ -1450,6 +1552,25 @@ async def _handle_agentic_media_message( except Exception as img_err: logger.warning("Image send failed", error=str(img_err)) + async def _handle_unknown_command( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Forward unknown slash commands to Claude in agentic mode. + + Known commands are handled by their own CommandHandlers (group 0); + this handler fires for *every* COMMAND message in group 10 but + returns immediately when the command is registered, preventing + double execution. + """ + msg = update.effective_message + if not msg or not msg.text: + return + cmd = msg.text.split()[0].lstrip("/").split("@")[0].lower() + if cmd in self._known_commands: + return # let the registered CommandHandler take care of it + # Forward unrecognised /commands to Claude as natural language + await self.agentic_text(update, context) + def _voice_unavailable_message(self) -> str: """Return provider-aware guidance when voice feature is unavailable.""" return ( diff --git a/src/bot/utils/draft_streamer.py b/src/bot/utils/draft_streamer.py index 4fe1e709..6f263ba6 100644 --- a/src/bot/utils/draft_streamer.py +++ b/src/bot/utils/draft_streamer.py @@ -1,4 +1,9 @@ -"""Stream partial responses to Telegram via sendMessageDraft.""" +"""Stream partial responses to Telegram via sendMessageDraft. + +Uses Telegram Bot API 9.3+ sendMessageDraft for smooth token-by-token +streaming in private chats. Falls back to editMessageText for group chats +where sendMessageDraft is unavailable. +""" import secrets import time @@ -14,6 +19,17 @@ # Max tool lines shown in the draft header _MAX_TOOL_LINES = 10 +# Minimum characters before sending the first draft (avoids triggering +# push notifications with just a few characters) +_MIN_INITIAL_CHARS = 20 + +# Error messages that indicate the draft transport is unavailable +_DRAFT_UNAVAILABLE_ERRORS = frozenset({ + "TEXTDRAFT_PEER_INVALID", + "Bad Request: draft can't be sent", + "Bad Request: peer doesn't support drafts", +}) + def generate_draft_id() -> int: """Generate a non-zero positive draft ID. @@ -30,18 +46,21 @@ class DraftStreamer: The draft is composed of two sections: 1. **Tool header** — compact lines showing tool calls and reasoning - snippets as they arrive, e.g. ``"📖 Read | 🔍 Grep | 🐚 Bash"``. + snippets as they arrive. 2. **Response body** — the actual assistant response text, streamed token-by-token. Both sections are combined into a single draft message and sent via - ``sendMessageDraft``. + ``sendMessageDraft`` (private chats) or ``editMessageText`` (groups). - Key design decisions: + Key design decisions (inspired by OpenClaw): - Plain text drafts (no parse_mode) to avoid partial HTML/markdown errors. - - Tail-truncation for messages >4096 chars: shows ``"\\u2026" + last 4093 chars``. - - Self-disabling: any API error silently disables the streamer so the - request continues with normal (non-streaming) delivery. + - Tail-truncation for messages >4096 chars. + - Min initial chars: waits for ~20 chars before first send. + - Anti-regressive: skips updates where text got shorter. + - Error classification: distinguishes draft-unavailable (fall back to edit) + from other errors (disable entirely). + - Self-disabling: persistent errors silently disable the streamer. """ def __init__( @@ -50,7 +69,8 @@ def __init__( chat_id: int, draft_id: int, message_thread_id: Optional[int] = None, - throttle_interval: float = 0.3, + throttle_interval: float = 0.4, + is_private_chat: bool = True, ) -> None: self.bot = bot self.chat_id = chat_id @@ -61,7 +81,18 @@ def __init__( self._tool_lines: List[str] = [] self._accumulated_text = "" self._last_send_time = 0.0 + self._last_sent_length = 0 # anti-regressive tracking self._enabled = True + self._error_count = 0 + self._max_errors = 3 + + # Transport mode: "draft" for private chats, "edit" for groups + self._use_draft = is_private_chat + self._edit_message_id: Optional[int] = None # for edit-based transport + + @property + def enabled(self) -> bool: + return self._enabled async def append_tool(self, line: str) -> None: """Append a tool activity line and send a draft if throttled.""" @@ -87,10 +118,14 @@ async def flush(self) -> None: return if not self._accumulated_text and not self._tool_lines: return - await self._send_draft() + await self._send_draft(force=True) + + def _compose_draft(self, is_final: bool = False) -> str: + """Combine tool header and response body into a single draft. - def _compose_draft(self) -> str: - """Combine tool header and response body into a single draft.""" + Appends a blinking cursor ▌ during streaming (like OpenClaw) + to indicate the response is still being generated. + """ parts: List[str] = [] if self._tool_lines: @@ -103,33 +138,157 @@ def _compose_draft(self) -> str: if self._accumulated_text: if parts: parts.append("") # blank separator line - parts.append(self._accumulated_text) + text = self._accumulated_text + if not is_final: + text += " ▌" + parts.append(text) return "\n".join(parts) - async def _send_draft(self) -> None: - """Send the composed draft (tools + text) as a message draft.""" + async def _send_draft(self, force: bool = False) -> None: + """Send the composed draft via the appropriate transport.""" draft_text = self._compose_draft() if not draft_text.strip(): return + # Min initial chars gate (skip if force-flushing) + if not force and self._last_sent_length == 0: + if len(self._accumulated_text) < _MIN_INITIAL_CHARS and not self._tool_lines: + return + + # Anti-regressive: skip if text got shorter (can happen with + # tool header rotation) + current_len = len(draft_text) + if not force and current_len < self._last_sent_length: + return + # Tail-truncate if over Telegram limit if len(draft_text) > TELEGRAM_MAX_MESSAGE_LENGTH: - draft_text = "\u2026" + draft_text[-(TELEGRAM_MAX_MESSAGE_LENGTH - 1) :] + draft_text = "\u2026" + draft_text[-(TELEGRAM_MAX_MESSAGE_LENGTH - 1):] try: + if self._use_draft: + await self._send_via_draft(draft_text) + else: + await self._send_via_edit(draft_text) + self._last_send_time = time.time() + self._last_sent_length = current_len + self._error_count = 0 # reset on success + except telegram.error.BadRequest as e: + error_str = str(e) + if any(err in error_str for err in _DRAFT_UNAVAILABLE_ERRORS): + # Draft transport unavailable — fall back to edit + logger.info( + "Draft transport unavailable, falling back to edit", + chat_id=self.chat_id, + error=error_str, + ) + self._use_draft = False + # Retry immediately with edit transport + try: + await self._send_via_edit(draft_text) + self._last_send_time = time.time() + self._last_sent_length = current_len + except Exception: + self._handle_error() + elif "Message is not modified" in error_str: + # Same content — not an error, just skip + self._last_send_time = time.time() + elif "Message to edit not found" in error_str: + # Message was deleted — re-create + self._edit_message_id = None + try: + await self._send_via_edit(draft_text) + self._last_send_time = time.time() + self._last_sent_length = current_len + except Exception: + self._handle_error() + else: + self._handle_error() + except Exception: + self._handle_error() + + def _handle_error(self) -> None: + """Track errors and disable after too many.""" + self._error_count += 1 + if self._error_count >= self._max_errors: + logger.debug( + "Draft streamer disabled after repeated errors", + chat_id=self.chat_id, + error_count=self._error_count, + ) + self._enabled = False + + async def _send_via_draft(self, text: str) -> None: + """Send via sendMessageDraft (private chats).""" + kwargs = { + "chat_id": self.chat_id, + "text": text, + "draft_id": self.draft_id, + } + if self.message_thread_id is not None: + kwargs["message_thread_id"] = self.message_thread_id + logger.debug( + "Sending draft", + transport="draft", + text_len=len(text), + preview=text[:80], + ) + await self.bot.send_message_draft(**kwargs) + + async def _send_via_edit(self, text: str) -> None: + """Send via editMessageText (group chat fallback). + + Creates a message on first call, then edits it on subsequent calls. + """ + if self._edit_message_id is None: + # Send initial message kwargs = { "chat_id": self.chat_id, - "text": draft_text, - "draft_id": self.draft_id, + "text": text, } if self.message_thread_id is not None: kwargs["message_thread_id"] = self.message_thread_id - await self.bot.send_message_draft(**kwargs) - self._last_send_time = time.time() - except Exception: - logger.debug( - "Draft send failed, disabling streamer", + msg = await self.bot.send_message(**kwargs) + self._edit_message_id = msg.message_id + else: + await self.bot.edit_message_text( + text, chat_id=self.chat_id, + message_id=self._edit_message_id, ) - self._enabled = False + + async def clear(self) -> None: + """Clear the draft bubble by sending an empty draft. + + Call this before sending the final response message so the draft + bubble disappears cleanly instead of overlapping with the real message. + """ + if not self._enabled: + return + try: + if self._use_draft: + # Send empty draft to dismiss the typing bubble + await self.bot.send_message_draft( + chat_id=self.chat_id, + text="", + draft_id=self.draft_id, + ) + elif self._edit_message_id is not None: + # For edit-based transport, delete the preview message + try: + await self.bot.delete_message( + chat_id=self.chat_id, + message_id=self._edit_message_id, + ) + except Exception: + pass + self._edit_message_id = None + except Exception: + pass + self._enabled = False + + @property + def edit_message_id(self) -> Optional[int]: + """Return the message ID used by edit transport (for cleanup).""" + return self._edit_message_id diff --git a/src/bot/utils/html_format.py b/src/bot/utils/html_format.py index 2799a4ee..b84bbd58 100644 --- a/src/bot/utils/html_format.py +++ b/src/bot/utils/html_format.py @@ -1,40 +1,99 @@ """HTML formatting utilities for Telegram messages. -Telegram's HTML mode only requires escaping 3 characters (<, >, &) vs the many -ambiguous Markdown v1 metacharacters, making it far more robust for rendering -Claude's output which contains underscores, asterisks, brackets, etc. +Telegram's HTML mode supports: , , , , ,
,
+
, , 
, +
, . + +This module converts Claude's markdown output into that subset. """ import re from typing import List, Tuple -def escape_html(text: str) -> str: - """Escape the 3 HTML-special characters for Telegram. +_INLINE_TAGS = {"b", "i", "s", "u", "code"} +_TAG_RE = re.compile(r"<(/?)(\w+)(?:\s[^>]*)?>") + + +def _repair_html_nesting(html: str) -> str: + """Fix misnested inline HTML tags that Telegram would reject. - This replaces all 3 _escape_markdown functions previously scattered - across the codebase. + Telegram requires strict nesting: ... is OK, + but ... is rejected. This walks the tag stack + and closes/reopens tags when it detects a mismatch. """ + result = [] + stack: List[str] = [] + last_end = 0 + + for m in _TAG_RE.finditer(html): + # Append text before this tag + result.append(html[last_end:m.start()]) + last_end = m.end() + + is_close = m.group(1) == "/" + tag = m.group(2).lower() + + # Only repair inline tags; skip
, 
, , etc. + if tag not in _INLINE_TAGS: + result.append(m.group(0)) + continue + + if not is_close: + stack.append(tag) + result.append(m.group(0)) + else: + if tag in stack: + # Close tags in reverse order up to the matching opener + idx = len(stack) - 1 - stack[::-1].index(tag) + tags_to_reopen = stack[idx + 1:] + # Close everything from top to idx + for t in reversed(stack[idx:]): + result.append(f"") + stack = stack[:idx] + # Reopen tags that were above the matched one + for t in tags_to_reopen: + result.append(f"<{t}>") + stack.append(t) + else: + # Orphan close tag — skip it + pass + + # Append remaining text + result.append(html[last_end:]) + + # Close any unclosed tags + for t in reversed(stack): + result.append(f"") + + return "".join(result) + + +def escape_html(text: str) -> str: + """Escape the 3 HTML-special characters for Telegram.""" return text.replace("&", "&").replace("<", "<").replace(">", ">") def markdown_to_telegram_html(text: str) -> str: """Convert Claude's markdown output to Telegram-compatible HTML. - Telegram supports a narrow HTML subset: , , ,
,
-    , , . This function converts common markdown patterns
-    to that subset while preserving code blocks verbatim.
-
-    Order of operations:
-    1. Extract fenced code blocks -> placeholders
-    2. Extract inline code -> placeholders
-    3. HTML-escape remaining text
-    4. Convert bold (**text** / __text__)
-    5. Convert italic (*text*, _text_ with word boundaries)
-    6. Convert links [text](url)
-    7. Convert headers (# Header -> Header)
-    8. Convert strikethrough (~~text~~)
-    9. Restore placeholders
+    Order of operations (early steps extract content into placeholders
+    to protect it from later regex passes):
+
+    0.  Markdown tables → aligned 
 blocks
+    1.  Fenced code blocks → 

+    2.  Inline code → 
+    3.  Blockquotes (> text) → 
+ 4. HTML-escape remaining text + 5. Horizontal rules (--- / ***) → ── separator + 6. Bold (**text** / __text__) + 7. Italic (*text* / _text_) + 8. Links [text](url) + 9. Headers (# Header → Header) + 10. Strikethrough (~~text~~) + 11. Unordered lists (- item / * item) + 12. Ordered lists (1. item) + 13. Restore placeholders """ placeholders: List[Tuple[str, str]] = [] placeholder_counter = 0 @@ -46,6 +105,52 @@ def _make_placeholder(html_content: str) -> str: placeholders.append((key, html_content)) return key + # --- 0. Extract markdown tables → monospace
 blocks ---
+    def _replace_table(m: re.Match) -> str:  # type: ignore[type-arg]
+        table_text = m.group(0)
+        lines = table_text.strip().split("\n")
+        rows = []
+        for line in lines:
+            stripped = line.strip()
+            if not stripped.startswith("|"):
+                continue
+            if re.match(r"^\|[\s\-:|]+\|$", stripped):
+                continue
+            cells = [c.strip() for c in stripped.split("|")[1:-1]]
+            if cells:
+                rows.append(cells)
+
+        if not rows:
+            return table_text
+
+        num_cols = max(len(r) for r in rows)
+        col_widths = [0] * num_cols
+        for row in rows:
+            for i, cell in enumerate(row):
+                if i < num_cols:
+                    col_widths[i] = max(col_widths[i], len(cell))
+
+        formatted_lines = []
+        for row in rows:
+            parts = []
+            for i in range(num_cols):
+                cell = row[i] if i < len(row) else ""
+                parts.append(cell.ljust(col_widths[i]))
+            formatted_lines.append(" │ ".join(parts))
+            if len(formatted_lines) == 1:
+                sep_parts = ["─" * w for w in col_widths]
+                formatted_lines.append("─┼─".join(sep_parts))
+
+        pre_content = "\n".join(formatted_lines)
+        return _make_placeholder(f"
{escape_html(pre_content)}
") + + text = re.sub( + r"(?:^\|.+\|$\n?){2,}", + _replace_table, + text, + flags=re.MULTILINE, + ) + # --- 1. Extract fenced code blocks --- def _replace_fenced(m: re.Match) -> str: # type: ignore[type-arg] lang = m.group(1) or "" @@ -72,33 +177,72 @@ def _replace_inline_code(m: re.Match) -> str: # type: ignore[type-arg] text = re.sub(r"`([^`\n]+)`", _replace_inline_code, text) - # --- 3. HTML-escape remaining text --- + # --- 3. Blockquotes: > text →
--- + def _replace_blockquote(m: re.Match) -> str: # type: ignore[type-arg] + block = m.group(0) + # Strip the leading > (and optional space) from each line + lines = [] + for line in block.split("\n"): + stripped = re.sub(r"^>\s?", "", line) + lines.append(stripped) + inner = "\n".join(lines) + # Recursively format the blockquote content + inner_html = escape_html(inner) + return _make_placeholder(f"
{inner_html}
") + + text = re.sub( + r"(?:^>.*$\n?)+", + _replace_blockquote, + text, + flags=re.MULTILINE, + ) + + # --- 4. HTML-escape remaining text --- text = escape_html(text) - # --- 4. Bold: **text** or __text__ --- + # --- 5. Horizontal rules: --- or *** or ___ → visual separator --- + text = re.sub( + r"^(?:---+|\*\*\*+|___+)\s*$", + "──────────", + text, + flags=re.MULTILINE, + ) + + # --- 6. Bold: **text** or __text__ --- text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) text = re.sub(r"__(.+?)__", r"\1", text) - # --- 5. Italic: *text* (require non-space after/before) --- + # --- 7. Italic: *text* (require non-space after/before) --- text = re.sub(r"\*(\S.*?\S|\S)\*", r"\1", text) - # _text_ only at word boundaries (avoid my_var_name) text = re.sub(r"(?\1
", text) - # --- 6. Links: [text](url) --- + # --- 8. Links: [text](url) --- text = re.sub( r"\[([^\]]+)\]\(([^)]+)\)", r'
\1', text, ) - # --- 7. Headers: # Header -> Header --- + # --- 9. Headers: # Header → Header --- text = re.sub(r"^#{1,6}\s+(.+)$", r"\1", text, flags=re.MULTILINE) - # --- 8. Strikethrough: ~~text~~ --- + # --- 10. Strikethrough: ~~text~~ --- text = re.sub(r"~~(.+?)~~", r"\1", text) - # --- 9. Restore placeholders --- + # --- 11. Unordered lists: - item / * item → bullet --- + text = re.sub(r"^[\-\*]\s+", "• ", text, flags=re.MULTILINE) + + # --- 12. Ordered lists: 1. item → keep number with period --- + # (Telegram has no
    , so just clean up the formatting) + text = re.sub(r"^(\d+)\.\s+", r"\1. ", text, flags=re.MULTILINE) + + # --- 13. Restore placeholders --- for key, html_content in placeholders: text = text.replace(key, html_content) + # --- 14. Repair HTML tag nesting --- + # Telegram is strict about nesting: ... is OK, + # but ... is rejected. Fix any mismatches. + text = _repair_html_nesting(text) + return text diff --git a/src/claude/facade.py b/src/claude/facade.py index fcb2ada6..09545ff6 100644 --- a/src/claude/facade.py +++ b/src/claude/facade.py @@ -37,6 +37,7 @@ async def run_command( session_id: Optional[str] = None, on_stream: Optional[Callable[[StreamUpdate], None]] = None, force_new: bool = False, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Run Claude Code command with full integration.""" logger.info( @@ -85,6 +86,7 @@ async def run_command( session_id=claude_session_id, continue_session=should_continue, stream_callback=on_stream, + model_override=model_override, ) except Exception as resume_error: # If resume failed (e.g., session expired/missing on Claude's side), @@ -109,6 +111,7 @@ async def run_command( session_id=None, continue_session=False, stream_callback=on_stream, + model_override=model_override, ) else: raise @@ -152,6 +155,7 @@ async def _execute( session_id: Optional[str] = None, continue_session: bool = False, stream_callback: Optional[Callable] = None, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Execute command via SDK.""" return await self.sdk_manager.execute_command( @@ -160,6 +164,7 @@ async def _execute( session_id=session_id, continue_session=continue_session, stream_callback=stream_callback, + model_override=model_override, ) async def _find_resumable_session( diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..c6b5e591 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -53,6 +53,7 @@ class ClaudeResponse: is_error: bool = False error_type: Optional[str] = None tools_used: List[Dict[str, Any]] = field(default_factory=list) + model: Optional[str] = None @dataclass @@ -153,6 +154,7 @@ async def execute_command( session_id: Optional[str] = None, continue_session: bool = False, stream_callback: Optional[Callable[[StreamUpdate], None]] = None, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() @@ -197,7 +199,7 @@ def _stderr_callback(line: str) -> None: # Build Claude Agent options options = ClaudeAgentOptions( max_turns=self.config.claude_max_turns, - model=self.config.claude_model or None, + model=model_override or self.config.claude_model or None, max_budget_usd=self.config.claude_max_cost_per_request, cwd=str(working_directory), allowed_tools=sdk_allowed_tools, @@ -294,11 +296,12 @@ async def _run_client() -> None: timeout=self.config.claude_timeout_seconds, ) - # Extract cost, tools, and session_id from result message + # Extract cost, tools, session_id, and model from result message cost = 0.0 tools_used: List[Dict[str, Any]] = [] claude_session_id = None result_content = None + response_model: Optional[str] = None for message in messages: if isinstance(message, ResultMessage): cost = getattr(message, "total_cost_usd", 0.0) or 0.0 @@ -307,6 +310,8 @@ async def _run_client() -> None: current_time = asyncio.get_event_loop().time() for msg in messages: if isinstance(msg, AssistantMessage): + if not response_model: + response_model = getattr(msg, "model", None) msg_content = getattr(msg, "content", []) if msg_content and isinstance(msg_content, list): for block in msg_content: @@ -377,6 +382,7 @@ async def _run_client() -> None: ] ), tools_used=tools_used, + model=response_model, ) except asyncio.TimeoutError: diff --git a/src/config/settings.py b/src/config/settings.py index 77c34ea4..9ba77b69 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -210,22 +210,23 @@ class Settings(BaseSettings): ), ) - # Output verbosity (0=quiet, 1=normal, 2=detailed) + # Output verbosity (0=quiet, 1=normal, 2=detailed, 3=full) verbose_level: int = Field( 1, description=( "Bot output verbosity: 0=quiet (final response only), " "1=normal (tool names + reasoning), " - "2=detailed (tool inputs + longer reasoning)" + "2=detailed (tool inputs + longer reasoning), " + "3=full (tool results + complete commands)" ), ge=0, - le=2, + le=3, ) - # Streaming drafts (Telegram sendMessageDraft) + # Streaming drafts (Telegram sendMessageDraft / editMessageText) enable_stream_drafts: bool = Field( False, - description="Stream partial responses via sendMessageDraft (private chats only)", + description="Stream partial responses to Telegram in real-time", ) stream_draft_interval: float = Field( 0.3, diff --git a/tests/unit/test_claude/test_facade.py b/tests/unit/test_claude/test_facade.py index 666a2246..814522e2 100644 --- a/tests/unit/test_claude/test_facade.py +++ b/tests/unit/test_claude/test_facade.py @@ -269,6 +269,85 @@ async def test_retry_after_failure_still_skips_auto_resume( assert user_data["force_new_session"] is False +class TestModelOverride: + """Verify model_override is passed through to _execute.""" + + async def test_model_override_forwarded_to_execute(self, facade, session_manager): + """run_command passes model_override through to _execute.""" + project = Path("/test/project") + user_id = 123 + + with patch.object( + facade, + "_execute", + return_value=_make_mock_response(), + ) as mock_execute: + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + model_override="opus", + ) + + mock_execute.assert_called_once() + assert mock_execute.call_args.kwargs["model_override"] == "opus" + + async def test_model_override_none_by_default(self, facade, session_manager): + """run_command passes model_override=None when not specified.""" + project = Path("/test/project") + user_id = 123 + + with patch.object( + facade, + "_execute", + return_value=_make_mock_response(), + ) as mock_execute: + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + ) + + mock_execute.assert_called_once() + assert mock_execute.call_args.kwargs["model_override"] is None + + async def test_model_override_survives_session_retry(self, facade, session_manager): + """model_override is preserved when session resume fails and retries.""" + project = Path("/test/project") + user_id = 123 + + # Seed an existing session so resume is attempted + existing = ClaudeSession( + session_id="old-session", + user_id=user_id, + project_path=project, + created_at=datetime.utcnow(), + last_used=datetime.utcnow(), + ) + await session_manager.storage.save_session(existing) + session_manager.active_sessions[existing.session_id] = existing + + call_count = [0] + + async def _execute_side_effect(**kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("session expired") + return _make_mock_response() + + with patch.object(facade, "_execute", side_effect=_execute_side_effect): + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + session_id="old-session", + model_override="haiku", + ) + + # Both the initial call and retry should have model_override="haiku" + assert call_count[0] == 2 + + class TestEmptySessionIdWarning: """Verify facade warns when final session_id is empty.""" diff --git a/tests/unit/test_claude/test_sdk_integration.py b/tests/unit/test_claude/test_sdk_integration.py index 17ba58ab..0af40e46 100644 --- a/tests/unit/test_claude/test_sdk_integration.py +++ b/tests/unit/test_claude/test_sdk_integration.py @@ -696,6 +696,66 @@ async def test_claude_model_none_when_unset(self, tmp_path): assert len(captured_options) == 1 assert captured_options[0].model is None + async def test_model_override_takes_priority(self, tmp_path): + """Test that model_override overrides claude_model from config.""" + config = Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + claude_timeout_seconds=2, + claude_model="claude-sonnet-4-6", + ) + manager = ClaudeSDKManager(config) + + captured_options = [] + mock_factory = _mock_client_factory( + _make_assistant_message("Test response"), + _make_result_message(total_cost_usd=0.01), + capture_options=captured_options, + ) + + with patch( + "src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory + ): + await manager.execute_command( + prompt="Test prompt", + working_directory=tmp_path, + model_override="claude-opus-4-6", + ) + + assert len(captured_options) == 1 + assert captured_options[0].model == "claude-opus-4-6" + + async def test_model_override_none_uses_config(self, tmp_path): + """Test that model_override=None falls back to config model.""" + config = Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + claude_timeout_seconds=2, + claude_model="claude-haiku-4-5-20251001", + ) + manager = ClaudeSDKManager(config) + + captured_options = [] + mock_factory = _mock_client_factory( + _make_assistant_message("Test response"), + _make_result_message(total_cost_usd=0.01), + capture_options=captured_options, + ) + + with patch( + "src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory + ): + await manager.execute_command( + prompt="Test prompt", + working_directory=tmp_path, + model_override=None, + ) + + assert len(captured_options) == 1 + assert captured_options[0].model == "claude-haiku-4-5-20251001" + class TestClaudeMCPErrors: """Test MCP-specific error handling.""" diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index cc02b7c0..2e06eea2 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -82,8 +82,8 @@ def deps(): } -def test_agentic_registers_6_commands(agentic_settings, deps): - """Agentic mode registers start, new, status, verbose, repo, restart commands.""" +def test_agentic_registers_7_commands(agentic_settings, deps): + """Agentic mode registers start, new, status, verbose, model, repo, restart.""" orchestrator = MessageOrchestrator(agentic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -100,11 +100,12 @@ def test_agentic_registers_6_commands(agentic_settings, deps): ] commands = [h[0][0].commands for h in cmd_handlers] - assert len(cmd_handlers) == 6 + assert len(cmd_handlers) == 7 assert frozenset({"start"}) in commands assert frozenset({"new"}) in commands assert frozenset({"status"}) in commands assert frozenset({"verbose"}) in commands + assert frozenset({"model"}) in commands assert frozenset({"repo"}) in commands assert frozenset({"restart"}) in commands @@ -149,20 +150,20 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): if isinstance(call[0][0], CallbackQueryHandler) ] - # 4 message handlers (text, document, photo, voice) - assert len(msg_handlers) == 4 + # 5 message handlers (text, document, photo, voice, unknown commands passthrough) + assert len(msg_handlers) == 5 # 1 callback handler (for cd: only) assert len(cb_handlers) == 1 async def test_agentic_bot_commands(agentic_settings, deps): - """Agentic mode returns 6 bot commands.""" + """Agentic mode returns 7 bot commands.""" orchestrator = MessageOrchestrator(agentic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 6 + assert len(commands) == 7 cmd_names = [c.command for c in commands] - assert cmd_names == ["start", "new", "status", "verbose", "repo", "restart"] + assert cmd_names == ["start", "new", "status", "verbose", "model", "repo", "restart"] async def test_classic_bot_commands(classic_settings, deps): @@ -926,3 +927,325 @@ async def help_command(update, context): assert called["value"] is False update.effective_message.reply_text.assert_called_once() + + +async def test_known_command_not_forwarded_to_claude(agentic_settings, deps): + """Known commands must NOT be forwarded to agentic_text.""" + from unittest.mock import AsyncMock, MagicMock, patch + + orchestrator = MessageOrchestrator(agentic_settings, deps) + app = MagicMock() + app.add_handler = MagicMock() + orchestrator.register_handlers(app) + + update = MagicMock() + update.effective_message.text = "/start" + context = MagicMock() + + with patch.object( + orchestrator, "agentic_text", new_callable=AsyncMock + ) as mock_claude: + await orchestrator._handle_unknown_command(update, context) + mock_claude.assert_not_called() + + +async def test_unknown_command_forwarded_to_claude(agentic_settings, deps): + """Unknown slash commands must be forwarded to agentic_text.""" + from unittest.mock import AsyncMock, MagicMock, patch + + orchestrator = MessageOrchestrator(agentic_settings, deps) + app = MagicMock() + app.add_handler = MagicMock() + orchestrator.register_handlers(app) + + update = MagicMock() + update.effective_message.text = "/workflow activate job-hunter" + context = MagicMock() + + with patch.object( + orchestrator, "agentic_text", new_callable=AsyncMock + ) as mock_claude: + await orchestrator._handle_unknown_command(update, context) + mock_claude.assert_called_once_with(update, context) + + +async def test_bot_suffixed_command_not_forwarded(agentic_settings, deps): + """Bot-suffixed known commands like /start@mybot must not reach Claude.""" + from unittest.mock import AsyncMock, MagicMock, patch + + orchestrator = MessageOrchestrator(agentic_settings, deps) + app = MagicMock() + app.add_handler = MagicMock() + orchestrator.register_handlers(app) + + update = MagicMock() + update.effective_message.text = "/start@mybot" + context = MagicMock() + + with patch.object( + orchestrator, "agentic_text", new_callable=AsyncMock + ) as mock_claude: + await orchestrator._handle_unknown_command(update, context) + mock_claude.assert_not_called() + + +# --- /model command tests --- + + +async def test_agentic_model_shows_last_model_when_unset(agentic_settings, deps): + """/model with no override shows the model from the last response.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {"last_model": "claude-opus-4-6"} + + await orchestrator.agentic_model(update, context) + + call_args = update.message.reply_text.call_args + text = call_args.args[0] + assert "claude-opus-4-6" in text + assert "Claude Code default" in text + + +async def test_agentic_model_shows_unknown_before_first_message(agentic_settings, deps): + """/model before any message shows unknown.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + call_args = update.message.reply_text.call_args + text = call_args.args[0] + assert "unknown" in text.lower() + assert call_args.kwargs.get("parse_mode") == "HTML" + + +async def test_agentic_model_shows_config_model(tmp_dir, deps): + """/model shows the server-configured model when CLAUDE_MODEL is set.""" + settings = create_test_config( + approved_directory=str(tmp_dir), + agentic_mode=True, + claude_model="claude-opus-4-6", + ) + orchestrator = MessageOrchestrator(settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "claude-opus-4-6" in text + assert "server config" in text + + +async def test_agentic_model_shows_user_override(agentic_settings, deps): + """/model shows the user's override when one is set.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {"model_override": "haiku"} + + await orchestrator.agentic_model(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "haiku" in text + assert "user override" in text + + +async def test_agentic_model_sets_override(agentic_settings, deps): + """/model sonnet sets the user's model override.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model sonnet" + update.message.reply_text = AsyncMock() + update.effective_user.id = 123 + + context = MagicMock() + context.user_data = {} + context.bot_data = {"audit_logger": AsyncMock()} + + await orchestrator.agentic_model(update, context) + + assert context.user_data["model_override"] == "sonnet" + text = update.message.reply_text.call_args.args[0] + assert "sonnet" in text + + +async def test_agentic_model_reset_to_default(agentic_settings, deps): + """/model default clears the user's model override.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model default" + update.message.reply_text = AsyncMock() + update.effective_user.id = 123 + + context = MagicMock() + context.user_data = {"model_override": "opus"} + context.bot_data = {"audit_logger": AsyncMock()} + + await orchestrator.agentic_model(update, context) + + assert "model_override" not in context.user_data + text = update.message.reply_text.call_args.args[0] + assert "reset" in text.lower() + + +async def test_agentic_model_audit_logged(agentic_settings, deps): + """/model sonnet logs the action to audit logger.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model sonnet" + update.message.reply_text = AsyncMock() + update.effective_user.id = 42 + + audit_logger = AsyncMock() + context = MagicMock() + context.user_data = {} + context.bot_data = {"audit_logger": audit_logger} + + await orchestrator.agentic_model(update, context) + + audit_logger.log_command.assert_called_once_with( + user_id=42, command="model", args=["sonnet"], success=True, + ) + + +async def test_agentic_model_reset_audit_logged(agentic_settings, deps): + """/model default logs as model_reset with empty args.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model default" + update.message.reply_text = AsyncMock() + update.effective_user.id = 42 + + audit_logger = AsyncMock() + context = MagicMock() + context.user_data = {"model_override": "opus"} + context.bot_data = {"audit_logger": audit_logger} + + await orchestrator.agentic_model(update, context) + + audit_logger.log_command.assert_called_once_with( + user_id=42, command="model_reset", args=[], success=True, + ) + + + +async def test_agentic_model_rejects_long_name(agentic_settings, deps): + """/model with overly long name is rejected.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model " + "a" * 101 + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + assert "model_override" not in context.user_data + text = update.message.reply_text.call_args.args[0] + assert "Invalid" in text + + +async def test_model_override_passed_to_run_command(agentic_settings, deps): + """User model override is passed through to claude_integration.run_command.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + mock_response = MagicMock() + mock_response.session_id = "session-abc" + mock_response.content = "Hello!" + mock_response.tools_used = [] + + claude_integration = AsyncMock() + claude_integration.run_command = AsyncMock(return_value=mock_response) + + update = MagicMock() + update.effective_user.id = 123 + update.message.text = "Help me" + update.message.message_id = 1 + update.message.chat.send_action = AsyncMock() + update.message.reply_text = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text.return_value = progress_msg + + context = MagicMock() + context.user_data = {"model_override": "haiku"} + context.bot_data = { + "settings": agentic_settings, + "claude_integration": claude_integration, + "storage": None, + "rate_limiter": None, + "audit_logger": None, + } + + await orchestrator.agentic_text(update, context) + + claude_integration.run_command.assert_called_once() + call_kwargs = claude_integration.run_command.call_args.kwargs + assert call_kwargs["model_override"] == "haiku" + + +async def test_model_override_none_when_not_set(agentic_settings, deps): + """model_override is None when user hasn't set one.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + mock_response = MagicMock() + mock_response.session_id = "session-abc" + mock_response.content = "Hello!" + mock_response.tools_used = [] + + claude_integration = AsyncMock() + claude_integration.run_command = AsyncMock(return_value=mock_response) + + update = MagicMock() + update.effective_user.id = 123 + update.message.text = "Help me" + update.message.message_id = 1 + update.message.chat.send_action = AsyncMock() + update.message.reply_text = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text.return_value = progress_msg + + context = MagicMock() + context.user_data = {} + context.bot_data = { + "settings": agentic_settings, + "claude_integration": claude_integration, + "storage": None, + "rate_limiter": None, + "audit_logger": None, + } + + await orchestrator.agentic_text(update, context) + + call_kwargs = claude_integration.run_command.call_args.kwargs + assert call_kwargs["model_override"] is None