diff --git a/backend/apps/slack/MANIFEST.yaml b/backend/apps/slack/MANIFEST.yaml index d6617e94b3..87c08a2a15 100644 --- a/backend/apps/slack/MANIFEST.yaml +++ b/backend/apps/slack/MANIFEST.yaml @@ -133,6 +133,7 @@ settings: - app_mention - member_joined_channel - message.channels + - message.im - team_join interactivity: is_enabled: true diff --git a/backend/apps/slack/common/handlers/ai.py b/backend/apps/slack/common/handlers/ai.py index ef0452e7b8..2056295f68 100644 --- a/backend/apps/slack/common/handlers/ai.py +++ b/backend/apps/slack/common/handlers/ai.py @@ -6,6 +6,8 @@ from apps.ai.agent.tools.rag.rag_tool import RagTool from apps.slack.blocks import markdown +from apps.slack.constants import CONVERSATION_CONTEXT_LIMIT +from apps.slack.models import Conversation, Workspace logger = logging.getLogger(__name__) @@ -46,6 +48,62 @@ def process_ai_query(query: str) -> str | None: return rag_tool.query(question=query) +def get_dm_blocks(query: str, workspace_id: str, channel_id: str) -> list[dict]: + """Get AI response blocks for DM with conversation context. + + Args: + query (str): The user's question. + workspace_id (str): Slack workspace ID. + channel_id (str): Slack channel ID for the DM. + + Returns: + list: A list of Slack blocks representing the AI response. + + """ + ai_response = process_dm_ai_query(query.strip(), workspace_id, channel_id) + + if ai_response: + return [markdown(ai_response)] + return get_error_blocks() + + +def process_dm_ai_query(query: str, workspace_id: str, channel_id: str) -> str | None: + """Process the AI query with DM conversation context. + + Args: + query (str): The user's question. + workspace_id (str): Slack workspace ID. + channel_id (str): Slack channel ID for the DM. + + Returns: + str | None: The AI response or None if error occurred. + + """ + try: + workspace = Workspace.objects.get(slack_workspace_id=workspace_id) + conversation = Conversation.objects.get(slack_channel_id=channel_id, workspace=workspace) + except (Workspace.DoesNotExist, Conversation.DoesNotExist): + logger.exception("Workspace or conversation not found for DM processing") + return None + + context = conversation.get_context(conversation_context_limit=CONVERSATION_CONTEXT_LIMIT) + + rag_tool = RagTool( + chat_model="gpt-4o", + embedding_model="text-embedding-3-small", + ) + + if context: + enhanced_query = f"Conversation context:\n{context}\n\nCurrent question: {query}" + else: + enhanced_query = query + + response = rag_tool.query(question=enhanced_query) + conversation.add_to_context(query, response) + + return response + + def get_error_blocks() -> list[dict]: """Get error response blocks. diff --git a/backend/apps/slack/constants.py b/backend/apps/slack/constants.py index ef9bb7b9bb..3bed72b5a4 100644 --- a/backend/apps/slack/constants.py +++ b/backend/apps/slack/constants.py @@ -2,6 +2,7 @@ from apps.common.constants import NL +CONVERSATION_CONTEXT_LIMIT = 20 NEST_BOT_NAME = "NestBot" OWASP_APPSEC_CHANNEL_ID = "#C0F7D6DFH" diff --git a/backend/apps/slack/events/message_posted.py b/backend/apps/slack/events/message_posted.py index 5b38c3077b..4558777350 100644 --- a/backend/apps/slack/events/message_posted.py +++ b/backend/apps/slack/events/message_posted.py @@ -1,4 +1,4 @@ -"""Slack message event template.""" +"""Slack message event handler for OWASP NestBot.""" import logging from datetime import timedelta @@ -6,16 +6,17 @@ import django_rq from apps.ai.common.constants import QUEUE_RESPONSE_TIME_MINUTES +from apps.slack.common.handlers.ai import get_dm_blocks from apps.slack.common.question_detector import QuestionDetector from apps.slack.events.event import EventBase -from apps.slack.models import Conversation, Member, Message +from apps.slack.models import Conversation, Member, Message, Workspace from apps.slack.services.message_auto_reply import generate_ai_reply_if_unanswered logger = logging.getLogger(__name__) class MessagePosted(EventBase): - """Handles new messages posted in channels.""" + """Handles new messages posted in channels or direct messages.""" event_type = "message" @@ -24,25 +25,30 @@ def __init__(self): self.question_detector = QuestionDetector() def handle_event(self, event, client): - """Handle an incoming message event.""" + """Handle incoming Slack message events.""" if event.get("subtype") or event.get("bot_id"): - logger.info("Ignored message due to subtype, bot_id, or thread_ts.") + logger.info("Ignored message due to subtype or bot_id.") + return + + channel_id = event.get("channel") + user_id = event.get("user") + text = event.get("text", "") + channel_type = event.get("channel_type") + + if channel_type == "im": + self.handle_dm(event, client, channel_id, user_id, text) return if event.get("thread_ts"): try: Message.objects.filter( slack_message_id=event.get("thread_ts"), - conversation__slack_channel_id=event.get("channel"), + conversation__slack_channel_id=channel_id, ).update(has_replies=True) except Message.DoesNotExist: logger.warning("Thread message not found.") return - channel_id = event.get("channel") - user_id = event.get("user") - text = event.get("text", "") - try: conversation = Conversation.objects.get( slack_channel_id=channel_id, @@ -71,3 +77,46 @@ def handle_event(self, event, client): generate_ai_reply_if_unanswered, message.id, ) + + def handle_dm(self, event, client, channel_id, user_id, text): + """Handle direct messages with NestBot (DMs).""" + workspace_id = event.get("team") + channel_info = client.conversations_info(channel=channel_id) + + try: + workspace = Workspace.objects.get(slack_workspace_id=workspace_id) + except Workspace.DoesNotExist: + logger.exception("Workspace not found for DM.") + return + + Conversation.update_data(channel_info["channel"], workspace) + + try: + Member.objects.get(slack_user_id=user_id, workspace=workspace) + except Member.DoesNotExist: + user_info = client.users_info(user=user_id) + Member.update_data(user_info["user"], workspace, save=True) + logger.info("Created new member for DM") + + thread_ts = event.get("thread_ts") + + try: + response_blocks = get_dm_blocks(text, workspace_id, channel_id) + if response_blocks: + client.chat_postMessage( + channel=channel_id, + blocks=response_blocks, + text=text, + thread_ts=thread_ts, + ) + + except Exception: + logger.exception("Error processing DM") + client.chat_postMessage( + channel=channel_id, + text=( + "I'm sorry, I'm having trouble processing your message right now. " + "Please try again later." + ), + thread_ts=thread_ts, + ) diff --git a/backend/apps/slack/migrations/0020_conversation_conversation_context.py b/backend/apps/slack/migrations/0020_conversation_conversation_context.py new file mode 100644 index 0000000000..b39fe8e246 --- /dev/null +++ b/backend/apps/slack/migrations/0020_conversation_conversation_context.py @@ -0,0 +1,17 @@ +# Generated by Django 5.2.6 on 2025-10-08 07:21 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("slack", "0019_conversation_is_nest_bot_assistant_enabled"), + ] + + operations = [ + migrations.AddField( + model_name="conversation", + name="conversation_context", + field=models.TextField(blank=True, verbose_name="Conversation context"), + ), + ] diff --git a/backend/apps/slack/models/conversation.py b/backend/apps/slack/models/conversation.py index 9735786c24..c0899c079f 100644 --- a/backend/apps/slack/models/conversation.py +++ b/backend/apps/slack/models/conversation.py @@ -42,6 +42,7 @@ class Meta: # Additional attributes. sync_messages = models.BooleanField(verbose_name="Sync messages", default=False) + conversation_context = models.TextField(blank=True, verbose_name="Conversation context") def __str__(self): """Channel human readable representation.""" @@ -105,3 +106,43 @@ def update_data(conversation_data, workspace, *, save=True): conversation.save() return conversation + + def add_to_context(self, user_message: str, bot_response: str | None = None) -> None: + """Add messages to the conversation context. + + Args: + user_message: The user's message to add to context. + bot_response: The bot's response to add to context. + + """ + if not self.conversation_context: + self.conversation_context = "" + + self.conversation_context = f"{self.conversation_context}{f'User: {user_message}\n'}" + + if bot_response: + self.conversation_context = f"{self.conversation_context}{f'Bot: {bot_response}\n'}" + + self.save(update_fields=["conversation_context"]) + + def get_context(self, conversation_context_limit: int | None = None) -> str: + """Get the conversation context. + + Args: + conversation_context_limit: Optional limit on number of exchanges to return. + + Returns: + The conversation context, potentially limited to recent exchanges. + + """ + if not self.conversation_context: + return "" + + if conversation_context_limit is None: + return self.conversation_context + + lines = self.conversation_context.strip().split("\n") + if len(lines) <= conversation_context_limit * 2: + return self.conversation_context + + return "\n".join(lines[-(conversation_context_limit * 2) :]) diff --git a/backend/tests/apps/slack/events/message_posted_test.py b/backend/tests/apps/slack/events/message_posted_test.py index 2f431416f9..2df5bf5e0a 100644 --- a/backend/tests/apps/slack/events/message_posted_test.py +++ b/backend/tests/apps/slack/events/message_posted_test.py @@ -88,11 +88,17 @@ def test_handle_event_ignores_thread_messages(self, message_handler): client = Mock() with patch("apps.slack.events.message_posted.Message") as mock_message: - mock_message.DoesNotExist = Exception - mock_message.objects.get.side_effect = Exception("Message not found") + mock_filter = Mock() + mock_message.objects.filter.return_value = mock_filter message_handler.handle_event(event, client) + mock_message.objects.filter.assert_called_once_with( + slack_message_id=event.get("thread_ts"), + conversation__slack_channel_id=event.get("channel"), + ) + mock_filter.update.assert_called_once_with(has_replies=True) + client.chat_postMessage.assert_not_called() def test_handle_event_conversation_not_found(self, message_handler): @@ -260,6 +266,7 @@ def test_handle_event_member_not_found(self, message_handler, conversation_mock) patch("apps.slack.events.message_posted.Conversation") as mock_conversation, patch("apps.slack.events.message_posted.Member") as mock_member, patch("apps.slack.events.message_posted.Message") as mock_message_model, + patch("apps.slack.events.message_posted.django_rq") as mock_django_rq, ): mock_conversation.objects.get.return_value = conversation_mock @@ -273,19 +280,17 @@ def test_handle_event_member_not_found(self, message_handler, conversation_mock) mock_message.id = 1 mock_message_model.update_data.return_value = mock_message - with ( - patch.object( - message_handler.question_detector, - "is_owasp_question", - return_value=True, - ), - patch("apps.slack.events.message_posted.django_rq") as mock_django_rq, - ): - mock_queue = Mock() - mock_django_rq.get_queue.return_value = mock_queue + mock_queue = Mock() + mock_django_rq.get_queue.return_value = mock_queue + with patch.object( + message_handler.question_detector, + "is_owasp_question", + return_value=True, + ): message_handler.handle_event(event, client) + mock_member.update_data.assert_called_once() mock_django_rq.get_queue.assert_called_once() def test_handle_event_empty_text(self, message_handler): diff --git a/backend/tests/apps/slack/models/conversation_test.py b/backend/tests/apps/slack/models/conversation_test.py index cdf9732cd0..324ba0dca3 100644 --- a/backend/tests/apps/slack/models/conversation_test.py +++ b/backend/tests/apps/slack/models/conversation_test.py @@ -137,3 +137,122 @@ def test_str_method(self): # Check __str__ returns the name assert str(conversation) == "test-workspace #test-channel" + + def test_add_to_context_first_message(self, mocker): + """Test adding the first message to an empty conversation context.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = "" + + save_mock = mocker.patch.object(Conversation, "save") + conversation.add_to_context("Hello, how are you?") + + assert conversation.conversation_context == "User: Hello, how are you?\n" + save_mock.assert_called_once_with(update_fields=["conversation_context"]) + + def test_add_to_context_with_bot_response(self, mocker): + """Test adding a user message and bot response to context.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = "User: Hello\nBot: Hi there!\n" + + save_mock = mocker.patch.object(Conversation, "save") + + conversation.add_to_context( + "What is OWASP?", "OWASP stands for Open Web Application Security Project." + ) + + expected_context = ( + "User: Hello\n" + "Bot: Hi there!\n" + "User: What is OWASP?\n" + "Bot: OWASP stands for Open Web Application Security Project.\n" + ) + assert conversation.conversation_context == expected_context + save_mock.assert_called_once_with(update_fields=["conversation_context"]) + + def test_add_to_context_empty_initial_context(self, mocker): + """Test adding context when conversation_context is None.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = None + + save_mock = mocker.patch.object(Conversation, "save") + + conversation.add_to_context("First message", "First response") + + expected_context = "User: First message\nBot: First response\n" + assert conversation.conversation_context == expected_context + save_mock.assert_called_once_with(update_fields=["conversation_context"]) + + def test_get_context_empty(self): + """Test getting context when conversation_context is empty.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = "" + + result = conversation.get_context() + + assert result == "" + + def test_get_context_no_limit(self): + """Test getting full context without limit.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = ( + "User: Message 1\n" + "Bot: Response 1\n" + "User: Message 2\n" + "Bot: Response 2\n" + "User: Message 3\n" + "Bot: Response 3\n" + ) + + result = conversation.get_context() + + assert result == conversation.conversation_context + + def test_get_context_with_limit_below_threshold(self): + """Test getting context with limit when total exchanges are below limit.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = ( + "User: Message 1\nBot: Response 1\nUser: Message 2\nBot: Response 2\n" + ) + + result = conversation.get_context(conversation_context_limit=3) + + assert result == conversation.conversation_context + + def test_get_context_with_limit_above_threshold(self): + """Test getting context with limit when total exchanges exceed limit.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = ( + "User: Message 1\n" + "Bot: Response 1\n" + "User: Message 2\n" + "Bot: Response 2\n" + "User: Message 3\n" + "Bot: Response 3\n" + "User: Message 4\n" + "Bot: Response 4\n" + ) + + result = conversation.get_context(conversation_context_limit=2) + + expected = "User: Message 3\nBot: Response 3\nUser: Message 4\nBot: Response 4" + assert result == expected + + def test_get_context_none_context(self): + """Test getting context when conversation_context is None.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = None + + result = conversation.get_context() + + assert result == "" + + def test_get_context_with_limit_exact_threshold(self): + """Test getting context when exchanges exactly match the limit.""" + conversation = Conversation(slack_channel_id="C12345") + conversation.conversation_context = ( + "User: Message 1\nBot: Response 1\nUser: Message 2\nBot: Response 2\n" + ) + + result = conversation.get_context(conversation_context_limit=2) + + assert result == conversation.conversation_context