Skip to content

Commit

Permalink
Merge branch 'main' into agentspace
Browse files Browse the repository at this point in the history
  • Loading branch information
holtskinner authored Jan 31, 2025
2 parents 748d6a9 + dc0c8bc commit b2e8075
Show file tree
Hide file tree
Showing 9 changed files with 2,572 additions and 440 deletions.
2 changes: 2 additions & 0 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ imshow
imwrite
inbox
informati
inpaint
iostream
ipd
iphoneos
Expand Down Expand Up @@ -807,6 +808,7 @@ lakecolor
landcolor
landuse
langgraph
latte
lego
lenzing
levelname
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
retriever = vector_store.as_retriever()


def retrieve_docs(query: str) -> Dict[str, str]:
async def retrieve_docs(query: str) -> Dict[str, str]:
"""
Retrieves pre-formatted documents about MLOps (Machine Learning Operations),
Gen AI lifecycle, and production deployment best practices.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from google.cloud import logging as google_cloud_logging
from google.genai import types
from google.genai.types import LiveServerToolCall
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from websockets.exceptions import ConnectionClosedError

app = FastAPI()
Expand Down Expand Up @@ -93,35 +93,42 @@ def _get_func(self, action_label: str) -> Optional[Callable]:
async def _handle_tool_call(
self, session: Any, tool_call: LiveServerToolCall
) -> None:
"""Process tool calls from Gemini and send back responses.
Args:
session: The Gemini session
tool_call: Tool call request from Gemini
"""
for fc in tool_call.function_calls:
logging.debug(f"Calling tool function: {fc.name} with args: {fc.args}")
response = self._get_func(fc.name)(**fc.args)
tool_response = types.LiveClientToolResponse(
function_responses=[
types.FunctionResponse(name=fc.name, id=fc.id, response=response)
]
)
logging.debug(f"Tool response: {tool_response}")
await session.send(tool_response)
"""Process tool calls from Gemini and send back responses."""
try:
for fc in tool_call.function_calls:
logging.debug(f"Calling tool function: {fc.name} with args: {fc.args}")
response = await self._get_func(fc.name)(**fc.args)
tool_response = types.LiveClientToolResponse(
function_responses=[
types.FunctionResponse(
name=fc.name, id=fc.id, response=response
)
]
)
logging.debug(f"Tool response: {tool_response}")
await session.send(tool_response)
except Exception as e:
logging.error(f"Error processing tool call: {str(e)}")

async def receive_from_gemini(self) -> None:
"""Listen for and process messages from Gemini.
Continuously receives messages from Gemini, forwards them to the client,
and handles any tool calls. Handles connection errors gracefully.
"""
while result := await self.session._ws.recv(decode=False):
await self.websocket.send_bytes(result)
message = types.LiveServerMessage.model_validate(json.loads(result))
if message.tool_call:
tool_call = LiveServerToolCall.model_validate(message.tool_call)
await self._handle_tool_call(self.session, tool_call)
"""Listen for and process messages from Gemini."""
try:
while result := await self.session._ws.recv(decode=False):
# Send the message to the client immediately
await self.websocket.send_bytes(result)

# Process any tool calls asynchronously
try:
message = types.LiveServerMessage.model_validate(json.loads(result))
except ValidationError:
continue

if message.tool_call:
tool_call = LiveServerToolCall.model_validate(message.tool_call)
# Create task for handling tool call
asyncio.create_task(self._handle_tool_call(self.session, tool_call))
except Exception as e:
logging.error(f"Error receiving from Gemini: {str(e)}")


def get_connect_and_run_callable(websocket: WebSocket) -> Callable:
Expand Down
Loading

0 comments on commit b2e8075

Please sign in to comment.