diff --git a/src/bot/core.py b/src/bot/core.py index 9fef0670..3c149e45 100644 --- a/src/bot/core.py +++ b/src/bot/core.py @@ -53,6 +53,7 @@ async def initialize(self) -> None: builder.token(self.settings.telegram_token_str) builder.defaults(Defaults(do_quote=self.settings.reply_quote)) builder.rate_limiter(AIORateLimiter(max_retries=1)) + builder.concurrent_updates(True) # Configure connection settings builder.connect_timeout(30) diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index ac1d5304..906ed12a 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -308,6 +308,7 @@ def _register_agentic_handlers(self, app: Application) -> None: ("verbose", self.agentic_verbose), ("repo", self.agentic_repo), ("restart", command.restart_command), + ("stop", self.agentic_stop), ] if self.settings.enable_project_threads: handlers.append(("sync_threads", command.sync_threads)) @@ -417,6 +418,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("verbose", "Set output verbosity (0/1/2)"), BotCommand("repo", "List repos / switch workspace"), BotCommand("restart", "Restart the bot"), + BotCommand("stop", "Cancel running task"), ] if self.settings.enable_project_threads: commands.append(BotCommand("sync_threads", "Sync project topics")) @@ -860,6 +862,62 @@ async def _send_images( return caption_sent + async def _interrupt_running_task( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Interrupt a running Claude task if one exists for this user.""" + running_task: Optional[asyncio.Task] = context.user_data.get( # type: ignore[assignment] + "running_claude_task" + ) + if running_task is None or running_task.done(): + return + + claude_integration = context.bot_data.get("claude_integration") + if not claude_integration: + return + + sdk_manager = getattr(claude_integration, "sdk_manager", None) + if sdk_manager is None: + return + + await update.message.reply_text("\U0001f4e8 Interrupting...") + logger.info( + "Interrupting running Claude task for follow-up", + user_id=update.effective_user.id, + ) + + # Graceful interrupt via SDK control protocol + await sdk_manager.interrupt() + + # Wait up to 3 seconds for graceful stop + try: + await asyncio.wait_for(asyncio.shield(running_task), timeout=3.0) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception): + # Graceful interrupt didn't finish in time — force kill + logger.warning("Graceful interrupt timed out, aborting") + await sdk_manager.abort() + running_task.cancel() + try: + await running_task + except (asyncio.CancelledError, Exception): + pass + + context.user_data["running_claude_task"] = None + + async def agentic_stop( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + """Manually cancel a running Claude task.""" + running_task: Optional[asyncio.Task] = context.user_data.get( # type: ignore[assignment] + "running_claude_task" + ) + if running_task is None or running_task.done(): + await update.message.reply_text("No task is currently running.") + return + + await self._interrupt_running_task(update, context) + await update.message.reply_text("Task cancelled.") + async def agentic_text( self, update: Update, context: ContextTypes.DEFAULT_TYPE ) -> None: @@ -873,6 +931,9 @@ async def agentic_text( message_length=len(message_text), ) + # Interrupt any running task before processing follow-up + await self._interrupt_running_task(update, context) + # Rate limit check rate_limiter = context.bot_data.get("rate_limiter") if rate_limiter: @@ -933,67 +994,94 @@ async def agentic_text( heartbeat = self._start_typing_heartbeat(chat) success = True - try: - claude_response = await claude_integration.run_command( - prompt=message_text, - working_directory=current_dir, - user_id=user_id, - session_id=session_id, - on_stream=on_stream, - force_new=force_new, - ) - # New session created successfully — clear the one-shot flag - if force_new: - context.user_data["force_new_session"] = False + async def _run() -> None: + nonlocal success, formatted_messages + try: + claude_response = await claude_integration.run_command( + prompt=message_text, + working_directory=current_dir, + user_id=user_id, + session_id=session_id, + on_stream=on_stream, + force_new=force_new, + ) - context.user_data["claude_session_id"] = claude_response.session_id + # New session created successfully — clear the one-shot flag + if force_new: + context.user_data["force_new_session"] = False - # Track directory changes - from .handlers.message import _update_working_directory_from_claude_response + context.user_data["claude_session_id"] = claude_response.session_id - _update_working_directory_from_claude_response( - claude_response, context, self.settings, user_id - ) + # Track directory changes + from .handlers.message import ( + _update_working_directory_from_claude_response, + ) - # Store interaction - storage = context.bot_data.get("storage") - if storage: - try: - await storage.save_claude_interaction( - user_id=user_id, - session_id=claude_response.session_id, - prompt=message_text, - response=claude_response, - ip_address=None, - ) - except Exception as e: - logger.warning("Failed to log interaction", error=str(e)) + _update_working_directory_from_claude_response( + claude_response, context, self.settings, user_id + ) - # Format response (no reply_markup — strip keyboards) - from .utils.formatting import ResponseFormatter + # Store interaction + storage = context.bot_data.get("storage") + if storage: + try: + await storage.save_claude_interaction( + user_id=user_id, + session_id=claude_response.session_id, + prompt=message_text, + response=claude_response, + ip_address=None, + ) + except Exception as e: + logger.warning("Failed to log interaction", error=str(e)) - formatter = ResponseFormatter(self.settings) - formatted_messages = formatter.format_claude_response( - claude_response.content - ) + # Format response (no reply_markup — strip keyboards) + from .utils.formatting import ResponseFormatter - except Exception as e: - success = False - logger.error("Claude integration failed", error=str(e), user_id=user_id) - from .handlers.message import _format_error_message - from .utils.formatting import FormattedMessage + formatter = ResponseFormatter(self.settings) + formatted_messages = formatter.format_claude_response( + claude_response.content + ) - formatted_messages = [ - FormattedMessage(_format_error_message(e), parse_mode="HTML") - ] - finally: - heartbeat.cancel() - if draft_streamer: - try: - await draft_streamer.flush() - except Exception: - logger.debug("Draft flush failed in finally block", user_id=user_id) + except asyncio.CancelledError: + success = False + logger.info("Claude task cancelled (interrupted)", user_id=user_id) + from .utils.formatting import FormattedMessage + + formatted_messages = [ + FormattedMessage( + "\u26a0\ufe0f Task interrupted.", parse_mode="HTML" + ) + ] + except Exception as e: + success = False + logger.error( + "Claude integration failed", error=str(e), user_id=user_id + ) + from .handlers.message import _format_error_message + from .utils.formatting import FormattedMessage + + formatted_messages = [ + FormattedMessage( + _format_error_message(e), parse_mode="HTML" + ) + ] + finally: + heartbeat.cancel() + context.user_data["running_claude_task"] = None + if draft_streamer: + try: + await draft_streamer.flush() + except Exception: + logger.debug( + "Draft flush failed in finally block", user_id=user_id + ) + + formatted_messages: list = [] # type: ignore[assignment] + task = asyncio.ensure_future(_run()) + context.user_data["running_claude_task"] = task + await task try: await progress_msg.delete() @@ -1177,87 +1265,102 @@ async def agentic_document( ) heartbeat = self._start_typing_heartbeat(chat) - try: - claude_response = await claude_integration.run_command( - prompt=prompt, - working_directory=current_dir, - user_id=user_id, - session_id=session_id, - on_stream=on_stream, - force_new=force_new, - ) - if force_new: - context.user_data["force_new_session"] = False + async def _run_doc() -> None: + try: + claude_response = await claude_integration.run_command( + prompt=prompt, + working_directory=current_dir, + user_id=user_id, + session_id=session_id, + on_stream=on_stream, + force_new=force_new, + ) - context.user_data["claude_session_id"] = claude_response.session_id + if force_new: + context.user_data["force_new_session"] = False - from .handlers.message import _update_working_directory_from_claude_response + context.user_data["claude_session_id"] = claude_response.session_id - _update_working_directory_from_claude_response( - claude_response, context, self.settings, user_id - ) + from .handlers.message import ( + _update_working_directory_from_claude_response, + ) - from .utils.formatting import ResponseFormatter + _update_working_directory_from_claude_response( + claude_response, context, self.settings, user_id + ) - formatter = ResponseFormatter(self.settings) - formatted_messages = formatter.format_claude_response( - claude_response.content - ) + from .utils.formatting import ResponseFormatter - try: - await progress_msg.delete() - except Exception: - logger.debug("Failed to delete progress message, ignoring") + formatter = ResponseFormatter(self.settings) + formatted_messages = formatter.format_claude_response( + claude_response.content + ) - # Use MCP-collected images (from send_image_to_user tool calls) - images: List[ImageAttachment] = mcp_images_doc + try: + await progress_msg.delete() + except Exception: + logger.debug("Failed to delete progress message, ignoring") + + # Use MCP-collected images (from send_image_to_user tool calls) + images: List[ImageAttachment] = mcp_images_doc + + caption_sent = False + if images and len(formatted_messages) == 1: + msg = formatted_messages[0] + if msg.text and len(msg.text) <= 1024: + try: + caption_sent = await self._send_images( + update, + images, + reply_to_message_id=update.message.message_id, + caption=msg.text, + caption_parse_mode=msg.parse_mode, + ) + except Exception as img_err: + logger.warning( + "Image+caption send failed", error=str(img_err) + ) - caption_sent = False - if images and len(formatted_messages) == 1: - msg = formatted_messages[0] - if msg.text and len(msg.text) <= 1024: - try: - caption_sent = await self._send_images( - update, - images, - reply_to_message_id=update.message.message_id, - caption=msg.text, - caption_parse_mode=msg.parse_mode, + if not caption_sent: + for i, message in enumerate(formatted_messages): + await update.message.reply_text( + message.text, + parse_mode=message.parse_mode, + reply_markup=None, + reply_to_message_id=( + update.message.message_id if i == 0 else None + ), ) - except Exception as img_err: - logger.warning("Image+caption send failed", error=str(img_err)) - - if not caption_sent: - for i, message in enumerate(formatted_messages): - await update.message.reply_text( - message.text, - parse_mode=message.parse_mode, - reply_markup=None, - reply_to_message_id=( - update.message.message_id if i == 0 else None - ), - ) - if i < len(formatted_messages) - 1: - await asyncio.sleep(0.5) + if i < len(formatted_messages) - 1: + await asyncio.sleep(0.5) + + if images: + try: + await self._send_images( + update, + images, + reply_to_message_id=update.message.message_id, + ) + except Exception as img_err: + logger.warning("Image send failed", error=str(img_err)) - if images: - try: - await self._send_images( - update, - images, - reply_to_message_id=update.message.message_id, - ) - except Exception as img_err: - logger.warning("Image send failed", error=str(img_err)) + except Exception as e: + from .handlers.message import _format_error_message - except Exception as e: - from .handlers.message import _format_error_message + await progress_msg.edit_text( + _format_error_message(e), parse_mode="HTML" + ) + logger.error( + "Claude file processing failed", error=str(e), user_id=user_id + ) + finally: + heartbeat.cancel() + context.user_data["running_claude_task"] = None - await progress_msg.edit_text(_format_error_message(e), parse_mode="HTML") - logger.error("Claude file processing failed", error=str(e), user_id=user_id) - finally: - heartbeat.cancel() + task = asyncio.ensure_future(_run_doc()) + context.user_data["running_claude_task"] = task + await task async def agentic_photo( self, update: Update, context: ContextTypes.DEFAULT_TYPE @@ -1376,79 +1479,94 @@ async def _handle_agentic_media_message( ) heartbeat = self._start_typing_heartbeat(chat) - try: - claude_response = await claude_integration.run_command( - prompt=prompt, - working_directory=current_dir, - user_id=user_id, - session_id=session_id, - on_stream=on_stream, - force_new=force_new, - ) - finally: - heartbeat.cancel() - if force_new: - context.user_data["force_new_session"] = False + async def _run_media() -> None: + try: + claude_response = await claude_integration.run_command( + prompt=prompt, + working_directory=current_dir, + user_id=user_id, + session_id=session_id, + on_stream=on_stream, + force_new=force_new, + ) + finally: + heartbeat.cancel() + context.user_data["running_claude_task"] = None - context.user_data["claude_session_id"] = claude_response.session_id + if force_new: + context.user_data["force_new_session"] = False - from .handlers.message import _update_working_directory_from_claude_response + context.user_data["claude_session_id"] = claude_response.session_id - _update_working_directory_from_claude_response( - claude_response, context, self.settings, user_id - ) + from .handlers.message import ( + _update_working_directory_from_claude_response, + ) - from .utils.formatting import ResponseFormatter + _update_working_directory_from_claude_response( + claude_response, context, self.settings, user_id + ) - formatter = ResponseFormatter(self.settings) - formatted_messages = formatter.format_claude_response(claude_response.content) + from .utils.formatting import ResponseFormatter - try: - await progress_msg.delete() - except Exception: - logger.debug("Failed to delete progress message, ignoring") + formatter = ResponseFormatter(self.settings) + formatted_messages = formatter.format_claude_response( + claude_response.content + ) - # Use MCP-collected images (from send_image_to_user tool calls). - images: List[ImageAttachment] = mcp_images_media + try: + await progress_msg.delete() + except Exception: + logger.debug("Failed to delete progress message, ignoring") - caption_sent = False - if images and len(formatted_messages) == 1: - msg = formatted_messages[0] - if msg.text and len(msg.text) <= 1024: - try: - caption_sent = await self._send_images( - update, - images, - reply_to_message_id=update.message.message_id, - caption=msg.text, - caption_parse_mode=msg.parse_mode, - ) - except Exception as img_err: - logger.warning("Image+caption send failed", error=str(img_err)) + # Use MCP-collected images (from send_image_to_user tool calls). + images: List[ImageAttachment] = mcp_images_media - if not caption_sent: - for i, message in enumerate(formatted_messages): - if not message.text or not message.text.strip(): - continue - await update.message.reply_text( - message.text, - parse_mode=message.parse_mode, - reply_markup=None, - reply_to_message_id=(update.message.message_id if i == 0 else None), - ) - if i < len(formatted_messages) - 1: - await asyncio.sleep(0.5) + caption_sent = False + if images and len(formatted_messages) == 1: + msg = formatted_messages[0] + if msg.text and len(msg.text) <= 1024: + try: + caption_sent = await self._send_images( + update, + images, + reply_to_message_id=update.message.message_id, + caption=msg.text, + caption_parse_mode=msg.parse_mode, + ) + except Exception as img_err: + logger.warning( + "Image+caption send failed", error=str(img_err) + ) - if images: - try: - await self._send_images( - update, - images, - reply_to_message_id=update.message.message_id, + if not caption_sent: + for i, message in enumerate(formatted_messages): + if not message.text or not message.text.strip(): + continue + await update.message.reply_text( + message.text, + parse_mode=message.parse_mode, + reply_markup=None, + reply_to_message_id=( + update.message.message_id if i == 0 else None + ), ) - except Exception as img_err: - logger.warning("Image send failed", error=str(img_err)) + if i < len(formatted_messages) - 1: + await asyncio.sleep(0.5) + + if images: + try: + await self._send_images( + update, + images, + reply_to_message_id=update.message.message_id, + ) + except Exception as img_err: + logger.warning("Image send failed", error=str(img_err)) + + task = asyncio.ensure_future(_run_media()) + context.user_data["running_claude_task"] = task + await task def _voice_unavailable_message(self) -> str: """Return provider-aware guidance when voice feature is unavailable.""" diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..ad9e34fb 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -137,6 +137,8 @@ def __init__( """Initialize SDK manager with configuration.""" self.config = config self.security_validator = security_validator + self._active_client: Optional[ClaudeSDKClient] = None + self._is_processing: bool = False # Set up environment for Claude Code SDK if API key is provided # If no API key is provided, the SDK will use existing CLI authentication @@ -146,6 +148,32 @@ def __init__( else: logger.info("No API key provided, using existing Claude CLI authentication") + @property + def is_processing(self) -> bool: + """Return True if a command is currently being processed.""" + return self._is_processing + + async def interrupt(self) -> None: + """Send interrupt signal to the active Claude client (like Ctrl+C).""" + client = self._active_client + if client is not None: + try: + await client.interrupt() + logger.info("Sent interrupt to active Claude client") + except Exception as e: + logger.warning("Failed to interrupt Claude client", error=str(e)) + + async def abort(self) -> None: + """Interrupt and then forcefully disconnect the active client.""" + await self.interrupt() + client = self._active_client + if client is not None: + try: + await client.disconnect() + logger.info("Force-disconnected active Claude client") + except Exception as e: + logger.warning("Failed to disconnect Claude client", error=str(e)) + async def execute_command( self, prompt: str, @@ -155,6 +183,7 @@ async def execute_command( stream_callback: Optional[Callable[[StreamUpdate], None]] = None, ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" + self._is_processing = True start_time = asyncio.get_event_loop().time() logger.info( @@ -247,6 +276,7 @@ async def _run_client() -> None: # a plain string. connect(None) uses an empty async # iterable internally, satisfying the requirement. client = ClaudeSDKClient(options) + self._active_client = client try: await client.connect() await client.query(prompt) @@ -286,6 +316,7 @@ async def _run_client() -> None: error_type=type(callback_error).__name__, ) finally: + self._active_client = None await client.disconnect() # Execute with timeout @@ -455,6 +486,9 @@ async def _run_client() -> None: ) raise ClaudeProcessError(f"Unexpected error: {str(e)}") + finally: + self._is_processing = False + async def _handle_stream_message( self, message: Message, stream_callback: Callable[[StreamUpdate], None] ) -> None: