diff --git a/src/bot/core.py b/src/bot/core.py index 9fef0670..192bddd3 100644 --- a/src/bot/core.py +++ b/src/bot/core.py @@ -54,6 +54,10 @@ async def initialize(self) -> None: builder.defaults(Defaults(do_quote=self.settings.reply_quote)) builder.rate_limiter(AIORateLimiter(max_retries=1)) + from .update_processor import StopAwareUpdateProcessor + + builder.concurrent_updates(StopAwareUpdateProcessor()) + # Configure connection settings builder.connect_timeout(30) builder.read_timeout(30) diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index ac1d5304..a18248b0 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -8,6 +8,7 @@ import asyncio import re import time +from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -109,12 +110,23 @@ def _tool_icon(name: str) -> str: return _TOOL_ICONS.get(name, "\U0001f527") +@dataclass +class ActiveRequest: + """Tracks an in-flight Claude request so it can be interrupted.""" + + user_id: int + interrupt_event: asyncio.Event = field(default_factory=asyncio.Event) + interrupted: bool = False + progress_msg: Any = None # telegram Message object + + class MessageOrchestrator: """Routes messages based on mode. Single entry point for all Telegram updates.""" def __init__(self, settings: Settings, deps: Dict[str, Any]): self.settings = settings self.deps = deps + self._active_requests: Dict[int, ActiveRequest] = {} def _inject_deps(self, handler: Callable) -> Callable: # type: ignore[type-arg] """Wrap handler to inject dependencies into context.bot_data.""" @@ -344,6 +356,14 @@ def _register_agentic_handlers(self, app: Application) -> None: group=10, ) + # Stop button callback (must be before cd: handler) + app.add_handler( + CallbackQueryHandler( + self._inject_deps(self._handle_stop_callback), + pattern=r"^stop:", + ) + ) + # Only cd: callbacks (for project selection), scoped by pattern app.add_handler( CallbackQueryHandler( @@ -675,9 +695,11 @@ def _make_stream_callback( progress_msg: Any, tool_log: List[Dict[str, Any]], start_time: float, + reply_markup: Optional[InlineKeyboardMarkup] = None, mcp_images: Optional[List[ImageAttachment]] = None, approved_directory: Optional[Path] = None, draft_streamer: Optional[DraftStreamer] = None, + interrupt_event: Optional[asyncio.Event] = None, ) -> Optional[Callable[[StreamUpdate], Any]]: """Create a stream callback for verbose progress updates. @@ -701,6 +723,10 @@ def _make_stream_callback( last_edit_time = [0.0] # mutable container for closure async def _on_stream(update_obj: StreamUpdate) -> None: + # Stop all streaming activity after interrupt + if interrupt_event is not None and interrupt_event.is_set(): + return + # Intercept send_image_to_user MCP tool calls. # The SDK namespaces MCP tools as "mcp____", # so match both the bare name and the namespaced variant. @@ -765,7 +791,9 @@ async def _on_stream(update_obj: StreamUpdate) -> None: tool_log, verbose_level, start_time ) try: - await progress_msg.edit_text(new_text) + await progress_msg.edit_text( + new_text, reply_markup=reply_markup + ) except Exception: pass @@ -885,12 +913,30 @@ async def agentic_text( await chat.send_action("typing") verbose_level = self._get_verbose_level(context) - progress_msg = await update.message.reply_text("Working...") + + # Create Stop button and interrupt event + interrupt_event = asyncio.Event() + stop_kb = InlineKeyboardMarkup( + [[InlineKeyboardButton("Stop", callback_data=f"stop:{user_id}")]] + ) + progress_msg = await update.message.reply_text( + "Working...", reply_markup=stop_kb + ) + + # Register active request for stop callback + active_request = ActiveRequest( + user_id=user_id, + interrupt_event=interrupt_event, + progress_msg=progress_msg, + ) + self._active_requests[user_id] = active_request claude_integration = context.bot_data.get("claude_integration") if not claude_integration: + self._active_requests.pop(user_id, None) await progress_msg.edit_text( - "Claude integration not available. Check configuration." + "Claude integration not available. Check configuration.", + reply_markup=None, ) return @@ -924,9 +970,11 @@ async def agentic_text( progress_msg, tool_log, start_time, + reply_markup=stop_kb, mcp_images=mcp_images, approved_directory=self.settings.approved_directory, draft_streamer=draft_streamer, + interrupt_event=interrupt_event, ) # Independent typing heartbeat — stays alive even with no stream events @@ -941,6 +989,7 @@ async def agentic_text( session_id=session_id, on_stream=on_stream, force_new=force_new, + interrupt_event=interrupt_event, ) # New session created successfully — clear the one-shot flag @@ -974,9 +1023,14 @@ async def agentic_text( from .utils.formatting import ResponseFormatter formatter = ResponseFormatter(self.settings) - formatted_messages = formatter.format_claude_response( - claude_response.content - ) + + response_content = claude_response.content + if claude_response.interrupted: + response_content = ( + response_content or "" + ) + "\n\n_(Interrupted by user)_" + + formatted_messages = formatter.format_claude_response(response_content) except Exception as e: success = False @@ -989,6 +1043,7 @@ async def agentic_text( ] finally: heartbeat.cancel() + self._active_requests.pop(user_id, None) if draft_streamer: try: await draft_streamer.flush() @@ -1555,6 +1610,37 @@ async def agentic_repo( reply_markup=reply_markup, ) + async def _handle_stop_callback( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Handle stop: callbacks — interrupt a running Claude request.""" + query = update.callback_query + target_user_id = int(query.data.split(":", 1)[1]) + + # Only the requesting user can stop their own request + if query.from_user.id != target_user_id: + await query.answer( + "Only the requesting user can stop this.", show_alert=True + ) + return + + active = self._active_requests.get(target_user_id) + if not active: + await query.answer("Already completed.", show_alert=False) + return + if active.interrupted: + await query.answer("Already stopping...", show_alert=False) + return + + active.interrupt_event.set() + active.interrupted = True + await query.answer("Stopping...", show_alert=False) + + try: + await active.progress_msg.edit_text("Stopping...", reply_markup=None) + except Exception: + pass + async def _agentic_callback( self, update: Update, context: ContextTypes.DEFAULT_TYPE ) -> None: diff --git a/src/bot/update_processor.py b/src/bot/update_processor.py new file mode 100644 index 00000000..14cb71f7 --- /dev/null +++ b/src/bot/update_processor.py @@ -0,0 +1,70 @@ +"""Selective-concurrency update processor for PTB. + +Regular updates (messages, commands) process sequentially -- one at a time. +Priority callbacks (stop:*) bypass the queue and run immediately so they can +interrupt the currently-running handler. +""" + +import asyncio +from typing import Any, Awaitable + +from telegram import Update +from telegram.ext._baseupdateprocessor import BaseUpdateProcessor + + +class StopAwareUpdateProcessor(BaseUpdateProcessor): + """Update processor that lets priority callbacks bypass sequential processing. + + PTB calls ``process_update(update, coroutine)`` for every incoming update. + The base class holds a semaphore (max 256) then calls our + ``do_process_update()``. + + For priority callbacks (``stop:*``): we just ``await coroutine`` -- runs + immediately. + For everything else: we acquire ``_sequential_lock`` first -- only one + runs at a time. + + A stop callback arrives while a text handler holds the lock -> stop + callback runs concurrently -> fires the ``asyncio.Event`` -> the watcher + task inside ``execute_command()`` calls ``client.interrupt()`` -> Claude + stops -> ``run_command()`` returns -> handler finishes -> lock released. + """ + + _PRIORITY_PREFIXES = ("stop:",) + + def __init__(self) -> None: + # High limit so priority callbacks are never blocked by semaphore + super().__init__(max_concurrent_updates=256) + self._sequential_lock = asyncio.Lock() + + @classmethod + def _is_priority_callback(cls, update: object) -> bool: + """Return True if the update is a priority callback query.""" + if not isinstance(update, Update): + return False + cb = update.callback_query + return ( + cb is not None + and cb.data is not None + and cb.data.startswith(cls._PRIORITY_PREFIXES) + ) + + async def do_process_update( + self, + update: object, + coroutine: Awaitable[Any], + ) -> None: + """Process an update, applying sequential lock for non-priority updates.""" + if self._is_priority_callback(update): + # Run immediately -- no sequential lock + await coroutine + else: + # One at a time for everything else + async with self._sequential_lock: + await coroutine + + async def initialize(self) -> None: + """Initialize the processor (no-op).""" + + async def shutdown(self) -> None: + """Shutdown the processor (no-op).""" diff --git a/src/claude/facade.py b/src/claude/facade.py index fcb2ada6..5c7276eb 100644 --- a/src/claude/facade.py +++ b/src/claude/facade.py @@ -3,6 +3,7 @@ Provides simple interface for bot handlers. """ +import asyncio from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -37,6 +38,7 @@ async def run_command( session_id: Optional[str] = None, on_stream: Optional[Callable[[StreamUpdate], None]] = None, force_new: bool = False, + interrupt_event: Optional["asyncio.Event"] = None, ) -> ClaudeResponse: """Run Claude Code command with full integration.""" logger.info( @@ -85,6 +87,7 @@ async def run_command( session_id=claude_session_id, continue_session=should_continue, stream_callback=on_stream, + interrupt_event=interrupt_event, ) except Exception as resume_error: # If resume failed (e.g., session expired/missing on Claude's side), @@ -109,6 +112,7 @@ async def run_command( session_id=None, continue_session=False, stream_callback=on_stream, + interrupt_event=interrupt_event, ) else: raise @@ -152,6 +156,7 @@ async def _execute( session_id: Optional[str] = None, continue_session: bool = False, stream_callback: Optional[Callable] = None, + interrupt_event: Optional[asyncio.Event] = None, ) -> ClaudeResponse: """Execute command via SDK.""" return await self.sdk_manager.execute_command( @@ -160,6 +165,7 @@ async def _execute( session_id=session_id, continue_session=continue_session, stream_callback=stream_callback, + interrupt_event=interrupt_event, ) async def _find_resumable_session( diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..ab9c4046 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -53,6 +53,7 @@ class ClaudeResponse: is_error: bool = False error_type: Optional[str] = None tools_used: List[Dict[str, Any]] = field(default_factory=list) + interrupted: bool = False @dataclass @@ -153,6 +154,7 @@ async def execute_command( session_id: Optional[str] = None, continue_session: bool = False, stream_callback: Optional[Callable[[StreamUpdate], None]] = None, + interrupt_event: Optional[asyncio.Event] = None, ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() @@ -240,24 +242,14 @@ def _stderr_callback(line: str) -> None: # Collect messages via ClaudeSDKClient messages: List[Message] = [] + interrupted = False async def _run_client() -> None: - # Use connect(None) + query(prompt) pattern because - # can_use_tool requires the prompt as AsyncIterable, not - # a plain string. connect(None) uses an empty async - # iterable internally, satisfying the requirement. client = ClaudeSDKClient(options) try: await client.connect() await client.query(prompt) - # Iterate over raw messages and parse them ourselves - # so that MessageParseError (e.g. from rate_limit_event) - # doesn't kill the underlying async generator. When - # parse_message raises inside the SDK's receive_messages() - # generator, Python terminates that generator permanently, - # causing us to lose all subsequent messages including - # the ResultMessage. async for raw_data in client._query.receive_messages(): try: message = parse_message(raw_data) @@ -288,11 +280,43 @@ async def _run_client() -> None: finally: await client.disconnect() - # Execute with timeout - await asyncio.wait_for( - _run_client(), - timeout=self.config.claude_timeout_seconds, - ) + # Execute: race client against timeout and optional interrupt + run_task = asyncio.create_task(_run_client()) + + interrupt_watcher: Optional["asyncio.Task[None]"] = None + if interrupt_event is not None: + + async def _cancel_on_interrupt() -> None: + nonlocal interrupted + await interrupt_event.wait() + interrupted = True + run_task.cancel() + + interrupt_watcher = asyncio.create_task(_cancel_on_interrupt()) + + try: + await asyncio.wait_for( + asyncio.shield(run_task), + timeout=self.config.claude_timeout_seconds, + ) + except asyncio.CancelledError: + if not interrupted: + raise + # Interrupt cancelled the task — wait for cleanup + try: + await run_task + except asyncio.CancelledError: + pass + except asyncio.TimeoutError: + run_task.cancel() + try: + await run_task + except asyncio.CancelledError: + pass + raise + finally: + if interrupt_watcher is not None: + interrupt_watcher.cancel() # Extract cost, tools, and session_id from result message cost = 0.0 @@ -377,6 +401,7 @@ async def _run_client() -> None: ] ), tools_used=tools_used, + interrupted=interrupted, ) except asyncio.TimeoutError: diff --git a/tests/unit/test_bot/test_stop_button.py b/tests/unit/test_bot/test_stop_button.py new file mode 100644 index 00000000..bb167d67 --- /dev/null +++ b/tests/unit/test_bot/test_stop_button.py @@ -0,0 +1,451 @@ +"""Tests for the Stop button (interrupt) feature. + +Covers: +- Stop button appears on progress message +- Stop callback fires interrupt event +- Non-owner cannot stop another user's request +- Stop after completion (graceful handling) +- Double-stop prevention +- SDK execute_command with interrupt_event triggers client.interrupt() +- Partial response preserved after interrupt +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock +from telegram import InlineKeyboardMarkup + +from src.bot.orchestrator import ActiveRequest, MessageOrchestrator +from src.claude.sdk_integration import ClaudeResponse, ClaudeSDKManager +from src.config.settings import Settings + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def settings(tmp_path): + return Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + agentic_mode=True, + ) + + +@pytest.fixture +def orchestrator(settings): + deps: dict = {} + return MessageOrchestrator(settings, deps) + + +@pytest.fixture +def sdk_manager(tmp_path): + config = Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + claude_timeout_seconds=5, + ) + return ClaudeSDKManager(config) + + +# --------------------------------------------------------------------------- +# ActiveRequest / orchestrator unit tests +# --------------------------------------------------------------------------- + + +class TestActiveRequest: + """Basic ActiveRequest dataclass behaviour.""" + + def test_defaults(self): + req = ActiveRequest(user_id=42) + assert req.user_id == 42 + assert isinstance(req.interrupt_event, asyncio.Event) + assert not req.interrupt_event.is_set() + assert req.interrupted is False + assert req.progress_msg is None + + +class TestStopCallback: + """_handle_stop_callback routing logic.""" + + async def test_owner_can_stop(self, orchestrator): + """Clicking Stop fires the interrupt event.""" + event = asyncio.Event() + progress_msg = AsyncMock() + active = ActiveRequest( + user_id=100, interrupt_event=event, progress_msg=progress_msg + ) + orchestrator._active_requests[100] = active + + query = AsyncMock() + query.data = "stop:100" + query.from_user = MagicMock() + query.from_user.id = 100 + + update = MagicMock() + update.callback_query = query + + context = MagicMock() + context.bot_data = {} + + await orchestrator._handle_stop_callback(update, context) + + assert event.is_set() + assert active.interrupted is True + query.answer.assert_awaited_once_with("Stopping...", show_alert=False) + progress_msg.edit_text.assert_awaited_once_with( + "Stopping...", reply_markup=None + ) + + async def test_non_owner_blocked(self, orchestrator): + """A different user cannot stop someone else's request.""" + event = asyncio.Event() + active = ActiveRequest( + user_id=100, interrupt_event=event, progress_msg=AsyncMock() + ) + orchestrator._active_requests[100] = active + + query = AsyncMock() + query.data = "stop:100" + query.from_user = MagicMock() + query.from_user.id = 999 # different user + + update = MagicMock() + update.callback_query = query + context = MagicMock() + context.bot_data = {} + + await orchestrator._handle_stop_callback(update, context) + + assert not event.is_set() + assert not active.interrupted + query.answer.assert_awaited_once_with( + "Only the requesting user can stop this.", show_alert=True + ) + + async def test_stop_after_completion(self, orchestrator): + """Clicking Stop after request completed is handled gracefully.""" + query = AsyncMock() + query.data = "stop:100" + query.from_user = MagicMock() + query.from_user.id = 100 + + update = MagicMock() + update.callback_query = query + context = MagicMock() + context.bot_data = {} + + # No active request registered + await orchestrator._handle_stop_callback(update, context) + + query.answer.assert_awaited_once_with("Already completed.", show_alert=False) + + async def test_double_stop_prevention(self, orchestrator): + """Second click shows 'Already stopping...' instead of re-firing.""" + event = asyncio.Event() + active = ActiveRequest( + user_id=100, interrupt_event=event, progress_msg=AsyncMock() + ) + active.interrupted = True # already stopped once + orchestrator._active_requests[100] = active + + query = AsyncMock() + query.data = "stop:100" + query.from_user = MagicMock() + query.from_user.id = 100 + + update = MagicMock() + update.callback_query = query + context = MagicMock() + context.bot_data = {} + + await orchestrator._handle_stop_callback(update, context) + + query.answer.assert_awaited_once_with("Already stopping...", show_alert=False) + + +class TestStopButtonOnProgress: + """Verify the Stop button is attached to progress messages.""" + + async def test_progress_message_has_stop_button(self, orchestrator, settings): + """agentic_text sends progress_msg with Stop keyboard.""" + user_id = 42 + mock_response = ClaudeResponse( + content="Done", + session_id="s1", + cost=0.01, + duration_ms=100, + num_turns=1, + ) + + update = MagicMock() + update.effective_user = MagicMock() + update.effective_user.id = user_id + update.message = AsyncMock() + update.message.message_id = 1 + update.message.text = "test" + update.message.chat = AsyncMock() + update.message.chat.send_action = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text = AsyncMock(return_value=progress_msg) + update.effective_message = update.message + + context = MagicMock() + context.user_data = {"current_directory": settings.approved_directory} + context.bot_data = { + "claude_integration": AsyncMock(), + "rate_limiter": None, + "audit_logger": None, + "storage": None, + } + context.bot_data["claude_integration"].run_command = AsyncMock( + return_value=mock_response + ) + + with patch( + "src.bot.orchestrator.MessageOrchestrator._start_typing_heartbeat" + ) as mock_hb: + mock_task = AsyncMock() + mock_task.cancel = MagicMock() + mock_hb.return_value = mock_task + + with patch( + "src.bot.handlers.message._update_working_directory_from_claude_response" + ): + with patch("src.bot.utils.formatting.ResponseFormatter") as MockFmt: + MockFmt.return_value.format_claude_response.return_value = [] + await orchestrator.agentic_text(update, context) + + # First reply_text call should be the progress message with Stop button + first_call = update.message.reply_text.call_args_list[0] + assert first_call.args[0] == "Working..." + reply_markup = first_call.kwargs.get("reply_markup") + assert reply_markup is not None + assert isinstance(reply_markup, InlineKeyboardMarkup) + button = reply_markup.inline_keyboard[0][0] + assert button.text == "Stop" + assert button.callback_data == f"stop:{user_id}" + + async def test_active_request_cleaned_up_after_success( + self, orchestrator, settings + ): + """_active_requests is cleared in the finally block.""" + user_id = 42 + mock_response = ClaudeResponse( + content="Done", + session_id="s1", + cost=0.01, + duration_ms=100, + num_turns=1, + ) + + update = MagicMock() + update.effective_user = MagicMock() + update.effective_user.id = user_id + update.message = AsyncMock() + update.message.message_id = 1 + update.message.text = "test" + update.message.chat = AsyncMock() + update.message.chat.send_action = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text = AsyncMock(return_value=progress_msg) + update.effective_message = update.message + + context = MagicMock() + context.user_data = {"current_directory": settings.approved_directory} + context.bot_data = { + "claude_integration": AsyncMock(), + "rate_limiter": None, + "audit_logger": None, + "storage": None, + } + context.bot_data["claude_integration"].run_command = AsyncMock( + return_value=mock_response + ) + + with patch( + "src.bot.orchestrator.MessageOrchestrator._start_typing_heartbeat" + ) as mock_hb: + mock_task = AsyncMock() + mock_task.cancel = MagicMock() + mock_hb.return_value = mock_task + with patch( + "src.bot.handlers.message._update_working_directory_from_claude_response" + ): + with patch("src.bot.utils.formatting.ResponseFormatter") as MockFmt: + MockFmt.return_value.format_claude_response.return_value = [] + await orchestrator.agentic_text(update, context) + + assert user_id not in orchestrator._active_requests + + async def test_active_request_cleaned_up_after_error(self, orchestrator, settings): + """_active_requests is cleared even when run_command raises.""" + user_id = 42 + + update = MagicMock() + update.effective_user = MagicMock() + update.effective_user.id = user_id + update.message = AsyncMock() + update.message.message_id = 1 + update.message.text = "test" + update.message.chat = AsyncMock() + update.message.chat.send_action = AsyncMock() + + progress_msg = AsyncMock() + progress_msg.delete = AsyncMock() + update.message.reply_text = AsyncMock(return_value=progress_msg) + update.effective_message = update.message + + context = MagicMock() + context.user_data = {"current_directory": settings.approved_directory} + context.bot_data = { + "claude_integration": AsyncMock(), + "rate_limiter": None, + "audit_logger": None, + "storage": None, + } + context.bot_data["claude_integration"].run_command = AsyncMock( + side_effect=RuntimeError("boom") + ) + + with patch( + "src.bot.orchestrator.MessageOrchestrator._start_typing_heartbeat" + ) as mock_hb: + mock_task = AsyncMock() + mock_task.cancel = MagicMock() + mock_hb.return_value = mock_task + with patch( + "src.bot.handlers.message._format_error_message", return_value="err" + ): + await orchestrator.agentic_text(update, context) + + assert user_id not in orchestrator._active_requests + + +# --------------------------------------------------------------------------- +# SDK-level interrupt tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _patch_parse_message(): + """Patch parse_message as identity so mocks can yield typed Message objects.""" + with patch("src.claude.sdk_integration.parse_message", side_effect=lambda x: x): + yield + + +def _mock_client(*messages, delay: float = 0.0): + """Create a mock ClaudeSDKClient that yields messages with optional delay.""" + client = AsyncMock() + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=False) + client.query = AsyncMock() + client.interrupt = AsyncMock() + + async def receive_raw_messages(): + for msg in messages: + if delay: + await asyncio.sleep(delay) + yield msg + + query_mock = AsyncMock() + query_mock.receive_messages = receive_raw_messages + client._query = query_mock + + return client + + +class TestSDKInterrupt: + """Test interrupt_event cancels the run task in execute_command.""" + + async def test_interrupt_event_cancels_task(self, sdk_manager, tmp_path): + """Setting the interrupt_event should cancel the client task.""" + assistant_msg = AssistantMessage( + content=[TextBlock(text="partial")], + model="claude-sonnet-4-20250514", + ) + result_msg = ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=80, + is_error=False, + num_turns=1, + session_id="s1", + total_cost_usd=0.01, + result="partial", + ) + + # First message arrives at t=0.05, second at t=0.5 + # Interrupt fires at t=0.15 (after first msg, during wait for second) + client = _mock_client(assistant_msg, result_msg, delay=0.05) + + interrupt_event = asyncio.Event() + + async def set_interrupt_soon(): + await asyncio.sleep(0.08) + interrupt_event.set() + + with patch("src.claude.sdk_integration.ClaudeSDKClient", return_value=client): + asyncio.create_task(set_interrupt_soon()) + response = await sdk_manager.execute_command( + prompt="test", + working_directory=tmp_path, + interrupt_event=interrupt_event, + ) + + assert response.interrupted is True + # Partial content from assistant message (ResultMessage never arrived) + assert response.content == "partial" + + async def test_no_interrupt_event_normal_flow(self, sdk_manager, tmp_path): + """Without interrupt_event, response.interrupted should be False.""" + result_msg = ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=80, + is_error=False, + num_turns=1, + session_id="s1", + total_cost_usd=0.01, + result="done", + ) + client = _mock_client(result_msg) + + with patch("src.claude.sdk_integration.ClaudeSDKClient", return_value=client): + response = await sdk_manager.execute_command( + prompt="test", + working_directory=tmp_path, + ) + + assert response.interrupted is False + assert response.content == "done" + + +class TestClaudeResponseInterruptedField: + """Test the interrupted field on ClaudeResponse.""" + + def test_default_false(self): + resp = ClaudeResponse( + content="x", session_id="s", cost=0.0, duration_ms=0, num_turns=1 + ) + assert resp.interrupted is False + + def test_explicit_true(self): + resp = ClaudeResponse( + content="x", + session_id="s", + cost=0.0, + duration_ms=0, + num_turns=1, + interrupted=True, + ) + assert resp.interrupted is True diff --git a/tests/unit/test_bot/test_update_processor.py b/tests/unit/test_bot/test_update_processor.py new file mode 100644 index 00000000..ebeb755c --- /dev/null +++ b/tests/unit/test_bot/test_update_processor.py @@ -0,0 +1,195 @@ +"""Tests for StopAwareUpdateProcessor. + +Covers: +- Stop callbacks bypass the sequential lock (run immediately) +- Regular updates are serialized (only one at a time) +- Non-stop callbacks (e.g. cd:) go through the sequential lock +""" + +import asyncio +from unittest.mock import MagicMock + +from telegram import CallbackQuery, Update + +from src.bot.update_processor import StopAwareUpdateProcessor + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_update(callback_data: str | None = None) -> Update: + """Build a minimal Update mock with optional callback_query data.""" + update = MagicMock(spec=Update) + if callback_data is not None: + cb = MagicMock(spec=CallbackQuery) + cb.data = callback_data + update.callback_query = cb + else: + update.callback_query = None + return update + + +# --------------------------------------------------------------------------- +# _is_priority_callback +# --------------------------------------------------------------------------- + + +class TestIsPriorityCallback: + def test_stop_callback_detected(self): + update = _make_update("stop:123") + assert StopAwareUpdateProcessor._is_priority_callback(update) is True + + def test_cd_callback_not_priority(self): + update = _make_update("cd:my_project") + assert StopAwareUpdateProcessor._is_priority_callback(update) is False + + def test_no_callback_query(self): + update = _make_update(None) + assert StopAwareUpdateProcessor._is_priority_callback(update) is False + + def test_non_update_object(self): + assert StopAwareUpdateProcessor._is_priority_callback("not an update") is False + + def test_callback_with_none_data(self): + update = MagicMock(spec=Update) + cb = MagicMock(spec=CallbackQuery) + cb.data = None + update.callback_query = cb + assert StopAwareUpdateProcessor._is_priority_callback(update) is False + + +# --------------------------------------------------------------------------- +# do_process_update — concurrency tests +# --------------------------------------------------------------------------- + + +class TestStopCallbackBypassesLock: + async def test_stop_callback_runs_while_lock_held(self): + """A stop callback runs immediately even when sequential lock is held.""" + processor = StopAwareUpdateProcessor() + + execution_order: list[str] = [] + lock_acquired = asyncio.Event() + stop_done = asyncio.Event() + + async def slow_coroutine(): + execution_order.append("regular_start") + lock_acquired.set() + # Wait for the stop callback to finish + await stop_done.wait() + execution_order.append("regular_end") + + async def stop_coroutine(): + execution_order.append("stop_start") + execution_order.append("stop_end") + stop_done.set() + + regular_update = _make_update(None) + stop_update = _make_update("stop:42") + + # Start the regular update (acquires lock) + regular_task = asyncio.create_task( + processor.do_process_update(regular_update, slow_coroutine()) + ) + + # Wait for the regular update to hold the lock + await lock_acquired.wait() + + # Now fire the stop callback — should run immediately + stop_task = asyncio.create_task( + processor.do_process_update(stop_update, stop_coroutine()) + ) + + await asyncio.gather(regular_task, stop_task) + + # Stop ran WHILE regular was still in progress + assert execution_order == [ + "regular_start", + "stop_start", + "stop_end", + "regular_end", + ] + + +class TestRegularUpdatesSequential: + async def test_two_regular_updates_do_not_overlap(self): + """Two regular updates are serialized by the sequential lock.""" + processor = StopAwareUpdateProcessor() + + execution_log: list[str] = [] + + async def coroutine_a(): + execution_log.append("a_start") + await asyncio.sleep(0.05) + execution_log.append("a_end") + + async def coroutine_b(): + execution_log.append("b_start") + await asyncio.sleep(0.05) + execution_log.append("b_end") + + update_a = _make_update(None) + update_b = _make_update(None) + + task_a = asyncio.create_task( + processor.do_process_update(update_a, coroutine_a()) + ) + # Yield so task_a starts and acquires the lock + await asyncio.sleep(0) + + task_b = asyncio.create_task( + processor.do_process_update(update_b, coroutine_b()) + ) + + await asyncio.gather(task_a, task_b) + + # b should not start until a has finished + assert execution_log == ["a_start", "a_end", "b_start", "b_end"] + + +class TestNonStopCallbackSequential: + async def test_cd_callback_goes_through_sequential_lock(self): + """Non-stop callbacks (cd:*) are treated as regular updates.""" + processor = StopAwareUpdateProcessor() + + execution_log: list[str] = [] + + async def regular_coroutine(): + execution_log.append("regular_start") + await asyncio.sleep(0.05) + execution_log.append("regular_end") + + async def cd_coroutine(): + execution_log.append("cd_start") + execution_log.append("cd_end") + + regular_update = _make_update(None) + cd_update = _make_update("cd:my_project") + + task_regular = asyncio.create_task( + processor.do_process_update(regular_update, regular_coroutine()) + ) + await asyncio.sleep(0) + + task_cd = asyncio.create_task( + processor.do_process_update(cd_update, cd_coroutine()) + ) + + await asyncio.gather(task_regular, task_cd) + + # cd callback waited for regular to finish + assert execution_log == [ + "regular_start", + "regular_end", + "cd_start", + "cd_end", + ] + + +class TestInitializeShutdown: + async def test_initialize_and_shutdown_are_noop(self): + """initialize() and shutdown() should not raise.""" + processor = StopAwareUpdateProcessor() + await processor.initialize() + await processor.shutdown() diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index cc02b7c0..320f54ae 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -151,8 +151,8 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): # 4 message handlers (text, document, photo, voice) assert len(msg_handlers) == 4 - # 1 callback handler (for cd: only) - assert len(cb_handlers) == 1 + # 2 callback handlers (stop: + cd:) + assert len(cb_handlers) == 2 async def test_agentic_bot_commands(agentic_settings, deps): @@ -338,10 +338,14 @@ async def test_agentic_callback_scoped_to_cd_pattern(agentic_settings, deps): if isinstance(call[0][0], CallbackQueryHandler) ] - assert len(cb_handlers) == 1 - # The pattern attribute should match cd: prefixed data - assert cb_handlers[0].pattern is not None - assert cb_handlers[0].pattern.match("cd:my_project") + assert len(cb_handlers) == 2 + # Find the cd: handler by pattern + cd_handler = [h for h in cb_handlers if h.pattern and h.pattern.match("cd:x")] + assert len(cd_handler) == 1 + assert cd_handler[0].pattern.match("cd:my_project") + # Also has a stop: handler + stop_handler = [h for h in cb_handlers if h.pattern and h.pattern.match("stop:1")] + assert len(stop_handler) == 1 async def test_agentic_document_rejects_large_files(agentic_settings, deps):