From ce5158b68d7dc427c7453719f6ca8d34a52291e6 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Tue, 22 Oct 2024 12:10:06 -0400 Subject: [PATCH] backend: Add death loop detection (#815) * Import death loop code - pending testing * wip --- src/backend/chat/custom/custom.py | 20 ++- src/backend/schemas/chat.py | 8 + src/backend/services/chat.py | 70 +++++++- src/backend/tests/unit/services/test_chat.py | 163 +++++++++++++++++++ 4 files changed, 257 insertions(+), 4 deletions(-) create mode 100644 src/backend/tests/unit/services/test_chat.py diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 5b29bb2bc..16ce04e3b 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -9,10 +9,11 @@ from backend.config.tools import AVAILABLE_TOOLS from backend.database_models.file import File from backend.model_deployments.base import BaseDeployment -from backend.schemas.chat import ChatMessage, ChatRole +from backend.schemas.chat import ChatMessage, ChatRole, EventState from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context from backend.schemas.tool import Category, Tool +from backend.services.chat import check_death_loop from backend.services.file import get_file_service from backend.tools.utils.tools_checkers import tool_has_category @@ -22,6 +23,13 @@ class CustomChat(BaseChat): """Custom chat flow not using integrations for models.""" + event_state = EventState( + distances_plans=[], + distances_actions=[], + previous_plan="", + previous_action="", + ) + async def chat( self, chat_request: CohereChatRequest, @@ -62,7 +70,7 @@ async def chat( stream = self.call_chat(self.chat_request, deployment_model, ctx, **kwargs) async for event in stream: - result = self.handle_event(event, chat_request) + result = self.handle_event(event, chat_request, ctx) if result: yield result @@ -102,15 +110,21 @@ def is_final_event( ) def handle_event( - self, event: Dict[str, Any], chat_request: CohereChatRequest + self, event: Dict[str, Any], chat_request: CohereChatRequest, ctx: Context ) -> Dict[str, Any]: # All events other than stream start and stream end are returned if ( event["event_type"] != StreamEvent.STREAM_START and event["event_type"] != StreamEvent.STREAM_END + and event["event_type"] != StreamEvent.TOOL_CALLS_GENERATION ): return event + # If the event is a tool call generation, we need to check if the tool call is similar to previous tool calls + if event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION: + self.event_state = check_death_loop(event, self.event_state, ctx) + return event + # Only the first occurrence of stream start is returned if event["event_type"] == StreamEvent.STREAM_START: if self.is_first_start: diff --git a/src/backend/schemas/chat.py b/src/backend/schemas/chat.py index 185d1ddf8..669decaf9 100644 --- a/src/backend/schemas/chat.py +++ b/src/backend/schemas/chat.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from enum import StrEnum from typing import Any, ClassVar, Dict, List, Union from uuid import uuid4 @@ -11,6 +12,13 @@ from backend.schemas.tool import Tool, ToolCall, ToolCallDelta +@dataclass +class EventState: + distances_plans: list + distances_actions: list + previous_plan: str + previous_action: str + class ChatRole(StrEnum): """One of CHATBOT|USER|SYSTEM to identify who the message is coming from.""" diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index 276f190c1..bf1d560a7 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -1,7 +1,8 @@ import json -from typing import Any, AsyncGenerator, Generator, List, Union +from typing import Any, AsyncGenerator, Dict, Generator, List, Union from uuid import uuid4 +import nltk from cohere.types import StreamedChatResponse from fastapi import HTTPException, Request from fastapi.encoders import jsonable_encoder @@ -30,6 +31,7 @@ ChatMessage, ChatResponseEvent, ChatRole, + EventState, NonStreamedChatResponse, StreamCitationGeneration, StreamEnd, @@ -47,6 +49,8 @@ from backend.schemas.tool import Tool, ToolCall, ToolCallDelta from backend.services.agent import validate_agent_exists +LOOKBACKS = [3, 5, 7] +DEATHLOOP_SIMILARITY_THRESHOLDS = [0.5, 0.7, 0.9] def process_chat( session: DBSessionDep, @@ -962,3 +966,67 @@ def handle_stream_end( stream_end = StreamEnd.model_validate(event | stream_end_data) stream_event = stream_end return stream_event, stream_end_data, response_message, document_ids_to_document + +def are_previous_actions_similar( + distances: List[float], threshold: float, lookback: int +) -> bool: + return all(dist > threshold for dist in distances[-lookback:]) + + +def check_similarity(distances: list[float], ctx: Context) -> bool: + """ + Check if the previous actions are similar to detect a potential death loop. + + Args: + distances (list[float]): List of distances between previous actions. + + Raises: + HTTPException: If a potential death loop is detected. + """ + logger = ctx.get_logger() + + if len(distances) < min(LOOKBACKS): + return False + + # EXPERIMENTAL: Check for potential death loops with different thresholds and lookbacks + for threshold in DEATHLOOP_SIMILARITY_THRESHOLDS: + for lookback in LOOKBACKS: + if are_previous_actions_similar(distances, threshold, lookback): + logger.warning( + event="[Chat] Potential death loop detected", + distances=distances, + threshold=threshold, + lookback=lookback, + ) + return True + + return False + + +def check_death_loop( + event: Dict[str, Any], event_state: EventState, ctx: Context +) -> EventState: + plan: str = event.get("text", "") + tool_calls: List = event.get("tool_calls", []) + action: str = json.dumps(tool_calls) + + if event_state.previous_action: + event_state.distances_actions.append( + 1 + - nltk.edit_distance(event_state.previous_action, action) + / max(len(event_state.previous_action), len(action)) + ) + check_similarity(event_state.distances_actions, ctx) + + if event_state.previous_plan: + event_state.distances_plans.append( + 1 + - nltk.edit_distance(event_state.previous_plan, plan) + / max(len(event_state.previous_plan), len(plan)) + ) + check_similarity(event_state.distances_plans, ctx) + + event_state.previous_plan = plan + event_state.previous_action = action + + return event_state diff --git a/src/backend/tests/unit/services/test_chat.py b/src/backend/tests/unit/services/test_chat.py new file mode 100644 index 000000000..20f5c42e0 --- /dev/null +++ b/src/backend/tests/unit/services/test_chat.py @@ -0,0 +1,163 @@ + +import pytest + +from backend.schemas.chat import EventState +from backend.schemas.context import Context +from backend.services.chat import ( + DEATHLOOP_SIMILARITY_THRESHOLDS, + are_previous_actions_similar, + check_death_loop, + check_similarity, +) + + +def test_are_previous_actions_similar(): + distances = [ + 0.5, + 0.6, + 0.8, + 0.9, + 1.0, + ] + + assert are_previous_actions_similar(distances, 0.7, 3) + + +def test_are_previous_actions_not_similar(): + distances = [ + 0.1, + 0.1, + 0.1, + 0.1, + ] + assert not are_previous_actions_similar(distances, 0.7, 3) + + +def test_check_similarity(): + ctx = Context() + distances = [ + 0.5, + 0.6, + 0.8, + 0.9, + 1.0, + ] + + response = check_similarity(distances, ctx) + assert response + + +def test_check_similarity_no_death_loop(): + ctx = Context() + distances = [ + 0.1, + 0.1, + 0.1, + 0.1, + ] + + response = check_similarity(distances, ctx) + assert not response + + +def test_check_similarity_not_enough_data(): + ctx = Context() + distances = [ + 0.1, + 0.1, + ] + + response = check_similarity(distances, ctx) + assert not response + + +@pytest.mark.skip(reason="We are supressing the exception while experimenting") +def test_check_death_loop_raises_on_plan(): + ctx = Context() + event = { + "text": "This is also a plan", + "tool_calls": [], + } + + event_state = EventState( + distances_plans=[ + 0.1, + 0.8, + 0.9, + ], + distances_actions=[ + 0.1, + 0.1, + 0.1, + ], + previous_plan="This is a plan", + previous_action="[]", + ) + + with pytest.raises(Exception): + check_death_loop(event, event_state, ctx) + + +@pytest.mark.skip(reason="We are supressing the exception while experimenting") +def test_check_death_loop_raises_on_action(): + ctx = Context() + event = { + "text": "This is a plan", + "tool_calls": [{"tool": "tool1", "args": ["This is an argument"]}], + } + + event_state = EventState( + distances_plans=[ + 0.1, + 0.1, + 0.1, + ], + distances_actions=[ + 0.1, + 0.8, + 0.9, + ], + previous_plan="Nothing like the previous plan", + previous_action="[{'tool': 'tool1', 'args': ['This is an argument']}]", + ) + + with pytest.raises(Exception): + check_death_loop(event, event_state, ctx) + + +def test_check_no_death_loop(): + ctx = Context() + event = { + "text": "Nothing like the previous plan", + "tool_calls": [ + {"tool": "different_tool", "args": ["Nothing like the previous action"]} + ], + } + + event_state = EventState( + distances_plans=[ + 0.1, + 0.1, + 0.1, + ], + distances_actions=[ + 0.1, + 0.1, + 0.1, + ], + previous_plan="This is a plan", + previous_action='[{"tool": "tool1", "args": ["This is an argument"]}]', + ) + + new_event_state = check_death_loop(event, event_state, ctx) + assert new_event_state.previous_plan == "Nothing like the previous plan" + assert ( + new_event_state.previous_action + == '[{"tool": "different_tool", "args": ["Nothing like the previous action"]}]' + ) + + assert len(new_event_state.distances_plans) == 4 + assert len(new_event_state.distances_actions) == 4 + + assert new_event_state.distances_plans[-1] < max(DEATHLOOP_SIMILARITY_THRESHOLDS) + assert new_event_state.distances_actions[-1] < max(DEATHLOOP_SIMILARITY_THRESHOLDS)