Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
)

INTERRUPT_CANCEL_REPLY = "Previous tool canceled by the user."

BEDROCK_TOOL_ERROR_MESSAGE = "toolConfig field must be defined"
9 changes: 7 additions & 2 deletions app/routers/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ..services.auth import get_user_id
from ..constants import CONTEXT_PARAMETERS_SUFFIX
from ..constants import BEDROCK_TOOL_ERROR_MESSAGE

router = APIRouter()

Expand Down Expand Up @@ -104,6 +105,7 @@ async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: B
base_config["callbacks"] = [langfuse_handler]

while True:
is_bedrock_tool_error = False
try:
request = await websocket.receive_text()
request_id = str(uuid.uuid4())
Expand All @@ -127,13 +129,16 @@ async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: B
break
except Exception as e:
logging.error(f"An error occurred: {e}", exc_info=True)
is_bedrock_tool_error = BEDROCK_TOOL_ERROR_MESSAGE in str(e)
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_text(f'<error>{json.dumps({"message": str(e)})}</error>')
if not is_bedrock_tool_error:
await websocket.send_text(f'<error>{json.dumps({"message": str(e)})}</error>')
else:
break
finally:
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_text("</message>")
if not is_bedrock_tool_error:
await websocket.send_text("</message>")

async def _call_agent(
agent: CompiledStateGraph,
Expand Down
12 changes: 12 additions & 0 deletions app/services/agent/middleware/ui_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import os
from typing import Any

from langchain.agents.middleware import AgentState, after_agent
Expand All @@ -12,6 +13,11 @@
from langgraph.runtime import Runtime

from ...ui_tools.selector import create_ui_tools_selector
from ...llm import (
get_llm,
get_active_llm,
get_llm_model,
)


def ui_tools_middleware(llm: BaseChatModel, only_when_direct: bool = False):
Expand Down Expand Up @@ -81,6 +87,12 @@ async def _dispatch_ui_tools_event(
List of selected UI tools, or empty list if dispatch was skipped.
"""
try:
activeLlm = get_active_llm()

if activeLlm.lower() == "bedrock":
logging.debug("Skipping UI tools dispatch for Bedrock to avoid toolConfig errors")
return []

request_metadata = config.get("configurable", {}).get("request_metadata", {})
ui_tools_config = request_metadata.get("ui_tools", {})

Expand Down
Loading