Skip to content

Commit

Permalink
backend: Add death loop detection (#815)
Browse files Browse the repository at this point in the history
* Import death loop code - pending testing

* wip
  • Loading branch information
tianjing-li authored Oct 22, 2024
1 parent 33e85f6 commit ce5158b
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 4 deletions.
20 changes: 17 additions & 3 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from backend.config.tools import AVAILABLE_TOOLS
from backend.database_models.file import File
from backend.model_deployments.base import BaseDeployment
from backend.schemas.chat import ChatMessage, ChatRole
from backend.schemas.chat import ChatMessage, ChatRole, EventState
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.schemas.tool import Category, Tool
from backend.services.chat import check_death_loop
from backend.services.file import get_file_service
from backend.tools.utils.tools_checkers import tool_has_category

Expand All @@ -22,6 +23,13 @@
class CustomChat(BaseChat):
"""Custom chat flow not using integrations for models."""

event_state = EventState(
distances_plans=[],
distances_actions=[],
previous_plan="",
previous_action="",
)

async def chat(
self,
chat_request: CohereChatRequest,
Expand Down Expand Up @@ -62,7 +70,7 @@ async def chat(
stream = self.call_chat(self.chat_request, deployment_model, ctx, **kwargs)

async for event in stream:
result = self.handle_event(event, chat_request)
result = self.handle_event(event, chat_request, ctx)

if result:
yield result
Expand Down Expand Up @@ -102,15 +110,21 @@ def is_final_event(
)

def handle_event(
self, event: Dict[str, Any], chat_request: CohereChatRequest
self, event: Dict[str, Any], chat_request: CohereChatRequest, ctx: Context
) -> Dict[str, Any]:
# All events other than stream start and stream end are returned
if (
event["event_type"] != StreamEvent.STREAM_START
and event["event_type"] != StreamEvent.STREAM_END
and event["event_type"] != StreamEvent.TOOL_CALLS_GENERATION
):
return event

# If the event is a tool call generation, we need to check if the tool call is similar to previous tool calls
if event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION:
self.event_state = check_death_loop(event, self.event_state, ctx)
return event

# Only the first occurrence of stream start is returned
if event["event_type"] == StreamEvent.STREAM_START:
if self.is_first_start:
Expand Down
8 changes: 8 additions & 0 deletions src/backend/schemas/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, ClassVar, Dict, List, Union
from uuid import uuid4
Expand All @@ -11,6 +12,13 @@
from backend.schemas.tool import Tool, ToolCall, ToolCallDelta


@dataclass
class EventState:
distances_plans: list
distances_actions: list
previous_plan: str
previous_action: str

class ChatRole(StrEnum):
"""One of CHATBOT|USER|SYSTEM to identify who the message is coming from."""

Expand Down
70 changes: 69 additions & 1 deletion src/backend/services/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from typing import Any, AsyncGenerator, Generator, List, Union
from typing import Any, AsyncGenerator, Dict, Generator, List, Union
from uuid import uuid4

import nltk
from cohere.types import StreamedChatResponse
from fastapi import HTTPException, Request
from fastapi.encoders import jsonable_encoder
Expand Down Expand Up @@ -30,6 +31,7 @@
ChatMessage,
ChatResponseEvent,
ChatRole,
EventState,
NonStreamedChatResponse,
StreamCitationGeneration,
StreamEnd,
Expand All @@ -47,6 +49,8 @@
from backend.schemas.tool import Tool, ToolCall, ToolCallDelta
from backend.services.agent import validate_agent_exists

LOOKBACKS = [3, 5, 7]
DEATHLOOP_SIMILARITY_THRESHOLDS = [0.5, 0.7, 0.9]

def process_chat(
session: DBSessionDep,
Expand Down Expand Up @@ -962,3 +966,67 @@ def handle_stream_end(
stream_end = StreamEnd.model_validate(event | stream_end_data)
stream_event = stream_end
return stream_event, stream_end_data, response_message, document_ids_to_document

def are_previous_actions_similar(
distances: List[float], threshold: float, lookback: int
) -> bool:
return all(dist > threshold for dist in distances[-lookback:])


def check_similarity(distances: list[float], ctx: Context) -> bool:
"""
Check if the previous actions are similar to detect a potential death loop.
Args:
distances (list[float]): List of distances between previous actions.
Raises:
HTTPException: If a potential death loop is detected.
"""
logger = ctx.get_logger()

if len(distances) < min(LOOKBACKS):
return False

# EXPERIMENTAL: Check for potential death loops with different thresholds and lookbacks
for threshold in DEATHLOOP_SIMILARITY_THRESHOLDS:
for lookback in LOOKBACKS:
if are_previous_actions_similar(distances, threshold, lookback):
logger.warning(
event="[Chat] Potential death loop detected",
distances=distances,
threshold=threshold,
lookback=lookback,
)
return True

return False


def check_death_loop(
event: Dict[str, Any], event_state: EventState, ctx: Context
) -> EventState:
plan: str = event.get("text", "")
tool_calls: List = event.get("tool_calls", [])
action: str = json.dumps(tool_calls)

if event_state.previous_action:
event_state.distances_actions.append(
1
- nltk.edit_distance(event_state.previous_action, action)
/ max(len(event_state.previous_action), len(action))
)
check_similarity(event_state.distances_actions, ctx)

if event_state.previous_plan:
event_state.distances_plans.append(
1
- nltk.edit_distance(event_state.previous_plan, plan)
/ max(len(event_state.previous_plan), len(plan))
)
check_similarity(event_state.distances_plans, ctx)

event_state.previous_plan = plan
event_state.previous_action = action

return event_state
163 changes: 163 additions & 0 deletions src/backend/tests/unit/services/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@

import pytest

from backend.schemas.chat import EventState
from backend.schemas.context import Context
from backend.services.chat import (
DEATHLOOP_SIMILARITY_THRESHOLDS,
are_previous_actions_similar,
check_death_loop,
check_similarity,
)


def test_are_previous_actions_similar():
distances = [
0.5,
0.6,
0.8,
0.9,
1.0,
]

assert are_previous_actions_similar(distances, 0.7, 3)


def test_are_previous_actions_not_similar():
distances = [
0.1,
0.1,
0.1,
0.1,
]
assert not are_previous_actions_similar(distances, 0.7, 3)


def test_check_similarity():
ctx = Context()
distances = [
0.5,
0.6,
0.8,
0.9,
1.0,
]

response = check_similarity(distances, ctx)
assert response


def test_check_similarity_no_death_loop():
ctx = Context()
distances = [
0.1,
0.1,
0.1,
0.1,
]

response = check_similarity(distances, ctx)
assert not response


def test_check_similarity_not_enough_data():
ctx = Context()
distances = [
0.1,
0.1,
]

response = check_similarity(distances, ctx)
assert not response


@pytest.mark.skip(reason="We are supressing the exception while experimenting")
def test_check_death_loop_raises_on_plan():
ctx = Context()
event = {
"text": "This is also a plan",
"tool_calls": [],
}

event_state = EventState(
distances_plans=[
0.1,
0.8,
0.9,
],
distances_actions=[
0.1,
0.1,
0.1,
],
previous_plan="This is a plan",
previous_action="[]",
)

with pytest.raises(Exception):
check_death_loop(event, event_state, ctx)


@pytest.mark.skip(reason="We are supressing the exception while experimenting")
def test_check_death_loop_raises_on_action():
ctx = Context()
event = {
"text": "This is a plan",
"tool_calls": [{"tool": "tool1", "args": ["This is an argument"]}],
}

event_state = EventState(
distances_plans=[
0.1,
0.1,
0.1,
],
distances_actions=[
0.1,
0.8,
0.9,
],
previous_plan="Nothing like the previous plan",
previous_action="[{'tool': 'tool1', 'args': ['This is an argument']}]",
)

with pytest.raises(Exception):
check_death_loop(event, event_state, ctx)


def test_check_no_death_loop():
ctx = Context()
event = {
"text": "Nothing like the previous plan",
"tool_calls": [
{"tool": "different_tool", "args": ["Nothing like the previous action"]}
],
}

event_state = EventState(
distances_plans=[
0.1,
0.1,
0.1,
],
distances_actions=[
0.1,
0.1,
0.1,
],
previous_plan="This is a plan",
previous_action='[{"tool": "tool1", "args": ["This is an argument"]}]',
)

new_event_state = check_death_loop(event, event_state, ctx)
assert new_event_state.previous_plan == "Nothing like the previous plan"
assert (
new_event_state.previous_action
== '[{"tool": "different_tool", "args": ["Nothing like the previous action"]}]'
)

assert len(new_event_state.distances_plans) == 4
assert len(new_event_state.distances_actions) == 4

assert new_event_state.distances_plans[-1] < max(DEATHLOOP_SIMILARITY_THRESHOLDS)
assert new_event_state.distances_actions[-1] < max(DEATHLOOP_SIMILARITY_THRESHOLDS)

0 comments on commit ce5158b

Please sign in to comment.