diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py
index ac1d5304..9a99c7b2 100644
--- a/src/bot/orchestrator.py
+++ b/src/bot/orchestrator.py
@@ -115,6 +115,7 @@ class MessageOrchestrator:
def __init__(self, settings: Settings, deps: Dict[str, Any]):
self.settings = settings
self.deps = deps
+ self._known_commands: frozenset[str] = frozenset()
def _inject_deps(self, handler: Callable) -> Callable: # type: ignore[type-arg]
"""Wrap handler to inject dependencies into context.bot_data."""
@@ -306,12 +307,16 @@ def _register_agentic_handlers(self, app: Application) -> None:
("new", self.agentic_new),
("status", self.agentic_status),
("verbose", self.agentic_verbose),
+ ("model", self.agentic_model),
("repo", self.agentic_repo),
("restart", command.restart_command),
]
if self.settings.enable_project_threads:
handlers.append(("sync_threads", command.sync_threads))
+ # Derive known commands dynamically — avoids drift when new commands are added
+ self._known_commands: frozenset[str] = frozenset(cmd for cmd, _ in handlers)
+
for cmd, handler in handlers:
app.add_handler(CommandHandler(cmd, self._inject_deps(handler)))
@@ -324,6 +329,19 @@ def _register_agentic_handlers(self, app: Application) -> None:
group=10,
)
+ # Unknown slash commands -> Claude (passthrough in agentic mode).
+ # Registered commands are handled by CommandHandlers in group 0
+ # (higher priority). This catches any /command not matched there
+ # and forwards it to Claude, while skipping known commands to
+ # avoid double-firing.
+ app.add_handler(
+ MessageHandler(
+ filters.COMMAND,
+ self._inject_deps(self._handle_unknown_command),
+ ),
+ group=10,
+ )
+
# File uploads -> Claude
app.add_handler(
MessageHandler(
@@ -415,6 +433,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg]
BotCommand("new", "Start a fresh session"),
BotCommand("status", "Show session status"),
BotCommand("verbose", "Set output verbosity (0/1/2)"),
+ BotCommand("model", "Switch Claude model"),
BotCommand("repo", "List repos / switch workspace"),
BotCommand("restart", "Restart the bot"),
]
@@ -578,6 +597,81 @@ async def agentic_verbose(
parse_mode="HTML",
)
+ def _get_model_override(self, context: ContextTypes.DEFAULT_TYPE) -> Optional[str]:
+ """Return per-user model override, or None to use the default."""
+ return context.user_data.get("model_override")
+
+ @staticmethod
+ def _resolve_model_display(
+ user_override: Optional[str],
+ config_model: Optional[str],
+ last_model: Optional[str] = None,
+ ) -> str:
+ """Return a human-readable model string showing what will actually be used."""
+ if user_override:
+ return user_override
+ if config_model:
+ return config_model
+ if last_model:
+ return last_model
+ return "unknown (send a message first to detect)"
+
+ async def agentic_model(
+ self, update: Update, context: ContextTypes.DEFAULT_TYPE
+ ) -> None:
+ """Set Claude model: /model [model_name]."""
+ args = update.message.text.split()[1:] if update.message.text else []
+ user_override = self._get_model_override(context)
+ last_model = context.user_data.get("last_model")
+ current = self._resolve_model_display(
+ user_override, self.settings.claude_model, last_model
+ )
+
+ if not args:
+ source = "user override" if user_override else (
+ "server config" if self.settings.claude_model else "Claude Code default"
+ )
+ await update.message.reply_text(
+ f"Model: {escape_html(current)} ({source})\n\n"
+ "Usage: /model model_name\n"
+ "Reset: /model default",
+ parse_mode="HTML",
+ )
+ return
+
+ model_name = args[0].strip()
+ if not model_name or len(model_name) > 100:
+ await update.message.reply_text("Invalid model name.")
+ return
+ audit_logger = context.bot_data.get("audit_logger")
+ if model_name == "default":
+ context.user_data.pop("model_override", None)
+ default = self._resolve_model_display(None, self.settings.claude_model)
+ await update.message.reply_text(
+ f"Model reset to {escape_html(default)}",
+ parse_mode="HTML",
+ )
+ if audit_logger:
+ await audit_logger.log_command(
+ user_id=update.effective_user.id,
+ command="model_reset",
+ args=[],
+ success=True,
+ )
+ else:
+ context.user_data["model_override"] = model_name
+ await update.message.reply_text(
+ f"Model set to {escape_html(model_name)}",
+ parse_mode="HTML",
+ )
+ if audit_logger:
+ await audit_logger.log_command(
+ user_id=update.effective_user.id,
+ command="model",
+ args=[model_name],
+ success=True,
+ )
+
def _format_verbose_progress(
self,
activity_log: List[Dict[str, Any]],
@@ -941,6 +1035,7 @@ async def agentic_text(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
+ model_override=self._get_model_override(context),
)
# New session created successfully — clear the one-shot flag
@@ -948,6 +1043,8 @@ async def agentic_text(
context.user_data["force_new_session"] = False
context.user_data["claude_session_id"] = claude_response.session_id
+ if claude_response.model:
+ context.user_data["last_model"] = claude_response.model
# Track directory changes
from .handlers.message import _update_working_directory_from_claude_response
@@ -1185,12 +1282,15 @@ async def agentic_document(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
+ model_override=self._get_model_override(context),
)
if force_new:
context.user_data["force_new_session"] = False
context.user_data["claude_session_id"] = claude_response.session_id
+ if claude_response.model:
+ context.user_data["last_model"] = claude_response.model
from .handlers.message import _update_working_directory_from_claude_response
@@ -1384,6 +1484,7 @@ async def _handle_agentic_media_message(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
+ model_override=self._get_model_override(context),
)
finally:
heartbeat.cancel()
@@ -1392,6 +1493,7 @@ async def _handle_agentic_media_message(
context.user_data["force_new_session"] = False
context.user_data["claude_session_id"] = claude_response.session_id
+ context.user_data["last_model"] = claude_response.model
from .handlers.message import _update_working_directory_from_claude_response
@@ -1450,6 +1552,25 @@ async def _handle_agentic_media_message(
except Exception as img_err:
logger.warning("Image send failed", error=str(img_err))
+ async def _handle_unknown_command(
+ self, update: Update, context: ContextTypes.DEFAULT_TYPE
+ ) -> None:
+ """Forward unknown slash commands to Claude in agentic mode.
+
+ Known commands are handled by their own CommandHandlers (group 0);
+ this handler fires for *every* COMMAND message in group 10 but
+ returns immediately when the command is registered, preventing
+ double execution.
+ """
+ msg = update.effective_message
+ if not msg or not msg.text:
+ return
+ cmd = msg.text.split()[0].lstrip("/").split("@")[0].lower()
+ if cmd in self._known_commands:
+ return # let the registered CommandHandler take care of it
+ # Forward unrecognised /commands to Claude as natural language
+ await self.agentic_text(update, context)
+
def _voice_unavailable_message(self) -> str:
"""Return provider-aware guidance when voice feature is unavailable."""
return (
diff --git a/src/bot/utils/draft_streamer.py b/src/bot/utils/draft_streamer.py
index 4fe1e709..6f263ba6 100644
--- a/src/bot/utils/draft_streamer.py
+++ b/src/bot/utils/draft_streamer.py
@@ -1,4 +1,9 @@
-"""Stream partial responses to Telegram via sendMessageDraft."""
+"""Stream partial responses to Telegram via sendMessageDraft.
+
+Uses Telegram Bot API 9.3+ sendMessageDraft for smooth token-by-token
+streaming in private chats. Falls back to editMessageText for group chats
+where sendMessageDraft is unavailable.
+"""
import secrets
import time
@@ -14,6 +19,17 @@
# Max tool lines shown in the draft header
_MAX_TOOL_LINES = 10
+# Minimum characters before sending the first draft (avoids triggering
+# push notifications with just a few characters)
+_MIN_INITIAL_CHARS = 20
+
+# Error messages that indicate the draft transport is unavailable
+_DRAFT_UNAVAILABLE_ERRORS = frozenset({
+ "TEXTDRAFT_PEER_INVALID",
+ "Bad Request: draft can't be sent",
+ "Bad Request: peer doesn't support drafts",
+})
+
def generate_draft_id() -> int:
"""Generate a non-zero positive draft ID.
@@ -30,18 +46,21 @@ class DraftStreamer:
The draft is composed of two sections:
1. **Tool header** — compact lines showing tool calls and reasoning
- snippets as they arrive, e.g. ``"📖 Read | 🔍 Grep | 🐚 Bash"``.
+ snippets as they arrive.
2. **Response body** — the actual assistant response text, streamed
token-by-token.
Both sections are combined into a single draft message and sent via
- ``sendMessageDraft``.
+ ``sendMessageDraft`` (private chats) or ``editMessageText`` (groups).
- Key design decisions:
+ Key design decisions (inspired by OpenClaw):
- Plain text drafts (no parse_mode) to avoid partial HTML/markdown errors.
- - Tail-truncation for messages >4096 chars: shows ``"\\u2026" + last 4093 chars``.
- - Self-disabling: any API error silently disables the streamer so the
- request continues with normal (non-streaming) delivery.
+ - Tail-truncation for messages >4096 chars.
+ - Min initial chars: waits for ~20 chars before first send.
+ - Anti-regressive: skips updates where text got shorter.
+ - Error classification: distinguishes draft-unavailable (fall back to edit)
+ from other errors (disable entirely).
+ - Self-disabling: persistent errors silently disable the streamer.
"""
def __init__(
@@ -50,7 +69,8 @@ def __init__(
chat_id: int,
draft_id: int,
message_thread_id: Optional[int] = None,
- throttle_interval: float = 0.3,
+ throttle_interval: float = 0.4,
+ is_private_chat: bool = True,
) -> None:
self.bot = bot
self.chat_id = chat_id
@@ -61,7 +81,18 @@ def __init__(
self._tool_lines: List[str] = []
self._accumulated_text = ""
self._last_send_time = 0.0
+ self._last_sent_length = 0 # anti-regressive tracking
self._enabled = True
+ self._error_count = 0
+ self._max_errors = 3
+
+ # Transport mode: "draft" for private chats, "edit" for groups
+ self._use_draft = is_private_chat
+ self._edit_message_id: Optional[int] = None # for edit-based transport
+
+ @property
+ def enabled(self) -> bool:
+ return self._enabled
async def append_tool(self, line: str) -> None:
"""Append a tool activity line and send a draft if throttled."""
@@ -87,10 +118,14 @@ async def flush(self) -> None:
return
if not self._accumulated_text and not self._tool_lines:
return
- await self._send_draft()
+ await self._send_draft(force=True)
+
+ def _compose_draft(self, is_final: bool = False) -> str:
+ """Combine tool header and response body into a single draft.
- def _compose_draft(self) -> str:
- """Combine tool header and response body into a single draft."""
+ Appends a blinking cursor ▌ during streaming (like OpenClaw)
+ to indicate the response is still being generated.
+ """
parts: List[str] = []
if self._tool_lines:
@@ -103,33 +138,157 @@ def _compose_draft(self) -> str:
if self._accumulated_text:
if parts:
parts.append("") # blank separator line
- parts.append(self._accumulated_text)
+ text = self._accumulated_text
+ if not is_final:
+ text += " ▌"
+ parts.append(text)
return "\n".join(parts)
- async def _send_draft(self) -> None:
- """Send the composed draft (tools + text) as a message draft."""
+ async def _send_draft(self, force: bool = False) -> None:
+ """Send the composed draft via the appropriate transport."""
draft_text = self._compose_draft()
if not draft_text.strip():
return
+ # Min initial chars gate (skip if force-flushing)
+ if not force and self._last_sent_length == 0:
+ if len(self._accumulated_text) < _MIN_INITIAL_CHARS and not self._tool_lines:
+ return
+
+ # Anti-regressive: skip if text got shorter (can happen with
+ # tool header rotation)
+ current_len = len(draft_text)
+ if not force and current_len < self._last_sent_length:
+ return
+
# Tail-truncate if over Telegram limit
if len(draft_text) > TELEGRAM_MAX_MESSAGE_LENGTH:
- draft_text = "\u2026" + draft_text[-(TELEGRAM_MAX_MESSAGE_LENGTH - 1) :]
+ draft_text = "\u2026" + draft_text[-(TELEGRAM_MAX_MESSAGE_LENGTH - 1):]
try:
+ if self._use_draft:
+ await self._send_via_draft(draft_text)
+ else:
+ await self._send_via_edit(draft_text)
+ self._last_send_time = time.time()
+ self._last_sent_length = current_len
+ self._error_count = 0 # reset on success
+ except telegram.error.BadRequest as e:
+ error_str = str(e)
+ if any(err in error_str for err in _DRAFT_UNAVAILABLE_ERRORS):
+ # Draft transport unavailable — fall back to edit
+ logger.info(
+ "Draft transport unavailable, falling back to edit",
+ chat_id=self.chat_id,
+ error=error_str,
+ )
+ self._use_draft = False
+ # Retry immediately with edit transport
+ try:
+ await self._send_via_edit(draft_text)
+ self._last_send_time = time.time()
+ self._last_sent_length = current_len
+ except Exception:
+ self._handle_error()
+ elif "Message is not modified" in error_str:
+ # Same content — not an error, just skip
+ self._last_send_time = time.time()
+ elif "Message to edit not found" in error_str:
+ # Message was deleted — re-create
+ self._edit_message_id = None
+ try:
+ await self._send_via_edit(draft_text)
+ self._last_send_time = time.time()
+ self._last_sent_length = current_len
+ except Exception:
+ self._handle_error()
+ else:
+ self._handle_error()
+ except Exception:
+ self._handle_error()
+
+ def _handle_error(self) -> None:
+ """Track errors and disable after too many."""
+ self._error_count += 1
+ if self._error_count >= self._max_errors:
+ logger.debug(
+ "Draft streamer disabled after repeated errors",
+ chat_id=self.chat_id,
+ error_count=self._error_count,
+ )
+ self._enabled = False
+
+ async def _send_via_draft(self, text: str) -> None:
+ """Send via sendMessageDraft (private chats)."""
+ kwargs = {
+ "chat_id": self.chat_id,
+ "text": text,
+ "draft_id": self.draft_id,
+ }
+ if self.message_thread_id is not None:
+ kwargs["message_thread_id"] = self.message_thread_id
+ logger.debug(
+ "Sending draft",
+ transport="draft",
+ text_len=len(text),
+ preview=text[:80],
+ )
+ await self.bot.send_message_draft(**kwargs)
+
+ async def _send_via_edit(self, text: str) -> None:
+ """Send via editMessageText (group chat fallback).
+
+ Creates a message on first call, then edits it on subsequent calls.
+ """
+ if self._edit_message_id is None:
+ # Send initial message
kwargs = {
"chat_id": self.chat_id,
- "text": draft_text,
- "draft_id": self.draft_id,
+ "text": text,
}
if self.message_thread_id is not None:
kwargs["message_thread_id"] = self.message_thread_id
- await self.bot.send_message_draft(**kwargs)
- self._last_send_time = time.time()
- except Exception:
- logger.debug(
- "Draft send failed, disabling streamer",
+ msg = await self.bot.send_message(**kwargs)
+ self._edit_message_id = msg.message_id
+ else:
+ await self.bot.edit_message_text(
+ text,
chat_id=self.chat_id,
+ message_id=self._edit_message_id,
)
- self._enabled = False
+
+ async def clear(self) -> None:
+ """Clear the draft bubble by sending an empty draft.
+
+ Call this before sending the final response message so the draft
+ bubble disappears cleanly instead of overlapping with the real message.
+ """
+ if not self._enabled:
+ return
+ try:
+ if self._use_draft:
+ # Send empty draft to dismiss the typing bubble
+ await self.bot.send_message_draft(
+ chat_id=self.chat_id,
+ text="",
+ draft_id=self.draft_id,
+ )
+ elif self._edit_message_id is not None:
+ # For edit-based transport, delete the preview message
+ try:
+ await self.bot.delete_message(
+ chat_id=self.chat_id,
+ message_id=self._edit_message_id,
+ )
+ except Exception:
+ pass
+ self._edit_message_id = None
+ except Exception:
+ pass
+ self._enabled = False
+
+ @property
+ def edit_message_id(self) -> Optional[int]:
+ """Return the message ID used by edit transport (for cleanup)."""
+ return self._edit_message_id
diff --git a/src/bot/utils/html_format.py b/src/bot/utils/html_format.py
index 2799a4ee..b84bbd58 100644
--- a/src/bot/utils/html_format.py
+++ b/src/bot/utils/html_format.py
@@ -1,40 +1,99 @@
"""HTML formatting utilities for Telegram messages.
-Telegram's HTML mode only requires escaping 3 characters (<, >, &) vs the many
-ambiguous Markdown v1 metacharacters, making it far more robust for rendering
-Claude's output which contains underscores, asterisks, brackets, etc.
+Telegram's HTML mode supports: , , , , is rejected. This walks the tag stack
+ and closes/reopens tags when it detects a mismatch.
"""
+ result = []
+ stack: List[str] = []
+ last_end = 0
+
+ for m in _TAG_RE.finditer(html):
+ # Append text before this tag
+ result.append(html[last_end:m.start()])
+ last_end = m.end()
+
+ is_close = m.group(1) == "/"
+ tag = m.group(2).lower()
+
+ # Only repair inline tags; skip , ,
+
, , ,
+
,
,, , etc. + if tag not in _INLINE_TAGS: + result.append(m.group(0)) + continue + + if not is_close: + stack.append(tag) + result.append(m.group(0)) + else: + if tag in stack: + # Close tags in reverse order up to the matching opener + idx = len(stack) - 1 - stack[::-1].index(tag) + tags_to_reopen = stack[idx + 1:] + # Close everything from top to idx + for t in reversed(stack[idx:]): + result.append(f"{t}>") + stack = stack[:idx] + # Reopen tags that were above the matched one + for t in tags_to_reopen: + result.append(f"<{t}>") + stack.append(t) + else: + # Orphan close tag — skip it + pass + + # Append remaining text + result.append(html[last_end:]) + + # Close any unclosed tags + for t in reversed(stack): + result.append(f"{t}>") + + return "".join(result) + + +def escape_html(text: str) -> str: + """Escape the 3 HTML-special characters for Telegram.""" return text.replace("&", "&").replace("<", "<").replace(">", ">") def markdown_to_telegram_html(text: str) -> str: """Convert Claude's markdown output to Telegram-compatible HTML. - Telegram supports a narrow HTML subset: , ,,", text) - # --- 6. Links: [text](url) --- + # --- 8. Links: [text](url) --- text = re.sub( r"\[([^\]]+)\]\(([^)]+)\)", r'\1', text, ) - # --- 7. Headers: # Header -> Header --- + # --- 9. Headers: # Header → Header --- text = re.sub(r"^#{1,6}\s+(.+)$", r"\1", text, flags=re.MULTILINE) - # --- 8. Strikethrough: ~~text~~ --- + # --- 10. Strikethrough: ~~text~~ --- text = re.sub(r"~~(.+?)~~", r", - ,, . This function converts common markdown patterns - to that subset while preserving code blocks verbatim. - - Order of operations: - 1. Extract fenced code blocks -> placeholders - 2. Extract inline code -> placeholders - 3. HTML-escape remaining text - 4. Convert bold (**text** / __text__) - 5. Convert italic (*text*, _text_ with word boundaries) - 6. Convert links [text](url) - 7. Convert headers (# Header -> Header) - 8. Convert strikethrough (~~text~~) - 9. Restore placeholders + Order of operations (early steps extract content into placeholders + to protect it from later regex passes): + + 0. Markdown tables → alignedblocks + 1. Fenced code blocks →+ 2. Inline code →+ 3. Blockquotes (> text) →+ 4. HTML-escape remaining text + 5. Horizontal rules (--- / ***) → ── separator + 6. Bold (**text** / __text__) + 7. Italic (*text* / _text_) + 8. Links [text](url) + 9. Headers (# Header → Header) + 10. Strikethrough (~~text~~) + 11. Unordered lists (- item / * item) + 12. Ordered lists (1. item) + 13. Restore placeholders """ placeholders: List[Tuple[str, str]] = [] placeholder_counter = 0 @@ -46,6 +105,52 @@ def _make_placeholder(html_content: str) -> str: placeholders.append((key, html_content)) return key + # --- 0. Extract markdown tables → monospaceblocks --- + def _replace_table(m: re.Match) -> str: # type: ignore[type-arg] + table_text = m.group(0) + lines = table_text.strip().split("\n") + rows = [] + for line in lines: + stripped = line.strip() + if not stripped.startswith("|"): + continue + if re.match(r"^\|[\s\-:|]+\|$", stripped): + continue + cells = [c.strip() for c in stripped.split("|")[1:-1]] + if cells: + rows.append(cells) + + if not rows: + return table_text + + num_cols = max(len(r) for r in rows) + col_widths = [0] * num_cols + for row in rows: + for i, cell in enumerate(row): + if i < num_cols: + col_widths[i] = max(col_widths[i], len(cell)) + + formatted_lines = [] + for row in rows: + parts = [] + for i in range(num_cols): + cell = row[i] if i < len(row) else "" + parts.append(cell.ljust(col_widths[i])) + formatted_lines.append(" │ ".join(parts)) + if len(formatted_lines) == 1: + sep_parts = ["─" * w for w in col_widths] + formatted_lines.append("─┼─".join(sep_parts)) + + pre_content = "\n".join(formatted_lines) + return _make_placeholder(f"{escape_html(pre_content)}") + + text = re.sub( + r"(?:^\|.+\|$\n?){2,}", + _replace_table, + text, + flags=re.MULTILINE, + ) + # --- 1. Extract fenced code blocks --- def _replace_fenced(m: re.Match) -> str: # type: ignore[type-arg] lang = m.group(1) or "" @@ -72,33 +177,72 @@ def _replace_inline_code(m: re.Match) -> str: # type: ignore[type-arg] text = re.sub(r"`([^`\n]+)`", _replace_inline_code, text) - # --- 3. HTML-escape remaining text --- + # --- 3. Blockquotes: > text →--- + def _replace_blockquote(m: re.Match) -> str: # type: ignore[type-arg] + block = m.group(0) + # Strip the leading > (and optional space) from each line + lines = [] + for line in block.split("\n"): + stripped = re.sub(r"^>\s?", "", line) + lines.append(stripped) + inner = "\n".join(lines) + # Recursively format the blockquote content + inner_html = escape_html(inner) + return _make_placeholder(f"{inner_html}") + + text = re.sub( + r"(?:^>.*$\n?)+", + _replace_blockquote, + text, + flags=re.MULTILINE, + ) + + # --- 4. HTML-escape remaining text --- text = escape_html(text) - # --- 4. Bold: **text** or __text__ --- + # --- 5. Horizontal rules: --- or *** or ___ → visual separator --- + text = re.sub( + r"^(?:---+|\*\*\*+|___+)\s*$", + "──────────", + text, + flags=re.MULTILINE, + ) + + # --- 6. Bold: **text** or __text__ --- text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) text = re.sub(r"__(.+?)__", r"\1", text) - # --- 5. Italic: *text* (require non-space after/before) --- + # --- 7. Italic: *text* (require non-space after/before) --- text = re.sub(r"\*(\S.*?\S|\S)\*", r"\1", text) - # _text_ only at word boundaries (avoid my_var_name) text = re.sub(r"(?\1\1", text) - # --- 9. Restore placeholders --- + # --- 11. Unordered lists: - item / * item → bullet --- + text = re.sub(r"^[\-\*]\s+", "• ", text, flags=re.MULTILINE) + + # --- 12. Ordered lists: 1. item → keep number with period --- + # (Telegram has no, so just clean up the formatting) + text = re.sub(r"^(\d+)\.\s+", r"\1. ", text, flags=re.MULTILINE) + + # --- 13. Restore placeholders --- for key, html_content in placeholders: text = text.replace(key, html_content) + # --- 14. Repair HTML tag nesting --- + # Telegram is strict about nesting: ... is OK, + # but ...
is rejected. Fix any mismatches. + text = _repair_html_nesting(text) + return text diff --git a/src/claude/facade.py b/src/claude/facade.py index fcb2ada6..09545ff6 100644 --- a/src/claude/facade.py +++ b/src/claude/facade.py @@ -37,6 +37,7 @@ async def run_command( session_id: Optional[str] = None, on_stream: Optional[Callable[[StreamUpdate], None]] = None, force_new: bool = False, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Run Claude Code command with full integration.""" logger.info( @@ -85,6 +86,7 @@ async def run_command( session_id=claude_session_id, continue_session=should_continue, stream_callback=on_stream, + model_override=model_override, ) except Exception as resume_error: # If resume failed (e.g., session expired/missing on Claude's side), @@ -109,6 +111,7 @@ async def run_command( session_id=None, continue_session=False, stream_callback=on_stream, + model_override=model_override, ) else: raise @@ -152,6 +155,7 @@ async def _execute( session_id: Optional[str] = None, continue_session: bool = False, stream_callback: Optional[Callable] = None, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Execute command via SDK.""" return await self.sdk_manager.execute_command( @@ -160,6 +164,7 @@ async def _execute( session_id=session_id, continue_session=continue_session, stream_callback=stream_callback, + model_override=model_override, ) async def _find_resumable_session( diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..c6b5e591 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -53,6 +53,7 @@ class ClaudeResponse: is_error: bool = False error_type: Optional[str] = None tools_used: List[Dict[str, Any]] = field(default_factory=list) + model: Optional[str] = None @dataclass @@ -153,6 +154,7 @@ async def execute_command( session_id: Optional[str] = None, continue_session: bool = False, stream_callback: Optional[Callable[[StreamUpdate], None]] = None, + model_override: Optional[str] = None, ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() @@ -197,7 +199,7 @@ def _stderr_callback(line: str) -> None: # Build Claude Agent options options = ClaudeAgentOptions( max_turns=self.config.claude_max_turns, - model=self.config.claude_model or None, + model=model_override or self.config.claude_model or None, max_budget_usd=self.config.claude_max_cost_per_request, cwd=str(working_directory), allowed_tools=sdk_allowed_tools, @@ -294,11 +296,12 @@ async def _run_client() -> None: timeout=self.config.claude_timeout_seconds, ) - # Extract cost, tools, and session_id from result message + # Extract cost, tools, session_id, and model from result message cost = 0.0 tools_used: List[Dict[str, Any]] = [] claude_session_id = None result_content = None + response_model: Optional[str] = None for message in messages: if isinstance(message, ResultMessage): cost = getattr(message, "total_cost_usd", 0.0) or 0.0 @@ -307,6 +310,8 @@ async def _run_client() -> None: current_time = asyncio.get_event_loop().time() for msg in messages: if isinstance(msg, AssistantMessage): + if not response_model: + response_model = getattr(msg, "model", None) msg_content = getattr(msg, "content", []) if msg_content and isinstance(msg_content, list): for block in msg_content: @@ -377,6 +382,7 @@ async def _run_client() -> None: ] ), tools_used=tools_used, + model=response_model, ) except asyncio.TimeoutError: diff --git a/src/config/settings.py b/src/config/settings.py index 77c34ea4..9ba77b69 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -210,22 +210,23 @@ class Settings(BaseSettings): ), ) - # Output verbosity (0=quiet, 1=normal, 2=detailed) + # Output verbosity (0=quiet, 1=normal, 2=detailed, 3=full) verbose_level: int = Field( 1, description=( "Bot output verbosity: 0=quiet (final response only), " "1=normal (tool names + reasoning), " - "2=detailed (tool inputs + longer reasoning)" + "2=detailed (tool inputs + longer reasoning), " + "3=full (tool results + complete commands)" ), ge=0, - le=2, + le=3, ) - # Streaming drafts (Telegram sendMessageDraft) + # Streaming drafts (Telegram sendMessageDraft / editMessageText) enable_stream_drafts: bool = Field( False, - description="Stream partial responses via sendMessageDraft (private chats only)", + description="Stream partial responses to Telegram in real-time", ) stream_draft_interval: float = Field( 0.3, diff --git a/tests/unit/test_claude/test_facade.py b/tests/unit/test_claude/test_facade.py index 666a2246..814522e2 100644 --- a/tests/unit/test_claude/test_facade.py +++ b/tests/unit/test_claude/test_facade.py @@ -269,6 +269,85 @@ async def test_retry_after_failure_still_skips_auto_resume( assert user_data["force_new_session"] is False +class TestModelOverride: + """Verify model_override is passed through to _execute.""" + + async def test_model_override_forwarded_to_execute(self, facade, session_manager): + """run_command passes model_override through to _execute.""" + project = Path("/test/project") + user_id = 123 + + with patch.object( + facade, + "_execute", + return_value=_make_mock_response(), + ) as mock_execute: + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + model_override="opus", + ) + + mock_execute.assert_called_once() + assert mock_execute.call_args.kwargs["model_override"] == "opus" + + async def test_model_override_none_by_default(self, facade, session_manager): + """run_command passes model_override=None when not specified.""" + project = Path("/test/project") + user_id = 123 + + with patch.object( + facade, + "_execute", + return_value=_make_mock_response(), + ) as mock_execute: + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + ) + + mock_execute.assert_called_once() + assert mock_execute.call_args.kwargs["model_override"] is None + + async def test_model_override_survives_session_retry(self, facade, session_manager): + """model_override is preserved when session resume fails and retries.""" + project = Path("/test/project") + user_id = 123 + + # Seed an existing session so resume is attempted + existing = ClaudeSession( + session_id="old-session", + user_id=user_id, + project_path=project, + created_at=datetime.utcnow(), + last_used=datetime.utcnow(), + ) + await session_manager.storage.save_session(existing) + session_manager.active_sessions[existing.session_id] = existing + + call_count = [0] + + async def _execute_side_effect(**kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("session expired") + return _make_mock_response() + + with patch.object(facade, "_execute", side_effect=_execute_side_effect): + await facade.run_command( + prompt="hello", + working_directory=project, + user_id=user_id, + session_id="old-session", + model_override="haiku", + ) + + # Both the initial call and retry should have model_override="haiku" + assert call_count[0] == 2 + + class TestEmptySessionIdWarning: """Verify facade warns when final session_id is empty.""" diff --git a/tests/unit/test_claude/test_sdk_integration.py b/tests/unit/test_claude/test_sdk_integration.py index 17ba58ab..0af40e46 100644 --- a/tests/unit/test_claude/test_sdk_integration.py +++ b/tests/unit/test_claude/test_sdk_integration.py @@ -696,6 +696,66 @@ async def test_claude_model_none_when_unset(self, tmp_path): assert len(captured_options) == 1 assert captured_options[0].model is None + async def test_model_override_takes_priority(self, tmp_path): + """Test that model_override overrides claude_model from config.""" + config = Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + claude_timeout_seconds=2, + claude_model="claude-sonnet-4-6", + ) + manager = ClaudeSDKManager(config) + + captured_options = [] + mock_factory = _mock_client_factory( + _make_assistant_message("Test response"), + _make_result_message(total_cost_usd=0.01), + capture_options=captured_options, + ) + + with patch( + "src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory + ): + await manager.execute_command( + prompt="Test prompt", + working_directory=tmp_path, + model_override="claude-opus-4-6", + ) + + assert len(captured_options) == 1 + assert captured_options[0].model == "claude-opus-4-6" + + async def test_model_override_none_uses_config(self, tmp_path): + """Test that model_override=None falls back to config model.""" + config = Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + claude_timeout_seconds=2, + claude_model="claude-haiku-4-5-20251001", + ) + manager = ClaudeSDKManager(config) + + captured_options = [] + mock_factory = _mock_client_factory( + _make_assistant_message("Test response"), + _make_result_message(total_cost_usd=0.01), + capture_options=captured_options, + ) + + with patch( + "src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory + ): + await manager.execute_command( + prompt="Test prompt", + working_directory=tmp_path, + model_override=None, + ) + + assert len(captured_options) == 1 + assert captured_options[0].model == "claude-haiku-4-5-20251001" + class TestClaudeMCPErrors: """Test MCP-specific error handling.""" diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index cc02b7c0..2e06eea2 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -82,8 +82,8 @@ def deps(): } -def test_agentic_registers_6_commands(agentic_settings, deps): - """Agentic mode registers start, new, status, verbose, repo, restart commands.""" +def test_agentic_registers_7_commands(agentic_settings, deps): + """Agentic mode registers start, new, status, verbose, model, repo, restart.""" orchestrator = MessageOrchestrator(agentic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -100,11 +100,12 @@ def test_agentic_registers_6_commands(agentic_settings, deps): ] commands = [h[0][0].commands for h in cmd_handlers] - assert len(cmd_handlers) == 6 + assert len(cmd_handlers) == 7 assert frozenset({"start"}) in commands assert frozenset({"new"}) in commands assert frozenset({"status"}) in commands assert frozenset({"verbose"}) in commands + assert frozenset({"model"}) in commands assert frozenset({"repo"}) in commands assert frozenset({"restart"}) in commands @@ -149,20 +150,20 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): if isinstance(call[0][0], CallbackQueryHandler) ] - # 4 message handlers (text, document, photo, voice) - assert len(msg_handlers) == 4 + # 5 message handlers (text, document, photo, voice, unknown commands passthrough) + assert len(msg_handlers) == 5 # 1 callback handler (for cd: only) assert len(cb_handlers) == 1 async def test_agentic_bot_commands(agentic_settings, deps): - """Agentic mode returns 6 bot commands.""" + """Agentic mode returns 7 bot commands.""" orchestrator = MessageOrchestrator(agentic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 6 + assert len(commands) == 7 cmd_names = [c.command for c in commands] - assert cmd_names == ["start", "new", "status", "verbose", "repo", "restart"] + assert cmd_names == ["start", "new", "status", "verbose", "model", "repo", "restart"] async def test_classic_bot_commands(classic_settings, deps): @@ -926,3 +927,325 @@ async def help_command(update, context): assert called["value"] is False update.effective_message.reply_text.assert_called_once() + + +async def test_known_command_not_forwarded_to_claude(agentic_settings, deps): + """Known commands must NOT be forwarded to agentic_text.""" + from unittest.mock import AsyncMock, MagicMock, patch + + orchestrator = MessageOrchestrator(agentic_settings, deps) + app = MagicMock() + app.add_handler = MagicMock() + orchestrator.register_handlers(app) + + update = MagicMock() + update.effective_message.text = "/start" + context = MagicMock() + + with patch.object( + orchestrator, "agentic_text", new_callable=AsyncMock + ) as mock_claude: + await orchestrator._handle_unknown_command(update, context) + mock_claude.assert_not_called() + + +async def test_unknown_command_forwarded_to_claude(agentic_settings, deps): + """Unknown slash commands must be forwarded to agentic_text.""" + from unittest.mock import AsyncMock, MagicMock, patch + + orchestrator = MessageOrchestrator(agentic_settings, deps) + app = MagicMock() + app.add_handler = MagicMock() + orchestrator.register_handlers(app) + + update = MagicMock() + update.effective_message.text = "/workflow activate job-hunter" + context = MagicMock() + + with patch.object( + orchestrator, "agentic_text", new_callable=AsyncMock + ) as mock_claude: + await orchestrator._handle_unknown_command(update, context) + mock_claude.assert_called_once_with(update, context) + + +async def test_bot_suffixed_command_not_forwarded(agentic_settings, deps): + """Bot-suffixed known commands like /start@mybot must not reach Claude.""" + from unittest.mock import AsyncMock, MagicMock, patch + + orchestrator = MessageOrchestrator(agentic_settings, deps) + app = MagicMock() + app.add_handler = MagicMock() + orchestrator.register_handlers(app) + + update = MagicMock() + update.effective_message.text = "/start@mybot" + context = MagicMock() + + with patch.object( + orchestrator, "agentic_text", new_callable=AsyncMock + ) as mock_claude: + await orchestrator._handle_unknown_command(update, context) + mock_claude.assert_not_called() + + +# --- /model command tests --- + + +async def test_agentic_model_shows_last_model_when_unset(agentic_settings, deps): + """/model with no override shows the model from the last response.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {"last_model": "claude-opus-4-6"} + + await orchestrator.agentic_model(update, context) + + call_args = update.message.reply_text.call_args + text = call_args.args[0] + assert "claude-opus-4-6" in text + assert "Claude Code default" in text + + +async def test_agentic_model_shows_unknown_before_first_message(agentic_settings, deps): + """/model before any message shows unknown.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + call_args = update.message.reply_text.call_args + text = call_args.args[0] + assert "unknown" in text.lower() + assert call_args.kwargs.get("parse_mode") == "HTML" + + +async def test_agentic_model_shows_config_model(tmp_dir, deps): + """/model shows the server-configured model when CLAUDE_MODEL is set.""" + settings = create_test_config( + approved_directory=str(tmp_dir), + agentic_mode=True, + claude_model="claude-opus-4-6", + ) + orchestrator = MessageOrchestrator(settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "claude-opus-4-6" in text + assert "server config" in text + + +async def test_agentic_model_shows_user_override(agentic_settings, deps): + """/model shows the user's override when one is set.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model" + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {"model_override": "haiku"} + + await orchestrator.agentic_model(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "haiku" in text + assert "user override" in text + + +async def test_agentic_model_sets_override(agentic_settings, deps): + """/model sonnet sets the user's model override.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model sonnet" + update.message.reply_text = AsyncMock() + update.effective_user.id = 123 + + context = MagicMock() + context.user_data = {} + context.bot_data = {"audit_logger": AsyncMock()} + + await orchestrator.agentic_model(update, context) + + assert context.user_data["model_override"] == "sonnet" + text = update.message.reply_text.call_args.args[0] + assert "sonnet" in text + + +async def test_agentic_model_reset_to_default(agentic_settings, deps): + """/model default clears the user's model override.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model default" + update.message.reply_text = AsyncMock() + update.effective_user.id = 123 + + context = MagicMock() + context.user_data = {"model_override": "opus"} + context.bot_data = {"audit_logger": AsyncMock()} + + await orchestrator.agentic_model(update, context) + + assert "model_override" not in context.user_data + text = update.message.reply_text.call_args.args[0] + assert "reset" in text.lower() + + +async def test_agentic_model_audit_logged(agentic_settings, deps): + """/model sonnet logs the action to audit logger.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model sonnet" + update.message.reply_text = AsyncMock() + update.effective_user.id = 42 + + audit_logger = AsyncMock() + context = MagicMock() + context.user_data = {} + context.bot_data = {"audit_logger": audit_logger} + + await orchestrator.agentic_model(update, context) + + audit_logger.log_command.assert_called_once_with( + user_id=42, command="model", args=["sonnet"], success=True, + ) + + +async def test_agentic_model_reset_audit_logged(agentic_settings, deps): + """/model default logs as model_reset with empty args.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model default" + update.message.reply_text = AsyncMock() + update.effective_user.id = 42 + + audit_logger = AsyncMock() + context = MagicMock() + context.user_data = {"model_override": "opus"} + context.bot_data = {"audit_logger": audit_logger} + + await orchestrator.agentic_model(update, context) + + audit_logger.log_command.assert_called_once_with( + user_id=42, command="model_reset", args=[], success=True, + ) + + + +async def test_agentic_model_rejects_long_name(agentic_settings, deps): + """/model with overly long name is rejected.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + update = MagicMock() + update.message.text = "/model " + "a" * 101 + update.message.reply_text = AsyncMock() + + context = MagicMock() + context.user_data = {} + + await orchestrator.agentic_model(update, context) + + assert "model_override" not in context.user_data + text = update.message.reply_text.call_args.args[0] + assert "Invalid" in text + + +async def test_model_override_passed_to_run_command(agentic_settings, deps): + """User model override is passed through to claude_integration.run_command.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + mock_response = MagicMock() + mock_response.session_id = "session-abc" + mock_response.content = "Hello!" + mock_response.tools_used = [] + + claude_integration = AsyncMock() + claude_integration.run_command = AsyncMock(return_value=mock_response) + + update = MagicMock() + update.effective_user.id = 123 + update.message.text = "Help me" + update.message.message_id = 1 + update.message.chat.send_action = AsyncMock() + update.message.reply_text = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text.return_value = progress_msg + + context = MagicMock() + context.user_data = {"model_override": "haiku"} + context.bot_data = { + "settings": agentic_settings, + "claude_integration": claude_integration, + "storage": None, + "rate_limiter": None, + "audit_logger": None, + } + + await orchestrator.agentic_text(update, context) + + claude_integration.run_command.assert_called_once() + call_kwargs = claude_integration.run_command.call_args.kwargs + assert call_kwargs["model_override"] == "haiku" + + +async def test_model_override_none_when_not_set(agentic_settings, deps): + """model_override is None when user hasn't set one.""" + orchestrator = MessageOrchestrator(agentic_settings, deps) + + mock_response = MagicMock() + mock_response.session_id = "session-abc" + mock_response.content = "Hello!" + mock_response.tools_used = [] + + claude_integration = AsyncMock() + claude_integration.run_command = AsyncMock(return_value=mock_response) + + update = MagicMock() + update.effective_user.id = 123 + update.message.text = "Help me" + update.message.message_id = 1 + update.message.chat.send_action = AsyncMock() + update.message.reply_text = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text.return_value = progress_msg + + context = MagicMock() + context.user_data = {} + context.bot_data = { + "settings": agentic_settings, + "claude_integration": claude_integration, + "storage": None, + "rate_limiter": None, + "audit_logger": None, + } + + await orchestrator.agentic_text(update, context) + + call_kwargs = claude_integration.run_command.call_args.kwargs + assert call_kwargs["model_override"] is None