Skip to content

Commit

Permalink
fix: starter pack live api async tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
eliasecchig committed Jan 31, 2025
1 parent d9d2276 commit 2ad194b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def retrieve_docs(query: str) -> Dict[str, str]:
"""
Retrieves pre-formatted documents about MLOps (Machine Learning Operations),
Gen AI lifecycle, and production deployment best practices.
When using this you should warn the user it might take few seconds.
Args:
query: Search query string related to MLOps, Gen AI, or production deployment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from google.cloud import logging as google_cloud_logging
from google.genai import types
from google.genai.types import LiveServerToolCall
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from websockets.exceptions import ConnectionClosedError


Expand Down Expand Up @@ -90,18 +90,11 @@ async def receive_from_client(self) -> None:
def _get_func(self, action_label: str) -> Optional[Callable]:
"""Get the tool function for a given action label."""
return None if action_label == "" else self.tool_functions.get(action_label)

async def _handle_tool_call(
self, session: Any, tool_call: LiveServerToolCall
) -> None:
"""Process tool calls from Gemini and send back responses."""
# Create a task for handling the tool call
asyncio.create_task(self._process_tool_call(session, tool_call))

async def _process_tool_call(
self, session: Any, tool_call: LiveServerToolCall
) -> None:
"""Process tool calls in a separate task."""
try:
for fc in tool_call.function_calls:
logging.debug(f"Calling tool function: {fc.name} with args: {fc.args}")
Expand All @@ -124,10 +117,15 @@ async def receive_from_gemini(self) -> None:
await self.websocket.send_bytes(result)

# Process any tool calls asynchronously
message = types.LiveServerMessage.model_validate(json.loads(result))
try:
message = types.LiveServerMessage.model_validate(json.loads(result))
except ValidationError:
continue

if message.tool_call:
tool_call = LiveServerToolCall.model_validate(message.tool_call)
await self._handle_tool_call(self.session, tool_call)
# Create task for handling tool call
asyncio.create_task(self._handle_tool_call(self.session, tool_call))
except Exception as e:
logging.error(f"Error receiving from Gemini: {str(e)}")

Expand Down

0 comments on commit 2ad194b

Please sign in to comment.