From 2ad194bfb9441ac064a81b0a0abe863696e20645 Mon Sep 17 00:00:00 2001 From: eliasecchig Date: Fri, 31 Jan 2025 19:06:11 +0100 Subject: [PATCH] fix: starter pack live api async tool calls --- .../multimodal_live_agent/app/agent.py | 2 +- .../multimodal_live_agent/app/server.py | 20 +++++++++---------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/patterns/multimodal_live_agent/app/agent.py b/gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/patterns/multimodal_live_agent/app/agent.py index 1720ad739b8..0f88e9c3e8c 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/patterns/multimodal_live_agent/app/agent.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/patterns/multimodal_live_agent/app/agent.py @@ -53,7 +53,7 @@ async def retrieve_docs(query: str) -> Dict[str, str]: """ Retrieves pre-formatted documents about MLOps (Machine Learning Operations), Gen AI lifecycle, and production deployment best practices. - + When using this you should warn the user it might take few seconds. Args: query: Search query string related to MLOps, Gen AI, or production deployment. diff --git a/gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/patterns/multimodal_live_agent/app/server.py b/gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/patterns/multimodal_live_agent/app/server.py index 7adb8c736f8..01fc0956cea 100644 --- a/gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/patterns/multimodal_live_agent/app/server.py +++ b/gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/patterns/multimodal_live_agent/app/server.py @@ -25,7 +25,7 @@ from google.cloud import logging as google_cloud_logging from google.genai import types from google.genai.types import LiveServerToolCall -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from websockets.exceptions import ConnectionClosedError @@ -90,18 +90,11 @@ async def receive_from_client(self) -> None: def _get_func(self, action_label: str) -> Optional[Callable]: """Get the tool function for a given action label.""" return None if action_label == "" else self.tool_functions.get(action_label) - + async def _handle_tool_call( self, session: Any, tool_call: LiveServerToolCall ) -> None: """Process tool calls from Gemini and send back responses.""" - # Create a task for handling the tool call - asyncio.create_task(self._process_tool_call(session, tool_call)) - - async def _process_tool_call( - self, session: Any, tool_call: LiveServerToolCall - ) -> None: - """Process tool calls in a separate task.""" try: for fc in tool_call.function_calls: logging.debug(f"Calling tool function: {fc.name} with args: {fc.args}") @@ -124,10 +117,15 @@ async def receive_from_gemini(self) -> None: await self.websocket.send_bytes(result) # Process any tool calls asynchronously - message = types.LiveServerMessage.model_validate(json.loads(result)) + try: + message = types.LiveServerMessage.model_validate(json.loads(result)) + except ValidationError: + continue + if message.tool_call: tool_call = LiveServerToolCall.model_validate(message.tool_call) - await self._handle_tool_call(self.session, tool_call) + # Create task for handling tool call + asyncio.create_task(self._handle_tool_call(self.session, tool_call)) except Exception as e: logging.error(f"Error receiving from Gemini: {str(e)}")