Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions samples/agent/adk/contact_lookup/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
from collections import OrderedDict
from collections.abc import AsyncIterable
from dataclasses import dataclass
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -65,7 +66,8 @@ def __init__(self, base_url: str):

self._schema_managers: Dict[str, A2uiSchemaManager] = {}
self._ui_runners: Dict[str, Runner] = {}
self._parsers: Dict[str, A2uiStreamParser] = {}
self._parsers: OrderedDict[str, A2uiStreamParser] = OrderedDict()
self._max_parsers = 1000 # Max active sessions to keep in memory

for version in [VERSION_0_8, VERSION_0_9]:
schema_manager = self._build_schema_manager(version)
Expand Down Expand Up @@ -259,8 +261,12 @@ async def token_stream():
if selected_catalog:
from a2ui.core.parser.streaming import A2uiStreamParser

if session_id not in self._parsers:
if session_id in self._parsers:
self._parsers.move_to_end(session_id)
else:
self._parsers[session_id] = A2uiStreamParser(catalog=selected_catalog)
if len(self._parsers) > self._max_parsers:
self._parsers.popitem(last=False)

async for part in stream_response_to_parts(
self._parsers[session_id],
Expand Down
60 changes: 43 additions & 17 deletions samples/agent/adk/custom-components-example/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
import json
import logging
import os
from collections import OrderedDict
from collections.abc import AsyncIterable
from dataclasses import dataclass
from typing import Any, Dict, Optional

import jsonschema

from a2ui_examples import load_floor_plan_example
from a2ui.core.parser.streaming import A2uiStreamParser
from google.adk.agents import run_config
from google.adk.agents.llm_agent import LlmAgent
from google.adk.artifacts import InMemoryArtifactService
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
Expand All @@ -45,7 +48,7 @@
from a2ui.core.schema.manager import A2uiSchemaManager
from a2ui.core.parser.parser import parse_response, ResponsePart
from a2ui.basic_catalog.provider import BasicCatalog
from a2ui.a2a import create_a2ui_part, get_a2ui_agent_extension, parse_response_to_parts
from a2ui.a2a import create_a2ui_part, get_a2ui_agent_extension, parse_response_to_parts, stream_response_to_parts

logger = logging.getLogger(__name__)

Expand All @@ -63,6 +66,8 @@ def __init__(self, base_url: str):

self._schema_managers: Dict[str, A2uiSchemaManager] = {}
self._ui_runners: Dict[str, Runner] = {}
self._parsers: OrderedDict[str, A2uiStreamParser] = OrderedDict()
self._max_parsers = 1000 # Max active sessions to keep in memory

for version in [VERSION_0_8, VERSION_0_9]:
schema_manager = self._build_schema_manager(version)
Expand Down Expand Up @@ -382,27 +387,48 @@ async def stream(
current_message = types.Content(
role="user", parts=[types.Part.from_text(text=current_query_text)]
)
final_response_content = None

async for event in runner.run_async(
user_id=self._user_id,
session_id=session.id,
new_message=current_message,
):
logger.info(f"Event from runner: {event}")
if event.is_final_response():
if event.content and event.content.parts and event.content.parts[0].text:
final_response_content = "\n".join(
[p.text for p in event.content.parts if p.text]
)
break # Got the final response, stop consuming events
full_content_list = []

async def token_stream():
async for event in runner.run_async(
user_id=self._user_id,
session_id=session.id,
run_config=run_config.RunConfig(
streaming_mode=run_config.StreamingMode.SSE
),
new_message=current_message,
):
if event.content and event.content.parts:
for p in event.content.parts:
if p.text:
full_content_list.append(p.text)
yield p.text

if selected_catalog:
if session_id in self._parsers:
self._parsers.move_to_end(session_id)
else:
logger.info(f"Intermediate event: {event}")
# Yield intermediate updates on every attempt
self._parsers[session_id] = A2uiStreamParser(catalog=selected_catalog)
if len(self._parsers) > self._max_parsers:
self._parsers.popitem(last=False)

async for part in stream_response_to_parts(
self._parsers[session_id],
token_stream(),
):
yield {
"is_task_complete": False,
"updates": self.get_processing_message(),
"parts": [part],
}
else:
async for token in token_stream():
yield {
"is_task_complete": False,
"updates": token,
}

final_response_content = "".join(full_content_list) if full_content_list else None

if final_response_content is None:
logger.warning(
Expand Down
12 changes: 8 additions & 4 deletions samples/agent/adk/custom-components-example/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,14 @@ async def execute(
):
is_task_complete = item["is_task_complete"]
if not is_task_complete:
await updater.update_status(
TaskState.working,
new_agent_text_message(item["updates"], task.context_id, task.id),
)
message = None
if "parts" in item:
message = new_agent_parts_message(item["parts"], task.context_id, task.id)
elif "updates" in item:
message = new_agent_text_message(item["updates"], task.context_id, task.id)

if message:
await updater.update_status(TaskState.working, message)
continue

final_state = TaskState.input_required # Default
Expand Down
10 changes: 8 additions & 2 deletions samples/agent/adk/restaurant_finder/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
from collections import OrderedDict
from collections.abc import AsyncIterable
from typing import Any, Optional, Dict

Expand Down Expand Up @@ -67,7 +68,8 @@ def __init__(self, base_url: str):

self._schema_managers: Dict[str, A2uiSchemaManager] = {}
self._ui_runners: Dict[str, Runner] = {}
self._parsers = {}
self._parsers = OrderedDict()
self._max_parsers = 1000 # Max active sessions to keep in memory

for version in [VERSION_0_8, VERSION_0_9]:
schema_manager = self._build_schema_manager(version)
Expand Down Expand Up @@ -255,8 +257,12 @@ async def token_stream():
if selected_catalog:
from a2ui.core.parser.streaming import A2uiStreamParser

if session_id not in self._parsers:
if session_id in self._parsers:
self._parsers.move_to_end(session_id)
else:
self._parsers[session_id] = A2uiStreamParser(catalog=selected_catalog)
if len(self._parsers) > self._max_parsers:
self._parsers.popitem(last=False)

async for part in stream_response_to_parts(
self._parsers[session_id],
Expand Down
93 changes: 81 additions & 12 deletions samples/client/lit/custom-components-example/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export class A2UIClient {
}

async send(
message: v0_8.Types.A2UIClientEventMessage
message: v0_8.Types.A2UIClientEventMessage,
onChunk?: (messages: v0_8.Types.ServerToClientMessage[]) => void
): Promise<v0_8.Types.ServerToClientMessage[]> {
const catalog = componentRegistry.getInlineCatalog();
const finalMessage = {
Expand All @@ -56,22 +57,90 @@ export class A2UIClient {
method: "POST",
});

if (response.ok) {
const data = (await response.json()) as A2AServerPayload;
const messages: v0_8.Types.ServerToClientMessage[] = [];
if ("error" in data) {
throw new Error(data.error);
} else {
for (const item of data) {
if (item.kind === "text") continue;
messages.push(item.data);
if (!response.ok) {
const error = (await response.json()) as { error: string };
throw new Error(error.error);
}

const contentType = response.headers.get("content-type");
const messages: v0_8.Types.ServerToClientMessage[] = [];

if (contentType?.includes("text/event-stream")) {
const reader = response.body?.getReader();
if (!reader) throw new Error("No response body");
const decoder = new TextDecoder();
let buffer = "";

while (true) {
const { done, value } = await reader.read();
if (done) break;

buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n\n");
buffer = lines.pop() || "";

for (const line of lines) {
if (line.startsWith("data: ")) {
const jsonStr = line.replace(/^data:\s*/, "");
try {
const parsed = JSON.parse(jsonStr);
if ("error" in parsed) {
throw new Error(parsed.error);
} else {
const chunkMessages = this.#extractMessages(parsed);
if (chunkMessages.length > 0) {
messages.push(...chunkMessages);
onChunk?.(chunkMessages);
}
}
} catch (e) {
console.error("Error parsing SSE data:", e, jsonStr);
}
}
}
}
return messages;
}

const error = (await response.json()) as { error: string };
throw new Error(error.error);
const data = (await response.json()) as any;
if (data && typeof data === 'object' && "error" in data) {
throw new Error(data.error);
} else {
const extracted = this.#extractMessages(data);
messages.push(...extracted);
if (messages.length > 0) {
onChunk?.(messages);
}
}
return messages;
}

#extractMessages(data: any): v0_8.Types.ServerToClientMessage[] {
let items: any[] = [];
if (data.messages && Array.isArray(data.messages)) {
items = data.messages;
} else {
items = Array.isArray(data)
? data
: (data.kind === "message" && Array.isArray(data.parts) ? data.parts : [data]);
}

const messages: v0_8.Types.ServerToClientMessage[] = [];
for (const item of items) {
if (item.kind === "message" && Array.isArray(item.parts)) {
for (const part of item.parts) {
if (part.data) {
messages.push(part.data);
}
}
} else {
if (item.kind === "text") continue;
if (item.data) {
messages.push(item.data);
}
}
}
return messages;
}
}
registerContactComponents();
Loading
Loading